| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import enum |
| | import logging |
| | from pathlib import Path |
| |
|
| | import yaml |
| |
|
| | from accelerate.commands.utils import CustomArgumentParser |
| |
|
| |
|
| | class ConversionStatus(enum.Enum): |
| | NOT_YET_IMPLEMENTED = 0 |
| | REMOVED = -1 |
| |
|
| |
|
| | ARGUMENT_KEY_MAPPING = { |
| | |
| | "fsdp_version": "fsdp_version", |
| | "fsdp_reshard_after_forward": "fsdp_reshard_after_forward", |
| | |
| | |
| | "fsdp_auto_wrap_policy": "fsdp_auto_wrap_policy", |
| | "fsdp_backward_prefetch": ConversionStatus.REMOVED, |
| | "fsdp_forward_prefetch": ConversionStatus.NOT_YET_IMPLEMENTED, |
| | "fsdp_cpu_ram_efficient_loading": "fsdp_cpu_ram_efficient_loading", |
| | "fsdp_offload_params": "fsdp_offload_params", |
| | "fsdp_sharding_strategy": "fsdp_reshard_after_forward", |
| | "fsdp_state_dict_type": "fsdp_state_dict_type", |
| | "fsdp_sync_module_states": ConversionStatus.REMOVED, |
| | "fsdp_transformer_layer_cls_to_wrap": "fsdp_transformer_layer_cls_to_wrap", |
| | "fsdp_min_num_params": "fsdp_min_num_params", |
| | "fsdp_use_orig_params": ConversionStatus.REMOVED, |
| | "fsdp_activation_checkpointing": "fsdp_activation_checkpointing", |
| | } |
| |
|
| | ARGUMENT_VALUE_MAPPING = { |
| | "fsdp_sharding_strategy": { |
| | "FULL_SHARD": True, |
| | "SHARD_GRAD_OP": False, |
| | "HYBRID_SHARD": True, |
| | "HYBRID_SHARD_ZERO2": False, |
| | "NO_SHARD": False, |
| | }, |
| | "fsdp_reshard_after_forward": { |
| | "FULL_SHARD": True, |
| | "SHARD_GRAD_OP": False, |
| | "HYBRID_SHARD": True, |
| | "HYBRID_SHARD_ZERO2": False, |
| | "NO_SHARD": False, |
| | }, |
| | } |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def _validate_to_fsdp2_args(args): |
| | if not Path(args.config_file).exists(): |
| | raise FileNotFoundError(f"Config file {args.config_file} not found") |
| |
|
| | if not args.overwrite and args.output_file is None: |
| | raise ValueError("If --overwrite is not set, --output_file must be provided") |
| |
|
| | if not args.overwrite and Path(args.output_file).exists(): |
| | raise FileExistsError(f"Output file {args.output_file} already exists and --overwrite is not set") |
| |
|
| |
|
| | def convert_config_to_fsdp2(config: dict) -> dict: |
| | fsdp_config = config.get("fsdp_config", {}) |
| |
|
| | if not fsdp_config: |
| | logger.info("No FSDP config found in the config file, skipping conversion...") |
| | return config |
| |
|
| | new_fsdp_config = {} |
| |
|
| | if fsdp_config.get("fsdp_version", 1) == 2: |
| | logger.warning("Config already specifies FSDP2, skipping conversion...") |
| | logger.warning( |
| | "If the config doesn't use new argument names, change `fsdp_version` to `1` and rerun the command." |
| | ) |
| | return config |
| |
|
| | for key, value in fsdp_config.items(): |
| | conversion_status = ARGUMENT_KEY_MAPPING.get(key, None) |
| | if isinstance(conversion_status, ConversionStatus) or conversion_status is None: |
| | conversion_status = key |
| | new_fsdp_config[conversion_status] = value |
| | continue |
| |
|
| | if conversion_status == ConversionStatus.REMOVED: |
| | logger.warning(f"Argument {key} has been removed in FSDP2, skipping this key...") |
| | continue |
| |
|
| | if conversion_status == ConversionStatus.NOT_YET_IMPLEMENTED: |
| | logger.warning(f"Argument {key} is not yet implemented in FSDP2, skipping this key...") |
| | continue |
| |
|
| | if conversion_status is None: |
| | logger.warning(f"Argument {key} is not being converted, skipping this key...") |
| | new_fsdp_config[key] = value |
| | else: |
| | if key in ARGUMENT_VALUE_MAPPING: |
| | value = ARGUMENT_VALUE_MAPPING[key].get(value, value) |
| | new_fsdp_config[ARGUMENT_KEY_MAPPING[key]] = value |
| |
|
| | new_fsdp_config["fsdp_version"] = 2 |
| | config["fsdp_config"] = new_fsdp_config |
| | return config |
| |
|
| |
|
| | def to_fsdp2_command_parser(subparsers=None): |
| | description = "Convert an Accelerate config from FSDP1 to FSDP2" |
| |
|
| | if subparsers is not None: |
| | parser = subparsers.add_parser("to-fsdp2", description=description) |
| | else: |
| | parser = CustomArgumentParser(description=description) |
| |
|
| | parser.add_argument("--config_file", type=str, help="The config file to convert to FSDP2", required=True) |
| | parser.add_argument( |
| | "--overwrite", |
| | action="store_true", |
| | help="Overwrite the config file if it exists", |
| | default=False, |
| | ) |
| | parser.add_argument( |
| | "--output_file", |
| | type=str, |
| | help="The path to the output file to write the converted config to. If not provided, the input file will be overwritten (if --overwrite is set)", |
| | default=None, |
| | ) |
| | if subparsers is not None: |
| | parser.set_defaults(func=to_fsdp2_command) |
| |
|
| | return parser |
| |
|
| |
|
| | def load_config(config_file: str) -> dict: |
| | with open(config_file) as f: |
| | config = yaml.safe_load(f) |
| | if not config: |
| | raise ValueError("Config file is empty") |
| |
|
| | return config |
| |
|
| |
|
| | def to_fsdp2_command(args): |
| | _validate_to_fsdp2_args(args) |
| | config = load_config(args.config_file) |
| |
|
| | if args.overwrite and args.output_file is None: |
| | args.output_file = args.config_file |
| |
|
| | new_config = convert_config_to_fsdp2(config) |
| |
|
| | with open(args.output_file, "w") as f: |
| | yaml.dump(new_config, f) |
| |
|