| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | This file contains the definition of encoders used in https://arxiv.org/pdf/1705.02364.pdf |
| | """ |
| |
|
| | import numpy as np |
| | import time |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | class InferSent(nn.Module): |
| |
|
| | def __init__(self, config): |
| | super(InferSent, self).__init__() |
| | self.bsize = config['bsize'] |
| | self.word_emb_dim = config['word_emb_dim'] |
| | self.enc_lstm_dim = config['enc_lstm_dim'] |
| | self.pool_type = config['pool_type'] |
| | self.dpout_model = config['dpout_model'] |
| | self.version = 1 if 'version' not in config else config['version'] |
| |
|
| | self.enc_lstm = nn.LSTM(self.word_emb_dim, self.enc_lstm_dim, 1, |
| | bidirectional=True, dropout=self.dpout_model) |
| |
|
| | assert self.version in [1, 2] |
| | if self.version == 1: |
| | self.bos = '<s>' |
| | self.eos = '</s>' |
| | self.max_pad = True |
| | self.moses_tok = False |
| | elif self.version == 2: |
| | self.bos = '<p>' |
| | self.eos = '</p>' |
| | self.max_pad = False |
| | self.moses_tok = True |
| |
|
| | def is_cuda(self): |
| | |
| | return self.enc_lstm.bias_hh_l0.data.is_cuda |
| |
|
| | def forward(self, sent_tuple): |
| | |
| | |
| | sent, sent_len = sent_tuple |
| |
|
| | |
| | sent_len_sorted, idx_sort = np.sort(sent_len)[::-1], np.argsort(-sent_len) |
| | sent_len_sorted = sent_len_sorted.copy() |
| | idx_unsort = np.argsort(idx_sort) |
| |
|
| | idx_sort = torch.from_numpy(idx_sort).cuda() if self.is_cuda() \ |
| | else torch.from_numpy(idx_sort) |
| | sent = sent.index_select(1, idx_sort) |
| |
|
| | |
| | sent_packed = nn.utils.rnn.pack_padded_sequence(sent, sent_len_sorted) |
| | sent_output = self.enc_lstm(sent_packed)[0] |
| | sent_output = nn.utils.rnn.pad_packed_sequence(sent_output)[0] |
| |
|
| | |
| | idx_unsort = torch.from_numpy(idx_unsort).cuda() if self.is_cuda() \ |
| | else torch.from_numpy(idx_unsort) |
| | sent_output = sent_output.index_select(1, idx_unsort) |
| |
|
| | |
| | if self.pool_type == "mean": |
| | sent_len = torch.FloatTensor(sent_len.copy()).unsqueeze(1).cuda() |
| | emb = torch.sum(sent_output, 0).squeeze(0) |
| | emb = emb / sent_len.expand_as(emb) |
| | elif self.pool_type == "max": |
| | if not self.max_pad: |
| | sent_output[sent_output == 0] = -1e9 |
| | emb = torch.max(sent_output, 0)[0] |
| | if emb.ndimension() == 3: |
| | emb = emb.squeeze(0) |
| | assert emb.ndimension() == 2 |
| |
|
| | return emb |
| |
|
| | def set_w2v_path(self, w2v_path): |
| | self.w2v_path = w2v_path |
| |
|
| | def get_word_dict(self, sentences, tokenize=True): |
| | |
| | word_dict = {} |
| | sentences = [s.split() if not tokenize else self.tokenize(s) for s in sentences] |
| | for sent in sentences: |
| | for word in sent: |
| | if word not in word_dict: |
| | word_dict[word] = '' |
| | word_dict[self.bos] = '' |
| | word_dict[self.eos] = '' |
| | return word_dict |
| |
|
| | def get_w2v(self, word_dict): |
| | assert hasattr(self, 'w2v_path'), 'w2v path not set' |
| | |
| | word_vec = {} |
| | with open(self.w2v_path, encoding='utf-8') as f: |
| | for line in f: |
| | word, vec = line.split(' ', 1) |
| | if word in word_dict: |
| | word_vec[word] = np.fromstring(vec, sep=' ') |
| | print('Found %s(/%s) words with w2v vectors' % (len(word_vec), len(word_dict))) |
| | return word_vec |
| |
|
| | def get_w2v_k(self, K): |
| | assert hasattr(self, 'w2v_path'), 'w2v path not set' |
| | |
| | k = 0 |
| | word_vec = {} |
| | with open(self.w2v_path, encoding='utf-8') as f: |
| | for line in f: |
| | word, vec = line.split(' ', 1) |
| | if k <= K: |
| | word_vec[word] = np.fromstring(vec, sep=' ') |
| | k += 1 |
| | if k > K: |
| | if word in [self.bos, self.eos]: |
| | word_vec[word] = np.fromstring(vec, sep=' ') |
| |
|
| | if k > K and all([w in word_vec for w in [self.bos, self.eos]]): |
| | break |
| | return word_vec |
| |
|
| | def build_vocab(self, sentences, tokenize=True): |
| | assert hasattr(self, 'w2v_path'), 'w2v path not set' |
| | word_dict = self.get_word_dict(sentences, tokenize) |
| | self.word_vec = self.get_w2v(word_dict) |
| | print('Vocab size : %s' % (len(self.word_vec))) |
| |
|
| | |
| | def build_vocab_k_words(self, K): |
| | assert hasattr(self, 'w2v_path'), 'w2v path not set' |
| | self.word_vec = self.get_w2v_k(K) |
| | print('Vocab size : %s' % (K)) |
| |
|
| | def update_vocab(self, sentences, tokenize=True): |
| | assert hasattr(self, 'w2v_path'), 'warning : w2v path not set' |
| | assert hasattr(self, 'word_vec'), 'build_vocab before updating it' |
| | word_dict = self.get_word_dict(sentences, tokenize) |
| |
|
| | |
| | for word in self.word_vec: |
| | if word in word_dict: |
| | del word_dict[word] |
| |
|
| | |
| | if word_dict: |
| | new_word_vec = self.get_w2v(word_dict) |
| | self.word_vec.update(new_word_vec) |
| | else: |
| | new_word_vec = [] |
| | print('New vocab size : %s (added %s words)'% (len(self.word_vec), len(new_word_vec))) |
| |
|
| | def get_batch(self, batch): |
| | |
| | |
| | embed = np.zeros((len(batch[0]), len(batch), self.word_emb_dim)) |
| |
|
| | for i in range(len(batch)): |
| | for j in range(len(batch[i])): |
| | embed[j, i, :] = self.word_vec[batch[i][j]] |
| |
|
| | return torch.FloatTensor(embed) |
| |
|
| | def tokenize(self, s): |
| | from nltk.tokenize import word_tokenize |
| | if self.moses_tok: |
| | s = ' '.join(word_tokenize(s)) |
| | s = s.replace(" n't ", "n 't ") |
| | return s.split() |
| | else: |
| | return word_tokenize(s) |
| |
|
| | def prepare_samples(self, sentences, bsize, tokenize, verbose): |
| | sentences = [[self.bos] + s.split() + [self.eos] if not tokenize else |
| | [self.bos] + self.tokenize(s) + [self.eos] for s in sentences] |
| | n_w = np.sum([len(x) for x in sentences]) |
| |
|
| | |
| | for i in range(len(sentences)): |
| | s_f = [word for word in sentences[i] if word in self.word_vec] |
| | if not s_f: |
| | import warnings |
| | warnings.warn('No words in "%s" (idx=%s) have w2v vectors. \ |
| | Replacing by "</s>"..' % (sentences[i], i)) |
| | s_f = [self.eos] |
| | sentences[i] = s_f |
| |
|
| | lengths = np.array([len(s) for s in sentences]) |
| | n_wk = np.sum(lengths) |
| | if verbose: |
| | print('Nb words kept : %s/%s (%.1f%s)' % ( |
| | n_wk, n_w, 100.0 * n_wk / n_w, '%')) |
| |
|
| | |
| | lengths, idx_sort = np.sort(lengths)[::-1], np.argsort(-lengths) |
| | sentences = np.array(sentences)[idx_sort] |
| |
|
| | return sentences, lengths, idx_sort |
| |
|
| | def encode(self, sentences, bsize=64, tokenize=True, verbose=False): |
| | tic = time.time() |
| | sentences, lengths, idx_sort = self.prepare_samples( |
| | sentences, bsize, tokenize, verbose) |
| |
|
| | embeddings = [] |
| | for stidx in range(0, len(sentences), bsize): |
| | batch = self.get_batch(sentences[stidx:stidx + bsize]) |
| | if self.is_cuda(): |
| | batch = batch.cuda() |
| | with torch.no_grad(): |
| | batch = self.forward((batch, lengths[stidx:stidx + bsize])).data.cpu().numpy() |
| | embeddings.append(batch) |
| | embeddings = np.vstack(embeddings) |
| |
|
| | |
| | idx_unsort = np.argsort(idx_sort) |
| | embeddings = embeddings[idx_unsort] |
| |
|
| | if verbose: |
| | print('Speed : %.1f sentences/s (%s mode, bsize=%s)' % ( |
| | len(embeddings)/(time.time()-tic), |
| | 'gpu' if self.is_cuda() else 'cpu', bsize)) |
| | return embeddings |
| |
|
| | def visualize(self, sent, tokenize=True): |
| |
|
| | sent = sent.split() if not tokenize else self.tokenize(sent) |
| | sent = [[self.bos] + [word for word in sent if word in self.word_vec] + [self.eos]] |
| |
|
| | if ' '.join(sent[0]) == '%s %s' % (self.bos, self.eos): |
| | import warnings |
| | warnings.warn('No words in "%s" have w2v vectors. Replacing \ |
| | by "%s %s"..' % (sent, self.bos, self.eos)) |
| | batch = self.get_batch(sent) |
| |
|
| | if self.is_cuda(): |
| | batch = batch.cuda() |
| | output = self.enc_lstm(batch)[0] |
| | output, idxs = torch.max(output, 0) |
| | |
| | idxs = idxs.data.cpu().numpy() |
| | argmaxs = [np.sum((idxs == k)) for k in range(len(sent[0]))] |
| |
|
| | |
| | import matplotlib.pyplot as plt |
| | x = range(len(sent[0])) |
| | y = [100.0 * n / np.sum(argmaxs) for n in argmaxs] |
| | plt.xticks(x, sent[0], rotation=45) |
| | plt.bar(x, y) |
| | plt.ylabel('%') |
| | plt.title('Visualisation of words importance') |
| | plt.show() |
| |
|
| | return output, idxs |
| |
|