from datasets import DatasetDict from typing import Optional import logging import torch logger = logging.getLogger(__name__) default_logging_config = { "version": 1, "disable_existing_loggers": False, "formatters": { "default": { "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", }, }, "handlers": { "console": { "class": "logging.StreamHandler", "formatter": "default", }, }, "loggers": { "": { "level": "INFO", "handlers": ["console"], }, }, } def get_torch_device(): if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): # For Apple Silicon MPS device = torch.device("mps") else: device = torch.device("cpu") logger.info(f"using {device}") return device def show_class_distribution(dataset, split_name, label_names): """ Print how many samples contain each label in the chosen split. This helps identify imbalance. - dataset[split_name] is a huggingface Dataset - label_names: list of label names in the dataset """ from collections import Counter label_counter = Counter() num_samples = len(dataset[split_name]) # Each sample's `orig_labels` is a list of label indices for ex in dataset[split_name]["orig_labels"]: label_counter.update(ex) logger.info(f"\n--- Class distribution for split '{split_name}' ({num_samples} samples) ---") for idx, label_name in enumerate(label_names): logger.info(f"{idx:02d} ({label_name}): count = {label_counter[idx]}") logger.info("---------------------------------------------------------------\n")