File size: 655 Bytes
29cdc9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from torch import nn, optim
import json
import time
from pathlib import Path

class TinyTrainer:
    def __init__(self, model, lr=1e-5):
        self.model = model
        self.model.train()
        self.optimizer = optim.AdamW(self.model.parameters(), lr=lr)
        self.criterion = nn.CrossEntropyLoss()
        self.step = 0

    def train_step(self, input_ids, target_ids):
        self.optimizer.zero_grad()
        logits = self.model(input_ids)
        loss = self.criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
        loss.backward()
        self.optimizer.step()
        self.step += 1
        return loss.item()