File size: 1,765 Bytes
ef613cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from datasets import DatasetDict
from typing import Optional
import logging
import torch

logger = logging.getLogger(__name__)

default_logging_config = {
    "version": 1,
    "disable_existing_loggers": False,
    "formatters": {
        "default": {
            "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        },
    },
    "handlers": {
        "console": {
            "class": "logging.StreamHandler",
            "formatter": "default",
        },
    },
    "loggers": {
        "": {
            "level": "INFO",
            "handlers": ["console"],
        },
    },
}


def get_torch_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():  # For Apple Silicon MPS
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    logger.info(f"using {device}")
    return device


def show_class_distribution(dataset, split_name, label_names):
    """
    Print how many samples contain each label in the chosen split.
    This helps identify imbalance.
    - dataset[split_name] is a huggingface Dataset
    - label_names: list of label names in the dataset
    """
    from collections import Counter
    label_counter = Counter()
    num_samples = len(dataset[split_name])

    # Each sample's `orig_labels` is a list of label indices
    for ex in dataset[split_name]["orig_labels"]:
        label_counter.update(ex)

    logger.info(f"\n--- Class distribution for split '{split_name}' ({num_samples} samples) ---")
    for idx, label_name in enumerate(label_names):
        logger.info(f"{idx:02d} ({label_name}): count = {label_counter[idx]}")
    logger.info("---------------------------------------------------------------\n")