File size: 2,804 Bytes
d541e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def stringify(array):
    return  '\n'.join([' '.join(inner_list) for inner_list in array])


def compress(text, tokenizer, model):
    """
    tokenizer: Tokenizer.
    text: str.
        Each line represents a single document.
    """
    tokens = [sentence.split() for sentence in text.split("\n")]
    indices, _ = tokenizer(text.split("\n"))

    logits = model(indices)
    next_token_predicted = logits.argmax(dim=2)

    # slices are for skipping edge tokens
    prediction_mask = indices[:, 1:] == next_token_predicted[:, :-1]

    # replace correctly predicted tokens with "X"
    for i, sentence_mask in enumerate(prediction_mask):
        sentence_len = len(tokens[i])
        for j, predicted_successfully in enumerate(sentence_mask):
            # length check is to ignore pad tokens
            if predicted_successfully and j < sentence_len and tokenizer.vocab[tokens[i][j]] != tokenizer.unk_index:
                tokens[i][j] = "X"

    sentences = [" ".join(sentence) for sentence in tokens]
    document = "\n".join(sentences)
    return document


def decompress(text, tokenizer, model):
    """
    text: str.
        Each line represents a single document.
    """
    sentence_tokens = [document.split() for document in text.split("\n")]
    indices, _ = tokenizer(text.split("\n"))

    uncompressed = []
    for i, sentence in enumerate(sentence_tokens):
        prefix = ['<EDGE>']
        for j, token in enumerate(sentence):
            if token != "X":
                prefix.append(token)
            else:
                # only infer when X is found
                indices = torch.tensor([tokenizer.vocab(prefix)],
                                       dtype=torch.int,
                                       device=device)
                logits = model(indices)
                # prediction logit for X
                logit = logits[:, -1, :]
                index = logit.argmax(dim=1)
                prefix.append(tokenizer.vocab.lookup_token(index))

        # reset prefix for new sentence
        uncompressed.append(prefix[1:])

    return stringify(uncompressed)


def load_from_checkpoint(model, checkpoint_path):
    """
    Loads a model from a checkpoint.

    Parameters:
    ----------
    checkpoint_path: The path to the checkpoint.

    Raises:
    ------
    Exception: If no checkpoint is found in the provided path.
    """
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        print(f"loaded existing model.")
    else:
        raise Exception("No checkpoint found in the provided path")