Spaces:
Runtime error
Runtime error
| from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query, Form | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi.encoders import jsonable_encoder | |
| from typing import Optional, List | |
| from pydantic import BaseModel | |
| from auth import get_current_user | |
| from utils import clean_text_response | |
| from analysis import analyze_patient_report | |
| from voice import recognize_speech, text_to_speech, extract_text_from_pdf | |
| from docx import Document | |
| import re | |
| import io | |
| from datetime import datetime | |
| from bson import ObjectId | |
| import asyncio | |
| from bson.errors import InvalidId | |
| # Define the ChatRequest model with an optional patient_id | |
| class ChatRequest(BaseModel): | |
| message: str | |
| history: Optional[List[dict]] = None | |
| format: Optional[str] = "clean" | |
| temperature: Optional[float] = 0.7 | |
| max_new_tokens: Optional[int] = 512 | |
| patient_id: Optional[str] = None # Added optional patient_id field | |
| class VoiceOutputRequest(BaseModel): | |
| text: str | |
| language: str = "en-US" | |
| slow: bool = False | |
| return_format: str = "mp3" | |
| class RiskLevel(BaseModel): | |
| level: str | |
| score: float | |
| factors: Optional[List[str]] = None | |
| def create_router(agent, logger, patients_collection, analysis_collection, users_collection): | |
| router = APIRouter() | |
| async def status(current_user: dict = Depends(get_current_user)): | |
| logger.info(f"Status endpoint accessed by {current_user['email']}") | |
| return { | |
| "status": "running", | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "version": "2.6.0", | |
| "features": ["chat", "voice-input", "voice-output", "patient-analysis", "report-upload"] | |
| } | |
| async def get_patient_analysis_results( | |
| name: Optional[str] = Query(None), | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| logger.info(f"Fetching analysis results by {current_user['email']}") | |
| try: | |
| query = {} | |
| if name: | |
| name_regex = re.compile(name, re.IGNORECASE) | |
| matching_patients = await patients_collection.find({"full_name": name_regex}).to_list(length=None) | |
| patient_ids = [p["fhir_id"] for p in matching_patients if "fhir_id" in p] | |
| if not patient_ids: | |
| return [] | |
| query = {"patient_id": {"$in": patient_ids}} | |
| analyses = await analysis_collection.find(query).sort("timestamp", -1).to_list(length=100) | |
| enriched_results = [] | |
| for analysis in analyses: | |
| patient = await patients_collection.find_one({"fhir_id": analysis.get("patient_id")}) | |
| if not patient: | |
| continue # Skip if patient no longer exists | |
| analysis["full_name"] = patient.get("full_name", "Unknown") | |
| analysis["_id"] = str(analysis["_id"]) | |
| enriched_results.append(analysis) | |
| return enriched_results | |
| except Exception as e: | |
| logger.error(f"Error fetching analysis results: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to retrieve analysis results") | |
| async def chat_stream_endpoint( | |
| request: ChatRequest, | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| logger.info(f"Chat stream initiated by {current_user['email']}") | |
| async def token_stream(): | |
| try: | |
| conversation = [{"role": "system", "content": agent.chat_prompt}] | |
| if request.history: | |
| conversation.extend(request.history) | |
| conversation.append({"role": "user", "content": request.message}) | |
| input_ids = agent.tokenizer.apply_chat_template( | |
| conversation, add_generation_prompt=True, return_tensors="pt" | |
| ).to(agent.device) | |
| output = agent.model.generate( | |
| input_ids, | |
| do_sample=True, | |
| temperature=request.temperature, | |
| max_new_tokens=request.max_new_tokens, | |
| pad_token_id=agent.tokenizer.eos_token_id, | |
| return_dict_in_generate=True | |
| ) | |
| text = agent.tokenizer.decode(output["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True) | |
| cleaned_text = clean_text_response(text) | |
| full_response = "" | |
| # Store chat session in database | |
| chat_entry = { | |
| "user_id": current_user["email"], | |
| "patient_id": request.patient_id, # Now safely optional, defaults to None | |
| "message": request.message, | |
| "response": cleaned_text, | |
| "chat_type": "chat", | |
| "timestamp": datetime.utcnow(), | |
| "temperature": request.temperature, | |
| "max_new_tokens": request.max_new_tokens | |
| } | |
| result = await analysis_collection.insert_one(chat_entry) | |
| chat_entry["_id"] = str(result.inserted_id) | |
| for chunk in cleaned_text.split(): | |
| full_response += chunk + " " | |
| yield chunk + " " | |
| await asyncio.sleep(0.05) | |
| # Update chat entry with full response | |
| await analysis_collection.update_one( | |
| {"_id": result.inserted_id}, | |
| {"$set": {"response": full_response}} | |
| ) | |
| except Exception as e: | |
| logger.error(f"Streaming error: {e}") | |
| yield f"⚠️ Error: {e}" | |
| return StreamingResponse(token_stream(), media_type="text/plain") | |
| async def get_chats( | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| logger.info(f"Fetching chats for {current_user['email']}") | |
| try: | |
| chats = await analysis_collection.find({"user_id": current_user["email"], "chat_type": "chat"}).sort("timestamp", -1).to_list(length=100) | |
| return [ | |
| { | |
| "id": str(chat["_id"]), | |
| "title": chat.get("message", "Untitled Chat")[:30], # First 30 chars of message as title | |
| "timestamp": chat["timestamp"].isoformat(), | |
| "message": chat["message"], | |
| "response": chat["response"] | |
| } | |
| for chat in chats | |
| ] | |
| except Exception as e: | |
| logger.error(f"Error fetching chats: {e}") | |
| raise HTTPException(status_code=500, detail="Failed to retrieve chats") | |
| async def transcribe_voice( | |
| audio: UploadFile = File(...), | |
| language: str = Query("en-US", description="Language code for speech recognition"), | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| logger.info(f"Voice transcription initiated by {current_user['email']}") | |
| try: | |
| audio_data = await audio.read() | |
| if not audio.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')): | |
| raise HTTPException(status_code=400, detail="Unsupported audio format") | |
| text = recognize_speech(audio_data, language) | |
| return {"text": text} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in voice transcription: {e}") | |
| raise HTTPException(status_code=500, detail="Error processing voice input") | |
| async def synthesize_voice( | |
| request: VoiceOutputRequest, | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| logger.info(f"Voice synthesis initiated by {current_user['email']}") | |
| try: | |
| audio_data = text_to_speech(request.text, request.language, request.slow) | |
| if request.return_format == "base64": | |
| return {"audio": base64.b64encode(audio_data).decode('utf-8')} | |
| else: | |
| return StreamingResponse( | |
| io.BytesIO(audio_data), | |
| media_type="audio/mpeg", | |
| headers={"Content-Disposition": "attachment; filename=speech.mp3"} | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in voice synthesis: {e}") | |
| raise HTTPException(status_code=500, detail="Error generating voice output") | |
| async def voice_chat_endpoint( | |
| audio: UploadFile = File(...), | |
| language: str = Query("en-US", description="Language code for speech recognition"), | |
| temperature: float = Query(0.7, ge=0.1, le=1.0), | |
| max_new_tokens: int = Query(512, ge=50, le=1024), | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| logger.info(f"Voice chat initiated by {current_user['email']}") | |
| try: | |
| audio_data = await audio.read() | |
| user_message = recognize_speech(audio_data, language) | |
| chat_response = agent.chat( | |
| message=user_message, | |
| history=[], | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens | |
| ) | |
| audio_data = text_to_speech(chat_response, language.split('-')[0]) | |
| # Store voice chat in database | |
| chat_entry = { | |
| "user_id": current_user["email"], | |
| "patient_id": None, | |
| "message": user_message, | |
| "response": chat_response, | |
| "chat_type": "voice_chat", | |
| "timestamp": datetime.utcnow(), | |
| "temperature": temperature, | |
| "max_new_tokens": max_new_tokens | |
| } | |
| result = await analysis_collection.insert_one(chat_entry) | |
| chat_entry["_id"] = str(result.inserted_id) | |
| return StreamingResponse( | |
| io.BytesIO(audio_data), | |
| media_type="audio/mpeg", | |
| headers={"Content-Disposition": "attachment; filename=response.mp3"} | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in voice chat: {e}") | |
| raise HTTPException(status_code=500, detail="Error processing voice chat") | |
| async def analyze_clinical_report( | |
| file: UploadFile = File(...), | |
| patient_id: Optional[str] = Form(None), | |
| temperature: float = Form(0.5), | |
| max_new_tokens: int = Form(1024), | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| logger.info(f"Report analysis initiated by {current_user['email']}") | |
| try: | |
| content_type = file.content_type | |
| allowed_types = [ | |
| 'application/pdf', | |
| 'text/plain', | |
| 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' | |
| ] | |
| if content_type not in allowed_types: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported file type: {content_type}. Supported types: PDF, TXT, DOCX" | |
| ) | |
| file_content = await file.read() | |
| if content_type == 'application/pdf': | |
| text = extract_text_from_pdf(file_content) | |
| elif content_type == 'text/plain': | |
| text = file_content.decode('utf-8') | |
| elif content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': | |
| doc = Document(io.BytesIO(file_content)) | |
| text = "\n".join([para.text for para in doc.paragraphs]) | |
| else: | |
| raise HTTPException(status_code=400, detail="Unsupported file type") | |
| text = clean_text_response(text) | |
| if len(text.strip()) < 50: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Extracted text is too short (minimum 50 characters required)" | |
| ) | |
| analysis = await analyze_patient_report( | |
| patient_id=patient_id, | |
| report_content=text, | |
| file_type=content_type, | |
| file_content=file_content | |
| ) | |
| if "_id" in analysis and isinstance(analysis["_id"], ObjectId): | |
| analysis["_id"] = str(analysis["_id"]) | |
| if "timestamp" in analysis and isinstance(analysis["timestamp"], datetime): | |
| analysis["timestamp"] = analysis["timestamp"].isoformat() | |
| return JSONResponse(content=jsonable_encoder({ | |
| "status": "success", | |
| "analysis": analysis, | |
| "patient_id": patient_id, | |
| "file_type": content_type, | |
| "file_size": len(file_content) | |
| })) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in report analysis: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to analyze report: {str(e)}" | |
| ) | |
| async def delete_patient( | |
| patient_id: str, | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| logger.info(f"Patient deletion initiated by {current_user['email']} for patient {patient_id}") | |
| try: | |
| # Check if the patient exists | |
| patient = await patients_collection.find_one({"fhir_id": patient_id}) | |
| if not patient: | |
| raise HTTPException(status_code=404, detail="Patient not found") | |
| # Check if the current user is authorized (e.g., created_by matches or is admin) | |
| if patient.get("created_by") != current_user["email"] and not current_user.get("is_admin", False): | |
| raise HTTPException(status_code=403, detail="Not authorized to delete this patient") | |
| # Delete all analyses and chats associated with this patient | |
| await analysis_collection.delete_many({"patient_id": patient_id}) | |
| logger.info(f"Deleted analyses and chats for patient {patient_id}") | |
| # Delete the patient | |
| await patients_collection.delete_one({"fhir_id": patient_id}) | |
| logger.info(f"Patient {patient_id} deleted successfully") | |
| return {"status": "success", "message": f"Patient {patient_id} and associated analyses/chats deleted"} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error deleting patient {patient_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to delete patient: {str(e)}") | |
| return router |