#!/usr/bin/env python3 """ Test script to verify all critical fixes have been applied to run_dpo.py """ import sys from pathlib import Path def check_fixes(): """Check if all critical fixes are present in run_dpo.py""" filepath = Path("run_dpo.py") if not filepath.exists(): print(f"❌ Error: {filepath} not found") return False with open(filepath, 'r') as f: content = f.read() checks = [] # Check 1: Memory cleanup imports if 'import gc' in content: checks.append(("✅", "gc import added")) else: checks.append(("❌", "gc import missing")) # Check 2: Logging setup if 'import logging' in content and 'logging.basicConfig' in content: checks.append(("✅", "Logging setup configured")) else: checks.append(("❌", "Logging setup missing")) # Check 3: Custom exceptions if 'class DataFormattingError' in content and 'class DataValidationError' in content: checks.append(("✅", "Custom exceptions defined")) else: checks.append(("❌", "Custom exceptions missing")) # Check 4: Data validation function if 'def validate_dpo_data' in content: checks.append(("✅", "Data validation function defined")) if 'validate_dpo_data(formatted_train' in content: checks.append(("✅", "Data validation called for train")) else: checks.append(("❌", "Data validation not called")) else: checks.append(("❌", "Data validation function missing")) # Check 5: Memory cleanup in merge_adapter if 'del base' in content and 'gc.collect()' in content: checks.append(("✅", "Memory cleanup in merge_adapter")) else: checks.append(("❌", "Memory cleanup missing")) # Check 6: TRL version check if 'version.parse(trl.__version__)' in content: checks.append(("✅", "TRL version check added")) else: checks.append(("❌", "TRL version check missing")) # Check 7: Error handling in format function if 'except (DataFormattingError, Exception) as e:' in content: checks.append(("✅", "Error handling in format function")) else: checks.append(("❌", "Error handling missing")) # Check 8: Logger usage (replaced some prints) if 'logger.info' in content and 'logger.warning' in content: checks.append(("✅", "Logger used for logging")) else: checks.append(("❌", "Logger not properly used")) # Check 9: Gradient norm validation (should be in TrainingArguments) if 'max_grad_norm' in content: checks.append(("✅", "Gradient norm parameter present")) else: checks.append(("⚠️", "Gradient norm parameter not found")) # Print results print("="*80) print("DPO TRAINER - FIX VERIFICATION") print("="*80) for status, message in checks: print(f"{status} {message}") print("="*80) # Summary passed = sum(1 for s, _ in checks if s == "✅") failed = sum(1 for s, _ in checks if s == "❌") warnings = sum(1 for s, _ in checks if s == "⚠️") print(f"\nSummary: {passed} passed, {failed} failed, {warnings} warnings") if failed == 0: print("\n✅ All critical fixes verified successfully!") print("\nYou can now proceed with training:") print(" python run_dpo.py --config config_dpo.yaml") return True else: print("\n❌ Some fixes are missing. Please review the implementation.") return False if __name__ == "__main__": success = check_fixes() sys.exit(0 if success else 1)