File size: 4,138 Bytes
151b875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os,csv,time
from argparse import ArgumentParser

import numpy as np
import torch
import torch.nn.functional as F

from transformers import AutoModelForCausalLM
from tqdm import tqdm

from anticipation.config import EVENT_SIZE

def log_loss(model, datafile, subsample):
    with open(datafile, 'r') as data:
        ce = torch.empty(0)
        for i,line in tqdm(list(enumerate(data))):
            if i % subsample != 0:
                continue

            tokens = [int(token) for token in line.split()]
            tokens = torch.tensor(tokens).unsqueeze(0).cuda()
            with torch.no_grad():
                logits = model(tokens).logits[0]
                ce = torch.cat([ce, F.cross_entropy(logits[:-1],tokens[0,1:],reduction='none').cpu()])

    return ce


def main(args):
    print(f'Sub-sampling results at rate {args.subsample}')

    results = os.path.join(args.model, args.output)
    print(f'Storing results at {results}')

    checkpoints = [os.path.join(f.path, 'hf') for f in os.scandir(args.model) if
            f.is_dir() and os.path.basename(f).startswith('step-')]

    if args.all:
        print('Calculating log-loss for checkpoints:')
        for ckpt in checkpoints:
            print('  ', ckpt)
    else:
        steps = [int(ckpt.split(os.sep)[-2][5:]) for ckpt in checkpoints]
        checkpoints = [os.path.join(args.model, f'step-{max(steps)}', 'hf')]
        print('Calculating log-loss for final checkpoint:')
        print('  ', checkpoints[0])

    print('Calculating log-loss on dataset:')
    print('  ', args.filename)
    with open(results, 'w', newline='') as f:
        fields = ['step', 'loss']
        if args.bpe:
            fields.append('bpe')
        if not args.interarrival:
            fields.extend(['event_ppl', 'onset_ppl', 'dur_ppl', 'note_ppl'])

        writer = csv.DictWriter(f, fieldnames=fields)
        writer.writeheader()
        for ckpt in checkpoints:
            step = int(ckpt.split(os.sep)[-2][5:])
            print(f'Loading checkpoint (step {step}):')
            print('  ', ckpt)
            t0 = time.time()
            model = AutoModelForCausalLM.from_pretrained(ckpt).cuda()
            print(f'  loaded in {time.time()-t0} seconds')

            ce = log_loss(model, args.filename, args.subsample)

            res = {}
            res['step'] = step
            res['loss'] = np.round(ce.mean().item(), 3)
            if args.bpe:
                # hardcoding length of the LakhMidi test set in hours: 560.98
                assert os.path.basename(args.filename) == 'test.txt'
                res['bpe'] = args.subsample*ce.mean().item()*np.log2(np.e)*(len(ce) / (560.98*3600))
            if not args.interarrival:
                res['event_ppl'] = np.round(np.exp(EVENT_SIZE*ce.mean().item()), 3)
                res['onset_ppl'] = np.round(np.exp(ce[0::3].mean().item()), 3)
                res['dur_ppl'] = np.round(np.exp(ce[1::3].mean().item()), 3)
                res['note_ppl'] = np.round(np.exp(ce[2::3].mean().item()), 3)

            writer.writerow(res)


if __name__ == '__main__':
    parser = ArgumentParser(description='evaluate log-loss for a tokenized dataset')
    parser.add_argument('-f', '--filename', help='file containing a tokenized dataset')
    parser.add_argument('-m', '--model', help='file containing a model to evaluate')
    parser.add_argument('-o', '--output', help='output file')
    parser.add_argument('-v', '--verbose', action='store_true', help='verbose console output')
    parser.add_argument('-a', '--all', action='store_true',
            help='calculate loss for all checkpoints')
    parser.add_argument('--bpe', action='store_true',
            help='calculate loss for all checkpoints')
    parser.add_argument('-i', '--interarrival', action='store_true',
            help='request interarrival-time enocoding (default to arrival-time encoding)')
    parser.add_argument('-s', '--subsample', type=int, default=10,
            help='dataset subsampling ratio')

    main(parser.parse_args())