deberta-goemotions / utils /__init__.py
veryfansome's picture
feat: adding training scripts
ef613cf
from datasets import DatasetDict
from typing import Optional
import logging
import torch
logger = logging.getLogger(__name__)
default_logging_config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "default",
},
},
"loggers": {
"": {
"level": "INFO",
"handlers": ["console"],
},
},
}
def get_torch_device():
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available(): # For Apple Silicon MPS
device = torch.device("mps")
else:
device = torch.device("cpu")
logger.info(f"using {device}")
return device
def show_class_distribution(dataset, split_name, label_names):
"""
Print how many samples contain each label in the chosen split.
This helps identify imbalance.
- dataset[split_name] is a huggingface Dataset
- label_names: list of label names in the dataset
"""
from collections import Counter
label_counter = Counter()
num_samples = len(dataset[split_name])
# Each sample's `orig_labels` is a list of label indices
for ex in dataset[split_name]["orig_labels"]:
label_counter.update(ex)
logger.info(f"\n--- Class distribution for split '{split_name}' ({num_samples} samples) ---")
for idx, label_name in enumerate(label_names):
logger.info(f"{idx:02d} ({label_name}): count = {label_counter[idx]}")
logger.info("---------------------------------------------------------------\n")