File size: 2,629 Bytes
016cbea
 
 
 
 
 
c8ca3e3
016cbea
c8ca3e3
016cbea
c8ca3e3
 
 
 
 
 
 
 
 
 
 
 
 
 
016cbea
 
 
c8ca3e3
 
 
 
 
016cbea
 
 
 
 
ffb1547
 
016cbea
 
 
 
 
dfc2121
ffb1547
016cbea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b43f258
016cbea
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
import gradio as gr
import numpy as np

class Simple1DCNN(nn.Module):
    def __init__(self, input_channels, num_classes, sequence_length):
        super(Simple1DCNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=input_channels, out_channels=64, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2)

        # Compute the output size after convolutions and pooling
        self._to_linear = self._compute_flattened_size(input_channels, sequence_length)

        self.fc1 = nn.Linear(self._to_linear, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def _compute_flattened_size(self, input_channels, sequence_length):
        x = torch.randn(1, input_channels, sequence_length)  # Dummy tensor
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        return x.numel()  # Total number of features after conv and pooling

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)  # Flatten
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Load model
model_path = "ecg.pth"  # Adjust if necessary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sequence_length = 187  # Adjust based on training
model = Simple1DCNN(input_channels=1, num_classes=5, sequence_length=sequence_length)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()

# Preprocessing function
def preprocess_ecg(data):
    """Convert input ECG data to PyTorch tensor and prepare for inference."""
    ecg = np.array(data.split(','), dtype=np.float32)
    ecg = torch.from_numpy(ecg).unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_length)
    return ecg

# Prediction function
def predict(ecg_data):
    ecg_tensor = preprocess_ecg(ecg_data)
    with torch.no_grad():
        output = model(ecg_tensor)
    predicted_class = int(output.argmax(dim=1).item())

    # Define class labels
    class_labels = {0: "Normal", 1: "AFib", 2: "PVC", 3: "ST", 4: "Other"}
    return class_labels.get(predicted_class, "Unknown")

# Create Gradio interface
app = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(lines=2, placeholder="Enter ECG values separated by commas"),
    outputs="label",
    title="ECG Classification",
    description="Predicts ECG conditions based on input signal.",
)


# Launch app
if __name__ == "__main__":
    app.launch()