| | import sys |
| | import os |
| | import traceback |
| | import json |
| | import pickle |
| | import numpy as np |
| | import scanpy as sc |
| | import pandas as pd |
| | import networkx as nx |
| | from tqdm import tqdm |
| | import logging |
| | import torch |
| | import torch.optim as optim |
| | import torch.nn as nn |
| | from sklearn.metrics import r2_score |
| | from torch.optim.lr_scheduler import StepLR |
| | from torch_geometric.nn import SGConv |
| | from copy import deepcopy |
| | from torch_geometric.data import Data, DataLoader |
| | from multiprocessing import Pool |
| | from torch.nn import Sequential, Linear, ReLU |
| | from scipy.stats import pearsonr |
| | from sklearn.metrics import mean_squared_error as mse |
| | from sklearn.metrics import mean_absolute_error as mae |
| |
|
| | class MLP(torch.nn.Module): |
| |
|
| | def __init__(self, sizes, batch_norm=True, last_layer_act="linear"): |
| | super(MLP, self).__init__() |
| | layers = [] |
| | for s in range(len(sizes) - 1): |
| | layers = layers + [ |
| | torch.nn.Linear(sizes[s], sizes[s + 1]), |
| | torch.nn.BatchNorm1d(sizes[s + 1]) |
| | if batch_norm and s < len(sizes) - 1 else None, |
| | torch.nn.ReLU() |
| | ] |
| |
|
| | layers = [l for l in layers if l is not None][:-1] |
| | self.activation = last_layer_act |
| | self.network = torch.nn.Sequential(*layers) |
| | self.relu = torch.nn.ReLU() |
| | def forward(self, x): |
| | return self.network(x) |
| |
|
| |
|
| | class GEARS_Model(torch.nn.Module): |
| | """ |
| | GEARS model |
| | |
| | """ |
| |
|
| | def __init__(self, args): |
| | """ |
| | :param args: arguments dictionary |
| | """ |
| |
|
| | super(GEARS_Model, self).__init__() |
| | self.args = args |
| | self.num_genes = args['num_genes'] |
| | self.num_perts = args['num_perts'] |
| | hidden_size = args['hidden_size'] |
| | self.uncertainty = args['uncertainty'] |
| | self.num_layers = args['num_go_gnn_layers'] |
| | self.indv_out_hidden_size = args['decoder_hidden_size'] |
| | self.num_layers_gene_pos = args['num_gene_gnn_layers'] |
| | self.no_perturb = args['no_perturb'] |
| | self.pert_emb_lambda = 0.2 |
| | |
| | |
| | self.pert_w = nn.Linear(1, hidden_size) |
| | |
| | |
| | self.gene_emb = nn.Embedding(self.num_genes, hidden_size, max_norm=True) |
| | self.pert_emb = nn.Embedding(self.num_perts, hidden_size, max_norm=True) |
| | |
| | |
| | self.emb_trans = nn.ReLU() |
| | self.pert_base_trans = nn.ReLU() |
| | self.transform = nn.ReLU() |
| | self.emb_trans_v2 = MLP([hidden_size, hidden_size, hidden_size], last_layer_act='ReLU') |
| | self.pert_fuse = MLP([hidden_size, hidden_size, hidden_size], last_layer_act='ReLU') |
| | |
| | |
| | self.G_coexpress = args['G_coexpress'].to(args['device']) |
| | self.G_coexpress_weight = args['G_coexpress_weight'].to(args['device']) |
| |
|
| | self.emb_pos = nn.Embedding(self.num_genes, hidden_size, max_norm=True) |
| | self.layers_emb_pos = torch.nn.ModuleList() |
| | for i in range(1, self.num_layers_gene_pos + 1): |
| | self.layers_emb_pos.append(SGConv(hidden_size, hidden_size, 1)) |
| | |
| | |
| | self.G_sim = args['G_go'].to(args['device']) |
| | self.G_sim_weight = args['G_go_weight'].to(args['device']) |
| |
|
| | self.sim_layers = torch.nn.ModuleList() |
| | for i in range(1, self.num_layers + 1): |
| | self.sim_layers.append(SGConv(hidden_size, hidden_size, 1)) |
| | |
| | |
| | self.recovery_w = MLP([hidden_size, hidden_size*2, hidden_size], last_layer_act='linear') |
| | |
| | |
| | self.indv_w1 = nn.Parameter(torch.rand(self.num_genes, |
| | hidden_size, 1)) |
| | self.indv_b1 = nn.Parameter(torch.rand(self.num_genes, 1)) |
| | self.act = nn.ReLU() |
| | nn.init.xavier_normal_(self.indv_w1) |
| | nn.init.xavier_normal_(self.indv_b1) |
| | |
| | |
| | self.cross_gene_state = MLP([self.num_genes, hidden_size, |
| | hidden_size]) |
| | |
| | self.indv_w2 = nn.Parameter(torch.rand(1, self.num_genes, |
| | hidden_size+1)) |
| | self.indv_b2 = nn.Parameter(torch.rand(1, self.num_genes)) |
| | nn.init.xavier_normal_(self.indv_w2) |
| | nn.init.xavier_normal_(self.indv_b2) |
| | |
| | |
| | self.bn_emb = nn.BatchNorm1d(hidden_size) |
| | self.bn_pert_base = nn.BatchNorm1d(hidden_size) |
| | self.bn_pert_base_trans = nn.BatchNorm1d(hidden_size) |
| | |
| | |
| | if self.uncertainty: |
| | self.uncertainty_w = MLP([hidden_size, hidden_size*2, hidden_size, 1], last_layer_act='linear') |
| | |
| | def forward(self, data): |
| | """ |
| | Forward pass of the model |
| | """ |
| | x, pert_idx = data.x, data.pert_idx |
| | if self.no_perturb: |
| | out = x.reshape(-1,1) |
| | out = torch.split(torch.flatten(out), self.num_genes) |
| | return torch.stack(out) |
| | else: |
| | num_graphs = len(data.batch.unique()) |
| |
|
| | |
| | emb = self.gene_emb(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device'])) |
| | emb = self.bn_emb(emb) |
| | base_emb = self.emb_trans(emb) |
| |
|
| | pos_emb = self.emb_pos(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device'])) |
| | for idx, layer in enumerate(self.layers_emb_pos): |
| | pos_emb = layer(pos_emb, self.G_coexpress, self.G_coexpress_weight) |
| | if idx < len(self.layers_emb_pos) - 1: |
| | pos_emb = pos_emb.relu() |
| |
|
| | base_emb = base_emb + 0.2 * pos_emb |
| | base_emb = self.emb_trans_v2(base_emb) |
| |
|
| | |
| |
|
| | pert_index = [] |
| | for idx, i in enumerate(pert_idx): |
| | for j in i: |
| | if j != -1: |
| | pert_index.append([idx, j]) |
| | pert_index = torch.tensor(pert_index).T |
| |
|
| | pert_global_emb = self.pert_emb(torch.LongTensor(list(range(self.num_perts))).to(self.args['device'])) |
| |
|
| | |
| | for idx, layer in enumerate(self.sim_layers): |
| | pert_global_emb = layer(pert_global_emb, self.G_sim, self.G_sim_weight) |
| | if idx < self.num_layers - 1: |
| | pert_global_emb = pert_global_emb.relu() |
| |
|
| | |
| | base_emb = base_emb.reshape(num_graphs, self.num_genes, -1) |
| |
|
| | if pert_index.shape[0] != 0: |
| | |
| | pert_track = {} |
| | for i, j in enumerate(pert_index[0]): |
| | if j.item() in pert_track: |
| | pert_track[j.item()] = pert_track[j.item()] + pert_global_emb[pert_index[1][i]] |
| | else: |
| | pert_track[j.item()] = pert_global_emb[pert_index[1][i]] |
| |
|
| | if len(list(pert_track.values())) > 0: |
| | if len(list(pert_track.values())) == 1: |
| | |
| | emb_total = self.pert_fuse(torch.stack(list(pert_track.values()) * 2)) |
| | else: |
| | emb_total = self.pert_fuse(torch.stack(list(pert_track.values()))) |
| |
|
| | for idx, j in enumerate(pert_track.keys()): |
| | base_emb[j] = base_emb[j] + emb_total[idx] |
| |
|
| | base_emb = base_emb.reshape(num_graphs * self.num_genes, -1) |
| | base_emb = self.bn_pert_base(base_emb) |
| |
|
| | |
| | base_emb = self.transform(base_emb) |
| | out = self.recovery_w(base_emb) |
| | out = out.reshape(num_graphs, self.num_genes, -1) |
| | out = out.unsqueeze(-1) * self.indv_w1 |
| | w = torch.sum(out, axis = 2) |
| | out = w + self.indv_b1 |
| |
|
| | |
| | cross_gene_embed = self.cross_gene_state(out.reshape(num_graphs, self.num_genes, -1).squeeze(2)) |
| | cross_gene_embed = cross_gene_embed.repeat(1, self.num_genes) |
| |
|
| | cross_gene_embed = cross_gene_embed.reshape([num_graphs,self.num_genes, -1]) |
| | cross_gene_out = torch.cat([out, cross_gene_embed], 2) |
| |
|
| | cross_gene_out = cross_gene_out * self.indv_w2 |
| | cross_gene_out = torch.sum(cross_gene_out, axis=2) |
| | out = cross_gene_out + self.indv_b2 |
| | out = out.reshape(num_graphs * self.num_genes, -1) + x.reshape(-1,1) |
| | out = torch.split(torch.flatten(out), self.num_genes) |
| |
|
| | |
| | if self.uncertainty: |
| | out_logvar = self.uncertainty_w(base_emb) |
| | out_logvar = torch.split(torch.flatten(out_logvar), self.num_genes) |
| | return torch.stack(out), torch.stack(out_logvar) |
| | |
| | return torch.stack(out) |
| |
|
| | class GEARS: |
| | """ |
| | GEARS base model class |
| | """ |
| |
|
| | def __init__(self, pert_data, |
| | device = 'cuda', |
| | weight_bias_track = True, |
| | proj_name = 'GEARS', |
| | exp_name = 'GEARS'): |
| |
|
| | self.weight_bias_track = weight_bias_track |
| | |
| | if self.weight_bias_track: |
| | import wandb |
| | wandb.init(project=proj_name, name=exp_name) |
| | self.wandb = wandb |
| | else: |
| | self.wandb = None |
| | |
| | self.device = device |
| | self.config = None |
| | |
| | self.dataloader = pert_data.dataloader |
| | self.adata = pert_data.adata |
| | self.node_map = pert_data.node_map |
| | self.node_map_pert = pert_data.node_map_pert |
| | self.data_path = pert_data.data_path |
| | self.dataset_name = pert_data.dataset_name |
| | self.split = pert_data.split |
| | self.seed = pert_data.seed |
| | self.train_gene_set_size = pert_data.train_gene_set_size |
| | self.set2conditions = pert_data.set2conditions |
| | self.subgroup = pert_data.subgroup |
| | self.gene_list = pert_data.gene_names.values.tolist() |
| | self.pert_list = pert_data.pert_names.tolist() |
| | self.num_genes = len(self.gene_list) |
| | self.num_perts = len(self.pert_list) |
| | self.default_pert_graph = pert_data.default_pert_graph |
| | self.saved_pred = {} |
| | self.saved_logvar_sum = {} |
| | |
| | self.ctrl_expression = torch.tensor( |
| | np.mean(self.adata.X[self.adata.obs['condition'].values == 'ctrl'], |
| | axis=0)).reshape(-1, ).to(self.device) |
| | pert_full_id2pert = dict(self.adata.obs[['condition_name', 'condition']].values) |
| | self.dict_filter = {pert_full_id2pert[i]: j for i, j in |
| | self.adata.uns['non_zeros_gene_idx'].items() if |
| | i in pert_full_id2pert} |
| | self.ctrl_adata = self.adata[self.adata.obs['condition'] == 'ctrl'] |
| | |
| | gene_dict = {g:i for i,g in enumerate(self.gene_list)} |
| | self.pert2gene = {p: gene_dict[pert] for p, pert in |
| | enumerate(self.pert_list) if pert in self.gene_list} |
| | |
| | def model_initialize(self, hidden_size = 64, |
| | num_go_gnn_layers = 1, |
| | num_gene_gnn_layers = 1, |
| | decoder_hidden_size = 16, |
| | num_similar_genes_go_graph = 20, |
| | num_similar_genes_co_express_graph = 20, |
| | coexpress_threshold = 0.4, |
| | uncertainty = False, |
| | uncertainty_reg = 1, |
| | direction_lambda = 1e-1, |
| | G_go = None, |
| | G_go_weight = None, |
| | G_coexpress = None, |
| | G_coexpress_weight = None, |
| | no_perturb = False, |
| | **kwargs |
| | ): |
| |
|
| | self.config = {'hidden_size': hidden_size, |
| | 'num_go_gnn_layers' : num_go_gnn_layers, |
| | 'num_gene_gnn_layers' : num_gene_gnn_layers, |
| | 'decoder_hidden_size' : decoder_hidden_size, |
| | 'num_similar_genes_go_graph' : num_similar_genes_go_graph, |
| | 'num_similar_genes_co_express_graph' : num_similar_genes_co_express_graph, |
| | 'coexpress_threshold': coexpress_threshold, |
| | 'uncertainty' : uncertainty, |
| | 'uncertainty_reg' : uncertainty_reg, |
| | 'direction_lambda' : direction_lambda, |
| | 'G_go': G_go, |
| | 'G_go_weight': G_go_weight, |
| | 'G_coexpress': G_coexpress, |
| | 'G_coexpress_weight': G_coexpress_weight, |
| | 'device': self.device, |
| | 'num_genes': self.num_genes, |
| | 'num_perts': self.num_perts, |
| | 'no_perturb': no_perturb |
| | } |
| | |
| | if self.wandb: |
| | self.wandb.config.update(self.config) |
| | |
| | if self.config['G_coexpress'] is None: |
| | |
| | edge_list = get_similarity_network(network_type='co-express', |
| | adata=self.adata, |
| | threshold=coexpress_threshold, |
| | k=num_similar_genes_co_express_graph, |
| | data_path=self.data_path, |
| | data_name=self.dataset_name, |
| | split=self.split, seed=self.seed, |
| | train_gene_set_size=self.train_gene_set_size, |
| | set2conditions=self.set2conditions) |
| |
|
| | sim_network = GeneSimNetwork(edge_list, self.gene_list, node_map = self.node_map) |
| | self.config['G_coexpress'] = sim_network.edge_index |
| | self.config['G_coexpress_weight'] = sim_network.edge_weight |
| | |
| | if self.config['G_go'] is None: |
| | |
| | edge_list = get_similarity_network(network_type='go', |
| | adata=self.adata, |
| | threshold=coexpress_threshold, |
| | k=num_similar_genes_go_graph, |
| | pert_list=self.pert_list, |
| | data_path=self.data_path, |
| | data_name=self.dataset_name, |
| | split=self.split, seed=self.seed, |
| | train_gene_set_size=self.train_gene_set_size, |
| | set2conditions=self.set2conditions, |
| | default_pert_graph=self.default_pert_graph) |
| |
|
| | sim_network = GeneSimNetwork(edge_list, self.pert_list, node_map = self.node_map_pert) |
| | self.config['G_go'] = sim_network.edge_index |
| | self.config['G_go_weight'] = sim_network.edge_weight |
| | |
| | self.model = GEARS_Model(self.config).to(self.device) |
| | self.best_model = deepcopy(self.model) |
| | |
| | def load_pretrained(self, path): |
| |
|
| | with open(os.path.join(path, 'config.pkl'), 'rb') as f: |
| | config = pickle.load(f) |
| | |
| | del config['device'], config['num_genes'], config['num_perts'] |
| | self.model_initialize(**config) |
| | self.config = config |
| | |
| | state_dict = torch.load(os.path.join(path, 'model.pt'), map_location = torch.device('cpu')) |
| | if next(iter(state_dict))[:7] == 'module.': |
| | |
| | from collections import OrderedDict |
| | new_state_dict = OrderedDict() |
| | for k, v in state_dict.items(): |
| | name = k[7:] |
| | new_state_dict[name] = v |
| | state_dict = new_state_dict |
| | |
| | self.model.load_state_dict(state_dict) |
| | self.model = self.model.to(self.device) |
| | self.best_model = self.model |
| | |
| | def save_model(self, path): |
| | if not os.path.exists(path): |
| | os.mkdir(path) |
| | |
| | if self.config is None: |
| | raise ValueError('No model is initialized...') |
| | |
| | with open(os.path.join(path, 'config.pkl'), 'wb') as f: |
| | pickle.dump(self.config, f) |
| | |
| | torch.save(self.best_model.state_dict(), os.path.join(path, 'model.pt')) |
| | |
| | |
| | def train(self, epochs = 20, |
| | lr = 1e-3, |
| | weight_decay = 5e-4 |
| | ): |
| | """ |
| | Train the model |
| | |
| | Parameters |
| | ---------- |
| | epochs: int |
| | number of epochs to train |
| | lr: float |
| | learning rate |
| | weight_decay: float |
| | weight decay |
| | |
| | Returns |
| | ------- |
| | None |
| | |
| | """ |
| | |
| | train_loader = self.dataloader['train_loader'] |
| | val_loader = self.dataloader['val_loader'] |
| | |
| | self.model = self.model.to(self.device) |
| | best_model = deepcopy(self.model) |
| | optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay = weight_decay) |
| | scheduler = StepLR(optimizer, step_size=1, gamma=0.5) |
| |
|
| | min_val = np.inf |
| | print_sys('Start Training...') |
| |
|
| | for epoch in range(epochs): |
| | self.model.train() |
| |
|
| | for step, batch in enumerate(train_loader): |
| | batch.to(self.device) |
| | optimizer.zero_grad() |
| | y = batch.y |
| | if self.config['uncertainty']: |
| | pred, logvar = self.model(batch) |
| | loss = uncertainty_loss_fct(pred, logvar, y, batch.pert, |
| | reg = self.config['uncertainty_reg'], |
| | ctrl = self.ctrl_expression, |
| | dict_filter = self.dict_filter, |
| | direction_lambda = self.config['direction_lambda']) |
| | else: |
| | pred = self.model(batch) |
| | loss = loss_fct(pred, y, batch.pert, |
| | ctrl = self.ctrl_expression, |
| | dict_filter = self.dict_filter, |
| | direction_lambda = self.config['direction_lambda']) |
| | loss.backward() |
| | nn.utils.clip_grad_value_(self.model.parameters(), clip_value=1.0) |
| | optimizer.step() |
| |
|
| | if self.wandb: |
| | self.wandb.log({'training_loss': loss.item()}) |
| |
|
| | if step % 50 == 0: |
| | log = "Epoch {} Step {} Train Loss: {:.4f}" |
| | print_sys(log.format(epoch + 1, step + 1, loss.item())) |
| |
|
| | scheduler.step() |
| | |
| | train_res = evaluate(train_loader, self.model, |
| | self.config['uncertainty'], self.device) |
| | val_res = evaluate(val_loader, self.model, |
| | self.config['uncertainty'], self.device) |
| | train_metrics, _ = compute_metrics(train_res) |
| | val_metrics, _ = compute_metrics(val_res) |
| |
|
| | |
| | log = "Epoch {}: Train Overall MSE: {:.4f} " \ |
| | "Validation Overall MSE: {:.4f}. " |
| | print_sys(log.format(epoch + 1, train_metrics['mse'], |
| | val_metrics['mse'])) |
| | |
| | |
| | log = "Train Top 20 DE MSE: {:.4f} " \ |
| | "Validation Top 20 DE MSE: {:.4f}. " |
| | print_sys(log.format(train_metrics['mse_de'], |
| | val_metrics['mse_de'])) |
| | |
| | if self.wandb: |
| | metrics = ['mse', 'pearson'] |
| | for m in metrics: |
| | self.wandb.log({'train_' + m: train_metrics[m], |
| | 'val_'+m: val_metrics[m], |
| | 'train_de_' + m: train_metrics[m + '_de'], |
| | 'val_de_'+m: val_metrics[m + '_de']}) |
| | |
| | if val_metrics['mse_de'] < min_val: |
| | min_val = val_metrics['mse_de'] |
| | best_model = deepcopy(self.model) |
| | |
| | print_sys("Done!") |
| | self.best_model = best_model |
| |
|
| | if 'test_loader' not in self.dataloader: |
| | print_sys('Done! No test dataloader detected.') |
| | return |
| | |
| | |
| | test_loader = self.dataloader['test_loader'] |
| | print_sys("Start Testing...") |
| | test_res = evaluate(test_loader, self.best_model, |
| | self.config['uncertainty'], self.device) |
| | test_metrics, test_pert_res = compute_metrics(test_res) |
| | log = "Best performing model: Test Top 20 DE MSE: {:.4f}" |
| | print_sys(log.format(test_metrics['mse_de'])) |
| | |
| | if self.wandb: |
| | metrics = ['mse', 'pearson'] |
| | for m in metrics: |
| | self.wandb.log({'test_' + m: test_metrics[m], |
| | 'test_de_'+m: test_metrics[m + '_de'] |
| | }) |
| | |
| | print_sys('Done!') |
| | self.test_metrics = test_metrics |
| |
|
| | def np_pearson_cor(x, y): |
| | xv = x - x.mean(axis=0) |
| | yv = y - y.mean(axis=0) |
| | xvss = (xv * xv).sum(axis=0) |
| | yvss = (yv * yv).sum(axis=0) |
| | result = np.matmul(xv.transpose(), yv) / np.sqrt(np.outer(xvss, yvss)) |
| | |
| | return np.maximum(np.minimum(result, 1.0), -1.0) |
| |
|
| | |
| | class GeneSimNetwork(): |
| | """ |
| | GeneSimNetwork class |
| | |
| | Args: |
| | edge_list (pd.DataFrame): edge list of the network |
| | gene_list (list): list of gene names |
| | node_map (dict): dictionary mapping gene names to node indices |
| | |
| | Attributes: |
| | edge_index (torch.Tensor): edge index of the network |
| | edge_weight (torch.Tensor): edge weight of the network |
| | G (nx.DiGraph): networkx graph object |
| | """ |
| | def __init__(self, edge_list, gene_list, node_map): |
| | """ |
| | Initialize GeneSimNetwork class |
| | """ |
| |
|
| | self.edge_list = edge_list |
| | self.G = nx.from_pandas_edgelist(self.edge_list, source='source', |
| | target='target', edge_attr=['importance'], |
| | create_using=nx.DiGraph()) |
| | self.gene_list = gene_list |
| | for n in self.gene_list: |
| | if n not in self.G.nodes(): |
| | self.G.add_node(n) |
| | |
| | edge_index_ = [(node_map[e[0]], node_map[e[1]]) for e in |
| | self.G.edges] |
| | self.edge_index = torch.tensor(edge_index_, dtype=torch.long).T |
| | |
| | |
| | edge_attr = nx.get_edge_attributes(self.G, 'importance') |
| | importance = np.array([edge_attr[e] for e in self.G.edges]) |
| | self.edge_weight = torch.Tensor(importance) |
| |
|
| | def get_GO_edge_list(args): |
| | """ |
| | Get gene ontology edge list |
| | """ |
| | g1, gene2go = args |
| | edge_list = [] |
| | for g2 in gene2go.keys(): |
| | score = len(gene2go[g1].intersection(gene2go[g2])) / len( |
| | gene2go[g1].union(gene2go[g2])) |
| | if score > 0.1: |
| | edge_list.append((g1, g2, score)) |
| | return edge_list |
| | |
| | def make_GO(data_path, pert_list, data_name, num_workers=25, save=True): |
| | """ |
| | Creates Gene Ontology graph from a custom set of genes |
| | """ |
| |
|
| | fname = './data/go_essential_' + data_name + '.csv' |
| | if os.path.exists(fname): |
| | return pd.read_csv(fname) |
| |
|
| | with open(os.path.join(data_path, 'gene2go_all.pkl'), 'rb') as f: |
| | gene2go = pickle.load(f) |
| | gene2go = {i: gene2go[i] for i in pert_list} |
| |
|
| | print('Creating custom GO graph, this can take a few minutes') |
| | with Pool(num_workers) as p: |
| | all_edge_list = list( |
| | tqdm(p.imap(get_GO_edge_list, ((g, gene2go) for g in gene2go.keys())), |
| | total=len(gene2go.keys()))) |
| | edge_list = [] |
| | for i in all_edge_list: |
| | edge_list = edge_list + i |
| |
|
| | df_edge_list = pd.DataFrame(edge_list).rename( |
| | columns={0: 'source', 1: 'target', 2: 'importance'}) |
| | |
| | if save: |
| | print('Saving edge_list to file') |
| | df_edge_list.to_csv(fname, index=False) |
| |
|
| | return df_edge_list |
| |
|
| | def get_similarity_network(network_type, adata, threshold, k, |
| | data_path, data_name, split, seed, train_gene_set_size, |
| | set2conditions, default_pert_graph=True, pert_list=None): |
| | |
| | if network_type == 'co-express': |
| | df_out = get_coexpression_network_from_train(adata, threshold, k, |
| | data_path, data_name, split, |
| | seed, train_gene_set_size, |
| | set2conditions) |
| | elif network_type == 'go': |
| | if default_pert_graph: |
| | server_path = 'https://dataverse.harvard.edu/api/access/datafile/6934319' |
| | |
| | |
| | |
| | df_jaccard = pd.read_csv(os.path.join(data_path, |
| | 'go_essential_all/go_essential_all.csv')) |
| |
|
| | else: |
| | df_jaccard = make_GO(data_path, pert_list, data_name) |
| |
|
| | df_out = df_jaccard.groupby('target').apply(lambda x: x.nlargest(k + 1, |
| | ['importance'])).reset_index(drop = True) |
| |
|
| | return df_out |
| |
|
| | def get_coexpression_network_from_train(adata, threshold, k, data_path, |
| | data_name, split, seed, train_gene_set_size, |
| | set2conditions): |
| | """ |
| | Infer co-expression network from training data |
| | |
| | Args: |
| | adata (anndata.AnnData): anndata object |
| | threshold (float): threshold for co-expression |
| | k (int): number of edges to keep |
| | data_path (str): path to data |
| | data_name (str): name of dataset |
| | split (str): split of dataset |
| | seed (int): seed for random number generator |
| | train_gene_set_size (int): size of training gene set |
| | set2conditions (dict): dictionary of perturbations to conditions |
| | """ |
| | |
| | fname = os.path.join(os.path.join(data_path, data_name), split + '_' + |
| | str(seed) + '_' + str(train_gene_set_size) + '_' + |
| | str(threshold) + '_' + str(k) + |
| | '_co_expression_network.csv') |
| | |
| | if os.path.exists(fname): |
| | return pd.read_csv(fname) |
| | else: |
| | gene_list = [f for f in adata.var.gene_name.values] |
| | idx2gene = dict(zip(range(len(gene_list)), gene_list)) |
| | X = adata.X |
| | train_perts = set2conditions['train'] |
| | X_tr = X[np.isin(adata.obs.condition, [i for i in train_perts if 'ctrl' in i])] |
| | gene_list = adata.var['gene_name'].values |
| |
|
| | X_tr = X_tr.toarray() |
| | out = np_pearson_cor(X_tr, X_tr) |
| | out[np.isnan(out)] = 0 |
| | out = np.abs(out) |
| |
|
| | out_sort_idx = np.argsort(out)[:, -(k + 1):] |
| | out_sort_val = np.sort(out)[:, -(k + 1):] |
| |
|
| | df_g = [] |
| | for i in range(out_sort_idx.shape[0]): |
| | target = idx2gene[i] |
| | for j in range(out_sort_idx.shape[1]): |
| | df_g.append((idx2gene[out_sort_idx[i, j]], target, out_sort_val[i, j])) |
| |
|
| | df_g = [i for i in df_g if i[2] > threshold] |
| | df_co_expression = pd.DataFrame(df_g).rename(columns = {0: 'source', |
| | 1: 'target', |
| | 2: 'importance'}) |
| | df_co_expression.to_csv(fname, index = False) |
| | return df_co_expression |
| | |
| | def uncertainty_loss_fct(pred, logvar, y, perts, reg = 0.1, ctrl = None, |
| | direction_lambda = 1e-3, dict_filter = None): |
| | """ |
| | Uncertainty loss function |
| | |
| | Args: |
| | pred (torch.tensor): predicted values |
| | logvar (torch.tensor): log variance |
| | y (torch.tensor): true values |
| | perts (list): list of perturbations |
| | reg (float): regularization parameter |
| | ctrl (str): control perturbation |
| | direction_lambda (float): direction loss weight hyperparameter |
| | dict_filter (dict): dictionary of perturbations to conditions |
| | |
| | """ |
| | gamma = 2 |
| | perts = np.array(perts) |
| | losses = torch.tensor(0.0, requires_grad=True).to(pred.device) |
| | for p in set(perts): |
| | if p!= 'ctrl': |
| | retain_idx = dict_filter[p] |
| | pred_p = pred[np.where(perts==p)[0]][:, retain_idx] |
| | y_p = y[np.where(perts==p)[0]][:, retain_idx] |
| | logvar_p = logvar[np.where(perts==p)[0]][:, retain_idx] |
| | else: |
| | pred_p = pred[np.where(perts==p)[0]] |
| | y_p = y[np.where(perts==p)[0]] |
| | logvar_p = logvar[np.where(perts==p)[0]] |
| | |
| | |
| | losses += torch.sum((pred_p - y_p)**(2 + gamma) + reg * torch.exp( |
| | -logvar_p) * (pred_p - y_p)**(2 + gamma))/pred_p.shape[0]/pred_p.shape[1] |
| | |
| | |
| | if p!= 'ctrl': |
| | losses += torch.sum(direction_lambda * |
| | (torch.sign(y_p - ctrl[retain_idx]) - |
| | torch.sign(pred_p - ctrl[retain_idx]))**2)/\ |
| | pred_p.shape[0]/pred_p.shape[1] |
| | else: |
| | losses += torch.sum(direction_lambda * |
| | (torch.sign(y_p - ctrl) - |
| | torch.sign(pred_p - ctrl))**2)/\ |
| | pred_p.shape[0]/pred_p.shape[1] |
| | |
| | return losses/(len(set(perts))) |
| |
|
| |
|
| | def loss_fct(pred, y, perts, ctrl = None, direction_lambda = 1e-3, dict_filter = None): |
| | """ |
| | Main MSE Loss function, includes direction loss |
| | |
| | Args: |
| | pred (torch.tensor): predicted values |
| | y (torch.tensor): true values |
| | perts (list): list of perturbations |
| | ctrl (str): control perturbation |
| | direction_lambda (float): direction loss weight hyperparameter |
| | dict_filter (dict): dictionary of perturbations to conditions |
| | |
| | """ |
| | gamma = 2 |
| | mse_p = torch.nn.MSELoss() |
| | perts = np.array(perts) |
| | losses = torch.tensor(0.0, requires_grad=True).to(pred.device) |
| |
|
| | for p in set(perts): |
| | pert_idx = np.where(perts == p)[0] |
| | |
| | |
| | |
| | if p!= 'ctrl': |
| | retain_idx = dict_filter[p] |
| | pred_p = pred[pert_idx][:, retain_idx] |
| | y_p = y[pert_idx][:, retain_idx] |
| | else: |
| | pred_p = pred[pert_idx] |
| | y_p = y[pert_idx] |
| | losses = losses + torch.sum((pred_p - y_p)**(2 + gamma))/pred_p.shape[0]/pred_p.shape[1] |
| | |
| | |
| | if (p!= 'ctrl'): |
| | losses = losses + torch.sum(direction_lambda * |
| | (torch.sign(y_p - ctrl[retain_idx]) - |
| | torch.sign(pred_p - ctrl[retain_idx]))**2)/\ |
| | pred_p.shape[0]/pred_p.shape[1] |
| | else: |
| | losses = losses + torch.sum(direction_lambda * (torch.sign(y_p - ctrl) - |
| | torch.sign(pred_p - ctrl))**2)/\ |
| | pred_p.shape[0]/pred_p.shape[1] |
| | return losses/(len(set(perts))) |
| | def evaluate(loader, model, uncertainty, device): |
| | """ |
| | Run model in inference mode using a given data loader |
| | """ |
| |
|
| | model.eval() |
| | model.to(device) |
| | pert_cat = [] |
| | pred = [] |
| | truth = [] |
| | pred_de = [] |
| | truth_de = [] |
| | results = {} |
| | logvar = [] |
| | |
| | for itr, batch in enumerate(loader): |
| |
|
| | batch.to(device) |
| | pert_cat.extend(batch.pert) |
| |
|
| | with torch.no_grad(): |
| | if uncertainty: |
| | p, unc = model(batch) |
| | logvar.extend(unc.cpu()) |
| | else: |
| | p = model(batch) |
| | t = batch.y |
| | pred.extend(p.cpu()) |
| | truth.extend(t.cpu()) |
| | |
| | |
| | for itr, de_idx in enumerate(batch.de_idx): |
| | pred_de.append(p[itr, de_idx]) |
| | truth_de.append(t[itr, de_idx]) |
| |
|
| | |
| | results['pert_cat'] = np.array(pert_cat) |
| | pred = torch.stack(pred) |
| | truth = torch.stack(truth) |
| | results['pred']= pred.detach().cpu().numpy() |
| | results['truth']= truth.detach().cpu().numpy() |
| |
|
| | pred_de = torch.stack(pred_de) |
| | truth_de = torch.stack(truth_de) |
| | results['pred_de']= pred_de.detach().cpu().numpy() |
| | results['truth_de']= truth_de.detach().cpu().numpy() |
| | |
| | if uncertainty: |
| | results['logvar'] = torch.stack(logvar).detach().cpu().numpy() |
| | |
| | return results |
| |
|
| |
|
| | def compute_metrics(results): |
| | """ |
| | Given results from a model run and the ground truth, compute metrics |
| | |
| | """ |
| | metrics = {} |
| | metrics_pert = {} |
| |
|
| | metric2fct = { |
| | 'mse': mse, |
| | 'pearson': pearsonr |
| | } |
| | |
| | for m in metric2fct.keys(): |
| | metrics[m] = [] |
| | metrics[m + '_de'] = [] |
| |
|
| | for pert in np.unique(results['pert_cat']): |
| |
|
| | metrics_pert[pert] = {} |
| | p_idx = np.where(results['pert_cat'] == pert)[0] |
| | |
| | for m, fct in metric2fct.items(): |
| | if m == 'pearson': |
| | val = fct(results['pred'][p_idx].mean(0), results['truth'][p_idx].mean(0))[0] |
| | if np.isnan(val): |
| | val = 0 |
| | else: |
| | val = fct(results['pred'][p_idx].mean(0), results['truth'][p_idx].mean(0)) |
| |
|
| | metrics_pert[pert][m] = val |
| | metrics[m].append(metrics_pert[pert][m]) |
| |
|
| | |
| | if pert != 'ctrl': |
| | |
| | for m, fct in metric2fct.items(): |
| | if m == 'pearson': |
| | val = fct(results['pred_de'][p_idx].mean(0), results['truth_de'][p_idx].mean(0))[0] |
| | if np.isnan(val): |
| | val = 0 |
| | else: |
| | val = fct(results['pred_de'][p_idx].mean(0), results['truth_de'][p_idx].mean(0)) |
| | |
| | metrics_pert[pert][m + '_de'] = val |
| | metrics[m + '_de'].append(metrics_pert[pert][m + '_de']) |
| |
|
| | else: |
| | for m, fct in metric2fct.items(): |
| | metrics_pert[pert][m + '_de'] = 0 |
| | |
| | for m in metric2fct.keys(): |
| | |
| | metrics[m] = np.mean(metrics[m]) |
| | metrics[m + '_de'] = np.mean(metrics[m + '_de']) |
| | |
| | return metrics, metrics_pert |
| |
|
| | def filter_pert_in_go(condition, pert_names): |
| | """ |
| | Filter perturbations in GO graph |
| | |
| | Args: |
| | condition (str): whether condition is 'ctrl' or not |
| | pert_names (list): list of perturbations |
| | """ |
| |
|
| | if condition == 'ctrl': |
| | return True |
| | else: |
| | cond1 = condition.split('+')[0] |
| | cond2 = condition.split('+')[1] |
| | num_ctrl = (cond1 == 'ctrl') + (cond2 == 'ctrl') |
| | num_in_perts = (cond1 in pert_names) + (cond2 in pert_names) |
| | if num_ctrl + num_in_perts == 2: |
| | return True |
| | else: |
| | return False |
| |
|
| | class PertData: |
| | def __init__(self, data_path, |
| | gene_set_path=None, |
| | default_pert_graph=True): |
| | |
| | |
| | self.data_path = data_path |
| | self.default_pert_graph = default_pert_graph |
| | self.gene_set_path = gene_set_path |
| | self.dataset_name = None |
| | self.dataset_path = None |
| | self.adata = None |
| | self.dataset_processed = None |
| | self.ctrl_adata = None |
| | self.gene_names = [] |
| | self.node_map = {} |
| |
|
| | |
| | self.split = None |
| | self.seed = None |
| | self.subgroup = None |
| | self.train_gene_set_size = None |
| |
|
| | if not os.path.exists(self.data_path): |
| | os.mkdir(self.data_path) |
| | server_path = 'https://dataverse.harvard.edu/api/access/datafile/6153417' |
| | with open(os.path.join(self.data_path, 'gene2go_all.pkl'), 'rb') as f: |
| | self.gene2go = pickle.load(f) |
| | |
| | def set_pert_genes(self): |
| | """ |
| | Set the list of genes that can be perturbed and are to be included in |
| | perturbation graph |
| | """ |
| | |
| | if self.gene_set_path is not None: |
| | |
| | path_ = self.gene_set_path |
| | self.default_pert_graph = False |
| | with open(path_, 'rb') as f: |
| | essential_genes = pickle.load(f) |
| | |
| | elif self.default_pert_graph is False: |
| | |
| | all_pert_genes = get_genes_from_perts(self.adata.obs['condition']) |
| | essential_genes = list(self.adata.var['gene_name'].values) |
| | essential_genes += all_pert_genes |
| | |
| | else: |
| | |
| | server_path = 'https://dataverse.harvard.edu/api/access/datafile/6934320' |
| | path_ = os.path.join(self.data_path, |
| | 'essential_all_data_pert_genes.pkl') |
| | with open(path_, 'rb') as f: |
| | essential_genes = pickle.load(f) |
| | |
| | gene2go = {i: self.gene2go[i] for i in essential_genes if i in self.gene2go} |
| |
|
| | self.pert_names = np.unique(list(gene2go.keys())) |
| | self.node_map_pert = {x: it for it, x in enumerate(self.pert_names)} |
| | |
| | def load(self, data_name = None, data_path = None): |
| | if data_name in ['norman', 'adamson', 'dixit', |
| | 'replogle_k562_essential', |
| | 'replogle_rpe1_essential']: |
| | data_path = os.path.join(self.data_path, data_name) |
| | |
| | self.dataset_name = data_path.split('/')[-1] |
| | self.dataset_path = data_path |
| | adata_path = os.path.join(data_path, 'perturb_processed.h5ad') |
| | self.adata = sc.read_h5ad(adata_path) |
| |
|
| | elif os.path.exists(data_path): |
| | adata_path = os.path.join(data_path, 'perturb_processed.h5ad') |
| | self.adata = sc.read_h5ad(adata_path) |
| | self.dataset_name = data_path.split('/')[-1] |
| | self.dataset_path = data_path |
| | else: |
| | raise ValueError("data attribute is either norman, adamson, dixit " |
| | "replogle_k562 or replogle_rpe1 " |
| | "or a path to an h5ad file") |
| | |
| | self.set_pert_genes() |
| | print_sys('These perturbations are not in the GO graph and their ' |
| | 'perturbation can thus not be predicted') |
| | not_in_go_pert = np.array(self.adata.obs[ |
| | self.adata.obs.condition.apply( |
| | lambda x:not filter_pert_in_go(x, |
| | self.pert_names))].condition.unique()) |
| | print_sys(not_in_go_pert) |
| | |
| | filter_go = self.adata.obs[self.adata.obs.condition.apply( |
| | lambda x: filter_pert_in_go(x, self.pert_names))] |
| | self.adata = self.adata[filter_go.index.values, :] |
| | pyg_path = os.path.join(data_path, 'data_pyg') |
| | if not os.path.exists(pyg_path): |
| | os.mkdir(pyg_path) |
| | dataset_fname = os.path.join(pyg_path, 'cell_graphs.pkl') |
| | |
| | if os.path.isfile(dataset_fname): |
| | print_sys("Local copy of pyg dataset is detected. Loading...") |
| | self.dataset_processed = pickle.load(open(dataset_fname, "rb")) |
| | print_sys("Done!") |
| | else: |
| | self.ctrl_adata = self.adata[self.adata.obs['condition'] == 'ctrl'] |
| | self.gene_names = self.adata.var.gene_name |
| | |
| | |
| | print_sys("Creating pyg object for each cell in the data...") |
| | self.create_dataset_file() |
| | print_sys("Saving new dataset pyg object at " + dataset_fname) |
| | pickle.dump(self.dataset_processed, open(dataset_fname, "wb")) |
| | print_sys("Done!") |
| | |
| | |
| | def prepare_split(self, split = 'simulation', |
| | seed = 1, |
| | train_gene_set_size = 0.75, |
| | combo_seen2_train_frac = 0.75, |
| | combo_single_split_test_set_fraction = 0.1, |
| | test_perts = None, |
| | only_test_set_perts = False, |
| | test_pert_genes = None, |
| | split_dict_path=None): |
| |
|
| | """ |
| | Prepare splits for training and testing |
| | |
| | Parameters |
| | ---------- |
| | split: str |
| | Type of split to use. Currently, we support 'simulation', |
| | 'simulation_single', 'combo_seen0', 'combo_seen1', 'combo_seen2', |
| | 'single', 'no_test', 'no_split', 'custom' |
| | seed: int |
| | Random seed |
| | train_gene_set_size: float |
| | Fraction of genes to use for training |
| | combo_seen2_train_frac: float |
| | Fraction of combo seen2 perturbations to use for training |
| | combo_single_split_test_set_fraction: float |
| | Fraction of combo single perturbations to use for testing |
| | test_perts: list |
| | List of perturbations to use for testing |
| | only_test_set_perts: bool |
| | If True, only use test set perturbations for testing |
| | test_pert_genes: list |
| | List of genes to use for testing |
| | split_dict_path: str |
| | Path to dictionary used for custom split. Sample format: |
| | {'train': [X, Y], 'val': [P, Q], 'test': [Z]} |
| | |
| | Returns |
| | ------- |
| | None |
| | |
| | """ |
| | available_splits = ['simulation', 'simulation_single', 'combo_seen0', |
| | 'combo_seen1', 'combo_seen2', 'single', 'no_test', |
| | 'no_split', 'custom'] |
| | if split not in available_splits: |
| | raise ValueError('currently, we only support ' + ','.join(available_splits)) |
| | self.split = split |
| | self.seed = seed |
| | self.subgroup = None |
| | |
| | if split == 'custom': |
| | try: |
| | with open(split_dict_path, 'rb') as f: |
| | self.set2conditions = pickle.load(f) |
| | except: |
| | raise ValueError('Please set split_dict_path for custom split') |
| | return |
| | |
| | self.train_gene_set_size = train_gene_set_size |
| | split_folder = os.path.join(self.dataset_path, 'splits') |
| | if not os.path.exists(split_folder): |
| | os.mkdir(split_folder) |
| | split_file = self.dataset_name + '_' + split + '_' + str(seed) + '_' \ |
| | + str(train_gene_set_size) + '.pkl' |
| | split_path = os.path.join(split_folder, split_file) |
| | |
| | if test_perts: |
| | split_path = split_path[:-4] + '_' + test_perts + '.pkl' |
| | |
| | if os.path.exists(split_path): |
| | print('here1') |
| | print_sys("Local copy of split is detected. Loading...") |
| | set2conditions = pickle.load(open(split_path, "rb")) |
| | if split == 'simulation': |
| | subgroup_path = split_path[:-4] + '_subgroup.pkl' |
| | subgroup = pickle.load(open(subgroup_path, "rb")) |
| | self.subgroup = subgroup |
| | else: |
| | print_sys("Creating new splits....") |
| | if test_perts: |
| | test_perts = test_perts.split('_') |
| | |
| | if split in ['simulation', 'simulation_single']: |
| | |
| | DS = DataSplitter(self.adata, split_type=split) |
| | |
| | adata, subgroup = DS.split_data(train_gene_set_size = train_gene_set_size, |
| | combo_seen2_train_frac = combo_seen2_train_frac, |
| | seed=seed, |
| | test_perts = test_perts, |
| | only_test_set_perts = only_test_set_perts |
| | ) |
| | subgroup_path = split_path[:-4] + '_subgroup.pkl' |
| | pickle.dump(subgroup, open(subgroup_path, "wb")) |
| | self.subgroup = subgroup |
| | |
| | elif split[:5] == 'combo': |
| | |
| | split_type = 'combo' |
| | seen = int(split[-1]) |
| |
|
| | if test_pert_genes: |
| | test_pert_genes = test_pert_genes.split('_') |
| | |
| | DS = DataSplitter(self.adata, split_type=split_type, seen=int(seen)) |
| | adata = DS.split_data(test_size=combo_single_split_test_set_fraction, |
| | test_perts=test_perts, |
| | test_pert_genes=test_pert_genes, |
| | seed=seed) |
| |
|
| | elif split == 'single': |
| | |
| | DS = DataSplitter(self.adata, split_type=split) |
| | adata = DS.split_data(test_size=combo_single_split_test_set_fraction, |
| | seed=seed) |
| |
|
| | elif split == 'no_test': |
| | |
| | DS = DataSplitter(self.adata, split_type=split) |
| | adata = DS.split_data(seed=seed) |
| | |
| | elif split == 'no_split': |
| | |
| | adata = self.adata |
| | adata.obs['split'] = 'test' |
| | |
| | set2conditions = dict(adata.obs.groupby('split').agg({'condition': |
| | lambda x: x}).condition) |
| | set2conditions = {i: j.unique().tolist() for i,j in set2conditions.items()} |
| | pickle.dump(set2conditions, open(split_path, "wb")) |
| | print_sys("Saving new splits at " + split_path) |
| | |
| | self.set2conditions = set2conditions |
| |
|
| | if split == 'simulation': |
| | print_sys('Simulation split test composition:') |
| | for i,j in subgroup['test_subgroup'].items(): |
| | print_sys(i + ':' + str(len(j))) |
| | print_sys("Done!") |
| | |
| | def get_dataloader(self, batch_size, test_batch_size = None): |
| | """ |
| | Get dataloaders for training and testing |
| | |
| | Parameters |
| | ---------- |
| | batch_size: int |
| | Batch size for training |
| | test_batch_size: int |
| | Batch size for testing |
| | |
| | Returns |
| | ------- |
| | dict |
| | Dictionary of dataloaders |
| | |
| | """ |
| | if test_batch_size is None: |
| | test_batch_size = batch_size |
| | |
| | self.node_map = {x: it for it, x in enumerate(self.adata.var.gene_name)} |
| | self.gene_names = self.adata.var.gene_name |
| | |
| | |
| | cell_graphs = {} |
| | if self.split == 'no_split': |
| | i = 'test' |
| | cell_graphs[i] = [] |
| | for p in self.set2conditions[i]: |
| | if p != 'ctrl': |
| | cell_graphs[i].extend(self.dataset_processed[p]) |
| | |
| | print_sys("Creating dataloaders....") |
| | |
| | test_loader = DataLoader(cell_graphs['test'], |
| | batch_size=batch_size, shuffle=False) |
| |
|
| | print_sys("Dataloaders created...") |
| | return {'test_loader': test_loader} |
| | else: |
| | if self.split =='no_test': |
| | splits = ['train','val'] |
| | else: |
| | splits = ['train','val','test'] |
| | for i in splits: |
| | cell_graphs[i] = [] |
| | for p in self.set2conditions[i]: |
| | cell_graphs[i].extend(self.dataset_processed[p]) |
| |
|
| | print_sys("Creating dataloaders....") |
| | |
| | |
| | train_loader = DataLoader(cell_graphs['train'], |
| | batch_size=batch_size, shuffle=True, drop_last = True) |
| | val_loader = DataLoader(cell_graphs['val'], |
| | batch_size=batch_size, shuffle=True) |
| | |
| | if self.split !='no_test': |
| | test_loader = DataLoader(cell_graphs['test'], |
| | batch_size=batch_size, shuffle=False) |
| | self.dataloader = {'train_loader': train_loader, |
| | 'val_loader': val_loader, |
| | 'test_loader': test_loader} |
| |
|
| | else: |
| | self.dataloader = {'train_loader': train_loader, |
| | 'val_loader': val_loader} |
| | print_sys("Done!") |
| |
|
| | def get_pert_idx(self, pert_category): |
| | """ |
| | Get perturbation index for a given perturbation category |
| | |
| | Parameters |
| | ---------- |
| | pert_category: str |
| | Perturbation category |
| | |
| | Returns |
| | ------- |
| | list |
| | List of perturbation indices |
| | |
| | """ |
| | try: |
| | pert_idx = [np.where(p == self.pert_names)[0][0] |
| | for p in pert_category.split('+') |
| | if p != 'ctrl'] |
| | except: |
| | print(pert_category) |
| | pert_idx = None |
| | |
| | return pert_idx |
| |
|
| | def create_cell_graph(self, X, y, de_idx, pert, pert_idx=None): |
| | """ |
| | Create a cell graph from a given cell |
| | |
| | Parameters |
| | ---------- |
| | X: np.ndarray |
| | Gene expression matrix |
| | y: np.ndarray |
| | Label vector |
| | de_idx: np.ndarray |
| | DE gene indices |
| | pert: str |
| | Perturbation category |
| | pert_idx: list |
| | List of perturbation indices |
| | |
| | Returns |
| | ------- |
| | torch_geometric.data.Data |
| | Cell graph to be used in dataloader |
| | |
| | """ |
| |
|
| | feature_mat = torch.Tensor(X).T |
| | if pert_idx is None: |
| | pert_idx = [-1] |
| | return Data(x=feature_mat, pert_idx=pert_idx, |
| | y=torch.Tensor(y), de_idx=de_idx, pert=pert) |
| |
|
| | def create_cell_graph_dataset(self, split_adata, pert_category, |
| | num_samples=1): |
| | """ |
| | Combine cell graphs to create a dataset of cell graphs |
| | |
| | Parameters |
| | ---------- |
| | split_adata: anndata.AnnData |
| | Annotated data matrix |
| | pert_category: str |
| | Perturbation category |
| | num_samples: int |
| | Number of samples to create per perturbed cell (i.e. number of |
| | control cells to map to each perturbed cell) |
| | |
| | Returns |
| | ------- |
| | list |
| | List of cell graphs |
| | |
| | """ |
| |
|
| | num_de_genes = 20 |
| | adata_ = split_adata[split_adata.obs['condition'] == pert_category] |
| | if 'rank_genes_groups_cov_all' in adata_.uns: |
| | de_genes = adata_.uns['rank_genes_groups_cov_all'] |
| | de = True |
| | else: |
| | de = False |
| | num_de_genes = 1 |
| | Xs = [] |
| | ys = [] |
| |
|
| | |
| | if pert_category != 'ctrl': |
| | |
| | pert_idx = self.get_pert_idx(pert_category) |
| |
|
| | |
| | pert_de_category = adata_.obs['condition_name'][0] |
| | if de: |
| | de_idx = np.where(adata_.var_names.isin( |
| | np.array(de_genes[pert_de_category][:num_de_genes])))[0] |
| | else: |
| | de_idx = [-1] * num_de_genes |
| | for cell_z in adata_.X: |
| | |
| | ctrl_samples = self.ctrl_adata[np.random.randint(0, |
| | len(self.ctrl_adata), num_samples), :] |
| | for c in ctrl_samples.X: |
| | Xs.append(c) |
| | ys.append(cell_z) |
| |
|
| | |
| | else: |
| | pert_idx = None |
| | de_idx = [-1] * num_de_genes |
| | for cell_z in adata_.X: |
| | Xs.append(cell_z) |
| | ys.append(cell_z) |
| |
|
| | |
| | cell_graphs = [] |
| | for X, y in zip(Xs, ys): |
| | cell_graphs.append(self.create_cell_graph(X.toarray(), |
| | y.toarray(), de_idx, pert_category, pert_idx)) |
| |
|
| | return cell_graphs |
| |
|
| | def create_dataset_file(self): |
| | """ |
| | Create dataset file for each perturbation condition |
| | """ |
| | print_sys("Creating dataset file...") |
| | self.dataset_processed = {} |
| | for p in tqdm(self.adata.obs['condition'].unique()): |
| | self.dataset_processed[p] = self.create_cell_graph_dataset(self.adata, p) |
| | print_sys("Done!") |
| |
|
| |
|
| | def main(data_path='./data', out_dir='./saved_models', device='cuda:0'): |
| | os.makedirs(data_path, exist_ok=True) |
| | os.makedirs(out_dir, exist_ok=True) |
| |
|
| | os.environ["WANDB_SILENT"] = "true" |
| | os.environ["WANDB_ERROR_REPORTING"] = "false" |
| |
|
| | print_sys("=== data loading ===") |
| | pert_data = PertData(data_path) |
| | |
| | pert_data.load(data_name='norman') |
| | |
| | pert_data.prepare_split(split='simulation', seed=1) |
| | pert_data.get_dataloader(batch_size=32, test_batch_size=128) |
| |
|
| | print_sys("\n=== model traing ===") |
| | gears_model = GEARS( |
| | pert_data, |
| | device=device, |
| | weight_bias_track=True, |
| | proj_name='GEARS', |
| | exp_name='gears_norman' |
| | ) |
| | gears_model.model_initialize(hidden_size = 64) |
| | |
| | gears_model.train(epochs=args.epochs, lr=1e-3) |
| | |
| | gears_model.save_model(os.path.join(out_dir, 'norman_full_model')) |
| | print_sys(f"model saved to {out_dir}") |
| | gears_model.load_pretrained(os.path.join(out_dir, 'norman_full_model')) |
| |
|
| | final_infos = { |
| | "Gears":{ |
| | "means":{ |
| | "Test Top 20 DE MSE": float(gears_model.test_metrics['mse_de'].item()) |
| | } |
| | } |
| | } |
| | |
| | with open(os.path.join(out_dir, 'final_info.json'), 'w') as f: |
| | json.dump(final_infos, f, indent=4) |
| | print_sys("final info saved.") |
| | |
| | def print_sys(s): |
| | """system print |
| | |
| | Args: |
| | s (str): the string to print |
| | """ |
| | print(s, flush = True, file = sys.stderr) |
| | log_path = os.path.join(args.out_dir, args.log_file) |
| | logging.basicConfig( |
| | filename=log_path, |
| | level=logging.INFO, |
| | ) |
| | logger = logging.getLogger() |
| | logger.info(s) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--data_path', type=str, default='./data') |
| | parser.add_argument('--out_dir', type=str, default='run_1') |
| | parser.add_argument('--device', type=str, default='cuda:0') |
| | parser.add_argument('--log_file', type=str, default="training_ds.log") |
| | parser.add_argument('--epochs', type=int, default=20) |
| | args = parser.parse_args() |
| | |
| | try: |
| | main( |
| | data_path=args.data_path, |
| | out_dir=args.out_dir, |
| | device=args.device |
| | ) |
| | except Exception as e: |
| | print("Origin error in main process:", flush=True) |
| | traceback.print_exc(file=open(os.path.join(args.out_dir, "traceback.log"), "w")) |
| | raise |
| |
|
| | |