elephmind-api / main.py
zousko-stark's picture
Upload folder using huggingface_hub
6a8cc1a verified
raw
history blame
70.2 kB
"""
ElephMind Medical AI Backend
============================
Production-ready FastAPI backend for medical image analysis using SigLIP.
Author: ElephMind Team
Version: 2.0.0 (Cleaned & Secured)
"""
import sys
import os
import uuid
import asyncio
import time
import logging
# --- DOTENV SUPPORT (MUST BE FIRST) ---
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
from enum import Enum
from typing import Dict, List, Optional, Any, Tuple
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from datetime import datetime
from contextlib import asynccontextmanager
import uvicorn
import base64
import cv2
import numpy as np
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.image import show_cam_on_image
from localization import localize_result
import torch
import torch.nn as nn
# Local modules
import database
import storage_manager # NEW: Physical storage layout
from database import JobStatus
from storage import get_storage_provider
import encryption
import database
# algorithms imported directly above
import math
from collections import deque
from dataclasses import dataclass, field
from PIL import Image
import io
# --- GRADCAM UTILS FOR SIGLIP/ViT ---
# Class moved to Line 781 (Deduplication)
# Function moved to Line 798 (Deduplication)
# --- AUTH IMPORTS ---
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi import Depends, status, Request
from datetime import datetime, timedelta
from jose import JWTError, jwt
import bcrypt
# --- DOTENV (Moved to top) ---
# =========================================================================
# LOGGING CONFIGURATION
# =========================================================================
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("ElephMind-Backend")
# =========================================================================
# 7 INTELLIGENCE ALGORITHMS (Merged from algorithms.py)
# =========================================================================
# 1. IMAGE QUALITY ASSESSMENT
def detect_blur(image: np.ndarray) -> float:
"""
Detect blur using Laplacian variance.
Higher score = sharper image.
Returns: 0.0 (very blurry) to 1.0 (very sharp)
"""
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image
laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
# Normalize to 0-1 (empirical thresholds for medical images)
return min(1.0, laplacian_var / 500.0)
def assess_image_quality(image: np.ndarray) -> Dict[str, Any]:
"""Assess image quality metrics."""
score = 0
metrics = []
# Blur detection
sharpness = detect_blur(image)
metrics.append({"metric": "Netteté", "value": int(sharpness * 100)})
if sharpness > 0.6: score += 40
elif sharpness > 0.3: score += 20
# Contrast check
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image
contrast = float(gray.std())
metrics.append({"metric": "Contraste", "value": int(min(100, contrast * 2))})
if contrast > 40: score += 30
# Resolution check
h, w = image.shape[:2]
metrics.append({"metric": "Résolution", "value": int(min(100, (h*w)/(1024*1024)*100))})
if h*w > 512*512: score += 30
return {
"quality_score": min(100, score),
"metrics": metrics
}
# 2. CONFIDENCE CALIBRATION
def calibrate_confidence(raw_stats: List[float], labels: List[str]) -> float:
"""
Calibrate raw confidence scores.
"""
if not raw_stats:
return 0.0
top_val = max(raw_stats)
return float(round(top_val * 100, 2))
# 3. CLINICAL PRIORITY SCORING
def calculate_priority_score(predictions: List[Dict], domain: str) -> str:
"""
Determine triage priority based on prediction severity.
"""
if not predictions:
return "Normale"
top_pred = predictions[0]
label = top_pred["label"].lower()
prob = top_pred["probability"]
# Critical keywords
critical_terms = ["malignant", "cancer", "carcinoma", "pneumonia", "pneumothorax", "fracture", "grade 4"]
warning_terms = ["grade 2", "grade 3", "effusion", "edema", "abnormal"]
if any(term in label for term in critical_terms) and prob > 50:
return "Élevée"
if any(term in label for term in warning_terms) and prob > 40:
return "Moyenne"
return "Normale"
# 4. AUTOMATIC REPORT GENERATION
def generate_clinical_report(analysis_result: Dict[str, Any], patient_info: Optional[Dict] = None) -> str:
"""
Generate a text summary of the findings using templates (Deterministic LLM-like).
"""
domain = analysis_result.get("domain", {}).get("label", "Unknown")
specifics = analysis_result.get("specific", [])
if not specifics:
return "Analyse non concluante."
top_finding = specifics[0]
report = f"RAPPORT D'ANALYSE AUTOMATISÉE - {domain.upper()}\n"
report += f"Date: {datetime.now().strftime('%d/%m/%Y %H:%M')}\n"
if patient_info:
report += f"Patient ID: {patient_info.get('id', 'N/A')}\n"
report += "-" * 40 + "\n"
report += f"Observation Principale: {top_finding['label']}\n"
report += f"Confiance IA: {top_finding['probability']}%\n"
priority = analysis_result.get("priority", "Normale")
report += f"Priorité de Triage: {priority.upper()}\n\n"
report += "Détails Techniques:\n"
for i, det in enumerate(specifics[1:4]):
report += f"- {det['label']}: {det['probability']}%\n"
return report
# 5. SIMILAR CASE DETECTION (Vector DB Mockup)
@dataclass
class CaseRecord:
id: str
embedding: np.ndarray
diagnosis: str
domain: str
probability: float
username: str # Added for isolation
class SimilarCaseDatabase:
def __init__(self):
self.cases: List[CaseRecord] = []
def add_case(self, case_id: str, embedding: np.ndarray, diagnosis: str, domain: str, probability: float, username: str):
self.cases.append(CaseRecord(case_id, embedding, diagnosis, domain, probability, username))
# Keep manageable size
if len(self.cases) > 1000:
self.cases.pop(0)
def find_similar(self, query_embedding: np.ndarray, username: str, top_k: int = 3, same_domain_only: bool = True, query_domain: str = None) -> List[Dict]:
if not self.cases:
return []
scores = []
for case in self.cases:
# STRICT ISOLATION: Only compare with own cases
if case.username != username:
continue
if same_domain_only and query_domain and case.domain != query_domain:
continue
# Cosine similarity
dot_product = np.dot(query_embedding, case.embedding)
norm_a = np.linalg.norm(query_embedding)
norm_b = np.linalg.norm(case.embedding)
similarity = dot_product / (norm_a * norm_b) if norm_a > 0 and norm_b > 0 else 0
scores.append((similarity, case))
scores.sort(key=lambda x: x[0], reverse=True)
return [
{
"case_id": c.id,
"diagnosis": c.diagnosis,
"similarity": round(float(s * 100), 1)
}
for s, c in scores[:top_k]
]
# Global instance
similar_case_db = SimilarCaseDatabase()
def find_similar_cases(embedding: np.ndarray, domain: str, username: str, top_k: int = 5) -> Dict[str, Any]:
"""Find similar cases based on embedding, strictly isolated by user."""
similar = similar_case_db.find_similar(
query_embedding=embedding,
username=username,
top_k=top_k,
same_domain_only=True,
query_domain=domain
)
return {
"similar_cases": similar,
"cases_searched": len(similar_case_db.cases),
"message": f"Trouvé {len(similar)} cas similaires" if similar else "Aucun cas similaire trouvé"
}
def store_case_for_similarity(case_id: str, embedding: np.ndarray, diagnosis: str, domain: str, probability: float, username: str):
"""Store a case for fiture similarity searches, isolated by user."""
similar_case_db.add_case(
case_id=case_id,
embedding=embedding,
diagnosis=diagnosis,
domain=domain,
probability=probability,
username=username
)
# 6. ADAPTIVE PREPROCESSING
def estimate_noise_level(image: np.ndarray) -> float:
"""Estimate noise level using Laplacian method."""
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image
# Use robust median absolute deviation
laplacian = cv2.Laplacian(gray, cv2.CV_64F)
sigma = np.median(np.abs(laplacian)) / 0.6745
return float(sigma)
def apply_clahe(image: np.ndarray, clip_limit: float = 2.0, grid_size: int = 8) -> np.ndarray:
"""Apply Contrast Limited Adaptive Histogram Equalization."""
if len(image.shape) == 3:
# Convert to LAB and apply to L channel
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(grid_size, grid_size))
lab[:, :, 0] = clahe.apply(lab[:, :, 0])
return cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
else:
clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(grid_size, grid_size))
return clahe.apply(image)
def gamma_correction(image: np.ndarray, gamma: float = 1.0) -> np.ndarray:
"""Apply gamma correction for brightness adjustment."""
inv_gamma = 1.0 / gamma
table = np.array([
((i / 255.0) ** inv_gamma) * 255
for i in np.arange(0, 256)
]).astype("uint8")
return cv2.LUT(image, table)
def bilateral_denoise(image: np.ndarray, d: int = 9, sigma_color: int = 75, sigma_space: int = 75) -> np.ndarray:
"""Apply bilateral filter for edge-preserving denoising."""
return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
def adaptive_preprocessing(image_bytes: bytes) -> Tuple[Image.Image, Dict[str, Any]]:
"""
Apply intelligent preprocessing based on image analysis.
Returns processed image and a log of transformations applied.
"""
# Decode image
nparr = np.frombuffer(image_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
if img is None:
raise ValueError("Could not decode image")
transformations = []
original_stats = {
"mean_brightness": float(np.mean(img)),
"std_dev": float(np.std(img))
}
# Convert to grayscale for analysis
if len(img.shape) == 3:
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
else:
gray = img
# Analyze histogram
hist = cv2.calcHist([gray], [0], None, [256], [0, 256]).flatten()
non_zero = np.where(hist > 0)[0]
is_low_contrast = bool(len(non_zero) > 0 and (non_zero[-1] - non_zero[0]) < 150)
is_dark = bool(np.mean(gray) < 60)
is_bright = bool(np.mean(gray) > 200)
noise_level = float(estimate_noise_level(gray))
# Apply adaptive corrections
processed = img.copy()
# 1. Low contrast - Apply CLAHE
if is_low_contrast:
processed = apply_clahe(processed, clip_limit=2.5)
transformations.append({
"type": "CLAHE",
"reason": "Faible contraste détecté",
"params": {"clip_limit": 2.5}
})
# 2. Dark image - Gamma correction
if is_dark:
processed = gamma_correction(processed, gamma=0.6)
transformations.append({
"type": "Gamma Correction",
"reason": "Image trop sombre",
"params": {"gamma": 0.6}
})
# 3. Overexposed - Inverse gamma
if is_bright:
processed = gamma_correction(processed, gamma=1.6)
transformations.append({
"type": "Gamma Correction",
"reason": "Image surexposée",
"params": {"gamma": 1.6}
})
# 4. Noisy - Bilateral filter
if noise_level > 15:
processed = bilateral_denoise(processed)
transformations.append({
"type": "Bilateral Denoise",
"reason": f"Bruit détecté (σ={noise_level:.1f})",
"params": {"d": 9, "sigma": 75}
})
# 5. Black level correction for X-rays (crush blacks)
if len(processed.shape) == 2 or (len(processed.shape) == 3 and processed.shape[2] == 1):
_, processed = cv2.threshold(processed, 15, 255, cv2.THRESH_TOZERO)
transformations.append({
"type": "Black Level Crush",
"reason": "Correction niveau noir (X-ray)",
"params": {"threshold": 15}
})
# Final normalization
min_val, max_val = processed.min(), processed.max()
if max_val > min_val:
processed = ((processed - min_val) / (max_val - min_val) * 255).astype(np.uint8)
transformations.append({
"type": "Normalization",
"reason": "Normalisation finale",
"params": {"min": float(min_val), "max": float(max_val)}
})
# Convert to PIL Image
if len(processed.shape) == 2:
pil_image = Image.fromarray(processed).convert("RGB")
else:
pil_image = Image.fromarray(cv2.cvtColor(processed, cv2.COLOR_BGR2RGB))
preprocessing_log = {
"original_stats": original_stats,
"analysis": {
"low_contrast": is_low_contrast,
"dark": is_dark,
"bright": is_bright,
"noise_level": round(noise_level, 2)
},
"transformations_applied": transformations,
"transformation_count": len(transformations)
}
return pil_image, preprocessing_log
# 7. ENHANCE ANALYSIS RESULT (PIPELINE)
def enhance_analysis_result(
base_result: Dict[str, Any],
image_array: np.ndarray = None,
embedding: np.ndarray = None,
case_id: str = None,
patient_info: Dict = None,
username: str = None
) -> Dict[str, Any]:
"""
Enhance base analysis result with all 7 algorithms.
This is the main entry point for the enhanced pipeline.
"""
enhanced = base_result.copy()
# 1. Image Quality (if image provided)
if image_array is not None:
enhanced["image_quality"] = assess_image_quality(image_array)
# 2. Confidence Calibration
if "specific" in enhanced and enhanced["specific"]:
raw_probs = [p["probability"] / 100 for p in enhanced["specific"]]
labels = [p["label"] for p in enhanced["specific"]]
enhanced["confidence"] = calibrate_confidence(raw_probs, labels=labels)
# 3. Priority Scoring
if "specific" in enhanced and enhanced["specific"]:
domain = enhanced.get("domain", {}).get("label", "Unknown")
enhanced["priority"] = calculate_priority_score(enhanced["specific"], domain)
# 4. Similar Cases (if embedding provided AND username provided)
if embedding is not None and "domain" in enhanced and username:
domain = enhanced["domain"].get("label", "Unknown")
enhanced["similar_cases"] = find_similar_cases(embedding, domain, username)
# Store this case for future searches
if case_id and enhanced["specific"]:
top_pred = enhanced["specific"][0]
store_case_for_similarity(
case_id=case_id,
embedding=embedding,
diagnosis=top_pred["label"],
domain=domain,
probability=top_pred["probability"],
username=username
)
# 5. Generate Report - REMOVED HERE
# Moved to predict() method to ensure it runs AFTER localization (Translation)
# enhanced["report"] = ...
return enhanced
BASE_MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
NESTED_DIR = os.path.join(BASE_MODELS_DIR, "oeil d'elephant")
MODEL_DIR = NESTED_DIR if os.path.exists(NESTED_DIR) else BASE_MODELS_DIR
# Environment Detection
ENVIRONMENT = os.getenv("ENVIRONMENT", "development")
IS_PRODUCTION = ENVIRONMENT == "production"
# Security Configuration - JWT Secret Key (ENFORCED in production)
SECRET_KEY = os.getenv("JWT_SECRET_KEY")
if not SECRET_KEY:
if IS_PRODUCTION:
logger.critical("🔴 FATAL ERROR: JWT_SECRET_KEY must be set in production environment")
logger.critical("Generate one with: python -c 'import secrets; print(secrets.token_hex(32))'")
sys.exit(1) # Fail-fast in production
else:
# Development fallback with warning
from secrets import token_hex
SECRET_KEY = "dev_insecure_key_" + token_hex(16)
logger.warning("⚠️ WARNING: Using development JWT secret. DO NOT use in production!")
ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256")
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("JWT_EXPIRE_MINUTES", "60"))
logger.info(f"🌍 Environment: {ENVIRONMENT}")
logger.info(f"✅ JWT SECRET_KEY: {'SET (secure)' if 'dev_insecure' not in SECRET_KEY else 'DEVELOPMENT MODE'}")
# CORS Configuration
CORS_ORIGINS_STR = os.getenv("CORS_ORIGINS", "http://localhost:5173,http://127.0.0.1:5173,http://localhost:5174,http://127.0.0.1:5174,http://localhost:5175,http://127.0.0.1:5175,http://localhost:5176,http://127.0.0.1:5176")
CORS_ORIGINS = [origin.strip() for origin in CORS_ORIGINS_STR.split(",")]
# Concurrency Control
MAX_CONCURRENT_USERS = int(os.getenv("MAX_CONCURRENT_USERS", "200"))
concurrency_semaphore = asyncio.Semaphore(MAX_CONCURRENT_USERS)
# =========================================================================
# MODEL PATH CONFIGURATION (HuggingFace Hub or Local)
# =========================================================================
def get_model_path():
"""Get model path - download from HuggingFace Hub if not available locally."""
# Check environment variable first
env_path = os.getenv("MODEL_DIR")
if env_path and os.path.exists(env_path):
logger.info(f"Using model from environment: {env_path}")
return env_path
# Check local path (development)
local_path = os.path.join(os.path.dirname(__file__), "models", "oeil d'elephant")
if os.path.exists(local_path):
logger.info(f"Using local model: {local_path}")
return local_path
# Download from HuggingFace Hub (production/cloud)
try:
from huggingface_hub import snapshot_download
logger.info("Downloading model from HuggingFace Hub...")
hub_path = snapshot_download(
repo_id="issoufzousko07/medsigclip-model",
repo_type="model"
)
logger.info(f"Model downloaded to: {hub_path}")
return hub_path
except Exception as e:
logger.error(f"Failed to download model: {e}")
raise RuntimeError(f"Model not found locally and failed to download: {e}")
MODEL_DIR = None # Will be set at startup
# OAuth2 Scheme
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# =========================================================================
# MEDICAL DOMAINS CONFIGURATION
# =========================================================================
MEDICAL_DOMAINS = {
'Thoracic': {
'domain_prompt': 'Chest X-Ray Analysis',
'specific_labels': [
'Diffuse interstitial opacities or ground-glass pattern (Viral/Atypical Pneumonia)',
'Focal alveolar consolidation with air bronchograms (Bacterial Pneumonia)',
'Perfectly clear lungs, sharp costophrenic angles, no pathology',
'Pneumothorax (Lung collapse)',
'Pleural Effusion (Fluid)',
'Cardiomegaly (Enlarged heart)',
'Pulmonary Edema',
'Lung Nodule or Mass',
'Atelectasis (Lung collapse)'
]
},
'Dermatology': {
'domain_prompt': 'Dermatoscopic analysis of a pigmented or non-pigmented skin lesion',
'specific_labels': [
'A healthy skin area without lesion',
'A benign nevus (mole) regular, symmetrical and homogeneous',
'A seborrheic keratosis (benign warty lesion)',
'A malignant melanoma with asymmetry, irregular borders and multiple colors',
'A basal cell carcinoma (pearly or ulcerated lesion)',
'A squamous cell carcinoma (crusty or budding lesion)',
'A non-specific inflammatory skin lesion'
]
},
'Histology': {
'domain_prompt': 'Microscopic analysis of a histological section (H&E stain)',
'specific_labels': [
'Healthy breast tissue with preserved lobular architecture',
'Healthy prostatic tissue with regular glands',
'Invasive ductal carcinoma of the breast (Disorganized cells)',
'Prostate adenocarcinoma (Gland fusion)',
'Cervical dysplasia or intraepithelial neoplasia',
'Colon cancer tumor tissue',
'Lung cancer tumor tissue',
'Adipose tissue (Fat) or connective stroma',
'Preparation artifact or empty area'
]
},
'Ophthalmology': {
'domain_prompt': 'Fundus photography (Retina)',
'specific_labels': [
'Normal retina, healthy macula and optic disc',
'Diabetic retinopathy (hemorrhages, exudates, aneurysms)',
'Glaucoma (optic disc cupping)',
'Macular degeneration (drusen or atrophy)'
]
},
'Orthopedics': {
'domain_prompt': 'Bone X-Ray (Musculoskeletal)',
'stage_1_triage': {
'prompt': 'Anatomical region identification',
'labels': [
'Other x-ray view (Chest, Hand, Foot, Pediatric) - OUT OF DISTRIBUTION',
'A knee x-ray view (Knee Joint)'
]
},
'stage_2_diagnosis': {
'prompt': 'Knee Osteoarthritis Severity Assessment',
'labels': [
'Severe osteoarthritis with bone-on-bone contact and large osteophytes (Grade 4)',
'Moderate osteoarthritis with definite joint space narrowing (Grade 2-3)',
'Normal knee joint with preserved joint space and no osteophytes (Grade 0-1)',
'Total knee arthroplasty (TKA) with metallic implant',
'Acute knee fracture or dislocation'
]
}
}
}
# =========================================================================
# PYDANTIC MODELS
# =========================================================================
class JobStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
class Job(BaseModel):
id: str
status: JobStatus
result: Optional[Dict[str, Any]] = None
error: Optional[str] = None
created_at: float
storage_path: Optional[str] = None
encrypted_user: Optional[str] = None
username: Optional[str] = None # For registry logging
file_type: Optional[str] = None # DICOM, PNG, JPEG
start_time_ms: Optional[float] = None # For computation time
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: Optional[str] = None
class User(BaseModel):
username: str
email: Optional[str] = None
class UserInDB(User):
hashed_password: str
security_question: str
security_answer: str
class UserRegister(BaseModel):
username: str
password: str
email: Optional[str] = None
security_question: str
security_answer: str
class UserResetPassword(BaseModel):
username: str
security_answer: str
new_password: str
class FeedbackModel(BaseModel):
username: str
rating: int
comment: str
# =========================================================================
# GLOBAL STATE
# =========================================================================
jobs: Dict[str, Job] = {} # REMOVED: Now using SQLite persistence
storage_provider = get_storage_provider(os.getenv("STORAGE_MODE", "LOCAL"))
# Initialize Database
database.init_db()
# --- SEED DEFAULT USER ---
# Ensure admin user exists for immediate login
try:
if not database.get_user_by_username("admin"):
logging.info("👤 Creating default admin user...")
# Hash "secret"
admin_pw = bcrypt.hashpw(b"secret", bcrypt.gensalt()).decode('utf-8')
security_ans = bcrypt.hashpw(b"admin", bcrypt.gensalt()).decode('utf-8') # Answer: admin
database.create_user({
"username": "admin",
"hashed_password": admin_pw,
"email": "admin@elephmind.com",
"security_question": "Who is the admin?",
"security_answer": security_ans
})
logging.info("✅ Default Admin Created: admin / secret")
except Exception as e:
logging.error(f"Failed to seed admin user: {e}")
# =========================================================================
# AUTHENTICATION HELPERS
# =========================================================================
from passlib.context import CryptContext
pwd_context = CryptContext(schemes=["argon2", "bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against a bcrypt hash using passlib."""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""Generate bcrypt hash for a password using passlib."""
return pwd_context.hash(password)
def get_user(db, username: str) -> Optional[UserInDB]:
"""Retrieve user from database."""
user_dict = database.get_user_by_username(username)
if user_dict:
return UserInDB(**user_dict)
return None
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create a JWT access token."""
to_encode = data.copy()
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInDB:
"""Dependency to get the current authenticated user."""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError:
raise credentials_exception
user = get_user(None, username=token_data.username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(current_user: UserInDB = Depends(get_current_user)) -> UserInDB:
"""Dependency to get current active user."""
# Logic to check if active could be added here
# if not current_user.is_active: raise ...
return current_user
# =========================================================================
# GRAD-CAM UTILITIES (Moved to explainability.py)
# =========================================================================
# (Refactored to separate module for medical grade validation)
# =========================================================================
# MODEL WRAPPER
# =========================================================================
class MedSigClipWrapper:
"""Wrapper for the SigLIP model with medical domain inference."""
def __init__(self, model_path: str):
self.model_path = model_path
self.processor = None
self.model = None
self.loaded = False
self.load_error = None
def load(self):
"""Load the SigLIP model from the specified directory."""
logger.info(f"Initiating model load from: {self.model_path}")
if not os.path.exists(self.model_path):
self.load_error = f"Model directory not found: {self.model_path}"
logger.critical(self.load_error)
return
try:
from transformers import AutoProcessor, AutoModel
import torch
self.processor = AutoProcessor.from_pretrained(self.model_path, local_files_only=True)
self.model = AutoModel.from_pretrained(self.model_path, local_files_only=True)
self.model.eval()
# Calibrate logit scale for better probability distribution
if hasattr(self.model, 'logit_scale'):
with torch.no_grad():
self.model.logit_scale.data.fill_(3.80666) # ln(45)
self.loaded = True
logger.info("✅ MedSigClip Model Loaded Successfully (448x448 SigLIP architecture)")
except Exception as e:
self.load_error = f"Exception during load: {str(e)}"
logger.error(f"Failed to load model: {str(e)}")
def predict(self, image_bytes: bytes, username: str = None) -> Dict[str, Any]:
"""Run hierarchical inference using SigLIP Zero-Shot."""
# ... (rest of function until line 1094) ...
# I need to match the indentation and context.
# Since I can't see "inside" the dots in a replace, I have to be careful.
# It's better to update just the definition line and the call to enhance_analysis_result.
pass # Placeholder, will use multiple chunks below
if not self.loaded:
msg = "MedSigClip Model is NOT loaded. Cannot perform inference."
if self.load_error:
msg += f" Reason: {self.load_error}"
raise RuntimeError(msg)
logger.info("Starting inference pipeline...")
start_time = time.time()
try:
from PIL import Image
import io
import torch
import pydicom
# Image preprocessing functions
def process_dicom(file_bytes: bytes) -> Tuple[Image.Image, Dict[str, Any]]:
"""Convert DICOM bytes to PIL Image with tags."""
ds = pydicom.dcmread(io.BytesIO(file_bytes))
img = ds.pixel_array.astype(np.float32)
# Extract Metadata
metadata = {
"patient_id": str(ds.get("PatientID", "N/A")),
"patient_name": str(ds.get("PatientName", "N/A")),
"birth_date": str(ds.get("PatientBirthDate", "")),
"study_date": str(ds.get("StudyDate", "")),
"modality": str(ds.get("Modality", "UNKNOWN"))
}
if hasattr(ds, 'PhotometricInterpretation') and ds.PhotometricInterpretation == "MONOCHROME1":
img = img.max() - img
# Lung Window: WL=-600, WW=1500
wl, ww = -600, 1500
min_val, max_val = wl - ww/2, wl + ww/2
img = np.clip(img, min_val, max_val)
img = (img - min_val) / (max_val - min_val)
img = (img * 255).astype(np.uint8)
return Image.fromarray(img).convert("RGB"), metadata
def process_standard_image(image_bytes: bytes) -> Image.Image:
"""Process standard images (PNG/JPG) - SIMPLIFIED like Colab.
Just load the image as RGB without aggressive preprocessing."""
nparr = np.frombuffer(image_bytes, np.uint8)
img_cv = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img_cv is None:
raise ValueError("Could not decode image")
# Convert BGR to RGB (OpenCV uses BGR)
img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
return Image.fromarray(img_rgb)
# Detect image format
header = image_bytes[:32]
is_png = header.startswith(b'\x89PNG\r\n\x1a\n')
is_jpeg = header.startswith(b'\xff\xd8\xff')
image = None
dicom_metadata = None
if is_png or is_jpeg:
try:
image = process_standard_image(image_bytes)
logger.info(f"Processed as {'PNG' if is_png else 'JPEG'}")
except Exception as e:
raise ValueError(f"Corrupt Image File: {str(e)}")
if image is None:
try:
image, dicom_metadata = process_dicom(image_bytes)
logger.info("Processed as DICOM")
except Exception:
try:
image = process_standard_image(image_bytes)
except Exception as e:
raise ValueError(f"Unknown image format: {str(e)}")
# =========================================================
# ADAPTIVE PREPROCESSING - DISABLED to match Colab behavior
# The model was trained on raw images, not preprocessed ones
# =========================================================
preprocessing_log = {"message": "Preprocessing disabled for accuracy", "transformation_count": 0}
# NOTE: Uncomment below to re-enable if needed
# try:
# import io as io_module
# buffer = io_module.BytesIO()
# image.save(buffer, format='PNG')
# image_bytes_for_preprocessing = buffer.getvalue()
# image, preprocessing_log = adaptive_preprocessing(image_bytes_for_preprocessing)
# logger.info(f"🔧 Adaptive preprocessing applied: {preprocessing_log.get('transformation_count', 0)} transformations")
# except Exception as e_preproc:
# logger.warning(f"Adaptive preprocessing skipped: {e_preproc}")
# STEP 1: DOMAIN IDENTIFICATION
domain_keys = list(MEDICAL_DOMAINS.keys())
domain_prompts = [d['domain_prompt'] for d in MEDICAL_DOMAINS.values()]
inputs_domain = self.processor(
text=domain_prompts,
images=image,
padding="max_length",
return_tensors="pt"
)
with torch.no_grad():
outputs_domain = self.model(**inputs_domain)
probs_domain = torch.softmax(outputs_domain.logits_per_image, dim=1)[0]
best_domain_idx = torch.argmax(probs_domain).item()
best_domain_key = domain_keys[best_domain_idx]
best_domain_prob = float(probs_domain[best_domain_idx] * 100)
logger.info(f"Identified Domain: {best_domain_key} ({best_domain_prob:.2f}%)")
# STEP 2: SPECIFIC ANALYSIS
domain_config = MEDICAL_DOMAINS[best_domain_key]
specific_results = []
if 'stage_1_triage' in domain_config:
# Hierarchical Logic (e.g., Orthopedics)
logger.info(f"Engaging Level 2 Hierarchical Logic for: {best_domain_key}")
triage_labels = domain_config['stage_1_triage']['labels']
inputs_triage = self.processor(text=triage_labels, images=image, padding="max_length", return_tensors="pt")
with torch.no_grad():
out_triage = self.model(**inputs_triage)
probs_triage = torch.softmax(out_triage.logits_per_image, dim=1)[0]
prob_abnormal = float(probs_triage[-1])
prob_normal = 1.0 - prob_abnormal
logger.info(f"Triage: Normal={prob_normal*100:.2f}%, Abnormal={prob_abnormal*100:.2f}%")
if prob_abnormal > prob_normal:
logger.info("Running Stage 2 Diagnosis...")
diag_labels = domain_config['stage_2_diagnosis']['labels']
inputs_diag = self.processor(text=diag_labels, images=image, padding="max_length", return_tensors="pt")
with torch.no_grad():
out_diag = self.model(**inputs_diag)
probs_diag = torch.softmax(out_diag.logits_per_image, dim=1)[0]
for i, label in enumerate(diag_labels):
specific_results.append({
"label": label,
"probability": round(float(probs_diag[i] * 100), 2)
})
else:
logger.info("Triage indicates Normal/Healthy. Skipping Stage 2.")
else:
# Flat Mode (Thoracic, Dermato, etc.)
specific_labels_raw = domain_config['specific_labels']
inputs_specific = self.processor(
text=specific_labels_raw,
images=image,
padding="max_length",
return_tensors="pt"
)
with torch.no_grad():
outputs_specific = self.model(**inputs_specific)
probs_specific = torch.softmax(outputs_specific.logits_per_image, dim=1)[0]
for i, label in enumerate(specific_labels_raw):
specific_results.append({
"label": label,
"probability": round(float(probs_specific[i] * 100), 2)
})
specific_results.sort(key=lambda x: x['probability'], reverse=True)
# STEP 3: HEATMAP GENERATION (Grad-CAM++ x MedSegCLIP)
heatmap_base64 = None
original_base64 = None
try:
if specific_results:
top_label_text = specific_results[0]['label']
logger.info(f"Generating Medical Explanation for: {top_label_text}")
# FIX: Initialize container for enhanced metadata
enhanced_result = {}
# --- QUALITY CONTROL GATE (Gate 1 & 2) ---
# Added per user request: Verify quality before deep explanation/analysis
# Ideally this should be even earlier, but performing it here ensures we have the image object ready.
from quality_control import QualityControlEngine
qc_engine = QualityControlEngine()
qc_result = qc_engine.run_quality_check(image)
enhanced_result['image_quality'] = {
"quality_score": qc_result['quality_score'],
"passed": qc_result['passed'],
"reasons": qc_result['reasons'],
"metrics": qc_result['metrics']
}
if not qc_result['passed']:
logger.warning(f"⛔ Quality Control Failed: {qc_result['reasons']}")
# STRICT REJECTION: Override diagnosis and clear predictions
enhanced_result['diagnosis'] = "Analyse Refusée (Qualité Insuffisante)"
enhanced_result['confidence'] = 0.0
enhanced_result['specific'] = [] # Clear predictions
enhanced_result['quality_failure_reasons'] = qc_result['reasons']
enhanced_result['image_quality'] = {
"quality_score": qc_result['quality_score'],
"passed": False,
"reasons": qc_result['reasons'],
"metrics": qc_result['metrics']
}
# FIX: Construct localized_result explicitly as it is not defined yet
rejection_result = {
"domain": {
"label": best_domain_key,
"description": MEDICAL_DOMAINS[best_domain_key]['domain_prompt'],
"probability": round(best_domain_prob, 2)
},
"specific": [],
"heatmap": None,
"original_image": None, # Will be handled by frontend fallback or we can encode it here if mostly wanted
"preprocessing": preprocessing_log,
"explainability": {
"method": "QC Rejection",
"reliability": 0.0
}
}
# We merge enhanced info
rejection_result.update(enhanced_result)
# Encode Original Image even on Rejection for Context
try:
img_tensor = np.array(image).astype(np.float32) / 255.0
original_uint8 = (img_tensor * 255).astype(np.uint8)
_, buffer_orig = cv2.imencode('.png', cv2.cvtColor(original_uint8, cv2.COLOR_RGB2BGR))
rejection_result["original_image"] = base64.b64encode(buffer_orig).decode('utf-8')
except:
pass
return rejection_result
# If QC Passed, Proceed to Explanation
import explainability
engine = explainability.ExplainabilityEngine(self)
# Define Anatomical Context based on Domain
anatomical_context = "body part" # Default
if best_domain_key == 'Thoracic': anatomical_context = "lung parenchyma"
elif best_domain_key == 'Orthopedics': anatomical_context = "bone structure"
elif best_domain_key == 'Dermatology': anatomical_context = "skin lesion"
elif best_domain_key == 'Ophthalmology': anatomical_context = "retina"
explanation = engine.explain(
image=image,
target_text=top_label_text,
anatomical_context=anatomical_context
)
if explanation['heatmap_array'] is not None:
# Encode Visualization
vis_img = explanation['heatmap_array']
_, buffer = cv2.imencode('.png', cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR))
heatmap_base64 = base64.b64encode(buffer).decode('utf-8')
# Original Image (Normalized for consistency)
img_tensor = np.array(image).astype(np.float32) / 255.0
original_uint8 = (img_tensor * 255).astype(np.uint8)
_, buffer_orig = cv2.imencode('.png', cv2.cvtColor(original_uint8, cv2.COLOR_RGB2BGR))
original_base64 = base64.b64encode(buffer_orig).decode('utf-8')
reliability = explanation.get("reliability_score", 0)
logger.info(f"✅ Explanation Generated. Reliability: {reliability} ({explanation.get('confidence_label')})")
else:
logger.warning("Could not generate explainability map.")
except Exception as e_cam:
import traceback
logger.error(f"Explainability Pipeline Failed: {traceback.format_exc()}")
# FINAL RESULT (Base)
result_json = {
"domain": {
"label": best_domain_key,
"description": MEDICAL_DOMAINS[best_domain_key]['domain_prompt'],
"probability": round(best_domain_prob, 2)
},
"specific": specific_results,
"heatmap": heatmap_base64,
"original_image": original_base64,
"preprocessing": preprocessing_log,
"explainability": { # NEW METADATA
"method": "Grad-CAM++ x MedSegCLIP (Proxy)",
"anatomical_context": anatomical_context if 'anatomical_context' in locals() else "Unknown",
"reliability": explanation.get("reliability_score") if 'explanation' in locals() else 0
}
}
# =========================================================
# APPLY 7 INTELLIGENCE ALGORITHMS
# =========================================================
logger.info("🧠 Applying Intelligence Algorithms...")
# Convert PIL image to numpy for quality assessment
image_array = np.array(image)
# Get image embedding for similar case detection
# OPT PERFORMANCE: Disabled to reduce inference time (<60s)
image_embedding = None
# try:
# with torch.no_grad():
# img_inputs = self.processor(images=image, return_tensors="pt")
# image_embedding = self.model.get_image_features(**img_inputs)
# image_embedding = image_embedding.cpu().numpy().flatten()
# except Exception as e_emb:
# logger.warning(f"Could not extract embedding: {e_emb}")
# image_embedding = None
# Enhance result with all algorithms
enhanced_result = enhance_analysis_result(
base_result=result_json,
image_array=image_array,
embedding=image_embedding,
case_id=str(uuid.uuid4()),
patient_info=None,
username=username
)
# --- LOCALIZATION (Translate to French) ---
localized_result = localize_result(enhanced_result)
# --- GENERATE REPORT (After Localization) ---
# Now the labels in localized_result['specific'] are in French
localized_result["report"] = generate_clinical_report(
localized_result,
patient_info=None
)
# --- MAP TO FRONTEND EXPECTATIONS ---
# frontend expects: diagnosis, confidence, productions, quality_metrics, etc.
# 1. Diagnosis
top_finding = enhanced_result['specific'][0] if enhanced_result['specific'] else {"label": "Inconnu", "probability": 0}
enhanced_result['diagnosis'] = top_finding['label']
# 2. Confidence & Calibrated
enhanced_result['calibrated_confidence'] = enhanced_result.get('confidence', top_finding['probability'])
enhanced_result['confidence'] = top_finding['probability']
# 3. Processing Time (Real Measurement)
enhanced_result['processing_time'] = round(time.time() - start_time, 3)
# 4. Predictions (Alias for specific)
enhanced_result['predictions'] = [
{"name": item['label'], "probability": item['probability']}
for item in enhanced_result['specific']
]
# 5. Quality Metrics (Flatten structure)
if 'image_quality' in enhanced_result:
enhanced_result['quality_score'] = enhanced_result['image_quality']['quality_score']
enhanced_result['quality_metrics'] = enhanced_result['image_quality']['metrics']
# 6. Priority
# If priority is a dict (from new algo), extract just the level/score for simple display, or keep object
# Frontend expects string 'priority' sometimes, or maybe object. Let's provide string for badge.
if isinstance(enhanced_result.get('priority'), str):
pass
elif isinstance(enhanced_result.get('priority'), dict):
# Flatten for frontend simple badge
enhanced_result['priority'] = enhanced_result['priority'].get('level', 'Normale')
# 7. DICOM Metadata (if available)
if dicom_metadata:
enhanced_result['patient_metadata'] = dicom_metadata
logger.info("✅ Intelligence Algorithms applied successfully")
logger.info("✅ Intelligence Algorithms applied successfully")
return localized_result
except Exception as e:
logger.error(f"Inference Error: {str(e)}")
raise e
# =========================================================================
# GLOBAL MODEL INSTANCE
# =========================================================================
model_wrapper: Optional[MedSigClipWrapper] = None
# =========================================================================
# FASTAPI LIFECYCLE
# =========================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
global model_wrapper, MODEL_DIR # CRITICAL: Use global variables
database.init_db()
database.init_analysis_registry()
# Get model path (downloads from HuggingFace Hub if needed)
MODEL_DIR = get_model_path()
model_wrapper = MedSigClipWrapper(MODEL_DIR)
model_wrapper.load()
logger.info("ElephMind Backend Started")
yield
logger.info("ElephMind Backend Shutting Down")
app = FastAPI(
lifespan=lifespan,
title="ElephMind Medical AI API",
version="2.0.0",
description="Medical image analysis powered by SigLIP"
)
# CORS Middleware with configurable origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins to fix "Failed to fetch" for user
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""
Handle validation errors gracefully, stripping binary/unsafe data from logs/response.
Fixes UnicodeDecodeError when multipart/form-data is sent to JSON endpoint.
"""
errors = exc.errors()
clean_errors = []
for error in errors:
copy = error.copy()
# Remove raw binary input which causes jsonable_encoder crash
if 'input' in copy:
val = copy['input']
if isinstance(val, (bytes, bytearray)):
copy['input'] = "<binary_data_stripped>"
elif isinstance(val, str) and len(val) > 200:
copy['input'] = val[:200] + "..." # Truncate long inputs
clean_errors.append(copy)
logger.warning(f"Validation Error on {request.url.path}: {clean_errors}")
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={"detail": jsonable_encoder(clean_errors)},
)
@app.middleware("http")
async def limit_concurrency(request: Request, call_next):
"""Limit concurrent requests to MAX_CONCURRENT_USERS."""
if request.url.path == "/health" or request.method == "OPTIONS":
return await call_next(request)
if concurrency_semaphore.locked():
logger.warning(f"Concurrency limit ({MAX_CONCURRENT_USERS}) reached. Request queued.")
async with concurrency_semaphore:
return await call_next(request)
# =========================================================================
# BACKGROUND WORKER
# =========================================================================
# =========================================================================
# BACKGROUND WORKER (Decoupled)
# =========================================================================
async def process_analysis_job(job_id: str, image_id: str, username: str):
"""
Worker that retrieves image from disk by ID and processes it.
Zero-shared-memory with API.
"""
# RESILIENCE: Retrieve job from DB
job = database.get_job(job_id)
if not job:
logger.error(f"❌ Job {job_id} not found DB")
return
logger.info(f"Worker processing Job {job_id} (Image: {image_id})")
database.update_job_status(job_id, JobStatus.PROCESSING.value)
start_time = time.time()
try:
if not model_wrapper:
raise RuntimeError("Model wrapper not initialized.")
# LOAD IMAGE FROM DISK (Physical Read)
image_bytes, file_path = storage_manager.load_image(username, image_id)
loop = asyncio.get_event_loop()
# Pass username to predict for isolation
import functools
result = await loop.run_in_executor(None, functools.partial(model_wrapper.predict, image_bytes, username=username))
# Calculate computation time
computation_time_ms = int((time.time() - start_time) * 1000)
# Update Job in DB
database.update_job_status(job_id, JobStatus.COMPLETED.value, result=result)
# Log to registry (REAL DATA)
if username and result:
domain = result.get('domain', {}).get('label', 'Unknown')
top_diag = result.get('specific', [{}])[0].get('label', 'Unknown') if result.get('specific') else 'Unknown'
confidence = result.get('specific', [{}])[0].get('probability', 0) if result.get('specific') else 0
priority = result.get('priority', 'Normale')
database.log_analysis(
username=username,
domain=domain,
top_diagnosis=top_diag,
confidence=confidence,
priority=priority,
computation_time_ms=computation_time_ms,
file_type='SavedImage'
)
logger.info(f"✅ Job {job_id} logged to registry")
logger.info(f"✅ Job {job_id} completed in {computation_time_ms}ms")
except Exception as e:
logger.error(f"❌ Job {job_id} failed: {str(e)}")
database.update_job_status(job_id, JobStatus.FAILED.value, error=str(e))
# =========================================================================
# API ENDPOINTS
# =========================================================================
# --- Authentication ---
@app.post("/token", response_model=Token)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
"""Authenticate user and return JWT token."""
user = database.get_user_by_username(form_data.username)
if not user or not verify_password(form_data.password, user['hashed_password']):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token = create_access_token(
data={"sub": user['username']},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
return {"access_token": access_token, "token_type": "bearer"}
class AnalysisRequest(BaseModel):
image_id: str
domain: str = "Triage"
priority: str = "Normale"
@app.post("/register", status_code=status.HTTP_201_CREATED)
async def register_user(user: UserRegister):
"""Register a new user."""
hashed_pw = get_password_hash(user.password)
# Hash security answer too for extra security
hashed_security_answer = get_password_hash(user.security_answer.strip().lower())
user_data = {
"username": user.username,
"hashed_password": hashed_pw,
"email": user.email,
"security_question": user.security_question,
"security_answer": hashed_security_answer
}
success = database.create_user(user_data)
if not success:
raise HTTPException(status_code=400, detail="Username already exists")
return {"message": "User created successfully"}
@app.get("/recover/{username}")
async def get_security_question(username: str):
"""Get security question for password recovery."""
user = database.get_user_by_username(username)
if not user:
raise HTTPException(status_code=404, detail="User not found")
return {"question": user['security_question']}
@app.post("/recover/reset")
async def reset_password(data: UserResetPassword):
"""Reset password using security question."""
user = database.get_user_by_username(data.username)
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Verify security answer (hashed comparison)
if not verify_password(data.security_answer.strip().lower(), user['security_answer']):
raise HTTPException(status_code=400, detail="Incorrect security answer")
new_hashed_pw = get_password_hash(data.new_password)
database.update_password(data.username, new_hashed_pw)
return {"message": "Password reset successfully"}
# --- Dashboard Analytics (REAL DATA ONLY) ---
@app.get("/api/dashboard/stats")
async def get_dashboard_statistics(current_user: User = Depends(get_current_user)):
"""
Get real dashboard statistics for the authenticated user.
Returns zeros if no analyses have been performed. NO FAKE DATA.
"""
stats = database.get_dashboard_stats(current_user.username)
recent = database.get_recent_analyses(current_user.username, limit=10)
return {
**stats,
"recent_analyses": recent
}
@app.post("/feedback")
async def submit_feedback(feedback: FeedbackModel):
"""Submit user feedback."""
database.add_feedback(feedback.username, feedback.rating, feedback.comment)
return {"message": "Feedback received"}
# --- Medical Analysis ---
# --- Analysis Flow (Async Job Architecture) ---
# Local modules
import database
import storage_manager
import dicom_processor # NEW: Medical Validation
from database import JobStatus
from storage import get_storage_provider
# ...
@app.post("/upload")
async def upload_image(
file: UploadFile = File(...),
current_user: User = Depends(get_current_active_user)
):
"""
Step 1: Upload image to physical storage.
- VALIDATES DICOM Compliance (if .dcm)
- ANONYMIZES Patient Data (PHI)
- Returns image_id to be used in analysis.
"""
try:
content = await file.read()
# Detect DICOM Magic Bytes (DICM at offset 128)
is_dicom = len(content) > 132 and content[128:132] == b'DICM'
if is_dicom:
logger.info(f"DICOM File detected for user {current_user.username}. Validating...")
try:
# Validate & Anonymize
safe_content, metadata = dicom_processor.process_dicom_upload(content, current_user.username)
# Use safe content for storage
content = safe_content
logger.info("✅ DICOM Validated and Anonymized.")
except ValueError as ve:
logger.error(f"❌ DICOM Rejected: {ve}")
raise HTTPException(status_code=400, detail=f"Conformité DICOM refusée: {str(ve)}")
# Save to Disk
image_id = storage_manager.save_image(
username=current_user.username,
file_bytes=content,
filename_hint=file.filename if not is_dicom else "anon.dcm"
)
return {
"image_id": image_id,
"status": "UPLOADED",
"message": "Image secured & sanitized. Ready for analysis."
}
except HTTPException as he:
raise he
except Exception as e:
logger.error(f"Upload failed: {e}")
raise HTTPException(status_code=500, detail=f"Upload Error: {str(e)}")
@app.post("/analyze", status_code=status.HTTP_202_ACCEPTED)
async def analyze_image(
request: AnalysisRequest,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_active_user)
):
"""
Step 2: Create Analysis Job using existing image_id.
Decoupled from upload.
"""
if not model_wrapper or not model_wrapper.loaded:
raise HTTPException(status_code=503, detail="Model not loaded yet")
# Verify image exists physically
try:
_ = storage_manager.get_image_absolute_path(current_user.username, request.image_id)
if not _:
raise FileNotFoundError()
except Exception:
raise HTTPException(status_code=404, detail="Image ID not found. Upload first.")
# Create Job ID
task_id = str(uuid.uuid4())
# Persist Job PENDING state
job_data = {
'id': task_id,
'status': JobStatus.PENDING.value,
'created_at': time.time(),
'result': None,
'error': None,
'storage_path': request.image_id, # Link to storage
'username': current_user.username,
'file_type': 'Unknown'
}
database.create_job(job_data)
# Enqueue Worker (Pass ID, not bytes)
background_tasks.add_task(process_analysis_job, task_id, request.image_id, current_user.username)
return {
"task_id": task_id,
"status": "queued",
"image_id": request.image_id
}
@app.get("/job/current")
async def get_current_job(current_user: User = Depends(get_current_active_user)):
"""
Get the latest job state for the user to restore UI on refresh.
Returns 404 if no recent job found (< 24h).
"""
job = database.get_latest_job(current_user.username)
if not job:
raise HTTPException(status_code=404, detail="No active job")
# Check if job is stale (e.g. > 24 hours old)
# If completed and old, we might not want to auto-load it on fresh login
# But for F5 refresh, we definitely want it.
# Heuristic: If < 1 hour, always return.
created_at = job.get('created_at', 0)
if time.time() - created_at > 86400: # 24 hours
raise HTTPException(status_code=404, detail="Job expired")
return {
"task_id": job['id'],
"status": job['status'],
"result": job['result'],
"error": job.get('error'),
"created_at": created_at,
"image_id": job.get('storage_path')
}
@app.get("/result/{task_id}")
async def get_result(task_id: str, current_user: User = Depends(get_current_user)):
"""
Get analysis result by task ID.
- **Requires authentication**
- Returns job status and results when complete
"""
# Retrieve job from DB - ENFORCE OWNERSHIP AT SQL LEVEL
job = database.get_job(task_id, username=current_user.username)
if not job:
# If job calls return None with username, it means either 404 or 403 (effectively 404 for security)
raise HTTPException(status_code=404, detail="Job not found or access denied")
# Redundant check removed as SQL handles it, but kept for audit logging if needed
# if job.get('username') != current_user.username: ...
logger.info(f"Polling Job {task_id}: Status={job.get('status')}")
return job
@app.get("/health")
def health_check():
"""Health check endpoint."""
loaded = model_wrapper.loaded if model_wrapper else False
return {
"status": "running",
"model_loaded": loaded,
"version": "2.0.0"
}
@app.get("/", include_in_schema=False)
async def root():
"""Redirect root to docs."""
from fastapi.responses import RedirectResponse
return RedirectResponse(url="/docs")
# --- DASHBOARD ENDPOINTS ---
@app.get("/api/dashboard/stats")
async def get_dashboard_stats_endpoint(current_user: User = Depends(get_current_user)):
"""Get real dashboard statistics for the authenticated user."""
try:
stats = database.get_dashboard_stats(current_user.username)
recent = database.get_recent_analyses(current_user.username, limit=5)
# Combine
return {
**stats,
"recent_analyses": recent
}
except Exception as e:
logger.error(f"Error fetching dashboard stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
# =========================================================================
# PATIENT API (New for Migration)
# =========================================================================
class PatientCreate(BaseModel):
patient_id: str
first_name: str = ""
last_name: str = ""
birth_date: str = ""
photo: Optional[str] = None # Base64 or URL
class PatientUpdate(BaseModel):
first_name: Optional[str] = None
last_name: Optional[str] = None
birth_date: Optional[str] = None
photo: Optional[str] = None
@app.post("/patients", status_code=status.HTTP_201_CREATED)
async def create_patient_endpoint(patient: PatientCreate, current_user: User = Depends(get_current_user)):
"""Create a new patient."""
pid = database.create_patient(
owner_username=current_user.username,
patient_id=patient.patient_id,
first_name=patient.first_name,
last_name=patient.last_name,
birth_date=patient.birth_date,
photo=patient.photo
)
if not pid:
raise HTTPException(status_code=400, detail="Could not create patient (ID might exist)")
return {"id": pid, "message": "Patient created"}
@app.get("/patients")
async def get_patients_endpoint(current_user: User = Depends(get_current_user)):
"""Get all patients for the current user."""
return database.get_patients_by_user(current_user.username)
@app.put("/patients/{patient_id}")
async def update_patient_endpoint(patient_id: int, updates: PatientUpdate, current_user: User = Depends(get_current_user)):
"""Update a patient."""
data = updates.dict(exclude_unset=True)
if not data:
raise HTTPException(status_code=400, detail="No data to update")
success = database.update_patient(current_user.username, patient_id, data)
if not success:
raise HTTPException(status_code=404, detail="Patient not found")
return {"message": "Patient updated"}
@app.delete("/patients/{patient_id}")
async def delete_patient_endpoint(patient_id: int, current_user: User = Depends(get_current_user)):
"""Delete a patient."""
success = database.delete_patient(current_user.username, patient_id)
if not success:
raise HTTPException(status_code=404, detail="Patient not found")
return {"message": "Patient deleted"}
@app.get("/api/dashboard/stats")
async def get_dashboard_stats_endpoint(current_user: User = Depends(get_current_user)):
"""Get dashboard statistics and recent analyses."""
stats = database.get_dashboard_stats(current_user.username)
recent = database.get_recent_analyses(current_user.username, limit=10)
stats["recent_analyses"] = recent
return stats
# =========================================================================
# MAIN ENTRY POINT
# =========================================================================
if __name__ == "__main__":
# Initialize DB tables including registry
database.init_db()
database.init_analysis_registry()
host = os.getenv("SERVER_HOST", "0.0.0.0")
port = int(os.getenv("SERVER_PORT", "8022"))
uvicorn.run(app, host=host, port=port)