Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
| 1 |
from fastapi import FastAPI, HTTPException
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
|
|
|
| 3 |
from pydantic import BaseModel
|
| 4 |
from typing import Dict, Any, Optional, List
|
| 5 |
from contextlib import asynccontextmanager
|
|
@@ -9,16 +12,24 @@ from utils.config import settings
|
|
| 9 |
# Import the LangGraph app
|
| 10 |
from core.graph_workflow import app as graph_app
|
| 11 |
|
|
|
|
|
|
|
| 12 |
@asynccontextmanager
|
| 13 |
async def lifespan(app: FastAPI):
|
| 14 |
# Startup: Ensure containers are running
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
|
|
|
| 22 |
yield
|
| 23 |
|
| 24 |
# Shutdown: Stop containers
|
|
@@ -86,30 +97,65 @@ def clear_api_cache():
|
|
| 86 |
|
| 87 |
@app.post("/search")
|
| 88 |
async def search_amendments(request: SearchRequest):
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
@app.get("/graph/explore")
|
| 115 |
def get_graph_exploration(limit: int = 1000):
|
|
|
|
| 1 |
from fastapi import FastAPI, HTTPException
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from fastapi.responses import StreamingResponse
|
| 4 |
+
import json
|
| 5 |
+
import asyncio
|
| 6 |
from pydantic import BaseModel
|
| 7 |
from typing import Dict, Any, Optional, List
|
| 8 |
from contextlib import asynccontextmanager
|
|
|
|
| 12 |
# Import the LangGraph app
|
| 13 |
from core.graph_workflow import app as graph_app
|
| 14 |
|
| 15 |
+
print("🚀 Nexus Lex Backend: Initializing...")
|
| 16 |
+
|
| 17 |
@asynccontextmanager
|
| 18 |
async def lifespan(app: FastAPI):
|
| 19 |
# Startup: Ensure containers are running
|
| 20 |
+
if settings.QDRANT_ENDPOINT:
|
| 21 |
+
print("📍 [STARTUP] Cloud Qdrant target detected.")
|
| 22 |
+
else:
|
| 23 |
+
print("⚠️ [STARTUP] No Qdrant endpoint found. Attempting local Docker lifecycle...")
|
| 24 |
+
start_qdrant_docker()
|
| 25 |
+
|
| 26 |
+
if settings.NEO4J_URI:
|
| 27 |
+
print("📍 [STARTUP] Cloud Neo4j target detected.")
|
| 28 |
+
else:
|
| 29 |
+
print("⚠️ [STARTUP] No Neo4j URI found. Attempting local Docker lifecycle...")
|
| 30 |
+
start_neo4j_docker()
|
| 31 |
|
| 32 |
+
print("✅ [STARTUP] Initialization sequence complete. App is live.")
|
| 33 |
yield
|
| 34 |
|
| 35 |
# Shutdown: Stop containers
|
|
|
|
| 97 |
|
| 98 |
@app.post("/search")
|
| 99 |
async def search_amendments(request: SearchRequest):
|
| 100 |
+
async def event_generator():
|
| 101 |
+
try:
|
| 102 |
+
initial_state = {"query": request.query, "retry_count": 0}
|
| 103 |
+
final_state = {}
|
| 104 |
+
|
| 105 |
+
# Use astream to get updates after every node
|
| 106 |
+
# stream_mode="updates" gives us the incremental changes from each node
|
| 107 |
+
async for event in graph_app.astream(initial_state, stream_mode="updates"):
|
| 108 |
+
for node_name, state_update in event.items():
|
| 109 |
+
# Update local state tracker
|
| 110 |
+
final_state.update(state_update)
|
| 111 |
+
|
| 112 |
+
# Yield a trace event for the UI terminal
|
| 113 |
+
trace_msg = f">> {node_name.replace('_', ' ').title()} sequence complete."
|
| 114 |
+
|
| 115 |
+
# Add descriptive context if available
|
| 116 |
+
if node_name == "classify":
|
| 117 |
+
entities = state_update.get("classification", {}).get("entities", {})
|
| 118 |
+
arts = entities.get("articles", [])
|
| 119 |
+
if arts: trace_msg = f">> Classified: Targeting Articles {', '.join(map(str, arts))}"
|
| 120 |
+
elif node_name == "graph_plan":
|
| 121 |
+
results = state_update.get("graph_results", [])
|
| 122 |
+
trace_msg = f">> Neo4j: Retrieved {len(results)} legal relationships."
|
| 123 |
+
elif node_name == "fetch_vector":
|
| 124 |
+
chunks = state_update.get("raw_chunks", [])
|
| 125 |
+
trace_msg = f">> Qdrant: Retrieved {len(chunks)} semantic text blocks."
|
| 126 |
+
elif node_name == "reason":
|
| 127 |
+
trace_msg = ">> Legal Reasoner: Synthesizing final opinion..."
|
| 128 |
+
|
| 129 |
+
yield json.dumps({"type": "trace", "message": trace_msg}) + "\n"
|
| 130 |
+
# Small sleep to ensure the UI feels sequential and "live"
|
| 131 |
+
await asyncio.sleep(0.1)
|
| 132 |
+
|
| 133 |
+
# We need to ensure we have the full final state for the result
|
| 134 |
+
# LangGraph's astream "updates" doesn't necessarily give the final dictionary in one go
|
| 135 |
+
# So we format the accumulated final_state
|
| 136 |
+
|
| 137 |
+
result_payload = {
|
| 138 |
+
"query": request.query,
|
| 139 |
+
"answer": final_state.get("draft_answer", {}).get("answer"),
|
| 140 |
+
"constitutional_status": final_state.get("draft_answer", {}).get("constitutional_status"),
|
| 141 |
+
"confidence": final_state.get("critique", {}).get("final_confidence"),
|
| 142 |
+
"sources": final_state.get("draft_answer", {}).get("sources"),
|
| 143 |
+
"quality_grade": final_state.get("critique", {}).get("quality_grade"),
|
| 144 |
+
"execution_trace": final_state.get("trace", []),
|
| 145 |
+
"graph_nodes": final_state.get("graph_results", []),
|
| 146 |
+
"vector_chunks": [
|
| 147 |
+
{"id": c["id"], "text": c["text"], "metadata": c["metadata"]}
|
| 148 |
+
for c in final_state.get("retrieved_chunks", [])
|
| 149 |
+
]
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
yield json.dumps({"type": "result", "payload": result_payload}) + "\n"
|
| 153 |
+
|
| 154 |
+
except Exception as e:
|
| 155 |
+
print(f"Streaming error: {e}")
|
| 156 |
+
yield json.dumps({"type": "error", "message": str(e)}) + "\n"
|
| 157 |
+
|
| 158 |
+
return StreamingResponse(event_generator(), media_type="application/x-ndjson")
|
| 159 |
|
| 160 |
@app.get("/graph/explore")
|
| 161 |
def get_graph_exploration(limit: int = 1000):
|