trl-mcsd / trl /cli /accelerate_config.py
ihbkaiser's picture
Implement MCSD for experimental SDPO
1fa3c6c verified
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.resources as resources
from pathlib import Path
def resolve_accelerate_config_argument(launch_args: list[str]) -> list[str]:
"""
Resolve `--accelerate_config` from CLI arguments into `accelerate --config_file`.
The function supports either a filesystem path or a predefined config name shipped in `trl/accelerate_configs`
(without the `.yaml` suffix).
"""
if "--accelerate_config" not in launch_args:
return launch_args
config_index = launch_args.index("--accelerate_config")
if config_index + 1 >= len(launch_args):
raise ValueError("Expected a value after `--accelerate_config`.")
config_name = launch_args[config_index + 1]
if Path(config_name).is_file():
accelerate_config_path = config_name
else:
candidate = resources.files("trl.accelerate_configs").joinpath(f"{config_name}.yaml")
if not candidate.exists():
raise ValueError(
f"Accelerate config {config_name} is neither a file nor a valid config in the `trl` package. "
"Please provide a valid config name or a path to a config file."
)
accelerate_config_path = candidate
# Remove '--accelerate_config <value>'.
launch_args = launch_args[:config_index] + launch_args[config_index + 2 :]
return ["--config_file", str(accelerate_config_path)] + launch_args