File size: 5,791 Bytes
89a8916 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
"""
Model Downloader for SYSPIN TTS Models
Downloads models from Hugging Face Hub
"""
import os
import logging
from pathlib import Path
from typing import Optional, List
from huggingface_hub import hf_hub_download, snapshot_download
from tqdm import tqdm
from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR
logger = logging.getLogger(__name__)
class ModelDownloader:
"""Downloads and manages SYSPIN TTS models from Hugging Face"""
def __init__(self, models_dir: str = MODELS_DIR):
self.models_dir = Path(models_dir)
self.models_dir.mkdir(parents=True, exist_ok=True)
def download_model(self, voice_key: str, force: bool = False) -> Path:
"""
Download a specific voice model
Args:
voice_key: Key from LANGUAGE_CONFIGS (e.g., 'hi_male', 'bn_female')
force: Re-download even if exists
Returns:
Path to downloaded model directory
"""
if voice_key not in LANGUAGE_CONFIGS:
raise ValueError(
f"Unknown voice: {voice_key}. Available: {list(LANGUAGE_CONFIGS.keys())}"
)
config = LANGUAGE_CONFIGS[voice_key]
model_dir = self.models_dir / voice_key
# Check if already downloaded
model_path = model_dir / config.model_filename
chars_path = model_dir / config.chars_filename
extra_path = model_dir / "extra.py"
if not force and model_path.exists() and chars_path.exists():
logger.info(f"Model {voice_key} already downloaded at {model_dir}")
return model_dir
logger.info(f"Downloading {voice_key} from {config.hf_model_id}...")
# Create model directory
model_dir.mkdir(parents=True, exist_ok=True)
try:
# Download all files from the repo
snapshot_download(
repo_id=config.hf_model_id,
local_dir=str(model_dir),
local_dir_use_symlinks=False,
allow_patterns=["*.pt", "*.pth", "*.txt", "*.py", "*.json"],
)
logger.info(f"Successfully downloaded {voice_key} to {model_dir}")
except Exception as e:
logger.error(f"Failed to download {voice_key}: {e}")
raise
return model_dir
def download_all_models(self, force: bool = False) -> List[Path]:
"""Download all available models"""
downloaded = []
for voice_key in tqdm(LANGUAGE_CONFIGS.keys(), desc="Downloading models"):
try:
path = self.download_model(voice_key, force=force)
downloaded.append(path)
except Exception as e:
logger.warning(f"Failed to download {voice_key}: {e}")
return downloaded
def download_language(self, lang_code: str, force: bool = False) -> List[Path]:
"""Download all voices for a specific language"""
downloaded = []
for voice_key, config in LANGUAGE_CONFIGS.items():
if config.code == lang_code:
try:
path = self.download_model(voice_key, force=force)
downloaded.append(path)
except Exception as e:
logger.warning(f"Failed to download {voice_key}: {e}")
return downloaded
def get_model_path(self, voice_key: str) -> Optional[Path]:
"""Get path to a downloaded model"""
if voice_key not in LANGUAGE_CONFIGS:
return None
config = LANGUAGE_CONFIGS[voice_key]
model_path = self.models_dir / voice_key / config.model_filename
if model_path.exists():
return model_path.parent
return None
def list_downloaded_models(self) -> List[str]:
"""List all downloaded models"""
downloaded = []
for voice_key, config in LANGUAGE_CONFIGS.items():
model_path = self.models_dir / voice_key / config.model_filename
if model_path.exists():
downloaded.append(voice_key)
return downloaded
def get_model_size(self, voice_key: str) -> Optional[int]:
"""Get size of downloaded model in bytes"""
model_path = self.get_model_path(voice_key)
if not model_path:
return None
total_size = 0
for f in model_path.iterdir():
if f.is_file():
total_size += f.stat().st_size
return total_size
def download_models_cli():
"""CLI entry point for downloading models"""
import argparse
parser = argparse.ArgumentParser(description="Download SYSPIN TTS models")
parser.add_argument(
"--voice", type=str, help="Specific voice to download (e.g., hi_male)"
)
parser.add_argument(
"--lang", type=str, help="Download all voices for a language (e.g., hi)"
)
parser.add_argument("--all", action="store_true", help="Download all models")
parser.add_argument("--list", action="store_true", help="List available models")
parser.add_argument("--force", action="store_true", help="Force re-download")
args = parser.parse_args()
downloader = ModelDownloader()
if args.list:
print("Available voices:")
for key, config in LANGUAGE_CONFIGS.items():
downloaded = "✓" if downloader.get_model_path(key) else " "
print(f" [{downloaded}] {key}: {config.name} ({config.code})")
return
if args.voice:
downloader.download_model(args.voice, force=args.force)
elif args.lang:
downloader.download_language(args.lang, force=args.force)
elif args.all:
downloader.download_all_models(force=args.force)
else:
parser.print_help()
if __name__ == "__main__":
download_models_cli()
|