|
|
""" |
|
|
Hub utilities for downloading and managing Chiluka TTS models. |
|
|
|
|
|
Supports: |
|
|
- HuggingFace Hub integration |
|
|
- Automatic model downloading |
|
|
- Local caching |
|
|
- Multiple model variants |
|
|
""" |
|
|
|
|
|
import os |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
from typing import Optional, Union |
|
|
|
|
|
|
|
|
DEFAULT_HF_REPO = "Seemanth/chiluka-tts" |
|
|
|
|
|
|
|
|
CACHE_DIR = Path.home() / ".cache" / "chiluka" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_REGISTRY = { |
|
|
"telugu": { |
|
|
"config": "configs/config_ft.yml", |
|
|
"checkpoint": "checkpoints/epoch_2nd_00017.pth", |
|
|
"languages": ["te", "en"], |
|
|
"description": "Telugu + English single-speaker TTS", |
|
|
}, |
|
|
"hindi_english": { |
|
|
"config": "configs/config_hindi_english.yml", |
|
|
"checkpoint": "checkpoints/epoch_2nd_00029.pth", |
|
|
"languages": ["hi", "en"], |
|
|
"description": "Hindi + English multi-speaker TTS (5 speakers)", |
|
|
}, |
|
|
} |
|
|
|
|
|
DEFAULT_MODEL = "hindi_english" |
|
|
|
|
|
|
|
|
PRETRAINED_FILES = { |
|
|
"asr_config": "pretrained/ASR/config.yml", |
|
|
"asr_model": "pretrained/ASR/epoch_00080.pth", |
|
|
"f0_model": "pretrained/JDC/bst.t7", |
|
|
"plbert_config": "pretrained/PLBERT/config.yml", |
|
|
"plbert_model": "pretrained/PLBERT/step_1000000.t7", |
|
|
} |
|
|
|
|
|
|
|
|
def list_models() -> dict: |
|
|
""" |
|
|
List all available model variants. |
|
|
|
|
|
Returns: |
|
|
Dictionary of model names and their info. |
|
|
|
|
|
Example: |
|
|
>>> from chiluka import hub |
|
|
>>> hub.list_models() |
|
|
{'telugu': {...}, 'hindi_english': {...}} |
|
|
""" |
|
|
return { |
|
|
name: { |
|
|
"languages": info["languages"], |
|
|
"description": info["description"], |
|
|
} |
|
|
for name, info in MODEL_REGISTRY.items() |
|
|
} |
|
|
|
|
|
|
|
|
def get_cache_dir() -> Path: |
|
|
"""Get the cache directory for Chiluka models.""" |
|
|
cache_dir = Path(os.environ.get("CHILUKA_CACHE", CACHE_DIR)) |
|
|
cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
return cache_dir |
|
|
|
|
|
|
|
|
def is_model_cached(repo_id: str = DEFAULT_HF_REPO) -> bool: |
|
|
"""Check if a model is already cached locally.""" |
|
|
cache_path = get_cache_dir() / repo_id.replace("/", "_") |
|
|
if not cache_path.exists(): |
|
|
return False |
|
|
|
|
|
|
|
|
for file_path in PRETRAINED_FILES.values(): |
|
|
if not (cache_path / file_path).exists(): |
|
|
return False |
|
|
|
|
|
|
|
|
for model_info in MODEL_REGISTRY.values(): |
|
|
config_exists = (cache_path / model_info["config"]).exists() |
|
|
checkpoint_exists = (cache_path / model_info["checkpoint"]).exists() |
|
|
if config_exists and checkpoint_exists: |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
def download_from_hf( |
|
|
repo_id: str = DEFAULT_HF_REPO, |
|
|
revision: str = "main", |
|
|
force_download: bool = False, |
|
|
token: Optional[str] = None, |
|
|
) -> Path: |
|
|
""" |
|
|
Download model files from HuggingFace Hub. |
|
|
|
|
|
Args: |
|
|
repo_id: HuggingFace Hub repository ID (e.g., 'Seemanth/chiluka-tts') |
|
|
revision: Git revision to download (branch, tag, or commit hash) |
|
|
force_download: If True, re-download even if cached |
|
|
token: HuggingFace API token for private repos |
|
|
|
|
|
Returns: |
|
|
Path to the downloaded model directory |
|
|
""" |
|
|
try: |
|
|
from huggingface_hub import snapshot_download |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"huggingface_hub is required for downloading models. " |
|
|
"Install with: pip install huggingface_hub" |
|
|
) |
|
|
|
|
|
cache_path = get_cache_dir() / repo_id.replace("/", "_") |
|
|
|
|
|
if is_model_cached(repo_id) and not force_download: |
|
|
print(f"Using cached model from {cache_path}") |
|
|
return cache_path |
|
|
|
|
|
print(f"Downloading model from HuggingFace Hub: {repo_id}...") |
|
|
|
|
|
downloaded_path = snapshot_download( |
|
|
repo_id=repo_id, |
|
|
revision=revision, |
|
|
cache_dir=get_cache_dir() / "hf_cache", |
|
|
token=token, |
|
|
local_dir=cache_path, |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
|
|
|
print(f"Model downloaded to {cache_path}") |
|
|
return Path(downloaded_path) |
|
|
|
|
|
|
|
|
def get_model_paths( |
|
|
model: str = DEFAULT_MODEL, |
|
|
repo_id: str = DEFAULT_HF_REPO, |
|
|
) -> dict: |
|
|
""" |
|
|
Get paths to all model files after downloading. |
|
|
|
|
|
Args: |
|
|
model: Model variant name ('telugu', 'hindi_english') |
|
|
repo_id: HuggingFace Hub repository ID |
|
|
|
|
|
Returns: |
|
|
Dictionary with paths to config, checkpoint, and pretrained directory |
|
|
""" |
|
|
if model not in MODEL_REGISTRY: |
|
|
available = ", ".join(MODEL_REGISTRY.keys()) |
|
|
raise ValueError( |
|
|
f"Unknown model '{model}'. Available models: {available}" |
|
|
) |
|
|
|
|
|
model_dir = download_from_hf(repo_id) |
|
|
model_info = MODEL_REGISTRY[model] |
|
|
|
|
|
return { |
|
|
"config_path": str(model_dir / model_info["config"]), |
|
|
"checkpoint_path": str(model_dir / model_info["checkpoint"]), |
|
|
"pretrained_dir": str(model_dir / "pretrained"), |
|
|
} |
|
|
|
|
|
|
|
|
def clear_cache(repo_id: Optional[str] = None): |
|
|
""" |
|
|
Clear cached models. |
|
|
|
|
|
Args: |
|
|
repo_id: If specified, only clear cache for this repo. |
|
|
If None, clear entire cache. |
|
|
""" |
|
|
cache_dir = get_cache_dir() |
|
|
|
|
|
if repo_id: |
|
|
cache_path = cache_dir / repo_id.replace("/", "_") |
|
|
if cache_path.exists(): |
|
|
shutil.rmtree(cache_path) |
|
|
print(f"Cleared cache for {repo_id}") |
|
|
else: |
|
|
if cache_dir.exists(): |
|
|
shutil.rmtree(cache_dir) |
|
|
print("Cleared entire Chiluka cache") |
|
|
|
|
|
|
|
|
def push_to_hub( |
|
|
local_dir: str, |
|
|
repo_id: str, |
|
|
token: Optional[str] = None, |
|
|
private: bool = False, |
|
|
commit_message: str = "Upload Chiluka TTS model", |
|
|
): |
|
|
""" |
|
|
Push a local model to HuggingFace Hub. |
|
|
|
|
|
Args: |
|
|
local_dir: Local directory containing model files |
|
|
repo_id: Target HuggingFace Hub repository ID |
|
|
token: HuggingFace API token (or set HF_TOKEN env var) |
|
|
private: Whether to create a private repository |
|
|
commit_message: Commit message for the upload |
|
|
|
|
|
Example: |
|
|
>>> push_to_hub( |
|
|
... local_dir="./chiluka", |
|
|
... repo_id="Seemanth/chiluka-tts", |
|
|
... private=False |
|
|
... ) |
|
|
""" |
|
|
try: |
|
|
from huggingface_hub import HfApi, create_repo |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"huggingface_hub is required for pushing models. " |
|
|
"Install with: pip install huggingface_hub" |
|
|
) |
|
|
|
|
|
api = HfApi(token=token) |
|
|
|
|
|
|
|
|
try: |
|
|
create_repo(repo_id, private=private, token=token, exist_ok=True) |
|
|
except Exception as e: |
|
|
print(f"Note: {e}") |
|
|
|
|
|
|
|
|
print(f"Uploading to {repo_id}...") |
|
|
api.upload_folder( |
|
|
folder_path=local_dir, |
|
|
repo_id=repo_id, |
|
|
commit_message=commit_message, |
|
|
ignore_patterns=["*.pyc", "__pycache__", "*.egg-info", ".git"], |
|
|
) |
|
|
|
|
|
print(f"Model uploaded to: https://huggingface.co/{repo_id}") |
|
|
|
|
|
|
|
|
def create_model_card(repo_id: str, save_path: Optional[str] = None) -> str: |
|
|
""" |
|
|
Generate a model card (README.md) for HuggingFace Hub. |
|
|
|
|
|
Args: |
|
|
repo_id: Repository ID for the model |
|
|
save_path: If provided, save the model card to this path |
|
|
|
|
|
Returns: |
|
|
Model card content as string |
|
|
""" |
|
|
owner = repo_id.split("/")[0] |
|
|
|
|
|
|
|
|
model_rows = "" |
|
|
for name, info in MODEL_REGISTRY.items(): |
|
|
langs = ", ".join(info["languages"]) |
|
|
model_rows += f"| `{name}` | {info['description']} | {langs} |\n" |
|
|
|
|
|
model_card = f"""--- |
|
|
language: |
|
|
- en |
|
|
- te |
|
|
- hi |
|
|
license: mit |
|
|
library_name: chiluka |
|
|
tags: |
|
|
- text-to-speech |
|
|
- tts |
|
|
- styletts2 |
|
|
- voice-cloning |
|
|
- multi-language |
|
|
--- |
|
|
|
|
|
# Chiluka TTS |
|
|
|
|
|
Chiluka (చిలుక - Telugu for "parrot") is a lightweight Text-to-Speech model based on StyleTTS2. |
|
|
|
|
|
## Available Models |
|
|
|
|
|
| Model | Description | Languages | |
|
|
|-------|-------------|-----------| |
|
|
{model_rows} |
|
|
|
|
|
## Installation |
|
|
|
|
|
```bash |
|
|
pip install chiluka |
|
|
``` |
|
|
|
|
|
Or install from source: |
|
|
|
|
|
```bash |
|
|
pip install git+https://github.com/{owner}/chiluka.git |
|
|
``` |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Hindi + English (default) |
|
|
|
|
|
```python |
|
|
from chiluka import Chiluka |
|
|
|
|
|
tts = Chiluka.from_pretrained() |
|
|
|
|
|
wav = tts.synthesize( |
|
|
text="Hello, world!", |
|
|
reference_audio="reference.wav", |
|
|
language="en" |
|
|
) |
|
|
tts.save_wav(wav, "output.wav") |
|
|
``` |
|
|
|
|
|
### Telugu |
|
|
|
|
|
```python |
|
|
tts = Chiluka.from_pretrained(model="telugu") |
|
|
|
|
|
wav = tts.synthesize( |
|
|
text="నమస్కారం", |
|
|
reference_audio="reference.wav", |
|
|
language="te" |
|
|
) |
|
|
``` |
|
|
|
|
|
### PyTorch Hub |
|
|
|
|
|
```python |
|
|
import torch |
|
|
|
|
|
tts = torch.hub.load('{owner}/chiluka', 'chiluka') |
|
|
tts = torch.hub.load('{owner}/chiluka', 'chiluka_telugu') |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
MIT License |
|
|
|
|
|
## Citation |
|
|
|
|
|
Based on StyleTTS2 by Yinghao Aaron Li et al. |
|
|
""" |
|
|
|
|
|
if save_path: |
|
|
with open(save_path, "w") as f: |
|
|
f.write(model_card) |
|
|
print(f"Model card saved to {save_path}") |
|
|
|
|
|
return model_card |
|
|
|