seemanthraju commited on
Commit
60fee7c
·
1 Parent(s): 7be9079

Add torch.hub and HuggingFace Hub support

Browse files
Files changed (6) hide show
  1. README_HF.md +92 -0
  2. chiluka/__init__.py +37 -1
  3. chiluka/hub.py +347 -0
  4. chiluka/inference.py +61 -0
  5. hubconf.py +84 -0
  6. setup.py +1 -0
README_HF.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - te
5
+ - hi
6
+ license: mit
7
+ library_name: chiluka
8
+ tags:
9
+ - text-to-speech
10
+ - tts
11
+ - styletts2
12
+ - voice-cloning
13
+ ---
14
+
15
+ # Chiluka TTS
16
+
17
+ Chiluka (చిలుక - Telugu for "parrot") is a lightweight Text-to-Speech model based on StyleTTS2.
18
+
19
+ ## Installation
20
+
21
+ ```bash
22
+ pip install chiluka
23
+ ```
24
+
25
+ Or install from source:
26
+
27
+ ```bash
28
+ pip install git+https://github.com/Seemanth/chiluka.git
29
+ ```
30
+
31
+ ## Usage
32
+
33
+ ### Quick Start (Auto-download)
34
+
35
+ ```python
36
+ from chiluka import Chiluka
37
+
38
+ # Automatically downloads model weights
39
+ tts = Chiluka.from_pretrained()
40
+
41
+ # Generate speech
42
+ wav = tts.synthesize(
43
+ text="Hello, world!",
44
+ reference_audio="path/to/reference.wav",
45
+ language="en"
46
+ )
47
+
48
+ # Save output
49
+ tts.save_wav(wav, "output.wav")
50
+ ```
51
+
52
+ ### PyTorch Hub
53
+
54
+ ```python
55
+ import torch
56
+
57
+ tts = torch.hub.load('Seemanth/chiluka', 'chiluka')
58
+ wav = tts.synthesize("Hello!", "reference.wav", language="en")
59
+ ```
60
+
61
+ ### HuggingFace Hub
62
+
63
+ ```python
64
+ from chiluka import Chiluka
65
+
66
+ tts = Chiluka.from_pretrained("Seemanth/chiluka-tts")
67
+ ```
68
+
69
+ ## Parameters
70
+
71
+ - `text`: Input text to synthesize
72
+ - `reference_audio`: Path to reference audio for style transfer
73
+ - `language`: Language code ('en', 'te', 'hi', etc.)
74
+ - `alpha`: Acoustic style mixing (0-1, default 0.3)
75
+ - `beta`: Prosodic style mixing (0-1, default 0.7)
76
+ - `diffusion_steps`: Quality vs speed tradeoff (default 5)
77
+
78
+ ## Supported Languages
79
+
80
+ Uses espeak-ng phonemizer. Common languages:
81
+ - English: `en-us`, `en-gb`
82
+ - Telugu: `te`
83
+ - Hindi: `hi`
84
+ - Tamil: `ta`
85
+
86
+ ## License
87
+
88
+ MIT License
89
+
90
+ ## Citation
91
+
92
+ Based on StyleTTS2 by Yinghao Aaron Li et al.
chiluka/__init__.py CHANGED
@@ -1,9 +1,45 @@
1
  """
2
  Chiluka - A lightweight TTS inference package based on StyleTTS2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  """
4
 
5
  __version__ = "0.1.0"
6
 
7
  from .inference import Chiluka
 
 
 
 
 
 
 
 
8
 
9
- __all__ = ["Chiluka"]
 
 
 
 
 
 
 
 
 
1
  """
2
  Chiluka - A lightweight TTS inference package based on StyleTTS2
3
+
4
+ Usage:
5
+ # Local weights (if you have them)
6
+ from chiluka import Chiluka
7
+ tts = Chiluka()
8
+
9
+ # Auto-download from HuggingFace Hub (recommended)
10
+ from chiluka import Chiluka
11
+ tts = Chiluka.from_pretrained()
12
+
13
+ # From specific HuggingFace repo
14
+ tts = Chiluka.from_pretrained("username/model-name")
15
+
16
+ # Generate speech
17
+ wav = tts.synthesize(
18
+ text="Hello, world!",
19
+ reference_audio="reference.wav",
20
+ language="en"
21
+ )
22
+ tts.save_wav(wav, "output.wav")
23
  """
24
 
25
  __version__ = "0.1.0"
26
 
27
  from .inference import Chiluka
28
+ from .hub import (
29
+ download_from_hf,
30
+ push_to_hub,
31
+ clear_cache,
32
+ get_cache_dir,
33
+ create_model_card,
34
+ DEFAULT_HF_REPO,
35
+ )
36
 
37
+ __all__ = [
38
+ "Chiluka",
39
+ "download_from_hf",
40
+ "push_to_hub",
41
+ "clear_cache",
42
+ "get_cache_dir",
43
+ "create_model_card",
44
+ "DEFAULT_HF_REPO",
45
+ ]
chiluka/hub.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hub utilities for downloading and managing Chiluka TTS models.
3
+
4
+ Supports:
5
+ - HuggingFace Hub integration
6
+ - Automatic model downloading
7
+ - Local caching
8
+ """
9
+
10
+ import os
11
+ import shutil
12
+ from pathlib import Path
13
+ from typing import Optional, Union
14
+
15
+ # Default HuggingFace Hub repository
16
+ DEFAULT_HF_REPO = "yourusername/chiluka-tts" # TODO: Update with your actual repo
17
+
18
+ # Cache directory for downloaded models
19
+ CACHE_DIR = Path.home() / ".cache" / "chiluka"
20
+
21
+ # Required model files
22
+ REQUIRED_FILES = {
23
+ "checkpoint": "checkpoints/epoch_2nd_00017.pth",
24
+ "config": "configs/config_ft.yml",
25
+ "asr_config": "pretrained/ASR/config.yml",
26
+ "asr_model": "pretrained/ASR/epoch_00080.pth",
27
+ "f0_model": "pretrained/JDC/bst.t7",
28
+ "plbert_config": "pretrained/PLBERT/config.yml",
29
+ "plbert_model": "pretrained/PLBERT/step_1000000.t7",
30
+ }
31
+
32
+
33
+ def get_cache_dir() -> Path:
34
+ """Get the cache directory for Chiluka models."""
35
+ cache_dir = Path(os.environ.get("CHILUKA_CACHE", CACHE_DIR))
36
+ cache_dir.mkdir(parents=True, exist_ok=True)
37
+ return cache_dir
38
+
39
+
40
+ def is_model_cached(repo_id: str = DEFAULT_HF_REPO) -> bool:
41
+ """Check if a model is already cached locally."""
42
+ cache_path = get_cache_dir() / repo_id.replace("/", "_")
43
+ if not cache_path.exists():
44
+ return False
45
+
46
+ # Check if all required files exist
47
+ for file_path in REQUIRED_FILES.values():
48
+ if not (cache_path / file_path).exists():
49
+ return False
50
+ return True
51
+
52
+
53
+ def download_from_hf(
54
+ repo_id: str = DEFAULT_HF_REPO,
55
+ revision: str = "main",
56
+ force_download: bool = False,
57
+ token: Optional[str] = None,
58
+ ) -> Path:
59
+ """
60
+ Download model files from HuggingFace Hub.
61
+
62
+ Args:
63
+ repo_id: HuggingFace Hub repository ID (e.g., 'username/model-name')
64
+ revision: Git revision to download (branch, tag, or commit hash)
65
+ force_download: If True, re-download even if cached
66
+ token: HuggingFace API token for private repos
67
+
68
+ Returns:
69
+ Path to the downloaded model directory
70
+
71
+ Example:
72
+ >>> model_path = download_from_hf("yourusername/chiluka-tts")
73
+ >>> print(model_path)
74
+ /home/user/.cache/chiluka/yourusername_chiluka-tts
75
+ """
76
+ try:
77
+ from huggingface_hub import snapshot_download, hf_hub_download
78
+ except ImportError:
79
+ raise ImportError(
80
+ "huggingface_hub is required for downloading models. "
81
+ "Install with: pip install huggingface_hub"
82
+ )
83
+
84
+ cache_path = get_cache_dir() / repo_id.replace("/", "_")
85
+
86
+ if is_model_cached(repo_id) and not force_download:
87
+ print(f"Using cached model from {cache_path}")
88
+ return cache_path
89
+
90
+ print(f"Downloading model from HuggingFace Hub: {repo_id}...")
91
+
92
+ # Download entire repository
93
+ downloaded_path = snapshot_download(
94
+ repo_id=repo_id,
95
+ revision=revision,
96
+ cache_dir=get_cache_dir() / "hf_cache",
97
+ token=token,
98
+ local_dir=cache_path,
99
+ local_dir_use_symlinks=False,
100
+ )
101
+
102
+ print(f"Model downloaded to {cache_path}")
103
+ return Path(downloaded_path)
104
+
105
+
106
+ def download_from_url(
107
+ url: str,
108
+ filename: str,
109
+ force_download: bool = False,
110
+ ) -> Path:
111
+ """
112
+ Download a single file from a URL.
113
+
114
+ Args:
115
+ url: URL to download from
116
+ filename: Local filename to save as
117
+ force_download: If True, re-download even if exists
118
+
119
+ Returns:
120
+ Path to the downloaded file
121
+ """
122
+ import urllib.request
123
+
124
+ cache_dir = get_cache_dir() / "downloads"
125
+ cache_dir.mkdir(parents=True, exist_ok=True)
126
+ local_path = cache_dir / filename
127
+
128
+ if local_path.exists() and not force_download:
129
+ print(f"Using cached file: {local_path}")
130
+ return local_path
131
+
132
+ print(f"Downloading {filename}...")
133
+
134
+ # Download with progress
135
+ def _progress_hook(count, block_size, total_size):
136
+ percent = int(count * block_size * 100 / total_size)
137
+ print(f"\rDownloading: {percent}%", end="", flush=True)
138
+
139
+ urllib.request.urlretrieve(url, local_path, reporthook=_progress_hook)
140
+ print() # New line after progress
141
+
142
+ return local_path
143
+
144
+
145
+ def get_model_paths(repo_id: str = DEFAULT_HF_REPO) -> dict:
146
+ """
147
+ Get paths to all model files after downloading.
148
+
149
+ Args:
150
+ repo_id: HuggingFace Hub repository ID
151
+
152
+ Returns:
153
+ Dictionary with paths to config, checkpoint, and pretrained directory
154
+ """
155
+ model_dir = download_from_hf(repo_id)
156
+
157
+ return {
158
+ "config_path": str(model_dir / "configs" / "config_ft.yml"),
159
+ "checkpoint_path": str(model_dir / "checkpoints" / "epoch_2nd_00017.pth"),
160
+ "pretrained_dir": str(model_dir / "pretrained"),
161
+ }
162
+
163
+
164
+ def clear_cache(repo_id: Optional[str] = None):
165
+ """
166
+ Clear cached models.
167
+
168
+ Args:
169
+ repo_id: If specified, only clear cache for this repo.
170
+ If None, clear entire cache.
171
+ """
172
+ cache_dir = get_cache_dir()
173
+
174
+ if repo_id:
175
+ cache_path = cache_dir / repo_id.replace("/", "_")
176
+ if cache_path.exists():
177
+ shutil.rmtree(cache_path)
178
+ print(f"Cleared cache for {repo_id}")
179
+ else:
180
+ if cache_dir.exists():
181
+ shutil.rmtree(cache_dir)
182
+ print("Cleared entire Chiluka cache")
183
+
184
+
185
+ def push_to_hub(
186
+ local_dir: str,
187
+ repo_id: str,
188
+ token: Optional[str] = None,
189
+ private: bool = False,
190
+ commit_message: str = "Upload Chiluka TTS model",
191
+ ):
192
+ """
193
+ Push a local model to HuggingFace Hub.
194
+
195
+ Args:
196
+ local_dir: Local directory containing model files
197
+ repo_id: Target HuggingFace Hub repository ID
198
+ token: HuggingFace API token (or set HF_TOKEN env var)
199
+ private: Whether to create a private repository
200
+ commit_message: Commit message for the upload
201
+
202
+ Example:
203
+ >>> push_to_hub(
204
+ ... local_dir="./chiluka",
205
+ ... repo_id="myusername/my-chiluka-model",
206
+ ... private=False
207
+ ... )
208
+ """
209
+ try:
210
+ from huggingface_hub import HfApi, create_repo
211
+ except ImportError:
212
+ raise ImportError(
213
+ "huggingface_hub is required for pushing models. "
214
+ "Install with: pip install huggingface_hub"
215
+ )
216
+
217
+ api = HfApi(token=token)
218
+
219
+ # Create repo if it doesn't exist
220
+ try:
221
+ create_repo(repo_id, private=private, token=token, exist_ok=True)
222
+ except Exception as e:
223
+ print(f"Note: {e}")
224
+
225
+ # Upload folder
226
+ print(f"Uploading to {repo_id}...")
227
+ api.upload_folder(
228
+ folder_path=local_dir,
229
+ repo_id=repo_id,
230
+ commit_message=commit_message,
231
+ ignore_patterns=["*.pyc", "__pycache__", "*.egg-info", ".git"],
232
+ )
233
+
234
+ print(f"Model uploaded to: https://huggingface.co/{repo_id}")
235
+
236
+
237
+ def create_model_card(repo_id: str, save_path: Optional[str] = None) -> str:
238
+ """
239
+ Generate a model card (README.md) for HuggingFace Hub.
240
+
241
+ Args:
242
+ repo_id: Repository ID for the model
243
+ save_path: If provided, save the model card to this path
244
+
245
+ Returns:
246
+ Model card content as string
247
+ """
248
+ model_card = f"""---
249
+ language:
250
+ - en
251
+ - te
252
+ - hi
253
+ license: mit
254
+ library_name: chiluka
255
+ tags:
256
+ - text-to-speech
257
+ - tts
258
+ - styletts2
259
+ - voice-cloning
260
+ ---
261
+
262
+ # Chiluka TTS
263
+
264
+ Chiluka (చిలుక - Telugu for "parrot") is a lightweight Text-to-Speech model based on StyleTTS2.
265
+
266
+ ## Installation
267
+
268
+ ```bash
269
+ pip install chiluka
270
+ ```
271
+
272
+ Or install from source:
273
+
274
+ ```bash
275
+ pip install git+https://github.com/{repo_id.split('/')[0]}/chiluka.git
276
+ ```
277
+
278
+ ## Usage
279
+
280
+ ### Quick Start (Auto-download)
281
+
282
+ ```python
283
+ from chiluka import Chiluka
284
+
285
+ # Automatically downloads model weights
286
+ tts = Chiluka.from_pretrained()
287
+
288
+ # Generate speech
289
+ wav = tts.synthesize(
290
+ text="Hello, world!",
291
+ reference_audio="path/to/reference.wav",
292
+ language="en"
293
+ )
294
+
295
+ # Save output
296
+ tts.save_wav(wav, "output.wav")
297
+ ```
298
+
299
+ ### PyTorch Hub
300
+
301
+ ```python
302
+ import torch
303
+
304
+ tts = torch.hub.load('{repo_id.split('/')[0]}/chiluka', 'chiluka')
305
+ wav = tts.synthesize("Hello!", "reference.wav", language="en")
306
+ ```
307
+
308
+ ### HuggingFace Hub
309
+
310
+ ```python
311
+ from chiluka import Chiluka
312
+
313
+ tts = Chiluka.from_pretrained("{repo_id}")
314
+ ```
315
+
316
+ ## Parameters
317
+
318
+ - `text`: Input text to synthesize
319
+ - `reference_audio`: Path to reference audio for style transfer
320
+ - `language`: Language code ('en', 'te', 'hi', etc.)
321
+ - `alpha`: Acoustic style mixing (0-1, default 0.3)
322
+ - `beta`: Prosodic style mixing (0-1, default 0.7)
323
+ - `diffusion_steps`: Quality vs speed tradeoff (default 5)
324
+
325
+ ## Supported Languages
326
+
327
+ Uses espeak-ng phonemizer. Common languages:
328
+ - English: `en-us`, `en-gb`
329
+ - Telugu: `te`
330
+ - Hindi: `hi`
331
+ - Tamil: `ta`
332
+
333
+ ## License
334
+
335
+ MIT License
336
+
337
+ ## Citation
338
+
339
+ Based on StyleTTS2 by Yinghao Aaron Li et al.
340
+ """
341
+
342
+ if save_path:
343
+ with open(save_path, "w") as f:
344
+ f.write(model_card)
345
+ print(f"Model card saved to {save_path}")
346
+
347
+ return model_card
chiluka/inference.py CHANGED
@@ -152,6 +152,67 @@ class Chiluka:
152
 
153
  print("✓ Chiluka TTS initialized successfully!")
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def _verify_pretrained_models(self, asr_path, f0_path, plbert_dir):
156
  """Verify all pretrained models exist."""
157
  missing = []
 
152
 
153
  print("✓ Chiluka TTS initialized successfully!")
154
 
155
+ @classmethod
156
+ def from_pretrained(
157
+ cls,
158
+ repo_id: str = None,
159
+ device: Optional[str] = None,
160
+ force_download: bool = False,
161
+ token: Optional[str] = None,
162
+ **kwargs,
163
+ ) -> "Chiluka":
164
+ """
165
+ Load Chiluka TTS from HuggingFace Hub or with auto-downloaded weights.
166
+
167
+ This is the recommended way to load Chiluka when you don't have local weights.
168
+ Weights are automatically downloaded and cached on first use.
169
+
170
+ Args:
171
+ repo_id: HuggingFace Hub repository ID (e.g., 'username/chiluka-tts').
172
+ If None, uses the default repository.
173
+ device: Device to use ('cuda' or 'cpu'). Auto-detects if None.
174
+ force_download: If True, re-download even if cached.
175
+ token: HuggingFace API token for private repositories.
176
+ **kwargs: Additional arguments passed to Chiluka constructor.
177
+
178
+ Returns:
179
+ Initialized Chiluka TTS model ready for inference.
180
+
181
+ Examples:
182
+ # Default repository (auto-download)
183
+ >>> tts = Chiluka.from_pretrained()
184
+
185
+ # Specific repository
186
+ >>> tts = Chiluka.from_pretrained("myuser/my-chiluka-model")
187
+
188
+ # Force re-download
189
+ >>> tts = Chiluka.from_pretrained(force_download=True)
190
+
191
+ # Private repository
192
+ >>> tts = Chiluka.from_pretrained("myuser/private-model", token="hf_xxx")
193
+ """
194
+ from .hub import download_from_hf, get_model_paths, DEFAULT_HF_REPO
195
+
196
+ repo_id = repo_id or DEFAULT_HF_REPO
197
+
198
+ # Download model files (or use cache)
199
+ model_dir = download_from_hf(
200
+ repo_id=repo_id,
201
+ force_download=force_download,
202
+ token=token,
203
+ )
204
+
205
+ # Get paths to model files
206
+ paths = get_model_paths(repo_id)
207
+
208
+ return cls(
209
+ config_path=paths["config_path"],
210
+ checkpoint_path=paths["checkpoint_path"],
211
+ pretrained_dir=paths["pretrained_dir"],
212
+ device=device,
213
+ **kwargs,
214
+ )
215
+
216
  def _verify_pretrained_models(self, asr_path, f0_path, plbert_dir):
217
  """Verify all pretrained models exist."""
218
  missing = []
hubconf.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Hub configuration for Chiluka TTS.
3
+
4
+ Usage:
5
+ import torch
6
+
7
+ # Load the model
8
+ tts = torch.hub.load('yourusername/chiluka', 'chiluka')
9
+
10
+ # Or with force reload
11
+ tts = torch.hub.load('yourusername/chiluka', 'chiluka', force_reload=True)
12
+
13
+ # Generate speech
14
+ wav = tts.synthesize(
15
+ text="Hello, world!",
16
+ reference_audio="path/to/reference.wav",
17
+ language="en"
18
+ )
19
+ """
20
+
21
+ dependencies = [
22
+ 'torch',
23
+ 'torchaudio',
24
+ 'transformers',
25
+ 'librosa',
26
+ 'phonemizer',
27
+ 'nltk',
28
+ 'PyYAML',
29
+ 'munch',
30
+ 'einops',
31
+ 'einops-exts',
32
+ 'numpy',
33
+ 'scipy',
34
+ 'huggingface_hub',
35
+ ]
36
+
37
+
38
+ def chiluka(pretrained: bool = True, device: str = None, **kwargs):
39
+ """
40
+ Load Chiluka TTS model.
41
+
42
+ Args:
43
+ pretrained: If True, downloads pretrained weights from HuggingFace Hub.
44
+ If False, returns uninitialized model (requires manual weight loading).
45
+ device: Device to use ('cuda' or 'cpu'). Auto-detects if None.
46
+ **kwargs: Additional arguments passed to Chiluka constructor.
47
+
48
+ Returns:
49
+ Chiluka: Initialized TTS model ready for inference.
50
+
51
+ Example:
52
+ >>> import torch
53
+ >>> tts = torch.hub.load('yourusername/chiluka', 'chiluka')
54
+ >>> wav = tts.synthesize("Hello!", "reference.wav", language="en")
55
+ """
56
+ from chiluka import Chiluka
57
+
58
+ if pretrained:
59
+ # Use from_pretrained to auto-download weights
60
+ return Chiluka.from_pretrained(device=device, **kwargs)
61
+ else:
62
+ # Return model expecting local weights
63
+ return Chiluka(device=device, **kwargs)
64
+
65
+
66
+ def chiluka_from_hf(repo_id: str = "yourusername/chiluka-tts", device: str = None, **kwargs):
67
+ """
68
+ Load Chiluka TTS from a specific HuggingFace Hub repository.
69
+
70
+ Args:
71
+ repo_id: HuggingFace Hub repository ID (e.g., 'username/model-name')
72
+ device: Device to use ('cuda' or 'cpu'). Auto-detects if None.
73
+ **kwargs: Additional arguments passed to Chiluka constructor.
74
+
75
+ Returns:
76
+ Chiluka: Initialized TTS model ready for inference.
77
+
78
+ Example:
79
+ >>> import torch
80
+ >>> tts = torch.hub.load('yourusername/chiluka', 'chiluka_from_hf',
81
+ ... repo_id='myuser/my-custom-chiluka')
82
+ """
83
+ from chiluka import Chiluka
84
+ return Chiluka.from_pretrained(repo_id=repo_id, device=device, **kwargs)
setup.py CHANGED
@@ -43,6 +43,7 @@ setup(
43
  "einops-exts>=0.0.4",
44
  "numpy>=1.21.0",
45
  "scipy>=1.7.0",
 
46
  ],
47
  extras_require={
48
  "playback": ["pyaudio>=0.2.11"],
 
43
  "einops-exts>=0.0.4",
44
  "numpy>=1.21.0",
45
  "scipy>=1.7.0",
46
+ "huggingface_hub>=0.16.0",
47
  ],
48
  extras_require={
49
  "playback": ["pyaudio>=0.2.11"],