|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import os
|
|
|
from model import MiniText
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SEQ_LEN = 64
|
|
|
EPOCHS = 12000
|
|
|
LR = 1e-4
|
|
|
SAVE_EVERY = 2000
|
|
|
CHECKPOINT_PATH = "checkpoint.pt"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open("dataset.txt", "rb") as f:
|
|
|
data = torch.tensor(list(f.read()), dtype=torch.long)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = MiniText()
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
|
|
|
|
start_epoch = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists(CHECKPOINT_PATH):
|
|
|
print("Checkpoint encontrado, retomando treino...")
|
|
|
checkpoint = torch.load(CHECKPOINT_PATH)
|
|
|
model.load_state_dict(checkpoint["model"])
|
|
|
optimizer.load_state_dict(checkpoint["optimizer"])
|
|
|
start_epoch = checkpoint["epoch"] + 1
|
|
|
else:
|
|
|
print("Nenhum checkpoint encontrado, treino do zero.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_batch():
|
|
|
idx = torch.randint(0, len(data) - SEQ_LEN - 1, (1,))
|
|
|
x = data[idx:idx + SEQ_LEN].unsqueeze(0)
|
|
|
y = data[idx + 1:idx + SEQ_LEN + 1].unsqueeze(0)
|
|
|
return x, y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for epoch in range(start_epoch, EPOCHS):
|
|
|
x, y = get_batch()
|
|
|
logits, _ = model(x)
|
|
|
loss = loss_fn(logits.view(-1, 256), y.view(-1))
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {loss.item():.4f}")
|
|
|
|
|
|
|
|
|
if (epoch + 1) % SAVE_EVERY == 0:
|
|
|
torch.save({
|
|
|
"epoch": epoch,
|
|
|
"model": model.state_dict(),
|
|
|
"optimizer": optimizer.state_dict()
|
|
|
}, CHECKPOINT_PATH)
|
|
|
print("Checkpoint salvo.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), "minitext.pt")
|
|
|
print("Treino finalizado. Modelo salvo em minitext.pt")
|
|
|
|