File size: 1,744 Bytes
bf1f674
 
d9d8dae
bf1f674
 
 
 
d9d8dae
bf1f674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9d8dae
bf1f674
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from module import config, transformers_utility as tr, utils, metrics, dataio
from prettytable import PrettyTable
import numpy as np

table = PrettyTable()
table.field_names = config.tissues
TOKENIZER_DIR = config.models / "byte-level-bpe-tokenizer"
PRETRAINED_MODEL = config.models / "transformer" / "prediction-model" / "saved_model.pth"
DATA_DIR = config.data

def load_model(args, settings):
    return tr.load_model(
        args.model_name,
        args.tokenizer_dir,
        pretrained_model=args.pretrained_model,
        log_offset=args.log_offset,
        **settings,
    )

def main(TEST_DATA):
    args = utils.get_args(
        data_dir=DATA_DIR,
        train_data=TEST_DATA,
        test_data=TEST_DATA,
        pretrained_model=PRETRAINED_MODEL,
        tokenizer_dir=TOKENIZER_DIR,
        model_name="roberta-pred-mean-pool",
    )

    settings = utils.get_model_settings(config.settings, args)
    if args.output_mode:
        settings["output_mode"] = args.output_mode
    if args.tissue_subset is not None:
        settings["num_labels"] = len(args.tissue_subset)
    
    print("Loading model...")
    config_obj, tokenizer, model = load_model(args, settings)

    print("Loading data...")
    datasets = dataio.load_datasets(
        tokenizer,
        args.train_data,
        eval_data=args.eval_data,
        test_data=args.test_data,
        seq_key="text",
        file_type="text",
        filter_empty=args.filter_empty,
        shuffle=False,
    )
    dataset_test = datasets["train"]

    print("Getting predictions:")
    preds = np.exp(np.array(metrics.get_predictions(model, dataset_test))) - 1
    for e in preds:
        table.add_row(e)
    print(table) 

if __name__ == "__main__":
    main("test.txt")