Spaces:
Running
Running
| # data.py | |
| import os, time, logging, requests | |
| from typing import Any, Dict, List, Optional, Tuple | |
| logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") | |
| log = logging.getLogger("replicate-catalog") | |
| REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN") | |
| if not REPLICATE_API_TOKEN: | |
| raise RuntimeError("Missing token: set Space secret REPLICATE_API_TOKEN.") | |
| BASE = "https://api.replicate.com/v1" | |
| HEADERS = {"Authorization": f"Token {REPLICATE_API_TOKEN}", "Accept": "application/json"} | |
| TIMEOUT = 30 | |
| BASE_DELAY_SEC = 0.15 | |
| _pages: List[Dict[str, Any]] = [] # [{"results":[...], "next": str|None}, ...] | |
| _next_cursor_url: Optional[str] = f"{BASE}/models" | |
| _BUILDING = False | |
| _BUILD_COMPLETE = False | |
| def get_state() -> Dict[str, Any]: | |
| return { | |
| "pages": len(_pages), | |
| "more": bool(_next_cursor_url), | |
| "building": _BUILDING, | |
| "complete": _BUILD_COMPLETE, | |
| } | |
| def get_pages() -> List[Dict[str, Any]]: | |
| return _pages | |
| def is_complete() -> bool: | |
| return _BUILD_COMPLETE | |
| def _get_with_backoff(url: str) -> Dict[str, Any]: | |
| delay = BASE_DELAY_SEC | |
| for attempt in range(6): | |
| time.sleep(delay) | |
| r = requests.get(url, headers=HEADERS, timeout=TIMEOUT) | |
| if r.status_code == 200: | |
| return r.json() | |
| if r.status_code == 429: | |
| ra = r.headers.get("Retry-After") | |
| delay = max(delay, float(ra)) if ra and ra.isdigit() else min(delay * 2, 8.0) | |
| log.info("429; backoff %.2fs (attempt %d)", delay, attempt + 1) | |
| continue | |
| raise RuntimeError(f"Replicate error {r.status_code}: {r.text[:400]}") | |
| raise RuntimeError("Too many retries after 429 responses.") | |
| def ensure_page_loaded(target_index: int) -> None: | |
| global _next_cursor_url | |
| while len(_pages) <= target_index and _next_cursor_url: | |
| payload = _get_with_backoff(_next_cursor_url) | |
| _pages.append({"results": payload.get("results", []) or [], "next": payload.get("next")}) | |
| _next_cursor_url = payload.get("next") | |
| def build_all_pages() -> None: | |
| global _BUILDING, _BUILD_COMPLETE, _next_cursor_url | |
| if _BUILD_COMPLETE or _BUILDING: | |
| return | |
| _BUILDING = True | |
| try: | |
| if not _pages and _next_cursor_url: | |
| ensure_page_loaded(0) | |
| while _next_cursor_url: | |
| payload = _get_with_backoff(_next_cursor_url) | |
| _pages.append({"results": payload.get("results", []) or [], "next": payload.get("next")}) | |
| _next_cursor_url = payload.get("next") | |
| _BUILD_COMPLETE = True | |
| log.info("Catalog build complete. Total pages: %d", len(_pages)) | |
| finally: | |
| _BUILDING = False | |
| def model_id(m: Dict[str, Any]) -> str: | |
| owner = m.get("owner") or m.get("username") or "" | |
| name = m.get("name") or "" | |
| return f"{owner}/{name}".strip("/") | |
| def license_label_and_url(m: Dict[str, Any]) -> Tuple[str, Optional[str]]: | |
| url = m.get("license_url") | |
| txt = (m.get("license") or "").strip() if not url else "" | |
| if url: | |
| try: | |
| import urllib.parse as up, re | |
| last = up.urlparse(url).path.rstrip("/").split("/")[-1] | |
| last = re.sub(r"(?i)\.(html|md)$", "", last) | |
| norm = { | |
| "apache-2.0": "Apache-2.0", "apache 2.0": "Apache-2.0", | |
| "mit": "MIT", | |
| "gpl-3.0": "GPL-3.0", "gpl 3.0": "GPL-3.0", | |
| "agpl-3.0": "AGPL-3.0", "agpl 3.0": "AGPL-3.0", | |
| "bsd-3-clause": "BSD-3-Clause", "bsd 3 clause": "BSD-3-Clause", | |
| "cc-by-4.0": "CC-BY-4.0", "cc by 4.0": "CC-BY-4.0", | |
| "cc-by-nc-4.0": "CC-BY-NC-4.0", "cc by nc 4.0": "CC-BY-NC-4.0", | |
| } | |
| key = last.replace("_", "-").lower() | |
| label = norm.get(key, last or "License") | |
| except Exception: | |
| label = "License" | |
| return (label, url) | |
| if txt: | |
| return (txt, None) | |
| return ("Unknown", None) | |
| def search_models(query: str) -> List[Dict[str, Any]]: | |
| q = (query or "").strip().lower() | |
| if not q: | |
| return [] | |
| out, seen = [], set() | |
| for page in _pages: | |
| for m in page.get("results", []): | |
| mid = model_id(m) | |
| hay = " ".join([ | |
| mid, | |
| m.get("description") or "", | |
| " ".join([str(t) for t in (m.get("tags") or m.get("categories") or [])]), | |
| ]).lower() | |
| if q in hay and mid not in seen: | |
| out.append(m); seen.add(mid) | |
| return out | |
| def flatten_models() -> List[Dict[str, Any]]: | |
| res: List[Dict[str, Any]] = [] | |
| for page in _pages: | |
| res.extend(page.get("results", []) or []) | |
| return res |