File size: 5,326 Bytes
1de0a51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from __future__ import annotations
import uuid
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from schemas import AgentRunRequest, AgentRunResponse, Message
from memory_mongo import memory_store  # MongoDB-backed memory
from graph import build_graph
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from fastapi.responses import StreamingResponse
import json
import time
from fastapi.encoders import jsonable_encoder

app = FastAPI(title="PharmAI Navigator (Agentic)", version="0.1.0")

# CORS (HF Spaces + your Node proxy)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Compile graph once at startup
GRAPH = build_graph()


@app.get("/health")
def health():
    """Health check with MongoDB status."""
    mongo_status = "connected"
    session_count = 0
    
    try:
        session_count = memory_store.get_session_count()
    except Exception as e:
        mongo_status = f"error: {str(e)}"
    
    return {
        "status": "ok",
        "mongodb": mongo_status,
        "active_sessions": session_count
    }


@app.get("/session/{session_id}/history")
def get_session_history(session_id: str):
    """Get chat history for a session (for testing)."""
    messages = memory_store.get(session_id)
    return {
        "session_id": session_id,
        "message_count": len(messages),
        "messages": [{"role": m.role, "content": m.content[:100] + "..." if len(m.content) > 100 else m.content} for m in messages]
    }


@app.delete("/session/{session_id}")
def clear_session(session_id: str):
    """Clear a session's history (for testing)."""
    memory_store.clear(session_id)
    return {"session_id": session_id, "status": "cleared"}


@app.post("/admin/cleanup-sessions")
def cleanup_old_sessions(days: int = 7):
    """

    Admin endpoint to manually cleanup old sessions.

    (TTL index handles this automatically if configured)

    """
    try:
        deleted = memory_store.cleanup_old_sessions(days=days)
        return {
            "status": "ok",
            "deleted_sessions": deleted,
            "days": days
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/test/echo")
def test_echo(req: AgentRunRequest):
    """

    Lightweight test endpoint - no LLM calls, just tests memory.

    Echoes back the query and shows session history.

    """
    session_id = req.session_id or str(uuid.uuid4())
    
    # Get prior history
    prior = memory_store.get(session_id)
    
    # Append user message
    memory_store.append(session_id, role="user", content=req.query)
    
    # Create fake response
    fake_response = f"Echo: {req.query} (Session has {len(prior)} prior messages)"
    
    # Append assistant message
    memory_store.append(session_id, role="assistant", content=fake_response)
    
    return {
        "session_id": session_id,
        "decision_brief": fake_response,
        "prior_message_count": len(prior),
        "current_message_count": len(memory_store.get(session_id)),
        "citations": [],
        "metadata": {"test_mode": True}
    }


@app.post("/run", response_model=AgentRunResponse)
def run_agent(req: AgentRunRequest):
    # 1) session handling
    session_id = req.session_id or str(uuid.uuid4())

    # 2) load prior history (for chat continuity)
    prior = memory_store.get(session_id)

    # Convert to LangChain message dict format for LangGraph MessagesState
    # LangGraph expects state["messages"] as list of LC messages; we pass dict-like messages.
    messages = []
    for m in prior:
        if m.role == "user":
            messages.append(HumanMessage(content=m.content))
        elif m.role == "assistant":
            messages.append(AIMessage(content=m.content))
        elif m.role == "system":
            messages.append(SystemMessage(content=m.content))

    # 3) append this user query to memory (pre-run)
    memory_store.append(session_id, role="user", content=req.query)

    # Append new user query as LangChain message
    messages = messages + [HumanMessage(content=req.query)]

    # 4) run graph (Mode A synchronous)
    try:
        final_state = GRAPH.invoke(
            {
                "session_id": session_id,
                "user_query": req.query,
                "messages": messages,
            }
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Agent run failed: {str(e)}")

    decision_brief = final_state.get("decision_brief") or final_state.get("final_decision") or ""
    citations = final_state.get("citations") or []

    # 5) save assistant response to memory (post-run)
    if decision_brief.strip():
        memory_store.append(session_id, role="assistant", content=decision_brief)

    return AgentRunResponse(
        session_id=session_id,
        decision_brief=decision_brief,
        confidence_score=final_state.get("confidence_score"),
        citations=citations,
        metadata={
            "has_prior_messages": len(prior) > 0,
        },
    )