File size: 6,009 Bytes
01d5a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Memory management utilities for PyTorch training.

This module provides lightweight utilities to monitor memory usage
and configure PyTorch's built-in memory management features.
"""

import os
import gc
import logging
import psutil
import torch
from typing import Dict, Any

# Configure logging
logger = logging.getLogger(__name__)

class MemoryManager:
    """Simple memory manager that leverages PyTorch's built-in memory optimizations."""
    
    def __init__(self):
        """Initialize the memory manager."""
        self.cuda_available = torch.cuda.is_available()
        self.process = psutil.Process(os.getpid())
        
        # Remove redundant environment variable setting - now handled in train_for_user.sh
    
    def get_memory_info(self) -> Dict[str, Any]:
        """Get current memory usage information."""
        info = {
            "ram_used_percent": psutil.virtual_memory().percent,
            "ram_used_gb": psutil.virtual_memory().used / (1024**3),
            "ram_available_gb": psutil.virtual_memory().available / (1024**3),
            "ram_total_gb": psutil.virtual_memory().total / (1024**3),
        }
        
        if self.cuda_available:
            try:
                info.update({
                    "vram_used_gb": torch.cuda.memory_allocated() / (1024**3),
                    "vram_reserved_gb": torch.cuda.memory_reserved() / (1024**3),
                    "vram_total_gb": torch.cuda.get_device_properties(0).total_memory / (1024**3),
                })
            except RuntimeError as e:
                logger.warning(f"Error getting CUDA memory info: {str(e)}")
                self.cuda_available = False
        
        return info
    
    def cleanup_memory(self, force: bool = False) -> None:
        """Free up memory by garbage collection and emptying CUDA cache."""
        # Run Python garbage collection
        gc.collect()
        
        # Empty CUDA cache if available
        if self.cuda_available:
            torch.cuda.empty_cache()
            
        # Log memory status after cleanup
        if force:
            info = self.get_memory_info()
            logger.info(
                f"Memory after cleanup: RAM: {info['ram_used_gb']:.2f}GB / {info['ram_total_gb']:.2f}GB, "
                f"VRAM: {info.get('vram_used_gb', 0):.2f}GB / {info.get('vram_total_gb', 0):.2f}GB"
            )
    
    def get_optimal_training_config(self) -> Dict[str, Any]:
        """Get recommended configurations for model training based on hardware capabilities."""
        # Default configs that rely on PyTorch's automatic memory management
        config = {
            "device_map": "auto",
            "fp16": False,
            "bf16": False,
            "gradient_checkpointing": True,
            "gradient_accumulation_steps": 1,
        }
        
        # Enable mixed precision based on hardware support
        if self.cuda_available:
            capability = torch.cuda.get_device_capability()
            if capability[0] >= 8:  # Ampere or newer (supports BF16)
                config["bf16"] = True
            elif capability[0] >= 7:  # Volta or newer (supports FP16)
                config["fp16"] = True
                
            # Adjust accumulation steps based on available memory
            vram_gb = self.get_memory_info().get("vram_total_gb", 0)
            if vram_gb < 8:  # Small GPUs
                config["gradient_accumulation_steps"] = 4
            elif vram_gb < 16:  # Medium GPUs
                config["gradient_accumulation_steps"] = 2
        
        return config
    
    def optimize_model_for_training(self, model):
        """Apply PyTorch's built-in memory optimizations for training."""
        # Enable gradient checkpointing if available
        if hasattr(model, "gradient_checkpointing_enable"):
            logger.info("Enabling gradient checkpointing for memory efficiency")
            model.gradient_checkpointing_enable()
            
        # Enable memory-efficient attention for PyTorch 2.0+
        if hasattr(model, "config"):
            try:
                model.config.use_memory_efficient_attention = True
            except:
                pass
            
            # Enable flash attention for compatible GPUs
            if self.cuda_available and torch.cuda.get_device_capability()[0] >= 8:
                try:
                    model.config.attn_implementation = "flash_attention_2"
                except:
                    pass
        
        return model
    
    def optimize_training_args(self, training_args):
        """Configure training arguments for efficient memory usage."""
        if not training_args:
            return None
            
        # Get optimal configuration based on hardware
        config = self.get_optimal_training_config()
        
        # Apply configurations to training arguments
        if not getattr(training_args, "fp16", False) and not getattr(training_args, "bf16", False):
            training_args.fp16 = config["fp16"]
            training_args.bf16 = config["bf16"]
        
        if not getattr(training_args, "gradient_checkpointing", False):
            training_args.gradient_checkpointing = config["gradient_checkpointing"]
        
        if training_args.gradient_accumulation_steps == 1:
            training_args.gradient_accumulation_steps = config["gradient_accumulation_steps"]
        
        logger.info("Training configuration optimized for memory efficiency:")
        logger.info(f"  Mixed precision: FP16={training_args.fp16}, BF16={training_args.bf16}")
        logger.info(f"  Gradient checkpointing: {training_args.gradient_checkpointing}")
        logger.info(f"  Gradient accumulation steps: {training_args.gradient_accumulation_steps}")
        
        return training_args


# Global memory manager instance
memory_manager = MemoryManager()

def get_memory_manager() -> MemoryManager:
    """Get the global memory manager instance."""
    return memory_manager