Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchcrf import CRF | |
| class BiLSTM_CRF(nn.Module): | |
| def __init__(self, vocab_size, output_size, embedding_size, hidden_size, pad_idx, dropout=0.5): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=pad_idx) | |
| self.dropout = nn.Dropout(dropout) | |
| self.lstm = nn.LSTM( | |
| input_size=embedding_size, | |
| hidden_size=hidden_size, | |
| batch_first=True, | |
| bidirectional=True | |
| ) | |
| self.fc = nn.Linear(hidden_size * 2, output_size) | |
| self.crf = CRF(output_size, batch_first=True) | |
| def forward(self, x, tags=None, mask=None): | |
| x = self.dropout(self.embedding(x)) | |
| x, _ = self.lstm(x) | |
| x = self.dropout(x) | |
| emissions = self.fc(x) | |
| if tags is not None: | |
| # Training mode | |
| loss = -self.crf(emissions, tags, mask=mask, reduction='mean') | |
| return loss | |
| else: | |
| # Validation mode: decode the best path | |
| prediction = self.crf.decode(emissions, mask=mask) | |
| return prediction | |
| def load_model(vocab_size, output_size, embedding_size, hidden_size, dropout=0.5, pad_idx=0): | |
| model = BiLSTM_CRF(vocab_size, output_size, embedding_size, hidden_size, pad_idx, dropout) | |
| model.load_state_dict(torch.load('models/BiLSTM_CRF.pth', map_location=torch.device('cpu'))) | |
| return model |