| import abc |
| import functools |
| import os |
| import time |
|
|
| import bpemb |
| |
| import torch |
| import torchtext |
|
|
| |
| from seq2struct.utils import registry |
|
|
|
|
| class Embedder(metaclass=abc.ABCMeta): |
|
|
| @abc.abstractmethod |
| def tokenize(self, sentence): |
| '''Given a string, return a list of tokens suitable for lookup.''' |
| pass |
|
|
| @abc.abstractmethod |
| def untokenize(self, tokens): |
| '''Undo tokenize.''' |
| pass |
|
|
| @abc.abstractmethod |
| def lookup(self, token): |
| '''Given a token, return a vector embedding if token is in vocabulary. |
| |
| If token is not in the vocabulary, then return None.''' |
| pass |
|
|
| @abc.abstractmethod |
| def contains(self, token): |
| pass |
|
|
| @abc.abstractmethod |
| def to(self, device): |
| '''Transfer the pretrained embeddings to the given device.''' |
| pass |
|
|
|
|
| @registry.register('word_emb', 'glove') |
| class GloVe(Embedder): |
|
|
| def __init__(self, kind, lemmatize=False): |
| cache = os.path.join(os.environ.get('CACHE_DIR', os.getcwd()), '.vector_cache') |
| self.glove = torchtext.vocab.GloVe(name=kind, cache=cache) |
| self.dim = self.glove.dim |
| self.vectors = self.glove.vectors |
| self.lemmatize = lemmatize |
| self.corenlp_annotators = ['tokenize', 'ssplit'] |
| if lemmatize: |
| self.corenlp_annotators.append('lemma') |
|
|
| @functools.lru_cache(maxsize=1024) |
| def tokenize(self, text): |
| ann = corenlp.annotate(text, self.corenlp_annotators) |
| if self.lemmatize: |
| return [tok.lemma.lower() for sent in ann.sentence for tok in sent.token] |
| else: |
| return [tok.word.lower() for sent in ann.sentence for tok in sent.token] |
| |
| @functools.lru_cache(maxsize=1024) |
| def tokenize_for_copying(self, text): |
| ann = corenlp.annotate(text, self.corenlp_annotators) |
| text_for_copying = [tok.originalText.lower() for sent in ann.sentence for tok in sent.token] |
| if self.lemmatize: |
| text = [tok.lemma.lower() for sent in ann.sentence for tok in sent.token] |
| else: |
| text = [tok.word.lower() for sent in ann.sentence for tok in sent.token] |
| return text, text_for_copying |
|
|
| def untokenize(self, tokens): |
| return ' '.join(tokens) |
|
|
| def lookup(self, token): |
| i = self.glove.stoi.get(token) |
| if i is None: |
| return None |
| return self.vectors[i] |
|
|
| def contains(self, token): |
| return token in self.glove.stoi |
|
|
| def to(self, device): |
| self.vectors = self.vectors.to(device) |
|
|
|
|
| @registry.register('word_emb', 'bpemb') |
| class BPEmb(Embedder): |
| def __init__(self, dim, vocab_size, lang='en'): |
| self.bpemb = bpemb.BPEmb(lang=lang, dim=dim, vs=vocab_size) |
| self.dim = dim |
| self.vectors = torch.from_numpy(self.bpemb.vectors) |
|
|
| def tokenize(self, text): |
| return self.bpemb.encode(text) |
|
|
| def untokenize(self, tokens): |
| return self.bpemb.decode(tokens) |
|
|
| def lookup(self, token): |
| i = self.bpemb.spm.PieceToId(token) |
| if i == self.bpemb.spm.unk_id(): |
| return None |
| return self.vectors[i] |
|
|
| def contains(self, token): |
| return self.lookup(token) is not None |
|
|
| def to(self, device): |
| self.vectors = self.vectors.to(device) |
|
|