Spaces:
Running on Zero
Running on Zero
| import random | |
| import datasets | |
| from .math_agent import MathAgent | |
| raw_dataset = datasets.load_dataset("nvidia/OpenMathInstruct-2", split="train") | |
| TRAIN_SIZE = 327680 | |
| TEST_SIZE = 1024 | |
| assert len(raw_dataset) >= TRAIN_SIZE + TEST_SIZE | |
| train_dataset = raw_dataset.select(range(TRAIN_SIZE)) | |
| test_dataset = raw_dataset.select(range(TRAIN_SIZE, TRAIN_SIZE + TEST_SIZE)) | |
| class OpenMathInstructAgent(MathAgent): | |
| env_id: str = "openmath_instruct" | |
| def get_dataset(self, validation: bool = False): | |
| return train_dataset if not validation else test_dataset | |
| async def evaluation_prompts( | |
| self, num_prompts: int, validation: bool = False | |
| ) -> list[tuple[str, dict]]: | |
| dataset = self.get_dataset(validation) | |
| return [ | |
| (self.make_prefix(**golden), golden) | |
| for golden in [dataset[i] for i in range(num_prompts)] | |
| ] | |
| async def get_prompt(self, validation=False) -> tuple[str, dict]: | |
| dataset = self.get_dataset(validation) | |
| golden = dataset[random.randrange(len(dataset))] | |
| prompt = self.make_prefix(**golden) | |
| return prompt, golden | |
| async def get_reward(self, response, golden: dict) -> float: | |
| return self.compute_score(response, golden, golden_key="expected_answer") | |