task2file-llm / trainer-kit /DPO /apply_critical_fixes.py
SirajRLX's picture
Upload folder using huggingface_hub
4eae728 verified
#!/usr/bin/env python3
"""
Quick fix script to apply critical improvements to run_dpo.py
Run this to automatically patch the DPO trainer with all critical fixes.
"""
import re
import shutil
from pathlib import Path
def backup_file(filepath):
"""Create backup of original file"""
backup_path = Path(str(filepath) + '.backup')
shutil.copy2(filepath, backup_path)
print(f"โœ… Backup created: {backup_path}")
return backup_path
def apply_fixes(filepath='run_dpo.py'):
"""Apply all critical fixes to the DPO training script"""
filepath = Path(filepath)
if not filepath.exists():
print(f"โŒ Error: {filepath} not found")
return False
# Backup original
backup_file(filepath)
with open(filepath, 'r') as f:
content = f.read()
fixes_applied = []
# Fix 1: Add missing imports
if 'import gc' not in content:
content = content.replace(
'import time\nfrom pathlib',
'import gc\nimport time\nimport logging\nfrom pathlib'
)
fixes_applied.append("Added gc and logging imports")
# Fix 2: Add logging setup
if 'logging.basicConfig' not in content:
content = content.replace(
'wandb = None\n\n\n# --------------------------\n# Helpers',
'''wandb = None
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --------------------------
# Custom Exceptions
# --------------------------
class DataFormattingError(Exception):
"""Exception raised for errors in data formatting."""
pass
class DataValidationError(Exception):
"""Exception raised for errors in data validation."""
pass
# --------------------------
# Helpers'''
)
fixes_applied.append("Added logging setup and custom exceptions")
# Fix 3: Add validation function
if 'def validate_dpo_data' not in content:
validation_func = '''
def validate_dpo_data(dataset, stage: str = "train") -> None:
"""
Validate DPO dataset has all required fields and proper structure.
Args:
dataset: Dataset to validate
stage: Training stage ("train" or "eval")
Raises:
DataValidationError if validation fails
"""
required_fields = ["prompt", "chosen", "rejected"]
# Check required fields exist
for field in required_fields:
if field not in dataset.column_names:
raise DataValidationError(
f"{stage} dataset missing required field: {field}. "
f"Available fields: {dataset.column_names}"
)
# Sample validation - check first example
if len(dataset) > 0:
sample = dataset[0]
for field in required_fields:
if not sample[field] or len(sample[field].strip()) == 0:
logger.warning(f"{stage} dataset has empty {field} in first example")
logger.info(f"{stage} dataset validation passed: {len(dataset)} examples")
'''
# Insert before build_dpo_datasets
content = content.replace(
'def build_dpo_datasets(cfg: Dict[str, Any], tokenizer)',
validation_func + 'def build_dpo_datasets(cfg: Dict[str, Any], tokenizer)'
)
fixes_applied.append("Added data validation function")
# Fix 4: Improve merge_adapter with memory cleanup
old_merge = ''' merged.save_pretrained(
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
)
tok = AutoTokenizer.from_pretrained('''
new_merge = ''' # Clean up base model to free memory
del base
gc.collect()
torch.cuda.empty_cache()
merged.save_pretrained(
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
)
# Clean up merged model
del merged
gc.collect()
torch.cuda.empty_cache()
tok = AutoTokenizer.from_pretrained('''
if old_merge in content and 'del base' not in content:
content = content.replace(old_merge, new_merge)
fixes_applied.append("Added memory cleanup in merge_adapter")
# Fix 5: Add TRL version check
if 'version.parse(trl.__version__)' not in content:
content = content.replace(
'from trl import DPOTrainer, DPOConfig',
'''from trl import DPOTrainer, DPOConfig
# Version check for TRL
try:
from packaging import version
import trl
if version.parse(trl.__version__) < version.parse("0.7.0"):
print(f"Warning: TRL version {trl.__version__} detected. Version >= 0.7.0 recommended.")
except ImportError:
print("Warning: Could not verify TRL version")'''
)
fixes_applied.append("Added TRL version check")
# Fix 6: Replace some critical print statements with logger
content = content.replace('print(f"Using local model at:', 'logger.info(f"Using local model at:')
content = content.replace('print(f"Loading reference model', 'logger.info(f"Loading reference model')
content = content.replace('print(f"DPO Training with beta', 'logger.info(f"DPO Training with beta')
content = content.replace('print(f"Resuming from', 'logger.info(f"Resuming from')
content = content.replace('print("Starting DPO training', 'logger.info("Starting DPO training')
content = content.replace('print(f"Saved best adapter', 'logger.info(f"Saved best adapter')
fixes_applied.append("Replaced print with logger calls")
# Write fixed content
with open(filepath, 'w') as f:
f.write(content)
print("\n" + "="*80)
print("DPO TRAINER - FIXES APPLIED")
print("="*80)
for i, fix in enumerate(fixes_applied, 1):
print(f"{i}. โœ… {fix}")
print("="*80)
print(f"\nโœ… All fixes applied successfully to {filepath}")
print(f"๐Ÿ“ Original backed up to {filepath}.backup")
print("\nTo verify: python run_dpo.py --config config_dpo.yaml")
return True
if __name__ == "__main__":
import sys
filepath = sys.argv[1] if len(sys.argv) > 1 else "run_dpo.py"
print("DPO Trainer - Quick Fix Script")
print("="*80)
print("This script will apply the following critical fixes:")
print(" 1. Add memory cleanup (gc.collect, torch.cuda.empty_cache)")
print(" 2. Add logging setup")
print(" 3. Add custom exceptions (DataFormattingError, DataValidationError)")
print(" 4. Add data validation function")
print(" 5. Add TRL version check")
print(" 6. Replace print with logger")
print("="*80)
print()
response = input("Apply fixes? [y/N]: ")
if response.lower() == 'y':
apply_fixes(filepath)
else:
print("Cancelled")