Spaces:
Running
Running
File size: 3,490 Bytes
8fa3acc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import argparse
import gc
import logging
from datetime import datetime
import torch
import wandb
from tqdm import tqdm
from predictions.all_llms import llms
from src import WANDB_PROJECT
from src.evaluation.llm_evaluator import ModelEvaluator
from src.evaluation.llm_factory import model_factory
from src.evaluation.tools import split_llm_list
from src.task.task_factory import tasks_factory
from src.task.task_names import Tasks
parser = argparse.ArgumentParser()
parser.add_argument(
"--test",
help="If set to true, the system will default to testing only a small model with a few examples.",
default=False,
type=bool,
)
parser.add_argument(
"--max_examples",
"-m",
help="The maximum number of examples to use, defaults to None.",
type=int,
default=None,
)
parser.add_argument(
"--models_name",
"-mn",
help="The name of the model(s) to load.",
type=str,
default=None,
)
parser.add_argument(
"--batch_size",
help="The batch size to use during the evaluation.",
type=int,
default=32,
)
parser.add_argument(
"--llm_split",
help="The split of the LLMs list to use. It can be '1', '2' or '3'.",
type=int,
default=None,
choices=[1, 2, 3],
)
parser.add_argument(
"--skip_first_n",
help="The number of LLM to skip in the list of split",
type=int,
default=None,
)
args = parser.parse_args()
tasks_names = list(Tasks)
tasks = tasks_factory(tasks_names)
models = []
if args.models_name is not None:
if args.models_name in llms:
models = llms[args.models_name]
else:
models = args.models_name.split(",")
else:
models = llms["all"]
models = split_llm_list(models=models, llm_split=args.llm_split)
if args.skip_first_n is not None:
models = models[args.skip_first_n :]
logging.info("Starting Evaluation")
time_start = datetime.now()
for model_name in tqdm(
models, total=len(models), desc="Processing LLM inference on tasks."
):
try:
model = model_factory(model_name, batch_size=args.batch_size)
logging.info("Creating model")
evaluator = ModelEvaluator()
logging.info("Evaluating model")
exp_name = f"{model_name}"
wandb.init(
project=WANDB_PROJECT,
entity="doctorate",
config={
"model_name": model_name,
"tasks": "; ".join(tasks_names),
"batch_size": args.batch_size,
},
name=exp_name,
)
predictions_payload = evaluator.evaluate_subset(model, tasks, args.max_examples)
wandb.log(predictions_payload)
logging.info("Saving results")
evaluator.save_results("./results")
metrics_payload = evaluator.compute_metrics()
evaluator.save_metrics("./results")
wandb.log(metrics_payload)
except Exception as e:
error_message = f"Evaluation failed for model {model_name}: {e}"
logging.error(error_message)
wandb.finish(exit_code=1)
continue
finally:
# Memory cleaning
if "model" in locals():
del model
if "evaluator" in locals():
del evaluator
gc.collect()
torch.cuda.empty_cache()
wandb.finish(exit_code=0)
time_end = datetime.now()
info_message = f"End time: {time_end}"
logging.info(info_message)
elapsed_time = time_end - time_start
info_message = f"Elapsed time: {elapsed_time}"
logging.info(info_message)
|