testing / sharegpt_polar.py
Delta-Vector's picture
Upload folder using huggingface_hub
86dd177 verified
"""ShareGPT + POLAR reward environment."""
from __future__ import annotations
from pathlib import Path
from typing import Any
from datasets import Dataset, load_dataset
import asyncio
import verifiers as vf
from verifiers.types import Messages
from xtuner.utils import RewardModelClient
DEFAULT_MODEL = "internlm/POLAR-7B"
def _load_sharegpt_dataset(path: str | Path) -> Dataset:
dataset = load_dataset("json", data_files=str(path), split="train")
def to_single_turn(example: dict[str, Any]) -> dict[str, Any]:
human_turn = next(
turn["value"] for turn in example["conversations"] if turn["from"] == "human"
)
assistant_turn = next(
turn["value"] for turn in example["conversations"] if turn["from"] == "gpt"
)
return {
"prompt": [{"role": "user", "content": human_turn}],
"info": {
"reference": [{"role": "assistant", "content": assistant_turn}],
},
}
return dataset.map(to_single_turn, remove_columns=dataset.column_names)
class PoolingClient:
def __init__(
self,
model_path: str,
server_address: str,
server_type: str = "lmdeploy",
max_length: int = 16384,
max_response_length: int = 4096,
response_cut_side: str = "left",
):
self.client = RewardModelClient(
model_path,
max_length=max_length,
max_response_length=max_response_length,
response_cut_side=response_cut_side,
server_type=server_type,
server_address=server_address,
)
def encode(self, sample: dict[str, Any]) -> str:
prompt_text = "\n".join(
message["content"] for message in sample.get("prompt", [])
)
reference_text = "\n".join(
message["content"] for message in sample.get("reference", [])
)
output_text = "\n".join(
message["content"] for message in sample.get("output", [])
)
return f"{prompt_text}\n{reference_text}<|reward|>{prompt_text}\n{output_text}[UNUSED_TOKEN_130]"
def score(self, payload: list[dict[str, Any]]) -> list[float]:
encoded_payload = [self.encode(item) for item in payload]
rewards = self.client.lmdeploy_request_reward(encoded_payload)
if rewards is None:
raise RuntimeError("Failed to get rewards from lmdeploy server")
return rewards
async def polar_reward(
prompt: Messages,
completion: Messages,
info: dict[str, Any],
reward_client: PoolingClient,
pooling_semaphore: asyncio.Semaphore,
**_: Any,
) -> float:
assistant_turns = [msg for msg in completion if msg.get("role") == "assistant"]
if not assistant_turns:
return 0.0
payload = [
{
"prompt": prompt,
"reference": info.get("reference", []),
"output": [assistant_turns[-1]],
}
]
async with pooling_semaphore:
loop = asyncio.get_running_loop()
rewards = await loop.run_in_executor(None, reward_client.score, payload)
if rewards:
return float(rewards[-1]) * 10.0
raise RuntimeError(f"Unexpected reward response: {rewards}")
def load_environment(
data_path: str | Path,
*,
server_address: str,
reward_model: str = DEFAULT_MODEL,
reward_scheme: type[vf.Rubric] | None = None,
server_type: str = "lmdeploy",
**env_kwargs: Any,
) -> vf.SingleTurnEnv:
dataset = _load_sharegpt_dataset(data_path)
client = PoolingClient(
model_path=reward_model,
server_address=server_address,
server_type=server_type,
)
rubric_cls = reward_scheme or vf.Rubric
rubric = rubric_cls(funcs=[polar_reward])
rubric.class_objects["reward_client"] = client
rubric.class_objects.setdefault("pooling_semaphore", asyncio.Semaphore(4))
return vf.SingleTurnEnv(dataset=dataset, rubric=rubric, **env_kwargs)