Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import numpy as np | |
| import mediapipe as mp | |
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from scipy import signal | |
| from scipy.signal import find_peaks | |
| import tempfile | |
| print("[Init] Loading MediaPipe...", flush=True) | |
| mp_face_mesh = mp.solutions.face_mesh | |
| face_mesh = mp_face_mesh.FaceMesh( | |
| static_image_mode=False, | |
| max_num_faces=1, | |
| refine_landmarks=True, | |
| min_detection_confidence=0.5, | |
| min_tracking_confidence=0.5 | |
| ) | |
| print("[Init] MediaPipe OK", flush=True) | |
| app = FastAPI(title="AF Detector API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ Signal Processing βββββββββββββββββββββββββββββββββββββ | |
| def extract_rppg_signal(frames, fps=30): | |
| """ | |
| Extrait le signal rPPG depuis les frames vidΓ©o. | |
| Utilise le canal vert de la ROI du visage (MediaPipe). | |
| """ | |
| green_signal = [] | |
| valid_frames = 0 | |
| for frame in frames: | |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| result = face_mesh.process(rgb) | |
| if result.multi_face_landmarks: | |
| lm = result.multi_face_landmarks[0].landmark | |
| h, w = frame.shape[:2] | |
| # ROI = joues + front (zones riches en vaisseaux) | |
| roi_points = [ | |
| # Front | |
| (int(lm[10].x * w), int(lm[10].y * h)), | |
| (int(lm[151].x * w), int(lm[151].y * h)), | |
| # Joue gauche | |
| (int(lm[234].x * w), int(lm[234].y * h)), | |
| (int(lm[93].x * w), int(lm[93].y * h)), | |
| # Joue droite | |
| (int(lm[454].x * w), int(lm[454].y * h)), | |
| (int(lm[323].x * w), int(lm[323].y * h)), | |
| ] | |
| # Bounding box de la ROI | |
| xs = [p[0] for p in roi_points] | |
| ys = [p[1] for p in roi_points] | |
| x1, x2 = max(0, min(xs)), min(w, max(xs)) | |
| y1, y2 = max(0, min(ys)), min(h, max(ys)) | |
| if x2 > x1 and y2 > y1: | |
| roi = frame[y1:y2, x1:x2] | |
| # Canal vert (le plus sensible aux pulsations) | |
| g_mean = np.mean(roi[:, :, 1]) | |
| green_signal.append(g_mean) | |
| valid_frames += 1 | |
| else: | |
| # Pas de visage β utiliser frame entiΓ¨re comme fallback | |
| g_mean = np.mean(frame[:, :, 1]) | |
| green_signal.append(g_mean) | |
| face_ratio = valid_frames / len(frames) if frames else 0 | |
| return np.array(green_signal), face_ratio | |
| def bandpass_filter(signal_data, fs, lowcut=0.7, highcut=4.0, order=4): | |
| """ | |
| Filtre passe-bande Butterworth. | |
| 0.7 Hz = 42 BPM (min) | |
| 4.0 Hz = 240 BPM (max) | |
| """ | |
| nyq = fs / 2.0 | |
| low = lowcut / nyq | |
| high = highcut / nyq | |
| b, a = signal.butter(order, [low, high], btype='band') | |
| return signal.filtfilt(b, a, signal_data) | |
| def detect_peaks_rr(filtered_signal, fps): | |
| """ | |
| DΓ©tecte les pics du signal cardiaque β intervalles RR. | |
| """ | |
| min_distance = int(fps * 0.35) # min 350ms entre pics (max ~170 BPM) | |
| threshold = np.std(filtered_signal) * 0.3 | |
| peaks, properties = find_peaks( | |
| filtered_signal, | |
| distance=min_distance, | |
| height=threshold | |
| ) | |
| return peaks | |
| def compute_hrv_metrics(rr_intervals_ms): | |
| """ | |
| Calcule les mΓ©triques HRV classiques utilisΓ©es pour dΓ©tecter la FA. | |
| """ | |
| if len(rr_intervals_ms) < 5: | |
| return None | |
| rr = np.array(rr_intervals_ms) | |
| # MΓ©triques temporelles | |
| mean_rr = np.mean(rr) | |
| sdnn = np.std(rr) # variabilitΓ© globale | |
| rmssd = np.sqrt(np.mean(np.diff(rr)**2)) # variabilitΓ© court terme | |
| pnn50 = np.sum(np.abs(np.diff(rr)) > 50) / len(rr) * 100 # % diff > 50ms | |
| # BPM | |
| bpm = round(60000 / mean_rr) | |
| # Coefficient de variation (CV) β clΓ© pour FA | |
| cv = (sdnn / mean_rr) * 100 | |
| # Irregularity index β entropie approchΓ©e | |
| diffs = np.abs(np.diff(rr)) | |
| irr_index = round(min(100, (np.mean(diffs) / mean_rr) * 100)) | |
| return { | |
| "bpm": int(np.clip(bpm, 30, 250)), | |
| "mean_rr": round(float(mean_rr), 1), | |
| "sdnn": round(float(sdnn), 1), | |
| "rmssd": round(float(rmssd), 1), | |
| "pnn50": round(float(pnn50), 1), | |
| "cv": round(float(cv), 2), | |
| "irr_index": irr_index, | |
| "rr_count": len(rr_intervals_ms), | |
| } | |
| def compute_af_score(metrics): | |
| """ | |
| Score de risque FA (0-100) basΓ© sur les mΓ©triques HRV. | |
| Critères cliniques FA : | |
| - Absence de onde P rΓ©guliΓ¨re β RR irrΓ©guliers | |
| - RMSSD Γ©levΓ© | |
| - CV Γ©levΓ© (>10%) | |
| - pNN50 Γ©levΓ© | |
| - Pattern d'irrΓ©gularitΓ© sans rythme | |
| """ | |
| score = 0 | |
| reasons = [] | |
| bpm = metrics["bpm"] | |
| rmssd = metrics["rmssd"] | |
| cv = metrics["cv"] | |
| pnn50 = metrics["pnn50"] | |
| irr = metrics["irr_index"] | |
| sdnn = metrics["sdnn"] | |
| # BPM anormal | |
| if bpm < 50: | |
| score += 15; reasons.append(f"Bradycardie ({bpm} BPM)") | |
| elif bpm > 100: | |
| score += 20; reasons.append(f"Tachycardie ({bpm} BPM)") | |
| # RMSSD β variabilitΓ© Γ©levΓ©e = irrΓ©gularitΓ© | |
| if rmssd > 100: | |
| score += 30; reasons.append(f"RMSSD très élevé ({rmssd}ms)") | |
| elif rmssd > 60: | |
| score += 18; reasons.append(f"RMSSD Γ©levΓ© ({rmssd}ms)") | |
| elif rmssd > 40: | |
| score += 8 | |
| # CV β coefficient de variation | |
| if cv > 15: | |
| score += 25; reasons.append(f"VariabilitΓ© RR critique (CV={cv}%)") | |
| elif cv > 10: | |
| score += 15; reasons.append(f"VariabilitΓ© RR Γ©levΓ©e (CV={cv}%)") | |
| elif cv > 6: | |
| score += 6 | |
| # pNN50 | |
| if pnn50 > 40: | |
| score += 15; reasons.append(f"pNN50 Γ©levΓ© ({pnn50}%)") | |
| elif pnn50 > 20: | |
| score += 8 | |
| # Irregularity index | |
| if irr > 25: | |
| score += 10; reasons.append(f"IrrΓ©gularitΓ© marquΓ©e ({irr}%)") | |
| score = int(min(100, score)) | |
| # Classification | |
| if score < 25: | |
| result = "NORMAL" | |
| label = "Normal Sinus Rhythm" | |
| risk = "LOW" | |
| elif score < 50: | |
| result = "IRREGULAR" | |
| label = "Irregular Pattern Detected" | |
| risk = "MODERATE" | |
| else: | |
| result = "AF_SUSPECTED" | |
| label = "Atrial Fibrillation Suspected" | |
| risk = "HIGH" | |
| return { | |
| "af_score": score, | |
| "result": result, | |
| "label": label, | |
| "risk": risk, | |
| "reasons": reasons, | |
| } | |
| # ββ API Endpoints βββββββββββββββββββββββββββββββββββββββββ | |
| def health(): | |
| return {"status": "ok", "service": "AF Detector API v1.0"} | |
| async def analyze_video(video_file: UploadFile = File(...), fps: float = 30.0): | |
| """ | |
| Analyse une vidΓ©o de 30s pour dΓ©tecter la FA via rPPG. | |
| Retourne les mΓ©triques HRV et le score AF. | |
| """ | |
| # Sauvegarder la vidΓ©o temporairement | |
| suffix = os.path.splitext(video_file.filename)[-1] or ".mp4" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| tmp.write(await video_file.read()) | |
| tmp_path = tmp.name | |
| try: | |
| # Lire la vidΓ©o | |
| cap = cv2.VideoCapture(tmp_path) | |
| if not cap.isOpened(): | |
| return {"error": "Cannot open video file"} | |
| real_fps = cap.get(cv2.CAP_PROP_FPS) or fps | |
| frames = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: break | |
| # Subsample si fps > 30 | |
| if len(frames) % max(1, int(real_fps / 30)) == 0: | |
| frames.append(frame) | |
| cap.release() | |
| os.remove(tmp_path) | |
| if len(frames) < 60: | |
| return {"error": "Video too short. Minimum 30 seconds required.", "frames": len(frames)} | |
| print(f"[Analyze] Frames: {len(frames)} | FPS: {real_fps:.1f}", flush=True) | |
| # ββ rPPG extraction ββββββββββββββββββββββββββββββ | |
| green_signal, face_ratio = extract_rppg_signal(frames, real_fps) | |
| print(f"[Analyze] Face detected: {face_ratio:.1%}", flush=True) | |
| if face_ratio < 0.3: | |
| return {"error": "Face not detected in most frames. Ensure good lighting.", "face_ratio": face_ratio} | |
| # ββ DΓ©trending + filtrage ββββββββββββββββββββββββ | |
| # Supprimer la tendance lente (illumination) | |
| detrended = signal.detrend(green_signal) | |
| # Normaliser | |
| detrended = (detrended - np.mean(detrended)) / (np.std(detrended) + 1e-8) | |
| # Bandpass 0.7-4Hz | |
| filtered = bandpass_filter(detrended, real_fps) | |
| # ββ Peak detection βββββββββββββββββββββββββββββββ | |
| peaks = detect_peaks_rr(filtered, real_fps) | |
| print(f"[Analyze] Peaks detected: {len(peaks)}", flush=True) | |
| if len(peaks) < 8: | |
| return {"error": "Signal too noisy. Stay still and ensure good lighting.", "peaks": len(peaks)} | |
| # ββ RR intervals (ms) ββββββββββββββββββββββββββββ | |
| rr_intervals = [(peaks[i] - peaks[i-1]) / real_fps * 1000 | |
| for i in range(1, len(peaks))] | |
| # Filtrer les RR aberrants (< 300ms ou > 2000ms) | |
| rr_intervals = [rr for rr in rr_intervals if 300 < rr < 2000] | |
| if len(rr_intervals) < 5: | |
| return {"error": "Not enough valid beats detected."} | |
| # ββ HRV metrics ββββββββββββββββββββββββββββββββββ | |
| metrics = compute_hrv_metrics(rr_intervals) | |
| if not metrics: | |
| return {"error": "Cannot compute HRV metrics."} | |
| # ββ AF Score βββββββββββββββββββββββββββββββββββββ | |
| af_data = compute_af_score(metrics) | |
| print(f"[Analyze] BPM={metrics['bpm']} | RMSSD={metrics['rmssd']} | CV={metrics['cv']} | AF_Score={af_data['af_score']}", flush=True) | |
| return { | |
| "success": True, | |
| "face_ratio": round(face_ratio, 2), | |
| "frames": len(frames), | |
| "fps": round(real_fps, 1), | |
| "peaks_count": len(peaks), | |
| "rr_intervals": [round(rr, 1) for rr in rr_intervals[-30:]], # derniers 30 | |
| **metrics, | |
| **af_data, | |
| "disclaimer": "Experimental AI tool. Not a medical diagnosis. Consult a cardiologist." | |
| } | |
| except Exception as e: | |
| if os.path.exists(tmp_path): | |
| os.remove(tmp_path) | |
| print(f"[Error] {e}", flush=True) | |
| return {"error": str(e)} | |
| async def analyze_frames_json(data: dict): | |
| """ | |
| Alternative : reΓ§oit le signal vert directement depuis le frontend. | |
| Plus rapide β pas besoin d'encoder/dΓ©coder la vidΓ©o. | |
| """ | |
| green_signal = np.array(data.get("green_signal", [])) | |
| fps = float(data.get("fps", 30.0)) | |
| if len(green_signal) < 60: | |
| return {"error": "Signal too short. Minimum 60 samples required."} | |
| try: | |
| detrended = signal.detrend(green_signal) | |
| detrended = (detrended - np.mean(detrended)) / (np.std(detrended) + 1e-8) | |
| filtered = bandpass_filter(detrended, fps) | |
| peaks = detect_peaks_rr(filtered, fps) | |
| if len(peaks) < 5: | |
| return {"error": "Signal too noisy. Stay still.", "peaks": len(peaks)} | |
| rr_intervals = [(peaks[i] - peaks[i-1]) / fps * 1000 | |
| for i in range(1, len(peaks))] | |
| rr_intervals = [rr for rr in rr_intervals if 300 < rr < 2000] | |
| if len(rr_intervals) < 4: | |
| return {"error": "Not enough valid beats."} | |
| metrics = compute_hrv_metrics(rr_intervals) | |
| af_data = compute_af_score(metrics) | |
| print(f"[Frames] BPM={metrics['bpm']} | RMSSD={metrics['rmssd']} | AF={af_data['af_score']}", flush=True) | |
| return { | |
| "success": True, | |
| "rr_intervals": [round(rr, 1) for rr in rr_intervals], | |
| **metrics, | |
| **af_data, | |
| } | |
| except Exception as e: | |
| print(f"[Error] {e}", flush=True) | |
| return {"error": str(e)} | |