wi-lab commited on
Commit
80c177a
·
1 Parent(s): 515eeef

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +27 -10
app.py CHANGED
@@ -203,21 +203,37 @@ def load_predictor() -> MoEPredictor:
203
  )
204
 
205
  # Ensure expert checkpoints are resolvable in the Space (paths inside ckpt are absolute)
206
- def ensure_expert(name: str) -> Path:
207
  """Return a local path to the expert checkpoint, downloading if needed."""
 
 
208
  local_candidates = [
209
- REPO_ROOT / "experts" / name,
210
- REPO_ROOT / name,
 
 
211
  ]
212
  for cand in local_candidates:
213
  if cand.exists():
214
  return cand
215
- # Download from model repo
216
- downloaded = hf_hub.hf_hub_download(
217
- repo_id="wi-lab/lwm-spectro",
218
- filename=f"experts/{name}",
219
- )
220
- return Path(downloaded)
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  # Rewrite expert paths into a temp checkpoint so MoEPredictor loads cleanly
223
  import torch # local import to keep top import list compact
@@ -231,7 +247,8 @@ def load_predictor() -> MoEPredictor:
231
  if not ckpt_field:
232
  continue
233
  fname = Path(ckpt_field).name
234
- local_path = ensure_expert(fname)
 
235
  if str(local_path) != ckpt_field:
236
  expert["checkpoint"] = str(local_path)
237
  patched = True
 
203
  )
204
 
205
  # Ensure expert checkpoints are resolvable in the Space (paths inside ckpt are absolute)
206
+ def ensure_expert(name: str, comm: str) -> Path:
207
  """Return a local path to the expert checkpoint, downloading if needed."""
208
+ fname = Path(name).name
209
+ comm_tag = comm.replace("/", "_")
210
  local_candidates = [
211
+ REPO_ROOT / "experts" / fname,
212
+ REPO_ROOT / fname,
213
+ REPO_ROOT / "experts" / f"{comm_tag}_expert.pth",
214
+ REPO_ROOT / f"{comm_tag}_expert.pth",
215
  ]
216
  for cand in local_candidates:
217
  if cand.exists():
218
  return cand
219
+ # Download from model repo with multiple filename guesses
220
+ download_candidates = [
221
+ f"experts/{fname}",
222
+ f"experts/{comm_tag}_expert.pth",
223
+ fname,
224
+ ]
225
+ last_err = None
226
+ for rel in download_candidates:
227
+ try:
228
+ downloaded = hf_hub.hf_hub_download(
229
+ repo_id="wi-lab/lwm-spectro",
230
+ filename=rel,
231
+ )
232
+ return Path(downloaded)
233
+ except Exception as exc: # pragma: no cover - network/permissions issues
234
+ last_err = exc
235
+ continue
236
+ raise RuntimeError(f"Could not resolve expert checkpoint for {comm} ({fname}): {last_err}")
237
 
238
  # Rewrite expert paths into a temp checkpoint so MoEPredictor loads cleanly
239
  import torch # local import to keep top import list compact
 
247
  if not ckpt_field:
248
  continue
249
  fname = Path(ckpt_field).name
250
+ comm = expert.get("comm", "unknown")
251
+ local_path = ensure_expert(fname, comm)
252
  if str(local_path) != ckpt_field:
253
  expert["checkpoint"] = str(local_path)
254
  patched = True