Spaces:
Runtime error
Runtime error
| from tasks.base import BaseProbInference | |
| class SST2ProbInferenceForMC(BaseProbInference): | |
| def __init__(self, prompt_version): | |
| super().__init__(prompt_version) | |
| self.CHOICES = ["negative", "positive"] | |
| self.can_be_stratified = True | |
| self.num_base_shot = len(self.CHOICES) | |
| def default_prompt_version(self): | |
| return "sp" | |
| def dataset_signature(self): | |
| return { | |
| "result": ("glue", "sst2", "validation"), | |
| "sample": ("glue", "sst2", "train"), | |
| } | |
| def dataset_preprocess(self, raw_data): | |
| data = [] | |
| for e in raw_data: | |
| # print(e, flush=True) | |
| data.append({"query": e["sentence"].strip(), "choices": self.CHOICES, "answer_idx": e["label"]}) | |
| return data | |
| def handcrafted_exemplars(self): | |
| raise NotImplementedError | |
| def exemplar_seperator(self): | |
| if self.prompt_version.startswith("sp"): | |
| return "\n\n" | |
| else: | |
| raise ValueError(f"SST2: Not supported prompt_version: {self.prompt_version}") | |
| def multiple_choice_promptify(self, query, choice): | |
| if self.prompt_version.startswith("sp"): | |
| with_query = f"Review: {query}\nSentiment:" | |
| with_query_and_choice = f"{with_query} {choice}" | |
| else: | |
| raise ValueError(f"SST2: Not supported prompt_version: {self.prompt_version}") | |
| return with_query, with_query_and_choice | |