NER_17_labels / models /bilstm_crf.py
dungquang's picture
Add: models
350e014 verified
import torch
import torch.nn as nn
from torchcrf import CRF
class BiLSTM_CRF(nn.Module):
def __init__(self, vocab_size, output_size, embedding_size, hidden_size, pad_idx, dropout=0.5):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=pad_idx)
self.dropout = nn.Dropout(dropout)
self.lstm = nn.LSTM(
input_size=embedding_size,
hidden_size=hidden_size,
batch_first=True,
bidirectional=True
)
self.fc = nn.Linear(hidden_size * 2, output_size)
self.crf = CRF(output_size, batch_first=True)
def forward(self, x, tags=None, mask=None):
x = self.dropout(self.embedding(x))
x, _ = self.lstm(x)
x = self.dropout(x)
emissions = self.fc(x)
if tags is not None:
# Training mode
loss = -self.crf(emissions, tags, mask=mask, reduction='mean')
return loss
else:
# Validation mode: decode the best path
prediction = self.crf.decode(emissions, mask=mask)
return prediction
def load_model(vocab_size, output_size, embedding_size, hidden_size, dropout=0.5, pad_idx=0):
model = BiLSTM_CRF(vocab_size, output_size, embedding_size, hidden_size, pad_idx, dropout)
model.load_state_dict(torch.load('models/BiLSTM_CRF.pth', map_location=torch.device('cpu')))
return model