File size: 3,697 Bytes
4eae728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
#!/usr/bin/env python3
"""
Test script to verify all critical fixes have been applied to run_dpo.py
"""
import sys
from pathlib import Path
def check_fixes():
"""Check if all critical fixes are present in run_dpo.py"""
filepath = Path("run_dpo.py")
if not filepath.exists():
print(f"β Error: {filepath} not found")
return False
with open(filepath, 'r') as f:
content = f.read()
checks = []
# Check 1: Memory cleanup imports
if 'import gc' in content:
checks.append(("β
", "gc import added"))
else:
checks.append(("β", "gc import missing"))
# Check 2: Logging setup
if 'import logging' in content and 'logging.basicConfig' in content:
checks.append(("β
", "Logging setup configured"))
else:
checks.append(("β", "Logging setup missing"))
# Check 3: Custom exceptions
if 'class DataFormattingError' in content and 'class DataValidationError' in content:
checks.append(("β
", "Custom exceptions defined"))
else:
checks.append(("β", "Custom exceptions missing"))
# Check 4: Data validation function
if 'def validate_dpo_data' in content:
checks.append(("β
", "Data validation function defined"))
if 'validate_dpo_data(formatted_train' in content:
checks.append(("β
", "Data validation called for train"))
else:
checks.append(("β", "Data validation not called"))
else:
checks.append(("β", "Data validation function missing"))
# Check 5: Memory cleanup in merge_adapter
if 'del base' in content and 'gc.collect()' in content:
checks.append(("β
", "Memory cleanup in merge_adapter"))
else:
checks.append(("β", "Memory cleanup missing"))
# Check 6: TRL version check
if 'version.parse(trl.__version__)' in content:
checks.append(("β
", "TRL version check added"))
else:
checks.append(("β", "TRL version check missing"))
# Check 7: Error handling in format function
if 'except (DataFormattingError, Exception) as e:' in content:
checks.append(("β
", "Error handling in format function"))
else:
checks.append(("β", "Error handling missing"))
# Check 8: Logger usage (replaced some prints)
if 'logger.info' in content and 'logger.warning' in content:
checks.append(("β
", "Logger used for logging"))
else:
checks.append(("β", "Logger not properly used"))
# Check 9: Gradient norm validation (should be in TrainingArguments)
if 'max_grad_norm' in content:
checks.append(("β
", "Gradient norm parameter present"))
else:
checks.append(("β οΈ", "Gradient norm parameter not found"))
# Print results
print("="*80)
print("DPO TRAINER - FIX VERIFICATION")
print("="*80)
for status, message in checks:
print(f"{status} {message}")
print("="*80)
# Summary
passed = sum(1 for s, _ in checks if s == "β
")
failed = sum(1 for s, _ in checks if s == "β")
warnings = sum(1 for s, _ in checks if s == "β οΈ")
print(f"\nSummary: {passed} passed, {failed} failed, {warnings} warnings")
if failed == 0:
print("\nβ
All critical fixes verified successfully!")
print("\nYou can now proceed with training:")
print(" python run_dpo.py --config config_dpo.yaml")
return True
else:
print("\nβ Some fixes are missing. Please review the implementation.")
return False
if __name__ == "__main__":
success = check_fixes()
sys.exit(0 if success else 1)
|