| from datasets import DatasetDict |
| from typing import Optional |
| import itertools |
| import logging |
| import sentencepiece as spm |
| import torch |
|
|
| logger = logging.getLogger(__name__) |
|
|
| sp = spm.SentencePieceProcessor() |
| sp.LoadFromFile(f"sp.model") |
|
|
| default_logging_config = { |
| "version": 1, |
| "disable_existing_loggers": False, |
| "formatters": { |
| "default": { |
| "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| }, |
| }, |
| "handlers": { |
| "console": { |
| "class": "logging.StreamHandler", |
| "formatter": "default", |
| }, |
| }, |
| "loggers": { |
| "": { |
| "level": "INFO", |
| "handlers": ["console"], |
| }, |
| }, |
| } |
|
|
|
|
| def get_torch_device(): |
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
| elif torch.backends.mps.is_available(): |
| device = torch.device("mps") |
| else: |
| device = torch.device("cpu") |
| logger.info(f"using {device}") |
| return device |
|
|
|
|
| def get_uniq_training_labels(ds: DatasetDict, columns_to_exclude: set[str] = None): |
| columns_to_train_on = [k for k in ds["train"].features.keys() if k not in ( |
| {"text", "tokens"} if columns_to_exclude is None else columns_to_exclude)] |
|
|
| |
| label_counters = {col: dict() for col in columns_to_train_on} |
| unique_label_values = {col: set() for col in columns_to_train_on} |
|
|
| |
| for split_name, dataset_split in ds.items(): |
| for example in dataset_split: |
| |
| |
| for col in columns_to_train_on: |
| unique_label_values[col].update(example[col]) |
| for label_val in example[col]: |
| if label_val not in label_counters[col]: |
| label_counters[col][label_val] = 0 |
| label_counters[col][label_val] += 1 |
|
|
| logger.info(f"Columns:") |
| for col in columns_to_train_on: |
| logger.info(f" {col}:") |
| |
| vals = sorted(unique_label_values[col]) |
| logger.info(f" {len(vals)} labels: {[f'{v}:{label_counters[col][v]}' for v in vals]}") |
|
|
| return unique_label_values |
|
|
|
|
| def show_examples(ds: DatasetDict, show_expr: Optional[str]): |
| logger.info(f"Dataset:\n{ds}") |
| if not show_expr: |
| count_to_show = 2 |
| examples_to_show = ds["train"][:count_to_show] |
| else: |
| args_show_tokens = show_expr.split("/") |
| split_to_show, col_to_show, label_to_show, count_to_show = args_show_tokens |
| count_to_show = int(count_to_show) |
| examples_to_show = ds[split_to_show].filter( |
| lambda exp: label_to_show in exp[col_to_show]).shuffle(seed=42)[:count_to_show] |
| for i in range(count_to_show): |
| logger.info(f"Example {i}:") |
| for feature in examples_to_show.keys(): |
| logger.info(f" {feature}: {examples_to_show[feature][i]}") |
|
|
|
|
| def sp_tokenize(text: str): |
| return list(itertools.chain.from_iterable([s.strip("▁").split("▁") for s in sp.EncodeAsPieces(text)])) |
|
|