Spaces:
Running
Running
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| import os | |
| # Import your agent directly (same repo) | |
| from nivra_agent import nivra_chat | |
| app = FastAPI( | |
| title="Nivra AI Healthcare Assistant API", | |
| description="🩺 India-first AI Healthcare Assistant with ClinicalBERT + Groq", | |
| version="1.0.0" | |
| ) | |
| # CORS for Flutter app (production-ready) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Lock this to your Flutter app domain in production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class SymptomInput(BaseModel): | |
| symptoms: List[str] = [] | |
| language: str = "en" | |
| age: Optional[int] = None | |
| gender: Optional[str] = None | |
| class DiagnosisResponse(BaseModel): | |
| diagnosis: str | |
| confidence: float = 0.85 | |
| recommendations: str = "" | |
| urgency: str = "low" | |
| audio_url: Optional[str] = None | |
| success: bool = True | |
| async def diagnose_text_symptoms(input: SymptomInput): | |
| """ | |
| Main App endpoint - Text-based symptom diagnosis | |
| Calls Nivra AI Agent for diagnosis via text | |
| """ | |
| try: | |
| # Format prompt for your agent | |
| symptoms_text = "Patient age: {} {}, symptoms: {}".format( | |
| input.age or "unknown", | |
| input.gender or "unknown", | |
| ", ".join(input.symptoms) | |
| ) | |
| # Call YOUR existing nivra_chat agent directly (no HTTP calls!) | |
| diagnosis = nivra_chat(symptoms_text) | |
| # Parse urgency from diagnosis (simple keyword matching) | |
| urgency = "low" | |
| if any(word in diagnosis.lower() for word in ["critical", "emergency", "severe"]): | |
| urgency = "critical" | |
| elif any(word in diagnosis.lower() for word in ["consult doctor", "see specialist"]): | |
| urgency = "medium" | |
| return DiagnosisResponse( | |
| diagnosis=diagnosis, | |
| confidence=0.85, | |
| recommendations="Follow the guidance above. Consult a doctor if symptoms worsen.", | |
| urgency=urgency, | |
| audio_url=f"https://huggingface.co/spaces/nivra/tts/{input.language}", # TTS endpoint | |
| success=True | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Diagnosis failed: {str(e)}" | |
| ) | |
| async def diagnose_image_symptoms( | |
| file: UploadFile = File(...), | |
| age: Optional[int] = None, | |
| gender: Optional[str] = None | |
| ): | |
| """ | |
| Image-based diagnosis endpoint | |
| Uses your image_symptom_tool.py | |
| """ | |
| try: | |
| # Save uploaded image temporarily | |
| image_path = f"/tmp/{file.filename}" | |
| with open(image_path, "wb") as f: | |
| f.write(await file.read()) | |
| # Call your agent with image context | |
| prompt = f"Patient image analysis: {image_path}" | |
| if age or gender: | |
| prompt += f"\nPatient: {age}yo {gender}" | |
| diagnosis = nivra_chat(prompt) | |
| return { | |
| "diagnosis": diagnosis, | |
| "type": "image_analysis", | |
| "success": True | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def root(): | |
| """Root endpoint - API info""" | |
| return { | |
| "message": "🩺 Nivra AI Healthcare API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "text_diagnosis": "/diagnose/text", | |
| "image_diagnosis": "/diagnose/image", | |
| "health_check": "/health", | |
| "docs": "/docs" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check for monitoring""" | |
| return { | |
| "status": "healthy", | |
| "agent": "nivra_chat loaded", | |
| "models": ["ClinicalBERT", "Groq LLM", "Indic Parler-TTS"] | |
| } | |
| # Environment info (useful for debugging on HF Spaces) | |
| async def system_info(): | |
| """System information""" | |
| return { | |
| "space_author": os.getenv("SPACE_AUTHOR_NAME", "unknown"), | |
| "space_repo": os.getenv("SPACE_REPO_NAME", "unknown"), | |
| "space_id": os.getenv("SPACE_ID", "unknown"), | |
| "host": os.getenv("SPACE_HOST", "localhost") | |
| } | |