Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Model Manager | |
| ============= | |
| Manages model loading states and provides a robust interface for | |
| ensuring models are loaded and validated before generation. | |
| States: | |
| - UNLOADED: No model loaded | |
| - LOADING: Model is being loaded | |
| - READY: Model loaded and validated | |
| - ERROR: Model failed to load | |
| """ | |
| import logging | |
| import threading | |
| import time | |
| from typing import Optional, Callable, Tuple | |
| from enum import Enum | |
| from PIL import Image | |
| from .backend_router import BackendRouter, BackendType | |
| from .character_service import CharacterSheetService | |
| logger = logging.getLogger(__name__) | |
| class ModelState(Enum): | |
| """Model loading states.""" | |
| UNLOADED = "unloaded" | |
| LOADING = "loading" | |
| READY = "ready" | |
| ERROR = "error" | |
| class ModelManager: | |
| """ | |
| Manages model loading lifecycle with state tracking. | |
| Ensures models are fully loaded and validated before allowing generation. | |
| Provides progress callbacks for UI updates during loading. | |
| """ | |
| def __init__(self): | |
| self._state = ModelState.UNLOADED | |
| self._current_backend: Optional[BackendType] = None | |
| self._service: Optional[CharacterSheetService] = None | |
| self._error_message: Optional[str] = None | |
| self._loading_progress: float = 0.0 | |
| self._loading_message: str = "" | |
| self._lock = threading.Lock() | |
| self._cancel_requested = False | |
| def state(self) -> ModelState: | |
| """Current model state.""" | |
| return self._state | |
| def is_ready(self) -> bool: | |
| """Check if model is ready for generation.""" | |
| return self._state == ModelState.READY | |
| def is_loading(self) -> bool: | |
| """Check if model is currently loading.""" | |
| return self._state == ModelState.LOADING | |
| def error_message(self) -> Optional[str]: | |
| """Get error message if in error state.""" | |
| return self._error_message | |
| def loading_progress(self) -> float: | |
| """Get loading progress (0.0 to 1.0).""" | |
| return self._loading_progress | |
| def loading_message(self) -> str: | |
| """Get current loading status message.""" | |
| return self._loading_message | |
| def current_backend(self) -> Optional[BackendType]: | |
| """Get currently loaded backend.""" | |
| return self._current_backend | |
| def service(self) -> Optional[CharacterSheetService]: | |
| """Get the character sheet service (only valid when ready).""" | |
| if self._state != ModelState.READY: | |
| return None | |
| return self._service | |
| def get_status_display(self) -> Tuple[str, str]: | |
| """ | |
| Get status message and color for UI display. | |
| Returns: | |
| Tuple of (message, color) where color is a CSS color string | |
| """ | |
| if self._state == ModelState.UNLOADED: | |
| return "No model loaded", "#888888" | |
| elif self._state == ModelState.LOADING: | |
| pct = int(self._loading_progress * 100) | |
| return f"Loading... {pct}% - {self._loading_message}", "#FFA500" | |
| elif self._state == ModelState.READY: | |
| backend_name = BackendRouter.BACKEND_NAMES.get( | |
| self._current_backend, | |
| str(self._current_backend) | |
| ) | |
| return f"Ready: {backend_name}", "#00AA00" | |
| elif self._state == ModelState.ERROR: | |
| return f"Error: {self._error_message}", "#FF0000" | |
| return "Unknown state", "#888888" | |
| def request_cancel(self): | |
| """Request cancellation of current loading operation.""" | |
| self._cancel_requested = True | |
| logger.info("Model loading cancellation requested") | |
| def load_model( | |
| self, | |
| backend: BackendType, | |
| api_key: Optional[str] = None, | |
| steps: int = 4, | |
| guidance: float = 1.0, | |
| progress_callback: Optional[Callable[[float, str], None]] = None | |
| ) -> bool: | |
| """ | |
| Load a model with progress tracking. | |
| Args: | |
| backend: Backend type to load | |
| api_key: API key for cloud backends | |
| steps: Default steps for generation | |
| guidance: Default guidance scale | |
| progress_callback: Callback for progress updates (progress, message) | |
| Returns: | |
| True if model loaded successfully | |
| """ | |
| with self._lock: | |
| if self._state == ModelState.LOADING: | |
| logger.warning("Model is already loading, ignoring request") | |
| return False | |
| self._state = ModelState.LOADING | |
| self._loading_progress = 0.0 | |
| self._loading_message = "Initializing..." | |
| self._error_message = None | |
| self._cancel_requested = False | |
| def update_progress(progress: float, message: str): | |
| self._loading_progress = progress | |
| self._loading_message = message | |
| if progress_callback: | |
| progress_callback(progress, message) | |
| try: | |
| # Step 1: Unload previous model if different backend | |
| update_progress(0.05, "Checking current model...") | |
| if self._service and self._current_backend != backend: | |
| update_progress(0.1, "Unloading previous model...") | |
| try: | |
| if hasattr(self._service, 'router'): | |
| self._service.router.unload_local_models() | |
| except Exception as e: | |
| logger.warning(f"Error unloading previous model: {e}") | |
| self._service = None | |
| if self._cancel_requested: | |
| self._state = ModelState.UNLOADED | |
| return False | |
| # Step 2: Create service and load model | |
| backend_name = BackendRouter.BACKEND_NAMES.get(backend, str(backend)) | |
| update_progress(0.15, f"Loading {backend_name}...") | |
| logger.info(f"Creating CharacterSheetService for {backend.value}") | |
| # For local models, this will load the model | |
| # For cloud backends, this just validates the API key | |
| self._service = CharacterSheetService( | |
| api_key=api_key, | |
| backend=backend | |
| ) | |
| if self._cancel_requested: | |
| self._state = ModelState.UNLOADED | |
| self._service = None | |
| return False | |
| update_progress(0.7, "Model loaded, configuring...") | |
| # Step 3: Configure default parameters | |
| if hasattr(self._service.client, 'default_steps'): | |
| self._service.client.default_steps = steps | |
| if hasattr(self._service.client, 'default_guidance'): | |
| self._service.client.default_guidance = guidance | |
| update_progress(0.8, "Validating model...") | |
| # Step 4: Validate model is actually working | |
| is_valid, error = self._validate_model() | |
| if not is_valid: | |
| raise RuntimeError(f"Model validation failed: {error}") | |
| update_progress(1.0, "Ready!") | |
| # Success! | |
| with self._lock: | |
| self._current_backend = backend | |
| self._state = ModelState.READY | |
| self._loading_progress = 1.0 | |
| self._loading_message = "Ready" | |
| logger.info(f"Model {backend.value} loaded and validated successfully") | |
| return True | |
| except Exception as e: | |
| error_msg = str(e) | |
| logger.error(f"Failed to load model {backend.value}: {error_msg}", exc_info=True) | |
| with self._lock: | |
| self._state = ModelState.ERROR | |
| self._error_message = self._simplify_error(error_msg) | |
| self._service = None | |
| if progress_callback: | |
| progress_callback(0.0, f"Error: {self._error_message}") | |
| return False | |
| def _validate_model(self) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Validate that the model is actually working. | |
| For local models, checks that the pipeline is loaded. | |
| For cloud backends, does a minimal health check. | |
| Returns: | |
| Tuple of (is_valid, error_message) | |
| """ | |
| if self._service is None: | |
| return False, "Service not initialized" | |
| try: | |
| client = self._service.client | |
| # Check if client has health check method | |
| if hasattr(client, 'is_healthy'): | |
| if not client.is_healthy(): | |
| return False, "Client health check failed" | |
| # For local models, check pipeline is loaded | |
| if hasattr(client, '_loaded'): | |
| if not client._loaded: | |
| return False, "Model pipeline not loaded" | |
| # For FLUX models, verify the pipe exists | |
| if hasattr(client, 'pipe'): | |
| if client.pipe is None: | |
| return False, "Model pipeline is None" | |
| return True, None | |
| except Exception as e: | |
| return False, str(e) | |
| def _simplify_error(self, error: str) -> str: | |
| """Simplify technical error messages for user display.""" | |
| error_lower = error.lower() | |
| if "cuda out of memory" in error_lower or "out of memory" in error_lower: | |
| return "Not enough GPU memory. Try a smaller model or close other applications." | |
| if "api key" in error_lower: | |
| return "Invalid or missing API key." | |
| if "connection" in error_lower or "network" in error_lower: | |
| return "Network connection error. Check your internet connection." | |
| if "not found" in error_lower and "model" in error_lower: | |
| return "Model files not found. The model may need to be downloaded." | |
| if "import" in error_lower: | |
| return "Missing dependencies. Some required packages are not installed." | |
| if "meta tensor" in error_lower: | |
| return "Model loading failed (meta tensor error). Try restarting the application." | |
| # Truncate long errors | |
| if len(error) > 100: | |
| return error[:97] + "..." | |
| return error | |
| def unload(self): | |
| """Unload the current model.""" | |
| with self._lock: | |
| if self._service: | |
| try: | |
| if hasattr(self._service, 'router'): | |
| self._service.router.unload_local_models() | |
| except Exception as e: | |
| logger.warning(f"Error during unload: {e}") | |
| self._service = None | |
| self._state = ModelState.UNLOADED | |
| self._current_backend = None | |
| self._error_message = None | |
| self._loading_progress = 0.0 | |
| self._loading_message = "" | |
| logger.info("Model unloaded") | |
| # Global singleton for model management | |
| _model_manager: Optional[ModelManager] = None | |
| def get_model_manager() -> ModelManager: | |
| """Get the global ModelManager instance.""" | |
| global _model_manager | |
| if _model_manager is None: | |
| _model_manager = ModelManager() | |
| return _model_manager | |