| import logging |
| from typing import Callable, Dict, Union |
|
|
| import evaluate as hf_evaluate |
|
|
| from lm_eval.api.model import LM |
|
|
|
|
| eval_logger = logging.getLogger("lm-eval") |
|
|
| MODEL_REGISTRY = {} |
|
|
|
|
| def register_model(*names): |
| |
| |
|
|
| def decorate(cls): |
| for name in names: |
| assert issubclass(cls, LM), ( |
| f"Model '{name}' ({cls.__name__}) must extend LM class" |
| ) |
|
|
| assert name not in MODEL_REGISTRY, ( |
| f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead." |
| ) |
|
|
| MODEL_REGISTRY[name] = cls |
| return cls |
|
|
| return decorate |
|
|
|
|
| def get_model(model_name): |
| try: |
| return MODEL_REGISTRY[model_name] |
| except KeyError: |
| raise ValueError( |
| f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}" |
| ) |
|
|
|
|
| TASK_REGISTRY = {} |
| GROUP_REGISTRY = {} |
| ALL_TASKS = set() |
| func2task_index = {} |
|
|
|
|
| def register_task(name): |
| def decorate(fn): |
| assert name not in TASK_REGISTRY, ( |
| f"task named '{name}' conflicts with existing registered task!" |
| ) |
|
|
| TASK_REGISTRY[name] = fn |
| ALL_TASKS.add(name) |
| func2task_index[fn.__name__] = name |
| return fn |
|
|
| return decorate |
|
|
|
|
| def register_group(name): |
| def decorate(fn): |
| func_name = func2task_index[fn.__name__] |
| if name in GROUP_REGISTRY: |
| GROUP_REGISTRY[name].append(func_name) |
| else: |
| GROUP_REGISTRY[name] = [func_name] |
| ALL_TASKS.add(name) |
| return fn |
|
|
| return decorate |
|
|
|
|
| OUTPUT_TYPE_REGISTRY = {} |
| METRIC_REGISTRY = {} |
| METRIC_AGGREGATION_REGISTRY = {} |
| AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {} |
| HIGHER_IS_BETTER_REGISTRY = {} |
| FILTER_REGISTRY = {} |
|
|
| DEFAULT_METRIC_REGISTRY = { |
| "loglikelihood": [ |
| "perplexity", |
| "acc", |
| ], |
| "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"], |
| "multiple_choice": ["acc", "acc_norm"], |
| "generate_until": ["exact_match"], |
| } |
|
|
|
|
| def register_metric(**args): |
| |
| def decorate(fn): |
| assert "metric" in args |
| name = args["metric"] |
|
|
| for key, registry in [ |
| ("metric", METRIC_REGISTRY), |
| ("higher_is_better", HIGHER_IS_BETTER_REGISTRY), |
| ("aggregation", METRIC_AGGREGATION_REGISTRY), |
| ]: |
| if key in args: |
| value = args[key] |
| assert value not in registry, ( |
| f"{key} named '{value}' conflicts with existing registered {key}!" |
| ) |
|
|
| if key == "metric": |
| registry[name] = fn |
| elif key == "aggregation": |
| registry[name] = AGGREGATION_REGISTRY[value] |
| else: |
| registry[name] = value |
|
|
| return fn |
|
|
| return decorate |
|
|
|
|
| def get_metric(name: str, hf_evaluate_metric=False) -> Callable: |
| if not hf_evaluate_metric: |
| if name in METRIC_REGISTRY: |
| return METRIC_REGISTRY[name] |
| else: |
| eval_logger.warning( |
| f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..." |
| ) |
|
|
| try: |
| metric_object = hf_evaluate.load(name) |
| return metric_object.compute |
| except Exception: |
| eval_logger.error( |
| f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric", |
| ) |
|
|
|
|
| def register_aggregation(name: str): |
| def decorate(fn): |
| assert name not in AGGREGATION_REGISTRY, ( |
| f"aggregation named '{name}' conflicts with existing registered aggregation!" |
| ) |
|
|
| AGGREGATION_REGISTRY[name] = fn |
| return fn |
|
|
| return decorate |
|
|
|
|
| def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: |
| try: |
| return AGGREGATION_REGISTRY[name] |
| except KeyError: |
| eval_logger.warning(f"{name} not a registered aggregation metric!") |
|
|
|
|
| def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]: |
| try: |
| return METRIC_AGGREGATION_REGISTRY[name] |
| except KeyError: |
| eval_logger.warning(f"{name} metric is not assigned a default aggregation!") |
|
|
|
|
| def is_higher_better(metric_name) -> bool: |
| try: |
| return HIGHER_IS_BETTER_REGISTRY[metric_name] |
| except KeyError: |
| eval_logger.warning( |
| f"higher_is_better not specified for metric '{metric_name}'!" |
| ) |
|
|
|
|
| def register_filter(name): |
| def decorate(cls): |
| if name in FILTER_REGISTRY: |
| eval_logger.info( |
| f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}" |
| ) |
| FILTER_REGISTRY[name] = cls |
| return cls |
|
|
| return decorate |
|
|
|
|
| def get_filter(filter_name: Union[str, Callable]) -> Callable: |
| try: |
| return FILTER_REGISTRY[filter_name] |
| except KeyError as e: |
| if callable(filter_name): |
| return filter_name |
| else: |
| eval_logger.warning(f"filter `{filter_name}` is not registered!") |
| raise e |
|
|