File size: 4,680 Bytes
e8332fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# data.py
import os, time, logging, requests
from typing import Any, Dict, List, Optional, Tuple

logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
log = logging.getLogger("replicate-catalog")

REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
if not REPLICATE_API_TOKEN:
    raise RuntimeError("Missing token: set Space secret REPLICATE_API_TOKEN.")

BASE = "https://api.replicate.com/v1"
HEADERS = {"Authorization": f"Token {REPLICATE_API_TOKEN}", "Accept": "application/json"}
TIMEOUT = 30
BASE_DELAY_SEC = 0.15

_pages: List[Dict[str, Any]] = []             # [{"results":[...], "next": str|None}, ...]
_next_cursor_url: Optional[str] = f"{BASE}/models"
_BUILDING = False
_BUILD_COMPLETE = False

def get_state() -> Dict[str, Any]:
    return {
        "pages": len(_pages),
        "more": bool(_next_cursor_url),
        "building": _BUILDING,
        "complete": _BUILD_COMPLETE,
    }

def get_pages() -> List[Dict[str, Any]]:
    return _pages

def is_complete() -> bool:
    return _BUILD_COMPLETE

def _get_with_backoff(url: str) -> Dict[str, Any]:
    delay = BASE_DELAY_SEC
    for attempt in range(6):
        time.sleep(delay)
        r = requests.get(url, headers=HEADERS, timeout=TIMEOUT)
        if r.status_code == 200:
            return r.json()
        if r.status_code == 429:
            ra = r.headers.get("Retry-After")
            delay = max(delay, float(ra)) if ra and ra.isdigit() else min(delay * 2, 8.0)
            log.info("429; backoff %.2fs (attempt %d)", delay, attempt + 1)
            continue
        raise RuntimeError(f"Replicate error {r.status_code}: {r.text[:400]}")
    raise RuntimeError("Too many retries after 429 responses.")

def ensure_page_loaded(target_index: int) -> None:
    global _next_cursor_url
    while len(_pages) <= target_index and _next_cursor_url:
        payload = _get_with_backoff(_next_cursor_url)
        _pages.append({"results": payload.get("results", []) or [], "next": payload.get("next")})
        _next_cursor_url = payload.get("next")

def build_all_pages() -> None:
    global _BUILDING, _BUILD_COMPLETE, _next_cursor_url
    if _BUILD_COMPLETE or _BUILDING:
        return
    _BUILDING = True
    try:
        if not _pages and _next_cursor_url:
            ensure_page_loaded(0)
        while _next_cursor_url:
            payload = _get_with_backoff(_next_cursor_url)
            _pages.append({"results": payload.get("results", []) or [], "next": payload.get("next")})
            _next_cursor_url = payload.get("next")
        _BUILD_COMPLETE = True
        log.info("Catalog build complete. Total pages: %d", len(_pages))
    finally:
        _BUILDING = False

def model_id(m: Dict[str, Any]) -> str:
    owner = m.get("owner") or m.get("username") or ""
    name = m.get("name") or ""
    return f"{owner}/{name}".strip("/")

def license_label_and_url(m: Dict[str, Any]) -> Tuple[str, Optional[str]]:
    url = m.get("license_url")
    txt = (m.get("license") or "").strip() if not url else ""
    if url:
        try:
            import urllib.parse as up, re
            last = up.urlparse(url).path.rstrip("/").split("/")[-1]
            last = re.sub(r"(?i)\.(html|md)$", "", last)
            norm = {
                "apache-2.0": "Apache-2.0", "apache 2.0": "Apache-2.0",
                "mit": "MIT",
                "gpl-3.0": "GPL-3.0", "gpl 3.0": "GPL-3.0",
                "agpl-3.0": "AGPL-3.0", "agpl 3.0": "AGPL-3.0",
                "bsd-3-clause": "BSD-3-Clause", "bsd 3 clause": "BSD-3-Clause",
                "cc-by-4.0": "CC-BY-4.0", "cc by 4.0": "CC-BY-4.0",
                "cc-by-nc-4.0": "CC-BY-NC-4.0", "cc by nc 4.0": "CC-BY-NC-4.0",
            }
            key = last.replace("_", "-").lower()
            label = norm.get(key, last or "License")
        except Exception:
            label = "License"
        return (label, url)
    if txt:
        return (txt, None)
    return ("Unknown", None)

def search_models(query: str) -> List[Dict[str, Any]]:
    q = (query or "").strip().lower()
    if not q:
        return []
    out, seen = [], set()
    for page in _pages:
        for m in page.get("results", []):
            mid = model_id(m)
            hay = " ".join([
                mid,
                m.get("description") or "",
                " ".join([str(t) for t in (m.get("tags") or m.get("categories") or [])]),
            ]).lower()
            if q in hay and mid not in seen:
                out.append(m); seen.add(mid)
    return out

def flatten_models() -> List[Dict[str, Any]]:
    res: List[Dict[str, Any]] = []
    for page in _pages:
        res.extend(page.get("results", []) or [])
    return res