File size: 3,658 Bytes
608ff23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import functools
from typing import Dict

import seqio
import tensorflow as tf
from datasets import load_dataset, load_from_disk
from t5.evaluation import metrics
from seqio import utils, FunctionDataSource
import t5.data
from datasets import load_dataset, load_from_disk
from t5.data import postprocessors
from t5.data import preprocessors


from ul2_objective import ul2_objective

# values from UL2 paper https://arxiv.org/pdf/2205.05131.pdf chapter 3.1.2 table 1
R_DENOISER_SPAN_LENGTHS = [3.0, 8.0]
X_DENOISER_SPAN_LENGTHS = [3.0, 8.0, 64.0, 64.0]
R_DENOISER_CORRUPT_RATES = [0.15, 0.15]
X_DENOISER_CORRUPT_RATES = [0.5, 0.5, 0.15, 0.5]

R_DENOISER_TOKEN_PREFIX = "[NLU]"
X_DENOISER_TOKEN_PREFIX = "[NLG]"
S_DENOISER_TOKEN_PREFIX = "[S2S]"

TaskRegistry = seqio.TaskRegistry

vocabulary = seqio.SentencePieceVocabulary('spiece.model')

DEFAULT_OUTPUT_FEATURES = {
    "inputs": seqio.Feature(vocabulary=vocabulary, add_eos=True, required=False),
    "targets": seqio.Feature(vocabulary=vocabulary, add_eos=True),
}

def gen_dataset(split, shuffle=False, seed=None, column="text", path=None, name=None):
    dataset = load_dataset(path, name, streaming=True, use_auth_token=True)
    # dataset = load_from_disk(path)
    
    if shuffle:
        if seed:
            dataset = dataset.shuffle(seed=seed)
        else:
            dataset = dataset.shuffle()
    while True:
        for item in dataset[str(split)]:
            yield item[column]


def dataset_fn(split, shuffle_files, seed=None, path=None, name=None):
    return tf.data.Dataset.from_generator(
        functools.partial(
            gen_dataset, split, shuffle_files, seed, path=path, name=name
        ),
        output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=path),
    )


@utils.map_over_dataset
def target_to_key(x, key_map, target_key):
    """Assign the value from the dataset to target_key in key_map"""
    return {**key_map, target_key: x}


TaskRegistry.add(
    "pretrain_medical_ul2",
    source=seqio.FunctionDataSource(
        dataset_fn=functools.partial(
            dataset_fn, path="Siddharth63/medical_dataset",
        ),
        splits=("train", "validation"),
        caching_permitted=False,
    ),
    preprocessors=[
        functools.partial(
            target_to_key,
            key_map={
                "inputs": "text",
                "targets": "text",
            },
            target_key="targets",
        ),
        seqio.preprocessors.tokenize,
        functools.partial(
            ul2_objective,
            shard_ds=False,
            use_prefix_lm_task=True,  # use S-denoising
            rates=[0.4 / len(R_DENOISER_SPAN_LENGTHS)] * len(R_DENOISER_SPAN_LENGTHS)
            + [0.4 / len(X_DENOISER_SPAN_LENGTHS)] * len(X_DENOISER_SPAN_LENGTHS)
            + [
                0.2
            ],  # equal total 40% rate for both R- and X-denoisers + 20% for S-denoising (suggested at the paper chapter 4.5)
            mean_noise_span_lengths=R_DENOISER_SPAN_LENGTHS + X_DENOISER_SPAN_LENGTHS,
            noise_densities=R_DENOISER_CORRUPT_RATES + X_DENOISER_CORRUPT_RATES,
            optional_task_prefixes=[R_DENOISER_TOKEN_PREFIX]
            * len(R_DENOISER_SPAN_LENGTHS)
            + [X_DENOISER_TOKEN_PREFIX] * len(X_DENOISER_SPAN_LENGTHS)
            + [S_DENOISER_TOKEN_PREFIX],
            reserved_for_packing=1,  # make room for task prefix token
        ),
        seqio.preprocessors.append_eos_after_trim,
    ],
    output_features={
        "targets": DEFAULT_OUTPUT_FEATURES["targets"],
        "inputs": seqio.Feature(vocabulary=vocabulary, add_eos=True),
    },
    metric_fns=[metrics.accuracy],
)