import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from data_loader.load_squad_json import load_squad_json from utils.squad_preprocess import process_sample from utils.vocab import build_vocab, encode from models.qa_model import QAModel class QADataset(Dataset): def __init__(self, samples, vocab, max_len=300): self.data = [] for s in samples: item = process_sample(s) if not item: continue tokens = item["tokens"] encoded = encode(tokens, vocab) if len(encoded) < max_len: encoded += [0] * (max_len - len(encoded)) else: encoded = encoded[:max_len] start = item["start"] end = item["end"] if start >= max_len or end >= max_len: continue self.data.append((encoded, start, end)) def __len__(self): return len(self.data) def __getitem__(self, idx): x, s, e = self.data[idx] return torch.tensor(x), torch.tensor(s), torch.tensor(e) def train(): print("Loading data...") raw = load_squad_json("data/train-v2.0.json")[:30000] print("Building vocab...") all_tokens = [] for s in raw: item = process_sample(s) if item: all_tokens += item["tokens"] vocab = build_vocab(all_tokens) print("Preparing dataset...") dataset = QADataset(raw, vocab) loader = DataLoader(dataset, batch_size=32, shuffle=True) print("Initializing model...") model = QAModel(len(vocab)) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) loss_fn = nn.CrossEntropyLoss() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) print("Training...\n") for epoch in range(5): total_loss = 0 for x, start, end in loader: x = x.to(device) start = start.to(device) end = end.to(device) pred_start, pred_end = model(x) loss = loss_fn(pred_start, start) + loss_fn(pred_end, end) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1} Loss: {total_loss:.2f}") torch.save({ "model_state": model.state_dict(), "vocab": vocab }, "qa_model.pth") print("\n✅ Model trained and saved!") if __name__ == "__main__": train()