Spaces:
Sleeping
Sleeping
| """ | |
| ElephMind Medical AI Backend | |
| ============================ | |
| Production-ready FastAPI backend for medical image analysis using SigLIP. | |
| Author: ElephMind Team | |
| Version: 2.0.0 (Cleaned & Secured) | |
| """ | |
| import sys | |
| import os | |
| import uuid | |
| import asyncio | |
| import time | |
| import logging | |
| # --- DOTENV SUPPORT (MUST BE FIRST) --- | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| from enum import Enum | |
| from typing import Dict, List, Optional, Any, Tuple | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from datetime import datetime | |
| from contextlib import asynccontextmanager | |
| import uvicorn | |
| import base64 | |
| import cv2 | |
| import numpy as np | |
| from pytorch_grad_cam import GradCAMPlusPlus | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from localization import localize_result | |
| import torch | |
| import torch.nn as nn | |
| # Local modules | |
| import database | |
| import storage_manager # NEW: Physical storage layout | |
| from database import JobStatus | |
| from storage import get_storage_provider | |
| import encryption | |
| import database | |
| # algorithms imported directly above | |
| import math | |
| from collections import deque | |
| from dataclasses import dataclass, field | |
| from PIL import Image | |
| import io | |
| # --- GRADCAM UTILS FOR SIGLIP/ViT --- | |
| # Class moved to Line 781 (Deduplication) | |
| # Function moved to Line 798 (Deduplication) | |
| # --- AUTH IMPORTS --- | |
| from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |
| from fastapi import Depends, status, Request | |
| from datetime import datetime, timedelta | |
| from jose import JWTError, jwt | |
| import bcrypt | |
| # --- DOTENV (Moved to top) --- | |
| # ========================================================================= | |
| # LOGGING CONFIGURATION | |
| # ========================================================================= | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger("ElephMind-Backend") | |
| # ========================================================================= | |
| # 7 INTELLIGENCE ALGORITHMS (Merged from algorithms.py) | |
| # ========================================================================= | |
| # 1. IMAGE QUALITY ASSESSMENT | |
| def detect_blur(image: np.ndarray) -> float: | |
| """ | |
| Detect blur using Laplacian variance. | |
| Higher score = sharper image. | |
| Returns: 0.0 (very blurry) to 1.0 (very sharp) | |
| """ | |
| if len(image.shape) == 3: | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| else: | |
| gray = image | |
| laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var() | |
| # Normalize to 0-1 (empirical thresholds for medical images) | |
| return min(1.0, laplacian_var / 500.0) | |
| def assess_image_quality(image: np.ndarray) -> Dict[str, Any]: | |
| """Assess image quality metrics.""" | |
| score = 0 | |
| metrics = [] | |
| # Blur detection | |
| sharpness = detect_blur(image) | |
| metrics.append({"metric": "Netteté", "value": int(sharpness * 100)}) | |
| if sharpness > 0.6: score += 40 | |
| elif sharpness > 0.3: score += 20 | |
| # Contrast check | |
| if len(image.shape) == 3: | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| else: | |
| gray = image | |
| contrast = float(gray.std()) | |
| metrics.append({"metric": "Contraste", "value": int(min(100, contrast * 2))}) | |
| if contrast > 40: score += 30 | |
| # Resolution check | |
| h, w = image.shape[:2] | |
| metrics.append({"metric": "Résolution", "value": int(min(100, (h*w)/(1024*1024)*100))}) | |
| if h*w > 512*512: score += 30 | |
| return { | |
| "quality_score": min(100, score), | |
| "metrics": metrics | |
| } | |
| # 2. CONFIDENCE CALIBRATION | |
| def calibrate_confidence(raw_stats: List[float], labels: List[str]) -> float: | |
| """ | |
| Calibrate raw confidence scores. | |
| """ | |
| if not raw_stats: | |
| return 0.0 | |
| top_val = max(raw_stats) | |
| return float(round(top_val * 100, 2)) | |
| # 3. CLINICAL PRIORITY SCORING | |
| def calculate_priority_score(predictions: List[Dict], domain: str) -> str: | |
| """ | |
| Determine triage priority based on prediction severity. | |
| """ | |
| if not predictions: | |
| return "Normale" | |
| top_pred = predictions[0] | |
| label = top_pred["label"].lower() | |
| prob = top_pred["probability"] | |
| # Critical keywords | |
| critical_terms = ["malignant", "cancer", "carcinoma", "pneumonia", "pneumothorax", "fracture", "grade 4"] | |
| warning_terms = ["grade 2", "grade 3", "effusion", "edema", "abnormal"] | |
| if any(term in label for term in critical_terms) and prob > 50: | |
| return "Élevée" | |
| if any(term in label for term in warning_terms) and prob > 40: | |
| return "Moyenne" | |
| return "Normale" | |
| # 4. AUTOMATIC REPORT GENERATION | |
| def generate_clinical_report(analysis_result: Dict[str, Any], patient_info: Optional[Dict] = None) -> str: | |
| """ | |
| Generate a text summary of the findings using templates (Deterministic LLM-like). | |
| """ | |
| domain = analysis_result.get("domain", {}).get("label", "Unknown") | |
| specifics = analysis_result.get("specific", []) | |
| if not specifics: | |
| return "Analyse non concluante." | |
| top_finding = specifics[0] | |
| report = f"RAPPORT D'ANALYSE AUTOMATISÉE - {domain.upper()}\n" | |
| report += f"Date: {datetime.now().strftime('%d/%m/%Y %H:%M')}\n" | |
| if patient_info: | |
| report += f"Patient ID: {patient_info.get('id', 'N/A')}\n" | |
| report += "-" * 40 + "\n" | |
| report += f"Observation Principale: {top_finding['label']}\n" | |
| report += f"Confiance IA: {top_finding['probability']}%\n" | |
| priority = analysis_result.get("priority", "Normale") | |
| report += f"Priorité de Triage: {priority.upper()}\n\n" | |
| report += "Détails Techniques:\n" | |
| for i, det in enumerate(specifics[1:4]): | |
| report += f"- {det['label']}: {det['probability']}%\n" | |
| return report | |
| # 5. SIMILAR CASE DETECTION (Vector DB Mockup) | |
| class CaseRecord: | |
| id: str | |
| embedding: np.ndarray | |
| diagnosis: str | |
| domain: str | |
| probability: float | |
| username: str # Added for isolation | |
| class SimilarCaseDatabase: | |
| def __init__(self): | |
| self.cases: List[CaseRecord] = [] | |
| def add_case(self, case_id: str, embedding: np.ndarray, diagnosis: str, domain: str, probability: float, username: str): | |
| self.cases.append(CaseRecord(case_id, embedding, diagnosis, domain, probability, username)) | |
| # Keep manageable size | |
| if len(self.cases) > 1000: | |
| self.cases.pop(0) | |
| def find_similar(self, query_embedding: np.ndarray, username: str, top_k: int = 3, same_domain_only: bool = True, query_domain: str = None) -> List[Dict]: | |
| if not self.cases: | |
| return [] | |
| scores = [] | |
| for case in self.cases: | |
| # STRICT ISOLATION: Only compare with own cases | |
| if case.username != username: | |
| continue | |
| if same_domain_only and query_domain and case.domain != query_domain: | |
| continue | |
| # Cosine similarity | |
| dot_product = np.dot(query_embedding, case.embedding) | |
| norm_a = np.linalg.norm(query_embedding) | |
| norm_b = np.linalg.norm(case.embedding) | |
| similarity = dot_product / (norm_a * norm_b) if norm_a > 0 and norm_b > 0 else 0 | |
| scores.append((similarity, case)) | |
| scores.sort(key=lambda x: x[0], reverse=True) | |
| return [ | |
| { | |
| "case_id": c.id, | |
| "diagnosis": c.diagnosis, | |
| "similarity": round(float(s * 100), 1) | |
| } | |
| for s, c in scores[:top_k] | |
| ] | |
| # Global instance | |
| similar_case_db = SimilarCaseDatabase() | |
| def find_similar_cases(embedding: np.ndarray, domain: str, username: str, top_k: int = 5) -> Dict[str, Any]: | |
| """Find similar cases based on embedding, strictly isolated by user.""" | |
| similar = similar_case_db.find_similar( | |
| query_embedding=embedding, | |
| username=username, | |
| top_k=top_k, | |
| same_domain_only=True, | |
| query_domain=domain | |
| ) | |
| return { | |
| "similar_cases": similar, | |
| "cases_searched": len(similar_case_db.cases), | |
| "message": f"Trouvé {len(similar)} cas similaires" if similar else "Aucun cas similaire trouvé" | |
| } | |
| def store_case_for_similarity(case_id: str, embedding: np.ndarray, diagnosis: str, domain: str, probability: float, username: str): | |
| """Store a case for fiture similarity searches, isolated by user.""" | |
| similar_case_db.add_case( | |
| case_id=case_id, | |
| embedding=embedding, | |
| diagnosis=diagnosis, | |
| domain=domain, | |
| probability=probability, | |
| username=username | |
| ) | |
| # 6. ADAPTIVE PREPROCESSING | |
| def estimate_noise_level(image: np.ndarray) -> float: | |
| """Estimate noise level using Laplacian method.""" | |
| if len(image.shape) == 3: | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| else: | |
| gray = image | |
| # Use robust median absolute deviation | |
| laplacian = cv2.Laplacian(gray, cv2.CV_64F) | |
| sigma = np.median(np.abs(laplacian)) / 0.6745 | |
| return float(sigma) | |
| def apply_clahe(image: np.ndarray, clip_limit: float = 2.0, grid_size: int = 8) -> np.ndarray: | |
| """Apply Contrast Limited Adaptive Histogram Equalization.""" | |
| if len(image.shape) == 3: | |
| # Convert to LAB and apply to L channel | |
| lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) | |
| clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(grid_size, grid_size)) | |
| lab[:, :, 0] = clahe.apply(lab[:, :, 0]) | |
| return cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) | |
| else: | |
| clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(grid_size, grid_size)) | |
| return clahe.apply(image) | |
| def gamma_correction(image: np.ndarray, gamma: float = 1.0) -> np.ndarray: | |
| """Apply gamma correction for brightness adjustment.""" | |
| inv_gamma = 1.0 / gamma | |
| table = np.array([ | |
| ((i / 255.0) ** inv_gamma) * 255 | |
| for i in np.arange(0, 256) | |
| ]).astype("uint8") | |
| return cv2.LUT(image, table) | |
| def bilateral_denoise(image: np.ndarray, d: int = 9, sigma_color: int = 75, sigma_space: int = 75) -> np.ndarray: | |
| """Apply bilateral filter for edge-preserving denoising.""" | |
| return cv2.bilateralFilter(image, d, sigma_color, sigma_space) | |
| def adaptive_preprocessing(image_bytes: bytes) -> Tuple[Image.Image, Dict[str, Any]]: | |
| """ | |
| Apply intelligent preprocessing based on image analysis. | |
| Returns processed image and a log of transformations applied. | |
| """ | |
| # Decode image | |
| nparr = np.frombuffer(image_bytes, np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED) | |
| if img is None: | |
| raise ValueError("Could not decode image") | |
| transformations = [] | |
| original_stats = { | |
| "mean_brightness": float(np.mean(img)), | |
| "std_dev": float(np.std(img)) | |
| } | |
| # Convert to grayscale for analysis | |
| if len(img.shape) == 3: | |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| else: | |
| gray = img | |
| # Analyze histogram | |
| hist = cv2.calcHist([gray], [0], None, [256], [0, 256]).flatten() | |
| non_zero = np.where(hist > 0)[0] | |
| is_low_contrast = bool(len(non_zero) > 0 and (non_zero[-1] - non_zero[0]) < 150) | |
| is_dark = bool(np.mean(gray) < 60) | |
| is_bright = bool(np.mean(gray) > 200) | |
| noise_level = float(estimate_noise_level(gray)) | |
| # Apply adaptive corrections | |
| processed = img.copy() | |
| # 1. Low contrast - Apply CLAHE | |
| if is_low_contrast: | |
| processed = apply_clahe(processed, clip_limit=2.5) | |
| transformations.append({ | |
| "type": "CLAHE", | |
| "reason": "Faible contraste détecté", | |
| "params": {"clip_limit": 2.5} | |
| }) | |
| # 2. Dark image - Gamma correction | |
| if is_dark: | |
| processed = gamma_correction(processed, gamma=0.6) | |
| transformations.append({ | |
| "type": "Gamma Correction", | |
| "reason": "Image trop sombre", | |
| "params": {"gamma": 0.6} | |
| }) | |
| # 3. Overexposed - Inverse gamma | |
| if is_bright: | |
| processed = gamma_correction(processed, gamma=1.6) | |
| transformations.append({ | |
| "type": "Gamma Correction", | |
| "reason": "Image surexposée", | |
| "params": {"gamma": 1.6} | |
| }) | |
| # 4. Noisy - Bilateral filter | |
| if noise_level > 15: | |
| processed = bilateral_denoise(processed) | |
| transformations.append({ | |
| "type": "Bilateral Denoise", | |
| "reason": f"Bruit détecté (σ={noise_level:.1f})", | |
| "params": {"d": 9, "sigma": 75} | |
| }) | |
| # 5. Black level correction for X-rays (crush blacks) | |
| if len(processed.shape) == 2 or (len(processed.shape) == 3 and processed.shape[2] == 1): | |
| _, processed = cv2.threshold(processed, 15, 255, cv2.THRESH_TOZERO) | |
| transformations.append({ | |
| "type": "Black Level Crush", | |
| "reason": "Correction niveau noir (X-ray)", | |
| "params": {"threshold": 15} | |
| }) | |
| # Final normalization | |
| min_val, max_val = processed.min(), processed.max() | |
| if max_val > min_val: | |
| processed = ((processed - min_val) / (max_val - min_val) * 255).astype(np.uint8) | |
| transformations.append({ | |
| "type": "Normalization", | |
| "reason": "Normalisation finale", | |
| "params": {"min": float(min_val), "max": float(max_val)} | |
| }) | |
| # Convert to PIL Image | |
| if len(processed.shape) == 2: | |
| pil_image = Image.fromarray(processed).convert("RGB") | |
| else: | |
| pil_image = Image.fromarray(cv2.cvtColor(processed, cv2.COLOR_BGR2RGB)) | |
| preprocessing_log = { | |
| "original_stats": original_stats, | |
| "analysis": { | |
| "low_contrast": is_low_contrast, | |
| "dark": is_dark, | |
| "bright": is_bright, | |
| "noise_level": round(noise_level, 2) | |
| }, | |
| "transformations_applied": transformations, | |
| "transformation_count": len(transformations) | |
| } | |
| return pil_image, preprocessing_log | |
| # 7. ENHANCE ANALYSIS RESULT (PIPELINE) | |
| def enhance_analysis_result( | |
| base_result: Dict[str, Any], | |
| image_array: np.ndarray = None, | |
| embedding: np.ndarray = None, | |
| case_id: str = None, | |
| patient_info: Dict = None, | |
| username: str = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Enhance base analysis result with all 7 algorithms. | |
| This is the main entry point for the enhanced pipeline. | |
| """ | |
| enhanced = base_result.copy() | |
| # 1. Image Quality (if image provided) | |
| if image_array is not None: | |
| enhanced["image_quality"] = assess_image_quality(image_array) | |
| # 2. Confidence Calibration | |
| if "specific" in enhanced and enhanced["specific"]: | |
| raw_probs = [p["probability"] / 100 for p in enhanced["specific"]] | |
| labels = [p["label"] for p in enhanced["specific"]] | |
| enhanced["confidence"] = calibrate_confidence(raw_probs, labels=labels) | |
| # 3. Priority Scoring | |
| if "specific" in enhanced and enhanced["specific"]: | |
| domain = enhanced.get("domain", {}).get("label", "Unknown") | |
| enhanced["priority"] = calculate_priority_score(enhanced["specific"], domain) | |
| # 4. Similar Cases (if embedding provided AND username provided) | |
| if embedding is not None and "domain" in enhanced and username: | |
| domain = enhanced["domain"].get("label", "Unknown") | |
| enhanced["similar_cases"] = find_similar_cases(embedding, domain, username) | |
| # Store this case for future searches | |
| if case_id and enhanced["specific"]: | |
| top_pred = enhanced["specific"][0] | |
| store_case_for_similarity( | |
| case_id=case_id, | |
| embedding=embedding, | |
| diagnosis=top_pred["label"], | |
| domain=domain, | |
| probability=top_pred["probability"], | |
| username=username | |
| ) | |
| # 5. Generate Report - REMOVED HERE | |
| # Moved to predict() method to ensure it runs AFTER localization (Translation) | |
| # enhanced["report"] = ... | |
| return enhanced | |
| BASE_MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") | |
| NESTED_DIR = os.path.join(BASE_MODELS_DIR, "oeil d'elephant") | |
| MODEL_DIR = NESTED_DIR if os.path.exists(NESTED_DIR) else BASE_MODELS_DIR | |
| # Environment Detection | |
| ENVIRONMENT = os.getenv("ENVIRONMENT", "development") | |
| IS_PRODUCTION = ENVIRONMENT == "production" | |
| # Security Configuration - JWT Secret Key (ENFORCED in production) | |
| SECRET_KEY = os.getenv("JWT_SECRET_KEY") | |
| if not SECRET_KEY: | |
| if IS_PRODUCTION: | |
| logger.critical("🔴 FATAL ERROR: JWT_SECRET_KEY must be set in production environment") | |
| logger.critical("Generate one with: python -c 'import secrets; print(secrets.token_hex(32))'") | |
| sys.exit(1) # Fail-fast in production | |
| else: | |
| # Development fallback with warning | |
| from secrets import token_hex | |
| SECRET_KEY = "dev_insecure_key_" + token_hex(16) | |
| logger.warning("⚠️ WARNING: Using development JWT secret. DO NOT use in production!") | |
| ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256") | |
| ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("JWT_EXPIRE_MINUTES", "60")) | |
| logger.info(f"🌍 Environment: {ENVIRONMENT}") | |
| logger.info(f"✅ JWT SECRET_KEY: {'SET (secure)' if 'dev_insecure' not in SECRET_KEY else 'DEVELOPMENT MODE'}") | |
| # CORS Configuration | |
| CORS_ORIGINS_STR = os.getenv("CORS_ORIGINS", "http://localhost:5173,http://127.0.0.1:5173,http://localhost:5174,http://127.0.0.1:5174,http://localhost:5175,http://127.0.0.1:5175,http://localhost:5176,http://127.0.0.1:5176") | |
| CORS_ORIGINS = [origin.strip() for origin in CORS_ORIGINS_STR.split(",")] | |
| # Concurrency Control | |
| MAX_CONCURRENT_USERS = int(os.getenv("MAX_CONCURRENT_USERS", "200")) | |
| concurrency_semaphore = asyncio.Semaphore(MAX_CONCURRENT_USERS) | |
| # ========================================================================= | |
| # MODEL PATH CONFIGURATION (HuggingFace Hub or Local) | |
| # ========================================================================= | |
| def get_model_path(): | |
| """Get model path - download from HuggingFace Hub if not available locally.""" | |
| # Check environment variable first | |
| env_path = os.getenv("MODEL_DIR") | |
| if env_path and os.path.exists(env_path): | |
| logger.info(f"Using model from environment: {env_path}") | |
| return env_path | |
| # Check local path (development) | |
| local_path = os.path.join(os.path.dirname(__file__), "models", "oeil d'elephant") | |
| if os.path.exists(local_path): | |
| logger.info(f"Using local model: {local_path}") | |
| return local_path | |
| # Download from HuggingFace Hub (production/cloud) | |
| try: | |
| from huggingface_hub import snapshot_download | |
| logger.info("Downloading model from HuggingFace Hub...") | |
| hub_path = snapshot_download( | |
| repo_id="issoufzousko07/medsigclip-model", | |
| repo_type="model" | |
| ) | |
| logger.info(f"Model downloaded to: {hub_path}") | |
| return hub_path | |
| except Exception as e: | |
| logger.error(f"Failed to download model: {e}") | |
| raise RuntimeError(f"Model not found locally and failed to download: {e}") | |
| MODEL_DIR = None # Will be set at startup | |
| # OAuth2 Scheme | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
| # ========================================================================= | |
| # MEDICAL DOMAINS CONFIGURATION | |
| # ========================================================================= | |
| MEDICAL_DOMAINS = { | |
| 'Thoracic': { | |
| 'domain_prompt': 'Chest X-Ray Analysis', | |
| 'specific_labels': [ | |
| 'Diffuse interstitial opacities or ground-glass pattern (Viral/Atypical Pneumonia)', | |
| 'Focal alveolar consolidation with air bronchograms (Bacterial Pneumonia)', | |
| 'Perfectly clear lungs, sharp costophrenic angles, no pathology', | |
| 'Pneumothorax (Lung collapse)', | |
| 'Pleural Effusion (Fluid)', | |
| 'Cardiomegaly (Enlarged heart)', | |
| 'Pulmonary Edema', | |
| 'Lung Nodule or Mass', | |
| 'Atelectasis (Lung collapse)' | |
| ] | |
| }, | |
| 'Dermatology': { | |
| 'domain_prompt': 'Dermatoscopic analysis of a pigmented or non-pigmented skin lesion', | |
| 'specific_labels': [ | |
| 'A healthy skin area without lesion', | |
| 'A benign nevus (mole) regular, symmetrical and homogeneous', | |
| 'A seborrheic keratosis (benign warty lesion)', | |
| 'A malignant melanoma with asymmetry, irregular borders and multiple colors', | |
| 'A basal cell carcinoma (pearly or ulcerated lesion)', | |
| 'A squamous cell carcinoma (crusty or budding lesion)', | |
| 'A non-specific inflammatory skin lesion' | |
| ] | |
| }, | |
| 'Histology': { | |
| 'domain_prompt': 'Microscopic analysis of a histological section (H&E stain)', | |
| 'specific_labels': [ | |
| 'Healthy breast tissue with preserved lobular architecture', | |
| 'Healthy prostatic tissue with regular glands', | |
| 'Invasive ductal carcinoma of the breast (Disorganized cells)', | |
| 'Prostate adenocarcinoma (Gland fusion)', | |
| 'Cervical dysplasia or intraepithelial neoplasia', | |
| 'Colon cancer tumor tissue', | |
| 'Lung cancer tumor tissue', | |
| 'Adipose tissue (Fat) or connective stroma', | |
| 'Preparation artifact or empty area' | |
| ] | |
| }, | |
| 'Ophthalmology': { | |
| 'domain_prompt': 'Fundus photography (Retina)', | |
| 'specific_labels': [ | |
| 'Normal retina, healthy macula and optic disc', | |
| 'Diabetic retinopathy (hemorrhages, exudates, aneurysms)', | |
| 'Glaucoma (optic disc cupping)', | |
| 'Macular degeneration (drusen or atrophy)' | |
| ] | |
| }, | |
| 'Orthopedics': { | |
| 'domain_prompt': 'Bone X-Ray (Musculoskeletal)', | |
| 'stage_1_triage': { | |
| 'prompt': 'Anatomical region identification', | |
| 'labels': [ | |
| 'Other x-ray view (Chest, Hand, Foot, Pediatric) - OUT OF DISTRIBUTION', | |
| 'A knee x-ray view (Knee Joint)' | |
| ] | |
| }, | |
| 'stage_2_diagnosis': { | |
| 'prompt': 'Knee Osteoarthritis Severity Assessment', | |
| 'labels': [ | |
| 'Severe osteoarthritis with bone-on-bone contact and large osteophytes (Grade 4)', | |
| 'Moderate osteoarthritis with definite joint space narrowing (Grade 2-3)', | |
| 'Normal knee joint with preserved joint space and no osteophytes (Grade 0-1)', | |
| 'Total knee arthroplasty (TKA) with metallic implant', | |
| 'Acute knee fracture or dislocation' | |
| ] | |
| } | |
| } | |
| } | |
| # ========================================================================= | |
| # PYDANTIC MODELS | |
| # ========================================================================= | |
| class JobStatus(str, Enum): | |
| PENDING = "pending" | |
| PROCESSING = "processing" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| class Job(BaseModel): | |
| id: str | |
| status: JobStatus | |
| result: Optional[Dict[str, Any]] = None | |
| error: Optional[str] = None | |
| created_at: float | |
| storage_path: Optional[str] = None | |
| encrypted_user: Optional[str] = None | |
| username: Optional[str] = None # For registry logging | |
| file_type: Optional[str] = None # DICOM, PNG, JPEG | |
| start_time_ms: Optional[float] = None # For computation time | |
| class Token(BaseModel): | |
| access_token: str | |
| token_type: str | |
| class TokenData(BaseModel): | |
| username: Optional[str] = None | |
| class User(BaseModel): | |
| username: str | |
| email: Optional[str] = None | |
| class UserInDB(User): | |
| hashed_password: str | |
| security_question: str | |
| security_answer: str | |
| class UserRegister(BaseModel): | |
| username: str | |
| password: str | |
| email: Optional[str] = None | |
| security_question: str | |
| security_answer: str | |
| class UserResetPassword(BaseModel): | |
| username: str | |
| security_answer: str | |
| new_password: str | |
| class FeedbackModel(BaseModel): | |
| username: str | |
| rating: int | |
| comment: str | |
| # ========================================================================= | |
| # GLOBAL STATE | |
| # ========================================================================= | |
| jobs: Dict[str, Job] = {} # REMOVED: Now using SQLite persistence | |
| storage_provider = get_storage_provider(os.getenv("STORAGE_MODE", "LOCAL")) | |
| # Initialize Database | |
| database.init_db() | |
| # --- SEED DEFAULT USER --- | |
| # Ensure admin user exists for immediate login | |
| try: | |
| if not database.get_user_by_username("admin"): | |
| logging.info("👤 Creating default admin user...") | |
| # Hash "secret" | |
| admin_pw = bcrypt.hashpw(b"secret", bcrypt.gensalt()).decode('utf-8') | |
| security_ans = bcrypt.hashpw(b"admin", bcrypt.gensalt()).decode('utf-8') # Answer: admin | |
| database.create_user({ | |
| "username": "admin", | |
| "hashed_password": admin_pw, | |
| "email": "admin@elephmind.com", | |
| "security_question": "Who is the admin?", | |
| "security_answer": security_ans | |
| }) | |
| logging.info("✅ Default Admin Created: admin / secret") | |
| except Exception as e: | |
| logging.error(f"Failed to seed admin user: {e}") | |
| # ========================================================================= | |
| # AUTHENTICATION HELPERS | |
| # ========================================================================= | |
| from passlib.context import CryptContext | |
| pwd_context = CryptContext(schemes=["argon2", "bcrypt"], deprecated="auto") | |
| def verify_password(plain_password: str, hashed_password: str) -> bool: | |
| """Verify a password against a bcrypt hash using passlib.""" | |
| return pwd_context.verify(plain_password, hashed_password) | |
| def get_password_hash(password: str) -> str: | |
| """Generate bcrypt hash for a password using passlib.""" | |
| return pwd_context.hash(password) | |
| def get_user(db, username: str) -> Optional[UserInDB]: | |
| """Retrieve user from database.""" | |
| user_dict = database.get_user_by_username(username) | |
| if user_dict: | |
| return UserInDB(**user_dict) | |
| return None | |
| def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: | |
| """Create a JWT access token.""" | |
| to_encode = data.copy() | |
| expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15)) | |
| to_encode.update({"exp": expire}) | |
| return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
| async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInDB: | |
| """Dependency to get the current authenticated user.""" | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| try: | |
| payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| username: str = payload.get("sub") | |
| if username is None: | |
| raise credentials_exception | |
| token_data = TokenData(username=username) | |
| except JWTError: | |
| raise credentials_exception | |
| user = get_user(None, username=token_data.username) | |
| if user is None: | |
| raise credentials_exception | |
| return user | |
| async def get_current_active_user(current_user: UserInDB = Depends(get_current_user)) -> UserInDB: | |
| """Dependency to get current active user.""" | |
| # Logic to check if active could be added here | |
| # if not current_user.is_active: raise ... | |
| return current_user | |
| # ========================================================================= | |
| # GRAD-CAM UTILITIES (Moved to explainability.py) | |
| # ========================================================================= | |
| # (Refactored to separate module for medical grade validation) | |
| # ========================================================================= | |
| # MODEL WRAPPER | |
| # ========================================================================= | |
| class MedSigClipWrapper: | |
| """Wrapper for the SigLIP model with medical domain inference.""" | |
| def __init__(self, model_path: str): | |
| self.model_path = model_path | |
| self.processor = None | |
| self.model = None | |
| self.loaded = False | |
| self.load_error = None | |
| def load(self): | |
| """Load the SigLIP model from the specified directory.""" | |
| logger.info(f"Initiating model load from: {self.model_path}") | |
| if not os.path.exists(self.model_path): | |
| self.load_error = f"Model directory not found: {self.model_path}" | |
| logger.critical(self.load_error) | |
| return | |
| try: | |
| from transformers import AutoProcessor, AutoModel | |
| import torch | |
| self.processor = AutoProcessor.from_pretrained(self.model_path, local_files_only=True) | |
| self.model = AutoModel.from_pretrained(self.model_path, local_files_only=True) | |
| self.model.eval() | |
| # Calibrate logit scale for better probability distribution | |
| if hasattr(self.model, 'logit_scale'): | |
| with torch.no_grad(): | |
| self.model.logit_scale.data.fill_(3.80666) # ln(45) | |
| self.loaded = True | |
| logger.info("✅ MedSigClip Model Loaded Successfully (448x448 SigLIP architecture)") | |
| except Exception as e: | |
| self.load_error = f"Exception during load: {str(e)}" | |
| logger.error(f"Failed to load model: {str(e)}") | |
| def predict(self, image_bytes: bytes, username: str = None) -> Dict[str, Any]: | |
| """Run hierarchical inference using SigLIP Zero-Shot.""" | |
| # ... (rest of function until line 1094) ... | |
| # I need to match the indentation and context. | |
| # Since I can't see "inside" the dots in a replace, I have to be careful. | |
| # It's better to update just the definition line and the call to enhance_analysis_result. | |
| pass # Placeholder, will use multiple chunks below | |
| if not self.loaded: | |
| msg = "MedSigClip Model is NOT loaded. Cannot perform inference." | |
| if self.load_error: | |
| msg += f" Reason: {self.load_error}" | |
| raise RuntimeError(msg) | |
| logger.info("Starting inference pipeline...") | |
| start_time = time.time() | |
| try: | |
| from PIL import Image | |
| import io | |
| import torch | |
| import pydicom | |
| # Image preprocessing functions | |
| def process_dicom(file_bytes: bytes) -> Tuple[Image.Image, Dict[str, Any]]: | |
| """Convert DICOM bytes to PIL Image with tags.""" | |
| ds = pydicom.dcmread(io.BytesIO(file_bytes)) | |
| img = ds.pixel_array.astype(np.float32) | |
| # Extract Metadata | |
| metadata = { | |
| "patient_id": str(ds.get("PatientID", "N/A")), | |
| "patient_name": str(ds.get("PatientName", "N/A")), | |
| "birth_date": str(ds.get("PatientBirthDate", "")), | |
| "study_date": str(ds.get("StudyDate", "")), | |
| "modality": str(ds.get("Modality", "UNKNOWN")) | |
| } | |
| if hasattr(ds, 'PhotometricInterpretation') and ds.PhotometricInterpretation == "MONOCHROME1": | |
| img = img.max() - img | |
| # Lung Window: WL=-600, WW=1500 | |
| wl, ww = -600, 1500 | |
| min_val, max_val = wl - ww/2, wl + ww/2 | |
| img = np.clip(img, min_val, max_val) | |
| img = (img - min_val) / (max_val - min_val) | |
| img = (img * 255).astype(np.uint8) | |
| return Image.fromarray(img).convert("RGB"), metadata | |
| def process_standard_image(image_bytes: bytes) -> Image.Image: | |
| """Process standard images (PNG/JPG) - SIMPLIFIED like Colab. | |
| Just load the image as RGB without aggressive preprocessing.""" | |
| nparr = np.frombuffer(image_bytes, np.uint8) | |
| img_cv = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if img_cv is None: | |
| raise ValueError("Could not decode image") | |
| # Convert BGR to RGB (OpenCV uses BGR) | |
| img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB) | |
| return Image.fromarray(img_rgb) | |
| # Detect image format | |
| header = image_bytes[:32] | |
| is_png = header.startswith(b'\x89PNG\r\n\x1a\n') | |
| is_jpeg = header.startswith(b'\xff\xd8\xff') | |
| image = None | |
| dicom_metadata = None | |
| if is_png or is_jpeg: | |
| try: | |
| image = process_standard_image(image_bytes) | |
| logger.info(f"Processed as {'PNG' if is_png else 'JPEG'}") | |
| except Exception as e: | |
| raise ValueError(f"Corrupt Image File: {str(e)}") | |
| if image is None: | |
| try: | |
| image, dicom_metadata = process_dicom(image_bytes) | |
| logger.info("Processed as DICOM") | |
| except Exception: | |
| try: | |
| image = process_standard_image(image_bytes) | |
| except Exception as e: | |
| raise ValueError(f"Unknown image format: {str(e)}") | |
| # ========================================================= | |
| # ADAPTIVE PREPROCESSING - DISABLED to match Colab behavior | |
| # The model was trained on raw images, not preprocessed ones | |
| # ========================================================= | |
| preprocessing_log = {"message": "Preprocessing disabled for accuracy", "transformation_count": 0} | |
| # NOTE: Uncomment below to re-enable if needed | |
| # try: | |
| # import io as io_module | |
| # buffer = io_module.BytesIO() | |
| # image.save(buffer, format='PNG') | |
| # image_bytes_for_preprocessing = buffer.getvalue() | |
| # image, preprocessing_log = adaptive_preprocessing(image_bytes_for_preprocessing) | |
| # logger.info(f"🔧 Adaptive preprocessing applied: {preprocessing_log.get('transformation_count', 0)} transformations") | |
| # except Exception as e_preproc: | |
| # logger.warning(f"Adaptive preprocessing skipped: {e_preproc}") | |
| # STEP 1: DOMAIN IDENTIFICATION | |
| domain_keys = list(MEDICAL_DOMAINS.keys()) | |
| domain_prompts = [d['domain_prompt'] for d in MEDICAL_DOMAINS.values()] | |
| inputs_domain = self.processor( | |
| text=domain_prompts, | |
| images=image, | |
| padding="max_length", | |
| return_tensors="pt" | |
| ) | |
| with torch.no_grad(): | |
| outputs_domain = self.model(**inputs_domain) | |
| probs_domain = torch.softmax(outputs_domain.logits_per_image, dim=1)[0] | |
| best_domain_idx = torch.argmax(probs_domain).item() | |
| best_domain_key = domain_keys[best_domain_idx] | |
| best_domain_prob = float(probs_domain[best_domain_idx] * 100) | |
| logger.info(f"Identified Domain: {best_domain_key} ({best_domain_prob:.2f}%)") | |
| # STEP 2: SPECIFIC ANALYSIS | |
| domain_config = MEDICAL_DOMAINS[best_domain_key] | |
| specific_results = [] | |
| if 'stage_1_triage' in domain_config: | |
| # Hierarchical Logic (e.g., Orthopedics) | |
| logger.info(f"Engaging Level 2 Hierarchical Logic for: {best_domain_key}") | |
| triage_labels = domain_config['stage_1_triage']['labels'] | |
| inputs_triage = self.processor(text=triage_labels, images=image, padding="max_length", return_tensors="pt") | |
| with torch.no_grad(): | |
| out_triage = self.model(**inputs_triage) | |
| probs_triage = torch.softmax(out_triage.logits_per_image, dim=1)[0] | |
| prob_abnormal = float(probs_triage[-1]) | |
| prob_normal = 1.0 - prob_abnormal | |
| logger.info(f"Triage: Normal={prob_normal*100:.2f}%, Abnormal={prob_abnormal*100:.2f}%") | |
| if prob_abnormal > prob_normal: | |
| logger.info("Running Stage 2 Diagnosis...") | |
| diag_labels = domain_config['stage_2_diagnosis']['labels'] | |
| inputs_diag = self.processor(text=diag_labels, images=image, padding="max_length", return_tensors="pt") | |
| with torch.no_grad(): | |
| out_diag = self.model(**inputs_diag) | |
| probs_diag = torch.softmax(out_diag.logits_per_image, dim=1)[0] | |
| for i, label in enumerate(diag_labels): | |
| specific_results.append({ | |
| "label": label, | |
| "probability": round(float(probs_diag[i] * 100), 2) | |
| }) | |
| else: | |
| logger.info("Triage indicates Normal/Healthy. Skipping Stage 2.") | |
| else: | |
| # Flat Mode (Thoracic, Dermato, etc.) | |
| specific_labels_raw = domain_config['specific_labels'] | |
| inputs_specific = self.processor( | |
| text=specific_labels_raw, | |
| images=image, | |
| padding="max_length", | |
| return_tensors="pt" | |
| ) | |
| with torch.no_grad(): | |
| outputs_specific = self.model(**inputs_specific) | |
| probs_specific = torch.softmax(outputs_specific.logits_per_image, dim=1)[0] | |
| for i, label in enumerate(specific_labels_raw): | |
| specific_results.append({ | |
| "label": label, | |
| "probability": round(float(probs_specific[i] * 100), 2) | |
| }) | |
| specific_results.sort(key=lambda x: x['probability'], reverse=True) | |
| # STEP 3: HEATMAP GENERATION (Grad-CAM++ x MedSegCLIP) | |
| heatmap_base64 = None | |
| original_base64 = None | |
| try: | |
| if specific_results: | |
| top_label_text = specific_results[0]['label'] | |
| logger.info(f"Generating Medical Explanation for: {top_label_text}") | |
| # FIX: Initialize container for enhanced metadata | |
| enhanced_result = {} | |
| # --- QUALITY CONTROL GATE (Gate 1 & 2) --- | |
| # Added per user request: Verify quality before deep explanation/analysis | |
| # Ideally this should be even earlier, but performing it here ensures we have the image object ready. | |
| from quality_control import QualityControlEngine | |
| qc_engine = QualityControlEngine() | |
| qc_result = qc_engine.run_quality_check(image) | |
| enhanced_result['image_quality'] = { | |
| "quality_score": qc_result['quality_score'], | |
| "passed": qc_result['passed'], | |
| "reasons": qc_result['reasons'], | |
| "metrics": qc_result['metrics'] | |
| } | |
| if not qc_result['passed']: | |
| logger.warning(f"⛔ Quality Control Failed: {qc_result['reasons']}") | |
| # STRICT REJECTION: Override diagnosis and clear predictions | |
| enhanced_result['diagnosis'] = "Analyse Refusée (Qualité Insuffisante)" | |
| enhanced_result['confidence'] = 0.0 | |
| enhanced_result['specific'] = [] # Clear predictions | |
| enhanced_result['quality_failure_reasons'] = qc_result['reasons'] | |
| enhanced_result['image_quality'] = { | |
| "quality_score": qc_result['quality_score'], | |
| "passed": False, | |
| "reasons": qc_result['reasons'], | |
| "metrics": qc_result['metrics'] | |
| } | |
| # FIX: Construct localized_result explicitly as it is not defined yet | |
| rejection_result = { | |
| "domain": { | |
| "label": best_domain_key, | |
| "description": MEDICAL_DOMAINS[best_domain_key]['domain_prompt'], | |
| "probability": round(best_domain_prob, 2) | |
| }, | |
| "specific": [], | |
| "heatmap": None, | |
| "original_image": None, # Will be handled by frontend fallback or we can encode it here if mostly wanted | |
| "preprocessing": preprocessing_log, | |
| "explainability": { | |
| "method": "QC Rejection", | |
| "reliability": 0.0 | |
| } | |
| } | |
| # We merge enhanced info | |
| rejection_result.update(enhanced_result) | |
| # Encode Original Image even on Rejection for Context | |
| try: | |
| img_tensor = np.array(image).astype(np.float32) / 255.0 | |
| original_uint8 = (img_tensor * 255).astype(np.uint8) | |
| _, buffer_orig = cv2.imencode('.png', cv2.cvtColor(original_uint8, cv2.COLOR_RGB2BGR)) | |
| rejection_result["original_image"] = base64.b64encode(buffer_orig).decode('utf-8') | |
| except: | |
| pass | |
| return rejection_result | |
| # If QC Passed, Proceed to Explanation | |
| import explainability | |
| engine = explainability.ExplainabilityEngine(self) | |
| # Define Anatomical Context based on Domain | |
| anatomical_context = "body part" # Default | |
| if best_domain_key == 'Thoracic': anatomical_context = "lung parenchyma" | |
| elif best_domain_key == 'Orthopedics': anatomical_context = "bone structure" | |
| elif best_domain_key == 'Dermatology': anatomical_context = "skin lesion" | |
| elif best_domain_key == 'Ophthalmology': anatomical_context = "retina" | |
| explanation = engine.explain( | |
| image=image, | |
| target_text=top_label_text, | |
| anatomical_context=anatomical_context | |
| ) | |
| if explanation['heatmap_array'] is not None: | |
| # Encode Visualization | |
| vis_img = explanation['heatmap_array'] | |
| _, buffer = cv2.imencode('.png', cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR)) | |
| heatmap_base64 = base64.b64encode(buffer).decode('utf-8') | |
| # Original Image (Normalized for consistency) | |
| img_tensor = np.array(image).astype(np.float32) / 255.0 | |
| original_uint8 = (img_tensor * 255).astype(np.uint8) | |
| _, buffer_orig = cv2.imencode('.png', cv2.cvtColor(original_uint8, cv2.COLOR_RGB2BGR)) | |
| original_base64 = base64.b64encode(buffer_orig).decode('utf-8') | |
| reliability = explanation.get("reliability_score", 0) | |
| logger.info(f"✅ Explanation Generated. Reliability: {reliability} ({explanation.get('confidence_label')})") | |
| else: | |
| logger.warning("Could not generate explainability map.") | |
| except Exception as e_cam: | |
| import traceback | |
| logger.error(f"Explainability Pipeline Failed: {traceback.format_exc()}") | |
| # FINAL RESULT (Base) | |
| result_json = { | |
| "domain": { | |
| "label": best_domain_key, | |
| "description": MEDICAL_DOMAINS[best_domain_key]['domain_prompt'], | |
| "probability": round(best_domain_prob, 2) | |
| }, | |
| "specific": specific_results, | |
| "heatmap": heatmap_base64, | |
| "original_image": original_base64, | |
| "preprocessing": preprocessing_log, | |
| "explainability": { # NEW METADATA | |
| "method": "Grad-CAM++ x MedSegCLIP (Proxy)", | |
| "anatomical_context": anatomical_context if 'anatomical_context' in locals() else "Unknown", | |
| "reliability": explanation.get("reliability_score") if 'explanation' in locals() else 0 | |
| } | |
| } | |
| # ========================================================= | |
| # APPLY 7 INTELLIGENCE ALGORITHMS | |
| # ========================================================= | |
| logger.info("🧠 Applying Intelligence Algorithms...") | |
| # Convert PIL image to numpy for quality assessment | |
| image_array = np.array(image) | |
| # Get image embedding for similar case detection | |
| # OPT PERFORMANCE: Disabled to reduce inference time (<60s) | |
| image_embedding = None | |
| # try: | |
| # with torch.no_grad(): | |
| # img_inputs = self.processor(images=image, return_tensors="pt") | |
| # image_embedding = self.model.get_image_features(**img_inputs) | |
| # image_embedding = image_embedding.cpu().numpy().flatten() | |
| # except Exception as e_emb: | |
| # logger.warning(f"Could not extract embedding: {e_emb}") | |
| # image_embedding = None | |
| # Enhance result with all algorithms | |
| enhanced_result = enhance_analysis_result( | |
| base_result=result_json, | |
| image_array=image_array, | |
| embedding=image_embedding, | |
| case_id=str(uuid.uuid4()), | |
| patient_info=None, | |
| username=username | |
| ) | |
| # --- LOCALIZATION (Translate to French) --- | |
| localized_result = localize_result(enhanced_result) | |
| # --- GENERATE REPORT (After Localization) --- | |
| # Now the labels in localized_result['specific'] are in French | |
| localized_result["report"] = generate_clinical_report( | |
| localized_result, | |
| patient_info=None | |
| ) | |
| # --- MAP TO FRONTEND EXPECTATIONS --- | |
| # frontend expects: diagnosis, confidence, productions, quality_metrics, etc. | |
| # 1. Diagnosis | |
| top_finding = enhanced_result['specific'][0] if enhanced_result['specific'] else {"label": "Inconnu", "probability": 0} | |
| enhanced_result['diagnosis'] = top_finding['label'] | |
| # 2. Confidence & Calibrated | |
| enhanced_result['calibrated_confidence'] = enhanced_result.get('confidence', top_finding['probability']) | |
| enhanced_result['confidence'] = top_finding['probability'] | |
| # 3. Processing Time (Real Measurement) | |
| enhanced_result['processing_time'] = round(time.time() - start_time, 3) | |
| # 4. Predictions (Alias for specific) | |
| enhanced_result['predictions'] = [ | |
| {"name": item['label'], "probability": item['probability']} | |
| for item in enhanced_result['specific'] | |
| ] | |
| # 5. Quality Metrics (Flatten structure) | |
| if 'image_quality' in enhanced_result: | |
| enhanced_result['quality_score'] = enhanced_result['image_quality']['quality_score'] | |
| enhanced_result['quality_metrics'] = enhanced_result['image_quality']['metrics'] | |
| # 6. Priority | |
| # If priority is a dict (from new algo), extract just the level/score for simple display, or keep object | |
| # Frontend expects string 'priority' sometimes, or maybe object. Let's provide string for badge. | |
| if isinstance(enhanced_result.get('priority'), str): | |
| pass | |
| elif isinstance(enhanced_result.get('priority'), dict): | |
| # Flatten for frontend simple badge | |
| enhanced_result['priority'] = enhanced_result['priority'].get('level', 'Normale') | |
| # 7. DICOM Metadata (if available) | |
| if dicom_metadata: | |
| enhanced_result['patient_metadata'] = dicom_metadata | |
| logger.info("✅ Intelligence Algorithms applied successfully") | |
| logger.info("✅ Intelligence Algorithms applied successfully") | |
| return localized_result | |
| except Exception as e: | |
| logger.error(f"Inference Error: {str(e)}") | |
| raise e | |
| # ========================================================================= | |
| # GLOBAL MODEL INSTANCE | |
| # ========================================================================= | |
| model_wrapper: Optional[MedSigClipWrapper] = None | |
| # ========================================================================= | |
| # FASTAPI LIFECYCLE | |
| # ========================================================================= | |
| async def lifespan(app: FastAPI): | |
| global model_wrapper, MODEL_DIR # CRITICAL: Use global variables | |
| database.init_db() | |
| database.init_analysis_registry() | |
| # Get model path (downloads from HuggingFace Hub if needed) | |
| MODEL_DIR = get_model_path() | |
| model_wrapper = MedSigClipWrapper(MODEL_DIR) | |
| model_wrapper.load() | |
| logger.info("ElephMind Backend Started") | |
| yield | |
| logger.info("ElephMind Backend Shutting Down") | |
| app = FastAPI( | |
| lifespan=lifespan, | |
| title="ElephMind Medical AI API", | |
| version="2.0.0", | |
| description="Medical image analysis powered by SigLIP" | |
| ) | |
| # CORS Middleware with configurable origins | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allow all origins to fix "Failed to fetch" for user | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| from fastapi.exceptions import RequestValidationError | |
| from fastapi.responses import JSONResponse | |
| from fastapi.encoders import jsonable_encoder | |
| async def validation_exception_handler(request: Request, exc: RequestValidationError): | |
| """ | |
| Handle validation errors gracefully, stripping binary/unsafe data from logs/response. | |
| Fixes UnicodeDecodeError when multipart/form-data is sent to JSON endpoint. | |
| """ | |
| errors = exc.errors() | |
| clean_errors = [] | |
| for error in errors: | |
| copy = error.copy() | |
| # Remove raw binary input which causes jsonable_encoder crash | |
| if 'input' in copy: | |
| val = copy['input'] | |
| if isinstance(val, (bytes, bytearray)): | |
| copy['input'] = "<binary_data_stripped>" | |
| elif isinstance(val, str) and len(val) > 200: | |
| copy['input'] = val[:200] + "..." # Truncate long inputs | |
| clean_errors.append(copy) | |
| logger.warning(f"Validation Error on {request.url.path}: {clean_errors}") | |
| return JSONResponse( | |
| status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | |
| content={"detail": jsonable_encoder(clean_errors)}, | |
| ) | |
| async def limit_concurrency(request: Request, call_next): | |
| """Limit concurrent requests to MAX_CONCURRENT_USERS.""" | |
| if request.url.path == "/health" or request.method == "OPTIONS": | |
| return await call_next(request) | |
| if concurrency_semaphore.locked(): | |
| logger.warning(f"Concurrency limit ({MAX_CONCURRENT_USERS}) reached. Request queued.") | |
| async with concurrency_semaphore: | |
| return await call_next(request) | |
| # ========================================================================= | |
| # BACKGROUND WORKER | |
| # ========================================================================= | |
| # ========================================================================= | |
| # BACKGROUND WORKER (Decoupled) | |
| # ========================================================================= | |
| async def process_analysis_job(job_id: str, image_id: str, username: str): | |
| """ | |
| Worker that retrieves image from disk by ID and processes it. | |
| Zero-shared-memory with API. | |
| """ | |
| # RESILIENCE: Retrieve job from DB | |
| job = database.get_job(job_id) | |
| if not job: | |
| logger.error(f"❌ Job {job_id} not found DB") | |
| return | |
| logger.info(f"Worker processing Job {job_id} (Image: {image_id})") | |
| database.update_job_status(job_id, JobStatus.PROCESSING.value) | |
| start_time = time.time() | |
| try: | |
| if not model_wrapper: | |
| raise RuntimeError("Model wrapper not initialized.") | |
| # LOAD IMAGE FROM DISK (Physical Read) | |
| image_bytes, file_path = storage_manager.load_image(username, image_id) | |
| loop = asyncio.get_event_loop() | |
| # Pass username to predict for isolation | |
| import functools | |
| result = await loop.run_in_executor(None, functools.partial(model_wrapper.predict, image_bytes, username=username)) | |
| # Calculate computation time | |
| computation_time_ms = int((time.time() - start_time) * 1000) | |
| # Update Job in DB | |
| database.update_job_status(job_id, JobStatus.COMPLETED.value, result=result) | |
| # Log to registry (REAL DATA) | |
| if username and result: | |
| domain = result.get('domain', {}).get('label', 'Unknown') | |
| top_diag = result.get('specific', [{}])[0].get('label', 'Unknown') if result.get('specific') else 'Unknown' | |
| confidence = result.get('specific', [{}])[0].get('probability', 0) if result.get('specific') else 0 | |
| priority = result.get('priority', 'Normale') | |
| database.log_analysis( | |
| username=username, | |
| domain=domain, | |
| top_diagnosis=top_diag, | |
| confidence=confidence, | |
| priority=priority, | |
| computation_time_ms=computation_time_ms, | |
| file_type='SavedImage' | |
| ) | |
| logger.info(f"✅ Job {job_id} logged to registry") | |
| logger.info(f"✅ Job {job_id} completed in {computation_time_ms}ms") | |
| except Exception as e: | |
| logger.error(f"❌ Job {job_id} failed: {str(e)}") | |
| database.update_job_status(job_id, JobStatus.FAILED.value, error=str(e)) | |
| # ========================================================================= | |
| # API ENDPOINTS | |
| # ========================================================================= | |
| # --- Authentication --- | |
| async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): | |
| """Authenticate user and return JWT token.""" | |
| user = database.get_user_by_username(form_data.username) | |
| if not user or not verify_password(form_data.password, user['hashed_password']): | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Incorrect username or password", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| access_token = create_access_token( | |
| data={"sub": user['username']}, | |
| expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
| ) | |
| return {"access_token": access_token, "token_type": "bearer"} | |
| class AnalysisRequest(BaseModel): | |
| image_id: str | |
| domain: str = "Triage" | |
| priority: str = "Normale" | |
| async def register_user(user: UserRegister): | |
| """Register a new user.""" | |
| hashed_pw = get_password_hash(user.password) | |
| # Hash security answer too for extra security | |
| hashed_security_answer = get_password_hash(user.security_answer.strip().lower()) | |
| user_data = { | |
| "username": user.username, | |
| "hashed_password": hashed_pw, | |
| "email": user.email, | |
| "security_question": user.security_question, | |
| "security_answer": hashed_security_answer | |
| } | |
| success = database.create_user(user_data) | |
| if not success: | |
| raise HTTPException(status_code=400, detail="Username already exists") | |
| return {"message": "User created successfully"} | |
| async def get_security_question(username: str): | |
| """Get security question for password recovery.""" | |
| user = database.get_user_by_username(username) | |
| if not user: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| return {"question": user['security_question']} | |
| async def reset_password(data: UserResetPassword): | |
| """Reset password using security question.""" | |
| user = database.get_user_by_username(data.username) | |
| if not user: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| # Verify security answer (hashed comparison) | |
| if not verify_password(data.security_answer.strip().lower(), user['security_answer']): | |
| raise HTTPException(status_code=400, detail="Incorrect security answer") | |
| new_hashed_pw = get_password_hash(data.new_password) | |
| database.update_password(data.username, new_hashed_pw) | |
| return {"message": "Password reset successfully"} | |
| # --- Dashboard Analytics (REAL DATA ONLY) --- | |
| async def get_dashboard_statistics(current_user: User = Depends(get_current_user)): | |
| """ | |
| Get real dashboard statistics for the authenticated user. | |
| Returns zeros if no analyses have been performed. NO FAKE DATA. | |
| """ | |
| stats = database.get_dashboard_stats(current_user.username) | |
| recent = database.get_recent_analyses(current_user.username, limit=10) | |
| return { | |
| **stats, | |
| "recent_analyses": recent | |
| } | |
| async def submit_feedback(feedback: FeedbackModel): | |
| """Submit user feedback.""" | |
| database.add_feedback(feedback.username, feedback.rating, feedback.comment) | |
| return {"message": "Feedback received"} | |
| # --- Medical Analysis --- | |
| # --- Analysis Flow (Async Job Architecture) --- | |
| # Local modules | |
| import database | |
| import storage_manager | |
| import dicom_processor # NEW: Medical Validation | |
| from database import JobStatus | |
| from storage import get_storage_provider | |
| # ... | |
| async def upload_image( | |
| file: UploadFile = File(...), | |
| current_user: User = Depends(get_current_active_user) | |
| ): | |
| """ | |
| Step 1: Upload image to physical storage. | |
| - VALIDATES DICOM Compliance (if .dcm) | |
| - ANONYMIZES Patient Data (PHI) | |
| - Returns image_id to be used in analysis. | |
| """ | |
| try: | |
| content = await file.read() | |
| # Detect DICOM Magic Bytes (DICM at offset 128) | |
| is_dicom = len(content) > 132 and content[128:132] == b'DICM' | |
| if is_dicom: | |
| logger.info(f"DICOM File detected for user {current_user.username}. Validating...") | |
| try: | |
| # Validate & Anonymize | |
| safe_content, metadata = dicom_processor.process_dicom_upload(content, current_user.username) | |
| # Use safe content for storage | |
| content = safe_content | |
| logger.info("✅ DICOM Validated and Anonymized.") | |
| except ValueError as ve: | |
| logger.error(f"❌ DICOM Rejected: {ve}") | |
| raise HTTPException(status_code=400, detail=f"Conformité DICOM refusée: {str(ve)}") | |
| # Save to Disk | |
| image_id = storage_manager.save_image( | |
| username=current_user.username, | |
| file_bytes=content, | |
| filename_hint=file.filename if not is_dicom else "anon.dcm" | |
| ) | |
| return { | |
| "image_id": image_id, | |
| "status": "UPLOADED", | |
| "message": "Image secured & sanitized. Ready for analysis." | |
| } | |
| except HTTPException as he: | |
| raise he | |
| except Exception as e: | |
| logger.error(f"Upload failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Upload Error: {str(e)}") | |
| async def analyze_image( | |
| request: AnalysisRequest, | |
| background_tasks: BackgroundTasks, | |
| current_user: User = Depends(get_current_active_user) | |
| ): | |
| """ | |
| Step 2: Create Analysis Job using existing image_id. | |
| Decoupled from upload. | |
| """ | |
| if not model_wrapper or not model_wrapper.loaded: | |
| raise HTTPException(status_code=503, detail="Model not loaded yet") | |
| # Verify image exists physically | |
| try: | |
| _ = storage_manager.get_image_absolute_path(current_user.username, request.image_id) | |
| if not _: | |
| raise FileNotFoundError() | |
| except Exception: | |
| raise HTTPException(status_code=404, detail="Image ID not found. Upload first.") | |
| # Create Job ID | |
| task_id = str(uuid.uuid4()) | |
| # Persist Job PENDING state | |
| job_data = { | |
| 'id': task_id, | |
| 'status': JobStatus.PENDING.value, | |
| 'created_at': time.time(), | |
| 'result': None, | |
| 'error': None, | |
| 'storage_path': request.image_id, # Link to storage | |
| 'username': current_user.username, | |
| 'file_type': 'Unknown' | |
| } | |
| database.create_job(job_data) | |
| # Enqueue Worker (Pass ID, not bytes) | |
| background_tasks.add_task(process_analysis_job, task_id, request.image_id, current_user.username) | |
| return { | |
| "task_id": task_id, | |
| "status": "queued", | |
| "image_id": request.image_id | |
| } | |
| async def get_current_job(current_user: User = Depends(get_current_active_user)): | |
| """ | |
| Get the latest job state for the user to restore UI on refresh. | |
| Returns 404 if no recent job found (< 24h). | |
| """ | |
| job = database.get_latest_job(current_user.username) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="No active job") | |
| # Check if job is stale (e.g. > 24 hours old) | |
| # If completed and old, we might not want to auto-load it on fresh login | |
| # But for F5 refresh, we definitely want it. | |
| # Heuristic: If < 1 hour, always return. | |
| created_at = job.get('created_at', 0) | |
| if time.time() - created_at > 86400: # 24 hours | |
| raise HTTPException(status_code=404, detail="Job expired") | |
| return { | |
| "task_id": job['id'], | |
| "status": job['status'], | |
| "result": job['result'], | |
| "error": job.get('error'), | |
| "created_at": created_at, | |
| "image_id": job.get('storage_path') | |
| } | |
| async def get_result(task_id: str, current_user: User = Depends(get_current_user)): | |
| """ | |
| Get analysis result by task ID. | |
| - **Requires authentication** | |
| - Returns job status and results when complete | |
| """ | |
| # Retrieve job from DB - ENFORCE OWNERSHIP AT SQL LEVEL | |
| job = database.get_job(task_id, username=current_user.username) | |
| if not job: | |
| # If job calls return None with username, it means either 404 or 403 (effectively 404 for security) | |
| raise HTTPException(status_code=404, detail="Job not found or access denied") | |
| # Redundant check removed as SQL handles it, but kept for audit logging if needed | |
| # if job.get('username') != current_user.username: ... | |
| logger.info(f"Polling Job {task_id}: Status={job.get('status')}") | |
| return job | |
| def health_check(): | |
| """Health check endpoint.""" | |
| loaded = model_wrapper.loaded if model_wrapper else False | |
| return { | |
| "status": "running", | |
| "model_loaded": loaded, | |
| "version": "2.0.0" | |
| } | |
| async def root(): | |
| """Redirect root to docs.""" | |
| from fastapi.responses import RedirectResponse | |
| return RedirectResponse(url="/docs") | |
| # --- DASHBOARD ENDPOINTS --- | |
| async def get_dashboard_stats_endpoint(current_user: User = Depends(get_current_user)): | |
| """Get real dashboard statistics for the authenticated user.""" | |
| try: | |
| stats = database.get_dashboard_stats(current_user.username) | |
| recent = database.get_recent_analyses(current_user.username, limit=5) | |
| # Combine | |
| return { | |
| **stats, | |
| "recent_analyses": recent | |
| } | |
| except Exception as e: | |
| logger.error(f"Error fetching dashboard stats: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ========================================================================= | |
| # PATIENT API (New for Migration) | |
| # ========================================================================= | |
| class PatientCreate(BaseModel): | |
| patient_id: str | |
| first_name: str = "" | |
| last_name: str = "" | |
| birth_date: str = "" | |
| photo: Optional[str] = None # Base64 or URL | |
| class PatientUpdate(BaseModel): | |
| first_name: Optional[str] = None | |
| last_name: Optional[str] = None | |
| birth_date: Optional[str] = None | |
| photo: Optional[str] = None | |
| async def create_patient_endpoint(patient: PatientCreate, current_user: User = Depends(get_current_user)): | |
| """Create a new patient.""" | |
| pid = database.create_patient( | |
| owner_username=current_user.username, | |
| patient_id=patient.patient_id, | |
| first_name=patient.first_name, | |
| last_name=patient.last_name, | |
| birth_date=patient.birth_date, | |
| photo=patient.photo | |
| ) | |
| if not pid: | |
| raise HTTPException(status_code=400, detail="Could not create patient (ID might exist)") | |
| return {"id": pid, "message": "Patient created"} | |
| async def get_patients_endpoint(current_user: User = Depends(get_current_user)): | |
| """Get all patients for the current user.""" | |
| return database.get_patients_by_user(current_user.username) | |
| async def update_patient_endpoint(patient_id: int, updates: PatientUpdate, current_user: User = Depends(get_current_user)): | |
| """Update a patient.""" | |
| data = updates.dict(exclude_unset=True) | |
| if not data: | |
| raise HTTPException(status_code=400, detail="No data to update") | |
| success = database.update_patient(current_user.username, patient_id, data) | |
| if not success: | |
| raise HTTPException(status_code=404, detail="Patient not found") | |
| return {"message": "Patient updated"} | |
| async def delete_patient_endpoint(patient_id: int, current_user: User = Depends(get_current_user)): | |
| """Delete a patient.""" | |
| success = database.delete_patient(current_user.username, patient_id) | |
| if not success: | |
| raise HTTPException(status_code=404, detail="Patient not found") | |
| return {"message": "Patient deleted"} | |
| async def get_dashboard_stats_endpoint(current_user: User = Depends(get_current_user)): | |
| """Get dashboard statistics and recent analyses.""" | |
| stats = database.get_dashboard_stats(current_user.username) | |
| recent = database.get_recent_analyses(current_user.username, limit=10) | |
| stats["recent_analyses"] = recent | |
| return stats | |
| # ========================================================================= | |
| # MAIN ENTRY POINT | |
| # ========================================================================= | |
| if __name__ == "__main__": | |
| # Initialize DB tables including registry | |
| database.init_db() | |
| database.init_analysis_registry() | |
| host = os.getenv("SERVER_HOST", "0.0.0.0") | |
| port = int(os.getenv("SERVER_PORT", "8022")) | |
| uvicorn.run(app, host=host, port=port) | |