File size: 11,629 Bytes
4942b80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
#!/usr/bin/env python3
"""
Bánh mì chuyển ngữ — Gemma 4 Unsloth fine-tuning pipeline.

Submission to the Gemma 4 Good Hackathon, Unsloth $10K track:
"best fine-tuned Gemma 4 model created using Unsloth, optimized for a
specific, impactful task" — short-utterance translation.

Defaults match the winning config (exp08) — training loss 2.916 -> 0.0115
(-99.6%, ~250x), achieved on a single NVIDIA L4 in ~12 h with checkpoint
resume. Running the script with no flags reproduces that run end-to-end.

Usage:
    # Reproduce the submission run (defaults = exp08 winning config)
    python scripts/train.py

    # Override individual hyperparameters
    python scripts/train.py --lora-rank 128 --learning-rate 5e-5

    # Use a custom local dataset
    python scripts/train.py --dataset data/processed/train.jsonl

    # Use the pinned YAML directly
    python scripts/train.py $(python -c "import yaml,sys; [print(f'--{k.replace(\"_\",\"-\")}', v) for k,v in yaml.safe_load(open('configs/train_config.yaml')).items() if not isinstance(v,bool) or v]")
"""

# IMPORTANT: import unsloth FIRST before other ML libraries
import unsloth
import argparse
import json
import os

import torch
from unsloth import FastModel
from unsloth.chat_templates import get_chat_template, train_on_responses_only
from datasets import load_dataset, Dataset
from trl import SFTTrainer, SFTConfig


def parse_args():
    parser = argparse.ArgumentParser(description="Fine-tune Gemma 4 with Unsloth")

    # Model — winning config: Unsloth Gemma 4 E4B 4-bit
    parser.add_argument("--model", type=str,
                        default="unsloth/gemma-4-E4B-it-unsloth-bnb-4bit",
                        help="Pretrained model name (default = exp08)")
    parser.add_argument("--max-seq-length", type=int, default=2048)
    parser.add_argument("--load-4bit", action="store_true", default=True,
                        help="QLoRA (4-bit) — default on for the exp08 reproduction")
    parser.add_argument("--load-16bit", action="store_true", help="bf16 LoRA (for MoE)")

    # LoRA — winning config: r=64 with RSLoRA
    parser.add_argument("--lora-rank", type=int, default=64,
                        help="LoRA rank (default = exp08 winning value)")
    parser.add_argument("--lora-alpha", type=int, default=None,
                        help="LoRA alpha (defaults to lora-rank)")
    parser.add_argument("--lora-dropout", type=float, default=0.0)
    parser.add_argument("--use-rslora", action="store_true", default=True,
                        help="Rank-stabilized LoRA (default on, required for r>=64)")

    # Data — winning config: 10k FineTome-100k samples
    parser.add_argument("--dataset", type=str, default="mlabonne/FineTome-100k",
                        help="Dataset name or path to local JSONL")
    parser.add_argument("--max-samples", type=int, default=10000,
                        help="Dataset sample cap (default = exp08 winning value)")
    parser.add_argument("--system-prompt", type=str, default=None)

    # Training — winning config: lr=7e-5, 5 epochs, grad_accum=8
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--grad-accum", type=int, default=8,
                        help="Gradient accumulation steps (default = exp08)")
    parser.add_argument("--learning-rate", type=float, default=7e-5,
                        help="Learning rate (default = exp08 winning value)")
    parser.add_argument("--max-steps", type=int, default=None,
                        help="Set only if num-epochs is None")
    parser.add_argument("--num-epochs", type=int, default=5,
                        help="Training epochs (default = exp08 winning value)")
    parser.add_argument("--warmup-steps", type=int, default=50,
                        help="LR warmup steps (default = exp08)")
    parser.add_argument("--weight-decay", type=float, default=0.01,
                        help="Weight decay (default = exp08)")
    parser.add_argument("--save-steps", type=int, default=250,
                        help="Checkpoint every N steps (default = exp08, crash-safe)")
    parser.add_argument("--save-total-limit", type=int, default=3,
                        help="Keep only the last N checkpoints")
    parser.add_argument("--scheduler", type=str, default="cosine",
                        choices=["cosine", "linear", "constant"])
    parser.add_argument("--seed", type=int, default=3407)

    # Output
    parser.add_argument("--output-dir", type=str, default="outputs")
    parser.add_argument("--save-path", type=str, default="checkpoints/finetuned/lora_adapter")
    parser.add_argument("--logging-steps", type=int, default=1)
    parser.add_argument("--resume-from", type=str, default=None,
                        help="Resume training from a checkpoint dir (e.g. outputs/exp06/checkpoint-2500)")

    # Post-training pipeline
    parser.add_argument("--experiment-name", type=str, default="experiment",
                        help="Name for this experiment (used in logs/reports)")
    return parser.parse_args()


def load_local_jsonl(path, max_samples=None):
    """Load a local JSONL dataset."""
    data = []
    with open(path) as f:
        for line in f:
            data.append(json.loads(line))
            if max_samples and len(data) >= max_samples:
                break
    return Dataset.from_list(data)


def main():
    args = parse_args()

    if args.lora_alpha is None:
        args.lora_alpha = args.lora_rank

    # Determine loading mode
    if not args.load_4bit and not args.load_16bit:
        args.load_4bit = True  # Default to QLoRA

    print("=" * 60)
    print("Gemma 4 Fine-Tuning with Unsloth")
    print("=" * 60)
    print(f"Model:          {args.model}")
    print(f"Quantization:   {'4-bit (QLoRA)' if args.load_4bit else '16-bit (bf16 LoRA)'}")
    print(f"LoRA rank:      {args.lora_rank}")
    print(f"LoRA alpha:     {args.lora_alpha}")
    print(f"Learning rate:  {args.learning_rate}")
    print(f"Max steps:      {args.max_steps}")
    print(f"Dataset:        {args.dataset}")
    print(f"Max samples:    {args.max_samples}")
    print("=" * 60)

    # ---- Load Model ----
    print("\n[1/5] Loading model...")
    model, tokenizer = FastModel.from_pretrained(
        model_name=args.model,
        max_seq_length=args.max_seq_length,
        load_in_4bit=args.load_4bit,
        load_in_16bit=args.load_16bit if not args.load_4bit else False,
        full_finetuning=False,
    )

    # ---- Configure LoRA ----
    print("\n[2/5] Configuring LoRA adapters...")
    model = FastModel.get_peft_model(
        model,
        finetune_vision_layers=False,
        finetune_language_layers=True,
        finetune_attention_modules=True,
        finetune_mlp_modules=True,
        r=args.lora_rank,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        bias="none",
        random_state=args.seed,
        use_rslora=args.use_rslora,
    )

    # Print trainable params
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"  Trainable: {trainable:,} / {total:,} ({100 * trainable / total:.2f}%)")

    # ---- Setup Chat Template ----
    print("\n[3/5] Setting up chat template...")
    tokenizer = get_chat_template(tokenizer, chat_template="gemma-4")

    # ---- Load & Format Dataset ----
    print("\n[4/5] Loading and formatting dataset...")

    if args.dataset.endswith(".jsonl") and os.path.exists(args.dataset):
        dataset = load_local_jsonl(args.dataset, args.max_samples)
        # Dataset is already in messages format, apply chat template
        def format_messages(examples):
            texts = []
            for messages in examples["messages"]:
                text = tokenizer.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=False
                ).removeprefix("<bos>")
                texts.append(text)
            return {"text": texts}
        dataset = dataset.map(format_messages, batched=True)
    else:
        # Load from HuggingFace
        from unsloth.chat_templates import standardize_data_formats

        if args.max_samples:
            dataset = load_dataset(args.dataset, split=f"train[:{args.max_samples}]")
        else:
            dataset = load_dataset(args.dataset, split="train")

        dataset = standardize_data_formats(dataset)

        def formatting_prompts_func(examples):
            convos = examples["conversations"]
            texts = [
                tokenizer.apply_chat_template(
                    convo, tokenize=False, add_generation_prompt=False
                ).removeprefix("<bos>")
                for convo in convos
            ]
            return {"text": texts}

        dataset = dataset.map(formatting_prompts_func, batched=True)

    print(f"  Dataset size: {len(dataset)} examples")

    # ---- Setup Logger ----
    from training_logger import TrainingLogger
    log_dir = os.path.join(args.output_dir, "logs")
    training_logger = TrainingLogger(
        output_dir=log_dir,
        experiment_name=args.experiment_name,
    )

    # ---- Train ----
    print("\n[5/5] Starting training...")

    training_kwargs = dict(
        dataset_text_field="text",
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum,
        warmup_steps=args.warmup_steps,
        learning_rate=args.learning_rate,
        logging_steps=args.logging_steps,
        optim="adamw_8bit",
        weight_decay=args.weight_decay,
        lr_scheduler_type=args.scheduler,
        seed=args.seed,
        output_dir=args.output_dir,
        report_to="none",
        save_strategy="steps",
        save_steps=args.save_steps,
        save_total_limit=args.save_total_limit,
    )

    if args.num_epochs:
        training_kwargs["num_train_epochs"] = args.num_epochs
    else:
        training_kwargs["max_steps"] = args.max_steps

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=dataset,
        args=SFTConfig(**training_kwargs),
        callbacks=[training_logger],
    )

    # Train on responses only (mask user/system tokens)
    trainer = train_on_responses_only(
        trainer,
        instruction_part="<|turn>user\n",
        response_part="<|turn>model\n",
    )

    trainer_stats = trainer.train(resume_from_checkpoint=args.resume_from)

    # ---- Save ----
    print(f"\nTraining complete!")
    print(f"  Runtime: {trainer_stats.metrics['train_runtime']:.1f}s")
    print(f"  Final loss: {trainer_stats.metrics.get('train_loss', 'N/A')}")

    print(f"\nSaving LoRA adapter to {args.save_path}...")
    os.makedirs(args.save_path, exist_ok=True)
    model.save_pretrained(args.save_path)
    tokenizer.save_pretrained(args.save_path)

    # Save training logs
    training_logger.save_summary(
        trainer_stats,
        config={
            "model_name": args.model,
            "dataset_name": args.dataset,
            "dataset_size": len(dataset),
            "lora_rank": args.lora_rank,
        },
    )

    print(f"\nMETRICS: loss={trainer_stats.metrics.get('train_loss', -1):.4f} "
          f"runtime={trainer_stats.metrics['train_runtime']:.1f} "
          f"samples={len(dataset)} lora_rank={args.lora_rank} lr={args.learning_rate}")

    print("\nDone! Next steps:")
    print(f"  1. Evaluate: python scripts/evaluate.py --model {args.save_path}")
    print(f"  2. Export:   python scripts/export_model.py --model {args.save_path}")


if __name__ == "__main__":
    main()