File size: 4,246 Bytes
7d23275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""LLM-as-a-judge rubric for reward computation.

Uses an LLM endpoint (via LLMClient) to evaluate agent actions/observations.

Usage:
    client = OpenAIClient("http://localhost", 8000, model="meta-llama/...")
    judge = LLMJudge(
        prompt_template="Rate this code solution:\\n{action}\\n\\nScore (0-1):",
        client=client,
    )
    score = await judge(action, observation)

See RFC 004 for full design: rfcs/004-rubrics.md
"""

import re
from typing import Any, Dict

from openenv.core.llm_client import LLMClient
from openenv.core.rubrics.base import Rubric


class LLMJudge(Rubric):
    """Rubric that uses an LLM to evaluate agent actions/observations.

    The prompt template is formatted with ``{action}`` and ``{observation}``
    placeholders. The LLM response is parsed for a numeric score.

    Args:
        prompt_template: Template string with {action} and {observation} placeholders.
        client: An LLMClient instance for making LLM calls.
        score_pattern: Regex to extract the score from the LLM response.
            Defaults to matching the first decimal number.
        default_score: Score returned when parsing fails.
        normalize: If True, clamp extracted score to [0, 1].
    """

    def __init__(
        self,
        prompt_template: str,
        client: LLMClient,
        *,
        score_pattern: str | None = None,
        default_score: float = 0.0,
        normalize: bool = True,
    ):
        super().__init__()
        self.prompt_template = prompt_template
        self._client = client
        self._score_pattern = re.compile(score_pattern or r"(\d+\.?\d*)")
        self.default_score = default_score
        self.normalize = normalize

    async def forward(self, action: Any, observation: Any) -> float:
        """Evaluate by sending a prompt to the LLM and parsing the score.

        Args:
            action: The action taken by the agent.
            observation: The resulting observation.

        Returns:
            Parsed score from the LLM response.
        """
        prompt = self._render_prompt(action, observation)
        response = await self._client.complete(prompt)
        return self._parse_score(response)

    def _render_prompt(self, action: Any, observation: Any) -> str:
        """Format the prompt template with action and observation.

        Override in subclasses for custom prompt construction.
        """
        return self.prompt_template.format(action=action, observation=observation)

    def _parse_score(self, response: str) -> float:
        """Extract a numeric score from the LLM response.

        Uses the configured regex pattern to find the first match.
        Returns default_score if no match is found.
        """
        match = self._score_pattern.search(response)
        if match is None:
            return self.default_score
        try:
            # Use first capture group if present, otherwise full match
            text = match.group(1) if match.lastindex else match.group(0)
            score = float(text)
        except (ValueError, IndexError):
            return self.default_score
        if self.normalize:
            score = max(0.0, min(1.0, score))
        return score

    def state_dict(self) -> Dict[str, Any]:
        """Serialize rubric configuration."""
        return {
            "prompt_template": self.prompt_template,
            "score_pattern": self._score_pattern.pattern,
            "default_score": self.default_score,
            "normalize": self.normalize,
        }

    def load_state_dict(self, state: Dict[str, Any]) -> None:
        """Load rubric configuration from checkpoint."""
        if "prompt_template" in state:
            self.prompt_template = state["prompt_template"]
        if "score_pattern" in state:
            self._score_pattern = re.compile(state["score_pattern"])
        if "default_score" in state:
            self.default_score = state["default_score"]
        if "normalize" in state:
            self.normalize = state["normalize"]