| |
| """ |
| Quick fix script to apply critical improvements to run_dpo.py |
| Run this to automatically patch the DPO trainer with all critical fixes. |
| """ |
|
|
| import re |
| import shutil |
| from pathlib import Path |
|
|
| def backup_file(filepath): |
| """Create backup of original file""" |
| backup_path = Path(str(filepath) + '.backup') |
| shutil.copy2(filepath, backup_path) |
| print(f"✅ Backup created: {backup_path}") |
| return backup_path |
|
|
| def apply_fixes(filepath='run_dpo.py'): |
| """Apply all critical fixes to the DPO training script""" |
| |
| filepath = Path(filepath) |
| if not filepath.exists(): |
| print(f"❌ Error: {filepath} not found") |
| return False |
| |
| |
| backup_file(filepath) |
| |
| with open(filepath, 'r') as f: |
| content = f.read() |
| |
| fixes_applied = [] |
| |
| |
| if 'import gc' not in content: |
| content = content.replace( |
| 'import time\nfrom pathlib', |
| 'import gc\nimport time\nimport logging\nfrom pathlib' |
| ) |
| fixes_applied.append("Added gc and logging imports") |
| |
| |
| if 'logging.basicConfig' not in content: |
| content = content.replace( |
| 'wandb = None\n\n\n# --------------------------\n# Helpers', |
| '''wandb = None |
| |
| # Setup logging |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
| |
| |
| # -------------------------- |
| # Custom Exceptions |
| # -------------------------- |
| |
| |
| class DataFormattingError(Exception): |
| """Exception raised for errors in data formatting.""" |
| pass |
| |
| |
| class DataValidationError(Exception): |
| """Exception raised for errors in data validation.""" |
| pass |
| |
| |
| # -------------------------- |
| # Helpers''' |
| ) |
| fixes_applied.append("Added logging setup and custom exceptions") |
| |
| |
| if 'def validate_dpo_data' not in content: |
| validation_func = ''' |
| |
| def validate_dpo_data(dataset, stage: str = "train") -> None: |
| """ |
| Validate DPO dataset has all required fields and proper structure. |
| |
| Args: |
| dataset: Dataset to validate |
| stage: Training stage ("train" or "eval") |
| |
| Raises: |
| DataValidationError if validation fails |
| """ |
| required_fields = ["prompt", "chosen", "rejected"] |
| |
| # Check required fields exist |
| for field in required_fields: |
| if field not in dataset.column_names: |
| raise DataValidationError( |
| f"{stage} dataset missing required field: {field}. " |
| f"Available fields: {dataset.column_names}" |
| ) |
| |
| # Sample validation - check first example |
| if len(dataset) > 0: |
| sample = dataset[0] |
| for field in required_fields: |
| if not sample[field] or len(sample[field].strip()) == 0: |
| logger.warning(f"{stage} dataset has empty {field} in first example") |
| |
| logger.info(f"{stage} dataset validation passed: {len(dataset)} examples") |
| |
| ''' |
| |
| content = content.replace( |
| 'def build_dpo_datasets(cfg: Dict[str, Any], tokenizer)', |
| validation_func + 'def build_dpo_datasets(cfg: Dict[str, Any], tokenizer)' |
| ) |
| fixes_applied.append("Added data validation function") |
| |
| |
| old_merge = ''' merged.save_pretrained( |
| str(final_dir), safe_serialization=True, max_shard_size=max_shard_size |
| ) |
| |
| tok = AutoTokenizer.from_pretrained(''' |
| |
| new_merge = ''' # Clean up base model to free memory |
| del base |
| gc.collect() |
| torch.cuda.empty_cache() |
| |
| merged.save_pretrained( |
| str(final_dir), safe_serialization=True, max_shard_size=max_shard_size |
| ) |
| |
| # Clean up merged model |
| del merged |
| gc.collect() |
| torch.cuda.empty_cache() |
| |
| tok = AutoTokenizer.from_pretrained(''' |
| |
| if old_merge in content and 'del base' not in content: |
| content = content.replace(old_merge, new_merge) |
| fixes_applied.append("Added memory cleanup in merge_adapter") |
| |
| |
| if 'version.parse(trl.__version__)' not in content: |
| content = content.replace( |
| 'from trl import DPOTrainer, DPOConfig', |
| '''from trl import DPOTrainer, DPOConfig |
| |
| # Version check for TRL |
| try: |
| from packaging import version |
| import trl |
| if version.parse(trl.__version__) < version.parse("0.7.0"): |
| print(f"Warning: TRL version {trl.__version__} detected. Version >= 0.7.0 recommended.") |
| except ImportError: |
| print("Warning: Could not verify TRL version")''' |
| ) |
| fixes_applied.append("Added TRL version check") |
| |
| |
| content = content.replace('print(f"Using local model at:', 'logger.info(f"Using local model at:') |
| content = content.replace('print(f"Loading reference model', 'logger.info(f"Loading reference model') |
| content = content.replace('print(f"DPO Training with beta', 'logger.info(f"DPO Training with beta') |
| content = content.replace('print(f"Resuming from', 'logger.info(f"Resuming from') |
| content = content.replace('print("Starting DPO training', 'logger.info("Starting DPO training') |
| content = content.replace('print(f"Saved best adapter', 'logger.info(f"Saved best adapter') |
| |
| fixes_applied.append("Replaced print with logger calls") |
| |
| |
| with open(filepath, 'w') as f: |
| f.write(content) |
| |
| print("\n" + "="*80) |
| print("DPO TRAINER - FIXES APPLIED") |
| print("="*80) |
| for i, fix in enumerate(fixes_applied, 1): |
| print(f"{i}. ✅ {fix}") |
| print("="*80) |
| print(f"\n✅ All fixes applied successfully to {filepath}") |
| print(f"📁 Original backed up to {filepath}.backup") |
| print("\nTo verify: python run_dpo.py --config config_dpo.yaml") |
| |
| return True |
|
|
| if __name__ == "__main__": |
| import sys |
| |
| filepath = sys.argv[1] if len(sys.argv) > 1 else "run_dpo.py" |
| |
| print("DPO Trainer - Quick Fix Script") |
| print("="*80) |
| print("This script will apply the following critical fixes:") |
| print(" 1. Add memory cleanup (gc.collect, torch.cuda.empty_cache)") |
| print(" 2. Add logging setup") |
| print(" 3. Add custom exceptions (DataFormattingError, DataValidationError)") |
| print(" 4. Add data validation function") |
| print(" 5. Add TRL version check") |
| print(" 6. Replace print with logger") |
| print("="*80) |
| print() |
| |
| response = input("Apply fixes? [y/N]: ") |
| if response.lower() == 'y': |
| apply_fixes(filepath) |
| else: |
| print("Cancelled") |
|
|