epinfomax's picture
Upload train.py with huggingface_hub
ef5974f verified
raw
history blame
2.06 kB
# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "transformers", "accelerate", "jinja2"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer
import trackio
import os
print("πŸš€ Starting FunctionGemma 270M Fine-tuning (V2 with Template Fix)")
model_id = "google/functiongemma-270m-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load dataset
dataset = load_dataset("epinfomax/vn-function-calling-dataset", split="train")
def format_conversation(example):
# Modern transformers template supports 'tools' argument
# We render the template to a string so SFTTrainer doesn't have to guess
text = tokenizer.apply_chat_template(
example["messages"],
tools=example["tools"],
tokenize=False,
add_generation_prompt=False
)
return {"text": text}
print("πŸ”„ Pre-processing dataset with chat template...")
dataset = dataset.map(format_conversation, remove_columns=dataset.column_names)
# Training configuration
config = SFTConfig(
dataset_text_field="text", # Use the pre-rendered text
max_seq_length=1024,
output_dir="vn-function-gemma-270m-finetuned",
push_to_hub=True,
hub_model_id="epinfomax/vn-function-gemma-270m-finetuned",
hub_strategy="every_save",
num_train_epochs=5,
per_device_train_batch_size=8,
gradient_accumulation_steps=2,
learning_rate=5e-5,
logging_steps=5,
save_strategy="steps",
save_steps=50,
report_to="trackio",
project="vn-function-calling",
run_name="function-gemma-270m-v2-fixed"
)
# LoRA configuration
peft_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
task_type="CAUSAL_LM",
)
# Initialize and train
trainer = SFTTrainer(
model=model_id,
train_dataset=dataset,
peft_config=peft_config,
args=config,
)
trainer.train()
trainer.push_to_hub()
print("βœ… Training complete and pushed to Hub!")