medintake-ai / app /main.py
priyansh-saxena1
fix: debug logging + real error messages + ROS guidance
11a7703
import argparse
import json
import os
import time
import traceback
from fastapi import FastAPI
from pydantic import BaseModel
from app.graph import build_graph
from app.schemas import ClinicalBrief
from langgraph.types import Command
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
class ChatRequest(BaseModel):
session_id: str
message: str
class ChatResponse(BaseModel):
reply: str
state: str
brief: ClinicalBrief | None = None
app = FastAPI(title="Clinical Intake Agent")
app.mount("/static", StaticFiles(directory="app/static"), name="static")
@app.get("/")
async def root():
return FileResponse("app/static/index.html")
graph, checkpointer = build_graph()
def get_current_node(session_id: str) -> str:
"""Get frontend stage from checkpoint."""
config = {"configurable": {"thread_id": session_id}}
try:
snapshot = graph.get_state(config)
if snapshot and snapshot.values:
return snapshot.values.get("frontend_stage", "intake")
except Exception:
pass
return "intake"
def get_last_reply(session_id: str) -> str:
"""Get last assistant reply from checkpoint."""
config = {"configurable": {"thread_id": session_id}}
try:
snapshot = graph.get_state(config)
if snapshot and snapshot.values:
messages = snapshot.values.get("messages", [])
for msg in reversed(messages):
if msg.get("role") == "assistant":
return msg.get("content", "")
except Exception:
pass
return ""
def get_brief(session_id: str) -> dict | None:
"""Get clinical brief from checkpoint."""
config = {"configurable": {"thread_id": session_id}}
try:
snapshot = graph.get_state(config)
if snapshot and snapshot.values:
return snapshot.values.get("clinical_brief")
except Exception:
pass
return None
@app.get("/health")
async def health():
mock_mode = os.environ.get("MOCK_LLM", "false").lower() == "true"
return {"status": "ok", "mock_mode": mock_mode}
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
t0 = time.time()
print(f"\n[{t0:.3f}] [API] -> POST /chat received for {request.session_id}")
print(f"[{t0:.3f}] [API] Message: '{request.message[:80]}'")
config = {"configurable": {"thread_id": request.session_id}}
try:
# Get current checkpoint state
snapshot = graph.get_state(config)
has_state = bool(snapshot and snapshot.values)
has_next = bool(snapshot.next) if snapshot else False
print(f"[{time.time():.3f}] [API] Snapshot: has_state={has_state}, has_next={has_next}, next={snapshot.next if snapshot else 'N/A'}")
# Guard: if session is already complete, don't re-invoke the graph
current_stage = snapshot.values.get("frontend_stage", "intake") if has_state else "intake"
print(f"[{time.time():.3f}] [API] Current stage: {current_stage}")
if current_stage == "done":
print(f"[{time.time():.3f}] [API] Session already complete. Returning existing brief.")
reply = get_last_reply(request.session_id)
brief_dict = get_brief(request.session_id)
return ChatResponse(
reply=reply or "Your intake is already complete. Please start a new session.",
state="done",
brief=brief_dict
)
# Check if graph is interrupted and waiting for input
t_start_graph = time.time()
if has_next:
print(f"[{time.time():.3f}] [API] Resuming graph from interrupt (next={snapshot.next})...")
graph.update_state(config, {"messages": [{"role": "user", "content": request.message}]})
result = graph.invoke(None, config=config)
else:
print(f"[{time.time():.3f}] [API] Starting new graph invoke...")
input_state = {"messages": [{"role": "user", "content": request.message}]}
result = graph.invoke(input_state, config=config)
print(f"[{time.time():.3f}] [API] <- Graph invoke returned in {time.time() - t_start_graph:.2f}s")
current_node = get_current_node(request.session_id)
reply = get_last_reply(request.session_id)
brief_dict = get_brief(request.session_id)
total_t = time.time() - t0
print(f"[{time.time():.3f}] [API] Chat completed in {total_t:.2f}s. Reply='{reply[:60]}' Stage={current_node}")
return ChatResponse(reply=reply, state=current_node, brief=brief_dict)
except Exception as e:
tb = traceback.format_exc()
print(f"[{time.time():.3f}] [API] *** EXCEPTION in /chat ***")
print(tb)
return ChatResponse(
reply=f"Server error: {type(e).__name__}: {str(e)[:200]}",
state="intake",
brief=None
)
def run_cli():
print("=" * 60)
print("Clinical Intake Agent - CLI Mode")
print("=" * 60)
print("Type your responses. The intake will end when complete.\n")
session_id = "cli_session"
while True:
try:
user_input = input("You: ").strip()
except EOFError:
break
if not user_input:
continue
config = {"configurable": {"thread_id": session_id}}
# Build input state from checkpoint or start fresh
snapshot = graph.get_state(config)
if snapshot and snapshot.values and snapshot.values.get("messages"):
# Continue existing conversation - only pass the new user message
# The Annotated reducer will append it to existing messages
input_state = {"messages": [{"role": "user", "content": user_input}]}
else:
input_state = {"messages": [{"role": "user", "content": user_input}]}
result = graph.invoke(input_state, config=config)
current_node = get_current_node(session_id)
reply = get_last_reply(session_id)
brief = get_brief(session_id)
print(f"\nAgent: {reply}\n")
if current_node == "done" and brief:
print("=" * 60)
print("CLINICAL INTAKE COMPLETE")
print("=" * 60)
print(json.dumps(brief, indent=2))
break
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Clinical Intake Agent")
parser.add_argument("--cli", action="store_true", help="Run in CLI mode")
args = parser.parse_args()
if args.cli:
run_cli()
else:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)