Spaces:
Sleeping
Sleeping
File size: 2,629 Bytes
016cbea c8ca3e3 016cbea c8ca3e3 016cbea c8ca3e3 016cbea c8ca3e3 016cbea ffb1547 016cbea dfc2121 ffb1547 016cbea b43f258 016cbea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
import torch.nn as nn
import gradio as gr
import numpy as np
class Simple1DCNN(nn.Module):
def __init__(self, input_channels, num_classes, sequence_length):
super(Simple1DCNN, self).__init__()
self.conv1 = nn.Conv1d(in_channels=input_channels, out_channels=64, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
self.pool = nn.MaxPool1d(kernel_size=2)
# Compute the output size after convolutions and pooling
self._to_linear = self._compute_flattened_size(input_channels, sequence_length)
self.fc1 = nn.Linear(self._to_linear, 256)
self.fc2 = nn.Linear(256, num_classes)
def _compute_flattened_size(self, input_channels, sequence_length):
x = torch.randn(1, input_channels, sequence_length) # Dummy tensor
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
return x.numel() # Total number of features after conv and pooling
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(x.shape[0], -1) # Flatten
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
# Load model
model_path = "ecg.pth" # Adjust if necessary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sequence_length = 187 # Adjust based on training
model = Simple1DCNN(input_channels=1, num_classes=5, sequence_length=sequence_length)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
# Preprocessing function
def preprocess_ecg(data):
"""Convert input ECG data to PyTorch tensor and prepare for inference."""
ecg = np.array(data.split(','), dtype=np.float32)
ecg = torch.from_numpy(ecg).unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_length)
return ecg
# Prediction function
def predict(ecg_data):
ecg_tensor = preprocess_ecg(ecg_data)
with torch.no_grad():
output = model(ecg_tensor)
predicted_class = int(output.argmax(dim=1).item())
# Define class labels
class_labels = {0: "Normal", 1: "AFib", 2: "PVC", 3: "ST", 4: "Other"}
return class_labels.get(predicted_class, "Unknown")
# Create Gradio interface
app = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=2, placeholder="Enter ECG values separated by commas"),
outputs="label",
title="ECG Classification",
description="Predicts ECG conditions based on input signal.",
)
# Launch app
if __name__ == "__main__":
app.launch()
|