Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import json | |
| import base64 | |
| import logging | |
| from dotenv import load_dotenv # type: ignore | |
| import httpx # type: ignore | |
| from fastapi import FastAPI, File, Form, UploadFile, HTTPException, Depends # type: ignore | |
| from fastapi.middleware.cors import CORSMiddleware # type: ignore | |
| from fastapi.responses import StreamingResponse # type: ignore | |
| from pydantic import BaseModel, EmailStr # type: ignore | |
| from PIL import Image # type: ignore | |
| from sqlalchemy.orm import Session | |
| from database import init_db, get_db, User, hash_password, verify_password | |
| from earcp_orchestrator import get_ensemble | |
| load_dotenv() | |
| # ββ Logging ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(levelname)s: %(message)s") | |
| logger = logging.getLogger("hemo") | |
| # ββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HF_TOKEN = os.getenv("HF_TOKEN", "") | |
| # Chat/LLM model β served via SambaNova inference provider (supports open LLMs) | |
| MEDGEMMA_MODEL = os.getenv("HF_MEDGEMMA_MODEL", "meta-llama/Llama-3.2-3B-Instruct") | |
| WHISPER_MODEL = os.getenv("HF_WHISPER_MODEL", "openai/whisper-large-v3") | |
| LLAVA_MODEL = os.getenv("HF_LLAVA_MODEL", "llava-hf/llava-1.5-7b-hf") | |
| # V1 Router: Supports main HF conversational API | |
| HF_CHAT_URL = "https://router.huggingface.co/v1/chat/completions" | |
| # hf-inference: used for ASR (Whisper) β CPU-capable tasks | |
| HF_ROUTER_BASE = "https://router.huggingface.co/hf-inference/models" | |
| # Direct inference API (for fallback / vision) | |
| HF_INFER_BASE = "https://api-inference.huggingface.co" | |
| HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} | |
| # ββ App βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI(title="Hemo AI Backend", version="3.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Init DB | |
| init_db() | |
| # Lightweight Ensemble Orchestrator | |
| ensemble = get_ensemble() | |
| logger.info("Ensemble orchestrator initialized.") | |
| async def root(): | |
| return { | |
| "status": "online", | |
| "message": "Hemo AI Backend is running successfully.", | |
| "version": "3.0.0", | |
| "api_docs": "/docs" | |
| } | |
| # ββ Pydantic schemas ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ChatRequest(BaseModel): | |
| message: str | |
| history: list[dict] = [] | |
| class UserSignup(BaseModel): | |
| username: str | |
| email: EmailStr | |
| password: str | |
| class UserLogin(BaseModel): | |
| username: str | |
| password: str | |
| class AuthResponse(BaseModel): | |
| message: str | |
| username: str | |
| token: str = "demo-token" # Placeholder until JWT is fully setup | |
| class ChatResponse(BaseModel): | |
| response: str | |
| class AudioResponse(BaseModel): | |
| transcription: str | |
| ai_response: str | |
| history: list[dict] = [] | |
| class VisionResponse(BaseModel): | |
| visual_description: str | |
| medical_analysis: str | |
| analysis: str | |
| class MultimodalResponse(BaseModel): | |
| response: str | |
| transcription: str | None = None | |
| visual_description: str | None = None | |
| earcp_weights: dict = {} | |
| history: list[dict] = [] | |
| # ββ System prompt ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_system_prompt() -> str: | |
| return ( | |
| "Tu es Dr. Hemo, un assistant mΓ©dical bienveillant et expert en drΓ©panocytose (anΓ©mie falciforme) " | |
| "et en santΓ© gΓ©nΓ©rale. " | |
| "RΓ©ponds TOUJOURS dans la mΓͺme langue que l'utilisateur (franΓ§ais, anglais, Γ©wΓ©, etc.). " | |
| "Sois clair, empathique et prΓ©cis. Structure tes rΓ©ponses si nΓ©cessaire (listes, Γ©tapes). " | |
| "Rappelle toujours Γ l'utilisateur de consulter un professionnel de santΓ© pour tout diagnostic." | |
| ) | |
| # ββ Core AI helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def call_medgemma(prompt: str, history: list[dict] | None = None) -> str: | |
| if history is None: | |
| history = [] | |
| system = make_system_prompt() | |
| messages = [{"role": "system", "content": system}] | |
| for turn in history[-10:]: | |
| messages.append(turn) | |
| messages.append({"role": "user", "content": prompt}) | |
| payload = { | |
| "model": MEDGEMMA_MODEL, | |
| "messages": messages, | |
| "max_new_tokens": 700, | |
| "temperature": 0.65, | |
| } | |
| url = HF_CHAT_URL | |
| logger.info(f"β MedGemma/Gemma-3: {prompt[:80]!r}") | |
| async with httpx.AsyncClient(timeout=90.0) as client: | |
| resp = await client.post(url, json=payload, headers=HEADERS) | |
| if resp.status_code != 200: | |
| logger.error(f"MedGemma/Gemma-3 {resp.status_code}: {resp.text[:200]}") | |
| return "Je suis dΓ©solΓ©, le service IA est temporairement indisponible. Veuillez rΓ©essayer." | |
| data = resp.json() | |
| return data["choices"][0]["message"]["content"].strip() | |
| async def call_medgemma_stream(prompt: str, history: list[dict] | None = None): | |
| if history is None: | |
| history = [] | |
| system = make_system_prompt() | |
| messages = [{"role": "system", "content": system}] | |
| for turn in history[-10:]: | |
| messages.append(turn) | |
| messages.append({"role": "user", "content": prompt}) | |
| payload = { | |
| "model": MEDGEMMA_MODEL, | |
| "messages": messages, | |
| "max_tokens": 700, | |
| "temperature": 0.65, | |
| "stream": True, | |
| } | |
| url = HF_CHAT_URL | |
| full_text = "" | |
| async with httpx.AsyncClient(timeout=120.0) as client: | |
| async with client.stream("POST", url, json=payload, headers=HEADERS) as resp: | |
| if resp.status_code != 200: | |
| error_body = await resp.aread() | |
| logger.error(f"MedGemma stream {resp.status_code}: {error_body[:200]}") | |
| yield f"data: {json.dumps({'error': 'Service IA indisponible'})}\n\n" | |
| return | |
| async for line in resp.aiter_lines(): | |
| if not line or not line.startswith("data: "): | |
| continue | |
| chunk = line[6:] | |
| if chunk.strip() == "[DONE]": | |
| break | |
| try: | |
| obj = json.loads(chunk) | |
| delta = obj["choices"][0]["delta"].get("content", "") | |
| if delta: | |
| full_text += delta | |
| yield f"data: {json.dumps({'delta': delta, 'done': False})}\n\n" | |
| except Exception: | |
| pass | |
| yield f"data: {json.dumps({'delta': '', 'done': True, 'full': full_text})}\n\n" | |
| async def call_llava_description(image_bytes: bytes) -> str: | |
| try: | |
| img = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| max_side = 768 | |
| w, h = img.size | |
| if max(w, h) > max_side: | |
| scale = max_side / max(w, h) | |
| img = img.resize((int(w * scale), int(h * scale)), Image.LANCZOS) | |
| buf = io.BytesIO() | |
| img.save(buf, format="JPEG", quality=80) | |
| img_b64 = base64.b64encode(buf.getvalue()).decode() | |
| img_url = f"data:image/jpeg;base64,{img_b64}" | |
| description_prompt = ( | |
| "Describe this medical image in detail. Include: " | |
| "what type of image it is (X-ray, photo, lab result, skin lesion, etc.), " | |
| "visible structures, any abnormalities, colors, patterns, and any text visible. " | |
| "Be factual and exhaustive. Do NOT give a diagnosis β only describe what you see." | |
| ) | |
| payload = { | |
| "model": LLAVA_MODEL, | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image_url", "image_url": {"url": img_url}}, | |
| {"type": "text", "text": description_prompt}, | |
| ], | |
| } | |
| ], | |
| "max_new_tokens": 400, | |
| "temperature": 0.2, | |
| } | |
| url = f"{HF_INFER_BASE}/v1/chat/completions" | |
| logger.info("β LLaVA: requesting visual description") | |
| async with httpx.AsyncClient(timeout=90.0) as client: | |
| resp = await client.post(url, json=payload, headers=HEADERS) | |
| if resp.status_code != 200: | |
| logger.warning(f"LLaVA failed {resp.status_code}: {resp.text[:200]}. Falling back.") | |
| return "" | |
| data = resp.json() | |
| description = data["choices"][0]["message"]["content"].strip() | |
| logger.info(f"LLaVA description ({len(description)} chars)") | |
| return description | |
| except Exception as e: | |
| logger.warning(f"LLaVA exception: {e}. Falling back to text-only.") | |
| return "" | |
| async def call_multimodal_analysis( | |
| image_bytes: bytes, | |
| user_question: str, | |
| history: list[dict] | None = None, | |
| ) -> tuple[str, str]: | |
| if history is None: | |
| history = [] | |
| visual_description = await call_llava_description(image_bytes) | |
| if visual_description: | |
| medgemma_prompt = ( | |
| f"[DESCRIPTION VISUELLE DE L'IMAGE]\n{visual_description}\n\n" | |
| f"[QUESTION DE L'UTILISATEUR]\n{user_question}\n\n" | |
| "En tant que Dr. Hemo, donne une analyse mΓ©dicale approfondie de cette image " | |
| "en te basant sur la description visuelle ci-dessus. " | |
| "Identifie les Γ©lΓ©ments mΓ©dicaux pertinents, explique ce qu'ils signifient " | |
| "dans le contexte de la drΓ©panocytose ou de la santΓ© gΓ©nΓ©rale, et donne des conseils pratiques. " | |
| "Structure ta rΓ©ponse clairement." | |
| ) | |
| else: | |
| medgemma_prompt = ( | |
| f"[Image mΓ©dicale soumise β analyse visuelle indisponible]\n" | |
| f"[QUESTION DE L'UTILISATEUR]\n{user_question}\n\n" | |
| "RΓ©ponds Γ la question mΓ©dicale de l'utilisateur mΓͺme sans accΓ¨s Γ l'image. " | |
| "Demande-lui de dΓ©crire ce qu'il voit si possible." | |
| ) | |
| medical_analysis = await call_medgemma(medgemma_prompt, history) | |
| return visual_description, medical_analysis | |
| async def call_whisper(audio_bytes: bytes) -> str: | |
| url = f"{HF_ROUTER_BASE}/{WHISPER_MODEL}" | |
| logger.info(f"β Whisper: {len(audio_bytes)} bytes") | |
| async with httpx.AsyncClient(timeout=120.0) as client: | |
| resp = await client.post( | |
| url, | |
| content=audio_bytes, | |
| headers={**HEADERS, "Content-Type": "audio/webm"}, | |
| ) | |
| if resp.status_code != 200: | |
| logger.error(f"Whisper {resp.status_code}: {resp.text[:200]}") | |
| raise HTTPException(status_code=502, detail=f"Whisper transcription failed: {resp.status_code}") | |
| data = resp.json() | |
| text = data.get("text", "").strip() | |
| logger.info(f"Whisper result: {text!r}") | |
| return text | |
| async def synthesize_tts(text: str, lang: str = "fr") -> bytes: | |
| """ | |
| Generate TTS audio using gTTS and return MP3 bytes. | |
| Falls back to empty bytes on error. | |
| """ | |
| try: | |
| from gtts import gTTS # type: ignore | |
| tts = gTTS(text=text[:500], lang=lang, slow=False) | |
| buf = io.BytesIO() | |
| tts.write_to_fp(buf) | |
| buf.seek(0) | |
| logger.info(f"TTS synthesized ({len(text)} chars)") | |
| return buf.read() | |
| except Exception as e: | |
| logger.warning(f"TTS failed: {e}") | |
| return b"" | |
| # ββ Endpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def signup(user_data: UserSignup, db: Session = Depends(get_db)): | |
| db_user = db.query(User).filter(User.username == user_data.username).first() | |
| if db_user: | |
| raise HTTPException(status_code=400, detail="Username already exists") | |
| new_user = User( | |
| username=user_data.username, | |
| email=user_data.email, | |
| hashed_password=hash_password(user_data.password) | |
| ) | |
| db.add(new_user) | |
| db.commit() | |
| db.refresh(new_user) | |
| return {"message": "Success", "username": new_user.username, "token": "signup-token"} | |
| async def login(user_data: UserLogin, db: Session = Depends(get_db)): | |
| user = db.query(User).filter(User.username == user_data.username).first() | |
| if not user or not verify_password(user_data.password, user.hashed_password): | |
| raise HTTPException(status_code=401, detail="Invalid credentials") | |
| return {"message": "Logged in", "username": user.username, "token": "login-token"} | |
| async def health(): | |
| ensemble = get_ensemble() | |
| return { | |
| "status": "ok", | |
| "service": "Hemo AI Backend", | |
| "version": "3.0.0", | |
| "models": { | |
| "chat": MEDGEMMA_MODEL, | |
| "vision": LLAVA_MODEL, | |
| "whisper": WHISPER_MODEL, | |
| }, | |
| "earcp": ensemble.get_diagnostics(), | |
| } | |
| async def chat(req: ChatRequest): | |
| """Standard text chat with conversation history.""" | |
| logger.info(f"Chat: {req.message[:80]!r}") | |
| response = await call_medgemma(req.message, req.history) | |
| return {"response": response} | |
| async def chat_stream(message: str, history_json: str = "[]"): | |
| """SSE streaming endpoint.""" | |
| try: | |
| history = json.loads(history_json) | |
| except Exception: | |
| history = [] | |
| logger.info(f"Chat stream: {message[:80]!r}") | |
| return StreamingResponse( | |
| call_medgemma_stream(message, history), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| async def audio_query( | |
| file: UploadFile = File(...), | |
| history_json: str = Form(default="[]"), | |
| ): | |
| """Voice input: audio β Whisper β MedGemma β response.""" | |
| logger.info(f"Audio: {file.filename}") | |
| audio_bytes = await file.read() | |
| try: | |
| history: list[dict] = json.loads(history_json) | |
| except Exception: | |
| history = [] | |
| transcription = await call_whisper(audio_bytes) | |
| if not transcription: | |
| raise HTTPException(status_code=422, detail="Empty transcription β please speak clearly.") | |
| ai_response = await call_medgemma(transcription, history) | |
| updated_history = history + [ | |
| {"role": "user", "content": transcription}, | |
| {"role": "assistant", "content": ai_response}, | |
| ] | |
| return { | |
| "transcription": transcription, | |
| "ai_response": ai_response, | |
| "history": updated_history, | |
| } | |
| async def vision_query( | |
| file: UploadFile = File(...), | |
| prompt: str = Form(default="Analyse cette image mΓ©dicale."), | |
| history_json: str = Form(default="[]"), | |
| ): | |
| """Multimodal image analysis: LLaVA + MedGemma.""" | |
| logger.info(f"Vision: {file.filename}, prompt={prompt[:60]!r}") | |
| image_bytes = await file.read() | |
| try: | |
| history: list[dict] = json.loads(history_json) | |
| except Exception: | |
| history = [] | |
| visual_description, medical_analysis = await call_multimodal_analysis( | |
| image_bytes, prompt, history | |
| ) | |
| combined = ( | |
| f"**Analyse visuelle :**\n{visual_description}\n\n**Analyse mΓ©dicale :**\n{medical_analysis}" | |
| if visual_description | |
| else medical_analysis | |
| ) | |
| return { | |
| "visual_description": visual_description, | |
| "medical_analysis": medical_analysis, | |
| "analysis": combined, | |
| } | |
| async def text_to_speech(req: ChatRequest): | |
| """ | |
| Convert text to speech using gTTS. | |
| Returns MP3 audio bytes as base64 in JSON for easy frontend use. | |
| """ | |
| text = req.message | |
| if not text.strip(): | |
| raise HTTPException(status_code=400, detail="Text cannot be empty.") | |
| # Detect language (basic heuristic) | |
| lang = "fr" | |
| english_words = {"the", "is", "are", "this", "that", "what", "how", "when", "where"} | |
| words = set(text.lower().split()[:10]) | |
| if words & english_words: | |
| lang = "en" | |
| audio_bytes = await synthesize_tts(text, lang=lang) | |
| if not audio_bytes: | |
| raise HTTPException(status_code=500, detail="TTS synthesis failed.") | |
| audio_b64 = base64.b64encode(audio_bytes).decode() | |
| return {"audio_b64": audio_b64, "format": "mp3", "lang": lang} | |
| async def multimodal_unified( | |
| text: str = Form(default=""), | |
| history_json: str = Form(default="[]"), | |
| tts: str = Form(default="false"), | |
| image: UploadFile | None = File(default=None), | |
| audio: UploadFile | None = File(default=None), | |
| ): | |
| """ | |
| β Unified multimodal endpoint orchestrated by EARCP β | |
| Accepts any combination of: | |
| - text: a typed message | |
| - image: an uploaded image file (triggers LLaVA + MedGemma) | |
| - audio: a recorded audio file (triggers Whisper β MedGemma) | |
| - tts: "true" β return audio of the response as base64 MP3 | |
| EARCP dynamically weights the three expert models based on which | |
| modalities are present and how well each performed in previous turns. | |
| Returns: | |
| - response: the final text response from Dr. Hemo | |
| - transcription: if audio was provided | |
| - visual_description: if image was provided | |
| - earcp_weights: live model weights from the ensemble | |
| - history: updated conversation history | |
| - audio_b64: (optional) TTS audio if tts=true | |
| """ | |
| try: | |
| history: list[dict] = json.loads(history_json) | |
| except Exception: | |
| history = [] | |
| transcription: str | None = None | |
| visual_description: str | None = None | |
| if audio is not None: | |
| logger.info("Multimodal: Audio input detected") | |
| audio_bytes = await audio.read() | |
| image_b64: str | None = None | |
| if image is not None: | |
| logger.info("Multimodal: Image input detected") | |
| image_bytes = await image.read() | |
| image_b64 = base64.b64encode(image_bytes).decode() | |
| # ββ Execute the Unified Hemo Orchestration ββββββββββββββββββββββββββββββ | |
| logger.info("Executing Hemo multimodal orchestration...") | |
| transcription = None | |
| if audio is not None: | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(suffix=".webm", delete=False, mode='wb') as f: | |
| f.write(audio_bytes) | |
| f.flush() | |
| audio_out = ensemble.process_audio(f.name) | |
| transcription = audio_out.get("transcription") | |
| os.unlink(f.name) | |
| visual_description = None | |
| if image is not None: | |
| vision_out = ensemble.process_vision(image_b64, text) | |
| visual_description = vision_out.get("visual_description") | |
| # Generate final response | |
| prompt = text or (transcription if transcription else "Analyse cette image.") | |
| if visual_description: | |
| prompt = f"Description visuelle: {visual_description}\nUtilisateur: {prompt}" | |
| ai_response = await call_medgemma(prompt, history) | |
| earcp_weights = ensemble.get_weights() | |
| logger.info(f"Unified model responded. EARCP weights: {earcp_weights}") | |
| # ββ Step 5: Update conversation history βββββββββββββββββββββββββββββββββββ | |
| user_content = text.strip() | |
| if transcription: | |
| user_content = transcription if not user_content else f"{user_content}\n{transcription}" | |
| if visual_description: | |
| user_content += f"\n[Image analysΓ©e β {len(visual_description)} chars de description visuelle]" | |
| updated_history = history + [ | |
| {"role": "user", "content": user_content or "Analyse cette image mΓ©dicale."}, | |
| {"role": "assistant", "content": ai_response}, | |
| ] | |
| # ββ Step 6: Optional TTS ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| audio_b64_res = None | |
| if tts.lower() == "true": | |
| tts_bytes = await synthesize_tts(ai_response) | |
| if tts_bytes: | |
| audio_b64_res = base64.b64encode(tts_bytes).decode() | |
| result = { | |
| "response": ai_response, | |
| "transcription": transcription, | |
| "visual_description": visual_description, | |
| "earcp_weights": earcp_weights, | |
| "history": updated_history, | |
| } | |
| if audio_b64_res: | |
| result["audio_b64"] = audio_b64_res | |
| return result | |
| async def analyze_file( | |
| file: UploadFile = File(...), | |
| prompt: str = Form(default=""), | |
| ): | |
| """Analyse a medical document (PDF/image).""" | |
| content_type = file.content_type or "" | |
| file_bytes = await file.read() | |
| if "image" in content_type: | |
| q = prompt or "Analysez ce document mΓ©dical et fournissez un rΓ©sumΓ© clair avec les points clΓ©s." | |
| visual_desc, medical = await call_multimodal_analysis(file_bytes, q) | |
| summary = ( | |
| f"**Description :** {visual_desc}\n\n**Analyse :** {medical}" | |
| if visual_desc | |
| else medical | |
| ) | |
| return {"summary": summary, "visual_description": visual_desc, "filename": file.filename} | |
| else: | |
| q = ( | |
| prompt | |
| or f"Document mΓ©dical : {file.filename}. " | |
| "Fournis un rΓ©sumΓ© dΓ©taillΓ© et des conseils liΓ©s Γ la drΓ©panocytose." | |
| ) | |
| summary = await call_medgemma(q) | |
| return {"summary": summary, "visual_description": "", "filename": file.filename} | |
| async def earcp_weights(): | |
| """Return current EARCP model weights (for UI polling).""" | |
| ensemble = get_ensemble() | |
| return ensemble.get_diagnostics() | |