infinite-craft-training / train_functiongemma_sft.py
ericlewis's picture
Upload train_functiongemma_sft.py with huggingface_hub
07037a5 verified
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.36.0",
# "accelerate>=0.24.0",
# "trackio",
# "datasets",
# "jinja2>=3.1.0",
# ]
# ///
"""
SFT training for Infinite Craft element fusion using FunctionGemma-270m-it.
This model is pre-trained for function calling, so we frame element fusion
as a tool call: fuse(result="Steam", emoji="♨️")
The model uses special tokens: <start_function_call>, <end_function_call>,
<escape> for structured output, which should give it a head start on
producing valid structured responses.
"""
import json
import re
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
# ---------------------------------------------------------------------------
# 1. Define the fuse tool schema (used in chat template)
# ---------------------------------------------------------------------------
FUSE_TOOL = {
"type": "function",
"function": {
"name": "fuse",
"description": "Fuse two elements.",
"parameters": {
"type": "object",
"properties": {
"result": {"type": "string", "description": "New element name."},
"emoji": {"type": "string", "description": "Element emoji."},
},
"required": ["result", "emoji"],
},
},
}
# ---------------------------------------------------------------------------
# 2. Convert our dataset to FunctionGemma's tool_call format
# ---------------------------------------------------------------------------
def convert_to_functiongemma_format(example):
"""
Convert from SmolLM2 chat format:
system: "Fuse two elements..."
user: "Combine Fire and Water"
assistant: '{"result": "Steam", "emoji": "♨️"}'
To FunctionGemma tool_call format:
developer: "You fuse elements together."
user: "Combine Fire and Water"
assistant: tool_calls=[{name: "fuse", arguments: {result: "Steam", emoji: "♨️"}}]
"""
messages = example["messages"]
user_msg = messages[1]["content"] # "Combine Fire and Water"
assistant_msg = messages[2]["content"] # '{"result": "Steam", "emoji": "♨️"}'
try:
result = json.loads(assistant_msg)
result_name = result.get("result", "Unknown")
result_emoji = result.get("emoji", "⚪")
except json.JSONDecodeError:
result_name = "Unknown"
result_emoji = "⚪"
new_messages = [
{"role": "developer", "content": "You fuse elements together."},
{"role": "user", "content": user_msg},
{
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "fuse",
"arguments": {
"result": result_name,
"emoji": result_emoji,
},
}
}
],
},
]
return {"messages": new_messages, "tools": [FUSE_TOOL]}
# ---------------------------------------------------------------------------
# 3. Load and convert datasets
# ---------------------------------------------------------------------------
train_dataset = load_dataset(
"ericlewis/infinite-craft-recipes",
data_files="data/train_100k.jsonl",
split="train",
)
eval_dataset = load_dataset(
"ericlewis/infinite-craft-recipes",
data_files="data/val_5k.jsonl",
split="train",
)
train_dataset = train_dataset.map(convert_to_functiongemma_format)
eval_dataset = eval_dataset.map(convert_to_functiongemma_format)
# ---------------------------------------------------------------------------
# 4. LoRA config — target Gemma3 attention + MLP layers
# ---------------------------------------------------------------------------
peft_config = LoraConfig(
r=32,
lora_alpha=64,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
)
# ---------------------------------------------------------------------------
# 5. Training config
# ---------------------------------------------------------------------------
config = SFTConfig(
output_dir="infinite-craft-functiongemma-270m",
push_to_hub=True,
hub_model_id="ericlewis/infinite-craft-functiongemma-270m",
hub_strategy="every_save",
hub_private_repo=False,
num_train_epochs=3,
per_device_train_batch_size=32,
gradient_accumulation_steps=4, # effective batch size: 128
learning_rate=3e-4,
max_length=128, # FunctionGemma format is ~112 tokens (more than SmolLM2's 62)
logging_steps=25,
save_strategy="steps",
save_steps=500,
save_total_limit=3,
eval_strategy="steps",
eval_steps=500,
warmup_ratio=0.05,
lr_scheduler_type="cosine",
bf16=True,
report_to="trackio",
project="infinite-craft",
run_name="functiongemma-270m-100k-3ep",
)
# ---------------------------------------------------------------------------
# 6. Train
# ---------------------------------------------------------------------------
trainer = SFTTrainer(
model="google/functiongemma-270m-it",
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
args=config,
)
trainer.train()
trainer.push_to_hub()