Stroke-ia commited on
Commit
907373e
·
verified ·
1 Parent(s): b1cd8f4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from fastapi import FastAPI, UploadFile, File
5
+ from fastapi.staticfiles import StaticFiles
6
+ from scipy import signal
7
+ import soundfile as sf
8
+ from datetime import datetime
9
+ import threading, time
10
+
11
+ from team_code import base_model # ton architecture
12
+
13
+ # ----------------------------
14
+ # CONFIG
15
+ # ----------------------------
16
+ SIG_LEN = 32256
17
+ N_FEATURES = 1
18
+ MODEL_PATH = "pretrained_model.h5"
19
+ OUTPUT_DIR = "/tmp/audio_results"
20
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
21
+
22
+ BASE_URL = "https://heart-murmur-api.hf.space" # adapte selon ton déploiement
23
+
24
+ # ----------------------------
25
+ # CHARGEMENT DU MODÈLE
26
+ # ----------------------------
27
+ print("[INFO] Chargement du modèle TensorFlow...")
28
+ model = base_model(SIG_LEN, N_FEATURES)
29
+ model.load_weights(MODEL_PATH)
30
+ model.compile() # juste pour être sûr
31
+ print("[INFO] Modèle chargé ✅")
32
+
33
+ # ----------------------------
34
+ # FASTAPI
35
+ # ----------------------------
36
+ app = FastAPI(title="Heart Murmur Detection API")
37
+ app.mount("/files", StaticFiles(directory=OUTPUT_DIR), name="files")
38
+
39
+ # ----------------------------
40
+ # FONCTION DE PRÉTRAITEMENT
41
+ # ----------------------------
42
+ def preprocess_audio(file_path):
43
+ data, sr = sf.read(file_path)
44
+ if data.ndim > 1:
45
+ data = np.mean(data, axis=1)
46
+ resampled = signal.resample(data, SIG_LEN)
47
+ return np.expand_dims(resampled, axis=(0, 2)) # (1, sig_len, 1)
48
+
49
+ # ----------------------------
50
+ # ENDPOINT PRINCIPAL
51
+ # ----------------------------
52
+ @app.post("/predict/")
53
+ async def predict_murmur(audio_file: UploadFile = File(...)):
54
+ """
55
+ Upload un fichier audio (.wav, .mp3) → renvoie diagnostic + probabilité
56
+ """
57
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
58
+ tmp_path = os.path.join(OUTPUT_DIR, f"{timestamp}_{audio_file.filename}")
59
+ with open(tmp_path, "wb") as f:
60
+ f.write(await audio_file.read())
61
+
62
+ try:
63
+ x = preprocess_audio(tmp_path)
64
+ pred = float(model.predict(x)[0][0])
65
+ label = "Abnormal" if pred > 0.5 else "Normal"
66
+ prob = round(pred, 3)
67
+
68
+ report_path = os.path.join(OUTPUT_DIR, f"report_{timestamp}.txt")
69
+ with open(report_path, "w") as f:
70
+ f.write(f"Result: {label}\nProbability: {prob}\n")
71
+
72
+ return {
73
+ "diagnosis": label,
74
+ "probability": prob,
75
+ "rapport_url": f"{BASE_URL}/files/{os.path.basename(report_path)}",
76
+ "message": "✅ Analyse audio terminée."
77
+ }
78
+
79
+ except Exception as e:
80
+ return {"error": str(e)}
81
+
82
+ finally:
83
+ if os.path.exists(tmp_path):
84
+ os.remove(tmp_path)
85
+
86
+ # ----------------------------
87
+ # AUTO-CLEANUP
88
+ # ----------------------------
89
+ def auto_cleanup(interval_minutes=10):
90
+ while True:
91
+ time.sleep(interval_minutes * 60)
92
+ for file in os.listdir(OUTPUT_DIR):
93
+ try:
94
+ os.remove(os.path.join(OUTPUT_DIR, file))
95
+ print(f"[CLEANUP] Fichier supprimé : {file}")
96
+ except Exception as e:
97
+ print(f"[CLEANUP] Erreur : {e}")
98
+
99
+ threading.Thread(target=auto_cleanup, daemon=True).start()