|
|
import os |
|
|
import json |
|
|
import torch |
|
|
from torch import nn, optim |
|
|
from pathlib import Path |
|
|
from torch.utils.data import DataLoader |
|
|
from services.word_generation_dataset import WordGenDataset, collate_fn |
|
|
from services.tokenizer import tokenize_dataset, build_vocab |
|
|
from services.transformer import TinyTransformer |
|
|
from services.model import save_model |
|
|
|
|
|
|
|
|
print("Preprocessing...") |
|
|
|
|
|
|
|
|
cwd = os.path.dirname(__file__) |
|
|
file_path = os.path.join(cwd, "data", "definitions-2.json") |
|
|
with open(file_path, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
print("Data loaded!") |
|
|
|
|
|
|
|
|
inputs, outputs = tokenize_dataset(data) |
|
|
print("Data Tokenized!") |
|
|
|
|
|
|
|
|
vocab, inv_vocab = build_vocab(inputs, outputs) |
|
|
print("Vocabulary built!") |
|
|
|
|
|
|
|
|
dataset = WordGenDataset(inputs, outputs, vocab) |
|
|
print("Dataset built!") |
|
|
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) |
|
|
print("Dataloader built!") |
|
|
|
|
|
|
|
|
print("Training...") |
|
|
|
|
|
|
|
|
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
|
|
vocab_size = len(vocab) |
|
|
model = TinyTransformer(vocab_size=vocab_size).to(device) |
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss(ignore_index=vocab["<pad>"]) |
|
|
optimizer = optim.Adam(model.parameters(), lr=1e-4) |
|
|
|
|
|
|
|
|
num_epochs = 10 |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
model.train() |
|
|
total_loss = 0 |
|
|
|
|
|
for batch in dataloader: |
|
|
src, tgt = batch |
|
|
src, tgt = src.to(device), tgt.to(device) |
|
|
|
|
|
|
|
|
tgt_input = tgt[:, :-1] |
|
|
tgt_expected = tgt[:, 1:] |
|
|
|
|
|
|
|
|
logits = model(src, tgt_input) |
|
|
|
|
|
|
|
|
loss = criterion(logits.reshape(-1, vocab_size), tgt_expected.reshape(-1)) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
|
|
|
avg_loss = total_loss / len(dataloader) |
|
|
print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}") |
|
|
|
|
|
|
|
|
print("Saving...") |
|
|
save_model(model, vocab) |
|
|
print("Finished!") |
|
|
|