Spaces:
Sleeping
Sleeping
| """ | |
| Progress tracking for model downloads using Server-Sent Events. | |
| """ | |
| from typing import Optional, Callable, Dict, List | |
| from fastapi.responses import StreamingResponse | |
| import asyncio | |
| import json | |
| import threading | |
| from datetime import datetime | |
| class ProgressManager: | |
| """Manages download progress for multiple models. | |
| Thread-safe: can be called from background threads (e.g., via asyncio.to_thread). | |
| """ | |
| # Throttle settings to prevent overwhelming SSE clients | |
| THROTTLE_INTERVAL_SECONDS = 0.5 # Minimum time between updates | |
| THROTTLE_PROGRESS_DELTA = 1.0 # Minimum progress change (%) to force update | |
| def __init__(self): | |
| self._progress: Dict[str, Dict] = {} | |
| self._listeners: Dict[str, list] = {} | |
| self._lock = threading.Lock() # Thread-safe lock for progress dict | |
| self._main_loop: Optional[asyncio.AbstractEventLoop] = None | |
| self._last_notify_time: Dict[str, float] = {} # Last notification time per model | |
| self._last_notify_progress: Dict[str, float] = {} # Last notified progress per model | |
| def _set_main_loop(self, loop: asyncio.AbstractEventLoop): | |
| """Set the main event loop for thread-safe operations.""" | |
| self._main_loop = loop | |
| def _notify_listeners_threadsafe(self, model_name: str, progress_data: Dict): | |
| """Notify listeners in a thread-safe manner.""" | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| if model_name not in self._listeners: | |
| return | |
| for queue in self._listeners[model_name]: | |
| try: | |
| # Check if we're in the main event loop thread | |
| try: | |
| running_loop = asyncio.get_running_loop() | |
| # We're in an async context, can use put_nowait directly | |
| queue.put_nowait(progress_data.copy()) | |
| except RuntimeError: | |
| # Not in async context (running in background thread) | |
| # Use call_soon_threadsafe to safely put on queue | |
| if self._main_loop and self._main_loop.is_running(): | |
| self._main_loop.call_soon_threadsafe( | |
| lambda q=queue, d=progress_data.copy(): q.put_nowait(d) if not q.full() else None | |
| ) | |
| else: | |
| logger.debug(f"No main loop available for {model_name}, skipping notification") | |
| except asyncio.QueueFull: | |
| logger.warning(f"Queue full for {model_name}, dropping update") | |
| except Exception as e: | |
| logger.warning(f"Error notifying listener for {model_name}: {e}") | |
| def update_progress( | |
| self, | |
| model_name: str, | |
| current: int, | |
| total: int, | |
| filename: Optional[str] = None, | |
| status: str = "downloading", | |
| ): | |
| """ | |
| Update progress for a model download. | |
| Thread-safe: can be called from background threads. | |
| Progress updates are throttled to prevent overwhelming SSE clients. | |
| Updates are sent at most every THROTTLE_INTERVAL_SECONDS, or when | |
| progress changes by at least THROTTLE_PROGRESS_DELTA percent. | |
| Args: | |
| model_name: Name of the model (e.g., "qwen-tts-1.7B", "whisper-base") | |
| current: Current bytes downloaded | |
| total: Total bytes to download | |
| filename: Current file being downloaded | |
| status: Status string (downloading, extracting, complete, error) | |
| """ | |
| import logging | |
| import time | |
| logger = logging.getLogger(__name__) | |
| # Calculate progress percentage, clamped to 0-100 range | |
| # This prevents crazy percentages from edge cases like: | |
| # - current > total temporarily during aggregation | |
| # - mixing file-count progress with byte-count progress | |
| if total > 0: | |
| progress_pct = min(100.0, max(0.0, (current / total * 100))) | |
| else: | |
| progress_pct = 0 | |
| progress_data = { | |
| "model_name": model_name, | |
| "current": current, | |
| "total": total, | |
| "progress": progress_pct, | |
| "filename": filename, | |
| "status": status, | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| # Thread-safe update of progress dict (always update internal state) | |
| with self._lock: | |
| self._progress[model_name] = progress_data | |
| # Check if we should notify listeners (throttling) | |
| current_time = time.time() | |
| last_time = self._last_notify_time.get(model_name, 0) | |
| last_progress = self._last_notify_progress.get(model_name, -100) | |
| time_delta = current_time - last_time | |
| progress_delta = abs(progress_pct - last_progress) | |
| # Always notify for complete/error status, or if throttle conditions are met | |
| should_notify = ( | |
| status in ("complete", "error") or | |
| time_delta >= self.THROTTLE_INTERVAL_SECONDS or | |
| progress_delta >= self.THROTTLE_PROGRESS_DELTA | |
| ) | |
| if not should_notify: | |
| return # Skip this update (throttled) | |
| # Update throttle tracking | |
| self._last_notify_time[model_name] = current_time | |
| self._last_notify_progress[model_name] = progress_pct | |
| # Notify all listeners (thread-safe) | |
| listener_count = len(self._listeners.get(model_name, [])) | |
| if listener_count > 0: | |
| logger.debug(f"Notifying {listener_count} listeners for {model_name}: {progress_pct:.1f}% ({filename})") | |
| self._notify_listeners_threadsafe(model_name, progress_data) | |
| else: | |
| logger.debug(f"No listeners for {model_name}, progress update stored: {progress_pct:.1f}%") | |
| def get_progress(self, model_name: str) -> Optional[Dict]: | |
| """Get current progress for a model. Thread-safe.""" | |
| with self._lock: | |
| progress = self._progress.get(model_name) | |
| return progress.copy() if progress else None | |
| def get_all_active(self) -> List[Dict]: | |
| """Get all active downloads (status is 'downloading' or 'extracting'). Thread-safe.""" | |
| active = [] | |
| with self._lock: | |
| for model_name, progress in self._progress.items(): | |
| status = progress.get("status", "") | |
| if status in ("downloading", "extracting"): | |
| active.append(progress.copy()) | |
| return active | |
| def create_progress_callback(self, model_name: str, filename: Optional[str] = None): | |
| """ | |
| Create a progress callback function for HuggingFace downloads. | |
| Args: | |
| model_name: Name of the model | |
| filename: Optional filename filter | |
| Returns: | |
| Callback function | |
| """ | |
| def callback(progress: Dict): | |
| """HuggingFace Hub progress callback.""" | |
| if "total" in progress and "current" in progress: | |
| current = progress.get("current", 0) | |
| total = progress.get("total", 0) | |
| file_name = progress.get("filename", filename) | |
| self.update_progress( | |
| model_name=model_name, | |
| current=current, | |
| total=total, | |
| filename=file_name, | |
| status="downloading", | |
| ) | |
| return callback | |
| async def subscribe(self, model_name: str): | |
| """ | |
| Subscribe to progress updates for a model. | |
| Yields progress updates as Server-Sent Events. | |
| """ | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| # Store the main event loop for thread-safe operations | |
| try: | |
| self._main_loop = asyncio.get_running_loop() | |
| except RuntimeError: | |
| pass | |
| queue = asyncio.Queue(maxsize=10) | |
| # Add to listeners | |
| if model_name not in self._listeners: | |
| self._listeners[model_name] = [] | |
| self._listeners[model_name].append(queue) | |
| logger.info(f"SSE client subscribed to {model_name}, total listeners: {len(self._listeners[model_name])}") | |
| try: | |
| # Send initial progress if available and still in progress (thread-safe read) | |
| with self._lock: | |
| initial_progress = self._progress.get(model_name) | |
| if initial_progress: | |
| initial_progress = initial_progress.copy() | |
| if initial_progress: | |
| status = initial_progress.get('status') | |
| # Only send initial progress if download is actually in progress | |
| # Don't send old 'complete' or 'error' status from previous downloads | |
| if status in ('downloading', 'extracting'): | |
| logger.info(f"Sending initial progress for {model_name}: {status}") | |
| yield f"data: {json.dumps(initial_progress)}\n\n" | |
| else: | |
| logger.info(f"Skipping initial progress for {model_name} (status: {status})") | |
| else: | |
| logger.info(f"No initial progress available for {model_name}") | |
| # Stream updates | |
| while True: | |
| try: | |
| # Wait for update with timeout | |
| progress = await asyncio.wait_for(queue.get(), timeout=1.0) | |
| logger.debug(f"Sending progress update for {model_name}: {progress.get('status')} - {progress.get('progress', 0):.1f}%") | |
| yield f"data: {json.dumps(progress)}\n\n" | |
| # Stop if complete or error | |
| if progress.get("status") in ("complete", "error"): | |
| logger.info(f"Download {progress.get('status')} for {model_name}, closing SSE connection") | |
| break | |
| except asyncio.TimeoutError: | |
| # Send heartbeat | |
| yield ": heartbeat\n\n" | |
| continue | |
| except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError): | |
| logger.debug(f"SSE client disconnected from {model_name}") | |
| finally: | |
| # Remove from listeners | |
| if model_name in self._listeners: | |
| self._listeners[model_name].remove(queue) | |
| if not self._listeners[model_name]: | |
| del self._listeners[model_name] | |
| logger.info(f"SSE client unsubscribed from {model_name}, remaining listeners: {len(self._listeners.get(model_name, []))}") | |
| def mark_complete(self, model_name: str): | |
| """Mark a model download as complete. Thread-safe.""" | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| with self._lock: | |
| if model_name in self._progress: | |
| self._progress[model_name]["status"] = "complete" | |
| self._progress[model_name]["progress"] = 100.0 | |
| progress_data = self._progress[model_name].copy() | |
| else: | |
| logger.warning(f"Cannot mark {model_name} as complete: not found in progress") | |
| return | |
| logger.info(f"Marked {model_name} as complete") | |
| # Notify listeners (thread-safe) | |
| self._notify_listeners_threadsafe(model_name, progress_data) | |
| def mark_error(self, model_name: str, error: str): | |
| """Mark a model download as failed. Thread-safe.""" | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| with self._lock: | |
| if model_name in self._progress: | |
| self._progress[model_name]["status"] = "error" | |
| self._progress[model_name]["error"] = error | |
| progress_data = self._progress[model_name].copy() | |
| else: | |
| # Create new progress entry for error | |
| progress_data = { | |
| "model_name": model_name, | |
| "current": 0, | |
| "total": 0, | |
| "progress": 0, | |
| "filename": None, | |
| "status": "error", | |
| "error": error, | |
| "timestamp": datetime.now().isoformat(), | |
| } | |
| self._progress[model_name] = progress_data | |
| logger.error(f"Marked {model_name} as error: {error}") | |
| # Notify listeners (thread-safe) | |
| self._notify_listeners_threadsafe(model_name, progress_data) | |
| # Global progress manager instance | |
| _progress_manager: Optional[ProgressManager] = None | |
| def get_progress_manager() -> ProgressManager: | |
| """Get or create the global progress manager.""" | |
| global _progress_manager | |
| if _progress_manager is None: | |
| _progress_manager = ProgressManager() | |
| return _progress_manager | |