Omegus / train_sft.py
TOKETTER's picture
Align Omegus package with SmolLM2 base
202db76 verified
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "datasets>=2.19.0",
# "huggingface_hub>=0.26.0",
# "trl>=0.12.0",
# "peft>=0.7.0",
# "transformers>=4.46.0",
# "accelerate>=0.33.0",
# "trackio",
# ]
# ///
"""Fine-tune a small Spanish technical chatbot with LoRA SFT.
Expected dataset format:
JSONL rows with a `messages` list of chat messages.
Environment:
HUB_MODEL_ID Optional. Default: TOKETTER/Omegus
BASE_MODEL Optional. Default: HuggingFaceTB/SmolLM2-135M-Instruct
DATASET_PATH Optional. Default: data/charlie_omega_sft.jsonl
TRACKIO_PROJECT Optional. Default: Omegus
PUSH_TO_HUB Optional. Default: 1 when HUB_MODEL_ID is set
REPORT_TO Optional. Default: trackio
"""
from __future__ import annotations
import os
from pathlib import Path
import trackio
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
BASE_MODEL = os.environ.get("BASE_MODEL", "HuggingFaceTB/SmolLM2-135M-Instruct")
DATASET_PATH = os.environ.get("DATASET_PATH", "data/charlie_omega_sft.jsonl")
HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "TOKETTER/Omegus")
TRACKIO_PROJECT = os.environ.get("TRACKIO_PROJECT", "Omegus")
RUN_NAME = os.environ.get("RUN_NAME", "omegus-qwen-0.5b-lora-demo")
NUM_TRAIN_EPOCHS = float(os.environ.get("NUM_TRAIN_EPOCHS", "5"))
TRAIN_BATCH_SIZE = int(os.environ.get("TRAIN_BATCH_SIZE", "2"))
GRADIENT_ACCUMULATION_STEPS = int(os.environ.get("GRADIENT_ACCUMULATION_STEPS", "4"))
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "768"))
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-4"))
REPORT_TO = os.environ.get("REPORT_TO", "trackio")
PUSH_TO_HUB_ENV = os.environ.get("PUSH_TO_HUB")
def main() -> None:
dataset_path = DATASET_PATH
if not Path(dataset_path).exists():
print(f"Downloading dataset from {HUB_MODEL_ID}:{DATASET_PATH}")
dataset_path = hf_hub_download(
repo_id=HUB_MODEL_ID,
filename=DATASET_PATH,
repo_type="model",
)
print(f"Loading dataset from {dataset_path}")
dataset = load_dataset("json", data_files=dataset_path, split="train")
split = dataset.train_test_split(test_size=0.15, seed=42)
push_to_hub = bool(HUB_MODEL_ID)
if PUSH_TO_HUB_ENV is not None:
push_to_hub = PUSH_TO_HUB_ENV.lower() in {"1", "true", "yes", "on"}
output_dir = HUB_MODEL_ID.split("/")[-1] if push_to_hub else "Omegus"
args = SFTConfig(
output_dir=output_dir,
push_to_hub=push_to_hub,
hub_model_id=HUB_MODEL_ID or None,
num_train_epochs=NUM_TRAIN_EPOCHS,
per_device_train_batch_size=TRAIN_BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
learning_rate=LEARNING_RATE,
max_length=MAX_LENGTH,
logging_steps=2,
save_strategy="epoch",
eval_strategy="epoch",
warmup_ratio=0.05,
lr_scheduler_type="cosine",
report_to=REPORT_TO,
project=TRACKIO_PROJECT,
run_name=RUN_NAME,
)
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
trainer = SFTTrainer(
model=BASE_MODEL,
train_dataset=split["train"],
eval_dataset=split["test"],
args=args,
peft_config=peft_config,
)
print(f"Training {BASE_MODEL}")
trainer.train()
if push_to_hub:
print(f"Pushing model to https://huggingface.co/{HUB_MODEL_ID}")
trainer.push_to_hub()
else:
print(f"Saved local adapter/checkpoints in {output_dir}")
if REPORT_TO == "trackio":
trackio.finish()
if __name__ == "__main__":
main()