Upload in_silico_perturber_stats.py
#313
by
davidjwen
- opened
geneformer/in_silico_perturber_stats.py
CHANGED
|
@@ -192,16 +192,27 @@ def get_impact_component(test_value, gaussian_mixture_model):
|
|
| 192 |
|
| 193 |
|
| 194 |
# aggregate data for single perturbation in multiple cells
|
| 195 |
-
def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
|
| 196 |
-
names = ["Cosine_shift"]
|
| 197 |
-
|
| 198 |
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
|
| 207 |
def find(variable, x):
|
|
@@ -1017,8 +1028,8 @@ class InSilicoPerturberStats:
|
|
| 1017 |
cos_sims_df_initial, dict_list, self.combos, self.anchor_token
|
| 1018 |
)
|
| 1019 |
|
| 1020 |
-
elif self.mode == "aggregate_data":
|
| 1021 |
-
cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
|
| 1022 |
|
| 1023 |
elif self.mode == "aggregate_gene_shifts":
|
| 1024 |
cos_sims_df = isp_aggregate_gene_shifts(
|
|
|
|
| 192 |
|
| 193 |
|
| 194 |
# aggregate data for single perturbation in multiple cells
|
| 195 |
+
def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
|
| 196 |
+
names = ["Cosine_shift", "Gene"]
|
| 197 |
+
cos_sims_full_dfs = []
|
| 198 |
|
| 199 |
+
|
| 200 |
+
gene_ids_df = cos_sims_df.loc[np.isin(cos_sims_df["Ensembl_ID"], genes_perturbed), :]
|
| 201 |
+
tokens = gene_ids_df["Gene"]
|
| 202 |
+
symbols = gene_ids_df["Gene_name"]
|
| 203 |
+
|
| 204 |
+
for token, symbol in zip(tokens, symbols):
|
| 205 |
+
cos_shift_data = []
|
| 206 |
+
for dict_i in dict_list:
|
| 207 |
+
cos_shift_data += dict_i.get((token, "cell_emb"), [])
|
| 208 |
+
|
| 209 |
+
df = pd.DataFrame(columns=names)
|
| 210 |
+
df["Cosine_shift"] = cos_shift_data
|
| 211 |
+
df["Gene"] = symbol
|
| 212 |
+
cos_sims_full_dfs.append(df)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
return pd.concat(cos_sims_full_dfs)
|
| 216 |
|
| 217 |
|
| 218 |
def find(variable, x):
|
|
|
|
| 1028 |
cos_sims_df_initial, dict_list, self.combos, self.anchor_token
|
| 1029 |
)
|
| 1030 |
|
| 1031 |
+
elif self.mode == "aggregate_data":
|
| 1032 |
+
cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list, self.genes_perturbed)
|
| 1033 |
|
| 1034 |
elif self.mode == "aggregate_gene_shifts":
|
| 1035 |
cos_sims_df = isp_aggregate_gene_shifts(
|