# /// script # requires-python = ">=3.10" # dependencies = [ # "trl>=0.12.0", # "peft>=0.7.0", # "transformers>=4.36.0", # "accelerate>=0.24.0", # "trackio", # "datasets", # "jinja2>=3.1.0", # ] # /// """ SFT training for Infinite Craft element fusion using FunctionGemma-270m-it. This model is pre-trained for function calling, so we frame element fusion as a tool call: fuse(result="Steam", emoji="♨️") The model uses special tokens: , , for structured output, which should give it a head start on producing valid structured responses. """ import json import re from datasets import load_dataset from peft import LoraConfig from trl import SFTTrainer, SFTConfig # --------------------------------------------------------------------------- # 1. Define the fuse tool schema (used in chat template) # --------------------------------------------------------------------------- FUSE_TOOL = { "type": "function", "function": { "name": "fuse", "description": "Fuse two elements.", "parameters": { "type": "object", "properties": { "result": {"type": "string", "description": "New element name."}, "emoji": {"type": "string", "description": "Element emoji."}, }, "required": ["result", "emoji"], }, }, } # --------------------------------------------------------------------------- # 2. Convert our dataset to FunctionGemma's tool_call format # --------------------------------------------------------------------------- def convert_to_functiongemma_format(example): """ Convert from SmolLM2 chat format: system: "Fuse two elements..." user: "Combine Fire and Water" assistant: '{"result": "Steam", "emoji": "♨️"}' To FunctionGemma tool_call format: developer: "You fuse elements together." user: "Combine Fire and Water" assistant: tool_calls=[{name: "fuse", arguments: {result: "Steam", emoji: "♨️"}}] """ messages = example["messages"] user_msg = messages[1]["content"] # "Combine Fire and Water" assistant_msg = messages[2]["content"] # '{"result": "Steam", "emoji": "♨️"}' try: result = json.loads(assistant_msg) result_name = result.get("result", "Unknown") result_emoji = result.get("emoji", "⚪") except json.JSONDecodeError: result_name = "Unknown" result_emoji = "⚪" new_messages = [ {"role": "developer", "content": "You fuse elements together."}, {"role": "user", "content": user_msg}, { "role": "assistant", "tool_calls": [ { "function": { "name": "fuse", "arguments": { "result": result_name, "emoji": result_emoji, }, } } ], }, ] return {"messages": new_messages, "tools": [FUSE_TOOL]} # --------------------------------------------------------------------------- # 3. Load and convert datasets # --------------------------------------------------------------------------- train_dataset = load_dataset( "ericlewis/infinite-craft-recipes", data_files="data/train_100k.jsonl", split="train", ) eval_dataset = load_dataset( "ericlewis/infinite-craft-recipes", data_files="data/val_5k.jsonl", split="train", ) train_dataset = train_dataset.map(convert_to_functiongemma_format) eval_dataset = eval_dataset.map(convert_to_functiongemma_format) # --------------------------------------------------------------------------- # 4. LoRA config — target Gemma3 attention + MLP layers # --------------------------------------------------------------------------- peft_config = LoraConfig( r=32, lora_alpha=64, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], ) # --------------------------------------------------------------------------- # 5. Training config # --------------------------------------------------------------------------- config = SFTConfig( output_dir="infinite-craft-functiongemma-270m", push_to_hub=True, hub_model_id="ericlewis/infinite-craft-functiongemma-270m", hub_strategy="every_save", hub_private_repo=False, num_train_epochs=3, per_device_train_batch_size=32, gradient_accumulation_steps=4, # effective batch size: 128 learning_rate=3e-4, max_length=128, # FunctionGemma format is ~112 tokens (more than SmolLM2's 62) logging_steps=25, save_strategy="steps", save_steps=500, save_total_limit=3, eval_strategy="steps", eval_steps=500, warmup_ratio=0.05, lr_scheduler_type="cosine", bf16=True, report_to="trackio", project="infinite-craft", run_name="functiongemma-270m-100k-3ep", ) # --------------------------------------------------------------------------- # 6. Train # --------------------------------------------------------------------------- trainer = SFTTrainer( model="google/functiongemma-270m-it", train_dataset=train_dataset, eval_dataset=eval_dataset, peft_config=peft_config, args=config, ) trainer.train() trainer.push_to_hub()