Spaces:
Sleeping
Sleeping
File size: 13,548 Bytes
13b4881 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 | """Baseline inference script for the Carrom OpenEnv environment.
Runs an LLM agent against the Carrom environment (ICF rules) and reports
game performance and Green Agent efficiency metrics.
Supports any OpenAI-compatible API endpoint. Configure via environment
variables so the same script works with HuggingFace Inference, Nebius,
vLLM, OpenAI, or any other compatible provider:
# HuggingFace Inference Router
export API_BASE_URL="https://api-inference.huggingface.co/v1"
export MODEL_NAME="Qwen/Qwen3-4B"
export HF_TOKEN="hf_..."
# Nebius
export API_BASE_URL="https://api.studio.nebius.com/v1"
export MODEL_NAME="nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"
export NEBIUS_API_KEY="ey..."
# OpenAI
export API_BASE_URL="https://api.openai.com/v1"
export MODEL_NAME="gpt-4o-mini"
export OPENAI_API_KEY="sk-..."
# Local vLLM
export API_BASE_URL="http://localhost:8000/v1"
export MODEL_NAME="Qwen/Qwen2.5-7B-Instruct"
# No key needed for local
Then run:
python inference.py
"""
from __future__ import annotations
import argparse
import json
import os
import re
import subprocess
import sys
import time
from typing import Optional
import requests
from carrom_env.env import CarromEnv
from carrom_env.models import Action, Observation
from carrom_env.green_agent import GreenCarromAgent, EvalReport, Task
# ---------------------------------------------------------------------------
# Configuration β all overridable via environment variables
# ---------------------------------------------------------------------------
API_BASE_URL = os.environ.get(
"API_BASE_URL",
"https://api-inference.huggingface.co/v1",
)
MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen3-4B")
# API key: checked in priority order β NEBIUS_API_KEY β OPENAI_API_KEY β HF_TOKEN
API_KEY: str = (
os.environ.get("NEBIUS_API_KEY")
or os.environ.get("OPENAI_API_KEY")
or os.environ.get("HF_TOKEN")
or ""
)
MAX_STEPS = int(os.environ.get("MAX_STEPS", "30"))
NUM_EPISODES = int(os.environ.get("NUM_EPISODES", "3"))
TIMEOUT = int(os.environ.get("TIMEOUT_MINUTES", "20")) * 60
# ---------------------------------------------------------------------------
# System prompt (ICF rules)
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """\
You are an expert Carrom player following ICF (International Carrom Federation) rules.
Board layout
------------
- 1.0 Γ 1.0 square centred at (0, 0). Pockets at the four corners (Β±0.5, Β±0.5).
- Your striker starts on the BOTTOM baseline (y β -0.42).
- You play WHITE coins. The opponent plays BLACK coins.
Scoring & rules
---------------
- Pocket a WHITE coin β +1 point, take another turn
- Pocket the QUEEN β +3 points; you must then pocket a white coin on the
same shot OR your next turn to "cover" it
- Pocket a BLACK coin β DUE: coin returns to board centre, your turn ENDS
- Pocket the STRIKER β FOUL: one of your pocketed coins returns to board
Action format
-------------
Respond with ONLY a valid JSON object (no markdown, no explanation):
{
"placement_x": <float, -0.4 to 0.4, 0 = centre>,
"angle": <float, radians, 0 = straight ahead toward +y>,
"force": <float, 0.0 to 1.0>
}
Strategy tips
-------------
- Prioritise white coins close to pockets for easy points
- Avoid shooting black coins β even if they are near a pocket
- Queen near centre: aim to pocket it AND a white coin in the same shot
- Adjust placement_x to get a direct line on your target
"""
# ---------------------------------------------------------------------------
# LLM interaction
# ---------------------------------------------------------------------------
def call_llm(observation_text: str) -> Optional[dict]:
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
}
payload = {
"model": MODEL_NAME,
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": observation_text},
],
# Generous budget to accommodate reasoning models (e.g. MiniMax-M2.5,
# Nemotron) that emit long CoT before the final JSON answer.
"max_tokens": int(os.environ.get("MAX_TOKENS", "2048")),
"temperature": 0.3,
}
try:
resp = requests.post(
f"{API_BASE_URL}/chat/completions",
headers=headers,
json=payload,
timeout=120,
)
resp.raise_for_status()
msg = resp.json()["choices"][0]["message"]
# Reasoning models put their final answer in `content` and the trace in
# `reasoning_content`. Fall back to reasoning_content if content is
# null (common when the JSON is inline inside the reasoning).
text = msg.get("content") or msg.get("reasoning_content") or ""
return _parse_json_action(text)
except Exception as e:
print(f" [LLM error] {e}")
return None
def _parse_json_action(text: str) -> Optional[dict]:
text = text.strip()
text = re.sub(r"^```(?:json)?\s*", "", text)
text = re.sub(r"\s*```$", "", text)
# Strip <think>β¦</think> blocks (some reasoning models)
text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
match = re.search(r"\{[^}]+\}", text)
if match:
try:
data = json.loads(match.group())
return {
"placement_x": float(data.get("placement_x", 0.0)),
"angle": float(data.get("angle", 0.0)),
"force": float(data.get("force", 0.5)),
}
except (json.JSONDecodeError, ValueError, TypeError):
pass
return None
# ---------------------------------------------------------------------------
# Policies
# ---------------------------------------------------------------------------
_llm_turn_counter = {"n": 0}
def llm_policy(obs: Observation) -> Action:
_llm_turn_counter["n"] += 1
parsed = call_llm(obs.text_summary)
if parsed:
action = Action(**parsed)
print(f" [shot {_llm_turn_counter['n']:>3}] "
f"px={action.placement_x:+.2f} "
f"angle={action.angle:+.2f} "
f"force={action.force:.2f} "
f"(score {obs.agent_score}-{obs.opponent_score}, "
f"coins left {obs.remaining_coins})", flush=True)
return action
import random
print(f" [shot {_llm_turn_counter['n']:>3}] PARSE FAIL β random fallback", flush=True)
return Action(
placement_x=random.uniform(-0.2, 0.2),
angle=random.uniform(-0.5, 0.5),
force=random.uniform(0.3, 0.8),
)
def random_policy(obs: Observation) -> Action:
import random
return Action(
placement_x=random.uniform(-0.35, 0.35),
angle=random.uniform(-1.0, 1.0),
force=random.uniform(0.2, 1.0),
)
def heuristic_policy(obs: Observation) -> Action:
"""Aim at the nearest WHITE coin to a pocket; avoid black coins."""
import math
best_angle = 0.0
best_placement = 0.0
best_score = float("inf")
baseline_y = -0.5 + 0.08
for coin in obs.coins:
if coin.pocketed:
continue
# Skip black coins β pocketing them is a due under ICF rules
if coin.color == "black":
continue
dx = coin.x - 0.0
dy = coin.y - baseline_y
angle = math.atan2(dx, dy)
score = coin.pocket_distance
if coin.color == "queen":
score *= 0.5
if score < best_score:
best_score = score
best_angle = angle
best_placement = max(-0.35, min(0.35, coin.x * 0.5))
return Action(placement_x=best_placement, angle=best_angle, force=0.6)
# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------
def build_task_suite(num_episodes: int, max_steps: int) -> list[Task]:
"""Flat task suite used for baseline comparisons: `num_episodes` tasks at
the given horizon, each with a unique seed. Keeps every policy evaluated
on the *same* set of board states for fair comparison.
"""
return [
Task(task_id=f"ep_{i}", seed=i * 100, max_turns=max_steps, tier="standard")
for i in range(num_episodes)
]
def run_baseline(
policy_fn,
policy_name: str,
tasks: list[Task],
) -> EvalReport:
"""Run a purple agent (policy_fn) against the shared task suite
using the green-agent evaluator, and print the scorecard.
"""
print(f"\n--- Evaluating: {policy_name} ({len(tasks)} tasks) ---")
evaluator = GreenCarromAgent(tasks=tasks)
report = evaluator.evaluate(policy_fn, verbose=True)
s = report.summary()
print(f"\n=== {policy_name} ({s['n_tasks']} tasks) ===")
print(f" Avg reward : {s['avg_reward']:+.3f}")
print(f" Win rate : {s['win_rate']:.2f}")
print(f" Avg coins : {s['avg_coins_potted']:.1f}")
print(f" Avg dues : {s['avg_dues']:.2f} (ICF violations)")
print(f" Avg fouls : {s['avg_fouls']:.2f}")
print(f" ICF compliance : {s['icf_compliance']:.3f}")
print(f" Sim steps : {s['total_sim_steps']}")
print(f" Efficiency : {s['efficiency_score']:.4f} coins/1k-steps")
return report
def launch_web_server(host: str = "0.0.0.0", port: int = 8000) -> None:
"""Start the FastAPI + Gradio server (foreground) and print the watch URL.
Use this when you want to watch the LLM play on the board and screen-record
it. Configure the endpoint/model/key inside the "Auto-play with LLM" panel
in the browser, then click "Auto-play" to stream animated shots.
The environment variables ``API_BASE_URL``, ``MODEL_NAME``, and an API key
(``NEBIUS_API_KEY`` / ``OPENAI_API_KEY`` / ``HF_TOKEN``) are inherited as
defaults in the web form.
"""
env = os.environ.copy()
env["ENABLE_WEB_INTERFACE"] = "true"
env.setdefault("PYTHONPATH", os.getcwd())
url = f"http://localhost:{port}/web"
print("=" * 70)
print("Carrom server starting with web UIβ¦")
print(f" Open: {url}")
print(f" Inside the UI, configure model/endpoint, set number of shots,")
print(f" then click the \"π€ Auto-play with LLM\" button to watch it play.")
print(f" Press Ctrl+C in this terminal to stop the server.")
print("=" * 70)
cmd = [
sys.executable, "-m", "uvicorn",
"server.app:app",
"--host", host,
"--port", str(port),
"--ws-ping-interval", "60",
"--ws-ping-timeout", "60",
]
# Foreground: user ctrl-c's to stop
try:
subprocess.run(cmd, env=env, check=False)
except KeyboardInterrupt:
print("\nServer stopped.")
def main():
parser = argparse.ArgumentParser(
description="Carrom inference β headless baselines or live web view."
)
parser.add_argument(
"--web", action="store_true",
help="Start the env server + Gradio web UI for auto-play watching "
"(screen-record friendly). No headless baselines run in this mode.",
)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="0.0.0.0")
args = parser.parse_args()
if args.web:
launch_web_server(host=args.host, port=args.port)
return
print(f"API endpoint : {API_BASE_URL}")
print(f"Model : {MODEL_NAME}")
print(f"API key set : {'yes' if API_KEY else 'no'}")
# Shared task suite β every policy sees the same boards (deterministic)
tasks = build_task_suite(NUM_EPISODES, MAX_STEPS)
print(f"Task suite : {len(tasks)} Γ {MAX_STEPS}-turn boards\n")
start = time.time()
reports: dict[str, EvalReport] = {}
print("=" * 60 + "\nPURPLE AGENT: Random\n" + "=" * 60)
reports["random"] = run_baseline(random_policy, "Random", tasks)
print("\n" + "=" * 60 + "\nPURPLE AGENT: Heuristic (ICF-aware)\n" + "=" * 60)
reports["heuristic"] = run_baseline(heuristic_policy, "Heuristic", tasks)
if API_KEY:
elapsed = time.time() - start
if elapsed < TIMEOUT - 120:
print(f"\n{'=' * 60}\nPURPLE AGENT: LLM ({MODEL_NAME})\n{'=' * 60}")
reports["llm"] = run_baseline(llm_policy, f"LLM ({MODEL_NAME})", tasks)
else:
print(f"\nSkipping LLM baseline β {elapsed:.0f}s elapsed.")
else:
print("\nSkipping LLM baseline β no API key (set NEBIUS_API_KEY / OPENAI_API_KEY / HF_TOKEN).")
# ββ Leaderboard ββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n" + "=" * 78 + "\nLEADERBOARD\n" + "=" * 78)
print(f"{'Purple Agent':<25} {'Reward':>8} {'Win%':>6} {'Coins':>6} {'Dues':>6} {'ICF%':>6} {'Eff':>8}")
print("-" * 78)
for name, report in reports.items():
s = report.summary()
print(f"{name:<25} {s['avg_reward']:>+8.2f} {s['win_rate']*100:>5.0f}% "
f"{s['avg_coins_potted']:>6.1f} {s['avg_dues']:>6.2f} "
f"{s['icf_compliance']*100:>5.0f}% {s['efficiency_score']:>8.3f}")
print(f"\nTotal runtime: {time.time() - start:.1f}s")
if __name__ == "__main__":
main()
|