|
|
""" |
|
|
REST API for case submission and result retrieval. |
|
|
""" |
|
|
from __future__ import annotations |
|
|
|
|
|
import asyncio |
|
|
import uuid |
|
|
from typing import Dict |
|
|
|
|
|
from fastapi import APIRouter, HTTPException |
|
|
|
|
|
from app.agent.orchestrator import Orchestrator |
|
|
from app.models.schemas import ( |
|
|
AgentState, |
|
|
CaseResponse, |
|
|
CaseResult, |
|
|
CaseSubmission, |
|
|
) |
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
|
|
|
|
|
|
_cases: Dict[str, Orchestrator] = {} |
|
|
|
|
|
|
|
|
@router.post("/submit", response_model=CaseResponse) |
|
|
async def submit_case(case: CaseSubmission): |
|
|
""" |
|
|
Submit a patient case for analysis. |
|
|
|
|
|
The agent pipeline runs asynchronously. Use the WebSocket endpoint |
|
|
or poll /api/cases/{case_id} for real-time updates. |
|
|
""" |
|
|
orchestrator = Orchestrator() |
|
|
|
|
|
|
|
|
case_id = str(uuid.uuid4())[:8] |
|
|
|
|
|
async def _run_pipeline(): |
|
|
async for _step in orchestrator.run(case): |
|
|
pass |
|
|
|
|
|
if orchestrator.state: |
|
|
_cases[orchestrator.state.case_id] = orchestrator |
|
|
|
|
|
asyncio.create_task(_run_pipeline()) |
|
|
|
|
|
|
|
|
await asyncio.sleep(0.15) |
|
|
|
|
|
|
|
|
actual_id = orchestrator.state.case_id if orchestrator.state else case_id |
|
|
_cases[actual_id] = orchestrator |
|
|
|
|
|
return CaseResponse( |
|
|
case_id=actual_id, |
|
|
status="running", |
|
|
message="Agent pipeline started. Connect to WebSocket for real-time updates.", |
|
|
) |
|
|
|
|
|
|
|
|
@router.get("/{case_id}", response_model=CaseResult) |
|
|
async def get_case(case_id: str): |
|
|
"""Get the current state and results for a case.""" |
|
|
orchestrator = _cases.get(case_id) |
|
|
if not orchestrator or not orchestrator.state: |
|
|
raise HTTPException(status_code=404, detail=f"Case {case_id} not found") |
|
|
|
|
|
return CaseResult( |
|
|
case_id=case_id, |
|
|
state=orchestrator.state, |
|
|
report=orchestrator.get_result(), |
|
|
) |
|
|
|
|
|
|
|
|
@router.get("/", response_model=list[str]) |
|
|
async def list_cases(): |
|
|
"""List all case IDs.""" |
|
|
return list(_cases.keys()) |
|
|
|