Naneet commited on
Commit
c8ca3e3
·
verified ·
1 Parent(s): ffb1547

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -7
app.py CHANGED
@@ -3,19 +3,33 @@ import torch.nn as nn
3
  import gradio as gr
4
  import numpy as np
5
 
6
- # Define the Simple1DCNN model
7
  class Simple1DCNN(nn.Module):
8
- def __init__(self, input_channels=1, num_classes=5, sequence_length=500):
9
  super(Simple1DCNN, self).__init__()
10
- self.conv1 = nn.Conv1d(input_channels, 16, kernel_size=5, stride=1, padding=2)
11
  self.relu = nn.ReLU()
12
- self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
13
- self.fc = nn.Linear(16 * (sequence_length // 2), num_classes)
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def forward(self, x):
16
  x = self.pool(self.relu(self.conv1(x)))
17
- x = x.view(x.size(0), -1)
18
- return self.fc(x)
 
 
 
19
 
20
  # Load model
21
  model_path = "ecg.pth" # Adjust if necessary
 
3
  import gradio as gr
4
  import numpy as np
5
 
 
6
  class Simple1DCNN(nn.Module):
7
+ def __init__(self, input_channels, num_classes, sequence_length):
8
  super(Simple1DCNN, self).__init__()
9
+ self.conv1 = nn.Conv1d(in_channels=input_channels, out_channels=64, kernel_size=3, padding=1)
10
  self.relu = nn.ReLU()
11
+ self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
12
+ self.pool = nn.MaxPool1d(kernel_size=2)
13
+
14
+ # Compute the output size after convolutions and pooling
15
+ self._to_linear = self._compute_flattened_size(input_channels, sequence_length)
16
+
17
+ self.fc1 = nn.Linear(self._to_linear, 256)
18
+ self.fc2 = nn.Linear(256, num_classes)
19
+
20
+ def _compute_flattened_size(self, input_channels, sequence_length):
21
+ x = torch.randn(1, input_channels, sequence_length) # Dummy tensor
22
+ x = self.pool(self.relu(self.conv1(x)))
23
+ x = self.pool(self.relu(self.conv2(x)))
24
+ return x.numel() # Total number of features after conv and pooling
25
 
26
  def forward(self, x):
27
  x = self.pool(self.relu(self.conv1(x)))
28
+ x = self.pool(self.relu(self.conv2(x)))
29
+ x = x.view(x.shape[0], -1) # Flatten
30
+ x = self.relu(self.fc1(x))
31
+ x = self.fc2(x)
32
+ return x
33
 
34
  # Load model
35
  model_path = "ecg.pth" # Adjust if necessary