from os import PathLike from transformers import AutoTokenizer, PreTrainedTokenizerBase, AutoConfig from datasets import Dataset from typing import Dict, Any, Tuple, List from pathlib import Path import json def make_full_text_(title: str, abstract: str) -> str: if abstract is None: abstract = "" return title + "\n\n" + abstract def transform_( examples: Dict[str, Any], tokenizer: PreTrainedTokenizerBase, categories_column: str, cat2ids: Dict[str, int], tokenizer_cfg: Dict[str, Any], ): # batched=True => examples это dict списков full_texts = [ make_full_text_(title, abstract) for title, abstract in zip(examples["title"], examples["abstract"]) ] title_tokens = tokenizer(examples["title"], **tokenizer_cfg) full_tokens = tokenizer(full_texts, **tokenizer_cfg) return { "title_input_ids": title_tokens["input_ids"], "title_attention_mask": title_tokens["attention_mask"], "full_input_ids": full_tokens["input_ids"], "full_attention_mask": full_tokens["attention_mask"], "labels_ids": [ [cat2ids[cat] for cat in categories] for categories in examples[categories_column] ], } def json_to_dataset_(data: Dict[str, Any], categories_column: str) -> Tuple[Dataset, List[str]]: rows = [] cats_names = set() for arxiv_id, fields in data.items(): row = { "id": arxiv_id, **fields } rows.append(row) cats_names.update(fields[categories_column]) return Dataset.from_list(rows), list(cats_names) def dataset_preprocess( dataset_path: str | PathLike[str], classifier_name: str, categories_column: str): if not Path(dataset_path).exists(): raise FileNotFoundError with open(dataset_path, "r", encoding="utf-8") as f: data = json.load(f) dataset, categories = json_to_dataset_(data, categories_column) categories = sorted(categories) cat2ids = {} ids2cat = {} for idx, cat in enumerate(categories): cat2ids[cat] = idx ids2cat[idx] = cat tokenizer = AutoTokenizer.from_pretrained(classifier_name) model_cfg = AutoConfig.from_pretrained(classifier_name) tokenizer_cfg = { "truncation": True, "padding": "max_length", "max_length": model_cfg.max_position_embeddings } dataset = dataset.map( lambda examples: transform_( examples, tokenizer, categories_column, cat2ids, tokenizer_cfg ), batched=True, batch_size=10_000, remove_columns=dataset.column_names, ) return { "dataset": dataset, "cat2ids": cat2ids, "ids2cat": ids2cat, }