| |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from typing import Optional |
|
|
| import transformers |
| from transformers import HfArgumentParser, TrainingArguments |
|
|
| from flame.logging import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| @dataclass |
| class TrainingArguments(TrainingArguments): |
|
|
| model_name_or_path: str = field( |
| default=None, |
| metadata={ |
| "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." |
| }, |
| ) |
| tokenizer: str = field( |
| default="fla-hub/gla-1.3B-100B", |
| metadata={"help": "Name of the tokenizer to use."} |
| ) |
| use_fast_tokenizer: bool = field( |
| default=False, |
| metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, |
| ) |
| from_config: bool = field( |
| default=True, |
| metadata={"help": "Whether to initialize models from scratch."}, |
| ) |
| dataset: Optional[str] = field( |
| default=None, |
| metadata={"help": "The dataset(s) to use. Use commas to separate multiple datasets."}, |
| ) |
| dataset_name: Optional[str] = field( |
| default=None, |
| metadata={"help": "The name of provided dataset(s) to use."}, |
| ) |
| cache_dir: str = field( |
| default=None, |
| metadata={"help": "Path to the cached tokenized dataset."}, |
| ) |
| split: str = field( |
| default="train", |
| metadata={"help": "Which dataset split to use for training and evaluation."}, |
| ) |
| streaming: bool = field( |
| default=False, |
| metadata={"help": "Enable dataset streaming."}, |
| ) |
| hf_hub_token: Optional[str] = field( |
| default=None, |
| metadata={"help": "Auth token to log in with Hugging Face Hub."}, |
| ) |
| preprocessing_num_workers: Optional[int] = field( |
| default=None, |
| metadata={"help": "The number of processes to use for the pre-processing."}, |
| ) |
| buffer_size: int = field( |
| default=2048, |
| metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}, |
| ) |
| context_length: int = field( |
| default=2048, |
| metadata={"help": "The context length of the tokenized inputs in the dataset."}, |
| ) |
| varlen: bool = field( |
| default=False, |
| metadata={"help": "Enable training with variable length inputs."}, |
| ) |
|
|
|
|
| def get_train_args(): |
| parser = HfArgumentParser(TrainingArguments) |
| args, unknown_args = parser.parse_args_into_dataclasses(return_remaining_strings=True) |
|
|
| if unknown_args: |
| print(parser.format_help()) |
| print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args)) |
| raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args)) |
|
|
| if args.should_log: |
| transformers.utils.logging.set_verbosity(args.get_process_log_level()) |
| transformers.utils.logging.enable_default_handler() |
| transformers.utils.logging.enable_explicit_format() |
| |
| transformers.set_seed(args.seed) |
| return args |
|
|