lukiod commited on
Commit
28dbb9e
·
verified ·
1 Parent(s): ea66f77

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ import tensorflow as tf
5
+ import wfdb
6
+ import tempfile
7
+ import os
8
+ from scipy.signal import resample
9
+ import matplotlib.pyplot as plt
10
+ # Custom activation functions
11
+ def sin_activation(x):
12
+ return tf.math.sin(x)
13
+
14
+ def cos_activation(x):
15
+ return tf.math.cos(x)
16
+
17
+ # Load model with custom objects
18
+ @st.cache_resource
19
+ def load_model():
20
+ return tf.keras.models.load_model(
21
+ "model.keras", # Use the .keras format instead of .h5
22
+ custom_objects={
23
+ 'sin': sin_activation,
24
+ 'cos': cos_activation,
25
+ 'gelu': tf.keras.activations.gelu
26
+ }
27
+ )
28
+
29
+ model = load_model()
30
+
31
+ # AAMI class map
32
+ class_map = {
33
+ 0: "Normal",
34
+ 1: "Supraventricular Ectopic (SVEB)",
35
+ 2: "Ventricular Ectopic (VEB)",
36
+ 3: "Fusion Beat",
37
+ 4: "Unknown"
38
+ }
39
+
40
+ def extract_beats(record, annotation, window_size=257):
41
+ beats = []
42
+ r_locs = annotation.sample
43
+ signal = record.p_signal[:, 0] # Using first channel
44
+
45
+ for r in r_locs:
46
+ start = max(0, r - window_size//2)
47
+ end = min(len(signal), r + window_size//2 + 1)
48
+
49
+ if end - start == window_size:
50
+ beat = signal[start:end]
51
+ beats.append(beat)
52
+
53
+ return np.array(beats)
54
+
55
+ st.title("ECG Arrhythmia Classification")
56
+ st.write("Upload MIT-BIH record files (.dat, .hea, .atr)")
57
+
58
+ uploaded_files = st.file_uploader(
59
+ "Choose files",
60
+ type=["dat", "hea", "atr"],
61
+ accept_multiple_files=True
62
+ )
63
+
64
+ if uploaded_files:
65
+ with tempfile.TemporaryDirectory() as tmpdir:
66
+ # Save uploaded files
67
+ for f in uploaded_files:
68
+ file_path = os.path.join(tmpdir, f.name)
69
+ with open(file_path, "wb") as f_out:
70
+ f_out.write(f.getbuffer())
71
+
72
+ # Find base record name
73
+ base_names = {os.path.splitext(f.name)[0] for f in uploaded_files}
74
+ common_base = list(base_names)[0] # Get first base name
75
+
76
+ try:
77
+ # Read record
78
+ record = wfdb.rdrecord(os.path.join(tmpdir, common_base))
79
+ annotation = wfdb.rdann(os.path.join(tmpdir, common_base), 'atr')
80
+
81
+ # Process beats
82
+ beats = extract_beats(record, annotation)
83
+ if len(beats) == 0:
84
+ st.error("No valid beats found in the record")
85
+ st.stop()
86
+
87
+ # Preprocess and predict
88
+ beats = beats.reshape((-1, 257, 1)).astype(np.float32)
89
+ predictions = model.predict(beats)
90
+ predicted_classes = np.argmax(predictions, axis=1)
91
+
92
+ # Display results
93
+ st.subheader("Classification Results")
94
+ results = pd.DataFrame({
95
+ "Beat Index": range(len(beats)),
96
+ "Predicted Class": [class_map[c] for c in predicted_classes],
97
+ "Confidence": np.max(predictions, axis=1)
98
+ })
99
+
100
+ st.dataframe(results)
101
+
102
+ # Add visualization
103
+ st.subheader("Sample ECG Beat")
104
+ fig, ax = plt.subplots()
105
+ ax.plot(beats[0].flatten())
106
+ st.pyplot(fig)
107
+
108
+ except Exception as e:
109
+ st.error(f"Error processing files: {str(e)}")