|
|
from datasets import DatasetDict |
|
|
from typing import Optional |
|
|
import logging |
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
default_logging_config = { |
|
|
"version": 1, |
|
|
"disable_existing_loggers": False, |
|
|
"formatters": { |
|
|
"default": { |
|
|
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
|
}, |
|
|
}, |
|
|
"handlers": { |
|
|
"console": { |
|
|
"class": "logging.StreamHandler", |
|
|
"formatter": "default", |
|
|
}, |
|
|
}, |
|
|
"loggers": { |
|
|
"": { |
|
|
"level": "INFO", |
|
|
"handlers": ["console"], |
|
|
}, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def get_torch_device(): |
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device("cuda") |
|
|
elif torch.backends.mps.is_available(): |
|
|
device = torch.device("mps") |
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
logger.info(f"using {device}") |
|
|
return device |
|
|
|
|
|
|
|
|
def show_class_distribution(dataset, split_name, label_names): |
|
|
""" |
|
|
Print how many samples contain each label in the chosen split. |
|
|
This helps identify imbalance. |
|
|
- dataset[split_name] is a huggingface Dataset |
|
|
- label_names: list of label names in the dataset |
|
|
""" |
|
|
from collections import Counter |
|
|
label_counter = Counter() |
|
|
num_samples = len(dataset[split_name]) |
|
|
|
|
|
|
|
|
for ex in dataset[split_name]["orig_labels"]: |
|
|
label_counter.update(ex) |
|
|
|
|
|
logger.info(f"\n--- Class distribution for split '{split_name}' ({num_samples} samples) ---") |
|
|
for idx, label_name in enumerate(label_names): |
|
|
logger.info(f"{idx:02d} ({label_name}): count = {label_counter[idx]}") |
|
|
logger.info("---------------------------------------------------------------\n") |
|
|
|