Hackathon-IA-VisualNovel / scripts /download_models.py
WillHbx's picture
Add STT fixes, SDXL-Lightning backend, voice UX and rembg sprite transparency
407855c
Raw
History Blame Contribute Delete
2.99 kB
"""Pull model weights into ./models for the real (non-mock) config.
uv run python scripts/download_models.py # everything for the beefy config
uv run python scripts/download_models.py --llm # just the GGUF LLM
uv run python scripts/download_models.py --image # just SDXL-Turbo
Needs `huggingface_hub` (comes with gradio). Diffusers/whisper weights are also fetched
lazily on first real use, so this script is mostly to warm the cache / work offline later.
"""
from __future__ import annotations
import argparse
from visualnovel import config
def fetch_llm() -> None:
from huggingface_hub import hf_hub_download
print(f"↓ GGUF LLM: {config.LLM_GGUF_REPO}/{config.LLM_GGUF_FILE}")
path = hf_hub_download(
repo_id=config.LLM_GGUF_REPO,
filename=config.LLM_GGUF_FILE,
local_dir=str(config.MODELS_DIR),
token=config.HF_TOKEN,
)
print(f" -> {path}")
def fetch_image() -> None:
from huggingface_hub import snapshot_download
print(f"↓ image model: {config.IMAGE_MODEL}")
snapshot_download(
repo_id=config.IMAGE_MODEL,
local_dir=str(config.MODELS_DIR / "image"),
local_dir_use_symlinks=False,
token=config.HF_TOKEN,
)
if config.IMAGE_LORA:
print(f"↓ style LoRA: {config.IMAGE_LORA}")
snapshot_download(
repo_id=config.IMAGE_LORA,
local_dir=str(config.MODELS_DIR / "lora"),
local_dir_use_symlinks=False,
token=config.HF_TOKEN,
)
def fetch_rembg() -> None:
try:
from rembg import new_session # noqa: PLC0415
print("↓ rembg: u2net background-removal model")
new_session("u2net") # downloads ~168 MB on first call, cached in ~/.u2net
print(" -> cached")
except ImportError:
print(" (skip rembg: install the `image` extra to pre-fetch)")
def fetch_whisper() -> None:
# faster-whisper downloads on first use; this just pre-warms it.
try:
from faster_whisper import WhisperModel
print(f"↓ whisper: {config.WHISPER_SIZE}")
WhisperModel(config.WHISPER_SIZE, device="cpu", compute_type="int8")
except ImportError:
print(" (skip whisper: install the `stt` extra to pre-fetch)")
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--llm", action="store_true")
ap.add_argument("--image", action="store_true")
ap.add_argument("--whisper", action="store_true")
ap.add_argument("--rembg", action="store_true")
args = ap.parse_args()
do_all = not (args.llm or args.image or args.whisper or args.rembg)
config.MODELS_DIR.mkdir(parents=True, exist_ok=True)
if args.llm or do_all:
fetch_llm()
if args.image or do_all:
fetch_image()
fetch_rembg()
if args.rembg:
fetch_rembg()
if args.whisper or do_all:
fetch_whisper()
print("done.")
if __name__ == "__main__":
main()