import os import cv2 import numpy as np import mediapipe as mp from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from scipy import signal from scipy.signal import find_peaks import tempfile print("[Init] Loading MediaPipe...", flush=True) mp_face_mesh = mp.solutions.face_mesh face_mesh = mp_face_mesh.FaceMesh( static_image_mode=False, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5, min_tracking_confidence=0.5 ) print("[Init] MediaPipe OK", flush=True) app = FastAPI(title="AF Detector API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ── Signal Processing ───────────────────────────────────── def extract_rppg_signal(frames, fps=30): """ Extrait le signal rPPG depuis les frames vidéo. Utilise le canal vert de la ROI du visage (MediaPipe). """ green_signal = [] valid_frames = 0 for frame in frames: rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) result = face_mesh.process(rgb) if result.multi_face_landmarks: lm = result.multi_face_landmarks[0].landmark h, w = frame.shape[:2] # ROI = joues + front (zones riches en vaisseaux) roi_points = [ # Front (int(lm[10].x * w), int(lm[10].y * h)), (int(lm[151].x * w), int(lm[151].y * h)), # Joue gauche (int(lm[234].x * w), int(lm[234].y * h)), (int(lm[93].x * w), int(lm[93].y * h)), # Joue droite (int(lm[454].x * w), int(lm[454].y * h)), (int(lm[323].x * w), int(lm[323].y * h)), ] # Bounding box de la ROI xs = [p[0] for p in roi_points] ys = [p[1] for p in roi_points] x1, x2 = max(0, min(xs)), min(w, max(xs)) y1, y2 = max(0, min(ys)), min(h, max(ys)) if x2 > x1 and y2 > y1: roi = frame[y1:y2, x1:x2] # Canal vert (le plus sensible aux pulsations) g_mean = np.mean(roi[:, :, 1]) green_signal.append(g_mean) valid_frames += 1 else: # Pas de visage — utiliser frame entière comme fallback g_mean = np.mean(frame[:, :, 1]) green_signal.append(g_mean) face_ratio = valid_frames / len(frames) if frames else 0 return np.array(green_signal), face_ratio def bandpass_filter(signal_data, fs, lowcut=0.7, highcut=4.0, order=4): """ Filtre passe-bande Butterworth. 0.7 Hz = 42 BPM (min) 4.0 Hz = 240 BPM (max) """ nyq = fs / 2.0 low = lowcut / nyq high = highcut / nyq b, a = signal.butter(order, [low, high], btype='band') return signal.filtfilt(b, a, signal_data) def detect_peaks_rr(filtered_signal, fps): """ Détecte les pics du signal cardiaque → intervalles RR. """ min_distance = int(fps * 0.35) # min 350ms entre pics (max ~170 BPM) threshold = np.std(filtered_signal) * 0.3 peaks, properties = find_peaks( filtered_signal, distance=min_distance, height=threshold ) return peaks def compute_hrv_metrics(rr_intervals_ms): """ Calcule les métriques HRV classiques utilisées pour détecter la FA. """ if len(rr_intervals_ms) < 5: return None rr = np.array(rr_intervals_ms) # Métriques temporelles mean_rr = np.mean(rr) sdnn = np.std(rr) # variabilité globale rmssd = np.sqrt(np.mean(np.diff(rr)**2)) # variabilité court terme pnn50 = np.sum(np.abs(np.diff(rr)) > 50) / len(rr) * 100 # % diff > 50ms # BPM bpm = round(60000 / mean_rr) # Coefficient de variation (CV) — clé pour FA cv = (sdnn / mean_rr) * 100 # Irregularity index — entropie approchée diffs = np.abs(np.diff(rr)) irr_index = round(min(100, (np.mean(diffs) / mean_rr) * 100)) return { "bpm": int(np.clip(bpm, 30, 250)), "mean_rr": round(float(mean_rr), 1), "sdnn": round(float(sdnn), 1), "rmssd": round(float(rmssd), 1), "pnn50": round(float(pnn50), 1), "cv": round(float(cv), 2), "irr_index": irr_index, "rr_count": len(rr_intervals_ms), } def compute_af_score(metrics): """ Score de risque FA (0-100) basé sur les métriques HRV. Critères cliniques FA : - Absence de onde P régulière → RR irréguliers - RMSSD élevé - CV élevé (>10%) - pNN50 élevé - Pattern d'irrégularité sans rythme """ score = 0 reasons = [] bpm = metrics["bpm"] rmssd = metrics["rmssd"] cv = metrics["cv"] pnn50 = metrics["pnn50"] irr = metrics["irr_index"] sdnn = metrics["sdnn"] # BPM anormal if bpm < 50: score += 15; reasons.append(f"Bradycardie ({bpm} BPM)") elif bpm > 100: score += 20; reasons.append(f"Tachycardie ({bpm} BPM)") # RMSSD — variabilité élevée = irrégularité if rmssd > 100: score += 30; reasons.append(f"RMSSD très élevé ({rmssd}ms)") elif rmssd > 60: score += 18; reasons.append(f"RMSSD élevé ({rmssd}ms)") elif rmssd > 40: score += 8 # CV — coefficient de variation if cv > 15: score += 25; reasons.append(f"Variabilité RR critique (CV={cv}%)") elif cv > 10: score += 15; reasons.append(f"Variabilité RR élevée (CV={cv}%)") elif cv > 6: score += 6 # pNN50 if pnn50 > 40: score += 15; reasons.append(f"pNN50 élevé ({pnn50}%)") elif pnn50 > 20: score += 8 # Irregularity index if irr > 25: score += 10; reasons.append(f"Irrégularité marquée ({irr}%)") score = int(min(100, score)) # Classification if score < 25: result = "NORMAL" label = "Normal Sinus Rhythm" risk = "LOW" elif score < 50: result = "IRREGULAR" label = "Irregular Pattern Detected" risk = "MODERATE" else: result = "AF_SUSPECTED" label = "Atrial Fibrillation Suspected" risk = "HIGH" return { "af_score": score, "result": result, "label": label, "risk": risk, "reasons": reasons, } # ── API Endpoints ───────────────────────────────────────── @app.get("/health") def health(): return {"status": "ok", "service": "AF Detector API v1.0"} @app.post("/analyze/") async def analyze_video(video_file: UploadFile = File(...), fps: float = 30.0): """ Analyse une vidéo de 30s pour détecter la FA via rPPG. Retourne les métriques HRV et le score AF. """ # Sauvegarder la vidéo temporairement suffix = os.path.splitext(video_file.filename)[-1] or ".mp4" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: tmp.write(await video_file.read()) tmp_path = tmp.name try: # Lire la vidéo cap = cv2.VideoCapture(tmp_path) if not cap.isOpened(): return {"error": "Cannot open video file"} real_fps = cap.get(cv2.CAP_PROP_FPS) or fps frames = [] while True: ret, frame = cap.read() if not ret: break # Subsample si fps > 30 if len(frames) % max(1, int(real_fps / 30)) == 0: frames.append(frame) cap.release() os.remove(tmp_path) if len(frames) < 60: return {"error": "Video too short. Minimum 30 seconds required.", "frames": len(frames)} print(f"[Analyze] Frames: {len(frames)} | FPS: {real_fps:.1f}", flush=True) # ── rPPG extraction ────────────────────────────── green_signal, face_ratio = extract_rppg_signal(frames, real_fps) print(f"[Analyze] Face detected: {face_ratio:.1%}", flush=True) if face_ratio < 0.3: return {"error": "Face not detected in most frames. Ensure good lighting.", "face_ratio": face_ratio} # ── Détrending + filtrage ──────────────────────── # Supprimer la tendance lente (illumination) detrended = signal.detrend(green_signal) # Normaliser detrended = (detrended - np.mean(detrended)) / (np.std(detrended) + 1e-8) # Bandpass 0.7-4Hz filtered = bandpass_filter(detrended, real_fps) # ── Peak detection ─────────────────────────────── peaks = detect_peaks_rr(filtered, real_fps) print(f"[Analyze] Peaks detected: {len(peaks)}", flush=True) if len(peaks) < 8: return {"error": "Signal too noisy. Stay still and ensure good lighting.", "peaks": len(peaks)} # ── RR intervals (ms) ──────────────────────────── rr_intervals = [(peaks[i] - peaks[i-1]) / real_fps * 1000 for i in range(1, len(peaks))] # Filtrer les RR aberrants (< 300ms ou > 2000ms) rr_intervals = [rr for rr in rr_intervals if 300 < rr < 2000] if len(rr_intervals) < 5: return {"error": "Not enough valid beats detected."} # ── HRV metrics ────────────────────────────────── metrics = compute_hrv_metrics(rr_intervals) if not metrics: return {"error": "Cannot compute HRV metrics."} # ── AF Score ───────────────────────────────────── af_data = compute_af_score(metrics) print(f"[Analyze] BPM={metrics['bpm']} | RMSSD={metrics['rmssd']} | CV={metrics['cv']} | AF_Score={af_data['af_score']}", flush=True) return { "success": True, "face_ratio": round(face_ratio, 2), "frames": len(frames), "fps": round(real_fps, 1), "peaks_count": len(peaks), "rr_intervals": [round(rr, 1) for rr in rr_intervals[-30:]], # derniers 30 **metrics, **af_data, "disclaimer": "Experimental AI tool. Not a medical diagnosis. Consult a cardiologist." } except Exception as e: if os.path.exists(tmp_path): os.remove(tmp_path) print(f"[Error] {e}", flush=True) return {"error": str(e)} @app.post("/analyze_frames/") async def analyze_frames_json(data: dict): """ Alternative : reçoit le signal vert directement depuis le frontend. Plus rapide — pas besoin d'encoder/décoder la vidéo. """ green_signal = np.array(data.get("green_signal", [])) fps = float(data.get("fps", 30.0)) if len(green_signal) < 60: return {"error": "Signal too short. Minimum 60 samples required."} try: detrended = signal.detrend(green_signal) detrended = (detrended - np.mean(detrended)) / (np.std(detrended) + 1e-8) filtered = bandpass_filter(detrended, fps) peaks = detect_peaks_rr(filtered, fps) if len(peaks) < 5: return {"error": "Signal too noisy. Stay still.", "peaks": len(peaks)} rr_intervals = [(peaks[i] - peaks[i-1]) / fps * 1000 for i in range(1, len(peaks))] rr_intervals = [rr for rr in rr_intervals if 300 < rr < 2000] if len(rr_intervals) < 4: return {"error": "Not enough valid beats."} metrics = compute_hrv_metrics(rr_intervals) af_data = compute_af_score(metrics) print(f"[Frames] BPM={metrics['bpm']} | RMSSD={metrics['rmssd']} | AF={af_data['af_score']}", flush=True) return { "success": True, "rr_intervals": [round(rr, 1) for rr in rr_intervals], **metrics, **af_data, } except Exception as e: print(f"[Error] {e}", flush=True) return {"error": str(e)}