niol08 commited on
Commit
08caee5
·
verified ·
1 Parent(s): e5bea5b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ from fastapi import FastAPI, UploadFile, File
3
+ import numpy as np
4
+ from keras.models import load_model
5
+ from graph import zeropad, zeropad_output_shape
6
+ from config import get_config
7
+
8
+ app = FastAPI(
9
+ title="ECG Classification Backend",
10
+ description="REST API for ECG heartbeat classification",
11
+ version="1.1.0"
12
+ )
13
+
14
+ # ======= Load the model at startup =======
15
+ config = get_config()
16
+ MODEL_PATH = "MLII-latest.keras"
17
+
18
+ print("🔹 Loading model:", MODEL_PATH)
19
+ model = load_model(
20
+ MODEL_PATH,
21
+ custom_objects={
22
+ "zeropad": zeropad,
23
+ "zeropad_output_shape": zeropad_output_shape
24
+ },
25
+ compile=False
26
+ )
27
+
28
+ # ECG class mappings
29
+ CLASSES = ["N", "V", "/", "A", "F", "~"]
30
+ CLASS_NAMES = {
31
+ "N": "Normal sinus beat",
32
+ "V": "Premature Ventricular Contraction (PVC)",
33
+ "/": "Paced beat (Pacemaker)",
34
+ "A": "Atrial Premature Beat",
35
+ "F": "Fusion of Ventricular & Normal Beat",
36
+ "~": "Unclassifiable / Noise"
37
+ }
38
+
39
+ @app.get("/")
40
+ async def root():
41
+ return {"message": "✅ ECG Inference API is running successfully!"}
42
+
43
+
44
+ @app.post("/predict-ecg/")
45
+ async def predict_ecg(file: UploadFile = File(...)):
46
+ """
47
+ Accepts a CSV or TXT file containing ECG signal samples.
48
+ Each value should be a single float per line.
49
+ """
50
+ content = await file.read()
51
+ text = content.decode("utf-8").strip().splitlines()
52
+
53
+ # Parse numeric values
54
+ try:
55
+ data = np.array([float(x.strip()) for x in text if x.strip() != ""])
56
+ except Exception:
57
+ return {"error": "Invalid file format. Please upload numeric ECG values only."}
58
+
59
+ # Normalize signal length (model expects 256 samples)
60
+ max_len = 256
61
+ if len(data) > max_len:
62
+ data = data[:max_len]
63
+ elif len(data) < max_len:
64
+ data = np.pad(data, (0, max_len - len(data)))
65
+
66
+ data = data.reshape(1, max_len, 1)
67
+
68
+ # Run model inference
69
+ preds = model.predict(data, verbose=0)
70
+ label_idx = int(np.argmax(preds))
71
+ confidence = float(np.max(preds))
72
+ label = CLASSES[label_idx]
73
+ description = CLASS_NAMES[label]
74
+
75
+ return {
76
+ "label": label,
77
+ "description": description,
78
+ "confidence": round(confidence, 4),
79
+ "samples_used": len(data[0])
80
+ }