| """Reward provider utilities for TextArena environments.""" | |
| from __future__ import annotations | |
| import re | |
| from typing import Dict, List, Protocol, Tuple | |
| try: | |
| from textarena_env.models import TextArenaAction, TextArenaObservation | |
| except ImportError: | |
| from models import TextArenaAction, TextArenaObservation | |
| class RewardProvider(Protocol): | |
| """Interface for computing auxiliary reward signals.""" | |
| def reset(self) -> None: | |
| """Clear any internal state before a new episode.""" | |
| def compute( | |
| self, *, action: TextArenaAction, observation: TextArenaObservation | |
| ) -> Dict[str, float]: | |
| """Return a mapping of reward names to float values for the step.""" | |
| def build_reward_providers(env_id: str) -> List[RewardProvider]: | |
| """Instantiate reward providers appropriate for the given environment.""" | |
| providers: List[RewardProvider] = [] | |
| if env_id == "Wordle-v0": | |
| providers.append(_WordleRewardProvider()) | |
| return providers | |
| _WORDLE_GUESS_PATTERN = re.compile(r"\[[A-Za-z]{5}\]") | |
| def extract_guess(text: str) -> str: | |
| """Normalize a Wordle guess string from arbitrary text.""" | |
| match = _WORDLE_GUESS_PATTERN.search(text) | |
| if match: | |
| return match.group(0).lower() | |
| cleaned = re.sub(r"[^a-z]", "", text.lower()) | |
| if len(cleaned) >= 5: | |
| return f"[{cleaned[:5]}]" | |
| return "[dunno]" | |
| def extract_wordle_feedback(observation: TextArenaObservation) -> str: | |
| """Pull the latest feedback text from a Wordle observation.""" | |
| for message in reversed(observation.messages): | |
| content = message.content.strip() | |
| if "Feedback:" in content: | |
| return content.split("Feedback:", 1)[-1].strip() | |
| return "" | |
| def extract_feedback_counts(feedback: str) -> Tuple[int, int]: | |
| """Return counts of green (G) and yellow (Y) markers from feedback.""" | |
| if not feedback: | |
| return (0, 0) | |
| lines = [line.strip() for line in feedback.split("\n") if line.strip()] | |
| if len(lines) < 2: | |
| return (0, 0) | |
| for line in reversed(lines): | |
| normalized = line.replace(" ", "") | |
| if normalized and all(c in "GYX" for c in normalized): | |
| green = normalized.count("G") | |
| yellow = normalized.count("Y") | |
| return (green, yellow) | |
| return (0, 0) | |
| class _WordleRewardProvider: | |
| """Reward provider that mirrors the GRPO Wordle heuristics.""" | |
| SIGNAL_MAP = { | |
| "greens": "wordle.greens", | |
| "yellows": "wordle.yellows", | |
| "repetitions": "wordle.repetitions", | |
| "correct": "wordle.correct", | |
| } | |
| def __init__(self) -> None: | |
| self._guess_history: Dict[str, int] = {} | |
| def reset(self) -> None: | |
| self._guess_history.clear() | |
| def compute( | |
| self, *, action: TextArenaAction, observation: TextArenaObservation | |
| ) -> Dict[str, float]: | |
| guess = extract_guess(action.message) | |
| feedback = extract_wordle_feedback(observation) | |
| normalized_guess = guess if guess and guess != "[dunno]" else "" | |
| previous_occurrences = ( | |
| self._guess_history.get(normalized_guess, 0) if normalized_guess else 0 | |
| ) | |
| green_score = 0.0 | |
| yellow_score = 0.0 | |
| if feedback: | |
| green_count, yellow_count = extract_feedback_counts(feedback) | |
| green_score = green_count / 5.0 | |
| yellow_score = yellow_count / 5.0 | |
| repetition_score = 1.0 - previous_occurrences | |
| correct_score = float(observation.reward or 0.0) | |
| if normalized_guess: | |
| self._guess_history[normalized_guess] = previous_occurrences + 1 | |
| return { | |
| self.SIGNAL_MAP["greens"]: float(green_score), | |
| self.SIGNAL_MAP["yellows"]: float(yellow_score), | |
| self.SIGNAL_MAP["repetitions"]: float(repetition_score), | |
| self.SIGNAL_MAP["correct"]: float(correct_score), | |
| } | |
| __all__ = [ | |
| "RewardProvider", | |
| "build_reward_providers", | |
| "extract_feedback_counts", | |
| "extract_guess", | |
| "extract_wordle_feedback", | |
| ] | |