import torch import torch.nn as nn class QAModel(nn.Module): def __init__(self, vocab_size, embed_dim=200, hidden_dim=256): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.lstm = nn.LSTM( embed_dim, hidden_dim, batch_first=True, bidirectional=True ) self.fc_start = nn.Linear(hidden_dim*2, 1) self.fc_end = nn.Linear(hidden_dim*2, 1) def forward(self, x): x = self.embedding(x) out, _ = self.lstm(x) start = self.fc_start(out).squeeze(-1) end = self.fc_end(out).squeeze(-1) return start, end