arittrabag commited on
Commit
1de0a51
·
verified ·
1 Parent(s): 1ed42fa

Added main app files

Browse files
Files changed (9) hide show
  1. .dockerignore +33 -0
  2. Dockerfile +30 -0
  3. app.py +164 -0
  4. graph.py +479 -0
  5. memory.py +91 -0
  6. memory_mongo.py +260 -0
  7. requirements.txt +32 -0
  8. schemas.py +98 -0
  9. tools.py +288 -0
.dockerignore ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ venv/
9
+ ENV/
10
+
11
+ # Testing
12
+ .pytest_cache
13
+ *.log
14
+
15
+ # IDE
16
+ .vscode
17
+ .idea
18
+ *.swp
19
+ *.swo
20
+
21
+ # Test files
22
+ test_*.py
23
+ test_*.ps1
24
+ *_test.py
25
+ fix_indexes.py
26
+
27
+ # Documentation
28
+ *.md
29
+ MEMORY_COMPARISON.md
30
+
31
+ # Output files
32
+ *.png
33
+ output.txt
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.11 slim image
2
+ FROM python:3.11-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies for graphviz (optional but recommended)
8
+ RUN apt-get update && apt-get install -y \
9
+ graphviz \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy requirements first for better caching
13
+ COPY requirements.txt .
14
+
15
+ # Install Python dependencies
16
+ RUN pip install --no-cache-dir --upgrade pip && \
17
+ pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Copy application code
20
+ COPY . .
21
+
22
+ # Expose port 7860 (HuggingFace Spaces default)
23
+ EXPOSE 7860
24
+
25
+ # Set environment variables
26
+ ENV PYTHONUNBUFFERED=1
27
+ ENV PORT=7860
28
+
29
+ # Run the FastAPI application
30
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import uuid
3
+ from fastapi import FastAPI, HTTPException
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from schemas import AgentRunRequest, AgentRunResponse, Message
6
+ from memory_mongo import memory_store # MongoDB-backed memory
7
+ from graph import build_graph
8
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
9
+ from fastapi.responses import StreamingResponse
10
+ import json
11
+ import time
12
+ from fastapi.encoders import jsonable_encoder
13
+
14
+ app = FastAPI(title="PharmAI Navigator (Agentic)", version="0.1.0")
15
+
16
+ # CORS (HF Spaces + your Node proxy)
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+ # Compile graph once at startup
26
+ GRAPH = build_graph()
27
+
28
+
29
+ @app.get("/health")
30
+ def health():
31
+ """Health check with MongoDB status."""
32
+ mongo_status = "connected"
33
+ session_count = 0
34
+
35
+ try:
36
+ session_count = memory_store.get_session_count()
37
+ except Exception as e:
38
+ mongo_status = f"error: {str(e)}"
39
+
40
+ return {
41
+ "status": "ok",
42
+ "mongodb": mongo_status,
43
+ "active_sessions": session_count
44
+ }
45
+
46
+
47
+ @app.get("/session/{session_id}/history")
48
+ def get_session_history(session_id: str):
49
+ """Get chat history for a session (for testing)."""
50
+ messages = memory_store.get(session_id)
51
+ return {
52
+ "session_id": session_id,
53
+ "message_count": len(messages),
54
+ "messages": [{"role": m.role, "content": m.content[:100] + "..." if len(m.content) > 100 else m.content} for m in messages]
55
+ }
56
+
57
+
58
+ @app.delete("/session/{session_id}")
59
+ def clear_session(session_id: str):
60
+ """Clear a session's history (for testing)."""
61
+ memory_store.clear(session_id)
62
+ return {"session_id": session_id, "status": "cleared"}
63
+
64
+
65
+ @app.post("/admin/cleanup-sessions")
66
+ def cleanup_old_sessions(days: int = 7):
67
+ """
68
+ Admin endpoint to manually cleanup old sessions.
69
+ (TTL index handles this automatically if configured)
70
+ """
71
+ try:
72
+ deleted = memory_store.cleanup_old_sessions(days=days)
73
+ return {
74
+ "status": "ok",
75
+ "deleted_sessions": deleted,
76
+ "days": days
77
+ }
78
+ except Exception as e:
79
+ raise HTTPException(status_code=500, detail=str(e))
80
+
81
+
82
+ @app.post("/test/echo")
83
+ def test_echo(req: AgentRunRequest):
84
+ """
85
+ Lightweight test endpoint - no LLM calls, just tests memory.
86
+ Echoes back the query and shows session history.
87
+ """
88
+ session_id = req.session_id or str(uuid.uuid4())
89
+
90
+ # Get prior history
91
+ prior = memory_store.get(session_id)
92
+
93
+ # Append user message
94
+ memory_store.append(session_id, role="user", content=req.query)
95
+
96
+ # Create fake response
97
+ fake_response = f"Echo: {req.query} (Session has {len(prior)} prior messages)"
98
+
99
+ # Append assistant message
100
+ memory_store.append(session_id, role="assistant", content=fake_response)
101
+
102
+ return {
103
+ "session_id": session_id,
104
+ "decision_brief": fake_response,
105
+ "prior_message_count": len(prior),
106
+ "current_message_count": len(memory_store.get(session_id)),
107
+ "citations": [],
108
+ "metadata": {"test_mode": True}
109
+ }
110
+
111
+
112
+ @app.post("/run", response_model=AgentRunResponse)
113
+ def run_agent(req: AgentRunRequest):
114
+ # 1) session handling
115
+ session_id = req.session_id or str(uuid.uuid4())
116
+
117
+ # 2) load prior history (for chat continuity)
118
+ prior = memory_store.get(session_id)
119
+
120
+ # Convert to LangChain message dict format for LangGraph MessagesState
121
+ # LangGraph expects state["messages"] as list of LC messages; we pass dict-like messages.
122
+ messages = []
123
+ for m in prior:
124
+ if m.role == "user":
125
+ messages.append(HumanMessage(content=m.content))
126
+ elif m.role == "assistant":
127
+ messages.append(AIMessage(content=m.content))
128
+ elif m.role == "system":
129
+ messages.append(SystemMessage(content=m.content))
130
+
131
+ # 3) append this user query to memory (pre-run)
132
+ memory_store.append(session_id, role="user", content=req.query)
133
+
134
+ # Append new user query as LangChain message
135
+ messages = messages + [HumanMessage(content=req.query)]
136
+
137
+ # 4) run graph (Mode A synchronous)
138
+ try:
139
+ final_state = GRAPH.invoke(
140
+ {
141
+ "session_id": session_id,
142
+ "user_query": req.query,
143
+ "messages": messages,
144
+ }
145
+ )
146
+ except Exception as e:
147
+ raise HTTPException(status_code=500, detail=f"Agent run failed: {str(e)}")
148
+
149
+ decision_brief = final_state.get("decision_brief") or final_state.get("final_decision") or ""
150
+ citations = final_state.get("citations") or []
151
+
152
+ # 5) save assistant response to memory (post-run)
153
+ if decision_brief.strip():
154
+ memory_store.append(session_id, role="assistant", content=decision_brief)
155
+
156
+ return AgentRunResponse(
157
+ session_id=session_id,
158
+ decision_brief=decision_brief,
159
+ confidence_score=final_state.get("confidence_score"),
160
+ citations=citations,
161
+ metadata={
162
+ "has_prior_messages": len(prior) > 0,
163
+ },
164
+ )
graph.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import json
5
+ import re
6
+ from typing import Any, Dict, List, Optional
7
+ from dotenv import load_dotenv
8
+
9
+ from langchain_anthropic import ChatAnthropic
10
+ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, ToolMessage
11
+ from langchain_core.tools import tool
12
+
13
+ from langgraph.graph import StateGraph, START, END, MessagesState
14
+ from langgraph.prebuilt import ToolNode, tools_condition
15
+
16
+ from tools import (
17
+ tavily_search,
18
+ stub_evidence,
19
+ classify_query,
20
+ extract_entities,
21
+ normalize_evidence,
22
+ generate_graph_dot,
23
+ clinicaltrials_search,
24
+ render_dot_to_png_base64
25
+ )
26
+
27
+ # Load environment variables
28
+ load_dotenv()
29
+
30
+ # -----------------------------
31
+ # LangChain Tool Wrappers
32
+ # -----------------------------
33
+ @tool("web_search")
34
+ def web_search_tool(query: str, max_results: int = 5) -> List[Dict[str, Any]]:
35
+ """Web search using Tavily. Returns a list of evidence dicts."""
36
+ ev = tavily_search(query=query, max_results=max_results)
37
+ return [e.model_dump() for e in ev]
38
+
39
+
40
+ @tool("stub_evidence")
41
+ def stub_evidence_tool(query: str) -> List[Dict[str, Any]]:
42
+ """Deterministic fallback evidence tool (offline/demo)."""
43
+ ev = stub_evidence(query=query)
44
+ return [e.model_dump() for e in ev]
45
+
46
+ @tool("classify_query")
47
+ def classify_query_tool(query: str) -> Dict[str, Any]:
48
+ """Classify query to decide which tools are needed."""
49
+ return classify_query(query)
50
+
51
+
52
+ @tool("extract_entities")
53
+ def extract_entities_tool(query: str) -> Dict[str, Optional[str]]:
54
+ """Extract drug and indication from query."""
55
+ return extract_entities(query)
56
+
57
+
58
+ @tool("normalize_evidence")
59
+ def normalize_evidence_tool(evidence: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
60
+ """Dedupe and clean evidence."""
61
+ return normalize_evidence(evidence)
62
+
63
+
64
+ @tool("generate_graph_dot")
65
+ def generate_graph_dot_tool(
66
+ title: str,
67
+ nodes: List[Dict[str, str]],
68
+ edges: List[Dict[str, str]],
69
+ rankdir: str = "LR",
70
+ ) -> str:
71
+ """
72
+ Generate Graphviz DOT.
73
+ IMPORTANT: Use this tool instead of writing DOT directly.
74
+ """
75
+ return generate_graph_dot(
76
+ title=title,
77
+ nodes=nodes,
78
+ edges=edges,
79
+ rankdir=rankdir,
80
+ )
81
+
82
+ @tool("clinicaltrials_search")
83
+ def clinicaltrials_search_tool(drug: str, indication: str, max_results: int = 5) -> List[Dict[str, Any]]:
84
+ """Search ClinicalTrials.gov (Tavily-based MVP)."""
85
+ ev = clinicaltrials_search(drug=drug, indication=indication, max_results=max_results)
86
+ return [e.model_dump() for e in ev]
87
+
88
+ @tool("render_dot_to_png_base64")
89
+ def render_dot_to_png_base64_tool(dot: str) -> Dict[str, Any]:
90
+ """Render DOT to PNG (base64). Optional dependency on graphviz."""
91
+ return render_dot_to_png_base64(dot)
92
+
93
+ TOOLS = [
94
+ web_search_tool,
95
+ stub_evidence_tool,
96
+ classify_query_tool,
97
+ extract_entities_tool,
98
+ normalize_evidence_tool,
99
+ generate_graph_dot_tool,
100
+ clinicaltrials_search_tool,
101
+ render_dot_to_png_base64_tool
102
+ ]
103
+
104
+ # -----------------------------
105
+ # LangGraph State
106
+ # -----------------------------
107
+ class PharmAIState(MessagesState):
108
+ session_id: Optional[str]
109
+ user_query: str
110
+ decision_brief: str
111
+ citations: List[str]
112
+ confidence_score: float
113
+ tool_loops: int # safety counter
114
+ diagram_png_base64: Optional[str] # <-- add
115
+ diagram_dot: Optional[str] # <-- optional
116
+ intent: str # "simple" | "diligence" | "diagram"
117
+
118
+ # -----------------------------
119
+ # Guardrails + Prompts
120
+ # -----------------------------
121
+ SYSTEM_PROMPT = """You are PharmAI Navigator, an evidence-grounded diligence assistant for drug/asset evaluation.
122
+
123
+ Your job:
124
+ Turn a query like "Assess {Drug} for {Indication}" into a decision-grade brief OR structured output.
125
+
126
+ CRITICAL TOOL USAGE RULES:
127
+ - If the user asks for a diagram, flow, architecture, graph, visualization, or Graphviz:
128
+ → You MUST call `generate_graph_dot`.
129
+ → You MUST NOT write Graphviz DOT directly in your response.
130
+ → If the user asks for an image/PNG, call `render_dot_to_png_base64` AFTER you get DOT.
131
+ - If the user asks for trials / phases / NCT IDs / endpoints:
132
+ → Prefer calling `extract_entities` then `clinicaltrials_search`.
133
+ - If the user asks for factual claims (approvals, safety, pricing, patents, market):
134
+ → Prefer calling `web_search`.
135
+
136
+ Guardrails (STRICT):
137
+ - Do NOT invent specific facts (approval dates, trial names, endpoints, statistics, patent expiry).
138
+ - Any concrete number/date/claim MUST be supported by tool evidence.
139
+ - If evidence is insufficient, clearly list Evidence Gaps.
140
+ - Be concise, structured, and decision-oriented.
141
+ - Avoid medical advice; present as diligence/analysis.
142
+
143
+ Simple Query Rule (CRITICAL):
144
+ - If the user asks a simple definitional question ("what is", "define", "explain") and you can answer without external verification, do NOT call tools and respond directly.
145
+ - Only use tools when you need current/specific data (trials, approvals, patents, market data).
146
+
147
+ Citations policy:
148
+ - The final response's "Citations" section is handled by the system.
149
+ - Do NOT create your own citation list.
150
+ """
151
+
152
+ FINAL_PROMPT = """Write the FINAL decision brief with these sections:
153
+
154
+ 1) Executive Recommendation (1–2 lines)
155
+ 2) Scientific Rationale (bullets)
156
+ 3) Clinical Evidence Snapshot (bullets)
157
+ 4) IP / Exclusivity Quick View (bullets)
158
+ 5) Market / SoC Snapshot (bullets)
159
+ 6) Key Risks + Next Actions (bullets)
160
+
161
+ Rules:
162
+ - If evidence is insufficient, include "Evidence Gaps" with bullets.
163
+ - Do NOT add a citations section yourself; the system will append it.
164
+ Return plain text only.
165
+ """
166
+
167
+ # Placeholder detection to avoid wasting tokens on "Drug X / Indication Y"
168
+ PLACEHOLDER_PATTERNS = [
169
+ r"\bdrug\s*x\b",
170
+ r"\bindication\s*y\b",
171
+ r"\bdrug\s*name\b",
172
+ r"\bindication\s*name\b",
173
+ ]
174
+ def _looks_like_placeholder(q: str) -> bool:
175
+ ql = (q or "").strip().lower()
176
+ return any(re.search(p, ql) for p in PLACEHOLDER_PATTERNS)
177
+
178
+
179
+ def _build_model() -> ChatAnthropic:
180
+ model_name = os.getenv("ANTHROPIC_MODEL", "claude-3-7-sonnet-latest")
181
+ return ChatAnthropic(
182
+ model=model_name,
183
+ temperature=0.2,
184
+ max_tokens=10000,
185
+ timeout=120,
186
+ streaming=False,
187
+ stop=None
188
+ ).bind_tools(TOOLS)
189
+
190
+
191
+ # Safety cap to avoid endless tool loops
192
+ MAX_TOOL_LOOPS = int(os.getenv("MAX_TOOL_LOOPS", "4"))
193
+
194
+
195
+ def llm_call(state: PharmAIState) -> Dict[str, Any]:
196
+ """
197
+ Calls Claude with tool schemas attached.
198
+ Returns new messages to append into state["messages"].
199
+ """
200
+ llm = _build_model()
201
+ messages: List[BaseMessage] = state["messages"]
202
+
203
+ if not messages or not isinstance(messages[0], SystemMessage):
204
+ messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages
205
+
206
+ tool_loops = state.get("tool_loops", 0)
207
+ if tool_loops >= MAX_TOOL_LOOPS:
208
+ # Stop tool-calling loop and force synthesis
209
+ stop_msg = HumanMessage(
210
+ content=(
211
+ "Stop calling tools now. Proceed to final synthesis using what you already have. "
212
+ "If evidence is insufficient, clearly list Evidence Gaps."
213
+ )
214
+ )
215
+ messages = messages + [stop_msg]
216
+
217
+ resp = llm.invoke(messages)
218
+ return {"messages": [resp]}
219
+
220
+
221
+ # -----------------------------
222
+ # Citations extraction (tool-only)
223
+ # -----------------------------
224
+ def _clean_url(u: str) -> str:
225
+ return u.strip().strip("),.]}\"'")
226
+
227
+ def _extract_citations_from_messages(messages: List[BaseMessage]) -> List[str]:
228
+ """
229
+ Tool-only citation extraction (single source of truth):
230
+ - ONLY reads ToolMessage contents (actual tool outputs).
231
+ - If tool output is JSON (list/dict), pull `source` fields.
232
+ - Fallback: regex URL extraction from tool text.
233
+ """
234
+ citations: List[str] = []
235
+ url_re = re.compile(r"https?://[^\s\]\)\}\",']+")
236
+
237
+ for m in messages:
238
+ if not isinstance(m, ToolMessage):
239
+ continue
240
+
241
+ content = getattr(m, "content", None)
242
+ if not content:
243
+ continue
244
+
245
+ if isinstance(content, str):
246
+ parsed = None
247
+ try:
248
+ parsed = json.loads(content)
249
+ except Exception:
250
+ parsed = None
251
+
252
+ if isinstance(parsed, list):
253
+ for item in parsed:
254
+ if isinstance(item, dict):
255
+ src = item.get("source")
256
+ if isinstance(src, str) and src.startswith(("http://", "https://")):
257
+ citations.append(_clean_url(src))
258
+ elif isinstance(parsed, dict):
259
+ src = parsed.get("source")
260
+ if isinstance(src, str) and src.startswith(("http://", "https://")):
261
+ citations.append(_clean_url(src))
262
+
263
+ for u in url_re.findall(content):
264
+ citations.append(_clean_url(u))
265
+
266
+ # De-duplicate
267
+ seen = set()
268
+ out = []
269
+ for c in citations:
270
+ # drop clearly broken/truncated URLs
271
+ if len(c) < 12:
272
+ continue
273
+ if c not in seen:
274
+ seen.add(c)
275
+ out.append(c)
276
+ return out
277
+
278
+
279
+ def _append_citations_section(brief_text: str, citations: List[str]) -> str:
280
+ """
281
+ Enforces "single source of truth":
282
+ - Removes any existing 'Citations' section the model may have produced
283
+ - Appends citations derived from tool outputs only
284
+ """
285
+ text = (brief_text or "").strip()
286
+
287
+ # Remove any model-generated citations section (best-effort)
288
+ # (handles '## Citations' or 'Citations' headers)
289
+ text = re.split(r"\n#{1,3}\s*Citations\s*\n|\nCitations\s*\n", text, maxsplit=1)[0].rstrip()
290
+
291
+ if citations:
292
+ lines = ["", "## Citations"]
293
+ for i, c in enumerate(citations, 1):
294
+ lines.append(f"{i}. {c}")
295
+ text = text + "\n" + "\n".join(lines)
296
+ else:
297
+ text = text + "\n\n## Citations\n- (No external sources retrieved.)"
298
+
299
+ return text
300
+
301
+ def capture_diagram(state: PharmAIState) -> Dict[str, Any]:
302
+ # Find the last ToolMessage (most recent tool output)
303
+ last_tool = None
304
+ for m in reversed(state["messages"]):
305
+ if isinstance(m, ToolMessage):
306
+ last_tool = m
307
+ break
308
+
309
+ if not last_tool:
310
+ return {}
311
+
312
+ tool_name = getattr(last_tool, "name", "") or ""
313
+ content = getattr(last_tool, "content", "")
314
+
315
+ # If your render tool returns base64 string directly
316
+ if tool_name == "render_dot_to_png_base64":
317
+ return {"diagram_png_base64": content}
318
+
319
+ # If your generate_graph_dot returns dot string
320
+ if tool_name == "generate_graph_dot":
321
+ return {"diagram_dot": content}
322
+
323
+ return {}
324
+
325
+ def route_after_tools(state: PharmAIState) -> str:
326
+ # If we already have the final diagram artifact, stop.
327
+ if state.get("diagram_png_base64"):
328
+ return END
329
+ return "bump_tool_loop"
330
+
331
+ def preprocess(state: PharmAIState) -> Dict[str, Any]:
332
+ q = (state.get("user_query") or "").strip().lower()
333
+
334
+ if any(k in q for k in ["diagram", "flowchart", "architecture", "graphviz", "dot", "draw"]):
335
+ return {"intent": "diagram"}
336
+
337
+ if re.match(r"^(what is|define|explain)\b", q) and len(q) < 120:
338
+ return {"intent": "simple"}
339
+
340
+ return {"intent": "diligence"}
341
+
342
+ def route_after_llm(state: PharmAIState):
343
+ # If query is simple, never call tools/synthesize
344
+ if state.get("intent") == "simple":
345
+ return "end_simple"
346
+
347
+ # If the model asked for tools, go tools
348
+ last = state["messages"][-1]
349
+ if getattr(last, "tool_calls", None):
350
+ return "tools"
351
+
352
+ return "synthesize"
353
+
354
+ def end_simple(state: PharmAIState) -> Dict[str, Any]:
355
+ # Return the last assistant content as the final answer
356
+ last = state["messages"][-1]
357
+ text = getattr(last, "content", "") if isinstance(getattr(last, "content", ""), str) else str(getattr(last, "content", ""))
358
+ return {"decision_brief": text, "citations": []}
359
+
360
+
361
+ # -----------------------------
362
+ # Final Synthesis Node
363
+ # -----------------------------
364
+ def synthesize(state: PharmAIState) -> Dict[str, Any]:
365
+ # Fast guardrail: placeholders -> short response without tool burn
366
+ uq = state.get("user_query", "")
367
+ if _looks_like_placeholder(uq):
368
+ brief = (
369
+ "# FINAL DECISION BRIEF\n\n"
370
+ "I need the **actual drug name** and **specific indication** to perform diligence.\n\n"
371
+ "## Evidence Gaps\n"
372
+ "- Drug name (e.g., semaglutide)\n"
373
+ "- Indication (e.g., obesity)\n"
374
+ "- Trial/program context (if any)\n"
375
+ )
376
+ return {
377
+ "decision_brief": _append_citations_section(brief, []),
378
+ "citations": [],
379
+ "messages": [HumanMessage(content="(placeholder query detected; returned guardrail response)")],
380
+ }
381
+
382
+ llm = _build_model()
383
+ messages: List[BaseMessage] = state["messages"]
384
+ messages = messages + [HumanMessage(content=FINAL_PROMPT)]
385
+
386
+ resp = llm.invoke(messages)
387
+
388
+ tool_citations = _extract_citations_from_messages(state["messages"])
389
+ brief_text = resp.content if isinstance(resp.content, str) else str(resp.content)
390
+ brief_text = _append_citations_section(brief_text, tool_citations)
391
+
392
+ return {
393
+ "decision_brief": brief_text,
394
+ "citations": tool_citations,
395
+ "messages": [resp],
396
+ }
397
+
398
+
399
+ # -----------------------------
400
+ # Build + Compile Graph
401
+ # -----------------------------
402
+ def build_graph():
403
+ """
404
+ Graph with preprocessing and smart routing.
405
+ """
406
+ g = StateGraph(PharmAIState)
407
+
408
+ g.add_node("preprocess", preprocess)
409
+ g.add_node("llm_call", llm_call)
410
+ g.add_node("tools", ToolNode(TOOLS))
411
+ g.add_node("capture_diagram", capture_diagram)
412
+ g.add_node("bump_tool_loop", lambda s: {"tool_loops": s.get("tool_loops", 0) + 1})
413
+ g.add_node("synthesize", synthesize)
414
+ g.add_node("end_simple", end_simple)
415
+
416
+ g.add_edge(START, "preprocess")
417
+ g.add_edge("preprocess", "llm_call")
418
+
419
+ # After LLM: route based on intent and tool calls
420
+ g.add_conditional_edges(
421
+ "llm_call",
422
+ route_after_llm,
423
+ {
424
+ "tools": "tools",
425
+ "synthesize": "synthesize",
426
+ "end_simple": "end_simple",
427
+ },
428
+ )
429
+
430
+ # After tools: capture diagram data
431
+ g.add_edge("tools", "capture_diagram")
432
+
433
+ # After capture: check if we should stop (diagram complete) or continue
434
+ g.add_conditional_edges(
435
+ "capture_diagram",
436
+ route_after_tools,
437
+ {
438
+ END: END, # Stop if diagram is complete
439
+ "bump_tool_loop": "bump_tool_loop", # Continue otherwise
440
+ },
441
+ )
442
+
443
+ g.add_edge("bump_tool_loop", "llm_call")
444
+ g.add_edge("end_simple", END)
445
+ g.add_edge("synthesize", END)
446
+
447
+ return g.compile()
448
+
449
+ # -----------------------------
450
+ # Test execution
451
+ # -----------------------------
452
+ if __name__ == "__main__":
453
+ print("Building PharmAI Navigator graph...")
454
+ graph = build_graph()
455
+ print("Graph compiled successfully!")
456
+
457
+ # Test query designed to trigger generate_graph_dot tool
458
+ #test_query = "Assess semaglutide for obesity"
459
+ #test_query = "Assess donanemab for early Alzheimer’s disease. Retrieve key clinical trials, summarize efficacy and safety outcomes, normalize the evidence, and generate a system architecture graph showing how PharmAI Navigator evaluates this asset."
460
+ #test_query = "Create a DOT graph showing the relationship between Drug, Indication, Clinical Trials, FDA Approval, and Market Launch and render it as png"
461
+ test_query = "What is pembrolizumab?"
462
+ print(f"\nRunning test query: {test_query}")
463
+
464
+ result = graph.invoke({
465
+ "messages": [HumanMessage(content=test_query)],
466
+ "user_query": test_query,
467
+ "tool_loops": 0,
468
+ })
469
+
470
+ print("\n" + "=" * 60)
471
+ print("OUTPUT:")
472
+ print("=" * 60)
473
+ print(result.get("decision_brief", "No output"))
474
+
475
+ print("\n" + "=" * 60)
476
+ print("CITATIONS (tool-only):")
477
+ print("=" * 60)
478
+ for i, citation in enumerate(result.get("citations", []), 1):
479
+ print(f"{i}. {citation}")
memory.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional
4
+ import time
5
+ import threading
6
+ from schemas import Message
7
+
8
+ @dataclass
9
+ class SessionMemory:
10
+ """In-memory chat history for a single session."""
11
+ messages: List[Message]
12
+ updated_at: float
13
+
14
+
15
+ class MemoryStore:
16
+ """
17
+ Simple thread-safe in-memory store.
18
+
19
+ - session_id -> list[Message]
20
+ - trims to keep memory bounded
21
+ - includes basic TTL cleanup hooks (optional)
22
+ """
23
+
24
+ def __init__(self, max_messages: int = 30, ttl_seconds: Optional[int] = None):
25
+ self.max_messages = max_messages
26
+ self.ttl_seconds = ttl_seconds
27
+ self._lock = threading.Lock()
28
+ self._store: Dict[str, SessionMemory] = {}
29
+
30
+ def _now(self) -> float:
31
+ return time.time()
32
+
33
+ def get(self, session_id: str) -> List[Message]:
34
+ """Get messages for a session (returns empty list if new session)."""
35
+ if not session_id:
36
+ return []
37
+ with self._lock:
38
+ self._gc_locked()
39
+ if session_id not in self._store:
40
+ self._store[session_id] = SessionMemory(messages=[], updated_at=self._now())
41
+ return list(self._store[session_id].messages)
42
+
43
+ def append(self, session_id: str, role: str, content: str) -> None:
44
+ """Append a message and enforce trimming."""
45
+ if not session_id:
46
+ return
47
+ with self._lock:
48
+ self._gc_locked()
49
+ if session_id not in self._store:
50
+ self._store[session_id] = SessionMemory(messages=[], updated_at=self._now())
51
+
52
+ self._store[session_id].messages.append(Message(role=role, content=content))
53
+ self._store[session_id].updated_at = self._now()
54
+
55
+ # Trim oldest messages (keep most recent)
56
+ if len(self._store[session_id].messages) > self.max_messages:
57
+ overflow = len(self._store[session_id].messages) - self.max_messages
58
+ self._store[session_id].messages = self._store[session_id].messages[overflow:]
59
+
60
+ def set_messages(self, session_id: str, messages: List[Message]) -> None:
61
+ """Replace session history entirely (rarely needed, but handy)."""
62
+ if not session_id:
63
+ return
64
+ with self._lock:
65
+ self._store[session_id] = SessionMemory(
66
+ messages=messages[-self.max_messages :],
67
+ updated_at=self._now(),
68
+ )
69
+
70
+ def clear(self, session_id: str) -> None:
71
+ """Clear a single session."""
72
+ if not session_id:
73
+ return
74
+ with self._lock:
75
+ self._store.pop(session_id, None)
76
+
77
+ def _gc_locked(self) -> None:
78
+ """TTL cleanup (only runs if ttl_seconds is configured)."""
79
+ if not self.ttl_seconds:
80
+ return
81
+ cutoff = self._now() - self.ttl_seconds
82
+ expired = [sid for sid, mem in self._store.items() if mem.updated_at < cutoff]
83
+ for sid in expired:
84
+ self._store.pop(sid, None)
85
+
86
+
87
+ # Global singleton (simple for HF Spaces demo)
88
+ memory_store = MemoryStore(
89
+ max_messages=int(__import__("os").getenv("MAX_SESSION_MESSAGES", "30")),
90
+ ttl_seconds=int(__import__("os").getenv("SESSION_TTL_SECONDS", "0")) or None,
91
+ )
memory_mongo.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MongoDB-backed session memory store.
3
+ Replaces in-memory storage with persistent MongoDB storage.
4
+ """
5
+ from __future__ import annotations
6
+ from typing import List, Optional
7
+ import os
8
+ from datetime import datetime, timedelta
9
+ from pymongo import MongoClient, ASCENDING
10
+ from pymongo.errors import ConnectionFailure, OperationFailure
11
+ from schemas import Message
12
+ from dotenv import load_dotenv
13
+
14
+ #load env vars
15
+ load_dotenv()
16
+
17
+ class MongoMemoryStore:
18
+ """
19
+ MongoDB-backed session memory store.
20
+
21
+ Schema:
22
+ {
23
+ "_id": "session_id",
24
+ "messages": [
25
+ {"role": "user", "content": "..."},
26
+ {"role": "assistant", "content": "..."}
27
+ ],
28
+ "updated_at": datetime,
29
+ "created_at": datetime
30
+ }
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ mongo_uri: Optional[str] = None,
36
+ database_name: str = "pharmai",
37
+ collection_name: str = "sessions",
38
+ max_messages: int = 30,
39
+ ttl_seconds: Optional[int] = None,
40
+ ):
41
+ self.max_messages = max_messages
42
+ self.ttl_seconds = ttl_seconds
43
+
44
+ # Get MongoDB URI from env or parameter
45
+ self.mongo_uri = mongo_uri or os.getenv("MONGO_URI")
46
+ if not self.mongo_uri:
47
+ raise ValueError("MONGO_URI not found in environment variables")
48
+
49
+ # Connect to MongoDB
50
+ try:
51
+ self.client = MongoClient(self.mongo_uri, serverSelectionTimeoutMS=5000)
52
+ # Test connection
53
+ self.client.admin.command('ping')
54
+ print(f"✅ MongoDB connected: {database_name}.{collection_name}")
55
+ except ConnectionFailure as e:
56
+ raise ConnectionError(f"Failed to connect to MongoDB: {e}")
57
+
58
+ self.db = self.client[database_name]
59
+ self.collection = self.db[collection_name]
60
+
61
+ # Create indexes
62
+ self._create_indexes()
63
+
64
+ def _create_indexes(self):
65
+ """Create indexes for performance and TTL."""
66
+ try:
67
+ # Get existing indexes
68
+ existing_indexes = self.collection.index_information()
69
+
70
+ # TTL index - automatically delete old sessions
71
+ if self.ttl_seconds:
72
+ # Check if TTL index exists
73
+ ttl_exists = any(
74
+ idx.get("expireAfterSeconds") is not None
75
+ for idx in existing_indexes.values()
76
+ )
77
+
78
+ if not ttl_exists:
79
+ # Drop the basic updated_at index if it exists (without TTL)
80
+ if "updated_at_1" in existing_indexes:
81
+ self.collection.drop_index("updated_at_1")
82
+
83
+ # Create TTL index
84
+ self.collection.create_index(
85
+ [("updated_at", ASCENDING)],
86
+ expireAfterSeconds=self.ttl_seconds,
87
+ name="session_ttl"
88
+ )
89
+ print(f"✅ Created TTL index (expires after {self.ttl_seconds}s)")
90
+ else:
91
+ # Just a regular index on updated_at (no TTL)
92
+ if "updated_at_1" not in existing_indexes and "session_ttl" not in existing_indexes:
93
+ self.collection.create_index([("updated_at", ASCENDING)])
94
+ print("✅ Created updated_at index")
95
+
96
+ except OperationFailure as e:
97
+ # Index creation failed, but continue anyway
98
+ print(f"⚠️ Index creation warning: {e}")
99
+ pass
100
+
101
+ def get(self, session_id: str) -> List[Message]:
102
+ """Get messages for a session."""
103
+ if not session_id:
104
+ return []
105
+
106
+ try:
107
+ doc = self.collection.find_one({"_id": session_id})
108
+ if not doc:
109
+ return []
110
+
111
+ # Convert dict messages to Message objects
112
+ messages = []
113
+ for msg in doc.get("messages", []):
114
+ messages.append(Message(
115
+ role=msg.get("role", "user"),
116
+ content=msg.get("content", "")
117
+ ))
118
+
119
+ return messages
120
+ except OperationFailure as e:
121
+ print(f"Error getting session {session_id}: {e}")
122
+ return []
123
+
124
+ def append(self, session_id: str, role: str, content: str) -> None:
125
+ """Append a message to a session."""
126
+ if not session_id:
127
+ return
128
+
129
+ now = datetime.utcnow()
130
+ message = {"role": role, "content": content}
131
+
132
+ try:
133
+ # Try to update existing session
134
+ result = self.collection.update_one(
135
+ {"_id": session_id},
136
+ {
137
+ "$push": {"messages": message},
138
+ "$set": {"updated_at": now}
139
+ }
140
+ )
141
+
142
+ # If session doesn't exist, create it
143
+ if result.matched_count == 0:
144
+ self.collection.insert_one({
145
+ "_id": session_id,
146
+ "messages": [message],
147
+ "created_at": now,
148
+ "updated_at": now
149
+ })
150
+
151
+ # Trim old messages if needed
152
+ self._trim_messages(session_id)
153
+
154
+ except OperationFailure as e:
155
+ print(f"Error appending to session {session_id}: {e}")
156
+
157
+ def _trim_messages(self, session_id: str) -> None:
158
+ """Keep only the most recent max_messages."""
159
+ try:
160
+ doc = self.collection.find_one({"_id": session_id})
161
+ if not doc:
162
+ return
163
+
164
+ messages = doc.get("messages", [])
165
+ if len(messages) > self.max_messages:
166
+ # Keep only the most recent messages
167
+ trimmed = messages[-self.max_messages:]
168
+ self.collection.update_one(
169
+ {"_id": session_id},
170
+ {"$set": {"messages": trimmed}}
171
+ )
172
+ except OperationFailure as e:
173
+ print(f"Error trimming session {session_id}: {e}")
174
+
175
+ def set_messages(self, session_id: str, messages: List[Message]) -> None:
176
+ """Replace session history entirely."""
177
+ if not session_id:
178
+ return
179
+
180
+ now = datetime.utcnow()
181
+ message_dicts = [{"role": m.role, "content": m.content} for m in messages]
182
+
183
+ # Keep only most recent messages
184
+ if len(message_dicts) > self.max_messages:
185
+ message_dicts = message_dicts[-self.max_messages:]
186
+
187
+ try:
188
+ self.collection.update_one(
189
+ {"_id": session_id},
190
+ {
191
+ "$set": {
192
+ "messages": message_dicts,
193
+ "updated_at": now
194
+ },
195
+ "$setOnInsert": {"created_at": now}
196
+ },
197
+ upsert=True
198
+ )
199
+ except OperationFailure as e:
200
+ print(f"Error setting messages for session {session_id}: {e}")
201
+
202
+ def clear(self, session_id: str) -> None:
203
+ """Clear a single session."""
204
+ if not session_id:
205
+ return
206
+
207
+ try:
208
+ self.collection.delete_one({"_id": session_id})
209
+ except OperationFailure as e:
210
+ print(f"Error clearing session {session_id}: {e}")
211
+
212
+ def cleanup_old_sessions(self, days: int = 7) -> int:
213
+ """
214
+ Manually cleanup sessions older than X days.
215
+ (TTL index handles this automatically if configured)
216
+ """
217
+ cutoff = datetime.utcnow() - timedelta(days=days)
218
+ try:
219
+ result = self.collection.delete_many({"updated_at": {"$lt": cutoff}})
220
+ return result.deleted_count
221
+ except OperationFailure as e:
222
+ print(f"Error cleaning up old sessions: {e}")
223
+ return 0
224
+
225
+ def get_session_count(self) -> int:
226
+ """Get total number of active sessions."""
227
+ try:
228
+ return self.collection.count_documents({})
229
+ except OperationFailure:
230
+ return 0
231
+
232
+ def close(self):
233
+ """Close MongoDB connection."""
234
+ if self.client:
235
+ self.client.close()
236
+
237
+
238
+ # Create global singleton
239
+ def create_memory_store() -> MongoMemoryStore:
240
+ """Factory function to create memory store based on configuration."""
241
+ try:
242
+ # Try MongoDB first
243
+ return MongoMemoryStore(
244
+ max_messages=int(os.getenv("MAX_SESSION_MESSAGES", "30")),
245
+ ttl_seconds=int(os.getenv("SESSION_TTL_SECONDS", "0")) or None,
246
+ )
247
+ except (ValueError, ConnectionError) as e:
248
+ print(f"⚠️ MongoDB not available: {e}")
249
+ print("⚠️ Falling back to in-memory storage")
250
+
251
+ # Fallback to in-memory
252
+ from memory import MemoryStore
253
+ return MemoryStore(
254
+ max_messages=int(os.getenv("MAX_SESSION_MESSAGES", "30")),
255
+ ttl_seconds=int(os.getenv("SESSION_TTL_SECONDS", "0")) or None,
256
+ )
257
+
258
+
259
+ # Global instance
260
+ memory_store = create_memory_store()
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FastAPI and server
2
+ fastapi
3
+ uvicorn
4
+ python-dotenv
5
+
6
+ # LangChain and LangGraph
7
+ langchain
8
+ langchain-anthropic
9
+ langchain-core
10
+ langchain-community
11
+ langgraph
12
+ langgraph-checkpoint
13
+ langsmith
14
+
15
+ # Tools and utilities
16
+ tavily-python
17
+ pydantic
18
+ pydantic-settings
19
+
20
+ # MongoDB
21
+ pymongo
22
+ motor
23
+
24
+ # HTTP client
25
+ httpx
26
+ aiohttp
27
+
28
+ # Graph rendering (optional)
29
+ graphviz
30
+
31
+ # Other dependencies
32
+ python-multipart
schemas.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # schemas.py
2
+ from typing import List, Optional, Dict, Any
3
+ from pydantic import BaseModel, Field
4
+ from enum import Enum
5
+
6
+ # Core Enums
7
+ class AgentType(str, Enum):
8
+ PLANNER = "planner"
9
+ SCIENTIFIC = "scientific"
10
+ PATENT = "patent"
11
+ MARKET = "market"
12
+ SUPPLY = "supply"
13
+ SYNTHESIS = "synthesis"
14
+
15
+ class EvidenceType(str, Enum):
16
+ LITERATURE = "literature"
17
+ CLINICAL_TRIAL = "clinical_trial"
18
+ PATENT = "patent"
19
+ MARKET = "market"
20
+ OTHER = "other"
21
+
22
+ # API Schemas (FastAPI I/O)
23
+ class AgentRunRequest(BaseModel):
24
+ """
25
+ Incoming request from Node.js backend or direct API call.
26
+ """
27
+ session_id: Optional[str] = Field(
28
+ default=None,
29
+ description="Optional session ID to maintain conversation state"
30
+ )
31
+ query: str = Field(
32
+ ...,
33
+ description="User query, e.g. 'Drug X for Indication Y'"
34
+ )
35
+
36
+ class AgentRunResponse(BaseModel):
37
+ """
38
+ Final response returned by the agent system.
39
+ """
40
+ session_id: Optional[str]
41
+ decision_brief: str
42
+ confidence_score: Optional[float] = Field(
43
+ default=None,
44
+ description="Optional overall confidence score (0–1)"
45
+ )
46
+ citations: Optional[List[str]] = Field(
47
+ default=None,
48
+ description="List of citation identifiers or URLs"
49
+ )
50
+ metadata: Optional[Dict[str, Any]] = Field(
51
+ default=None,
52
+ description="Extra debug or trace metadata"
53
+ )
54
+
55
+ # Internal Agent State
56
+ class Message(BaseModel):
57
+ """
58
+ Canonical message format passed between agents.
59
+ """
60
+ role: str # system | user | assistant | tool
61
+ content: str
62
+
63
+ class EvidenceItem(BaseModel):
64
+ """
65
+ A single piece of evidence produced by tools or agents.
66
+ """
67
+ type: EvidenceType
68
+ source: str
69
+ summary: str
70
+ confidence: Optional[float] = None
71
+ raw: Optional[Dict[str, Any]] = None
72
+
73
+ class AgentOutput(BaseModel):
74
+ """
75
+ Output produced by a single agent.
76
+ """
77
+ agent: AgentType
78
+ text: str
79
+ evidence: Optional[List[EvidenceItem]] = None
80
+
81
+ class AgentState(BaseModel):
82
+ """
83
+ LangGraph state object.
84
+ This is what flows between graph nodes.
85
+ """
86
+ session_id: Optional[str]
87
+ user_query: str
88
+
89
+ messages: List[Message] = Field(default_factory=list)
90
+
91
+ agent_outputs: Dict[AgentType, AgentOutput] = Field(
92
+ default_factory=dict,
93
+ description="Outputs from each agent"
94
+ )
95
+
96
+ final_decision: Optional[str] = None
97
+
98
+ confidence_score: Optional[float] = None
tools.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, Optional
2
+ import os
3
+ import uuid
4
+ import re
5
+ import base64
6
+ from schemas import EvidenceItem, EvidenceType
7
+
8
+ #Helper Functions
9
+ def _etype(name: str, default: EvidenceType) -> EvidenceType:
10
+ """Return EvidenceType.<name> if it exists, else default (prevents breaking)."""
11
+ return getattr(EvidenceType, name, default)
12
+
13
+ def _short(s: str, n: int = 700) -> str:
14
+ return (s or "")[:n]
15
+
16
+ def _is_url(s: str) -> bool:
17
+ return isinstance(s, str) and s.startswith(("http://", "https://"))
18
+
19
+ # Tool 1: Tavily Web Search (existing, unchanged)
20
+ def tavily_search(query: str, max_results: int = 5) -> List[EvidenceItem]:
21
+ """
22
+ Uses Tavily API to perform web search.
23
+ Returns structured evidence items.
24
+ """
25
+ api_key = os.getenv("TAVILY_API_KEY")
26
+
27
+ if not api_key:
28
+ return [
29
+ EvidenceItem(
30
+ type=EvidenceType.OTHER,
31
+ source="tavily_disabled",
32
+ summary="Tavily API key not configured; search skipped.",
33
+ confidence=0.0,
34
+ )
35
+ ]
36
+
37
+ try:
38
+ from tavily import TavilyClient
39
+
40
+ client = TavilyClient(api_key=api_key)
41
+ results = client.search(
42
+ query=query,
43
+ max_results=max_results,
44
+ include_raw_content=False,
45
+ )
46
+
47
+ evidence: List[EvidenceItem] = []
48
+
49
+ for r in results.get("results", []):
50
+ evidence.append(
51
+ EvidenceItem(
52
+ type=EvidenceType.LITERATURE,
53
+ source=r.get("url", "unknown"),
54
+ summary=r.get("content", "")[:500],
55
+ confidence=0.6,
56
+ raw=r,
57
+ )
58
+ )
59
+
60
+ return evidence
61
+
62
+ except Exception as e:
63
+ return [
64
+ EvidenceItem(
65
+ type=EvidenceType.OTHER,
66
+ source="tavily_error",
67
+ summary=f"Tavily search failed: {str(e)}",
68
+ confidence=0.0,
69
+ )
70
+ ]
71
+
72
+ # Tool 2: Stub Evidence Generator (existing, unchanged)
73
+ def stub_evidence(query: str) -> List[EvidenceItem]:
74
+ """
75
+ Deterministic fallback tool.
76
+ Useful for demos, offline mode, or testing agent logic.
77
+ """
78
+ return [
79
+ EvidenceItem(
80
+ type=EvidenceType.OTHER,
81
+ source="stub_tool",
82
+ summary=f"Stub evidence generated for query: '{query}'. "
83
+ f"This indicates where real retrieval will plug in.",
84
+ confidence=0.2,
85
+ raw={
86
+ "id": str(uuid.uuid4()),
87
+ "note": "Replace with real retrieval later",
88
+ },
89
+ )
90
+ ]
91
+
92
+ # Tool 3: Query Classifier (planner helper)
93
+ def classify_query(query: str) -> Dict[str, Any]:
94
+ """
95
+ Lightweight classifier to help the agent decide
96
+ which tools (if any) are required.
97
+ """
98
+ q = (query or "").lower()
99
+ needs_graph = any(k in q for k in ["diagram", "graph", "graphviz", "dot", "flow", "architecture", "arch", "draw"])
100
+ needs_trials = any(k in q for k in ["trial", "clinical", "phase", "nct", "primary endpoint", "secondary endpoint"])
101
+ needs_facts = any(k in q for k in ["fda", "approval", "label", "patent", "exclusivity", "pricing", "aria", "safety", "market"])
102
+ needs_entities = any(k in q for k in ["evaluate", "assess", "analyze", "repurpose", "for "])
103
+ return {
104
+ "needs_graph": needs_graph,
105
+ "needs_clinical_trials": needs_trials,
106
+ "needs_web_search": needs_facts or needs_trials,
107
+ "needs_entity_extraction": needs_entities,
108
+ }
109
+
110
+ # Tool 4: Entity Extraction (Drug / Indication)
111
+ def extract_entities(query: str) -> Dict[str, Optional[str]]:
112
+ """
113
+ Minimal entity extractor for MVP.
114
+ """
115
+ text = (query or "").strip()
116
+ m = re.search(
117
+ r"(evaluate|assess|analyze)\s+(?P<drug>.+?)\s+for\s+(?P<indication>.+)",
118
+ text,
119
+ re.IGNORECASE,
120
+ )
121
+ if m:
122
+ return {
123
+ "drug": m.group("drug").strip(),
124
+ "indication": m.group("indication").strip(),
125
+ }
126
+ return {"drug": None, "indication": None}
127
+
128
+ # Tool 5: Evidence Normalizer (dedupe + cleanup)
129
+ def normalize_evidence(evidence: List[EvidenceItem]) -> List[EvidenceItem]:
130
+ """
131
+ Deduplicates evidence by source and trims noisy content.
132
+ """
133
+ seen = set()
134
+ cleaned: List[EvidenceItem] = []
135
+
136
+ for e in evidence:
137
+ if e.source in seen:
138
+ continue
139
+ seen.add(e.source)
140
+
141
+ cleaned.append(
142
+ EvidenceItem(
143
+ type=e.type,
144
+ source=e.source,
145
+ summary=(e.summary or "")[:800],
146
+ confidence=e.confidence,
147
+ raw=None, # drop heavy payloads
148
+ )
149
+ )
150
+
151
+ return cleaned
152
+
153
+ # Tool 6: Graph Generation (Graphviz DOT only)
154
+ def generate_graph_dot(
155
+ title: str,
156
+ nodes: List[Dict[str, str]],
157
+ edges: List[Dict[str, str]],
158
+ rankdir: str = "LR",
159
+ ) -> str:
160
+ """
161
+ Generates Graphviz DOT code.
162
+ IMPORTANT: LLM must call this tool; never output DOT directly.
163
+ """
164
+ safe_title = (title or "PharmAI Graph").replace('"', "'")
165
+
166
+ lines = [
167
+ "digraph G {",
168
+ f" rankdir={rankdir};",
169
+ ' labelloc="t";',
170
+ ' labeljust="c";',
171
+ f' label=<<B><FONT POINT-SIZE="28">{safe_title}</FONT></B>>;',
172
+ " node [shape=box, style=rounded];",
173
+ "",
174
+ ]
175
+
176
+ for n in nodes or []:
177
+ nid = n.get("id")
178
+ lbl = (n.get("label") or nid).replace('"', "'")
179
+ if nid:
180
+ lines.append(f' {nid} [label="{lbl}"];')
181
+
182
+ lines.append("")
183
+
184
+ for e in edges or []:
185
+ src = e.get("from")
186
+ tgt = e.get("to")
187
+ lbl = e.get("label")
188
+ if src and tgt:
189
+ if lbl:
190
+ lines.append(f' {src} -> {tgt} [label="{lbl}"];')
191
+ else:
192
+ lines.append(f" {src} -> {tgt};")
193
+
194
+ lines.append("}")
195
+ return "\n".join(lines)
196
+
197
+ #Tool 7: ClinicalTrials search (lightweight, Tavily-based)
198
+ def clinicaltrials_search(drug: str, indication: str, max_results: int = 5) -> List[EvidenceItem]:
199
+ """
200
+ MVP approach:
201
+ - Uses Tavily to target ClinicalTrials.gov / NCT IDs
202
+ - Returns EvidenceItems for trial links + snippets
203
+ """
204
+ drug = (drug or "").strip()
205
+ indication = (indication or "").strip()
206
+
207
+ if not drug or not indication:
208
+ return [
209
+ EvidenceItem(
210
+ type=EvidenceType.OTHER,
211
+ source="clinicaltrials_search_invalid_input",
212
+ summary="Missing drug or indication for clinical trials search.",
213
+ confidence=0.0,
214
+ )
215
+ ]
216
+
217
+ query = f'site:clinicaltrials.gov ("{drug}") ("{indication}") NCT'
218
+ ev = tavily_search(query=query, max_results=max_results)
219
+
220
+ trial_type = _etype("CLINICAL_TRIAL", EvidenceType.LITERATURE)
221
+
222
+ out: List[EvidenceItem] = []
223
+ for e in ev:
224
+ # only keep plausible CT.gov results if possible
225
+ if _is_url(e.source) and "clinicaltrials.gov" in e.source:
226
+ out.append(
227
+ EvidenceItem(
228
+ type=trial_type,
229
+ source=e.source,
230
+ summary=e.summary,
231
+ confidence=max(0.55, float(e.confidence or 0.55)),
232
+ raw=e.raw,
233
+ )
234
+ )
235
+
236
+ if out:
237
+ return out
238
+
239
+ # fallback: return whatever Tavily gave (still structured)
240
+ return [
241
+ EvidenceItem(
242
+ type=trial_type,
243
+ source=e.source,
244
+ summary=e.summary,
245
+ confidence=float(e.confidence or 0.4),
246
+ raw=e.raw,
247
+ )
248
+ for e in ev
249
+ ]
250
+
251
+ #Tool 8 : DOT -> PNG
252
+ def render_dot_to_png_base64(dot: str) -> Dict[str, Any]:
253
+ """
254
+ Converts DOT to PNG and returns base64 string.
255
+ - Uses python 'graphviz' package if available.
256
+ - If graphviz isn't installed in the environment, returns an error payload.
257
+ """
258
+ dot = (dot or "").strip()
259
+ if not dot:
260
+ return {"ok": False, "error": "Empty DOT string"}
261
+
262
+ try:
263
+ from graphviz import Source # optional dependency
264
+
265
+ src = Source(dot)
266
+ png_bytes = src.pipe(format="png")
267
+ b64 = base64.b64encode(png_bytes).decode("utf-8")
268
+ return {"ok": True, "png_base64": b64}
269
+ except Exception as e:
270
+ return {
271
+ "ok": False,
272
+ "error": f"DOT->PNG render failed. Ensure `graphviz` Python package and system binaries are installed. Details: {str(e)}",
273
+ }
274
+
275
+ # Tool Registry (extended, backward compatible)
276
+ TOOL_REGISTRY: Dict[str, Any] = {
277
+ # existing
278
+ "web_search": tavily_search,
279
+ "stub_evidence": stub_evidence,
280
+
281
+ # new
282
+ "classify_query": classify_query,
283
+ "extract_entities": extract_entities,
284
+ "normalize_evidence": normalize_evidence,
285
+ "generate_graph_dot": generate_graph_dot,
286
+ "clinicaltrials_search": clinicaltrials_search,
287
+ "render_dot_to_png_base64": render_dot_to_png_base64
288
+ }