Sanchari / training /train.py
Mike369williams's picture
Create training/train.py
2848bde verified
# training/train.py
# Minimal training skeleton using Hugging Face transformers Trainer.
# Designed to train Sanchari-S (200-350M) from scratch or fine-tune.
# Run: python training/train.py --config training/config_s.json --tokenizer_dir ../tokenizer
import json
import argparse
import os
from pathlib import Path
from datasets import load_dataset
from transformers import (
AutoTokenizer,
GPT2Config,
AutoModelForCausalLM,
DataCollatorForLanguageModeling,
TrainingArguments,
Trainer
)
def load_config(path):
with open(path, "r") as f:
return json.load(f)
def group_texts(examples, block_size):
# concatenate and chunk
concatenated = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated["input_ids"])
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated.items()
}
return result
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="Path to config json")
parser.add_argument("--tokenizer_dir", required=True, help="Path to tokenizer folder (containing .model/.vocab)")
parser.add_argument("--data_file", default="../data/all_texts.txt", help="Single-line text file or newline-separated.")
parser.add_argument("--output_dir", default="./outputs/sanchari-s", help="Output directory")
args = parser.parse_args()
cfg = load_config(args.config)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir, use_fast=False)
# Make sure tokenizer has pad token
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
block_size = cfg.get("block_size", 1024)
# Create or load dataset (text)
if not os.path.exists(args.data_file):
raise FileNotFoundError(f"Data file not found: {args.data_file}")
raw_dsets = load_dataset("text", data_files={"train": args.data_file})
# Tokenize
def tokenize_fn(examples):
return tokenizer(examples["text"], return_special_tokens_mask=False)
tokenized = raw_dsets.map(
tokenize_fn,
batched=True,
remove_columns=["text"],
num_proc=1
)
# Convert tokenized sequences to blocks of block_size
tokenized = tokenized.map(
lambda examples: {
"input_ids": sum(examples["input_ids"], [])
},
batched=True,
remove_columns=tokenized["train"].column_names
)
# Group into blocks
def chunker(examples):
all_ids = examples["input_ids"]
chunks = [all_ids[i:i+block_size] for i in range(0, len(all_ids), block_size) if len(all_ids[i:i+block_size])==block_size]
return {"input_ids": chunks}
dataset = tokenized["train"].map(
chunker,
batched=True,
remove_columns=tokenized["train"].column_names,
)
# Build model config and model
model_cfg = GPT2Config(
vocab_size=len(tokenizer),
n_positions=block_size,
n_ctx=block_size,
n_embd=cfg["model"]["n_embd"],
n_layer=cfg["model"]["n_layer"],
n_head=cfg["model"]["n_head"],
bos_token_id=tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.convert_tokens_to_ids(tokenizer.cls_token) if tokenizer.cls_token else 1,
eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2,
)
model = AutoModelForCausalLM.from_config(model_cfg)
# resize token embeddings if tokenizer added tokens
model.resize_token_embeddings(len(tokenizer))
# Data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Training arguments from config
train_args = cfg["training"]
training_args = TrainingArguments(
output_dir=args.output_dir,
per_device_train_batch_size=train_args.get("per_device_train_batch_size", 2),
gradient_accumulation_steps=train_args.get("gradient_accumulation_steps", 8),
num_train_epochs=train_args.get("num_train_epochs", 1),
learning_rate=train_args.get("learning_rate", 2e-4),
weight_decay=train_args.get("weight_decay", 0.01),
fp16=train_args.get("fp16", True),
logging_steps=train_args.get("logging_steps", 100),
save_steps=train_args.get("save_steps", 1000),
evaluation_strategy="no",
save_total_limit=3,
remove_unused_columns=False,
report_to="none" # disable wandb by default
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=data_collator
)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
print("Training complete. Model & tokenizer saved to", args.output_dir)
if __name__ == "__main__":
main()