Spaces:
Running
Running
| """ | |
| Memory management module for SeedVR2 | |
| Handles VRAM usage, cache management, and memory optimization | |
| Extracted from: seedvr2.py (lines 373-405, 607-626, 1016-1044) | |
| """ | |
| import torch | |
| import gc | |
| import sys | |
| import time | |
| import psutil | |
| import platform | |
| from typing import Tuple, Dict, Any, Optional, List, Union | |
| def _device_str(device: Union[torch.device, str]) -> str: | |
| """Normalized uppercase device string for comparison and logging. MPS variants → 'MPS'.""" | |
| s = str(device).upper() | |
| return 'MPS' if s.startswith('MPS') else s | |
| def is_mps_available() -> bool: | |
| """Check if MPS (Apple Metal) backend is available.""" | |
| return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() | |
| def is_cuda_available() -> bool: | |
| """Check if CUDA backend is available.""" | |
| return torch.cuda.is_available() | |
| def get_gpu_backend() -> str: | |
| """Get the active GPU backend type. | |
| Returns: | |
| 'cuda': NVIDIA CUDA | |
| 'mps': Apple Metal Performance Shaders | |
| 'cpu': No GPU backend available | |
| """ | |
| if is_cuda_available(): | |
| return 'cuda' | |
| if is_mps_available(): | |
| return 'mps' | |
| return 'cpu' | |
| def get_device_list(include_none: bool = False, include_cpu: bool = False) -> List[str]: | |
| """ | |
| Get list of available compute devices for SeedVR2 | |
| Args: | |
| include_none: If True, prepend "none" to the device list (for offload options) | |
| include_cpu: If True, include "cpu" in the device list (for offload options only) | |
| Note: On MPS-only systems, "cpu" is automatically excluded since | |
| unified memory architecture makes CPU offloading meaningless | |
| Returns: | |
| List of device strings (e.g., ["cuda:0", "cuda:1"] or ["none", "cpu", "cuda:0", "cuda:1"]) | |
| """ | |
| devs = [] | |
| has_cuda = False | |
| has_mps = False | |
| try: | |
| if is_cuda_available(): | |
| devs += [f"cuda:{i}" for i in range(torch.cuda.device_count())] | |
| has_cuda = True | |
| except Exception: | |
| pass | |
| try: | |
| if is_mps_available(): | |
| devs.append("mps") # MPS doesn't use device indices | |
| has_mps = True | |
| except Exception: | |
| pass | |
| # Build result list with optional prefixes | |
| result = [] | |
| if include_none: | |
| result.append("none") | |
| # Only include "cpu" option if: | |
| # 1. It was requested (include_cpu=True), AND | |
| # 2. Either CUDA is available OR MPS is not the only option | |
| # Rationale: On MPS-only systems with unified memory architecture, | |
| # CPU offloading is semantically meaningless as CPU and GPU share the same memory pool | |
| if include_cpu and (has_cuda or not has_mps): | |
| result.append("cpu") | |
| result.extend(devs) | |
| return result if result else [] | |
| def get_basic_vram_info(device: Optional[torch.device] = None) -> Dict[str, Any]: | |
| """ | |
| Get basic VRAM availability info (free and total memory). | |
| Used for capacity planning and initial checks. | |
| Args: | |
| device: Optional device to query. If None, uses cuda:0 | |
| Returns: | |
| dict: {"free_gb": float, "total_gb": float} or {"error": str} | |
| """ | |
| try: | |
| if is_cuda_available(): | |
| if device is None: | |
| device = torch.device("cuda:0") | |
| elif not isinstance(device, torch.device): | |
| device = torch.device(device) | |
| free_memory, total_memory = torch.cuda.mem_get_info(device) | |
| elif is_mps_available(): | |
| # MPS doesn't support per-device queries or mem_get_info | |
| # Use system memory as proxy | |
| mem = psutil.virtual_memory() | |
| free_memory = mem.total - mem.used | |
| total_memory = mem.total | |
| else: | |
| return {"error": "No GPU backend available (CUDA/MPS)"} | |
| return { | |
| "free_gb": free_memory / (1024**3), | |
| "total_gb": total_memory / (1024**3) | |
| } | |
| except Exception as e: | |
| return {"error": f"Failed to get memory info: {str(e)}"} | |
| # Initial VRAM check at module load | |
| vram_info = get_basic_vram_info(device=None) | |
| if "error" not in vram_info: | |
| backend = "MPS" if is_mps_available() else "CUDA" | |
| print(f"📊 Initial {backend} memory: {vram_info['free_gb']:.2f}GB free / {vram_info['total_gb']:.2f}GB total") | |
| else: | |
| print(f"⚠️ Memory check failed: {vram_info['error']} - No available backend!") | |
| def get_vram_usage(device: Optional[torch.device] = None, debug: Optional['Debug'] = None) -> Tuple[float, float, float, float]: | |
| """ | |
| Get current VRAM usage metrics for monitoring. | |
| Used for tracking memory consumption during processing. | |
| Args: | |
| device: Optional device to query. If None, uses cuda:0 | |
| debug: Optional debug instance for logging | |
| Returns: | |
| tuple: (allocated_gb, reserved_gb, peak_allocated_gb, peak_reserved_gb) | |
| Returns (0, 0, 0, 0) if no GPU available | |
| """ | |
| try: | |
| if is_cuda_available(): | |
| if device is None: | |
| device = torch.device("cuda:0") | |
| elif not isinstance(device, torch.device): | |
| device = torch.device(device) | |
| allocated = torch.cuda.memory_allocated(device) / (1024**3) | |
| reserved = torch.cuda.memory_reserved(device) / (1024**3) | |
| peak_allocated = torch.cuda.max_memory_allocated(device) / (1024**3) | |
| peak_reserved = torch.cuda.max_memory_reserved(device) / (1024**3) | |
| return allocated, reserved, peak_allocated, peak_reserved | |
| elif is_mps_available(): | |
| # MPS doesn't support per-device queries - uses global memory tracking | |
| allocated = torch.mps.current_allocated_memory() / (1024**3) | |
| reserved = torch.mps.driver_allocated_memory() / (1024**3) | |
| # MPS doesn't track peak separately | |
| return allocated, reserved, allocated, reserved | |
| except Exception as e: | |
| if debug: | |
| debug.log(f"Failed to get VRAM usage: {e}", level="WARNING", category="memory", force=True) | |
| return 0.0, 0.0, 0.0, 0.0 | |
| def get_ram_usage(debug: Optional['Debug'] = None) -> Tuple[float, float, float, float]: | |
| """ | |
| Get current RAM usage metrics for the current process. | |
| Provides accurate tracking of process-specific memory consumption. | |
| Args: | |
| debug: Optional debug instance for logging | |
| Returns: | |
| tuple: (process_gb, available_gb, total_gb, used_by_others_gb) | |
| Returns (0, 0, 0, 0) if psutil not available or on error | |
| """ | |
| try: | |
| if not psutil: | |
| return 0.0, 0.0, 0.0, 0.0 | |
| # Get current process memory | |
| process = psutil.Process() | |
| process_memory = process.memory_info() | |
| process_gb = process_memory.rss / (1024**3) | |
| # Get system memory | |
| sys_memory = psutil.virtual_memory() | |
| total_gb = sys_memory.total / (1024**3) | |
| available_gb = sys_memory.available / (1024**3) | |
| # Calculate memory used by other processes | |
| # This is the CORRECT calculation: | |
| total_used_gb = total_gb - available_gb # Total memory used by ALL processes | |
| used_by_others_gb = max(0, total_used_gb - process_gb) # Subtract current process | |
| return process_gb, available_gb, total_gb, used_by_others_gb | |
| except Exception as e: | |
| if debug: | |
| debug.log(f"Failed to get RAM usage: {e}", level="WARNING", category="memory", force=True) | |
| return 0.0, 0.0, 0.0, 0.0 | |
| # Global cache for OS libraries (initialized once) | |
| _os_memory_lib = None | |
| def clear_memory(debug: Optional['Debug'] = None, deep: bool = False, force: bool = True, | |
| timer_name: Optional[str] = None) -> None: | |
| """ | |
| Clear memory caches with two-tier approach for optimal performance. | |
| Args: | |
| debug: Debug instance for logging (optional) | |
| force: If True, always clear. If False, only clear when <5% free | |
| deep: If True, perform deep cleanup including GC and OS operations. | |
| If False (default), only perform minimal GPU cache clearing. | |
| timer_name: Optional suffix for timer names to make them unique per invocation | |
| Two-tier approach: | |
| - Minimal mode (deep=False): GPU cache operations (~1-5ms) | |
| Used for frequent calls during batch processing | |
| - Deep mode (deep=True): Complete cleanup with GC and OS operations (~10-50ms) | |
| Used at key points like model switches or final cleanup | |
| """ | |
| global _os_memory_lib | |
| # Create unique timer names if suffix provided | |
| if timer_name: | |
| main_timer = f"memory_clear_{timer_name}" | |
| gpu_timer = f"gpu_cache_clear_{timer_name}" | |
| gc_timer = f"garbage_collection_{timer_name}" | |
| os_timer = f"os_memory_release_{timer_name}" | |
| completion_msg = f"clear_memory() completion ({timer_name})" | |
| else: | |
| main_timer = "memory_clear" | |
| gpu_timer = "gpu_cache_clear" | |
| gc_timer = "garbage_collection" | |
| os_timer = "os_memory_release" | |
| completion_msg = "clear_memory() completion" | |
| # Start timer for entire operation | |
| if debug: | |
| debug.start_timer(main_timer) | |
| # Check if we should clear based on memory pressure | |
| if not force: | |
| should_clear = False | |
| # Use existing function for memory info | |
| mem_info = get_basic_vram_info(device=None) | |
| if "error" not in mem_info and mem_info["total_gb"] > 0: | |
| # Check VRAM/MPS memory pressure (5% free threshold) | |
| free_ratio = mem_info["free_gb"] / mem_info["total_gb"] | |
| if free_ratio < 0.05: | |
| should_clear = True | |
| if debug: | |
| backend = "Unified Memory" if is_mps_available() else "VRAM" | |
| debug.log(f"{backend} pressure: {mem_info['free_gb']:.2f}GB free of {mem_info['total_gb']:.2f}GB", category="memory") | |
| # For non-MPS systems, also check system RAM separately | |
| if not should_clear and not is_mps_available(): | |
| mem = psutil.virtual_memory() | |
| if mem.available < mem.total * 0.05: | |
| should_clear = True | |
| if debug: | |
| debug.log(f"RAM pressure: {mem.available/(1024**3):.2f}GB free of {mem.total/(1024**3):.2f}GB", category="memory") | |
| if not should_clear: | |
| # End timer before early return to keep stack clean | |
| if debug: | |
| debug.end_timer(main_timer) | |
| return | |
| # Determine cleanup level | |
| cleanup_mode = "deep" if deep else "minimal" | |
| if debug: | |
| debug.log(f"Clearing memory caches ({cleanup_mode})...", category="cleanup") | |
| # ===== MINIMAL OPERATIONS (Always performed) ===== | |
| # Step 1: Clear GPU caches - Fast operations (~1-5ms) | |
| if debug: | |
| debug.start_timer(gpu_timer) | |
| if is_cuda_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| elif is_mps_available(): | |
| torch.mps.empty_cache() | |
| if debug: | |
| debug.end_timer(gpu_timer, "GPU cache clearing") | |
| # ===== DEEP OPERATIONS (Only when deep=True) ===== | |
| if deep: | |
| # Step 2: Deep garbage collection (expensive ~5-20ms) | |
| if debug: | |
| debug.start_timer(gc_timer) | |
| gc.collect(2) | |
| if debug: | |
| debug.end_timer(gc_timer, "Garbage collection") | |
| # Step 3: Return memory to OS (platform-specific, ~5-30ms) | |
| if debug: | |
| debug.start_timer(os_timer) | |
| try: | |
| if sys.platform == 'linux': | |
| # Linux: malloc_trim | |
| import ctypes # Import only when needed | |
| if _os_memory_lib is None: | |
| _os_memory_lib = ctypes.CDLL("libc.so.6") | |
| _os_memory_lib.malloc_trim(0) | |
| elif sys.platform == 'win32': | |
| # Windows: Trim working set | |
| import ctypes # Import only when needed | |
| if _os_memory_lib is None: | |
| _os_memory_lib = ctypes.windll.kernel32 | |
| handle = _os_memory_lib.GetCurrentProcess() | |
| _os_memory_lib.SetProcessWorkingSetSize(handle, -1, -1) | |
| elif is_mps_available(): | |
| # macOS with MPS | |
| import ctypes # Import only when needed | |
| import ctypes.util | |
| if _os_memory_lib is None: | |
| libc_path = ctypes.util.find_library('c') | |
| if libc_path: | |
| _os_memory_lib = ctypes.CDLL(libc_path) | |
| if _os_memory_lib: | |
| _os_memory_lib.sync() | |
| except Exception as e: | |
| if debug: | |
| debug.log(f"Failed to perform OS memory operations: {e}", level="WARNING", category="memory", force=True) | |
| if debug: | |
| debug.end_timer(os_timer, "OS memory release") | |
| # End overall timer | |
| if debug: | |
| debug.end_timer(main_timer, completion_msg) | |
| def retry_on_oom(func, *args, debug=None, operation_name="operation", **kwargs): | |
| """ | |
| Execute function with single OOM retry after memory cleanup. | |
| Args: | |
| func: Callable to execute | |
| *args: Positional arguments for func | |
| debug: Debug instance for logging (optional) | |
| operation_name: Name for logging | |
| **kwargs: Keyword arguments for func | |
| Returns: | |
| Result of func(*args, **kwargs) | |
| """ | |
| try: | |
| return func(*args, **kwargs) | |
| except (torch.cuda.OutOfMemoryError, RuntimeError) as e: | |
| # Only handle OOM errors | |
| if not any(x in str(e).lower() for x in ["out of memory", "allocation on device"]): | |
| raise | |
| if debug: | |
| debug.log(f"OOM during {operation_name}: {e}", level="WARNING", category="memory", force=True) | |
| debug.log(f"Clearing memory and retrying", category="info", force=True) | |
| # Clear memory | |
| clear_memory(debug=debug, deep=True, force=True, timer_name=operation_name) | |
| # Let memory settle | |
| time.sleep(0.5) | |
| debug.log_memory_state("After memory clearing", show_tensors=False, detailed_tensors=False) | |
| # Single retry | |
| try: | |
| result = func(*args, **kwargs) | |
| if debug: | |
| debug.log(f"Retry successful for {operation_name}", category="success", force=True) | |
| return result | |
| except Exception as retry_e: | |
| if debug: | |
| debug.log(f"Retry failed for {operation_name}: {retry_e}", level="ERROR", category="memory", force=True) | |
| raise | |
| def reset_vram_peak(device: Optional[torch.device] = None, debug: Optional['Debug'] = None) -> None: | |
| """ | |
| Reset VRAM peak memory statistics for fresh tracking. | |
| Args: | |
| device: Optional device to reset stats for. If None, uses cuda:0 | |
| debug: Optional debug instance for logging | |
| """ | |
| if debug and debug.enabled: | |
| debug.log("Resetting VRAM peak memory statistics", category="memory") | |
| try: | |
| if is_cuda_available(): | |
| if device is None: | |
| device = torch.device("cuda:0") | |
| elif not isinstance(device, torch.device): | |
| device = torch.device(device) | |
| torch.cuda.reset_peak_memory_stats(device) | |
| # Note: MPS doesn't support peak memory reset - no action needed | |
| except Exception as e: | |
| if debug and debug.enabled: | |
| debug.log(f"Failed to reset peak memory stats: {e}", level="WARNING", category="memory", force=True) | |
| def clear_rope_lru_caches(model: Optional[torch.nn.Module], debug: Optional['Debug'] = None) -> int: | |
| """ | |
| Clear ALL LRU caches from RoPE modules. | |
| Args: | |
| model: PyTorch model to clear caches from | |
| debug: Optional debug instance for logging | |
| Returns: | |
| Number of caches cleared | |
| """ | |
| if model is None: | |
| return 0 | |
| cleared_count = 0 | |
| try: | |
| for name, module in model.named_modules(): | |
| if hasattr(module, 'get_axial_freqs') and hasattr(module.get_axial_freqs, 'cache_clear'): | |
| try: | |
| module.get_axial_freqs.cache_clear() | |
| cleared_count += 1 | |
| except Exception as e: | |
| if debug: | |
| debug.log(f"Failed to clear RoPE LRU cache for module {name}: {e}", level="WARNING", category="memory", force=True) | |
| except (AttributeError, RuntimeError) as e: | |
| if debug: | |
| debug.log(f"Failed to iterate model modules for RoPE LRU cache clearing: {e}", level="WARNING", category="memory", force=True) | |
| return cleared_count | |
| def release_tensor_memory(tensor: Optional[torch.Tensor]) -> None: | |
| """Release tensor memory from any device (CPU/CUDA/MPS)""" | |
| if tensor is not None and torch.is_tensor(tensor): | |
| # Release storage for all devices (CPU, CUDA, MPS) | |
| if tensor.numel() > 0: | |
| tensor.data.set_() | |
| tensor.grad = None | |
| def release_tensor_collection(collection: Any, recursive: bool = True) -> None: | |
| """ | |
| Release GPU memory from tensors in any collection (list, tuple, dict, or single tensor). | |
| Args: | |
| collection: Tensor, list, tuple, dict, or nested structure to release | |
| recursive: If True, handle nested structures recursively | |
| Examples: | |
| release_tensor_collection(tensor) # Single tensor | |
| release_tensor_collection([tensor1, tensor2]) # List of tensors | |
| release_tensor_collection([[t1, t2], [t3, t4]]) # Nested lists | |
| release_tensor_collection({'a': tensor}) # Dict values | |
| """ | |
| if collection is None: | |
| return | |
| if torch.is_tensor(collection): | |
| release_tensor_memory(collection) | |
| elif isinstance(collection, dict): | |
| for value in collection.values(): | |
| if recursive: | |
| release_tensor_collection(value, recursive=True) | |
| elif torch.is_tensor(value): | |
| release_tensor_memory(value) | |
| elif isinstance(collection, (list, tuple)): | |
| for item in collection: | |
| if recursive: | |
| release_tensor_collection(item, recursive=True) | |
| elif torch.is_tensor(item): | |
| release_tensor_memory(item) | |
| def release_text_embeddings(*embeddings: torch.Tensor, debug: Optional['Debug'] = None, names: Optional[List[str]] = None) -> None: | |
| """ | |
| Release memory for text embeddings | |
| Args: | |
| *embeddings: Variable number of embedding tensors to release | |
| debug: Optional debug instance for logging | |
| names: Optional list of names for logging | |
| """ | |
| for i, embedding in enumerate(embeddings): | |
| if embedding is not None: | |
| release_tensor_memory(embedding) | |
| if debug and names and i < len(names): | |
| debug.log(f"Cleaned up {names[i]}", category="cleanup") | |
| def cleanup_text_embeddings(ctx: Dict[str, Any], debug: Optional['Debug'] = None) -> None: | |
| """ | |
| Clean up text embeddings from a context dictionary. | |
| Extracts embeddings, releases memory, and clears the context entry. | |
| Args: | |
| ctx: Context dictionary potentially containing 'text_embeds' | |
| debug: Optional debug instance for logging | |
| """ | |
| if not ctx or not ctx.get('text_embeds'): | |
| return | |
| embeddings = [] | |
| names = [] | |
| for key, embeds_list in ctx['text_embeds'].items(): | |
| if embeds_list: | |
| embeddings.extend(embeds_list) | |
| names.append(key) | |
| if embeddings: | |
| release_text_embeddings(embeddings, names, debug) | |
| if debug: | |
| debug.log(f"Cleaned up text embeddings: {', '.join(names)}", category="cleanup") | |
| ctx['text_embeds'] = None | |
| def release_model_memory(model: Optional[torch.nn.Module], debug: Optional['Debug'] = None) -> None: | |
| """ | |
| Release all GPU/MPS memory from model in-place without CPU transfer. | |
| Args: | |
| model: PyTorch model to release memory from | |
| debug: Optional debug instance for logging | |
| """ | |
| if model is None: | |
| return | |
| try: | |
| # Clear gradients first | |
| model.zero_grad(set_to_none=True) | |
| # Release GPU memory directly without CPU transfer | |
| released_params = 0 | |
| released_buffers = 0 | |
| for param in model.parameters(): | |
| if param.is_cuda or param.is_mps: | |
| if param.numel() > 0: | |
| param.data.set_() | |
| released_params += 1 | |
| param.grad = None | |
| for buffer in model.buffers(): | |
| if buffer.is_cuda or buffer.is_mps: | |
| if buffer.numel() > 0: | |
| buffer.data.set_() | |
| released_buffers += 1 | |
| if debug and (released_params > 0 or released_buffers > 0): | |
| debug.log(f"Released memory from {released_params} params and {released_buffers} buffers", category="success") | |
| except (AttributeError, RuntimeError) as e: | |
| if debug: | |
| debug.log(f"Failed to release model memory: {e}", level="WARNING", category="memory", force=True) | |
| def manage_tensor( | |
| tensor: torch.Tensor, | |
| target_device: torch.device, | |
| tensor_name: str = "tensor", | |
| dtype: Optional[torch.dtype] = None, | |
| non_blocking: bool = False, | |
| debug: Optional['Debug'] = None, | |
| reason: Optional[str] = None, | |
| indent_level: int = 0 | |
| ) -> torch.Tensor: | |
| """ | |
| Unified tensor management for device movement and dtype conversion. | |
| Handles both device transfers (CPU ↔ GPU) and dtype conversions (e.g., float16 → bfloat16) | |
| with intelligent early-exit optimization and comprehensive logging. | |
| Args: | |
| tensor: Tensor to manage | |
| target_device: Target device (torch.device object) | |
| tensor_name: Descriptive name for logging (e.g., "latent", "sample", "alpha_channel") | |
| dtype: Optional target dtype to cast to (if None, keeps original dtype) | |
| non_blocking: Whether to use non-blocking transfer | |
| debug: Debug instance for logging | |
| reason: Optional reason for the operation (e.g., "inference", "offload", "dtype alignment") | |
| indent_level: Indentation level for debug logging (0=no indent, 1=2 spaces, etc.) | |
| Returns: | |
| Tensor on target device with optional dtype conversion | |
| Note: | |
| - Skips operation if tensor already has target device and dtype (zero-copy) | |
| - Uses PyTorch's optimized .to() for efficient device/dtype handling | |
| - Logs all operations consistently for tracking and debugging | |
| """ | |
| if tensor is None: | |
| return tensor | |
| # Get current state | |
| current_device = tensor.device | |
| current_dtype = tensor.dtype | |
| target_dtype = dtype if dtype is not None else current_dtype | |
| # Check if movement is actually needed | |
| needs_device_move = _device_str(current_device) != _device_str(target_device) | |
| needs_dtype_change = dtype is not None and current_dtype != target_dtype | |
| if not needs_device_move and not needs_dtype_change: | |
| # Already on target device and dtype - skip | |
| return tensor | |
| # Determine reason for movement | |
| if reason is None: | |
| if needs_device_move and needs_dtype_change: | |
| reason = "device and dtype conversion" | |
| elif needs_device_move: | |
| reason = "device movement" | |
| else: | |
| reason = "dtype conversion" | |
| # Log the movement | |
| if debug: | |
| current_device_str = _device_str(current_device) | |
| target_device_str = _device_str(target_device) | |
| dtype_info = "" | |
| if needs_dtype_change: | |
| dtype_info = f", {current_dtype} → {target_dtype}" | |
| debug.log( | |
| f"Moving {tensor_name} from {current_device_str} to {target_device_str}{dtype_info} ({reason})", | |
| category="general", | |
| indent_level=indent_level | |
| ) | |
| # Perform the operation based on what needs to change | |
| if needs_device_move and needs_dtype_change: | |
| # Both device and dtype need to change | |
| return tensor.to(target_device, dtype=target_dtype, non_blocking=non_blocking) | |
| elif needs_device_move: | |
| # Only device needs to change | |
| return tensor.to(target_device, non_blocking=non_blocking) | |
| else: | |
| # Only dtype needs to change | |
| return tensor.to(dtype=target_dtype) | |
| def manage_model_device(model: torch.nn.Module, target_device: torch.device, model_name: str, | |
| debug: Optional['Debug'] = None, reason: Optional[str] = None, | |
| runner: Optional[Any] = None) -> bool: | |
| """ | |
| Move model to target device with optimizations. | |
| Handles BlockSwap-enabled models transparently. | |
| Args: | |
| model: The model to move | |
| target_device: Target device (torch.device object, e.g., torch.device('cuda:0')) | |
| model_name: Name for logging (e.g., "VAE", "DiT") | |
| debug: Debug instance for logging | |
| reason: Optional custom reason for the movement | |
| runner: Optional runner instance for BlockSwap detection | |
| Returns: | |
| bool: True if model was moved, False if already on target device | |
| """ | |
| if model is None: | |
| return False | |
| # Check if this is a BlockSwap-enabled DiT model | |
| is_blockswap_model = False | |
| actual_model = model | |
| if runner and model_name == "DiT": | |
| # Import here to avoid circular dependency | |
| from .blockswap import is_blockswap_enabled | |
| # Check if BlockSwap config exists and is enabled | |
| has_blockswap_config = ( | |
| hasattr(runner, '_dit_block_swap_config') and | |
| is_blockswap_enabled(runner._dit_block_swap_config) | |
| ) | |
| if has_blockswap_config: | |
| is_blockswap_model = True | |
| # Get the actual model (handle CompatibleDiT wrapper) | |
| if hasattr(model, "dit_model"): | |
| actual_model = model.dit_model | |
| # Get current device | |
| try: | |
| current_device = next(model.parameters()).device | |
| except StopIteration: | |
| return False | |
| # Extract device type for comparison (both are torch.device objects) | |
| target_type = target_device.type | |
| current_device_upper = _device_str(current_device) | |
| target_device_upper = _device_str(target_device) | |
| # Compare normalized device types | |
| if current_device_upper == target_device_upper and not is_blockswap_model: | |
| # Already on target device type, no movement needed | |
| if debug: | |
| debug.log(f"{model_name} already on {current_device_upper}, skipping movement", category="general") | |
| return False | |
| # Handle BlockSwap models specially | |
| if is_blockswap_model: | |
| return _handle_blockswap_model_movement( | |
| runner, actual_model, current_device, target_device, target_type, | |
| model_name, debug, reason | |
| ) | |
| # Standard model movement (non-BlockSwap) | |
| return _standard_model_movement( | |
| model, current_device, target_device, target_type, model_name, | |
| debug, reason | |
| ) | |
| def _handle_blockswap_model_movement(runner: Any, model: torch.nn.Module, | |
| current_device: torch.device, target_device: torch.device, | |
| target_type: str, model_name: str, | |
| debug: Optional['Debug'] = None, reason: Optional[str] = None) -> bool: | |
| """ | |
| Handle device movement for BlockSwap-enabled models. | |
| Args: | |
| runner: Runner instance with BlockSwap configuration | |
| model: Model to move (actual unwrapped model) | |
| current_device: Current device of the model | |
| target_device: Target device (torch.device object) | |
| target_type: Target device type (cpu/cuda/mps) | |
| model_name: Model name for logging | |
| debug: Debug instance | |
| reason: Movement reason | |
| Returns: | |
| bool: True if model was moved | |
| """ | |
| # Import here to avoid circular dependency | |
| from .blockswap import set_blockswap_bypass | |
| if target_type == "cpu": | |
| # Moving to offload device (typically CPU) | |
| # Check if any parameter is on GPU (for accurate logging) | |
| actual_source_device = None | |
| for param in model.parameters(): | |
| if param.device.type in ['cuda', 'mps']: | |
| actual_source_device = param.device | |
| break | |
| source_device_desc = _device_str(actual_source_device) if actual_source_device else _device_str(target_device) | |
| if debug: | |
| debug.log(f"Moving {model_name} from {source_device_desc} to {_device_str(target_device)} ({reason or 'model caching'})", category="general") | |
| # Enable bypass to allow movement | |
| set_blockswap_bypass(runner=runner, bypass=True, debug=debug) | |
| # Start timer | |
| timer_name = f"{model_name.lower()}_to_{target_type}" | |
| if debug: | |
| debug.start_timer(timer_name) | |
| # Move entire model to target offload device | |
| model.to(target_device) | |
| model.zero_grad(set_to_none=True) | |
| if debug: | |
| debug.end_timer(timer_name, f"BlockSwap model offloaded to {_device_str(target_device)}") | |
| return True | |
| else: | |
| # Moving to GPU (reload) | |
| # Check if we're in bypass mode (coming from offload) | |
| if not getattr(model, "_blockswap_bypass_protection", False): | |
| # Not in bypass mode, blocks are already configured | |
| if debug: | |
| debug.log(f"{model_name} with BlockSwap active - blocks already distributed across devices, skipping movement", category="general") | |
| return False | |
| # Get actual current device for accurate logging | |
| actual_current_device = None | |
| for param in model.parameters(): | |
| if param.device.type != 'meta': | |
| actual_current_device = param.device | |
| break | |
| current_device_desc = _device_str(actual_current_device) if actual_current_device else "OFFLOAD" | |
| if debug: | |
| debug.log(f"Moving {model_name} from {current_device_desc} to {_device_str(target_device)} ({reason or 'inference requirement'})", category="general") | |
| timer_name = f"{model_name.lower()}_to_gpu" | |
| if debug: | |
| debug.start_timer(timer_name) | |
| # Restore blocks to their configured devices | |
| if hasattr(model, "blocks") and hasattr(model, "blocks_to_swap"): | |
| # Use configured offload_device from BlockSwap config | |
| offload_device = model._block_swap_config.get("offload_device") | |
| if not offload_device: | |
| raise ValueError("BlockSwap config missing offload_device") | |
| # Move blocks according to BlockSwap configuration | |
| for b, block in enumerate(model.blocks): | |
| if b > model.blocks_to_swap: | |
| # This block should be on GPU | |
| block.to(target_device) | |
| else: | |
| # This block stays on offload device (will be swapped during forward) | |
| block.to(offload_device) | |
| # Handle I/O components | |
| if not model._block_swap_config.get("swap_io_components", False): | |
| # I/O components should be on GPU if not offloaded | |
| for name, module in model.named_children(): | |
| if name != "blocks": | |
| module.to(target_device) | |
| else: | |
| # I/O components stay on offload device | |
| for name, module in model.named_children(): | |
| if name != "blocks": | |
| module.to(offload_device) | |
| if debug: | |
| # Get actual configuration from runner | |
| if hasattr(model, '_block_swap_config'): | |
| blocks_on_gpu = model._block_swap_config.get('total_blocks', 32) - model._block_swap_config.get('blocks_swapped', 16) | |
| total_blocks = model._block_swap_config.get('total_blocks', 32) | |
| main_device = model._block_swap_config.get('main_device', 'GPU') | |
| debug.log(f"BlockSwap blocks restored to configured devices ({blocks_on_gpu}/{total_blocks} blocks on {_device_str(main_device)})", category="success") | |
| else: | |
| debug.log("BlockSwap blocks restored to configured devices", category="success") | |
| # Reactivate BlockSwap now that blocks are restored to their configured devices | |
| runner._blockswap_active = True | |
| # Disable bypass, re-enable protection | |
| set_blockswap_bypass(runner=runner, bypass=False, debug=debug) | |
| if debug: | |
| debug.end_timer(timer_name, "BlockSwap model restored") | |
| return True | |
| def _standard_model_movement(model: torch.nn.Module, current_device: torch.device, | |
| target_device: torch.device, target_type: str, model_name: str, | |
| debug: Optional['Debug'] = None, reason: Optional[str] = None) -> bool: | |
| """ | |
| Handle standard (non-BlockSwap) model movement. | |
| Args: | |
| model: Model to move | |
| current_device: Current device of the model | |
| target_device: Target device (torch.device object) | |
| target_type: Target device type | |
| model_name: Model name for logging | |
| debug: Debug instance | |
| reason: Movement reason | |
| Returns: | |
| bool: True if model was moved | |
| """ | |
| # Check if model is on meta device - can't move meta tensors | |
| if current_device.type == 'meta': | |
| if debug: | |
| debug.log(f"{model_name} is on meta device - skipping movement (will materialize when needed)", | |
| category=model_name.lower()) | |
| return False | |
| # Determine reason for movement | |
| reason = reason or "inference requirement" | |
| # Log the movement with full device strings | |
| if debug: | |
| current_device_str = _device_str(current_device) | |
| target_device_str = _device_str(target_device) | |
| debug.log(f"Moving {model_name} from {current_device_str} to {target_device_str} ({reason})", category="general") | |
| # Start timer based on direction | |
| timer_name = f"{model_name.lower()}_to_{'gpu' if target_type != 'cpu' else 'cpu'}" | |
| if debug: | |
| debug.start_timer(timer_name) | |
| # Move model and clear gradients | |
| model.to(target_device) | |
| model.zero_grad(set_to_none=True) | |
| # Clear VAE memory buffers when moving to CPU | |
| if target_type == 'cpu' and model_name == "VAE": | |
| cleared_count = 0 | |
| for module in model.modules(): | |
| if hasattr(module, 'memory') and module.memory is not None: | |
| if torch.is_tensor(module.memory) and (module.memory.is_cuda or module.memory.is_mps): | |
| module.memory = None | |
| cleared_count += 1 | |
| if cleared_count > 0 and debug: | |
| debug.log(f"Cleared {cleared_count} VAE memory buffers", category="success") | |
| # End timer | |
| if debug: | |
| debug.end_timer(timer_name, f"{model_name} moved to {_device_str(target_device)}") | |
| return True | |
| def clear_runtime_caches(runner: Any, debug: Optional['Debug'] = None) -> int: | |
| """ | |
| Clear all runtime caches and temporary attributes. | |
| """ | |
| if not runner: | |
| return 0 | |
| if debug: | |
| debug.start_timer("runtime_cache_clear") | |
| cleaned_items = 0 | |
| # 1. Clear main runner cache | |
| if hasattr(runner, 'cache') and hasattr(runner.cache, 'cache'): | |
| if debug: | |
| debug.start_timer("runner_cache_clear") | |
| cache_entries = len(runner.cache.cache) | |
| # Properly release tensor memory and delete as we go | |
| for key in list(runner.cache.cache.keys()): | |
| value = runner.cache.cache[key] | |
| if torch.is_tensor(value): | |
| release_tensor_memory(value) | |
| elif isinstance(value, (list, tuple)): | |
| for item in value: | |
| if torch.is_tensor(item): | |
| release_tensor_memory(item) | |
| # Delete immediately to release reference | |
| del runner.cache.cache[key] | |
| # Final clear for safety | |
| runner.cache.cache.clear() | |
| cleaned_items += cache_entries | |
| if debug: | |
| debug.end_timer("runner_cache_clear", f"Clearing main runner cache entries") | |
| if cache_entries > 0: | |
| debug.log(f"Cleared {cache_entries} runtime cache entries", category="success") | |
| # 2. Clear RoPE caches | |
| if hasattr(runner, 'dit'): | |
| if debug: | |
| debug.start_timer("rope_cache_clear") | |
| model = runner.dit | |
| if hasattr(model, 'dit_model'): # Handle wrapper | |
| model = model.dit_model | |
| rope_cleared = clear_rope_lru_caches(model=model, debug=debug) | |
| cleaned_items += rope_cleared | |
| if debug: | |
| debug.end_timer("rope_cache_clear", "Clearing RoPE LRU caches") | |
| if rope_cleared > 0: | |
| debug.log(f"Cleared {rope_cleared} RoPE LRU caches", category="success") | |
| # 3. Clear temporary attributes | |
| temp_attrs = ['_temp_cache', '_block_cache', '_swap_cache', '_generation_cache', | |
| '_rope_cache', '_intermediate_cache', '_backward_cache'] | |
| for obj in [runner, getattr(runner, 'dit', None), getattr(runner, 'vae', None)]: | |
| if obj is None: | |
| continue | |
| actual_obj = obj.dit_model if hasattr(obj, 'dit_model') else obj | |
| for attr in temp_attrs: | |
| if hasattr(actual_obj, attr): | |
| delattr(actual_obj, attr) | |
| cleaned_items += 1 | |
| if debug: | |
| debug.end_timer("runtime_cache_clear", f"clear_runtime_caches() completion") | |
| return cleaned_items | |
| def cleanup_dit(runner: Any, debug: Optional['Debug'] = None, cache_model: bool = False) -> None: | |
| """ | |
| Cleanup DiT model and BlockSwap state after upscaling phase. | |
| Called at the end of upscale_all_batches when DiT is no longer needed. | |
| Args: | |
| runner: Runner instance containing DiT model | |
| debug: Debug instance for logging | |
| cache_model: If True, move DiT to offload_device; if False, delete completely | |
| """ | |
| if not runner or not hasattr(runner, 'dit'): | |
| return | |
| if debug: | |
| debug.log("Cleaning up DiT components", category="cleanup") | |
| # 1. Clear DiT-specific runtime caches first | |
| if hasattr(runner, 'dit'): | |
| model = runner.dit | |
| if hasattr(model, 'dit_model'): # Handle wrapper | |
| model = model.dit_model | |
| # Clear RoPE caches | |
| rope_cleared = clear_rope_lru_caches(model=model, debug=debug) | |
| if rope_cleared > 0 and debug: | |
| debug.log(f"Cleared {rope_cleared} RoPE LRU caches", category="success") | |
| # Clear DiT temporary attributes | |
| temp_attrs = ['_temp_cache', '_block_cache', '_swap_cache', '_generation_cache', | |
| '_rope_cache', '_intermediate_cache', '_backward_cache'] | |
| actual_obj = model.dit_model if hasattr(model, 'dit_model') else model | |
| for attr in temp_attrs: | |
| if hasattr(actual_obj, attr): | |
| delattr(actual_obj, attr) | |
| # 2. Handle model offloading (for caching or before deletion) | |
| try: | |
| param_device = next(runner.dit.parameters()).device | |
| # Move model off GPU if needed | |
| if param_device.type not in ['meta', 'cpu']: | |
| # MPS: skip CPU movement before deletion (unified memory, just causes sync) | |
| if param_device.type == 'mps' and not cache_model: | |
| if debug: | |
| debug.log("DiT on MPS - skipping CPU movement before deletion", category="cleanup") | |
| else: | |
| offload_target = getattr(runner, '_dit_offload_device', None) | |
| if offload_target is None or offload_target == 'none': | |
| offload_target = torch.device('cpu') | |
| reason = "model caching" if cache_model else "releasing GPU memory" | |
| manage_model_device(model=runner.dit, target_device=offload_target, model_name="DiT", | |
| debug=debug, reason=reason, runner=runner) | |
| elif param_device.type == 'meta' and debug: | |
| debug.log("DiT on meta device - keeping structure for cache", category="cleanup") | |
| except StopIteration: | |
| pass | |
| # 3. Clean BlockSwap after model movement | |
| if hasattr(runner, "_blockswap_active") and runner._blockswap_active: | |
| # Import here to avoid circular dependency | |
| from .blockswap import cleanup_blockswap | |
| cleanup_blockswap(runner=runner, keep_state_for_cache=cache_model) | |
| # 4. Complete cleanup if not caching | |
| if not cache_model: | |
| release_model_memory(model=runner.dit, debug=debug) | |
| runner.dit = None | |
| if debug: | |
| debug.log("DiT model deleted", category="cleanup") | |
| # Clear DiT config attributes - not needed when model is not cached (will be recreated) | |
| if hasattr(runner, '_dit_compile_args'): | |
| delattr(runner, '_dit_compile_args') | |
| if hasattr(runner, '_dit_block_swap_config'): | |
| delattr(runner, '_dit_block_swap_config') | |
| if hasattr(runner, '_dit_attention_mode'): | |
| delattr(runner, '_dit_attention_mode') | |
| # 5. Clear DiT temporary attributes (should be already cleared in materialize_model) | |
| runner._dit_checkpoint = None | |
| runner._dit_dtype_override = None | |
| # 6. Clear DiT-related components and temporary attributes | |
| runner.sampler = None | |
| runner.sampling_timesteps = None | |
| runner.schedule = None | |
| def cleanup_vae(runner: Any, debug: Optional['Debug'] = None, cache_model: bool = False) -> None: | |
| """ | |
| Cleanup VAE model after decoding phase. | |
| Called at the end of decode_all_batches when VAE is no longer needed. | |
| Args: | |
| runner: Runner instance containing VAE model | |
| debug: Debug instance for logging | |
| cache_model: If True, move VAE to offload_device; if False, delete completely | |
| """ | |
| if not runner or not hasattr(runner, 'vae'): | |
| return | |
| if debug: | |
| debug.log("Cleaning up VAE components", category="cleanup") | |
| # 1. Clear VAE-specific temporary attributes | |
| if hasattr(runner, 'vae'): | |
| temp_attrs = ['_temp_cache', '_block_cache', '_swap_cache', '_generation_cache', | |
| '_rope_cache', '_intermediate_cache', '_backward_cache'] | |
| for attr in temp_attrs: | |
| if hasattr(runner.vae, attr): | |
| delattr(runner.vae, attr) | |
| # 2. Handle model offloading (for caching or before deletion) | |
| try: | |
| param_device = next(runner.vae.parameters()).device | |
| # Move model off GPU if needed | |
| if param_device.type not in ['meta', 'cpu']: | |
| # MPS: skip CPU movement before deletion (unified memory, just causes sync) | |
| if param_device.type == 'mps' and not cache_model: | |
| if debug: | |
| debug.log("VAE on MPS - skipping CPU movement before deletion", category="cleanup") | |
| else: | |
| offload_target = getattr(runner, '_vae_offload_device', None) | |
| if offload_target is None or offload_target == 'none': | |
| offload_target = torch.device('cpu') | |
| reason = "model caching" if cache_model else "releasing GPU memory" | |
| manage_model_device(model=runner.vae, target_device=offload_target, model_name="VAE", | |
| debug=debug, reason=reason, runner=runner) | |
| elif param_device.type == 'meta' and debug: | |
| debug.log("VAE on meta device - keeping structure for cache", category="cleanup") | |
| except StopIteration: | |
| pass | |
| # 3. Complete cleanup if not caching | |
| if not cache_model: | |
| release_model_memory(model=runner.vae, debug=debug) | |
| runner.vae = None | |
| if debug: | |
| debug.log("VAE model deleted", category="cleanup") | |
| # Clear VAE config attributes - not needed when model is not cached (will be recreated) | |
| if hasattr(runner, '_vae_compile_args'): | |
| delattr(runner, '_vae_compile_args') | |
| if hasattr(runner, '_vae_tiling_config'): | |
| delattr(runner, '_vae_tiling_config') | |
| # 3. Clear VAE temporary attributes (should be already cleared in materialize_model) | |
| runner._vae_checkpoint = None | |
| runner._vae_dtype_override = None | |
| def complete_cleanup(runner: Any, debug: Optional['Debug'] = None, dit_cache: bool = False, vae_cache: bool = False) -> None: | |
| """ | |
| Complete cleanup of runner and remaining components with independent model caching support. | |
| This is a lightweight cleanup for final stage, as model-specific cleanup | |
| happens in their respective phases (cleanup_dit, cleanup_vae). | |
| Args: | |
| runner: Runner instance to clean up | |
| debug: Debug instance for logging | |
| dit_cache: If True, preserve DiT model on offload_device for future runs | |
| vae_cache: If True, preserve VAE model on offload_device for future runs | |
| Behavior: | |
| - Can cache DiT and VAE independently for flexible memory management | |
| - Preserves _dit_model_name and _vae_model_name when either model is cached for change detection | |
| - Clears all temporary attributes and runtime caches | |
| - Performs deep memory cleanup only when both models are fully released | |
| Note: | |
| Model name tracking (_dit_model_name, _vae_model_name) is only cleared if neither | |
| model is cached, enabling proper model change detection on subsequent runs. | |
| """ | |
| if not runner: | |
| return | |
| if debug: | |
| cleanup_type = "partial cleanup" if (dit_cache or vae_cache) else "full cleanup" | |
| debug.log(f"Starting {cleanup_type}", category="cleanup") | |
| # 1. Cleanup any remaining models if they still exist | |
| # (This handles cases where phases were skipped or errored) | |
| if hasattr(runner, 'dit') and runner.dit is not None: | |
| cleanup_dit(runner=runner, debug=debug, cache_model=dit_cache) | |
| if hasattr(runner, 'vae') and runner.vae is not None: | |
| cleanup_vae(runner=runner, debug=debug, cache_model=vae_cache) | |
| # 2. Clear remaining runtime caches | |
| clear_runtime_caches(runner=runner, debug=debug) | |
| # 3. Clear config and other non-model components when fully releasing runner | |
| if not (dit_cache or vae_cache): | |
| # Full cleanup - clear config and model tracking | |
| runner.config = None | |
| runner._dit_model_name = None | |
| runner._vae_model_name = None | |
| # 4. Final memory cleanup | |
| clear_memory(debug=debug, deep=True, force=True, timer_name="complete_cleanup") | |
| # 5. Clear cuBLAS workspaces | |
| torch._C._cuda_clearCublasWorkspaces() if hasattr(torch._C, '_cuda_clearCublasWorkspaces') else None | |
| # Log what models are cached for next run | |
| if dit_cache or vae_cache: | |
| cached_models = [] | |
| if dit_cache and hasattr(runner, '_dit_model_name'): | |
| cached_models.append(f"DiT ({runner._dit_model_name})") | |
| if vae_cache and hasattr(runner, '_vae_model_name'): | |
| cached_models.append(f"VAE ({runner._vae_model_name})") | |
| if cached_models: | |
| models_str = " and ".join(cached_models) | |
| debug.log(f"Models cached for next run: {models_str}", category="cache", force=True) | |
| if debug: | |
| debug.log(f"Completed {cleanup_type}", category="success") |