File size: 6,779 Bytes
4eae728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#!/usr/bin/env python3
"""
Quick fix script to apply critical improvements to run_dpo.py
Run this to automatically patch the DPO trainer with all critical fixes.
"""

import re
import shutil
from pathlib import Path

def backup_file(filepath):
    """Create backup of original file"""
    backup_path = Path(str(filepath) + '.backup')
    shutil.copy2(filepath, backup_path)
    print(f"✅ Backup created: {backup_path}")
    return backup_path

def apply_fixes(filepath='run_dpo.py'):
    """Apply all critical fixes to the DPO training script"""
    
    filepath = Path(filepath)
    if not filepath.exists():
        print(f"❌ Error: {filepath} not found")
        return False
    
    # Backup original
    backup_file(filepath)
    
    with open(filepath, 'r') as f:
        content = f.read()
    
    fixes_applied = []
    
    # Fix 1: Add missing imports
    if 'import gc' not in content:
        content = content.replace(
            'import time\nfrom pathlib',
            'import gc\nimport time\nimport logging\nfrom pathlib'
        )
        fixes_applied.append("Added gc and logging imports")
    
    # Fix 2: Add logging setup
    if 'logging.basicConfig' not in content:
        content = content.replace(
            'wandb = None\n\n\n# --------------------------\n# Helpers',
            '''wandb = None

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


# --------------------------
# Custom Exceptions
# --------------------------


class DataFormattingError(Exception):
    """Exception raised for errors in data formatting."""
    pass


class DataValidationError(Exception):
    """Exception raised for errors in data validation."""
    pass


# --------------------------
# Helpers'''
        )
        fixes_applied.append("Added logging setup and custom exceptions")
    
    # Fix 3: Add validation function
    if 'def validate_dpo_data' not in content:
        validation_func = '''

def validate_dpo_data(dataset, stage: str = "train") -> None:
    """
    Validate DPO dataset has all required fields and proper structure.
    
    Args:
        dataset: Dataset to validate
        stage: Training stage ("train" or "eval")
    
    Raises:
        DataValidationError if validation fails
    """
    required_fields = ["prompt", "chosen", "rejected"]
    
    # Check required fields exist
    for field in required_fields:
        if field not in dataset.column_names:
            raise DataValidationError(
                f"{stage} dataset missing required field: {field}. "
                f"Available fields: {dataset.column_names}"
            )
    
    # Sample validation - check first example
    if len(dataset) > 0:
        sample = dataset[0]
        for field in required_fields:
            if not sample[field] or len(sample[field].strip()) == 0:
                logger.warning(f"{stage} dataset has empty {field} in first example")
    
    logger.info(f"{stage} dataset validation passed: {len(dataset)} examples")

'''
        # Insert before build_dpo_datasets
        content = content.replace(
            'def build_dpo_datasets(cfg: Dict[str, Any], tokenizer)',
            validation_func + 'def build_dpo_datasets(cfg: Dict[str, Any], tokenizer)'
        )
        fixes_applied.append("Added data validation function")
    
    # Fix 4: Improve merge_adapter with memory cleanup
    old_merge = '''    merged.save_pretrained(
        str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
    )

    tok = AutoTokenizer.from_pretrained('''
    
    new_merge = '''    # Clean up base model to free memory
    del base
    gc.collect()
    torch.cuda.empty_cache()

    merged.save_pretrained(
        str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
    )

    # Clean up merged model
    del merged
    gc.collect()
    torch.cuda.empty_cache()

    tok = AutoTokenizer.from_pretrained('''
    
    if old_merge in content and 'del base' not in content:
        content = content.replace(old_merge, new_merge)
        fixes_applied.append("Added memory cleanup in merge_adapter")
    
    # Fix 5: Add TRL version check
    if 'version.parse(trl.__version__)' not in content:
        content = content.replace(
            'from trl import DPOTrainer, DPOConfig',
            '''from trl import DPOTrainer, DPOConfig

# Version check for TRL
try:
    from packaging import version
    import trl
    if version.parse(trl.__version__) < version.parse("0.7.0"):
        print(f"Warning: TRL version {trl.__version__} detected. Version >= 0.7.0 recommended.")
except ImportError:
    print("Warning: Could not verify TRL version")'''
        )
        fixes_applied.append("Added TRL version check")
    
    # Fix 6: Replace some critical print statements with logger
    content = content.replace('print(f"Using local model at:', 'logger.info(f"Using local model at:')
    content = content.replace('print(f"Loading reference model', 'logger.info(f"Loading reference model')
    content = content.replace('print(f"DPO Training with beta', 'logger.info(f"DPO Training with beta')
    content = content.replace('print(f"Resuming from', 'logger.info(f"Resuming from')
    content = content.replace('print("Starting DPO training', 'logger.info("Starting DPO training')
    content = content.replace('print(f"Saved best adapter', 'logger.info(f"Saved best adapter')
    
    fixes_applied.append("Replaced print with logger calls")
    
    # Write fixed content
    with open(filepath, 'w') as f:
        f.write(content)
    
    print("\n" + "="*80)
    print("DPO TRAINER - FIXES APPLIED")
    print("="*80)
    for i, fix in enumerate(fixes_applied, 1):
        print(f"{i}. ✅ {fix}")
    print("="*80)
    print(f"\n✅ All fixes applied successfully to {filepath}")
    print(f"📁 Original backed up to {filepath}.backup")
    print("\nTo verify: python run_dpo.py --config config_dpo.yaml")
    
    return True

if __name__ == "__main__":
    import sys
    
    filepath = sys.argv[1] if len(sys.argv) > 1 else "run_dpo.py"
    
    print("DPO Trainer - Quick Fix Script")
    print("="*80)
    print("This script will apply the following critical fixes:")
    print("  1. Add memory cleanup (gc.collect, torch.cuda.empty_cache)")
    print("  2. Add logging setup")
    print("  3. Add custom exceptions (DataFormattingError, DataValidationError)")
    print("  4. Add data validation function")
    print("  5. Add TRL version check")
    print("  6. Replace print with logger")
    print("="*80)
    print()
    
    response = input("Apply fixes? [y/N]: ")
    if response.lower() == 'y':
        apply_fixes(filepath)
    else:
        print("Cancelled")