File size: 2,068 Bytes
b41204b ef5974f b41204b e3473e4 b41204b e3473e4 ef5974f b41204b ef5974f e3473e4 ef5974f e3473e4 b41204b e3473e4 b41204b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | # /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "transformers", "accelerate", "jinja2"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer
import trackio
import os
print("🚀 Starting FunctionGemma 270M Fine-tuning (V5 - Final)")
model_id = "google/functiongemma-270m-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load dataset
dataset = load_dataset("epinfomax/vn-function-calling-dataset", split="train")
def format_conversation(example):
# Pre-render the conversation using the model's chat template
text = tokenizer.apply_chat_template(
example["messages"],
tools=example["tools"],
tokenize=False,
add_generation_prompt=False
)
return {"text": text}
print("🔄 Pre-processing dataset with chat template...")
dataset = dataset.map(format_conversation, remove_columns=dataset.column_names)
# LoRA configuration - Define early to avoid NameError
peft_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
task_type="CAUSAL_LM",
)
# Training configuration (TRL 0.26.2 style)
config = SFTConfig(
dataset_text_field="text",
max_length=1024, # Confirmed correct for TRL 0.26.2
output_dir="vn-function-gemma-270m-finetuned",
push_to_hub=True,
hub_model_id="epinfomax/vn-function-gemma-270m-finetuned",
hub_strategy="every_save",
num_train_epochs=5,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=5e-5,
logging_steps=5,
save_strategy="steps",
save_steps=50,
report_to="trackio",
project="vn-function-calling",
run_name="function-gemma-270m-final"
)
# Initialize and train
print("🎯 Initializing SFTTrainer...")
trainer = SFTTrainer(
model=model_id,
train_dataset=dataset,
peft_config=peft_config,
args=config,
)
trainer.train()
trainer.push_to_hub()
print("✅ Training complete and pushed to Hub!")
|