Spaces:
Sleeping
Sleeping
| """ | |
| Download progress tracker for Hugging Face models. | |
| Tracks real-time download progress in bytes. | |
| """ | |
| import threading | |
| import time | |
| from typing import Dict, Optional | |
| from dataclasses import dataclass, field | |
| class DownloadProgress: | |
| """Track download progress for a single file.""" | |
| filename: str | |
| total_bytes: int = 0 | |
| downloaded_bytes: int = 0 | |
| started_at: Optional[float] = None | |
| completed_at: Optional[float] = None | |
| speed_bytes_per_sec: float = 0.0 | |
| def percentage(self) -> float: | |
| """Calculate download percentage.""" | |
| if self.total_bytes == 0: | |
| return 0.0 | |
| return min(100.0, (self.downloaded_bytes / self.total_bytes) * 100.0) | |
| def is_complete(self) -> bool: | |
| """Check if download is complete.""" | |
| return self.total_bytes > 0 and self.downloaded_bytes >= self.total_bytes | |
| def elapsed_time(self) -> float: | |
| """Get elapsed time in seconds.""" | |
| if self.started_at is None: | |
| return 0.0 | |
| end_time = self.completed_at or time.time() | |
| return end_time - self.started_at | |
| class ModelDownloadProgress: | |
| """Track overall download progress for a model.""" | |
| model_path: str | |
| files: Dict[str, DownloadProgress] = field(default_factory=dict) | |
| started_at: Optional[float] = None | |
| completed_at: Optional[float] = None | |
| def update_file(self, filename: str, downloaded: int, total: int): | |
| """Update progress for a specific file.""" | |
| if filename not in self.files: | |
| self.files[filename] = DownloadProgress( | |
| filename=filename, | |
| started_at=time.time() | |
| ) | |
| if self.started_at is None: | |
| self.started_at = time.time() | |
| file_progress = self.files[filename] | |
| file_progress.downloaded_bytes = downloaded | |
| file_progress.total_bytes = total | |
| # Calculate speed | |
| if file_progress.started_at: | |
| elapsed = time.time() - file_progress.started_at | |
| if elapsed > 0: | |
| file_progress.speed_bytes_per_sec = downloaded / elapsed | |
| # Mark as complete | |
| if total > 0 and downloaded >= total: | |
| file_progress.completed_at = time.time() | |
| def complete_file(self, filename: str): | |
| """Mark a file as complete.""" | |
| if filename in self.files: | |
| self.files[filename].completed_at = time.time() | |
| def total_bytes(self) -> int: | |
| """Get total bytes across all files.""" | |
| return sum(f.total_bytes for f in self.files.values()) | |
| def downloaded_bytes(self) -> int: | |
| """Get downloaded bytes across all files.""" | |
| return sum(f.downloaded_bytes for f in self.files.values()) | |
| def percentage(self) -> float: | |
| """Calculate overall download percentage.""" | |
| total = self.total_bytes | |
| if total == 0: | |
| # If no total yet, count completed files | |
| if len(self.files) == 0: | |
| return 0.0 | |
| completed = sum(1 for f in self.files.values() if f.is_complete) | |
| return (completed / len(self.files)) * 100.0 | |
| return min(100.0, (self.downloaded_bytes / total) * 100.0) | |
| def is_complete(self) -> bool: | |
| """Check if all files are downloaded.""" | |
| if len(self.files) == 0: | |
| return False | |
| return all(f.is_complete for f in self.files.values()) | |
| def speed_bytes_per_sec(self) -> float: | |
| """Get overall download speed.""" | |
| total_speed = sum(f.speed_bytes_per_sec for f in self.files.values() if f.started_at) | |
| return total_speed | |
| def elapsed_time(self) -> float: | |
| """Get elapsed time in seconds.""" | |
| if self.started_at is None: | |
| return 0.0 | |
| end_time = self.completed_at or time.time() | |
| return end_time - self.started_at | |
| def to_dict(self) -> Dict: | |
| """Convert to dictionary for JSON serialization.""" | |
| return { | |
| "model_path": self.model_path, | |
| "total_bytes": self.total_bytes, | |
| "downloaded_bytes": self.downloaded_bytes, | |
| "percentage": round(self.percentage, 2), | |
| "speed_bytes_per_sec": round(self.speed_bytes_per_sec, 2), | |
| "speed_mb_per_sec": round(self.speed_bytes_per_sec / (1024 * 1024), 2), | |
| "elapsed_time": round(self.elapsed_time, 2), | |
| "is_complete": self.is_complete, | |
| "files_count": len(self.files), | |
| "files_completed": sum(1 for f in self.files.values() if f.is_complete), | |
| "files": { | |
| name: { | |
| "filename": f.filename, | |
| "total_bytes": f.total_bytes, | |
| "downloaded_bytes": f.downloaded_bytes, | |
| "percentage": round(f.percentage, 2), | |
| "speed_mb_per_sec": round(f.speed_bytes_per_sec / (1024 * 1024), 2), | |
| "is_complete": f.is_complete | |
| } | |
| for name, f in self.files.items() | |
| } | |
| } | |
| class ProgressTracker: | |
| """Thread-safe progress tracker for multiple models.""" | |
| def __init__(self): | |
| self._progress: Dict[str, ModelDownloadProgress] = {} | |
| self._lock = threading.Lock() | |
| def get_or_create(self, model_path: str) -> ModelDownloadProgress: | |
| """Get or create progress tracker for a model.""" | |
| with self._lock: | |
| if model_path not in self._progress: | |
| self._progress[model_path] = ModelDownloadProgress(model_path=model_path) | |
| return self._progress[model_path] | |
| def get(self, model_path: str) -> Optional[ModelDownloadProgress]: | |
| """Get progress tracker for a model.""" | |
| with self._lock: | |
| return self._progress.get(model_path) | |
| def update(self, model_path: str, filename: str, downloaded: int, total: int): | |
| """Update download progress for a file.""" | |
| progress = self.get_or_create(model_path) | |
| progress.update_file(filename, downloaded, total) | |
| def complete_file(self, model_path: str, filename: str): | |
| """Mark a file as complete.""" | |
| progress = self.get(model_path) | |
| if progress: | |
| progress.complete_file(filename) | |
| def complete_model(self, model_path: str): | |
| """Mark entire model download as complete.""" | |
| progress = self.get(model_path) | |
| if progress: | |
| progress.completed_at = time.time() | |
| def get_all(self) -> Dict[str, Dict]: | |
| """Get all progress as dictionary.""" | |
| with self._lock: | |
| return { | |
| path: prog.to_dict() | |
| for path, prog in self._progress.items() | |
| } | |
| def get_model_progress(self, model_path: str) -> Optional[Dict]: | |
| """Get progress for a specific model.""" | |
| progress = self.get(model_path) | |
| if progress: | |
| return progress.to_dict() | |
| return None | |
| # Global progress tracker instance | |
| _global_tracker = ProgressTracker() | |
| def get_progress_tracker() -> ProgressTracker: | |
| """Get global progress tracker instance.""" | |
| return _global_tracker | |
| def create_progress_callback(model_path: str): | |
| """ | |
| Create a progress callback for huggingface_hub downloads. | |
| Usage: | |
| from huggingface_hub import snapshot_download | |
| callback = create_progress_callback("Qwen/Qwen2.5-32B-Instruct") | |
| snapshot_download(repo_id=model_path, resume_download=True, | |
| tqdm_class=callback) | |
| """ | |
| tracker = get_progress_tracker() | |
| class ProgressCallback: | |
| """Progress callback for tqdm.""" | |
| def __init__(self, *args, **kwargs): | |
| # Store tqdm arguments but don't initialize yet | |
| self.tqdm_args = args | |
| self.tqdm_kwargs = kwargs | |
| self.current_file = None | |
| def __call__(self, *args, **kwargs): | |
| # This will be called by huggingface_hub | |
| # We'll intercept the progress updates | |
| pass | |
| def update(self, n: int = 1): | |
| """Update progress.""" | |
| if self.current_file: | |
| # Get current progress from tqdm | |
| if hasattr(self, 'n'): | |
| downloaded = self.n | |
| else: | |
| downloaded = n | |
| if hasattr(self, 'total'): | |
| total = self.total | |
| else: | |
| total = 0 | |
| tracker.update(model_path, self.current_file, downloaded, total) | |
| def set_description(self, desc: str): | |
| """Set description (filename).""" | |
| # Extract filename from description | |
| if desc: | |
| self.current_file = desc.split()[-1] if ' ' in desc else desc | |
| def close(self): | |
| """Close progress bar.""" | |
| if self.current_file: | |
| tracker.complete_file(model_path, self.current_file) | |
| return ProgressCallback | |
| def create_hf_progress_callback(model_path: str): | |
| """ | |
| Create a progress callback compatible with huggingface_hub. | |
| Returns a function that can be used with tqdm. | |
| """ | |
| tracker = get_progress_tracker() | |
| current_file = [None] # Use list to allow modification in nested function | |
| def progress_callback(tqdm_bar): | |
| """Progress callback function.""" | |
| if tqdm_bar.desc: | |
| # Extract filename from description | |
| filename = tqdm_bar.desc.split()[-1] if ' ' in tqdm_bar.desc else tqdm_bar.desc | |
| if filename != current_file[0]: | |
| current_file[0] = filename | |
| if current_file[0] not in tracker.get_or_create(model_path).files: | |
| tracker.get_or_create(model_path).files[current_file[0]] = DownloadProgress( | |
| filename=current_file[0], | |
| started_at=time.time() | |
| ) | |
| if current_file[0]: | |
| downloaded = getattr(tqdm_bar, 'n', 0) | |
| total = getattr(tqdm_bar, 'total', 0) | |
| tracker.update(model_path, current_file[0], downloaded, total) | |
| return progress_callback | |