File size: 655 Bytes
29cdc9d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | import torch
from torch import nn, optim
import json
import time
from pathlib import Path
class TinyTrainer:
def __init__(self, model, lr=1e-5):
self.model = model
self.model.train()
self.optimizer = optim.AdamW(self.model.parameters(), lr=lr)
self.criterion = nn.CrossEntropyLoss()
self.step = 0
def train_step(self, input_ids, target_ids):
self.optimizer.zero_grad()
logits = self.model(input_ids)
loss = self.criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
loss.backward()
self.optimizer.step()
self.step += 1
return loss.item()
|