Yuning You
commited on
Commit
·
2f63b5b
1
Parent(s):
9565da7
update
Browse files
models/cifm.py
CHANGED
|
@@ -32,8 +32,8 @@ class CIFM(
|
|
| 32 |
self.hidden_dim = args.hidden_dim
|
| 33 |
self.radius_spatial_graph = args.radius_spatial_graph
|
| 34 |
|
| 35 |
-
def channel_matching(self,
|
| 36 |
-
channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]
|
| 37 |
|
| 38 |
linear_in = nn.Linear(len(channel2ensembl_ids_target), self.hidden_dim, bias=False)
|
| 39 |
linear_out1 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
|
|
@@ -97,8 +97,6 @@ class CIFM(
|
|
| 97 |
expressions_dec = self.relu(self.mask_cell_expression(embeddings_dec))
|
| 98 |
dropouts_dec = self.sigmoid(self.mask_cell_dropout(embeddings_dec))
|
| 99 |
|
| 100 |
-
# import pdb ; pdb.set_trace()
|
| 101 |
-
|
| 102 |
expressions_dec[dropouts_dec<=0.5] = 0
|
| 103 |
return expressions_dec
|
| 104 |
|
|
|
|
| 32 |
self.hidden_dim = args.hidden_dim
|
| 33 |
self.radius_spatial_graph = args.radius_spatial_graph
|
| 34 |
|
| 35 |
+
def channel_matching(self, channel2ensembl_ids_target, channel2ensembl_ids_source, zero_init_for_unmatched_genes=True):
|
| 36 |
+
# channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]
|
| 37 |
|
| 38 |
linear_in = nn.Linear(len(channel2ensembl_ids_target), self.hidden_dim, bias=False)
|
| 39 |
linear_out1 = nn.Linear(self.hidden_dim, len(channel2ensembl_ids_target), bias=False)
|
|
|
|
| 97 |
expressions_dec = self.relu(self.mask_cell_expression(embeddings_dec))
|
| 98 |
dropouts_dec = self.sigmoid(self.mask_cell_dropout(embeddings_dec))
|
| 99 |
|
|
|
|
|
|
|
| 100 |
expressions_dec[dropouts_dec<=0.5] = 0
|
| 101 |
return expressions_dec
|
| 102 |
|
models/egnn_void_invariant.py
CHANGED
|
@@ -48,15 +48,6 @@ class VIEGNNModel(torch.nn.Module):
|
|
| 48 |
self.convs = torch.nn.ModuleList()
|
| 49 |
for _ in range(num_layers):
|
| 50 |
self.convs.append(EGNNLayer(emb_dim, num_mlp_layers_in_module, aggr))
|
| 51 |
-
|
| 52 |
-
# MLP predictor for invariant tasks using only scalar features
|
| 53 |
-
# self.pred = torch.nn.Sequential(
|
| 54 |
-
# torch.nn.Linear(emb_dim, emb_dim, bias=False),
|
| 55 |
-
# torch.nn.ReLU(),
|
| 56 |
-
# torch.nn.Linear(emb_dim, out_dim, bias=False)
|
| 57 |
-
# )
|
| 58 |
-
# layers = [torch.nn.Linear(emb_dim, emb_dim, bias=False), torch.nn.ReLU()] * (num_mlp_layers_in_module-1) + [torch.nn.Linear(emb_dim, out_dim, bias=False)]
|
| 59 |
-
# self.pred = torch.nn.Sequential(*layers)
|
| 60 |
self.pred = MLPBiasFree(in_dim=emb_dim, out_dim=out_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers_in_module)
|
| 61 |
|
| 62 |
# unroll the batch argments and comment out the pooling operation
|
|
|
|
| 48 |
self.convs = torch.nn.ModuleList()
|
| 49 |
for _ in range(num_layers):
|
| 50 |
self.convs.append(EGNNLayer(emb_dim, num_mlp_layers_in_module, aggr))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
self.pred = MLPBiasFree(in_dim=emb_dim, out_dim=out_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers_in_module)
|
| 52 |
|
| 53 |
# unroll the batch argments and comment out the pooling operation
|
models/layers/__init__.py
DELETED
|
File without changes
|
models/layers/__pycache__/__init__.cpython-311.pyc
DELETED
|
Binary file (170 Bytes)
|
|
|
models/layers/__pycache__/egnn_layer_void_invariant.cpython-311.pyc
DELETED
|
Binary file (4.8 kB)
|
|
|
models/layers/egnn_layer_void_invariant.py
CHANGED
|
@@ -23,49 +23,11 @@ class EGNNLayer(MessagePassing):
|
|
| 23 |
super().__init__(aggr=aggr)
|
| 24 |
|
| 25 |
self.emb_dim = emb_dim
|
| 26 |
-
# self.activation = ReLU()
|
| 27 |
|
| 28 |
self.dist_embedding = Linear(1, emb_dim, bias=False)
|
| 29 |
self.innerprod_embedding = MLPBiasFree(in_dim=1, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
| 30 |
-
|
| 31 |
-
# MLP `\psi_h` for computing messages `m_ij`
|
| 32 |
-
# self.mlp_msg = Sequential(
|
| 33 |
-
# Linear(2 * emb_dim + 1, emb_dim, bias=False),
|
| 34 |
-
# torch.nn.LayerNorm(emb_dim, bias=False),
|
| 35 |
-
# self.activation,
|
| 36 |
-
# Linear(emb_dim, emb_dim, bias=False),
|
| 37 |
-
# torch.nn.LayerNorm(emb_dim, bias=False),
|
| 38 |
-
# self.activation,
|
| 39 |
-
# )
|
| 40 |
-
# layers = [Linear(2 * emb_dim + 1, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] \
|
| 41 |
-
# + [Linear(emb_dim, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] * (num_mlp_layers-1)
|
| 42 |
-
# layers = [Linear(3 * emb_dim, emb_dim, bias=False)] \
|
| 43 |
-
# + [self.activation, Linear(emb_dim, emb_dim, bias=False)] * (num_mlp_layers-1) \
|
| 44 |
-
# + [torch.nn.LayerNorm(emb_dim, bias=False)]
|
| 45 |
-
# self.mlp_msg = Sequential(*layers)
|
| 46 |
self.mlp_msg = MLPBiasFree(in_dim=3*emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
| 47 |
-
|
| 48 |
-
# MLP `\psi_x` for computing messages `\overrightarrow{m}_ij`
|
| 49 |
-
# self.mlp_pos = Sequential(
|
| 50 |
-
# Linear(emb_dim, emb_dim), torch.nn.LayerNorm(emb_dim), self.activation, Linear(emb_dim, 1)
|
| 51 |
-
# )
|
| 52 |
-
# layers = [Linear(emb_dim, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] * (num_mlp_layers-1) + [Linear(emb_dim, 1, bias=False)]
|
| 53 |
-
# layers = [Linear(emb_dim, emb_dim, bias=False), self.activation] * (num_mlp_layers-1) + [Linear(emb_dim, 1, bias=False)]
|
| 54 |
-
# self.mlp_pos = Sequential(*layers)
|
| 55 |
self.mlp_pos = MLPBiasFree(in_dim=emb_dim, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
| 56 |
-
|
| 57 |
-
# MLP `\phi` for computing updated node features `h_i^{l+1}`
|
| 58 |
-
# self.mlp_upd = Sequential(
|
| 59 |
-
# Linear(2 * emb_dim, emb_dim, bias=False),
|
| 60 |
-
# torch.nn.LayerNorm(emb_dim, bias=False),
|
| 61 |
-
# self.activation,
|
| 62 |
-
# Linear(emb_dim, emb_dim, bias=False),
|
| 63 |
-
# torch.nn.LayerNorm(emb_dim, bias=False),
|
| 64 |
-
# self.activation,
|
| 65 |
-
# )
|
| 66 |
-
# layers = [Linear(emb_dim, emb_dim, bias=False), torch.nn.LayerNorm(emb_dim, bias=False), self.activation] * num_mlp_layers
|
| 67 |
-
# layers = [Linear(emb_dim, emb_dim, bias=False)] + [self.activation, Linear(emb_dim, emb_dim, bias=False)] * (num_mlp_layers-1)
|
| 68 |
-
# self.mlp_upd = Sequential(*layers)
|
| 69 |
self.mlp_upd = MLPBiasFree(in_dim=emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
| 70 |
|
| 71 |
def forward(self, h, pos, edge_index):
|
|
@@ -83,7 +45,6 @@ class EGNNLayer(MessagePassing):
|
|
| 83 |
def message(self, h_i, h_j, pos_i, pos_j):
|
| 84 |
# Compute messages
|
| 85 |
pos_diff = pos_i - pos_j
|
| 86 |
-
# dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
|
| 87 |
dists = torch.exp(- torch.norm(pos_diff, dim=-1).unsqueeze(1) / 30 ) # reference distances: 30um
|
| 88 |
inner_prod = torch.mean(h_i * h_j, dim=-1).unsqueeze(1)
|
| 89 |
msg = torch.cat([h_i, h_j, self.dist_embedding(dists)], dim=-1) * self.innerprod_embedding(inner_prod)
|
|
@@ -91,10 +52,6 @@ class EGNNLayer(MessagePassing):
|
|
| 91 |
# Scale magnitude of displacement vector
|
| 92 |
pos_diff = pos_diff * self.mlp_pos(msg)
|
| 93 |
# NOTE: some papers divide pos_diff by (dists + 1) to stabilise model.
|
| 94 |
-
# NOTE: lucidrains clamps pos_diff between some [-n, +n], also for stability.
|
| 95 |
-
# print(torch.cat([h_i, h_j, self.dist_embedding(dists)], dim=-1))
|
| 96 |
-
# print(msg)
|
| 97 |
-
# import pdb; pdb.set_trace()
|
| 98 |
return msg, pos_diff, inner_prod
|
| 99 |
|
| 100 |
def aggregate(self, inputs, index):
|
|
@@ -109,17 +66,12 @@ class EGNNLayer(MessagePassing):
|
|
| 109 |
counts = scatter(counts, index, dim=0, reduce="add")
|
| 110 |
counts[counts==0] = 1
|
| 111 |
pos_aggr = pos_aggr / counts
|
| 112 |
-
# print(msgs)
|
| 113 |
-
# print(msg_aggr)
|
| 114 |
-
# import pdb; pdb.set_trace()
|
| 115 |
return msg_aggr, pos_aggr
|
| 116 |
|
| 117 |
def update(self, aggr_out, h, pos):
|
| 118 |
msg_aggr, pos_aggr = aggr_out
|
| 119 |
-
# upd_out = self.mlp_upd(torch.cat([h, msg_aggr], dim=-1))
|
| 120 |
upd_out = self.mlp_upd(msg_aggr)
|
| 121 |
upd_pos = pos + pos_aggr
|
| 122 |
-
# import pdb; pdb.set_trace()
|
| 123 |
return upd_out, upd_pos
|
| 124 |
|
| 125 |
def __repr__(self) -> str:
|
|
|
|
| 23 |
super().__init__(aggr=aggr)
|
| 24 |
|
| 25 |
self.emb_dim = emb_dim
|
|
|
|
| 26 |
|
| 27 |
self.dist_embedding = Linear(1, emb_dim, bias=False)
|
| 28 |
self.innerprod_embedding = MLPBiasFree(in_dim=1, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
self.mlp_msg = MLPBiasFree(in_dim=3*emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
self.mlp_pos = MLPBiasFree(in_dim=emb_dim, out_dim=1, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
self.mlp_upd = MLPBiasFree(in_dim=emb_dim, out_dim=emb_dim, hidden_dim=emb_dim, num_layer=num_mlp_layers)
|
| 32 |
|
| 33 |
def forward(self, h, pos, edge_index):
|
|
|
|
| 45 |
def message(self, h_i, h_j, pos_i, pos_j):
|
| 46 |
# Compute messages
|
| 47 |
pos_diff = pos_i - pos_j
|
|
|
|
| 48 |
dists = torch.exp(- torch.norm(pos_diff, dim=-1).unsqueeze(1) / 30 ) # reference distances: 30um
|
| 49 |
inner_prod = torch.mean(h_i * h_j, dim=-1).unsqueeze(1)
|
| 50 |
msg = torch.cat([h_i, h_j, self.dist_embedding(dists)], dim=-1) * self.innerprod_embedding(inner_prod)
|
|
|
|
| 52 |
# Scale magnitude of displacement vector
|
| 53 |
pos_diff = pos_diff * self.mlp_pos(msg)
|
| 54 |
# NOTE: some papers divide pos_diff by (dists + 1) to stabilise model.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
return msg, pos_diff, inner_prod
|
| 56 |
|
| 57 |
def aggregate(self, inputs, index):
|
|
|
|
| 66 |
counts = scatter(counts, index, dim=0, reduce="add")
|
| 67 |
counts[counts==0] = 1
|
| 68 |
pos_aggr = pos_aggr / counts
|
|
|
|
|
|
|
|
|
|
| 69 |
return msg_aggr, pos_aggr
|
| 70 |
|
| 71 |
def update(self, aggr_out, h, pos):
|
| 72 |
msg_aggr, pos_aggr = aggr_out
|
|
|
|
| 73 |
upd_out = self.mlp_upd(msg_aggr)
|
| 74 |
upd_pos = pos + pos_aggr
|
|
|
|
| 75 |
return upd_out, upd_pos
|
| 76 |
|
| 77 |
def __repr__(self) -> str:
|
test.ipynb
CHANGED
|
@@ -12,9 +12,16 @@
|
|
| 12 |
"import scanpy as sc"
|
| 13 |
]
|
| 14 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
{
|
| 16 |
"cell_type": "code",
|
| 17 |
-
"execution_count":
|
| 18 |
"metadata": {},
|
| 19 |
"outputs": [
|
| 20 |
{
|
|
@@ -94,13 +101,25 @@
|
|
| 94 |
],
|
| 95 |
"source": [
|
| 96 |
"args_model = torch.load('./model_files/args.pt')\n",
|
| 97 |
-
"
|
|
|
|
| 98 |
"model.eval()"
|
| 99 |
]
|
| 100 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
{
|
| 102 |
"cell_type": "code",
|
| 103 |
-
"execution_count":
|
| 104 |
"metadata": {},
|
| 105 |
"outputs": [
|
| 106 |
{
|
|
@@ -120,7 +139,6 @@
|
|
| 120 |
}
|
| 121 |
],
|
| 122 |
"source": [
|
| 123 |
-
"channel2ensembl = torch.load('./model_files/channel2ensembl.pt')\n",
|
| 124 |
"adata = sc.read_h5ad('./adata.h5ad')\n",
|
| 125 |
"adata.layers['counts'] = adata.X.copy()\n",
|
| 126 |
"sc.pp.normalize_total(adata)\n",
|
|
@@ -128,9 +146,20 @@
|
|
| 128 |
"adata"
|
| 129 |
]
|
| 130 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
{
|
| 132 |
"cell_type": "code",
|
| 133 |
-
"execution_count":
|
| 134 |
"metadata": {},
|
| 135 |
"outputs": [
|
| 136 |
{
|
|
@@ -142,7 +171,16 @@
|
|
| 142 |
}
|
| 143 |
],
|
| 144 |
"source": [
|
| 145 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
]
|
| 147 |
},
|
| 148 |
{
|
|
@@ -174,6 +212,13 @@
|
|
| 174 |
"embeddings, embeddings.shape"
|
| 175 |
]
|
| 176 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
{
|
| 178 |
"cell_type": "code",
|
| 179 |
"execution_count": 5,
|
|
|
|
| 12 |
"import scanpy as sc"
|
| 13 |
]
|
| 14 |
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "markdown",
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"source": [
|
| 19 |
+
"### 1. load model"
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
{
|
| 23 |
"cell_type": "code",
|
| 24 |
+
"execution_count": null,
|
| 25 |
"metadata": {},
|
| 26 |
"outputs": [
|
| 27 |
{
|
|
|
|
| 101 |
],
|
| 102 |
"source": [
|
| 103 |
"args_model = torch.load('./model_files/args.pt')\n",
|
| 104 |
+
"device = 'cpu' # or 'cuda\n",
|
| 105 |
+
"model = CIFM.from_pretrained('ynyou/CIFM', args=args_model).to(device)\n",
|
| 106 |
"model.eval()"
|
| 107 |
]
|
| 108 |
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "markdown",
|
| 111 |
+
"metadata": {},
|
| 112 |
+
"source": [
|
| 113 |
+
"### 2. load and preprocess sample adata\n",
|
| 114 |
+
"- some requirements for adata:\n",
|
| 115 |
+
"- ```adata.X```: need to the raw count\n",
|
| 116 |
+
"- ```adata.obsm['spatial']```: the coordinates of cells in the unit of micrometer\n",
|
| 117 |
+
"- if in a different unit, it might result in a weird geometric graph: we use a radius 20 (micrometer) to construct the geometric graph in the model, so a different unit might result in a overly sparse or dense graph"
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
{
|
| 121 |
"cell_type": "code",
|
| 122 |
+
"execution_count": null,
|
| 123 |
"metadata": {},
|
| 124 |
"outputs": [
|
| 125 |
{
|
|
|
|
| 139 |
}
|
| 140 |
],
|
| 141 |
"source": [
|
|
|
|
| 142 |
"adata = sc.read_h5ad('./adata.h5ad')\n",
|
| 143 |
"adata.layers['counts'] = adata.X.copy()\n",
|
| 144 |
"sc.pp.normalize_total(adata)\n",
|
|
|
|
| 146 |
"adata"
|
| 147 |
]
|
| 148 |
},
|
| 149 |
+
{
|
| 150 |
+
"cell_type": "markdown",
|
| 151 |
+
"metadata": {},
|
| 152 |
+
"source": [
|
| 153 |
+
"### 3. match feature channels\n",
|
| 154 |
+
"- we need a list which maps feature channels to ensemble ids: ```channel2ensembl_ids_target```\n",
|
| 155 |
+
"- format: ```channel2ensembl_ids_target = [[ensemblid1_for_channel1, ensemblid1_for_channel1, ...], [ensemblid11_for_channel2, ensemblid12_for_channel2, ...], ...]```\n",
|
| 156 |
+
"- one channel could correspond to multiple ensemble ids, e.g., when your original data the channels are annotated with gene names\n",
|
| 157 |
+
"- you can use to BioMart map you each gene name to one or multiple ensemble ids"
|
| 158 |
+
]
|
| 159 |
+
},
|
| 160 |
{
|
| 161 |
"cell_type": "code",
|
| 162 |
+
"execution_count": null,
|
| 163 |
"metadata": {},
|
| 164 |
"outputs": [
|
| 165 |
{
|
|
|
|
| 171 |
}
|
| 172 |
],
|
| 173 |
"source": [
|
| 174 |
+
"channel2ensembl_ids_target = [[i] for i in adata.var.index.tolist()]\n",
|
| 175 |
+
"channel2ensembl_ids_source = torch.load('./model_files/channel2ensembl.pt')\n",
|
| 176 |
+
"model.channel_matching(channel2ensembl_ids_target, channel2ensembl_ids_source)"
|
| 177 |
+
]
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"cell_type": "markdown",
|
| 181 |
+
"metadata": {},
|
| 182 |
+
"source": [
|
| 183 |
+
"### 4. embed the microenvironments centered at each cell"
|
| 184 |
]
|
| 185 |
},
|
| 186 |
{
|
|
|
|
| 212 |
"embeddings, embeddings.shape"
|
| 213 |
]
|
| 214 |
},
|
| 215 |
+
{
|
| 216 |
+
"cell_type": "markdown",
|
| 217 |
+
"metadata": {},
|
| 218 |
+
"source": [
|
| 219 |
+
"### 5. infer the potential gene expressions at certain locations"
|
| 220 |
+
]
|
| 221 |
+
},
|
| 222 |
{
|
| 223 |
"cell_type": "code",
|
| 224 |
"execution_count": 5,
|