pere commited on
Commit
ea6dc92
·
1 Parent(s): dbf98c8
train_exp1_base_engvoc.sh → :q RENAMED
@@ -1,8 +1,8 @@
1
- PROJECT_DIR=${HOME}"/models/long-t5x"
2
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
- MODEL_DIR="gs://nb-t5x-us-central2/exp1-t5-base-engvoc"
4
  export PYTHONPATH=${PROJECT_DIR}
5
- MIXTURE_OR_TASK_NAME="ncc_scandinavian_span_corruption_stream_engvoc"
6
 
7
  python3 ${T5X_DIR}/t5x/train.py \
8
  --gin_search_paths=${PROJECT_DIR} \
 
1
+ PROJECT_DIR=${HOME}"/models/ul2-t5x"
2
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
+ MODEL_DIR="gs://nb-t5x-us-central2/exp3-t5-base-ul2-engvoc"
4
  export PYTHONPATH=${PROJECT_DIR}
5
+ MIXTURE_OR_TASK_NAME="scandinavian_ul2_engvoc"
6
 
7
  python3 ${T5X_DIR}/t5x/train.py \
8
  --gin_search_paths=${PROJECT_DIR} \
__pycache__/tasks.cpython-38.pyc ADDED
Binary file (3.57 kB). View file
 
__pycache__/ul2_objective.cpython-38.pyc ADDED
Binary file (6.21 kB). View file
 
base.gin CHANGED
@@ -3,14 +3,13 @@ include 'pretrain_cont.gin'
3
  #include 't5x/configs/runs/pretrain.gin'
4
  #iinclude 't5x/configs/runs/finetune.gin'
5
 
6
-
7
  # Register necessary SeqIO Tasks/Mixtures.
8
  import t5.data.mixtures
9
  import tasks
10
 
11
  MIXTURE_OR_TASK_NAME = %gin.REQUIRED
12
  TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
13
- TRAIN_STEPS = 1_700_000
14
  DROPOUT_RATE = 0.0 # Changed from the default since T5-1.1 recomments this.
15
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_NCC_plus_English_t5x_base/checkpoint_1500000"
16
  #INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_base/checkpoint_1000000"
 
3
  #include 't5x/configs/runs/pretrain.gin'
4
  #iinclude 't5x/configs/runs/finetune.gin'
5
 
 
6
  # Register necessary SeqIO Tasks/Mixtures.
7
  import t5.data.mixtures
8
  import tasks
9
 
10
  MIXTURE_OR_TASK_NAME = %gin.REQUIRED
11
  TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
12
+ TRAIN_STEPS = 1_300_000
13
  DROPOUT_RATE = 0.0 # Changed from the default since T5-1.1 recomments this.
14
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_NCC_plus_English_t5x_base/checkpoint_1500000"
15
  #INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_base/checkpoint_1000000"
base_modified_scand.gin ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # T5.1.1 Base model.
2
+ from __gin__ import dynamic_registration
3
+
4
+ import seqio
5
+ from t5x import adafactor
6
+ from t5x import models
7
+ from t5x.examples.t5 import network
8
+
9
+ # ------------------- Loss HParam ----------------------------------------------
10
+ Z_LOSS = 0.0001
11
+ LABEL_SMOOTHING = 0.0
12
+ # NOTE: When fine-tuning the public T5 checkpoints (trained in T5 MeshTF)
13
+ # the loss normalizing factor should be set to pretraining batch_size *
14
+ # target_token_length.
15
+ LOSS_NORMALIZING_FACTOR = None
16
+ # Dropout should be specified in the "run" files
17
+ DROPOUT_RATE = %gin.REQUIRED
18
+
19
+ # Vocabulary (shared by encoder and decoder)
20
+ VOCABULARY = @seqio.SentencePieceVocabulary()
21
+ seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://nb-t5/t5/vocabs/wikipedia/no-da-en-sv-nn-is_32000_unigram.sp.model"
22
+ seqio.SentencePieceVocabulary.extra_ids = 100
23
+
24
+ # ------------------- Optimizer ------------------------------------------------
25
+ # `learning_rate` is set by `Trainer.learning_rate_fn`.
26
+ OPTIMIZER = @adafactor.Adafactor()
27
+ adafactor.Adafactor:
28
+ decay_rate = 0.8
29
+ step_offset = 0
30
+ logical_factor_rules = @adafactor.standard_logical_factor_rules()
31
+
32
+ # ------------------- Model ----------------------------------------------------
33
+ MODEL = @models.EncoderDecoderModel()
34
+ models.EncoderDecoderModel:
35
+ module = @network.Transformer()
36
+ input_vocabulary = %VOCABULARY
37
+ output_vocabulary = %VOCABULARY
38
+ optimizer_def = %OPTIMIZER
39
+ z_loss = %Z_LOSS
40
+ label_smoothing = %LABEL_SMOOTHING
41
+ loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
42
+
43
+ # ------------------- Network specification ------------------------------------
44
+ network.Transformer.config = @network.T5Config()
45
+ network.T5Config:
46
+ vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency
47
+ dtype = 'bfloat16'
48
+ emb_dim = 768
49
+ num_heads = 12
50
+ num_encoder_layers = 12
51
+ num_decoder_layers = 12
52
+ head_dim = 64
53
+ mlp_dim = 2048
54
+ mlp_activations = ('gelu', 'linear')
55
+ dropout_rate = %DROPOUT_RATE
56
+ logits_via_embedding = False
base_scandvoc.gin ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ include 'base_modified_scand.gin'
2
+ include 'pretrain_cont.gin'
3
+ #include 't5x/configs/runs/pretrain.gin'
4
+ #iinclude 't5x/configs/runs/finetune.gin'
5
+
6
+ # Register necessary SeqIO Tasks/Mixtures.
7
+ import t5.data.mixtures
8
+ import tasks
9
+
10
+ MIXTURE_OR_TASK_NAME = %gin.REQUIRED
11
+ TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
12
+ TRAIN_STEPS = 1_300_000
13
+ DROPOUT_RATE = 0.0 # Changed from the default since T5-1.1 recomments this.
14
+ #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_NCC_plus_English_t5x_base/checkpoint_1500000"
15
+ #INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_base/checkpoint_1000000"
16
+ INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/flan_t5_base/checkpoint_1184000"
17
+
18
+
19
+ PjitPartitioner.num_partitions = 1
20
+
21
+
batch_nynorsk_NCC_base.sh DELETED
@@ -1,11 +0,0 @@
1
- PROJECT_DIR=${HOME}"/models/t5-nynorsk-oversetter"
2
- export PYTHONPATH=${PROJECT_DIR}
3
- INITIAL_CHECKPOINT_PATH=\"gs://nb-t5x-us-central2/norwegian_NCC_plus_English_t5x_base/checkpoint_1500000\"
4
- TRAIN_STEPS=1505000
5
-
6
- python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_translate_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"translate\" --gin.MODEL_DIR=\"gs://nb-t5x-us-central2/finetuned/nynorsk_NCC_base_v1\" &&
7
- python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_translate_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"translate\" --gin.MODEL_DIR=\"gs://nb-t5x-us-central2/finetuned/nynorsk_NCC_base_v2\" &&
8
- python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_translate_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"translate\" --gin.MODEL_DIR=\"gs://nb-t5x-us-central2/finetuned/nynorsk_NCC_base_v3\" &&
9
- python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_translate_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"translate\" --gin.MODEL_DIR=\"gs://nb-t5x-us-central2/finetuned/nynorsk_NCC_base_v4\" &&
10
- python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_translate_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"translate\" --gin.MODEL_DIR=\"gs://nb-t5x-us-central2/finetuned/nynorsk_NCC_base_v5\"
11
-
 
 
 
 
 
 
 
 
 
 
 
 
pretrain_cont.gin CHANGED
@@ -96,7 +96,7 @@ utils.RestoreCheckpointConfig:
96
  mode = 'specific'
97
  dtype = 'float32'
98
  utils.SaveCheckpointConfig:
99
- period = 1000
100
  dtype = 'float32'
101
  keep = None # keep all checkpoints
102
  save_dataset = False # don't checkpoint dataset state
 
96
  mode = 'specific'
97
  dtype = 'float32'
98
  utils.SaveCheckpointConfig:
99
+ period = 10000
100
  dtype = 'float32'
101
  keep = None # keep all checkpoints
102
  save_dataset = False # don't checkpoint dataset state
tasks.py CHANGED
@@ -8,6 +8,18 @@ from t5.data import preprocessors
8
  from t5.evaluation import metrics
9
  from seqio import FunctionDataSource, utils
10
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  TaskRegistry = seqio.TaskRegistry
12
  scand_vocabulary=seqio.SentencePieceVocabulary('gs://nb-t5/t5/vocabs/wikipedia/no-da-en-sv-nn-is_32000_unigram.sp.model', extra_ids=100)
13
  eng_vocabulary=seqio.SentencePieceVocabulary('gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model', extra_ids=0)
@@ -60,7 +72,7 @@ dataset_name = 'NbAiLab/scandinavian'
60
  dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
61
  dataset_shapes = None
62
  TaskRegistry.add(
63
- "ncc_scandinavian_span_corruption_stream_engvoc",
64
  source=seqio.FunctionDataSource(
65
  dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
66
  splits=("train", "validation"),
@@ -83,12 +95,12 @@ TaskRegistry.add(
83
  )
84
 
85
 
86
- # Final pretraining task used in Raffel et al., 2019 adaptated to NCC
87
  dataset_name = 'NbAiLab/scandinavian'
88
  dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
89
  dataset_shapes = None
90
  TaskRegistry.add(
91
- "ncc_scandinavian_span_corruption_stream_scandvoc",
92
  source=seqio.FunctionDataSource(
93
  dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
94
  splits=("train", "validation"),
@@ -111,4 +123,80 @@ TaskRegistry.add(
111
  )
112
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
 
8
  from t5.evaluation import metrics
9
  from seqio import FunctionDataSource, utils
10
 
11
+ from ul2_objective import ul2_objective
12
+ # values from UL2 paper https://arxiv.org/pdf/2205.05131.pdf chapter 3.1.2 table 1
13
+ R_DENOISER_SPAN_LENGTHS = [3.0, 8.0]
14
+ X_DENOISER_SPAN_LENGTHS = [3.0, 8.0, 64.0, 64.0]
15
+ R_DENOISER_CORRUPT_RATES = [0.15, 0.15]
16
+ X_DENOISER_CORRUPT_RATES = [0.5, 0.5, 0.15, 0.5]
17
+
18
+ R_DENOISER_TOKEN_PREFIX = '[NLU]'
19
+ X_DENOISER_TOKEN_PREFIX = '[NLG]'
20
+ S_DENOISER_TOKEN_PREFIX = '[S2S]'
21
+
22
+
23
  TaskRegistry = seqio.TaskRegistry
24
  scand_vocabulary=seqio.SentencePieceVocabulary('gs://nb-t5/t5/vocabs/wikipedia/no-da-en-sv-nn-is_32000_unigram.sp.model', extra_ids=100)
25
  eng_vocabulary=seqio.SentencePieceVocabulary('gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model', extra_ids=0)
 
72
  dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
73
  dataset_shapes = None
74
  TaskRegistry.add(
75
+ "scandinavian_span_engvoc",
76
  source=seqio.FunctionDataSource(
77
  dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
78
  splits=("train", "validation"),
 
95
  )
96
 
97
 
98
+ # Final pretraining task used in Tay et al., 2022 adaptated by @beats
99
  dataset_name = 'NbAiLab/scandinavian'
100
  dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
101
  dataset_shapes = None
102
  TaskRegistry.add(
103
+ "scandinavian_span_scandvoc",
104
  source=seqio.FunctionDataSource(
105
  dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
106
  splits=("train", "validation"),
 
123
  )
124
 
125
 
126
+ # Final pretraining task used in Tay et al., 2022 adaptated by @beats
127
+ dataset_name = 'NbAiLab/scandinavian'
128
+ dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
129
+ dataset_shapes = None
130
+ TaskRegistry.add(
131
+ "scandinavian_ul2_engvoc",
132
+ source=seqio.FunctionDataSource(
133
+ dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
134
+ splits=("train", "validation"),
135
+ caching_permitted=False,
136
+ num_input_examples=dataset_shapes,
137
+ ),
138
+ preprocessors=[
139
+ functools.partial(
140
+ target_to_key, key_map={
141
+ "inputs": None,
142
+ "targets": None,
143
+ }, target_key="targets"),
144
+ seqio.preprocessors.tokenize,
145
+ functools.partial(
146
+ ul2_objective,
147
+ shard_ds=False,
148
+ use_prefix_lm_task=True, # use S-denoising
149
+ rates=[0.4 / len(R_DENOISER_SPAN_LENGTHS)]*len(R_DENOISER_SPAN_LENGTHS) + [
150
+ 0.4 / len(X_DENOISER_SPAN_LENGTHS)]*len(X_DENOISER_SPAN_LENGTHS) + [0.2], # equal total 40% rate for both R- and X-denoisers + 20% for S-denoising (suggested at the paper chapter 4.5)
151
+ mean_noise_span_lengths=R_DENOISER_SPAN_LENGTHS + X_DENOISER_SPAN_LENGTHS,
152
+ noise_densities=R_DENOISER_CORRUPT_RATES + X_DENOISER_CORRUPT_RATES,
153
+ optional_task_prefixes=[R_DENOISER_TOKEN_PREFIX]*len(R_DENOISER_SPAN_LENGTHS) + [
154
+ X_DENOISER_TOKEN_PREFIX]*len(X_DENOISER_SPAN_LENGTHS) + [S_DENOISER_TOKEN_PREFIX],
155
+ reserved_for_packing=1, # make room for task prefix token
156
+ ),
157
+ seqio.preprocessors.append_eos_after_trim,
158
+ ],
159
+ output_features={"targets": ENG_OUTPUT_FEATURES["targets"]},
160
+ metric_fns=[]
161
+ )
162
+
163
+
164
+ # Final pretraining task used in Raffel et al., 2019 adaptated to NCC
165
+ dataset_name = 'NbAiLab/scandinavian'
166
+ dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
167
+ dataset_shapes = None
168
+ TaskRegistry.add(
169
+ "scandinavian_ul2_scandvoc",
170
+ source=seqio.FunctionDataSource(
171
+ dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
172
+ splits=("train", "validation"),
173
+ caching_permitted=False,
174
+ num_input_examples=dataset_shapes,
175
+ ),
176
+ preprocessors=[
177
+ functools.partial(
178
+ target_to_key, key_map={
179
+ "inputs": None,
180
+ "targets": None,
181
+ }, target_key="targets"),
182
+ seqio.preprocessors.tokenize,
183
+ functools.partial(
184
+ ul2_objective,
185
+ shard_ds=False,
186
+ use_prefix_lm_task=True, # use S-denoising
187
+ rates=[0.4 / len(R_DENOISER_SPAN_LENGTHS)]*len(R_DENOISER_SPAN_LENGTHS) + [
188
+ 0.4 / len(X_DENOISER_SPAN_LENGTHS)]*len(X_DENOISER_SPAN_LENGTHS) + [0.2], # equal total 40% rate for both R- and X-denoisers + 20% for S-denoising (suggested at the paper chapter 4.5)
189
+ mean_noise_span_lengths=R_DENOISER_SPAN_LENGTHS + X_DENOISER_SPAN_LENGTHS,
190
+ noise_densities=R_DENOISER_CORRUPT_RATES + X_DENOISER_CORRUPT_RATES,
191
+ optional_task_prefixes=[R_DENOISER_TOKEN_PREFIX]*len(R_DENOISER_SPAN_LENGTHS) + [
192
+ X_DENOISER_TOKEN_PREFIX]*len(X_DENOISER_SPAN_LENGTHS) + [S_DENOISER_TOKEN_PREFIX],
193
+ reserved_for_packing=1, # make room for task prefix token
194
+ ),
195
+ seqio.preprocessors.append_eos_after_trim,
196
+ ],
197
+ output_features={"targets": SCAND_OUTPUT_FEATURES["targets"]},
198
+ metric_fns=[]
199
+ )
200
+
201
+
202
 
train_exp1_base_scandvoc.sh → train_exp1_base_ul2_engvoc.sh RENAMED
@@ -1,8 +1,8 @@
1
- PROJECT_DIR=${HOME}"/models/long-t5x"
2
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
- MODEL_DIR="gs://nb-t5x-us-central2/exp1-t5-base-scandvoc"
4
  export PYTHONPATH=${PROJECT_DIR}
5
- MIXTURE_OR_TASK_NAME="ncc_scandinavian_span_corruption_stream_scandvoc"
6
 
7
  python3 ${T5X_DIR}/t5x/train.py \
8
  --gin_search_paths=${PROJECT_DIR} \
 
1
+ PROJECT_DIR=${HOME}"/models/ul2-t5x"
2
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
+ MODEL_DIR="gs://nb-t5x-us-central2/exp3-t5-base-ul2-engvoc"
4
  export PYTHONPATH=${PROJECT_DIR}
5
+ MIXTURE_OR_TASK_NAME="scandinavian_ul2_engvoc"
6
 
7
  python3 ${T5X_DIR}/t5x/train.py \
8
  --gin_search_paths=${PROJECT_DIR} \
train_exp2_base_ul2_scandvoc.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROJECT_DIR=${HOME}"/models/ul2-t5x"
2
+ T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
+ MODEL_DIR="gs://nb-t5x-us-central2/exp2-t5-base-ul2-scandvoc"
4
+ export PYTHONPATH=${PROJECT_DIR}
5
+ MIXTURE_OR_TASK_NAME="scandinavian_ul2_scandvoc"
6
+
7
+ python3 ${T5X_DIR}/t5x/train.py \
8
+ --gin_search_paths=${PROJECT_DIR} \
9
+ --gin_file="base_scandvoc.gin" \
10
+ --gin.MODEL_DIR="'${MODEL_DIR}'" \
11
+ --gin.MIXTURE_OR_TASK_NAME="'${MIXTURE_OR_TASK_NAME}'" \
12
+
train_exp3_base_span_engvoc.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROJECT_DIR=${HOME}"/models/ul2-t5x"
2
+ T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
+ MODEL_DIR="gs://nb-t5x-us-central2/exp3-t5-base-span-engvoc"
4
+ export PYTHONPATH=${PROJECT_DIR}
5
+ MIXTURE_OR_TASK_NAME="scandinavian_span_engvoc"
6
+
7
+ python3 ${T5X_DIR}/t5x/train.py \
8
+ --gin_search_paths=${PROJECT_DIR} \
9
+ --gin_file="base.gin" \
10
+ --gin.MODEL_DIR="'${MODEL_DIR}'" \
11
+ --gin.MIXTURE_OR_TASK_NAME="'${MIXTURE_OR_TASK_NAME}'" \
12
+
train_exp4_base_span_scandvoc.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROJECT_DIR=${HOME}"/models/ul2-t5x"
2
+ T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
+ MODEL_DIR="gs://nb-t5x-us-central2/exp4-t5-base-span-scandvoc"
4
+ export PYTHONPATH=${PROJECT_DIR}
5
+ MIXTURE_OR_TASK_NAME="scandinavian_span_scandvoc"
6
+
7
+ python3 ${T5X_DIR}/t5x/train.py \
8
+ --gin_search_paths=${PROJECT_DIR} \
9
+ --gin_file="base_scandvoc.gin" \
10
+ --gin.MODEL_DIR="'${MODEL_DIR}'" \
11
+ --gin.MIXTURE_OR_TASK_NAME="'${MIXTURE_OR_TASK_NAME}'" \
12
+
ul2_objective.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import tensorflow as tf
3
+ import seqio
4
+ import t5.data
5
+ from typing import Optional, Sequence
6
+
7
+ # found this function and modified from https://github.com/GoogleCloudPlatform/t5x-on-vertex-ai/blob/main/tasks/custom_tasks.py#L78
8
+ # UL2 paper appendix code missed this function
9
+ def prepend_prompt(dataset: tf.data.Dataset,
10
+ output_features: seqio.preprocessors.OutputFeaturesType,
11
+ sequence_length: Optional[
12
+ seqio.preprocessors.SequenceLengthType] = None,
13
+ prompt_mode: str = "",
14
+ key: str = "inputs",
15
+ mode: str = "") -> tf.data.Dataset:
16
+ """Prepends a prompt at the beginning of an input sequence."""
17
+ del sequence_length
18
+ if prompt_mode and mode:
19
+ # output_features may not have inputs key
20
+ out_keys = list(output_features.keys())
21
+ prompt_tokens = output_features[out_keys[0]
22
+ ].vocabulary.encode_tf(prompt_mode)
23
+
24
+ def add_to_inputs(x):
25
+ x[key] = tf.concat([prompt_tokens, x[key]], axis=0)
26
+ return x
27
+
28
+ dataset = dataset.map(add_to_inputs)
29
+ return dataset
30
+
31
+ # modified from t5.data.preprocessors because output_features may not have inputs key
32
+ def split_tokens_to_inputs_length(dataset, sequence_length,
33
+ output_features, **kwargs):
34
+ max_tokens = sequence_length['inputs']
35
+ # output_features may not have inputs key
36
+ out_keys = list(output_features.keys())
37
+ if output_features[out_keys[0]].add_eos:
38
+ # Leave room to insert an EOS token.
39
+ max_tokens -= 1
40
+
41
+ return t5.data.preprocessors.split_tokens(dataset, max_tokens_per_segment=max_tokens, **kwargs)
42
+
43
+ # modified from t5.data.preprocessors because output_features may not have inputs key
44
+ def prefix_lm(dataset, sequence_length, output_features):
45
+ """Prefix language modeling objective used in Raffel et al. 2019."""
46
+ ds = dataset
47
+ ds = t5.data.preprocessors.select_random_chunk(ds, output_features=output_features,
48
+ feature_key='targets', max_length=65536)
49
+ ds = split_tokens_to_inputs_length(ds, output_features=output_features,
50
+ sequence_length=sequence_length)
51
+ ds = t5.data.preprocessors.denoise(
52
+ ds,
53
+ output_features,
54
+ inputs_fn=t5.data.preprocessors.drop_nonnoise_tokens,
55
+ targets_fn=t5.data.preprocessors.drop_noise_tokens,
56
+ noise_density=0.5,
57
+ noise_mask_fn=t5.data.preprocessors.random_prefix_noise_mask,
58
+ )
59
+ return ds
60
+
61
+ # copied from UL2 paper https://arxiv.org/pdf/2205.05131.pdf appendix chapter 9.2
62
+ # note: modified to use the prefix_lm() from above instead of the default t5.data.preprocessors.prefix_lm() because output_features may not have inputs key
63
+ def ul2_objective(dataset: tf.data.Dataset,
64
+ sequence_length: seqio.preprocessors.SequenceLengthType,
65
+ output_features: seqio.preprocessors.OutputFeaturesType,
66
+ use_prefix_lm_task: bool = False,
67
+ rates: Optional[Sequence[float]] = None,
68
+ mean_noise_span_lengths: Sequence[float] = (3.0,),
69
+ noise_densities: Sequence[float] = (0.15,),
70
+ shard_ds: bool = True,
71
+ optional_task_prefixes: Optional[Sequence[str]] = None,
72
+ input_feature_key: str = "inputs",
73
+ merge_examples_to_reduce_padding: bool = True,
74
+ reserved_for_packing: bool = None,
75
+ seed: int = 7) -> tf.data.Dataset:
76
+ """UL2-like pre-training objectives.
77
+ This preprocessor amounts to calling the 'span_corruption' function several
78
+ times with different values of 'noise_density' and 'mean_noise_span_length'.
79
+ We either shard or copy the dataset, then apply each function to each shard.
80
+ Add S-denoising (prefixLM) using use_prefix_lm_task.
81
+ Args:
82
+ dataset: A tf.data.Dataset with dictionaries containing the key 'input_feature_key'.
83
+ sequence_length: dict mapping of feature key to int length for that feature.
84
+ output_features: mapping of keys to features.
85
+ use_prefix_lm_task: <bool> If True, include PrefixLM in the task mix.
86
+ rates: <Optional<List<float>> List of rates per task. If None, tasks are sampled uniformly.
87
+ mean_noise_span_lengths: List of mean number of tokens per masked span per example.
88
+ noise_densities: List of what fraction of the tokens to mask.
89
+ shard_ds: <bool> If True, shard dataset per objective.
90
+ optional_task_prefixes: <Optional<list<str>> Strings to prepend for each corruption scheme. NOTE: If including prefixLM task, it must be the last prefix.
91
+ input_feature_key: which feature to use from the dataset as the input text tokens.
92
+ merge_examples_to_reduce_padding: if True, combines multiple input examples to reduce padding.
93
+ reserved_for_packing: if specified, reduces the desired inputs length by the specified amount to enable multiple examples to be packed together downstream.
94
+ seed: tf.int64 for controlling the random choice of spans.
95
+ Returns:
96
+ a dataset
97
+ """
98
+
99
+ if optional_task_prefixes: # Ensure each task has a prefix.
100
+ num_tasks = len(noise_densities) + int(use_prefix_lm_task)
101
+ valid_number_of_prefixes = num_tasks == len(optional_task_prefixes)
102
+ if not valid_number_of_prefixes:
103
+ raise ValueError(
104
+ "Number of task prefixes must match number of tasks.")
105
+ inputs_length = sequence_length[input_feature_key]
106
+ input_lengths, targets_lengths = [], []
107
+ sequence_lengths = {x: y for x, y in sequence_length.items()}
108
+ if reserved_for_packing:
109
+ inputs_length -= reserved_for_packing
110
+ for x, y in sequence_length.items():
111
+ sequence_lengths[x] = y - reserved_for_packing
112
+ hyperparams = list(zip(mean_noise_span_lengths, noise_densities))
113
+ for mean_noise_span_length, noise_density in hyperparams:
114
+ input_length, targets_length = t5.data.preprocessors.random_spans_helper(
115
+ extra_tokens_per_span_inputs=1,
116
+ extra_tokens_per_span_targets=1,
117
+ inputs_length=inputs_length,
118
+ mean_noise_span_length=mean_noise_span_length,
119
+ noise_density=noise_density)
120
+ input_lengths.append(input_length)
121
+ targets_lengths.append(targets_length)
122
+
123
+ if sequence_length["targets"] < targets_length:
124
+ upper_bound = max(targets_lengths)
125
+ raise ValueError(
126
+ f'Expected max targets length for span corruption ({upper_bound}) is '
127
+ f'greater than configured targets length '
128
+ f"({sequence_length['targets']})")
129
+ ds = dataset
130
+ ds = t5.data.preprocessors.select_random_chunk(
131
+ ds,
132
+ output_features=output_features,
133
+ feature_key="targets",
134
+ max_length=65536)
135
+ if merge_examples_to_reduce_padding:
136
+ ds = t5.data.preprocessors.reduce_concat_tokens(
137
+ ds, feature_key="targets", batch_size=128)
138
+ num_shards = len(input_lengths) + int(use_prefix_lm_task)
139
+ if shard_ds:
140
+ ds_shards = [ds.shard(num_shards, i) for i in range(num_shards)]
141
+ else:
142
+ ds_shards = [ds for _ in range(num_shards)]
143
+ processed_ds = []
144
+ hyperparams = zip(input_lengths, hyperparams, range(num_shards))
145
+ for input_length, (noise_span_length, noise_density), i in hyperparams:
146
+ ds = ds_shards[i]
147
+ ds = t5.data.preprocessors.split_tokens(
148
+ ds,
149
+ feature_key="targets",
150
+ min_tokens_per_segment=None,
151
+ max_tokens_per_segment=input_length)
152
+ ds = t5.data.preprocessors.denoise(
153
+ ds,
154
+ output_features,
155
+ inputs_fn=t5.data.preprocessors.noise_span_to_unique_sentinel,
156
+ targets_fn=t5.data.preprocessors.nonnoise_span_to_unique_sentinel,
157
+ noise_density=noise_density,
158
+ noise_mask_fn=functools.partial(
159
+ t5.data.preprocessors.random_spans_noise_mask,
160
+ mean_noise_span_length=noise_span_length),
161
+ input_feature_key=input_feature_key)
162
+ if optional_task_prefixes:
163
+ ds = prepend_prompt(
164
+ ds,
165
+ output_features,
166
+ prompt_mode=optional_task_prefixes[i],
167
+ mode=optional_task_prefixes[i],
168
+ key=input_feature_key)
169
+ processed_ds.append(ds)
170
+ if use_prefix_lm_task:
171
+ ds = ds_shards[-1]
172
+ ds = prefix_lm(
173
+ ds, sequence_lengths, output_features)
174
+ if optional_task_prefixes:
175
+ ds = prepend_prompt(
176
+ ds,
177
+ output_features,
178
+ prompt_mode=optional_task_prefixes[-1],
179
+ mode=optional_task_prefixes[-1],
180
+ key=input_feature_key)
181
+ processed_ds.append(ds)
182
+ ds = tf.data.experimental.sample_from_datasets(processed_ds, rates, seed)
183
+ return ds
184
+