File size: 1,013 Bytes
d541e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161e02f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from src.model import Model2
from src.tokenizer import Tokenizer
from src.util import *


def evaluate(args):
    vocab = torch.load(args.vocab, map_location=torch.device('cpu'))
    model = Model2(len(vocab), 300, 256, vocab['<PAD>'])
    load_from_checkpoint(model, args.checkpoint)

    print()
    if args.decompress:
        print(decompress(args.text, Tokenizer(vocab), model))
    else:
        print(compress(args.text, Tokenizer(vocab), model))


def evaluate(text, compression=True):
    vocab = torch.load("vocab.pt", map_location=torch.device('cpu'))
    model = Model2(len(vocab), 300, 256, vocab['<PAD>'])
    load_from_checkpoint(model, "model_lr0.0001_bs256_epoch50.pt")

    if not compression:
        result = decompress(text, Tokenizer(vocab), model)
    else:
        result = compress(text, Tokenizer(vocab), model)

    # calculate the compression ratio from string lengths
    compression_ratio = (1 - (len(result) / len(text))) * 100
    return result, f"{compression_ratio}% compressed"