File size: 11,284 Bytes
ef814bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from utils.helpers import create_multimodal_model
from models import SingleTransformer
from scipy.sparse import csr_matrix

def filter_idx(dataset, idx):
    """
    Filter the idx to only return the samples that none of its modalities are all zeros
    Args:
        dataset: Dataset object containing the data.
        idx: List of indices to filter.
    Returns:
        filtered_idx: List of filtered indices.
    """
    rna = dataset.rna_data
    atac = dataset.atac_data
    flux = dataset.flux_data
    mask = (rna != 0).any(axis=1) & (atac != 0).any(axis=1) & (flux != 0).any(axis=1)
    # filter the idx if the id is in the mask
    filtered_idx = [i for i in idx if mask[i]]
    
    return filtered_idx


def analyze_cls_attention(id, fold_results, dataset, model_config, device, indices, 
                          average_heads=True, return_flow_attention=False):
    """
    Extracts the attention weights of the validation set of each fold
    Args:
        id: The type of data to use. Must be one of 'RNA', 'ATAC', 'Flux', 'Multi'.
        fold_results: List of dictionaries containing the results of each fold.
        dataset: Dataset object containing the data.
        model_config: Dictionary containing the model configuration.
        device: Device to run the model on.
        sample_type: The type of samples to analyze. Must be one of 'all', 'dead-end', or 'reprogramming'. Defaults to 'all'.
        average_heads: Whether to average the attention weights across heads. Defaults to True.
    Returns:
        all_attention_weights: Numpy array containing the attention weights of the validation set
    """
    if id not in ['RNA', 'ATAC', 'Flux', 'Multi']:
        raise ValueError("id must be one of 'RNA', 'ATAC', 'Flux', 'Multi'")
    
    all_attention_weights = []

    for fold in fold_results:
        
        val_idx = fold['val_idx']
        # filter val_idx if is in indices
        val_idx = [i for i in val_idx if i in indices]
        
        if id == 'Multi':
            val_idx = filter_idx(dataset, val_idx)

        if len(val_idx) == 0:
            print('No samples of the specified type in the validation set. Skipping...')
            continue

        val_ds = Subset(dataset, val_idx)
        val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)
        
        if id=='Multi':
            model = create_multimodal_model(model_config, device, use_mlm=False)
        else:
            model = SingleTransformer(id=id, **model_config).to(device)

        model_path = fold['best_model_path']
        state_dict = torch.load(model_path, map_location='cpu')
        model.load_state_dict(state_dict)
        model.eval()
        
        with torch.no_grad():
            for batch in val_loader:
                x, b, _ = batch
                if isinstance(x, list):
                    rna = x[0].to(device)
                    atac = x[1].to(device)
                    flux = x[2].to(device)
                    x = (rna, atac, flux)
                else:
                    x = x.to(device)
                b = b.to(device)
                
                _, _, attention_weights = model(x, b, return_attention=True, return_flow_attention=return_flow_attention)
                
                if not return_flow_attention:
                    if average_heads:
                        attention_weights = attention_weights.squeeze(-2).mean(dim=1)  # Average across heads (batch, 1, seq_len) -> (batch, seq_len)
                    else:
                        attention_weights = attention_weights.squeeze(-2)  # (batch, num_heads, 1, seq_len) -> (batch, num_heads, seq_len)
                
                # if hasattr(attention_weights, 'numpy'):
                #     attention_weights = attention_weights.cpu().numpy()
                all_attention_weights.append(attention_weights)

    if not return_flow_attention:
        return np.concatenate(all_attention_weights, axis=0) # (n_samples, seq_len) or (n_samples, num_heads, seq_len)
    else:
        att_w = {'rna': [], 'atac': [], 'flux': [], 'cls': []}
        # noew we have a dict. So concatenating all values for each key
        num_layers_mlm = len(all_attention_weights[0]['rna'])
        num_layers_cls = len(all_attention_weights[0]['cls']) if isinstance(all_attention_weights[0]['cls'], list) else 1

        for key in all_attention_weights[0].keys():
            key_all_attentions = []
            for batch_row in all_attention_weights:
                modality_batch_attention_layers = batch_row[key]
                if isinstance(modality_batch_attention_layers, list):
                    for i, modality_attention_layers in enumerate(modality_batch_attention_layers):
                        modality_batch_attention_layers[i] = modality_attention_layers.cpu()
                    key_all_attentions.append(modality_batch_attention_layers)
                else:
                    key_all_attentions.append([modality_batch_attention_layers.cpu()])
            # now I have a list of attention weights for each batch in each layer [[layer0_att_weights_batch1, layer1_att_weights_batch1, ...], [layer0_att_weights_batch2, layer1_att_weights_batch2, ...], ...]
            # I want to concatenate all the attention weights for each layer
            num_layers = num_layers_cls if key == 'cls' else num_layers_mlm
            att_w[key] = [torch.cat([layer[i] for layer in key_all_attentions], axis=0) for i in range(num_layers)]
    return att_w


# def compute_attention_rollout(attention_weights):
#     num_layers = len(attention_weights)
#     combined_attention = torch.eye(attention_weights[0].size(-1)).to(attention_weights[0].device)
#     for layer in range(num_layers):
#         layer_attention = attention_weights[layer].mean(dim=1)  # Average over heads
#         combined_attention = torch.matmul(layer_attention, combined_attention)
#     return combined_attention
def compute_attention_rollout(attention_weights):
    """
    Computes the attention rollout for a batch of samples.
    Expects attention_weights to be a list (length=num_layers) of tensors 
    with shape (batch, num_heads, seq_len, seq_len). For each layer, we average 
    over the heads and then compute the rollout per sample.

    Returns:
        rollout: A tensor of shape (batch, seq_len, seq_len) representing the 
                 effective attention from the input token (typically CLS) to all tokens.
    """
    num_layers = len(attention_weights)
    # Get batch size and sequence length from the first layer's tensor
    batch_size, num_heads, seq_len, _ = attention_weights[0].shape

    # Initialize the combined attention as the identity matrix for each sample
    combined_attention = torch.eye(seq_len, device=attention_weights[0].device)
    combined_attention = combined_attention.unsqueeze(0).repeat(batch_size, 1, 1)

    for layer in range(num_layers):
        # Average over heads to get a (batch, seq_len, seq_len) tensor for this layer
        layer_attention = attention_weights[layer].mean(dim=1)
        # Update the rollout for each sample using batched matrix multiplication
        combined_attention = torch.bmm(layer_attention, combined_attention)
    return combined_attention
def multimodal_attention_rollout(all_attention_weights):
    rna_rollout = compute_attention_rollout(all_attention_weights['rna'])
    atac_rollout = compute_attention_rollout(all_attention_weights['atac'])
    flux_rollout = compute_attention_rollout(all_attention_weights['flux'])
    
    cls_attention = all_attention_weights['cls'][0].mean(dim=1).squeeze(1)   # Average over heads
    
    # Split CLS attention for each modality
    rna_cls_attn, atac_cls_attn, flux_cls_attn = cls_attention.split(
        [rna_rollout.size(1), atac_rollout.size(1), flux_rollout.size(1)], dim=1)
    
    final_rollout = torch.cat([
        rna_cls_attn.unsqueeze(1) @ rna_rollout,
        atac_cls_attn.unsqueeze(1) @ atac_rollout,
        flux_cls_attn.unsqueeze(1) @ flux_rollout
    ], dim=2)
    
    return final_rollout.squeeze(1) # remove head dimension [samples, tokens]

def print_top_features(attention_weights, feature_names, top_n=5, modality=None):
    print(f"\nTop {top_n} attended features ({modality} samples):")
    avg_attention = attention_weights.mean(axis=0).numpy() if hasattr(attention_weights, 'numpy') else attention_weights.mean(axis=0)
    top_indices = avg_attention.argsort()[-top_n:][::-1]
    for i in top_indices:
        print(f"{feature_names[i]}: {avg_attention[i]:.4f}")

def get_top_features(attention_weights, feature_names, top_n=100, modality=None):
    ls = []
    avg_attention = attention_weights.mean(axis=0).numpy() if hasattr(attention_weights, 'numpy') else attention_weights.mean(axis=0)
    if top_n:
        top_indices = avg_attention.argsort()[-top_n:][::-1]
    else:
        top_indices = avg_attention.argsort()[::-1]
        
    for i in top_indices:
        ls.append((feature_names[i],avg_attention[i]))
    return ls

from scipy.sparse.csgraph import maximum_flow

def compute_attention_flow(attention_weights):
    num_layers = len(attention_weights)
    num_tokens = attention_weights[0].size(-1)
    
    # Create adjacency matrix for the flow network
    adj_matrix = np.zeros((num_layers * num_tokens, num_layers * num_tokens))
    
    for i in range(num_layers - 1):
        layer_attention = attention_weights[i].mean(dim=1).cpu().numpy()  # Average over heads
        start_idx = i * num_tokens
        end_idx = (i + 1) * num_tokens
        adj_matrix[start_idx:end_idx, end_idx:(end_idx + num_tokens)] = layer_attention
    
    for i in range(num_layers - 1):
        start_idx = i * num_tokens
        end_idx = (i + 1) * num_tokens
        adj_matrix[start_idx:end_idx, end_idx:(end_idx + num_tokens)] += np.eye(num_tokens)
    
    flows = np.zeros((num_tokens, num_tokens))
    for i in range(num_tokens):
        source = i
        for j in range(num_tokens):
            sink = (num_layers - 1) * num_tokens + j
            _, flow = maximum_flow(csr_matrix(adj_matrix), source, sink)
            flows[i, j] = flow
    
    return torch.tensor(flows, device=attention_weights[0].device)

def multimodal_attention_flow(all_attention_weights):
    rna_flow = compute_attention_flow(all_attention_weights['rna'])
    atac_flow = compute_attention_flow(all_attention_weights['atac'])
    flux_flow = compute_attention_flow(all_attention_weights['flux'])
    
    cls_attention = all_attention_weights['cls'][0].mean(dim=1).squeeze(1)  # Average over heads
    
    # Split CLS attention for each modality
    rna_cls_attn, atac_cls_attn, flux_cls_attn = cls_attention.split(
        [rna_flow.size(1), atac_flow.size(1), flux_flow.size(1)], dim=1)
    
    # Normalize flows
    rna_flow = rna_flow / rna_flow.sum(dim=1, keepdim=True)
    atac_flow = atac_flow / atac_flow.sum(dim=1, keepdim=True)
    flux_flow = flux_flow / flux_flow.sum(dim=1, keepdim=True)
    
    final_flow = torch.cat([
        rna_cls_attn.unsqueeze(1) @ rna_flow,
        atac_cls_attn.unsqueeze(1) @ atac_flow,
        flux_cls_attn.unsqueeze(1) @ flux_flow
    ], dim=2)
    
    return final_flow.squeeze(1)