File size: 5,537 Bytes
7275aef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | # pipelines/memory_monitor.py
import torch
import gc
import psutil
import os
from typing import Dict, Any, Optional
from rich.console import Console
console = Console()
class MemoryMonitor:
"""Memory monitoring and error recovery for distributed training"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.distributed = config.get("distributed", False)
self.rank = config.get("rank", 0)
self.device = config.get("device", "cuda:0")
self.memory_threshold = 0.85 # 85% memory usage threshold
self.cleanup_frequency = 10 # Cleanup every 10 steps
def check_memory_usage(self) -> Dict[str, float]:
"""Check current memory usage"""
if not torch.cuda.is_available():
return {"gpu_memory": 0.0, "cpu_memory": 0.0}
# GPU memory
gpu_memory = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated()
gpu_memory = min(gpu_memory, 1.0) # Cap at 100%
# CPU memory
cpu_memory = psutil.virtual_memory().percent / 100.0
return {
"gpu_memory": gpu_memory,
"cpu_memory": cpu_memory
}
def should_cleanup(self, step: int) -> bool:
"""Check if memory cleanup is needed"""
if step % self.cleanup_frequency != 0:
return False
memory_usage = self.check_memory_usage()
return memory_usage["gpu_memory"] > self.memory_threshold
def cleanup_memory(self) -> None:
"""Clean up GPU memory"""
if not torch.cuda.is_available():
return
# Clear CUDA cache
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Force garbage collection
gc.collect()
# Log memory usage
memory_usage = self.check_memory_usage()
if self.rank == 0: # Only log from rank 0
console.print(f"[blue]🧹 Memory cleanup: GPU {memory_usage['gpu_memory']:.1%}, CPU {memory_usage['cpu_memory']:.1%}[/blue]")
def monitor_training_step(self, step: int, model, optimizer) -> bool:
"""Monitor training step and handle memory issues"""
try:
# Check if cleanup is needed
if self.should_cleanup(step):
self.cleanup_memory()
# Check for OOM
memory_usage = self.check_memory_usage()
if memory_usage["gpu_memory"] > 0.95: # 95% threshold for OOM
console.print(f"[red]⚠️ High memory usage: {memory_usage['gpu_memory']:.1%}[/red]")
self.cleanup_memory()
# If still high, reduce batch size
if memory_usage["gpu_memory"] > 0.90:
console.print("[yellow]⚠️ Reducing batch size due to memory pressure[/yellow]")
return False # Signal to reduce batch size
return True
except RuntimeError as e:
if "out of memory" in str(e).lower():
console.print(f"[red]❌ OOM detected: {e}[/red]")
self.cleanup_memory()
return False # Signal to reduce batch size
else:
raise e
def handle_nccl_error(self, error: Exception) -> bool:
"""Handle NCCL errors with recovery"""
error_str = str(error).lower()
if "nccl" in error_str or "cuda error" in error_str:
console.print(f"[red]❌ NCCL/CUDA error detected: {error}[/red]")
# Clean up memory
self.cleanup_memory()
# Check if we can recover
memory_usage = self.check_memory_usage()
if memory_usage["gpu_memory"] < 0.80: # If memory is low enough
console.print("[yellow]🔄 Attempting recovery...[/yellow]")
return True # Try to recover
else:
console.print("[red]❌ Memory too high for recovery, falling back to single GPU[/red]")
return False # Fall back to single GPU
return False # Not an NCCL error
def get_memory_stats(self) -> Dict[str, Any]:
"""Get detailed memory statistics"""
if not torch.cuda.is_available():
return {"gpu_available": False}
device = torch.device(self.device)
memory_allocated = torch.cuda.memory_allocated(device)
memory_reserved = torch.cuda.memory_reserved(device)
memory_max = torch.cuda.max_memory_allocated(device)
return {
"gpu_available": True,
"device": str(device),
"memory_allocated": memory_allocated,
"memory_reserved": memory_reserved,
"memory_max": memory_max,
"memory_allocated_gb": memory_allocated / 1024**3,
"memory_reserved_gb": memory_reserved / 1024**3,
"memory_max_gb": memory_max / 1024**3,
}
def log_memory_stats(self, step: int) -> None:
"""Log memory statistics"""
if self.rank != 0: # Only log from rank 0
return
stats = self.get_memory_stats()
if stats["gpu_available"]:
console.print(f"[blue]📊 Step {step}: GPU Memory - Allocated: {stats['memory_allocated_gb']:.2f}GB, "
f"Reserved: {stats['memory_reserved_gb']:.2f}GB, Max: {stats['memory_max_gb']:.2f}GB[/blue]")
|