File size: 2,303 Bytes
a3b29ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import re
import torch

vocabulary = {}
token_vocabulary = {}
# vocabulary_length = ['<EOS>']

with open('cl100k_base_vocab_list.txt', 'r', encoding='utf-8') as file:
    for line_count, line in enumerate(file):
        line = line.rstrip('\n')
        if (line.startswith('\'') and line.endswith('\'')) or (line.startswith('\"') and line.endswith('\"')):
            line = line[1:-1]
            vocabulary[line] = line_count
        else:
            vocabulary[line] = line_count
token_vocabulary = {v: k for k, v in vocabulary.items()}

def get_vocabulary():
    return vocabulary


def get_token_vocabulary():
    return token_vocabulary

# def check_vocabulary_length(word):
#     append_length = True
#     for vocab in vocabulary_length:
#         if word == vocab:
#             append_length = False
#             break
#     if append_length == True:
#         vocabulary_length.append(word)
#
# def return_vocabulary_length():
#     return vocabulary_length

def tokenize_sequence(sentence):
    # tokenized_seq = [vocabulary.get('<SOS>')]
    tokenized_seq = []
    regex = r'(\s+\w+|\S+)'
    words = re.split(regex, sentence)
    for word in words:
        if word in vocabulary:
            tokenized_seq.append(vocabulary.get(word, vocabulary.get('<UNK>')))
        else:
            i = 0
            while i < len(word):
                subword_len = 1
                for j in range(len(word), i - 1, -1):
                    subword = word[i:j]
                    if subword in vocabulary:
                        tokenized_seq.append(vocabulary.get(subword, vocabulary.get('<UNK>')))
                        subword_len = len(subword)
                        break
                    if j - i == 1:
                        tokenized_seq.append(vocabulary.get('<UNK>'))
                        break
                i += subword_len
    tokenized_seq.append(vocabulary.get('<EOS>'))
    return tokenized_seq


def detokenize_sequence(tokenized_seq):
    decoded_sentence = ''
    for token in tokenized_seq:
        decoded_sentence += token_vocabulary[token]
    return decoded_sentence


def pad_to_length(seq, length):
    padded_seq = torch.full((length,), fill_value=0, dtype=torch.long)
    padded_seq[:len(seq)] = torch.tensor(seq, dtype=torch.long)
    return padded_seq