MiniText-v1.0-base / train.py
Arthur Samuel Galego Panucci FIgueiredo
Upload 5 files
c0741ab verified
import torch
import torch.nn as nn
import os
from model import MiniText
# -----------------------
# hiperparâmetros
# -----------------------
SEQ_LEN = 64
EPOCHS = 12000
LR = 1e-4
SAVE_EVERY = 2000 # salva checkpoint a cada X epochs
CHECKPOINT_PATH = "checkpoint.pt"
# -----------------------
# dataset
# -----------------------
with open("dataset.txt", "rb") as f:
data = torch.tensor(list(f.read()), dtype=torch.long)
# -----------------------
# model + optimizer
# -----------------------
model = MiniText()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss()
start_epoch = 0
# -----------------------
# load checkpoint (se existir)
# -----------------------
if os.path.exists(CHECKPOINT_PATH):
print("Checkpoint encontrado, retomando treino...")
checkpoint = torch.load(CHECKPOINT_PATH)
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
start_epoch = checkpoint["epoch"] + 1
else:
print("Nenhum checkpoint encontrado, treino do zero.")
# -----------------------
# batch sampler
# -----------------------
def get_batch():
idx = torch.randint(0, len(data) - SEQ_LEN - 1, (1,))
x = data[idx:idx + SEQ_LEN].unsqueeze(0)
y = data[idx + 1:idx + SEQ_LEN + 1].unsqueeze(0)
return x, y
# -----------------------
# training loop
# -----------------------
for epoch in range(start_epoch, EPOCHS):
x, y = get_batch()
logits, _ = model(x)
loss = loss_fn(logits.view(-1, 256), y.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {loss.item():.4f}")
# salvar checkpoint
if (epoch + 1) % SAVE_EVERY == 0:
torch.save({
"epoch": epoch,
"model": model.state_dict(),
"optimizer": optimizer.state_dict()
}, CHECKPOINT_PATH)
print("Checkpoint salvo.")
# -----------------------
# salvar modelo final
# -----------------------
torch.save(model.state_dict(), "minitext.pt")
print("Treino finalizado. Modelo salvo em minitext.pt")