Spaces:
Configuration error
Configuration error
File size: 2,804 Bytes
d541e5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import os
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def stringify(array):
return '\n'.join([' '.join(inner_list) for inner_list in array])
def compress(text, tokenizer, model):
"""
tokenizer: Tokenizer.
text: str.
Each line represents a single document.
"""
tokens = [sentence.split() for sentence in text.split("\n")]
indices, _ = tokenizer(text.split("\n"))
logits = model(indices)
next_token_predicted = logits.argmax(dim=2)
# slices are for skipping edge tokens
prediction_mask = indices[:, 1:] == next_token_predicted[:, :-1]
# replace correctly predicted tokens with "X"
for i, sentence_mask in enumerate(prediction_mask):
sentence_len = len(tokens[i])
for j, predicted_successfully in enumerate(sentence_mask):
# length check is to ignore pad tokens
if predicted_successfully and j < sentence_len and tokenizer.vocab[tokens[i][j]] != tokenizer.unk_index:
tokens[i][j] = "X"
sentences = [" ".join(sentence) for sentence in tokens]
document = "\n".join(sentences)
return document
def decompress(text, tokenizer, model):
"""
text: str.
Each line represents a single document.
"""
sentence_tokens = [document.split() for document in text.split("\n")]
indices, _ = tokenizer(text.split("\n"))
uncompressed = []
for i, sentence in enumerate(sentence_tokens):
prefix = ['<EDGE>']
for j, token in enumerate(sentence):
if token != "X":
prefix.append(token)
else:
# only infer when X is found
indices = torch.tensor([tokenizer.vocab(prefix)],
dtype=torch.int,
device=device)
logits = model(indices)
# prediction logit for X
logit = logits[:, -1, :]
index = logit.argmax(dim=1)
prefix.append(tokenizer.vocab.lookup_token(index))
# reset prefix for new sentence
uncompressed.append(prefix[1:])
return stringify(uncompressed)
def load_from_checkpoint(model, checkpoint_path):
"""
Loads a model from a checkpoint.
Parameters:
----------
checkpoint_path: The path to the checkpoint.
Raises:
------
Exception: If no checkpoint is found in the provided path.
"""
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"loaded existing model.")
else:
raise Exception("No checkpoint found in the provided path")
|