Update geneformer/evaluation_utils.py to be compatible with different versions of the datasets package.
#552
by
IchigoJiken
- opened
geneformer/evaluation_utils.py
CHANGED
|
@@ -8,6 +8,7 @@ import numpy as np
|
|
| 8 |
import pandas as pd
|
| 9 |
import seaborn as sns
|
| 10 |
import torch
|
|
|
|
| 11 |
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
| 12 |
from sklearn import preprocessing
|
| 13 |
from sklearn.metrics import (
|
|
@@ -103,6 +104,10 @@ def classifier_predict(model, classifier_type, evalset, forward_batch_size, gene
|
|
| 103 |
)
|
| 104 |
padded_batch.set_format(type="torch")
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
input_data_batch = padded_batch["input_ids"]
|
| 107 |
attn_msk_batch = padded_batch["attention_mask"]
|
| 108 |
label_batch = padded_batch[label_name]
|
|
|
|
| 8 |
import pandas as pd
|
| 9 |
import seaborn as sns
|
| 10 |
import torch
|
| 11 |
+
import datasets
|
| 12 |
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
| 13 |
from sklearn import preprocessing
|
| 14 |
from sklearn.metrics import (
|
|
|
|
| 104 |
)
|
| 105 |
padded_batch.set_format(type="torch")
|
| 106 |
|
| 107 |
+
# For datasets>=4.0.0, convert to dict to avoid format issues
|
| 108 |
+
if int(datasets.__version__.split(".")[0]) >= 4:
|
| 109 |
+
padded_batch = padded_batch[:]
|
| 110 |
+
|
| 111 |
input_data_batch = padded_batch["input_ids"]
|
| 112 |
attn_msk_batch = padded_batch["attention_mask"]
|
| 113 |
label_batch = padded_batch[label_name]
|