File size: 3,478 Bytes
e42a55a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# /// script
# dependencies = ["unsloth[colab-new]", "trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "xformers"]
# ///
"""
Fine-tune FunctionGemma for llama-agent on HuggingFace Jobs.

Submit with:
    hf_jobs("uv", {
        "script": "<this script content>",
        "flavor": "a10g-large",
        "timeout": "2h",
        "secrets": {"HF_TOKEN": "$HF_TOKEN"}
    })
"""

import os

# Config - override via environment variables
MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/functiongemma-270m-it")
DATASET_NAME = os.environ.get("DATASET_NAME", "victor/functiongemma-agent-sft")
OUTPUT_REPO = os.environ.get("OUTPUT_REPO", "victor/functiongemma-agent-finetuned")
MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "4096"))
LORA_R = int(os.environ.get("LORA_R", "128"))
LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "256"))
NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "3"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "4"))
GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "2"))
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-4"))

# Imports
from unsloth import FastLanguageModel
from unsloth.chat_templates import train_on_responses_only
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
import trackio

print(f"Loading model: {MODEL_NAME}")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    load_in_4bit=False,
    load_in_8bit=False,
    load_in_16bit=True,
    full_finetuning=False,
)

print(f"Adding LoRA adapters (r={LORA_R}, alpha={LORA_ALPHA})")
model = FastLanguageModel.get_peft_model(
    model,
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=0,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)

print(f"Loading dataset: {DATASET_NAME}")
dataset = load_dataset(DATASET_NAME, split="train")
print(f"Dataset size: {len(dataset)} examples")

# SFTConfig with Trackio monitoring
sft_config = SFTConfig(
    dataset_text_field="text",
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    warmup_steps=5,
    num_train_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    logging_steps=10,
    optim="adamw_8bit",
    weight_decay=0.001,
    lr_scheduler_type="linear",
    seed=3407,
    output_dir="./outputs",
    save_steps=500,
    save_total_limit=3,
    max_seq_length=MAX_SEQ_LENGTH,
    # Trackio monitoring
    report_to="trackio",
    run_name="functiongemma-agent-sft",
    # Hub push (CRITICAL - environment is ephemeral!)
    push_to_hub=True,
    hub_model_id=OUTPUT_REPO,
    hub_strategy="every_save",
)

# Create trainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    eval_dataset=None,
    args=sft_config,
)

# CRITICAL: Only train on model responses, not instructions
print("Applying train_on_responses_only (masking instruction tokens)...")
trainer = train_on_responses_only(
    trainer,
    instruction_part="<start_of_turn>user\n",
    response_part="<start_of_turn>model\n",
)

print("Starting training...")
trainer_stats = trainer.train()

# Final push to hub
print(f"Pushing final model to {OUTPUT_REPO}...")
trainer.push_to_hub()

print("Training complete!")
print(f"Model saved to: https://huggingface.co/{OUTPUT_REPO}")