Naneet's picture
Update app.py
b43f258 verified
import torch
import torch.nn as nn
import gradio as gr
import numpy as np
class Simple1DCNN(nn.Module):
def __init__(self, input_channels, num_classes, sequence_length):
super(Simple1DCNN, self).__init__()
self.conv1 = nn.Conv1d(in_channels=input_channels, out_channels=64, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
self.pool = nn.MaxPool1d(kernel_size=2)
# Compute the output size after convolutions and pooling
self._to_linear = self._compute_flattened_size(input_channels, sequence_length)
self.fc1 = nn.Linear(self._to_linear, 256)
self.fc2 = nn.Linear(256, num_classes)
def _compute_flattened_size(self, input_channels, sequence_length):
x = torch.randn(1, input_channels, sequence_length) # Dummy tensor
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
return x.numel() # Total number of features after conv and pooling
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(x.shape[0], -1) # Flatten
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
# Load model
model_path = "ecg.pth" # Adjust if necessary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sequence_length = 187 # Adjust based on training
model = Simple1DCNN(input_channels=1, num_classes=5, sequence_length=sequence_length)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
# Preprocessing function
def preprocess_ecg(data):
"""Convert input ECG data to PyTorch tensor and prepare for inference."""
ecg = np.array(data.split(','), dtype=np.float32)
ecg = torch.from_numpy(ecg).unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_length)
return ecg
# Prediction function
def predict(ecg_data):
ecg_tensor = preprocess_ecg(ecg_data)
with torch.no_grad():
output = model(ecg_tensor)
predicted_class = int(output.argmax(dim=1).item())
# Define class labels
class_labels = {0: "Normal", 1: "AFib", 2: "PVC", 3: "ST", 4: "Other"}
return class_labels.get(predicted_class, "Unknown")
# Create Gradio interface
app = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=2, placeholder="Enter ECG values separated by commas"),
outputs="label",
title="ECG Classification",
description="Predicts ECG conditions based on input signal.",
)
# Launch app
if __name__ == "__main__":
app.launch()