"""Model management utilities.""" from __future__ import annotations import os import shutil import tempfile import zipfile from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from .config import ( BUILTIN_MODELS, MODELS_DIR, logger, ) def list_models() -> list[str]: """List available voice models.""" if not MODELS_DIR.exists(): return [] return sorted( p.name for p in MODELS_DIR.iterdir() if p.is_dir() and list(p.glob("*.pth")) ) def pth_and_index(name: str) -> tuple[str, str]: """Get model .pth and optional .index paths.""" d = MODELS_DIR / name pths = list(d.glob("*.pth")) idxs = list(d.glob("*.index")) if not pths: raise FileNotFoundError(f"No .pth file found in model '{name}'") return str(pths[0]), str(idxs[0]) if idxs else "" def extract_zip(zip_path: str | Path, dest_name: str) -> None: """Extract a model ZIP, flattening nested .pth/.index files.""" dest = MODELS_DIR / dest_name dest.mkdir(exist_ok=True) with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(dest) for nested in list(dest.rglob("*.pth")) + list(dest.rglob("*.index")): target = dest / nested.name if nested != target: shutil.move(str(nested), str(target)) def download_file(url: str, dest: Path) -> None: """Download a single file if not already present.""" if dest.exists(): return dest.parent.mkdir(parents=True, exist_ok=True) logger.info("Downloading %s …", dest.name) import requests r = requests.get(url, stream=True, timeout=300) r.raise_for_status() with tempfile.NamedTemporaryFile(delete=False, dir=dest.parent, suffix=".tmp") as tmp: for chunk in r.iter_content(8192): tmp.write(chunk) tmp_path = tmp.name os.replace(tmp_path, dest) logger.info("%s ready.", dest.name) def download_model_entry(model: dict) -> str: """Download a single built-in model ZIP. Returns model name.""" import requests name = model["name"] dest = MODELS_DIR / name if dest.exists() and list(dest.glob("*.pth")): logger.info("Model already present: %s", name) return name logger.info("Downloading model: %s …", name) with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmp: r = requests.get(model["url"], stream=True, timeout=300) r.raise_for_status() for chunk in r.iter_content(8192): tmp.write(chunk) tmp_path = tmp.name extract_zip(tmp_path, name) os.unlink(tmp_path) logger.info("Model ready: %s", name) return name def startup_downloads() -> str: """ Download all required assets in parallel at startup. Returns name of first built-in model as the default selection. """ import requests predictor_base = "https://huggingface.co/JackismyShephard/ultimate-rvc/resolve/main/Resources/predictors" embedder_base = "https://huggingface.co/JackismyShephard/ultimate-rvc/resolve/main/Resources/embedders" predictors_dir = MODELS_DIR / "predictors" embedders_dir = MODELS_DIR / "embedders" file_tasks = [ (f"{predictor_base}/rmvpe.pt", predictors_dir / "rmvpe.pt"), (f"{predictor_base}/fcpe.pt", predictors_dir / "fcpe.pt"), (f"{embedder_base}/contentvec/pytorch_model.bin", embedders_dir / "contentvec" / "pytorch_model.bin"), (f"{embedder_base}/contentvec/config.json", embedders_dir / "contentvec" / "config.json"), ] with ThreadPoolExecutor(max_workers=8) as pool: file_futures = {pool.submit(download_file, url, dest): dest.name for url, dest in file_tasks} model_futures = {pool.submit(download_model_entry, m): m["name"] for m in BUILTIN_MODELS} all_futures = {**file_futures, **model_futures} for future in as_completed(all_futures): try: future.result() except Exception as exc: logger.warning("Download failed (%s): %s", all_futures[future], exc) return BUILTIN_MODELS[0]["name"]