|
|
|
|
|
""" |
|
|
Quick fix script to apply critical improvements to run_dpo.py |
|
|
Run this to automatically patch the DPO trainer with all critical fixes. |
|
|
""" |
|
|
|
|
|
import re |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
|
|
|
def backup_file(filepath): |
|
|
"""Create backup of original file""" |
|
|
backup_path = Path(str(filepath) + '.backup') |
|
|
shutil.copy2(filepath, backup_path) |
|
|
print(f"โ
Backup created: {backup_path}") |
|
|
return backup_path |
|
|
|
|
|
def apply_fixes(filepath='run_dpo.py'): |
|
|
"""Apply all critical fixes to the DPO training script""" |
|
|
|
|
|
filepath = Path(filepath) |
|
|
if not filepath.exists(): |
|
|
print(f"โ Error: {filepath} not found") |
|
|
return False |
|
|
|
|
|
|
|
|
backup_file(filepath) |
|
|
|
|
|
with open(filepath, 'r') as f: |
|
|
content = f.read() |
|
|
|
|
|
fixes_applied = [] |
|
|
|
|
|
|
|
|
if 'import gc' not in content: |
|
|
content = content.replace( |
|
|
'import time\nfrom pathlib', |
|
|
'import gc\nimport time\nimport logging\nfrom pathlib' |
|
|
) |
|
|
fixes_applied.append("Added gc and logging imports") |
|
|
|
|
|
|
|
|
if 'logging.basicConfig' not in content: |
|
|
content = content.replace( |
|
|
'wandb = None\n\n\n# --------------------------\n# Helpers', |
|
|
'''wandb = None |
|
|
|
|
|
# Setup logging |
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
# -------------------------- |
|
|
# Custom Exceptions |
|
|
# -------------------------- |
|
|
|
|
|
|
|
|
class DataFormattingError(Exception): |
|
|
"""Exception raised for errors in data formatting.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class DataValidationError(Exception): |
|
|
"""Exception raised for errors in data validation.""" |
|
|
pass |
|
|
|
|
|
|
|
|
# -------------------------- |
|
|
# Helpers''' |
|
|
) |
|
|
fixes_applied.append("Added logging setup and custom exceptions") |
|
|
|
|
|
|
|
|
if 'def validate_dpo_data' not in content: |
|
|
validation_func = ''' |
|
|
|
|
|
def validate_dpo_data(dataset, stage: str = "train") -> None: |
|
|
""" |
|
|
Validate DPO dataset has all required fields and proper structure. |
|
|
|
|
|
Args: |
|
|
dataset: Dataset to validate |
|
|
stage: Training stage ("train" or "eval") |
|
|
|
|
|
Raises: |
|
|
DataValidationError if validation fails |
|
|
""" |
|
|
required_fields = ["prompt", "chosen", "rejected"] |
|
|
|
|
|
# Check required fields exist |
|
|
for field in required_fields: |
|
|
if field not in dataset.column_names: |
|
|
raise DataValidationError( |
|
|
f"{stage} dataset missing required field: {field}. " |
|
|
f"Available fields: {dataset.column_names}" |
|
|
) |
|
|
|
|
|
# Sample validation - check first example |
|
|
if len(dataset) > 0: |
|
|
sample = dataset[0] |
|
|
for field in required_fields: |
|
|
if not sample[field] or len(sample[field].strip()) == 0: |
|
|
logger.warning(f"{stage} dataset has empty {field} in first example") |
|
|
|
|
|
logger.info(f"{stage} dataset validation passed: {len(dataset)} examples") |
|
|
|
|
|
''' |
|
|
|
|
|
content = content.replace( |
|
|
'def build_dpo_datasets(cfg: Dict[str, Any], tokenizer)', |
|
|
validation_func + 'def build_dpo_datasets(cfg: Dict[str, Any], tokenizer)' |
|
|
) |
|
|
fixes_applied.append("Added data validation function") |
|
|
|
|
|
|
|
|
old_merge = ''' merged.save_pretrained( |
|
|
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size |
|
|
) |
|
|
|
|
|
tok = AutoTokenizer.from_pretrained(''' |
|
|
|
|
|
new_merge = ''' # Clean up base model to free memory |
|
|
del base |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
merged.save_pretrained( |
|
|
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size |
|
|
) |
|
|
|
|
|
# Clean up merged model |
|
|
del merged |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
tok = AutoTokenizer.from_pretrained(''' |
|
|
|
|
|
if old_merge in content and 'del base' not in content: |
|
|
content = content.replace(old_merge, new_merge) |
|
|
fixes_applied.append("Added memory cleanup in merge_adapter") |
|
|
|
|
|
|
|
|
if 'version.parse(trl.__version__)' not in content: |
|
|
content = content.replace( |
|
|
'from trl import DPOTrainer, DPOConfig', |
|
|
'''from trl import DPOTrainer, DPOConfig |
|
|
|
|
|
# Version check for TRL |
|
|
try: |
|
|
from packaging import version |
|
|
import trl |
|
|
if version.parse(trl.__version__) < version.parse("0.7.0"): |
|
|
print(f"Warning: TRL version {trl.__version__} detected. Version >= 0.7.0 recommended.") |
|
|
except ImportError: |
|
|
print("Warning: Could not verify TRL version")''' |
|
|
) |
|
|
fixes_applied.append("Added TRL version check") |
|
|
|
|
|
|
|
|
content = content.replace('print(f"Using local model at:', 'logger.info(f"Using local model at:') |
|
|
content = content.replace('print(f"Loading reference model', 'logger.info(f"Loading reference model') |
|
|
content = content.replace('print(f"DPO Training with beta', 'logger.info(f"DPO Training with beta') |
|
|
content = content.replace('print(f"Resuming from', 'logger.info(f"Resuming from') |
|
|
content = content.replace('print("Starting DPO training', 'logger.info("Starting DPO training') |
|
|
content = content.replace('print(f"Saved best adapter', 'logger.info(f"Saved best adapter') |
|
|
|
|
|
fixes_applied.append("Replaced print with logger calls") |
|
|
|
|
|
|
|
|
with open(filepath, 'w') as f: |
|
|
f.write(content) |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("DPO TRAINER - FIXES APPLIED") |
|
|
print("="*80) |
|
|
for i, fix in enumerate(fixes_applied, 1): |
|
|
print(f"{i}. โ
{fix}") |
|
|
print("="*80) |
|
|
print(f"\nโ
All fixes applied successfully to {filepath}") |
|
|
print(f"๐ Original backed up to {filepath}.backup") |
|
|
print("\nTo verify: python run_dpo.py --config config_dpo.yaml") |
|
|
|
|
|
return True |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
|
|
|
filepath = sys.argv[1] if len(sys.argv) > 1 else "run_dpo.py" |
|
|
|
|
|
print("DPO Trainer - Quick Fix Script") |
|
|
print("="*80) |
|
|
print("This script will apply the following critical fixes:") |
|
|
print(" 1. Add memory cleanup (gc.collect, torch.cuda.empty_cache)") |
|
|
print(" 2. Add logging setup") |
|
|
print(" 3. Add custom exceptions (DataFormattingError, DataValidationError)") |
|
|
print(" 4. Add data validation function") |
|
|
print(" 5. Add TRL version check") |
|
|
print(" 6. Replace print with logger") |
|
|
print("="*80) |
|
|
print() |
|
|
|
|
|
response = input("Apply fixes? [y/N]: ") |
|
|
if response.lower() == 'y': |
|
|
apply_fixes(filepath) |
|
|
else: |
|
|
print("Cancelled") |
|
|
|