CharacterForgePro / src /model_manager.py
ghmk's picture
Deploy full Character Sheet Pro with HF auth
da23dfe
"""
Model Manager
=============
Manages model loading states and provides a robust interface for
ensuring models are loaded and validated before generation.
States:
- UNLOADED: No model loaded
- LOADING: Model is being loaded
- READY: Model loaded and validated
- ERROR: Model failed to load
"""
import logging
import threading
import time
from typing import Optional, Callable, Tuple
from enum import Enum
from PIL import Image
from .backend_router import BackendRouter, BackendType
from .character_service import CharacterSheetService
logger = logging.getLogger(__name__)
class ModelState(Enum):
"""Model loading states."""
UNLOADED = "unloaded"
LOADING = "loading"
READY = "ready"
ERROR = "error"
class ModelManager:
"""
Manages model loading lifecycle with state tracking.
Ensures models are fully loaded and validated before allowing generation.
Provides progress callbacks for UI updates during loading.
"""
def __init__(self):
self._state = ModelState.UNLOADED
self._current_backend: Optional[BackendType] = None
self._service: Optional[CharacterSheetService] = None
self._error_message: Optional[str] = None
self._loading_progress: float = 0.0
self._loading_message: str = ""
self._lock = threading.Lock()
self._cancel_requested = False
@property
def state(self) -> ModelState:
"""Current model state."""
return self._state
@property
def is_ready(self) -> bool:
"""Check if model is ready for generation."""
return self._state == ModelState.READY
@property
def is_loading(self) -> bool:
"""Check if model is currently loading."""
return self._state == ModelState.LOADING
@property
def error_message(self) -> Optional[str]:
"""Get error message if in error state."""
return self._error_message
@property
def loading_progress(self) -> float:
"""Get loading progress (0.0 to 1.0)."""
return self._loading_progress
@property
def loading_message(self) -> str:
"""Get current loading status message."""
return self._loading_message
@property
def current_backend(self) -> Optional[BackendType]:
"""Get currently loaded backend."""
return self._current_backend
@property
def service(self) -> Optional[CharacterSheetService]:
"""Get the character sheet service (only valid when ready)."""
if self._state != ModelState.READY:
return None
return self._service
def get_status_display(self) -> Tuple[str, str]:
"""
Get status message and color for UI display.
Returns:
Tuple of (message, color) where color is a CSS color string
"""
if self._state == ModelState.UNLOADED:
return "No model loaded", "#888888"
elif self._state == ModelState.LOADING:
pct = int(self._loading_progress * 100)
return f"Loading... {pct}% - {self._loading_message}", "#FFA500"
elif self._state == ModelState.READY:
backend_name = BackendRouter.BACKEND_NAMES.get(
self._current_backend,
str(self._current_backend)
)
return f"Ready: {backend_name}", "#00AA00"
elif self._state == ModelState.ERROR:
return f"Error: {self._error_message}", "#FF0000"
return "Unknown state", "#888888"
def request_cancel(self):
"""Request cancellation of current loading operation."""
self._cancel_requested = True
logger.info("Model loading cancellation requested")
def load_model(
self,
backend: BackendType,
api_key: Optional[str] = None,
steps: int = 4,
guidance: float = 1.0,
progress_callback: Optional[Callable[[float, str], None]] = None
) -> bool:
"""
Load a model with progress tracking.
Args:
backend: Backend type to load
api_key: API key for cloud backends
steps: Default steps for generation
guidance: Default guidance scale
progress_callback: Callback for progress updates (progress, message)
Returns:
True if model loaded successfully
"""
with self._lock:
if self._state == ModelState.LOADING:
logger.warning("Model is already loading, ignoring request")
return False
self._state = ModelState.LOADING
self._loading_progress = 0.0
self._loading_message = "Initializing..."
self._error_message = None
self._cancel_requested = False
def update_progress(progress: float, message: str):
self._loading_progress = progress
self._loading_message = message
if progress_callback:
progress_callback(progress, message)
try:
# Step 1: Unload previous model if different backend
update_progress(0.05, "Checking current model...")
if self._service and self._current_backend != backend:
update_progress(0.1, "Unloading previous model...")
try:
if hasattr(self._service, 'router'):
self._service.router.unload_local_models()
except Exception as e:
logger.warning(f"Error unloading previous model: {e}")
self._service = None
if self._cancel_requested:
self._state = ModelState.UNLOADED
return False
# Step 2: Create service and load model
backend_name = BackendRouter.BACKEND_NAMES.get(backend, str(backend))
update_progress(0.15, f"Loading {backend_name}...")
logger.info(f"Creating CharacterSheetService for {backend.value}")
# For local models, this will load the model
# For cloud backends, this just validates the API key
self._service = CharacterSheetService(
api_key=api_key,
backend=backend
)
if self._cancel_requested:
self._state = ModelState.UNLOADED
self._service = None
return False
update_progress(0.7, "Model loaded, configuring...")
# Step 3: Configure default parameters
if hasattr(self._service.client, 'default_steps'):
self._service.client.default_steps = steps
if hasattr(self._service.client, 'default_guidance'):
self._service.client.default_guidance = guidance
update_progress(0.8, "Validating model...")
# Step 4: Validate model is actually working
is_valid, error = self._validate_model()
if not is_valid:
raise RuntimeError(f"Model validation failed: {error}")
update_progress(1.0, "Ready!")
# Success!
with self._lock:
self._current_backend = backend
self._state = ModelState.READY
self._loading_progress = 1.0
self._loading_message = "Ready"
logger.info(f"Model {backend.value} loaded and validated successfully")
return True
except Exception as e:
error_msg = str(e)
logger.error(f"Failed to load model {backend.value}: {error_msg}", exc_info=True)
with self._lock:
self._state = ModelState.ERROR
self._error_message = self._simplify_error(error_msg)
self._service = None
if progress_callback:
progress_callback(0.0, f"Error: {self._error_message}")
return False
def _validate_model(self) -> Tuple[bool, Optional[str]]:
"""
Validate that the model is actually working.
For local models, checks that the pipeline is loaded.
For cloud backends, does a minimal health check.
Returns:
Tuple of (is_valid, error_message)
"""
if self._service is None:
return False, "Service not initialized"
try:
client = self._service.client
# Check if client has health check method
if hasattr(client, 'is_healthy'):
if not client.is_healthy():
return False, "Client health check failed"
# For local models, check pipeline is loaded
if hasattr(client, '_loaded'):
if not client._loaded:
return False, "Model pipeline not loaded"
# For FLUX models, verify the pipe exists
if hasattr(client, 'pipe'):
if client.pipe is None:
return False, "Model pipeline is None"
return True, None
except Exception as e:
return False, str(e)
def _simplify_error(self, error: str) -> str:
"""Simplify technical error messages for user display."""
error_lower = error.lower()
if "cuda out of memory" in error_lower or "out of memory" in error_lower:
return "Not enough GPU memory. Try a smaller model or close other applications."
if "api key" in error_lower:
return "Invalid or missing API key."
if "connection" in error_lower or "network" in error_lower:
return "Network connection error. Check your internet connection."
if "not found" in error_lower and "model" in error_lower:
return "Model files not found. The model may need to be downloaded."
if "import" in error_lower:
return "Missing dependencies. Some required packages are not installed."
if "meta tensor" in error_lower:
return "Model loading failed (meta tensor error). Try restarting the application."
# Truncate long errors
if len(error) > 100:
return error[:97] + "..."
return error
def unload(self):
"""Unload the current model."""
with self._lock:
if self._service:
try:
if hasattr(self._service, 'router'):
self._service.router.unload_local_models()
except Exception as e:
logger.warning(f"Error during unload: {e}")
self._service = None
self._state = ModelState.UNLOADED
self._current_backend = None
self._error_message = None
self._loading_progress = 0.0
self._loading_message = ""
logger.info("Model unloaded")
# Global singleton for model management
_model_manager: Optional[ModelManager] = None
def get_model_manager() -> ModelManager:
"""Get the global ModelManager instance."""
global _model_manager
if _model_manager is None:
_model_manager = ModelManager()
return _model_manager