Tumo505 commited on
Commit
6fcc7e5
·
1 Parent(s): 1dc306c

Simplify model loading - use direct CNN instead of PreTrainedModel wrapper

Browse files
Files changed (1) hide show
  1. app.py +46 -7
app.py CHANGED
@@ -6,15 +6,16 @@ Deploy to Hugging Face Spaces
6
 
7
  import gradio as gr
8
  import torch
 
9
  import numpy as np
10
  import plotly.graph_objects as go
11
- from transformers import AutoModel, AutoConfig
12
  import tempfile
13
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
  # Constants
17
- MODEL_ID = "Tumo505/SSL-ECG-Classificcation"
18
  CLASS_LABELS = ["NORM", "MI", "STTC", "HYP", "CD"]
19
  CLASS_COLORS = {
20
  "NORM": "#90EE90",
@@ -24,16 +25,55 @@ CLASS_COLORS = {
24
  "CD": "#A29BFE"
25
  }
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Load model
28
  model = None
29
  try:
30
  print("Loading model from Hub...")
31
- model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
32
  model.to(device)
33
  model.eval()
34
- print("Model loaded successfully")
35
  except Exception as e:
36
- print(f"Error loading model: {e}")
 
 
37
 
38
  def predict_ecg(file_obj):
39
  """Main prediction function"""
@@ -95,8 +135,7 @@ def predict_ecg(file_obj):
95
 
96
  # Predict
97
  with torch.no_grad():
98
- output = model(x)
99
- logits = output["logits"][0].cpu().numpy()
100
  probs = torch.softmax(torch.tensor(logits), dim=0).numpy()
101
 
102
  # Get prediction
 
6
 
7
  import gradio as gr
8
  import torch
9
+ import torch.nn as nn
10
  import numpy as np
11
  import plotly.graph_objects as go
12
+ from huggingface_hub import hf_hub_download
13
  import tempfile
14
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
  # Constants
18
+ REPO_ID = "Tumo505/SSL-ECG-Classificcation"
19
  CLASS_LABELS = ["NORM", "MI", "STTC", "HYP", "CD"]
20
  CLASS_COLORS = {
21
  "NORM": "#90EE90",
 
25
  "CD": "#A29BFE"
26
  }
27
 
28
+ # Define model architecture (1D CNN)
29
+ class ECGClassifier(nn.Module):
30
+ def __init__(self, num_classes=5, num_leads=12, output_size=128):
31
+ super().__init__()
32
+ self.encoder = nn.Sequential(
33
+ nn.Conv1d(num_leads, 32, kernel_size=7, padding=3),
34
+ nn.BatchNorm1d(32),
35
+ nn.ReLU(),
36
+ nn.MaxPool1d(2),
37
+ nn.Conv1d(32, 64, kernel_size=5, padding=2),
38
+ nn.BatchNorm1d(64),
39
+ nn.ReLU(),
40
+ nn.MaxPool1d(2),
41
+ nn.Conv1d(64, 128, kernel_size=3, padding=1),
42
+ nn.BatchNorm1d(128),
43
+ nn.ReLU(),
44
+ nn.AdaptiveAvgPool1d(1),
45
+ nn.Flatten(),
46
+ nn.Linear(128, output_size),
47
+ )
48
+ self.classifier = nn.Linear(output_size, num_classes)
49
+
50
+ def forward(self, x):
51
+ embeddings = self.encoder(x)
52
+ logits = self.classifier(embeddings)
53
+ return logits
54
+
55
  # Load model
56
  model = None
57
  try:
58
  print("Loading model from Hub...")
59
+ model = ECGClassifier(num_classes=len(CLASS_LABELS), num_leads=12, output_size=128)
60
+
61
+ # Download weights from Hub
62
+ weights_path = hf_hub_download(repo_id=REPO_ID, filename="model.safetensors")
63
+
64
+ # Load safetensors
65
+ from safetensors.torch import load_file
66
+ state_dict = load_file(weights_path)
67
+
68
+ # Load weights into model
69
+ model.load_state_dict(state_dict, strict=False)
70
  model.to(device)
71
  model.eval()
72
+ print("Model loaded successfully")
73
  except Exception as e:
74
+ print(f"Error loading model: {e}")
75
+ import traceback
76
+ traceback.print_exc()
77
 
78
  def predict_ecg(file_obj):
79
  """Main prediction function"""
 
135
 
136
  # Predict
137
  with torch.no_grad():
138
+ logits = model(x)[0].cpu().numpy()
 
139
  probs = torch.softmax(torch.tensor(logits), dim=0).numpy()
140
 
141
  # Get prediction