Albin Thörn Cleland
Clean initial commit with LFS
19b8775
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence
from stanza.models.common.char_model import CharacterLanguageModelWordAdapter
from stanza.models.common.foundation_cache import load_charlm
class Tokenizer(nn.Module):
def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, foundation_cache=None):
super().__init__()
self.unsaved_modules = []
self.args = args
feat_dim = args['feat_dim']
self.embeddings = nn.Embedding(nchars, emb_dim, padding_idx=0)
self.input_dim = emb_dim + feat_dim
charmodel = None
if args is not None and args.get('charlm_forward_file', None):
charmodel_forward = load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache)
charmodels = nn.ModuleList([charmodel_forward])
charmodel = CharacterLanguageModelWordAdapter(charmodels)
self.input_dim += charmodel.hidden_dim()
self.add_unsaved_module("charmodel", charmodel)
self.rnn = nn.LSTM(self.input_dim, hidden_dim, num_layers=self.args['rnn_layers'], bidirectional=True, batch_first=True, dropout=dropout if self.args['rnn_layers'] > 1 else 0)
if self.args['conv_res'] is not None:
self.conv_res = nn.ModuleList()
self.conv_sizes = [int(x) for x in self.args['conv_res'].split(',')]
for si, size in enumerate(self.conv_sizes):
l = nn.Conv1d(self.input_dim, hidden_dim * 2, size, padding=size//2, bias=self.args.get('hier_conv_res', False) or (si == 0))
self.conv_res.append(l)
if self.args.get('hier_conv_res', False):
self.conv_res2 = nn.Conv1d(hidden_dim * 2 * len(self.conv_sizes), hidden_dim * 2, 1)
self.tok_clf = nn.Linear(hidden_dim * 2, 1)
self.sent_clf = nn.Linear(hidden_dim * 2, 1)
if self.args['use_mwt']:
self.mwt_clf = nn.Linear(hidden_dim * 2, 1)
if args['hierarchical']:
in_dim = hidden_dim * 2
self.rnn2 = nn.LSTM(in_dim, hidden_dim, num_layers=1, bidirectional=True, batch_first=True)
self.tok_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
self.sent_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
if self.args['use_mwt']:
self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)
self.dropout = nn.Dropout(dropout)
self.dropout_feat = nn.Dropout(feat_dropout)
self.toknoise = nn.Dropout(self.args['tok_noise'])
def add_unsaved_module(self, name, module):
self.unsaved_modules += [name]
setattr(self, name, module)
def forward(self, x, feats, lengths, raw=None):
emb = self.embeddings(x)
if self.charmodel is not None and raw is not None:
char_emb = self.charmodel(raw, wrap=False)
emb = torch.cat([emb, char_emb], axis=2)
emb = self.dropout(emb)
feats = self.dropout_feat(feats)
emb = torch.cat([emb, feats], 2)
emb = pack_padded_sequence(emb, lengths, batch_first=True)
inp, _ = self.rnn(emb)
inp, _ = pad_packed_sequence(inp, batch_first=True)
if self.args['conv_res'] is not None:
conv_input = emb.transpose(1, 2).contiguous()
if not self.args.get('hier_conv_res', False):
for l in self.conv_res:
inp = inp + l(conv_input).transpose(1, 2).contiguous()
else:
hid = []
for l in self.conv_res:
hid += [l(conv_input)]
hid = torch.cat(hid, 1)
hid = F.relu(hid)
hid = self.dropout(hid)
inp = inp + self.conv_res2(hid).transpose(1, 2).contiguous()
inp = self.dropout(inp)
tok0 = self.tok_clf(inp)
sent0 = self.sent_clf(inp)
if self.args['use_mwt']:
mwt0 = self.mwt_clf(inp)
if self.args['hierarchical']:
inp2 = inp
if self.args['hier_invtemp'] > 0:
inp2 = inp2 * (1 - self.toknoise(torch.sigmoid(-tok0 * self.args['hier_invtemp'])))
inp2 = pack_padded_sequence(inp2, lengths, batch_first=True)
inp2, _ = self.rnn2(inp2)
inp2, _ = pad_packed_sequence(inp2, batch_first=True)
inp2 = self.dropout(inp2)
tok0 = tok0 + self.tok_clf2(inp2)
sent0 = sent0 + self.sent_clf2(inp2)
if self.args['use_mwt']:
mwt0 = mwt0 + self.mwt_clf2(inp2)
nontok = F.logsigmoid(-tok0)
tok = F.logsigmoid(tok0)
nonsent = F.logsigmoid(-sent0)
sent = F.logsigmoid(sent0)
if self.args['use_mwt']:
nonmwt = F.logsigmoid(-mwt0)
mwt = F.logsigmoid(mwt0)
pred = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2)
else:
pred = torch.cat([nontok, tok+nonsent, tok+sent], 2)
return pred