Spaces:
Sleeping
Sleeping
File size: 6,009 Bytes
01d5a5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
"""Memory management utilities for PyTorch training.
This module provides lightweight utilities to monitor memory usage
and configure PyTorch's built-in memory management features.
"""
import os
import gc
import logging
import psutil
import torch
from typing import Dict, Any
# Configure logging
logger = logging.getLogger(__name__)
class MemoryManager:
"""Simple memory manager that leverages PyTorch's built-in memory optimizations."""
def __init__(self):
"""Initialize the memory manager."""
self.cuda_available = torch.cuda.is_available()
self.process = psutil.Process(os.getpid())
# Remove redundant environment variable setting - now handled in train_for_user.sh
def get_memory_info(self) -> Dict[str, Any]:
"""Get current memory usage information."""
info = {
"ram_used_percent": psutil.virtual_memory().percent,
"ram_used_gb": psutil.virtual_memory().used / (1024**3),
"ram_available_gb": psutil.virtual_memory().available / (1024**3),
"ram_total_gb": psutil.virtual_memory().total / (1024**3),
}
if self.cuda_available:
try:
info.update({
"vram_used_gb": torch.cuda.memory_allocated() / (1024**3),
"vram_reserved_gb": torch.cuda.memory_reserved() / (1024**3),
"vram_total_gb": torch.cuda.get_device_properties(0).total_memory / (1024**3),
})
except RuntimeError as e:
logger.warning(f"Error getting CUDA memory info: {str(e)}")
self.cuda_available = False
return info
def cleanup_memory(self, force: bool = False) -> None:
"""Free up memory by garbage collection and emptying CUDA cache."""
# Run Python garbage collection
gc.collect()
# Empty CUDA cache if available
if self.cuda_available:
torch.cuda.empty_cache()
# Log memory status after cleanup
if force:
info = self.get_memory_info()
logger.info(
f"Memory after cleanup: RAM: {info['ram_used_gb']:.2f}GB / {info['ram_total_gb']:.2f}GB, "
f"VRAM: {info.get('vram_used_gb', 0):.2f}GB / {info.get('vram_total_gb', 0):.2f}GB"
)
def get_optimal_training_config(self) -> Dict[str, Any]:
"""Get recommended configurations for model training based on hardware capabilities."""
# Default configs that rely on PyTorch's automatic memory management
config = {
"device_map": "auto",
"fp16": False,
"bf16": False,
"gradient_checkpointing": True,
"gradient_accumulation_steps": 1,
}
# Enable mixed precision based on hardware support
if self.cuda_available:
capability = torch.cuda.get_device_capability()
if capability[0] >= 8: # Ampere or newer (supports BF16)
config["bf16"] = True
elif capability[0] >= 7: # Volta or newer (supports FP16)
config["fp16"] = True
# Adjust accumulation steps based on available memory
vram_gb = self.get_memory_info().get("vram_total_gb", 0)
if vram_gb < 8: # Small GPUs
config["gradient_accumulation_steps"] = 4
elif vram_gb < 16: # Medium GPUs
config["gradient_accumulation_steps"] = 2
return config
def optimize_model_for_training(self, model):
"""Apply PyTorch's built-in memory optimizations for training."""
# Enable gradient checkpointing if available
if hasattr(model, "gradient_checkpointing_enable"):
logger.info("Enabling gradient checkpointing for memory efficiency")
model.gradient_checkpointing_enable()
# Enable memory-efficient attention for PyTorch 2.0+
if hasattr(model, "config"):
try:
model.config.use_memory_efficient_attention = True
except:
pass
# Enable flash attention for compatible GPUs
if self.cuda_available and torch.cuda.get_device_capability()[0] >= 8:
try:
model.config.attn_implementation = "flash_attention_2"
except:
pass
return model
def optimize_training_args(self, training_args):
"""Configure training arguments for efficient memory usage."""
if not training_args:
return None
# Get optimal configuration based on hardware
config = self.get_optimal_training_config()
# Apply configurations to training arguments
if not getattr(training_args, "fp16", False) and not getattr(training_args, "bf16", False):
training_args.fp16 = config["fp16"]
training_args.bf16 = config["bf16"]
if not getattr(training_args, "gradient_checkpointing", False):
training_args.gradient_checkpointing = config["gradient_checkpointing"]
if training_args.gradient_accumulation_steps == 1:
training_args.gradient_accumulation_steps = config["gradient_accumulation_steps"]
logger.info("Training configuration optimized for memory efficiency:")
logger.info(f" Mixed precision: FP16={training_args.fp16}, BF16={training_args.bf16}")
logger.info(f" Gradient checkpointing: {training_args.gradient_checkpointing}")
logger.info(f" Gradient accumulation steps: {training_args.gradient_accumulation_steps}")
return training_args
# Global memory manager instance
memory_manager = MemoryManager()
def get_memory_manager() -> MemoryManager:
"""Get the global memory manager instance."""
return memory_manager |