AFDetector / app.py
Stroke-ia's picture
Rename af_app.py to app.py
5a6d836 verified
import os
import cv2
import numpy as np
import mediapipe as mp
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from scipy import signal
from scipy.signal import find_peaks
import tempfile
print("[Init] Loading MediaPipe...", flush=True)
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(
static_image_mode=False,
max_num_faces=1,
refine_landmarks=True,
min_detection_confidence=0.5,
min_tracking_confidence=0.5
)
print("[Init] MediaPipe OK", flush=True)
app = FastAPI(title="AF Detector API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── Signal Processing ─────────────────────────────────────
def extract_rppg_signal(frames, fps=30):
"""
Extrait le signal rPPG depuis les frames vidΓ©o.
Utilise le canal vert de la ROI du visage (MediaPipe).
"""
green_signal = []
valid_frames = 0
for frame in frames:
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
result = face_mesh.process(rgb)
if result.multi_face_landmarks:
lm = result.multi_face_landmarks[0].landmark
h, w = frame.shape[:2]
# ROI = joues + front (zones riches en vaisseaux)
roi_points = [
# Front
(int(lm[10].x * w), int(lm[10].y * h)),
(int(lm[151].x * w), int(lm[151].y * h)),
# Joue gauche
(int(lm[234].x * w), int(lm[234].y * h)),
(int(lm[93].x * w), int(lm[93].y * h)),
# Joue droite
(int(lm[454].x * w), int(lm[454].y * h)),
(int(lm[323].x * w), int(lm[323].y * h)),
]
# Bounding box de la ROI
xs = [p[0] for p in roi_points]
ys = [p[1] for p in roi_points]
x1, x2 = max(0, min(xs)), min(w, max(xs))
y1, y2 = max(0, min(ys)), min(h, max(ys))
if x2 > x1 and y2 > y1:
roi = frame[y1:y2, x1:x2]
# Canal vert (le plus sensible aux pulsations)
g_mean = np.mean(roi[:, :, 1])
green_signal.append(g_mean)
valid_frames += 1
else:
# Pas de visage — utiliser frame entière comme fallback
g_mean = np.mean(frame[:, :, 1])
green_signal.append(g_mean)
face_ratio = valid_frames / len(frames) if frames else 0
return np.array(green_signal), face_ratio
def bandpass_filter(signal_data, fs, lowcut=0.7, highcut=4.0, order=4):
"""
Filtre passe-bande Butterworth.
0.7 Hz = 42 BPM (min)
4.0 Hz = 240 BPM (max)
"""
nyq = fs / 2.0
low = lowcut / nyq
high = highcut / nyq
b, a = signal.butter(order, [low, high], btype='band')
return signal.filtfilt(b, a, signal_data)
def detect_peaks_rr(filtered_signal, fps):
"""
DΓ©tecte les pics du signal cardiaque β†’ intervalles RR.
"""
min_distance = int(fps * 0.35) # min 350ms entre pics (max ~170 BPM)
threshold = np.std(filtered_signal) * 0.3
peaks, properties = find_peaks(
filtered_signal,
distance=min_distance,
height=threshold
)
return peaks
def compute_hrv_metrics(rr_intervals_ms):
"""
Calcule les mΓ©triques HRV classiques utilisΓ©es pour dΓ©tecter la FA.
"""
if len(rr_intervals_ms) < 5:
return None
rr = np.array(rr_intervals_ms)
# MΓ©triques temporelles
mean_rr = np.mean(rr)
sdnn = np.std(rr) # variabilitΓ© globale
rmssd = np.sqrt(np.mean(np.diff(rr)**2)) # variabilitΓ© court terme
pnn50 = np.sum(np.abs(np.diff(rr)) > 50) / len(rr) * 100 # % diff > 50ms
# BPM
bpm = round(60000 / mean_rr)
# Coefficient de variation (CV) β€” clΓ© pour FA
cv = (sdnn / mean_rr) * 100
# Irregularity index β€” entropie approchΓ©e
diffs = np.abs(np.diff(rr))
irr_index = round(min(100, (np.mean(diffs) / mean_rr) * 100))
return {
"bpm": int(np.clip(bpm, 30, 250)),
"mean_rr": round(float(mean_rr), 1),
"sdnn": round(float(sdnn), 1),
"rmssd": round(float(rmssd), 1),
"pnn50": round(float(pnn50), 1),
"cv": round(float(cv), 2),
"irr_index": irr_index,
"rr_count": len(rr_intervals_ms),
}
def compute_af_score(metrics):
"""
Score de risque FA (0-100) basΓ© sur les mΓ©triques HRV.
Critères cliniques FA :
- Absence de onde P régulière → RR irréguliers
- RMSSD Γ©levΓ©
- CV Γ©levΓ© (>10%)
- pNN50 Γ©levΓ©
- Pattern d'irrΓ©gularitΓ© sans rythme
"""
score = 0
reasons = []
bpm = metrics["bpm"]
rmssd = metrics["rmssd"]
cv = metrics["cv"]
pnn50 = metrics["pnn50"]
irr = metrics["irr_index"]
sdnn = metrics["sdnn"]
# BPM anormal
if bpm < 50:
score += 15; reasons.append(f"Bradycardie ({bpm} BPM)")
elif bpm > 100:
score += 20; reasons.append(f"Tachycardie ({bpm} BPM)")
# RMSSD β€” variabilitΓ© Γ©levΓ©e = irrΓ©gularitΓ©
if rmssd > 100:
score += 30; reasons.append(f"RMSSD très élevé ({rmssd}ms)")
elif rmssd > 60:
score += 18; reasons.append(f"RMSSD Γ©levΓ© ({rmssd}ms)")
elif rmssd > 40:
score += 8
# CV β€” coefficient de variation
if cv > 15:
score += 25; reasons.append(f"VariabilitΓ© RR critique (CV={cv}%)")
elif cv > 10:
score += 15; reasons.append(f"VariabilitΓ© RR Γ©levΓ©e (CV={cv}%)")
elif cv > 6:
score += 6
# pNN50
if pnn50 > 40:
score += 15; reasons.append(f"pNN50 Γ©levΓ© ({pnn50}%)")
elif pnn50 > 20:
score += 8
# Irregularity index
if irr > 25:
score += 10; reasons.append(f"IrrΓ©gularitΓ© marquΓ©e ({irr}%)")
score = int(min(100, score))
# Classification
if score < 25:
result = "NORMAL"
label = "Normal Sinus Rhythm"
risk = "LOW"
elif score < 50:
result = "IRREGULAR"
label = "Irregular Pattern Detected"
risk = "MODERATE"
else:
result = "AF_SUSPECTED"
label = "Atrial Fibrillation Suspected"
risk = "HIGH"
return {
"af_score": score,
"result": result,
"label": label,
"risk": risk,
"reasons": reasons,
}
# ── API Endpoints ─────────────────────────────────────────
@app.get("/health")
def health():
return {"status": "ok", "service": "AF Detector API v1.0"}
@app.post("/analyze/")
async def analyze_video(video_file: UploadFile = File(...), fps: float = 30.0):
"""
Analyse une vidΓ©o de 30s pour dΓ©tecter la FA via rPPG.
Retourne les mΓ©triques HRV et le score AF.
"""
# Sauvegarder la vidΓ©o temporairement
suffix = os.path.splitext(video_file.filename)[-1] or ".mp4"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(await video_file.read())
tmp_path = tmp.name
try:
# Lire la vidΓ©o
cap = cv2.VideoCapture(tmp_path)
if not cap.isOpened():
return {"error": "Cannot open video file"}
real_fps = cap.get(cv2.CAP_PROP_FPS) or fps
frames = []
while True:
ret, frame = cap.read()
if not ret: break
# Subsample si fps > 30
if len(frames) % max(1, int(real_fps / 30)) == 0:
frames.append(frame)
cap.release()
os.remove(tmp_path)
if len(frames) < 60:
return {"error": "Video too short. Minimum 30 seconds required.", "frames": len(frames)}
print(f"[Analyze] Frames: {len(frames)} | FPS: {real_fps:.1f}", flush=True)
# ── rPPG extraction ──────────────────────────────
green_signal, face_ratio = extract_rppg_signal(frames, real_fps)
print(f"[Analyze] Face detected: {face_ratio:.1%}", flush=True)
if face_ratio < 0.3:
return {"error": "Face not detected in most frames. Ensure good lighting.", "face_ratio": face_ratio}
# ── DΓ©trending + filtrage ────────────────────────
# Supprimer la tendance lente (illumination)
detrended = signal.detrend(green_signal)
# Normaliser
detrended = (detrended - np.mean(detrended)) / (np.std(detrended) + 1e-8)
# Bandpass 0.7-4Hz
filtered = bandpass_filter(detrended, real_fps)
# ── Peak detection ───────────────────────────────
peaks = detect_peaks_rr(filtered, real_fps)
print(f"[Analyze] Peaks detected: {len(peaks)}", flush=True)
if len(peaks) < 8:
return {"error": "Signal too noisy. Stay still and ensure good lighting.", "peaks": len(peaks)}
# ── RR intervals (ms) ────────────────────────────
rr_intervals = [(peaks[i] - peaks[i-1]) / real_fps * 1000
for i in range(1, len(peaks))]
# Filtrer les RR aberrants (< 300ms ou > 2000ms)
rr_intervals = [rr for rr in rr_intervals if 300 < rr < 2000]
if len(rr_intervals) < 5:
return {"error": "Not enough valid beats detected."}
# ── HRV metrics ──────────────────────────────────
metrics = compute_hrv_metrics(rr_intervals)
if not metrics:
return {"error": "Cannot compute HRV metrics."}
# ── AF Score ─────────────────────────────────────
af_data = compute_af_score(metrics)
print(f"[Analyze] BPM={metrics['bpm']} | RMSSD={metrics['rmssd']} | CV={metrics['cv']} | AF_Score={af_data['af_score']}", flush=True)
return {
"success": True,
"face_ratio": round(face_ratio, 2),
"frames": len(frames),
"fps": round(real_fps, 1),
"peaks_count": len(peaks),
"rr_intervals": [round(rr, 1) for rr in rr_intervals[-30:]], # derniers 30
**metrics,
**af_data,
"disclaimer": "Experimental AI tool. Not a medical diagnosis. Consult a cardiologist."
}
except Exception as e:
if os.path.exists(tmp_path):
os.remove(tmp_path)
print(f"[Error] {e}", flush=True)
return {"error": str(e)}
@app.post("/analyze_frames/")
async def analyze_frames_json(data: dict):
"""
Alternative : reΓ§oit le signal vert directement depuis le frontend.
Plus rapide β€” pas besoin d'encoder/dΓ©coder la vidΓ©o.
"""
green_signal = np.array(data.get("green_signal", []))
fps = float(data.get("fps", 30.0))
if len(green_signal) < 60:
return {"error": "Signal too short. Minimum 60 samples required."}
try:
detrended = signal.detrend(green_signal)
detrended = (detrended - np.mean(detrended)) / (np.std(detrended) + 1e-8)
filtered = bandpass_filter(detrended, fps)
peaks = detect_peaks_rr(filtered, fps)
if len(peaks) < 5:
return {"error": "Signal too noisy. Stay still.", "peaks": len(peaks)}
rr_intervals = [(peaks[i] - peaks[i-1]) / fps * 1000
for i in range(1, len(peaks))]
rr_intervals = [rr for rr in rr_intervals if 300 < rr < 2000]
if len(rr_intervals) < 4:
return {"error": "Not enough valid beats."}
metrics = compute_hrv_metrics(rr_intervals)
af_data = compute_af_score(metrics)
print(f"[Frames] BPM={metrics['bpm']} | RMSSD={metrics['rmssd']} | AF={af_data['af_score']}", flush=True)
return {
"success": True,
"rr_intervals": [round(rr, 1) for rr in rr_intervals],
**metrics,
**af_data,
}
except Exception as e:
print(f"[Error] {e}", flush=True)
return {"error": str(e)}