| import torch | |
| import torch.nn as nn | |
| class QAModel(nn.Module): | |
| def __init__(self, vocab_size, embed_dim=200, hidden_dim=256): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim) | |
| self.lstm = nn.LSTM( | |
| embed_dim, | |
| hidden_dim, | |
| batch_first=True, | |
| bidirectional=True | |
| ) | |
| self.fc_start = nn.Linear(hidden_dim*2, 1) | |
| self.fc_end = nn.Linear(hidden_dim*2, 1) | |
| def forward(self, x): | |
| x = self.embedding(x) | |
| out, _ = self.lstm(x) | |
| start = self.fc_start(out).squeeze(-1) | |
| end = self.fc_end(out).squeeze(-1) | |
| return start, end | |