""" © SupraLabs 2026 - SFT script for Supra-50M on alpaca-cleaned No TRL. Uses HuggingFace Trainer with prompt-masked cross-entropy loss. """ import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" print("[*] Loading libraries...") import torch import numpy as np from dataclasses import dataclass from typing import Optional from datasets import load_dataset from transformers import ( AutoModelForCausalLM, Trainer, TrainingArguments, PreTrainedTokenizerBase, PreTrainedTokenizerFast ) from torch.utils.data import Dataset # ── Config ──────────────────────────────────────────────────────────────────── MODEL_ID = "./Chimera-FINAL" OUTPUT_DIR = "./Supra-50M-SFT" MAX_LENGTH = 512 # alpaca samples are short, 512 is plenty IGNORE_INDEX = -100 # standard label mask value for cross-entropy # Conservative hyperparameters — small model, don't nuke the pretraining LEARNING_RATE = 3e-4 EPOCHS = 4 BATCH_SIZE = 8 GRAD_ACCUM = 2 # effective batch size = 16 WARMUP_RATIO = 0.1 WEIGHT_DECAY = 0.0 MAX_GRAD_NORM = 1.0 # ── Alpaca prompt template ──────────────────────────────────────────────────── PROMPT_WITH_INPUT = ( "Below is an instruction that describes a task, paired with an input " "that provides further context. Write a response that appropriately " "completes the request.\n\n" "### Instruction:\n{instruction}\n\n" "### Input:\n{input}\n\n" "### Response:\n" ) PROMPT_WITHOUT_INPUT = ( "Below is an instruction that describes a task. Write a response that " "appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n" "### Response:\n" ) def build_prompt(sample: dict) -> tuple[str, str]: """Returns (prompt, response) — kept separate so we can mask the prompt.""" instruction = sample["instruction"].strip() inp = sample.get("input", "").strip() output = sample["output"].strip() if inp: prompt = PROMPT_WITH_INPUT.format(instruction=instruction, input=inp) else: prompt = PROMPT_WITHOUT_INPUT.format(instruction=instruction) return prompt, output # ── Dataset ─────────────────────────────────────────────────────────────────── class AlpacaDataset(Dataset): """ Tokenizes each sample and masks the prompt portion of the labels so the model only computes loss on the response tokens — not on the instruction. """ def __init__(self, hf_dataset, tokenizer: PreTrainedTokenizerBase, max_length: int): self.tokenizer = tokenizer self.max_length = max_length self.samples = hf_dataset def __len__(self): return len(self.samples) def __getitem__(self, idx): prompt, response = build_prompt(self.samples[idx]) # Tokenize prompt and response separately so we know the prompt length prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False) prompt_ids = [self.tokenizer.bos_token_id] + prompt_ids # explizit response_ids = self.tokenizer.encode(response, add_special_tokens=False) + [self.tokenizer.eos_token_id] input_ids = prompt_ids + response_ids # Truncate to max_length input_ids = input_ids[:self.max_length] # Labels: mask prompt tokens with IGNORE_INDEX prompt_len = min(len(prompt_ids), len(input_ids)) labels = [IGNORE_INDEX] * prompt_len + input_ids[prompt_len:] # Sanity: both must be the same length after truncation assert len(input_ids) == len(labels) return { "input_ids": torch.tensor(input_ids, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), } # ── Collator ────────────────────────────────────────────────────────────────── @dataclass class PaddingCollator: """ Right-pads input_ids and labels to the longest sequence in the batch. Labels are padded with IGNORE_INDEX so padding never contributes to loss. """ tokenizer: PreTrainedTokenizerBase max_length: int def __call__(self, batch): max_len = max(len(x["input_ids"]) for x in batch) max_len = min(max_len, self.max_length) input_ids_padded = [] labels_padded = [] attention_masks = [] for item in batch: ids = item["input_ids"][:max_len] lbls = item["labels"][:max_len] pad_n = max_len - len(ids) input_ids_padded.append( torch.cat([ids, torch.full((pad_n,), self.tokenizer.pad_token_id, dtype=torch.long)]) ) labels_padded.append( torch.cat([lbls, torch.full((pad_n,), IGNORE_INDEX, dtype=torch.long)]) ) attention_masks.append( torch.cat([torch.ones(len(ids), dtype=torch.long), torch.zeros(pad_n, dtype=torch.long)]) ) return { "input_ids": torch.stack(input_ids_padded), "labels": torch.stack(labels_padded), "attention_mask": torch.stack(attention_masks), } # ── Main ────────────────────────────────────────────────────────────────────── def main(): # Load tokenizer + model from Hub print(f"[*] Loading tokenizer from {MODEL_ID}...") from tokenizers import ByteLevelBPETokenizer fast_tokenizer = ByteLevelBPETokenizer( "custom_llama_tokenizer-vocab.json", "custom_llama_tokenizer-merges.txt" ) tokenizer = PreTrainedTokenizerFast( tokenizer_object=fast_tokenizer, bos_token="", eos_token="", unk_token="", pad_token="", ) print(f"[*] Loading model from {MODEL_ID}...") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16, device_map="auto", ) print(f"[+] Model loaded — {model.num_parameters():,} parameters") # Load alpaca-cleaned (≈52k instruction-tuning pairs) print("[*] Loading alpaca-cleaned dataset...") raw = load_dataset("yahma/alpaca-cleaned", split="train") print(f"[+] Dataset: {len(raw):,} samples") # Optional: quick sanity-check split (comment out for full training) # raw = raw.select(range(1000)) split = raw.train_test_split(test_size=0.01, seed=42) train_dataset = AlpacaDataset(split["train"], tokenizer, MAX_LENGTH) eval_dataset = AlpacaDataset(split["test"], tokenizer, MAX_LENGTH) collator = PaddingCollator(tokenizer=tokenizer, max_length=MAX_LENGTH) print(f"[+] Dataset ready: {len(train_dataset):,} samples") print(f"[+] Example prompt preview:\n{build_prompt(raw[0])[0][:800]}...") # Training arguments training_args = TrainingArguments( output_dir=OUTPUT_DIR, num_train_epochs=EPOCHS, per_device_train_batch_size=BATCH_SIZE, gradient_accumulation_steps=GRAD_ACCUM, learning_rate=LEARNING_RATE, lr_scheduler_type="cosine", warmup_ratio=WARMUP_RATIO, weight_decay=WEIGHT_DECAY, max_grad_norm=MAX_GRAD_NORM, bf16=True, fp16=False, logging_steps=50, save_total_limit=2, report_to="none", dataloader_num_workers=8, dataloader_pin_memory=True, optim="adamw_torch_fused", adam_beta1=0.9, adam_beta2=0.999, push_to_hub=False, seed=42, data_seed=42, eval_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=collator, ) print("[*] Starting SFT...") trainer.train() print(f"[*] Saving final model to {OUTPUT_DIR}-FINAL ...") trainer.save_model(f"{OUTPUT_DIR}-FINAL") tokenizer.save_pretrained(f"{OUTPUT_DIR}-FINAL") print("[+] Done.") if __name__ == "__main__": main()