# Copyright 2023-2024 SGLang Team # Copyright 2025 ModelBest Inc. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os from typing import Any, Optional, Tuple from uuid import uuid4 from verl.utils.reward_score import gsm8k from .base_tool import BaseTool from .schemas import OpenAIFunctionToolSchema logger = logging.getLogger(__name__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) class Gsm8kTool(BaseTool): """A demo tool for calculating the reward of gsm8k. - `to_openai_function_tool_schema`: return the tool schema in OpenAI format. - `create`: create a tool instance for a trajectory. - `execute`: execute the tool. - `calc_reward`: calculate the reward respect to tool state. - `release`: release the tool instance. """ def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): """ _tool_schema = OpenAIFunctionToolSchema.model_validate({ "type": "function", "function": { "name": "calc_gsm8k_reward", "description": "A tool for calculating the reward of gsm8k", "parameters": { "type": "object", "properties": { "answer": { "type": "string", "description": "The answer to the question", }, }, "required": ["answer"], }, } }) """ super().__init__(config, tool_schema) self._instance_dict = {} def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: return self.tool_schema async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: if instance_id is None: instance_id = str(uuid4()) self._instance_dict[instance_id] = { "response": "", "ground_truth": ground_truth, "reward": 0.0, } return instance_id async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]: answer = parameters.get("answer", "") if not isinstance(answer, str): answer = str(answer) if answer.startswith("#### "): self._instance_dict[instance_id]["response"] = answer else: self._instance_dict[instance_id]["response"] = "#### " + answer reward = await self.calc_reward(instance_id) # penalty for non improved answer submission tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05 # update the reward self._instance_dict[instance_id]["reward"] = reward return f"Current parsed {answer=} {reward=}", tool_reward, {} async def calc_reward(self, instance_id: str, **kwargs) -> float: return gsm8k.compute_score( self._instance_dict[instance_id]["response"], self._instance_dict[instance_id]["ground_truth"], method="flexible", format_score=0.0, score=1.0, ) async def release(self, instance_id: str, **kwargs) -> None: del self._instance_dict[instance_id]