Meshyboi commited on
Commit
a695743
·
verified ·
1 Parent(s): 0cd3dc5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +76 -30
main.py CHANGED
@@ -1,5 +1,8 @@
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
3
  from pydantic import BaseModel
4
  from typing import Dict, Any, Optional, List
5
  from contextlib import asynccontextmanager
@@ -9,16 +12,24 @@ from utils.config import settings
9
  # Import the LangGraph app
10
  from core.graph_workflow import app as graph_app
11
 
 
 
12
  @asynccontextmanager
13
  async def lifespan(app: FastAPI):
14
  # Startup: Ensure containers are running
15
- print("Starting Qdrant container...")
16
- start_qdrant_docker()
17
- print("Starting Neo4j container...")
18
- start_neo4j_docker()
19
-
20
- # Pre-loading models happens inside the services instantiated in graph_workflow
 
 
 
 
 
21
 
 
22
  yield
23
 
24
  # Shutdown: Stop containers
@@ -86,30 +97,65 @@ def clear_api_cache():
86
 
87
  @app.post("/search")
88
  async def search_amendments(request: SearchRequest):
89
- try:
90
- # Invoke LangGraph
91
- initial_state = {"query": request.query, "retry_count": 0}
92
- result = await graph_app.ainvoke(initial_state)
93
-
94
- # Format output from final state
95
- return {
96
- "query": result["query"],
97
- "answer": result["draft_answer"].get("answer"),
98
- "constitutional_status": result["draft_answer"].get("constitutional_status"),
99
- "confidence": result["critique"].get("final_confidence"),
100
- "sources": result["draft_answer"].get("sources"),
101
- "quality_grade": result["critique"].get("quality_grade"),
102
- "execution_trace": result["trace"],
103
- # PASS RAW DATA FOR UI HIGHLIGHTING
104
- "graph_nodes": result.get("graph_results", []),
105
- "vector_chunks": [
106
- {"id": c["id"], "text": c["text"], "metadata": c["metadata"]}
107
- for c in result.get("retrieved_chunks", [])
108
- ]
109
- }
110
- except Exception as e:
111
- print(f"Graph execution error: {e}")
112
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  @app.get("/graph/explore")
115
  def get_graph_exploration(limit: int = 1000):
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse
4
+ import json
5
+ import asyncio
6
  from pydantic import BaseModel
7
  from typing import Dict, Any, Optional, List
8
  from contextlib import asynccontextmanager
 
12
  # Import the LangGraph app
13
  from core.graph_workflow import app as graph_app
14
 
15
+ print("🚀 Nexus Lex Backend: Initializing...")
16
+
17
  @asynccontextmanager
18
  async def lifespan(app: FastAPI):
19
  # Startup: Ensure containers are running
20
+ if settings.QDRANT_ENDPOINT:
21
+ print("📍 [STARTUP] Cloud Qdrant target detected.")
22
+ else:
23
+ print("⚠️ [STARTUP] No Qdrant endpoint found. Attempting local Docker lifecycle...")
24
+ start_qdrant_docker()
25
+
26
+ if settings.NEO4J_URI:
27
+ print("📍 [STARTUP] Cloud Neo4j target detected.")
28
+ else:
29
+ print("⚠️ [STARTUP] No Neo4j URI found. Attempting local Docker lifecycle...")
30
+ start_neo4j_docker()
31
 
32
+ print("✅ [STARTUP] Initialization sequence complete. App is live.")
33
  yield
34
 
35
  # Shutdown: Stop containers
 
97
 
98
  @app.post("/search")
99
  async def search_amendments(request: SearchRequest):
100
+ async def event_generator():
101
+ try:
102
+ initial_state = {"query": request.query, "retry_count": 0}
103
+ final_state = {}
104
+
105
+ # Use astream to get updates after every node
106
+ # stream_mode="updates" gives us the incremental changes from each node
107
+ async for event in graph_app.astream(initial_state, stream_mode="updates"):
108
+ for node_name, state_update in event.items():
109
+ # Update local state tracker
110
+ final_state.update(state_update)
111
+
112
+ # Yield a trace event for the UI terminal
113
+ trace_msg = f">> {node_name.replace('_', ' ').title()} sequence complete."
114
+
115
+ # Add descriptive context if available
116
+ if node_name == "classify":
117
+ entities = state_update.get("classification", {}).get("entities", {})
118
+ arts = entities.get("articles", [])
119
+ if arts: trace_msg = f">> Classified: Targeting Articles {', '.join(map(str, arts))}"
120
+ elif node_name == "graph_plan":
121
+ results = state_update.get("graph_results", [])
122
+ trace_msg = f">> Neo4j: Retrieved {len(results)} legal relationships."
123
+ elif node_name == "fetch_vector":
124
+ chunks = state_update.get("raw_chunks", [])
125
+ trace_msg = f">> Qdrant: Retrieved {len(chunks)} semantic text blocks."
126
+ elif node_name == "reason":
127
+ trace_msg = ">> Legal Reasoner: Synthesizing final opinion..."
128
+
129
+ yield json.dumps({"type": "trace", "message": trace_msg}) + "\n"
130
+ # Small sleep to ensure the UI feels sequential and "live"
131
+ await asyncio.sleep(0.1)
132
+
133
+ # We need to ensure we have the full final state for the result
134
+ # LangGraph's astream "updates" doesn't necessarily give the final dictionary in one go
135
+ # So we format the accumulated final_state
136
+
137
+ result_payload = {
138
+ "query": request.query,
139
+ "answer": final_state.get("draft_answer", {}).get("answer"),
140
+ "constitutional_status": final_state.get("draft_answer", {}).get("constitutional_status"),
141
+ "confidence": final_state.get("critique", {}).get("final_confidence"),
142
+ "sources": final_state.get("draft_answer", {}).get("sources"),
143
+ "quality_grade": final_state.get("critique", {}).get("quality_grade"),
144
+ "execution_trace": final_state.get("trace", []),
145
+ "graph_nodes": final_state.get("graph_results", []),
146
+ "vector_chunks": [
147
+ {"id": c["id"], "text": c["text"], "metadata": c["metadata"]}
148
+ for c in final_state.get("retrieved_chunks", [])
149
+ ]
150
+ }
151
+
152
+ yield json.dumps({"type": "result", "payload": result_payload}) + "\n"
153
+
154
+ except Exception as e:
155
+ print(f"Streaming error: {e}")
156
+ yield json.dumps({"type": "error", "message": str(e)}) + "\n"
157
+
158
+ return StreamingResponse(event_generator(), media_type="application/x-ndjson")
159
 
160
  @app.get("/graph/explore")
161
  def get_graph_exploration(limit: int = 1000):