File size: 35,185 Bytes
c6a4629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scanpy as sc\n",
    "import numpy as np \n",
    "from scanpy import AnnData\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 14/14 [01:49<00:00,  7.86s/it]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[AnnData object with n_obs Γ— n_vars = 171297 Γ— 62710 backed at 'data/h5ad/reduced/plate1.h5ad'\n",
       "     obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate',\n",
       " AnnData object with n_obs Γ— n_vars = 6892456 Γ— 62710 backed at 'data/h5ad/reduced/plate2.h5ad'\n",
       "     obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate',\n",
       " AnnData object with n_obs Γ— n_vars = 1417313 Γ— 62710 backed at 'data/h5ad/reduced/plate3.h5ad'\n",
       "     obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate',\n",
       " AnnData object with n_obs Γ— n_vars = 155688 Γ— 62710 backed at 'data/h5ad/reduced/plate4.h5ad'\n",
       "     obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate',\n",
       " AnnData object with n_obs Γ— n_vars = 254437 Γ— 62710 backed at 'data/h5ad/reduced/plate5.h5ad'\n",
       "     obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate',\n",
       " AnnData object with n_obs Γ— n_vars = 7062275 Γ— 62710 backed at 'data/h5ad/reduced/plate6.h5ad'\n",
       "     obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate',\n",
       " AnnData object with n_obs Γ— n_vars = 88033 Γ— 62710 backed at 'data/h5ad/reduced/plate7.h5ad'\n",
       "     obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate',\n",
       " AnnData object with n_obs Γ— n_vars = 7611223 Γ— 62710 backed at 'data/h5ad/reduced/plate8.h5ad'\n",
       "     obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate',\n",
       " AnnData object with n_obs Γ— n_vars = 1217465 Γ— 62710 backed at 'data/h5ad/reduced/plate9.h5ad'\n",
       "     obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate',\n",
       " AnnData object with n_obs Γ— n_vars = 201038 Γ— 62710 backed at 'data/h5ad/reduced/plate10.h5ad'\n",
       "     obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate',\n",
       " AnnData object with n_obs Γ— n_vars = 2361999 Γ— 62710 backed at 'data/h5ad/reduced/plate11.h5ad'\n",
       "     obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate',\n",
       " AnnData object with n_obs Γ— n_vars = 8480425 Γ— 62710 backed at 'data/h5ad/reduced/plate12.h5ad'\n",
       "     obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate',\n",
       " AnnData object with n_obs Γ— n_vars = 2612839 Γ— 62710 backed at 'data/h5ad/reduced/plate13.h5ad'\n",
       "     obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate',\n",
       " AnnData object with n_obs Γ— n_vars = 6086200 Γ— 62710 backed at 'data/h5ad/reduced/plate14.h5ad'\n",
       "     obs: 'sample', 'gene_count', 'tscp_count', 'mread_count', 'drugname_drugconc', 'drug', 'cell_line', 'sublibrary', 'BARCODE', 'pcnt_mito', 'S_score', 'G2M_score', 'phase', 'pass_filter', 'cell_name', 'plate']"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Xs = [sc.read_h5ad(f\"../Data/h5ad/h5ad/plate{i}_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad\", backed=\"r\") for i in tqdm(range(1, 14 + 1))]\n",
    "Xs = [sc.read_h5ad(f\"data/h5ad/reduced/plate{i}.h5ad\", backed=\"r\") for i in tqdm(range(1, 14 + 1))]\n",
    "# X = sc.read_h5ad(\"data/tahoe_sampled_1M.h5ad.gz\")\n",
    "# Xs = [X]\n",
    "Xs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define held out data for the validation and test sets\n",
    "val_one_per_moa = [\"Phenylephrine (hydrochloride)\",\"Darolutamide\", \"palbociclib\", \"Tolmetin\", \"Procainamide (hydrochloride)\", \"Trifluridine\", \"Simotinib\", \"Methylprednisolone succinate\", \"Dapagliflozin\", \"CP21R7\", \"Panobinostat\", \"Tofacitinib\", \"Trametinib\", \"Vinblastine (sulfate)\", \"Temsirolimus\", \"Sunitinib\", \"Ralimetinib dimesylate\", \"Tirabrutinib\", \"GSK1059615\", \"SBI-0640756\", \"Lonafarnib\", \"Retinoic acid\"]\n",
    "val_one_entire_moa = [\"Bortezomib\", \"Ixazomib\", \"Ixazomib citrate\"]\n",
    "\n",
    "val_one_per_organ = [\"CVCL_1724\", \"CVCL_0179\", \"CVCL_1715\", \"CVCL_0366\", \"CVCL_1550\", \"CVCL_0480\", \"CVCL_0069\" ]\n",
    "val_one_entire_organ = [\"CVCL_0359\"]\n",
    "\n",
    "test_one_per_moa = [\"Vilanterol\", \"Flutamide\", \"Ribociclib\", \"Valdecoxib\", \"Ξ³-Oryzanol\", \"Trimetrexate\", \"Tucatinib\", \"Triamcinolone\", \"Dapagliflozin ((2S)-1,2-propanediol, hydrate)\", \"LY2090314\", \"Tucidinostat\", \"Tofacitinib (citrate)\", \"Trametinib (DMSO_TF solvate)\", \"vincristine\", \"Torkinib\", \"Vandetanib\", \"Temuterkib\", \"Tirabrutinib (hydrochloride)\", \"Ipatasertib\", \"Tomivosertib\", \"RMC-6236\", \"Tazarotene\"]\n",
    "test_one_entire_moa = [\"Glasdegib\" , \"Sonidegib\", \"Vismodegib\"]\n",
    "\n",
    "test_one_per_organ = [\"CVCL_0397\", \"CVCL_1239\", \"CVCL_0371\", \"CVCL_1716\", \"CVCL_C466\", \"CVCL_1666\", \"CVCL_0293\"]\n",
    "test_one_entire_organ = [\"CVCL_1094\"]\n",
    "\n",
    "val_plate_14 = lambda x:  x.obs_names[0::2]\n",
    "test_plate_14 = lambda x : x.obs_names[1::2]\n",
    "\n",
    "held_out_val = lambda x : x.obs[\"drug\"].isin(val_one_per_moa) | x.obs[\"drug\"].isin(val_one_entire_moa) | x.obs[\"cell_line\"].isin(val_one_per_organ) | x.obs[\"cell_line\"].isin(val_one_entire_organ) | ((x.obs[\"plate\"] == \"plate14\") &  (x.obs_names.isin(val_plate_14(x))))\n",
    "held_out_test = lambda x : x.obs[\"drug\"].isin(test_one_per_moa) | x.obs[\"drug\"].isin(test_one_entire_moa) | x.obs[\"cell_line\"].isin(test_one_per_organ) | x.obs[\"cell_line\"].isin(test_one_entire_organ)  | ((x.obs[\"plate\"] == \"plate14\") &  (x.obs_names.isin(test_plate_14(x))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><style>\n",
       ".dataframe > thead > tr,\n",
       ".dataframe > tbody > tr {\n",
       "  text-align: right;\n",
       "  white-space: pre-wrap;\n",
       "}\n",
       "</style>\n",
       "<small>shape: (380, 4)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>drug</th><th>drugname_drugconc</th><th>len</th><th>dose</th></tr><tr><td>str</td><td>str</td><td>i64</td><td>str</td></tr></thead><tbody><tr><td>&quot;(R)-Verapamil (hydrochloride)&quot;</td><td>&quot;[(&#x27;(R)-Verapamil (hydrochlorid…</td><td>166943</td><td>&quot; 5.0&quot;</td></tr><tr><td>&quot;(S)-Crizotinib&quot;</td><td>&quot;[(&#x27;(S)-Crizotinib&#x27;, 0.5, &#x27;uM&#x27;)…</td><td>86840</td><td>&quot; 0.5&quot;</td></tr><tr><td>&quot;18Ξ²-Glycyrrhetinic acid&quot;</td><td>&quot;[(&#x27;18Ξ²-Glycyrrhetinic acid&#x27;, 5…</td><td>113159</td><td>&quot; 5.0&quot;</td></tr><tr><td>&quot;4EGI-1&quot;</td><td>&quot;[(&#x27;4EGI-1&#x27;, 0.5, &#x27;uM&#x27;)]&quot;</td><td>128549</td><td>&quot; 0.5&quot;</td></tr><tr><td>&quot;5-Azacytidine&quot;</td><td>&quot;[(&#x27;5-Azacytidine&#x27;, 5.0, &#x27;uM&#x27;)]&quot;</td><td>71466</td><td>&quot; 5.0&quot;</td></tr><tr><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td><td>&hellip;</td></tr><tr><td>&quot;olaparib&quot;</td><td>&quot;[(&#x27;olaparib&#x27;, 0.5, &#x27;uM&#x27;)]&quot;</td><td>136783</td><td>&quot; 0.5&quot;</td></tr><tr><td>&quot;palbociclib&quot;</td><td>&quot;[(&#x27;palbociclib&#x27;, 0.5, &#x27;uM&#x27;)]&quot;</td><td>91681</td><td>&quot; 0.5&quot;</td></tr><tr><td>&quot;venetoclax&quot;</td><td>&quot;[(&#x27;venetoclax&#x27;, 0.5, &#x27;uM&#x27;)]&quot;</td><td>118167</td><td>&quot; 0.5&quot;</td></tr><tr><td>&quot;vincristine&quot;</td><td>&quot;[(&#x27;vincristine&#x27;, 0.5, &#x27;uM&#x27;)]&quot;</td><td>35862</td><td>&quot; 0.5&quot;</td></tr><tr><td>&quot;Ξ³-Oryzanol&quot;</td><td>&quot;[(&#x27;Ξ³-Oryzanol&#x27;, 5.0, &#x27;uM&#x27;)]&quot;</td><td>103024</td><td>&quot; 5.0&quot;</td></tr></tbody></table></div>"
      ],
      "text/plain": [
       "shape: (380, 4)\n",
       "β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”\n",
       "β”‚ drug                          ┆ drugname_drugconc               ┆ len    ┆ dose β”‚\n",
       "β”‚ ---                           ┆ ---                             ┆ ---    ┆ ---  β”‚\n",
       "β”‚ str                           ┆ str                             ┆ i64    ┆ str  β”‚\n",
       "β•žβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•ͺ═════════════════════════════════β•ͺ════════β•ͺ══════║\n",
       "β”‚ (R)-Verapamil (hydrochloride) ┆ [('(R)-Verapamil (hydrochlorid… ┆ 166943 ┆  5.0 β”‚\n",
       "β”‚ (S)-Crizotinib                ┆ [('(S)-Crizotinib', 0.5, 'uM')… ┆ 86840  ┆  0.5 β”‚\n",
       "β”‚ 18Ξ²-Glycyrrhetinic acid       ┆ [('18Ξ²-Glycyrrhetinic acid', 5… ┆ 113159 ┆  5.0 β”‚\n",
       "β”‚ 4EGI-1                        ┆ [('4EGI-1', 0.5, 'uM')]         ┆ 128549 ┆  0.5 β”‚\n",
       "β”‚ 5-Azacytidine                 ┆ [('5-Azacytidine', 5.0, 'uM')]  ┆ 71466  ┆  5.0 β”‚\n",
       "β”‚ …                             ┆ …                               ┆ …      ┆ …    β”‚\n",
       "β”‚ olaparib                      ┆ [('olaparib', 0.5, 'uM')]       ┆ 136783 ┆  0.5 β”‚\n",
       "β”‚ palbociclib                   ┆ [('palbociclib', 0.5, 'uM')]    ┆ 91681  ┆  0.5 β”‚\n",
       "β”‚ venetoclax                    ┆ [('venetoclax', 0.5, 'uM')]     ┆ 118167 ┆  0.5 β”‚\n",
       "β”‚ vincristine                   ┆ [('vincristine', 0.5, 'uM')]    ┆ 35862  ┆  0.5 β”‚\n",
       "β”‚ Ξ³-Oryzanol                    ┆ [('Ξ³-Oryzanol', 5.0, 'uM')]     ┆ 103024 ┆  5.0 β”‚\n",
       "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import polars as pl\n",
    "concentrations = pl.read_csv(\"concentrations.csv\")\n",
    "\n",
    "drug_to_concentration = {\n",
    "    row[0]: row[1]\n",
    "    for row in concentrations.iter_rows()\n",
    "}\n",
    "\n",
    "concentrations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "for d in [val_one_per_moa, val_one_entire_moa, val_one_per_organ, val_one_entire_organ, test_one_per_moa, test_one_entire_moa, test_one_per_organ, test_one_entire_organ]:\n",
    "    for i in d:\n",
    "        count = sum([x[(x.obs[\"drug\"] == i) | (x.obs[\"cell_line\"] == i)].n_obs for x in Xs])\n",
    "        # Assert to filter leads to zero cells\n",
    "        if count == 0:\n",
    "            print(f\"{i} {count} leads to zero cells\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mKeyboardInterrupt\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[35]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m Xs = \u001b[43m[\u001b[49m\u001b[43mX\u001b[49m\u001b[43m[\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mpass_filter\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[43m==\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mfull\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[43m&\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mdrugname_drugconc\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mstr\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[43m==\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mdrug\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43;01mlambda\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mdrug_to_concentration\u001b[49m\u001b[43m[\u001b[49m\u001b[43mx\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mstr\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mXs\u001b[49m\u001b[43m]\u001b[49m\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[35]\u001b[39m\u001b[32m, line 1\u001b[39m, in \u001b[36m<listcomp>\u001b[39m\u001b[34m(.0)\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m Xs = [X[(X.obs[\u001b[33m\"\u001b[39m\u001b[33mpass_filter\u001b[39m\u001b[33m\"\u001b[39m] == \u001b[33m\"\u001b[39m\u001b[33mfull\u001b[39m\u001b[33m\"\u001b[39m) & (\u001b[43mX\u001b[49m\u001b[43m.\u001b[49m\u001b[43mobs\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mdrugname_drugconc\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mstr\u001b[39;49m\u001b[43m)\u001b[49m == X.obs[\u001b[33m\"\u001b[39m\u001b[33mdrug\u001b[39m\u001b[33m\"\u001b[39m].map(\u001b[38;5;28;01mlambda\u001b[39;00m x: drug_to_concentration[x]).astype(\u001b[38;5;28mstr\u001b[39m))] \u001b[38;5;28;01mfor\u001b[39;00m X \u001b[38;5;129;01min\u001b[39;00m Xs]\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/flow/.venv/lib/python3.11/site-packages/pandas/core/generic.py:6643\u001b[39m, in \u001b[36mNDFrame.astype\u001b[39m\u001b[34m(self, dtype, copy, errors)\u001b[39m\n\u001b[32m   6637\u001b[39m     results = [\n\u001b[32m   6638\u001b[39m         ser.astype(dtype, copy=copy, errors=errors) \u001b[38;5;28;01mfor\u001b[39;00m _, ser \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.items()\n\u001b[32m   6639\u001b[39m     ]\n\u001b[32m   6641\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m   6642\u001b[39m     \u001b[38;5;66;03m# else, only a single dtype is given\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m6643\u001b[39m     new_data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_mgr\u001b[49m\u001b[43m.\u001b[49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[43m=\u001b[49m\u001b[43merrors\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   6644\u001b[39m     res = \u001b[38;5;28mself\u001b[39m._constructor_from_mgr(new_data, axes=new_data.axes)\n\u001b[32m   6645\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m res.__finalize__(\u001b[38;5;28mself\u001b[39m, method=\u001b[33m\"\u001b[39m\u001b[33mastype\u001b[39m\u001b[33m\"\u001b[39m)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/flow/.venv/lib/python3.11/site-packages/pandas/core/internals/managers.py:430\u001b[39m, in \u001b[36mBaseBlockManager.astype\u001b[39m\u001b[34m(self, dtype, copy, errors)\u001b[39m\n\u001b[32m    427\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m using_copy_on_write():\n\u001b[32m    428\u001b[39m     copy = \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m430\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m    431\u001b[39m \u001b[43m    \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mastype\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m    432\u001b[39m \u001b[43m    \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    433\u001b[39m \u001b[43m    \u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    434\u001b[39m \u001b[43m    \u001b[49m\u001b[43merrors\u001b[49m\u001b[43m=\u001b[49m\u001b[43merrors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    435\u001b[39m \u001b[43m    \u001b[49m\u001b[43musing_cow\u001b[49m\u001b[43m=\u001b[49m\u001b[43musing_copy_on_write\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    436\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/flow/.venv/lib/python3.11/site-packages/pandas/core/internals/managers.py:363\u001b[39m, in \u001b[36mBaseBlockManager.apply\u001b[39m\u001b[34m(self, f, align_keys, **kwargs)\u001b[39m\n\u001b[32m    361\u001b[39m         applied = b.apply(f, **kwargs)\n\u001b[32m    362\u001b[39m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m363\u001b[39m         applied = \u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mf\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    364\u001b[39m     result_blocks = extend_blocks(applied, result_blocks)\n\u001b[32m    366\u001b[39m out = \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m).from_blocks(result_blocks, \u001b[38;5;28mself\u001b[39m.axes)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/flow/.venv/lib/python3.11/site-packages/pandas/core/internals/blocks.py:758\u001b[39m, in \u001b[36mBlock.astype\u001b[39m\u001b[34m(self, dtype, copy, errors, using_cow, squeeze)\u001b[39m\n\u001b[32m    755\u001b[39m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mCan not squeeze with more than one column.\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m    756\u001b[39m     values = values[\u001b[32m0\u001b[39m, :]  \u001b[38;5;66;03m# type: ignore[call-overload]\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m758\u001b[39m new_values = \u001b[43mastype_array_safe\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[43m=\u001b[49m\u001b[43merrors\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    760\u001b[39m new_values = maybe_coerce_values(new_values)\n\u001b[32m    762\u001b[39m refs = \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/flow/.venv/lib/python3.11/site-packages/pandas/core/dtypes/astype.py:237\u001b[39m, in \u001b[36mastype_array_safe\u001b[39m\u001b[34m(values, dtype, copy, errors)\u001b[39m\n\u001b[32m    234\u001b[39m     dtype = dtype.numpy_dtype\n\u001b[32m    236\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m237\u001b[39m     new_values = \u001b[43mastype_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalues\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    238\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mValueError\u001b[39;00m, \u001b[38;5;167;01mTypeError\u001b[39;00m):\n\u001b[32m    239\u001b[39m     \u001b[38;5;66;03m# e.g. _astype_nansafe can fail on object-dtype of strings\u001b[39;00m\n\u001b[32m    240\u001b[39m     \u001b[38;5;66;03m#  trying to convert to float\u001b[39;00m\n\u001b[32m    241\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m errors == \u001b[33m\"\u001b[39m\u001b[33mignore\u001b[39m\u001b[33m\"\u001b[39m:\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/flow/.venv/lib/python3.11/site-packages/pandas/core/dtypes/astype.py:179\u001b[39m, in \u001b[36mastype_array\u001b[39m\u001b[34m(values, dtype, copy)\u001b[39m\n\u001b[32m    175\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m values\n\u001b[32m    177\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(values, np.ndarray):\n\u001b[32m    178\u001b[39m     \u001b[38;5;66;03m# i.e. ExtensionArray\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m179\u001b[39m     values = \u001b[43mvalues\u001b[49m\u001b[43m.\u001b[49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    181\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m    182\u001b[39m     values = _astype_nansafe(values, dtype, copy=copy)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/flow/.venv/lib/python3.11/site-packages/pandas/core/arrays/categorical.py:604\u001b[39m, in \u001b[36mCategorical.astype\u001b[39m\u001b[34m(self, dtype, copy)\u001b[39m\n\u001b[32m    601\u001b[39m         msg = \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mCannot cast \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.categories.dtype\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m dtype to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdtype\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m    602\u001b[39m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(msg)\n\u001b[32m--> \u001b[39m\u001b[32m604\u001b[39m     result = \u001b[43mtake_nd\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m    605\u001b[39m \u001b[43m        \u001b[49m\u001b[43mnew_cats\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mensure_platform_int\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_codes\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[43m=\u001b[49m\u001b[43mfill_value\u001b[49m\n\u001b[32m    606\u001b[39m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    608\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m result\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/flow/.venv/lib/python3.11/site-packages/pandas/core/array_algos/take.py:117\u001b[39m, in \u001b[36mtake_nd\u001b[39m\u001b[34m(arr, indexer, axis, fill_value, allow_fill)\u001b[39m\n\u001b[32m    114\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)\n\u001b[32m    116\u001b[39m arr = np.asarray(arr)\n\u001b[32m--> \u001b[39m\u001b[32m117\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_take_nd_ndarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43marr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindexer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mallow_fill\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/flow/.venv/lib/python3.11/site-packages/pandas/core/array_algos/take.py:162\u001b[39m, in \u001b[36m_take_nd_ndarray\u001b[39m\u001b[34m(arr, indexer, axis, fill_value, allow_fill)\u001b[39m\n\u001b[32m    157\u001b[39m     out = np.empty(out_shape, dtype=dtype)\n\u001b[32m    159\u001b[39m func = _get_take_nd_function(\n\u001b[32m    160\u001b[39m     arr.ndim, arr.dtype, out.dtype, axis=axis, mask_info=mask_info\n\u001b[32m    161\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m162\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43marr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindexer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    164\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m flip_order:\n\u001b[32m    165\u001b[39m     out = out.T\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/flow/.venv/lib/python3.11/site-packages/pandas/core/array_algos/take.py:345\u001b[39m, in \u001b[36m_get_take_nd_function.<locals>.func\u001b[39m\u001b[34m(arr, indexer, out, fill_value)\u001b[39m\n\u001b[32m    343\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mfunc\u001b[39m(arr, indexer, out, fill_value=np.nan) -> \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m    344\u001b[39m     indexer = ensure_platform_int(indexer)\n\u001b[32m--> \u001b[39m\u001b[32m345\u001b[39m     \u001b[43m_take_nd_object\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m    346\u001b[39m \u001b[43m        \u001b[49m\u001b[43marr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindexer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m=\u001b[49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[43m=\u001b[49m\u001b[43mfill_value\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmask_info\u001b[49m\u001b[43m=\u001b[49m\u001b[43mmask_info\u001b[49m\n\u001b[32m    347\u001b[39m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/flow/.venv/lib/python3.11/site-packages/pandas/core/array_algos/take.py:528\u001b[39m, in \u001b[36m_take_nd_object\u001b[39m\u001b[34m(arr, indexer, out, axis, fill_value, mask_info)\u001b[39m\n\u001b[32m    526\u001b[39m     arr = arr.astype(out.dtype)\n\u001b[32m    527\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m arr.shape[axis] > \u001b[32m0\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m528\u001b[39m     \u001b[43marr\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtake\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindexer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maxis\u001b[49m\u001b[43m=\u001b[49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout\u001b[49m\u001b[43m=\u001b[49m\u001b[43mout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    529\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m needs_masking:\n\u001b[32m    530\u001b[39m     outindexer = [\u001b[38;5;28mslice\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m)] * arr.ndim\n",
      "\u001b[31mKeyboardInterrupt\u001b[39m: "
     ]
    }
   ],
   "source": [
    "Xs = [X[(X.obs[\"pass_filter\"] == \"full\") & (X.obs[\"drugname_drugconc\"].astype(str) == X.obs[\"drug\"].map(lambda x: drug_to_concentration[x]).astype(str))] for X in Xs]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Xs[0].write_h5ad(\"data/tahoe_sampled_filtered.h5ad\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Callable\n",
    "import pandas as pd\n",
    "\n",
    "def random_subset(adatas: list[AnnData], scale : int = 1, filter : Callable[[AnnData], \"pd.Series[bool]\"] | None = None) -> AnnData:\n",
    "    np.random.seed(42)\n",
    "    if filter == None:\n",
    "        filter = lambda x: pd.Series(True, index=x.obs.index)\n",
    "\n",
    "    # Ys = [X[(X.obs[\"pass_filter\"] == \"full\") & (X.obs[\"drugname_drugconc\"].astype(str) == X.obs[\"drug\"].map(lambda x: drug_to_concentration[x]).astype(str)) & (filter(X))] for X in adatas]\n",
    "    Ys = [X[(filter(X))] for X in adatas]\n",
    "    Rs = [np.random.choice(Y.obs_names, size=len(Y)//scale, replace=False) for Y in Ys] if scale != 1 else Ys\n",
    "    xs = [X[r, :] for X, r in zip(adatas, Rs)]\n",
    "    x = sc.concat([x.to_memory() for x in xs])\n",
    "    return x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "SCALES = [10000, 1000, 100, 10, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating train subset with scale 10000\n",
      "Creating val subset with scale 10000\n",
      "Creating test subset with scale 10000\n",
      "Creating train subset with scale 1000\n",
      "Creating val subset with scale 1000\n",
      "Creating test subset with scale 1000\n",
      "Creating train subset with scale 100\n",
      "Creating val subset with scale 100\n",
      "Creating test subset with scale 100\n"
     ]
    }
   ],
   "source": [
    "for SCALE in SCALES:\n",
    "    print(f\"Creating train subset with scale {SCALE}\")\n",
    "    xs = random_subset(Xs, SCALE, lambda x: (~held_out_val(x)) & (~held_out_test(x)))\n",
    "    xs.write_h5ad(f\"../Data/subsets/train_{SCALE}x.h5ad\")\n",
    "\n",
    "    print(f\"Creating val subset with scale {SCALE}\")\n",
    "    # First category of hold outs: Drugs/Cell lines where we have trained on other drugs/cell lines from the same MOA/organ\n",
    "    xs = random_subset(Xs, SCALE, lambda x: x.obs[\"drug\"].isin(val_one_per_moa))\n",
    "    xs.write_h5ad(f\"../Data/subsets/val_shared_MOA_{SCALE}x.h5ad\")\n",
    "\n",
    "    xs = random_subset(Xs, SCALE, lambda x: x.obs[\"cell_line\"].isin(val_one_per_organ))\n",
    "    xs.write_h5ad(f\"../Data/subsets/val_shared_organ_{SCALE}x.h5ad\")\n",
    "\n",
    "    # Second category of hold outs: Drugs/Cell lines where we have not trained on any other drugs/cell lines from the same MOA/organ\n",
    "    xs = random_subset(Xs, SCALE, lambda x: x.obs[\"drug\"].isin(val_one_entire_moa))\n",
    "    xs.write_h5ad(f\"../Data/subsets/val_unseen_MOA_{SCALE}x.h5ad\")\n",
    "\n",
    "    xs = random_subset(Xs, SCALE, lambda x: x.obs[\"cell_line\"].isin(val_one_entire_organ))\n",
    "    xs.write_h5ad(f\"../Data/subsets/val_unseen_organ_{SCALE}x.h5ad\")\n",
    "\n",
    "    # plate 14 \n",
    "    xs = random_subset(Xs, SCALE, lambda x: ((x.obs[\"plate\"] == \"plate14\") &  (x.obs_names.isin(val_plate_14(x)))))\n",
    "    xs.write_h5ad(f\"../Data/subsets/val_plate14_{SCALE}x.h5ad\")\n",
    "\n",
    "    # Now test data \n",
    "    print(f\"Creating test subset with scale {SCALE}\")\n",
    "\n",
    "    xs = random_subset(Xs, SCALE, lambda x: x.obs[\"drug\"].isin(test_one_per_moa))\n",
    "    xs.write_h5ad(f\"../Data/subsets/test_shared_MOA_{SCALE}x.h5ad\")\n",
    "\n",
    "    xs = random_subset(Xs, SCALE, lambda x: x.obs[\"cell_line\"].isin(test_one_per_organ))\n",
    "    xs.write_h5ad(f\"../Data/subsets/test_shared_organ_{SCALE}x.h5ad\")\n",
    "\n",
    "    xs = random_subset(Xs, SCALE, lambda x: x.obs[\"drug\"].isin(test_one_entire_moa))\n",
    "    xs.write_h5ad(f\"../Data/subsets/test_unseen_MOA_{SCALE}x.h5ad\")\n",
    "\n",
    "    xs = random_subset(Xs, SCALE, lambda x: x.obs[\"cell_line\"].isin(test_one_entire_organ))\n",
    "    xs.write_h5ad(f\"../Data/subsets/test_unseen_organ_{SCALE}x.h5ad\")\n",
    "\n",
    "    xs = random_subset(Xs, SCALE, lambda x: ((x.obs[\"plate\"] == \"plate14\") &  (x.obs_names.isin(test_plate_14(x)))))\n",
    "    xs.write_h5ad(f\"../Data/subsets/test_plate14_{SCALE}x.h5ad\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}