LH-Tech-AI's picture
Create sft.py
1606f40 verified
"""
Β© SupraLabs 2026 - SFT script for Supra-50M on alpaca-cleaned
No TRL. Uses HuggingFace Trainer with prompt-masked cross-entropy loss.
"""
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
print("[*] Loading libraries...")
import torch
import numpy as np
from dataclasses import dataclass
from typing import Optional
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
Trainer,
TrainingArguments,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast
)
from torch.utils.data import Dataset
# ── Config ────────────────────────────────────────────────────────────────────
MODEL_ID = "./Chimera-FINAL"
OUTPUT_DIR = "./Supra-50M-SFT"
MAX_LENGTH = 512 # alpaca samples are short, 512 is plenty
IGNORE_INDEX = -100 # standard label mask value for cross-entropy
# Conservative hyperparameters β€” small model, don't nuke the pretraining
LEARNING_RATE = 3e-4
EPOCHS = 4
BATCH_SIZE = 8
GRAD_ACCUM = 2 # effective batch size = 16
WARMUP_RATIO = 0.1
WEIGHT_DECAY = 0.0
MAX_GRAD_NORM = 1.0
# ── Alpaca prompt template ────────────────────────────────────────────────────
PROMPT_WITH_INPUT = (
"Below is an instruction that describes a task, paired with an input "
"that provides further context. Write a response that appropriately "
"completes the request.\n\n"
"### Instruction:\n{instruction}\n\n"
"### Input:\n{input}\n\n"
"### Response:\n"
)
PROMPT_WITHOUT_INPUT = (
"Below is an instruction that describes a task. Write a response that "
"appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n"
"### Response:\n"
)
def build_prompt(sample: dict) -> tuple[str, str]:
"""Returns (prompt, response) β€” kept separate so we can mask the prompt."""
instruction = sample["instruction"].strip()
inp = sample.get("input", "").strip()
output = sample["output"].strip()
if inp:
prompt = PROMPT_WITH_INPUT.format(instruction=instruction, input=inp)
else:
prompt = PROMPT_WITHOUT_INPUT.format(instruction=instruction)
return prompt, output
# ── Dataset ───────────────────────────────────────────────────────────────────
class AlpacaDataset(Dataset):
"""
Tokenizes each sample and masks the prompt portion of the labels so the
model only computes loss on the response tokens β€” not on the instruction.
"""
def __init__(self, hf_dataset, tokenizer: PreTrainedTokenizerBase, max_length: int):
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = hf_dataset
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
prompt, response = build_prompt(self.samples[idx])
# Tokenize prompt and response separately so we know the prompt length
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
prompt_ids = [self.tokenizer.bos_token_id] + prompt_ids # explizit
response_ids = self.tokenizer.encode(response, add_special_tokens=False) + [self.tokenizer.eos_token_id]
input_ids = prompt_ids + response_ids
# Truncate to max_length
input_ids = input_ids[:self.max_length]
# Labels: mask prompt tokens with IGNORE_INDEX
prompt_len = min(len(prompt_ids), len(input_ids))
labels = [IGNORE_INDEX] * prompt_len + input_ids[prompt_len:]
# Sanity: both must be the same length after truncation
assert len(input_ids) == len(labels)
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
# ── Collator ──────────────────────────────────────────────────────────────────
@dataclass
class PaddingCollator:
"""
Right-pads input_ids and labels to the longest sequence in the batch.
Labels are padded with IGNORE_INDEX so padding never contributes to loss.
"""
tokenizer: PreTrainedTokenizerBase
max_length: int
def __call__(self, batch):
max_len = max(len(x["input_ids"]) for x in batch)
max_len = min(max_len, self.max_length)
input_ids_padded = []
labels_padded = []
attention_masks = []
for item in batch:
ids = item["input_ids"][:max_len]
lbls = item["labels"][:max_len]
pad_n = max_len - len(ids)
input_ids_padded.append(
torch.cat([ids, torch.full((pad_n,), self.tokenizer.pad_token_id, dtype=torch.long)])
)
labels_padded.append(
torch.cat([lbls, torch.full((pad_n,), IGNORE_INDEX, dtype=torch.long)])
)
attention_masks.append(
torch.cat([torch.ones(len(ids), dtype=torch.long),
torch.zeros(pad_n, dtype=torch.long)])
)
return {
"input_ids": torch.stack(input_ids_padded),
"labels": torch.stack(labels_padded),
"attention_mask": torch.stack(attention_masks),
}
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
# Load tokenizer + model from Hub
print(f"[*] Loading tokenizer from {MODEL_ID}...")
from tokenizers import ByteLevelBPETokenizer
fast_tokenizer = ByteLevelBPETokenizer(
"custom_llama_tokenizer-vocab.json",
"custom_llama_tokenizer-merges.txt"
)
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=fast_tokenizer,
bos_token="<s>",
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
)
print(f"[*] Loading model from {MODEL_ID}...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map="auto",
)
print(f"[+] Model loaded β€” {model.num_parameters():,} parameters")
# Load alpaca-cleaned (β‰ˆ52k instruction-tuning pairs)
print("[*] Loading alpaca-cleaned dataset...")
raw = load_dataset("yahma/alpaca-cleaned", split="train")
print(f"[+] Dataset: {len(raw):,} samples")
# Optional: quick sanity-check split (comment out for full training)
# raw = raw.select(range(1000))
split = raw.train_test_split(test_size=0.01, seed=42)
train_dataset = AlpacaDataset(split["train"], tokenizer, MAX_LENGTH)
eval_dataset = AlpacaDataset(split["test"], tokenizer, MAX_LENGTH)
collator = PaddingCollator(tokenizer=tokenizer, max_length=MAX_LENGTH)
print(f"[+] Dataset ready: {len(train_dataset):,} samples")
print(f"[+] Example prompt preview:\n{build_prompt(raw[0])[0][:800]}...")
# Training arguments
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LEARNING_RATE,
lr_scheduler_type="cosine",
warmup_ratio=WARMUP_RATIO,
weight_decay=WEIGHT_DECAY,
max_grad_norm=MAX_GRAD_NORM,
bf16=True,
fp16=False,
logging_steps=50,
save_total_limit=2,
report_to="none",
dataloader_num_workers=8,
dataloader_pin_memory=True,
optim="adamw_torch_fused",
adam_beta1=0.9,
adam_beta2=0.999,
push_to_hub=False,
seed=42,
data_seed=42,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=collator,
)
print("[*] Starting SFT...")
trainer.train()
print(f"[*] Saving final model to {OUTPUT_DIR}-FINAL ...")
trainer.save_model(f"{OUTPUT_DIR}-FINAL")
tokenizer.save_pretrained(f"{OUTPUT_DIR}-FINAL")
print("[+] Done.")
if __name__ == "__main__":
main()