victor HF Staff commited on
Commit
e42a55a
·
verified ·
1 Parent(s): 85146f9

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +119 -0
train.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["unsloth[colab-new]", "trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "xformers"]
3
+ # ///
4
+ """
5
+ Fine-tune FunctionGemma for llama-agent on HuggingFace Jobs.
6
+
7
+ Submit with:
8
+ hf_jobs("uv", {
9
+ "script": "<this script content>",
10
+ "flavor": "a10g-large",
11
+ "timeout": "2h",
12
+ "secrets": {"HF_TOKEN": "$HF_TOKEN"}
13
+ })
14
+ """
15
+
16
+ import os
17
+
18
+ # Config - override via environment variables
19
+ MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/functiongemma-270m-it")
20
+ DATASET_NAME = os.environ.get("DATASET_NAME", "victor/functiongemma-agent-sft")
21
+ OUTPUT_REPO = os.environ.get("OUTPUT_REPO", "victor/functiongemma-agent-finetuned")
22
+ MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "4096"))
23
+ LORA_R = int(os.environ.get("LORA_R", "128"))
24
+ LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "256"))
25
+ NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "3"))
26
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "4"))
27
+ GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "2"))
28
+ LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-4"))
29
+
30
+ # Imports
31
+ from unsloth import FastLanguageModel
32
+ from unsloth.chat_templates import train_on_responses_only
33
+ from datasets import load_dataset
34
+ from trl import SFTTrainer, SFTConfig
35
+ import trackio
36
+
37
+ print(f"Loading model: {MODEL_NAME}")
38
+ model, tokenizer = FastLanguageModel.from_pretrained(
39
+ model_name=MODEL_NAME,
40
+ max_seq_length=MAX_SEQ_LENGTH,
41
+ load_in_4bit=False,
42
+ load_in_8bit=False,
43
+ load_in_16bit=True,
44
+ full_finetuning=False,
45
+ )
46
+
47
+ print(f"Adding LoRA adapters (r={LORA_R}, alpha={LORA_ALPHA})")
48
+ model = FastLanguageModel.get_peft_model(
49
+ model,
50
+ r=LORA_R,
51
+ lora_alpha=LORA_ALPHA,
52
+ lora_dropout=0,
53
+ target_modules=[
54
+ "q_proj", "k_proj", "v_proj", "o_proj",
55
+ "gate_proj", "up_proj", "down_proj",
56
+ ],
57
+ bias="none",
58
+ use_gradient_checkpointing="unsloth",
59
+ random_state=3407,
60
+ use_rslora=False,
61
+ loftq_config=None,
62
+ )
63
+
64
+ print(f"Loading dataset: {DATASET_NAME}")
65
+ dataset = load_dataset(DATASET_NAME, split="train")
66
+ print(f"Dataset size: {len(dataset)} examples")
67
+
68
+ # SFTConfig with Trackio monitoring
69
+ sft_config = SFTConfig(
70
+ dataset_text_field="text",
71
+ per_device_train_batch_size=BATCH_SIZE,
72
+ gradient_accumulation_steps=GRAD_ACCUM,
73
+ warmup_steps=5,
74
+ num_train_epochs=NUM_EPOCHS,
75
+ learning_rate=LEARNING_RATE,
76
+ logging_steps=10,
77
+ optim="adamw_8bit",
78
+ weight_decay=0.001,
79
+ lr_scheduler_type="linear",
80
+ seed=3407,
81
+ output_dir="./outputs",
82
+ save_steps=500,
83
+ save_total_limit=3,
84
+ max_seq_length=MAX_SEQ_LENGTH,
85
+ # Trackio monitoring
86
+ report_to="trackio",
87
+ run_name="functiongemma-agent-sft",
88
+ # Hub push (CRITICAL - environment is ephemeral!)
89
+ push_to_hub=True,
90
+ hub_model_id=OUTPUT_REPO,
91
+ hub_strategy="every_save",
92
+ )
93
+
94
+ # Create trainer
95
+ trainer = SFTTrainer(
96
+ model=model,
97
+ tokenizer=tokenizer,
98
+ train_dataset=dataset,
99
+ eval_dataset=None,
100
+ args=sft_config,
101
+ )
102
+
103
+ # CRITICAL: Only train on model responses, not instructions
104
+ print("Applying train_on_responses_only (masking instruction tokens)...")
105
+ trainer = train_on_responses_only(
106
+ trainer,
107
+ instruction_part="<start_of_turn>user\n",
108
+ response_part="<start_of_turn>model\n",
109
+ )
110
+
111
+ print("Starting training...")
112
+ trainer_stats = trainer.train()
113
+
114
+ # Final push to hub
115
+ print(f"Pushing final model to {OUTPUT_REPO}...")
116
+ trainer.push_to_hub()
117
+
118
+ print("Training complete!")
119
+ print(f"Model saved to: https://huggingface.co/{OUTPUT_REPO}")