Spaces:
Running
Running
| """Baseline inference script for the Carrom OpenEnv environment. | |
| Runs an LLM agent against the Carrom environment (ICF rules) and reports | |
| game performance and Green Agent efficiency metrics. | |
| Supports any OpenAI-compatible API endpoint. Configure via environment | |
| variables so the same script works with HuggingFace Inference, Nebius, | |
| vLLM, OpenAI, or any other compatible provider: | |
| # HuggingFace Inference Router | |
| export API_BASE_URL="https://api-inference.huggingface.co/v1" | |
| export MODEL_NAME="Qwen/Qwen3-4B" | |
| export HF_TOKEN="hf_..." | |
| # Nebius | |
| export API_BASE_URL="https://api.studio.nebius.com/v1" | |
| export MODEL_NAME="nvidia/Llama-3.1-Nemotron-70B-Instruct-HF" | |
| export NEBIUS_API_KEY="ey..." | |
| # OpenAI | |
| export API_BASE_URL="https://api.openai.com/v1" | |
| export MODEL_NAME="gpt-4o-mini" | |
| export OPENAI_API_KEY="sk-..." | |
| # Local vLLM | |
| export API_BASE_URL="http://localhost:8000/v1" | |
| export MODEL_NAME="Qwen/Qwen2.5-7B-Instruct" | |
| # No key needed for local | |
| Then run: | |
| python inference.py | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import re | |
| import subprocess | |
| import sys | |
| import time | |
| from typing import Optional | |
| import requests | |
| from carrom_env.env import CarromEnv | |
| from carrom_env.models import Action, Observation | |
| from carrom_env.green_agent import GreenCarromAgent, EvalReport, Task | |
| # --------------------------------------------------------------------------- | |
| # Configuration β all overridable via environment variables | |
| # --------------------------------------------------------------------------- | |
| API_BASE_URL = os.environ.get( | |
| "API_BASE_URL", | |
| "https://api-inference.huggingface.co/v1", | |
| ) | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen3-4B") | |
| # API key: checked in priority order β NEBIUS_API_KEY β OPENAI_API_KEY β HF_TOKEN | |
| API_KEY: str = ( | |
| os.environ.get("NEBIUS_API_KEY") | |
| or os.environ.get("OPENAI_API_KEY") | |
| or os.environ.get("HF_TOKEN") | |
| or "" | |
| ) | |
| MAX_STEPS = int(os.environ.get("MAX_STEPS", "30")) | |
| NUM_EPISODES = int(os.environ.get("NUM_EPISODES", "3")) | |
| TIMEOUT = int(os.environ.get("TIMEOUT_MINUTES", "20")) * 60 | |
| # --------------------------------------------------------------------------- | |
| # System prompt (ICF rules) | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = """\ | |
| You are an expert Carrom player following ICF (International Carrom Federation) rules. | |
| Board layout | |
| ------------ | |
| - 1.0 Γ 1.0 square centred at (0, 0). Pockets at the four corners (Β±0.5, Β±0.5). | |
| - Your striker starts on the BOTTOM baseline (y β -0.42). | |
| - You play WHITE coins. The opponent plays BLACK coins. | |
| Scoring & rules | |
| --------------- | |
| - Pocket a WHITE coin β +1 point, take another turn | |
| - Pocket the QUEEN β +3 points; you must then pocket a white coin on the | |
| same shot OR your next turn to "cover" it | |
| - Pocket a BLACK coin β DUE: coin returns to board centre, your turn ENDS | |
| - Pocket the STRIKER β FOUL: one of your pocketed coins returns to board | |
| Action format | |
| ------------- | |
| Respond with ONLY a valid JSON object (no markdown, no explanation): | |
| { | |
| "placement_x": <float, -0.4 to 0.4, 0 = centre>, | |
| "angle": <float, radians, 0 = straight ahead toward +y>, | |
| "force": <float, 0.0 to 1.0> | |
| } | |
| Strategy tips | |
| ------------- | |
| - Prioritise white coins close to pockets for easy points | |
| - Avoid shooting black coins β even if they are near a pocket | |
| - Queen near centre: aim to pocket it AND a white coin in the same shot | |
| - Adjust placement_x to get a direct line on your target | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # LLM interaction | |
| # --------------------------------------------------------------------------- | |
| def call_llm(observation_text: str) -> Optional[dict]: | |
| headers = { | |
| "Authorization": f"Bearer {API_KEY}", | |
| "Content-Type": "application/json", | |
| } | |
| payload = { | |
| "model": MODEL_NAME, | |
| "messages": [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": observation_text}, | |
| ], | |
| # Generous budget to accommodate reasoning models (e.g. MiniMax-M2.5, | |
| # Nemotron) that emit long CoT before the final JSON answer. | |
| "max_tokens": int(os.environ.get("MAX_TOKENS", "2048")), | |
| "temperature": 0.3, | |
| } | |
| try: | |
| resp = requests.post( | |
| f"{API_BASE_URL}/chat/completions", | |
| headers=headers, | |
| json=payload, | |
| timeout=120, | |
| ) | |
| resp.raise_for_status() | |
| msg = resp.json()["choices"][0]["message"] | |
| # Reasoning models put their final answer in `content` and the trace in | |
| # `reasoning_content`. Fall back to reasoning_content if content is | |
| # null (common when the JSON is inline inside the reasoning). | |
| text = msg.get("content") or msg.get("reasoning_content") or "" | |
| return _parse_json_action(text) | |
| except Exception as e: | |
| print(f" [LLM error] {e}") | |
| return None | |
| def _parse_json_action(text: str) -> Optional[dict]: | |
| text = text.strip() | |
| text = re.sub(r"^```(?:json)?\s*", "", text) | |
| text = re.sub(r"\s*```$", "", text) | |
| # Strip <think>β¦</think> blocks (some reasoning models) | |
| text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip() | |
| match = re.search(r"\{[^}]+\}", text) | |
| if match: | |
| try: | |
| data = json.loads(match.group()) | |
| return { | |
| "placement_x": float(data.get("placement_x", 0.0)), | |
| "angle": float(data.get("angle", 0.0)), | |
| "force": float(data.get("force", 0.5)), | |
| } | |
| except (json.JSONDecodeError, ValueError, TypeError): | |
| pass | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Policies | |
| # --------------------------------------------------------------------------- | |
| _llm_turn_counter = {"n": 0} | |
| def llm_policy(obs: Observation) -> Action: | |
| _llm_turn_counter["n"] += 1 | |
| parsed = call_llm(obs.text_summary) | |
| if parsed: | |
| action = Action(**parsed) | |
| print(f" [shot {_llm_turn_counter['n']:>3}] " | |
| f"px={action.placement_x:+.2f} " | |
| f"angle={action.angle:+.2f} " | |
| f"force={action.force:.2f} " | |
| f"(score {obs.agent_score}-{obs.opponent_score}, " | |
| f"coins left {obs.remaining_coins})", flush=True) | |
| return action | |
| import random | |
| print(f" [shot {_llm_turn_counter['n']:>3}] PARSE FAIL β random fallback", flush=True) | |
| return Action( | |
| placement_x=random.uniform(-0.2, 0.2), | |
| angle=random.uniform(-0.5, 0.5), | |
| force=random.uniform(0.3, 0.8), | |
| ) | |
| def random_policy(obs: Observation) -> Action: | |
| import random | |
| return Action( | |
| placement_x=random.uniform(-0.35, 0.35), | |
| angle=random.uniform(-1.0, 1.0), | |
| force=random.uniform(0.2, 1.0), | |
| ) | |
| def heuristic_policy(obs: Observation) -> Action: | |
| """Aim at the nearest WHITE coin to a pocket; avoid black coins.""" | |
| import math | |
| best_angle = 0.0 | |
| best_placement = 0.0 | |
| best_score = float("inf") | |
| baseline_y = -0.5 + 0.08 | |
| for coin in obs.coins: | |
| if coin.pocketed: | |
| continue | |
| # Skip black coins β pocketing them is a due under ICF rules | |
| if coin.color == "black": | |
| continue | |
| dx = coin.x - 0.0 | |
| dy = coin.y - baseline_y | |
| angle = math.atan2(dx, dy) | |
| score = coin.pocket_distance | |
| if coin.color == "queen": | |
| score *= 0.5 | |
| if score < best_score: | |
| best_score = score | |
| best_angle = angle | |
| best_placement = max(-0.35, min(0.35, coin.x * 0.5)) | |
| return Action(placement_x=best_placement, angle=best_angle, force=0.6) | |
| # --------------------------------------------------------------------------- | |
| # Runner | |
| # --------------------------------------------------------------------------- | |
| def build_task_suite(num_episodes: int, max_steps: int) -> list[Task]: | |
| """Flat task suite used for baseline comparisons: `num_episodes` tasks at | |
| the given horizon, each with a unique seed. Keeps every policy evaluated | |
| on the *same* set of board states for fair comparison. | |
| """ | |
| return [ | |
| Task(task_id=f"ep_{i}", seed=i * 100, max_turns=max_steps, tier="standard") | |
| for i in range(num_episodes) | |
| ] | |
| def run_baseline( | |
| policy_fn, | |
| policy_name: str, | |
| tasks: list[Task], | |
| ) -> EvalReport: | |
| """Run a purple agent (policy_fn) against the shared task suite | |
| using the green-agent evaluator, and print the scorecard. | |
| """ | |
| print(f"\n--- Evaluating: {policy_name} ({len(tasks)} tasks) ---") | |
| evaluator = GreenCarromAgent(tasks=tasks) | |
| report = evaluator.evaluate(policy_fn, verbose=True) | |
| s = report.summary() | |
| print(f"\n=== {policy_name} ({s['n_tasks']} tasks) ===") | |
| print(f" Avg reward : {s['avg_reward']:+.3f}") | |
| print(f" Win rate : {s['win_rate']:.2f}") | |
| print(f" Avg coins : {s['avg_coins_potted']:.1f}") | |
| print(f" Avg dues : {s['avg_dues']:.2f} (ICF violations)") | |
| print(f" Avg fouls : {s['avg_fouls']:.2f}") | |
| print(f" ICF compliance : {s['icf_compliance']:.3f}") | |
| print(f" Sim steps : {s['total_sim_steps']}") | |
| print(f" Efficiency : {s['efficiency_score']:.4f} coins/1k-steps") | |
| return report | |
| def launch_web_server(host: str = "0.0.0.0", port: int = 8000) -> None: | |
| """Start the FastAPI + Gradio server (foreground) and print the watch URL. | |
| Use this when you want to watch the LLM play on the board and screen-record | |
| it. Configure the endpoint/model/key inside the "Auto-play with LLM" panel | |
| in the browser, then click "Auto-play" to stream animated shots. | |
| The environment variables ``API_BASE_URL``, ``MODEL_NAME``, and an API key | |
| (``NEBIUS_API_KEY`` / ``OPENAI_API_KEY`` / ``HF_TOKEN``) are inherited as | |
| defaults in the web form. | |
| """ | |
| env = os.environ.copy() | |
| env["ENABLE_WEB_INTERFACE"] = "true" | |
| env.setdefault("PYTHONPATH", os.getcwd()) | |
| url = f"http://localhost:{port}/web" | |
| print("=" * 70) | |
| print("Carrom server starting with web UIβ¦") | |
| print(f" Open: {url}") | |
| print(f" Inside the UI, configure model/endpoint, set number of shots,") | |
| print(f" then click the \"π€ Auto-play with LLM\" button to watch it play.") | |
| print(f" Press Ctrl+C in this terminal to stop the server.") | |
| print("=" * 70) | |
| cmd = [ | |
| sys.executable, "-m", "uvicorn", | |
| "server.app:app", | |
| "--host", host, | |
| "--port", str(port), | |
| "--ws-ping-interval", "60", | |
| "--ws-ping-timeout", "60", | |
| ] | |
| # Foreground: user ctrl-c's to stop | |
| try: | |
| subprocess.run(cmd, env=env, check=False) | |
| except KeyboardInterrupt: | |
| print("\nServer stopped.") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Carrom inference β headless baselines or live web view." | |
| ) | |
| parser.add_argument( | |
| "--web", action="store_true", | |
| help="Start the env server + Gradio web UI for auto-play watching " | |
| "(screen-record friendly). No headless baselines run in this mode.", | |
| ) | |
| parser.add_argument("--port", type=int, default=8000) | |
| parser.add_argument("--host", type=str, default="0.0.0.0") | |
| args = parser.parse_args() | |
| if args.web: | |
| launch_web_server(host=args.host, port=args.port) | |
| return | |
| print(f"API endpoint : {API_BASE_URL}") | |
| print(f"Model : {MODEL_NAME}") | |
| print(f"API key set : {'yes' if API_KEY else 'no'}") | |
| # Shared task suite β every policy sees the same boards (deterministic) | |
| tasks = build_task_suite(NUM_EPISODES, MAX_STEPS) | |
| print(f"Task suite : {len(tasks)} Γ {MAX_STEPS}-turn boards\n") | |
| start = time.time() | |
| reports: dict[str, EvalReport] = {} | |
| print("=" * 60 + "\nPURPLE AGENT: Random\n" + "=" * 60) | |
| reports["random"] = run_baseline(random_policy, "Random", tasks) | |
| print("\n" + "=" * 60 + "\nPURPLE AGENT: Heuristic (ICF-aware)\n" + "=" * 60) | |
| reports["heuristic"] = run_baseline(heuristic_policy, "Heuristic", tasks) | |
| if API_KEY: | |
| elapsed = time.time() - start | |
| if elapsed < TIMEOUT - 120: | |
| print(f"\n{'=' * 60}\nPURPLE AGENT: LLM ({MODEL_NAME})\n{'=' * 60}") | |
| reports["llm"] = run_baseline(llm_policy, f"LLM ({MODEL_NAME})", tasks) | |
| else: | |
| print(f"\nSkipping LLM baseline β {elapsed:.0f}s elapsed.") | |
| else: | |
| print("\nSkipping LLM baseline β no API key (set NEBIUS_API_KEY / OPENAI_API_KEY / HF_TOKEN).") | |
| # ββ Leaderboard ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n" + "=" * 78 + "\nLEADERBOARD\n" + "=" * 78) | |
| print(f"{'Purple Agent':<25} {'Reward':>8} {'Win%':>6} {'Coins':>6} {'Dues':>6} {'ICF%':>6} {'Eff':>8}") | |
| print("-" * 78) | |
| for name, report in reports.items(): | |
| s = report.summary() | |
| print(f"{name:<25} {s['avg_reward']:>+8.2f} {s['win_rate']*100:>5.0f}% " | |
| f"{s['avg_coins_potted']:>6.1f} {s['avg_dues']:>6.2f} " | |
| f"{s['icf_compliance']*100:>5.0f}% {s['efficiency_score']:>8.3f}") | |
| print(f"\nTotal runtime: {time.time() - start:.1f}s") | |
| if __name__ == "__main__": | |
| main() | |