| import argparse |
| import copy |
| import fnmatch |
| import math |
| import os.path as osp |
| import statistics |
| import time |
| from collections import Counter |
| from inspect import signature |
| from shutil import which |
| from typing import List, Optional |
|
|
| import mmengine |
| from mmengine.config import Config, ConfigDict |
| from mmengine.utils import mkdir_or_exist |
|
|
| from opencompass.registry import (ICL_EVALUATORS, MODELS, TASKS, |
| TEXT_POSTPROCESSORS) |
| from opencompass.tasks.base import BaseTask |
| from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg, |
| get_infer_output_path, get_logger, |
| task_abbr_from_cfg) |
|
|
|
|
| def extract_role_pred(s: str, begin_str: Optional[str], |
| end_str: Optional[str]) -> str: |
| """Extract the role prediction from the full prediction string. The role |
| prediction may be the substring between the begin and end string. |
| |
| Args: |
| s (str): Full prediction string. |
| begin_str (str): The beginning string of the role |
| end_str (str): The ending string of the role. |
| |
| Returns: |
| str: The extracted role prediction. |
| """ |
| start = 0 |
| end = len(s) |
|
|
| if begin_str: |
| begin_idx = s.find(begin_str) |
| if begin_idx != -1: |
| start = begin_idx + len(begin_str) |
|
|
| if end_str: |
| |
| |
| end_idx = s.find(end_str, start) |
| if end_idx != -1: |
| end = end_idx |
|
|
| return s[start:end] |
|
|
|
|
| @TASKS.register_module(force=(__name__ == '__main__')) |
| class OpenICLEvalTask(BaseTask): |
| """OpenICL Evaluation Task. |
| |
| This task is used to evaluate the metric between predictions and |
| references. |
| """ |
|
|
| name_prefix = 'OpenICLEval' |
| log_subdir = 'logs/eval' |
| output_subdir = 'results' |
|
|
| def __init__(self, cfg: ConfigDict): |
| super().__init__(cfg) |
| self.logger = get_logger() |
| self.num_gpus = max( |
| c.get('eval_cfg', {}).get('num_gpus', 0) |
| for c in sum(self.dataset_cfgs, [])) |
| self.dump_details = cfg.get('eval', {}).get('runner', {}).get( |
| 'task', {}).get('dump_details', False) |
|
|
| def get_command(self, cfg_path, template): |
| script_path = __file__ |
| python = 'python3' if which('python3') else 'python' |
| command = f'{python} {script_path} {cfg_path}' |
| return template.format(task_cmd=command) |
|
|
| def run(self): |
| for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs): |
| for dataset_cfg in dataset_cfgs: |
| self.model_cfg = model_cfg |
| self.dataset_cfg = dataset_cfg |
|
|
| |
| self.eval_cfg = self.dataset_cfg.get('eval_cfg') |
| self.output_column = dataset_cfg['reader_cfg']['output_column'] |
|
|
| |
| ds_abbr = dataset_abbr_from_cfg(self.dataset_cfg) |
| model_postprocessors = self.model_cfg.get( |
| 'pred_postprocessor', {}) |
| for pattern in model_postprocessors.keys(): |
| if fnmatch.fnmatch(ds_abbr, pattern): |
| self.eval_cfg[ |
| 'pred_postprocessor'] = model_postprocessors[ |
| pattern] |
| break |
|
|
| out_path = get_infer_output_path( |
| self.model_cfg, self.dataset_cfg, |
| osp.join(self.work_dir, 'results')) |
| if osp.exists(out_path): |
| continue |
| self._score() |
|
|
| def _score(self): |
| test_set = build_dataset_from_cfg(self.dataset_cfg).test |
| |
| if 'dataset_postprocessor' in self.eval_cfg: |
| proc = self.eval_cfg['dataset_postprocessor']['type'] |
| if isinstance(proc, str): |
| proc = TEXT_POSTPROCESSORS.get(proc) |
|
|
| def postprocess(sample): |
| s = sample[self.output_column] |
| sample[self.output_column] = proc(s) |
| return sample |
|
|
| test_set = test_set.map(postprocess) |
|
|
| |
| filename = get_infer_output_path( |
| self.model_cfg, self.dataset_cfg, |
| osp.join(self.work_dir, 'predictions')) |
| |
| root, ext = osp.splitext(filename) |
| partial_filename = root + '_0' + ext |
|
|
| |
| sc_size = self.eval_cfg.get('sc_size') |
|
|
| if not osp.exists(osp.realpath(filename)) and not osp.exists( |
| osp.realpath(partial_filename)): |
| result = {'error': 'No predictions found.'} |
| else: |
| if osp.exists(osp.realpath(filename)): |
| preds = mmengine.load(filename) |
| preds = [preds[str(i)] for i in range(len(preds))] |
| else: |
| filename = partial_filename |
| preds = [] |
| i = 1 |
| while osp.exists(osp.realpath(filename)): |
| sub_preds = mmengine.load(filename) |
| preds.extend( |
| [sub_preds[str(i)] for i in range(len(sub_preds))]) |
| filename = root + f'_{i}' + ext |
| i += 1 |
| pred_dicts = copy.deepcopy(preds) |
| preds = {k: [pred.get(k) for pred in preds] for k in preds[0]} |
|
|
| pred_strs = preds.pop('prediction', None) |
| pred_list_flag = pred_strs is not None and isinstance( |
| pred_strs[0], list) |
| if ('pred_role' in self.eval_cfg |
| and 'meta_template' in self.model_cfg |
| and not MODELS.get(self.model_cfg['type']).is_api): |
| |
| from opencompass.models.base import LMTemplateParser |
| parser = LMTemplateParser(self.model_cfg['meta_template']) |
| role = parser.roles[self.eval_cfg['pred_role']] |
| if sc_size is not None: |
| assert pred_list_flag, ( |
| 'The prediction for Self-Consistency' |
| 'must be list.') |
| if pred_list_flag: |
| pred_strs = [[ |
| extract_role_pred(_pred, role.get('begin', None), |
| role.get('end', None)) |
| for _pred in pred |
| ] for pred in pred_strs] |
| else: |
| pred_strs = [ |
| extract_role_pred(pred, role.get('begin', None), |
| role.get('end', None)) |
| for pred in pred_strs |
| ] |
|
|
| |
| if 'pred_postprocessor' in self.eval_cfg: |
| kwargs = self.eval_cfg['pred_postprocessor'] |
| proc = kwargs.pop('type') |
| if isinstance(proc, str): |
| proc = TEXT_POSTPROCESSORS.get(proc) |
| if pred_list_flag: |
| pred_strs = [[proc(s, **kwargs) for s in preds] |
| for preds in pred_strs] |
| else: |
| pred_strs = [proc(s, **kwargs) for s in pred_strs] |
|
|
| |
| if sc_size is not None: |
| pred_strs = [ |
| Counter(s).most_common(1)[0][0] for s in pred_strs |
| ] |
|
|
| icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator']) |
| |
| out_path = get_infer_output_path( |
| self.model_cfg, self.dataset_cfg, |
| osp.join(self.work_dir, 'results')) |
| icl_evaluator._out_dir = osp.splitext(out_path)[ |
| 0] |
|
|
| preds['predictions'] = pred_strs |
| preds['references'] = (test_set[self.output_column] |
| if self.output_column else None) |
| preds['test_set'] = test_set |
| preds = { |
| k: preds[k] |
| for k in signature(icl_evaluator.score).parameters |
| } |
| result = icl_evaluator.score(**preds) |
|
|
| if self.dump_details: |
| details = result.get('details', None) |
| try: |
| result['details'] = self.format_details( |
| pred_strs, test_set[self.output_column], details, |
| pred_dicts) |
| result['type'] = result['details'].pop('type', None) |
|
|
| if 'PPL' in str( |
| self.dataset_cfg.infer_cfg.inferencer.type): |
| result['correct_bpb'], result['incorrect_bpb'] = \ |
| self.calculate_bpb(pred_dicts) |
| except Exception as e: |
| self.logger.warning(f'Skip dumping details due to: {e}.') |
| else: |
| result.pop('details', None) |
|
|
| if 'error' in result: |
| self.logger.error( |
| f'Task {task_abbr_from_cfg(self.cfg)}: {result["error"]}') |
| return |
| else: |
| result_wo_details = { |
| i: result[i] |
| for i in result if i != 'details' |
| } |
| self.logger.info( |
| f'Task {task_abbr_from_cfg(self.cfg)}: {result_wo_details}') |
|
|
| |
| out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg, |
| osp.join(self.work_dir, 'results')) |
| mkdir_or_exist(osp.split(out_path)[0]) |
| mmengine.dump(result, out_path, ensure_ascii=False, indent=4) |
|
|
| def format_details(self, predictions, references, details, pred_dicts): |
| """This function is responsible for formatting prediction details. |
| |
| Args: |
| predictions (list): The prediction list. |
| references (list): The reference list. |
| details (list): Contains the 'pred' 'answer' and 'correct' for each |
| sample. Such as `[{'pred': '光荣和ωforce', |
| 'answers': ['光荣和ω-force', '光荣和ωforce'], 'correct': True}]` |
| pred_dicts (list): Contains a list of samples with the original |
| prompts. Such as |
| `[{'origin_prompt': '根据文章回答问题。你的答案应该尽可能3》…………', |
| 'prediction': ' 光荣和ω-force\n', 'gold': ['光荣和ω-force']}]` |
| |
| Returns: |
| list: The formatted prediction details. |
| """ |
| results = {} |
| for i in range(len(predictions)): |
| ppl_flag = False |
| result = {} |
| origin_prediction = copy.deepcopy(pred_dicts[i]) |
| origin_prediction.pop('in-context examples', None) |
| origin_prediction.pop('prediction', None) |
| keys = copy.deepcopy(list(origin_prediction.keys())) |
| for key in keys: |
| if key.startswith('label:'): |
| ppl_flag = True |
| origin_prediction[key].pop('testing input', None) |
| new_key = key.replace('label: ', '') |
| origin_prediction[new_key] = origin_prediction.pop(key) |
| if ppl_flag: |
| results['type'] = 'PPL' |
| result['origin_prediction'] = origin_prediction |
| result['predictions'] = str(predictions[i]) |
| result['references'] = str(references[i]) |
| result['correct'] = str(predictions[i]) == str(references[i]) |
| elif details is not None: |
| results['type'] = 'GEN' |
| result['prompt'] = origin_prediction['origin_prompt'] |
| result['origin_prediction'] = pred_dicts[i]['prediction'] |
| result['predictions'] = details[i]['pred'] |
| result['references'] = details[i]['answer'] |
| result['correct'] = details[i]['correct'] |
| else: |
| results['type'] = 'GEN' |
| result['prompt'] = origin_prediction['origin_prompt'] |
| result['origin_prediction'] = pred_dicts[i]['prediction'] |
| result['predictions'] = str(predictions[i]) |
| result['references'] = str(references[i]) |
| results[str(i)] = result |
| return results |
|
|
| def calculate_bpb(self, pred_dicts: List): |
| """This function is used to calculate the BPB (Bits Per Byte) for the |
| data. The correct BPB is obtained directly from the values in the |
| 'predictions' file. The incorrect BPB is the average of the remaining |
| BPB values for each sample under different labels after subtracting the |
| correct BPB. The calculation of BPB (Bits Per Byte) is similar to PPL, |
| with the difference that it computes the additional bits needed on |
| average, in terms of character length, to encode the true sequence |
| based on the predictions. This calculation involves applying a |
| weighting factor based on the ratio of words to characters. |
| |
| Args: |
| pred_dicts (list): Contains a list of samples with each options |
| and BPB scores. |
| |
| Returns: |
| dict: Contains correct and incorrect bpb. |
| """ |
| incorrect_bpb_list = [] |
| bpb_list = [] |
| for pred_dict in pred_dicts: |
| preds = { |
| key: value |
| for key, value in pred_dict.items() |
| if key.startswith('label: ') |
| } |
| values = [] |
| for item in preds.items(): |
| values.append(item[1]) |
| bpbs = [value['BPB'] for value in values] |
| incorrect_bpb_list.append( |
| (sum(bpbs) - min(bpbs)) / (len(bpbs) - 1)) |
| bpb_list.append(min(bpbs)) |
|
|
| def filters(origins): |
| targets = [target for target in origins if not math.isnan(target)] |
| return targets |
|
|
| mean_incorrect = statistics.mean(filters(incorrect_bpb_list)) |
| mean_correct = statistics.mean(filters(bpb_list)) |
| return 100 * mean_correct, 100 * mean_incorrect |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Score Calculator') |
| parser.add_argument('config', help='Config file path') |
| args = parser.parse_args() |
| return args |
|
|
|
|
| if __name__ == '__main__': |
| args = parse_args() |
| cfg = Config.fromfile(args.config) |
| start_time = time.time() |
| inferencer = OpenICLEvalTask(cfg) |
| inferencer.run() |
| end_time = time.time() |
| get_logger().info(f'time elapsed: {end_time - start_time:.2f}s') |
|
|