Spaces:
Running on Zero
Running on Zero
| """ | |
| Download and verify RDBT weights; optional Hub files from config.ARTIFACTS. | |
| (Civitai RDBT + retries / min-size / .part streaming — parity with the original Comfy-based Space.) | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import shutil | |
| import sys | |
| import time | |
| import requests | |
| from huggingface_hub import hf_hub_download | |
| from src import config | |
| from src.errors import UserFacingError | |
| def _ok_size(path: str) -> bool: | |
| name = os.path.basename(path) | |
| if not os.path.isfile(path): | |
| return False | |
| sz = os.path.getsize(path) | |
| return sz >= config.MIN_SIZES.get(name, 1_000_000) | |
| def _download_one(repo_id: str, repo_file: str, dest: str) -> None: | |
| dest_dir = os.path.dirname(dest) | |
| os.makedirs(dest_dir, mode=0o755, exist_ok=True) | |
| if _ok_size(dest): | |
| print(f"[bootstrap] skip (exists): {dest}", flush=True) | |
| return | |
| for attempt in range(1, config.MAX_RETRIES + 1): | |
| try: | |
| if os.path.isfile(dest): | |
| os.remove(dest) | |
| print( | |
| f"[bootstrap] {repo_id} {repo_file} -> {dest} (attempt {attempt}/{config.MAX_RETRIES})", | |
| flush=True, | |
| ) | |
| cached = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=repo_file, | |
| repo_type="model", | |
| ) | |
| shutil.copy2(cached, dest) | |
| if not _ok_size(dest): | |
| raise RuntimeError(f"file too small after copy: {dest}") | |
| print(f"[bootstrap] ok: {dest}", flush=True) | |
| return | |
| except Exception as e: | |
| print(f"[bootstrap] error: {e}", file=sys.stderr, flush=True) | |
| if attempt >= config.MAX_RETRIES: | |
| raise | |
| delay = min(config.BACKOFF_CAP_S, 2**attempt) | |
| print(f"[bootstrap] retry in {delay}s...", flush=True) | |
| time.sleep(delay) | |
| def _download_url(url: str, dest: str) -> None: | |
| dest_dir = os.path.dirname(dest) | |
| os.makedirs(dest_dir, mode=0o755, exist_ok=True) | |
| if _ok_size(dest): | |
| print(f"[bootstrap] skip (exists): {dest}", flush=True) | |
| return | |
| part_path = dest + ".part" | |
| headers: dict[str, str] = {} | |
| token = os.environ.get("CIVITAI_TOKEN", "").strip() | |
| if token: | |
| headers["Authorization"] = f"Bearer {token}" | |
| chunk_size = 1024 * 1024 | |
| progress_interval = 256 * 1024 * 1024 | |
| for attempt in range(1, config.MAX_RETRIES + 1): | |
| try: | |
| if os.path.isfile(dest): | |
| os.remove(dest) | |
| if os.path.isfile(part_path): | |
| os.remove(part_path) | |
| print( | |
| f"[bootstrap] {url} -> {dest} (attempt {attempt}/{config.MAX_RETRIES})", | |
| flush=True, | |
| ) | |
| with requests.get( | |
| url, | |
| stream=True, | |
| allow_redirects=True, | |
| timeout=(10, 600), | |
| headers=headers or None, | |
| ) as r: | |
| r.raise_for_status() | |
| written = 0 | |
| next_log = progress_interval | |
| with open(part_path, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=chunk_size): | |
| if chunk: | |
| f.write(chunk) | |
| written += len(chunk) | |
| if written >= next_log: | |
| mb = written // (1024 * 1024) | |
| print( | |
| f"[bootstrap] ... {mb} MiB downloaded", | |
| flush=True, | |
| ) | |
| next_log += progress_interval | |
| os.replace(part_path, dest) | |
| if not _ok_size(dest): | |
| raise RuntimeError(f"file too small after download: {dest}") | |
| print(f"[bootstrap] ok: {dest}", flush=True) | |
| return | |
| except Exception as e: | |
| print(f"[bootstrap] error: {e}", file=sys.stderr, flush=True) | |
| if os.path.isfile(part_path): | |
| try: | |
| os.remove(part_path) | |
| except OSError: | |
| pass | |
| if attempt >= config.MAX_RETRIES: | |
| raise | |
| delay = min(config.BACKOFF_CAP_S, 2**attempt) | |
| print(f"[bootstrap] retry in {delay}s...", flush=True) | |
| time.sleep(delay) | |
| def bootstrap_model_artifacts() -> None: | |
| """Download RDBT (Civitai) and any optional config.ARTIFACTS into model_artifacts_root().""" | |
| root = config.model_artifacts_root() | |
| for repo_id, hub_path, rel in config.ARTIFACTS: | |
| dest = os.path.join(root, rel) | |
| _download_one(repo_id, hub_path, dest) | |
| if config.skip_civitai(): | |
| print("[bootstrap] SKIP_CIVITAI=1: skipping Civitai downloads.", flush=True) | |
| else: | |
| url, rel = config.CIVITAI_RDBT | |
| dest = os.path.join(root, rel) | |
| _download_url(url, dest) | |
| print("[bootstrap] all model artifacts ready.", flush=True) | |