add support for v3-32
Browse files- run_mlm_flax_stream.py +5 -25
run_mlm_flax_stream.py
CHANGED
|
@@ -262,29 +262,6 @@ class FlaxDataCollatorForLanguageModeling:
|
|
| 262 |
return inputs, labels
|
| 263 |
|
| 264 |
|
| 265 |
-
@dataclass
|
| 266 |
-
class SamplingArguments:
|
| 267 |
-
"""
|
| 268 |
-
Arguments pertaining to how to perform sampling of the dataset.
|
| 269 |
-
"""
|
| 270 |
-
|
| 271 |
-
perplexity_model: Optional[str] = field(
|
| 272 |
-
default="./es.arpa.bin", metadata={"help": "Path to KenLM model to use to get perplexity values."}
|
| 273 |
-
)
|
| 274 |
-
sampling_method: Optional[str] = field(
|
| 275 |
-
default=None, metadata={"help": "Sample using a 'step' or 'gaussian' perplexity function per document, or 'random'."}
|
| 276 |
-
)
|
| 277 |
-
sampling_factor: Optional[float] = field(
|
| 278 |
-
default=None, metadata={"help": "Sampling factor. Integers for step function, decimals for gaussian."}
|
| 279 |
-
)
|
| 280 |
-
boundaries: Optional[str] = field(
|
| 281 |
-
default="536394.99320948,662247.50212365,919250.87225178", metadata={"help": "Quartile boundaries"}
|
| 282 |
-
)
|
| 283 |
-
|
| 284 |
-
def __post_init__(self):
|
| 285 |
-
self.boundaries = [float(q.strip()) for q in self.boundaries.split(",")]
|
| 286 |
-
|
| 287 |
-
|
| 288 |
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
| 289 |
num_samples = len(samples_idx)
|
| 290 |
samples_to_remove = num_samples % batch_size
|
|
@@ -310,7 +287,9 @@ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
|
|
| 310 |
i += len(tokenized_samples["input_ids"])
|
| 311 |
|
| 312 |
# concatenate tokenized samples to list
|
| 313 |
-
samples = {
|
|
|
|
|
|
|
| 314 |
|
| 315 |
# Concatenated tokens are split to lists of length `max_seq_length`.
|
| 316 |
# Note that remainedr of % max_seq_length are thrown away.
|
|
@@ -404,7 +383,7 @@ if __name__ == "__main__":
|
|
| 404 |
# or by passing the --help flag to this script.
|
| 405 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
| 406 |
|
| 407 |
-
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments
|
| 408 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 409 |
# If we pass only one argument to the script and it's the path to a json file,
|
| 410 |
# let's parse it to get our arguments.
|
|
@@ -528,6 +507,7 @@ if __name__ == "__main__":
|
|
| 528 |
|
| 529 |
# Data collator
|
| 530 |
# This one will take care of randomly masking the tokens.
|
|
|
|
| 531 |
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
| 532 |
|
| 533 |
# Initialize our training
|
|
|
|
| 262 |
return inputs, labels
|
| 263 |
|
| 264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
| 266 |
num_samples = len(samples_idx)
|
| 267 |
samples_to_remove = num_samples % batch_size
|
|
|
|
| 287 |
i += len(tokenized_samples["input_ids"])
|
| 288 |
|
| 289 |
# concatenate tokenized samples to list
|
| 290 |
+
samples = {
|
| 291 |
+
k: samples[k] + tokenized_samples[k] for k in ["input_ids", "attention_mask", "special_tokens_mask"]
|
| 292 |
+
}
|
| 293 |
|
| 294 |
# Concatenated tokens are split to lists of length `max_seq_length`.
|
| 295 |
# Note that remainedr of % max_seq_length are thrown away.
|
|
|
|
| 383 |
# or by passing the --help flag to this script.
|
| 384 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
| 385 |
|
| 386 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
| 387 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 388 |
# If we pass only one argument to the script and it's the path to a json file,
|
| 389 |
# let's parse it to get our arguments.
|
|
|
|
| 507 |
|
| 508 |
# Data collator
|
| 509 |
# This one will take care of randomly masking the tokens.
|
| 510 |
+
print("DATA COLLATOR")
|
| 511 |
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
| 512 |
|
| 513 |
# Initialize our training
|