dpe1/jules-tinyreasoner / src /sft_train.py
dpe1's picture
download
raw
3.6 kB
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pytorch_optimizer import SOAP
import json
import os
import sys
from src.tokenizer import CharTokenizer
from src.model import TinyReasonerModel
class SFTDataset(Dataset):
def __init__(self, tokenizer, data_path, seq_len=512):
self.tokenizer = tokenizer
self.seq_len = seq_len
with open(data_path, "r") as f:
self.data = json.load(f)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
full_text = f"[BOS]{item['prompt']}\n{item['completion']}[EOS]"
tokens = self.tokenizer.encode(full_text)
# Pad or truncate
if len(tokens) > self.seq_len + 1:
tokens = tokens[:self.seq_len + 1]
else:
tokens = tokens + [self.tokenizer.pad_token_id] * (self.seq_len + 1 - len(tokens))
x = torch.tensor(tokens[:-1]).long()
y = torch.tensor(tokens[1:]).long()
return x, y
def train_sft(data_path="data/sft_data.json", output_path="models/sft_model.pt", base_model_path=None, num_epochs=10):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
tokenizer = CharTokenizer()
if not os.path.exists(data_path):
print(f"Error: {data_path} not found.")
return
dataset = SFTDataset(tokenizer, data_path)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
model = TinyReasonerModel(tokenizer.vocab_size).to(device)
if base_model_path and os.path.exists(base_model_path):
model.load_state_dict(torch.load(base_model_path, map_location=device))
print(f"Loaded base model from {base_model_path}.")
elif os.path.exists("models/pretrained.pt"):
model.load_state_dict(torch.load("models/pretrained.pt", map_location=device))
print("Loaded pretrained model.")
else:
print("Warning: No base model found. Training from scratch.")
embedding_params = list(model.embedding.parameters())
other_params = [p for n, p in model.named_parameters() if "embedding" not in n]
param_groups = [
{"params": other_params},
{"params": embedding_params, "max_precond_dim": 1}
]
optimizer = SOAP(param_groups, lr=5e-4)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
model.train()
for epoch in range(num_epochs):
total_loss = 0
for i, (x, y) in enumerate(dataloader):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
logits, _ = model(x)
loss = criterion(logits.view(-1, tokenizer.vocab_size), y.view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch}, Avg Loss: {avg_loss:.4f}")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
torch.save(model.state_dict(), output_path)
print(f"Model saved to {output_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, default="data/sft_data.json")
parser.add_argument("--output", type=str, default="models/sft_model.pt")
parser.add_argument("--base", type=str, default=None)
parser.add_argument("--epochs", type=int, default=10)
args = parser.parse_args()
train_sft(data_path=args.data, output_path=args.output, base_model_path=args.base, num_epochs=args.epochs)

Xet Storage Details

Size:
3.6 kB
·
Xet hash:
546b6bc59d3cf0220fbc1d8e0c03086f72dd22d761b7061b1d085a291ebb7e5b

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.