| | |
| | |
| | |
| |
|
| | import dataclasses |
| | import pprint |
| | from functools import partial |
| | import os |
| | from tqdm import tqdm, trange |
| | import numpy as np |
| | import mlxu |
| |
|
| | from flax.traverse_util import flatten_dict |
| | from lm_eval import evaluator, tasks |
| | from lm_eval.base import LM |
| |
|
| | from EasyLM.serving import LMClient |
| |
|
| |
|
| | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( |
| | tasks='wsc,piqa,winogrande,openbookqa,logiqa', |
| | shots=0, |
| | limit=0, |
| | write_out=False, |
| | lm_client=LMClient.get_default_config(), |
| | logger=mlxu.WandBLogger.get_default_config(), |
| | ) |
| |
|
| |
|
| | class LMEvalHarnessInterface(LM): |
| |
|
| | def __init__(self, lm_client): |
| | self.lm_client = lm_client |
| |
|
| | def greedy_until(self, inputs): |
| | prefix, until = zip(*inputs) |
| | return self.lm_client.greedy_until(prefix, until) |
| |
|
| | def loglikelihood_rolling(self, inputs): |
| | loglikelihood, is_greedy = self.lm_client.loglikelihood_rolling(inputs) |
| | return list(zip(loglikelihood, is_greedy)) |
| |
|
| | def loglikelihood(self, inputs): |
| | prefix, text = zip(*inputs) |
| | loglikelihood, is_greedy = self.lm_client.loglikelihood(prefix, text) |
| | return list(zip(loglikelihood, is_greedy)) |
| |
|
| |
|
| | def main(argv): |
| | logger = mlxu.WandBLogger( |
| | config=FLAGS.logger, variant=mlxu.get_user_flags(FLAGS, FLAGS_DEF) |
| | ) |
| | model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client)) |
| | task_list = FLAGS.tasks.split(',') |
| | results = evaluator.evaluate( |
| | model, tasks.get_task_dict(task_list), False, FLAGS.shots, |
| | limit=None if FLAGS.limit <= 0 else FLAGS.limit, |
| | write_out=FLAGS.write_out, |
| | ) |
| | logger.log(flatten_dict(results['results'], sep='/')) |
| | pprint.pprint(results) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | mlxu.run(main) |
| |
|