RVC-CH / lib /models.py
ozipoetra
fix: correct predictor and embedder download paths
2782216
"""Model management utilities."""
from __future__ import annotations
import os
import shutil
import tempfile
import zipfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from .config import (
BUILTIN_MODELS,
MODELS_DIR,
logger,
)
def list_models() -> list[str]:
"""List available voice models."""
if not MODELS_DIR.exists():
return []
return sorted(
p.name for p in MODELS_DIR.iterdir()
if p.is_dir() and list(p.glob("*.pth"))
)
def pth_and_index(name: str) -> tuple[str, str]:
"""Get model .pth and optional .index paths."""
d = MODELS_DIR / name
pths = list(d.glob("*.pth"))
idxs = list(d.glob("*.index"))
if not pths:
raise FileNotFoundError(f"No .pth file found in model '{name}'")
return str(pths[0]), str(idxs[0]) if idxs else ""
def extract_zip(zip_path: str | Path, dest_name: str) -> None:
"""Extract a model ZIP, flattening nested .pth/.index files."""
dest = MODELS_DIR / dest_name
dest.mkdir(exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(dest)
for nested in list(dest.rglob("*.pth")) + list(dest.rglob("*.index")):
target = dest / nested.name
if nested != target:
shutil.move(str(nested), str(target))
def download_file(url: str, dest: Path) -> None:
"""Download a single file if not already present."""
if dest.exists():
return
dest.parent.mkdir(parents=True, exist_ok=True)
logger.info("Downloading %s …", dest.name)
import requests
r = requests.get(url, stream=True, timeout=300)
r.raise_for_status()
with tempfile.NamedTemporaryFile(delete=False, dir=dest.parent, suffix=".tmp") as tmp:
for chunk in r.iter_content(8192):
tmp.write(chunk)
tmp_path = tmp.name
os.replace(tmp_path, dest)
logger.info("%s ready.", dest.name)
def download_model_entry(model: dict) -> str:
"""Download a single built-in model ZIP. Returns model name."""
import requests
name = model["name"]
dest = MODELS_DIR / name
if dest.exists() and list(dest.glob("*.pth")):
logger.info("Model already present: %s", name)
return name
logger.info("Downloading model: %s …", name)
with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmp:
r = requests.get(model["url"], stream=True, timeout=300)
r.raise_for_status()
for chunk in r.iter_content(8192):
tmp.write(chunk)
tmp_path = tmp.name
extract_zip(tmp_path, name)
os.unlink(tmp_path)
logger.info("Model ready: %s", name)
return name
def startup_downloads() -> str:
"""
Download all required assets in parallel at startup.
Returns name of first built-in model as the default selection.
"""
import requests
predictor_base = "https://huggingface.co/JackismyShephard/ultimate-rvc/resolve/main/Resources/predictors"
embedder_base = "https://huggingface.co/JackismyShephard/ultimate-rvc/resolve/main/Resources/embedders"
predictors_dir = MODELS_DIR / "predictors"
embedders_dir = MODELS_DIR / "embedders"
file_tasks = [
(f"{predictor_base}/rmvpe.pt", predictors_dir / "rmvpe.pt"),
(f"{predictor_base}/fcpe.pt", predictors_dir / "fcpe.pt"),
(f"{embedder_base}/contentvec/pytorch_model.bin", embedders_dir / "contentvec" / "pytorch_model.bin"),
(f"{embedder_base}/contentvec/config.json", embedders_dir / "contentvec" / "config.json"),
]
with ThreadPoolExecutor(max_workers=8) as pool:
file_futures = {pool.submit(download_file, url, dest): dest.name for url, dest in file_tasks}
model_futures = {pool.submit(download_model_entry, m): m["name"] for m in BUILTIN_MODELS}
all_futures = {**file_futures, **model_futures}
for future in as_completed(all_futures):
try:
future.result()
except Exception as exc:
logger.warning("Download failed (%s): %s", all_futures[future], exc)
return BUILTIN_MODELS[0]["name"]