|
|
"""
|
|
|
Utility to optimize transformer configuration for GPU memory constraints.
|
|
|
"""
|
|
|
import os
|
|
|
import json
|
|
|
import torch
|
|
|
import logging
|
|
|
from pathlib import Path
|
|
|
import argparse
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def optimize_config_for_gpu(config_path, target_vram_mb=None, batch_reduction_factor=0.5):
|
|
|
"""
|
|
|
Optimize config settings for the available GPU memory.
|
|
|
|
|
|
Args:
|
|
|
config_path: Path to the config.json file
|
|
|
target_vram_mb: Target VRAM usage in MB (if None, will use 80% of available)
|
|
|
batch_reduction_factor: How much to reduce batch size (0.5 = half)
|
|
|
"""
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
device = torch.cuda.current_device()
|
|
|
gpu_properties = torch.cuda.get_device_properties(device)
|
|
|
total_memory = gpu_properties.total_memory / (1024 * 1024)
|
|
|
gpu_name = gpu_properties.name
|
|
|
logger.info(f"GPU detected: {gpu_name} with {total_memory:.0f}MB VRAM")
|
|
|
|
|
|
|
|
|
if target_vram_mb is None:
|
|
|
target_vram_mb = int(total_memory * 0.8)
|
|
|
else:
|
|
|
logger.warning("No GPU detected, using conservative settings for CPU")
|
|
|
target_vram_mb = 2048
|
|
|
gpu_name = "CPU"
|
|
|
|
|
|
logger.info(f"Target VRAM usage: {target_vram_mb}MB")
|
|
|
|
|
|
|
|
|
with open(config_path, 'r') as f:
|
|
|
config = json.load(f)
|
|
|
|
|
|
|
|
|
original_batch_size = config["TRANSFORMER_CONFIG"]["BATCH_SIZE"]
|
|
|
original_sequence_length = config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"]
|
|
|
original_num_layers = config["TRANSFORMER_CONFIG"]["NUM_LAYERS"]
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
|
if "4050" in gpu_name or total_memory < 7000:
|
|
|
|
|
|
config["TRANSFORMER_CONFIG"]["BATCH_SIZE"] = max(4, int(original_batch_size * batch_reduction_factor))
|
|
|
|
|
|
|
|
|
if target_vram_mb < 5500:
|
|
|
|
|
|
if config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"] > 256:
|
|
|
config["TRANSFORMER_CONFIG"]["MAX_SEQ_LENGTH"] = 256
|
|
|
|
|
|
|
|
|
if config["TRANSFORMER_CONFIG"]["NUM_LAYERS"] > 6:
|
|
|
config["TRANSFORMER_CONFIG"]["NUM_LAYERS"] = 6
|
|
|
|
|
|
|
|
|
if "OPTIMIZATION" not in config:
|
|
|
config["OPTIMIZATION"] = {}
|
|
|
|
|
|
config["OPTIMIZATION"]["USE_GRADIENT_CHECKPOINTING"] = True
|
|
|
config["OPTIMIZATION"]["USE_MIXED_PRECISION"] = True
|
|
|
|
|
|
|
|
|
gpu_name_simple = gpu_name.replace(" ", "_").lower()
|
|
|
opt_config_path = config_path.replace(".json", f"_{gpu_name_simple}_optimized.json")
|
|
|
|
|
|
|
|
|
with open(opt_config_path, 'w') as f:
|
|
|
json.dump(config, f, indent=2)
|
|
|
|
|
|
|
|
|
logger.info(f"Optimized configuration saved to: {opt_config_path}")
|
|
|
logger.info("Changes made:")
|
|
|
logger.info(f" - Batch size: {original_batch_size} → {config['TRANSFORMER_CONFIG']['BATCH_SIZE']}")
|
|
|
logger.info(f" - Sequence length: {original_sequence_length} → {config['TRANSFORMER_CONFIG']['MAX_SEQ_LENGTH']}")
|
|
|
logger.info(f" - Num layers: {original_num_layers} → {config['TRANSFORMER_CONFIG']['NUM_LAYERS']}")
|
|
|
logger.info(f" - Gradient checkpointing: Enabled")
|
|
|
logger.info(f" - Mixed precision: Enabled")
|
|
|
|
|
|
return opt_config_path
|
|
|
|
|
|
def apply_optimized_config(opt_config_path):
|
|
|
"""
|
|
|
Apply the optimized config by updating the main config file.
|
|
|
"""
|
|
|
|
|
|
with open(opt_config_path, 'r') as f:
|
|
|
opt_config = json.load(f)
|
|
|
|
|
|
|
|
|
main_config_path = os.path.join(os.path.dirname(os.path.dirname(opt_config_path)), "config.json")
|
|
|
|
|
|
|
|
|
backup_path = main_config_path + '.backup'
|
|
|
if not os.path.exists(backup_path):
|
|
|
import shutil
|
|
|
shutil.copy2(main_config_path, backup_path)
|
|
|
logger.info(f"Original config backed up to: {backup_path}")
|
|
|
|
|
|
|
|
|
with open(main_config_path, 'w') as f:
|
|
|
json.dump(opt_config, f, indent=2)
|
|
|
|
|
|
logger.info(f"Applied optimized config to {main_config_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(description="Optimize transformer config for GPU memory constraints")
|
|
|
parser.add_argument("--config", type=str, default="config.json", help="Path to config file")
|
|
|
parser.add_argument("--apply", action="store_true", help="Apply optimized config")
|
|
|
parser.add_argument("--batch-factor", type=float, default=0.5, help="Batch size reduction factor")
|
|
|
parser.add_argument("--target-vram", type=int, default=None, help="Target VRAM usage in MB")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
if not os.path.isabs(args.config):
|
|
|
config_dir = Path(__file__).resolve().parent.parent
|
|
|
config_path = os.path.join(config_dir, args.config)
|
|
|
else:
|
|
|
config_path = args.config
|
|
|
|
|
|
|
|
|
opt_config_path = optimize_config_for_gpu(
|
|
|
config_path,
|
|
|
target_vram_mb=args.target_vram,
|
|
|
batch_reduction_factor=args.batch_factor
|
|
|
)
|
|
|
|
|
|
|
|
|
if args.apply:
|
|
|
apply_optimized_config(opt_config_path)
|
|
|
|