|
|
|
|
|
import importlib.util |
|
|
import os |
|
|
import subprocess |
|
|
import sys |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
from swift.utils import get_logger |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
ROUTE_MAPPING: Dict[str, str] = { |
|
|
'pt': 'swift.cli.pt', |
|
|
'sft': 'swift.cli.sft', |
|
|
'infer': 'swift.cli.infer', |
|
|
'merge-lora': 'swift.cli.merge_lora', |
|
|
'web-ui': 'swift.cli.web_ui', |
|
|
'deploy': 'swift.cli.deploy', |
|
|
'rollout': 'swift.cli.rollout', |
|
|
'rlhf': 'swift.cli.rlhf', |
|
|
'sample': 'swift.cli.sample', |
|
|
'export': 'swift.cli.export', |
|
|
'eval': 'swift.cli.eval', |
|
|
'app': 'swift.cli.app', |
|
|
} |
|
|
|
|
|
|
|
|
def use_torchrun() -> bool: |
|
|
nproc_per_node = os.getenv('NPROC_PER_NODE') |
|
|
nnodes = os.getenv('NNODES') |
|
|
if nproc_per_node is None and nnodes is None: |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
def get_torchrun_args() -> Optional[List[str]]: |
|
|
if not use_torchrun(): |
|
|
return |
|
|
torchrun_args = [] |
|
|
for env_key in ['NPROC_PER_NODE', 'MASTER_PORT', 'NNODES', 'NODE_RANK', 'MASTER_ADDR']: |
|
|
env_val = os.getenv(env_key) |
|
|
if env_val is None: |
|
|
continue |
|
|
torchrun_args += [f'--{env_key.lower()}', env_val] |
|
|
return torchrun_args |
|
|
|
|
|
|
|
|
def _compat_web_ui(argv): |
|
|
|
|
|
method_name = argv[0] |
|
|
if method_name in {'web-ui', 'web_ui'} and ('--model' in argv or '--adapters' in argv or '--ckpt_dir' in argv): |
|
|
argv[0] = 'app' |
|
|
logger.warning('Please use `swift app`.') |
|
|
|
|
|
|
|
|
def cli_main(route_mapping: Optional[Dict[str, str]] = None) -> None: |
|
|
route_mapping = route_mapping or ROUTE_MAPPING |
|
|
argv = sys.argv[1:] |
|
|
_compat_web_ui(argv) |
|
|
method_name = argv[0].replace('_', '-') |
|
|
argv = argv[1:] |
|
|
file_path = importlib.util.find_spec(route_mapping[method_name]).origin |
|
|
torchrun_args = get_torchrun_args() |
|
|
python_cmd = sys.executable |
|
|
if torchrun_args is None or method_name not in {'pt', 'sft', 'rlhf', 'infer'}: |
|
|
args = [python_cmd, file_path, *argv] |
|
|
else: |
|
|
args = [python_cmd, '-m', 'torch.distributed.run', *torchrun_args, file_path, *argv] |
|
|
print(f"run sh: `{' '.join(args)}`", flush=True) |
|
|
result = subprocess.run(args) |
|
|
if result.returncode != 0: |
|
|
sys.exit(result.returncode) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
cli_main() |
|
|
|