File size: 5,036 Bytes
ef814bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from models import SingleTransformer, MultiModalTransformer
import config
from data import create_dataset

def create_masked_input(input_tensor, mask_token, mask_prob=0.20):
    """
    Creates a masked input tensor by randomly replacing elements with a mask token.
    Args:
        input_tensor (torch.Tensor): The input tensor to be masked.
        mask_token: The token to be used for masking.
        mask_prob (float, optional): The probability of masking an element. Defaults to 0.20.
    Returns:
        torch.Tensor: The masked input tensor.
        torch.Tensor: A boolean mask indicating which elements were masked.
    """

    mask = torch.rand(input_tensor.shape) < mask_prob
    masked_input = input_tensor.clone()
    masked_input[mask] = mask_token
    return masked_input, mask

def get_max(adata):
    """
    Get the maximum value in the data.
    Args:
        adata (list): A list of AnnData objects.
    Returns:
        float: The maximum value in the list data.
    """
    assert(isinstance(adata, list)), "adata must be a list of AnnData objects."
    x_s = []
    for i in adata:
        X = torch.tensor(i.X.toarray().copy())
        x_s.append(np.array(X).flatten().max())
    return max(x_s)

def get_token_embeddings(model, dataset, device):
    """
    Get the token embeddings for the dataset.
    Args:
        model (torch.nn.Module): Model.
        dataset (torch.utils.data.Dataset): Dataset.
        device (str): Device to use.
    Returns:
        torch.Tensor: Embeddings.
    """
    model.eval()
    embeddings = []
    loader = DataLoader(dataset, batch_size=32, shuffle=False) 
    with torch.no_grad():
        for batch in loader:
            if len(batch) == 3:
                inputs, bi, _ = batch
            elif len(batch) == 2:
                inputs, bi = batch
            if isinstance(inputs, list):
                rna= inputs[0].to(device)
                atac = inputs[1].to(device)
                flux = inputs[2].to(device)
                inputs = (rna, atac, flux)
            else:
                inputs = inputs.to(device)
            bi = bi.to(device)

            output = model(inputs, bi, return_embeddings=True)
            embeddings.append(output.cpu().detach())
    
    # Concatenate embeddings across batches
    embeddings = torch.cat(embeddings, dim=0)  # shape: (n_samples, seq_len, d_model)
    return embeddings

def get_all_modalities_available_samples(dataset):
    
    rna = dataset.rna_data
    atac = dataset.atac_data
    flux = dataset.flux_data
    mask = (rna != 0).any(axis=1) & (atac != 0).any(axis=1) & (flux != 0).any(axis=1)
    new_ds = create_dataset.MultiModalDataset((rna[mask], atac[mask], flux[mask]), 
                                              dataset.batch_no[mask], 
                                              dataset.labels[mask])
    return new_ds

def separate_dataset(ds):
    """
    Separate a dataset into two groups based on the labels.
    Args:
        ds (TensorDataset): Dataset.
    Returns:
        TensorDataset: Dataset with label 0.
        TensorDataset: Dataset with label 1.
    """
    X, b, y = ds.tensors

    # Create masks for labels 0 and 1
    mask_0 = (y == 0)
    mask_1 = (y == 1)

    # Filter the tensors based on the masks
    X_0, b_0, y_0 = X[mask_0], b[mask_0], y[mask_0]
    X_1, b_1, y_1 = X[mask_1], b[mask_1], y[mask_1]

    # Create new datasets for each group
    dataset_0 = TensorDataset(X_0, b_0, y_0)  # Dataset with y == 0
    dataset_1 = TensorDataset(X_1, b_1, y_1)

    return dataset_0, dataset_1

def create_multimodal_model(model_config, device, use_mlm=False):
    """
    Create a multimodal model.
    Args:
        model_config (dict): Model configuration.
        device (str): Device to use.
        use_mlm (bool, optional): Whether to use MLM pretraining. Defaults to False.
    Returns:
        MultiModalTransformer: Multimodal model.
    """
    model_config_rna, model_config_atac, model_config_flux = model_config['RNA'], model_config['ATAC'], model_config['Flux']
    share_config, model_config_multi = model_config['Share'], model_config['Multi']
    rna_model = SingleTransformer("RNA", **model_config_rna, **share_config).to(device)
    atac_model = SingleTransformer("ATAC", **model_config_atac, **share_config).to(device)
    flux_model = SingleTransformer("Flux", **model_config_flux, **share_config).to(device)
    if use_mlm:
            rna_model.load_state_dict(torch.load(config.MLM_RNA_CKP), strict=False)
            atac_model.load_state_dict(torch.load(config.MLM_ATAC_CKP), strict=False)
            flux_model.load_state_dict(torch.load(config.MLM_FLUX_CKP), strict=False)
            # print("Loaded MLM pretraining weights.: \n RNA: {}, ATAC: {}, Flux: {}".format(config.MLM_RNA_CKP, config.MLM_ATAC_CKP, config.MLM_FLUX_CKP))
    model = MultiModalTransformer(rna_model, atac_model, flux_model, **model_config_multi).to(device)
    return model