File size: 3,028 Bytes
b41204b ef5974f b41204b e3473e4 b41204b e2c3e92 b41204b e2c3e92 ef5974f b41204b e2c3e92 b41204b ef5974f e2c3e92 ef5974f e2c3e92 ef5974f e2c3e92 ef5974f e2c3e92 b41204b e2c3e92 e3473e4 e2c3e92 e3473e4 e2c3e92 e3473e4 e2c3e92 e3473e4 e2c3e92 e3473e4 b41204b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "transformers", "accelerate", "jinja2"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer
import trackio
import os
import json
print("🚀 Starting FunctionGemma 270M Fine-tuning (V6 - Optimized with Sample Best Practices)")
model_id = "google/functiongemma-270m-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Ensure pad token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load dataset
dataset = load_dataset("epinfomax/vn-function-calling-dataset", split="train")
def format_conversation(example):
# As per the sample notebook: separate prompt and completion for completion_only_loss
# but TRL SFTTrainer can also handle a single 'text' field with completion_only_loss=True
# by using a specific collator if needed.
# Here we will follow the sample's way of defining prompt and completion columns.
full_text = tokenizer.apply_chat_template(
example["messages"],
tools=example["tools"],
tokenize=False,
add_generation_prompt=False
)
prompt_text = tokenizer.apply_chat_template(
example["messages"][:-1], # Everything except the last assistant message
tools=example["tools"],
tokenize=False,
add_generation_prompt=True # Include 'model' header
)
completion_text = full_text[len(prompt_text):]
return {
"prompt": prompt_text,
"completion": completion_text
}
print("🔄 Pre-processing dataset with prompt/completion split...")
dataset = dataset.map(format_conversation, remove_columns=dataset.column_names)
# LoRA configuration
peft_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
task_type="CAUSAL_LM",
)
# Training configuration (Optimized with Sample Best Practices)
config = SFTConfig(
output_dir="vn-function-gemma-270m-finetuned",
max_length=1024,
push_to_hub=True,
hub_model_id="epinfomax/vn-function-gemma-270m-finetuned",
hub_strategy="every_save",
num_train_epochs=5,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=1e-5, # From sample: more conservative
lr_scheduler_type="cosine", # From sample
optim="adamw_torch_fused", # From sample
logging_steps=5,
save_strategy="steps",
save_steps=50,
report_to="trackio",
project="vn-function-calling",
run_name="function-gemma-270m-v6-optimized",
completion_only_loss=True, # Focus on assistant responses
packing=False
)
# Initialize and train
print("🎯 Initializing SFTTrainer with optimized configuration...")
trainer = SFTTrainer(
model=model_id,
train_dataset=dataset,
peft_config=peft_config,
args=config,
)
trainer.train()
trainer.push_to_hub()
print("✅ Training complete and pushed to Hub!")
|