Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,092 Bytes
da23dfe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 |
"""
Model Manager
=============
Manages model loading states and provides a robust interface for
ensuring models are loaded and validated before generation.
States:
- UNLOADED: No model loaded
- LOADING: Model is being loaded
- READY: Model loaded and validated
- ERROR: Model failed to load
"""
import logging
import threading
import time
from typing import Optional, Callable, Tuple
from enum import Enum
from PIL import Image
from .backend_router import BackendRouter, BackendType
from .character_service import CharacterSheetService
logger = logging.getLogger(__name__)
class ModelState(Enum):
"""Model loading states."""
UNLOADED = "unloaded"
LOADING = "loading"
READY = "ready"
ERROR = "error"
class ModelManager:
"""
Manages model loading lifecycle with state tracking.
Ensures models are fully loaded and validated before allowing generation.
Provides progress callbacks for UI updates during loading.
"""
def __init__(self):
self._state = ModelState.UNLOADED
self._current_backend: Optional[BackendType] = None
self._service: Optional[CharacterSheetService] = None
self._error_message: Optional[str] = None
self._loading_progress: float = 0.0
self._loading_message: str = ""
self._lock = threading.Lock()
self._cancel_requested = False
@property
def state(self) -> ModelState:
"""Current model state."""
return self._state
@property
def is_ready(self) -> bool:
"""Check if model is ready for generation."""
return self._state == ModelState.READY
@property
def is_loading(self) -> bool:
"""Check if model is currently loading."""
return self._state == ModelState.LOADING
@property
def error_message(self) -> Optional[str]:
"""Get error message if in error state."""
return self._error_message
@property
def loading_progress(self) -> float:
"""Get loading progress (0.0 to 1.0)."""
return self._loading_progress
@property
def loading_message(self) -> str:
"""Get current loading status message."""
return self._loading_message
@property
def current_backend(self) -> Optional[BackendType]:
"""Get currently loaded backend."""
return self._current_backend
@property
def service(self) -> Optional[CharacterSheetService]:
"""Get the character sheet service (only valid when ready)."""
if self._state != ModelState.READY:
return None
return self._service
def get_status_display(self) -> Tuple[str, str]:
"""
Get status message and color for UI display.
Returns:
Tuple of (message, color) where color is a CSS color string
"""
if self._state == ModelState.UNLOADED:
return "No model loaded", "#888888"
elif self._state == ModelState.LOADING:
pct = int(self._loading_progress * 100)
return f"Loading... {pct}% - {self._loading_message}", "#FFA500"
elif self._state == ModelState.READY:
backend_name = BackendRouter.BACKEND_NAMES.get(
self._current_backend,
str(self._current_backend)
)
return f"Ready: {backend_name}", "#00AA00"
elif self._state == ModelState.ERROR:
return f"Error: {self._error_message}", "#FF0000"
return "Unknown state", "#888888"
def request_cancel(self):
"""Request cancellation of current loading operation."""
self._cancel_requested = True
logger.info("Model loading cancellation requested")
def load_model(
self,
backend: BackendType,
api_key: Optional[str] = None,
steps: int = 4,
guidance: float = 1.0,
progress_callback: Optional[Callable[[float, str], None]] = None
) -> bool:
"""
Load a model with progress tracking.
Args:
backend: Backend type to load
api_key: API key for cloud backends
steps: Default steps for generation
guidance: Default guidance scale
progress_callback: Callback for progress updates (progress, message)
Returns:
True if model loaded successfully
"""
with self._lock:
if self._state == ModelState.LOADING:
logger.warning("Model is already loading, ignoring request")
return False
self._state = ModelState.LOADING
self._loading_progress = 0.0
self._loading_message = "Initializing..."
self._error_message = None
self._cancel_requested = False
def update_progress(progress: float, message: str):
self._loading_progress = progress
self._loading_message = message
if progress_callback:
progress_callback(progress, message)
try:
# Step 1: Unload previous model if different backend
update_progress(0.05, "Checking current model...")
if self._service and self._current_backend != backend:
update_progress(0.1, "Unloading previous model...")
try:
if hasattr(self._service, 'router'):
self._service.router.unload_local_models()
except Exception as e:
logger.warning(f"Error unloading previous model: {e}")
self._service = None
if self._cancel_requested:
self._state = ModelState.UNLOADED
return False
# Step 2: Create service and load model
backend_name = BackendRouter.BACKEND_NAMES.get(backend, str(backend))
update_progress(0.15, f"Loading {backend_name}...")
logger.info(f"Creating CharacterSheetService for {backend.value}")
# For local models, this will load the model
# For cloud backends, this just validates the API key
self._service = CharacterSheetService(
api_key=api_key,
backend=backend
)
if self._cancel_requested:
self._state = ModelState.UNLOADED
self._service = None
return False
update_progress(0.7, "Model loaded, configuring...")
# Step 3: Configure default parameters
if hasattr(self._service.client, 'default_steps'):
self._service.client.default_steps = steps
if hasattr(self._service.client, 'default_guidance'):
self._service.client.default_guidance = guidance
update_progress(0.8, "Validating model...")
# Step 4: Validate model is actually working
is_valid, error = self._validate_model()
if not is_valid:
raise RuntimeError(f"Model validation failed: {error}")
update_progress(1.0, "Ready!")
# Success!
with self._lock:
self._current_backend = backend
self._state = ModelState.READY
self._loading_progress = 1.0
self._loading_message = "Ready"
logger.info(f"Model {backend.value} loaded and validated successfully")
return True
except Exception as e:
error_msg = str(e)
logger.error(f"Failed to load model {backend.value}: {error_msg}", exc_info=True)
with self._lock:
self._state = ModelState.ERROR
self._error_message = self._simplify_error(error_msg)
self._service = None
if progress_callback:
progress_callback(0.0, f"Error: {self._error_message}")
return False
def _validate_model(self) -> Tuple[bool, Optional[str]]:
"""
Validate that the model is actually working.
For local models, checks that the pipeline is loaded.
For cloud backends, does a minimal health check.
Returns:
Tuple of (is_valid, error_message)
"""
if self._service is None:
return False, "Service not initialized"
try:
client = self._service.client
# Check if client has health check method
if hasattr(client, 'is_healthy'):
if not client.is_healthy():
return False, "Client health check failed"
# For local models, check pipeline is loaded
if hasattr(client, '_loaded'):
if not client._loaded:
return False, "Model pipeline not loaded"
# For FLUX models, verify the pipe exists
if hasattr(client, 'pipe'):
if client.pipe is None:
return False, "Model pipeline is None"
return True, None
except Exception as e:
return False, str(e)
def _simplify_error(self, error: str) -> str:
"""Simplify technical error messages for user display."""
error_lower = error.lower()
if "cuda out of memory" in error_lower or "out of memory" in error_lower:
return "Not enough GPU memory. Try a smaller model or close other applications."
if "api key" in error_lower:
return "Invalid or missing API key."
if "connection" in error_lower or "network" in error_lower:
return "Network connection error. Check your internet connection."
if "not found" in error_lower and "model" in error_lower:
return "Model files not found. The model may need to be downloaded."
if "import" in error_lower:
return "Missing dependencies. Some required packages are not installed."
if "meta tensor" in error_lower:
return "Model loading failed (meta tensor error). Try restarting the application."
# Truncate long errors
if len(error) > 100:
return error[:97] + "..."
return error
def unload(self):
"""Unload the current model."""
with self._lock:
if self._service:
try:
if hasattr(self._service, 'router'):
self._service.router.unload_local_models()
except Exception as e:
logger.warning(f"Error during unload: {e}")
self._service = None
self._state = ModelState.UNLOADED
self._current_backend = None
self._error_message = None
self._loading_progress = 0.0
self._loading_message = ""
logger.info("Model unloaded")
# Global singleton for model management
_model_manager: Optional[ModelManager] = None
def get_model_manager() -> ModelManager:
"""Get the global ModelManager instance."""
global _model_manager
if _model_manager is None:
_model_manager = ModelManager()
return _model_manager
|