| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| SFT training for Infinite Craft element fusion using FunctionGemma-270m-it. |
| |
| This model is pre-trained for function calling, so we frame element fusion |
| as a tool call: fuse(result="Steam", emoji="♨️") |
| |
| The model uses special tokens: <start_function_call>, <end_function_call>, |
| <escape> for structured output, which should give it a head start on |
| producing valid structured responses. |
| """ |
|
|
| import json |
| import re |
| from datasets import load_dataset |
| from peft import LoraConfig |
| from trl import SFTTrainer, SFTConfig |
|
|
|
|
| |
| |
| |
| FUSE_TOOL = { |
| "type": "function", |
| "function": { |
| "name": "fuse", |
| "description": "Fuse two elements.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "result": {"type": "string", "description": "New element name."}, |
| "emoji": {"type": "string", "description": "Element emoji."}, |
| }, |
| "required": ["result", "emoji"], |
| }, |
| }, |
| } |
|
|
|
|
| |
| |
| |
| def convert_to_functiongemma_format(example): |
| """ |
| Convert from SmolLM2 chat format: |
| system: "Fuse two elements..." |
| user: "Combine Fire and Water" |
| assistant: '{"result": "Steam", "emoji": "♨️"}' |
| |
| To FunctionGemma tool_call format: |
| developer: "You fuse elements together." |
| user: "Combine Fire and Water" |
| assistant: tool_calls=[{name: "fuse", arguments: {result: "Steam", emoji: "♨️"}}] |
| """ |
| messages = example["messages"] |
| user_msg = messages[1]["content"] |
| assistant_msg = messages[2]["content"] |
|
|
| try: |
| result = json.loads(assistant_msg) |
| result_name = result.get("result", "Unknown") |
| result_emoji = result.get("emoji", "⚪") |
| except json.JSONDecodeError: |
| result_name = "Unknown" |
| result_emoji = "⚪" |
|
|
| new_messages = [ |
| {"role": "developer", "content": "You fuse elements together."}, |
| {"role": "user", "content": user_msg}, |
| { |
| "role": "assistant", |
| "tool_calls": [ |
| { |
| "function": { |
| "name": "fuse", |
| "arguments": { |
| "result": result_name, |
| "emoji": result_emoji, |
| }, |
| } |
| } |
| ], |
| }, |
| ] |
| return {"messages": new_messages, "tools": [FUSE_TOOL]} |
|
|
|
|
| |
| |
| |
| train_dataset = load_dataset( |
| "ericlewis/infinite-craft-recipes", |
| data_files="data/train_100k.jsonl", |
| split="train", |
| ) |
| eval_dataset = load_dataset( |
| "ericlewis/infinite-craft-recipes", |
| data_files="data/val_5k.jsonl", |
| split="train", |
| ) |
|
|
| train_dataset = train_dataset.map(convert_to_functiongemma_format) |
| eval_dataset = eval_dataset.map(convert_to_functiongemma_format) |
|
|
|
|
| |
| |
| |
| peft_config = LoraConfig( |
| r=32, |
| lora_alpha=64, |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM", |
| target_modules=[ |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| ], |
| ) |
|
|
|
|
| |
| |
| |
| config = SFTConfig( |
| output_dir="infinite-craft-functiongemma-270m", |
| push_to_hub=True, |
| hub_model_id="ericlewis/infinite-craft-functiongemma-270m", |
| hub_strategy="every_save", |
| hub_private_repo=False, |
|
|
| num_train_epochs=3, |
| per_device_train_batch_size=32, |
| gradient_accumulation_steps=4, |
| learning_rate=3e-4, |
| max_length=128, |
| logging_steps=25, |
|
|
| save_strategy="steps", |
| save_steps=500, |
| save_total_limit=3, |
| eval_strategy="steps", |
| eval_steps=500, |
|
|
| warmup_ratio=0.05, |
| lr_scheduler_type="cosine", |
| bf16=True, |
|
|
| report_to="trackio", |
| project="infinite-craft", |
| run_name="functiongemma-270m-100k-3ep", |
| ) |
|
|
|
|
| |
| |
| |
| trainer = SFTTrainer( |
| model="google/functiongemma-270m-it", |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| peft_config=peft_config, |
| args=config, |
| ) |
|
|
| trainer.train() |
| trainer.push_to_hub() |
|
|