epinfomax's picture
Upload train.py with huggingface_hub
e2c3e92 verified
# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "transformers", "accelerate", "jinja2"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer
import trackio
import os
import json
print("πŸš€ Starting FunctionGemma 270M Fine-tuning (V6 - Optimized with Sample Best Practices)")
model_id = "google/functiongemma-270m-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Ensure pad token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load dataset
dataset = load_dataset("epinfomax/vn-function-calling-dataset", split="train")
def format_conversation(example):
# As per the sample notebook: separate prompt and completion for completion_only_loss
# but TRL SFTTrainer can also handle a single 'text' field with completion_only_loss=True
# by using a specific collator if needed.
# Here we will follow the sample's way of defining prompt and completion columns.
full_text = tokenizer.apply_chat_template(
example["messages"],
tools=example["tools"],
tokenize=False,
add_generation_prompt=False
)
prompt_text = tokenizer.apply_chat_template(
example["messages"][:-1], # Everything except the last assistant message
tools=example["tools"],
tokenize=False,
add_generation_prompt=True # Include 'model' header
)
completion_text = full_text[len(prompt_text):]
return {
"prompt": prompt_text,
"completion": completion_text
}
print("πŸ”„ Pre-processing dataset with prompt/completion split...")
dataset = dataset.map(format_conversation, remove_columns=dataset.column_names)
# LoRA configuration
peft_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
task_type="CAUSAL_LM",
)
# Training configuration (Optimized with Sample Best Practices)
config = SFTConfig(
output_dir="vn-function-gemma-270m-finetuned",
max_length=1024,
push_to_hub=True,
hub_model_id="epinfomax/vn-function-gemma-270m-finetuned",
hub_strategy="every_save",
num_train_epochs=5,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=1e-5, # From sample: more conservative
lr_scheduler_type="cosine", # From sample
optim="adamw_torch_fused", # From sample
logging_steps=5,
save_strategy="steps",
save_steps=50,
report_to="trackio",
project="vn-function-calling",
run_name="function-gemma-270m-v6-optimized",
completion_only_loss=True, # Focus on assistant responses
packing=False
)
# Initialize and train
print("🎯 Initializing SFTTrainer with optimized configuration...")
trainer = SFTTrainer(
model=model_id,
train_dataset=dataset,
peft_config=peft_config,
args=config,
)
trainer.train()
trainer.push_to_hub()
print("βœ… Training complete and pushed to Hub!")