multi-classifier / utils /__init__.py
veryfansome's picture
feat: updates for models/ud_ewt_gum_pud_20250611
cf60c27
from datasets import DatasetDict
from typing import Optional
import logging
import torch
logger = logging.getLogger(__name__)
default_logging_config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "default",
},
},
"loggers": {
"": {
"level": "INFO",
"handlers": ["console"],
},
},
}
def get_torch_device():
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available(): # For Apple Silicon MPS
device = torch.device("mps")
else:
device = torch.device("cpu")
logger.info(f"using {device}")
return device
def get_uniq_training_labels(ds: DatasetDict, columns_to_exclude: set[str] = None):
columns_to_train_on = [k for k in ds["train"].features.keys() if k not in (
{"text", "tokens"} if columns_to_exclude is None else columns_to_exclude)]
# Create a dictionary of sets, keyed by each column name
label_counters = {col: dict() for col in columns_to_train_on}
unique_label_values = {col: set() for col in columns_to_train_on}
# Loop through each split and each example, and collect values
for split_name, dataset_split in ds.items():
for example in dataset_split:
# Each of these columns is a list (one entry per token),
# so we update our set with each token-level value
for col in columns_to_train_on:
unique_label_values[col].update(example[col])
for label_val in example[col]:
if label_val not in label_counters[col]:
label_counters[col][label_val] = 0 # Inits with 0
label_counters[col][label_val] += 1
logger.info(f"Columns:")
for col in columns_to_train_on:
logger.info(f" {col}:")
# Convert to a sorted list just to have a nice, stable ordering
vals = sorted(unique_label_values[col])
logger.info(f" {len(vals)} labels: {[f'{v}:{label_counters[col][v]}' for v in vals]}")
return unique_label_values
def show_examples(ds: DatasetDict, show_expr: Optional[str]):
logger.info(f"Dataset:\n{ds}")
if show_expr:
args_show_tokens = show_expr.split("/")
split_to_show, col_to_show, label_to_show, count_to_show = args_show_tokens
examples_to_show = ds[split_to_show].filter(
lambda exp: label_to_show in exp[col_to_show]
).shuffle(seed=42)
count_to_show = min(int(count_to_show), len(examples_to_show))
for i in range(count_to_show):
logger.info(f"Example {i}:")
for feature in examples_to_show[:count_to_show].keys():
logger.info(f" {feature}: {examples_to_show[feature][i]}")