File size: 3,815 Bytes
f2c113d c800712 f2c113d c800712 f2c113d c800712 f2c113d c800712 f2c113d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
"""
WebSocket endpoint for real-time agent step streaming.
The frontend connects here to see each agent step as it happens:
- Step started (with tool name)
- Step completed (with output summary)
- Step failed (with error)
- Final report ready
"""
from __future__ import annotations
import asyncio
import json
import logging
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from app.agent.orchestrator import Orchestrator
from app.models.schemas import CaseSubmission
from app.services.medgemma import MedGemmaService
logger = logging.getLogger(__name__)
router = APIRouter()
@router.websocket("/agent")
async def agent_websocket(websocket: WebSocket):
"""
WebSocket endpoint for real-time agent pipeline execution.
Protocol:
Client sends: JSON with patient case data (CaseSubmission format)
Server sends: JSON messages for each step update and final report
Message types:
- {"type": "step_update", "step": {...}}
- {"type": "report", "report": {...}}
- {"type": "error", "message": "..."}
- {"type": "complete", "case_id": "..."}
"""
await websocket.accept()
try:
# Receive the case submission
raw = await websocket.receive_text()
data = json.loads(raw)
case = CaseSubmission(**data)
# Send acknowledgment
await websocket.send_json({
"type": "ack",
"message": "Case received. Checking model readiness...",
})
# ── Readiness gate: wait for MedGemma to be warm ──
medgemma = MedGemmaService()
async def _send_warming(elapsed: float, message: str):
"""Stream warm-up progress to client."""
try:
await websocket.send_json({
"type": "warming_up",
"message": message,
"elapsed_seconds": int(elapsed),
})
except Exception:
pass # client may have disconnected
ready = await medgemma.wait_until_ready(on_waiting=_send_warming)
if not ready:
await websocket.send_json({
"type": "error",
"message": (
"MedGemma model did not become ready within the timeout. "
"The endpoint may be starting up — please try again in a minute."
),
})
return
await websocket.send_json({
"type": "model_ready",
"message": "MedGemma is ready. Starting agent pipeline...",
})
# Run the orchestrator and stream updates
orchestrator = Orchestrator()
async for step in orchestrator.run(case):
await websocket.send_json({
"type": "step_update",
"step": step.model_dump(mode="json"),
})
# Send final report
report = orchestrator.get_result()
if report:
await websocket.send_json({
"type": "report",
"report": report.model_dump(mode="json"),
})
# Send completion
await websocket.send_json({
"type": "complete",
"case_id": orchestrator.state.case_id if orchestrator.state else "unknown",
})
except WebSocketDisconnect:
pass
except json.JSONDecodeError:
await websocket.send_json({
"type": "error",
"message": "Invalid JSON received",
})
except Exception as e:
try:
await websocket.send_json({
"type": "error",
"message": str(e),
})
except Exception:
pass
finally:
try:
await websocket.close()
except Exception:
pass
|