SQuAD / train.py
tnp554's picture
feat: deploy SQuAD backend with all AI models
09daf0b
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from data_loader.load_squad_json import load_squad_json
from utils.squad_preprocess import process_sample
from utils.vocab import build_vocab, encode
from models.qa_model import QAModel
class QADataset(Dataset):
def __init__(self, samples, vocab, max_len=300):
self.data = []
for s in samples:
item = process_sample(s)
if not item:
continue
tokens = item["tokens"]
encoded = encode(tokens, vocab)
if len(encoded) < max_len:
encoded += [0] * (max_len - len(encoded))
else:
encoded = encoded[:max_len]
start = item["start"]
end = item["end"]
if start >= max_len or end >= max_len:
continue
self.data.append((encoded, start, end))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
x, s, e = self.data[idx]
return torch.tensor(x), torch.tensor(s), torch.tensor(e)
def train():
print("Loading data...")
raw = load_squad_json("data/train-v2.0.json")[:30000]
print("Building vocab...")
all_tokens = []
for s in raw:
item = process_sample(s)
if item:
all_tokens += item["tokens"]
vocab = build_vocab(all_tokens)
print("Preparing dataset...")
dataset = QADataset(raw, vocab)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
print("Initializing model...")
model = QAModel(len(vocab))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Training...\n")
for epoch in range(5):
total_loss = 0
for x, start, end in loader:
x = x.to(device)
start = start.to(device)
end = end.to(device)
pred_start, pred_end = model(x)
loss = loss_fn(pred_start, start) + loss_fn(pred_end, end)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1} Loss: {total_loss:.2f}")
torch.save({
"model_state": model.state_dict(),
"vocab": vocab
}, "qa_model.pth")
print("\n✅ Model trained and saved!")
if __name__ == "__main__":
train()