| from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks |
| from pydantic import BaseModel |
| from typing import List, Optional |
| from ..schemas.research import ( |
| ResearchQueryRequest, ResearchQueryResponse, SessionStatusResponse, |
| ) |
| from ..schemas.council import AgentConfig, CouncilConfigRequest, DEFAULT_AGENTS |
| from ..core.auth import get_current_user |
| from ..core.firebase import get_db |
| from ..workers.pipeline import run_research_pipeline |
| import uuid |
| from datetime import datetime, timezone |
|
|
|
|
| class ChatRequest(BaseModel): |
| message: str |
| history: Optional[List[dict]] = [] |
|
|
|
|
| def _snake_to_camel(s: str) -> str: |
| parts = s.split("_") |
| return parts[0] + "".join(p.capitalize() for p in parts[1:]) |
|
|
|
|
| def _to_camel(obj): |
| """Recursively convert dict keys from snake_case to camelCase.""" |
| if isinstance(obj, dict): |
| return {_snake_to_camel(k): _to_camel(v) for k, v in obj.items()} |
| if isinstance(obj, list): |
| return [_to_camel(i) for i in obj] |
| return obj |
|
|
| router = APIRouter(prefix="/research", tags=["research"]) |
|
|
|
|
| @router.post("/query", response_model=ResearchQueryResponse) |
| async def submit_query( |
| body: ResearchQueryRequest, |
| background_tasks: BackgroundTasks, |
| user: dict = Depends(get_current_user), |
| ): |
| |
| |
| session_id = str(uuid.uuid4()) |
| print(f"[API] /research/query called β session {session_id[:8]} user {user['uid'][:8]}") |
| background_tasks.add_task(run_research_pipeline, session_id, body.query, user["uid"]) |
| return ResearchQueryResponse(session_id=session_id) |
|
|
|
|
| @router.get("/{session_id}/status", response_model=SessionStatusResponse) |
| async def get_status( |
| session_id: str, |
| user: dict = Depends(get_current_user), |
| ): |
| db = get_db() |
| doc = db.collection("research_sessions").document(session_id).get() |
| if not doc.exists: |
| raise HTTPException(status_code=404, detail="Session not found") |
| data = doc.to_dict() |
| if data.get("userId") != user["uid"]: |
| raise HTTPException(status_code=403, detail="Forbidden") |
| return SessionStatusResponse( |
| session_id=session_id, |
| status=data["status"], |
| paper_count=data.get("paperCount"), |
| ) |
|
|
|
|
| @router.get("/{session_id}/papers") |
| async def get_papers(session_id: str, user: dict = Depends(get_current_user)): |
| db = get_db() |
| docs = db.collection("papers").where("sessionId", "==", session_id).stream() |
| papers = [_to_camel({"id": d.id, **d.to_dict()}) for d in docs] |
| return {"papers": papers} |
|
|
|
|
| @router.get("/{session_id}/report") |
| async def get_report(session_id: str, user: dict = Depends(get_current_user)): |
| db = get_db() |
| |
| docs = list(db.collection("reports").where("session_id", "==", session_id).limit(1).stream()) |
| if not docs: |
| docs = list(db.collection("reports").where("sessionId", "==", session_id).limit(1).stream()) |
| report = next(({"id": d.id, **d.to_dict()} for d in docs), None) |
| if not report: |
| raise HTTPException(status_code=404, detail="Report not ready yet") |
| return {"report": _to_camel(report)} |
|
|
|
|
| @router.post("/{session_id}/regenerate-report") |
| async def regenerate_report(session_id: str, background_tasks: BackgroundTasks, user: dict = Depends(get_current_user)): |
| """Re-run report generation for an existing completed session (fixes static fallback reports).""" |
| db = get_db() |
| doc = db.collection("research_sessions").document(session_id).get() |
| if not doc.exists: |
| raise HTTPException(status_code=404, detail="Session not found") |
| data = doc.to_dict() |
| if data.get("userId") != user["uid"]: |
| raise HTTPException(status_code=403, detail="Forbidden") |
|
|
| query = data.get("query", "") |
|
|
| async def _regen(): |
| from ..services.reporting.report_generator import generate_report |
| from ..schemas.research import Paper, PaperExtraction, PaperSource |
|
|
| |
| paper_docs = db.collection("papers").where("sessionId", "==", session_id).stream() |
| papers = [] |
| for pd in paper_docs: |
| raw = pd.to_dict() |
| try: |
| ext_raw = raw.get("extraction") or {} |
| extraction = PaperExtraction(**ext_raw) if ext_raw else None |
| papers.append(Paper( |
| id=raw.get("id", pd.id), |
| session_id=session_id, |
| external_source=raw.get("externalSource") or raw.get("external_source") or "arxiv", |
| source_paper_id=raw.get("sourcePaperId") or raw.get("source_paper_id") or "", |
| title=raw.get("title", ""), |
| authors=raw.get("authors", []), |
| year=raw.get("year"), |
| abstract=raw.get("abstract", ""), |
| citation_count=raw.get("citationCount") or raw.get("citation_count"), |
| relevance_score=raw.get("relevanceScore") or raw.get("relevance_score"), |
| extraction=extraction, |
| )) |
| except Exception as e: |
| print(f"[Regen] Skipping paper: {e}") |
|
|
| print(f"[Regen] Regenerating report for session {session_id[:8]} with {len(papers)} papers") |
| report = await generate_report(query, papers, session_id) |
| db.collection("reports").document(report.id).set(report.model_dump()) |
| print(f"[Regen] Report saved: {report.id}") |
|
|
| background_tasks.add_task(_regen) |
| return {"status": "regenerating", "message": "Report is being regenerated. Refresh in ~30 seconds."} |
|
|
|
|
| @router.get("/{session_id}/graph") |
| async def get_graph(session_id: str, user: dict = Depends(get_current_user)): |
| db = get_db() |
| doc = db.collection("graphs").document(session_id).get() |
| if not doc.exists: |
| raise HTTPException(status_code=404, detail="Graph not ready yet") |
| return {"graph": doc.to_dict()} |
|
|
|
|
| @router.get("/{session_id}/neo4j-graph") |
| async def get_neo4j_graph(session_id: str, user: dict = Depends(get_current_user)): |
| """Returns the knowledge graph directly from Neo4j AuraDB.""" |
| try: |
| from ..services.graph.neo4j_writer import get_graph_for_session |
| data = get_graph_for_session(session_id) |
| return {"graph": data} |
| except RuntimeError as e: |
| raise HTTPException(status_code=503, detail=str(e)) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Neo4j query failed: {e}") |
|
|
|
|
| @router.get("/{session_id}/contradictions") |
| async def get_contradictions(session_id: str, user: dict = Depends(get_current_user)): |
| db = get_db() |
| |
| docs = list(db.collection("contradictions").where("session_id", "==", session_id).stream()) |
| if not docs: |
| docs = list(db.collection("contradictions").where("sessionId", "==", session_id).stream()) |
| return {"contradictions": [_to_camel({"id": d.id, **d.to_dict()}) for d in docs]} |
|
|
|
|
| |
|
|
| @router.get("/{session_id}/council") |
| async def get_council(session_id: str, user: dict = Depends(get_current_user)): |
| db = get_db() |
| doc = db.collection("council_sessions").document(session_id).get() |
| if not doc.exists: |
| return { |
| "council": { |
| "session_id": session_id, |
| "status": "idle", |
| "agents": [a.model_dump() for a in DEFAULT_AGENTS], |
| "round_1": [], |
| "round_2": [], |
| "round_3": [], |
| "supervisor": None, |
| } |
| } |
| return {"council": doc.to_dict()} |
|
|
|
|
| @router.post("/{session_id}/council/configure") |
| async def configure_council( |
| session_id: str, |
| body: CouncilConfigRequest, |
| user: dict = Depends(get_current_user), |
| ): |
| db = get_db() |
| doc = db.collection("council_sessions").document(session_id).get() |
| current = doc.to_dict() if doc.exists else {} |
| current_status = current.get("status", "idle") |
|
|
| if current_status not in ("idle", "configured"): |
| raise HTTPException(status_code=409, detail="Council has already started β cannot reconfigure.") |
|
|
| db.collection("council_sessions").document(session_id).set({ |
| "session_id": session_id, |
| "status": "configured", |
| "agents": [a.model_dump() for a in body.agents], |
| "round_1": [], |
| "round_2": [], |
| "round_3": [], |
| "supervisor": None, |
| "created_at": datetime.now(timezone.utc).isoformat(), |
| }) |
| return {"status": "configured", "agents": [a.model_dump() for a in body.agents]} |
|
|
|
|
| @router.post("/{session_id}/council/start") |
| async def start_council( |
| session_id: str, |
| background_tasks: BackgroundTasks, |
| user: dict = Depends(get_current_user), |
| ): |
| """Manually trigger the council debate (if user wants to start it before the pipeline reaches that stage).""" |
| db = get_db() |
| doc = db.collection("council_sessions").document(session_id).get() |
| if not doc.exists: |
| raise HTTPException(status_code=404, detail="No council session found. Run the research pipeline first.") |
|
|
| cd = doc.to_dict() |
| if cd.get("status") not in ("idle", "configured"): |
| return {"status": cd.get("status"), "message": "Council is already running or complete."} |
|
|
| |
| paper_docs = db.collection("papers").where("sessionId", "==", session_id).stream() |
| from ..schemas.research import Paper, PaperExtraction |
| papers = [] |
| for pd in paper_docs: |
| raw = pd.to_dict() |
| try: |
| ext_raw = raw.get("extraction") or {} |
| extraction = PaperExtraction(**ext_raw) if ext_raw else None |
| papers.append(Paper( |
| id=raw.get("id", pd.id), |
| session_id=session_id, |
| external_source=raw.get("externalSource") or raw.get("external_source") or "arxiv", |
| source_paper_id=raw.get("sourcePaperId") or raw.get("source_paper_id") or "", |
| title=raw.get("title", ""), |
| authors=raw.get("authors", []), |
| year=raw.get("year"), |
| abstract=raw.get("abstract", ""), |
| citation_count=raw.get("citationCount") or raw.get("citation_count"), |
| relevance_score=raw.get("relevanceScore") or raw.get("relevance_score"), |
| extraction=extraction, |
| )) |
| except Exception as e: |
| print(f"[Council start] Skipping paper: {e}") |
|
|
| agents = [AgentConfig(**a) for a in cd.get("agents", [])] or DEFAULT_AGENTS |
| session_doc = db.collection("research_sessions").document(session_id).get() |
| query = (session_doc.to_dict() or {}).get("query", "") |
|
|
| from ..core.config import get_settings |
| from ..services.council.council_runner import run_debate_council |
|
|
| settings = get_settings() |
|
|
| async def _run_council(): |
| await run_debate_council(session_id, query, papers, agents, db, num_rounds=settings.council_rounds) |
|
|
| background_tasks.add_task(_run_council) |
| return {"status": "started", "message": "Council debate is running in the background."} |
|
|
|
|
| |
|
|
| @router.get("/{session_id}/chat") |
| async def get_chat_history(session_id: str, user: dict = Depends(get_current_user)): |
| db = get_db() |
| docs = ( |
| db.collection("chat_messages") |
| .where("session_id", "==", session_id) |
| .order_by("created_at") |
| .stream() |
| ) |
| messages = [{"id": d.id, **d.to_dict()} for d in docs] |
| return {"messages": messages} |
|
|
|
|
| @router.post("/{session_id}/chat") |
| async def send_chat_message( |
| session_id: str, |
| body: ChatRequest, |
| user: dict = Depends(get_current_user), |
| ): |
| from ..services.chat.chat_agent import answer_research_question |
|
|
| db = get_db() |
|
|
| |
| paper_docs = db.collection("papers").where("sessionId", "==", session_id).stream() |
| papers = [] |
| for pd in paper_docs: |
| raw = pd.to_dict() |
| papers.append(raw) |
|
|
| if not papers: |
| return {"answer": "No papers found for this session yet. Please wait for the pipeline to complete."} |
|
|
| try: |
| answer = await answer_research_question( |
| question=body.message, |
| papers=papers, |
| history=body.history or [], |
| ) |
| except Exception as e: |
| print(f"[Chat] Error: {e}") |
| answer = f"An error occurred while processing your question: {e}" |
|
|
| |
| now = datetime.now(timezone.utc).isoformat() |
| try: |
| db.collection("chat_messages").document(str(uuid.uuid4())).set({ |
| "session_id": session_id, |
| "role": "user", |
| "content": body.message, |
| "created_at": now, |
| }) |
| db.collection("chat_messages").document(str(uuid.uuid4())).set({ |
| "session_id": session_id, |
| "role": "assistant", |
| "content": answer, |
| "created_at": now, |
| }) |
| except Exception as e: |
| print(f"[Chat] Could not persist message: {e}") |
|
|
| return {"answer": answer} |
|
|