""" Download and verify RDBT weights; optional Hub files from config.ARTIFACTS. (Civitai RDBT + retries / min-size / .part streaming — parity with the original Comfy-based Space.) """ from __future__ import annotations import os import shutil import sys import time import requests from huggingface_hub import hf_hub_download from src import config from src.errors import UserFacingError def _ok_size(path: str) -> bool: name = os.path.basename(path) if not os.path.isfile(path): return False sz = os.path.getsize(path) return sz >= config.MIN_SIZES.get(name, 1_000_000) def _download_one(repo_id: str, repo_file: str, dest: str) -> None: dest_dir = os.path.dirname(dest) os.makedirs(dest_dir, mode=0o755, exist_ok=True) if _ok_size(dest): print(f"[bootstrap] skip (exists): {dest}", flush=True) return for attempt in range(1, config.MAX_RETRIES + 1): try: if os.path.isfile(dest): os.remove(dest) print( f"[bootstrap] {repo_id} {repo_file} -> {dest} (attempt {attempt}/{config.MAX_RETRIES})", flush=True, ) cached = hf_hub_download( repo_id=repo_id, filename=repo_file, repo_type="model", ) shutil.copy2(cached, dest) if not _ok_size(dest): raise RuntimeError(f"file too small after copy: {dest}") print(f"[bootstrap] ok: {dest}", flush=True) return except Exception as e: print(f"[bootstrap] error: {e}", file=sys.stderr, flush=True) if attempt >= config.MAX_RETRIES: raise delay = min(config.BACKOFF_CAP_S, 2**attempt) print(f"[bootstrap] retry in {delay}s...", flush=True) time.sleep(delay) def _download_url(url: str, dest: str) -> None: dest_dir = os.path.dirname(dest) os.makedirs(dest_dir, mode=0o755, exist_ok=True) if _ok_size(dest): print(f"[bootstrap] skip (exists): {dest}", flush=True) return part_path = dest + ".part" headers: dict[str, str] = {} token = os.environ.get("CIVITAI_TOKEN", "").strip() if token: headers["Authorization"] = f"Bearer {token}" chunk_size = 1024 * 1024 progress_interval = 256 * 1024 * 1024 for attempt in range(1, config.MAX_RETRIES + 1): try: if os.path.isfile(dest): os.remove(dest) if os.path.isfile(part_path): os.remove(part_path) print( f"[bootstrap] {url} -> {dest} (attempt {attempt}/{config.MAX_RETRIES})", flush=True, ) with requests.get( url, stream=True, allow_redirects=True, timeout=(10, 600), headers=headers or None, ) as r: r.raise_for_status() written = 0 next_log = progress_interval with open(part_path, "wb") as f: for chunk in r.iter_content(chunk_size=chunk_size): if chunk: f.write(chunk) written += len(chunk) if written >= next_log: mb = written // (1024 * 1024) print( f"[bootstrap] ... {mb} MiB downloaded", flush=True, ) next_log += progress_interval os.replace(part_path, dest) if not _ok_size(dest): raise RuntimeError(f"file too small after download: {dest}") print(f"[bootstrap] ok: {dest}", flush=True) return except Exception as e: print(f"[bootstrap] error: {e}", file=sys.stderr, flush=True) if os.path.isfile(part_path): try: os.remove(part_path) except OSError: pass if attempt >= config.MAX_RETRIES: raise delay = min(config.BACKOFF_CAP_S, 2**attempt) print(f"[bootstrap] retry in {delay}s...", flush=True) time.sleep(delay) def bootstrap_model_artifacts() -> None: """Download RDBT (Civitai) and any optional config.ARTIFACTS into model_artifacts_root().""" root = config.model_artifacts_root() for repo_id, hub_path, rel in config.ARTIFACTS: dest = os.path.join(root, rel) _download_one(repo_id, hub_path, dest) if config.skip_civitai(): print("[bootstrap] SKIP_CIVITAI=1: skipping Civitai downloads.", flush=True) else: url, rel = config.CIVITAI_RDBT dest = os.path.join(root, rel) _download_url(url, dest) print("[bootstrap] all model artifacts ready.", flush=True)