Commit
·
c48e37c
1
Parent(s):
748f48a
Update minor formatting
Browse files
examples/hyperparam_optimiz_for_disease_classifier.py
CHANGED
|
@@ -50,7 +50,7 @@ def initialize_ray_with_check(ip_address):
|
|
| 50 |
# Usage:
|
| 51 |
ip = 'your_ip:xxxx' # Replace with your actual IP address and port
|
| 52 |
if initialize_ray_with_check(ip):
|
| 53 |
-
print("Ray initialized successfully
|
| 54 |
else:
|
| 55 |
print("Error during Ray initialization.")
|
| 56 |
|
|
@@ -62,7 +62,7 @@ import seaborn as sns; sns.set()
|
|
| 62 |
from collections import Counter
|
| 63 |
from datasets import load_from_disk
|
| 64 |
from scipy.stats import ranksums
|
| 65 |
-
from sklearn.metrics import accuracy_score
|
| 66 |
from transformers import BertForSequenceClassification
|
| 67 |
from transformers import Trainer
|
| 68 |
from transformers.training_args import TrainingArguments
|
|
@@ -155,6 +155,7 @@ def model_init():
|
|
| 155 |
return model
|
| 156 |
|
| 157 |
# define metrics
|
|
|
|
| 158 |
def compute_metrics(pred):
|
| 159 |
labels = pred.label_ids
|
| 160 |
preds = pred.predictions.argmax(-1)
|
|
|
|
| 50 |
# Usage:
|
| 51 |
ip = 'your_ip:xxxx' # Replace with your actual IP address and port
|
| 52 |
if initialize_ray_with_check(ip):
|
| 53 |
+
print("Ray initialized successfully.")
|
| 54 |
else:
|
| 55 |
print("Error during Ray initialization.")
|
| 56 |
|
|
|
|
| 62 |
from collections import Counter
|
| 63 |
from datasets import load_from_disk
|
| 64 |
from scipy.stats import ranksums
|
| 65 |
+
from sklearn.metrics import accuracy_score
|
| 66 |
from transformers import BertForSequenceClassification
|
| 67 |
from transformers import Trainer
|
| 68 |
from transformers.training_args import TrainingArguments
|
|
|
|
| 155 |
return model
|
| 156 |
|
| 157 |
# define metrics
|
| 158 |
+
# note: macro f1 score recommended for imbalanced multiclass classifiers
|
| 159 |
def compute_metrics(pred):
|
| 160 |
labels = pred.label_ids
|
| 161 |
preds = pred.predictions.argmax(-1)
|