""" Model Manager ============= Manages model loading states and provides a robust interface for ensuring models are loaded and validated before generation. States: - UNLOADED: No model loaded - LOADING: Model is being loaded - READY: Model loaded and validated - ERROR: Model failed to load """ import logging import threading import time from typing import Optional, Callable, Tuple from enum import Enum from PIL import Image from .backend_router import BackendRouter, BackendType from .character_service import CharacterSheetService logger = logging.getLogger(__name__) class ModelState(Enum): """Model loading states.""" UNLOADED = "unloaded" LOADING = "loading" READY = "ready" ERROR = "error" class ModelManager: """ Manages model loading lifecycle with state tracking. Ensures models are fully loaded and validated before allowing generation. Provides progress callbacks for UI updates during loading. """ def __init__(self): self._state = ModelState.UNLOADED self._current_backend: Optional[BackendType] = None self._service: Optional[CharacterSheetService] = None self._error_message: Optional[str] = None self._loading_progress: float = 0.0 self._loading_message: str = "" self._lock = threading.Lock() self._cancel_requested = False @property def state(self) -> ModelState: """Current model state.""" return self._state @property def is_ready(self) -> bool: """Check if model is ready for generation.""" return self._state == ModelState.READY @property def is_loading(self) -> bool: """Check if model is currently loading.""" return self._state == ModelState.LOADING @property def error_message(self) -> Optional[str]: """Get error message if in error state.""" return self._error_message @property def loading_progress(self) -> float: """Get loading progress (0.0 to 1.0).""" return self._loading_progress @property def loading_message(self) -> str: """Get current loading status message.""" return self._loading_message @property def current_backend(self) -> Optional[BackendType]: """Get currently loaded backend.""" return self._current_backend @property def service(self) -> Optional[CharacterSheetService]: """Get the character sheet service (only valid when ready).""" if self._state != ModelState.READY: return None return self._service def get_status_display(self) -> Tuple[str, str]: """ Get status message and color for UI display. Returns: Tuple of (message, color) where color is a CSS color string """ if self._state == ModelState.UNLOADED: return "No model loaded", "#888888" elif self._state == ModelState.LOADING: pct = int(self._loading_progress * 100) return f"Loading... {pct}% - {self._loading_message}", "#FFA500" elif self._state == ModelState.READY: backend_name = BackendRouter.BACKEND_NAMES.get( self._current_backend, str(self._current_backend) ) return f"Ready: {backend_name}", "#00AA00" elif self._state == ModelState.ERROR: return f"Error: {self._error_message}", "#FF0000" return "Unknown state", "#888888" def request_cancel(self): """Request cancellation of current loading operation.""" self._cancel_requested = True logger.info("Model loading cancellation requested") def load_model( self, backend: BackendType, api_key: Optional[str] = None, steps: int = 4, guidance: float = 1.0, progress_callback: Optional[Callable[[float, str], None]] = None ) -> bool: """ Load a model with progress tracking. Args: backend: Backend type to load api_key: API key for cloud backends steps: Default steps for generation guidance: Default guidance scale progress_callback: Callback for progress updates (progress, message) Returns: True if model loaded successfully """ with self._lock: if self._state == ModelState.LOADING: logger.warning("Model is already loading, ignoring request") return False self._state = ModelState.LOADING self._loading_progress = 0.0 self._loading_message = "Initializing..." self._error_message = None self._cancel_requested = False def update_progress(progress: float, message: str): self._loading_progress = progress self._loading_message = message if progress_callback: progress_callback(progress, message) try: # Step 1: Unload previous model if different backend update_progress(0.05, "Checking current model...") if self._service and self._current_backend != backend: update_progress(0.1, "Unloading previous model...") try: if hasattr(self._service, 'router'): self._service.router.unload_local_models() except Exception as e: logger.warning(f"Error unloading previous model: {e}") self._service = None if self._cancel_requested: self._state = ModelState.UNLOADED return False # Step 2: Create service and load model backend_name = BackendRouter.BACKEND_NAMES.get(backend, str(backend)) update_progress(0.15, f"Loading {backend_name}...") logger.info(f"Creating CharacterSheetService for {backend.value}") # For local models, this will load the model # For cloud backends, this just validates the API key self._service = CharacterSheetService( api_key=api_key, backend=backend ) if self._cancel_requested: self._state = ModelState.UNLOADED self._service = None return False update_progress(0.7, "Model loaded, configuring...") # Step 3: Configure default parameters if hasattr(self._service.client, 'default_steps'): self._service.client.default_steps = steps if hasattr(self._service.client, 'default_guidance'): self._service.client.default_guidance = guidance update_progress(0.8, "Validating model...") # Step 4: Validate model is actually working is_valid, error = self._validate_model() if not is_valid: raise RuntimeError(f"Model validation failed: {error}") update_progress(1.0, "Ready!") # Success! with self._lock: self._current_backend = backend self._state = ModelState.READY self._loading_progress = 1.0 self._loading_message = "Ready" logger.info(f"Model {backend.value} loaded and validated successfully") return True except Exception as e: error_msg = str(e) logger.error(f"Failed to load model {backend.value}: {error_msg}", exc_info=True) with self._lock: self._state = ModelState.ERROR self._error_message = self._simplify_error(error_msg) self._service = None if progress_callback: progress_callback(0.0, f"Error: {self._error_message}") return False def _validate_model(self) -> Tuple[bool, Optional[str]]: """ Validate that the model is actually working. For local models, checks that the pipeline is loaded. For cloud backends, does a minimal health check. Returns: Tuple of (is_valid, error_message) """ if self._service is None: return False, "Service not initialized" try: client = self._service.client # Check if client has health check method if hasattr(client, 'is_healthy'): if not client.is_healthy(): return False, "Client health check failed" # For local models, check pipeline is loaded if hasattr(client, '_loaded'): if not client._loaded: return False, "Model pipeline not loaded" # For FLUX models, verify the pipe exists if hasattr(client, 'pipe'): if client.pipe is None: return False, "Model pipeline is None" return True, None except Exception as e: return False, str(e) def _simplify_error(self, error: str) -> str: """Simplify technical error messages for user display.""" error_lower = error.lower() if "cuda out of memory" in error_lower or "out of memory" in error_lower: return "Not enough GPU memory. Try a smaller model or close other applications." if "api key" in error_lower: return "Invalid or missing API key." if "connection" in error_lower or "network" in error_lower: return "Network connection error. Check your internet connection." if "not found" in error_lower and "model" in error_lower: return "Model files not found. The model may need to be downloaded." if "import" in error_lower: return "Missing dependencies. Some required packages are not installed." if "meta tensor" in error_lower: return "Model loading failed (meta tensor error). Try restarting the application." # Truncate long errors if len(error) > 100: return error[:97] + "..." return error def unload(self): """Unload the current model.""" with self._lock: if self._service: try: if hasattr(self._service, 'router'): self._service.router.unload_local_models() except Exception as e: logger.warning(f"Error during unload: {e}") self._service = None self._state = ModelState.UNLOADED self._current_backend = None self._error_message = None self._loading_progress = 0.0 self._loading_message = "" logger.info("Model unloaded") # Global singleton for model management _model_manager: Optional[ModelManager] = None def get_model_manager() -> ModelManager: """Get the global ModelManager instance.""" global _model_manager if _model_manager is None: _model_manager = ModelManager() return _model_manager