|
|
'Model training for NLP' |
|
|
from ..torch_core import * |
|
|
from ..basic_train import * |
|
|
from ..callbacks import * |
|
|
from ..data_block import CategoryList |
|
|
from ..basic_data import * |
|
|
from ..datasets import * |
|
|
from ..metrics import accuracy |
|
|
from ..train import GradientClipping |
|
|
from ..layers import * |
|
|
from .models import * |
|
|
from .transform import * |
|
|
from .data import * |
|
|
|
|
|
__all__ = ['RNNLearner', 'LanguageLearner', 'convert_weights', 'decode_spec_tokens', 'get_language_model', 'language_model_learner', |
|
|
'MultiBatchEncoder', 'get_text_classifier', 'text_classifier_learner', 'PoolingLinearClassifier'] |
|
|
|
|
|
_model_meta = {AWD_LSTM: {'hid_name':'emb_sz', 'url':URLs.WT103_FWD, 'url_bwd':URLs.WT103_BWD, |
|
|
'config_lm':awd_lstm_lm_config, 'split_lm': awd_lstm_lm_split, |
|
|
'config_clas':awd_lstm_clas_config, 'split_clas': awd_lstm_clas_split}, |
|
|
Transformer: {'hid_name':'d_model', 'url':URLs.OPENAI_TRANSFORMER, |
|
|
'config_lm':tfmer_lm_config, 'split_lm': tfmer_lm_split, |
|
|
'config_clas':tfmer_clas_config, 'split_clas': tfmer_clas_split}, |
|
|
TransformerXL: {'hid_name':'d_model', |
|
|
'config_lm':tfmerXL_lm_config, 'split_lm': tfmerXL_lm_split, |
|
|
'config_clas':tfmerXL_clas_config, 'split_clas': tfmerXL_clas_split}} |
|
|
|
|
|
def convert_weights(wgts:Weights, stoi_wgts:Dict[str,int], itos_new:Collection[str]) -> Weights: |
|
|
"Convert the model `wgts` to go with a new vocabulary." |
|
|
dec_bias, enc_wgts = wgts.get('1.decoder.bias', None), wgts['0.encoder.weight'] |
|
|
wgts_m = enc_wgts.mean(0) |
|
|
if dec_bias is not None: bias_m = dec_bias.mean(0) |
|
|
new_w = enc_wgts.new_zeros((len(itos_new),enc_wgts.size(1))).zero_() |
|
|
if dec_bias is not None: new_b = dec_bias.new_zeros((len(itos_new),)).zero_() |
|
|
for i,w in enumerate(itos_new): |
|
|
r = stoi_wgts[w] if w in stoi_wgts else -1 |
|
|
new_w[i] = enc_wgts[r] if r>=0 else wgts_m |
|
|
if dec_bias is not None: new_b[i] = dec_bias[r] if r>=0 else bias_m |
|
|
wgts['0.encoder.weight'] = new_w |
|
|
if '0.encoder_dp.emb.weight' in wgts: wgts['0.encoder_dp.emb.weight'] = new_w.clone() |
|
|
wgts['1.decoder.weight'] = new_w.clone() |
|
|
if dec_bias is not None: wgts['1.decoder.bias'] = new_b |
|
|
return wgts |
|
|
|
|
|
class RNNLearner(Learner): |
|
|
"Basic class for a `Learner` in NLP." |
|
|
def __init__(self, data:DataBunch, model:nn.Module, split_func:OptSplitFunc=None, clip:float=None, |
|
|
alpha:float=2., beta:float=1., metrics=None, **learn_kwargs): |
|
|
is_class = (hasattr(data.train_ds, 'y') and (isinstance(data.train_ds.y, CategoryList) or |
|
|
isinstance(data.train_ds.y, LMLabelList))) |
|
|
metrics = ifnone(metrics, ([accuracy] if is_class else [])) |
|
|
super().__init__(data, model, metrics=metrics, **learn_kwargs) |
|
|
self.callbacks.append(RNNTrainer(self, alpha=alpha, beta=beta)) |
|
|
if clip: self.callback_fns.append(partial(GradientClipping, clip=clip)) |
|
|
if split_func: self.split(split_func) |
|
|
|
|
|
def save_encoder(self, name:str): |
|
|
"Save the encoder to `name` inside the model directory." |
|
|
if is_pathlike(name): self._test_writeable_path() |
|
|
encoder = get_model(self.model)[0] |
|
|
if hasattr(encoder, 'module'): encoder = encoder.module |
|
|
torch.save(encoder.state_dict(), self.path/self.model_dir/f'{name}.pth') |
|
|
|
|
|
def load_encoder(self, name:str, device:torch.device=None): |
|
|
"Load the encoder `name` from the model directory." |
|
|
encoder = get_model(self.model)[0] |
|
|
if device is None: device = self.data.device |
|
|
if hasattr(encoder, 'module'): encoder = encoder.module |
|
|
encoder.load_state_dict(torch.load(self.path/self.model_dir/f'{name}.pth', map_location=device)) |
|
|
self.freeze() |
|
|
|
|
|
def load_pretrained(self, wgts_fname:str, itos_fname:str, strict:bool=True): |
|
|
"Load a pretrained model and adapts it to the data vocabulary." |
|
|
old_itos = pickle.load(open(itos_fname, 'rb')) |
|
|
old_stoi = {v:k for k,v in enumerate(old_itos)} |
|
|
wgts = torch.load(wgts_fname, map_location=lambda storage, loc: storage) |
|
|
if 'model' in wgts: wgts = wgts['model'] |
|
|
wgts = convert_weights(wgts, old_stoi, self.data.train_ds.vocab.itos) |
|
|
self.model.load_state_dict(wgts, strict=strict) |
|
|
|
|
|
def get_preds(self, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None, with_loss:bool=False, n_batch:Optional[int]=None, |
|
|
pbar:Optional[PBar]=None, ordered:bool=False) -> List[Tensor]: |
|
|
"Return predictions and targets on the valid, train, or test set, depending on `ds_type`." |
|
|
self.model.reset() |
|
|
if ordered: np.random.seed(42) |
|
|
preds = super().get_preds(ds_type=ds_type, activ=activ, with_loss=with_loss, n_batch=n_batch, pbar=pbar) |
|
|
if ordered and hasattr(self.dl(ds_type), 'sampler'): |
|
|
np.random.seed(42) |
|
|
sampler = [i for i in self.dl(ds_type).sampler] |
|
|
reverse_sampler = np.argsort(sampler) |
|
|
preds = [p[reverse_sampler] for p in preds] |
|
|
return(preds) |
|
|
|
|
|
def decode_spec_tokens(tokens): |
|
|
new_toks,rule,arg = [],None,None |
|
|
for t in tokens: |
|
|
if t in [TK_MAJ, TK_UP, TK_REP, TK_WREP]: rule = t |
|
|
elif rule is None: new_toks.append(t) |
|
|
elif rule == TK_MAJ: |
|
|
new_toks.append(t[:1].upper() + t[1:].lower()) |
|
|
rule = None |
|
|
elif rule == TK_UP: |
|
|
new_toks.append(t.upper()) |
|
|
rule = None |
|
|
elif arg is None: |
|
|
try: arg = int(t) |
|
|
except: rule = None |
|
|
else: |
|
|
if rule == TK_REP: new_toks.append(t * arg) |
|
|
else: new_toks += [t] * arg |
|
|
return new_toks |
|
|
|
|
|
class LanguageLearner(RNNLearner): |
|
|
"Subclass of RNNLearner for predictions." |
|
|
|
|
|
def predict(self, text:str, n_words:int=1, no_unk:bool=True, temperature:float=1., min_p:float=None, sep:str=' ', |
|
|
decoder=decode_spec_tokens): |
|
|
"Return the `n_words` that come after `text`." |
|
|
ds = self.data.single_dl.dataset |
|
|
self.model.reset() |
|
|
xb,yb = self.data.one_item(text) |
|
|
new_idx = [] |
|
|
for _ in range(n_words): |
|
|
res = self.pred_batch(batch=(xb,yb))[0][-1] |
|
|
|
|
|
if no_unk: res[self.data.vocab.stoi[UNK]] = 0. |
|
|
if min_p is not None: |
|
|
if (res >= min_p).float().sum() == 0: |
|
|
warn(f"There is no item with probability >= {min_p}, try a lower value.") |
|
|
else: res[res < min_p] = 0. |
|
|
if temperature != 1.: res.pow_(1 / temperature) |
|
|
idx = torch.multinomial(res, 1).item() |
|
|
new_idx.append(idx) |
|
|
xb = xb.new_tensor([idx])[None] |
|
|
return text + sep + sep.join(decoder(self.data.vocab.textify(new_idx, sep=None))) |
|
|
|
|
|
def beam_search(self, text:str, n_words:int, no_unk:bool=True, top_k:int=10, beam_sz:int=1000, temperature:float=1., |
|
|
sep:str=' ', decoder=decode_spec_tokens): |
|
|
"Return the `n_words` that come after `text` using beam search." |
|
|
ds = self.data.single_dl.dataset |
|
|
self.model.reset() |
|
|
self.model.eval() |
|
|
xb, yb = self.data.one_item(text) |
|
|
nodes = None |
|
|
nodes = xb.clone() |
|
|
scores = xb.new_zeros(1).float() |
|
|
with torch.no_grad(): |
|
|
for k in progress_bar(range(n_words), leave=False): |
|
|
out = F.log_softmax(self.model(xb)[0][:,-1], dim=-1) |
|
|
if no_unk: out[:,self.data.vocab.stoi[UNK]] = -float('Inf') |
|
|
values, indices = out.topk(top_k, dim=-1) |
|
|
scores = (-values + scores[:,None]).view(-1) |
|
|
indices_idx = torch.arange(0,nodes.size(0))[:,None].expand(nodes.size(0), top_k).contiguous().view(-1) |
|
|
sort_idx = scores.argsort()[:beam_sz] |
|
|
scores = scores[sort_idx] |
|
|
nodes = torch.cat([nodes[:,None].expand(nodes.size(0),top_k,nodes.size(1)), |
|
|
indices[:,:,None].expand(nodes.size(0),top_k,1),], dim=2) |
|
|
nodes = nodes.view(-1, nodes.size(2))[sort_idx] |
|
|
self.model[0].select_hidden(indices_idx[sort_idx]) |
|
|
xb = nodes[:,-1][:,None] |
|
|
if temperature != 1.: scores.div_(temperature) |
|
|
node_idx = torch.multinomial(torch.exp(-scores), 1).item() |
|
|
return text + sep + sep.join(decoder(self.data.vocab.textify([i.item() for i in nodes[node_idx][1:] ], sep=None))) |
|
|
|
|
|
def show_results(self, ds_type=DatasetType.Valid, rows:int=5, max_len:int=20): |
|
|
from IPython.display import display, HTML |
|
|
"Show `rows` result of predictions on `ds_type` dataset." |
|
|
ds = self.dl(ds_type).dataset |
|
|
x,y = self.data.one_batch(ds_type, detach=False, denorm=False) |
|
|
preds = self.pred_batch(batch=(x,y)) |
|
|
y = y.view(*x.size()) |
|
|
z = preds.view(*x.size(),-1).argmax(dim=2) |
|
|
xs = [ds.x.reconstruct(grab_idx(x, i)) for i in range(rows)] |
|
|
ys = [ds.x.reconstruct(grab_idx(y, i)) for i in range(rows)] |
|
|
zs = [ds.x.reconstruct(grab_idx(z, i)) for i in range(rows)] |
|
|
items,names = [],['text', 'target', 'pred'] |
|
|
for i, (x,y,z) in enumerate(zip(xs,ys,zs)): |
|
|
txt_x = ' '.join(x.text.split(' ')[:max_len]) |
|
|
txt_y = ' '.join(y.text.split(' ')[max_len-1:2*max_len-1]) |
|
|
txt_z = ' '.join(z.text.split(' ')[max_len-1:2*max_len-1]) |
|
|
items.append([txt_x, txt_y, txt_z]) |
|
|
items = np.array(items) |
|
|
df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names) |
|
|
with pd.option_context('display.max_colwidth', -1): |
|
|
display(HTML(df.to_html(index=False))) |
|
|
|
|
|
def get_language_model(arch:Callable, vocab_sz:int, config:dict=None, drop_mult:float=1.): |
|
|
"Create a language model from `arch` and its `config`, maybe `pretrained`." |
|
|
meta = _model_meta[arch] |
|
|
config = ifnone(config, meta['config_lm']).copy() |
|
|
for k in config.keys(): |
|
|
if k.endswith('_p'): config[k] *= drop_mult |
|
|
tie_weights,output_p,out_bias = map(config.pop, ['tie_weights', 'output_p', 'out_bias']) |
|
|
init = config.pop('init') if 'init' in config else None |
|
|
encoder = arch(vocab_sz, **config) |
|
|
enc = encoder.encoder if tie_weights else None |
|
|
decoder = LinearDecoder(vocab_sz, config[meta['hid_name']], output_p, tie_encoder=enc, bias=out_bias) |
|
|
model = SequentialRNN(encoder, decoder) |
|
|
return model if init is None else model.apply(init) |
|
|
|
|
|
def language_model_learner(data:DataBunch, arch, config:dict=None, drop_mult:float=1., pretrained:bool=True, |
|
|
pretrained_fnames:OptStrTuple=None, **learn_kwargs) -> 'LanguageLearner': |
|
|
"Create a `Learner` with a language model from `data` and `arch`." |
|
|
model = get_language_model(arch, len(data.vocab.itos), config=config, drop_mult=drop_mult) |
|
|
meta = _model_meta[arch] |
|
|
learn = LanguageLearner(data, model, split_func=meta['split_lm'], **learn_kwargs) |
|
|
url = 'url_bwd' if data.backwards else 'url' |
|
|
if pretrained or pretrained_fnames: |
|
|
if pretrained_fnames is not None: |
|
|
fnames = [learn.path/learn.model_dir/f'{fn}.{ext}' for fn,ext in zip(pretrained_fnames, ['pth', 'pkl'])] |
|
|
else: |
|
|
if url not in meta: |
|
|
warn("There are no pretrained weights for that architecture yet!") |
|
|
return learn |
|
|
model_path = untar_data(meta[url] , data=False) |
|
|
fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']] |
|
|
learn.load_pretrained(*fnames) |
|
|
learn.freeze() |
|
|
return learn |
|
|
|
|
|
def masked_concat_pool(outputs, mask): |
|
|
"Pool MultiBatchEncoder outputs into one vector [last_hidden, max_pool, avg_pool]." |
|
|
output = outputs[-1] |
|
|
avg_pool = output.masked_fill(mask[:, :, None], 0).mean(dim=1) |
|
|
avg_pool *= output.size(1) / (output.size(1)-mask.type(avg_pool.dtype).sum(dim=1))[:,None] |
|
|
max_pool = output.masked_fill(mask[:,:,None], -float('inf')).max(dim=1)[0] |
|
|
x = torch.cat([output[:,-1], max_pool, avg_pool], 1) |
|
|
return x |
|
|
|
|
|
class PoolingLinearClassifier(Module): |
|
|
"Create a linear classifier with pooling." |
|
|
def __init__(self, layers:Collection[int], drops:Collection[float]): |
|
|
mod_layers = [] |
|
|
if len(drops) != len(layers)-1: raise ValueError("Number of layers and dropout values do not match.") |
|
|
activs = [nn.ReLU(inplace=True)] * (len(layers) - 2) + [None] |
|
|
for n_in, n_out, p, actn in zip(layers[:-1], layers[1:], drops, activs): |
|
|
mod_layers += bn_drop_lin(n_in, n_out, p=p, actn=actn) |
|
|
self.layers = nn.Sequential(*mod_layers) |
|
|
|
|
|
def forward(self, input:Tuple[Tensor,Tensor, Tensor])->Tuple[Tensor,Tensor,Tensor]: |
|
|
raw_outputs,outputs,mask = input |
|
|
x = masked_concat_pool(outputs, mask) |
|
|
x = self.layers(x) |
|
|
return x, raw_outputs, outputs |
|
|
|
|
|
class MultiBatchEncoder(Module): |
|
|
"Create an encoder over `module` that can process a full sentence." |
|
|
def __init__(self, bptt:int, max_len:int, module:nn.Module, pad_idx:int=1): |
|
|
self.max_len,self.bptt,self.module,self.pad_idx = max_len,bptt,module,pad_idx |
|
|
|
|
|
def concat(self, arrs:Collection[Tensor])->Tensor: |
|
|
"Concatenate the `arrs` along the batch dimension." |
|
|
return [torch.cat([l[si] for l in arrs], dim=1) for si in range_of(arrs[0])] |
|
|
|
|
|
def reset(self): |
|
|
if hasattr(self.module, 'reset'): self.module.reset() |
|
|
|
|
|
def forward(self, input:LongTensor)->Tuple[Tensor,Tensor]: |
|
|
bs,sl = input.size() |
|
|
self.reset() |
|
|
raw_outputs,outputs,masks = [],[],[] |
|
|
for i in range(0, sl, self.bptt): |
|
|
r, o = self.module(input[:,i: min(i+self.bptt, sl)]) |
|
|
if i>(sl-self.max_len): |
|
|
masks.append(input[:,i: min(i+self.bptt, sl)] == self.pad_idx) |
|
|
raw_outputs.append(r) |
|
|
outputs.append(o) |
|
|
return self.concat(raw_outputs),self.concat(outputs),torch.cat(masks,dim=1) |
|
|
|
|
|
def get_text_classifier(arch:Callable, vocab_sz:int, n_class:int, bptt:int=70, max_len:int=20*70, config:dict=None, |
|
|
drop_mult:float=1., lin_ftrs:Collection[int]=None, ps:Collection[float]=None, |
|
|
pad_idx:int=1) -> nn.Module: |
|
|
"Create a text classifier from `arch` and its `config`, maybe `pretrained`." |
|
|
meta = _model_meta[arch] |
|
|
config = ifnone(config, meta['config_clas']).copy() |
|
|
for k in config.keys(): |
|
|
if k.endswith('_p'): config[k] *= drop_mult |
|
|
if lin_ftrs is None: lin_ftrs = [50] |
|
|
if ps is None: ps = [0.1]*len(lin_ftrs) |
|
|
layers = [config[meta['hid_name']] * 3] + lin_ftrs + [n_class] |
|
|
ps = [config.pop('output_p')] + ps |
|
|
init = config.pop('init') if 'init' in config else None |
|
|
encoder = MultiBatchEncoder(bptt, max_len, arch(vocab_sz, **config), pad_idx=pad_idx) |
|
|
model = SequentialRNN(encoder, PoolingLinearClassifier(layers, ps)) |
|
|
return model if init is None else model.apply(init) |
|
|
|
|
|
def text_classifier_learner(data:DataBunch, arch:Callable, bptt:int=70, max_len:int=70*20, config:dict=None, |
|
|
pretrained:bool=True, drop_mult:float=1., lin_ftrs:Collection[int]=None, |
|
|
ps:Collection[float]=None, **learn_kwargs) -> 'TextClassifierLearner': |
|
|
"Create a `Learner` with a text classifier from `data` and `arch`." |
|
|
model = get_text_classifier(arch, len(data.vocab.itos), data.c, bptt=bptt, max_len=max_len, |
|
|
config=config, drop_mult=drop_mult, lin_ftrs=lin_ftrs, ps=ps) |
|
|
meta = _model_meta[arch] |
|
|
learn = RNNLearner(data, model, split_func=meta['split_clas'], **learn_kwargs) |
|
|
if pretrained: |
|
|
if 'url' not in meta: |
|
|
warn("There are no pretrained weights for that architecture yet!") |
|
|
return learn |
|
|
model_path = untar_data(meta['url'], data=False) |
|
|
fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']] |
|
|
learn.load_pretrained(*fnames, strict=False) |
|
|
learn.freeze() |
|
|
return learn |
|
|
|