add stream
Browse files- run_mlm_flax_stream.py +11 -28
run_mlm_flax_stream.py
CHANGED
|
@@ -308,7 +308,7 @@ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
|
|
| 308 |
while i < num_total_tokens:
|
| 309 |
tokenized_samples = next(train_iterator)
|
| 310 |
i += len(tokenized_samples["input_ids"])
|
| 311 |
-
|
| 312 |
# concatenate tokenized samples to list
|
| 313 |
samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
|
| 314 |
|
|
@@ -451,30 +451,13 @@ if __name__ == "__main__":
|
|
| 451 |
# 'text' is found. You can easily tweak this behavior (see below).
|
| 452 |
if data_args.dataset_name is not None:
|
| 453 |
# Downloading and loading a dataset from the hub.
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
data_args.dataset_name,
|
| 462 |
-
data_args.dataset_config_name,
|
| 463 |
-
cache_dir=model_args.cache_dir,
|
| 464 |
-
streaming=True,
|
| 465 |
-
split="train",
|
| 466 |
-
)
|
| 467 |
-
except Exception as exc:
|
| 468 |
-
logger.warning(
|
| 469 |
-
f"Unable to load local dataset with perplexity sampling support. Using huggingface.co/datasets/{data_args.dataset_name}: {exc}"
|
| 470 |
-
)
|
| 471 |
-
dataset = load_dataset(
|
| 472 |
-
data_args.dataset_name,
|
| 473 |
-
data_args.dataset_config_name,
|
| 474 |
-
cache_dir=model_args.cache_dir,
|
| 475 |
-
streaming=True,
|
| 476 |
-
split="train",
|
| 477 |
-
)
|
| 478 |
|
| 479 |
if model_args.config_name:
|
| 480 |
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
|
@@ -505,13 +488,13 @@ if __name__ == "__main__":
|
|
| 505 |
return tokenizer(
|
| 506 |
examples[data_args.text_column_name],
|
| 507 |
max_length=512,
|
| 508 |
-
truncation=True,
|
| 509 |
return_special_tokens_mask=True
|
| 510 |
)
|
| 511 |
|
| 512 |
tokenized_datasets = dataset.map(
|
| 513 |
tokenize_function,
|
| 514 |
batched=True,
|
|
|
|
| 515 |
)
|
| 516 |
|
| 517 |
shuffle_seed = training_args.seed
|
|
@@ -524,8 +507,8 @@ if __name__ == "__main__":
|
|
| 524 |
# Enable Weight&Biases
|
| 525 |
import wandb
|
| 526 |
wandb.init(
|
| 527 |
-
entity='
|
| 528 |
-
project='roberta-
|
| 529 |
sync_tensorboard=True,
|
| 530 |
)
|
| 531 |
wandb.config.update(training_args)
|
|
|
|
| 308 |
while i < num_total_tokens:
|
| 309 |
tokenized_samples = next(train_iterator)
|
| 310 |
i += len(tokenized_samples["input_ids"])
|
| 311 |
+
|
| 312 |
# concatenate tokenized samples to list
|
| 313 |
samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
|
| 314 |
|
|
|
|
| 451 |
# 'text' is found. You can easily tweak this behavior (see below).
|
| 452 |
if data_args.dataset_name is not None:
|
| 453 |
# Downloading and loading a dataset from the hub.
|
| 454 |
+
dataset = load_dataset(
|
| 455 |
+
data_args.dataset_name,
|
| 456 |
+
data_args.dataset_config_name,
|
| 457 |
+
cache_dir=model_args.cache_dir,
|
| 458 |
+
streaming=True,
|
| 459 |
+
split="train",
|
| 460 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
|
| 462 |
if model_args.config_name:
|
| 463 |
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
|
|
|
| 488 |
return tokenizer(
|
| 489 |
examples[data_args.text_column_name],
|
| 490 |
max_length=512,
|
|
|
|
| 491 |
return_special_tokens_mask=True
|
| 492 |
)
|
| 493 |
|
| 494 |
tokenized_datasets = dataset.map(
|
| 495 |
tokenize_function,
|
| 496 |
batched=True,
|
| 497 |
+
remove_columns=list(dataset.features.keys()),
|
| 498 |
)
|
| 499 |
|
| 500 |
shuffle_seed = training_args.seed
|
|
|
|
| 507 |
# Enable Weight&Biases
|
| 508 |
import wandb
|
| 509 |
wandb.init(
|
| 510 |
+
entity='wandb',
|
| 511 |
+
project='hf-flax-bertin-roberta-es',
|
| 512 |
sync_tensorboard=True,
|
| 513 |
)
|
| 514 |
wandb.config.update(training_args)
|