| | from datasets import DatasetDict |
| | from typing import Optional |
| | import logging |
| | import torch |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | default_logging_config = { |
| | "version": 1, |
| | "disable_existing_loggers": False, |
| | "formatters": { |
| | "default": { |
| | "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| | }, |
| | }, |
| | "handlers": { |
| | "console": { |
| | "class": "logging.StreamHandler", |
| | "formatter": "default", |
| | }, |
| | }, |
| | "loggers": { |
| | "": { |
| | "level": "INFO", |
| | "handlers": ["console"], |
| | }, |
| | }, |
| | } |
| |
|
| |
|
| | def get_torch_device(): |
| | if torch.cuda.is_available(): |
| | device = torch.device("cuda") |
| | elif torch.backends.mps.is_available(): |
| | device = torch.device("mps") |
| | else: |
| | device = torch.device("cpu") |
| | logger.info(f"using {device}") |
| | return device |
| |
|
| |
|
| | def get_uniq_training_labels(ds: DatasetDict, columns_to_exclude: set[str] = None): |
| | columns_to_train_on = [k for k in ds["train"].features.keys() if k not in ( |
| | {"text", "tokens"} if columns_to_exclude is None else columns_to_exclude)] |
| |
|
| | |
| | label_counters = {col: dict() for col in columns_to_train_on} |
| | unique_label_values = {col: set() for col in columns_to_train_on} |
| |
|
| | |
| | for split_name, dataset_split in ds.items(): |
| | for example in dataset_split: |
| | |
| | |
| | for col in columns_to_train_on: |
| | unique_label_values[col].update(example[col]) |
| | for label_val in example[col]: |
| | if label_val not in label_counters[col]: |
| | label_counters[col][label_val] = 0 |
| | label_counters[col][label_val] += 1 |
| |
|
| | logger.info(f"Columns:") |
| | for col in columns_to_train_on: |
| | logger.info(f" {col}:") |
| | |
| | vals = sorted(unique_label_values[col]) |
| | logger.info(f" {len(vals)} labels: {[f'{v}:{label_counters[col][v]}' for v in vals]}") |
| |
|
| | return unique_label_values |
| |
|
| |
|
| | def show_examples(ds: DatasetDict, show_expr: Optional[str]): |
| | logger.info(f"Dataset:\n{ds}") |
| | if not show_expr: |
| | count_to_show = 2 |
| | examples_to_show = ds["train"][:count_to_show] |
| | else: |
| | args_show_tokens = show_expr.split("/") |
| | split_to_show, col_to_show, label_to_show, count_to_show = args_show_tokens |
| | count_to_show = int(count_to_show) |
| | examples_to_show = ds[split_to_show].filter( |
| | lambda exp: label_to_show in exp[col_to_show]).shuffle(seed=42)[:count_to_show] |
| | for i in range(count_to_show): |
| | logger.info(f"Example {i}:") |
| | for feature in examples_to_show.keys(): |
| | logger.info(f" {feature}: {examples_to_show[feature][i]}") |
| |
|