File size: 5,566 Bytes
06257c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import gc
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from chestxray14 import ChestXray14Dataset
from chexpert import CheXpertDataset
from descriptors import disease_descriptors_chexpert, disease_descriptors_chestxray14
from model import InferenceModel
from utils import calculate_auroc

torch.multiprocessing.set_sharing_strategy('file_system')


def inference_chexpert():
    split = 'test'
    dataset = CheXpertDataset(f'data/chexpert/{split}_labels.csv')  # also do test
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x, num_workers=0)
    inference_model = InferenceModel()
    all_descriptors = inference_model.get_all_descriptors(disease_descriptors_chexpert)

    all_labels = []
    all_probs_neg = []

    for batch in tqdm(dataloader):
        batch = batch[0]
        image_paths, labels, keys = batch
        image_paths = [Path(image_path) for image_path in image_paths]
        agg_probs = []
        agg_negative_probs = []
        for image_path in image_paths:
            probs, negative_probs = inference_model.get_descriptor_probs(image_path, descriptors=all_descriptors)
            agg_probs.append(probs)
            agg_negative_probs.append(negative_probs)
        probs = {}  # Aggregated
        negative_probs = {}  # Aggregated
        for key in agg_probs[0].keys():
            probs[key] = sum([p[key] for p in agg_probs]) / len(agg_probs)  # Mean Aggregation

        for key in agg_negative_probs[0].keys():
            negative_probs[key] = sum([p[key] for p in agg_negative_probs]) / len(agg_negative_probs)  # Mean Aggregation

        disease_probs, negative_disease_probs = inference_model.get_diseases_probs(disease_descriptors_chexpert, pos_probs=probs,
                                                                                   negative_probs=negative_probs)
        predicted_diseases, prob_vector_neg_prompt = inference_model.get_predictions_bin_prompting(disease_descriptors_chexpert,
                                                                                                   disease_probs=disease_probs,
                                                                                                   negative_disease_probs=negative_disease_probs,
                                                                                                   keys=keys)
        all_labels.append(labels)
        all_probs_neg.append(prob_vector_neg_prompt)

    all_labels = torch.stack(all_labels)
    all_probs_neg = torch.stack(all_probs_neg)

    # evaluation
    existing_mask = sum(all_labels, 0) > 0
    all_labels_clean = all_labels[:, existing_mask]
    all_probs_neg_clean = all_probs_neg[:, existing_mask]
    all_keys_clean = [key for idx, key in enumerate(keys) if existing_mask[idx]]

    overall_auroc, per_disease_auroc = calculate_auroc(all_probs_neg_clean, all_labels_clean)
    print(f"AUROC: {overall_auroc:.5f}\n")
    for idx, key in enumerate(all_keys_clean):
        print(f'{key}: {per_disease_auroc[idx]:.5f}')


def inference_chestxray14():
    dataset = ChestXray14Dataset(f'data/chestxray14/Data_Entry_2017_v2020_modified.csv')
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x, num_workers=1)
    inference_model = InferenceModel()
    all_descriptors = inference_model.get_all_descriptors(disease_descriptors_chestxray14)

    all_labels = []
    all_probs_neg = []
    for batch in tqdm(dataloader):
        batch = batch[0]
        image_path, labels, keys = batch
        image_path = Path(image_path)
        probs, negative_probs = inference_model.get_descriptor_probs(image_path, descriptors=all_descriptors)
        disease_probs, negative_disease_probs = inference_model.get_diseases_probs(disease_descriptors_chestxray14, pos_probs=probs,
                                                                                   negative_probs=negative_probs)
        predicted_diseases, prob_vector_neg_prompt = inference_model.get_predictions_bin_prompting(disease_descriptors_chestxray14,
                                                                                                   disease_probs=disease_probs,
                                                                                                   negative_disease_probs=negative_disease_probs,
                                                                                                   keys=keys)
        all_labels.append(labels)
        all_probs_neg.append(prob_vector_neg_prompt)
        gc.collect()

    all_labels = torch.stack(all_labels)
    all_probs_neg = torch.stack(all_probs_neg)

    existing_mask = sum(all_labels, 0) > 0
    all_labels_clean = all_labels[:, existing_mask]
    all_probs_neg_clean = all_probs_neg[:, existing_mask]
    all_keys_clean = [key for idx, key in enumerate(keys) if existing_mask[idx]]

    overall_auroc, per_disease_auroc = calculate_auroc(all_probs_neg_clean[:, 1:], all_labels_clean[:, 1:])
    print(f"AUROC: {overall_auroc:.5f}\n")
    for idx, key in enumerate(all_keys_clean[1:]):
        print(f'{key}: {per_disease_auroc[idx]:.5f}')


if __name__ == '__main__':
    # add argument parser
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='chexpert', help='chexpert or chestxray14')
    args = parser.parse_args()

    if args.dataset == 'chexpert':
        inference_chexpert()
    elif args.dataset == 'chestxray14':
        inference_chestxray14()