Text Generation
PEFT
lora
trl
naming
brand-generation
controllable-generation
nomen-ai / scripts /train_sft.py
krystv's picture
Set Qwen chat EOS token in SFT config
0254329 verified
"""LoRA SFT for Nomen-AI on a Colab/T4."""
import os
from datasets import load_dataset
from peft import LoraConfig
from transformers import TrainerCallback
from trl import SFTConfig, SFTTrainer
BASE=os.environ.get('BASE_MODEL','Qwen/Qwen2.5-1.5B-Instruct'); SFT_DS=os.environ.get('SFT_DATASET','krystv/nomen-ai-sft'); HUB_ID=os.environ.get('HUB_MODEL_ID','krystv/nomen-ai-sft-lora')
class AlertCallback(TrainerCallback):
def on_log(self,args,state,control,logs=None,**kwargs):
if not logs or logs.get('loss') is None: return
try: import trackio
except Exception: return
loss=logs['loss']
if loss!=loss or loss>20: trackio.alert('SFT divergence',f'loss={loss} at step {state.global_step} — lr too high, try x0.1',level='ERROR')
elif state.global_step and state.global_step%200==0: trackio.alert('SFT progress',f'loss={loss:.3f} at step {state.global_step}',level='INFO')
def main():
ds=load_dataset(SFT_DS)
peft=LoraConfig(r=32,lora_alpha=64,lora_dropout=0.05,bias='none',task_type='CAUSAL_LM',target_modules=['q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj'])
args=SFTConfig(output_dir='/tmp/nomen-sft',num_train_epochs=3,per_device_train_batch_size=16,gradient_accumulation_steps=2,learning_rate=2e-4,lr_scheduler_type='cosine',warmup_ratio=0.05,max_length=192,eos_token='<|im_end|>',packing=False,fp16=True,gradient_checkpointing=True,logging_strategy='steps',logging_steps=10,logging_first_step=True,disable_tqdm=True,eval_strategy='epoch',save_strategy='epoch',report_to='trackio',run_name='sft_qwen2.5-1.5b_r32_lr2e-4',push_to_hub=True,hub_model_id=HUB_ID)
trainer=SFTTrainer(model=BASE,args=args,train_dataset=ds['train'],eval_dataset=ds['test'],peft_config=peft,callbacks=[AlertCallback()])
trainer.train(); trainer.push_to_hub()
try: import trackio; trackio.alert('SFT complete',f'pushed to {HUB_ID}',level='INFO')
except Exception: pass
if __name__=='__main__': main()