Spaces:
Running
Running
verbose mode
Browse files
tasks.py
CHANGED
|
@@ -126,9 +126,10 @@ def multichoice(responses: Any, references: list[str]):
|
|
| 126 |
else:
|
| 127 |
responses = decode_choice(responses)
|
| 128 |
|
| 129 |
-
return [
|
| 130 |
-
|
| 131 |
-
]
|
|
|
|
| 132 |
|
| 133 |
|
| 134 |
class Metrics:
|
|
@@ -136,12 +137,18 @@ class Metrics:
|
|
| 136 |
mmlu = multichoice
|
| 137 |
|
| 138 |
def gsm8k(responses: list[str], answers: list[str | int]):
|
| 139 |
-
scores = []
|
| 140 |
-
for response, answer in zip(responses, answers):
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
def MATH(responses: list[str], answers: list[str]):
|
| 147 |
scores = []
|
|
@@ -445,7 +452,7 @@ class MMLU:
|
|
| 445 |
label_column=cls.label_column,
|
| 446 |
prompt=partial(cls.prompt_mmlu, chat=chat),
|
| 447 |
few_shot=0 if chat else 5,
|
| 448 |
-
few_shot_from="validation"
|
| 449 |
)
|
| 450 |
for subcategories in finer_categories[subject]
|
| 451 |
]
|
|
|
|
| 126 |
else:
|
| 127 |
responses = decode_choice(responses)
|
| 128 |
|
| 129 |
+
# return [
|
| 130 |
+
# int(response == reference) for reference, response in zip(references, responses)
|
| 131 |
+
# ]
|
| 132 |
+
return responses, references
|
| 133 |
|
| 134 |
|
| 135 |
class Metrics:
|
|
|
|
| 137 |
mmlu = multichoice
|
| 138 |
|
| 139 |
def gsm8k(responses: list[str], answers: list[str | int]):
|
| 140 |
+
# scores = []
|
| 141 |
+
# for response, answer in zip(responses, answers):
|
| 142 |
+
# pred = extract_numeric(response)
|
| 143 |
+
# gold = extract_numeric(answer) if isinstance(answer, str) else str(answer)
|
| 144 |
+
# scores.append(1.0 * (pred == gold))
|
| 145 |
+
responses = [extract_numeric(response) for response in responses]
|
| 146 |
+
answers = [
|
| 147 |
+
extract_numeric(answer) if isinstance(answer, str) else str(answer)
|
| 148 |
+
for answer in answers
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
return responses, answers
|
| 152 |
|
| 153 |
def MATH(responses: list[str], answers: list[str]):
|
| 154 |
scores = []
|
|
|
|
| 452 |
label_column=cls.label_column,
|
| 453 |
prompt=partial(cls.prompt_mmlu, chat=chat),
|
| 454 |
few_shot=0 if chat else 5,
|
| 455 |
+
few_shot_from="validation",
|
| 456 |
)
|
| 457 |
for subcategories in finer_categories[subject]
|
| 458 |
]
|
tlem.py
CHANGED
|
@@ -6,6 +6,7 @@ except Exception as e:
|
|
| 6 |
import logging
|
| 7 |
|
| 8 |
from typing import Any, Optional, Protocol, Iterable, Callable
|
|
|
|
| 9 |
from tqdm.auto import tqdm
|
| 10 |
from evaluate.evaluation_suite import EvaluationSuite
|
| 11 |
import evaluate
|
|
@@ -70,10 +71,26 @@ class ReasoningMetric(evaluate.Metric):
|
|
| 70 |
reference_urls=["http://path.to.reference.url/new_module"],
|
| 71 |
)
|
| 72 |
|
| 73 |
-
def _compute(self, responses, references):
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
logging.info(results)
|
|
|
|
|
|
|
| 77 |
return results
|
| 78 |
|
| 79 |
|
|
|
|
| 6 |
import logging
|
| 7 |
|
| 8 |
from typing import Any, Optional, Protocol, Iterable, Callable
|
| 9 |
+
from numpy.lib import extract
|
| 10 |
from tqdm.auto import tqdm
|
| 11 |
from evaluate.evaluation_suite import EvaluationSuite
|
| 12 |
import evaluate
|
|
|
|
| 71 |
reference_urls=["http://path.to.reference.url/new_module"],
|
| 72 |
)
|
| 73 |
|
| 74 |
+
def _compute(self, responses, references, verbose=False):
|
| 75 |
+
extract_responses, extract_references = getattr(Metrics, self.config_name)(
|
| 76 |
+
responses, references
|
| 77 |
+
)
|
| 78 |
+
df = pd.DataFrame(
|
| 79 |
+
{
|
| 80 |
+
"responses": responses,
|
| 81 |
+
"references": references,
|
| 82 |
+
}
|
| 83 |
+
)
|
| 84 |
+
df["extract_responses"] = extract_responses
|
| 85 |
+
df["extract_references"] = extract_references
|
| 86 |
+
results = {
|
| 87 |
+
"Accuracy": (df["extract_references"] == df["extract_responses"])
|
| 88 |
+
.astype(int)
|
| 89 |
+
.mean(),
|
| 90 |
+
}
|
| 91 |
logging.info(results)
|
| 92 |
+
if verbose:
|
| 93 |
+
results["df"] = df
|
| 94 |
return results
|
| 95 |
|
| 96 |
|