wordle / rewards.py
sergiopaniego's picture
Upload folder using huggingface_hub
0f3ecd2 verified
"""Reward provider utilities for TextArena environments."""
from __future__ import annotations
import re
from typing import Dict, List, Protocol, Tuple
try:
from textarena_env.models import TextArenaAction, TextArenaObservation
except ImportError:
from models import TextArenaAction, TextArenaObservation
class RewardProvider(Protocol):
"""Interface for computing auxiliary reward signals."""
def reset(self) -> None:
"""Clear any internal state before a new episode."""
def compute(
self, *, action: TextArenaAction, observation: TextArenaObservation
) -> Dict[str, float]:
"""Return a mapping of reward names to float values for the step."""
def build_reward_providers(env_id: str) -> List[RewardProvider]:
"""Instantiate reward providers appropriate for the given environment."""
providers: List[RewardProvider] = []
if env_id == "Wordle-v0":
providers.append(_WordleRewardProvider())
return providers
_WORDLE_GUESS_PATTERN = re.compile(r"\[[A-Za-z]{5}\]")
def extract_guess(text: str) -> str:
"""Normalize a Wordle guess string from arbitrary text."""
match = _WORDLE_GUESS_PATTERN.search(text)
if match:
return match.group(0).lower()
cleaned = re.sub(r"[^a-z]", "", text.lower())
if len(cleaned) >= 5:
return f"[{cleaned[:5]}]"
return "[dunno]"
def extract_wordle_feedback(observation: TextArenaObservation) -> str:
"""Pull the latest feedback text from a Wordle observation."""
for message in reversed(observation.messages):
content = message.content.strip()
if "Feedback:" in content:
return content.split("Feedback:", 1)[-1].strip()
return ""
def extract_feedback_counts(feedback: str) -> Tuple[int, int]:
"""Return counts of green (G) and yellow (Y) markers from feedback."""
if not feedback:
return (0, 0)
lines = [line.strip() for line in feedback.split("\n") if line.strip()]
if len(lines) < 2:
return (0, 0)
for line in reversed(lines):
normalized = line.replace(" ", "")
if normalized and all(c in "GYX" for c in normalized):
green = normalized.count("G")
yellow = normalized.count("Y")
return (green, yellow)
return (0, 0)
class _WordleRewardProvider:
"""Reward provider that mirrors the GRPO Wordle heuristics."""
SIGNAL_MAP = {
"greens": "wordle.greens",
"yellows": "wordle.yellows",
"repetitions": "wordle.repetitions",
"correct": "wordle.correct",
}
def __init__(self) -> None:
self._guess_history: Dict[str, int] = {}
def reset(self) -> None:
self._guess_history.clear()
def compute(
self, *, action: TextArenaAction, observation: TextArenaObservation
) -> Dict[str, float]:
guess = extract_guess(action.message)
feedback = extract_wordle_feedback(observation)
normalized_guess = guess if guess and guess != "[dunno]" else ""
previous_occurrences = (
self._guess_history.get(normalized_guess, 0) if normalized_guess else 0
)
green_score = 0.0
yellow_score = 0.0
if feedback:
green_count, yellow_count = extract_feedback_counts(feedback)
green_score = green_count / 5.0
yellow_score = yellow_count / 5.0
repetition_score = 1.0 - previous_occurrences
correct_score = float(observation.reward or 0.0)
if normalized_guess:
self._guess_history[normalized_guess] = previous_occurrences + 1
return {
self.SIGNAL_MAP["greens"]: float(green_score),
self.SIGNAL_MAP["yellows"]: float(yellow_score),
self.SIGNAL_MAP["repetitions"]: float(repetition_score),
self.SIGNAL_MAP["correct"]: float(correct_score),
}
__all__ = [
"RewardProvider",
"build_reward_providers",
"extract_feedback_counts",
"extract_guess",
"extract_wordle_feedback",
]