| from os import PathLike
|
| from transformers import AutoTokenizer, PreTrainedTokenizerBase, AutoConfig
|
| from datasets import Dataset
|
| from typing import Dict, Any, Tuple, List
|
| from pathlib import Path
|
| import json
|
|
|
|
|
| def make_full_text_(title: str, abstract: str) -> str:
|
| if abstract is None:
|
| abstract = ""
|
| return title + "\n\n" + abstract
|
|
|
|
|
| def transform_(
|
| examples: Dict[str, Any],
|
| tokenizer: PreTrainedTokenizerBase,
|
| categories_column: str,
|
| cat2ids: Dict[str, int],
|
| tokenizer_cfg: Dict[str, Any],
|
| ):
|
|
|
|
|
| full_texts = [
|
| make_full_text_(title, abstract)
|
| for title, abstract in zip(examples["title"], examples["abstract"])
|
| ]
|
|
|
| title_tokens = tokenizer(examples["title"], **tokenizer_cfg)
|
| full_tokens = tokenizer(full_texts, **tokenizer_cfg)
|
|
|
| return {
|
| "title_input_ids": title_tokens["input_ids"],
|
| "title_attention_mask": title_tokens["attention_mask"],
|
|
|
| "full_input_ids": full_tokens["input_ids"],
|
| "full_attention_mask": full_tokens["attention_mask"],
|
|
|
| "labels_ids": [
|
| [cat2ids[cat] for cat in categories]
|
| for categories in examples[categories_column]
|
| ],
|
| }
|
|
|
|
|
| def json_to_dataset_(data: Dict[str, Any], categories_column: str) -> Tuple[Dataset, List[str]]:
|
| rows = []
|
| cats_names = set()
|
| for arxiv_id, fields in data.items():
|
| row = {
|
| "id": arxiv_id,
|
| **fields
|
| }
|
| rows.append(row)
|
| cats_names.update(fields[categories_column])
|
|
|
| return Dataset.from_list(rows), list(cats_names)
|
|
|
|
|
| def dataset_preprocess(
|
| dataset_path: str | PathLike[str],
|
| classifier_name: str,
|
| categories_column: str):
|
|
|
| if not Path(dataset_path).exists():
|
| raise FileNotFoundError
|
|
|
| with open(dataset_path, "r", encoding="utf-8") as f:
|
| data = json.load(f)
|
|
|
| dataset, categories = json_to_dataset_(data, categories_column)
|
| categories = sorted(categories)
|
|
|
| cat2ids = {}
|
| ids2cat = {}
|
|
|
| for idx, cat in enumerate(categories):
|
| cat2ids[cat] = idx
|
| ids2cat[idx] = cat
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(classifier_name)
|
| model_cfg = AutoConfig.from_pretrained(classifier_name)
|
|
|
| tokenizer_cfg = {
|
| "truncation": True,
|
| "padding": "max_length",
|
| "max_length": model_cfg.max_position_embeddings
|
| }
|
|
|
| dataset = dataset.map(
|
| lambda examples: transform_(
|
| examples,
|
| tokenizer,
|
| categories_column,
|
| cat2ids,
|
| tokenizer_cfg
|
| ),
|
| batched=True,
|
| batch_size=10_000,
|
| remove_columns=dataset.column_names,
|
| )
|
|
|
| return {
|
| "dataset": dataset,
|
| "cat2ids": cat2ids,
|
| "ids2cat": ids2cat,
|
| } |