|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
class Embedding(object): |
|
|
|
|
|
def __init__(self, tokens, vectors, unk=None): |
|
|
super(Embedding, self).__init__() |
|
|
self.tokens = tokens |
|
|
self.vectors = torch.tensor([v[0] for v in vectors]) |
|
|
print(self.vectors.size(0)) |
|
|
self.pretrained = {w: v for w, v in zip(tokens, vectors)} |
|
|
self.unk = '[UNK]' |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.tokens) |
|
|
|
|
|
def __contains__(self, token): |
|
|
return token in self.pretrained |
|
|
|
|
|
@property |
|
|
def dim(self): |
|
|
return self.vectors.size(0) |
|
|
|
|
|
@property |
|
|
def unk_index(self): |
|
|
if self.unk is not None: |
|
|
return self.tokens.index(self.unk) |
|
|
else: |
|
|
raise AttributeError |
|
|
|
|
|
@classmethod |
|
|
def load(cls, path, unk=None): |
|
|
with open(path, 'r') as f: |
|
|
lines = [line for line in f] |
|
|
splits = [line.split() for line in lines] |
|
|
tokens, vectors = zip(*[(s[0], list(map(float, s[1:]))) |
|
|
for s in splits]) |
|
|
|
|
|
return cls(tokens, vectors, unk=unk) |
|
|
|