SirajRLX's picture
Upload folder using huggingface_hub
4eae728 verified
#!/usr/bin/env python3
"""
Test script to verify all critical fixes have been applied to run_dpo.py
"""
import sys
from pathlib import Path
def check_fixes():
"""Check if all critical fixes are present in run_dpo.py"""
filepath = Path("run_dpo.py")
if not filepath.exists():
print(f"❌ Error: {filepath} not found")
return False
with open(filepath, 'r') as f:
content = f.read()
checks = []
# Check 1: Memory cleanup imports
if 'import gc' in content:
checks.append(("βœ…", "gc import added"))
else:
checks.append(("❌", "gc import missing"))
# Check 2: Logging setup
if 'import logging' in content and 'logging.basicConfig' in content:
checks.append(("βœ…", "Logging setup configured"))
else:
checks.append(("❌", "Logging setup missing"))
# Check 3: Custom exceptions
if 'class DataFormattingError' in content and 'class DataValidationError' in content:
checks.append(("βœ…", "Custom exceptions defined"))
else:
checks.append(("❌", "Custom exceptions missing"))
# Check 4: Data validation function
if 'def validate_dpo_data' in content:
checks.append(("βœ…", "Data validation function defined"))
if 'validate_dpo_data(formatted_train' in content:
checks.append(("βœ…", "Data validation called for train"))
else:
checks.append(("❌", "Data validation not called"))
else:
checks.append(("❌", "Data validation function missing"))
# Check 5: Memory cleanup in merge_adapter
if 'del base' in content and 'gc.collect()' in content:
checks.append(("βœ…", "Memory cleanup in merge_adapter"))
else:
checks.append(("❌", "Memory cleanup missing"))
# Check 6: TRL version check
if 'version.parse(trl.__version__)' in content:
checks.append(("βœ…", "TRL version check added"))
else:
checks.append(("❌", "TRL version check missing"))
# Check 7: Error handling in format function
if 'except (DataFormattingError, Exception) as e:' in content:
checks.append(("βœ…", "Error handling in format function"))
else:
checks.append(("❌", "Error handling missing"))
# Check 8: Logger usage (replaced some prints)
if 'logger.info' in content and 'logger.warning' in content:
checks.append(("βœ…", "Logger used for logging"))
else:
checks.append(("❌", "Logger not properly used"))
# Check 9: Gradient norm validation (should be in TrainingArguments)
if 'max_grad_norm' in content:
checks.append(("βœ…", "Gradient norm parameter present"))
else:
checks.append(("⚠️", "Gradient norm parameter not found"))
# Print results
print("="*80)
print("DPO TRAINER - FIX VERIFICATION")
print("="*80)
for status, message in checks:
print(f"{status} {message}")
print("="*80)
# Summary
passed = sum(1 for s, _ in checks if s == "βœ…")
failed = sum(1 for s, _ in checks if s == "❌")
warnings = sum(1 for s, _ in checks if s == "⚠️")
print(f"\nSummary: {passed} passed, {failed} failed, {warnings} warnings")
if failed == 0:
print("\nβœ… All critical fixes verified successfully!")
print("\nYou can now proceed with training:")
print(" python run_dpo.py --config config_dpo.yaml")
return True
else:
print("\n❌ Some fixes are missing. Please review the implementation.")
return False
if __name__ == "__main__":
success = check_fixes()
sys.exit(0 if success else 1)