File size: 3,697 Bytes
4eae728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python3
"""
Test script to verify all critical fixes have been applied to run_dpo.py
"""

import sys
from pathlib import Path

def check_fixes():
    """Check if all critical fixes are present in run_dpo.py"""
    
    filepath = Path("run_dpo.py")
    if not filepath.exists():
        print(f"❌ Error: {filepath} not found")
        return False
    
    with open(filepath, 'r') as f:
        content = f.read()
    
    checks = []
    
    # Check 1: Memory cleanup imports
    if 'import gc' in content:
        checks.append(("βœ…", "gc import added"))
    else:
        checks.append(("❌", "gc import missing"))
    
    # Check 2: Logging setup
    if 'import logging' in content and 'logging.basicConfig' in content:
        checks.append(("βœ…", "Logging setup configured"))
    else:
        checks.append(("❌", "Logging setup missing"))
    
    # Check 3: Custom exceptions
    if 'class DataFormattingError' in content and 'class DataValidationError' in content:
        checks.append(("βœ…", "Custom exceptions defined"))
    else:
        checks.append(("❌", "Custom exceptions missing"))
    
    # Check 4: Data validation function
    if 'def validate_dpo_data' in content:
        checks.append(("βœ…", "Data validation function defined"))
        if 'validate_dpo_data(formatted_train' in content:
            checks.append(("βœ…", "Data validation called for train"))
        else:
            checks.append(("❌", "Data validation not called"))
    else:
        checks.append(("❌", "Data validation function missing"))
    
    # Check 5: Memory cleanup in merge_adapter
    if 'del base' in content and 'gc.collect()' in content:
        checks.append(("βœ…", "Memory cleanup in merge_adapter"))
    else:
        checks.append(("❌", "Memory cleanup missing"))
    
    # Check 6: TRL version check
    if 'version.parse(trl.__version__)' in content:
        checks.append(("βœ…", "TRL version check added"))
    else:
        checks.append(("❌", "TRL version check missing"))
    
    # Check 7: Error handling in format function
    if 'except (DataFormattingError, Exception) as e:' in content:
        checks.append(("βœ…", "Error handling in format function"))
    else:
        checks.append(("❌", "Error handling missing"))
    
    # Check 8: Logger usage (replaced some prints)
    if 'logger.info' in content and 'logger.warning' in content:
        checks.append(("βœ…", "Logger used for logging"))
    else:
        checks.append(("❌", "Logger not properly used"))
    
    # Check 9: Gradient norm validation (should be in TrainingArguments)
    if 'max_grad_norm' in content:
        checks.append(("βœ…", "Gradient norm parameter present"))
    else:
        checks.append(("⚠️", "Gradient norm parameter not found"))
    
    # Print results
    print("="*80)
    print("DPO TRAINER - FIX VERIFICATION")
    print("="*80)
    
    for status, message in checks:
        print(f"{status} {message}")
    
    print("="*80)
    
    # Summary
    passed = sum(1 for s, _ in checks if s == "βœ…")
    failed = sum(1 for s, _ in checks if s == "❌")
    warnings = sum(1 for s, _ in checks if s == "⚠️")
    
    print(f"\nSummary: {passed} passed, {failed} failed, {warnings} warnings")
    
    if failed == 0:
        print("\nβœ… All critical fixes verified successfully!")
        print("\nYou can now proceed with training:")
        print("  python run_dpo.py --config config_dpo.yaml")
        return True
    else:
        print("\n❌ Some fixes are missing. Please review the implementation.")
        return False

if __name__ == "__main__":
    success = check_fixes()
    sys.exit(0 if success else 1)