File size: 3,650 Bytes
327bfe3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import json
import queue
import threading
import time
from flask import Flask, Response, render_template, request, stream_with_context
from pydantic import ValidationError

from core import (is_deterministic, set_event_queue, clear_event_queue,
                  push_event, register_child_thread)
from legal_graph import LEGAL_GRAPH
from legal_schemas import ConversationState, UserRequest

app = Flask(__name__)

def _run_pipeline(text: str, country: str, q: queue.Queue) -> None:
    """Run the LangGraph pipeline and stream events into the queue."""
    root_tid = threading.get_ident()
    set_event_queue(q)

    import concurrent.futures as _cf
    _orig_submit = _cf.ThreadPoolExecutor.submit
    def _patched_submit(self, fn, *args, **kwargs):
        def _wrapped(*a, **kw):
            register_child_thread(root_tid)
            return fn(*a, **kw)
        return _orig_submit(self, _wrapped, *args, **kwargs)
    _cf.ThreadPoolExecutor.submit = _patched_submit

    try:
        request_obj = UserRequest(text=text, country=country or "not specified")
        state = ConversationState(user_request=request_obj)
        push_event("pipeline_start", {"status": "initializing"})
        
        raw = LEGAL_GRAPH.invoke(state)
        final = ConversationState(**raw) if isinstance(raw, dict) else raw
        
        # Construct the final result payload with RAG references
        result = {
            "intent": final.intent.model_dump() if final.intent else None,
            "document": {
                "type": final.document_classification.document_type,
                "confidence": round(final.document_classification.confidence * 100, 1)
            } if final.document_classification else None,
            "explanation": final.explanation.model_dump() if final.explanation else None,
            "critic": final.critic_review.model_dump() if final.critic_review else None,
            "action_plan": [s.model_dump() for s in final.action_plan.steps] if final.action_plan else [],
            "risk": final.risk_assessment.model_dump() if final.risk_assessment else None,
            "case_saved": final.case_saved,
            "reflection_count": final.reflection_count,
            "references": final.rag_context  # Explicitly passing RAG context for confidence
        }
        q.put({"type": "result", "data": result})

    except Exception as exc:
        q.put({"type": "error", "data": {"message": str(exc)}})
    finally:
        q.put(None)
        _cf.ThreadPoolExecutor.submit = _orig_submit
        clear_event_queue()

@app.route("/")
def index():
    return render_template("index.html")

@app.route("/analyze", methods=["GET"])
def analyze():
    text = request.args.get("text", "")
    country = request.args.get("country", "not specified")
    
    q = queue.Queue()
    threading.Thread(target=_run_pipeline, args=(text, country, q), daemon=True).start()

    def generate():
        while True:
            try:
                # 15s timeout to send a ping and prevent browser timeouts
                item = q.get(timeout=15)
                if item is None:
                    yield "event: done\ndata: {}\n\n"
                    break
                yield f"event: {item['type']}\ndata: {json.dumps(item['data'], ensure_ascii=False)}\n\n"
            except queue.Empty:
                yield "event: ping\ndata: {}\n\n"

    return Response(
        stream_with_context(generate()), 
        mimetype="text/event-stream",
        headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}
    )

if __name__ == "__main__":
    app.run(debug=False, host="0.0.0.0", port=7860, threaded=True)