File size: 2,521 Bytes
09daf0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from data_loader.load_squad_json import load_squad_json
from utils.squad_preprocess import process_sample
from utils.vocab import build_vocab, encode
from models.qa_model import QAModel


class QADataset(Dataset):
    def __init__(self, samples, vocab, max_len=300):
        self.data = []

        for s in samples:
            item = process_sample(s)
            if not item:
                continue

            tokens = item["tokens"]
            encoded = encode(tokens, vocab)

            if len(encoded) < max_len:
                encoded += [0] * (max_len - len(encoded))
            else:
                encoded = encoded[:max_len]

            start = item["start"]
            end = item["end"]

            if start >= max_len or end >= max_len:
                continue

            self.data.append((encoded, start, end))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x, s, e = self.data[idx]
        return torch.tensor(x), torch.tensor(s), torch.tensor(e)


def train():
    print("Loading data...")
    raw = load_squad_json("data/train-v2.0.json")[:30000]

    print("Building vocab...")
    all_tokens = []
    for s in raw:
        item = process_sample(s)
        if item:
            all_tokens += item["tokens"]

    vocab = build_vocab(all_tokens)

    print("Preparing dataset...")
    dataset = QADataset(raw, vocab)
    loader = DataLoader(dataset, batch_size=32, shuffle=True)

    print("Initializing model...")
    model = QAModel(len(vocab))
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.CrossEntropyLoss()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    print("Training...\n")

    for epoch in range(5):
        total_loss = 0

        for x, start, end in loader:
            x = x.to(device)
            start = start.to(device)
            end = end.to(device)

            pred_start, pred_end = model(x)

            loss = loss_fn(pred_start, start) + loss_fn(pred_end, end)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1} Loss: {total_loss:.2f}")

    torch.save({
        "model_state": model.state_dict(),
        "vocab": vocab
    }, "qa_model.pth")

    print("\n✅ Model trained and saved!")


if __name__ == "__main__":
    train()