Spaces:
Runtime error
Runtime error
| """Model management utilities.""" | |
| from __future__ import annotations | |
| import os | |
| import shutil | |
| import tempfile | |
| import zipfile | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from pathlib import Path | |
| from .config import ( | |
| BUILTIN_MODELS, | |
| MODELS_DIR, | |
| logger, | |
| ) | |
| def list_models() -> list[str]: | |
| """List available voice models.""" | |
| if not MODELS_DIR.exists(): | |
| return [] | |
| return sorted( | |
| p.name for p in MODELS_DIR.iterdir() | |
| if p.is_dir() and list(p.glob("*.pth")) | |
| ) | |
| def pth_and_index(name: str) -> tuple[str, str]: | |
| """Get model .pth and optional .index paths.""" | |
| d = MODELS_DIR / name | |
| pths = list(d.glob("*.pth")) | |
| idxs = list(d.glob("*.index")) | |
| if not pths: | |
| raise FileNotFoundError(f"No .pth file found in model '{name}'") | |
| return str(pths[0]), str(idxs[0]) if idxs else "" | |
| def extract_zip(zip_path: str | Path, dest_name: str) -> None: | |
| """Extract a model ZIP, flattening nested .pth/.index files.""" | |
| dest = MODELS_DIR / dest_name | |
| dest.mkdir(exist_ok=True) | |
| with zipfile.ZipFile(zip_path, "r") as zf: | |
| zf.extractall(dest) | |
| for nested in list(dest.rglob("*.pth")) + list(dest.rglob("*.index")): | |
| target = dest / nested.name | |
| if nested != target: | |
| shutil.move(str(nested), str(target)) | |
| def download_file(url: str, dest: Path) -> None: | |
| """Download a single file if not already present.""" | |
| if dest.exists(): | |
| return | |
| dest.parent.mkdir(parents=True, exist_ok=True) | |
| logger.info("Downloading %s …", dest.name) | |
| import requests | |
| r = requests.get(url, stream=True, timeout=300) | |
| r.raise_for_status() | |
| with tempfile.NamedTemporaryFile(delete=False, dir=dest.parent, suffix=".tmp") as tmp: | |
| for chunk in r.iter_content(8192): | |
| tmp.write(chunk) | |
| tmp_path = tmp.name | |
| os.replace(tmp_path, dest) | |
| logger.info("%s ready.", dest.name) | |
| def download_model_entry(model: dict) -> str: | |
| """Download a single built-in model ZIP. Returns model name.""" | |
| import requests | |
| name = model["name"] | |
| dest = MODELS_DIR / name | |
| if dest.exists() and list(dest.glob("*.pth")): | |
| logger.info("Model already present: %s", name) | |
| return name | |
| logger.info("Downloading model: %s …", name) | |
| with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmp: | |
| r = requests.get(model["url"], stream=True, timeout=300) | |
| r.raise_for_status() | |
| for chunk in r.iter_content(8192): | |
| tmp.write(chunk) | |
| tmp_path = tmp.name | |
| extract_zip(tmp_path, name) | |
| os.unlink(tmp_path) | |
| logger.info("Model ready: %s", name) | |
| return name | |
| def startup_downloads() -> str: | |
| """ | |
| Download all required assets in parallel at startup. | |
| Returns name of first built-in model as the default selection. | |
| """ | |
| import requests | |
| predictor_base = "https://huggingface.co/JackismyShephard/ultimate-rvc/resolve/main/Resources/predictors" | |
| embedder_base = "https://huggingface.co/JackismyShephard/ultimate-rvc/resolve/main/Resources/embedders" | |
| predictors_dir = MODELS_DIR / "predictors" | |
| embedders_dir = MODELS_DIR / "embedders" | |
| file_tasks = [ | |
| (f"{predictor_base}/rmvpe.pt", predictors_dir / "rmvpe.pt"), | |
| (f"{predictor_base}/fcpe.pt", predictors_dir / "fcpe.pt"), | |
| (f"{embedder_base}/contentvec/pytorch_model.bin", embedders_dir / "contentvec" / "pytorch_model.bin"), | |
| (f"{embedder_base}/contentvec/config.json", embedders_dir / "contentvec" / "config.json"), | |
| ] | |
| with ThreadPoolExecutor(max_workers=8) as pool: | |
| file_futures = {pool.submit(download_file, url, dest): dest.name for url, dest in file_tasks} | |
| model_futures = {pool.submit(download_model_entry, m): m["name"] for m in BUILTIN_MODELS} | |
| all_futures = {**file_futures, **model_futures} | |
| for future in as_completed(all_futures): | |
| try: | |
| future.result() | |
| except Exception as exc: | |
| logger.warning("Download failed (%s): %s", all_futures[future], exc) | |
| return BUILTIN_MODELS[0]["name"] | |