chiluka-tts / hub.py
Seemanth's picture
Add Chiluka TTS models (Hindi-English + Telugu)
13f85be verified
"""
Hub utilities for downloading and managing Chiluka TTS models.
Supports:
- HuggingFace Hub integration
- Automatic model downloading
- Local caching
- Multiple model variants
"""
import os
import shutil
from pathlib import Path
from typing import Optional, Union
# Default HuggingFace Hub repository
DEFAULT_HF_REPO = "Seemanth/chiluka-tts"
# Cache directory for downloaded models
CACHE_DIR = Path.home() / ".cache" / "chiluka"
# ============================================
# Model Registry
# ============================================
# Maps model names to their config + checkpoint paths
# relative to the repo root.
MODEL_REGISTRY = {
"telugu": {
"config": "configs/config_ft.yml",
"checkpoint": "checkpoints/epoch_2nd_00017.pth",
"languages": ["te", "en"],
"description": "Telugu + English single-speaker TTS",
},
"hindi_english": {
"config": "configs/config_hindi_english.yml",
"checkpoint": "checkpoints/epoch_2nd_00029.pth",
"languages": ["hi", "en"],
"description": "Hindi + English multi-speaker TTS (5 speakers)",
},
}
DEFAULT_MODEL = "hindi_english"
# Shared pretrained sub-models (same across all variants)
PRETRAINED_FILES = {
"asr_config": "pretrained/ASR/config.yml",
"asr_model": "pretrained/ASR/epoch_00080.pth",
"f0_model": "pretrained/JDC/bst.t7",
"plbert_config": "pretrained/PLBERT/config.yml",
"plbert_model": "pretrained/PLBERT/step_1000000.t7",
}
def list_models() -> dict:
"""
List all available model variants.
Returns:
Dictionary of model names and their info.
Example:
>>> from chiluka import hub
>>> hub.list_models()
{'telugu': {...}, 'hindi_english': {...}}
"""
return {
name: {
"languages": info["languages"],
"description": info["description"],
}
for name, info in MODEL_REGISTRY.items()
}
def get_cache_dir() -> Path:
"""Get the cache directory for Chiluka models."""
cache_dir = Path(os.environ.get("CHILUKA_CACHE", CACHE_DIR))
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
def is_model_cached(repo_id: str = DEFAULT_HF_REPO) -> bool:
"""Check if a model is already cached locally."""
cache_path = get_cache_dir() / repo_id.replace("/", "_")
if not cache_path.exists():
return False
# Check if shared pretrained files exist
for file_path in PRETRAINED_FILES.values():
if not (cache_path / file_path).exists():
return False
# Check if at least one model variant exists
for model_info in MODEL_REGISTRY.values():
config_exists = (cache_path / model_info["config"]).exists()
checkpoint_exists = (cache_path / model_info["checkpoint"]).exists()
if config_exists and checkpoint_exists:
return True
return False
def download_from_hf(
repo_id: str = DEFAULT_HF_REPO,
revision: str = "main",
force_download: bool = False,
token: Optional[str] = None,
) -> Path:
"""
Download model files from HuggingFace Hub.
Args:
repo_id: HuggingFace Hub repository ID (e.g., 'Seemanth/chiluka-tts')
revision: Git revision to download (branch, tag, or commit hash)
force_download: If True, re-download even if cached
token: HuggingFace API token for private repos
Returns:
Path to the downloaded model directory
"""
try:
from huggingface_hub import snapshot_download
except ImportError:
raise ImportError(
"huggingface_hub is required for downloading models. "
"Install with: pip install huggingface_hub"
)
cache_path = get_cache_dir() / repo_id.replace("/", "_")
if is_model_cached(repo_id) and not force_download:
print(f"Using cached model from {cache_path}")
return cache_path
print(f"Downloading model from HuggingFace Hub: {repo_id}...")
downloaded_path = snapshot_download(
repo_id=repo_id,
revision=revision,
cache_dir=get_cache_dir() / "hf_cache",
token=token,
local_dir=cache_path,
local_dir_use_symlinks=False,
)
print(f"Model downloaded to {cache_path}")
return Path(downloaded_path)
def get_model_paths(
model: str = DEFAULT_MODEL,
repo_id: str = DEFAULT_HF_REPO,
) -> dict:
"""
Get paths to all model files after downloading.
Args:
model: Model variant name ('telugu', 'hindi_english')
repo_id: HuggingFace Hub repository ID
Returns:
Dictionary with paths to config, checkpoint, and pretrained directory
"""
if model not in MODEL_REGISTRY:
available = ", ".join(MODEL_REGISTRY.keys())
raise ValueError(
f"Unknown model '{model}'. Available models: {available}"
)
model_dir = download_from_hf(repo_id)
model_info = MODEL_REGISTRY[model]
return {
"config_path": str(model_dir / model_info["config"]),
"checkpoint_path": str(model_dir / model_info["checkpoint"]),
"pretrained_dir": str(model_dir / "pretrained"),
}
def clear_cache(repo_id: Optional[str] = None):
"""
Clear cached models.
Args:
repo_id: If specified, only clear cache for this repo.
If None, clear entire cache.
"""
cache_dir = get_cache_dir()
if repo_id:
cache_path = cache_dir / repo_id.replace("/", "_")
if cache_path.exists():
shutil.rmtree(cache_path)
print(f"Cleared cache for {repo_id}")
else:
if cache_dir.exists():
shutil.rmtree(cache_dir)
print("Cleared entire Chiluka cache")
def push_to_hub(
local_dir: str,
repo_id: str,
token: Optional[str] = None,
private: bool = False,
commit_message: str = "Upload Chiluka TTS model",
):
"""
Push a local model to HuggingFace Hub.
Args:
local_dir: Local directory containing model files
repo_id: Target HuggingFace Hub repository ID
token: HuggingFace API token (or set HF_TOKEN env var)
private: Whether to create a private repository
commit_message: Commit message for the upload
Example:
>>> push_to_hub(
... local_dir="./chiluka",
... repo_id="Seemanth/chiluka-tts",
... private=False
... )
"""
try:
from huggingface_hub import HfApi, create_repo
except ImportError:
raise ImportError(
"huggingface_hub is required for pushing models. "
"Install with: pip install huggingface_hub"
)
api = HfApi(token=token)
# Create repo if it doesn't exist
try:
create_repo(repo_id, private=private, token=token, exist_ok=True)
except Exception as e:
print(f"Note: {e}")
# Upload folder
print(f"Uploading to {repo_id}...")
api.upload_folder(
folder_path=local_dir,
repo_id=repo_id,
commit_message=commit_message,
ignore_patterns=["*.pyc", "__pycache__", "*.egg-info", ".git"],
)
print(f"Model uploaded to: https://huggingface.co/{repo_id}")
def create_model_card(repo_id: str, save_path: Optional[str] = None) -> str:
"""
Generate a model card (README.md) for HuggingFace Hub.
Args:
repo_id: Repository ID for the model
save_path: If provided, save the model card to this path
Returns:
Model card content as string
"""
owner = repo_id.split("/")[0]
# Build model table
model_rows = ""
for name, info in MODEL_REGISTRY.items():
langs = ", ".join(info["languages"])
model_rows += f"| `{name}` | {info['description']} | {langs} |\n"
model_card = f"""---
language:
- en
- te
- hi
license: mit
library_name: chiluka
tags:
- text-to-speech
- tts
- styletts2
- voice-cloning
- multi-language
---
# Chiluka TTS
Chiluka (చిలుక - Telugu for "parrot") is a lightweight Text-to-Speech model based on StyleTTS2.
## Available Models
| Model | Description | Languages |
|-------|-------------|-----------|
{model_rows}
## Installation
```bash
pip install chiluka
```
Or install from source:
```bash
pip install git+https://github.com/{owner}/chiluka.git
```
## Usage
### Hindi + English (default)
```python
from chiluka import Chiluka
tts = Chiluka.from_pretrained()
wav = tts.synthesize(
text="Hello, world!",
reference_audio="reference.wav",
language="en"
)
tts.save_wav(wav, "output.wav")
```
### Telugu
```python
tts = Chiluka.from_pretrained(model="telugu")
wav = tts.synthesize(
text="నమస్కారం",
reference_audio="reference.wav",
language="te"
)
```
### PyTorch Hub
```python
import torch
tts = torch.hub.load('{owner}/chiluka', 'chiluka')
tts = torch.hub.load('{owner}/chiluka', 'chiluka_telugu')
```
## License
MIT License
## Citation
Based on StyleTTS2 by Yinghao Aaron Li et al.
"""
if save_path:
with open(save_path, "w") as f:
f.write(model_card)
print(f"Model card saved to {save_path}")
return model_card