import torch import torch.nn as nn import gradio as gr import numpy as np class Simple1DCNN(nn.Module): def __init__(self, input_channels, num_classes, sequence_length): super(Simple1DCNN, self).__init__() self.conv1 = nn.Conv1d(in_channels=input_channels, out_channels=64, kernel_size=3, padding=1) self.relu = nn.ReLU() self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1) self.pool = nn.MaxPool1d(kernel_size=2) # Compute the output size after convolutions and pooling self._to_linear = self._compute_flattened_size(input_channels, sequence_length) self.fc1 = nn.Linear(self._to_linear, 256) self.fc2 = nn.Linear(256, num_classes) def _compute_flattened_size(self, input_channels, sequence_length): x = torch.randn(1, input_channels, sequence_length) # Dummy tensor x = self.pool(self.relu(self.conv1(x))) x = self.pool(self.relu(self.conv2(x))) return x.numel() # Total number of features after conv and pooling def forward(self, x): x = self.pool(self.relu(self.conv1(x))) x = self.pool(self.relu(self.conv2(x))) x = x.view(x.shape[0], -1) # Flatten x = self.relu(self.fc1(x)) x = self.fc2(x) return x # Load model model_path = "ecg.pth" # Adjust if necessary device = torch.device("cuda" if torch.cuda.is_available() else "cpu") sequence_length = 187 # Adjust based on training model = Simple1DCNN(input_channels=1, num_classes=5, sequence_length=sequence_length) model.load_state_dict(torch.load(model_path, map_location="cpu")) model.eval() # Preprocessing function def preprocess_ecg(data): """Convert input ECG data to PyTorch tensor and prepare for inference.""" ecg = np.array(data.split(','), dtype=np.float32) ecg = torch.from_numpy(ecg).unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_length) return ecg # Prediction function def predict(ecg_data): ecg_tensor = preprocess_ecg(ecg_data) with torch.no_grad(): output = model(ecg_tensor) predicted_class = int(output.argmax(dim=1).item()) # Define class labels class_labels = {0: "Normal", 1: "AFib", 2: "PVC", 3: "ST", 4: "Other"} return class_labels.get(predicted_class, "Unknown") # Create Gradio interface app = gr.Interface( fn=predict, inputs=gr.Textbox(lines=2, placeholder="Enter ECG values separated by commas"), outputs="label", title="ECG Classification", description="Predicts ECG conditions based on input signal.", ) # Launch app if __name__ == "__main__": app.launch()