#!/usr/bin/env python3 """ Quick fix script to apply critical improvements to run_dpo.py Run this to automatically patch the DPO trainer with all critical fixes. """ import re import shutil from pathlib import Path def backup_file(filepath): """Create backup of original file""" backup_path = Path(str(filepath) + '.backup') shutil.copy2(filepath, backup_path) print(f"āœ… Backup created: {backup_path}") return backup_path def apply_fixes(filepath='run_dpo.py'): """Apply all critical fixes to the DPO training script""" filepath = Path(filepath) if not filepath.exists(): print(f"āŒ Error: {filepath} not found") return False # Backup original backup_file(filepath) with open(filepath, 'r') as f: content = f.read() fixes_applied = [] # Fix 1: Add missing imports if 'import gc' not in content: content = content.replace( 'import time\nfrom pathlib', 'import gc\nimport time\nimport logging\nfrom pathlib' ) fixes_applied.append("Added gc and logging imports") # Fix 2: Add logging setup if 'logging.basicConfig' not in content: content = content.replace( 'wandb = None\n\n\n# --------------------------\n# Helpers', '''wandb = None # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # -------------------------- # Custom Exceptions # -------------------------- class DataFormattingError(Exception): """Exception raised for errors in data formatting.""" pass class DataValidationError(Exception): """Exception raised for errors in data validation.""" pass # -------------------------- # Helpers''' ) fixes_applied.append("Added logging setup and custom exceptions") # Fix 3: Add validation function if 'def validate_dpo_data' not in content: validation_func = ''' def validate_dpo_data(dataset, stage: str = "train") -> None: """ Validate DPO dataset has all required fields and proper structure. Args: dataset: Dataset to validate stage: Training stage ("train" or "eval") Raises: DataValidationError if validation fails """ required_fields = ["prompt", "chosen", "rejected"] # Check required fields exist for field in required_fields: if field not in dataset.column_names: raise DataValidationError( f"{stage} dataset missing required field: {field}. " f"Available fields: {dataset.column_names}" ) # Sample validation - check first example if len(dataset) > 0: sample = dataset[0] for field in required_fields: if not sample[field] or len(sample[field].strip()) == 0: logger.warning(f"{stage} dataset has empty {field} in first example") logger.info(f"{stage} dataset validation passed: {len(dataset)} examples") ''' # Insert before build_dpo_datasets content = content.replace( 'def build_dpo_datasets(cfg: Dict[str, Any], tokenizer)', validation_func + 'def build_dpo_datasets(cfg: Dict[str, Any], tokenizer)' ) fixes_applied.append("Added data validation function") # Fix 4: Improve merge_adapter with memory cleanup old_merge = ''' merged.save_pretrained( str(final_dir), safe_serialization=True, max_shard_size=max_shard_size ) tok = AutoTokenizer.from_pretrained(''' new_merge = ''' # Clean up base model to free memory del base gc.collect() torch.cuda.empty_cache() merged.save_pretrained( str(final_dir), safe_serialization=True, max_shard_size=max_shard_size ) # Clean up merged model del merged gc.collect() torch.cuda.empty_cache() tok = AutoTokenizer.from_pretrained(''' if old_merge in content and 'del base' not in content: content = content.replace(old_merge, new_merge) fixes_applied.append("Added memory cleanup in merge_adapter") # Fix 5: Add TRL version check if 'version.parse(trl.__version__)' not in content: content = content.replace( 'from trl import DPOTrainer, DPOConfig', '''from trl import DPOTrainer, DPOConfig # Version check for TRL try: from packaging import version import trl if version.parse(trl.__version__) < version.parse("0.7.0"): print(f"Warning: TRL version {trl.__version__} detected. Version >= 0.7.0 recommended.") except ImportError: print("Warning: Could not verify TRL version")''' ) fixes_applied.append("Added TRL version check") # Fix 6: Replace some critical print statements with logger content = content.replace('print(f"Using local model at:', 'logger.info(f"Using local model at:') content = content.replace('print(f"Loading reference model', 'logger.info(f"Loading reference model') content = content.replace('print(f"DPO Training with beta', 'logger.info(f"DPO Training with beta') content = content.replace('print(f"Resuming from', 'logger.info(f"Resuming from') content = content.replace('print("Starting DPO training', 'logger.info("Starting DPO training') content = content.replace('print(f"Saved best adapter', 'logger.info(f"Saved best adapter') fixes_applied.append("Replaced print with logger calls") # Write fixed content with open(filepath, 'w') as f: f.write(content) print("\n" + "="*80) print("DPO TRAINER - FIXES APPLIED") print("="*80) for i, fix in enumerate(fixes_applied, 1): print(f"{i}. āœ… {fix}") print("="*80) print(f"\nāœ… All fixes applied successfully to {filepath}") print(f"šŸ“ Original backed up to {filepath}.backup") print("\nTo verify: python run_dpo.py --config config_dpo.yaml") return True if __name__ == "__main__": import sys filepath = sys.argv[1] if len(sys.argv) > 1 else "run_dpo.py" print("DPO Trainer - Quick Fix Script") print("="*80) print("This script will apply the following critical fixes:") print(" 1. Add memory cleanup (gc.collect, torch.cuda.empty_cache)") print(" 2. Add logging setup") print(" 3. Add custom exceptions (DataFormattingError, DataValidationError)") print(" 4. Add data validation function") print(" 5. Add TRL version check") print(" 6. Replace print with logger") print("="*80) print() response = input("Apply fixes? [y/N]: ") if response.lower() == 'y': apply_fixes(filepath) else: print("Cancelled")