from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse import json import asyncio from pydantic import BaseModel from typing import Dict, Any, Optional, List from contextlib import asynccontextmanager from services.qdrant import start_qdrant_docker, stop_qdrant_docker, get_qdrant_client from services.neo4j import start_neo4j_docker, stop_neo4j_docker, get_neo4j_driver from utils.config import settings # Import the LangGraph app from core.graph_workflow import app as graph_app print("🚀 Nexus Lex Backend: Initializing...") @asynccontextmanager async def lifespan(app: FastAPI): # Startup: Ensure containers are running if settings.QDRANT_ENDPOINT: print("📍 [STARTUP] Cloud Qdrant target detected.") else: print("⚠️ [STARTUP] No Qdrant endpoint found. Attempting local Docker lifecycle...") start_qdrant_docker() if settings.NEO4J_URI: print("📍 [STARTUP] Cloud Neo4j target detected.") else: print("⚠️ [STARTUP] No Neo4j URI found. Attempting local Docker lifecycle...") start_neo4j_docker() print("✅ [STARTUP] Initialization sequence complete. App is live.") yield # Shutdown: Stop containers print("Stopping Qdrant container...") stop_qdrant_docker() print("Stopping Neo4j container...") stop_neo4j_docker() app = FastAPI(lifespan=lifespan) # Add CORS Middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # For development; in production, use specific origins like ["http://localhost:5173"] allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global Cache for Static/Slow Data CACHE = { "graph_explore": None, "stats": None, "timeline": None } class SearchRequest(BaseModel): query: str limit: int = 5 filters: Optional[Dict[str, Any]] = None @app.get("/") def read_root(): return {"message": "Constitution Agent API (LangGraph Edition) is running"} @app.get("/health") def health_check(): health_status = {"status": "healthy", "qdrant": "unknown", "neo4j": "unknown"} try: client = get_qdrant_client() collections = client.get_collections() health_status["qdrant"] = "connected" except Exception as e: health_status["qdrant"] = str(e) health_status["status"] = "unhealthy" try: driver = get_neo4j_driver() with driver.session() as session: session.run("RETURN 1") health_status["neo4j"] = "connected" except Exception as e: health_status["neo4j"] = str(e) health_status["status"] = "unhealthy" return health_status @app.post("/cache/clear") def clear_api_cache(): """Wipes the in-memory cache for all dashboard data.""" global CACHE for key in CACHE: CACHE[key] = None return {"status": "Cache Cleared"} @app.post("/search") async def search_amendments(request: SearchRequest): async def event_generator(): try: initial_state = {"query": request.query, "retry_count": 0} final_state = {} # Use astream to get updates after every node # stream_mode="updates" gives us the incremental changes from each node async for event in graph_app.astream(initial_state, stream_mode="updates"): for node_name, state_update in event.items(): # Update local state tracker final_state.update(state_update) # Yield a trace event for the UI terminal trace_msg = f">> {node_name.replace('_', ' ').title()} sequence complete." # Add descriptive context if available if node_name == "classify": entities = state_update.get("classification", {}).get("entities", {}) arts = entities.get("articles", []) if arts: trace_msg = f">> Classified: Targeting Articles {', '.join(map(str, arts))}" elif node_name == "graph_plan": results = state_update.get("graph_results", []) trace_msg = f">> Neo4j: Retrieved {len(results)} legal relationships." elif node_name == "fetch_vector": chunks = state_update.get("raw_chunks", []) trace_msg = f">> Qdrant: Retrieved {len(chunks)} semantic text blocks." elif node_name == "reason": trace_msg = ">> Legal Reasoner: Synthesizing final opinion..." yield json.dumps({"type": "trace", "message": trace_msg}) + "\n" # Small sleep to ensure the UI feels sequential and "live" await asyncio.sleep(0.1) # We need to ensure we have the full final state for the result # LangGraph's astream "updates" doesn't necessarily give the final dictionary in one go # So we format the accumulated final_state result_payload = { "query": request.query, "answer": final_state.get("draft_answer", {}).get("answer"), "constitutional_status": final_state.get("draft_answer", {}).get("constitutional_status"), "confidence": final_state.get("critique", {}).get("final_confidence"), "sources": final_state.get("draft_answer", {}).get("sources"), "quality_grade": final_state.get("critique", {}).get("quality_grade"), "execution_trace": final_state.get("trace", []), "graph_nodes": final_state.get("graph_results", []), "vector_chunks": [ {"id": c["id"], "text": c["text"], "metadata": c["metadata"]} for c in final_state.get("retrieved_chunks", []) ] } yield json.dumps({"type": "result", "payload": result_payload}) + "\n" except Exception as e: print(f"Streaming error: {e}") yield json.dumps({"type": "error", "message": str(e)}) + "\n" return StreamingResponse(event_generator(), media_type="application/x-ndjson") @app.get("/graph/explore") def get_graph_exploration(limit: int = 1000): """Returns a subset of nodes and edges for the 3D visualization. Result is cached.""" global CACHE if CACHE["graph_explore"]: return CACHE["graph_explore"] driver = get_neo4j_driver() nodes = [] links = [] query = f""" MATCH (n)-[r]->(m) RETURN id(n) as source_id, labels(n)[0] as source_label, n.number as source_num, n.id as source_name, toInteger(n.year) as source_year, type(r) as rel_type, id(m) as target_id, labels(m)[0] as target_label, m.number as target_num, m.id as target_name, toInteger(m.year) as target_year ORDER BY source_label ASC, source_year ASC LIMIT 2000 """ try: with driver.session() as session: result = session.run(query) seen_nodes = set() for record in result: # Process Source Node s_id = str(record["source_id"]) if s_id not in seen_nodes: # Explicitly check for None to avoid falsy 0 falling through s_name = record["source_num"] if s_name is None: s_name = record["source_name"] if s_name is None: s_name = s_id nodes.append({ "id": s_id, "label": record["source_label"], "name": str(s_name), "type": record["source_label"], "year": record["source_year"] }) seen_nodes.add(s_id) # Process Target Node t_id = str(record["target_id"]) if t_id not in seen_nodes: # Explicitly check for None to avoid falsy 0 falling through t_name = record["target_num"] if t_name is None: t_name = record["target_name"] if t_name is None: t_name = t_id nodes.append({ "id": t_id, "label": record["target_label"], "name": str(t_name), "type": record["target_label"], "year": record["target_year"] }) seen_nodes.add(t_id) # Process Link links.append({ "source": s_id, "target": t_id, "type": record["rel_type"] }) result_data = {"nodes": nodes, "links": links} CACHE["graph_explore"] = result_data return result_data except Exception as e: raise HTTPException(status_code=500, detail=f"Neo4j Error: {str(e)}") @app.get("/stats") def get_stats(): """Returns database statistics for the dashboard. Result is cached.""" global CACHE if CACHE["stats"]: return CACHE["stats"] driver = get_neo4j_driver() stats = {} try: with driver.session() as session: # Counts res = session.run("MATCH (a:Article) RETURN count(a) as articles") stats["articles"] = res.single()["articles"] res = session.run("MATCH (am:Amendment) RETURN count(am) as amendments") stats["amendments"] = res.single()["amendments"] res = session.run("MATCH (c:Clause) RETURN count(c) as clauses") stats["clauses"] = res.single()["clauses"] # Relationships count res = session.run("MATCH ()-[r]->() RETURN count(r) as connections") stats["total_connections"] = res.single()["connections"] CACHE["stats"] = stats return stats except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/timeline") def get_timeline(): """Returns amendment activity by year. Result is cached.""" global CACHE if CACHE["timeline"]: return CACHE["timeline"] driver = get_neo4j_driver() try: with driver.session() as session: res = session.run(""" MATCH (am:Amendment) WHERE am.year IS NOT NULL RETURN am.year as year, count(am) as count ORDER BY year ASC """) data = [record.data() for record in res] CACHE["timeline"] = data return data except Exception as e: raise HTTPException(status_code=500, detail=str(e))