Spaces:
Running
Running
File size: 6,461 Bytes
1e732dd 3ca1d38 1e732dd 3ca1d38 1e732dd 696f787 1e732dd 3ca1d38 1e732dd 696f787 1e732dd 7caf4dc 1e732dd 3ca1d38 9659593 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 fd5543a 3ca1d38 9659593 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 9659593 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 9659593 3ca1d38 9659593 3ca1d38 9659593 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 9659593 696f787 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | """
MediGuard AI — Ask Router
Free-form medical Q&A powered by the agentic RAG pipeline.
Supports both synchronous and SSE streaming responses.
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
import uuid
from collections.abc import AsyncGenerator
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from src.schemas.schemas import AskRequest, AskResponse, FeedbackRequest, FeedbackResponse
logger = logging.getLogger(__name__)
router = APIRouter(tags=["ask"])
@router.post("/ask", response_model=AskResponse)
async def ask_medical_question(body: AskRequest, request: Request):
"""Answer a free-form medical question via agentic RAG."""
rag_service = getattr(request.app.state, "rag_service", None)
if rag_service is None:
raise HTTPException(status_code=503, detail="RAG service unavailable")
request_id = f"req_{uuid.uuid4().hex[:12]}"
t0 = time.time()
try:
result = rag_service.ask(
query=body.question,
biomarkers=body.biomarkers,
patient_context=body.patient_context or "",
)
except Exception as exc:
logger.exception("Agentic RAG failed: %s", exc)
raise HTTPException(status_code=500, detail=f"RAG pipeline error: {exc}") from exc
elapsed = (time.time() - t0) * 1000
return AskResponse(
status="success",
request_id=request_id,
question=body.question,
answer=result.get("final_answer", ""),
guardrail_score=result.get("guardrail_score"),
documents_retrieved=len(result.get("retrieved_documents", [])),
documents_relevant=len(result.get("relevant_documents", [])),
processing_time_ms=round(elapsed, 1),
)
# ---------------------------------------------------------------------------
# SSE Streaming Endpoint
# ---------------------------------------------------------------------------
async def _stream_rag_response(
rag_service,
question: str,
biomarkers: dict | None,
patient_context: str,
request_id: str,
) -> AsyncGenerator[str, None]:
"""
Generate Server-Sent Events for streaming RAG responses.
Event types:
- status: Pipeline stage updates
- token: Individual response tokens
- metadata: Retrieval/grading info
- done: Final completion signal
- error: Error information
"""
t0 = time.time()
try:
# Send initial status
yield f"event: status\ndata: {json.dumps({'stage': 'guardrail', 'message': 'Validating query...'})}\n\n"
await asyncio.sleep(0) # Allow event loop to flush
# Run the RAG pipeline (synchronous, but we yield progress)
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
None,
lambda: rag_service.ask(
query=question,
biomarkers=biomarkers,
patient_context=patient_context,
),
)
# Send retrieval metadata
yield f"event: metadata\ndata: {json.dumps({'documents_retrieved': len(result.get('retrieved_documents', [])), 'documents_relevant': len(result.get('relevant_documents', [])), 'guardrail_score': result.get('guardrail_score')})}\n\n"
await asyncio.sleep(0)
# Stream the answer token by token for smooth UI
answer = result.get("final_answer", "")
if answer:
yield f"event: status\ndata: {json.dumps({'stage': 'generating', 'message': 'Generating response...'})}\n\n"
# Simulate streaming by chunking the response
words = answer.split()
chunk_size = 3 # Send 3 words at a time
for i in range(0, len(words), chunk_size):
chunk = " ".join(words[i : i + chunk_size])
if i + chunk_size < len(words):
chunk += " "
yield f"event: token\ndata: {json.dumps({'text': chunk})}\n\n"
await asyncio.sleep(0.02) # Small delay for visual streaming effect
# Send completion
elapsed = (time.time() - t0) * 1000
yield f"event: done\ndata: {json.dumps({'request_id': request_id, 'processing_time_ms': round(elapsed, 1), 'status': 'success'})}\n\n"
except Exception as exc:
logger.exception("Streaming RAG failed: %s", exc)
yield f"event: error\ndata: {json.dumps({'error': str(exc), 'request_id': request_id})}\n\n"
@router.post("/ask/stream")
async def ask_medical_question_stream(body: AskRequest, request: Request):
"""
Stream a medical Q&A response via Server-Sent Events (SSE).
Events:
- `status`: Pipeline stage updates (guardrail, retrieve, grade, generate)
- `token`: Individual response tokens for real-time display
- `metadata`: Retrieval statistics (documents found, relevance scores)
- `done`: Completion signal with timing info
- `error`: Error details if something fails
Example client code (JavaScript):
```javascript
const eventSource = new EventSource('/ask/stream', {
method: 'POST',
body: JSON.stringify({ question: 'What causes high glucose?' })
});
eventSource.addEventListener('token', (e) => {
const data = JSON.parse(e.data);
document.getElementById('response').innerHTML += data.text;
});
```
"""
rag_service = getattr(request.app.state, "rag_service", None)
if rag_service is None:
raise HTTPException(status_code=503, detail="RAG service unavailable")
request_id = f"req_{uuid.uuid4().hex[:12]}"
return StreamingResponse(
_stream_rag_response(
rag_service,
body.question,
body.biomarkers,
body.patient_context or "",
request_id,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Request-ID": request_id,
},
)
@router.post("/feedback", response_model=FeedbackResponse)
async def submit_feedback(body: FeedbackRequest, request: Request):
"""Submit user feedback for an analysis or RAG response."""
tracer = getattr(request.app.state, "tracer", None)
if tracer:
tracer.score(trace_id=body.request_id, name="user-feedback", value=body.score, comment=body.comment)
return FeedbackResponse(request_id=body.request_id)
|