Christina Theodoris
commited on
Commit
·
67f674c
1
Parent(s):
d20ad0a
Add uniform max len for padding for predictions
Browse files
examples/gene_classification.ipynb
CHANGED
|
@@ -139,14 +139,15 @@
|
|
| 139 |
"metadata": {},
|
| 140 |
"outputs": [],
|
| 141 |
"source": [
|
| 142 |
-
"def preprocess_classifier_batch(cell_batch):\n",
|
| 143 |
-
"
|
|
|
|
| 144 |
" def pad_label_example(example):\n",
|
| 145 |
" example[\"labels\"] = np.pad(example[\"labels\"], \n",
|
| 146 |
-
" (0,
|
| 147 |
" mode='constant', constant_values=-100)\n",
|
| 148 |
" example[\"input_ids\"] = np.pad(example[\"input_ids\"], \n",
|
| 149 |
-
" (0,
|
| 150 |
" mode='constant', constant_values=token_dictionary.get(\"<pad>\"))\n",
|
| 151 |
" example[\"attention_mask\"] = (example[\"input_ids\"] != token_dictionary.get(\"<pad>\")).astype(int)\n",
|
| 152 |
" return example\n",
|
|
@@ -158,10 +159,19 @@
|
|
| 158 |
" predict_logits = []\n",
|
| 159 |
" predict_labels = []\n",
|
| 160 |
" model.eval()\n",
|
| 161 |
-
"
|
| 162 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
" batch_evalset = evalset.select([i for i in range(i, max_range)])\n",
|
| 164 |
-
" padded_batch = preprocess_classifier_batch(batch_evalset)\n",
|
| 165 |
" padded_batch.set_format(type=\"torch\")\n",
|
| 166 |
" \n",
|
| 167 |
" input_data_batch = padded_batch[\"input_ids\"]\n",
|
|
@@ -224,7 +234,16 @@
|
|
| 224 |
" all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]\n",
|
| 225 |
" roc_auc = np.sum(all_weighted_roc_auc)\n",
|
| 226 |
" roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))\n",
|
| 227 |
-
" return mean_tpr, roc_auc, roc_auc_sd"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
]
|
| 229 |
},
|
| 230 |
{
|
|
@@ -327,7 +346,7 @@
|
|
| 327 |
" \n",
|
| 328 |
" # load model\n",
|
| 329 |
" model = BertForTokenClassification.from_pretrained(\n",
|
| 330 |
-
" \"/
|
| 331 |
" num_labels=2,\n",
|
| 332 |
" output_attentions = False,\n",
|
| 333 |
" output_hidden_states = False\n",
|
|
|
|
| 139 |
"metadata": {},
|
| 140 |
"outputs": [],
|
| 141 |
"source": [
|
| 142 |
+
"def preprocess_classifier_batch(cell_batch, max_len):\n",
|
| 143 |
+
" if max_len == None:\n",
|
| 144 |
+
" max_len = max([len(i) for i in cell_batch[\"input_ids\"]])\n",
|
| 145 |
" def pad_label_example(example):\n",
|
| 146 |
" example[\"labels\"] = np.pad(example[\"labels\"], \n",
|
| 147 |
+
" (0, max_len-len(example[\"input_ids\"])), \n",
|
| 148 |
" mode='constant', constant_values=-100)\n",
|
| 149 |
" example[\"input_ids\"] = np.pad(example[\"input_ids\"], \n",
|
| 150 |
+
" (0, max_len-len(example[\"input_ids\"])), \n",
|
| 151 |
" mode='constant', constant_values=token_dictionary.get(\"<pad>\"))\n",
|
| 152 |
" example[\"attention_mask\"] = (example[\"input_ids\"] != token_dictionary.get(\"<pad>\")).astype(int)\n",
|
| 153 |
" return example\n",
|
|
|
|
| 159 |
" predict_logits = []\n",
|
| 160 |
" predict_labels = []\n",
|
| 161 |
" model.eval()\n",
|
| 162 |
+
" \n",
|
| 163 |
+
" # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims\n",
|
| 164 |
+
" evalset_len = len(evalset)\n",
|
| 165 |
+
" max_divisible = find_largest_div(evalset_len, forward_batch_size)\n",
|
| 166 |
+
" if len(evalset) - max_divisible == 1:\n",
|
| 167 |
+
" evalset_len = max_divisible\n",
|
| 168 |
+
" \n",
|
| 169 |
+
" max_evalset_len = max(evalset.select([i for i in range(evalset_len)])[\"length\"])\n",
|
| 170 |
+
" \n",
|
| 171 |
+
" for i in range(0, evalset_len, forward_batch_size):\n",
|
| 172 |
+
" max_range = min(i+forward_batch_size, evalset_len)\n",
|
| 173 |
" batch_evalset = evalset.select([i for i in range(i, max_range)])\n",
|
| 174 |
+
" padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)\n",
|
| 175 |
" padded_batch.set_format(type=\"torch\")\n",
|
| 176 |
" \n",
|
| 177 |
" input_data_batch = padded_batch[\"input_ids\"]\n",
|
|
|
|
| 234 |
" all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]\n",
|
| 235 |
" roc_auc = np.sum(all_weighted_roc_auc)\n",
|
| 236 |
" roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))\n",
|
| 237 |
+
" return mean_tpr, roc_auc, roc_auc_sd\n",
|
| 238 |
+
"\n",
|
| 239 |
+
"# Function to find the largest number smaller\n",
|
| 240 |
+
"# than or equal to N that is divisible by k\n",
|
| 241 |
+
"def find_largest_div(N, K):\n",
|
| 242 |
+
" rem = N % K\n",
|
| 243 |
+
" if(rem == 0):\n",
|
| 244 |
+
" return N\n",
|
| 245 |
+
" else:\n",
|
| 246 |
+
" return N - rem"
|
| 247 |
]
|
| 248 |
},
|
| 249 |
{
|
|
|
|
| 346 |
" \n",
|
| 347 |
" # load model\n",
|
| 348 |
" model = BertForTokenClassification.from_pretrained(\n",
|
| 349 |
+
" \"/gladstone/theodoris/lab/ctheodoris/archive/geneformer_files/geneformer/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/\",\n",
|
| 350 |
" num_labels=2,\n",
|
| 351 |
" output_attentions = False,\n",
|
| 352 |
" output_hidden_states = False\n",
|