Spaces:
Running
Running
File size: 14,149 Bytes
0dfbd72 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 | #!/usr/bin/env python3
"""
Test script to observe exactly how HuggingFace reports download progress
for each TTS model. Doesn't load models β just downloads and tracks tqdm.
Usage:
backend/venv/bin/python scripts/test_download_progress.py qwen
backend/venv/bin/python scripts/test_download_progress.py luxtts
backend/venv/bin/python scripts/test_download_progress.py chatterbox
Add --delete to clear cache first and force a real download:
backend/venv/bin/python scripts/test_download_progress.py chatterbox --delete
"""
import os
import shutil
import sys
import time
import threading
from pathlib import Path
from contextlib import contextmanager
# βββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
MODELS = {
"qwen": {
"repo_id": "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
"method": "from_pretrained",
"description": "Qwen TTS 1.7B (uses transformers from_pretrained)",
},
"luxtts": {
"repo_id": "YatharthS/LuxTTS",
"method": "snapshot_download",
"description": "LuxTTS (uses snapshot_download)",
},
"chatterbox": {
"repo_id": "ResembleAI/chatterbox",
"method": "snapshot_download",
"allow_patterns": [
"ve.pt",
"t3_mtl23ls_v2.safetensors",
"s3gen.pt",
"grapheme_mtl_merged_expanded_v1.json",
"conds.pt",
"Cangjie5_TC.json",
],
"description": "Chatterbox Multilingual (uses snapshot_download with allow_patterns)",
},
}
# βββ Progress tracking (mirrors our HFProgressTracker) ββββββββββββββββββββββββ
class ProgressSpy:
"""Intercepts tqdm to see exactly what HF reports."""
def __init__(self):
self._lock = threading.Lock()
self.events = [] # List of dicts: {time, type, ...}
self._original_tqdm_class = None
self._original_tqdm_auto = None
self._patched_modules = {}
self._hf_tqdm_original_update = None
self._start_time = None
def _elapsed(self):
return time.time() - self._start_time if self._start_time else 0
def _log(self, event_type, **kwargs):
entry = {"time": f"{self._elapsed():.1f}s", "type": event_type, **kwargs}
self.events.append(entry)
# Live print
parts = [f"[{entry['time']:>7s}] {event_type:>10s}"]
for k, v in kwargs.items():
if k in ("current", "total") and isinstance(v, (int, float)) and v > 1_000_000:
parts.append(f"{k}={v / 1_000_000:.1f}MB")
else:
parts.append(f"{k}={v}")
print(" ".join(parts), flush=True)
def _create_tracked_tqdm_class(self):
spy = self
original_tqdm = self._original_tqdm_class
class SpyTqdm(original_tqdm):
def __init__(self, *args, **kwargs):
desc = kwargs.get("desc", "")
if not desc and args:
first_arg = args[0]
if isinstance(first_arg, str):
desc = first_arg
filename = ""
if desc:
if ":" in desc:
filename = desc.split(":")[0].strip()
else:
filename = desc.strip()
# Filter out non-standard kwargs
tqdm_kwargs = {
'iterable', 'desc', 'total', 'leave', 'file', 'ncols',
'mininterval', 'maxinterval', 'miniters', 'ascii', 'disable',
'unit', 'unit_scale', 'dynamic_ncols', 'smoothing',
'bar_format', 'initial', 'position', 'postfix',
'unit_divisor', 'write_bytes', 'lock_args', 'nrows',
'colour', 'color', 'delay', 'gui', 'disable_default', 'pos',
}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in tqdm_kwargs}
try:
super().__init__(*args, **filtered_kwargs)
except TypeError:
super().__init__(*args, **kwargs)
self._spy_filename = filename or "unknown"
total = getattr(self, "total", None)
spy._log(
"INIT",
filename=self._spy_filename,
total=total or 0,
unit=kwargs.get("unit", "?"),
unit_scale=kwargs.get("unit_scale", False),
disable=kwargs.get("disable", False),
)
def update(self, n=1):
result = super().update(n)
current = getattr(self, "n", 0)
total = getattr(self, "total", 0)
filename = self._spy_filename
spy._log(
"UPDATE",
filename=filename,
n=n,
current=current,
total=total or 0,
pct=f"{100 * current / total:.1f}%" if total else "?",
)
return result
def close(self):
spy._log("CLOSE", filename=self._spy_filename)
return super().close()
return SpyTqdm
@contextmanager
def patch(self):
"""Context manager that patches tqdm globally β same as HFProgressTracker."""
self._start_time = time.time()
try:
import tqdm as tqdm_module
self._original_tqdm_class = tqdm_module.tqdm
except ImportError:
yield
return
tracked_tqdm = self._create_tracked_tqdm_class()
# Patch tqdm.tqdm
tqdm_module.tqdm = tracked_tqdm
# Patch tqdm.auto.tqdm
self._original_tqdm_auto = None
if hasattr(tqdm_module, "auto") and hasattr(tqdm_module.auto, "tqdm"):
self._original_tqdm_auto = tqdm_module.auto.tqdm
tqdm_module.auto.tqdm = tracked_tqdm
# Patch in sys.modules (same as HFProgressTracker)
tqdm_attr_names = ['tqdm', 'base_tqdm', 'old_tqdm']
patched_count = 0
for module_name in list(sys.modules.keys()):
if "huggingface" in module_name or module_name.startswith("tqdm"):
try:
module = sys.modules[module_name]
for attr_name in tqdm_attr_names:
if hasattr(module, attr_name):
attr = getattr(module, attr_name)
is_tqdm_class = (
attr is self._original_tqdm_class
or (self._original_tqdm_auto and attr is self._original_tqdm_auto)
or (
hasattr(attr, "__name__")
and attr.__name__ == "tqdm"
and hasattr(attr, "update")
)
)
if is_tqdm_class:
key = f"{module_name}.{attr_name}"
self._patched_modules[key] = (module, attr_name, attr)
setattr(module, attr_name, tracked_tqdm)
patched_count += 1
except (AttributeError, TypeError):
pass
# Monkey-patch HF's tqdm.update (same as HFProgressTracker)
try:
from huggingface_hub.utils import tqdm as hf_tqdm_module
if hasattr(hf_tqdm_module, 'tqdm'):
hf_tqdm_class = hf_tqdm_module.tqdm
self._hf_tqdm_original_update = hf_tqdm_class.update
spy = self
def patched_update(tqdm_self, n=1):
result = spy._hf_tqdm_original_update(tqdm_self, n)
desc = getattr(tqdm_self, 'desc', '') or ''
current = getattr(tqdm_self, 'n', 0)
total = getattr(tqdm_self, 'total', 0) or 0
spy._log(
"HF_UPDATE",
desc=desc,
current=current,
total=total,
pct=f"{100 * current / total:.1f}%" if total else "?",
)
return result
hf_tqdm_class.update = patched_update
patched_count += 1
except (ImportError, AttributeError):
pass
print(f"\n=== Patched {patched_count} tqdm references ===\n", flush=True)
try:
yield
finally:
# Restore everything
import tqdm as tqdm_module
tqdm_module.tqdm = self._original_tqdm_class
if self._original_tqdm_auto:
tqdm_module.auto.tqdm = self._original_tqdm_auto
for key, (module, attr_name, original) in self._patched_modules.items():
try:
setattr(module, attr_name, original)
except (AttributeError, TypeError):
pass
if self._hf_tqdm_original_update:
try:
from huggingface_hub.utils import tqdm as hf_tqdm_module
if hasattr(hf_tqdm_module, 'tqdm'):
hf_tqdm_module.tqdm.update = self._hf_tqdm_original_update
except (ImportError, AttributeError):
pass
def summary(self):
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
inits = [e for e in self.events if e["type"] == "INIT"]
updates = [e for e in self.events if e["type"] in ("UPDATE", "HF_UPDATE")]
print(f"\ntqdm bars created: {len(inits)}")
for e in inits:
print(f" - {e.get('filename', '?'):40s} total={e.get('total', '?')}")
print(f"\nTotal update calls: {len(updates)}")
# Group updates by filename
by_file = {}
for e in updates:
fn = e.get("filename") or e.get("desc", "unknown")
if fn not in by_file:
by_file[fn] = []
by_file[fn].append(e)
for fn, evts in by_file.items():
max_current = max(e.get("current", 0) for e in evts)
max_total = max(e.get("total", 0) for e in evts)
print(f"\n {fn}:")
print(f" updates: {len(evts)}")
print(f" max current: {max_current:,}")
print(f" max total: {max_total:,}")
if max_total > 0 and max_current > 0:
print(f" final pct: {100 * max_current / max_total:.1f}%")
else:
print(f" final pct: NO PROGRESS REPORTED")
# βββ Delete cache βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def delete_cache(repo_id: str):
from huggingface_hub import constants as hf_constants
cache_dir = Path(hf_constants.HF_HUB_CACHE)
repo_cache = cache_dir / ("models--" + repo_id.replace("/", "--"))
if repo_cache.exists():
print(f"Deleting cache: {repo_cache}")
shutil.rmtree(repo_cache)
print("Deleted.")
else:
print(f"No cache found at {repo_cache}")
# βββ Download functions βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def download_qwen(spy: ProgressSpy):
"""Mirrors how pytorch_backend.py downloads Qwen."""
from transformers import AutoModel
repo_id = MODELS["qwen"]["repo_id"]
print(f"Downloading {repo_id} via AutoModel.from_pretrained...")
with spy.patch():
# This is what Qwen3TTSModel.from_pretrained does under the hood
from huggingface_hub import snapshot_download
snapshot_download(repo_id)
def download_luxtts(spy: ProgressSpy):
"""Mirrors how luxtts_backend.py downloads LuxTTS."""
from huggingface_hub import snapshot_download
repo_id = MODELS["luxtts"]["repo_id"]
print(f"Downloading {repo_id} via snapshot_download...")
with spy.patch():
snapshot_download(repo_id)
def download_chatterbox(spy: ProgressSpy):
"""Mirrors how chatterbox_backend.py downloads Chatterbox."""
from huggingface_hub import snapshot_download
cfg = MODELS["chatterbox"]
print(f"Downloading {cfg['repo_id']} via snapshot_download with allow_patterns...")
with spy.patch():
snapshot_download(
repo_id=cfg["repo_id"],
repo_type="model",
revision="main",
allow_patterns=cfg["allow_patterns"],
token=os.getenv("HF_TOKEN"),
)
# βββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def main():
if len(sys.argv) < 2 or sys.argv[1] not in MODELS:
print(f"Usage: {sys.argv[0]} <{'|'.join(MODELS.keys())}> [--delete]")
sys.exit(1)
model_key = sys.argv[1]
should_delete = "--delete" in sys.argv
cfg = MODELS[model_key]
print(f"\n{'=' * 70}")
print(f"Testing download progress for: {cfg['description']}")
print(f"Repo: {cfg['repo_id']}")
print(f"Method: {cfg['method']}")
print(f"{'=' * 70}\n")
if should_delete:
delete_cache(cfg["repo_id"])
print()
spy = ProgressSpy()
dispatch = {
"qwen": download_qwen,
"luxtts": download_luxtts,
"chatterbox": download_chatterbox,
}
try:
dispatch[model_key](spy)
except Exception as e:
print(f"\n!!! Download failed: {e}")
spy.summary()
if __name__ == "__main__":
main()
|