| """Model management utilities.""" |
| from __future__ import annotations |
|
|
| import os |
| import shutil |
| import tempfile |
| import zipfile |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from pathlib import Path |
|
|
| from .config import ( |
| BUILTIN_MODELS, |
| MODELS_DIR, |
| logger, |
| ) |
|
|
|
|
| def list_models() -> list[str]: |
| """List available voice models.""" |
| if not MODELS_DIR.exists(): |
| return [] |
| return sorted( |
| p.name for p in MODELS_DIR.iterdir() |
| if p.is_dir() and list(p.glob("*.pth")) |
| ) |
|
|
|
|
| def pth_and_index(name: str) -> tuple[str, str]: |
| """Get model .pth and optional .index paths.""" |
| d = MODELS_DIR / name |
| pths = list(d.glob("*.pth")) |
| idxs = list(d.glob("*.index")) |
| if not pths: |
| raise FileNotFoundError(f"No .pth file found in model '{name}'") |
| return str(pths[0]), str(idxs[0]) if idxs else "" |
|
|
|
|
| def extract_zip(zip_path: str | Path, dest_name: str) -> None: |
| """Extract a model ZIP, flattening nested .pth/.index files.""" |
| dest = MODELS_DIR / dest_name |
| dest.mkdir(exist_ok=True) |
| with zipfile.ZipFile(zip_path, "r") as zf: |
| zf.extractall(dest) |
| for nested in list(dest.rglob("*.pth")) + list(dest.rglob("*.index")): |
| target = dest / nested.name |
| if nested != target: |
| shutil.move(str(nested), str(target)) |
|
|
|
|
| def download_file(url: str, dest: Path) -> None: |
| """Download a single file if not already present.""" |
| if dest.exists(): |
| return |
| dest.parent.mkdir(parents=True, exist_ok=True) |
| logger.info("Downloading %s …", dest.name) |
| import requests |
| r = requests.get(url, stream=True, timeout=300) |
| r.raise_for_status() |
| with tempfile.NamedTemporaryFile(delete=False, dir=dest.parent, suffix=".tmp") as tmp: |
| for chunk in r.iter_content(8192): |
| tmp.write(chunk) |
| tmp_path = tmp.name |
| os.replace(tmp_path, dest) |
| logger.info("%s ready.", dest.name) |
|
|
|
|
| def download_model_entry(model: dict) -> str: |
| """Download a single built-in model ZIP. Returns model name.""" |
| import requests |
| name = model["name"] |
| dest = MODELS_DIR / name |
| if dest.exists() and list(dest.glob("*.pth")): |
| logger.info("Model already present: %s", name) |
| return name |
| logger.info("Downloading model: %s …", name) |
| with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmp: |
| r = requests.get(model["url"], stream=True, timeout=300) |
| r.raise_for_status() |
| for chunk in r.iter_content(8192): |
| tmp.write(chunk) |
| tmp_path = tmp.name |
| extract_zip(tmp_path, name) |
| os.unlink(tmp_path) |
| logger.info("Model ready: %s", name) |
| return name |
|
|
|
|
| def startup_downloads() -> str: |
| """ |
| Download all required assets in parallel at startup. |
| Returns name of first built-in model as the default selection. |
| """ |
| import requests |
|
|
| predictor_base = "https://huggingface.co/JackismyShephard/ultimate-rvc/resolve/main/Resources/predictors" |
| embedder_base = "https://huggingface.co/JackismyShephard/ultimate-rvc/resolve/main/Resources/embedders" |
| predictors_dir = MODELS_DIR / "predictors" |
| embedders_dir = MODELS_DIR / "embedders" |
|
|
| file_tasks = [ |
| (f"{predictor_base}/rmvpe.pt", predictors_dir / "rmvpe.pt"), |
| (f"{predictor_base}/fcpe.pt", predictors_dir / "fcpe.pt"), |
| (f"{embedder_base}/contentvec/pytorch_model.bin", embedders_dir / "contentvec" / "pytorch_model.bin"), |
| (f"{embedder_base}/contentvec/config.json", embedders_dir / "contentvec" / "config.json"), |
| ] |
|
|
| with ThreadPoolExecutor(max_workers=8) as pool: |
| file_futures = {pool.submit(download_file, url, dest): dest.name for url, dest in file_tasks} |
| model_futures = {pool.submit(download_model_entry, m): m["name"] for m in BUILTIN_MODELS} |
|
|
| all_futures = {**file_futures, **model_futures} |
| for future in as_completed(all_futures): |
| try: |
| future.result() |
| except Exception as exc: |
| logger.warning("Download failed (%s): %s", all_futures[future], exc) |
|
|
| return BUILTIN_MODELS[0]["name"] |
|
|