File size: 5,548 Bytes
ef814bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
from utils.helpers import create_multimodal_model
from models import SingleTransformer
from utils.helpers import get_all_modalities_available_samples
from data import create_dataset
import shap

def filter_ds(dataset, indices):
    rna = dataset.rna_data[indices]
    atac = dataset.atac_data[indices]
    flux = dataset.flux_data[indices]
    new_ds = create_dataset.MultiModalDataset((rna, atac, flux), 
                                              dataset.batch_no[indices], 
                                              dataset.labels[indices])
    return new_ds

def get_background_data(id, dataset, samples=100, return_other_samples=False):
    """
    Get background data with balanced samples from each label
    Args:
        dataset: MultiModalDataset object
        samples: Number of samples to get
        return_other_samples: If True, return other samples as well
    Returns:
        new_ds: MultiModalDataset object with background samples
        background_indices: Indices of background samples
        other_ds: MultiModalDataset object with other samples
        other_indices: Indices of other samples
    """
    if id not in ['RNA', 'ATAC', 'Flux', 'Multi']:
        raise ValueError("id must be one of 'RNA', 'ATAC', 'Flux', 'Multi'")
    
    if id == 'Multi':
        dataset = get_all_modalities_available_samples(dataset)
        labels = dataset.labels

        # get a balance of samples between labels
        samples_per_label = samples // len(torch.unique(labels))
        background_indices = []
        for label in torch.unique(labels):
            label_indices = torch.where(labels == label)[0]
            background_indices.extend(label_indices[:samples_per_label])
        background_indices = torch.tensor(background_indices)
        background_rna = dataset.rna_data[background_indices]
        background_atac = dataset.atac_data[background_indices]
        background_flux = dataset.flux_data[background_indices]
        bg_ds = create_dataset.MultiModalDataset((background_rna, background_atac, background_flux), 
                                                dataset.batch_no[background_indices], 
                                                dataset.labels[background_indices])
        if return_other_samples:
            # create a new dataset of other samples
            other_indices = torch.tensor([i for i in range(len(labels)) if i not in background_indices])
            other_rna = dataset.rna_data[other_indices]
            other_atac = dataset.atac_data[other_indices]
            other_flux = dataset.flux_data[other_indices]
            other_ds = create_dataset.MultiModalDataset((other_rna, other_atac, other_flux), 
                                                        dataset.batch_no[other_indices], 
                                                        dataset.labels[other_indices])
            return bg_ds, background_indices, other_ds, other_indices
        return bg_ds, background_indices
    else:
        raise ValueError("Not Implemented")
    
class ShapWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.model.eval()

    def forward(self, x):
        inputs, b = x[:,:-2], x[:,-1].squeeze(-1).long()
        inputs = (inputs[:,:944].long(), inputs[:,944:944+883].float(), inputs[:,944+883:].float())
        preds, _ = self.model(inputs, b)
        preds = torch.sigmoid(preds)
        # print(preds.shape)
        return preds
    
def compute_shap_values(id, fold_results, dataset, model_config, device):
    
    if id not in ['RNA', 'ATAC', 'Flux', 'Multi']:
        raise ValueError("id must be one of 'RNA', 'ATAC', 'Flux', 'Multi'")
    
    all_shap_values = []

    if id == 'Multi':
        bg_ds, bg_idx, other_ds, other_idx = get_background_data(id, dataset, samples=50, return_other_samples=True)
        print("total background samples: ", len(bg_idx), "total test samples: ", len(other_idx))
   
    for fold in fold_results:
        val_idx = fold['val_idx']
        # filter val_idx if is in indices
        val_idx = [i for i in val_idx if i in other_idx]

        if len(val_idx) == 0:
            print('No samples of the specified type in the validation set. Skipping...')
            continue
        else:
            print(f'fold {fold["fold"]} -> {len(val_idx)} samples')

        val_ds = filter_ds(dataset, val_idx)
        val_loader = torch.utils.data.DataLoader(val_ds, batch_size=32, shuffle=False)

        if id=='Multi':
            model = create_multimodal_model(model_config, device, use_mlm=False)
        else:
            model = SingleTransformer(id=id, **model_config).to(device)

        model_path = fold['best_model_path']
        model.load_state_dict(torch.load(model_path))
        model.eval()
        wrapped_model = ShapWrapper(model).to(device)

        bg_x = torch.cat([bg_ds.rna_data, bg_ds.atac_data, bg_ds.flux_data], dim=1).to(device)
        bg_b = bg_ds.batch_no.to(device)
        bgx = torch.cat([bg_x, bg_b[...,None]], dim=-1)
        explainer = shap.GradientExplainer(wrapped_model, bgx)

        inputs, batch_indices = (val_ds.rna_data, val_ds.atac_data, val_ds.flux_data), val_ds.batch_no

        inputs = torch.cat([inputs[0], inputs[1], inputs[2]], dim=1).to(device)
        batch_indices = batch_indices.to(device)
        bgv = torch.cat([inputs, batch_indices[...,None]], dim=-1)
        shap_values = explainer(bgv)
        all_shap_values.append(shap_values)
    
    return all_shap_values