Spaces:
Sleeping
Sleeping
Added main app files
Browse files- .dockerignore +33 -0
- Dockerfile +30 -0
- app.py +164 -0
- graph.py +479 -0
- memory.py +91 -0
- memory_mongo.py +260 -0
- requirements.txt +32 -0
- schemas.py +98 -0
- tools.py +288 -0
.dockerignore
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
env/
|
| 8 |
+
venv/
|
| 9 |
+
ENV/
|
| 10 |
+
|
| 11 |
+
# Testing
|
| 12 |
+
.pytest_cache
|
| 13 |
+
*.log
|
| 14 |
+
|
| 15 |
+
# IDE
|
| 16 |
+
.vscode
|
| 17 |
+
.idea
|
| 18 |
+
*.swp
|
| 19 |
+
*.swo
|
| 20 |
+
|
| 21 |
+
# Test files
|
| 22 |
+
test_*.py
|
| 23 |
+
test_*.ps1
|
| 24 |
+
*_test.py
|
| 25 |
+
fix_indexes.py
|
| 26 |
+
|
| 27 |
+
# Documentation
|
| 28 |
+
*.md
|
| 29 |
+
MEMORY_COMPARISON.md
|
| 30 |
+
|
| 31 |
+
# Output files
|
| 32 |
+
*.png
|
| 33 |
+
output.txt
|
Dockerfile
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Python 3.11 slim image
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
# Set working directory
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install system dependencies for graphviz (optional but recommended)
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
graphviz \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# Copy requirements first for better caching
|
| 13 |
+
COPY requirements.txt .
|
| 14 |
+
|
| 15 |
+
# Install Python dependencies
|
| 16 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 17 |
+
pip install --no-cache-dir -r requirements.txt
|
| 18 |
+
|
| 19 |
+
# Copy application code
|
| 20 |
+
COPY . .
|
| 21 |
+
|
| 22 |
+
# Expose port 7860 (HuggingFace Spaces default)
|
| 23 |
+
EXPOSE 7860
|
| 24 |
+
|
| 25 |
+
# Set environment variables
|
| 26 |
+
ENV PYTHONUNBUFFERED=1
|
| 27 |
+
ENV PORT=7860
|
| 28 |
+
|
| 29 |
+
# Run the FastAPI application
|
| 30 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import uuid
|
| 3 |
+
from fastapi import FastAPI, HTTPException
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from schemas import AgentRunRequest, AgentRunResponse, Message
|
| 6 |
+
from memory_mongo import memory_store # MongoDB-backed memory
|
| 7 |
+
from graph import build_graph
|
| 8 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 9 |
+
from fastapi.responses import StreamingResponse
|
| 10 |
+
import json
|
| 11 |
+
import time
|
| 12 |
+
from fastapi.encoders import jsonable_encoder
|
| 13 |
+
|
| 14 |
+
app = FastAPI(title="PharmAI Navigator (Agentic)", version="0.1.0")
|
| 15 |
+
|
| 16 |
+
# CORS (HF Spaces + your Node proxy)
|
| 17 |
+
app.add_middleware(
|
| 18 |
+
CORSMiddleware,
|
| 19 |
+
allow_origins=["*"],
|
| 20 |
+
allow_credentials=True,
|
| 21 |
+
allow_methods=["*"],
|
| 22 |
+
allow_headers=["*"],
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Compile graph once at startup
|
| 26 |
+
GRAPH = build_graph()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@app.get("/health")
|
| 30 |
+
def health():
|
| 31 |
+
"""Health check with MongoDB status."""
|
| 32 |
+
mongo_status = "connected"
|
| 33 |
+
session_count = 0
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
session_count = memory_store.get_session_count()
|
| 37 |
+
except Exception as e:
|
| 38 |
+
mongo_status = f"error: {str(e)}"
|
| 39 |
+
|
| 40 |
+
return {
|
| 41 |
+
"status": "ok",
|
| 42 |
+
"mongodb": mongo_status,
|
| 43 |
+
"active_sessions": session_count
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@app.get("/session/{session_id}/history")
|
| 48 |
+
def get_session_history(session_id: str):
|
| 49 |
+
"""Get chat history for a session (for testing)."""
|
| 50 |
+
messages = memory_store.get(session_id)
|
| 51 |
+
return {
|
| 52 |
+
"session_id": session_id,
|
| 53 |
+
"message_count": len(messages),
|
| 54 |
+
"messages": [{"role": m.role, "content": m.content[:100] + "..." if len(m.content) > 100 else m.content} for m in messages]
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@app.delete("/session/{session_id}")
|
| 59 |
+
def clear_session(session_id: str):
|
| 60 |
+
"""Clear a session's history (for testing)."""
|
| 61 |
+
memory_store.clear(session_id)
|
| 62 |
+
return {"session_id": session_id, "status": "cleared"}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@app.post("/admin/cleanup-sessions")
|
| 66 |
+
def cleanup_old_sessions(days: int = 7):
|
| 67 |
+
"""
|
| 68 |
+
Admin endpoint to manually cleanup old sessions.
|
| 69 |
+
(TTL index handles this automatically if configured)
|
| 70 |
+
"""
|
| 71 |
+
try:
|
| 72 |
+
deleted = memory_store.cleanup_old_sessions(days=days)
|
| 73 |
+
return {
|
| 74 |
+
"status": "ok",
|
| 75 |
+
"deleted_sessions": deleted,
|
| 76 |
+
"days": days
|
| 77 |
+
}
|
| 78 |
+
except Exception as e:
|
| 79 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@app.post("/test/echo")
|
| 83 |
+
def test_echo(req: AgentRunRequest):
|
| 84 |
+
"""
|
| 85 |
+
Lightweight test endpoint - no LLM calls, just tests memory.
|
| 86 |
+
Echoes back the query and shows session history.
|
| 87 |
+
"""
|
| 88 |
+
session_id = req.session_id or str(uuid.uuid4())
|
| 89 |
+
|
| 90 |
+
# Get prior history
|
| 91 |
+
prior = memory_store.get(session_id)
|
| 92 |
+
|
| 93 |
+
# Append user message
|
| 94 |
+
memory_store.append(session_id, role="user", content=req.query)
|
| 95 |
+
|
| 96 |
+
# Create fake response
|
| 97 |
+
fake_response = f"Echo: {req.query} (Session has {len(prior)} prior messages)"
|
| 98 |
+
|
| 99 |
+
# Append assistant message
|
| 100 |
+
memory_store.append(session_id, role="assistant", content=fake_response)
|
| 101 |
+
|
| 102 |
+
return {
|
| 103 |
+
"session_id": session_id,
|
| 104 |
+
"decision_brief": fake_response,
|
| 105 |
+
"prior_message_count": len(prior),
|
| 106 |
+
"current_message_count": len(memory_store.get(session_id)),
|
| 107 |
+
"citations": [],
|
| 108 |
+
"metadata": {"test_mode": True}
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@app.post("/run", response_model=AgentRunResponse)
|
| 113 |
+
def run_agent(req: AgentRunRequest):
|
| 114 |
+
# 1) session handling
|
| 115 |
+
session_id = req.session_id or str(uuid.uuid4())
|
| 116 |
+
|
| 117 |
+
# 2) load prior history (for chat continuity)
|
| 118 |
+
prior = memory_store.get(session_id)
|
| 119 |
+
|
| 120 |
+
# Convert to LangChain message dict format for LangGraph MessagesState
|
| 121 |
+
# LangGraph expects state["messages"] as list of LC messages; we pass dict-like messages.
|
| 122 |
+
messages = []
|
| 123 |
+
for m in prior:
|
| 124 |
+
if m.role == "user":
|
| 125 |
+
messages.append(HumanMessage(content=m.content))
|
| 126 |
+
elif m.role == "assistant":
|
| 127 |
+
messages.append(AIMessage(content=m.content))
|
| 128 |
+
elif m.role == "system":
|
| 129 |
+
messages.append(SystemMessage(content=m.content))
|
| 130 |
+
|
| 131 |
+
# 3) append this user query to memory (pre-run)
|
| 132 |
+
memory_store.append(session_id, role="user", content=req.query)
|
| 133 |
+
|
| 134 |
+
# Append new user query as LangChain message
|
| 135 |
+
messages = messages + [HumanMessage(content=req.query)]
|
| 136 |
+
|
| 137 |
+
# 4) run graph (Mode A synchronous)
|
| 138 |
+
try:
|
| 139 |
+
final_state = GRAPH.invoke(
|
| 140 |
+
{
|
| 141 |
+
"session_id": session_id,
|
| 142 |
+
"user_query": req.query,
|
| 143 |
+
"messages": messages,
|
| 144 |
+
}
|
| 145 |
+
)
|
| 146 |
+
except Exception as e:
|
| 147 |
+
raise HTTPException(status_code=500, detail=f"Agent run failed: {str(e)}")
|
| 148 |
+
|
| 149 |
+
decision_brief = final_state.get("decision_brief") or final_state.get("final_decision") or ""
|
| 150 |
+
citations = final_state.get("citations") or []
|
| 151 |
+
|
| 152 |
+
# 5) save assistant response to memory (post-run)
|
| 153 |
+
if decision_brief.strip():
|
| 154 |
+
memory_store.append(session_id, role="assistant", content=decision_brief)
|
| 155 |
+
|
| 156 |
+
return AgentRunResponse(
|
| 157 |
+
session_id=session_id,
|
| 158 |
+
decision_brief=decision_brief,
|
| 159 |
+
confidence_score=final_state.get("confidence_score"),
|
| 160 |
+
citations=citations,
|
| 161 |
+
metadata={
|
| 162 |
+
"has_prior_messages": len(prior) > 0,
|
| 163 |
+
},
|
| 164 |
+
)
|
graph.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
from typing import Any, Dict, List, Optional
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
from langchain_anthropic import ChatAnthropic
|
| 10 |
+
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, ToolMessage
|
| 11 |
+
from langchain_core.tools import tool
|
| 12 |
+
|
| 13 |
+
from langgraph.graph import StateGraph, START, END, MessagesState
|
| 14 |
+
from langgraph.prebuilt import ToolNode, tools_condition
|
| 15 |
+
|
| 16 |
+
from tools import (
|
| 17 |
+
tavily_search,
|
| 18 |
+
stub_evidence,
|
| 19 |
+
classify_query,
|
| 20 |
+
extract_entities,
|
| 21 |
+
normalize_evidence,
|
| 22 |
+
generate_graph_dot,
|
| 23 |
+
clinicaltrials_search,
|
| 24 |
+
render_dot_to_png_base64
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Load environment variables
|
| 28 |
+
load_dotenv()
|
| 29 |
+
|
| 30 |
+
# -----------------------------
|
| 31 |
+
# LangChain Tool Wrappers
|
| 32 |
+
# -----------------------------
|
| 33 |
+
@tool("web_search")
|
| 34 |
+
def web_search_tool(query: str, max_results: int = 5) -> List[Dict[str, Any]]:
|
| 35 |
+
"""Web search using Tavily. Returns a list of evidence dicts."""
|
| 36 |
+
ev = tavily_search(query=query, max_results=max_results)
|
| 37 |
+
return [e.model_dump() for e in ev]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@tool("stub_evidence")
|
| 41 |
+
def stub_evidence_tool(query: str) -> List[Dict[str, Any]]:
|
| 42 |
+
"""Deterministic fallback evidence tool (offline/demo)."""
|
| 43 |
+
ev = stub_evidence(query=query)
|
| 44 |
+
return [e.model_dump() for e in ev]
|
| 45 |
+
|
| 46 |
+
@tool("classify_query")
|
| 47 |
+
def classify_query_tool(query: str) -> Dict[str, Any]:
|
| 48 |
+
"""Classify query to decide which tools are needed."""
|
| 49 |
+
return classify_query(query)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@tool("extract_entities")
|
| 53 |
+
def extract_entities_tool(query: str) -> Dict[str, Optional[str]]:
|
| 54 |
+
"""Extract drug and indication from query."""
|
| 55 |
+
return extract_entities(query)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@tool("normalize_evidence")
|
| 59 |
+
def normalize_evidence_tool(evidence: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 60 |
+
"""Dedupe and clean evidence."""
|
| 61 |
+
return normalize_evidence(evidence)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@tool("generate_graph_dot")
|
| 65 |
+
def generate_graph_dot_tool(
|
| 66 |
+
title: str,
|
| 67 |
+
nodes: List[Dict[str, str]],
|
| 68 |
+
edges: List[Dict[str, str]],
|
| 69 |
+
rankdir: str = "LR",
|
| 70 |
+
) -> str:
|
| 71 |
+
"""
|
| 72 |
+
Generate Graphviz DOT.
|
| 73 |
+
IMPORTANT: Use this tool instead of writing DOT directly.
|
| 74 |
+
"""
|
| 75 |
+
return generate_graph_dot(
|
| 76 |
+
title=title,
|
| 77 |
+
nodes=nodes,
|
| 78 |
+
edges=edges,
|
| 79 |
+
rankdir=rankdir,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
@tool("clinicaltrials_search")
|
| 83 |
+
def clinicaltrials_search_tool(drug: str, indication: str, max_results: int = 5) -> List[Dict[str, Any]]:
|
| 84 |
+
"""Search ClinicalTrials.gov (Tavily-based MVP)."""
|
| 85 |
+
ev = clinicaltrials_search(drug=drug, indication=indication, max_results=max_results)
|
| 86 |
+
return [e.model_dump() for e in ev]
|
| 87 |
+
|
| 88 |
+
@tool("render_dot_to_png_base64")
|
| 89 |
+
def render_dot_to_png_base64_tool(dot: str) -> Dict[str, Any]:
|
| 90 |
+
"""Render DOT to PNG (base64). Optional dependency on graphviz."""
|
| 91 |
+
return render_dot_to_png_base64(dot)
|
| 92 |
+
|
| 93 |
+
TOOLS = [
|
| 94 |
+
web_search_tool,
|
| 95 |
+
stub_evidence_tool,
|
| 96 |
+
classify_query_tool,
|
| 97 |
+
extract_entities_tool,
|
| 98 |
+
normalize_evidence_tool,
|
| 99 |
+
generate_graph_dot_tool,
|
| 100 |
+
clinicaltrials_search_tool,
|
| 101 |
+
render_dot_to_png_base64_tool
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
# -----------------------------
|
| 105 |
+
# LangGraph State
|
| 106 |
+
# -----------------------------
|
| 107 |
+
class PharmAIState(MessagesState):
|
| 108 |
+
session_id: Optional[str]
|
| 109 |
+
user_query: str
|
| 110 |
+
decision_brief: str
|
| 111 |
+
citations: List[str]
|
| 112 |
+
confidence_score: float
|
| 113 |
+
tool_loops: int # safety counter
|
| 114 |
+
diagram_png_base64: Optional[str] # <-- add
|
| 115 |
+
diagram_dot: Optional[str] # <-- optional
|
| 116 |
+
intent: str # "simple" | "diligence" | "diagram"
|
| 117 |
+
|
| 118 |
+
# -----------------------------
|
| 119 |
+
# Guardrails + Prompts
|
| 120 |
+
# -----------------------------
|
| 121 |
+
SYSTEM_PROMPT = """You are PharmAI Navigator, an evidence-grounded diligence assistant for drug/asset evaluation.
|
| 122 |
+
|
| 123 |
+
Your job:
|
| 124 |
+
Turn a query like "Assess {Drug} for {Indication}" into a decision-grade brief OR structured output.
|
| 125 |
+
|
| 126 |
+
CRITICAL TOOL USAGE RULES:
|
| 127 |
+
- If the user asks for a diagram, flow, architecture, graph, visualization, or Graphviz:
|
| 128 |
+
→ You MUST call `generate_graph_dot`.
|
| 129 |
+
→ You MUST NOT write Graphviz DOT directly in your response.
|
| 130 |
+
→ If the user asks for an image/PNG, call `render_dot_to_png_base64` AFTER you get DOT.
|
| 131 |
+
- If the user asks for trials / phases / NCT IDs / endpoints:
|
| 132 |
+
→ Prefer calling `extract_entities` then `clinicaltrials_search`.
|
| 133 |
+
- If the user asks for factual claims (approvals, safety, pricing, patents, market):
|
| 134 |
+
→ Prefer calling `web_search`.
|
| 135 |
+
|
| 136 |
+
Guardrails (STRICT):
|
| 137 |
+
- Do NOT invent specific facts (approval dates, trial names, endpoints, statistics, patent expiry).
|
| 138 |
+
- Any concrete number/date/claim MUST be supported by tool evidence.
|
| 139 |
+
- If evidence is insufficient, clearly list Evidence Gaps.
|
| 140 |
+
- Be concise, structured, and decision-oriented.
|
| 141 |
+
- Avoid medical advice; present as diligence/analysis.
|
| 142 |
+
|
| 143 |
+
Simple Query Rule (CRITICAL):
|
| 144 |
+
- If the user asks a simple definitional question ("what is", "define", "explain") and you can answer without external verification, do NOT call tools and respond directly.
|
| 145 |
+
- Only use tools when you need current/specific data (trials, approvals, patents, market data).
|
| 146 |
+
|
| 147 |
+
Citations policy:
|
| 148 |
+
- The final response's "Citations" section is handled by the system.
|
| 149 |
+
- Do NOT create your own citation list.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
FINAL_PROMPT = """Write the FINAL decision brief with these sections:
|
| 153 |
+
|
| 154 |
+
1) Executive Recommendation (1–2 lines)
|
| 155 |
+
2) Scientific Rationale (bullets)
|
| 156 |
+
3) Clinical Evidence Snapshot (bullets)
|
| 157 |
+
4) IP / Exclusivity Quick View (bullets)
|
| 158 |
+
5) Market / SoC Snapshot (bullets)
|
| 159 |
+
6) Key Risks + Next Actions (bullets)
|
| 160 |
+
|
| 161 |
+
Rules:
|
| 162 |
+
- If evidence is insufficient, include "Evidence Gaps" with bullets.
|
| 163 |
+
- Do NOT add a citations section yourself; the system will append it.
|
| 164 |
+
Return plain text only.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
# Placeholder detection to avoid wasting tokens on "Drug X / Indication Y"
|
| 168 |
+
PLACEHOLDER_PATTERNS = [
|
| 169 |
+
r"\bdrug\s*x\b",
|
| 170 |
+
r"\bindication\s*y\b",
|
| 171 |
+
r"\bdrug\s*name\b",
|
| 172 |
+
r"\bindication\s*name\b",
|
| 173 |
+
]
|
| 174 |
+
def _looks_like_placeholder(q: str) -> bool:
|
| 175 |
+
ql = (q or "").strip().lower()
|
| 176 |
+
return any(re.search(p, ql) for p in PLACEHOLDER_PATTERNS)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _build_model() -> ChatAnthropic:
|
| 180 |
+
model_name = os.getenv("ANTHROPIC_MODEL", "claude-3-7-sonnet-latest")
|
| 181 |
+
return ChatAnthropic(
|
| 182 |
+
model=model_name,
|
| 183 |
+
temperature=0.2,
|
| 184 |
+
max_tokens=10000,
|
| 185 |
+
timeout=120,
|
| 186 |
+
streaming=False,
|
| 187 |
+
stop=None
|
| 188 |
+
).bind_tools(TOOLS)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# Safety cap to avoid endless tool loops
|
| 192 |
+
MAX_TOOL_LOOPS = int(os.getenv("MAX_TOOL_LOOPS", "4"))
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def llm_call(state: PharmAIState) -> Dict[str, Any]:
|
| 196 |
+
"""
|
| 197 |
+
Calls Claude with tool schemas attached.
|
| 198 |
+
Returns new messages to append into state["messages"].
|
| 199 |
+
"""
|
| 200 |
+
llm = _build_model()
|
| 201 |
+
messages: List[BaseMessage] = state["messages"]
|
| 202 |
+
|
| 203 |
+
if not messages or not isinstance(messages[0], SystemMessage):
|
| 204 |
+
messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages
|
| 205 |
+
|
| 206 |
+
tool_loops = state.get("tool_loops", 0)
|
| 207 |
+
if tool_loops >= MAX_TOOL_LOOPS:
|
| 208 |
+
# Stop tool-calling loop and force synthesis
|
| 209 |
+
stop_msg = HumanMessage(
|
| 210 |
+
content=(
|
| 211 |
+
"Stop calling tools now. Proceed to final synthesis using what you already have. "
|
| 212 |
+
"If evidence is insufficient, clearly list Evidence Gaps."
|
| 213 |
+
)
|
| 214 |
+
)
|
| 215 |
+
messages = messages + [stop_msg]
|
| 216 |
+
|
| 217 |
+
resp = llm.invoke(messages)
|
| 218 |
+
return {"messages": [resp]}
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# -----------------------------
|
| 222 |
+
# Citations extraction (tool-only)
|
| 223 |
+
# -----------------------------
|
| 224 |
+
def _clean_url(u: str) -> str:
|
| 225 |
+
return u.strip().strip("),.]}\"'")
|
| 226 |
+
|
| 227 |
+
def _extract_citations_from_messages(messages: List[BaseMessage]) -> List[str]:
|
| 228 |
+
"""
|
| 229 |
+
Tool-only citation extraction (single source of truth):
|
| 230 |
+
- ONLY reads ToolMessage contents (actual tool outputs).
|
| 231 |
+
- If tool output is JSON (list/dict), pull `source` fields.
|
| 232 |
+
- Fallback: regex URL extraction from tool text.
|
| 233 |
+
"""
|
| 234 |
+
citations: List[str] = []
|
| 235 |
+
url_re = re.compile(r"https?://[^\s\]\)\}\",']+")
|
| 236 |
+
|
| 237 |
+
for m in messages:
|
| 238 |
+
if not isinstance(m, ToolMessage):
|
| 239 |
+
continue
|
| 240 |
+
|
| 241 |
+
content = getattr(m, "content", None)
|
| 242 |
+
if not content:
|
| 243 |
+
continue
|
| 244 |
+
|
| 245 |
+
if isinstance(content, str):
|
| 246 |
+
parsed = None
|
| 247 |
+
try:
|
| 248 |
+
parsed = json.loads(content)
|
| 249 |
+
except Exception:
|
| 250 |
+
parsed = None
|
| 251 |
+
|
| 252 |
+
if isinstance(parsed, list):
|
| 253 |
+
for item in parsed:
|
| 254 |
+
if isinstance(item, dict):
|
| 255 |
+
src = item.get("source")
|
| 256 |
+
if isinstance(src, str) and src.startswith(("http://", "https://")):
|
| 257 |
+
citations.append(_clean_url(src))
|
| 258 |
+
elif isinstance(parsed, dict):
|
| 259 |
+
src = parsed.get("source")
|
| 260 |
+
if isinstance(src, str) and src.startswith(("http://", "https://")):
|
| 261 |
+
citations.append(_clean_url(src))
|
| 262 |
+
|
| 263 |
+
for u in url_re.findall(content):
|
| 264 |
+
citations.append(_clean_url(u))
|
| 265 |
+
|
| 266 |
+
# De-duplicate
|
| 267 |
+
seen = set()
|
| 268 |
+
out = []
|
| 269 |
+
for c in citations:
|
| 270 |
+
# drop clearly broken/truncated URLs
|
| 271 |
+
if len(c) < 12:
|
| 272 |
+
continue
|
| 273 |
+
if c not in seen:
|
| 274 |
+
seen.add(c)
|
| 275 |
+
out.append(c)
|
| 276 |
+
return out
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _append_citations_section(brief_text: str, citations: List[str]) -> str:
|
| 280 |
+
"""
|
| 281 |
+
Enforces "single source of truth":
|
| 282 |
+
- Removes any existing 'Citations' section the model may have produced
|
| 283 |
+
- Appends citations derived from tool outputs only
|
| 284 |
+
"""
|
| 285 |
+
text = (brief_text or "").strip()
|
| 286 |
+
|
| 287 |
+
# Remove any model-generated citations section (best-effort)
|
| 288 |
+
# (handles '## Citations' or 'Citations' headers)
|
| 289 |
+
text = re.split(r"\n#{1,3}\s*Citations\s*\n|\nCitations\s*\n", text, maxsplit=1)[0].rstrip()
|
| 290 |
+
|
| 291 |
+
if citations:
|
| 292 |
+
lines = ["", "## Citations"]
|
| 293 |
+
for i, c in enumerate(citations, 1):
|
| 294 |
+
lines.append(f"{i}. {c}")
|
| 295 |
+
text = text + "\n" + "\n".join(lines)
|
| 296 |
+
else:
|
| 297 |
+
text = text + "\n\n## Citations\n- (No external sources retrieved.)"
|
| 298 |
+
|
| 299 |
+
return text
|
| 300 |
+
|
| 301 |
+
def capture_diagram(state: PharmAIState) -> Dict[str, Any]:
|
| 302 |
+
# Find the last ToolMessage (most recent tool output)
|
| 303 |
+
last_tool = None
|
| 304 |
+
for m in reversed(state["messages"]):
|
| 305 |
+
if isinstance(m, ToolMessage):
|
| 306 |
+
last_tool = m
|
| 307 |
+
break
|
| 308 |
+
|
| 309 |
+
if not last_tool:
|
| 310 |
+
return {}
|
| 311 |
+
|
| 312 |
+
tool_name = getattr(last_tool, "name", "") or ""
|
| 313 |
+
content = getattr(last_tool, "content", "")
|
| 314 |
+
|
| 315 |
+
# If your render tool returns base64 string directly
|
| 316 |
+
if tool_name == "render_dot_to_png_base64":
|
| 317 |
+
return {"diagram_png_base64": content}
|
| 318 |
+
|
| 319 |
+
# If your generate_graph_dot returns dot string
|
| 320 |
+
if tool_name == "generate_graph_dot":
|
| 321 |
+
return {"diagram_dot": content}
|
| 322 |
+
|
| 323 |
+
return {}
|
| 324 |
+
|
| 325 |
+
def route_after_tools(state: PharmAIState) -> str:
|
| 326 |
+
# If we already have the final diagram artifact, stop.
|
| 327 |
+
if state.get("diagram_png_base64"):
|
| 328 |
+
return END
|
| 329 |
+
return "bump_tool_loop"
|
| 330 |
+
|
| 331 |
+
def preprocess(state: PharmAIState) -> Dict[str, Any]:
|
| 332 |
+
q = (state.get("user_query") or "").strip().lower()
|
| 333 |
+
|
| 334 |
+
if any(k in q for k in ["diagram", "flowchart", "architecture", "graphviz", "dot", "draw"]):
|
| 335 |
+
return {"intent": "diagram"}
|
| 336 |
+
|
| 337 |
+
if re.match(r"^(what is|define|explain)\b", q) and len(q) < 120:
|
| 338 |
+
return {"intent": "simple"}
|
| 339 |
+
|
| 340 |
+
return {"intent": "diligence"}
|
| 341 |
+
|
| 342 |
+
def route_after_llm(state: PharmAIState):
|
| 343 |
+
# If query is simple, never call tools/synthesize
|
| 344 |
+
if state.get("intent") == "simple":
|
| 345 |
+
return "end_simple"
|
| 346 |
+
|
| 347 |
+
# If the model asked for tools, go tools
|
| 348 |
+
last = state["messages"][-1]
|
| 349 |
+
if getattr(last, "tool_calls", None):
|
| 350 |
+
return "tools"
|
| 351 |
+
|
| 352 |
+
return "synthesize"
|
| 353 |
+
|
| 354 |
+
def end_simple(state: PharmAIState) -> Dict[str, Any]:
|
| 355 |
+
# Return the last assistant content as the final answer
|
| 356 |
+
last = state["messages"][-1]
|
| 357 |
+
text = getattr(last, "content", "") if isinstance(getattr(last, "content", ""), str) else str(getattr(last, "content", ""))
|
| 358 |
+
return {"decision_brief": text, "citations": []}
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# -----------------------------
|
| 362 |
+
# Final Synthesis Node
|
| 363 |
+
# -----------------------------
|
| 364 |
+
def synthesize(state: PharmAIState) -> Dict[str, Any]:
|
| 365 |
+
# Fast guardrail: placeholders -> short response without tool burn
|
| 366 |
+
uq = state.get("user_query", "")
|
| 367 |
+
if _looks_like_placeholder(uq):
|
| 368 |
+
brief = (
|
| 369 |
+
"# FINAL DECISION BRIEF\n\n"
|
| 370 |
+
"I need the **actual drug name** and **specific indication** to perform diligence.\n\n"
|
| 371 |
+
"## Evidence Gaps\n"
|
| 372 |
+
"- Drug name (e.g., semaglutide)\n"
|
| 373 |
+
"- Indication (e.g., obesity)\n"
|
| 374 |
+
"- Trial/program context (if any)\n"
|
| 375 |
+
)
|
| 376 |
+
return {
|
| 377 |
+
"decision_brief": _append_citations_section(brief, []),
|
| 378 |
+
"citations": [],
|
| 379 |
+
"messages": [HumanMessage(content="(placeholder query detected; returned guardrail response)")],
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
llm = _build_model()
|
| 383 |
+
messages: List[BaseMessage] = state["messages"]
|
| 384 |
+
messages = messages + [HumanMessage(content=FINAL_PROMPT)]
|
| 385 |
+
|
| 386 |
+
resp = llm.invoke(messages)
|
| 387 |
+
|
| 388 |
+
tool_citations = _extract_citations_from_messages(state["messages"])
|
| 389 |
+
brief_text = resp.content if isinstance(resp.content, str) else str(resp.content)
|
| 390 |
+
brief_text = _append_citations_section(brief_text, tool_citations)
|
| 391 |
+
|
| 392 |
+
return {
|
| 393 |
+
"decision_brief": brief_text,
|
| 394 |
+
"citations": tool_citations,
|
| 395 |
+
"messages": [resp],
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
# -----------------------------
|
| 400 |
+
# Build + Compile Graph
|
| 401 |
+
# -----------------------------
|
| 402 |
+
def build_graph():
|
| 403 |
+
"""
|
| 404 |
+
Graph with preprocessing and smart routing.
|
| 405 |
+
"""
|
| 406 |
+
g = StateGraph(PharmAIState)
|
| 407 |
+
|
| 408 |
+
g.add_node("preprocess", preprocess)
|
| 409 |
+
g.add_node("llm_call", llm_call)
|
| 410 |
+
g.add_node("tools", ToolNode(TOOLS))
|
| 411 |
+
g.add_node("capture_diagram", capture_diagram)
|
| 412 |
+
g.add_node("bump_tool_loop", lambda s: {"tool_loops": s.get("tool_loops", 0) + 1})
|
| 413 |
+
g.add_node("synthesize", synthesize)
|
| 414 |
+
g.add_node("end_simple", end_simple)
|
| 415 |
+
|
| 416 |
+
g.add_edge(START, "preprocess")
|
| 417 |
+
g.add_edge("preprocess", "llm_call")
|
| 418 |
+
|
| 419 |
+
# After LLM: route based on intent and tool calls
|
| 420 |
+
g.add_conditional_edges(
|
| 421 |
+
"llm_call",
|
| 422 |
+
route_after_llm,
|
| 423 |
+
{
|
| 424 |
+
"tools": "tools",
|
| 425 |
+
"synthesize": "synthesize",
|
| 426 |
+
"end_simple": "end_simple",
|
| 427 |
+
},
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# After tools: capture diagram data
|
| 431 |
+
g.add_edge("tools", "capture_diagram")
|
| 432 |
+
|
| 433 |
+
# After capture: check if we should stop (diagram complete) or continue
|
| 434 |
+
g.add_conditional_edges(
|
| 435 |
+
"capture_diagram",
|
| 436 |
+
route_after_tools,
|
| 437 |
+
{
|
| 438 |
+
END: END, # Stop if diagram is complete
|
| 439 |
+
"bump_tool_loop": "bump_tool_loop", # Continue otherwise
|
| 440 |
+
},
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
g.add_edge("bump_tool_loop", "llm_call")
|
| 444 |
+
g.add_edge("end_simple", END)
|
| 445 |
+
g.add_edge("synthesize", END)
|
| 446 |
+
|
| 447 |
+
return g.compile()
|
| 448 |
+
|
| 449 |
+
# -----------------------------
|
| 450 |
+
# Test execution
|
| 451 |
+
# -----------------------------
|
| 452 |
+
if __name__ == "__main__":
|
| 453 |
+
print("Building PharmAI Navigator graph...")
|
| 454 |
+
graph = build_graph()
|
| 455 |
+
print("Graph compiled successfully!")
|
| 456 |
+
|
| 457 |
+
# Test query designed to trigger generate_graph_dot tool
|
| 458 |
+
#test_query = "Assess semaglutide for obesity"
|
| 459 |
+
#test_query = "Assess donanemab for early Alzheimer’s disease. Retrieve key clinical trials, summarize efficacy and safety outcomes, normalize the evidence, and generate a system architecture graph showing how PharmAI Navigator evaluates this asset."
|
| 460 |
+
#test_query = "Create a DOT graph showing the relationship between Drug, Indication, Clinical Trials, FDA Approval, and Market Launch and render it as png"
|
| 461 |
+
test_query = "What is pembrolizumab?"
|
| 462 |
+
print(f"\nRunning test query: {test_query}")
|
| 463 |
+
|
| 464 |
+
result = graph.invoke({
|
| 465 |
+
"messages": [HumanMessage(content=test_query)],
|
| 466 |
+
"user_query": test_query,
|
| 467 |
+
"tool_loops": 0,
|
| 468 |
+
})
|
| 469 |
+
|
| 470 |
+
print("\n" + "=" * 60)
|
| 471 |
+
print("OUTPUT:")
|
| 472 |
+
print("=" * 60)
|
| 473 |
+
print(result.get("decision_brief", "No output"))
|
| 474 |
+
|
| 475 |
+
print("\n" + "=" * 60)
|
| 476 |
+
print("CITATIONS (tool-only):")
|
| 477 |
+
print("=" * 60)
|
| 478 |
+
for i, citation in enumerate(result.get("citations", []), 1):
|
| 479 |
+
print(f"{i}. {citation}")
|
memory.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Dict, List, Optional
|
| 4 |
+
import time
|
| 5 |
+
import threading
|
| 6 |
+
from schemas import Message
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class SessionMemory:
|
| 10 |
+
"""In-memory chat history for a single session."""
|
| 11 |
+
messages: List[Message]
|
| 12 |
+
updated_at: float
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MemoryStore:
|
| 16 |
+
"""
|
| 17 |
+
Simple thread-safe in-memory store.
|
| 18 |
+
|
| 19 |
+
- session_id -> list[Message]
|
| 20 |
+
- trims to keep memory bounded
|
| 21 |
+
- includes basic TTL cleanup hooks (optional)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, max_messages: int = 30, ttl_seconds: Optional[int] = None):
|
| 25 |
+
self.max_messages = max_messages
|
| 26 |
+
self.ttl_seconds = ttl_seconds
|
| 27 |
+
self._lock = threading.Lock()
|
| 28 |
+
self._store: Dict[str, SessionMemory] = {}
|
| 29 |
+
|
| 30 |
+
def _now(self) -> float:
|
| 31 |
+
return time.time()
|
| 32 |
+
|
| 33 |
+
def get(self, session_id: str) -> List[Message]:
|
| 34 |
+
"""Get messages for a session (returns empty list if new session)."""
|
| 35 |
+
if not session_id:
|
| 36 |
+
return []
|
| 37 |
+
with self._lock:
|
| 38 |
+
self._gc_locked()
|
| 39 |
+
if session_id not in self._store:
|
| 40 |
+
self._store[session_id] = SessionMemory(messages=[], updated_at=self._now())
|
| 41 |
+
return list(self._store[session_id].messages)
|
| 42 |
+
|
| 43 |
+
def append(self, session_id: str, role: str, content: str) -> None:
|
| 44 |
+
"""Append a message and enforce trimming."""
|
| 45 |
+
if not session_id:
|
| 46 |
+
return
|
| 47 |
+
with self._lock:
|
| 48 |
+
self._gc_locked()
|
| 49 |
+
if session_id not in self._store:
|
| 50 |
+
self._store[session_id] = SessionMemory(messages=[], updated_at=self._now())
|
| 51 |
+
|
| 52 |
+
self._store[session_id].messages.append(Message(role=role, content=content))
|
| 53 |
+
self._store[session_id].updated_at = self._now()
|
| 54 |
+
|
| 55 |
+
# Trim oldest messages (keep most recent)
|
| 56 |
+
if len(self._store[session_id].messages) > self.max_messages:
|
| 57 |
+
overflow = len(self._store[session_id].messages) - self.max_messages
|
| 58 |
+
self._store[session_id].messages = self._store[session_id].messages[overflow:]
|
| 59 |
+
|
| 60 |
+
def set_messages(self, session_id: str, messages: List[Message]) -> None:
|
| 61 |
+
"""Replace session history entirely (rarely needed, but handy)."""
|
| 62 |
+
if not session_id:
|
| 63 |
+
return
|
| 64 |
+
with self._lock:
|
| 65 |
+
self._store[session_id] = SessionMemory(
|
| 66 |
+
messages=messages[-self.max_messages :],
|
| 67 |
+
updated_at=self._now(),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def clear(self, session_id: str) -> None:
|
| 71 |
+
"""Clear a single session."""
|
| 72 |
+
if not session_id:
|
| 73 |
+
return
|
| 74 |
+
with self._lock:
|
| 75 |
+
self._store.pop(session_id, None)
|
| 76 |
+
|
| 77 |
+
def _gc_locked(self) -> None:
|
| 78 |
+
"""TTL cleanup (only runs if ttl_seconds is configured)."""
|
| 79 |
+
if not self.ttl_seconds:
|
| 80 |
+
return
|
| 81 |
+
cutoff = self._now() - self.ttl_seconds
|
| 82 |
+
expired = [sid for sid, mem in self._store.items() if mem.updated_at < cutoff]
|
| 83 |
+
for sid in expired:
|
| 84 |
+
self._store.pop(sid, None)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# Global singleton (simple for HF Spaces demo)
|
| 88 |
+
memory_store = MemoryStore(
|
| 89 |
+
max_messages=int(__import__("os").getenv("MAX_SESSION_MESSAGES", "30")),
|
| 90 |
+
ttl_seconds=int(__import__("os").getenv("SESSION_TTL_SECONDS", "0")) or None,
|
| 91 |
+
)
|
memory_mongo.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MongoDB-backed session memory store.
|
| 3 |
+
Replaces in-memory storage with persistent MongoDB storage.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
import os
|
| 8 |
+
from datetime import datetime, timedelta
|
| 9 |
+
from pymongo import MongoClient, ASCENDING
|
| 10 |
+
from pymongo.errors import ConnectionFailure, OperationFailure
|
| 11 |
+
from schemas import Message
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
|
| 14 |
+
#load env vars
|
| 15 |
+
load_dotenv()
|
| 16 |
+
|
| 17 |
+
class MongoMemoryStore:
|
| 18 |
+
"""
|
| 19 |
+
MongoDB-backed session memory store.
|
| 20 |
+
|
| 21 |
+
Schema:
|
| 22 |
+
{
|
| 23 |
+
"_id": "session_id",
|
| 24 |
+
"messages": [
|
| 25 |
+
{"role": "user", "content": "..."},
|
| 26 |
+
{"role": "assistant", "content": "..."}
|
| 27 |
+
],
|
| 28 |
+
"updated_at": datetime,
|
| 29 |
+
"created_at": datetime
|
| 30 |
+
}
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
mongo_uri: Optional[str] = None,
|
| 36 |
+
database_name: str = "pharmai",
|
| 37 |
+
collection_name: str = "sessions",
|
| 38 |
+
max_messages: int = 30,
|
| 39 |
+
ttl_seconds: Optional[int] = None,
|
| 40 |
+
):
|
| 41 |
+
self.max_messages = max_messages
|
| 42 |
+
self.ttl_seconds = ttl_seconds
|
| 43 |
+
|
| 44 |
+
# Get MongoDB URI from env or parameter
|
| 45 |
+
self.mongo_uri = mongo_uri or os.getenv("MONGO_URI")
|
| 46 |
+
if not self.mongo_uri:
|
| 47 |
+
raise ValueError("MONGO_URI not found in environment variables")
|
| 48 |
+
|
| 49 |
+
# Connect to MongoDB
|
| 50 |
+
try:
|
| 51 |
+
self.client = MongoClient(self.mongo_uri, serverSelectionTimeoutMS=5000)
|
| 52 |
+
# Test connection
|
| 53 |
+
self.client.admin.command('ping')
|
| 54 |
+
print(f"✅ MongoDB connected: {database_name}.{collection_name}")
|
| 55 |
+
except ConnectionFailure as e:
|
| 56 |
+
raise ConnectionError(f"Failed to connect to MongoDB: {e}")
|
| 57 |
+
|
| 58 |
+
self.db = self.client[database_name]
|
| 59 |
+
self.collection = self.db[collection_name]
|
| 60 |
+
|
| 61 |
+
# Create indexes
|
| 62 |
+
self._create_indexes()
|
| 63 |
+
|
| 64 |
+
def _create_indexes(self):
|
| 65 |
+
"""Create indexes for performance and TTL."""
|
| 66 |
+
try:
|
| 67 |
+
# Get existing indexes
|
| 68 |
+
existing_indexes = self.collection.index_information()
|
| 69 |
+
|
| 70 |
+
# TTL index - automatically delete old sessions
|
| 71 |
+
if self.ttl_seconds:
|
| 72 |
+
# Check if TTL index exists
|
| 73 |
+
ttl_exists = any(
|
| 74 |
+
idx.get("expireAfterSeconds") is not None
|
| 75 |
+
for idx in existing_indexes.values()
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
if not ttl_exists:
|
| 79 |
+
# Drop the basic updated_at index if it exists (without TTL)
|
| 80 |
+
if "updated_at_1" in existing_indexes:
|
| 81 |
+
self.collection.drop_index("updated_at_1")
|
| 82 |
+
|
| 83 |
+
# Create TTL index
|
| 84 |
+
self.collection.create_index(
|
| 85 |
+
[("updated_at", ASCENDING)],
|
| 86 |
+
expireAfterSeconds=self.ttl_seconds,
|
| 87 |
+
name="session_ttl"
|
| 88 |
+
)
|
| 89 |
+
print(f"✅ Created TTL index (expires after {self.ttl_seconds}s)")
|
| 90 |
+
else:
|
| 91 |
+
# Just a regular index on updated_at (no TTL)
|
| 92 |
+
if "updated_at_1" not in existing_indexes and "session_ttl" not in existing_indexes:
|
| 93 |
+
self.collection.create_index([("updated_at", ASCENDING)])
|
| 94 |
+
print("✅ Created updated_at index")
|
| 95 |
+
|
| 96 |
+
except OperationFailure as e:
|
| 97 |
+
# Index creation failed, but continue anyway
|
| 98 |
+
print(f"⚠️ Index creation warning: {e}")
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
def get(self, session_id: str) -> List[Message]:
|
| 102 |
+
"""Get messages for a session."""
|
| 103 |
+
if not session_id:
|
| 104 |
+
return []
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
doc = self.collection.find_one({"_id": session_id})
|
| 108 |
+
if not doc:
|
| 109 |
+
return []
|
| 110 |
+
|
| 111 |
+
# Convert dict messages to Message objects
|
| 112 |
+
messages = []
|
| 113 |
+
for msg in doc.get("messages", []):
|
| 114 |
+
messages.append(Message(
|
| 115 |
+
role=msg.get("role", "user"),
|
| 116 |
+
content=msg.get("content", "")
|
| 117 |
+
))
|
| 118 |
+
|
| 119 |
+
return messages
|
| 120 |
+
except OperationFailure as e:
|
| 121 |
+
print(f"Error getting session {session_id}: {e}")
|
| 122 |
+
return []
|
| 123 |
+
|
| 124 |
+
def append(self, session_id: str, role: str, content: str) -> None:
|
| 125 |
+
"""Append a message to a session."""
|
| 126 |
+
if not session_id:
|
| 127 |
+
return
|
| 128 |
+
|
| 129 |
+
now = datetime.utcnow()
|
| 130 |
+
message = {"role": role, "content": content}
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
# Try to update existing session
|
| 134 |
+
result = self.collection.update_one(
|
| 135 |
+
{"_id": session_id},
|
| 136 |
+
{
|
| 137 |
+
"$push": {"messages": message},
|
| 138 |
+
"$set": {"updated_at": now}
|
| 139 |
+
}
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# If session doesn't exist, create it
|
| 143 |
+
if result.matched_count == 0:
|
| 144 |
+
self.collection.insert_one({
|
| 145 |
+
"_id": session_id,
|
| 146 |
+
"messages": [message],
|
| 147 |
+
"created_at": now,
|
| 148 |
+
"updated_at": now
|
| 149 |
+
})
|
| 150 |
+
|
| 151 |
+
# Trim old messages if needed
|
| 152 |
+
self._trim_messages(session_id)
|
| 153 |
+
|
| 154 |
+
except OperationFailure as e:
|
| 155 |
+
print(f"Error appending to session {session_id}: {e}")
|
| 156 |
+
|
| 157 |
+
def _trim_messages(self, session_id: str) -> None:
|
| 158 |
+
"""Keep only the most recent max_messages."""
|
| 159 |
+
try:
|
| 160 |
+
doc = self.collection.find_one({"_id": session_id})
|
| 161 |
+
if not doc:
|
| 162 |
+
return
|
| 163 |
+
|
| 164 |
+
messages = doc.get("messages", [])
|
| 165 |
+
if len(messages) > self.max_messages:
|
| 166 |
+
# Keep only the most recent messages
|
| 167 |
+
trimmed = messages[-self.max_messages:]
|
| 168 |
+
self.collection.update_one(
|
| 169 |
+
{"_id": session_id},
|
| 170 |
+
{"$set": {"messages": trimmed}}
|
| 171 |
+
)
|
| 172 |
+
except OperationFailure as e:
|
| 173 |
+
print(f"Error trimming session {session_id}: {e}")
|
| 174 |
+
|
| 175 |
+
def set_messages(self, session_id: str, messages: List[Message]) -> None:
|
| 176 |
+
"""Replace session history entirely."""
|
| 177 |
+
if not session_id:
|
| 178 |
+
return
|
| 179 |
+
|
| 180 |
+
now = datetime.utcnow()
|
| 181 |
+
message_dicts = [{"role": m.role, "content": m.content} for m in messages]
|
| 182 |
+
|
| 183 |
+
# Keep only most recent messages
|
| 184 |
+
if len(message_dicts) > self.max_messages:
|
| 185 |
+
message_dicts = message_dicts[-self.max_messages:]
|
| 186 |
+
|
| 187 |
+
try:
|
| 188 |
+
self.collection.update_one(
|
| 189 |
+
{"_id": session_id},
|
| 190 |
+
{
|
| 191 |
+
"$set": {
|
| 192 |
+
"messages": message_dicts,
|
| 193 |
+
"updated_at": now
|
| 194 |
+
},
|
| 195 |
+
"$setOnInsert": {"created_at": now}
|
| 196 |
+
},
|
| 197 |
+
upsert=True
|
| 198 |
+
)
|
| 199 |
+
except OperationFailure as e:
|
| 200 |
+
print(f"Error setting messages for session {session_id}: {e}")
|
| 201 |
+
|
| 202 |
+
def clear(self, session_id: str) -> None:
|
| 203 |
+
"""Clear a single session."""
|
| 204 |
+
if not session_id:
|
| 205 |
+
return
|
| 206 |
+
|
| 207 |
+
try:
|
| 208 |
+
self.collection.delete_one({"_id": session_id})
|
| 209 |
+
except OperationFailure as e:
|
| 210 |
+
print(f"Error clearing session {session_id}: {e}")
|
| 211 |
+
|
| 212 |
+
def cleanup_old_sessions(self, days: int = 7) -> int:
|
| 213 |
+
"""
|
| 214 |
+
Manually cleanup sessions older than X days.
|
| 215 |
+
(TTL index handles this automatically if configured)
|
| 216 |
+
"""
|
| 217 |
+
cutoff = datetime.utcnow() - timedelta(days=days)
|
| 218 |
+
try:
|
| 219 |
+
result = self.collection.delete_many({"updated_at": {"$lt": cutoff}})
|
| 220 |
+
return result.deleted_count
|
| 221 |
+
except OperationFailure as e:
|
| 222 |
+
print(f"Error cleaning up old sessions: {e}")
|
| 223 |
+
return 0
|
| 224 |
+
|
| 225 |
+
def get_session_count(self) -> int:
|
| 226 |
+
"""Get total number of active sessions."""
|
| 227 |
+
try:
|
| 228 |
+
return self.collection.count_documents({})
|
| 229 |
+
except OperationFailure:
|
| 230 |
+
return 0
|
| 231 |
+
|
| 232 |
+
def close(self):
|
| 233 |
+
"""Close MongoDB connection."""
|
| 234 |
+
if self.client:
|
| 235 |
+
self.client.close()
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Create global singleton
|
| 239 |
+
def create_memory_store() -> MongoMemoryStore:
|
| 240 |
+
"""Factory function to create memory store based on configuration."""
|
| 241 |
+
try:
|
| 242 |
+
# Try MongoDB first
|
| 243 |
+
return MongoMemoryStore(
|
| 244 |
+
max_messages=int(os.getenv("MAX_SESSION_MESSAGES", "30")),
|
| 245 |
+
ttl_seconds=int(os.getenv("SESSION_TTL_SECONDS", "0")) or None,
|
| 246 |
+
)
|
| 247 |
+
except (ValueError, ConnectionError) as e:
|
| 248 |
+
print(f"⚠️ MongoDB not available: {e}")
|
| 249 |
+
print("⚠️ Falling back to in-memory storage")
|
| 250 |
+
|
| 251 |
+
# Fallback to in-memory
|
| 252 |
+
from memory import MemoryStore
|
| 253 |
+
return MemoryStore(
|
| 254 |
+
max_messages=int(os.getenv("MAX_SESSION_MESSAGES", "30")),
|
| 255 |
+
ttl_seconds=int(os.getenv("SESSION_TTL_SECONDS", "0")) or None,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# Global instance
|
| 260 |
+
memory_store = create_memory_store()
|
requirements.txt
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FastAPI and server
|
| 2 |
+
fastapi
|
| 3 |
+
uvicorn
|
| 4 |
+
python-dotenv
|
| 5 |
+
|
| 6 |
+
# LangChain and LangGraph
|
| 7 |
+
langchain
|
| 8 |
+
langchain-anthropic
|
| 9 |
+
langchain-core
|
| 10 |
+
langchain-community
|
| 11 |
+
langgraph
|
| 12 |
+
langgraph-checkpoint
|
| 13 |
+
langsmith
|
| 14 |
+
|
| 15 |
+
# Tools and utilities
|
| 16 |
+
tavily-python
|
| 17 |
+
pydantic
|
| 18 |
+
pydantic-settings
|
| 19 |
+
|
| 20 |
+
# MongoDB
|
| 21 |
+
pymongo
|
| 22 |
+
motor
|
| 23 |
+
|
| 24 |
+
# HTTP client
|
| 25 |
+
httpx
|
| 26 |
+
aiohttp
|
| 27 |
+
|
| 28 |
+
# Graph rendering (optional)
|
| 29 |
+
graphviz
|
| 30 |
+
|
| 31 |
+
# Other dependencies
|
| 32 |
+
python-multipart
|
schemas.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# schemas.py
|
| 2 |
+
from typing import List, Optional, Dict, Any
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
from enum import Enum
|
| 5 |
+
|
| 6 |
+
# Core Enums
|
| 7 |
+
class AgentType(str, Enum):
|
| 8 |
+
PLANNER = "planner"
|
| 9 |
+
SCIENTIFIC = "scientific"
|
| 10 |
+
PATENT = "patent"
|
| 11 |
+
MARKET = "market"
|
| 12 |
+
SUPPLY = "supply"
|
| 13 |
+
SYNTHESIS = "synthesis"
|
| 14 |
+
|
| 15 |
+
class EvidenceType(str, Enum):
|
| 16 |
+
LITERATURE = "literature"
|
| 17 |
+
CLINICAL_TRIAL = "clinical_trial"
|
| 18 |
+
PATENT = "patent"
|
| 19 |
+
MARKET = "market"
|
| 20 |
+
OTHER = "other"
|
| 21 |
+
|
| 22 |
+
# API Schemas (FastAPI I/O)
|
| 23 |
+
class AgentRunRequest(BaseModel):
|
| 24 |
+
"""
|
| 25 |
+
Incoming request from Node.js backend or direct API call.
|
| 26 |
+
"""
|
| 27 |
+
session_id: Optional[str] = Field(
|
| 28 |
+
default=None,
|
| 29 |
+
description="Optional session ID to maintain conversation state"
|
| 30 |
+
)
|
| 31 |
+
query: str = Field(
|
| 32 |
+
...,
|
| 33 |
+
description="User query, e.g. 'Drug X for Indication Y'"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
class AgentRunResponse(BaseModel):
|
| 37 |
+
"""
|
| 38 |
+
Final response returned by the agent system.
|
| 39 |
+
"""
|
| 40 |
+
session_id: Optional[str]
|
| 41 |
+
decision_brief: str
|
| 42 |
+
confidence_score: Optional[float] = Field(
|
| 43 |
+
default=None,
|
| 44 |
+
description="Optional overall confidence score (0–1)"
|
| 45 |
+
)
|
| 46 |
+
citations: Optional[List[str]] = Field(
|
| 47 |
+
default=None,
|
| 48 |
+
description="List of citation identifiers or URLs"
|
| 49 |
+
)
|
| 50 |
+
metadata: Optional[Dict[str, Any]] = Field(
|
| 51 |
+
default=None,
|
| 52 |
+
description="Extra debug or trace metadata"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Internal Agent State
|
| 56 |
+
class Message(BaseModel):
|
| 57 |
+
"""
|
| 58 |
+
Canonical message format passed between agents.
|
| 59 |
+
"""
|
| 60 |
+
role: str # system | user | assistant | tool
|
| 61 |
+
content: str
|
| 62 |
+
|
| 63 |
+
class EvidenceItem(BaseModel):
|
| 64 |
+
"""
|
| 65 |
+
A single piece of evidence produced by tools or agents.
|
| 66 |
+
"""
|
| 67 |
+
type: EvidenceType
|
| 68 |
+
source: str
|
| 69 |
+
summary: str
|
| 70 |
+
confidence: Optional[float] = None
|
| 71 |
+
raw: Optional[Dict[str, Any]] = None
|
| 72 |
+
|
| 73 |
+
class AgentOutput(BaseModel):
|
| 74 |
+
"""
|
| 75 |
+
Output produced by a single agent.
|
| 76 |
+
"""
|
| 77 |
+
agent: AgentType
|
| 78 |
+
text: str
|
| 79 |
+
evidence: Optional[List[EvidenceItem]] = None
|
| 80 |
+
|
| 81 |
+
class AgentState(BaseModel):
|
| 82 |
+
"""
|
| 83 |
+
LangGraph state object.
|
| 84 |
+
This is what flows between graph nodes.
|
| 85 |
+
"""
|
| 86 |
+
session_id: Optional[str]
|
| 87 |
+
user_query: str
|
| 88 |
+
|
| 89 |
+
messages: List[Message] = Field(default_factory=list)
|
| 90 |
+
|
| 91 |
+
agent_outputs: Dict[AgentType, AgentOutput] = Field(
|
| 92 |
+
default_factory=dict,
|
| 93 |
+
description="Outputs from each agent"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
final_decision: Optional[str] = None
|
| 97 |
+
|
| 98 |
+
confidence_score: Optional[float] = None
|
tools.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Any, Optional
|
| 2 |
+
import os
|
| 3 |
+
import uuid
|
| 4 |
+
import re
|
| 5 |
+
import base64
|
| 6 |
+
from schemas import EvidenceItem, EvidenceType
|
| 7 |
+
|
| 8 |
+
#Helper Functions
|
| 9 |
+
def _etype(name: str, default: EvidenceType) -> EvidenceType:
|
| 10 |
+
"""Return EvidenceType.<name> if it exists, else default (prevents breaking)."""
|
| 11 |
+
return getattr(EvidenceType, name, default)
|
| 12 |
+
|
| 13 |
+
def _short(s: str, n: int = 700) -> str:
|
| 14 |
+
return (s or "")[:n]
|
| 15 |
+
|
| 16 |
+
def _is_url(s: str) -> bool:
|
| 17 |
+
return isinstance(s, str) and s.startswith(("http://", "https://"))
|
| 18 |
+
|
| 19 |
+
# Tool 1: Tavily Web Search (existing, unchanged)
|
| 20 |
+
def tavily_search(query: str, max_results: int = 5) -> List[EvidenceItem]:
|
| 21 |
+
"""
|
| 22 |
+
Uses Tavily API to perform web search.
|
| 23 |
+
Returns structured evidence items.
|
| 24 |
+
"""
|
| 25 |
+
api_key = os.getenv("TAVILY_API_KEY")
|
| 26 |
+
|
| 27 |
+
if not api_key:
|
| 28 |
+
return [
|
| 29 |
+
EvidenceItem(
|
| 30 |
+
type=EvidenceType.OTHER,
|
| 31 |
+
source="tavily_disabled",
|
| 32 |
+
summary="Tavily API key not configured; search skipped.",
|
| 33 |
+
confidence=0.0,
|
| 34 |
+
)
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
from tavily import TavilyClient
|
| 39 |
+
|
| 40 |
+
client = TavilyClient(api_key=api_key)
|
| 41 |
+
results = client.search(
|
| 42 |
+
query=query,
|
| 43 |
+
max_results=max_results,
|
| 44 |
+
include_raw_content=False,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
evidence: List[EvidenceItem] = []
|
| 48 |
+
|
| 49 |
+
for r in results.get("results", []):
|
| 50 |
+
evidence.append(
|
| 51 |
+
EvidenceItem(
|
| 52 |
+
type=EvidenceType.LITERATURE,
|
| 53 |
+
source=r.get("url", "unknown"),
|
| 54 |
+
summary=r.get("content", "")[:500],
|
| 55 |
+
confidence=0.6,
|
| 56 |
+
raw=r,
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
return evidence
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
return [
|
| 64 |
+
EvidenceItem(
|
| 65 |
+
type=EvidenceType.OTHER,
|
| 66 |
+
source="tavily_error",
|
| 67 |
+
summary=f"Tavily search failed: {str(e)}",
|
| 68 |
+
confidence=0.0,
|
| 69 |
+
)
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
# Tool 2: Stub Evidence Generator (existing, unchanged)
|
| 73 |
+
def stub_evidence(query: str) -> List[EvidenceItem]:
|
| 74 |
+
"""
|
| 75 |
+
Deterministic fallback tool.
|
| 76 |
+
Useful for demos, offline mode, or testing agent logic.
|
| 77 |
+
"""
|
| 78 |
+
return [
|
| 79 |
+
EvidenceItem(
|
| 80 |
+
type=EvidenceType.OTHER,
|
| 81 |
+
source="stub_tool",
|
| 82 |
+
summary=f"Stub evidence generated for query: '{query}'. "
|
| 83 |
+
f"This indicates where real retrieval will plug in.",
|
| 84 |
+
confidence=0.2,
|
| 85 |
+
raw={
|
| 86 |
+
"id": str(uuid.uuid4()),
|
| 87 |
+
"note": "Replace with real retrieval later",
|
| 88 |
+
},
|
| 89 |
+
)
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
# Tool 3: Query Classifier (planner helper)
|
| 93 |
+
def classify_query(query: str) -> Dict[str, Any]:
|
| 94 |
+
"""
|
| 95 |
+
Lightweight classifier to help the agent decide
|
| 96 |
+
which tools (if any) are required.
|
| 97 |
+
"""
|
| 98 |
+
q = (query or "").lower()
|
| 99 |
+
needs_graph = any(k in q for k in ["diagram", "graph", "graphviz", "dot", "flow", "architecture", "arch", "draw"])
|
| 100 |
+
needs_trials = any(k in q for k in ["trial", "clinical", "phase", "nct", "primary endpoint", "secondary endpoint"])
|
| 101 |
+
needs_facts = any(k in q for k in ["fda", "approval", "label", "patent", "exclusivity", "pricing", "aria", "safety", "market"])
|
| 102 |
+
needs_entities = any(k in q for k in ["evaluate", "assess", "analyze", "repurpose", "for "])
|
| 103 |
+
return {
|
| 104 |
+
"needs_graph": needs_graph,
|
| 105 |
+
"needs_clinical_trials": needs_trials,
|
| 106 |
+
"needs_web_search": needs_facts or needs_trials,
|
| 107 |
+
"needs_entity_extraction": needs_entities,
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
# Tool 4: Entity Extraction (Drug / Indication)
|
| 111 |
+
def extract_entities(query: str) -> Dict[str, Optional[str]]:
|
| 112 |
+
"""
|
| 113 |
+
Minimal entity extractor for MVP.
|
| 114 |
+
"""
|
| 115 |
+
text = (query or "").strip()
|
| 116 |
+
m = re.search(
|
| 117 |
+
r"(evaluate|assess|analyze)\s+(?P<drug>.+?)\s+for\s+(?P<indication>.+)",
|
| 118 |
+
text,
|
| 119 |
+
re.IGNORECASE,
|
| 120 |
+
)
|
| 121 |
+
if m:
|
| 122 |
+
return {
|
| 123 |
+
"drug": m.group("drug").strip(),
|
| 124 |
+
"indication": m.group("indication").strip(),
|
| 125 |
+
}
|
| 126 |
+
return {"drug": None, "indication": None}
|
| 127 |
+
|
| 128 |
+
# Tool 5: Evidence Normalizer (dedupe + cleanup)
|
| 129 |
+
def normalize_evidence(evidence: List[EvidenceItem]) -> List[EvidenceItem]:
|
| 130 |
+
"""
|
| 131 |
+
Deduplicates evidence by source and trims noisy content.
|
| 132 |
+
"""
|
| 133 |
+
seen = set()
|
| 134 |
+
cleaned: List[EvidenceItem] = []
|
| 135 |
+
|
| 136 |
+
for e in evidence:
|
| 137 |
+
if e.source in seen:
|
| 138 |
+
continue
|
| 139 |
+
seen.add(e.source)
|
| 140 |
+
|
| 141 |
+
cleaned.append(
|
| 142 |
+
EvidenceItem(
|
| 143 |
+
type=e.type,
|
| 144 |
+
source=e.source,
|
| 145 |
+
summary=(e.summary or "")[:800],
|
| 146 |
+
confidence=e.confidence,
|
| 147 |
+
raw=None, # drop heavy payloads
|
| 148 |
+
)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
return cleaned
|
| 152 |
+
|
| 153 |
+
# Tool 6: Graph Generation (Graphviz DOT only)
|
| 154 |
+
def generate_graph_dot(
|
| 155 |
+
title: str,
|
| 156 |
+
nodes: List[Dict[str, str]],
|
| 157 |
+
edges: List[Dict[str, str]],
|
| 158 |
+
rankdir: str = "LR",
|
| 159 |
+
) -> str:
|
| 160 |
+
"""
|
| 161 |
+
Generates Graphviz DOT code.
|
| 162 |
+
IMPORTANT: LLM must call this tool; never output DOT directly.
|
| 163 |
+
"""
|
| 164 |
+
safe_title = (title or "PharmAI Graph").replace('"', "'")
|
| 165 |
+
|
| 166 |
+
lines = [
|
| 167 |
+
"digraph G {",
|
| 168 |
+
f" rankdir={rankdir};",
|
| 169 |
+
' labelloc="t";',
|
| 170 |
+
' labeljust="c";',
|
| 171 |
+
f' label=<<B><FONT POINT-SIZE="28">{safe_title}</FONT></B>>;',
|
| 172 |
+
" node [shape=box, style=rounded];",
|
| 173 |
+
"",
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
for n in nodes or []:
|
| 177 |
+
nid = n.get("id")
|
| 178 |
+
lbl = (n.get("label") or nid).replace('"', "'")
|
| 179 |
+
if nid:
|
| 180 |
+
lines.append(f' {nid} [label="{lbl}"];')
|
| 181 |
+
|
| 182 |
+
lines.append("")
|
| 183 |
+
|
| 184 |
+
for e in edges or []:
|
| 185 |
+
src = e.get("from")
|
| 186 |
+
tgt = e.get("to")
|
| 187 |
+
lbl = e.get("label")
|
| 188 |
+
if src and tgt:
|
| 189 |
+
if lbl:
|
| 190 |
+
lines.append(f' {src} -> {tgt} [label="{lbl}"];')
|
| 191 |
+
else:
|
| 192 |
+
lines.append(f" {src} -> {tgt};")
|
| 193 |
+
|
| 194 |
+
lines.append("}")
|
| 195 |
+
return "\n".join(lines)
|
| 196 |
+
|
| 197 |
+
#Tool 7: ClinicalTrials search (lightweight, Tavily-based)
|
| 198 |
+
def clinicaltrials_search(drug: str, indication: str, max_results: int = 5) -> List[EvidenceItem]:
|
| 199 |
+
"""
|
| 200 |
+
MVP approach:
|
| 201 |
+
- Uses Tavily to target ClinicalTrials.gov / NCT IDs
|
| 202 |
+
- Returns EvidenceItems for trial links + snippets
|
| 203 |
+
"""
|
| 204 |
+
drug = (drug or "").strip()
|
| 205 |
+
indication = (indication or "").strip()
|
| 206 |
+
|
| 207 |
+
if not drug or not indication:
|
| 208 |
+
return [
|
| 209 |
+
EvidenceItem(
|
| 210 |
+
type=EvidenceType.OTHER,
|
| 211 |
+
source="clinicaltrials_search_invalid_input",
|
| 212 |
+
summary="Missing drug or indication for clinical trials search.",
|
| 213 |
+
confidence=0.0,
|
| 214 |
+
)
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
query = f'site:clinicaltrials.gov ("{drug}") ("{indication}") NCT'
|
| 218 |
+
ev = tavily_search(query=query, max_results=max_results)
|
| 219 |
+
|
| 220 |
+
trial_type = _etype("CLINICAL_TRIAL", EvidenceType.LITERATURE)
|
| 221 |
+
|
| 222 |
+
out: List[EvidenceItem] = []
|
| 223 |
+
for e in ev:
|
| 224 |
+
# only keep plausible CT.gov results if possible
|
| 225 |
+
if _is_url(e.source) and "clinicaltrials.gov" in e.source:
|
| 226 |
+
out.append(
|
| 227 |
+
EvidenceItem(
|
| 228 |
+
type=trial_type,
|
| 229 |
+
source=e.source,
|
| 230 |
+
summary=e.summary,
|
| 231 |
+
confidence=max(0.55, float(e.confidence or 0.55)),
|
| 232 |
+
raw=e.raw,
|
| 233 |
+
)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
if out:
|
| 237 |
+
return out
|
| 238 |
+
|
| 239 |
+
# fallback: return whatever Tavily gave (still structured)
|
| 240 |
+
return [
|
| 241 |
+
EvidenceItem(
|
| 242 |
+
type=trial_type,
|
| 243 |
+
source=e.source,
|
| 244 |
+
summary=e.summary,
|
| 245 |
+
confidence=float(e.confidence or 0.4),
|
| 246 |
+
raw=e.raw,
|
| 247 |
+
)
|
| 248 |
+
for e in ev
|
| 249 |
+
]
|
| 250 |
+
|
| 251 |
+
#Tool 8 : DOT -> PNG
|
| 252 |
+
def render_dot_to_png_base64(dot: str) -> Dict[str, Any]:
|
| 253 |
+
"""
|
| 254 |
+
Converts DOT to PNG and returns base64 string.
|
| 255 |
+
- Uses python 'graphviz' package if available.
|
| 256 |
+
- If graphviz isn't installed in the environment, returns an error payload.
|
| 257 |
+
"""
|
| 258 |
+
dot = (dot or "").strip()
|
| 259 |
+
if not dot:
|
| 260 |
+
return {"ok": False, "error": "Empty DOT string"}
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
from graphviz import Source # optional dependency
|
| 264 |
+
|
| 265 |
+
src = Source(dot)
|
| 266 |
+
png_bytes = src.pipe(format="png")
|
| 267 |
+
b64 = base64.b64encode(png_bytes).decode("utf-8")
|
| 268 |
+
return {"ok": True, "png_base64": b64}
|
| 269 |
+
except Exception as e:
|
| 270 |
+
return {
|
| 271 |
+
"ok": False,
|
| 272 |
+
"error": f"DOT->PNG render failed. Ensure `graphviz` Python package and system binaries are installed. Details: {str(e)}",
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
# Tool Registry (extended, backward compatible)
|
| 276 |
+
TOOL_REGISTRY: Dict[str, Any] = {
|
| 277 |
+
# existing
|
| 278 |
+
"web_search": tavily_search,
|
| 279 |
+
"stub_evidence": stub_evidence,
|
| 280 |
+
|
| 281 |
+
# new
|
| 282 |
+
"classify_query": classify_query,
|
| 283 |
+
"extract_entities": extract_entities,
|
| 284 |
+
"normalize_evidence": normalize_evidence,
|
| 285 |
+
"generate_graph_dot": generate_graph_dot,
|
| 286 |
+
"clinicaltrials_search": clinicaltrials_search,
|
| 287 |
+
"render_dot_to_png_base64": render_dot_to_png_base64
|
| 288 |
+
}
|