Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| import json | |
| import asyncio | |
| from pydantic import BaseModel | |
| from typing import Dict, Any, Optional, List | |
| from contextlib import asynccontextmanager | |
| from services.qdrant import start_qdrant_docker, stop_qdrant_docker, get_qdrant_client | |
| from services.neo4j import start_neo4j_docker, stop_neo4j_docker, get_neo4j_driver | |
| from utils.config import settings | |
| # Import the LangGraph app | |
| from core.graph_workflow import app as graph_app | |
| print("🚀 Nexus Lex Backend: Initializing...") | |
| async def lifespan(app: FastAPI): | |
| # Startup: Ensure containers are running | |
| if settings.QDRANT_ENDPOINT: | |
| print("📍 [STARTUP] Cloud Qdrant target detected.") | |
| else: | |
| print("⚠️ [STARTUP] No Qdrant endpoint found. Attempting local Docker lifecycle...") | |
| start_qdrant_docker() | |
| if settings.NEO4J_URI: | |
| print("📍 [STARTUP] Cloud Neo4j target detected.") | |
| else: | |
| print("⚠️ [STARTUP] No Neo4j URI found. Attempting local Docker lifecycle...") | |
| start_neo4j_docker() | |
| print("✅ [STARTUP] Initialization sequence complete. App is live.") | |
| yield | |
| # Shutdown: Stop containers | |
| print("Stopping Qdrant container...") | |
| stop_qdrant_docker() | |
| print("Stopping Neo4j container...") | |
| stop_neo4j_docker() | |
| app = FastAPI(lifespan=lifespan) | |
| # Add CORS Middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # For development; in production, use specific origins like ["http://localhost:5173"] | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global Cache for Static/Slow Data | |
| CACHE = { | |
| "graph_explore": None, | |
| "stats": None, | |
| "timeline": None | |
| } | |
| class SearchRequest(BaseModel): | |
| query: str | |
| limit: int = 5 | |
| filters: Optional[Dict[str, Any]] = None | |
| def read_root(): | |
| return {"message": "Constitution Agent API (LangGraph Edition) is running"} | |
| def health_check(): | |
| health_status = {"status": "healthy", "qdrant": "unknown", "neo4j": "unknown"} | |
| try: | |
| client = get_qdrant_client() | |
| collections = client.get_collections() | |
| health_status["qdrant"] = "connected" | |
| except Exception as e: | |
| health_status["qdrant"] = str(e) | |
| health_status["status"] = "unhealthy" | |
| try: | |
| driver = get_neo4j_driver() | |
| with driver.session() as session: | |
| session.run("RETURN 1") | |
| health_status["neo4j"] = "connected" | |
| except Exception as e: | |
| health_status["neo4j"] = str(e) | |
| health_status["status"] = "unhealthy" | |
| return health_status | |
| def clear_api_cache(): | |
| """Wipes the in-memory cache for all dashboard data.""" | |
| global CACHE | |
| for key in CACHE: | |
| CACHE[key] = None | |
| return {"status": "Cache Cleared"} | |
| async def search_amendments(request: SearchRequest): | |
| async def event_generator(): | |
| try: | |
| initial_state = {"query": request.query, "retry_count": 0} | |
| final_state = {} | |
| # Use astream to get updates after every node | |
| # stream_mode="updates" gives us the incremental changes from each node | |
| async for event in graph_app.astream(initial_state, stream_mode="updates"): | |
| for node_name, state_update in event.items(): | |
| # Update local state tracker | |
| final_state.update(state_update) | |
| # Yield a trace event for the UI terminal | |
| trace_msg = f">> {node_name.replace('_', ' ').title()} sequence complete." | |
| # Add descriptive context if available | |
| if node_name == "classify": | |
| entities = state_update.get("classification", {}).get("entities", {}) | |
| arts = entities.get("articles", []) | |
| if arts: trace_msg = f">> Classified: Targeting Articles {', '.join(map(str, arts))}" | |
| elif node_name == "graph_plan": | |
| results = state_update.get("graph_results", []) | |
| trace_msg = f">> Neo4j: Retrieved {len(results)} legal relationships." | |
| elif node_name == "fetch_vector": | |
| chunks = state_update.get("raw_chunks", []) | |
| trace_msg = f">> Qdrant: Retrieved {len(chunks)} semantic text blocks." | |
| elif node_name == "reason": | |
| trace_msg = ">> Legal Reasoner: Synthesizing final opinion..." | |
| yield json.dumps({"type": "trace", "message": trace_msg}) + "\n" | |
| # Small sleep to ensure the UI feels sequential and "live" | |
| await asyncio.sleep(0.1) | |
| # We need to ensure we have the full final state for the result | |
| # LangGraph's astream "updates" doesn't necessarily give the final dictionary in one go | |
| # So we format the accumulated final_state | |
| result_payload = { | |
| "query": request.query, | |
| "answer": final_state.get("draft_answer", {}).get("answer"), | |
| "constitutional_status": final_state.get("draft_answer", {}).get("constitutional_status"), | |
| "confidence": final_state.get("critique", {}).get("final_confidence"), | |
| "sources": final_state.get("draft_answer", {}).get("sources"), | |
| "quality_grade": final_state.get("critique", {}).get("quality_grade"), | |
| "execution_trace": final_state.get("trace", []), | |
| "graph_nodes": final_state.get("graph_results", []), | |
| "vector_chunks": [ | |
| {"id": c["id"], "text": c["text"], "metadata": c["metadata"]} | |
| for c in final_state.get("retrieved_chunks", []) | |
| ] | |
| } | |
| yield json.dumps({"type": "result", "payload": result_payload}) + "\n" | |
| except Exception as e: | |
| print(f"Streaming error: {e}") | |
| yield json.dumps({"type": "error", "message": str(e)}) + "\n" | |
| return StreamingResponse(event_generator(), media_type="application/x-ndjson") | |
| def get_graph_exploration(limit: int = 1000): | |
| """Returns a subset of nodes and edges for the 3D visualization. Result is cached.""" | |
| global CACHE | |
| if CACHE["graph_explore"]: | |
| return CACHE["graph_explore"] | |
| driver = get_neo4j_driver() | |
| nodes = [] | |
| links = [] | |
| query = f""" | |
| MATCH (n)-[r]->(m) | |
| RETURN | |
| id(n) as source_id, labels(n)[0] as source_label, n.number as source_num, n.id as source_name, toInteger(n.year) as source_year, | |
| type(r) as rel_type, | |
| id(m) as target_id, labels(m)[0] as target_label, m.number as target_num, m.id as target_name, toInteger(m.year) as target_year | |
| ORDER BY source_label ASC, source_year ASC | |
| LIMIT 2000 | |
| """ | |
| try: | |
| with driver.session() as session: | |
| result = session.run(query) | |
| seen_nodes = set() | |
| for record in result: | |
| # Process Source Node | |
| s_id = str(record["source_id"]) | |
| if s_id not in seen_nodes: | |
| # Explicitly check for None to avoid falsy 0 falling through | |
| s_name = record["source_num"] | |
| if s_name is None: s_name = record["source_name"] | |
| if s_name is None: s_name = s_id | |
| nodes.append({ | |
| "id": s_id, | |
| "label": record["source_label"], | |
| "name": str(s_name), | |
| "type": record["source_label"], | |
| "year": record["source_year"] | |
| }) | |
| seen_nodes.add(s_id) | |
| # Process Target Node | |
| t_id = str(record["target_id"]) | |
| if t_id not in seen_nodes: | |
| # Explicitly check for None to avoid falsy 0 falling through | |
| t_name = record["target_num"] | |
| if t_name is None: t_name = record["target_name"] | |
| if t_name is None: t_name = t_id | |
| nodes.append({ | |
| "id": t_id, | |
| "label": record["target_label"], | |
| "name": str(t_name), | |
| "type": record["target_label"], | |
| "year": record["target_year"] | |
| }) | |
| seen_nodes.add(t_id) | |
| # Process Link | |
| links.append({ | |
| "source": s_id, | |
| "target": t_id, | |
| "type": record["rel_type"] | |
| }) | |
| result_data = {"nodes": nodes, "links": links} | |
| CACHE["graph_explore"] = result_data | |
| return result_data | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Neo4j Error: {str(e)}") | |
| def get_stats(): | |
| """Returns database statistics for the dashboard. Result is cached.""" | |
| global CACHE | |
| if CACHE["stats"]: | |
| return CACHE["stats"] | |
| driver = get_neo4j_driver() | |
| stats = {} | |
| try: | |
| with driver.session() as session: | |
| # Counts | |
| res = session.run("MATCH (a:Article) RETURN count(a) as articles") | |
| stats["articles"] = res.single()["articles"] | |
| res = session.run("MATCH (am:Amendment) RETURN count(am) as amendments") | |
| stats["amendments"] = res.single()["amendments"] | |
| res = session.run("MATCH (c:Clause) RETURN count(c) as clauses") | |
| stats["clauses"] = res.single()["clauses"] | |
| # Relationships count | |
| res = session.run("MATCH ()-[r]->() RETURN count(r) as connections") | |
| stats["total_connections"] = res.single()["connections"] | |
| CACHE["stats"] = stats | |
| return stats | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def get_timeline(): | |
| """Returns amendment activity by year. Result is cached.""" | |
| global CACHE | |
| if CACHE["timeline"]: | |
| return CACHE["timeline"] | |
| driver = get_neo4j_driver() | |
| try: | |
| with driver.session() as session: | |
| res = session.run(""" | |
| MATCH (am:Amendment) | |
| WHERE am.year IS NOT NULL | |
| RETURN am.year as year, count(am) as count | |
| ORDER BY year ASC | |
| """) | |
| data = [record.data() for record in res] | |
| CACHE["timeline"] = data | |
| return data | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |