# Copyright 2025 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Argument utils""" import argparse import json import math import os import sys import types from collections import defaultdict from dataclasses import MISSING, asdict, dataclass, field, fields from enum import Enum from inspect import isclass from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union, get_type_hints import yaml from . import logging T = TypeVar("T") logger = logging.get_logger(__name__) @dataclass class ModelArguments: config_path: Optional[str] = field( default=None, metadata={"help": "Path to the model config. Defaults to `model_path`."}, ) model_path: Optional[str] = field( default=None, metadata={"help": "Path to the pre-trained model. If unspecified, use random init."}, ) tokenizer_path: Optional[str] = field( default=None, metadata={"help": "Path to the tokenizer. Defaults to `config_path`."}, ) vlm_repo_id: Optional[str] = field( default=None, metadata={"help": "Path to the VLM. Defaults to None."}, ) post_training: Optional[bool] = field( default=False, metadata={"help": "Whether to use post training."}, ) vocab_size: Optional[int] = field( default=0, metadata={"help": "Vocab size. 257152 is for paligemma in initial pi0."}, ) incremental_training: Optional[bool] = field( default=False, metadata={"help": "Whether to apply incremental training."}, ) depth_incremental_training: Optional[bool] = field( default=False, metadata={"help": "Whether to re-init depth_align_head."}, ) adanorm_time: Optional[bool] = field( default=False, metadata={"help": "Whether to apply extra time embed to ada_norm in expert."}, ) encoders: Dict[Literal["image"], Dict[str, str]] = field( default_factory=dict, metadata={"help": "Multimodal encoder config and weights."}, ) decoders: Dict[Literal["image"], Dict[str, str]] = field( default_factory=dict, metadata={"help": "Multimodal decoder config and weights."}, ) input_encoder: Literal["encoder", "decoder"] = field( default="encoder", metadata={"help": "Use encoder to encode input images or use decoder.encoder to encode input images."}, ) output_encoder: Literal["encoder", "decoder"] = field( default="decoder", metadata={"help": "Use encoder to encode output images or use decoder.encoder to encode output images."}, ) encode_target: bool = field( default=False, metadata={"help": "Whether to encode target with decoder. Only supports stable diffusion as decoder."}, ) attn_implementation: Optional[Literal["eager", "sdpa", "flash_attention_2", "flash_attention_3"]] = field( default="flash_attention_2", metadata={"help": "Attention implementation to use."}, ) moe_implementation: Optional[Literal[None, "eager", "fused"]] = field( default=None, metadata={"help": "MoE implementation to use."}, ) basic_modules: Optional[List[str]] = field( default_factory=list, metadata={"help": "Basic modules beyond model._no_split_modules to be sharded in FSDP."}, ) force_use_huggingface: bool = field( default=False, metadata={"help": "Force loading model from huggingface."}, ) use_lm_head: bool = field( default=False, metadata={"help": "Whether to use lm_head."}, ) split_gate_liner: bool = field( default=False, metadata={"help": "Whether to split gate liner in adanorm."}, ) nosplit_gate_liner: bool = field( default=False, metadata={"help": "Whether to nosplit gate liner in adanorm."}, ) separate_time_proj: bool = field( default=False, metadata={"help": "Whether to split time proj in embed_suffix."}, ) final_norm_adanorm: bool = field( default=False, metadata={"help": "Whether to use adanorm in final norm."}, ) old_adanorm: bool = field( default=False, metadata={"help": "Whether to use old adanorm."}, ) moge_path: str = field( default=None, metadata={"help": "path of MgGe."}, ) morgbd_path: str = field( default=None, metadata={"help": "path of LingBot-Depth."}, ) def __post_init__(self): if self.config_path is None and self.model_path is None: raise ValueError("`config_path` must be specified when `model_path` is None.") if self.config_path is None: self.config_path = self.model_path if self.tokenizer_path is None: self.tokenizer_path = self.config_path for encoder_type, encoder_args in self.encoders.items(): if encoder_type not in ["image"]: raise ValueError(f"Unsupported encoder type: {encoder_type}. Should be one of {{image}}.") if encoder_args.get("config_path") is None and encoder_args.get("model_path") is None: raise ValueError("`config_path` and `model_path` cannot be both empty.") if encoder_args.get("config_path") is None: encoder_args["config_path"] = encoder_args["model_path"] for decoder_type, decoder_args in self.decoders.items(): if decoder_type not in ["image"]: raise ValueError(f"Unsupported decoder type: {decoder_type}. Should be one of {{image}}.") if decoder_args.get("config_path") is None and decoder_args.get("model_path") is None: raise ValueError("`config_path` and `model_path` cannot be both empty.") if decoder_args.get("config_path") is None: decoder_args["config_path"] = decoder_args["model_path"] @dataclass class DataArguments: train_path: str = field( metadata={"help": "Path of the training data. Use comma to separate multiple datasets."}, ) train_size: int = field( default=10_000_000, metadata={"help": "Number of tokens for training to compute training steps for dynamic batch dataloader."}, ) data_type: Literal["plaintext", "conversation", "diffusion"] = field( default="conversation", metadata={"help": "Type of the training data."}, ) dataloader_type: Literal["native"] = field( default="native", metadata={"help": "Type of the dataloader."}, ) datasets_type: Literal["mapping", "iterable", "vla"] = field( default="mapping", metadata={"help": "Type of the datasets."}, ) data_name: str = field( default=None, metadata={"help": "Dataset name for multimodal training."}, ) data_root: str = field( default=None, metadata={"help": "Root path of datasets."}, ) data_tag: Literal["default", "mmtag"] = field( default="default", metadata={"help": "Dataset tag for multimodal training."}, ) text_keys: str = field( default=None, metadata={"help": "Key to get text from the training data."}, ) image_keys: str = field( default="images", metadata={"help": "Key to get images from the training data."}, ) chat_template: str = field( default="default", metadata={"help": "Chat template to use."}, ) max_seq_len: int = field( default=2048, metadata={"help": "Maximum sequence length in training."}, ) num_workers: int = field( default=20, metadata={"help": "Number of workers to load data."}, ) prefetch_factor: int = field( default=4, metadata={"help": "Number of batches loaded in advance by each worker."}, ) drop_last: bool = field( default=True, metadata={"help": "Whether to drop the last incomplete batch."}, ) pin_memory: bool = field( default=True, metadata={"help": "Whether to pin memory for dataloader."}, ) def __post_init__(self): if self.text_keys is None: if self.data_type == "plaintext": self.text_keys = "content_split" elif self.data_type == "conversation": self.text_keys = "messages" else: raise ValueError(f"Unknown data type: {self.data_type}") @dataclass class TrainingArguments: output_dir: str = field( metadata={"help": "Path to save model checkpoints."}, ) lr: float = field( default=5e-5, metadata={"help": "Maximum learning rate or defult learning rate, or init learning rate for warmup."}, ) lr_min: float = field( default=1e-7, metadata={"help": "Minimum learning rate."}, ) lr_start: float = field( default=0.0, metadata={"help": "Learning rate for warmup start. Default to 0.0."}, ) weight_decay: float = field( default=0, metadata={"help": "L2 regularization strength."}, ) optimizer: Literal["adamw", "anyprecision_adamw"] = field( default="adamw", metadata={"help": "Optimizer. Default to adamw."}, ) max_grad_norm: float = field( default=1.0, metadata={"help": "Clip value for gradient norm."}, ) micro_batch_size: int = field( default=1, metadata={"help": "Micro batch size. The number of samples per iteration on each device."}, ) global_batch_size: Optional[int] = field( default=None, metadata={"help": "Global batch size. If None, use `micro_batch_size` * `data_parallel_size`."}, ) num_train_epochs: int = field( default=1, metadata={"help": "Epochs to train."}, ) rmpad: bool = field( default=True, metadata={"help": "Enable padding-free training by using the cu_seqlens."}, ) rmpad_with_pos_ids: bool = field( default=False, metadata={"help": "Enable padding-free training by using the position_ids."}, ) dyn_bsz: bool = field( default=True, metadata={"help": "Enable dynamic batch size for padding-free training."}, ) dyn_bsz_margin: int = field( default=0, metadata={"help": "Number of pad tokens in dynamic batch."}, ) dyn_bsz_buffer_size: int = field( default=200, metadata={"help": "Buffer size for dynamic batch size."}, ) bsz_warmup_ratio: float = field( default=0, metadata={"help": "Ratio of batch size warmup steps."}, ) bsz_warmup_init_mbtoken: int = field( default=200, metadata={"help": "Initial number of tokens in a batch in warmup phase."}, ) lr_warmup_ratio: float = field( default=0, metadata={"help": "Ratio of learning rate warmup steps."}, ) lr_decay_style: str = field( default="constant", metadata={"help": "Name of the learning rate scheduler."}, ) lr_decay_ratio: float = field( default=1.0, metadata={"help": "Ratio of learning rate decay steps."}, ) use_doptim: bool = field( default=False, metadata={"help": "Use veScale's ZeRO optimizer."}, ) enable_mixed_precision: bool = field( default=True, metadata={"help": "Enable mixed precision training."}, # false -> torch_dtype when loading model is bf16 ) enable_fp32: bool = field( default=False, metadata={"help": "Enable fp32 training."}, ) enable_resume: bool = field( default=False, metadata={"help": "Whether to automatically resume training from a checkpoint."}, ) enable_gradient_checkpointing: bool = field( default=True, metadata={"help": "Enable gradient checkpointing."}, ) enable_reentrant: bool = field( default=False, metadata={"help": "Use reentrant gradient checkpointing."}, ) enable_full_shard: bool = field( default=True, metadata={"help": "Enable fully shard for FSDP training (ZeRO-3)."}, ) enable_forward_prefetch: bool = field( default=True, metadata={"help": "Enable forward prefetch for FSDP1."}, ) enable_fsdp_offload: bool = field( default=False, metadata={"help": "Enable CPU offload for FSDP1."}, ) enable_activation_offload: bool = field( default=False, metadata={"help": "Enable activation offload to CPU."}, ) activation_gpu_limit: float = field( default=0.0, metadata={ "help": "When enabling activation offload, `activation_gpu_limit` GB activations are allowed to reserve on GPU." }, ) enable_manual_eager: bool = field( default=False, metadata={"help": "Enable veScale's manual eager."}, ) init_device: Literal["cpu", "cuda", "meta"] = field( default="cuda", metadata={ "help": "Device to initialize model weights. 1. `cpu`: Init parameters on CPU in rank0 only. 2. `cuda`: Init parameters on GPU. 3. `meta`: Init parameters on meta." }, ) enable_full_determinism: bool = field( default=False, metadata={"help": "Enable full determinism."}, ) empty_cache_steps: int = field( default=500, metadata={"help": "Number of steps between two empty cache operations."}, ) data_parallel_mode: Literal["ddp", "fsdp1", "fsdp2", "fsdp2-vescale"] = field( default="ddp", metadata={"help": "Data parallel mode."}, ) use_compile: bool = field( default=False, metadata={"help": "wether to enable torch.compile."}, ) module_fsdp_enable: bool = field( default=True, metadata={"help": "Enable FSDP for module."},) data_parallel_replicate_size: int = field( default=-1, metadata={"help": "Data parallel replicate size."}, ) data_parallel_shard_size: int = field( default=-1, metadata={"help": "Data parallel shard degree."}, ) tensor_parallel_size: int = field( default=1, metadata={"help": "Tensor parallel size."}, ) expert_parallel_size: int = field( default=1, metadata={"help": "Expert parallel size."}, ) pipeline_parallel_size: int = field( default=1, metadata={"help": "Pipeline parallel size."}, ) ulysses_parallel_size: int = field( default=1, metadata={"help": "Ulysses sequence parallel size."}, ) context_parallel_size: int = field( default=1, metadata={"help": "Ring-attn context parallel size."}, ) ckpt_manager: Literal["bytecheckpoint", "dcp"] = field( default="dcp", metadata={"help": "Checkpoint manager."}, ) load_checkpoint_path: Optional[str] = field( default=None, metadata={"help": "Path to bytecheckpoint checkpoint to resume from."}, ) save_steps: int = field( default=0, metadata={"help": "Number of steps between two checkpoint saves."}, ) save_epochs: int = field( default=1, metadata={"help": "Number of epochs between two checkpoint saves."}, ) save_hf_weights: bool = field( default=True, metadata={"help": "Save the huggingface format weights to the last checkpoint dir."}, ) seed: int = field( default=42, metadata={"help": "Random seed."}, ) use_wandb: bool = field( default=True, metadata={"help": "Use wandb to log experiment."}, ) wandb_project: str = field( default="LingBotVLA", metadata={"help": "Wandb project name."}, ) wandb_name: Optional[str] = field( default=None, metadata={"help": "Wandb experiment name."}, ) enable_profiling: bool = field( default=False, metadata={"help": "Enable profiling."}, ) profile_start_step: int = field( default=1, metadata={"help": "Start step for profiling."}, ) profile_end_step: int = field( default=2, metadata={"help": "End step for profiling."}, ) profile_trace_dir: str = field( default="./trace", metadata={"help": "Direction to export the profiling result."}, ) profile_record_shapes: bool = field( default=True, metadata={"help": "Whether or not to record the shapes of the input tensors."}, ) profile_profile_memory: bool = field( default=True, metadata={"help": "Whether or not to profile the memory usage."}, ) profile_with_stack: bool = field( default=True, metadata={"help": "Whether or not to record the stack traces."}, ) max_steps: Optional[int] = field( default=None, metadata={"help": "Max training steps per epoch. (for debug)"}, ) def __post_init__(self): self._train_steps = -1 self.local_rank = int(os.getenv("LOCAL_RANK")) self.global_rank = int(os.getenv("RANK")) self.world_size = int(os.getenv("WORLD_SIZE")) if ( self.world_size % ( self.pipeline_parallel_size * self.ulysses_parallel_size * self.context_parallel_size * self.tensor_parallel_size ) != 0 ): raise ValueError( f"World size should be a multiple of pipeline_parallel_size: {self.pipeline_parallel_size}, ulysses_parallel_size: {self.ulysses_parallel_size}, context_parallel_size: {self.context_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}." ) assert self.tensor_parallel_size == 1, "Tensor parallel size not supported yet." assert self.pipeline_parallel_size == 1, "Pipeline parallel size not supported yet." self.data_parallel_size = self.world_size // ( self.pipeline_parallel_size * self.ulysses_parallel_size * self.context_parallel_size * self.tensor_parallel_size ) # configure data parallel size if self.data_parallel_replicate_size > 0 and self.data_parallel_shard_size > 0: assert self.data_parallel_size == self.data_parallel_replicate_size * self.data_parallel_shard_size, ( f"data_parallel_size should be equal to data_parallel_replicate_size: {self.data_parallel_replicate_size} * data_parallel_shard_size: {self.data_parallel_shard_size}." ) elif self.data_parallel_replicate_size > 0: if self.data_parallel_size % self.data_parallel_replicate_size != 0: raise ValueError("data_parallel_size should be a multiple of data_parallel_replicate_size.") self.data_parallel_shard_size = self.data_parallel_size // self.data_parallel_replicate_size elif self.data_parallel_shard_size > 0: if self.data_parallel_size % self.data_parallel_shard_size != 0: raise ValueError("data_parallel_size should be a multiple of data_parallel_shard_size.") self.data_parallel_replicate_size = self.data_parallel_size // self.data_parallel_shard_size else: self.data_parallel_replicate_size = 1 self.data_parallel_shard_size = self.data_parallel_size if self.rmpad and self.rmpad_with_pos_ids: raise ValueError("`rmpad` and `rmpad_with_pos_ids` cannot be both True.") # init method check assert self.expert_parallel_size == 1 or self.init_device != "cpu", ( "cpu init is not supported when enable ep. Please use `init_device = cuda` or `init_device = meta` instead." ) # calculate gradient accumulation steps if self.global_batch_size is None: self.global_batch_size = self.micro_batch_size * self.data_parallel_size self.gradient_accumulation_steps = 1 logger.info_rank0("`global_batch_size` is None, disable gradient accumulation.") elif self.global_batch_size % (self.micro_batch_size * self.data_parallel_size) == 0: self.gradient_accumulation_steps = self.global_batch_size // ( self.micro_batch_size * self.data_parallel_size ) logger.info_rank0(f"Set gradient accumulation to {self.gradient_accumulation_steps}.") else: raise ValueError( f"`global_batch_size` should be a multiple of {self.micro_batch_size * self.data_parallel_size}." ) if self.gradient_accumulation_steps > 1 and self.enable_fsdp_offload: raise ValueError("Gradient accumulation is not supported with FSDP offload.") self.dataloader_batch_size = self.global_batch_size // self.data_parallel_size # = micro bsz * grad accu # merlin save paths self.save_checkpoint_path = os.path.join(self.output_dir, "checkpoints") self.model_assets_dir = os.path.join(self.output_dir, "model_assets") def compute_train_steps( self, max_seq_len: Optional[int] = None, train_size: Optional[int] = None, dataset_length: Optional[int] = None ) -> None: """ Computes the training steps per epoch according to the data length. """ if self.rmpad or self.rmpad_with_pos_ids: assert max_seq_len is not None and train_size is not None, "max_seq_len and train_size are required." token_micro_bsz = self.micro_batch_size * max_seq_len train_size = int(train_size * (1 + self.bsz_warmup_ratio / 2)) eff_token_rate = (token_micro_bsz - self.dyn_bsz_margin) / token_micro_bsz self._train_steps = math.ceil(train_size / (self.global_batch_size * max_seq_len * eff_token_rate)) elif dataset_length is not None: self._train_steps = math.floor(dataset_length / (self.dataloader_batch_size * self.world_size)) # assuming drop_last is true elif self.max_steps is not None: self._train_steps = self.max_steps else: raise ValueError("Please provide `dataset_length` or `max_steps`!") @property def train_steps(self) -> int: if self.max_steps is not None and self._train_steps >= self.max_steps: logger.warning_once(f"Set train_steps to {self.max_steps}. It should be for debug purpose only.") return self.max_steps if self._train_steps == -1: raise ValueError("Please run `compute_train_steps` first!") return self._train_steps @dataclass class InferArguments: model_path: str = field( metadata={"help": "Path to the pre-trained model."}, ) tokenizer_path: Optional[str] = field( default=None, metadata={"help": "Path to the tokenizer. Defaults to `config_path`."}, ) seed: int = field( default=42, metadata={"help": "Random seed."}, ) do_sample: bool = field( default=True, metadata={"help": "Whether or not to use sampling in decoding."}, ) temperature: float = field( default=1.0, metadata={"help": "The temperature value of decoding."}, ) top_p: float = field( default=1.0, metadata={"help": "The top_p value of decoding."}, ) max_tokens: int = field( default=1024, metadata={"help": "Max tokens to generate."}, ) def __post_init__(self): if self.tokenizer_path is None: self.tokenizer_path = self.model_path def _string_to_bool(value: Union[bool, str]) -> bool: """ Converts a string input to bool value. Taken from: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse """ if isinstance(value, bool): return value if value.lower() in ("yes", "true", "t", "y", "1"): return True if value.lower() in ("no", "false", "f", "n", "0"): return False raise argparse.ArgumentTypeError( f"Truthy value expected: got {value} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." ) def _convert_str_dict(input_dict: Dict[str, Any]) -> Dict[str, Any]: """ Safely checks that a passed value is a dictionary and converts any string values to their appropriate types. Taken from: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/training_args.py#L189 """ for key, value in input_dict.items(): if isinstance(value, dict): input_dict[key] = _convert_str_dict(value) elif isinstance(value, str): if value.lower() in ("true", "false"): # check for bool input_dict[key] = value.lower() == "true" elif value.isdigit(): # check for digit input_dict[key] = int(value) elif value.replace(".", "", 1).isdigit(): input_dict[key] = float(value) return input_dict def _make_choice_type_function(choices: List[Any]) -> Callable[[str], Any]: """ Creates a mapping function from each choices string representation to the actual value. Used to support multiple value types for a single argument. Based on: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/hf_argparser.py#L48 Args: choices (list): List of choices. Returns: Callable[[str], Any]: Mapping function from string representation to actual value for each choice. """ str_to_choice = {str(choice): choice for choice in choices} return lambda arg: str_to_choice.get(arg, arg) def parse_args(rootclass: T) -> T: """ Parses the root argument class using the CLI inputs or yaml inputs. Based on: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/hf_argparser.py#L266 """ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) base_to_subclass = {} dict_fields = set() list_fields = set() for subclass in fields(rootclass): base = subclass.name base_to_subclass[base] = subclass.default_factory try: type_hints: Dict[str, type] = get_type_hints(subclass.default_factory) except Exception: raise RuntimeError(f"Type resolution failed for {subclass.default_factory}.") for attr in fields(subclass.default_factory): if not attr.init: continue attr_type = type_hints[attr.name] origin_type = getattr(attr_type, "__origin__", attr_type) if isinstance(attr_type, str): raise RuntimeError(f"Cannot resolve type {attr.type} of {attr.name}.") if origin_type is Union or (hasattr(types, "UnionType") and isinstance(origin_type, types.UnionType)): if len(attr_type.__args__) != 2 or type(None) not in attr_type.__args__: # only allows Optional[X] raise RuntimeError(f"Cannot resolve type {attr.type} of {attr.name}.") if bool not in attr_type.__args__: # except for `Union[bool, NoneType]` attr_type = ( attr_type.__args__[0] if isinstance(None, attr_type.__args__[1]) else attr_type.__args__[1] ) origin_type = getattr(attr_type, "__origin__", attr_type) parser_kwargs = attr.metadata.copy() if origin_type is Literal or (isinstance(attr_type, type) and issubclass(attr_type, Enum)): if origin_type is Literal: parser_kwargs["choices"] = attr_type.__args__ else: parser_kwargs["choices"] = [x.value for x in attr_type] parser_kwargs["type"] = _make_choice_type_function(parser_kwargs["choices"]) if attr.default is not MISSING: parser_kwargs["default"] = attr.default else: parser_kwargs["required"] = True elif attr_type is bool or attr_type == Optional[bool]: parser_kwargs["type"] = _string_to_bool if attr_type is bool or (attr.default is not None and attr.default is not MISSING): parser_kwargs["default"] = False if attr.default is MISSING else attr.default parser_kwargs["nargs"] = "?" parser_kwargs["const"] = True elif isclass(origin_type) and issubclass(origin_type, list): parser_kwargs["type"] = attr_type.__args__[0] parser_kwargs["nargs"] = "+" list_fields.add(f"{base}.{attr.name}") if attr.default_factory is not MISSING: parser_kwargs["default"] = attr.default_factory() elif attr.default is MISSING: parser_kwargs["required"] = True elif isclass(origin_type) and issubclass(origin_type, dict): parser_kwargs["type"] = str # parse dict inputs with json string dict_fields.add(f"{base}.{attr.name}") if attr.default_factory is not MISSING: parser_kwargs["default"] = str(attr.default_factory()) elif attr.default is MISSING: parser_kwargs["required"] = True else: parser_kwargs["type"] = attr_type if attr.default is not MISSING: parser_kwargs["default"] = attr.default elif attr.default_factory is not MISSING: parser_kwargs["default"] = attr.default_factory() else: parser_kwargs["required"] = True parser.add_argument(f"--{base}.{attr.name}", **parser_kwargs) cmd_args = sys.argv[1:] cmd_args_string = "=".join(cmd_args) # use `=` to mark the end of arg name input_data = {} if cmd_args[0].endswith(".yaml") or cmd_args[0].endswith(".yml"): input_path = cmd_args.pop(0) with open(os.path.abspath(input_path), encoding="utf-8") as f: input_data: Dict[str, Dict[str, Any]] = yaml.safe_load(f) elif cmd_args[0].endswith(".json"): input_path = cmd_args.pop(0) with open(os.path.abspath(input_path), encoding="utf-8") as f: input_data: Dict[str, Dict[str, Any]] = json.load(f) for base, arg_dict in input_data.items(): for arg_name, arg_value in arg_dict.items(): if f"--{base}.{arg_name}=" not in cmd_args_string: # lower priority # Skip list fields with None values to use default if f"{base}.{arg_name}" in list_fields and arg_value is None: continue cmd_args.append(f"--{base}.{arg_name}") if f"{base}.{arg_name}" in list_fields and isinstance(arg_value, list): # For list fields, extend the arguments with individual elements cmd_args.extend([str(item) for item in arg_value]) else: cmd_args.append(arg_value if isinstance(arg_value, str) else json.dumps(arg_value)) args, remaining_args = parser.parse_known_args(cmd_args) if remaining_args: raise ValueError(f"Some specified arguments are not used by the ArgumentParser: {remaining_args}") parse_result = defaultdict(dict) for key, value in vars(args).items(): if key in dict_fields: if isinstance(value, str) and value.startswith("{"): value = _convert_str_dict(json.loads(value)) else: raise ValueError(f"Expect a json string for dict argument, but got {value}") base, name = key.split(".", maxsplit=1) parse_result[base][name] = value data_classes = {} for base, subclass_type in base_to_subclass.items(): data_classes[base] = subclass_type(**parse_result.get(base, {})) return rootclass(**data_classes) def save_args(args: T, output_path: str) -> None: """ Saves arguments to a json file. Args: args (dataclass): Arguments. output_path (str): Output path. """ local_dir = output_path os.makedirs(local_dir, exist_ok=True) local_path = os.path.join(local_dir, "lingbotvla_cli.yaml") with open(local_path, "w") as f: f.write(yaml.safe_dump(asdict(args), default_flow_style=False))