wi-lab commited on
Commit
515eeef
·
1 Parent(s): 0205171

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py CHANGED
@@ -202,6 +202,44 @@ def load_predictor() -> MoEPredictor:
202
  hf_hub.hf_hub_download(repo_id="wi-lab/lwm-spectro", filename="moe_checkpoint.pth")
203
  )
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  _predictor = MoEPredictor.from_checkpoint(ckpt_path)
206
  return _predictor
207
 
 
202
  hf_hub.hf_hub_download(repo_id="wi-lab/lwm-spectro", filename="moe_checkpoint.pth")
203
  )
204
 
205
+ # Ensure expert checkpoints are resolvable in the Space (paths inside ckpt are absolute)
206
+ def ensure_expert(name: str) -> Path:
207
+ """Return a local path to the expert checkpoint, downloading if needed."""
208
+ local_candidates = [
209
+ REPO_ROOT / "experts" / name,
210
+ REPO_ROOT / name,
211
+ ]
212
+ for cand in local_candidates:
213
+ if cand.exists():
214
+ return cand
215
+ # Download from model repo
216
+ downloaded = hf_hub.hf_hub_download(
217
+ repo_id="wi-lab/lwm-spectro",
218
+ filename=f"experts/{name}",
219
+ )
220
+ return Path(downloaded)
221
+
222
+ # Rewrite expert paths into a temp checkpoint so MoEPredictor loads cleanly
223
+ import torch # local import to keep top import list compact
224
+
225
+ raw_ckpt = torch.load(ckpt_path, map_location="cpu")
226
+ experts = raw_ckpt.get("experts", [])
227
+ if experts:
228
+ patched = False
229
+ for expert in experts:
230
+ ckpt_field = expert.get("checkpoint")
231
+ if not ckpt_field:
232
+ continue
233
+ fname = Path(ckpt_field).name
234
+ local_path = ensure_expert(fname)
235
+ if str(local_path) != ckpt_field:
236
+ expert["checkpoint"] = str(local_path)
237
+ patched = True
238
+ if patched:
239
+ tmp_path = Path("/tmp/moe_checkpoint_patched.pth")
240
+ torch.save(raw_ckpt, tmp_path)
241
+ ckpt_path = tmp_path
242
+
243
  _predictor = MoEPredictor.from_checkpoint(ckpt_path)
244
  return _predictor
245