""" Utility to optimize transformer configuration for GPU memory constraints. """ import os import json import torch import logging from pathlib import Path import argparse # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) def optimize_config_for_gpu(config_path, target_vram_mb=None, batch_reduction_factor=0.5): """ Optimize config settings for the available GPU memory. Args: config_path: Path to the config.json file target_vram_mb: Target VRAM usage in MB (if None, will use 80% of available) batch_reduction_factor: How much to reduce batch size (0.5 = half) """ # Get current GPU memory capacity if torch.cuda.is_available(): device = torch.cuda.current_device() gpu_properties = torch.cuda.get_device_properties(device) total_memory = gpu_properties.total_memory / (1024 * 1024) # Convert to MB gpu_name = gpu_properties.name logger.info(f"GPU detected: {gpu_name} with {total_memory:.0f}MB VRAM") # Set target memory if not specified (80% of available) if target_vram_mb is None: target_vram_mb = int(total_memory * 0.8) else: logger.warning("No GPU detected, using conservative settings for CPU") target_vram_mb = 2048 # Conservative default for CPU gpu_name = "CPU" logger.info(f"Target VRAM usage: {target_vram_mb}MB") # Load current config with open(config_path, 'r') as f: config = json.load(f) # Store original values for reporting original_batch_size = config["TRANSFORMER_CONFIG"]["BATCH_SIZE"] original_sequence_length = config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"] original_num_layers = config["TRANSFORMER_CONFIG"]["NUM_LAYERS"] # Adjust batch size based on GPU if torch.cuda.is_available(): # RTX 4050 specific optimizations (around 6GB VRAM) if "4050" in gpu_name or total_memory < 7000: # Significant reductions needed for 4050 config["TRANSFORMER_CONFIG"]["BATCH_SIZE"] = max(4, int(original_batch_size * batch_reduction_factor)) # If still too large after batch reduction, reduce sequence length too if target_vram_mb < 5500: # RTX 4050 has ~6GB VRAM # Maybe reduce sequence length for very large inputs if config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"] > 256: config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"] = 256 # Maybe reduce model complexity if still needed if config["TRANSFORMER_CONFIG"]["NUM_LAYERS"] > 6: config["TRANSFORMER_CONFIG"]["NUM_LAYERS"] = 6 # Enable gradient checkpointing (trades compute for memory) if "OPTIMIZATION" not in config: config["OPTIMIZATION"] = {} config["OPTIMIZATION"]["USE_GRADIENT_CHECKPOINTING"] = True config["OPTIMIZATION"]["USE_MIXED_PRECISION"] = True # Create optimized filename with the GPU type gpu_name_simple = gpu_name.replace(" ", "_").lower() opt_config_path = config_path.replace(".json", f"_{gpu_name_simple}_optimized.json") # Save optimized config with open(opt_config_path, 'w') as f: json.dump(config, f, indent=2) # Report changes logger.info(f"Optimized configuration saved to: {opt_config_path}") logger.info("Changes made:") logger.info(f" - Batch size: {original_batch_size} → {config['TRANSFORMER_CONFIG']['BATCH_SIZE']}") logger.info(f" - Sequence length: {original_sequence_length} → {config['TRANSFORMER_CONFIG']['MAX_SEQ_LENGTH']}") logger.info(f" - Num layers: {original_num_layers} → {config['TRANSFORMER_CONFIG']['NUM_LAYERS']}") logger.info(f" - Gradient checkpointing: Enabled") logger.info(f" - Mixed precision: Enabled") return opt_config_path def apply_optimized_config(opt_config_path): """ Apply the optimized config by updating the main config file. """ # Load optimized config with open(opt_config_path, 'r') as f: opt_config = json.load(f) # Get main config path main_config_path = os.path.join(os.path.dirname(os.path.dirname(opt_config_path)), "config.json") # Backup original config backup_path = main_config_path + '.backup' if not os.path.exists(backup_path): import shutil shutil.copy2(main_config_path, backup_path) logger.info(f"Original config backed up to: {backup_path}") # Apply optimized config with open(main_config_path, 'w') as f: json.dump(opt_config, f, indent=2) logger.info(f"Applied optimized config to {main_config_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Optimize transformer config for GPU memory constraints") parser.add_argument("--config", type=str, default="config.json", help="Path to config file") parser.add_argument("--apply", action="store_true", help="Apply optimized config") parser.add_argument("--batch-factor", type=float, default=0.5, help="Batch size reduction factor") parser.add_argument("--target-vram", type=int, default=None, help="Target VRAM usage in MB") args = parser.parse_args() # Resolve config path if not os.path.isabs(args.config): config_dir = Path(__file__).resolve().parent.parent config_path = os.path.join(config_dir, args.config) else: config_path = args.config # Optimize config opt_config_path = optimize_config_for_gpu( config_path, target_vram_mb=args.target_vram, batch_reduction_factor=args.batch_factor ) # Apply if requested if args.apply: apply_optimized_config(opt_config_path)