Naneet commited on
Commit
016cbea
·
verified ·
1 Parent(s): 03b4270

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -57
app.py CHANGED
@@ -1,57 +1,57 @@
1
- import torch
2
- import torch.nn as nn
3
- import gradio as gr
4
- import numpy as np
5
-
6
- # Define the Simple1DCNN model
7
- class Simple1DCNN(nn.Module):
8
- def __init__(self, input_channels=1, num_classes=5, sequence_length=500):
9
- super(Simple1DCNN, self).__init__()
10
- self.conv1 = nn.Conv1d(input_channels, 16, kernel_size=5, stride=1, padding=2)
11
- self.relu = nn.ReLU()
12
- self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
13
- self.fc = nn.Linear(16 * (sequence_length // 2), num_classes)
14
-
15
- def forward(self, x):
16
- x = self.pool(self.relu(self.conv1(x)))
17
- x = x.view(x.size(0), -1)
18
- return self.fc(x)
19
-
20
- # Load model
21
- model_path = "ecg.pth" # Adjust if necessary
22
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
- sequence_length = 500 # Adjust based on training
24
- model = Simple1DCNN(input_channels=1, num_classes=5, sequence_length=sequence_length).to(device)
25
- model.load_state_dict(torch.load(model_path, map_location=device))
26
- model.eval()
27
-
28
- # Preprocessing function
29
- def preprocess_ecg(data):
30
- """Convert input ECG data to PyTorch tensor and prepare for inference."""
31
- ecg = np.array(data).astype(np.float32) # Ensure NumPy array format
32
- ecg = torch.from_numpy(ecg).unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, seq_length)
33
- return ecg
34
-
35
- # Prediction function
36
- def predict(ecg_data):
37
- ecg_tensor = preprocess_ecg(ecg_data)
38
- with torch.no_grad():
39
- output = model(ecg_tensor)
40
- predicted_class = int(output.argmax(dim=1).item())
41
-
42
- # Define class labels
43
- class_labels = {0: "Normal", 1: "AFib", 2: "PVC", 3: "ST", 4: "Other"}
44
- return class_labels.get(predicted_class, "Unknown")
45
-
46
- # Create Gradio interface
47
- app = gr.Interface(
48
- fn=predict,
49
- inputs=gr.Textbox(lines=2, placeholder="Enter ECG values separated by commas"),
50
- outputs="label",
51
- title="ECG Classification",
52
- description="Predicts ECG conditions based on input signal.",
53
- )
54
-
55
- # Launch app
56
- if __name__ == "__main__":
57
- app.launch()
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import gradio as gr
4
+ import numpy as np
5
+
6
+ # Define the Simple1DCNN model
7
+ class Simple1DCNN(nn.Module):
8
+ def __init__(self, input_channels=1, num_classes=5, sequence_length=500):
9
+ super(Simple1DCNN, self).__init__()
10
+ self.conv1 = nn.Conv1d(input_channels, 16, kernel_size=5, stride=1, padding=2)
11
+ self.relu = nn.ReLU()
12
+ self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
13
+ self.fc = nn.Linear(16 * (sequence_length // 2), num_classes)
14
+
15
+ def forward(self, x):
16
+ x = self.pool(self.relu(self.conv1(x)))
17
+ x = x.view(x.size(0), -1)
18
+ return self.fc(x)
19
+
20
+ # Load model
21
+ model_path = "ecg.pth" # Adjust if necessary
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ sequence_length = 187 # Adjust based on training
24
+ model = Simple1DCNN(input_channels=1, num_classes=5, sequence_length=sequence_length).to(device)
25
+ model.load_state_dict(torch.load(model_path, map_location=device))
26
+ model.eval()
27
+
28
+ # Preprocessing function
29
+ def preprocess_ecg(data):
30
+ """Convert input ECG data to PyTorch tensor and prepare for inference."""
31
+ ecg = np.array(data).astype(np.float32) # Ensure NumPy array format
32
+ ecg = torch.from_numpy(ecg).unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, seq_length)
33
+ return ecg
34
+
35
+ # Prediction function
36
+ def predict(ecg_data):
37
+ ecg_tensor = preprocess_ecg(ecg_data)
38
+ with torch.no_grad():
39
+ output = model(ecg_tensor)
40
+ predicted_class = int(output.argmax(dim=1).item())
41
+
42
+ # Define class labels
43
+ class_labels = {0: "Normal", 1: "AFib", 2: "PVC", 3: "ST", 4: "Other"}
44
+ return class_labels.get(predicted_class, "Unknown")
45
+
46
+ # Create Gradio interface
47
+ app = gr.Interface(
48
+ fn=predict,
49
+ inputs=gr.Textbox(lines=2, placeholder="Enter ECG values separated by commas"),
50
+ outputs="label",
51
+ title="ECG Classification",
52
+ description="Predicts ECG conditions based on input signal.",
53
+ )
54
+
55
+ # Launch app
56
+ if __name__ == "__main__":
57
+ app.launch()