""" Hub utilities for downloading and managing Chiluka TTS models. Supports: - HuggingFace Hub integration - Automatic model downloading - Local caching - Multiple model variants """ import os import shutil from pathlib import Path from typing import Optional, Union # Default HuggingFace Hub repository DEFAULT_HF_REPO = "Seemanth/chiluka-tts" # Cache directory for downloaded models CACHE_DIR = Path.home() / ".cache" / "chiluka" # ============================================ # Model Registry # ============================================ # Maps model names to their config + checkpoint paths # relative to the repo root. MODEL_REGISTRY = { "telugu": { "config": "configs/config_ft.yml", "checkpoint": "checkpoints/epoch_2nd_00017.pth", "languages": ["te", "en"], "description": "Telugu + English single-speaker TTS", }, "hindi_english": { "config": "configs/config_hindi_english.yml", "checkpoint": "checkpoints/epoch_2nd_00029.pth", "languages": ["hi", "en"], "description": "Hindi + English multi-speaker TTS (5 speakers)", }, } DEFAULT_MODEL = "hindi_english" # Shared pretrained sub-models (same across all variants) PRETRAINED_FILES = { "asr_config": "pretrained/ASR/config.yml", "asr_model": "pretrained/ASR/epoch_00080.pth", "f0_model": "pretrained/JDC/bst.t7", "plbert_config": "pretrained/PLBERT/config.yml", "plbert_model": "pretrained/PLBERT/step_1000000.t7", } def list_models() -> dict: """ List all available model variants. Returns: Dictionary of model names and their info. Example: >>> from chiluka import hub >>> hub.list_models() {'telugu': {...}, 'hindi_english': {...}} """ return { name: { "languages": info["languages"], "description": info["description"], } for name, info in MODEL_REGISTRY.items() } def get_cache_dir() -> Path: """Get the cache directory for Chiluka models.""" cache_dir = Path(os.environ.get("CHILUKA_CACHE", CACHE_DIR)) cache_dir.mkdir(parents=True, exist_ok=True) return cache_dir def is_model_cached(repo_id: str = DEFAULT_HF_REPO) -> bool: """Check if a model is already cached locally.""" cache_path = get_cache_dir() / repo_id.replace("/", "_") if not cache_path.exists(): return False # Check if shared pretrained files exist for file_path in PRETRAINED_FILES.values(): if not (cache_path / file_path).exists(): return False # Check if at least one model variant exists for model_info in MODEL_REGISTRY.values(): config_exists = (cache_path / model_info["config"]).exists() checkpoint_exists = (cache_path / model_info["checkpoint"]).exists() if config_exists and checkpoint_exists: return True return False def download_from_hf( repo_id: str = DEFAULT_HF_REPO, revision: str = "main", force_download: bool = False, token: Optional[str] = None, ) -> Path: """ Download model files from HuggingFace Hub. Args: repo_id: HuggingFace Hub repository ID (e.g., 'Seemanth/chiluka-tts') revision: Git revision to download (branch, tag, or commit hash) force_download: If True, re-download even if cached token: HuggingFace API token for private repos Returns: Path to the downloaded model directory """ try: from huggingface_hub import snapshot_download except ImportError: raise ImportError( "huggingface_hub is required for downloading models. " "Install with: pip install huggingface_hub" ) cache_path = get_cache_dir() / repo_id.replace("/", "_") if is_model_cached(repo_id) and not force_download: print(f"Using cached model from {cache_path}") return cache_path print(f"Downloading model from HuggingFace Hub: {repo_id}...") downloaded_path = snapshot_download( repo_id=repo_id, revision=revision, cache_dir=get_cache_dir() / "hf_cache", token=token, local_dir=cache_path, local_dir_use_symlinks=False, ) print(f"Model downloaded to {cache_path}") return Path(downloaded_path) def get_model_paths( model: str = DEFAULT_MODEL, repo_id: str = DEFAULT_HF_REPO, ) -> dict: """ Get paths to all model files after downloading. Args: model: Model variant name ('telugu', 'hindi_english') repo_id: HuggingFace Hub repository ID Returns: Dictionary with paths to config, checkpoint, and pretrained directory """ if model not in MODEL_REGISTRY: available = ", ".join(MODEL_REGISTRY.keys()) raise ValueError( f"Unknown model '{model}'. Available models: {available}" ) model_dir = download_from_hf(repo_id) model_info = MODEL_REGISTRY[model] return { "config_path": str(model_dir / model_info["config"]), "checkpoint_path": str(model_dir / model_info["checkpoint"]), "pretrained_dir": str(model_dir / "pretrained"), } def clear_cache(repo_id: Optional[str] = None): """ Clear cached models. Args: repo_id: If specified, only clear cache for this repo. If None, clear entire cache. """ cache_dir = get_cache_dir() if repo_id: cache_path = cache_dir / repo_id.replace("/", "_") if cache_path.exists(): shutil.rmtree(cache_path) print(f"Cleared cache for {repo_id}") else: if cache_dir.exists(): shutil.rmtree(cache_dir) print("Cleared entire Chiluka cache") def push_to_hub( local_dir: str, repo_id: str, token: Optional[str] = None, private: bool = False, commit_message: str = "Upload Chiluka TTS model", ): """ Push a local model to HuggingFace Hub. Args: local_dir: Local directory containing model files repo_id: Target HuggingFace Hub repository ID token: HuggingFace API token (or set HF_TOKEN env var) private: Whether to create a private repository commit_message: Commit message for the upload Example: >>> push_to_hub( ... local_dir="./chiluka", ... repo_id="Seemanth/chiluka-tts", ... private=False ... ) """ try: from huggingface_hub import HfApi, create_repo except ImportError: raise ImportError( "huggingface_hub is required for pushing models. " "Install with: pip install huggingface_hub" ) api = HfApi(token=token) # Create repo if it doesn't exist try: create_repo(repo_id, private=private, token=token, exist_ok=True) except Exception as e: print(f"Note: {e}") # Upload folder print(f"Uploading to {repo_id}...") api.upload_folder( folder_path=local_dir, repo_id=repo_id, commit_message=commit_message, ignore_patterns=["*.pyc", "__pycache__", "*.egg-info", ".git"], ) print(f"Model uploaded to: https://huggingface.co/{repo_id}") def create_model_card(repo_id: str, save_path: Optional[str] = None) -> str: """ Generate a model card (README.md) for HuggingFace Hub. Args: repo_id: Repository ID for the model save_path: If provided, save the model card to this path Returns: Model card content as string """ owner = repo_id.split("/")[0] # Build model table model_rows = "" for name, info in MODEL_REGISTRY.items(): langs = ", ".join(info["languages"]) model_rows += f"| `{name}` | {info['description']} | {langs} |\n" model_card = f"""--- language: - en - te - hi license: mit library_name: chiluka tags: - text-to-speech - tts - styletts2 - voice-cloning - multi-language --- # Chiluka TTS Chiluka (చిలుక - Telugu for "parrot") is a lightweight Text-to-Speech model based on StyleTTS2. ## Available Models | Model | Description | Languages | |-------|-------------|-----------| {model_rows} ## Installation ```bash pip install chiluka ``` Or install from source: ```bash pip install git+https://github.com/{owner}/chiluka.git ``` ## Usage ### Hindi + English (default) ```python from chiluka import Chiluka tts = Chiluka.from_pretrained() wav = tts.synthesize( text="Hello, world!", reference_audio="reference.wav", language="en" ) tts.save_wav(wav, "output.wav") ``` ### Telugu ```python tts = Chiluka.from_pretrained(model="telugu") wav = tts.synthesize( text="నమస్కారం", reference_audio="reference.wav", language="te" ) ``` ### PyTorch Hub ```python import torch tts = torch.hub.load('{owner}/chiluka', 'chiluka') tts = torch.hub.load('{owner}/chiluka', 'chiluka_telugu') ``` ## License MIT License ## Citation Based on StyleTTS2 by Yinghao Aaron Li et al. """ if save_path: with open(save_path, "w") as f: f.write(model_card) print(f"Model card saved to {save_path}") return model_card