test
Browse files- base.gin +1 -1
- batch_nynorsk_NCC_base.sh +11 -0
- tasks.py +42 -6
- train_base.sh → train_exp1_base_engvoc.sh +4 -1
- train_exp1_base_scandvoc.sh +12 -0
base.gin
CHANGED
|
@@ -8,7 +8,7 @@ include 'pretrain_cont.gin'
|
|
| 8 |
import t5.data.mixtures
|
| 9 |
import tasks
|
| 10 |
|
| 11 |
-
MIXTURE_OR_TASK_NAME =
|
| 12 |
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
|
| 13 |
TRAIN_STEPS = 1_700_000
|
| 14 |
DROPOUT_RATE = 0.0 # Changed from the default since T5-1.1 recomments this.
|
|
|
|
| 8 |
import t5.data.mixtures
|
| 9 |
import tasks
|
| 10 |
|
| 11 |
+
MIXTURE_OR_TASK_NAME = %gin.REQUIRED
|
| 12 |
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
|
| 13 |
TRAIN_STEPS = 1_700_000
|
| 14 |
DROPOUT_RATE = 0.0 # Changed from the default since T5-1.1 recomments this.
|
batch_nynorsk_NCC_base.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
PROJECT_DIR=${HOME}"/models/t5-nynorsk-oversetter"
|
| 2 |
+
export PYTHONPATH=${PROJECT_DIR}
|
| 3 |
+
INITIAL_CHECKPOINT_PATH=\"gs://nb-t5x-us-central2/norwegian_NCC_plus_English_t5x_base/checkpoint_1500000\"
|
| 4 |
+
TRAIN_STEPS=1505000
|
| 5 |
+
|
| 6 |
+
python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_translate_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"translate\" --gin.MODEL_DIR=\"gs://nb-t5x-us-central2/finetuned/nynorsk_NCC_base_v1\" &&
|
| 7 |
+
python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_translate_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"translate\" --gin.MODEL_DIR=\"gs://nb-t5x-us-central2/finetuned/nynorsk_NCC_base_v2\" &&
|
| 8 |
+
python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_translate_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"translate\" --gin.MODEL_DIR=\"gs://nb-t5x-us-central2/finetuned/nynorsk_NCC_base_v3\" &&
|
| 9 |
+
python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_translate_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"translate\" --gin.MODEL_DIR=\"gs://nb-t5x-us-central2/finetuned/nynorsk_NCC_base_v4\" &&
|
| 10 |
+
python3 ../../t5x/t5x/train.py --gin_search_paths="./" --gin.TRAIN_STEPS=${TRAIN_STEPS} --gin_file="finetune_translate_base.gin" --gin.INITIAL_CHECKPOINT_PATH=${INITIAL_CHECKPOINT_PATH} --gin.MIXTURE_OR_TASK_NAME=\"translate\" --gin.MODEL_DIR=\"gs://nb-t5x-us-central2/finetuned/nynorsk_NCC_base_v5\"
|
| 11 |
+
|
tasks.py
CHANGED
|
@@ -9,16 +9,24 @@ from t5.evaluation import metrics
|
|
| 9 |
from seqio import FunctionDataSource, utils
|
| 10 |
|
| 11 |
TaskRegistry = seqio.TaskRegistry
|
| 12 |
-
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
"inputs": seqio.Feature(
|
| 16 |
-
vocabulary=
|
| 17 |
required=False),
|
| 18 |
"targets": seqio.Feature(
|
| 19 |
-
vocabulary=
|
| 20 |
}
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None):
|
| 24 |
dataset = load_dataset(**dataset_params)
|
|
@@ -52,7 +60,35 @@ dataset_name = 'NbAiLab/scandinavian'
|
|
| 52 |
dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
|
| 53 |
dataset_shapes = None
|
| 54 |
TaskRegistry.add(
|
| 55 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
source=seqio.FunctionDataSource(
|
| 57 |
dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
|
| 58 |
splits=("train", "validation"),
|
|
@@ -70,7 +106,7 @@ TaskRegistry.add(
|
|
| 70 |
preprocessors.span_corruption,
|
| 71 |
seqio.preprocessors.append_eos_after_trim,
|
| 72 |
],
|
| 73 |
-
output_features={"targets":
|
| 74 |
metric_fns=[]
|
| 75 |
)
|
| 76 |
|
|
|
|
| 9 |
from seqio import FunctionDataSource, utils
|
| 10 |
|
| 11 |
TaskRegistry = seqio.TaskRegistry
|
| 12 |
+
scand_vocabulary=seqio.SentencePieceVocabulary('gs://nb-t5/t5/vocabs/wikipedia/no-da-en-sv-nn-is_32000_unigram.sp.model', extra_ids=100)
|
| 13 |
+
eng_vocabulary=seqio.SentencePieceVocabulary('gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model', extra_ids=0)
|
| 14 |
|
| 15 |
+
SCAND_OUTPUT_FEATURES = {
|
| 16 |
"inputs": seqio.Feature(
|
| 17 |
+
vocabulary=scand_vocabulary, add_eos=True,
|
| 18 |
required=False),
|
| 19 |
"targets": seqio.Feature(
|
| 20 |
+
vocabulary=scand_vocabulary, add_eos=True)
|
| 21 |
}
|
| 22 |
|
| 23 |
+
ENG_OUTPUT_FEATURES = {
|
| 24 |
+
"inputs": seqio.Feature(
|
| 25 |
+
vocabulary=eng_vocabulary, add_eos=True,
|
| 26 |
+
required=False),
|
| 27 |
+
"targets": seqio.Feature(
|
| 28 |
+
vocabulary=eng_vocabulary, add_eos=True)
|
| 29 |
+
}
|
| 30 |
|
| 31 |
def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None):
|
| 32 |
dataset = load_dataset(**dataset_params)
|
|
|
|
| 60 |
dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
|
| 61 |
dataset_shapes = None
|
| 62 |
TaskRegistry.add(
|
| 63 |
+
"ncc_scandinavian_span_corruption_stream_engvoc",
|
| 64 |
+
source=seqio.FunctionDataSource(
|
| 65 |
+
dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
|
| 66 |
+
splits=("train", "validation"),
|
| 67 |
+
caching_permitted=False,
|
| 68 |
+
num_input_examples=dataset_shapes,
|
| 69 |
+
),
|
| 70 |
+
preprocessors=[
|
| 71 |
+
functools.partial(
|
| 72 |
+
target_to_key, key_map={
|
| 73 |
+
"inputs": None,
|
| 74 |
+
"targets": None,
|
| 75 |
+
}, target_key="targets"),
|
| 76 |
+
seqio.preprocessors.tokenize,
|
| 77 |
+
# seqio.CacheDatasetPlaceholder(),
|
| 78 |
+
preprocessors.span_corruption,
|
| 79 |
+
seqio.preprocessors.append_eos_after_trim,
|
| 80 |
+
],
|
| 81 |
+
output_features={"targets": ENG_OUTPUT_FEATURES["targets"]},
|
| 82 |
+
metric_fns=[]
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Final pretraining task used in Raffel et al., 2019 adaptated to NCC
|
| 87 |
+
dataset_name = 'NbAiLab/scandinavian'
|
| 88 |
+
dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
|
| 89 |
+
dataset_shapes = None
|
| 90 |
+
TaskRegistry.add(
|
| 91 |
+
"ncc_scandinavian_span_corruption_stream_scandvoc",
|
| 92 |
source=seqio.FunctionDataSource(
|
| 93 |
dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
|
| 94 |
splits=("train", "validation"),
|
|
|
|
| 106 |
preprocessors.span_corruption,
|
| 107 |
seqio.preprocessors.append_eos_after_trim,
|
| 108 |
],
|
| 109 |
+
output_features={"targets": SCAND_OUTPUT_FEATURES["targets"]},
|
| 110 |
metric_fns=[]
|
| 111 |
)
|
| 112 |
|
train_base.sh → train_exp1_base_engvoc.sh
RENAMED
|
@@ -1,9 +1,12 @@
|
|
| 1 |
PROJECT_DIR=${HOME}"/models/long-t5x"
|
| 2 |
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
| 3 |
-
MODEL_DIR="gs://nb-t5x-us-central2/
|
| 4 |
export PYTHONPATH=${PROJECT_DIR}
|
|
|
|
| 5 |
|
| 6 |
python3 ${T5X_DIR}/t5x/train.py \
|
| 7 |
--gin_search_paths=${PROJECT_DIR} \
|
| 8 |
--gin_file="base.gin" \
|
| 9 |
--gin.MODEL_DIR="'${MODEL_DIR}'" \
|
|
|
|
|
|
|
|
|
| 1 |
PROJECT_DIR=${HOME}"/models/long-t5x"
|
| 2 |
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
| 3 |
+
MODEL_DIR="gs://nb-t5x-us-central2/exp1-t5-base-engvoc"
|
| 4 |
export PYTHONPATH=${PROJECT_DIR}
|
| 5 |
+
MIXTURE_OR_TASK_NAME="ncc_scandinavian_span_corruption_stream_engvoc"
|
| 6 |
|
| 7 |
python3 ${T5X_DIR}/t5x/train.py \
|
| 8 |
--gin_search_paths=${PROJECT_DIR} \
|
| 9 |
--gin_file="base.gin" \
|
| 10 |
--gin.MODEL_DIR="'${MODEL_DIR}'" \
|
| 11 |
+
--gin.MIXTURE_OR_TASK_NAME="'${MIXTURE_OR_TASK_NAME}'" \
|
| 12 |
+
|
train_exp1_base_scandvoc.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
PROJECT_DIR=${HOME}"/models/long-t5x"
|
| 2 |
+
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
| 3 |
+
MODEL_DIR="gs://nb-t5x-us-central2/exp1-t5-base-scandvoc"
|
| 4 |
+
export PYTHONPATH=${PROJECT_DIR}
|
| 5 |
+
MIXTURE_OR_TASK_NAME="ncc_scandinavian_span_corruption_stream_scandvoc"
|
| 6 |
+
|
| 7 |
+
python3 ${T5X_DIR}/t5x/train.py \
|
| 8 |
+
--gin_search_paths=${PROJECT_DIR} \
|
| 9 |
+
--gin_file="base.gin" \
|
| 10 |
+
--gin.MODEL_DIR="'${MODEL_DIR}'" \
|
| 11 |
+
--gin.MIXTURE_OR_TASK_NAME="'${MIXTURE_OR_TASK_NAME}'" \
|
| 12 |
+
|