from datasets import DatasetDict from typing import Optional import logging import torch logger = logging.getLogger(__name__) default_logging_config = { "version": 1, "disable_existing_loggers": False, "formatters": { "default": { "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", }, }, "handlers": { "console": { "class": "logging.StreamHandler", "formatter": "default", }, }, "loggers": { "": { "level": "INFO", "handlers": ["console"], }, }, } def get_torch_device(): if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): # For Apple Silicon MPS device = torch.device("mps") else: device = torch.device("cpu") logger.info(f"using {device}") return device def get_uniq_training_labels(ds: DatasetDict, columns_to_exclude: set[str] = None): columns_to_train_on = [k for k in ds["train"].features.keys() if k not in ( {"text", "tokens"} if columns_to_exclude is None else columns_to_exclude)] # Create a dictionary of sets, keyed by each column name label_counters = {col: dict() for col in columns_to_train_on} unique_label_values = {col: set() for col in columns_to_train_on} # Loop through each split and each example, and collect values for split_name, dataset_split in ds.items(): for example in dataset_split: # Each of these columns is a list (one entry per token), # so we update our set with each token-level value for col in columns_to_train_on: unique_label_values[col].update(example[col]) for label_val in example[col]: if label_val not in label_counters[col]: label_counters[col][label_val] = 0 # Inits with 0 label_counters[col][label_val] += 1 logger.info(f"Columns:") for col in columns_to_train_on: logger.info(f" {col}:") # Convert to a sorted list just to have a nice, stable ordering vals = sorted(unique_label_values[col]) logger.info(f" {len(vals)} labels: {[f'{v}:{label_counters[col][v]}' for v in vals]}") return unique_label_values def show_examples(ds: DatasetDict, show_expr: Optional[str]): logger.info(f"Dataset:\n{ds}") if not show_expr: count_to_show = 2 examples_to_show = ds["train"][:count_to_show] else: args_show_tokens = show_expr.split("/") split_to_show, col_to_show, label_to_show, count_to_show = args_show_tokens count_to_show = int(count_to_show) examples_to_show = ds[split_to_show].filter( lambda exp: label_to_show in exp[col_to_show]).shuffle(seed=42)[:count_to_show] for i in range(count_to_show): logger.info(f"Example {i}:") for feature in examples_to_show.keys(): logger.info(f" {feature}: {examples_to_show[feature][i]}")