import json import queue import threading import time from flask import Flask, Response, render_template, request, stream_with_context from pydantic import ValidationError from core import (is_deterministic, set_event_queue, clear_event_queue, push_event, register_child_thread) from legal_graph import LEGAL_GRAPH from legal_schemas import ConversationState, UserRequest app = Flask(__name__) def _run_pipeline(text: str, country: str, q: queue.Queue) -> None: """Run the LangGraph pipeline and stream events into the queue.""" root_tid = threading.get_ident() set_event_queue(q) import concurrent.futures as _cf _orig_submit = _cf.ThreadPoolExecutor.submit def _patched_submit(self, fn, *args, **kwargs): def _wrapped(*a, **kw): register_child_thread(root_tid) return fn(*a, **kw) return _orig_submit(self, _wrapped, *args, **kwargs) _cf.ThreadPoolExecutor.submit = _patched_submit try: request_obj = UserRequest(text=text, country=country or "not specified") state = ConversationState(user_request=request_obj) push_event("pipeline_start", {"status": "initializing"}) raw = LEGAL_GRAPH.invoke(state) final = ConversationState(**raw) if isinstance(raw, dict) else raw # Construct the final result payload with RAG references result = { "intent": final.intent.model_dump() if final.intent else None, "document": { "type": final.document_classification.document_type, "confidence": round(final.document_classification.confidence * 100, 1) } if final.document_classification else None, "explanation": final.explanation.model_dump() if final.explanation else None, "critic": final.critic_review.model_dump() if final.critic_review else None, "action_plan": [s.model_dump() for s in final.action_plan.steps] if final.action_plan else [], "risk": final.risk_assessment.model_dump() if final.risk_assessment else None, "case_saved": final.case_saved, "reflection_count": final.reflection_count, "references": final.rag_context # Explicitly passing RAG context for confidence } q.put({"type": "result", "data": result}) except Exception as exc: q.put({"type": "error", "data": {"message": str(exc)}}) finally: q.put(None) _cf.ThreadPoolExecutor.submit = _orig_submit clear_event_queue() @app.route("/") def index(): return render_template("index.html") @app.route("/analyze", methods=["GET"]) def analyze(): text = request.args.get("text", "") country = request.args.get("country", "not specified") q = queue.Queue() threading.Thread(target=_run_pipeline, args=(text, country, q), daemon=True).start() def generate(): while True: try: # 15s timeout to send a ping and prevent browser timeouts item = q.get(timeout=15) if item is None: yield "event: done\ndata: {}\n\n" break yield f"event: {item['type']}\ndata: {json.dumps(item['data'], ensure_ascii=False)}\n\n" except queue.Empty: yield "event: ping\ndata: {}\n\n" return Response( stream_with_context(generate()), mimetype="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"} ) if __name__ == "__main__": app.run(debug=False, host="0.0.0.0", port=7860, threaded=True)