| import os |
| import logging |
| import sys |
| from dataclasses import dataclass, field |
| from typing import Optional |
|
|
| from transformers import ( |
| HfArgumentParser, |
| AutoModelForMaskedLM, |
| AutoTokenizer, |
| set_seed, |
| AutoConfig, |
| DataCollatorForLanguageModeling, |
| ) |
|
|
| from transformers import Trainer, TrainingArguments |
| from datasets import load_dataset |
|
|
|
|
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| handlers=[logging.StreamHandler(sys.stdout)], |
| level=logging.INFO, |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class ScriptArguments: |
| """ |
| Arguments which aren't included in the TrainingArguments |
| """ |
| resume_from_checkpoint: str = field(default=None) |
| dataset_id: str = field( |
| default=None, metadata={"help": "The repository id of the dataset to use (via the datasets library)."} |
| ) |
| tokenizer_id: str = field( |
| default=None, metadata={"help": "The repository id of the tokenizer to use (via AutoTokenizer)."} |
| ) |
| repository_id: str = field( |
| default=None, |
| metadata={"help": "The repository id where the model will be saved or loaded from for futher pre-training."}, |
| ) |
| model_config_id: Optional[str] = field( |
| default="bert-base-uncased", metadata={"help": "Pretrained config name or path if not the same as model_name"} |
| ) |
| per_device_train_batch_size: Optional[int] = field( |
| default=16, |
| metadata={"help": "The Batch Size per HPU used during training"}, |
| ) |
| max_steps: Optional[int] = field( |
| default=1_000_000, |
| metadata={"help": "The Number of Training steps to perform."}, |
| ) |
| learning_rate: Optional[float] = field(default=1e-4, metadata={"help": "Learning Rate for the training"}) |
| mlm_probability: Optional[float] = field( |
| default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} |
| ) |
|
|
|
|
| def run_mlm(): |
| |
| parser = HfArgumentParser(ScriptArguments) |
| script_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)[0] |
| logger.info(f"Script parameters {script_args}") |
|
|
| |
| seed = 34 |
| set_seed(seed) |
|
|
| |
| train_dataset = load_dataset(script_args.dataset_id, split="train") |
| |
| tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_id) |
|
|
| |
| logger.info("Training new model from scratch") |
| config = AutoConfig.from_pretrained(script_args.model_config_id) |
| model = AutoModelForMaskedLM.from_config(config) |
|
|
| logger.info(f"Resizing token embedding to {len(tokenizer)}") |
| model.resize_token_embeddings(len(tokenizer)) |
|
|
| |
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=tokenizer, mlm_probability=script_args.mlm_probability, pad_to_multiple_of=8 |
| ) |
|
|
| |
| training_args = TrainingArguments( |
| output_dir=script_args.repository_id, |
| per_device_train_batch_size=script_args.per_device_train_batch_size, |
| learning_rate=script_args.learning_rate, |
| seed=seed, |
| max_steps=script_args.max_steps, |
| |
| logging_dir=f"{script_args.repository_id}/logs", |
| logging_strategy="steps", |
| logging_steps=100, |
| save_strategy="steps", |
| save_steps=5_000, |
| save_total_limit=2, |
| report_to="tensorboard", |
| |
| ddp_find_unused_parameters=True, |
| |
| ) |
|
|
| |
| trainer = Trainer( |
| args=training_args, |
| model=model, |
| train_dataset=train_dataset, |
| tokenizer=tokenizer, |
| data_collator=data_collator, |
| ) |
| |
| trainer.train(script_args.resume_from_checkpoint) |
|
|
|
|
| if __name__ == "__main__": |
| run_mlm() |
|
|