| import logging |
| import os |
| import os.path as osp |
| import subprocess |
| import sys |
| import time |
| import traceback |
| from multiprocessing import Manager, Pool |
| from multiprocessing.managers import SyncManager |
| from typing import Any, Dict, List, Tuple |
|
|
| import mmengine |
| from mmengine.config import ConfigDict |
| from tqdm import tqdm |
|
|
| from opencompass.registry import RUNNERS, TASKS |
| from opencompass.tasks import OpenICLInferTask |
| from opencompass.tasks.base import BaseTask |
| from opencompass.utils import (build_dataset_from_cfg, build_model_from_cfg, |
| get_infer_output_path, get_logger, |
| task_abbr_from_cfg) |
|
|
| from .base import BaseRunner |
|
|
|
|
| def monkey_run(self, tokens: SyncManager.Semaphore): |
| """Hack for infer task run, add tokens for multiprocess.""" |
| self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}') |
| for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs): |
| self.max_out_len = model_cfg.get('max_out_len', None) |
| self.min_out_len = model_cfg.get('min_out_len', None) |
| self.batch_size = model_cfg.get('batch_size', None) |
| self.model = build_model_from_cfg(model_cfg) |
| |
| assert self.model.is_api, 'Only API model is supported.' |
| self.model.tokens = tokens |
|
|
| for dataset_cfg in dataset_cfgs: |
| self.model_cfg = model_cfg |
| self.dataset_cfg = dataset_cfg |
| self.infer_cfg = self.dataset_cfg['infer_cfg'] |
| self.dataset = build_dataset_from_cfg(self.dataset_cfg) |
| self.sub_cfg = { |
| 'models': [self.model_cfg], |
| 'datasets': [[self.dataset_cfg]], |
| } |
| out_path = get_infer_output_path( |
| self.model_cfg, self.dataset_cfg, |
| osp.join(self.work_dir, 'predictions')) |
| if osp.exists(out_path): |
| continue |
| self._inference() |
|
|
|
|
| old_stdout = sys.stdout |
| old_stderr = sys.stderr |
|
|
|
|
| def redirect_std_to_file(filename: str): |
| """Redirect stdout and stderr, also change logger stream handler.""" |
| f = open(filename, 'w', encoding='utf-8') |
| sys.stdout = f |
| sys.stderr = f |
| |
| logger = get_logger() |
| for h in logger.handlers: |
| if isinstance(h, logging.StreamHandler): |
| h.stream = sys.stdout |
| |
| gen_logger = logging.getLogger( |
| 'opencompass.openicl.icl_inferencer.icl_gen_inferencer') |
| for h in gen_logger.handlers: |
| if isinstance(h, logging.StreamHandler): |
| h.stream = sys.stdout |
|
|
|
|
| def reset_std(): |
| """Reset stdout and stderr, also change logger stream handler.""" |
| sys.stdout.close() |
| sys.stdout = old_stdout |
| sys.stderr = old_stderr |
| |
| logger = get_logger() |
| for h in logger.handlers: |
| if isinstance(h, logging.StreamHandler): |
| h.stream = sys.stdout |
| |
| gen_logger = logging.getLogger( |
| 'opencompass.openicl.icl_inferencer.icl_gen_inferencer') |
| for h in gen_logger.handlers: |
| if isinstance(h, logging.StreamHandler): |
| h.stream = sys.stdout |
|
|
|
|
| def launch(task: BaseTask, tokens: SyncManager.Semaphore): |
| """Launch a single task. |
| |
| Args: |
| task (BaseTask): Task to launch. |
| tokens (SyncManager.Semaphore): Multiprocessing semaphore |
| for every subprocess to follow. |
| |
| Returns: |
| tuple[str, int]: Task name and exit code. |
| """ |
|
|
| task_name = task.name |
| returncode = 0 |
| logger = get_logger() |
|
|
| try: |
| |
| out_path = task.get_log_path(file_extension='out') |
| mmengine.mkdir_or_exist(osp.split(out_path)[0]) |
| redirect_std_to_file(out_path) |
|
|
| |
| start_time = time.time() |
| inferencer = OpenICLInferTask(task.cfg) |
| origin_run = inferencer.run |
| inferencer.run = monkey_run |
| inferencer.run(inferencer, tokens) |
| inferencer.run = origin_run |
| end_time = time.time() |
| logger.info(f'time elapsed: {end_time - start_time:.2f}s') |
| except Exception: |
| |
| traceback.print_exc() |
| |
| reset_std() |
| logger.warning(f'task {task_name} fail, see\n{out_path}') |
| returncode = 1 |
| else: |
| |
| reset_std() |
| return task_name, returncode |
|
|
|
|
| def submit(task, type, tokens): |
| """Helper for launch the task.""" |
| task = TASKS.build(dict(cfg=task, type=type)) |
| tqdm.write(f'Launch {task.name} on CPU ') |
|
|
| res = launch(task, tokens) |
| return res |
|
|
|
|
| @RUNNERS.register_module() |
| class LocalAPIRunner(BaseRunner): |
| """Local API Runner. Start tasks by local python. |
| |
| The query per second cannot guarantee the number of concurrents, therefore |
| Supported concurrent users with multiple tasks. Applied for those apis |
| which has a restriction on concurrent numbers. |
| |
| Args: |
| task (ConfigDict): Task type config. |
| concurrent_users (int): Max number of concurrent workers to request |
| the resources. |
| max_num_workers (int): Max number of workers to run in parallel. |
| Defaults to 16. |
| debug (bool): Whether to run in debug mode. |
| lark_bot_url (str): Lark bot url. |
| """ |
|
|
| def __init__(self, |
| task: ConfigDict, |
| concurrent_users: int, |
| max_num_workers: int = 16, |
| debug: bool = False, |
| lark_bot_url: str = None): |
| super().__init__(task=task, debug=debug, lark_bot_url=lark_bot_url) |
| self.max_num_workers = max_num_workers |
| self.concurrent_users = concurrent_users |
| assert task['type'] in [ |
| 'OpenICLInferTask', |
| 'opencompass.tasks.OpenICLInferTask', |
| ], 'Only supported for api infer task.' |
|
|
| def launch(self, tasks: List[Dict[str, Any]]) -> List[Tuple[str, int]]: |
| """Launch multiple tasks. |
| |
| Args: |
| tasks (list[dict]): A list of task configs, usually generated by |
| Partitioner. |
| |
| Returns: |
| list[tuple[str, int]]: A list of (task name, exit code). |
| """ |
| status = [] |
| if self.debug: |
| |
| for task in tasks: |
| task = TASKS.build(dict(cfg=task, type=self.task_cfg['type'])) |
| task_name = task.name |
| |
| mmengine.mkdir_or_exist('tmp/') |
| param_file = f'tmp/{os.getpid()}_params.py' |
| try: |
| task.cfg.dump(param_file) |
| cmd = task.get_command(cfg_path=param_file, |
| template='{task_cmd}') |
| |
| if cmd.startswith('python'): |
| task.run() |
| else: |
| subprocess.run(cmd, shell=True, text=True) |
| finally: |
| os.remove(param_file) |
| status.append((task_name, 0)) |
| else: |
|
|
| pbar = tqdm(total=len(tasks)) |
|
|
| get_logger().info('All the logs and processes for each task' |
| ' should be checked in each infer/.out file.') |
| with Manager() as manager: |
| tokens = manager.Semaphore(self.concurrent_users) |
| |
| |
| pbar_counter = manager.Value('i', 0) |
| status = [] |
|
|
| def update(args): |
| """Update pbar counter when callback.""" |
| pbar_counter.value += 1 |
| status.append(args) |
|
|
| with Pool(processes=self.max_num_workers) as pool: |
| for task in tasks: |
| pool.apply_async(submit, |
| (task, self.task_cfg['type'], tokens), |
| callback=update) |
| pool.close() |
|
|
| |
| while True: |
| cur_count = pbar_counter.value |
| if cur_count > pbar.n: |
| pbar.update(cur_count - pbar.n) |
| |
| if cur_count >= pbar.total: |
| pbar.close() |
| break |
| |
| time.sleep(1) |
|
|
| pool.join() |
| return status |
|
|