Spaces:
Runtime error
Runtime error
Ram Narayanan Ananthakrishnapuram Sampath
Added dashboard and rendered with a local LLM to validate Env interaction
8918e76 | from unsloth import FastLanguageModel | |
| from unsloth.chat_templates import get_chat_template, train_on_responses_only | |
| from datasets import load_dataset | |
| from trl import SFTTrainer | |
| from transformers import TrainingArguments | |
| # 1. Load Qwen 2.5 (3 Billion Parameters - super light on RAM!) | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name="unsloth/Qwen2.5-3B-Instruct", | |
| max_seq_length=2048, | |
| load_in_4bit=True, # Compresses the weights so it fits easily | |
| ) | |
| # 2. Use the ChatML template (Qwen's native language format) | |
| tokenizer = get_chat_template(tokenizer, chat_template="chatml") | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=16, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| lora_alpha=16, | |
| lora_dropout=0, | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| ) | |
| dataset = load_dataset("json", data_files="sft_data.json", split="train") | |
| def format_prompts(examples): | |
| convos = examples["conversations"] | |
| # Apply the ChatML template to the raw JSON data | |
| texts = [tokenizer.apply_chat_template(c, tokenize=False, add_generation_prompt=False) for c in convos] | |
| return {"text": texts} | |
| dataset = dataset.map(format_prompts, batched=True) | |
| trainer = SFTTrainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| train_dataset=dataset, | |
| dataset_text_field="text", | |
| max_seq_length=2048, | |
| args=TrainingArguments( | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=4, | |
| warmup_steps=5, | |
| max_steps=150, | |
| learning_rate=2e-4, | |
| fp16=not FastLanguageModel.is_bfloat16_supported(), | |
| bf16=FastLanguageModel.is_bfloat16_supported(), | |
| logging_steps=10, | |
| optim="adamw_8bit", | |
| output_dir="sft_outputs", | |
| seed=3407, | |
| ), | |
| ) | |
| # 3. Tell Unsloth to look for Qwen's specific ChatML tags | |
| trainer = train_on_responses_only( | |
| trainer, | |
| instruction_part="<|im_start|>user\n", | |
| response_part="<|im_start|>assistant\n", | |
| ) | |
| print("Starting Supervised Fine-Tuning on Qwen 3B...") | |
| trainer.train() | |
| model.save_pretrained("voice_agent_sft") | |
| tokenizer.save_pretrained("voice_agent_sft") | |
| print("SFT complete! Base model saved to ./voice_agent_sft") |