JSCPPProgrammer's picture
Switch backend to diffusers (diffusers-anima); RDBT via from_single_file
fe74e7d verified
"""
Download and verify RDBT weights; optional Hub files from config.ARTIFACTS.
(Civitai RDBT + retries / min-size / .part streaming — parity with the original Comfy-based Space.)
"""
from __future__ import annotations
import os
import shutil
import sys
import time
import requests
from huggingface_hub import hf_hub_download
from src import config
from src.errors import UserFacingError
def _ok_size(path: str) -> bool:
name = os.path.basename(path)
if not os.path.isfile(path):
return False
sz = os.path.getsize(path)
return sz >= config.MIN_SIZES.get(name, 1_000_000)
def _download_one(repo_id: str, repo_file: str, dest: str) -> None:
dest_dir = os.path.dirname(dest)
os.makedirs(dest_dir, mode=0o755, exist_ok=True)
if _ok_size(dest):
print(f"[bootstrap] skip (exists): {dest}", flush=True)
return
for attempt in range(1, config.MAX_RETRIES + 1):
try:
if os.path.isfile(dest):
os.remove(dest)
print(
f"[bootstrap] {repo_id} {repo_file} -> {dest} (attempt {attempt}/{config.MAX_RETRIES})",
flush=True,
)
cached = hf_hub_download(
repo_id=repo_id,
filename=repo_file,
repo_type="model",
)
shutil.copy2(cached, dest)
if not _ok_size(dest):
raise RuntimeError(f"file too small after copy: {dest}")
print(f"[bootstrap] ok: {dest}", flush=True)
return
except Exception as e:
print(f"[bootstrap] error: {e}", file=sys.stderr, flush=True)
if attempt >= config.MAX_RETRIES:
raise
delay = min(config.BACKOFF_CAP_S, 2**attempt)
print(f"[bootstrap] retry in {delay}s...", flush=True)
time.sleep(delay)
def _download_url(url: str, dest: str) -> None:
dest_dir = os.path.dirname(dest)
os.makedirs(dest_dir, mode=0o755, exist_ok=True)
if _ok_size(dest):
print(f"[bootstrap] skip (exists): {dest}", flush=True)
return
part_path = dest + ".part"
headers: dict[str, str] = {}
token = os.environ.get("CIVITAI_TOKEN", "").strip()
if token:
headers["Authorization"] = f"Bearer {token}"
chunk_size = 1024 * 1024
progress_interval = 256 * 1024 * 1024
for attempt in range(1, config.MAX_RETRIES + 1):
try:
if os.path.isfile(dest):
os.remove(dest)
if os.path.isfile(part_path):
os.remove(part_path)
print(
f"[bootstrap] {url} -> {dest} (attempt {attempt}/{config.MAX_RETRIES})",
flush=True,
)
with requests.get(
url,
stream=True,
allow_redirects=True,
timeout=(10, 600),
headers=headers or None,
) as r:
r.raise_for_status()
written = 0
next_log = progress_interval
with open(part_path, "wb") as f:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)
written += len(chunk)
if written >= next_log:
mb = written // (1024 * 1024)
print(
f"[bootstrap] ... {mb} MiB downloaded",
flush=True,
)
next_log += progress_interval
os.replace(part_path, dest)
if not _ok_size(dest):
raise RuntimeError(f"file too small after download: {dest}")
print(f"[bootstrap] ok: {dest}", flush=True)
return
except Exception as e:
print(f"[bootstrap] error: {e}", file=sys.stderr, flush=True)
if os.path.isfile(part_path):
try:
os.remove(part_path)
except OSError:
pass
if attempt >= config.MAX_RETRIES:
raise
delay = min(config.BACKOFF_CAP_S, 2**attempt)
print(f"[bootstrap] retry in {delay}s...", flush=True)
time.sleep(delay)
def bootstrap_model_artifacts() -> None:
"""Download RDBT (Civitai) and any optional config.ARTIFACTS into model_artifacts_root()."""
root = config.model_artifacts_root()
for repo_id, hub_path, rel in config.ARTIFACTS:
dest = os.path.join(root, rel)
_download_one(repo_id, hub_path, dest)
if config.skip_civitai():
print("[bootstrap] SKIP_CIVITAI=1: skipping Civitai downloads.", flush=True)
else:
url, rel = config.CIVITAI_RDBT
dest = os.path.join(root, rel)
_download_url(url, dest)
print("[bootstrap] all model artifacts ready.", flush=True)