|
|
from pathlib import Path |
|
|
import torch |
|
|
import os |
|
|
import random |
|
|
import argparse |
|
|
import json |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from sklearn.metrics import precision_recall_fscore_support |
|
|
from ast import literal_eval |
|
|
|
|
|
|
|
|
def pred_by_threshold( |
|
|
t: float, |
|
|
y_true: np.array, |
|
|
similarities: np.array, |
|
|
classes: dict, |
|
|
): |
|
|
preds = (similarities >= t) * 1 |
|
|
sk_results = precision_recall_fscore_support( |
|
|
y_true, |
|
|
preds, |
|
|
|
|
|
) |
|
|
outputs = { |
|
|
"f1": np.average(sk_results[2]), |
|
|
"P": np.average(sk_results[0]), |
|
|
"R": np.average(sk_results[1]), |
|
|
} |
|
|
for label_name, idx in classes.items(): |
|
|
outputs[f"{label_name}_f1"] = sk_results[2][idx] |
|
|
return outputs |
|
|
|
|
|
|
|
|
def get_avg_length(dataset: torch.utils.data.Dataset): |
|
|
all_lengths = 0 |
|
|
data_size = len(dataset) |
|
|
for i in range(data_size): |
|
|
all_lengths += len(dataset[i]["input_ids"]) |
|
|
return all_lengths / data_size |
|
|
|
|
|
|
|
|
def load_csv_multi_label(filename: str, col_name: str = "labels") -> pd.DataFrame: |
|
|
"""Prevent Pandas from converting lists of int into lists of strings. |
|
|
|
|
|
Args: |
|
|
filename (str): path of a csv file |
|
|
col_name (str, optional): column name of lists of int. Defaults to 'labels'. |
|
|
|
|
|
Returns: |
|
|
pd.DataFrame: a Pandas dataframe |
|
|
""" |
|
|
return pd.read_csv(filename, converters={col_name: literal_eval}) |
|
|
|
|
|
|
|
|
def save_logged_results(filename: str, results: dict): |
|
|
try: |
|
|
old_df = pd.read_csv(filename) |
|
|
df = pd.concat([old_df, pd.DataFrame(results)], ignore_index=True) |
|
|
except FileNotFoundError: |
|
|
df = pd.DataFrame(results) |
|
|
|
|
|
df.to_csv(filename, index=None) |
|
|
|
|
|
|
|
|
def set_seed(seed): |
|
|
""" |
|
|
Args: |
|
|
seed: an integer number to initialize a pseudorandom number generator |
|
|
""" |
|
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(seed) |
|
|
|
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
def save_baseline_table( |
|
|
y_preds: list, |
|
|
baseline_name: str, |
|
|
baseline_result_file: str = "results/baselines.pkl", |
|
|
all_doc_idx: list = None, |
|
|
) -> None: |
|
|
if Path(baseline_result_file).exists(): |
|
|
df = pd.read_pickle(baseline_result_file) |
|
|
else: |
|
|
assert all_doc_idx is not None |
|
|
df = pd.DataFrame({"doc_idx": all_doc_idx}) |
|
|
|
|
|
df[baseline_name] = y_preds |
|
|
df.to_pickle(baseline_result_file) |
|
|
|
|
|
|
|
|
def load_params(path_of_params): |
|
|
with open(path_of_params, "r") as f: |
|
|
params = json.load(f) |
|
|
return argparse.Namespace(**params) |
|
|
|
|
|
|
|
|
def get_label_words(classes: list, use_multi_label_words=False) -> list: |
|
|
mapping = { |
|
|
"cyst": "cyst", |
|
|
"HCC": "hcc", |
|
|
"cirrhosis": "cirrhosis", |
|
|
"post-treatment": "posttreatment", |
|
|
"steatosis": "steatosis", |
|
|
"metastasis": "metastasis", |
|
|
"hemangioma": "hemangioma", |
|
|
} |
|
|
if use_multi_label_words: |
|
|
mapping = { |
|
|
"cyst": ["cyst"], |
|
|
"HCC": ["hcc", "hepatoma"], |
|
|
"cirrhosis": ["cirrhosis"], |
|
|
"post-treatment": ["posttreatment"], |
|
|
"steatosis": ["steatosis", "steatohepatitis"], |
|
|
"metastasis": ["metastasis"], |
|
|
"hemangioma": ["hemangioma"], |
|
|
} |
|
|
|
|
|
label_words = [mapping[c] for c in classes] |
|
|
return label_words |
|
|
|
|
|
|
|
|
def seed_mapper(data_type: str) -> list: |
|
|
mapping = { |
|
|
"train_8": [2, 4, 7, 11, 21, 23, 24, 36, 44, 128], |
|
|
"train_32": [0, 1, 3, 7, 10], |
|
|
} |
|
|
if data_type in mapping: |
|
|
return mapping[data_type] |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|