sam-brause commited on
Commit
85c4bdf
·
1 Parent(s): 71e8593

Updated handler to download model from Hugging Face

Browse files
Files changed (1) hide show
  1. handler.py +22 -6
handler.py CHANGED
@@ -1,16 +1,32 @@
1
  import os
2
- import subprocess
3
  import torch
 
4
  import torchvision.transforms as transforms
5
  from PIL import Image
6
  import json
7
 
8
- # Fetch model from Git LFS if not downloaded
9
- MODEL_PATH = "model_scripted_efficientnet.pt"
 
 
10
 
11
- if not os.path.exists(MODEL_PATH):
12
- print("Model file not found! Fetching from Git LFS...")
13
- subprocess.run(["git", "lfs", "pull"], check=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Load Model
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
1
  import os
 
2
  import torch
3
+ import requests
4
  import torchvision.transforms as transforms
5
  from PIL import Image
6
  import json
7
 
8
+ # Hugging Face model URL
9
+ HF_REPO = "your-username/your-model-name"
10
+ MODEL_FILENAME = "model_scripted_efficientnet.pt"
11
+ MODEL_PATH = f"/repository/{MODEL_FILENAME}"
12
 
13
+ def download_model():
14
+ """Download model file from Hugging Face if it does not exist."""
15
+ if not os.path.exists(MODEL_PATH):
16
+ print("Model file not found! Downloading from Hugging Face...")
17
+ model_url = f"https://huggingface.co/{HF_REPO}/resolve/main/{MODEL_FILENAME}"
18
+ response = requests.get(model_url, stream=True)
19
+
20
+ if response.status_code == 200:
21
+ with open(MODEL_PATH, "wb") as f:
22
+ for chunk in response.iter_content(chunk_size=8192):
23
+ f.write(chunk)
24
+ print("Model downloaded successfully!")
25
+ else:
26
+ raise RuntimeError(f"Failed to download model, status code: {response.status_code}")
27
+
28
+ # Ensure model is downloaded before loading
29
+ download_model()
30
 
31
  # Load Model
32
  device = "cuda" if torch.cuda.is_available() else "cpu"