Spaces:
Running
Running
add mmlu and cmmlu
Browse files
tasks.py
CHANGED
|
@@ -3,19 +3,14 @@ from datasets import load_dataset, Dataset
|
|
| 3 |
from functools import cached_property
|
| 4 |
from tqdm.auto import tqdm
|
| 5 |
from typing import Any, Optional, Protocol, Iterable, Callable
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
from .utils import
|
| 8 |
-
NUMERIC_IN_ZH,
|
| 9 |
-
extract_choice_ans,
|
| 10 |
-
extract_numeric,
|
| 11 |
-
get_answer,
|
| 12 |
-
is_equiv,
|
| 13 |
-
)
|
| 14 |
|
| 15 |
from evaluate import load
|
| 16 |
|
| 17 |
-
TextGenerationPipeline = Callable[[Iterable[str]], list[str]]
|
| 18 |
-
|
| 19 |
|
| 20 |
def fake_pipeline(prompts: Iterable[str]) -> list[str]:
|
| 21 |
return [prompt for prompt in tqdm(prompts)]
|
|
@@ -30,14 +25,25 @@ class Task:
|
|
| 30 |
input_column: str = "question"
|
| 31 |
label_column: str = "answer"
|
| 32 |
prompt: Optional[Callable | str] = None
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
self.dataset_name
|
| 38 |
if isinstance(self.dataset_name, str)
|
| 39 |
-
else self.dataset_name
|
| 40 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
@cached_property
|
| 43 |
def samples(self):
|
|
@@ -49,20 +55,38 @@ class Task:
|
|
| 49 |
*self.dataset_name
|
| 50 |
if isinstance(self.dataset_name, tuple)
|
| 51 |
else self.dataset_name,
|
| 52 |
-
split=self.split,
|
| 53 |
)
|
|
|
|
| 54 |
if self.prompt is not None:
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
lambda example: {
|
| 57 |
-
self.input_column: self.
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
}
|
| 61 |
-
if isinstance(self.prompt, str)
|
| 62 |
-
else self.prompt(example),
|
| 63 |
)
|
| 64 |
|
| 65 |
-
return
|
| 66 |
|
| 67 |
@cached_property
|
| 68 |
def metric(self):
|
|
@@ -73,14 +97,44 @@ class Task:
|
|
| 73 |
)
|
| 74 |
return metric
|
| 75 |
|
| 76 |
-
def run(
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
|
| 83 |
class Metrics:
|
|
|
|
|
|
|
|
|
|
| 84 |
def gsm8k(responses: list[str], answers: list[str | int]):
|
| 85 |
scores = []
|
| 86 |
for response, answer in zip(responses, answers):
|
|
@@ -112,26 +166,287 @@ class Metrics:
|
|
| 112 |
scores.append(1.0 * (pred == gold))
|
| 113 |
return scores
|
| 114 |
|
| 115 |
-
def gsm8k_zh(responses: list[str], answers: list[str]):
|
| 116 |
-
scores = []
|
| 117 |
-
for response, answer in zip(responses, answers):
|
| 118 |
-
pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
|
| 119 |
-
gold = extract_numeric(answer)
|
| 120 |
-
scores.append(1.0 * (pred == gold))
|
| 121 |
-
return scores
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
return scores
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from functools import cached_property
|
| 4 |
from tqdm.auto import tqdm
|
| 5 |
from typing import Any, Optional, Protocol, Iterable, Callable
|
| 6 |
+
import logging
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from functools import partial
|
| 9 |
|
| 10 |
+
from .utils import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from evaluate import load
|
| 13 |
|
|
|
|
|
|
|
| 14 |
|
| 15 |
def fake_pipeline(prompts: Iterable[str]) -> list[str]:
|
| 16 |
return [prompt for prompt in tqdm(prompts)]
|
|
|
|
| 25 |
input_column: str = "question"
|
| 26 |
label_column: str = "answer"
|
| 27 |
prompt: Optional[Callable | str] = None
|
| 28 |
+
few_shot: int = 0
|
| 29 |
+
few_shot_from: Optional[str] = None
|
| 30 |
+
# results: dict[str, Any] = field(default_factory=dict)
|
| 31 |
|
| 32 |
+
def __post_init__(self):
|
| 33 |
+
names = (
|
| 34 |
+
[self.dataset_name]
|
|
|
|
| 35 |
if isinstance(self.dataset_name, str)
|
| 36 |
+
else list(self.dataset_name)
|
| 37 |
+
)
|
| 38 |
+
names[0] = names[0].split("/")[-1]
|
| 39 |
+
|
| 40 |
+
self.name = "-".join(names) + f"-{self.split}"
|
| 41 |
+
if isinstance(self.prompt, str):
|
| 42 |
+
self.prompt = lambda example: {
|
| 43 |
+
self.input_column: self.prompt.format(
|
| 44 |
+
input_column=example[self.input_column]
|
| 45 |
+
)
|
| 46 |
+
}
|
| 47 |
|
| 48 |
@cached_property
|
| 49 |
def samples(self):
|
|
|
|
| 55 |
*self.dataset_name
|
| 56 |
if isinstance(self.dataset_name, tuple)
|
| 57 |
else self.dataset_name,
|
| 58 |
+
# split=self.split,
|
| 59 |
)
|
| 60 |
+
test_ds = ds[self.split]
|
| 61 |
if self.prompt is not None:
|
| 62 |
+
test_ds = test_ds.map(self.prompt)
|
| 63 |
+
|
| 64 |
+
if self.few_shot:
|
| 65 |
+
if self.few_shot_from is None:
|
| 66 |
+
for name in ["train", "validation", "val", "dev"]:
|
| 67 |
+
if name in ds:
|
| 68 |
+
self.few_shot_from = name
|
| 69 |
+
break
|
| 70 |
+
|
| 71 |
+
shots = ds[self.few_shot_from].select(range(self.few_shot))
|
| 72 |
+
if self.prompt is not None:
|
| 73 |
+
shots = shots.map(self.prompt)
|
| 74 |
+
|
| 75 |
+
shots = shots.map(
|
| 76 |
lambda example: {
|
| 77 |
+
self.input_column: example[self.input_column]
|
| 78 |
+
+ example[self.label_column],
|
| 79 |
+
}
|
| 80 |
+
)[self.input_column]
|
| 81 |
+
few_shot_prompts = "\n".join(shots)
|
| 82 |
+
|
| 83 |
+
test_ds = test_ds.map(
|
| 84 |
+
lambda example: {
|
| 85 |
+
self.input_column: few_shot_prompts + example[self.input_column],
|
| 86 |
}
|
|
|
|
|
|
|
| 87 |
)
|
| 88 |
|
| 89 |
+
return test_ds
|
| 90 |
|
| 91 |
@cached_property
|
| 92 |
def metric(self):
|
|
|
|
| 97 |
)
|
| 98 |
return metric
|
| 99 |
|
| 100 |
+
def run(
|
| 101 |
+
self,
|
| 102 |
+
pipeline,
|
| 103 |
+
):
|
| 104 |
+
if (outputs := pipeline(self.samples)) is None:
|
| 105 |
+
logging.warning("pipeline returns None")
|
| 106 |
+
return
|
| 107 |
+
self.outputs = outputs
|
| 108 |
+
try:
|
| 109 |
+
result = self.metric._compute(
|
| 110 |
+
responses=outputs, references=self.dataset[self.label_column]
|
| 111 |
+
)
|
| 112 |
+
except Exception as e:
|
| 113 |
+
result = self.metric.compute(
|
| 114 |
+
responses=outputs, references=self.dataset[self.label_column]
|
| 115 |
+
)
|
| 116 |
+
# if log:
|
| 117 |
+
# name = name or pipeline.__name__
|
| 118 |
+
# self.results[name] = result
|
| 119 |
+
|
| 120 |
+
return result
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def multichoice(responses: Any, references: list[str]):
|
| 124 |
+
if isinstance(responses[0], str):
|
| 125 |
+
responses = [extract_choice(response) for response in responses]
|
| 126 |
+
else:
|
| 127 |
+
responses = decode_choice(responses)
|
| 128 |
+
|
| 129 |
+
return [
|
| 130 |
+
int(response == reference) for reference, response in zip(references, responses)
|
| 131 |
+
]
|
| 132 |
|
| 133 |
|
| 134 |
class Metrics:
|
| 135 |
+
cmmlu = multichoice
|
| 136 |
+
mmlu = multichoice
|
| 137 |
+
|
| 138 |
def gsm8k(responses: list[str], answers: list[str | int]):
|
| 139 |
scores = []
|
| 140 |
for response, answer in zip(responses, answers):
|
|
|
|
| 166 |
scores.append(1.0 * (pred == gold))
|
| 167 |
return scores
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
class CMMLU:
|
| 171 |
+
def prompt_cmmlu(example, chat=False):
|
| 172 |
+
prefix = "以下是一道多项选择题,请从A、B、C和D中选择最合适的答案作为这个问题的答案。\n\n" if chat else "问题:"
|
| 173 |
+
prompt = prefix + example["Question"]
|
| 174 |
+
for choice in list("ABCD"):
|
| 175 |
+
prompt += f"\n{choice}. {example[choice]}"
|
|
|
|
| 176 |
|
| 177 |
+
prompt += "\n答案:"
|
| 178 |
+
return {"prompt": prompt}
|
| 179 |
+
|
| 180 |
+
subcategories = {
|
| 181 |
+
"agronomy": ["other"],
|
| 182 |
+
"anatomy": ["biology"],
|
| 183 |
+
"ancient_chinese": ["linguistics", "china specific"],
|
| 184 |
+
"arts": ["arts"],
|
| 185 |
+
"astronomy": ["physics"],
|
| 186 |
+
"business_ethics": ["business"],
|
| 187 |
+
"chinese_civil_service_exam": ["politics", "china specific"],
|
| 188 |
+
"chinese_driving_rule": ["other", "china specific"],
|
| 189 |
+
"chinese_food_culture": ["culture", "china specific"],
|
| 190 |
+
"chinese_foreign_policy": ["politics", "china specific"],
|
| 191 |
+
"chinese_history": ["history", "china specific"],
|
| 192 |
+
"chinese_literature": ["literature", "china specific"],
|
| 193 |
+
"chinese_teacher_qualification": ["education", "china specific"],
|
| 194 |
+
"college_actuarial_science": ["math"],
|
| 195 |
+
"college_education": ["education"],
|
| 196 |
+
"college_engineering_hydrology": ["engineering"],
|
| 197 |
+
"college_law": ["law"],
|
| 198 |
+
"college_mathematics": ["math"],
|
| 199 |
+
"college_medical_statistics": ["statistics"],
|
| 200 |
+
"clinical_knowledge": ["other"],
|
| 201 |
+
"college_medicine": ["other"],
|
| 202 |
+
"computer_science": ["computer science"],
|
| 203 |
+
"computer_security": ["other"],
|
| 204 |
+
"conceptual_physics": ["physics"],
|
| 205 |
+
"construction_project_management": ["other", "china specific"],
|
| 206 |
+
"economics": ["economics"],
|
| 207 |
+
"education": ["education"],
|
| 208 |
+
"elementary_chinese": ["linguistics", "china specific"],
|
| 209 |
+
"elementary_commonsense": ["other", "china specific"],
|
| 210 |
+
"elementary_information_and_technology": ["other"],
|
| 211 |
+
"electrical_engineering": ["engineering"],
|
| 212 |
+
"elementary_mathematics": ["math"],
|
| 213 |
+
"ethnology": ["culture", "china specific"],
|
| 214 |
+
"food_science": ["other"],
|
| 215 |
+
"genetics": ["biology"],
|
| 216 |
+
"global_facts": ["global"],
|
| 217 |
+
"high_school_biology": ["biology"],
|
| 218 |
+
"high_school_chemistry": ["chemistry"],
|
| 219 |
+
"high_school_geography": ["geography"],
|
| 220 |
+
"high_school_mathematics": ["math"],
|
| 221 |
+
"high_school_physics": ["physics"],
|
| 222 |
+
"high_school_politics": ["politics", "china specific"],
|
| 223 |
+
"human_sexuality": ["other"],
|
| 224 |
+
"international_law": ["law"],
|
| 225 |
+
"journalism": ["sociology"],
|
| 226 |
+
"jurisprudence": ["law"],
|
| 227 |
+
"legal_and_moral_basis": ["other"],
|
| 228 |
+
"logical": ["philosophy"],
|
| 229 |
+
"machine_learning": ["computer science"],
|
| 230 |
+
"management": ["business"],
|
| 231 |
+
"marketing": ["business"],
|
| 232 |
+
"marxist_theory": ["philosophy"],
|
| 233 |
+
"modern_chinese": ["linguistics", "china specific"],
|
| 234 |
+
"nutrition": ["other"],
|
| 235 |
+
"philosophy": ["philosophy"],
|
| 236 |
+
"professional_accounting": ["business"],
|
| 237 |
+
"professional_law": ["law"],
|
| 238 |
+
"professional_medicine": ["other"],
|
| 239 |
+
"professional_psychology": ["psychology"],
|
| 240 |
+
"public_relations": ["politics"],
|
| 241 |
+
"security_study": ["politics"],
|
| 242 |
+
"sociology": ["culture"],
|
| 243 |
+
"sports_science": ["other"],
|
| 244 |
+
"traditional_chinese_medicine": ["other", "china specific"],
|
| 245 |
+
"virology": ["biology"],
|
| 246 |
+
"world_history": ["history"],
|
| 247 |
+
"world_religions": ["global"],
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
categories = {
|
| 251 |
+
"STEM": [
|
| 252 |
+
"physics",
|
| 253 |
+
"chemistry",
|
| 254 |
+
"biology",
|
| 255 |
+
"computer science",
|
| 256 |
+
"math",
|
| 257 |
+
"engineering",
|
| 258 |
+
"statistics",
|
| 259 |
+
],
|
| 260 |
+
"Humanities": ["history", "philosophy", "law", "arts", "literature", "global"],
|
| 261 |
+
"Social Science": [
|
| 262 |
+
"linguistics",
|
| 263 |
+
"business",
|
| 264 |
+
"politics",
|
| 265 |
+
"culture",
|
| 266 |
+
"economics",
|
| 267 |
+
"geography",
|
| 268 |
+
"psychology",
|
| 269 |
+
"education",
|
| 270 |
+
"sociology",
|
| 271 |
+
],
|
| 272 |
+
"Other": ["other"],
|
| 273 |
+
"China specific": ["china specific"],
|
| 274 |
+
"Test": ["computer science"],
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
finer_categories = (
|
| 278 |
+
pd.Series(subcategories) # noqa # type: ignore
|
| 279 |
+
.explode()
|
| 280 |
+
.reset_index()
|
| 281 |
+
.set_index(0)
|
| 282 |
+
.groupby(0)
|
| 283 |
+
.agg(list)["index"]
|
| 284 |
+
.to_dict()
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
@classmethod
|
| 288 |
+
def suite(cls, chat=False):
|
| 289 |
+
suite = {}
|
| 290 |
+
for k, v in cls.categories.items():
|
| 291 |
+
for subject in v:
|
| 292 |
+
suite[k] = [
|
| 293 |
+
Task(
|
| 294 |
+
("haonan-li/cmmlu", subcategories),
|
| 295 |
+
metric_name=("sustech/tlem", "cmmlu"),
|
| 296 |
+
input_column="prompt",
|
| 297 |
+
label_column="Answer",
|
| 298 |
+
prompt=partial(cls.prompt_cmmlu, chat=chat),
|
| 299 |
+
)
|
| 300 |
+
for subcategories in cls.finer_categories[subject]
|
| 301 |
+
]
|
| 302 |
+
return suite
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class MMLU:
|
| 306 |
+
input_column = "prompt"
|
| 307 |
+
label_column = "target"
|
| 308 |
+
|
| 309 |
+
@classmethod
|
| 310 |
+
def prompt_mmlu(cls, example, chat=False):
|
| 311 |
+
prefix = (
|
| 312 |
+
"The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n"
|
| 313 |
+
if chat
|
| 314 |
+
else "Question: "
|
| 315 |
+
)
|
| 316 |
+
prompt = prefix + example["input"]
|
| 317 |
+
for choice in list("ABCD"):
|
| 318 |
+
prompt += f"\n{choice}. {example[choice]}"
|
| 319 |
+
|
| 320 |
+
prompt += "\nAnswer:"
|
| 321 |
+
return {"prompt": prompt}
|
| 322 |
+
|
| 323 |
+
subcategories = {
|
| 324 |
+
"abstract_algebra": ["math"],
|
| 325 |
+
"anatomy": ["health"],
|
| 326 |
+
"astronomy": ["physics"],
|
| 327 |
+
"business_ethics": ["business"],
|
| 328 |
+
"clinical_knowledge": ["health"],
|
| 329 |
+
"college_biology": ["biology"],
|
| 330 |
+
"college_chemistry": ["chemistry"],
|
| 331 |
+
"college_computer_science": ["computer science"],
|
| 332 |
+
"college_mathematics": ["math"],
|
| 333 |
+
"college_medicine": ["health"],
|
| 334 |
+
"college_physics": ["physics"],
|
| 335 |
+
"computer_security": ["computer science"],
|
| 336 |
+
"conceptual_physics": ["physics"],
|
| 337 |
+
"econometrics": ["economics"],
|
| 338 |
+
"electrical_engineering": ["engineering"],
|
| 339 |
+
"elementary_mathematics": ["math"],
|
| 340 |
+
"formal_logic": ["philosophy"],
|
| 341 |
+
"global_facts": ["other"],
|
| 342 |
+
"high_school_biology": ["biology"],
|
| 343 |
+
"high_school_chemistry": ["chemistry"],
|
| 344 |
+
"high_school_computer_science": ["computer science"],
|
| 345 |
+
"high_school_european_history": ["history"],
|
| 346 |
+
"high_school_geography": ["geography"],
|
| 347 |
+
"high_school_government_and_politics": ["politics"],
|
| 348 |
+
"high_school_macroeconomics": ["economics"],
|
| 349 |
+
"high_school_mathematics": ["math"],
|
| 350 |
+
"high_school_microeconomics": ["economics"],
|
| 351 |
+
"high_school_physics": ["physics"],
|
| 352 |
+
"high_school_psychology": ["psychology"],
|
| 353 |
+
"high_school_statistics": ["math"],
|
| 354 |
+
"high_school_us_history": ["history"],
|
| 355 |
+
"high_school_world_history": ["history"],
|
| 356 |
+
"human_aging": ["health"],
|
| 357 |
+
"human_sexuality": ["culture"],
|
| 358 |
+
"international_law": ["law"],
|
| 359 |
+
"jurisprudence": ["law"],
|
| 360 |
+
"logical_fallacies": ["philosophy"],
|
| 361 |
+
"machine_learning": ["computer science"],
|
| 362 |
+
"management": ["business"],
|
| 363 |
+
"marketing": ["business"],
|
| 364 |
+
"medical_genetics": ["health"],
|
| 365 |
+
"miscellaneous": ["other"],
|
| 366 |
+
"moral_disputes": ["philosophy"],
|
| 367 |
+
"moral_scenarios": ["philosophy"],
|
| 368 |
+
"nutrition": ["health"],
|
| 369 |
+
"philosophy": ["philosophy"],
|
| 370 |
+
"prehistory": ["history"],
|
| 371 |
+
"professional_accounting": ["other"],
|
| 372 |
+
"professional_law": ["law"],
|
| 373 |
+
"professional_medicine": ["health"],
|
| 374 |
+
"professional_psychology": ["psychology"],
|
| 375 |
+
"public_relations": ["politics"],
|
| 376 |
+
"security_studies": ["politics"],
|
| 377 |
+
"sociology": ["culture"],
|
| 378 |
+
"us_foreign_policy": ["politics"],
|
| 379 |
+
"virology": ["health"],
|
| 380 |
+
"world_religions": ["philosophy"],
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
categories = {
|
| 384 |
+
"Math": [
|
| 385 |
+
"math",
|
| 386 |
+
],
|
| 387 |
+
"STEM": [
|
| 388 |
+
"physics",
|
| 389 |
+
"chemistry",
|
| 390 |
+
"biology",
|
| 391 |
+
"computer science",
|
| 392 |
+
"math",
|
| 393 |
+
"engineering",
|
| 394 |
+
],
|
| 395 |
+
"humanities": ["history", "philosophy", "law"],
|
| 396 |
+
"social sciences": [
|
| 397 |
+
"politics",
|
| 398 |
+
"culture",
|
| 399 |
+
"economics",
|
| 400 |
+
"geography",
|
| 401 |
+
"psychology",
|
| 402 |
+
],
|
| 403 |
+
"Other": ["other", "business", "health"],
|
| 404 |
+
"All": [
|
| 405 |
+
"physics",
|
| 406 |
+
"chemistry",
|
| 407 |
+
"biology",
|
| 408 |
+
"computer science",
|
| 409 |
+
"math",
|
| 410 |
+
"engineering",
|
| 411 |
+
"history",
|
| 412 |
+
"philosophy",
|
| 413 |
+
"law",
|
| 414 |
+
"politics",
|
| 415 |
+
"culture",
|
| 416 |
+
"economics",
|
| 417 |
+
"geography",
|
| 418 |
+
"psychology",
|
| 419 |
+
"other",
|
| 420 |
+
"business",
|
| 421 |
+
"health",
|
| 422 |
+
],
|
| 423 |
+
"Test": ["culture"],
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
@classmethod
|
| 427 |
+
def suite(cls, chat=False):
|
| 428 |
+
finer_categories = (
|
| 429 |
+
pd.Series(cls.subcategories) # noqa # type: ignore
|
| 430 |
+
.explode()
|
| 431 |
+
.reset_index()
|
| 432 |
+
.set_index(0)
|
| 433 |
+
.groupby(0)
|
| 434 |
+
.agg(list)["index"]
|
| 435 |
+
.to_dict()
|
| 436 |
+
)
|
| 437 |
+
suite = {}
|
| 438 |
+
for k, v in cls.categories.items():
|
| 439 |
+
for subject in v:
|
| 440 |
+
suite[k] = [
|
| 441 |
+
Task(
|
| 442 |
+
("lukaemon/mmlu", subcategories),
|
| 443 |
+
metric_name=("sustech/tlem", "mmlu"),
|
| 444 |
+
input_column=cls.input_column,
|
| 445 |
+
label_column=cls.label_column,
|
| 446 |
+
prompt=partial(cls.prompt_mmlu, chat=chat),
|
| 447 |
+
few_shot=0 if chat else 5,
|
| 448 |
+
few_shot_from="validation"
|
| 449 |
+
)
|
| 450 |
+
for subcategories in finer_categories[subject]
|
| 451 |
+
]
|
| 452 |
+
return suite
|
tlem.py
CHANGED
|
@@ -11,7 +11,8 @@ from evaluate.evaluation_suite import EvaluationSuite
|
|
| 11 |
import evaluate
|
| 12 |
import numpy as np
|
| 13 |
import datasets
|
| 14 |
-
|
|
|
|
| 15 |
from .utils import is_equiv
|
| 16 |
|
| 17 |
# %%
|
|
@@ -24,56 +25,35 @@ from .utils import is_equiv
|
|
| 24 |
|
| 25 |
# TODO: Add BibTeX citation
|
| 26 |
_CITATION = """\
|
| 27 |
-
@InProceedings{huggingface:module,
|
| 28 |
-
title = {A great new module},
|
| 29 |
-
authors={huggingface, Inc.},
|
| 30 |
-
year={2020}
|
| 31 |
-
}
|
| 32 |
"""
|
| 33 |
|
| 34 |
# TODO: Add description of the module here
|
| 35 |
_DESCRIPTION = """\
|
| 36 |
-
A simple measurement that returns the number of elements in dataset.
|
| 37 |
"""
|
| 38 |
|
| 39 |
|
| 40 |
# TODO: Add description of the arguments of the module here
|
| 41 |
_KWARGS_DESCRIPTION = """
|
| 42 |
-
Calculates number of elements in dataset
|
| 43 |
-
Args:
|
| 44 |
-
data: list of elements.
|
| 45 |
-
Returns:
|
| 46 |
-
element_count: number of elements in dataset,
|
| 47 |
-
Examples:
|
| 48 |
-
>>> measure = evaluate.load("lvwerra/element_count")
|
| 49 |
-
>>> measure.compute(["a", "b", "c")
|
| 50 |
-
{"element_count": 3}
|
| 51 |
"""
|
| 52 |
|
| 53 |
# TODO: Define external resources urls if needed
|
| 54 |
BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
|
| 55 |
|
| 56 |
|
| 57 |
-
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
| 58 |
class ReasoningMetric(evaluate.Metric):
|
| 59 |
"""TODO: Short description of my evaluation module."""
|
| 60 |
|
| 61 |
def _info(self):
|
|
|
|
| 62 |
features = datasets.Features(
|
| 63 |
{
|
| 64 |
"responses": datasets.Value("string"),
|
|
|
|
| 65 |
"references": datasets.Value("string"),
|
| 66 |
}
|
| 67 |
)
|
| 68 |
|
| 69 |
-
if self.config_name == "svamp":
|
| 70 |
-
features = datasets.Features(
|
| 71 |
-
{
|
| 72 |
-
"responses": datasets.Value("string"),
|
| 73 |
-
"references": datasets.Value("float"),
|
| 74 |
-
}
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
# TODO: Specifies the evaluate.EvaluationModuleInfo object
|
| 78 |
return evaluate.EvaluationModuleInfo(
|
| 79 |
# This is the description that will appear on the modules page.
|
|
@@ -90,38 +70,59 @@ class ReasoningMetric(evaluate.Metric):
|
|
| 90 |
reference_urls=["http://path.to.reference.url/new_module"],
|
| 91 |
)
|
| 92 |
|
| 93 |
-
def _compute(self, responses, references
|
| 94 |
-
results = {}
|
| 95 |
scores = getattr(Metrics, self.config_name)(responses, references)
|
| 96 |
-
|
| 97 |
-
results
|
| 98 |
-
"accuracy": acc,
|
| 99 |
-
"scores": scores,
|
| 100 |
-
}
|
| 101 |
-
|
| 102 |
-
if verbose:
|
| 103 |
-
results["references"] = references
|
| 104 |
-
results["answers"] = responses
|
| 105 |
-
# results["scores"] = scores
|
| 106 |
-
|
| 107 |
return results
|
| 108 |
|
| 109 |
|
| 110 |
class Suite(EvaluationSuite):
|
| 111 |
def run(
|
| 112 |
-
self,
|
|
|
|
|
|
|
| 113 |
) -> dict[str, float]:
|
| 114 |
self.assert_suite_nonempty()
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
super().__init__(name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
self.suite = [
|
| 127 |
Task(
|
|
@@ -136,4 +137,3 @@ class Suite(EvaluationSuite):
|
|
| 136 |
|
| 137 |
|
| 138 |
# %%
|
| 139 |
-
|
|
|
|
| 11 |
import evaluate
|
| 12 |
import numpy as np
|
| 13 |
import datasets
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from .tasks import *
|
| 16 |
from .utils import is_equiv
|
| 17 |
|
| 18 |
# %%
|
|
|
|
| 25 |
|
| 26 |
# TODO: Add BibTeX citation
|
| 27 |
_CITATION = """\
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
"""
|
| 29 |
|
| 30 |
# TODO: Add description of the module here
|
| 31 |
_DESCRIPTION = """\
|
|
|
|
| 32 |
"""
|
| 33 |
|
| 34 |
|
| 35 |
# TODO: Add description of the arguments of the module here
|
| 36 |
_KWARGS_DESCRIPTION = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
"""
|
| 38 |
|
| 39 |
# TODO: Define external resources urls if needed
|
| 40 |
BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
|
| 41 |
|
| 42 |
|
| 43 |
+
# @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
| 44 |
class ReasoningMetric(evaluate.Metric):
|
| 45 |
"""TODO: Short description of my evaluation module."""
|
| 46 |
|
| 47 |
def _info(self):
|
| 48 |
+
# if self.config_name in ["cmmlu"]:
|
| 49 |
features = datasets.Features(
|
| 50 |
{
|
| 51 |
"responses": datasets.Value("string"),
|
| 52 |
+
# "responses": datasets.Sequence(datasets.Value("float")),
|
| 53 |
"references": datasets.Value("string"),
|
| 54 |
}
|
| 55 |
)
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# TODO: Specifies the evaluate.EvaluationModuleInfo object
|
| 58 |
return evaluate.EvaluationModuleInfo(
|
| 59 |
# This is the description that will appear on the modules page.
|
|
|
|
| 70 |
reference_urls=["http://path.to.reference.url/new_module"],
|
| 71 |
)
|
| 72 |
|
| 73 |
+
def _compute(self, responses, references):
|
|
|
|
| 74 |
scores = getattr(Metrics, self.config_name)(responses, references)
|
| 75 |
+
results = {"Accuracy": np.nanmean(scores)}
|
| 76 |
+
logging.info(results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
return results
|
| 78 |
|
| 79 |
|
| 80 |
class Suite(EvaluationSuite):
|
| 81 |
def run(
|
| 82 |
+
self,
|
| 83 |
+
model_or_pipeline: Any,
|
| 84 |
+
name="tlem",
|
| 85 |
) -> dict[str, float]:
|
| 86 |
self.assert_suite_nonempty()
|
| 87 |
|
| 88 |
+
def run_tasks(tasks):
|
| 89 |
+
for task in tqdm(tasks):
|
| 90 |
+
if task.name not in self.cached_result:
|
| 91 |
+
self.cached_result[task.name] = task.run(model_or_pipeline)
|
| 92 |
+
results = [self.cached_result[task.name] for task in tasks]
|
| 93 |
+
return pd.DataFrame(results).mean().to_dict()
|
| 94 |
+
|
| 95 |
+
if isinstance(self.suite, dict):
|
| 96 |
+
for category, tasks in tqdm(self.suite.items()):
|
| 97 |
+
logging.warning(f"Combined results: {category}:{run_tasks(tasks)}")
|
| 98 |
+
else:
|
| 99 |
+
logging.warning(f"Combined results: {run_tasks(self.suite)}")
|
| 100 |
+
|
| 101 |
+
return self.cached_result
|
| 102 |
+
|
| 103 |
+
def add(self, name):
|
| 104 |
+
chat = False
|
| 105 |
+
match name:
|
| 106 |
+
case _ if "chat" in name:
|
| 107 |
+
chat = True
|
| 108 |
+
match name:
|
| 109 |
+
case _ if name.startswith("mmlu"):
|
| 110 |
+
suite = MMLU.suite(chat=chat)
|
| 111 |
+
case _ if name.startswith("cmmlu"):
|
| 112 |
+
suite = CMMLU.suite(chat=chat)
|
| 113 |
+
match name:
|
| 114 |
+
case _ if "test" in name:
|
| 115 |
+
suite = suite["Test"]
|
| 116 |
+
|
| 117 |
+
self.suite = suite
|
| 118 |
+
|
| 119 |
+
def __init__(self, name="tlem"):
|
| 120 |
super().__init__(name)
|
| 121 |
+
self.cached_result = {}
|
| 122 |
+
|
| 123 |
+
match self.name:
|
| 124 |
+
case "cmmlu":
|
| 125 |
+
pass
|
| 126 |
|
| 127 |
self.suite = [
|
| 128 |
Task(
|
|
|
|
| 137 |
|
| 138 |
|
| 139 |
# %%
|
|
|
utils.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
import logging
|
| 2 |
import re
|
|
|
|
|
|
|
| 3 |
|
| 4 |
NUMERIC_IN_EN = r"(?:[\s=+-/<>($:\.\*\\])(?=\S)((?:0|(?:\d{1,3}(?:,\d{3})+(?=\D|$))|(?:\d+))(?:\.\d+)?%?)(?:(?![^\s=+-/>)$:\.\*\\])|(?=, ))"
|
| 5 |
NUMERIC_IN_ZH = (
|
|
@@ -7,17 +9,43 @@ NUMERIC_IN_ZH = (
|
|
| 7 |
)
|
| 8 |
|
| 9 |
|
| 10 |
-
def
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
def standardize(ans):
|
| 18 |
-
return ans if len(ans) == 1 else ans[1]
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def extract_numeric(string, pattern=NUMERIC_IN_EN) -> str:
|
|
|
|
| 1 |
import logging
|
| 2 |
import re
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Any
|
| 5 |
|
| 6 |
NUMERIC_IN_EN = r"(?:[\s=+-/<>($:\.\*\\])(?=\S)((?:0|(?:\d{1,3}(?:,\d{3})+(?=\D|$))|(?:\d+))(?:\.\d+)?%?)(?:(?![^\s=+-/>)$:\.\*\\])|(?=, ))"
|
| 7 |
NUMERIC_IN_ZH = (
|
|
|
|
| 9 |
)
|
| 10 |
|
| 11 |
|
| 12 |
+
def extract_choice(gen):
|
| 13 |
+
# answer is A | choice is A | choose A
|
| 14 |
+
res = re.search(
|
| 15 |
+
r"(?:(?:[Cc]hoose)|(?:(?:[Aa]nswer|[Cc]hoice)(?![^ABCD]{0,20}?(?:n't|not))[^ABCD]{0,10}?\b(?:|is|:|be))\b)[^ABCD]{0,20}?\b(A|B|C|D)\b",
|
| 16 |
+
gen,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# A is correct | A is right
|
| 20 |
+
if res is None:
|
| 21 |
+
res = re.search(
|
| 22 |
+
r"\b(A|B|C|D)\b(?![^ABCD]{0,8}?(?:n't|not)[^ABCD]{0,5}?(?:correct|right))[^ABCD]{0,10}?\b(?:correct|right)\b",
|
| 23 |
+
gen,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# straight answer: A
|
| 27 |
+
if res is None:
|
| 28 |
+
res = re.search(r"^(A|B|C|D)(?:\.|,|:|$)", gen)
|
| 29 |
+
|
| 30 |
+
# simply extract the first appearred letter
|
| 31 |
+
if res is None:
|
| 32 |
+
res = re.search(r"(?<![a-zA-Z])(A|B|C|D)(?![a-zA-Z=])", gen)
|
| 33 |
+
|
| 34 |
+
if res is None:
|
| 35 |
+
res = "A"
|
| 36 |
+
|
| 37 |
+
if isinstance(res, str):
|
| 38 |
+
return res
|
| 39 |
+
|
| 40 |
+
return res.group(1)
|
| 41 |
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
def decode_choice(responses: list[Any]):
|
| 44 |
+
num_choices = responses[0].shape[0]
|
| 45 |
+
choices = np.argmax(np.asarray(responses), axis=1)
|
| 46 |
+
responses = np.array(list("ABCDEFGHIJKL"[:num_choices]))[choices]
|
| 47 |
+
# return (responses == np.array(references)).mean()
|
| 48 |
+
return responses
|
| 49 |
|
| 50 |
|
| 51 |
def extract_numeric(string, pattern=NUMERIC_IN_EN) -> str:
|