| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset, DataLoader |
|
|
| from data_loader.load_squad_json import load_squad_json |
| from utils.squad_preprocess import process_sample |
| from utils.vocab import build_vocab, encode |
| from models.qa_model import QAModel |
|
|
|
|
| class QADataset(Dataset): |
| def __init__(self, samples, vocab, max_len=300): |
| self.data = [] |
|
|
| for s in samples: |
| item = process_sample(s) |
| if not item: |
| continue |
|
|
| tokens = item["tokens"] |
| encoded = encode(tokens, vocab) |
|
|
| if len(encoded) < max_len: |
| encoded += [0] * (max_len - len(encoded)) |
| else: |
| encoded = encoded[:max_len] |
|
|
| start = item["start"] |
| end = item["end"] |
|
|
| if start >= max_len or end >= max_len: |
| continue |
|
|
| self.data.append((encoded, start, end)) |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| x, s, e = self.data[idx] |
| return torch.tensor(x), torch.tensor(s), torch.tensor(e) |
|
|
|
|
| def train(): |
| print("Loading data...") |
| raw = load_squad_json("data/train-v2.0.json")[:30000] |
|
|
| print("Building vocab...") |
| all_tokens = [] |
| for s in raw: |
| item = process_sample(s) |
| if item: |
| all_tokens += item["tokens"] |
|
|
| vocab = build_vocab(all_tokens) |
|
|
| print("Preparing dataset...") |
| dataset = QADataset(raw, vocab) |
| loader = DataLoader(dataset, batch_size=32, shuffle=True) |
|
|
| print("Initializing model...") |
| model = QAModel(len(vocab)) |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
| loss_fn = nn.CrossEntropyLoss() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
|
|
| print("Training...\n") |
|
|
| for epoch in range(5): |
| total_loss = 0 |
|
|
| for x, start, end in loader: |
| x = x.to(device) |
| start = start.to(device) |
| end = end.to(device) |
|
|
| pred_start, pred_end = model(x) |
|
|
| loss = loss_fn(pred_start, start) + loss_fn(pred_end, end) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| total_loss += loss.item() |
|
|
| print(f"Epoch {epoch+1} Loss: {total_loss:.2f}") |
|
|
| torch.save({ |
| "model_state": model.state_dict(), |
| "vocab": vocab |
| }, "qa_model.pth") |
|
|
| print("\n✅ Model trained and saved!") |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|