""" Download progress tracker for Hugging Face models. Tracks real-time download progress in bytes. """ import threading import time from typing import Dict, Optional from dataclasses import dataclass, field @dataclass class DownloadProgress: """Track download progress for a single file.""" filename: str total_bytes: int = 0 downloaded_bytes: int = 0 started_at: Optional[float] = None completed_at: Optional[float] = None speed_bytes_per_sec: float = 0.0 @property def percentage(self) -> float: """Calculate download percentage.""" if self.total_bytes == 0: return 0.0 return min(100.0, (self.downloaded_bytes / self.total_bytes) * 100.0) @property def is_complete(self) -> bool: """Check if download is complete.""" return self.total_bytes > 0 and self.downloaded_bytes >= self.total_bytes @property def elapsed_time(self) -> float: """Get elapsed time in seconds.""" if self.started_at is None: return 0.0 end_time = self.completed_at or time.time() return end_time - self.started_at @dataclass class ModelDownloadProgress: """Track overall download progress for a model.""" model_path: str files: Dict[str, DownloadProgress] = field(default_factory=dict) started_at: Optional[float] = None completed_at: Optional[float] = None def update_file(self, filename: str, downloaded: int, total: int): """Update progress for a specific file.""" if filename not in self.files: self.files[filename] = DownloadProgress( filename=filename, started_at=time.time() ) if self.started_at is None: self.started_at = time.time() file_progress = self.files[filename] file_progress.downloaded_bytes = downloaded file_progress.total_bytes = total # Calculate speed if file_progress.started_at: elapsed = time.time() - file_progress.started_at if elapsed > 0: file_progress.speed_bytes_per_sec = downloaded / elapsed # Mark as complete if total > 0 and downloaded >= total: file_progress.completed_at = time.time() def complete_file(self, filename: str): """Mark a file as complete.""" if filename in self.files: self.files[filename].completed_at = time.time() @property def total_bytes(self) -> int: """Get total bytes across all files.""" return sum(f.total_bytes for f in self.files.values()) @property def downloaded_bytes(self) -> int: """Get downloaded bytes across all files.""" return sum(f.downloaded_bytes for f in self.files.values()) @property def percentage(self) -> float: """Calculate overall download percentage.""" total = self.total_bytes if total == 0: # If no total yet, count completed files if len(self.files) == 0: return 0.0 completed = sum(1 for f in self.files.values() if f.is_complete) return (completed / len(self.files)) * 100.0 return min(100.0, (self.downloaded_bytes / total) * 100.0) @property def is_complete(self) -> bool: """Check if all files are downloaded.""" if len(self.files) == 0: return False return all(f.is_complete for f in self.files.values()) @property def speed_bytes_per_sec(self) -> float: """Get overall download speed.""" total_speed = sum(f.speed_bytes_per_sec for f in self.files.values() if f.started_at) return total_speed @property def elapsed_time(self) -> float: """Get elapsed time in seconds.""" if self.started_at is None: return 0.0 end_time = self.completed_at or time.time() return end_time - self.started_at def to_dict(self) -> Dict: """Convert to dictionary for JSON serialization.""" return { "model_path": self.model_path, "total_bytes": self.total_bytes, "downloaded_bytes": self.downloaded_bytes, "percentage": round(self.percentage, 2), "speed_bytes_per_sec": round(self.speed_bytes_per_sec, 2), "speed_mb_per_sec": round(self.speed_bytes_per_sec / (1024 * 1024), 2), "elapsed_time": round(self.elapsed_time, 2), "is_complete": self.is_complete, "files_count": len(self.files), "files_completed": sum(1 for f in self.files.values() if f.is_complete), "files": { name: { "filename": f.filename, "total_bytes": f.total_bytes, "downloaded_bytes": f.downloaded_bytes, "percentage": round(f.percentage, 2), "speed_mb_per_sec": round(f.speed_bytes_per_sec / (1024 * 1024), 2), "is_complete": f.is_complete } for name, f in self.files.items() } } class ProgressTracker: """Thread-safe progress tracker for multiple models.""" def __init__(self): self._progress: Dict[str, ModelDownloadProgress] = {} self._lock = threading.Lock() def get_or_create(self, model_path: str) -> ModelDownloadProgress: """Get or create progress tracker for a model.""" with self._lock: if model_path not in self._progress: self._progress[model_path] = ModelDownloadProgress(model_path=model_path) return self._progress[model_path] def get(self, model_path: str) -> Optional[ModelDownloadProgress]: """Get progress tracker for a model.""" with self._lock: return self._progress.get(model_path) def update(self, model_path: str, filename: str, downloaded: int, total: int): """Update download progress for a file.""" progress = self.get_or_create(model_path) progress.update_file(filename, downloaded, total) def complete_file(self, model_path: str, filename: str): """Mark a file as complete.""" progress = self.get(model_path) if progress: progress.complete_file(filename) def complete_model(self, model_path: str): """Mark entire model download as complete.""" progress = self.get(model_path) if progress: progress.completed_at = time.time() def get_all(self) -> Dict[str, Dict]: """Get all progress as dictionary.""" with self._lock: return { path: prog.to_dict() for path, prog in self._progress.items() } def get_model_progress(self, model_path: str) -> Optional[Dict]: """Get progress for a specific model.""" progress = self.get(model_path) if progress: return progress.to_dict() return None # Global progress tracker instance _global_tracker = ProgressTracker() def get_progress_tracker() -> ProgressTracker: """Get global progress tracker instance.""" return _global_tracker def create_progress_callback(model_path: str): """ Create a progress callback for huggingface_hub downloads. Usage: from huggingface_hub import snapshot_download callback = create_progress_callback("Qwen/Qwen2.5-32B-Instruct") snapshot_download(repo_id=model_path, resume_download=True, tqdm_class=callback) """ tracker = get_progress_tracker() class ProgressCallback: """Progress callback for tqdm.""" def __init__(self, *args, **kwargs): # Store tqdm arguments but don't initialize yet self.tqdm_args = args self.tqdm_kwargs = kwargs self.current_file = None def __call__(self, *args, **kwargs): # This will be called by huggingface_hub # We'll intercept the progress updates pass def update(self, n: int = 1): """Update progress.""" if self.current_file: # Get current progress from tqdm if hasattr(self, 'n'): downloaded = self.n else: downloaded = n if hasattr(self, 'total'): total = self.total else: total = 0 tracker.update(model_path, self.current_file, downloaded, total) def set_description(self, desc: str): """Set description (filename).""" # Extract filename from description if desc: self.current_file = desc.split()[-1] if ' ' in desc else desc def close(self): """Close progress bar.""" if self.current_file: tracker.complete_file(model_path, self.current_file) return ProgressCallback def create_hf_progress_callback(model_path: str): """ Create a progress callback compatible with huggingface_hub. Returns a function that can be used with tqdm. """ tracker = get_progress_tracker() current_file = [None] # Use list to allow modification in nested function def progress_callback(tqdm_bar): """Progress callback function.""" if tqdm_bar.desc: # Extract filename from description filename = tqdm_bar.desc.split()[-1] if ' ' in tqdm_bar.desc else tqdm_bar.desc if filename != current_file[0]: current_file[0] = filename if current_file[0] not in tracker.get_or_create(model_path).files: tracker.get_or_create(model_path).files[current_file[0]] = DownloadProgress( filename=current_file[0], started_at=time.time() ) if current_file[0]: downloaded = getattr(tqdm_bar, 'n', 0) total = getattr(tqdm_bar, 'total', 0) tracker.update(model_path, current_file[0], downloaded, total) return progress_callback