Wildnerve-tlm01_Hybrid_Model / utils /gpu_config_optimizer.py
WildnerveAI's picture
Upload 20 files
0861a59 verified
"""
Utility to optimize transformer configuration for GPU memory constraints.
"""
import os
import json
import torch
import logging
from pathlib import Path
import argparse
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def optimize_config_for_gpu(config_path, target_vram_mb=None, batch_reduction_factor=0.5):
"""
Optimize config settings for the available GPU memory.
Args:
config_path: Path to the config.json file
target_vram_mb: Target VRAM usage in MB (if None, will use 80% of available)
batch_reduction_factor: How much to reduce batch size (0.5 = half)
"""
# Get current GPU memory capacity
if torch.cuda.is_available():
device = torch.cuda.current_device()
gpu_properties = torch.cuda.get_device_properties(device)
total_memory = gpu_properties.total_memory / (1024 * 1024) # Convert to MB
gpu_name = gpu_properties.name
logger.info(f"GPU detected: {gpu_name} with {total_memory:.0f}MB VRAM")
# Set target memory if not specified (80% of available)
if target_vram_mb is None:
target_vram_mb = int(total_memory * 0.8)
else:
logger.warning("No GPU detected, using conservative settings for CPU")
target_vram_mb = 2048 # Conservative default for CPU
gpu_name = "CPU"
logger.info(f"Target VRAM usage: {target_vram_mb}MB")
# Load current config
with open(config_path, 'r') as f:
config = json.load(f)
# Store original values for reporting
original_batch_size = config["TRANSFORMER_CONFIG"]["BATCH_SIZE"]
original_sequence_length = config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"]
original_num_layers = config["TRANSFORMER_CONFIG"]["NUM_LAYERS"]
# Adjust batch size based on GPU
if torch.cuda.is_available():
# RTX 4050 specific optimizations (around 6GB VRAM)
if "4050" in gpu_name or total_memory < 7000:
# Significant reductions needed for 4050
config["TRANSFORMER_CONFIG"]["BATCH_SIZE"] = max(4, int(original_batch_size * batch_reduction_factor))
# If still too large after batch reduction, reduce sequence length too
if target_vram_mb < 5500: # RTX 4050 has ~6GB VRAM
# Maybe reduce sequence length for very large inputs
if config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"] > 256:
config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"] = 256
# Maybe reduce model complexity if still needed
if config["TRANSFORMER_CONFIG"]["NUM_LAYERS"] > 6:
config["TRANSFORMER_CONFIG"]["NUM_LAYERS"] = 6
# Enable gradient checkpointing (trades compute for memory)
if "OPTIMIZATION" not in config:
config["OPTIMIZATION"] = {}
config["OPTIMIZATION"]["USE_GRADIENT_CHECKPOINTING"] = True
config["OPTIMIZATION"]["USE_MIXED_PRECISION"] = True
# Create optimized filename with the GPU type
gpu_name_simple = gpu_name.replace(" ", "_").lower()
opt_config_path = config_path.replace(".json", f"_{gpu_name_simple}_optimized.json")
# Save optimized config
with open(opt_config_path, 'w') as f:
json.dump(config, f, indent=2)
# Report changes
logger.info(f"Optimized configuration saved to: {opt_config_path}")
logger.info("Changes made:")
logger.info(f" - Batch size: {original_batch_size}{config['TRANSFORMER_CONFIG']['BATCH_SIZE']}")
logger.info(f" - Sequence length: {original_sequence_length}{config['TRANSFORMER_CONFIG']['MAX_SEQ_LENGTH']}")
logger.info(f" - Num layers: {original_num_layers}{config['TRANSFORMER_CONFIG']['NUM_LAYERS']}")
logger.info(f" - Gradient checkpointing: Enabled")
logger.info(f" - Mixed precision: Enabled")
return opt_config_path
def apply_optimized_config(opt_config_path):
"""
Apply the optimized config by updating the main config file.
"""
# Load optimized config
with open(opt_config_path, 'r') as f:
opt_config = json.load(f)
# Get main config path
main_config_path = os.path.join(os.path.dirname(os.path.dirname(opt_config_path)), "config.json")
# Backup original config
backup_path = main_config_path + '.backup'
if not os.path.exists(backup_path):
import shutil
shutil.copy2(main_config_path, backup_path)
logger.info(f"Original config backed up to: {backup_path}")
# Apply optimized config
with open(main_config_path, 'w') as f:
json.dump(opt_config, f, indent=2)
logger.info(f"Applied optimized config to {main_config_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Optimize transformer config for GPU memory constraints")
parser.add_argument("--config", type=str, default="config.json", help="Path to config file")
parser.add_argument("--apply", action="store_true", help="Apply optimized config")
parser.add_argument("--batch-factor", type=float, default=0.5, help="Batch size reduction factor")
parser.add_argument("--target-vram", type=int, default=None, help="Target VRAM usage in MB")
args = parser.parse_args()
# Resolve config path
if not os.path.isabs(args.config):
config_dir = Path(__file__).resolve().parent.parent
config_path = os.path.join(config_dir, args.config)
else:
config_path = args.config
# Optimize config
opt_config_path = optimize_config_for_gpu(
config_path,
target_vram_mb=args.target_vram,
batch_reduction_factor=args.batch_factor
)
# Apply if requested
if args.apply:
apply_optimized_config(opt_config_path)