updates
Browse files- run.sh +2 -2
- run_mlm_flax_stream.py +6 -2
run.sh
CHANGED
|
@@ -6,8 +6,8 @@ python run_mlm_flax_stream.py \
|
|
| 6 |
--dataset_name="NbAiLab/scandinavian" \
|
| 7 |
--max_seq_length="512" \
|
| 8 |
--weight_decay="0.01" \
|
| 9 |
-
--per_device_train_batch_size="
|
| 10 |
-
--per_device_eval_batch_size="
|
| 11 |
--learning_rate="1e-4" \
|
| 12 |
--warmup_steps="10000" \
|
| 13 |
--overwrite_output_dir \
|
|
|
|
| 6 |
--dataset_name="NbAiLab/scandinavian" \
|
| 7 |
--max_seq_length="512" \
|
| 8 |
--weight_decay="0.01" \
|
| 9 |
+
--per_device_train_batch_size="12" \
|
| 10 |
+
--per_device_eval_batch_size="12" \
|
| 11 |
--learning_rate="1e-4" \
|
| 12 |
--warmup_steps="10000" \
|
| 13 |
--overwrite_output_dir \
|
run_mlm_flax_stream.py
CHANGED
|
@@ -395,11 +395,11 @@ if __name__ == "__main__":
|
|
| 395 |
|
| 396 |
if model_args.tokenizer_name:
|
| 397 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 398 |
-
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 399 |
)
|
| 400 |
elif model_args.model_name_or_path:
|
| 401 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 402 |
-
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
| 403 |
)
|
| 404 |
else:
|
| 405 |
raise ValueError(
|
|
@@ -451,6 +451,10 @@ if __name__ == "__main__":
|
|
| 451 |
num_epochs = int(training_args.num_train_epochs)
|
| 452 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
| 453 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
# define number steps per stream epoch
|
| 456 |
num_train_steps = data_args.num_train_steps
|
|
|
|
| 395 |
|
| 396 |
if model_args.tokenizer_name:
|
| 397 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 398 |
+
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer,model_max_length=512
|
| 399 |
)
|
| 400 |
elif model_args.model_name_or_path:
|
| 401 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 402 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer,model_max_length=512
|
| 403 |
)
|
| 404 |
else:
|
| 405 |
raise ValueError(
|
|
|
|
| 451 |
num_epochs = int(training_args.num_train_epochs)
|
| 452 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
| 453 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
| 454 |
+
|
| 455 |
+
print("***************************")
|
| 456 |
+
print(f"Train Batch Size: {train_batch_size}")
|
| 457 |
+
print("***************************")
|
| 458 |
|
| 459 |
# define number steps per stream epoch
|
| 460 |
num_train_steps = data_args.num_train_steps
|