carrom_rl_env / examples /grpo_utils.py
bpHigh's picture
Initial: OpenEnv Carrom v0.3.0 — ICF-compliant physics RL environment
13b4881
"""GRPO training utilities for the Carrom environment (ICF rules).
Provides reward functions, prompt formatting, and rollout collection
compatible with TRL's GRPOTrainer.
"""
from __future__ import annotations
import json
import math
import re
from typing import Any, Dict, List, Optional
from carrom_env.env import CarromEnv
from carrom_env.models import Action, Observation
# ---------------------------------------------------------------------------
# System prompt — ICF-aware
# ---------------------------------------------------------------------------
CARROM_SYSTEM_PROMPT = """\
You are a Carrom agent playing under ICF (International Carrom Federation) rules.
Board: 1.0 × 1.0 square centred at (0,0). Pockets at the four corners (±0.5, ±0.5).
Your striker is on the BOTTOM baseline (y ≈ -0.42).
Colour assignment
-----------------
YOU play WHITE coins (+1 pt each). Opponent plays BLACK coins.
Queen (red, centre) = +3 pts — must be "covered" by pocketing a white coin on the
same or next shot.
ICF rules to follow
-------------------
- Pocket a WHITE coin → score +1, take another turn
- Pocket a BLACK coin → DUE: coin returns to board centre, turn ENDS (no score)
- Pocket STRIKER → FOUL: one pocketed coin returns to board, turn ends
- Miss (nothing own) → turn passes to opponent
Respond with a JSON object and nothing else:
{"placement_x": <-0.4 to 0.4>, "angle": <radians, 0=straight>, "force": <0.0 to 1.0>}
Think step by step: identify reachable WHITE coins near pockets, then choose angle/force."""
# ---------------------------------------------------------------------------
# Prompt formatting
# ---------------------------------------------------------------------------
def format_prompt(obs: Observation) -> str:
return f"<|system|>\n{CARROM_SYSTEM_PROMPT}\n<|user|>\n{obs.text_summary}\n<|assistant|>"
def format_chat_prompt(obs: Observation) -> List[Dict[str, str]]:
return [
{"role": "system", "content": CARROM_SYSTEM_PROMPT},
{"role": "user", "content": obs.text_summary},
]
# ---------------------------------------------------------------------------
# Response parsing
# ---------------------------------------------------------------------------
def parse_response(response: str) -> Optional[Action]:
text = response.strip()
text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
text = re.sub(r"^```(?:json)?\s*", "", text)
text = re.sub(r"\s*```$", "", text).strip()
match = re.search(r"\{[^}]+\}", text)
if not match:
return None
try:
data = json.loads(match.group())
return Action(
placement_x=float(data.get("placement_x", 0.0)),
angle=float(data.get("angle", 0.0)),
force=max(0.0, min(1.0, float(data.get("force", 0.5)))),
)
except (json.JSONDecodeError, ValueError, TypeError):
return None
# ---------------------------------------------------------------------------
# Reward functions for GRPO
# ---------------------------------------------------------------------------
def carrom_reward_fn(
prompts: List[str],
completions: List[str],
env_rewards: List[float],
**kwargs,
) -> List[float]:
"""Combined reward for GRPO: format + range + environment outcome."""
rewards = []
for prompt, completion, env_reward in zip(prompts, completions, env_rewards):
reward = float(env_reward)
action = parse_response(completion)
if action is not None:
reward += 0.2
if -0.4 <= action.placement_x <= 0.4: reward += 0.05
if 0.1 <= action.force <= 0.95: reward += 0.05
if -math.pi / 2 <= action.angle <= math.pi / 2: reward += 0.05
else:
reward -= 0.5
rewards.append(reward)
return rewards
def carrom_reward_for_trl(completions: List[Any], **kwargs) -> List[float]:
"""Reward function matching TRL GRPOTrainer's expected signature.
Evaluates each completion by:
1. Parsing the JSON action (+0.3 if valid, -0.5 if not)
2. Checking parameter ranges (+0.1 each for placement, force, angle)
3. Executing in a fresh env instance for the actual game reward
— includes due-coin penalty from ICF rules
Completions can be plain strings or chat message lists (TRL passes both).
"""
rewards = []
for completion in completions:
# TRL may pass chat lists (vLLM path) or plain strings
if isinstance(completion, list):
text = completion[-1]["content"] if completion else ""
else:
text = str(completion)
reward = 0.0
action = parse_response(text)
if action is not None:
reward += 0.3
if -0.4 <= action.placement_x <= 0.4: reward += 0.1
if 0.15 <= action.force <= 0.9: reward += 0.1
if -math.pi / 2 <= action.angle <= math.pi / 2: reward += 0.1
try:
env = CarromEnv(seed=hash(text) % 100_000)
env.reset()
_, env_reward, _, _, info = env.step(action)
reward += env_reward
if info.get("coin_potted", 0) > 0:
reward += 0.5
# Additional signal for ICF due violations
reward -= 0.3 * info.get("due_coins", 0)
except Exception:
pass
else:
reward -= 0.5
rewards.append(reward)
return rewards
# ---------------------------------------------------------------------------
# Helpers for offline rollout collection
# ---------------------------------------------------------------------------
def compute_env_reward(
env: CarromEnv,
response: str,
) -> tuple[float, bool, Observation]:
action = parse_response(response)
if action is None:
obs = Observation(
positions=[], velocities=[], pocketed=[],
agent_score=env.agent_score,
opponent_score=env.opponent_score,
current_player="agent",
remaining_coins=0,
text_summary="Parse error.",
)
return -0.5, False, obs
obs, reward, terminated, truncated, info = env.step(action)
return reward, terminated or truncated, obs
def collect_rollouts(
generate_fn,
num_rollouts: int = 8,
max_turns_per_episode: int = 15,
seed: int = 0,
) -> Dict[str, List]:
"""Collect rollouts for offline GRPO training.
Args:
generate_fn: Callable(prompt: str) -> str.
num_rollouts: Number of episodes (group size G in GRPO).
max_turns_per_episode: Max steps per episode.
seed: Base seed.
"""
all_prompts, all_completions, all_rewards = [], [], []
for rollout_idx in range(num_rollouts):
env = CarromEnv(seed=seed + rollout_idx)
obs = env.reset()
for _ in range(max_turns_per_episode):
prompt = format_prompt(obs)
completion = generate_fn(prompt)
reward, done, obs = compute_env_reward(env, completion)
all_prompts.append(prompt)
all_completions.append(completion)
all_rewards.append(reward)
if done:
break
return {"prompts": all_prompts, "completions": all_completions, "rewards": all_rewards}