| import os |
| import json |
| import torch |
| from torch import nn, optim |
| from pathlib import Path |
| from torch.utils.data import DataLoader |
| from services.word_generation_dataset import WordGenDataset, collate_fn |
| from services.tokenizer import tokenize_dataset, build_vocab |
| from services.transformer import TinyTransformer |
| from services.model import save_model |
|
|
| |
| print("Preprocessing...") |
|
|
| |
| cwd = os.path.dirname(__file__) |
| file_path = os.path.join(cwd, "data", "definitions-2.json") |
| with open(file_path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| print("Data loaded!") |
|
|
| |
| inputs, outputs = tokenize_dataset(data) |
| print("Data Tokenized!") |
|
|
| |
| vocab, inv_vocab = build_vocab(inputs, outputs) |
| print("Vocabulary built!") |
|
|
| |
| dataset = WordGenDataset(inputs, outputs, vocab) |
| print("Dataset built!") |
| dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) |
| print("Dataloader built!") |
|
|
| |
| print("Training...") |
|
|
| |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
| vocab_size = len(vocab) |
| model = TinyTransformer(vocab_size=vocab_size).to(device) |
|
|
| |
| criterion = nn.CrossEntropyLoss(ignore_index=vocab["<pad>"]) |
| optimizer = optim.Adam(model.parameters(), lr=1e-4) |
|
|
| |
| num_epochs = 10 |
|
|
| for epoch in range(num_epochs): |
| model.train() |
| total_loss = 0 |
|
|
| for batch in dataloader: |
| src, tgt = batch |
| src, tgt = src.to(device), tgt.to(device) |
|
|
| |
| tgt_input = tgt[:, :-1] |
| tgt_expected = tgt[:, 1:] |
|
|
| |
| logits = model(src, tgt_input) |
|
|
| |
| loss = criterion(logits.reshape(-1, vocab_size), tgt_expected.reshape(-1)) |
|
|
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| total_loss += loss.item() |
|
|
| avg_loss = total_loss / len(dataloader) |
| print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}") |
|
|
| |
| print("Saving...") |
| save_model(model, vocab) |
| print("Finished!") |
|
|