|
|
|
|
|
""" |
|
|
Test script to verify all critical fixes have been applied to run_dpo.py |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
def check_fixes(): |
|
|
"""Check if all critical fixes are present in run_dpo.py""" |
|
|
|
|
|
filepath = Path("run_dpo.py") |
|
|
if not filepath.exists(): |
|
|
print(f"β Error: {filepath} not found") |
|
|
return False |
|
|
|
|
|
with open(filepath, 'r') as f: |
|
|
content = f.read() |
|
|
|
|
|
checks = [] |
|
|
|
|
|
|
|
|
if 'import gc' in content: |
|
|
checks.append(("β
", "gc import added")) |
|
|
else: |
|
|
checks.append(("β", "gc import missing")) |
|
|
|
|
|
|
|
|
if 'import logging' in content and 'logging.basicConfig' in content: |
|
|
checks.append(("β
", "Logging setup configured")) |
|
|
else: |
|
|
checks.append(("β", "Logging setup missing")) |
|
|
|
|
|
|
|
|
if 'class DataFormattingError' in content and 'class DataValidationError' in content: |
|
|
checks.append(("β
", "Custom exceptions defined")) |
|
|
else: |
|
|
checks.append(("β", "Custom exceptions missing")) |
|
|
|
|
|
|
|
|
if 'def validate_dpo_data' in content: |
|
|
checks.append(("β
", "Data validation function defined")) |
|
|
if 'validate_dpo_data(formatted_train' in content: |
|
|
checks.append(("β
", "Data validation called for train")) |
|
|
else: |
|
|
checks.append(("β", "Data validation not called")) |
|
|
else: |
|
|
checks.append(("β", "Data validation function missing")) |
|
|
|
|
|
|
|
|
if 'del base' in content and 'gc.collect()' in content: |
|
|
checks.append(("β
", "Memory cleanup in merge_adapter")) |
|
|
else: |
|
|
checks.append(("β", "Memory cleanup missing")) |
|
|
|
|
|
|
|
|
if 'version.parse(trl.__version__)' in content: |
|
|
checks.append(("β
", "TRL version check added")) |
|
|
else: |
|
|
checks.append(("β", "TRL version check missing")) |
|
|
|
|
|
|
|
|
if 'except (DataFormattingError, Exception) as e:' in content: |
|
|
checks.append(("β
", "Error handling in format function")) |
|
|
else: |
|
|
checks.append(("β", "Error handling missing")) |
|
|
|
|
|
|
|
|
if 'logger.info' in content and 'logger.warning' in content: |
|
|
checks.append(("β
", "Logger used for logging")) |
|
|
else: |
|
|
checks.append(("β", "Logger not properly used")) |
|
|
|
|
|
|
|
|
if 'max_grad_norm' in content: |
|
|
checks.append(("β
", "Gradient norm parameter present")) |
|
|
else: |
|
|
checks.append(("β οΈ", "Gradient norm parameter not found")) |
|
|
|
|
|
|
|
|
print("="*80) |
|
|
print("DPO TRAINER - FIX VERIFICATION") |
|
|
print("="*80) |
|
|
|
|
|
for status, message in checks: |
|
|
print(f"{status} {message}") |
|
|
|
|
|
print("="*80) |
|
|
|
|
|
|
|
|
passed = sum(1 for s, _ in checks if s == "β
") |
|
|
failed = sum(1 for s, _ in checks if s == "β") |
|
|
warnings = sum(1 for s, _ in checks if s == "β οΈ") |
|
|
|
|
|
print(f"\nSummary: {passed} passed, {failed} failed, {warnings} warnings") |
|
|
|
|
|
if failed == 0: |
|
|
print("\nβ
All critical fixes verified successfully!") |
|
|
print("\nYou can now proceed with training:") |
|
|
print(" python run_dpo.py --config config_dpo.yaml") |
|
|
return True |
|
|
else: |
|
|
print("\nβ Some fixes are missing. Please review the implementation.") |
|
|
return False |
|
|
|
|
|
if __name__ == "__main__": |
|
|
success = check_fixes() |
|
|
sys.exit(0 if success else 1) |
|
|
|