Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Test script to observe exactly how HuggingFace reports download progress | |
| for each TTS model. Doesn't load models β just downloads and tracks tqdm. | |
| Usage: | |
| backend/venv/bin/python scripts/test_download_progress.py qwen | |
| backend/venv/bin/python scripts/test_download_progress.py luxtts | |
| backend/venv/bin/python scripts/test_download_progress.py chatterbox | |
| Add --delete to clear cache first and force a real download: | |
| backend/venv/bin/python scripts/test_download_progress.py chatterbox --delete | |
| """ | |
| import os | |
| import shutil | |
| import sys | |
| import time | |
| import threading | |
| from pathlib import Path | |
| from contextlib import contextmanager | |
| # βββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODELS = { | |
| "qwen": { | |
| "repo_id": "Qwen/Qwen3-TTS-12Hz-1.7B-Base", | |
| "method": "from_pretrained", | |
| "description": "Qwen TTS 1.7B (uses transformers from_pretrained)", | |
| }, | |
| "luxtts": { | |
| "repo_id": "YatharthS/LuxTTS", | |
| "method": "snapshot_download", | |
| "description": "LuxTTS (uses snapshot_download)", | |
| }, | |
| "chatterbox": { | |
| "repo_id": "ResembleAI/chatterbox", | |
| "method": "snapshot_download", | |
| "allow_patterns": [ | |
| "ve.pt", | |
| "t3_mtl23ls_v2.safetensors", | |
| "s3gen.pt", | |
| "grapheme_mtl_merged_expanded_v1.json", | |
| "conds.pt", | |
| "Cangjie5_TC.json", | |
| ], | |
| "description": "Chatterbox Multilingual (uses snapshot_download with allow_patterns)", | |
| }, | |
| } | |
| # βββ Progress tracking (mirrors our HFProgressTracker) ββββββββββββββββββββββββ | |
| class ProgressSpy: | |
| """Intercepts tqdm to see exactly what HF reports.""" | |
| def __init__(self): | |
| self._lock = threading.Lock() | |
| self.events = [] # List of dicts: {time, type, ...} | |
| self._original_tqdm_class = None | |
| self._original_tqdm_auto = None | |
| self._patched_modules = {} | |
| self._hf_tqdm_original_update = None | |
| self._start_time = None | |
| def _elapsed(self): | |
| return time.time() - self._start_time if self._start_time else 0 | |
| def _log(self, event_type, **kwargs): | |
| entry = {"time": f"{self._elapsed():.1f}s", "type": event_type, **kwargs} | |
| self.events.append(entry) | |
| # Live print | |
| parts = [f"[{entry['time']:>7s}] {event_type:>10s}"] | |
| for k, v in kwargs.items(): | |
| if k in ("current", "total") and isinstance(v, (int, float)) and v > 1_000_000: | |
| parts.append(f"{k}={v / 1_000_000:.1f}MB") | |
| else: | |
| parts.append(f"{k}={v}") | |
| print(" ".join(parts), flush=True) | |
| def _create_tracked_tqdm_class(self): | |
| spy = self | |
| original_tqdm = self._original_tqdm_class | |
| class SpyTqdm(original_tqdm): | |
| def __init__(self, *args, **kwargs): | |
| desc = kwargs.get("desc", "") | |
| if not desc and args: | |
| first_arg = args[0] | |
| if isinstance(first_arg, str): | |
| desc = first_arg | |
| filename = "" | |
| if desc: | |
| if ":" in desc: | |
| filename = desc.split(":")[0].strip() | |
| else: | |
| filename = desc.strip() | |
| # Filter out non-standard kwargs | |
| tqdm_kwargs = { | |
| 'iterable', 'desc', 'total', 'leave', 'file', 'ncols', | |
| 'mininterval', 'maxinterval', 'miniters', 'ascii', 'disable', | |
| 'unit', 'unit_scale', 'dynamic_ncols', 'smoothing', | |
| 'bar_format', 'initial', 'position', 'postfix', | |
| 'unit_divisor', 'write_bytes', 'lock_args', 'nrows', | |
| 'colour', 'color', 'delay', 'gui', 'disable_default', 'pos', | |
| } | |
| filtered_kwargs = {k: v for k, v in kwargs.items() if k in tqdm_kwargs} | |
| try: | |
| super().__init__(*args, **filtered_kwargs) | |
| except TypeError: | |
| super().__init__(*args, **kwargs) | |
| self._spy_filename = filename or "unknown" | |
| total = getattr(self, "total", None) | |
| spy._log( | |
| "INIT", | |
| filename=self._spy_filename, | |
| total=total or 0, | |
| unit=kwargs.get("unit", "?"), | |
| unit_scale=kwargs.get("unit_scale", False), | |
| disable=kwargs.get("disable", False), | |
| ) | |
| def update(self, n=1): | |
| result = super().update(n) | |
| current = getattr(self, "n", 0) | |
| total = getattr(self, "total", 0) | |
| filename = self._spy_filename | |
| spy._log( | |
| "UPDATE", | |
| filename=filename, | |
| n=n, | |
| current=current, | |
| total=total or 0, | |
| pct=f"{100 * current / total:.1f}%" if total else "?", | |
| ) | |
| return result | |
| def close(self): | |
| spy._log("CLOSE", filename=self._spy_filename) | |
| return super().close() | |
| return SpyTqdm | |
| def patch(self): | |
| """Context manager that patches tqdm globally β same as HFProgressTracker.""" | |
| self._start_time = time.time() | |
| try: | |
| import tqdm as tqdm_module | |
| self._original_tqdm_class = tqdm_module.tqdm | |
| except ImportError: | |
| yield | |
| return | |
| tracked_tqdm = self._create_tracked_tqdm_class() | |
| # Patch tqdm.tqdm | |
| tqdm_module.tqdm = tracked_tqdm | |
| # Patch tqdm.auto.tqdm | |
| self._original_tqdm_auto = None | |
| if hasattr(tqdm_module, "auto") and hasattr(tqdm_module.auto, "tqdm"): | |
| self._original_tqdm_auto = tqdm_module.auto.tqdm | |
| tqdm_module.auto.tqdm = tracked_tqdm | |
| # Patch in sys.modules (same as HFProgressTracker) | |
| tqdm_attr_names = ['tqdm', 'base_tqdm', 'old_tqdm'] | |
| patched_count = 0 | |
| for module_name in list(sys.modules.keys()): | |
| if "huggingface" in module_name or module_name.startswith("tqdm"): | |
| try: | |
| module = sys.modules[module_name] | |
| for attr_name in tqdm_attr_names: | |
| if hasattr(module, attr_name): | |
| attr = getattr(module, attr_name) | |
| is_tqdm_class = ( | |
| attr is self._original_tqdm_class | |
| or (self._original_tqdm_auto and attr is self._original_tqdm_auto) | |
| or ( | |
| hasattr(attr, "__name__") | |
| and attr.__name__ == "tqdm" | |
| and hasattr(attr, "update") | |
| ) | |
| ) | |
| if is_tqdm_class: | |
| key = f"{module_name}.{attr_name}" | |
| self._patched_modules[key] = (module, attr_name, attr) | |
| setattr(module, attr_name, tracked_tqdm) | |
| patched_count += 1 | |
| except (AttributeError, TypeError): | |
| pass | |
| # Monkey-patch HF's tqdm.update (same as HFProgressTracker) | |
| try: | |
| from huggingface_hub.utils import tqdm as hf_tqdm_module | |
| if hasattr(hf_tqdm_module, 'tqdm'): | |
| hf_tqdm_class = hf_tqdm_module.tqdm | |
| self._hf_tqdm_original_update = hf_tqdm_class.update | |
| spy = self | |
| def patched_update(tqdm_self, n=1): | |
| result = spy._hf_tqdm_original_update(tqdm_self, n) | |
| desc = getattr(tqdm_self, 'desc', '') or '' | |
| current = getattr(tqdm_self, 'n', 0) | |
| total = getattr(tqdm_self, 'total', 0) or 0 | |
| spy._log( | |
| "HF_UPDATE", | |
| desc=desc, | |
| current=current, | |
| total=total, | |
| pct=f"{100 * current / total:.1f}%" if total else "?", | |
| ) | |
| return result | |
| hf_tqdm_class.update = patched_update | |
| patched_count += 1 | |
| except (ImportError, AttributeError): | |
| pass | |
| print(f"\n=== Patched {patched_count} tqdm references ===\n", flush=True) | |
| try: | |
| yield | |
| finally: | |
| # Restore everything | |
| import tqdm as tqdm_module | |
| tqdm_module.tqdm = self._original_tqdm_class | |
| if self._original_tqdm_auto: | |
| tqdm_module.auto.tqdm = self._original_tqdm_auto | |
| for key, (module, attr_name, original) in self._patched_modules.items(): | |
| try: | |
| setattr(module, attr_name, original) | |
| except (AttributeError, TypeError): | |
| pass | |
| if self._hf_tqdm_original_update: | |
| try: | |
| from huggingface_hub.utils import tqdm as hf_tqdm_module | |
| if hasattr(hf_tqdm_module, 'tqdm'): | |
| hf_tqdm_module.tqdm.update = self._hf_tqdm_original_update | |
| except (ImportError, AttributeError): | |
| pass | |
| def summary(self): | |
| print("\n" + "=" * 70) | |
| print("SUMMARY") | |
| print("=" * 70) | |
| inits = [e for e in self.events if e["type"] == "INIT"] | |
| updates = [e for e in self.events if e["type"] in ("UPDATE", "HF_UPDATE")] | |
| print(f"\ntqdm bars created: {len(inits)}") | |
| for e in inits: | |
| print(f" - {e.get('filename', '?'):40s} total={e.get('total', '?')}") | |
| print(f"\nTotal update calls: {len(updates)}") | |
| # Group updates by filename | |
| by_file = {} | |
| for e in updates: | |
| fn = e.get("filename") or e.get("desc", "unknown") | |
| if fn not in by_file: | |
| by_file[fn] = [] | |
| by_file[fn].append(e) | |
| for fn, evts in by_file.items(): | |
| max_current = max(e.get("current", 0) for e in evts) | |
| max_total = max(e.get("total", 0) for e in evts) | |
| print(f"\n {fn}:") | |
| print(f" updates: {len(evts)}") | |
| print(f" max current: {max_current:,}") | |
| print(f" max total: {max_total:,}") | |
| if max_total > 0 and max_current > 0: | |
| print(f" final pct: {100 * max_current / max_total:.1f}%") | |
| else: | |
| print(f" final pct: NO PROGRESS REPORTED") | |
| # βββ Delete cache βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def delete_cache(repo_id: str): | |
| from huggingface_hub import constants as hf_constants | |
| cache_dir = Path(hf_constants.HF_HUB_CACHE) | |
| repo_cache = cache_dir / ("models--" + repo_id.replace("/", "--")) | |
| if repo_cache.exists(): | |
| print(f"Deleting cache: {repo_cache}") | |
| shutil.rmtree(repo_cache) | |
| print("Deleted.") | |
| else: | |
| print(f"No cache found at {repo_cache}") | |
| # βββ Download functions βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def download_qwen(spy: ProgressSpy): | |
| """Mirrors how pytorch_backend.py downloads Qwen.""" | |
| from transformers import AutoModel | |
| repo_id = MODELS["qwen"]["repo_id"] | |
| print(f"Downloading {repo_id} via AutoModel.from_pretrained...") | |
| with spy.patch(): | |
| # This is what Qwen3TTSModel.from_pretrained does under the hood | |
| from huggingface_hub import snapshot_download | |
| snapshot_download(repo_id) | |
| def download_luxtts(spy: ProgressSpy): | |
| """Mirrors how luxtts_backend.py downloads LuxTTS.""" | |
| from huggingface_hub import snapshot_download | |
| repo_id = MODELS["luxtts"]["repo_id"] | |
| print(f"Downloading {repo_id} via snapshot_download...") | |
| with spy.patch(): | |
| snapshot_download(repo_id) | |
| def download_chatterbox(spy: ProgressSpy): | |
| """Mirrors how chatterbox_backend.py downloads Chatterbox.""" | |
| from huggingface_hub import snapshot_download | |
| cfg = MODELS["chatterbox"] | |
| print(f"Downloading {cfg['repo_id']} via snapshot_download with allow_patterns...") | |
| with spy.patch(): | |
| snapshot_download( | |
| repo_id=cfg["repo_id"], | |
| repo_type="model", | |
| revision="main", | |
| allow_patterns=cfg["allow_patterns"], | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| # βββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| if len(sys.argv) < 2 or sys.argv[1] not in MODELS: | |
| print(f"Usage: {sys.argv[0]} <{'|'.join(MODELS.keys())}> [--delete]") | |
| sys.exit(1) | |
| model_key = sys.argv[1] | |
| should_delete = "--delete" in sys.argv | |
| cfg = MODELS[model_key] | |
| print(f"\n{'=' * 70}") | |
| print(f"Testing download progress for: {cfg['description']}") | |
| print(f"Repo: {cfg['repo_id']}") | |
| print(f"Method: {cfg['method']}") | |
| print(f"{'=' * 70}\n") | |
| if should_delete: | |
| delete_cache(cfg["repo_id"]) | |
| print() | |
| spy = ProgressSpy() | |
| dispatch = { | |
| "qwen": download_qwen, | |
| "luxtts": download_luxtts, | |
| "chatterbox": download_chatterbox, | |
| } | |
| try: | |
| dispatch[model_key](spy) | |
| except Exception as e: | |
| print(f"\n!!! Download failed: {e}") | |
| spy.summary() | |
| if __name__ == "__main__": | |
| main() | |