update
Browse files- T5Trainer.py +16 -18
T5Trainer.py
CHANGED
|
@@ -27,33 +27,31 @@ def create_arg_parser():
|
|
| 27 |
parser.add_argument("-tf", "--transformer", default="google/byt5-small",
|
| 28 |
type=str, help="this argument takes the pretrained "
|
| 29 |
"language model URL from HuggingFace "
|
| 30 |
-
"default is
|
| 31 |
"HuggingFace for full URL")
|
| 32 |
parser.add_argument("-c_model", "--custom_model",
|
| 33 |
type=str, help="this argument takes a custom "
|
| 34 |
"pretrained checkpoint")
|
| 35 |
-
parser.add_argument("-train", "--train_data", default='
|
| 36 |
type=str, help="this argument takes the train "
|
| 37 |
"data file as input")
|
| 38 |
-
parser.add_argument("-dev", "--dev_data", default='
|
| 39 |
-
help="this argument takes the dev data file
|
| 40 |
-
|
| 41 |
-
parser.add_argument("-
|
| 42 |
-
help="class weights for custom loss calculation")
|
| 43 |
-
parser.add_argument("-lr", "--learn_rate", default=1e-3, type=float,
|
| 44 |
help="Set a custom learn rate for "
|
| 45 |
-
"the
|
| 46 |
-
parser.add_argument("-bs", "--batch_size", default=
|
| 47 |
help="Set a custom batch size for "
|
| 48 |
"the pretrained language model, default is 8")
|
| 49 |
-
parser.add_argument("-sl_train", "--sequence_length_train", default=
|
| 50 |
type=int, help="Set a custom maximum sequence length"
|
| 51 |
"for the pretrained language model,"
|
| 52 |
-
"default is
|
| 53 |
-
parser.add_argument("-sl_dev", "--sequence_length_dev", default=
|
| 54 |
type=int, help="Set a custom maximum sequence length"
|
| 55 |
"for the pretrained language model,"
|
| 56 |
-
"default is
|
| 57 |
parser.add_argument("-ep", "--epochs", default=1, type=int,
|
| 58 |
help="This argument selects the amount of epochs "
|
| 59 |
"to run the model with, default is 1 epoch")
|
|
@@ -61,7 +59,7 @@ def create_arg_parser():
|
|
| 61 |
help="Set the value to monitor for earlystopping")
|
| 62 |
parser.add_argument("-es_p", "--early_stop_patience", default=2,
|
| 63 |
type=int, help="Set the patience value for "
|
| 64 |
-
"earlystopping")
|
| 65 |
args = parser.parse_args()
|
| 66 |
return args
|
| 67 |
|
|
@@ -131,7 +129,7 @@ def create_data(data):
|
|
| 131 |
|
| 132 |
|
| 133 |
def split_sent(data, max_length):
|
| 134 |
-
'''Splitting sentences if longer than given
|
| 135 |
short_sent = []
|
| 136 |
long_sent = []
|
| 137 |
for n in data:
|
|
@@ -159,7 +157,7 @@ def split_sent(data, max_length):
|
|
| 159 |
|
| 160 |
|
| 161 |
def preprocess_function(tk, s, t):
|
| 162 |
-
'''tokenizing
|
| 163 |
model_inputs = tk(s)
|
| 164 |
|
| 165 |
with tk.as_target_tokenizer():
|
|
@@ -195,7 +193,7 @@ def convert_tok(tok, sl):
|
|
| 195 |
|
| 196 |
|
| 197 |
def train_model(model_name, lr, bs, sl_train, sl_dev, ep, es, es_p, train, dev):
|
| 198 |
-
'''Finetune and save a given T5 version with given
|
| 199 |
print('Training model: {}\nWith parameters:\nLearn rate: {}, '
|
| 200 |
'Batch size: {}\nSequence length train: {}, sequence length dev: {}\n'
|
| 201 |
'Epochs: {}'.format(model_name, lr, bs, sl_train, sl_dev, ep))
|
|
|
|
| 27 |
parser.add_argument("-tf", "--transformer", default="google/byt5-small",
|
| 28 |
type=str, help="this argument takes the pretrained "
|
| 29 |
"language model URL from HuggingFace "
|
| 30 |
+
"default is ByT5-small, please visit "
|
| 31 |
"HuggingFace for full URL")
|
| 32 |
parser.add_argument("-c_model", "--custom_model",
|
| 33 |
type=str, help="this argument takes a custom "
|
| 34 |
"pretrained checkpoint")
|
| 35 |
+
parser.add_argument("-train", "--train_data", default='training_data10k.txt',
|
| 36 |
type=str, help="this argument takes the train "
|
| 37 |
"data file as input")
|
| 38 |
+
parser.add_argument("-dev", "--dev_data", default='validation_data.txt',
|
| 39 |
+
type=str, help="this argument takes the dev data file "
|
| 40 |
+
"as input")
|
| 41 |
+
parser.add_argument("-lr", "--learn_rate", default=5e-5, type=float,
|
|
|
|
|
|
|
| 42 |
help="Set a custom learn rate for "
|
| 43 |
+
"the model, default is 5e-5")
|
| 44 |
+
parser.add_argument("-bs", "--batch_size", default=8, type=int,
|
| 45 |
help="Set a custom batch size for "
|
| 46 |
"the pretrained language model, default is 8")
|
| 47 |
+
parser.add_argument("-sl_train", "--sequence_length_train", default=155,
|
| 48 |
type=int, help="Set a custom maximum sequence length"
|
| 49 |
"for the pretrained language model,"
|
| 50 |
+
"default is 155")
|
| 51 |
+
parser.add_argument("-sl_dev", "--sequence_length_dev", default=155,
|
| 52 |
type=int, help="Set a custom maximum sequence length"
|
| 53 |
"for the pretrained language model,"
|
| 54 |
+
"default is 155")
|
| 55 |
parser.add_argument("-ep", "--epochs", default=1, type=int,
|
| 56 |
help="This argument selects the amount of epochs "
|
| 57 |
"to run the model with, default is 1 epoch")
|
|
|
|
| 59 |
help="Set the value to monitor for earlystopping")
|
| 60 |
parser.add_argument("-es_p", "--early_stop_patience", default=2,
|
| 61 |
type=int, help="Set the patience value for "
|
| 62 |
+
"earlystopping, default is 2")
|
| 63 |
args = parser.parse_args()
|
| 64 |
return args
|
| 65 |
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
def split_sent(data, max_length):
|
| 132 |
+
'''Splitting sentences if longer than given max_length value'''
|
| 133 |
short_sent = []
|
| 134 |
long_sent = []
|
| 135 |
for n in data:
|
|
|
|
| 157 |
|
| 158 |
|
| 159 |
def preprocess_function(tk, s, t):
|
| 160 |
+
'''tokenizing text and labels'''
|
| 161 |
model_inputs = tk(s)
|
| 162 |
|
| 163 |
with tk.as_target_tokenizer():
|
|
|
|
| 193 |
|
| 194 |
|
| 195 |
def train_model(model_name, lr, bs, sl_train, sl_dev, ep, es, es_p, train, dev):
|
| 196 |
+
'''Finetune and save a given T5 version with given parameters'''
|
| 197 |
print('Training model: {}\nWith parameters:\nLearn rate: {}, '
|
| 198 |
'Batch size: {}\nSequence length train: {}, sequence length dev: {}\n'
|
| 199 |
'Epochs: {}'.format(model_name, lr, bs, sl_train, sl_dev, ep))
|