| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence | |
| from stanza.models.common.char_model import CharacterLanguageModelWordAdapter | |
| from stanza.models.common.foundation_cache import load_charlm | |
| class Tokenizer(nn.Module): | |
| def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, foundation_cache=None): | |
| super().__init__() | |
| self.unsaved_modules = [] | |
| self.args = args | |
| feat_dim = args['feat_dim'] | |
| self.embeddings = nn.Embedding(nchars, emb_dim, padding_idx=0) | |
| self.input_dim = emb_dim + feat_dim | |
| charmodel = None | |
| if args is not None and args.get('charlm_forward_file', None): | |
| charmodel_forward = load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache) | |
| charmodels = nn.ModuleList([charmodel_forward]) | |
| charmodel = CharacterLanguageModelWordAdapter(charmodels) | |
| self.input_dim += charmodel.hidden_dim() | |
| self.add_unsaved_module("charmodel", charmodel) | |
| self.rnn = nn.LSTM(self.input_dim, hidden_dim, num_layers=self.args['rnn_layers'], bidirectional=True, batch_first=True, dropout=dropout if self.args['rnn_layers'] > 1 else 0) | |
| if self.args['conv_res'] is not None: | |
| self.conv_res = nn.ModuleList() | |
| self.conv_sizes = [int(x) for x in self.args['conv_res'].split(',')] | |
| for si, size in enumerate(self.conv_sizes): | |
| l = nn.Conv1d(self.input_dim, hidden_dim * 2, size, padding=size//2, bias=self.args.get('hier_conv_res', False) or (si == 0)) | |
| self.conv_res.append(l) | |
| if self.args.get('hier_conv_res', False): | |
| self.conv_res2 = nn.Conv1d(hidden_dim * 2 * len(self.conv_sizes), hidden_dim * 2, 1) | |
| self.tok_clf = nn.Linear(hidden_dim * 2, 1) | |
| self.sent_clf = nn.Linear(hidden_dim * 2, 1) | |
| if self.args['use_mwt']: | |
| self.mwt_clf = nn.Linear(hidden_dim * 2, 1) | |
| if args['hierarchical']: | |
| in_dim = hidden_dim * 2 | |
| self.rnn2 = nn.LSTM(in_dim, hidden_dim, num_layers=1, bidirectional=True, batch_first=True) | |
| self.tok_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False) | |
| self.sent_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False) | |
| if self.args['use_mwt']: | |
| self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| self.dropout_feat = nn.Dropout(feat_dropout) | |
| self.toknoise = nn.Dropout(self.args['tok_noise']) | |
| def add_unsaved_module(self, name, module): | |
| self.unsaved_modules += [name] | |
| setattr(self, name, module) | |
| def forward(self, x, feats, lengths, raw=None): | |
| emb = self.embeddings(x) | |
| if self.charmodel is not None and raw is not None: | |
| char_emb = self.charmodel(raw, wrap=False) | |
| emb = torch.cat([emb, char_emb], axis=2) | |
| emb = self.dropout(emb) | |
| feats = self.dropout_feat(feats) | |
| emb = torch.cat([emb, feats], 2) | |
| emb = pack_padded_sequence(emb, lengths, batch_first=True) | |
| inp, _ = self.rnn(emb) | |
| inp, _ = pad_packed_sequence(inp, batch_first=True) | |
| if self.args['conv_res'] is not None: | |
| conv_input = emb.transpose(1, 2).contiguous() | |
| if not self.args.get('hier_conv_res', False): | |
| for l in self.conv_res: | |
| inp = inp + l(conv_input).transpose(1, 2).contiguous() | |
| else: | |
| hid = [] | |
| for l in self.conv_res: | |
| hid += [l(conv_input)] | |
| hid = torch.cat(hid, 1) | |
| hid = F.relu(hid) | |
| hid = self.dropout(hid) | |
| inp = inp + self.conv_res2(hid).transpose(1, 2).contiguous() | |
| inp = self.dropout(inp) | |
| tok0 = self.tok_clf(inp) | |
| sent0 = self.sent_clf(inp) | |
| if self.args['use_mwt']: | |
| mwt0 = self.mwt_clf(inp) | |
| if self.args['hierarchical']: | |
| inp2 = inp | |
| if self.args['hier_invtemp'] > 0: | |
| inp2 = inp2 * (1 - self.toknoise(torch.sigmoid(-tok0 * self.args['hier_invtemp']))) | |
| inp2 = pack_padded_sequence(inp2, lengths, batch_first=True) | |
| inp2, _ = self.rnn2(inp2) | |
| inp2, _ = pad_packed_sequence(inp2, batch_first=True) | |
| inp2 = self.dropout(inp2) | |
| tok0 = tok0 + self.tok_clf2(inp2) | |
| sent0 = sent0 + self.sent_clf2(inp2) | |
| if self.args['use_mwt']: | |
| mwt0 = mwt0 + self.mwt_clf2(inp2) | |
| nontok = F.logsigmoid(-tok0) | |
| tok = F.logsigmoid(tok0) | |
| nonsent = F.logsigmoid(-sent0) | |
| sent = F.logsigmoid(sent0) | |
| if self.args['use_mwt']: | |
| nonmwt = F.logsigmoid(-mwt0) | |
| mwt = F.logsigmoid(mwt0) | |
| pred = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2) | |
| else: | |
| pred = torch.cat([nontok, tok+nonsent, tok+sent], 2) | |
| return pred | |