File size: 6,685 Bytes
56808e1
 
 
11a7703
 
56808e1
 
 
 
 
 
 
058b7cd
 
11a7703
56808e1
 
 
 
 
 
 
 
 
 
 
 
 
 
058b7cd
 
 
 
 
 
 
56808e1
 
 
 
284dfa9
56808e1
 
 
 
284dfa9
56808e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b46033
 
11a7703
56808e1
11a7703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb8adc6
 
11a7703
 
 
 
 
 
 
 
 
 
cb8adc6
11a7703
 
 
cb8adc6
 
56808e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import argparse
import json
import os
import time
import traceback

from fastapi import FastAPI
from pydantic import BaseModel

from app.graph import build_graph
from app.schemas import ClinicalBrief
from langgraph.types import Command
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse


class ChatRequest(BaseModel):
    session_id: str
    message: str


class ChatResponse(BaseModel):
    reply: str
    state: str
    brief: ClinicalBrief | None = None


app = FastAPI(title="Clinical Intake Agent")

app.mount("/static", StaticFiles(directory="app/static"), name="static")

@app.get("/")
async def root():
    return FileResponse("app/static/index.html")


graph, checkpointer = build_graph()


def get_current_node(session_id: str) -> str:
    """Get frontend stage from checkpoint."""
    config = {"configurable": {"thread_id": session_id}}
    try:
        snapshot = graph.get_state(config)
        if snapshot and snapshot.values:
            return snapshot.values.get("frontend_stage", "intake")
    except Exception:
        pass
    return "intake"


def get_last_reply(session_id: str) -> str:
    """Get last assistant reply from checkpoint."""
    config = {"configurable": {"thread_id": session_id}}
    try:
        snapshot = graph.get_state(config)
        if snapshot and snapshot.values:
            messages = snapshot.values.get("messages", [])
            for msg in reversed(messages):
                if msg.get("role") == "assistant":
                    return msg.get("content", "")
    except Exception:
        pass
    return ""


def get_brief(session_id: str) -> dict | None:
    """Get clinical brief from checkpoint."""
    config = {"configurable": {"thread_id": session_id}}
    try:
        snapshot = graph.get_state(config)
        if snapshot and snapshot.values:
            return snapshot.values.get("clinical_brief")
    except Exception:
        pass
    return None


@app.get("/health")
async def health():
    mock_mode = os.environ.get("MOCK_LLM", "false").lower() == "true"
    return {"status": "ok", "mock_mode": mock_mode}


@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
    t0 = time.time()
    print(f"\n[{t0:.3f}] [API] -> POST /chat received for {request.session_id}")
    print(f"[{t0:.3f}] [API]    Message: '{request.message[:80]}'")
    config = {"configurable": {"thread_id": request.session_id}}

    try:
        # Get current checkpoint state
        snapshot = graph.get_state(config)
        has_state = bool(snapshot and snapshot.values)
        has_next = bool(snapshot.next) if snapshot else False
        print(f"[{time.time():.3f}] [API] Snapshot: has_state={has_state}, has_next={has_next}, next={snapshot.next if snapshot else 'N/A'}")

        # Guard: if session is already complete, don't re-invoke the graph
        current_stage = snapshot.values.get("frontend_stage", "intake") if has_state else "intake"
        print(f"[{time.time():.3f}] [API] Current stage: {current_stage}")

        if current_stage == "done":
            print(f"[{time.time():.3f}] [API] Session already complete. Returning existing brief.")
            reply = get_last_reply(request.session_id)
            brief_dict = get_brief(request.session_id)
            return ChatResponse(
                reply=reply or "Your intake is already complete. Please start a new session.",
                state="done",
                brief=brief_dict
            )

        # Check if graph is interrupted and waiting for input
        t_start_graph = time.time()
        if has_next:
            print(f"[{time.time():.3f}] [API] Resuming graph from interrupt (next={snapshot.next})...")
            graph.update_state(config, {"messages": [{"role": "user", "content": request.message}]})
            result = graph.invoke(None, config=config)
        else:
            print(f"[{time.time():.3f}] [API] Starting new graph invoke...")
            input_state = {"messages": [{"role": "user", "content": request.message}]}
            result = graph.invoke(input_state, config=config)
        print(f"[{time.time():.3f}] [API] <- Graph invoke returned in {time.time() - t_start_graph:.2f}s")

        current_node = get_current_node(request.session_id)
        reply = get_last_reply(request.session_id)
        brief_dict = get_brief(request.session_id)

        total_t = time.time() - t0
        print(f"[{time.time():.3f}] [API] Chat completed in {total_t:.2f}s. Reply='{reply[:60]}' Stage={current_node}")

        return ChatResponse(reply=reply, state=current_node, brief=brief_dict)

    except Exception as e:
        tb = traceback.format_exc()
        print(f"[{time.time():.3f}] [API] *** EXCEPTION in /chat ***")
        print(tb)
        return ChatResponse(
            reply=f"Server error: {type(e).__name__}: {str(e)[:200]}",
            state="intake",
            brief=None
        )


def run_cli():
    print("=" * 60)
    print("Clinical Intake Agent - CLI Mode")
    print("=" * 60)
    print("Type your responses. The intake will end when complete.\n")

    session_id = "cli_session"
    
    while True:
        try:
            user_input = input("You: ").strip()
        except EOFError:
            break

        if not user_input:
            continue

        config = {"configurable": {"thread_id": session_id}}
        
        # Build input state from checkpoint or start fresh
        snapshot = graph.get_state(config)
        if snapshot and snapshot.values and snapshot.values.get("messages"):
            # Continue existing conversation - only pass the new user message
            # The Annotated reducer will append it to existing messages
            input_state = {"messages": [{"role": "user", "content": user_input}]}
        else:
            input_state = {"messages": [{"role": "user", "content": user_input}]}
        
        result = graph.invoke(input_state, config=config)
        
        current_node = get_current_node(session_id)
        reply = get_last_reply(session_id)
        brief = get_brief(session_id)

        print(f"\nAgent: {reply}\n")

        if current_node == "done" and brief:
            print("=" * 60)
            print("CLINICAL INTAKE COMPLETE")
            print("=" * 60)
            print(json.dumps(brief, indent=2))
            break


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Clinical Intake Agent")
    parser.add_argument("--cli", action="store_true", help="Run in CLI mode")
    args = parser.parse_args()

    if args.cli:
        run_cli()
    else:
        import uvicorn
        uvicorn.run(app, host="0.0.0.0", port=7860)