| """GPU memory management utilities for training and inference.""" |
|
|
| import functools |
| import gc |
| import subprocess |
| from typing import Callable, TypeVar |
|
|
| import torch |
|
|
| from ltx_trainer import logger |
|
|
| F = TypeVar("F", bound=Callable) |
|
|
|
|
| def free_gpu_memory(log: bool = False) -> None: |
| """Free GPU memory by running garbage collection and emptying CUDA cache. |
| Args: |
| log: If True, log memory stats after clearing |
| """ |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| if log: |
| allocated = torch.cuda.memory_allocated() / 1024**3 |
| reserved = torch.cuda.memory_reserved() / 1024**3 |
| logger.debug(f"GPU memory freed. Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB") |
|
|
|
|
| class free_gpu_memory_context: |
| """Context manager and decorator to free GPU memory before and/or after execution. |
| Can be used as a decorator: |
| @free_gpu_memory_context(after=True) |
| def my_function(): |
| ... |
| Or as a context manager: |
| with free_gpu_memory_context(): |
| heavy_operation() |
| Args: |
| before: Free memory before execution (default: False) |
| after: Free memory after execution (default: True) |
| log: Log memory stats when freeing (default: False) |
| """ |
|
|
| def __init__(self, *, before: bool = False, after: bool = True, log: bool = False) -> None: |
| self.before = before |
| self.after = after |
| self.log = log |
|
|
| def __enter__(self) -> "free_gpu_memory_context": |
| if self.before: |
| free_gpu_memory(log=self.log) |
| return self |
|
|
| def __exit__(self, exc_type: type | None, exc_val: Exception | None, exc_tb: object) -> None: |
| if self.after: |
| free_gpu_memory(log=self.log) |
|
|
| def __call__(self, func: F) -> F: |
| @functools.wraps(func) |
| def wrapper(*args, **kwargs) -> object: |
| with self: |
| return func(*args, **kwargs) |
|
|
| return wrapper |
|
|
|
|
| def get_gpu_memory_gb(device: torch.device) -> float: |
| """Get current GPU memory usage in GB using nvidia-smi. |
| Args: |
| device: torch.device to get memory usage for |
| Returns: |
| Current GPU memory usage in GB |
| """ |
| try: |
| device_id = device.index if device.index is not None else 0 |
| result = subprocess.check_output( |
| [ |
| "nvidia-smi", |
| "--query-gpu=memory.used", |
| "--format=csv,nounits,noheader", |
| "-i", |
| str(device_id), |
| ], |
| encoding="utf-8", |
| ) |
| return float(result.strip()) / 1024 |
| except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e: |
| logger.error(f"Failed to get GPU memory from nvidia-smi: {e}") |
| |
| return torch.cuda.memory_allocated(device) / 1024**3 |
|
|