File size: 2,944 Bytes
e829681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from os import PathLike
from transformers import AutoTokenizer, PreTrainedTokenizerBase, AutoConfig
from datasets import Dataset
from typing import Dict, Any, Tuple, List
from pathlib import Path
import json


def make_full_text_(title: str, abstract: str) -> str:
    if abstract is None:
        abstract = ""
    return title + "\n\n" + abstract


def transform_(

    examples: Dict[str, Any],

    tokenizer: PreTrainedTokenizerBase,

    categories_column: str,

    cat2ids: Dict[str, int],

    tokenizer_cfg: Dict[str, Any],

):
    # batched=True => examples это dict списков

    full_texts = [
        make_full_text_(title, abstract)
        for title, abstract in zip(examples["title"], examples["abstract"])
    ]

    title_tokens = tokenizer(examples["title"], **tokenizer_cfg)
    full_tokens = tokenizer(full_texts, **tokenizer_cfg)

    return {
        "title_input_ids": title_tokens["input_ids"],
        "title_attention_mask": title_tokens["attention_mask"],

        "full_input_ids": full_tokens["input_ids"],
        "full_attention_mask": full_tokens["attention_mask"],

        "labels_ids": [
            [cat2ids[cat] for cat in categories]
            for categories in examples[categories_column]
        ],
    }


def json_to_dataset_(data: Dict[str, Any], categories_column: str) -> Tuple[Dataset, List[str]]:
    rows = []
    cats_names = set()
    for arxiv_id, fields in data.items():
        row = {
            "id": arxiv_id,
            **fields
        }
        rows.append(row)
        cats_names.update(fields[categories_column])

    return Dataset.from_list(rows), list(cats_names)


def dataset_preprocess(

        dataset_path: str | PathLike[str],

        classifier_name: str,

        categories_column: str):

    if not Path(dataset_path).exists():
        raise FileNotFoundError

    with open(dataset_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    dataset, categories = json_to_dataset_(data, categories_column)
    categories = sorted(categories)

    cat2ids = {}
    ids2cat = {}

    for idx, cat in enumerate(categories):
        cat2ids[cat] = idx
        ids2cat[idx] = cat

    tokenizer = AutoTokenizer.from_pretrained(classifier_name)
    model_cfg = AutoConfig.from_pretrained(classifier_name)

    tokenizer_cfg = {
        "truncation": True,
        "padding": "max_length",
        "max_length": model_cfg.max_position_embeddings
    }

    dataset = dataset.map(
        lambda examples: transform_(
            examples,
            tokenizer,
            categories_column,
            cat2ids,
            tokenizer_cfg
        ),
        batched=True,
        batch_size=10_000,
        remove_columns=dataset.column_names,
    )

    return {
        "dataset": dataset,
        "cat2ids": cat2ids,
        "ids2cat": ids2cat,
    }