File size: 2,944 Bytes
e829681 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | from os import PathLike
from transformers import AutoTokenizer, PreTrainedTokenizerBase, AutoConfig
from datasets import Dataset
from typing import Dict, Any, Tuple, List
from pathlib import Path
import json
def make_full_text_(title: str, abstract: str) -> str:
if abstract is None:
abstract = ""
return title + "\n\n" + abstract
def transform_(
examples: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase,
categories_column: str,
cat2ids: Dict[str, int],
tokenizer_cfg: Dict[str, Any],
):
# batched=True => examples это dict списков
full_texts = [
make_full_text_(title, abstract)
for title, abstract in zip(examples["title"], examples["abstract"])
]
title_tokens = tokenizer(examples["title"], **tokenizer_cfg)
full_tokens = tokenizer(full_texts, **tokenizer_cfg)
return {
"title_input_ids": title_tokens["input_ids"],
"title_attention_mask": title_tokens["attention_mask"],
"full_input_ids": full_tokens["input_ids"],
"full_attention_mask": full_tokens["attention_mask"],
"labels_ids": [
[cat2ids[cat] for cat in categories]
for categories in examples[categories_column]
],
}
def json_to_dataset_(data: Dict[str, Any], categories_column: str) -> Tuple[Dataset, List[str]]:
rows = []
cats_names = set()
for arxiv_id, fields in data.items():
row = {
"id": arxiv_id,
**fields
}
rows.append(row)
cats_names.update(fields[categories_column])
return Dataset.from_list(rows), list(cats_names)
def dataset_preprocess(
dataset_path: str | PathLike[str],
classifier_name: str,
categories_column: str):
if not Path(dataset_path).exists():
raise FileNotFoundError
with open(dataset_path, "r", encoding="utf-8") as f:
data = json.load(f)
dataset, categories = json_to_dataset_(data, categories_column)
categories = sorted(categories)
cat2ids = {}
ids2cat = {}
for idx, cat in enumerate(categories):
cat2ids[cat] = idx
ids2cat[idx] = cat
tokenizer = AutoTokenizer.from_pretrained(classifier_name)
model_cfg = AutoConfig.from_pretrained(classifier_name)
tokenizer_cfg = {
"truncation": True,
"padding": "max_length",
"max_length": model_cfg.max_position_embeddings
}
dataset = dataset.map(
lambda examples: transform_(
examples,
tokenizer,
categories_column,
cat2ids,
tokenizer_cfg
),
batched=True,
batch_size=10_000,
remove_columns=dataset.column_names,
)
return {
"dataset": dataset,
"cat2ids": cat2ids,
"ids2cat": ids2cat,
} |