Text Generation
PEFT
lora
trl
naming
brand-generation
controllable-generation
nomen-ai / scripts /train_dpo.py
krystv's picture
Set left-padding tokenizer for DPOTrainer
439dbca verified
"""DPO anti-derivative training for Nomen-AI.
Uses the SFT LoRA adapter as the starting policy and trains it on
prompt/chosen/rejected preference pairs. The tokenizer is explicitly configured
with padding_side='left' as required by TRL DPOTrainer preference batching.
"""
import os, torch
from datasets import load_dataset
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
from trl import DPOConfig, DPOTrainer
BASE=os.environ.get('BASE_MODEL','Qwen/Qwen2.5-1.5B-Instruct')
SFT_ADAPTER=os.environ.get('SFT_ADAPTER','krystv/nomen-ai-sft-lora')
DPO_DS=os.environ.get('DPO_DATASET','krystv/nomen-ai-dpo')
HUB_ID=os.environ.get('HUB_MODEL_ID','krystv/nomen-ai-dpo-lora')
class AlertCallback(TrainerCallback):
def on_log(self,args,state,control,logs=None,**kwargs):
if not logs or logs.get('loss') is None: return
try: import trackio
except Exception: return
loss=logs['loss']
if loss!=loss or loss>10: trackio.alert('DPO divergence',f'loss={loss} at step {state.global_step} — reduce lr x0.1',level='ERROR')
elif state.global_step and state.global_step%100==0: trackio.alert('DPO progress',f'loss={loss:.3f} at step {state.global_step}',level='INFO')
def main():
ds=load_dataset(DPO_DS)
tokenizer=AutoTokenizer.from_pretrained(BASE)
tokenizer.padding_side='left'
if tokenizer.pad_token is None:
tokenizer.pad_token=tokenizer.eos_token
base=AutoModelForCausalLM.from_pretrained(BASE,torch_dtype=torch.float16)
model=PeftModel.from_pretrained(base,SFT_ADAPTER,is_trainable=True)
args=DPOConfig(output_dir='/tmp/nomen-dpo',num_train_epochs=1,per_device_train_batch_size=8,gradient_accumulation_steps=4,learning_rate=5e-6,max_length=192,max_prompt_length=160,fp16=True,gradient_checkpointing=True,logging_strategy='steps',logging_steps=10,logging_first_step=True,disable_tqdm=True,eval_strategy='epoch',save_strategy='epoch',report_to='trackio',run_name='dpo_qwen2.5-1.5b_antiderivative',push_to_hub=True,hub_model_id=HUB_ID)
trainer=DPOTrainer(model=model,args=args,train_dataset=ds['train'],eval_dataset=ds['test'],processing_class=tokenizer,callbacks=[AlertCallback()])
trainer.train(); trainer.push_to_hub()
try: import trackio; trackio.alert('DPO complete',f'pushed to {HUB_ID}',level='INFO')
except Exception: pass
if __name__=='__main__': main()