# -*- coding: utf-8 -*- import torch class Embedding(object): def __init__(self, tokens, vectors, unk=None): super(Embedding, self).__init__() self.tokens = tokens self.vectors = torch.tensor([v[0] for v in vectors]) print(self.vectors.size(0)) self.pretrained = {w: v for w, v in zip(tokens, vectors)} self.unk = '[UNK]' def __len__(self): return len(self.tokens) def __contains__(self, token): return token in self.pretrained @property def dim(self): return self.vectors.size(0) @property def unk_index(self): if self.unk is not None: return self.tokens.index(self.unk) else: raise AttributeError @classmethod def load(cls, path, unk=None): with open(path, 'r') as f: lines = [line for line in f] splits = [line.split() for line in lines] tokens, vectors = zip(*[(s[0], list(map(float, s[1:]))) for s in splits]) return cls(tokens, vectors, unk=unk)