|
|
from datasets import DatasetDict |
|
|
from typing import Optional |
|
|
import logging |
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
default_logging_config = { |
|
|
"version": 1, |
|
|
"disable_existing_loggers": False, |
|
|
"formatters": { |
|
|
"default": { |
|
|
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
|
}, |
|
|
}, |
|
|
"handlers": { |
|
|
"console": { |
|
|
"class": "logging.StreamHandler", |
|
|
"formatter": "default", |
|
|
}, |
|
|
}, |
|
|
"loggers": { |
|
|
"": { |
|
|
"level": "INFO", |
|
|
"handlers": ["console"], |
|
|
}, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def get_torch_device(): |
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device("cuda") |
|
|
elif torch.backends.mps.is_available(): |
|
|
device = torch.device("mps") |
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
logger.info(f"using {device}") |
|
|
return device |
|
|
|
|
|
|
|
|
def get_uniq_training_labels(ds: DatasetDict, columns_to_exclude: set[str] = None): |
|
|
columns_to_train_on = [k for k in ds["train"].features.keys() if k not in ( |
|
|
{"text", "tokens"} if columns_to_exclude is None else columns_to_exclude)] |
|
|
|
|
|
|
|
|
label_counters = {col: dict() for col in columns_to_train_on} |
|
|
unique_label_values = {col: set() for col in columns_to_train_on} |
|
|
|
|
|
|
|
|
for split_name, dataset_split in ds.items(): |
|
|
for example in dataset_split: |
|
|
|
|
|
|
|
|
for col in columns_to_train_on: |
|
|
unique_label_values[col].update(example[col]) |
|
|
for label_val in example[col]: |
|
|
if label_val not in label_counters[col]: |
|
|
label_counters[col][label_val] = 0 |
|
|
label_counters[col][label_val] += 1 |
|
|
|
|
|
logger.info(f"Columns:") |
|
|
for col in columns_to_train_on: |
|
|
logger.info(f" {col}:") |
|
|
|
|
|
vals = sorted(unique_label_values[col]) |
|
|
logger.info(f" {len(vals)} labels: {[f'{v}:{label_counters[col][v]}' for v in vals]}") |
|
|
|
|
|
return unique_label_values |
|
|
|
|
|
|
|
|
def show_examples(ds: DatasetDict, show_expr: Optional[str]): |
|
|
logger.info(f"Dataset:\n{ds}") |
|
|
if not show_expr: |
|
|
count_to_show = 2 |
|
|
examples_to_show = ds["train"][:count_to_show] |
|
|
else: |
|
|
args_show_tokens = show_expr.split("/") |
|
|
split_to_show, col_to_show, label_to_show, count_to_show = args_show_tokens |
|
|
count_to_show = int(count_to_show) |
|
|
examples_to_show = ds[split_to_show].filter( |
|
|
lambda exp: label_to_show in exp[col_to_show]).shuffle(seed=42)[:count_to_show] |
|
|
for i in range(count_to_show): |
|
|
logger.info(f"Example {i}:") |
|
|
for feature in examples_to_show.keys(): |
|
|
logger.info(f" {feature}: {examples_to_show[feature][i]}") |
|
|
|