Spaces:
Sleeping
Sleeping
File size: 12,494 Bytes
bf3a89f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 | import os
import cv2
import numpy as np
import mediapipe as mp
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from scipy import signal
from scipy.signal import find_peaks
import tempfile
print("[Init] Loading MediaPipe...", flush=True)
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(
static_image_mode=False,
max_num_faces=1,
refine_landmarks=True,
min_detection_confidence=0.5,
min_tracking_confidence=0.5
)
print("[Init] MediaPipe OK", flush=True)
app = FastAPI(title="AF Detector API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── Signal Processing ─────────────────────────────────────
def extract_rppg_signal(frames, fps=30):
"""
Extrait le signal rPPG depuis les frames vidéo.
Utilise le canal vert de la ROI du visage (MediaPipe).
"""
green_signal = []
valid_frames = 0
for frame in frames:
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
result = face_mesh.process(rgb)
if result.multi_face_landmarks:
lm = result.multi_face_landmarks[0].landmark
h, w = frame.shape[:2]
# ROI = joues + front (zones riches en vaisseaux)
roi_points = [
# Front
(int(lm[10].x * w), int(lm[10].y * h)),
(int(lm[151].x * w), int(lm[151].y * h)),
# Joue gauche
(int(lm[234].x * w), int(lm[234].y * h)),
(int(lm[93].x * w), int(lm[93].y * h)),
# Joue droite
(int(lm[454].x * w), int(lm[454].y * h)),
(int(lm[323].x * w), int(lm[323].y * h)),
]
# Bounding box de la ROI
xs = [p[0] for p in roi_points]
ys = [p[1] for p in roi_points]
x1, x2 = max(0, min(xs)), min(w, max(xs))
y1, y2 = max(0, min(ys)), min(h, max(ys))
if x2 > x1 and y2 > y1:
roi = frame[y1:y2, x1:x2]
# Canal vert (le plus sensible aux pulsations)
g_mean = np.mean(roi[:, :, 1])
green_signal.append(g_mean)
valid_frames += 1
else:
# Pas de visage — utiliser frame entière comme fallback
g_mean = np.mean(frame[:, :, 1])
green_signal.append(g_mean)
face_ratio = valid_frames / len(frames) if frames else 0
return np.array(green_signal), face_ratio
def bandpass_filter(signal_data, fs, lowcut=0.7, highcut=4.0, order=4):
"""
Filtre passe-bande Butterworth.
0.7 Hz = 42 BPM (min)
4.0 Hz = 240 BPM (max)
"""
nyq = fs / 2.0
low = lowcut / nyq
high = highcut / nyq
b, a = signal.butter(order, [low, high], btype='band')
return signal.filtfilt(b, a, signal_data)
def detect_peaks_rr(filtered_signal, fps):
"""
Détecte les pics du signal cardiaque → intervalles RR.
"""
min_distance = int(fps * 0.35) # min 350ms entre pics (max ~170 BPM)
threshold = np.std(filtered_signal) * 0.3
peaks, properties = find_peaks(
filtered_signal,
distance=min_distance,
height=threshold
)
return peaks
def compute_hrv_metrics(rr_intervals_ms):
"""
Calcule les métriques HRV classiques utilisées pour détecter la FA.
"""
if len(rr_intervals_ms) < 5:
return None
rr = np.array(rr_intervals_ms)
# Métriques temporelles
mean_rr = np.mean(rr)
sdnn = np.std(rr) # variabilité globale
rmssd = np.sqrt(np.mean(np.diff(rr)**2)) # variabilité court terme
pnn50 = np.sum(np.abs(np.diff(rr)) > 50) / len(rr) * 100 # % diff > 50ms
# BPM
bpm = round(60000 / mean_rr)
# Coefficient de variation (CV) — clé pour FA
cv = (sdnn / mean_rr) * 100
# Irregularity index — entropie approchée
diffs = np.abs(np.diff(rr))
irr_index = round(min(100, (np.mean(diffs) / mean_rr) * 100))
return {
"bpm": int(np.clip(bpm, 30, 250)),
"mean_rr": round(float(mean_rr), 1),
"sdnn": round(float(sdnn), 1),
"rmssd": round(float(rmssd), 1),
"pnn50": round(float(pnn50), 1),
"cv": round(float(cv), 2),
"irr_index": irr_index,
"rr_count": len(rr_intervals_ms),
}
def compute_af_score(metrics):
"""
Score de risque FA (0-100) basé sur les métriques HRV.
Critères cliniques FA :
- Absence de onde P régulière → RR irréguliers
- RMSSD élevé
- CV élevé (>10%)
- pNN50 élevé
- Pattern d'irrégularité sans rythme
"""
score = 0
reasons = []
bpm = metrics["bpm"]
rmssd = metrics["rmssd"]
cv = metrics["cv"]
pnn50 = metrics["pnn50"]
irr = metrics["irr_index"]
sdnn = metrics["sdnn"]
# BPM anormal
if bpm < 50:
score += 15; reasons.append(f"Bradycardie ({bpm} BPM)")
elif bpm > 100:
score += 20; reasons.append(f"Tachycardie ({bpm} BPM)")
# RMSSD — variabilité élevée = irrégularité
if rmssd > 100:
score += 30; reasons.append(f"RMSSD très élevé ({rmssd}ms)")
elif rmssd > 60:
score += 18; reasons.append(f"RMSSD élevé ({rmssd}ms)")
elif rmssd > 40:
score += 8
# CV — coefficient de variation
if cv > 15:
score += 25; reasons.append(f"Variabilité RR critique (CV={cv}%)")
elif cv > 10:
score += 15; reasons.append(f"Variabilité RR élevée (CV={cv}%)")
elif cv > 6:
score += 6
# pNN50
if pnn50 > 40:
score += 15; reasons.append(f"pNN50 élevé ({pnn50}%)")
elif pnn50 > 20:
score += 8
# Irregularity index
if irr > 25:
score += 10; reasons.append(f"Irrégularité marquée ({irr}%)")
score = int(min(100, score))
# Classification
if score < 25:
result = "NORMAL"
label = "Normal Sinus Rhythm"
risk = "LOW"
elif score < 50:
result = "IRREGULAR"
label = "Irregular Pattern Detected"
risk = "MODERATE"
else:
result = "AF_SUSPECTED"
label = "Atrial Fibrillation Suspected"
risk = "HIGH"
return {
"af_score": score,
"result": result,
"label": label,
"risk": risk,
"reasons": reasons,
}
# ── API Endpoints ─────────────────────────────────────────
@app.get("/health")
def health():
return {"status": "ok", "service": "AF Detector API v1.0"}
@app.post("/analyze/")
async def analyze_video(video_file: UploadFile = File(...), fps: float = 30.0):
"""
Analyse une vidéo de 30s pour détecter la FA via rPPG.
Retourne les métriques HRV et le score AF.
"""
# Sauvegarder la vidéo temporairement
suffix = os.path.splitext(video_file.filename)[-1] or ".mp4"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(await video_file.read())
tmp_path = tmp.name
try:
# Lire la vidéo
cap = cv2.VideoCapture(tmp_path)
if not cap.isOpened():
return {"error": "Cannot open video file"}
real_fps = cap.get(cv2.CAP_PROP_FPS) or fps
frames = []
while True:
ret, frame = cap.read()
if not ret: break
# Subsample si fps > 30
if len(frames) % max(1, int(real_fps / 30)) == 0:
frames.append(frame)
cap.release()
os.remove(tmp_path)
if len(frames) < 60:
return {"error": "Video too short. Minimum 30 seconds required.", "frames": len(frames)}
print(f"[Analyze] Frames: {len(frames)} | FPS: {real_fps:.1f}", flush=True)
# ── rPPG extraction ──────────────────────────────
green_signal, face_ratio = extract_rppg_signal(frames, real_fps)
print(f"[Analyze] Face detected: {face_ratio:.1%}", flush=True)
if face_ratio < 0.3:
return {"error": "Face not detected in most frames. Ensure good lighting.", "face_ratio": face_ratio}
# ── Détrending + filtrage ────────────────────────
# Supprimer la tendance lente (illumination)
detrended = signal.detrend(green_signal)
# Normaliser
detrended = (detrended - np.mean(detrended)) / (np.std(detrended) + 1e-8)
# Bandpass 0.7-4Hz
filtered = bandpass_filter(detrended, real_fps)
# ── Peak detection ───────────────────────────────
peaks = detect_peaks_rr(filtered, real_fps)
print(f"[Analyze] Peaks detected: {len(peaks)}", flush=True)
if len(peaks) < 8:
return {"error": "Signal too noisy. Stay still and ensure good lighting.", "peaks": len(peaks)}
# ── RR intervals (ms) ────────────────────────────
rr_intervals = [(peaks[i] - peaks[i-1]) / real_fps * 1000
for i in range(1, len(peaks))]
# Filtrer les RR aberrants (< 300ms ou > 2000ms)
rr_intervals = [rr for rr in rr_intervals if 300 < rr < 2000]
if len(rr_intervals) < 5:
return {"error": "Not enough valid beats detected."}
# ── HRV metrics ──────────────────────────────────
metrics = compute_hrv_metrics(rr_intervals)
if not metrics:
return {"error": "Cannot compute HRV metrics."}
# ── AF Score ─────────────────────────────────────
af_data = compute_af_score(metrics)
print(f"[Analyze] BPM={metrics['bpm']} | RMSSD={metrics['rmssd']} | CV={metrics['cv']} | AF_Score={af_data['af_score']}", flush=True)
return {
"success": True,
"face_ratio": round(face_ratio, 2),
"frames": len(frames),
"fps": round(real_fps, 1),
"peaks_count": len(peaks),
"rr_intervals": [round(rr, 1) for rr in rr_intervals[-30:]], # derniers 30
**metrics,
**af_data,
"disclaimer": "Experimental AI tool. Not a medical diagnosis. Consult a cardiologist."
}
except Exception as e:
if os.path.exists(tmp_path):
os.remove(tmp_path)
print(f"[Error] {e}", flush=True)
return {"error": str(e)}
@app.post("/analyze_frames/")
async def analyze_frames_json(data: dict):
"""
Alternative : reçoit le signal vert directement depuis le frontend.
Plus rapide — pas besoin d'encoder/décoder la vidéo.
"""
green_signal = np.array(data.get("green_signal", []))
fps = float(data.get("fps", 30.0))
if len(green_signal) < 60:
return {"error": "Signal too short. Minimum 60 samples required."}
try:
detrended = signal.detrend(green_signal)
detrended = (detrended - np.mean(detrended)) / (np.std(detrended) + 1e-8)
filtered = bandpass_filter(detrended, fps)
peaks = detect_peaks_rr(filtered, fps)
if len(peaks) < 5:
return {"error": "Signal too noisy. Stay still.", "peaks": len(peaks)}
rr_intervals = [(peaks[i] - peaks[i-1]) / fps * 1000
for i in range(1, len(peaks))]
rr_intervals = [rr for rr in rr_intervals if 300 < rr < 2000]
if len(rr_intervals) < 4:
return {"error": "Not enough valid beats."}
metrics = compute_hrv_metrics(rr_intervals)
af_data = compute_af_score(metrics)
print(f"[Frames] BPM={metrics['bpm']} | RMSSD={metrics['rmssd']} | AF={af_data['af_score']}", flush=True)
return {
"success": True,
"rr_intervals": [round(rr, 1) for rr in rr_intervals],
**metrics,
**af_data,
}
except Exception as e:
print(f"[Error] {e}", flush=True)
return {"error": str(e)}
|