import torch from torch import nn, optim import json import time from pathlib import Path class TinyTrainer: def __init__(self, model, lr=1e-5): self.model = model self.model.train() self.optimizer = optim.AdamW(self.model.parameters(), lr=lr) self.criterion = nn.CrossEntropyLoss() self.step = 0 def train_step(self, input_ids, target_ids): self.optimizer.zero_grad() logits = self.model(input_ids) loss = self.criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1)) loss.backward() self.optimizer.step() self.step += 1 return loss.item()