Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from src.utils import get_batch | |
| def estimate_loss(model: nn.Module, eval_iters, block_size, batch_size, device): | |
| out = {} | |
| model.eval() | |
| for split in ["train", "val"]: | |
| losses = torch.zeros(eval_iters) | |
| for k in range(eval_iters): | |
| X, Y = get_batch(split, block_size, batch_size) | |
| X, Y = X.to(device), Y.to(device) | |
| logits, loss = model(X, Y) | |
| losses[k] = loss.item() | |
| out[split] = losses.mean() | |
| model.train() | |
| return out | |
| def train( | |
| model, | |
| optimizer, | |
| max_iters, | |
| eval_interval, | |
| eval_iters, | |
| block_size, | |
| batch_size, | |
| device, | |
| ): | |
| val_loss = None | |
| for iter in range(max_iters): | |
| if iter % eval_interval == 0: | |
| losses = estimate_loss(model, eval_iters, block_size, batch_size, device) | |
| print( | |
| f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" | |
| ) | |
| if val_loss is not None: | |
| if losses["val"] < val_loss: | |
| torch.save(model, "checkpoints/model.pth") | |
| else: | |
| val_loss = losses["val"] | |
| xb, yb = get_batch("train", block_size, batch_size) | |
| xb, yb = xb.to(device), yb.to(device) | |
| logits, loss = model(xb, yb) | |
| optimizer.zero_grad(set_to_none=True) | |
| loss.backward() | |
| optimizer.step() | |