"""Reward provider utilities for TextArena environments.""" from __future__ import annotations import re from typing import Dict, List, Protocol, Tuple try: from textarena_env.models import TextArenaAction, TextArenaObservation except ImportError: from models import TextArenaAction, TextArenaObservation class RewardProvider(Protocol): """Interface for computing auxiliary reward signals.""" def reset(self) -> None: """Clear any internal state before a new episode.""" def compute( self, *, action: TextArenaAction, observation: TextArenaObservation ) -> Dict[str, float]: """Return a mapping of reward names to float values for the step.""" def build_reward_providers(env_id: str) -> List[RewardProvider]: """Instantiate reward providers appropriate for the given environment.""" providers: List[RewardProvider] = [] if env_id == "Wordle-v0": providers.append(_WordleRewardProvider()) return providers _WORDLE_GUESS_PATTERN = re.compile(r"\[[A-Za-z]{5}\]") def extract_guess(text: str) -> str: """Normalize a Wordle guess string from arbitrary text.""" match = _WORDLE_GUESS_PATTERN.search(text) if match: return match.group(0).lower() cleaned = re.sub(r"[^a-z]", "", text.lower()) if len(cleaned) >= 5: return f"[{cleaned[:5]}]" return "[dunno]" def extract_wordle_feedback(observation: TextArenaObservation) -> str: """Pull the latest feedback text from a Wordle observation.""" for message in reversed(observation.messages): content = message.content.strip() if "Feedback:" in content: return content.split("Feedback:", 1)[-1].strip() return "" def extract_feedback_counts(feedback: str) -> Tuple[int, int]: """Return counts of green (G) and yellow (Y) markers from feedback.""" if not feedback: return (0, 0) lines = [line.strip() for line in feedback.split("\n") if line.strip()] if len(lines) < 2: return (0, 0) for line in reversed(lines): normalized = line.replace(" ", "") if normalized and all(c in "GYX" for c in normalized): green = normalized.count("G") yellow = normalized.count("Y") return (green, yellow) return (0, 0) class _WordleRewardProvider: """Reward provider that mirrors the GRPO Wordle heuristics.""" SIGNAL_MAP = { "greens": "wordle.greens", "yellows": "wordle.yellows", "repetitions": "wordle.repetitions", "correct": "wordle.correct", } def __init__(self) -> None: self._guess_history: Dict[str, int] = {} def reset(self) -> None: self._guess_history.clear() def compute( self, *, action: TextArenaAction, observation: TextArenaObservation ) -> Dict[str, float]: guess = extract_guess(action.message) feedback = extract_wordle_feedback(observation) normalized_guess = guess if guess and guess != "[dunno]" else "" previous_occurrences = ( self._guess_history.get(normalized_guess, 0) if normalized_guess else 0 ) green_score = 0.0 yellow_score = 0.0 if feedback: green_count, yellow_count = extract_feedback_counts(feedback) green_score = green_count / 5.0 yellow_score = yellow_count / 5.0 repetition_score = 1.0 - previous_occurrences correct_score = float(observation.reward or 0.0) if normalized_guess: self._guess_history[normalized_guess] = previous_occurrences + 1 return { self.SIGNAL_MAP["greens"]: float(green_score), self.SIGNAL_MAP["yellows"]: float(yellow_score), self.SIGNAL_MAP["repetitions"]: float(repetition_score), self.SIGNAL_MAP["correct"]: float(correct_score), } __all__ = [ "RewardProvider", "build_reward_providers", "extract_feedback_counts", "extract_guess", "extract_wordle_feedback", ]