File size: 3,061 Bytes
ab4b5ab
 
 
c5081c8
ab4b5ab
 
360e354
 
0cdb887
 
 
 
 
360e354
0cdb887
 
 
 
 
360e354
0cdb887
 
 
 
 
360e354
0cdb887
 
ab4b5ab
 
c5081c8
 
 
 
 
 
 
 
 
 
 
ab4b5ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from datasets import DatasetDict
from typing import Optional
import logging
import torch

logger = logging.getLogger(__name__)

default_logging_config = {
    "version": 1,
    "disable_existing_loggers": False,
    "formatters": {
        "default": {
            "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        },
    },
    "handlers": {
        "console": {
            "class": "logging.StreamHandler",
            "formatter": "default",
        },
    },
    "loggers": {
        "": {
            "level": "INFO",
            "handlers": ["console"],
        },
    },
}


def get_torch_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():  # For Apple Silicon MPS
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    logger.info(f"using {device}")
    return device


def get_uniq_training_labels(ds: DatasetDict, columns_to_exclude: set[str] = None):
    columns_to_train_on = [k for k in ds["train"].features.keys() if k not in (
        {"text", "tokens"} if columns_to_exclude is None else columns_to_exclude)]

    # Create a dictionary of sets, keyed by each column name
    label_counters = {col: dict() for col in columns_to_train_on}
    unique_label_values = {col: set() for col in columns_to_train_on}

    # Loop through each split and each example, and collect values
    for split_name, dataset_split in ds.items():
        for example in dataset_split:
            # Each of these columns is a list (one entry per token),
            # so we update our set with each token-level value
            for col in columns_to_train_on:
                unique_label_values[col].update(example[col])
                for label_val in example[col]:
                    if label_val not in label_counters[col]:
                        label_counters[col][label_val] = 0  # Inits with 0
                    label_counters[col][label_val] += 1

    logger.info(f"Columns:")
    for col in columns_to_train_on:
        logger.info(f"  {col}:")
        # Convert to a sorted list just to have a nice, stable ordering
        vals = sorted(unique_label_values[col])
        logger.info(f"    {len(vals)} labels: {[f'{v}:{label_counters[col][v]}' for v in vals]}")

    return unique_label_values


def show_examples(ds: DatasetDict, show_expr: Optional[str]):
    logger.info(f"Dataset:\n{ds}")
    if not show_expr:
        count_to_show = 2
        examples_to_show = ds["train"][:count_to_show]
    else:
        args_show_tokens = show_expr.split("/")
        split_to_show, col_to_show, label_to_show, count_to_show = args_show_tokens
        count_to_show = int(count_to_show)
        examples_to_show = ds[split_to_show].filter(
            lambda exp: label_to_show in exp[col_to_show]).shuffle(seed=42)[:count_to_show]
    for i in range(count_to_show):
        logger.info(f"Example {i}:")
        for feature in examples_to_show.keys():
            logger.info(f"  {feature}: {examples_to_show[feature][i]}")