mahmoud611 commited on
Commit
435a22c
·
verified ·
1 Parent(s): 2519750

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +17 -3
inference.py CHANGED
@@ -309,10 +309,24 @@ def _load_cnn_model():
309
 
310
  weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
311
  "weights", "cnn_heart_classifier.pt")
 
 
312
  if not os.path.exists(weights_path):
313
- print("CNN weights not found — CNN disabled", flush=True)
314
- _cnn_available = False
315
- return False
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
  model = HeartSoundCNN()
318
  model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True))
 
309
 
310
  weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
311
  "weights", "cnn_heart_classifier.pt")
312
+
313
+ # If weights not found locally, try downloading from HF repo
314
  if not os.path.exists(weights_path):
315
+ try:
316
+ from huggingface_hub import hf_hub_download
317
+ print("Downloading CNN weights from HF...", flush=True)
318
+ os.makedirs(os.path.dirname(weights_path), exist_ok=True)
319
+ hf_hub_download(
320
+ repo_id="mahmoud611/cardioscreen-api",
321
+ filename="weights/cnn_heart_classifier.pt",
322
+ repo_type="space",
323
+ local_dir=os.path.dirname(os.path.dirname(weights_path)),
324
+ )
325
+ print("CNN weights downloaded ✓", flush=True)
326
+ except Exception as dl_err:
327
+ print(f"CNN weights not found and download failed: {dl_err}", flush=True)
328
+ _cnn_available = False
329
+ return False
330
 
331
  model = HeartSoundCNN()
332
  model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True))