khala / models /Megatron /examples /rl /environments /math /openmath_agent.py
multimodalart's picture
multimodalart HF Staff
Initial best-effort ZeroGPU port of Khala song generation
d1f1097 verified
import random
import datasets
from .math_agent import MathAgent
raw_dataset = datasets.load_dataset("nvidia/OpenMathInstruct-2", split="train")
TRAIN_SIZE = 327680
TEST_SIZE = 1024
assert len(raw_dataset) >= TRAIN_SIZE + TEST_SIZE
train_dataset = raw_dataset.select(range(TRAIN_SIZE))
test_dataset = raw_dataset.select(range(TRAIN_SIZE, TRAIN_SIZE + TEST_SIZE))
class OpenMathInstructAgent(MathAgent):
env_id: str = "openmath_instruct"
def get_dataset(self, validation: bool = False):
return train_dataset if not validation else test_dataset
async def evaluation_prompts(
self, num_prompts: int, validation: bool = False
) -> list[tuple[str, dict]]:
dataset = self.get_dataset(validation)
return [
(self.make_prefix(**golden), golden)
for golden in [dataset[i] for i in range(num_prompts)]
]
async def get_prompt(self, validation=False) -> tuple[str, dict]:
dataset = self.get_dataset(validation)
golden = dataset[random.randrange(len(dataset))]
prompt = self.make_prefix(**golden)
return prompt, golden
async def get_reward(self, response, golden: dict) -> float:
return self.compute_score(response, golden, golden_key="expected_answer")