voice_agent / train_sft.py
Ram Narayanan Ananthakrishnapuram Sampath
Added dashboard and rendered with a local LLM to validate Env interaction
8918e76
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template, train_on_responses_only
from datasets import load_dataset
from trl import SFTTrainer
from transformers import TrainingArguments
# 1. Load Qwen 2.5 (3 Billion Parameters - super light on RAM!)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Qwen2.5-3B-Instruct",
max_seq_length=2048,
load_in_4bit=True, # Compresses the weights so it fits easily
)
# 2. Use the ChatML template (Qwen's native language format)
tokenizer = get_chat_template(tokenizer, chat_template="chatml")
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
)
dataset = load_dataset("json", data_files="sft_data.json", split="train")
def format_prompts(examples):
convos = examples["conversations"]
# Apply the ChatML template to the raw JSON data
texts = [tokenizer.apply_chat_template(c, tokenize=False, add_generation_prompt=False) for c in convos]
return {"text": texts}
dataset = dataset.map(format_prompts, batched=True)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=2048,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
max_steps=150,
learning_rate=2e-4,
fp16=not FastLanguageModel.is_bfloat16_supported(),
bf16=FastLanguageModel.is_bfloat16_supported(),
logging_steps=10,
optim="adamw_8bit",
output_dir="sft_outputs",
seed=3407,
),
)
# 3. Tell Unsloth to look for Qwen's specific ChatML tags
trainer = train_on_responses_only(
trainer,
instruction_part="<|im_start|>user\n",
response_part="<|im_start|>assistant\n",
)
print("Starting Supervised Fine-Tuning on Qwen 3B...")
trainer.train()
model.save_pretrained("voice_agent_sft")
tokenizer.save_pretrained("voice_agent_sft")
print("SFT complete! Base model saved to ./voice_agent_sft")