Spaces:
Sleeping
Sleeping
Upload inference.py with huggingface_hub
Browse files- inference.py +17 -3
inference.py
CHANGED
|
@@ -309,10 +309,24 @@ def _load_cnn_model():
|
|
| 309 |
|
| 310 |
weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
| 311 |
"weights", "cnn_heart_classifier.pt")
|
|
|
|
|
|
|
| 312 |
if not os.path.exists(weights_path):
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
model = HeartSoundCNN()
|
| 318 |
model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True))
|
|
|
|
| 309 |
|
| 310 |
weights_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
| 311 |
"weights", "cnn_heart_classifier.pt")
|
| 312 |
+
|
| 313 |
+
# If weights not found locally, try downloading from HF repo
|
| 314 |
if not os.path.exists(weights_path):
|
| 315 |
+
try:
|
| 316 |
+
from huggingface_hub import hf_hub_download
|
| 317 |
+
print("Downloading CNN weights from HF...", flush=True)
|
| 318 |
+
os.makedirs(os.path.dirname(weights_path), exist_ok=True)
|
| 319 |
+
hf_hub_download(
|
| 320 |
+
repo_id="mahmoud611/cardioscreen-api",
|
| 321 |
+
filename="weights/cnn_heart_classifier.pt",
|
| 322 |
+
repo_type="space",
|
| 323 |
+
local_dir=os.path.dirname(os.path.dirname(weights_path)),
|
| 324 |
+
)
|
| 325 |
+
print("CNN weights downloaded ✓", flush=True)
|
| 326 |
+
except Exception as dl_err:
|
| 327 |
+
print(f"CNN weights not found and download failed: {dl_err}", flush=True)
|
| 328 |
+
_cnn_available = False
|
| 329 |
+
return False
|
| 330 |
|
| 331 |
model = HeartSoundCNN()
|
| 332 |
model.load_state_dict(torch.load(weights_path, map_location="cpu", weights_only=True))
|