davidquarel commited on
Commit
82662a0
verified
1 Parent(s): a173546

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -11,7 +11,10 @@ from vidshow import vidshow_dual_from_png_bytes, vidshow_from_png_bytes
11
  ACTION_PROBS_FILES = ("action_probs.tar.gz", "action_probs.zip")
12
  ACTION_PROBS_SUFFIXES = (".tar.gz", ".zip")
13
 
14
- JAXGMG_COLLECTION_SLUG = "davidquarel/jaxgmg-69528c4f23d35d7de3eecfa1"
 
 
 
15
  DEFAULT_REPO = "davidquarel/jaxgmg_3phase_seed"
16
  LOCAL_BASE_PATH = Path(__file__).parent.parent
17
  USE_LOCAL = os.environ.get("JAXGMG_USE_LOCAL", "").lower() in ("1", "true", "yes")
@@ -28,14 +31,20 @@ def get_hf_api() -> HfApi:
28
 
29
  @lru_cache(maxsize=1)
30
  def get_collection_repos() -> tuple[str, ...]:
31
- """Fetch all model repos from the jaxgmg collection."""
32
- try:
33
- collection = get_hf_api().get_collection(JAXGMG_COLLECTION_SLUG)
34
- repos = [item.item_id for item in collection.items if item.item_type == "model"]
35
- return tuple(repos) if repos else (DEFAULT_REPO,)
36
- except Exception as e:
37
- print(f"Warning: Could not fetch collection: {e}")
38
- return (DEFAULT_REPO,)
 
 
 
 
 
 
39
 
40
 
41
  def get_local_path(repo_id: str) -> Path:
@@ -366,7 +375,8 @@ def create_demo():
366
 
367
  gr.Markdown(
368
  f"<small>Source: {'Local' if USE_LOCAL else 'HuggingFace'} 路 "
369
- f"[Collection](https://huggingface.co/collections/davidquarel/jaxgmg)</small>",
 
370
  )
371
 
372
  repo_dropdown.change(
 
11
  ACTION_PROBS_FILES = ("action_probs.tar.gz", "action_probs.zip")
12
  ACTION_PROBS_SUFFIXES = (".tar.gz", ".zip")
13
 
14
+ JAXGMG_COLLECTION_SLUGS = (
15
+ "davidquarel/jaxgmg-69528c4f23d35d7de3eecfa1",
16
+ "timaeus/project-rl2-69bd40fed77a62a6eba28fc9",
17
+ )
18
  DEFAULT_REPO = "davidquarel/jaxgmg_3phase_seed"
19
  LOCAL_BASE_PATH = Path(__file__).parent.parent
20
  USE_LOCAL = os.environ.get("JAXGMG_USE_LOCAL", "").lower() in ("1", "true", "yes")
 
31
 
32
  @lru_cache(maxsize=1)
33
  def get_collection_repos() -> tuple[str, ...]:
34
+ """Fetch all model repos from all configured collections (deduplicated)."""
35
+ seen: set[str] = set()
36
+ repos: list[str] = []
37
+ api = get_hf_api()
38
+ for slug in JAXGMG_COLLECTION_SLUGS:
39
+ try:
40
+ collection = api.get_collection(slug)
41
+ for item in collection.items:
42
+ if item.item_type == "model" and item.item_id not in seen:
43
+ seen.add(item.item_id)
44
+ repos.append(item.item_id)
45
+ except Exception as e:
46
+ print(f"Warning: Could not fetch collection {slug}: {e}")
47
+ return tuple(repos) if repos else (DEFAULT_REPO,)
48
 
49
 
50
  def get_local_path(repo_id: str) -> Path:
 
375
 
376
  gr.Markdown(
377
  f"<small>Source: {'Local' if USE_LOCAL else 'HuggingFace'} 路 "
378
+ f"[davidquarel/jaxgmg](https://huggingface.co/collections/davidquarel/jaxgmg)"
379
+ f"[timaeus/project-rl2](https://huggingface.co/collections/timaeus/project-rl2)</small>",
380
  )
381
 
382
  repo_dropdown.change(