File size: 11,210 Bytes
ef814bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import torch
from torch.utils.data import DataLoader, Subset
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from models import SingleTransformer
from utils.helpers import create_multimodal_model
from data.create_dataset import MultiModalDataset
from .attentions import filter_idx

def get_latent_space(id, fold_results, labelled_dataset, 
            model_config, device, batch_size=32, common_samples=True):

    if id not in ['RNA', 'ATAC', 'Flux', 'Multi']:
        raise ValueError("id must be one of 'RNA', 'ATAC', 'Flux', 'Multi'")

    latent_space = []
    labels = []
    preds = []
    for fold in fold_results:
        model_path = fold['best_model_path']
        val_idx = fold['val_idx']
        if common_samples:
            val_idx = filter_idx(labelled_dataset, val_idx)
        val_ds = Subset(labelled_dataset, val_idx)
        val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
        if id=='Multi':
            model = create_multimodal_model(model_config, device, use_mlm=False)
        else:
            model = SingleTransformer(id=id, **model_config).to(device)

        # Load weights to CPU first, then move to target device (handles CUDA->MPS/CPU transfer)
        state_dict = torch.load(model_path, map_location='cpu')
        model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()
        with torch.no_grad():
            for batch in val_loader:
                x, b, y = batch
                if isinstance(x, list):
                    rna= x[0].to(device)
                    atac = x[1].to(device)
                    flux = x[2].to(device)
                    x = (rna, atac, flux)
                else:
                    x = x.to(device)
                b = b.to(device)
                
                ls, pred = model.get_latent_space(x, b)
                latent_space.append(ls.cpu().numpy())
                labels.append(y.numpy())
                preds.append(pred.cpu().numpy())
    latent_space = np.concatenate(latent_space)
    labels = np.concatenate(labels)
    preds = np.concatenate(preds)
    preds = np.round(preds)
    return latent_space, labels, preds

def get_latent_space_cached(models, fold_results, dataset, device, batch_size=64, common_samples=True):
    """
    Compute latent space using preloaded models.
    """
    latent_space = []
    labels = []
    preds = []
    for model, fold in zip(models, fold_results):
        val_idx = fold['val_idx']
        if common_samples:
            val_idx = filter_idx(dataset, val_idx)
        val_ds = Subset(dataset, val_idx)
        # Increase batch size to speed up inference
        val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
        model.eval()
        with torch.no_grad():
            for batch in val_loader:
                x, b, y = batch
                if isinstance(x, list):
                    # For multimodal inputs, move each modality to device
                    rna = x[0].to(device)
                    atac = x[1].to(device)
                    flux = x[2].to(device)
                    x = (rna, atac, flux)
                else:
                    x = x.to(device)
                b = b.to(device)
                ls, pred = model.get_latent_space(x, b)
                latent_space.append(ls.cpu().numpy())
                labels.append(y.numpy())
                preds.append(pred.cpu().numpy())
    latent_space = np.concatenate(latent_space)
    labels = np.concatenate(labels)
    preds = np.concatenate(preds)
    preds = np.round(preds)
    return latent_space, labels, preds

def measure_shift(original_latent, perturbed_latent):
    return np.mean(np.linalg.norm(original_latent - perturbed_latent, axis=1))

def perturb_feature(data, feature_idx, perturbation_type='additive', scale=0.1, min_samples_threshold=10):
    perturbed_data = data.clone()
    non_zero_rows_mask = data[:, feature_idx] != 0
    
    # Check if feature has enough non-zero samples
    if non_zero_rows_mask.sum() < min_samples_threshold:
        return None, True  # Return None and flag indicating insufficient samples

    if perturbation_type == 'shuffle':
        # Shuffle only non-zero values (preserves sparsity pattern)
        non_zero_values = perturbed_data[non_zero_rows_mask, feature_idx].clone()
        shuffled_idx = torch.randperm(non_zero_values.size(0), device=perturbed_data.device)
        perturbed_data[non_zero_rows_mask, feature_idx] = non_zero_values[shuffled_idx]
        
    elif perturbation_type == 'shuffle_all':
        # Shuffle all values (including zeros)
        shuffled_idx = torch.randperm(perturbed_data.size(0), device=perturbed_data.device)
        perturbed_data[:, feature_idx] = data[shuffled_idx, feature_idx]

    elif perturbation_type == 'additive':
        noise = torch.randn_like(perturbed_data[:, feature_idx].float()) * scale * torch.std(perturbed_data[:, feature_idx].float())
        noise = noise.to(perturbed_data.device)

        if data.dtype == torch.int32:
            perturbed_data[non_zero_rows_mask, feature_idx] += torch.tensor(noise[non_zero_rows_mask], dtype=torch.int32).to(perturbed_data.device)
        else:
            perturbed_data[non_zero_rows_mask, feature_idx] += noise[non_zero_rows_mask]

    elif perturbation_type == 'multiplicative':
        factor = 1 + scale * (torch.rand(perturbed_data.shape[0], device=perturbed_data.device) - 0.5)
        if data.dtype == torch.int32:
            perturbed_data[non_zero_rows_mask, feature_idx] = torch.tensor(
                perturbed_data[non_zero_rows_mask, feature_idx].float() * factor[non_zero_rows_mask],
                dtype=torch.int32).to(perturbed_data.device)
        else:
            perturbed_data[non_zero_rows_mask, feature_idx] *= factor[non_zero_rows_mask]

    return perturbed_data, False  # Return perturbed data and flag indicating sufficient samples

def analyze_feature_importance_multi(id, model_config, fold_results, dataset, feature_names, 
            device, analyse_features='all', perturbation_scale=0.1, min_samples_threshold=10, common_samples=True):
    if analyse_features not in ['all', 'RNA', 'ATAC', 'Flux']:
        raise ValueError("analyse_features must be one of 'all', 'RNA', 'ATAC', 'Flux'")
    
    models = []
    for fold in fold_results:
        model_path = fold['best_model_path']
        if id == 'Multi':
            model = create_multimodal_model(model_config, device, use_mlm=False)
        else:
            model = SingleTransformer(id=id, **model_config).to(device)
        # Load weights to CPU first, then move to target device (handles CUDA->MPS/CPU transfer)
        state_dict = torch.load(model_path, map_location='cpu')
        model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()
        models.append(model)
    
    # Compute the original latent space once using the cached models
    original_latent, _, _ = get_latent_space_cached(models, fold_results, dataset, device, batch_size=64, common_samples=common_samples)
    
    feature_shifts = []
    skipped_features = []  # Track features skipped due to insufficient samples
    # Unpack multi-modal data
    X, b, y = (dataset.rna_data, dataset.atac_data, dataset.flux_data), dataset.batch_no, dataset.labels
    rna_input, atac_input, flux_input = X[0], X[1], X[2]
    atac_start = rna_input.shape[1] + 1
    flux_start = atac_start + atac_input.shape[1] + 1
    print("atac start", atac_start, "flux start", flux_start)
    perturb_type = 'shuffle'
    if analyse_features in ['RNA', 'all']:
        print("Analyzing RNA features")
        print("Permuting RNA features with", perturb_type)
        for i in tqdm(range(rna_input.shape[1])):
            # Choose perturbation type based on the mean value
             #if rna_input[:, i].float().mean() < 10 else 'multiplicative'
            perturbed_rna, insufficient_samples = perturb_feature(rna_input, i, perturb_type, scale=perturbation_scale, min_samples_threshold=min_samples_threshold)
            if insufficient_samples:
                skipped_features.append((feature_names[i], "RNA", (rna_input[:, i] != 0).sum().item()))
                feature_shifts.append((feature_names[i], 0.0))  # Add with 0 importance
            else:
                perturbed_dataset = MultiModalDataset((perturbed_rna, atac_input, flux_input), b, y)
                perturbed_latent, _, _ = get_latent_space_cached(models, fold_results, perturbed_dataset, device, batch_size=64, common_samples=common_samples)
                shift = measure_shift(original_latent, perturbed_latent)
                feature_shifts.append((feature_names[i], shift))

    if analyse_features in ['ATAC', 'all']:
        print("Analyzing ATAC features")
        print("Permuting ATAC features with", perturb_type)
        for i in tqdm(range(atac_input.shape[1])):
            perturbed_atac, insufficient_samples = perturb_feature(atac_input, i, perturb_type, perturbation_scale, min_samples_threshold=min_samples_threshold)
            if insufficient_samples:
                skipped_features.append((feature_names[atac_start + i], "ATAC", (atac_input[:, i] != 0).sum().item()))
                feature_shifts.append((feature_names[atac_start + i], 0.0))  # Add with 0 importance
            else:
                perturbed_dataset = MultiModalDataset((rna_input, perturbed_atac, flux_input), b, y)
                perturbed_latent, _, _ = get_latent_space_cached(models, fold_results, perturbed_dataset, device, batch_size=64, common_samples=common_samples)
                shift = measure_shift(original_latent, perturbed_latent)
                feature_shifts.append((feature_names[atac_start + i], shift))
            
    if analyse_features in ['Flux', 'all']:
        print("Permuting Flux features with", perturb_type)
        print("Analyzing Flux features")
        for i in tqdm(range(flux_input.shape[1])):
            perturbed_flux, insufficient_samples = perturb_feature(flux_input, i, 'shuffle_all', perturbation_scale, min_samples_threshold=min_samples_threshold)
            if insufficient_samples:
                skipped_features.append((feature_names[flux_start + i], "Flux", (flux_input[:, i] != 0).sum().item()))
                feature_shifts.append((feature_names[flux_start + i], 0.0))  # Add with 0 importance
            else:
                perturbed_dataset = MultiModalDataset((rna_input, atac_input, perturbed_flux), b, y)
                perturbed_latent, _, _ = get_latent_space_cached(models, fold_results, perturbed_dataset, device, batch_size=64, common_samples=common_samples)
                shift = measure_shift(original_latent, perturbed_latent)
                feature_shifts.append((feature_names[flux_start + i], shift))
    
    # Log skipped features
    if skipped_features:
        print(f"\nSkipped {len(skipped_features)} features due to insufficient samples (< {min_samples_threshold}):")
        for feature_name, modality, sample_count in skipped_features:
            print(f"  {feature_name} ({modality}): {sample_count} samples")
    
    return sorted(feature_shifts, key=lambda x: x[1], reverse=True)