File size: 2,992 Bytes
86f402d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | """
Chat Routes - Patient-level chat with image analysis tools
"""
import asyncio
import json
import threading
from typing import Optional
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from fastapi.responses import StreamingResponse
from data.case_store import get_case_store
from backend.services.chat_service import get_chat_service
router = APIRouter()
@router.get("/{patient_id}/chat")
def get_chat_history(patient_id: str):
"""Get patient-level chat history"""
store = get_case_store()
if not store.get_patient(patient_id):
raise HTTPException(status_code=404, detail="Patient not found")
messages = store.get_patient_chat_history(patient_id)
return {"messages": messages}
@router.delete("/{patient_id}/chat")
def clear_chat(patient_id: str):
"""Clear patient-level chat history"""
store = get_case_store()
if not store.get_patient(patient_id):
raise HTTPException(status_code=404, detail="Patient not found")
store.clear_patient_chat_history(patient_id)
return {"success": True}
@router.post("/{patient_id}/chat")
async def post_chat_message(
patient_id: str,
content: str = Form(""),
image: Optional[UploadFile] = File(None),
):
"""Send a chat message, optionally with an image — SSE streaming response.
The sync ML generator runs in a background thread so it never blocks the
event loop. Events flow through an asyncio.Queue, so each SSE event is
flushed to the browser the moment it is produced (spinner shows instantly).
"""
store = get_case_store()
if not store.get_patient(patient_id):
raise HTTPException(status_code=404, detail="Patient not found")
image_bytes = None
if image and image.filename:
image_bytes = await image.read()
chat_service = get_chat_service()
async def generate():
loop = asyncio.get_event_loop()
queue: asyncio.Queue = asyncio.Queue()
_SENTINEL = object()
def run_sync():
try:
for event in chat_service.stream_chat(patient_id, content, image_bytes):
loop.call_soon_threadsafe(queue.put_nowait, event)
except Exception as e:
loop.call_soon_threadsafe(
queue.put_nowait,
{"type": "error", "message": str(e)},
)
finally:
loop.call_soon_threadsafe(queue.put_nowait, _SENTINEL)
thread = threading.Thread(target=run_sync, daemon=True)
thread.start()
while True:
event = await queue.get()
if event is _SENTINEL:
break
yield f"data: {json.dumps(event)}\n\n"
yield f"data: {json.dumps({'type': 'done'})}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
|