File size: 3,899 Bytes
be34e96
 
 
 
 
9d31597
be34e96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202db76
be34e96
 
61fc4b0
 
be34e96
 
 
 
 
9d31597
be34e96
 
 
9d31597
be34e96
 
 
 
202db76
be34e96
 
 
 
f85dbd3
 
 
 
 
61fc4b0
 
be34e96
 
 
9d31597
 
 
 
 
 
 
 
 
 
 
be34e96
 
 
61fc4b0
 
be34e96
 
 
 
 
 
f85dbd3
 
 
 
 
be34e96
 
 
 
 
61fc4b0
be34e96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61fc4b0
 
be34e96
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "datasets>=2.19.0",
#     "huggingface_hub>=0.26.0",
#     "trl>=0.12.0",
#     "peft>=0.7.0",
#     "transformers>=4.46.0",
#     "accelerate>=0.33.0",
#     "trackio",
# ]
# ///

"""Fine-tune a small Spanish technical chatbot with LoRA SFT.

Expected dataset format:
  JSONL rows with a `messages` list of chat messages.

Environment:
  HUB_MODEL_ID      Optional. Default: TOKETTER/Omegus
  BASE_MODEL        Optional. Default: HuggingFaceTB/SmolLM2-135M-Instruct
  DATASET_PATH      Optional. Default: data/charlie_omega_sft.jsonl
  TRACKIO_PROJECT   Optional. Default: Omegus
  PUSH_TO_HUB       Optional. Default: 1 when HUB_MODEL_ID is set
  REPORT_TO         Optional. Default: trackio
"""

from __future__ import annotations

import os
from pathlib import Path

import trackio
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer


BASE_MODEL = os.environ.get("BASE_MODEL", "HuggingFaceTB/SmolLM2-135M-Instruct")
DATASET_PATH = os.environ.get("DATASET_PATH", "data/charlie_omega_sft.jsonl")
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "TOKETTER/Omegus")
TRACKIO_PROJECT = os.environ.get("TRACKIO_PROJECT", "Omegus")
RUN_NAME = os.environ.get("RUN_NAME", "omegus-qwen-0.5b-lora-demo")
NUM_TRAIN_EPOCHS = float(os.environ.get("NUM_TRAIN_EPOCHS", "5"))
TRAIN_BATCH_SIZE = int(os.environ.get("TRAIN_BATCH_SIZE", "2"))
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get("GRADIENT_ACCUMULATION_STEPS", "4"))
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "768"))
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-4"))
REPORT_TO = os.environ.get("REPORT_TO", "trackio")
PUSH_TO_HUB_ENV = os.environ.get("PUSH_TO_HUB")


def main() -> None:
    dataset_path = DATASET_PATH
    if not Path(dataset_path).exists():
        print(f"Downloading dataset from {HUB_MODEL_ID}:{DATASET_PATH}")
        dataset_path = hf_hub_download(
            repo_id=HUB_MODEL_ID,
            filename=DATASET_PATH,
            repo_type="model",
        )

    print(f"Loading dataset from {dataset_path}")
    dataset = load_dataset("json", data_files=dataset_path, split="train")
    split = dataset.train_test_split(test_size=0.15, seed=42)

    push_to_hub = bool(HUB_MODEL_ID)
    if PUSH_TO_HUB_ENV is not None:
        push_to_hub = PUSH_TO_HUB_ENV.lower() in {"1", "true", "yes", "on"}
    output_dir = HUB_MODEL_ID.split("/")[-1] if push_to_hub else "Omegus"

    args = SFTConfig(
        output_dir=output_dir,
        push_to_hub=push_to_hub,
        hub_model_id=HUB_MODEL_ID or None,
        num_train_epochs=NUM_TRAIN_EPOCHS,
        per_device_train_batch_size=TRAIN_BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        learning_rate=LEARNING_RATE,
        max_length=MAX_LENGTH,
        logging_steps=2,
        save_strategy="epoch",
        eval_strategy="epoch",
        warmup_ratio=0.05,
        lr_scheduler_type="cosine",
        report_to=REPORT_TO,
        project=TRACKIO_PROJECT,
        run_name=RUN_NAME,
    )

    peft_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    )

    trainer = SFTTrainer(
        model=BASE_MODEL,
        train_dataset=split["train"],
        eval_dataset=split["test"],
        args=args,
        peft_config=peft_config,
    )

    print(f"Training {BASE_MODEL}")
    trainer.train()

    if push_to_hub:
        print(f"Pushing model to https://huggingface.co/{HUB_MODEL_ID}")
        trainer.push_to_hub()
    else:
        print(f"Saved local adapter/checkpoints in {output_dir}")

    if REPORT_TO == "trackio":
        trackio.finish()


if __name__ == "__main__":
    main()