File size: 6,779 Bytes
4eae728 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | #!/usr/bin/env python3
"""
Quick fix script to apply critical improvements to run_dpo.py
Run this to automatically patch the DPO trainer with all critical fixes.
"""
import re
import shutil
from pathlib import Path
def backup_file(filepath):
"""Create backup of original file"""
backup_path = Path(str(filepath) + '.backup')
shutil.copy2(filepath, backup_path)
print(f"✅ Backup created: {backup_path}")
return backup_path
def apply_fixes(filepath='run_dpo.py'):
"""Apply all critical fixes to the DPO training script"""
filepath = Path(filepath)
if not filepath.exists():
print(f"❌ Error: {filepath} not found")
return False
# Backup original
backup_file(filepath)
with open(filepath, 'r') as f:
content = f.read()
fixes_applied = []
# Fix 1: Add missing imports
if 'import gc' not in content:
content = content.replace(
'import time\nfrom pathlib',
'import gc\nimport time\nimport logging\nfrom pathlib'
)
fixes_applied.append("Added gc and logging imports")
# Fix 2: Add logging setup
if 'logging.basicConfig' not in content:
content = content.replace(
'wandb = None\n\n\n# --------------------------\n# Helpers',
'''wandb = None
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --------------------------
# Custom Exceptions
# --------------------------
class DataFormattingError(Exception):
"""Exception raised for errors in data formatting."""
pass
class DataValidationError(Exception):
"""Exception raised for errors in data validation."""
pass
# --------------------------
# Helpers'''
)
fixes_applied.append("Added logging setup and custom exceptions")
# Fix 3: Add validation function
if 'def validate_dpo_data' not in content:
validation_func = '''
def validate_dpo_data(dataset, stage: str = "train") -> None:
"""
Validate DPO dataset has all required fields and proper structure.
Args:
dataset: Dataset to validate
stage: Training stage ("train" or "eval")
Raises:
DataValidationError if validation fails
"""
required_fields = ["prompt", "chosen", "rejected"]
# Check required fields exist
for field in required_fields:
if field not in dataset.column_names:
raise DataValidationError(
f"{stage} dataset missing required field: {field}. "
f"Available fields: {dataset.column_names}"
)
# Sample validation - check first example
if len(dataset) > 0:
sample = dataset[0]
for field in required_fields:
if not sample[field] or len(sample[field].strip()) == 0:
logger.warning(f"{stage} dataset has empty {field} in first example")
logger.info(f"{stage} dataset validation passed: {len(dataset)} examples")
'''
# Insert before build_dpo_datasets
content = content.replace(
'def build_dpo_datasets(cfg: Dict[str, Any], tokenizer)',
validation_func + 'def build_dpo_datasets(cfg: Dict[str, Any], tokenizer)'
)
fixes_applied.append("Added data validation function")
# Fix 4: Improve merge_adapter with memory cleanup
old_merge = ''' merged.save_pretrained(
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
)
tok = AutoTokenizer.from_pretrained('''
new_merge = ''' # Clean up base model to free memory
del base
gc.collect()
torch.cuda.empty_cache()
merged.save_pretrained(
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
)
# Clean up merged model
del merged
gc.collect()
torch.cuda.empty_cache()
tok = AutoTokenizer.from_pretrained('''
if old_merge in content and 'del base' not in content:
content = content.replace(old_merge, new_merge)
fixes_applied.append("Added memory cleanup in merge_adapter")
# Fix 5: Add TRL version check
if 'version.parse(trl.__version__)' not in content:
content = content.replace(
'from trl import DPOTrainer, DPOConfig',
'''from trl import DPOTrainer, DPOConfig
# Version check for TRL
try:
from packaging import version
import trl
if version.parse(trl.__version__) < version.parse("0.7.0"):
print(f"Warning: TRL version {trl.__version__} detected. Version >= 0.7.0 recommended.")
except ImportError:
print("Warning: Could not verify TRL version")'''
)
fixes_applied.append("Added TRL version check")
# Fix 6: Replace some critical print statements with logger
content = content.replace('print(f"Using local model at:', 'logger.info(f"Using local model at:')
content = content.replace('print(f"Loading reference model', 'logger.info(f"Loading reference model')
content = content.replace('print(f"DPO Training with beta', 'logger.info(f"DPO Training with beta')
content = content.replace('print(f"Resuming from', 'logger.info(f"Resuming from')
content = content.replace('print("Starting DPO training', 'logger.info("Starting DPO training')
content = content.replace('print(f"Saved best adapter', 'logger.info(f"Saved best adapter')
fixes_applied.append("Replaced print with logger calls")
# Write fixed content
with open(filepath, 'w') as f:
f.write(content)
print("\n" + "="*80)
print("DPO TRAINER - FIXES APPLIED")
print("="*80)
for i, fix in enumerate(fixes_applied, 1):
print(f"{i}. ✅ {fix}")
print("="*80)
print(f"\n✅ All fixes applied successfully to {filepath}")
print(f"📁 Original backed up to {filepath}.backup")
print("\nTo verify: python run_dpo.py --config config_dpo.yaml")
return True
if __name__ == "__main__":
import sys
filepath = sys.argv[1] if len(sys.argv) > 1 else "run_dpo.py"
print("DPO Trainer - Quick Fix Script")
print("="*80)
print("This script will apply the following critical fixes:")
print(" 1. Add memory cleanup (gc.collect, torch.cuda.empty_cache)")
print(" 2. Add logging setup")
print(" 3. Add custom exceptions (DataFormattingError, DataValidationError)")
print(" 4. Add data validation function")
print(" 5. Add TRL version check")
print(" 6. Replace print with logger")
print("="*80)
print()
response = input("Apply fixes? [y/N]: ")
if response.lower() == 'y':
apply_fixes(filepath)
else:
print("Cancelled")
|