riyadhrazzaq's picture
added inference scripts, model and vocab
d541e5a
import os
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def stringify(array):
return '\n'.join([' '.join(inner_list) for inner_list in array])
def compress(text, tokenizer, model):
"""
tokenizer: Tokenizer.
text: str.
Each line represents a single document.
"""
tokens = [sentence.split() for sentence in text.split("\n")]
indices, _ = tokenizer(text.split("\n"))
logits = model(indices)
next_token_predicted = logits.argmax(dim=2)
# slices are for skipping edge tokens
prediction_mask = indices[:, 1:] == next_token_predicted[:, :-1]
# replace correctly predicted tokens with "X"
for i, sentence_mask in enumerate(prediction_mask):
sentence_len = len(tokens[i])
for j, predicted_successfully in enumerate(sentence_mask):
# length check is to ignore pad tokens
if predicted_successfully and j < sentence_len and tokenizer.vocab[tokens[i][j]] != tokenizer.unk_index:
tokens[i][j] = "X"
sentences = [" ".join(sentence) for sentence in tokens]
document = "\n".join(sentences)
return document
def decompress(text, tokenizer, model):
"""
text: str.
Each line represents a single document.
"""
sentence_tokens = [document.split() for document in text.split("\n")]
indices, _ = tokenizer(text.split("\n"))
uncompressed = []
for i, sentence in enumerate(sentence_tokens):
prefix = ['<EDGE>']
for j, token in enumerate(sentence):
if token != "X":
prefix.append(token)
else:
# only infer when X is found
indices = torch.tensor([tokenizer.vocab(prefix)],
dtype=torch.int,
device=device)
logits = model(indices)
# prediction logit for X
logit = logits[:, -1, :]
index = logit.argmax(dim=1)
prefix.append(tokenizer.vocab.lookup_token(index))
# reset prefix for new sentence
uncompressed.append(prefix[1:])
return stringify(uncompressed)
def load_from_checkpoint(model, checkpoint_path):
"""
Loads a model from a checkpoint.
Parameters:
----------
checkpoint_path: The path to the checkpoint.
Raises:
------
Exception: If no checkpoint is found in the provided path.
"""
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"loaded existing model.")
else:
raise Exception("No checkpoint found in the provided path")