| """ |
| Model Downloader for SYSPIN TTS Models |
| Downloads models from Hugging Face Hub |
| """ |
|
|
| import os |
| import logging |
| from pathlib import Path |
| from typing import Optional, List |
| from huggingface_hub import hf_hub_download, snapshot_download |
| from tqdm import tqdm |
|
|
| from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ModelDownloader: |
| """Downloads and manages SYSPIN TTS models from Hugging Face""" |
|
|
| def __init__(self, models_dir: str = MODELS_DIR): |
| self.models_dir = Path(models_dir) |
| self.models_dir.mkdir(parents=True, exist_ok=True) |
|
|
| def download_model(self, voice_key: str, force: bool = False) -> Path: |
| """ |
| Download a specific voice model |
| |
| Args: |
| voice_key: Key from LANGUAGE_CONFIGS (e.g., 'hi_male', 'bn_female') |
| force: Re-download even if exists |
| |
| Returns: |
| Path to downloaded model directory |
| """ |
| if voice_key not in LANGUAGE_CONFIGS: |
| raise ValueError( |
| f"Unknown voice: {voice_key}. Available: {list(LANGUAGE_CONFIGS.keys())}" |
| ) |
|
|
| config = LANGUAGE_CONFIGS[voice_key] |
| model_dir = self.models_dir / voice_key |
|
|
| |
| model_path = model_dir / config.model_filename |
| chars_path = model_dir / config.chars_filename |
| extra_path = model_dir / "extra.py" |
|
|
| if not force and model_path.exists() and chars_path.exists(): |
| logger.info(f"Model {voice_key} already downloaded at {model_dir}") |
| return model_dir |
|
|
| logger.info(f"Downloading {voice_key} from {config.hf_model_id}...") |
|
|
| |
| model_dir.mkdir(parents=True, exist_ok=True) |
|
|
| try: |
| |
| snapshot_download( |
| repo_id=config.hf_model_id, |
| local_dir=str(model_dir), |
| local_dir_use_symlinks=False, |
| allow_patterns=["*.pt", "*.pth", "*.txt", "*.py", "*.json"], |
| ) |
| logger.info(f"Successfully downloaded {voice_key} to {model_dir}") |
|
|
| except Exception as e: |
| logger.error(f"Failed to download {voice_key}: {e}") |
| raise |
|
|
| return model_dir |
|
|
| def download_all_models(self, force: bool = False) -> List[Path]: |
| """Download all available models""" |
| downloaded = [] |
|
|
| for voice_key in tqdm(LANGUAGE_CONFIGS.keys(), desc="Downloading models"): |
| try: |
| path = self.download_model(voice_key, force=force) |
| downloaded.append(path) |
| except Exception as e: |
| logger.warning(f"Failed to download {voice_key}: {e}") |
|
|
| return downloaded |
|
|
| def download_language(self, lang_code: str, force: bool = False) -> List[Path]: |
| """Download all voices for a specific language""" |
| downloaded = [] |
|
|
| for voice_key, config in LANGUAGE_CONFIGS.items(): |
| if config.code == lang_code: |
| try: |
| path = self.download_model(voice_key, force=force) |
| downloaded.append(path) |
| except Exception as e: |
| logger.warning(f"Failed to download {voice_key}: {e}") |
|
|
| return downloaded |
|
|
| def get_model_path(self, voice_key: str) -> Optional[Path]: |
| """Get path to a downloaded model""" |
| if voice_key not in LANGUAGE_CONFIGS: |
| return None |
|
|
| config = LANGUAGE_CONFIGS[voice_key] |
| model_path = self.models_dir / voice_key / config.model_filename |
|
|
| if model_path.exists(): |
| return model_path.parent |
| return None |
|
|
| def list_downloaded_models(self) -> List[str]: |
| """List all downloaded models""" |
| downloaded = [] |
|
|
| for voice_key, config in LANGUAGE_CONFIGS.items(): |
| model_path = self.models_dir / voice_key / config.model_filename |
| if model_path.exists(): |
| downloaded.append(voice_key) |
|
|
| return downloaded |
|
|
| def get_model_size(self, voice_key: str) -> Optional[int]: |
| """Get size of downloaded model in bytes""" |
| model_path = self.get_model_path(voice_key) |
| if not model_path: |
| return None |
|
|
| total_size = 0 |
| for f in model_path.iterdir(): |
| if f.is_file(): |
| total_size += f.stat().st_size |
|
|
| return total_size |
|
|
|
|
| def download_models_cli(): |
| """CLI entry point for downloading models""" |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="Download SYSPIN TTS models") |
| parser.add_argument( |
| "--voice", type=str, help="Specific voice to download (e.g., hi_male)" |
| ) |
| parser.add_argument( |
| "--lang", type=str, help="Download all voices for a language (e.g., hi)" |
| ) |
| parser.add_argument("--all", action="store_true", help="Download all models") |
| parser.add_argument("--list", action="store_true", help="List available models") |
| parser.add_argument("--force", action="store_true", help="Force re-download") |
|
|
| args = parser.parse_args() |
|
|
| downloader = ModelDownloader() |
|
|
| if args.list: |
| print("Available voices:") |
| for key, config in LANGUAGE_CONFIGS.items(): |
| downloaded = "✓" if downloader.get_model_path(key) else " " |
| print(f" [{downloaded}] {key}: {config.name} ({config.code})") |
| return |
|
|
| if args.voice: |
| downloader.download_model(args.voice, force=args.force) |
| elif args.lang: |
| downloader.download_language(args.lang, force=args.force) |
| elif args.all: |
| downloader.download_all_models(force=args.force) |
| else: |
| parser.print_help() |
|
|
|
|
| if __name__ == "__main__": |
| download_models_cli() |
|
|