| __author__ = 'Taneem Jan, taneemishere.github.io' |
|
|
| import sys |
| import numpy as np |
|
|
| START_TOKEN = "<START>" |
| END_TOKEN = "<END>" |
| PLACEHOLDER = " " |
| SEPARATOR = '->' |
|
|
|
|
| class Vocabulary: |
| def __init__(self): |
| self.binary_vocabulary = {} |
| self.vocabulary = {} |
| self.token_lookup = {} |
| self.size = 0 |
|
|
| self.append(START_TOKEN) |
| self.append(END_TOKEN) |
| self.append(PLACEHOLDER) |
|
|
| def append(self, token): |
| if token not in self.vocabulary: |
| self.vocabulary[token] = self.size |
| self.token_lookup[self.size] = token |
| self.size += 1 |
|
|
| def create_binary_representation(self): |
| if sys.version_info >= (3,): |
| items = self.vocabulary.items() |
| else: |
| items = self.vocabulary.iteritems() |
| for key, value in items: |
| binary = np.zeros(self.size) |
| binary[value] = 1 |
| self.binary_vocabulary[key] = binary |
|
|
| def get_serialized_binary_representation(self): |
| if len(self.binary_vocabulary) == 0: |
| self.create_binary_representation() |
|
|
| string = "" |
| if sys.version_info >= (3,): |
| items = self.binary_vocabulary.items() |
| else: |
| items = self.binary_vocabulary.iteritems() |
| for key, value in items: |
| array_as_string = np.array2string(value, separator=',', max_line_width=self.size * self.size) |
| string += "{}{}{}\n".format(key, SEPARATOR, array_as_string[1:len(array_as_string) - 1]) |
| return string |
|
|
| def save(self, path): |
| output_file_name = "{}/words.vocab".format(path) |
| output_file = open(output_file_name, 'w') |
| output_file.write(self.get_serialized_binary_representation()) |
| output_file.close() |
|
|
| def retrieve(self, path): |
| input_file = open("{}/words.vocab".format(path), 'r') |
| buffer = "" |
| for line in input_file: |
| try: |
| separator_position = len(buffer) + line.index(SEPARATOR) |
| buffer += line |
| key = buffer[:separator_position] |
| value = buffer[separator_position + len(SEPARATOR):] |
| value = np.fromstring(value, sep=',') |
|
|
| self.binary_vocabulary[key] = value |
| self.vocabulary[key] = np.where(value == 1)[0][0] |
| self.token_lookup[np.where(value == 1)[0][0]] = key |
|
|
| buffer = "" |
| except ValueError: |
| buffer += line |
| input_file.close() |
| self.size = len(self.vocabulary) |
|
|