File size: 4,045 Bytes
6e194fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f3ecd2
 
 
6e194fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f3ecd2
 
 
6e194fa
 
 
 
0f3ecd2
 
 
6e194fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Reward provider utilities for TextArena environments."""

from __future__ import annotations

import re
from typing import Dict, List, Protocol, Tuple

try:
    from textarena_env.models import TextArenaAction, TextArenaObservation
except ImportError:
    from models import TextArenaAction, TextArenaObservation


class RewardProvider(Protocol):
    """Interface for computing auxiliary reward signals."""

    def reset(self) -> None:
        """Clear any internal state before a new episode."""

    def compute(
        self, *, action: TextArenaAction, observation: TextArenaObservation
    ) -> Dict[str, float]:
        """Return a mapping of reward names to float values for the step."""


def build_reward_providers(env_id: str) -> List[RewardProvider]:
    """Instantiate reward providers appropriate for the given environment."""

    providers: List[RewardProvider] = []
    if env_id == "Wordle-v0":
        providers.append(_WordleRewardProvider())
    return providers


_WORDLE_GUESS_PATTERN = re.compile(r"\[[A-Za-z]{5}\]")


def extract_guess(text: str) -> str:
    """Normalize a Wordle guess string from arbitrary text."""

    match = _WORDLE_GUESS_PATTERN.search(text)
    if match:
        return match.group(0).lower()

    cleaned = re.sub(r"[^a-z]", "", text.lower())
    if len(cleaned) >= 5:
        return f"[{cleaned[:5]}]"
    return "[dunno]"


def extract_wordle_feedback(observation: TextArenaObservation) -> str:
    """Pull the latest feedback text from a Wordle observation."""

    for message in reversed(observation.messages):
        content = message.content.strip()
        if "Feedback:" in content:
            return content.split("Feedback:", 1)[-1].strip()
    return ""


def extract_feedback_counts(feedback: str) -> Tuple[int, int]:
    """Return counts of green (G) and yellow (Y) markers from feedback."""

    if not feedback:
        return (0, 0)

    lines = [line.strip() for line in feedback.split("\n") if line.strip()]
    if len(lines) < 2:
        return (0, 0)

    for line in reversed(lines):
        normalized = line.replace(" ", "")
        if normalized and all(c in "GYX" for c in normalized):
            green = normalized.count("G")
            yellow = normalized.count("Y")
            return (green, yellow)

    return (0, 0)


class _WordleRewardProvider:
    """Reward provider that mirrors the GRPO Wordle heuristics."""

    SIGNAL_MAP = {
        "greens": "wordle.greens",
        "yellows": "wordle.yellows",
        "repetitions": "wordle.repetitions",
        "correct": "wordle.correct",
    }

    def __init__(self) -> None:
        self._guess_history: Dict[str, int] = {}

    def reset(self) -> None:
        self._guess_history.clear()

    def compute(
        self, *, action: TextArenaAction, observation: TextArenaObservation
    ) -> Dict[str, float]:
        guess = extract_guess(action.message)
        feedback = extract_wordle_feedback(observation)

        normalized_guess = guess if guess and guess != "[dunno]" else ""
        previous_occurrences = (
            self._guess_history.get(normalized_guess, 0) if normalized_guess else 0
        )

        green_score = 0.0
        yellow_score = 0.0
        if feedback:
            green_count, yellow_count = extract_feedback_counts(feedback)
            green_score = green_count / 5.0
            yellow_score = yellow_count / 5.0

        repetition_score = 1.0 - previous_occurrences
        correct_score = float(observation.reward or 0.0)

        if normalized_guess:
            self._guess_history[normalized_guess] = previous_occurrences + 1

        return {
            self.SIGNAL_MAP["greens"]: float(green_score),
            self.SIGNAL_MAP["yellows"]: float(yellow_score),
            self.SIGNAL_MAP["repetitions"]: float(repetition_score),
            self.SIGNAL_MAP["correct"]: float(correct_score),
        }


__all__ = [
    "RewardProvider",
    "build_reward_providers",
    "extract_feedback_counts",
    "extract_guess",
    "extract_wordle_feedback",
]