Stroke-ia commited on
Commit
bf3a89f
Β·
verified Β·
1 Parent(s): 712ab11

Upload af_app.py

Browse files
Files changed (1) hide show
  1. af_app.py +374 -0
af_app.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import mediapipe as mp
5
+ from fastapi import FastAPI, UploadFile, File
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from scipy import signal
8
+ from scipy.signal import find_peaks
9
+ import tempfile
10
+
11
+ print("[Init] Loading MediaPipe...", flush=True)
12
+ mp_face_mesh = mp.solutions.face_mesh
13
+ face_mesh = mp_face_mesh.FaceMesh(
14
+ static_image_mode=False,
15
+ max_num_faces=1,
16
+ refine_landmarks=True,
17
+ min_detection_confidence=0.5,
18
+ min_tracking_confidence=0.5
19
+ )
20
+ print("[Init] MediaPipe OK", flush=True)
21
+
22
+ app = FastAPI(title="AF Detector API")
23
+
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=["*"],
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ # ── Signal Processing ─────────────────────────────────────
32
+
33
+ def extract_rppg_signal(frames, fps=30):
34
+ """
35
+ Extrait le signal rPPG depuis les frames vidΓ©o.
36
+ Utilise le canal vert de la ROI du visage (MediaPipe).
37
+ """
38
+ green_signal = []
39
+ valid_frames = 0
40
+
41
+ for frame in frames:
42
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
43
+ result = face_mesh.process(rgb)
44
+
45
+ if result.multi_face_landmarks:
46
+ lm = result.multi_face_landmarks[0].landmark
47
+ h, w = frame.shape[:2]
48
+
49
+ # ROI = joues + front (zones riches en vaisseaux)
50
+ roi_points = [
51
+ # Front
52
+ (int(lm[10].x * w), int(lm[10].y * h)),
53
+ (int(lm[151].x * w), int(lm[151].y * h)),
54
+ # Joue gauche
55
+ (int(lm[234].x * w), int(lm[234].y * h)),
56
+ (int(lm[93].x * w), int(lm[93].y * h)),
57
+ # Joue droite
58
+ (int(lm[454].x * w), int(lm[454].y * h)),
59
+ (int(lm[323].x * w), int(lm[323].y * h)),
60
+ ]
61
+
62
+ # Bounding box de la ROI
63
+ xs = [p[0] for p in roi_points]
64
+ ys = [p[1] for p in roi_points]
65
+ x1, x2 = max(0, min(xs)), min(w, max(xs))
66
+ y1, y2 = max(0, min(ys)), min(h, max(ys))
67
+
68
+ if x2 > x1 and y2 > y1:
69
+ roi = frame[y1:y2, x1:x2]
70
+ # Canal vert (le plus sensible aux pulsations)
71
+ g_mean = np.mean(roi[:, :, 1])
72
+ green_signal.append(g_mean)
73
+ valid_frames += 1
74
+ else:
75
+ # Pas de visage — utiliser frame entière comme fallback
76
+ g_mean = np.mean(frame[:, :, 1])
77
+ green_signal.append(g_mean)
78
+
79
+ face_ratio = valid_frames / len(frames) if frames else 0
80
+ return np.array(green_signal), face_ratio
81
+
82
+
83
+ def bandpass_filter(signal_data, fs, lowcut=0.7, highcut=4.0, order=4):
84
+ """
85
+ Filtre passe-bande Butterworth.
86
+ 0.7 Hz = 42 BPM (min)
87
+ 4.0 Hz = 240 BPM (max)
88
+ """
89
+ nyq = fs / 2.0
90
+ low = lowcut / nyq
91
+ high = highcut / nyq
92
+ b, a = signal.butter(order, [low, high], btype='band')
93
+ return signal.filtfilt(b, a, signal_data)
94
+
95
+
96
+ def detect_peaks_rr(filtered_signal, fps):
97
+ """
98
+ DΓ©tecte les pics du signal cardiaque β†’ intervalles RR.
99
+ """
100
+ min_distance = int(fps * 0.35) # min 350ms entre pics (max ~170 BPM)
101
+ threshold = np.std(filtered_signal) * 0.3
102
+
103
+ peaks, properties = find_peaks(
104
+ filtered_signal,
105
+ distance=min_distance,
106
+ height=threshold
107
+ )
108
+ return peaks
109
+
110
+
111
+ def compute_hrv_metrics(rr_intervals_ms):
112
+ """
113
+ Calcule les mΓ©triques HRV classiques utilisΓ©es pour dΓ©tecter la FA.
114
+ """
115
+ if len(rr_intervals_ms) < 5:
116
+ return None
117
+
118
+ rr = np.array(rr_intervals_ms)
119
+
120
+ # MΓ©triques temporelles
121
+ mean_rr = np.mean(rr)
122
+ sdnn = np.std(rr) # variabilitΓ© globale
123
+ rmssd = np.sqrt(np.mean(np.diff(rr)**2)) # variabilitΓ© court terme
124
+ pnn50 = np.sum(np.abs(np.diff(rr)) > 50) / len(rr) * 100 # % diff > 50ms
125
+
126
+ # BPM
127
+ bpm = round(60000 / mean_rr)
128
+
129
+ # Coefficient de variation (CV) β€” clΓ© pour FA
130
+ cv = (sdnn / mean_rr) * 100
131
+
132
+ # Irregularity index β€” entropie approchΓ©e
133
+ diffs = np.abs(np.diff(rr))
134
+ irr_index = round(min(100, (np.mean(diffs) / mean_rr) * 100))
135
+
136
+ return {
137
+ "bpm": int(np.clip(bpm, 30, 250)),
138
+ "mean_rr": round(float(mean_rr), 1),
139
+ "sdnn": round(float(sdnn), 1),
140
+ "rmssd": round(float(rmssd), 1),
141
+ "pnn50": round(float(pnn50), 1),
142
+ "cv": round(float(cv), 2),
143
+ "irr_index": irr_index,
144
+ "rr_count": len(rr_intervals_ms),
145
+ }
146
+
147
+
148
+ def compute_af_score(metrics):
149
+ """
150
+ Score de risque FA (0-100) basΓ© sur les mΓ©triques HRV.
151
+
152
+ Critères cliniques FA :
153
+ - Absence de onde P régulière → RR irréguliers
154
+ - RMSSD Γ©levΓ©
155
+ - CV Γ©levΓ© (>10%)
156
+ - pNN50 Γ©levΓ©
157
+ - Pattern d'irrΓ©gularitΓ© sans rythme
158
+ """
159
+ score = 0
160
+ reasons = []
161
+
162
+ bpm = metrics["bpm"]
163
+ rmssd = metrics["rmssd"]
164
+ cv = metrics["cv"]
165
+ pnn50 = metrics["pnn50"]
166
+ irr = metrics["irr_index"]
167
+ sdnn = metrics["sdnn"]
168
+
169
+ # BPM anormal
170
+ if bpm < 50:
171
+ score += 15; reasons.append(f"Bradycardie ({bpm} BPM)")
172
+ elif bpm > 100:
173
+ score += 20; reasons.append(f"Tachycardie ({bpm} BPM)")
174
+
175
+ # RMSSD β€” variabilitΓ© Γ©levΓ©e = irrΓ©gularitΓ©
176
+ if rmssd > 100:
177
+ score += 30; reasons.append(f"RMSSD très élevé ({rmssd}ms)")
178
+ elif rmssd > 60:
179
+ score += 18; reasons.append(f"RMSSD Γ©levΓ© ({rmssd}ms)")
180
+ elif rmssd > 40:
181
+ score += 8
182
+
183
+ # CV β€” coefficient de variation
184
+ if cv > 15:
185
+ score += 25; reasons.append(f"VariabilitΓ© RR critique (CV={cv}%)")
186
+ elif cv > 10:
187
+ score += 15; reasons.append(f"VariabilitΓ© RR Γ©levΓ©e (CV={cv}%)")
188
+ elif cv > 6:
189
+ score += 6
190
+
191
+ # pNN50
192
+ if pnn50 > 40:
193
+ score += 15; reasons.append(f"pNN50 Γ©levΓ© ({pnn50}%)")
194
+ elif pnn50 > 20:
195
+ score += 8
196
+
197
+ # Irregularity index
198
+ if irr > 25:
199
+ score += 10; reasons.append(f"IrrΓ©gularitΓ© marquΓ©e ({irr}%)")
200
+
201
+ score = int(min(100, score))
202
+
203
+ # Classification
204
+ if score < 25:
205
+ result = "NORMAL"
206
+ label = "Normal Sinus Rhythm"
207
+ risk = "LOW"
208
+ elif score < 50:
209
+ result = "IRREGULAR"
210
+ label = "Irregular Pattern Detected"
211
+ risk = "MODERATE"
212
+ else:
213
+ result = "AF_SUSPECTED"
214
+ label = "Atrial Fibrillation Suspected"
215
+ risk = "HIGH"
216
+
217
+ return {
218
+ "af_score": score,
219
+ "result": result,
220
+ "label": label,
221
+ "risk": risk,
222
+ "reasons": reasons,
223
+ }
224
+
225
+
226
+ # ── API Endpoints ─────────────────────────────────────────
227
+
228
+ @app.get("/health")
229
+ def health():
230
+ return {"status": "ok", "service": "AF Detector API v1.0"}
231
+
232
+
233
+ @app.post("/analyze/")
234
+ async def analyze_video(video_file: UploadFile = File(...), fps: float = 30.0):
235
+ """
236
+ Analyse une vidΓ©o de 30s pour dΓ©tecter la FA via rPPG.
237
+ Retourne les mΓ©triques HRV et le score AF.
238
+ """
239
+ # Sauvegarder la vidΓ©o temporairement
240
+ suffix = os.path.splitext(video_file.filename)[-1] or ".mp4"
241
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
242
+ tmp.write(await video_file.read())
243
+ tmp_path = tmp.name
244
+
245
+ try:
246
+ # Lire la vidΓ©o
247
+ cap = cv2.VideoCapture(tmp_path)
248
+ if not cap.isOpened():
249
+ return {"error": "Cannot open video file"}
250
+
251
+ real_fps = cap.get(cv2.CAP_PROP_FPS) or fps
252
+ frames = []
253
+
254
+ while True:
255
+ ret, frame = cap.read()
256
+ if not ret: break
257
+ # Subsample si fps > 30
258
+ if len(frames) % max(1, int(real_fps / 30)) == 0:
259
+ frames.append(frame)
260
+
261
+ cap.release()
262
+ os.remove(tmp_path)
263
+
264
+ if len(frames) < 60:
265
+ return {"error": "Video too short. Minimum 30 seconds required.", "frames": len(frames)}
266
+
267
+ print(f"[Analyze] Frames: {len(frames)} | FPS: {real_fps:.1f}", flush=True)
268
+
269
+ # ── rPPG extraction ──────────────────────────────
270
+ green_signal, face_ratio = extract_rppg_signal(frames, real_fps)
271
+ print(f"[Analyze] Face detected: {face_ratio:.1%}", flush=True)
272
+
273
+ if face_ratio < 0.3:
274
+ return {"error": "Face not detected in most frames. Ensure good lighting.", "face_ratio": face_ratio}
275
+
276
+ # ── DΓ©trending + filtrage ────────────────────────
277
+ # Supprimer la tendance lente (illumination)
278
+ detrended = signal.detrend(green_signal)
279
+
280
+ # Normaliser
281
+ detrended = (detrended - np.mean(detrended)) / (np.std(detrended) + 1e-8)
282
+
283
+ # Bandpass 0.7-4Hz
284
+ filtered = bandpass_filter(detrended, real_fps)
285
+
286
+ # ── Peak detection ───────────────────────────────
287
+ peaks = detect_peaks_rr(filtered, real_fps)
288
+ print(f"[Analyze] Peaks detected: {len(peaks)}", flush=True)
289
+
290
+ if len(peaks) < 8:
291
+ return {"error": "Signal too noisy. Stay still and ensure good lighting.", "peaks": len(peaks)}
292
+
293
+ # ── RR intervals (ms) ────────────────────────────
294
+ rr_intervals = [(peaks[i] - peaks[i-1]) / real_fps * 1000
295
+ for i in range(1, len(peaks))]
296
+
297
+ # Filtrer les RR aberrants (< 300ms ou > 2000ms)
298
+ rr_intervals = [rr for rr in rr_intervals if 300 < rr < 2000]
299
+
300
+ if len(rr_intervals) < 5:
301
+ return {"error": "Not enough valid beats detected."}
302
+
303
+ # ── HRV metrics ──────────────────────────────────
304
+ metrics = compute_hrv_metrics(rr_intervals)
305
+ if not metrics:
306
+ return {"error": "Cannot compute HRV metrics."}
307
+
308
+ # ── AF Score ─────────────────────────────────────
309
+ af_data = compute_af_score(metrics)
310
+
311
+ print(f"[Analyze] BPM={metrics['bpm']} | RMSSD={metrics['rmssd']} | CV={metrics['cv']} | AF_Score={af_data['af_score']}", flush=True)
312
+
313
+ return {
314
+ "success": True,
315
+ "face_ratio": round(face_ratio, 2),
316
+ "frames": len(frames),
317
+ "fps": round(real_fps, 1),
318
+ "peaks_count": len(peaks),
319
+ "rr_intervals": [round(rr, 1) for rr in rr_intervals[-30:]], # derniers 30
320
+ **metrics,
321
+ **af_data,
322
+ "disclaimer": "Experimental AI tool. Not a medical diagnosis. Consult a cardiologist."
323
+ }
324
+
325
+ except Exception as e:
326
+ if os.path.exists(tmp_path):
327
+ os.remove(tmp_path)
328
+ print(f"[Error] {e}", flush=True)
329
+ return {"error": str(e)}
330
+
331
+
332
+ @app.post("/analyze_frames/")
333
+ async def analyze_frames_json(data: dict):
334
+ """
335
+ Alternative : reΓ§oit le signal vert directement depuis le frontend.
336
+ Plus rapide β€” pas besoin d'encoder/dΓ©coder la vidΓ©o.
337
+ """
338
+ green_signal = np.array(data.get("green_signal", []))
339
+ fps = float(data.get("fps", 30.0))
340
+
341
+ if len(green_signal) < 60:
342
+ return {"error": "Signal too short. Minimum 60 samples required."}
343
+
344
+ try:
345
+ detrended = signal.detrend(green_signal)
346
+ detrended = (detrended - np.mean(detrended)) / (np.std(detrended) + 1e-8)
347
+ filtered = bandpass_filter(detrended, fps)
348
+ peaks = detect_peaks_rr(filtered, fps)
349
+
350
+ if len(peaks) < 5:
351
+ return {"error": "Signal too noisy. Stay still.", "peaks": len(peaks)}
352
+
353
+ rr_intervals = [(peaks[i] - peaks[i-1]) / fps * 1000
354
+ for i in range(1, len(peaks))]
355
+ rr_intervals = [rr for rr in rr_intervals if 300 < rr < 2000]
356
+
357
+ if len(rr_intervals) < 4:
358
+ return {"error": "Not enough valid beats."}
359
+
360
+ metrics = compute_hrv_metrics(rr_intervals)
361
+ af_data = compute_af_score(metrics)
362
+
363
+ print(f"[Frames] BPM={metrics['bpm']} | RMSSD={metrics['rmssd']} | AF={af_data['af_score']}", flush=True)
364
+
365
+ return {
366
+ "success": True,
367
+ "rr_intervals": [round(rr, 1) for rr in rr_intervals],
368
+ **metrics,
369
+ **af_data,
370
+ }
371
+
372
+ except Exception as e:
373
+ print(f"[Error] {e}", flush=True)
374
+ return {"error": str(e)}