Spaces:
Running
Running
File size: 4,680 Bytes
e8332fd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | # data.py
import os, time, logging, requests
from typing import Any, Dict, List, Optional, Tuple
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
log = logging.getLogger("replicate-catalog")
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
if not REPLICATE_API_TOKEN:
raise RuntimeError("Missing token: set Space secret REPLICATE_API_TOKEN.")
BASE = "https://api.replicate.com/v1"
HEADERS = {"Authorization": f"Token {REPLICATE_API_TOKEN}", "Accept": "application/json"}
TIMEOUT = 30
BASE_DELAY_SEC = 0.15
_pages: List[Dict[str, Any]] = [] # [{"results":[...], "next": str|None}, ...]
_next_cursor_url: Optional[str] = f"{BASE}/models"
_BUILDING = False
_BUILD_COMPLETE = False
def get_state() -> Dict[str, Any]:
return {
"pages": len(_pages),
"more": bool(_next_cursor_url),
"building": _BUILDING,
"complete": _BUILD_COMPLETE,
}
def get_pages() -> List[Dict[str, Any]]:
return _pages
def is_complete() -> bool:
return _BUILD_COMPLETE
def _get_with_backoff(url: str) -> Dict[str, Any]:
delay = BASE_DELAY_SEC
for attempt in range(6):
time.sleep(delay)
r = requests.get(url, headers=HEADERS, timeout=TIMEOUT)
if r.status_code == 200:
return r.json()
if r.status_code == 429:
ra = r.headers.get("Retry-After")
delay = max(delay, float(ra)) if ra and ra.isdigit() else min(delay * 2, 8.0)
log.info("429; backoff %.2fs (attempt %d)", delay, attempt + 1)
continue
raise RuntimeError(f"Replicate error {r.status_code}: {r.text[:400]}")
raise RuntimeError("Too many retries after 429 responses.")
def ensure_page_loaded(target_index: int) -> None:
global _next_cursor_url
while len(_pages) <= target_index and _next_cursor_url:
payload = _get_with_backoff(_next_cursor_url)
_pages.append({"results": payload.get("results", []) or [], "next": payload.get("next")})
_next_cursor_url = payload.get("next")
def build_all_pages() -> None:
global _BUILDING, _BUILD_COMPLETE, _next_cursor_url
if _BUILD_COMPLETE or _BUILDING:
return
_BUILDING = True
try:
if not _pages and _next_cursor_url:
ensure_page_loaded(0)
while _next_cursor_url:
payload = _get_with_backoff(_next_cursor_url)
_pages.append({"results": payload.get("results", []) or [], "next": payload.get("next")})
_next_cursor_url = payload.get("next")
_BUILD_COMPLETE = True
log.info("Catalog build complete. Total pages: %d", len(_pages))
finally:
_BUILDING = False
def model_id(m: Dict[str, Any]) -> str:
owner = m.get("owner") or m.get("username") or ""
name = m.get("name") or ""
return f"{owner}/{name}".strip("/")
def license_label_and_url(m: Dict[str, Any]) -> Tuple[str, Optional[str]]:
url = m.get("license_url")
txt = (m.get("license") or "").strip() if not url else ""
if url:
try:
import urllib.parse as up, re
last = up.urlparse(url).path.rstrip("/").split("/")[-1]
last = re.sub(r"(?i)\.(html|md)$", "", last)
norm = {
"apache-2.0": "Apache-2.0", "apache 2.0": "Apache-2.0",
"mit": "MIT",
"gpl-3.0": "GPL-3.0", "gpl 3.0": "GPL-3.0",
"agpl-3.0": "AGPL-3.0", "agpl 3.0": "AGPL-3.0",
"bsd-3-clause": "BSD-3-Clause", "bsd 3 clause": "BSD-3-Clause",
"cc-by-4.0": "CC-BY-4.0", "cc by 4.0": "CC-BY-4.0",
"cc-by-nc-4.0": "CC-BY-NC-4.0", "cc by nc 4.0": "CC-BY-NC-4.0",
}
key = last.replace("_", "-").lower()
label = norm.get(key, last or "License")
except Exception:
label = "License"
return (label, url)
if txt:
return (txt, None)
return ("Unknown", None)
def search_models(query: str) -> List[Dict[str, Any]]:
q = (query or "").strip().lower()
if not q:
return []
out, seen = [], set()
for page in _pages:
for m in page.get("results", []):
mid = model_id(m)
hay = " ".join([
mid,
m.get("description") or "",
" ".join([str(t) for t in (m.get("tags") or m.get("categories") or [])]),
]).lower()
if q in hay and mid not in seen:
out.append(m); seen.add(mid)
return out
def flatten_models() -> List[Dict[str, Any]]:
res: List[Dict[str, Any]] = []
for page in _pages:
res.extend(page.get("results", []) or [])
return res |