File size: 6,044 Bytes
0861a59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""

Utility to optimize transformer configuration for GPU memory constraints.

"""
import os
import json
import torch
import logging
from pathlib import Path
import argparse

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

def optimize_config_for_gpu(config_path, target_vram_mb=None, batch_reduction_factor=0.5):
    """

    Optimize config settings for the available GPU memory.

    

    Args:

        config_path: Path to the config.json file

        target_vram_mb: Target VRAM usage in MB (if None, will use 80% of available)

        batch_reduction_factor: How much to reduce batch size (0.5 = half)

    """
    # Get current GPU memory capacity
    if torch.cuda.is_available():
        device = torch.cuda.current_device()
        gpu_properties = torch.cuda.get_device_properties(device)
        total_memory = gpu_properties.total_memory / (1024 * 1024)  # Convert to MB
        gpu_name = gpu_properties.name
        logger.info(f"GPU detected: {gpu_name} with {total_memory:.0f}MB VRAM")
        
        # Set target memory if not specified (80% of available)
        if target_vram_mb is None:
            target_vram_mb = int(total_memory * 0.8)
    else:
        logger.warning("No GPU detected, using conservative settings for CPU")
        target_vram_mb = 2048  # Conservative default for CPU
        gpu_name = "CPU"
    
    logger.info(f"Target VRAM usage: {target_vram_mb}MB")
    
    # Load current config
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    # Store original values for reporting
    original_batch_size = config["TRANSFORMER_CONFIG"]["BATCH_SIZE"]
    original_sequence_length = config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"]
    original_num_layers = config["TRANSFORMER_CONFIG"]["NUM_LAYERS"]
    
    # Adjust batch size based on GPU
    if torch.cuda.is_available():
        # RTX 4050 specific optimizations (around 6GB VRAM)
        if "4050" in gpu_name or total_memory < 7000:
            # Significant reductions needed for 4050
            config["TRANSFORMER_CONFIG"]["BATCH_SIZE"] = max(4, int(original_batch_size * batch_reduction_factor))
            
            # If still too large after batch reduction, reduce sequence length too
            if target_vram_mb < 5500:  # RTX 4050 has ~6GB VRAM
                # Maybe reduce sequence length for very large inputs
                if config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"] > 256:
                    config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"] = 256
                
                # Maybe reduce model complexity if still needed
                if config["TRANSFORMER_CONFIG"]["NUM_LAYERS"] > 6:
                    config["TRANSFORMER_CONFIG"]["NUM_LAYERS"] = 6
    
    # Enable gradient checkpointing (trades compute for memory)
    if "OPTIMIZATION" not in config:
        config["OPTIMIZATION"] = {}
    
    config["OPTIMIZATION"]["USE_GRADIENT_CHECKPOINTING"] = True
    config["OPTIMIZATION"]["USE_MIXED_PRECISION"] = True
    
    # Create optimized filename with the GPU type
    gpu_name_simple = gpu_name.replace(" ", "_").lower() 
    opt_config_path = config_path.replace(".json", f"_{gpu_name_simple}_optimized.json")
    
    # Save optimized config
    with open(opt_config_path, 'w') as f:
        json.dump(config, f, indent=2)
    
    # Report changes
    logger.info(f"Optimized configuration saved to: {opt_config_path}")
    logger.info("Changes made:")
    logger.info(f" - Batch size: {original_batch_size}{config['TRANSFORMER_CONFIG']['BATCH_SIZE']}")
    logger.info(f" - Sequence length: {original_sequence_length}{config['TRANSFORMER_CONFIG']['MAX_SEQ_LENGTH']}")
    logger.info(f" - Num layers: {original_num_layers}{config['TRANSFORMER_CONFIG']['NUM_LAYERS']}")
    logger.info(f" - Gradient checkpointing: Enabled")
    logger.info(f" - Mixed precision: Enabled")
    
    return opt_config_path

def apply_optimized_config(opt_config_path):
    """

    Apply the optimized config by updating the main config file.

    """
    # Load optimized config
    with open(opt_config_path, 'r') as f:
        opt_config = json.load(f)
    
    # Get main config path
    main_config_path = os.path.join(os.path.dirname(os.path.dirname(opt_config_path)), "config.json")
    
    # Backup original config
    backup_path = main_config_path + '.backup'
    if not os.path.exists(backup_path):
        import shutil
        shutil.copy2(main_config_path, backup_path)
        logger.info(f"Original config backed up to: {backup_path}")
    
    # Apply optimized config
    with open(main_config_path, 'w') as f:
        json.dump(opt_config, f, indent=2)
    
    logger.info(f"Applied optimized config to {main_config_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Optimize transformer config for GPU memory constraints")
    parser.add_argument("--config", type=str, default="config.json", help="Path to config file")
    parser.add_argument("--apply", action="store_true", help="Apply optimized config")
    parser.add_argument("--batch-factor", type=float, default=0.5, help="Batch size reduction factor")
    parser.add_argument("--target-vram", type=int, default=None, help="Target VRAM usage in MB")
    
    args = parser.parse_args()
    
    # Resolve config path
    if not os.path.isabs(args.config):
        config_dir = Path(__file__).resolve().parent.parent
        config_path = os.path.join(config_dir, args.config)
    else:
        config_path = args.config
    
    # Optimize config
    opt_config_path = optimize_config_for_gpu(
        config_path,
        target_vram_mb=args.target_vram,
        batch_reduction_factor=args.batch_factor
    )
    
    # Apply if requested
    if args.apply:
        apply_optimized_config(opt_config_path)