import torch import torch.nn as nn import os from model import MiniText # ----------------------- # hiperparâmetros # ----------------------- SEQ_LEN = 64 EPOCHS = 12000 LR = 1e-4 SAVE_EVERY = 2000 # salva checkpoint a cada X epochs CHECKPOINT_PATH = "checkpoint.pt" # ----------------------- # dataset # ----------------------- with open("dataset.txt", "rb") as f: data = torch.tensor(list(f.read()), dtype=torch.long) # ----------------------- # model + optimizer # ----------------------- model = MiniText() optimizer = torch.optim.Adam(model.parameters(), lr=LR) loss_fn = nn.CrossEntropyLoss() start_epoch = 0 # ----------------------- # load checkpoint (se existir) # ----------------------- if os.path.exists(CHECKPOINT_PATH): print("Checkpoint encontrado, retomando treino...") checkpoint = torch.load(CHECKPOINT_PATH) model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) start_epoch = checkpoint["epoch"] + 1 else: print("Nenhum checkpoint encontrado, treino do zero.") # ----------------------- # batch sampler # ----------------------- def get_batch(): idx = torch.randint(0, len(data) - SEQ_LEN - 1, (1,)) x = data[idx:idx + SEQ_LEN].unsqueeze(0) y = data[idx + 1:idx + SEQ_LEN + 1].unsqueeze(0) return x, y # ----------------------- # training loop # ----------------------- for epoch in range(start_epoch, EPOCHS): x, y = get_batch() logits, _ = model(x) loss = loss_fn(logits.view(-1, 256), y.view(-1)) optimizer.zero_grad() loss.backward() optimizer.step() print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {loss.item():.4f}") # salvar checkpoint if (epoch + 1) % SAVE_EVERY == 0: torch.save({ "epoch": epoch, "model": model.state_dict(), "optimizer": optimizer.state_dict() }, CHECKPOINT_PATH) print("Checkpoint salvo.") # ----------------------- # salvar modelo final # ----------------------- torch.save(model.state_dict(), "minitext.pt") print("Treino finalizado. Modelo salvo em minitext.pt")