File size: 3,028 Bytes
b41204b
ef5974f
b41204b
 
 
 
 
e3473e4
b41204b
 
e2c3e92
b41204b
e2c3e92
ef5974f
 
 
b41204b
e2c3e92
 
 
 
b41204b
 
 
ef5974f
e2c3e92
 
 
 
 
 
ef5974f
 
 
 
 
e2c3e92
 
 
 
 
 
 
 
 
 
 
 
 
 
ef5974f
e2c3e92
ef5974f
 
e2c3e92
b41204b
 
 
 
 
 
 
e2c3e92
e3473e4
 
e2c3e92
e3473e4
 
 
 
 
 
e2c3e92
 
 
e3473e4
 
 
 
 
e2c3e92
 
 
e3473e4
 
 
e2c3e92
e3473e4
 
 
 
 
 
b41204b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "transformers", "accelerate", "jinja2"]
# ///

from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer
import trackio
import os
import json

print("🚀 Starting FunctionGemma 270M Fine-tuning (V6 - Optimized with Sample Best Practices)")

model_id = "google/functiongemma-270m-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Ensure pad token is set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load dataset
dataset = load_dataset("epinfomax/vn-function-calling-dataset", split="train")

def format_conversation(example):
    # As per the sample notebook: separate prompt and completion for completion_only_loss
    # but TRL SFTTrainer can also handle a single 'text' field with completion_only_loss=True
    # by using a specific collator if needed. 
    # Here we will follow the sample's way of defining prompt and completion columns.
    
    full_text = tokenizer.apply_chat_template(
        example["messages"],
        tools=example["tools"],
        tokenize=False,
        add_generation_prompt=False
    )
    
    prompt_text = tokenizer.apply_chat_template(
        example["messages"][:-1], # Everything except the last assistant message
        tools=example["tools"],
        tokenize=False,
        add_generation_prompt=True # Include 'model' header
    )
    
    completion_text = full_text[len(prompt_text):]
    
    return {
        "prompt": prompt_text,
        "completion": completion_text
    }

print("🔄 Pre-processing dataset with prompt/completion split...")
dataset = dataset.map(format_conversation, remove_columns=dataset.column_names)

# LoRA configuration
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    task_type="CAUSAL_LM",
)

# Training configuration (Optimized with Sample Best Practices)
config = SFTConfig(
    output_dir="vn-function-gemma-270m-finetuned",
    max_length=1024,
    push_to_hub=True,
    hub_model_id="epinfomax/vn-function-gemma-270m-finetuned",
    hub_strategy="every_save",
    num_train_epochs=5,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=1e-5, # From sample: more conservative
    lr_scheduler_type="cosine", # From sample
    optim="adamw_torch_fused", # From sample
    logging_steps=5,
    save_strategy="steps",
    save_steps=50,
    report_to="trackio",
    project="vn-function-calling",
    run_name="function-gemma-270m-v6-optimized",
    completion_only_loss=True, # Focus on assistant responses
    packing=False
)

# Initialize and train
print("🎯 Initializing SFTTrainer with optimized configuration...")
trainer = SFTTrainer(
    model=model_id,
    train_dataset=dataset,
    peft_config=peft_config,
    args=config,
)

trainer.train()
trainer.push_to_hub()
print("✅ Training complete and pushed to Hub!")