Naneet commited on
Commit
b1b46b2
·
verified ·
1 Parent(s): 6ec31c6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +57 -0
  2. ecg.pth +3 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import gradio as gr
4
+ import numpy as np
5
+
6
+ # Define the Simple1DCNN model
7
+ class Simple1DCNN(nn.Module):
8
+ def __init__(self, input_channels=1, num_classes=5, sequence_length=500):
9
+ super(Simple1DCNN, self).__init__()
10
+ self.conv1 = nn.Conv1d(input_channels, 16, kernel_size=5, stride=1, padding=2)
11
+ self.relu = nn.ReLU()
12
+ self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
13
+ self.fc = nn.Linear(16 * (sequence_length // 2), num_classes)
14
+
15
+ def forward(self, x):
16
+ x = self.pool(self.relu(self.conv1(x)))
17
+ x = x.view(x.size(0), -1)
18
+ return self.fc(x)
19
+
20
+ # Load model
21
+ model_path = "ecg.pth" # Adjust if necessary
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ sequence_length = 500 # Adjust based on training
24
+ model = Simple1DCNN(input_channels=1, num_classes=5, sequence_length=sequence_length).to(device)
25
+ model.load_state_dict(torch.load(model_path, map_location=device))
26
+ model.eval()
27
+
28
+ # Preprocessing function
29
+ def preprocess_ecg(data):
30
+ """Convert input ECG data to PyTorch tensor and prepare for inference."""
31
+ ecg = np.array(data).astype(np.float32) # Ensure NumPy array format
32
+ ecg = torch.from_numpy(ecg).unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, seq_length)
33
+ return ecg
34
+
35
+ # Prediction function
36
+ def predict(ecg_data):
37
+ ecg_tensor = preprocess_ecg(ecg_data)
38
+ with torch.no_grad():
39
+ output = model(ecg_tensor)
40
+ predicted_class = int(output.argmax(dim=1).item())
41
+
42
+ # Define class labels
43
+ class_labels = {0: "Normal", 1: "AFib", 2: "PVC", 3: "ST", 4: "Other"}
44
+ return class_labels.get(predicted_class, "Unknown")
45
+
46
+ # Create Gradio interface
47
+ app = gr.Interface(
48
+ fn=predict,
49
+ inputs=gr.Textbox(lines=2, placeholder="Enter ECG values separated by commas"),
50
+ outputs="label",
51
+ title="ECG Classification",
52
+ description="Predicts ECG conditions based on input signal.",
53
+ )
54
+
55
+ # Launch app
56
+ if __name__ == "__main__":
57
+ app.launch()
ecg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fb18b9a827a58328596e3d9cf55bd0293afdc3193bc9a05301a6a7808447aad
3
+ size 6138064