echo_env / src /core /rubrics /llm_judge.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
7d23275 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""LLM-as-a-judge rubric for reward computation.
Uses an LLM endpoint (via LLMClient) to evaluate agent actions/observations.
Usage:
client = OpenAIClient("http://localhost", 8000, model="meta-llama/...")
judge = LLMJudge(
prompt_template="Rate this code solution:\\n{action}\\n\\nScore (0-1):",
client=client,
)
score = await judge(action, observation)
See RFC 004 for full design: rfcs/004-rubrics.md
"""
import re
from typing import Any, Dict
from openenv.core.llm_client import LLMClient
from openenv.core.rubrics.base import Rubric
class LLMJudge(Rubric):
"""Rubric that uses an LLM to evaluate agent actions/observations.
The prompt template is formatted with ``{action}`` and ``{observation}``
placeholders. The LLM response is parsed for a numeric score.
Args:
prompt_template: Template string with {action} and {observation} placeholders.
client: An LLMClient instance for making LLM calls.
score_pattern: Regex to extract the score from the LLM response.
Defaults to matching the first decimal number.
default_score: Score returned when parsing fails.
normalize: If True, clamp extracted score to [0, 1].
"""
def __init__(
self,
prompt_template: str,
client: LLMClient,
*,
score_pattern: str | None = None,
default_score: float = 0.0,
normalize: bool = True,
):
super().__init__()
self.prompt_template = prompt_template
self._client = client
self._score_pattern = re.compile(score_pattern or r"(\d+\.?\d*)")
self.default_score = default_score
self.normalize = normalize
async def forward(self, action: Any, observation: Any) -> float:
"""Evaluate by sending a prompt to the LLM and parsing the score.
Args:
action: The action taken by the agent.
observation: The resulting observation.
Returns:
Parsed score from the LLM response.
"""
prompt = self._render_prompt(action, observation)
response = await self._client.complete(prompt)
return self._parse_score(response)
def _render_prompt(self, action: Any, observation: Any) -> str:
"""Format the prompt template with action and observation.
Override in subclasses for custom prompt construction.
"""
return self.prompt_template.format(action=action, observation=observation)
def _parse_score(self, response: str) -> float:
"""Extract a numeric score from the LLM response.
Uses the configured regex pattern to find the first match.
Returns default_score if no match is found.
"""
match = self._score_pattern.search(response)
if match is None:
return self.default_score
try:
# Use first capture group if present, otherwise full match
text = match.group(1) if match.lastindex else match.group(0)
score = float(text)
except (ValueError, IndexError):
return self.default_score
if self.normalize:
score = max(0.0, min(1.0, score))
return score
def state_dict(self) -> Dict[str, Any]:
"""Serialize rubric configuration."""
return {
"prompt_template": self.prompt_template,
"score_pattern": self._score_pattern.pattern,
"default_score": self.default_score,
"normalize": self.normalize,
}
def load_state_dict(self, state: Dict[str, Any]) -> None:
"""Load rubric configuration from checkpoint."""
if "prompt_template" in state:
self.prompt_template = state["prompt_template"]
if "score_pattern" in state:
self._score_pattern = re.compile(state["score_pattern"])
if "default_score" in state:
self.default_score = state["default_score"]
if "normalize" in state:
self.normalize = state["normalize"]