Spaces:
Sleeping
Sleeping
File size: 5,326 Bytes
1de0a51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
from __future__ import annotations
import uuid
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from schemas import AgentRunRequest, AgentRunResponse, Message
from memory_mongo import memory_store # MongoDB-backed memory
from graph import build_graph
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from fastapi.responses import StreamingResponse
import json
import time
from fastapi.encoders import jsonable_encoder
app = FastAPI(title="PharmAI Navigator (Agentic)", version="0.1.0")
# CORS (HF Spaces + your Node proxy)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Compile graph once at startup
GRAPH = build_graph()
@app.get("/health")
def health():
"""Health check with MongoDB status."""
mongo_status = "connected"
session_count = 0
try:
session_count = memory_store.get_session_count()
except Exception as e:
mongo_status = f"error: {str(e)}"
return {
"status": "ok",
"mongodb": mongo_status,
"active_sessions": session_count
}
@app.get("/session/{session_id}/history")
def get_session_history(session_id: str):
"""Get chat history for a session (for testing)."""
messages = memory_store.get(session_id)
return {
"session_id": session_id,
"message_count": len(messages),
"messages": [{"role": m.role, "content": m.content[:100] + "..." if len(m.content) > 100 else m.content} for m in messages]
}
@app.delete("/session/{session_id}")
def clear_session(session_id: str):
"""Clear a session's history (for testing)."""
memory_store.clear(session_id)
return {"session_id": session_id, "status": "cleared"}
@app.post("/admin/cleanup-sessions")
def cleanup_old_sessions(days: int = 7):
"""
Admin endpoint to manually cleanup old sessions.
(TTL index handles this automatically if configured)
"""
try:
deleted = memory_store.cleanup_old_sessions(days=days)
return {
"status": "ok",
"deleted_sessions": deleted,
"days": days
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/test/echo")
def test_echo(req: AgentRunRequest):
"""
Lightweight test endpoint - no LLM calls, just tests memory.
Echoes back the query and shows session history.
"""
session_id = req.session_id or str(uuid.uuid4())
# Get prior history
prior = memory_store.get(session_id)
# Append user message
memory_store.append(session_id, role="user", content=req.query)
# Create fake response
fake_response = f"Echo: {req.query} (Session has {len(prior)} prior messages)"
# Append assistant message
memory_store.append(session_id, role="assistant", content=fake_response)
return {
"session_id": session_id,
"decision_brief": fake_response,
"prior_message_count": len(prior),
"current_message_count": len(memory_store.get(session_id)),
"citations": [],
"metadata": {"test_mode": True}
}
@app.post("/run", response_model=AgentRunResponse)
def run_agent(req: AgentRunRequest):
# 1) session handling
session_id = req.session_id or str(uuid.uuid4())
# 2) load prior history (for chat continuity)
prior = memory_store.get(session_id)
# Convert to LangChain message dict format for LangGraph MessagesState
# LangGraph expects state["messages"] as list of LC messages; we pass dict-like messages.
messages = []
for m in prior:
if m.role == "user":
messages.append(HumanMessage(content=m.content))
elif m.role == "assistant":
messages.append(AIMessage(content=m.content))
elif m.role == "system":
messages.append(SystemMessage(content=m.content))
# 3) append this user query to memory (pre-run)
memory_store.append(session_id, role="user", content=req.query)
# Append new user query as LangChain message
messages = messages + [HumanMessage(content=req.query)]
# 4) run graph (Mode A synchronous)
try:
final_state = GRAPH.invoke(
{
"session_id": session_id,
"user_query": req.query,
"messages": messages,
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Agent run failed: {str(e)}")
decision_brief = final_state.get("decision_brief") or final_state.get("final_decision") or ""
citations = final_state.get("citations") or []
# 5) save assistant response to memory (post-run)
if decision_brief.strip():
memory_store.append(session_id, role="assistant", content=decision_brief)
return AgentRunResponse(
session_id=session_id,
decision_brief=decision_brief,
confidence_score=final_state.get("confidence_score"),
citations=citations,
metadata={
"has_prior_messages": len(prior) > 0,
},
)
|