Spaces:
Running
Running
| from dataclasses import dataclass, field | |
| from datasets import load_dataset, Dataset | |
| from functools import cached_property | |
| from tqdm.auto import tqdm | |
| from typing import Any, Optional, Protocol, Iterable, Callable | |
| import logging | |
| import pandas as pd | |
| from functools import partial | |
| from .utils import * | |
| from evaluate import load | |
| from collections import defaultdict | |
| def fake_pipeline(prompts: Iterable[str]) -> list[str]: | |
| return [prompt for prompt in tqdm(prompts)] | |
| class Task: | |
| dataset_name: str | tuple[str, str] = ("gsm8k", "main") | |
| split: str = "test" | |
| # metrics: list[str] = field(default_factory=list) | |
| metric_name: str | tuple[str, str] = ("sustech/tlem", "gsm8k") | |
| input_column: str = "question" | |
| label_column: str = "answer" | |
| prompt: Optional[Callable | str] = None | |
| few_shot: int = 0 | |
| few_shot_from: Optional[str] = None | |
| # results: dict[str, Any] = field(default_factory=dict) | |
| def __post_init__(self): | |
| names = ( | |
| [self.dataset_name] | |
| if isinstance(self.dataset_name, str) | |
| else list(self.dataset_name) | |
| ) | |
| names[0] = names[0].split("/")[-1] | |
| self.name = "-".join(names) + f"-{self.split}" | |
| if isinstance(self.prompt, str): | |
| self.prompt = lambda example: { | |
| self.input_column: self.prompt.format( | |
| input_column=example[self.input_column] | |
| ) | |
| } | |
| def samples(self): | |
| return self.dataset[self.input_column] | |
| def dataset(self): | |
| ds = load_dataset( | |
| *self.dataset_name | |
| if isinstance(self.dataset_name, tuple) | |
| else self.dataset_name, | |
| # split=self.split, | |
| ) | |
| test_ds = ds[self.split] | |
| if self.prompt is not None: | |
| test_ds = test_ds.map(self.prompt) | |
| if self.few_shot: | |
| if self.few_shot_from is None: | |
| for name in ["train", "validation", "val", "dev"]: | |
| if name in ds: | |
| self.few_shot_from = name | |
| break | |
| shots = ds[self.few_shot_from].select(range(self.few_shot)) | |
| if self.prompt is not None: | |
| shots = shots.map(self.prompt) | |
| shots = shots.map( | |
| lambda example: { | |
| self.input_column: example[self.input_column] | |
| + example[self.label_column], | |
| } | |
| )[self.input_column] | |
| few_shot_prompts = "\n\n".join(shots) | |
| test_ds = test_ds.map( | |
| lambda example: { | |
| self.input_column: few_shot_prompts | |
| + "\n\n" | |
| + example[self.input_column], | |
| } | |
| ) | |
| return test_ds | |
| def metric(self): | |
| metric = ( | |
| load(self.metric_name) | |
| if isinstance(self.metric_name, str) | |
| else load(*self.metric_name) | |
| ) | |
| return metric | |
| def run( | |
| self, | |
| pipeline, | |
| ): | |
| if (outputs := pipeline(self.samples)) is None: | |
| logging.warning("pipeline returns None") | |
| return | |
| self.outputs = outputs | |
| try: | |
| result = self.metric._compute( | |
| responses=outputs, references=self.dataset[self.label_column] | |
| ) | |
| except Exception as e: | |
| result = self.metric.compute( | |
| responses=outputs, references=self.dataset[self.label_column] | |
| ) | |
| # if log: | |
| # name = name or pipeline.__name__ | |
| # self.results[name] = result | |
| return result | |
| def multichoice(responses: Any, references: list[str]): | |
| if isinstance(responses[0], str): | |
| responses = [extract_choice(response) for response in responses] | |
| else: | |
| responses = decode_choice(responses) | |
| # return [ | |
| # int(response == reference) for reference, response in zip(references, responses) | |
| # ] | |
| return responses, references | |
| class Metrics: | |
| cmmlu = multichoice | |
| mmlu = multichoice | |
| def gsm8k(responses: list[str], answers: list[str | int]): | |
| # scores = [] | |
| # for response, answer in zip(responses, answers): | |
| # pred = extract_numeric(response) | |
| # gold = extract_numeric(answer) if isinstance(answer, str) else str(answer) | |
| # scores.append(1.0 * (pred == gold)) | |
| responses = [extract_numeric(response) for response in responses] | |
| answers = [ | |
| extract_numeric(answer) if isinstance(answer, str) else str(answer) | |
| for answer in answers | |
| ] | |
| return responses, answers | |
| def MATH(responses: list[str], answers: list[str]): | |
| scores = [] | |
| for response, answer in zip(responses, answers): | |
| indices = [pos for pos, char in enumerate(response) if char == "$"] | |
| if len(indices) <= 2: | |
| scores.append(0) | |
| continue | |
| else: | |
| result = response[indices[-2] + 1 : indices[-1]] | |
| gold = get_answer(answer) | |
| scores.append(1.0 * is_equiv(result, gold)) | |
| return scores | |
| def math23k(responses: list[str], answers: list[str]): | |
| scores = [] | |
| for response, answer in zip(responses, answers): | |
| pred = extract_numeric(response, pattern=NUMERIC_IN_ZH) | |
| gold = extract_numeric(answer, pattern=NUMERIC_IN_ZH) | |
| scores.append(1.0 * (pred == gold)) | |
| return scores | |
| class CMMLU: | |
| input_column = "prompt" | |
| label_column = "Answer" | |
| def prompt_cmmlu(example, chat=False): | |
| prefix = "以下是一道多项选择题,请从A、B、C和D中选择最合适的答案作为这个问题的答案。\n\n" if chat else "问题:" | |
| prompt = prefix + example["Question"] | |
| for choice in list("ABCD"): | |
| prompt += f"\n{choice}. {example[choice]}" | |
| prompt += "\n答案:" | |
| return {"prompt": prompt} | |
| subcategories = { | |
| "agronomy": ["other"], | |
| "anatomy": ["biology"], | |
| "ancient_chinese": ["linguistics", "china specific"], | |
| "arts": ["arts"], | |
| "astronomy": ["physics"], | |
| "business_ethics": ["business"], | |
| "chinese_civil_service_exam": ["politics", "china specific"], | |
| "chinese_driving_rule": ["other", "china specific"], | |
| "chinese_food_culture": ["culture", "china specific"], | |
| "chinese_foreign_policy": ["politics", "china specific"], | |
| "chinese_history": ["history", "china specific"], | |
| "chinese_literature": ["literature", "china specific"], | |
| "chinese_teacher_qualification": ["education", "china specific"], | |
| "college_actuarial_science": ["math"], | |
| "college_education": ["education"], | |
| "college_engineering_hydrology": ["engineering"], | |
| "college_law": ["law"], | |
| "college_mathematics": ["math"], | |
| "college_medical_statistics": ["statistics"], | |
| "clinical_knowledge": ["other"], | |
| "college_medicine": ["other"], | |
| "computer_science": ["computer science"], | |
| "computer_security": ["other"], | |
| "conceptual_physics": ["physics"], | |
| "construction_project_management": ["other", "china specific"], | |
| "economics": ["economics"], | |
| "education": ["education"], | |
| "elementary_chinese": ["linguistics", "china specific"], | |
| "elementary_commonsense": ["other", "china specific"], | |
| "elementary_information_and_technology": ["other"], | |
| "electrical_engineering": ["engineering"], | |
| "elementary_mathematics": ["math"], | |
| "ethnology": ["culture", "china specific"], | |
| "food_science": ["other"], | |
| "genetics": ["biology"], | |
| "global_facts": ["global"], | |
| "high_school_biology": ["biology"], | |
| "high_school_chemistry": ["chemistry"], | |
| "high_school_geography": ["geography"], | |
| "high_school_mathematics": ["math"], | |
| "high_school_physics": ["physics"], | |
| "high_school_politics": ["politics", "china specific"], | |
| "human_sexuality": ["other"], | |
| "international_law": ["law"], | |
| "journalism": ["sociology"], | |
| "jurisprudence": ["law"], | |
| "legal_and_moral_basis": ["other"], | |
| "logical": ["philosophy"], | |
| "machine_learning": ["computer science"], | |
| "management": ["business"], | |
| "marketing": ["business"], | |
| "marxist_theory": ["philosophy"], | |
| "modern_chinese": ["linguistics", "china specific"], | |
| "nutrition": ["other"], | |
| "philosophy": ["philosophy"], | |
| "professional_accounting": ["business"], | |
| "professional_law": ["law"], | |
| "professional_medicine": ["other"], | |
| "professional_psychology": ["psychology"], | |
| "public_relations": ["politics"], | |
| "security_study": ["politics"], | |
| "sociology": ["culture"], | |
| "sports_science": ["other"], | |
| "traditional_chinese_medicine": ["other", "china specific"], | |
| "virology": ["biology"], | |
| "world_history": ["history"], | |
| "world_religions": ["global"], | |
| } | |
| categories = { | |
| "STEM": [ | |
| "physics", | |
| "chemistry", | |
| "biology", | |
| "computer science", | |
| "math", | |
| "engineering", | |
| "statistics", | |
| ], | |
| "Humanities": ["history", "philosophy", "law", "arts", "literature", "global"], | |
| "Social Science": [ | |
| "linguistics", | |
| "business", | |
| "politics", | |
| "culture", | |
| "economics", | |
| "geography", | |
| "psychology", | |
| "education", | |
| "sociology", | |
| ], | |
| "Other": ["other"], | |
| "China specific": ["china specific"], | |
| "Test": ["computer science"], | |
| } | |
| def suite(cls, chat=False): | |
| finer_categories = ( | |
| pd.Series(cls.subcategories) # noqa # type: ignore | |
| .explode() | |
| .reset_index() | |
| .set_index(0) | |
| .groupby(0) | |
| .agg(list)["index"] | |
| .to_dict() | |
| ) | |
| suite = defaultdict(list) | |
| for k, v in cls.categories.items(): | |
| for subject in v: | |
| suite[k].extend( | |
| [ | |
| Task( | |
| ("haonan-li/cmmlu", subcategories), | |
| metric_name=("sustech/tlem", "cmmlu"), | |
| input_column=cls.input_column, | |
| label_column=cls.label_column, | |
| prompt=partial(cls.prompt_cmmlu, chat=chat), | |
| few_shot=0 if chat else 5, | |
| few_shot_from="dev", | |
| ) | |
| for subcategories in finer_categories[subject] | |
| ] | |
| ) | |
| return suite | |
| class MMLU: | |
| input_column = "prompt" | |
| label_column = "target" | |
| def prompt_mmlu(cls, example, chat=False): | |
| prefix = ( | |
| "The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n" | |
| if chat | |
| else "Question: " | |
| ) | |
| prompt = prefix + example["input"] | |
| for choice in list("ABCD"): | |
| prompt += f"\n{choice}. {example[choice]}" | |
| prompt += "\nAnswer:" | |
| return {"prompt": prompt} | |
| subcategories = { | |
| "abstract_algebra": ["math"], | |
| "anatomy": ["health"], | |
| "astronomy": ["physics"], | |
| "business_ethics": ["business"], | |
| "clinical_knowledge": ["health"], | |
| "college_biology": ["biology"], | |
| "college_chemistry": ["chemistry"], | |
| "college_computer_science": ["computer science"], | |
| "college_mathematics": ["math"], | |
| "college_medicine": ["health"], | |
| "college_physics": ["physics"], | |
| "computer_security": ["computer science"], | |
| "conceptual_physics": ["physics"], | |
| "econometrics": ["economics"], | |
| "electrical_engineering": ["engineering"], | |
| "elementary_mathematics": ["math"], | |
| "formal_logic": ["philosophy"], | |
| "global_facts": ["other"], | |
| "high_school_biology": ["biology"], | |
| "high_school_chemistry": ["chemistry"], | |
| "high_school_computer_science": ["computer science"], | |
| "high_school_european_history": ["history"], | |
| "high_school_geography": ["geography"], | |
| "high_school_government_and_politics": ["politics"], | |
| "high_school_macroeconomics": ["economics"], | |
| "high_school_mathematics": ["math"], | |
| "high_school_microeconomics": ["economics"], | |
| "high_school_physics": ["physics"], | |
| "high_school_psychology": ["psychology"], | |
| "high_school_statistics": ["math"], | |
| "high_school_us_history": ["history"], | |
| "high_school_world_history": ["history"], | |
| "human_aging": ["health"], | |
| "human_sexuality": ["culture"], | |
| "international_law": ["law"], | |
| "jurisprudence": ["law"], | |
| "logical_fallacies": ["philosophy"], | |
| "machine_learning": ["computer science"], | |
| "management": ["business"], | |
| "marketing": ["business"], | |
| "medical_genetics": ["health"], | |
| "miscellaneous": ["other"], | |
| "moral_disputes": ["philosophy"], | |
| "moral_scenarios": ["philosophy"], | |
| "nutrition": ["health"], | |
| "philosophy": ["philosophy"], | |
| "prehistory": ["history"], | |
| "professional_accounting": ["other"], | |
| "professional_law": ["law"], | |
| "professional_medicine": ["health"], | |
| "professional_psychology": ["psychology"], | |
| "public_relations": ["politics"], | |
| "security_studies": ["politics"], | |
| "sociology": ["culture"], | |
| "us_foreign_policy": ["politics"], | |
| "virology": ["health"], | |
| "world_religions": ["philosophy"], | |
| } | |
| categories = { | |
| "STEM": [ | |
| "physics", | |
| "chemistry", | |
| "biology", | |
| "computer science", | |
| "math", | |
| "engineering", | |
| ], | |
| "humanities": ["history", "philosophy", "law"], | |
| "social sciences": [ | |
| "politics", | |
| "culture", | |
| "economics", | |
| "geography", | |
| "psychology", | |
| ], | |
| "other": ["other", "business", "health"], | |
| "Test": ["culture"], | |
| } | |
| def suite(cls, chat=False): | |
| finer_categories = ( | |
| pd.Series(cls.subcategories) # noqa # type: ignore | |
| .explode() | |
| .reset_index() | |
| .set_index(0) | |
| .groupby(0) | |
| .agg(list)["index"] | |
| .to_dict() | |
| ) | |
| suite = defaultdict(list) | |
| for k, v in cls.categories.items(): | |
| for subject in v: | |
| suite[k].extend( | |
| [ | |
| Task( | |
| ("lukaemon/mmlu", subcategories), | |
| metric_name=("sustech/tlem", "mmlu"), | |
| input_column=cls.input_column, | |
| label_column=cls.label_column, | |
| prompt=partial(cls.prompt_mmlu, chat=chat), | |
| few_shot=0 if chat else 5, | |
| few_shot_from="validation", | |
| ) | |
| for subcategories in finer_categories[subject] | |
| ] | |
| ) | |
| return suite | |