File size: 4,045 Bytes
6e194fa 0f3ecd2 6e194fa 0f3ecd2 6e194fa 0f3ecd2 6e194fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """Reward provider utilities for TextArena environments."""
from __future__ import annotations
import re
from typing import Dict, List, Protocol, Tuple
try:
from textarena_env.models import TextArenaAction, TextArenaObservation
except ImportError:
from models import TextArenaAction, TextArenaObservation
class RewardProvider(Protocol):
"""Interface for computing auxiliary reward signals."""
def reset(self) -> None:
"""Clear any internal state before a new episode."""
def compute(
self, *, action: TextArenaAction, observation: TextArenaObservation
) -> Dict[str, float]:
"""Return a mapping of reward names to float values for the step."""
def build_reward_providers(env_id: str) -> List[RewardProvider]:
"""Instantiate reward providers appropriate for the given environment."""
providers: List[RewardProvider] = []
if env_id == "Wordle-v0":
providers.append(_WordleRewardProvider())
return providers
_WORDLE_GUESS_PATTERN = re.compile(r"\[[A-Za-z]{5}\]")
def extract_guess(text: str) -> str:
"""Normalize a Wordle guess string from arbitrary text."""
match = _WORDLE_GUESS_PATTERN.search(text)
if match:
return match.group(0).lower()
cleaned = re.sub(r"[^a-z]", "", text.lower())
if len(cleaned) >= 5:
return f"[{cleaned[:5]}]"
return "[dunno]"
def extract_wordle_feedback(observation: TextArenaObservation) -> str:
"""Pull the latest feedback text from a Wordle observation."""
for message in reversed(observation.messages):
content = message.content.strip()
if "Feedback:" in content:
return content.split("Feedback:", 1)[-1].strip()
return ""
def extract_feedback_counts(feedback: str) -> Tuple[int, int]:
"""Return counts of green (G) and yellow (Y) markers from feedback."""
if not feedback:
return (0, 0)
lines = [line.strip() for line in feedback.split("\n") if line.strip()]
if len(lines) < 2:
return (0, 0)
for line in reversed(lines):
normalized = line.replace(" ", "")
if normalized and all(c in "GYX" for c in normalized):
green = normalized.count("G")
yellow = normalized.count("Y")
return (green, yellow)
return (0, 0)
class _WordleRewardProvider:
"""Reward provider that mirrors the GRPO Wordle heuristics."""
SIGNAL_MAP = {
"greens": "wordle.greens",
"yellows": "wordle.yellows",
"repetitions": "wordle.repetitions",
"correct": "wordle.correct",
}
def __init__(self) -> None:
self._guess_history: Dict[str, int] = {}
def reset(self) -> None:
self._guess_history.clear()
def compute(
self, *, action: TextArenaAction, observation: TextArenaObservation
) -> Dict[str, float]:
guess = extract_guess(action.message)
feedback = extract_wordle_feedback(observation)
normalized_guess = guess if guess and guess != "[dunno]" else ""
previous_occurrences = (
self._guess_history.get(normalized_guess, 0) if normalized_guess else 0
)
green_score = 0.0
yellow_score = 0.0
if feedback:
green_count, yellow_count = extract_feedback_counts(feedback)
green_score = green_count / 5.0
yellow_score = yellow_count / 5.0
repetition_score = 1.0 - previous_occurrences
correct_score = float(observation.reward or 0.0)
if normalized_guess:
self._guess_history[normalized_guess] = previous_occurrences + 1
return {
self.SIGNAL_MAP["greens"]: float(green_score),
self.SIGNAL_MAP["yellows"]: float(yellow_score),
self.SIGNAL_MAP["repetitions"]: float(repetition_score),
self.SIGNAL_MAP["correct"]: float(correct_score),
}
__all__ = [
"RewardProvider",
"build_reward_providers",
"extract_feedback_counts",
"extract_guess",
"extract_wordle_feedback",
]
|