epinfomax's picture
Upload train.py with huggingface_hub
b41204b verified
raw
history blame
1.3 kB
# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "transformers", "accelerate"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import trackio
import os
print("🚀 Starting FunctionGemma 2B Fine-tuning")
# Load dataset
dataset = load_dataset("epinfomax/vn-function-calling-dataset", split="train")
# Training configuration
config = SFTConfig(
output_dir="vn-function-gemma-finetuned",
push_to_hub=True,
hub_model_id="epinfomax/vn-function-gemma-finetuned",
hub_strategy="every_save",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-5,
logging_steps=10,
save_strategy="steps",
save_steps=50,
report_to="trackio",
project="vn-function-calling",
run_name="function-gemma-2b-baseline"
)
# LoRA configuration
peft_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
task_type="CAUSAL_LM",
)
# Initialize and train
trainer = SFTTrainer(
model="google/function-gemma-2b",
train_dataset=dataset,
peft_config=peft_config,
args=config,
)
trainer.train()
trainer.push_to_hub()
print("✅ Training complete and pushed to Hub!")