Ajinkya commited on
Commit
16aa079
·
1 Parent(s): 61b0d5a

Use hf_hub to download model at runtime; remove local checkpoint from Space app

Browse files
Files changed (3) hide show
  1. app.py +28 -4
  2. model/iter_160000.pth +0 -3
  3. requirements.txt +1 -0
app.py CHANGED
@@ -4,10 +4,12 @@ import torch
4
  import gradio as gr
5
 
6
  from mmseg.apis import init_model, inference_model
 
7
 
8
- # Paths to model assets inside the repo
9
- CONFIG_PATH = "model/segformer_mit-b5_8xb1-160k_pre-cityscapes_seaicergb0-1024x1024.py"
10
- CHECKPOINT_PATH = "model/iter_160000.pth"
 
11
 
12
  _model = None
13
  _palette = None
@@ -17,7 +19,29 @@ def _get_model():
17
  global _model, _palette
18
  if _model is None:
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
- _model = init_model(CONFIG_PATH, CHECKPOINT_PATH, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  meta = getattr(_model, "dataset_meta", None) or {}
22
  _palette = meta.get("palette", None)
23
  if _palette is None:
 
4
  import gradio as gr
5
 
6
  from mmseg.apis import init_model, inference_model
7
+ from huggingface_hub import hf_hub_download
8
 
9
+ # Source model repository on the Hub (avoid storing large files in the Space)
10
+ REPO_ID = "triton7777/SeaIce"
11
+ CONFIG_FILENAME = "segformer_mit-b5_8xb1-160k_pre-cityscapes_seaicergb0-1024x1024.py"
12
+ CHECKPOINT_FILENAME = "iter_160000.pth"
13
 
14
  _model = None
15
  _palette = None
 
19
  global _model, _palette
20
  if _model is None:
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ # Prefer local files if present (for local dev), otherwise download from Hub
23
+ local_cfg = os.path.join("model", CONFIG_FILENAME)
24
+ local_ckpt = os.path.join("model", CHECKPOINT_FILENAME)
25
+
26
+ if os.path.exists(local_cfg) and os.path.exists(local_ckpt):
27
+ cfg_path = local_cfg
28
+ ckpt_path = local_ckpt
29
+ else:
30
+ def _download_any(name: str):
31
+ candidates = [name, os.path.join("model", name)]
32
+ last_err = None
33
+ for cand in candidates:
34
+ try:
35
+ return hf_hub_download(repo_id=REPO_ID, filename=cand)
36
+ except Exception as e: # pragma: no cover (runtime fallback)
37
+ last_err = e
38
+ continue
39
+ raise last_err
40
+
41
+ cfg_path = _download_any(CONFIG_FILENAME)
42
+ ckpt_path = _download_any(CHECKPOINT_FILENAME)
43
+
44
+ _model = init_model(cfg_path, ckpt_path, device=device)
45
  meta = getattr(_model, "dataset_meta", None) or {}
46
  _palette = meta.get("palette", None)
47
  if _palette is None:
model/iter_160000.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9782394b4c725ad13283db293d2fc232633a34ea1312c00ab8c1868379bb24d7
3
- size 1005871349
 
 
 
 
requirements.txt CHANGED
@@ -13,3 +13,4 @@ opencv-python-headless==4.8.1.78
13
  numpy==1.24.4
14
  gradio==4.12.0
15
  pillow>=9.4.0
 
 
13
  numpy==1.24.4
14
  gradio==4.12.0
15
  pillow>=9.4.0
16
+ huggingface_hub>=0.20.0