| import os |
| import os.path as osp |
| import random |
| import subprocess |
| import time |
| from functools import partial |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import mmengine |
| from mmengine.config import ConfigDict |
| from mmengine.utils import track_parallel_progress |
|
|
| from opencompass.registry import RUNNERS, TASKS |
| from opencompass.utils import get_logger |
|
|
| from .base import BaseRunner |
|
|
|
|
| @RUNNERS.register_module() |
| class SlurmRunner(BaseRunner): |
| """Distributed runner based on Slurm. It will launch tasks in parallel |
| using `srun` command. |
| |
| Args: |
| task (ConfigDict): Task type config. |
| max_num_workers (int): Max number of workers to run in parallel. |
| Defaults to 32. |
| retry (int): Number of retries if the job failed. Defaults to 2. |
| partition (str): Slurm partition name. Defaults to None. |
| quotatype (str): Slurm quota type. Defaults to None. |
| qos (str): Slurm quality of service. Defaults to None. |
| debug (bool): Whether to run in debug mode. Defaults to False. |
| lark_bot_url (str): Lark bot url. Defaults to None. |
| extra_command (List, optional): Extra slurm command. |
| For example ['-c 12', '-w node1']. Defaults to None. |
| """ |
|
|
| def __init__(self, |
| task: ConfigDict, |
| max_num_workers: int = 32, |
| retry: int = 2, |
| partition: str = None, |
| quotatype: str = None, |
| qos: str = None, |
| debug: bool = False, |
| lark_bot_url: str = None, |
| extra_command: Optional[List[str]] = None): |
| super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url) |
| self.max_num_workers = max_num_workers |
| self.retry = retry |
| self.partition = partition |
| self.quotatype = quotatype |
| self.qos = qos |
| if not extra_command: |
| extra_command = [] |
| assert isinstance(extra_command, list) |
| self.extra_command = extra_command |
|
|
| def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]: |
| """Launch multiple tasks. |
| |
| Args: |
| tasks (list[dict]): A list of task configs, usually generated by |
| Partitioner. |
| |
| Returns: |
| list[tuple[str, int]]: A list of (task name, exit code). |
| """ |
|
|
| if not self.debug: |
| status = track_parallel_progress(self._launch, |
| tasks, |
| nproc=self.max_num_workers, |
| keep_order=False) |
| else: |
| status = [self._launch(task, random_sleep=False) for task in tasks] |
| return status |
|
|
| def _launch(self, cfg: ConfigDict, random_sleep: bool = True): |
| """Launch a single task. |
| |
| Args: |
| cfg (ConfigDict): Task config. |
| random_sleep (bool): Whether to sleep for a random time before |
| running the command. This avoids cluster error when launching |
| multiple tasks at the same time. Default: True. |
| |
| Returns: |
| tuple[str, int]: Task name and exit code. |
| """ |
| task = TASKS.build(dict(cfg=cfg, type=self.task_cfg['type'])) |
| num_gpus = task.num_gpus |
| task_name = task.name |
|
|
| |
| mmengine.mkdir_or_exist('tmp/') |
| param_file = f'tmp/{os.getpid()}_params.py' |
| try: |
| cfg.dump(param_file) |
|
|
| |
| tmpl = 'srun' |
| if self.partition: |
| tmpl += f' -p {self.partition}' |
| if self.quotatype: |
| tmpl += f' --quotatype={self.quotatype}' |
| if self.qos: |
| tmpl += f' --qos={self.qos}' |
| if num_gpus > 0: |
| tmpl += f' --gres=gpu:{num_gpus}' |
| for extra_cmd in self.extra_command: |
| tmpl += f' {extra_cmd}' |
| tmpl += f" -N1 -u -J '{task_name[:512]}'" + ' {task_cmd}' |
| get_cmd = partial(task.get_command, |
| cfg_path=param_file, |
| template=tmpl) |
| cmd = get_cmd() |
|
|
| logger = get_logger() |
| logger.debug(f'Running command: {cmd}') |
|
|
| |
| if self.debug: |
| stdout = None |
| else: |
| out_path = task.get_log_path(file_extension='out') |
| mmengine.mkdir_or_exist(osp.split(out_path)[0]) |
| stdout = open(out_path, 'w', encoding='utf-8') |
|
|
| if random_sleep: |
| time.sleep(random.randint(0, 10)) |
| result = subprocess.run(cmd, |
| shell=True, |
| text=True, |
| stdout=stdout, |
| stderr=stdout) |
|
|
| retry = self.retry |
| output_paths = task.get_output_paths() |
| while self._job_failed(result.returncode, |
| output_paths) and retry > 0: |
| retry -= 1 |
| if random_sleep: |
| time.sleep(random.randint(0, 10)) |
| |
| cmd = get_cmd() |
| result = subprocess.run(cmd, |
| shell=True, |
| text=True, |
| stdout=stdout, |
| stderr=stdout) |
|
|
| if result.returncode != 0 and not self.debug: |
| logger.warning(f'task {task_name} fail, see\n{out_path}') |
| finally: |
| |
| os.remove(param_file) |
| return task_name, result.returncode |
|
|
| def _job_failed(self, return_code: int, output_paths: List[str]) -> bool: |
| return return_code != 0 or not all( |
| osp.exists(output_path) for output_path in output_paths) |
|
|