BERT / src /preprocess.py
Empfloo's picture
Upload 12 files
e829681 verified
from os import PathLike
from transformers import AutoTokenizer, PreTrainedTokenizerBase, AutoConfig
from datasets import Dataset
from typing import Dict, Any, Tuple, List
from pathlib import Path
import json
def make_full_text_(title: str, abstract: str) -> str:
if abstract is None:
abstract = ""
return title + "\n\n" + abstract
def transform_(
examples: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase,
categories_column: str,
cat2ids: Dict[str, int],
tokenizer_cfg: Dict[str, Any],
):
# batched=True => examples это dict списков
full_texts = [
make_full_text_(title, abstract)
for title, abstract in zip(examples["title"], examples["abstract"])
]
title_tokens = tokenizer(examples["title"], **tokenizer_cfg)
full_tokens = tokenizer(full_texts, **tokenizer_cfg)
return {
"title_input_ids": title_tokens["input_ids"],
"title_attention_mask": title_tokens["attention_mask"],
"full_input_ids": full_tokens["input_ids"],
"full_attention_mask": full_tokens["attention_mask"],
"labels_ids": [
[cat2ids[cat] for cat in categories]
for categories in examples[categories_column]
],
}
def json_to_dataset_(data: Dict[str, Any], categories_column: str) -> Tuple[Dataset, List[str]]:
rows = []
cats_names = set()
for arxiv_id, fields in data.items():
row = {
"id": arxiv_id,
**fields
}
rows.append(row)
cats_names.update(fields[categories_column])
return Dataset.from_list(rows), list(cats_names)
def dataset_preprocess(
dataset_path: str | PathLike[str],
classifier_name: str,
categories_column: str):
if not Path(dataset_path).exists():
raise FileNotFoundError
with open(dataset_path, "r", encoding="utf-8") as f:
data = json.load(f)
dataset, categories = json_to_dataset_(data, categories_column)
categories = sorted(categories)
cat2ids = {}
ids2cat = {}
for idx, cat in enumerate(categories):
cat2ids[cat] = idx
ids2cat[idx] = cat
tokenizer = AutoTokenizer.from_pretrained(classifier_name)
model_cfg = AutoConfig.from_pretrained(classifier_name)
tokenizer_cfg = {
"truncation": True,
"padding": "max_length",
"max_length": model_cfg.max_position_embeddings
}
dataset = dataset.map(
lambda examples: transform_(
examples,
tokenizer,
categories_column,
cat2ids,
tokenizer_cfg
),
batched=True,
batch_size=10_000,
remove_columns=dataset.column_names,
)
return {
"dataset": dataset,
"cat2ids": cat2ids,
"ids2cat": ids2cat,
}