TunisianCoder commited on
Commit
dc21fc5
Β·
verified Β·
1 Parent(s): 32584ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -66,18 +66,22 @@ class SleepStageCNN(nn.Module):
66
  # ────────────────────────────────────────────────────────────────
67
  # Load Model at startup
68
  # ────────────────────────────────────────────────────────────────
69
-
70
  device = torch.device("cpu")
71
  model = SleepStageCNN(n_channels=1, n_classes=6)
72
-
73
  if os.path.exists(MODEL_PATH):
74
  checkpoint = torch.load(
75
  MODEL_PATH, map_location=device, weights_only=False
76
  )
77
  if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
78
- model.load_state_dict(checkpoint["model_state_dict"])
79
  else:
80
- model.load_state_dict(checkpoint)
 
 
 
 
 
 
81
  model.eval().to(device)
82
  print(f"βœ… Model loaded from {MODEL_PATH}")
83
  else:
@@ -85,8 +89,6 @@ else:
85
  f"Model file not found at {MODEL_PATH}. "
86
  "Upload sleep_stage_cnn.pth to this Space."
87
  )
88
-
89
-
90
  # ────────────────────────────────────────────────────────────────
91
  # Inference Function
92
  # ────────────────────────────────────────────────────────────────
 
66
  # ────────────────────────────────────────────────────────────────
67
  # Load Model at startup
68
  # ────────────────────────────────────────────────────────────────
 
69
  device = torch.device("cpu")
70
  model = SleepStageCNN(n_channels=1, n_classes=6)
 
71
  if os.path.exists(MODEL_PATH):
72
  checkpoint = torch.load(
73
  MODEL_PATH, map_location=device, weights_only=False
74
  )
75
  if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
76
+ state_dict = checkpoint["model_state_dict"]
77
  else:
78
+ state_dict = checkpoint
79
+
80
+ # Remap bare Sequential keys (e.g. "0.weight") β†’ "network.0.weight"
81
+ if any(k.split(".")[0].isdigit() for k in state_dict.keys()):
82
+ state_dict = {"network." + k: v for k, v in state_dict.items()}
83
+
84
+ model.load_state_dict(state_dict)
85
  model.eval().to(device)
86
  print(f"βœ… Model loaded from {MODEL_PATH}")
87
  else:
 
89
  f"Model file not found at {MODEL_PATH}. "
90
  "Upload sleep_stage_cnn.pth to this Space."
91
  )
 
 
92
  # ────────────────────────────────────────────────────────────────
93
  # Inference Function
94
  # ────────────────────────────────────────────────────────────────