"""ShareGPT + POLAR reward environment.""" from __future__ import annotations from pathlib import Path from typing import Any from datasets import Dataset, load_dataset import asyncio import verifiers as vf from verifiers.types import Messages from xtuner.utils import RewardModelClient DEFAULT_MODEL = "internlm/POLAR-7B" def _load_sharegpt_dataset(path: str | Path) -> Dataset: dataset = load_dataset("json", data_files=str(path), split="train") def to_single_turn(example: dict[str, Any]) -> dict[str, Any]: human_turn = next( turn["value"] for turn in example["conversations"] if turn["from"] == "human" ) assistant_turn = next( turn["value"] for turn in example["conversations"] if turn["from"] == "gpt" ) return { "prompt": [{"role": "user", "content": human_turn}], "info": { "reference": [{"role": "assistant", "content": assistant_turn}], }, } return dataset.map(to_single_turn, remove_columns=dataset.column_names) class PoolingClient: def __init__( self, model_path: str, server_address: str, server_type: str = "lmdeploy", max_length: int = 16384, max_response_length: int = 4096, response_cut_side: str = "left", ): self.client = RewardModelClient( model_path, max_length=max_length, max_response_length=max_response_length, response_cut_side=response_cut_side, server_type=server_type, server_address=server_address, ) def encode(self, sample: dict[str, Any]) -> str: prompt_text = "\n".join( message["content"] for message in sample.get("prompt", []) ) reference_text = "\n".join( message["content"] for message in sample.get("reference", []) ) output_text = "\n".join( message["content"] for message in sample.get("output", []) ) return f"{prompt_text}\n{reference_text}<|reward|>{prompt_text}\n{output_text}[UNUSED_TOKEN_130]" def score(self, payload: list[dict[str, Any]]) -> list[float]: encoded_payload = [self.encode(item) for item in payload] rewards = self.client.lmdeploy_request_reward(encoded_payload) if rewards is None: raise RuntimeError("Failed to get rewards from lmdeploy server") return rewards async def polar_reward( prompt: Messages, completion: Messages, info: dict[str, Any], reward_client: PoolingClient, pooling_semaphore: asyncio.Semaphore, **_: Any, ) -> float: assistant_turns = [msg for msg in completion if msg.get("role") == "assistant"] if not assistant_turns: return 0.0 payload = [ { "prompt": prompt, "reference": info.get("reference", []), "output": [assistant_turns[-1]], } ] async with pooling_semaphore: loop = asyncio.get_running_loop() rewards = await loop.run_in_executor(None, reward_client.score, payload) if rewards: return float(rewards[-1]) * 10.0 raise RuntimeError(f"Unexpected reward response: {rewards}") def load_environment( data_path: str | Path, *, server_address: str, reward_model: str = DEFAULT_MODEL, reward_scheme: type[vf.Rubric] | None = None, server_type: str = "lmdeploy", **env_kwargs: Any, ) -> vf.SingleTurnEnv: dataset = _load_sharegpt_dataset(data_path) client = PoolingClient( model_path=reward_model, server_address=server_address, server_type=server_type, ) rubric_cls = reward_scheme or vf.Rubric rubric = rubric_cls(funcs=[polar_reward]) rubric.class_objects["reward_client"] = client rubric.class_objects.setdefault("pooling_semaphore", asyncio.Semaphore(4)) return vf.SingleTurnEnv(dataset=dataset, rubric=rubric, **env_kwargs)