File size: 2,214 Bytes
2f610df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95de76e
2f610df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da3a6cf
 
2f610df
da3a6cf
 
 
2f610df
da3a6cf
 
 
2f610df
da3a6cf
 
2f610df
da3a6cf
 
2f610df
da3a6cf
 
 
 
2f610df
da3a6cf
2f610df
da3a6cf
 
2f610df
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import json
import torch
from torch import nn, optim
from pathlib import Path
from torch.utils.data import DataLoader
from services.word_generation_dataset import WordGenDataset, collate_fn
from services.tokenizer import tokenize_dataset, build_vocab
from services.transformer import TinyTransformer
from services.model import save_model

# 1. Preprocessing
print("Preprocessing...")

# 1.1 Read dataset and store it as json
cwd = os.path.dirname(__file__)
file_path = os.path.join(cwd, "data", "definitions-2.json")
with open(file_path, "r", encoding="utf-8") as f:
  data = json.load(f)
print("Data loaded!")

# 1.2 Tokenize using internal tokenizer logic
inputs, outputs = tokenize_dataset(data)
print("Data Tokenized!")

# 1.3 Build vocabulary
vocab, inv_vocab = build_vocab(inputs, outputs)
print("Vocabulary built!")

# 1.4 Create dataset and dataloader
dataset = WordGenDataset(inputs, outputs, vocab)
print("Dataset built!")
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
print("Dataloader built!")

# 2. Training
print("Training...")

# 2.1 Initialize model
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
vocab_size = len(vocab)
model = TinyTransformer(vocab_size=vocab_size).to(device)

# 2.2 Loss and Optimizer
criterion = nn.CrossEntropyLoss(ignore_index=vocab["<pad>"])
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# 2.3 Training Loop
num_epochs = 10

for epoch in range(num_epochs):
  model.train()
  total_loss = 0

  for batch in dataloader:
    src, tgt = batch
    src, tgt = src.to(device), tgt.to(device)

    # Shift target to create input/target pairs
    tgt_input = tgt[:, :-1]
    tgt_expected = tgt[:, 1:]

    # Forward pass
    logits = model(src, tgt_input)

    # Reshape for loss: (batch*seq_len, vocab_size)
    loss = criterion(logits.reshape(-1, vocab_size), tgt_expected.reshape(-1))

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

  avg_loss = total_loss / len(dataloader)
  print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

# 2.4 Save model
print("Saving...")
save_model(model, vocab)
print("Finished!")