Spaces:
Running on Zero
Running on Zero
File size: 5,154 Bytes
5befce1 fe74e7d 5befce1 fe74e7d 5befce1 fe74e7d 5befce1 fe74e7d 5befce1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | """
Download and verify RDBT weights; optional Hub files from config.ARTIFACTS.
(Civitai RDBT + retries / min-size / .part streaming — parity with the original Comfy-based Space.)
"""
from __future__ import annotations
import os
import shutil
import sys
import time
import requests
from huggingface_hub import hf_hub_download
from src import config
from src.errors import UserFacingError
def _ok_size(path: str) -> bool:
name = os.path.basename(path)
if not os.path.isfile(path):
return False
sz = os.path.getsize(path)
return sz >= config.MIN_SIZES.get(name, 1_000_000)
def _download_one(repo_id: str, repo_file: str, dest: str) -> None:
dest_dir = os.path.dirname(dest)
os.makedirs(dest_dir, mode=0o755, exist_ok=True)
if _ok_size(dest):
print(f"[bootstrap] skip (exists): {dest}", flush=True)
return
for attempt in range(1, config.MAX_RETRIES + 1):
try:
if os.path.isfile(dest):
os.remove(dest)
print(
f"[bootstrap] {repo_id} {repo_file} -> {dest} (attempt {attempt}/{config.MAX_RETRIES})",
flush=True,
)
cached = hf_hub_download(
repo_id=repo_id,
filename=repo_file,
repo_type="model",
)
shutil.copy2(cached, dest)
if not _ok_size(dest):
raise RuntimeError(f"file too small after copy: {dest}")
print(f"[bootstrap] ok: {dest}", flush=True)
return
except Exception as e:
print(f"[bootstrap] error: {e}", file=sys.stderr, flush=True)
if attempt >= config.MAX_RETRIES:
raise
delay = min(config.BACKOFF_CAP_S, 2**attempt)
print(f"[bootstrap] retry in {delay}s...", flush=True)
time.sleep(delay)
def _download_url(url: str, dest: str) -> None:
dest_dir = os.path.dirname(dest)
os.makedirs(dest_dir, mode=0o755, exist_ok=True)
if _ok_size(dest):
print(f"[bootstrap] skip (exists): {dest}", flush=True)
return
part_path = dest + ".part"
headers: dict[str, str] = {}
token = os.environ.get("CIVITAI_TOKEN", "").strip()
if token:
headers["Authorization"] = f"Bearer {token}"
chunk_size = 1024 * 1024
progress_interval = 256 * 1024 * 1024
for attempt in range(1, config.MAX_RETRIES + 1):
try:
if os.path.isfile(dest):
os.remove(dest)
if os.path.isfile(part_path):
os.remove(part_path)
print(
f"[bootstrap] {url} -> {dest} (attempt {attempt}/{config.MAX_RETRIES})",
flush=True,
)
with requests.get(
url,
stream=True,
allow_redirects=True,
timeout=(10, 600),
headers=headers or None,
) as r:
r.raise_for_status()
written = 0
next_log = progress_interval
with open(part_path, "wb") as f:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)
written += len(chunk)
if written >= next_log:
mb = written // (1024 * 1024)
print(
f"[bootstrap] ... {mb} MiB downloaded",
flush=True,
)
next_log += progress_interval
os.replace(part_path, dest)
if not _ok_size(dest):
raise RuntimeError(f"file too small after download: {dest}")
print(f"[bootstrap] ok: {dest}", flush=True)
return
except Exception as e:
print(f"[bootstrap] error: {e}", file=sys.stderr, flush=True)
if os.path.isfile(part_path):
try:
os.remove(part_path)
except OSError:
pass
if attempt >= config.MAX_RETRIES:
raise
delay = min(config.BACKOFF_CAP_S, 2**attempt)
print(f"[bootstrap] retry in {delay}s...", flush=True)
time.sleep(delay)
def bootstrap_model_artifacts() -> None:
"""Download RDBT (Civitai) and any optional config.ARTIFACTS into model_artifacts_root()."""
root = config.model_artifacts_root()
for repo_id, hub_path, rel in config.ARTIFACTS:
dest = os.path.join(root, rel)
_download_one(repo_id, hub_path, dest)
if config.skip_civitai():
print("[bootstrap] SKIP_CIVITAI=1: skipping Civitai downloads.", flush=True)
else:
url, rel = config.CIVITAI_RDBT
dest = os.path.join(root, rel)
_download_url(url, dest)
print("[bootstrap] all model artifacts ready.", flush=True)
|