File size: 13,548 Bytes
13b4881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
"""Baseline inference script for the Carrom OpenEnv environment.

Runs an LLM agent against the Carrom environment (ICF rules) and reports
game performance and Green Agent efficiency metrics.

Supports any OpenAI-compatible API endpoint.  Configure via environment
variables so the same script works with HuggingFace Inference, Nebius,
vLLM, OpenAI, or any other compatible provider:

    # HuggingFace Inference Router
    export API_BASE_URL="https://api-inference.huggingface.co/v1"
    export MODEL_NAME="Qwen/Qwen3-4B"
    export HF_TOKEN="hf_..."

    # Nebius
    export API_BASE_URL="https://api.studio.nebius.com/v1"
    export MODEL_NAME="nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"
    export NEBIUS_API_KEY="ey..."

    # OpenAI
    export API_BASE_URL="https://api.openai.com/v1"
    export MODEL_NAME="gpt-4o-mini"
    export OPENAI_API_KEY="sk-..."

    # Local vLLM
    export API_BASE_URL="http://localhost:8000/v1"
    export MODEL_NAME="Qwen/Qwen2.5-7B-Instruct"
    # No key needed for local

Then run:
    python inference.py
"""

from __future__ import annotations

import argparse
import json
import os
import re
import subprocess
import sys
import time
from typing import Optional

import requests

from carrom_env.env import CarromEnv
from carrom_env.models import Action, Observation
from carrom_env.green_agent import GreenCarromAgent, EvalReport, Task


# ---------------------------------------------------------------------------
# Configuration β€” all overridable via environment variables
# ---------------------------------------------------------------------------

API_BASE_URL = os.environ.get(
    "API_BASE_URL",
    "https://api-inference.huggingface.co/v1",
)
MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen3-4B")

# API key: checked in priority order β€” NEBIUS_API_KEY β†’ OPENAI_API_KEY β†’ HF_TOKEN
API_KEY: str = (
    os.environ.get("NEBIUS_API_KEY")
    or os.environ.get("OPENAI_API_KEY")
    or os.environ.get("HF_TOKEN")
    or ""
)

MAX_STEPS    = int(os.environ.get("MAX_STEPS",        "30"))
NUM_EPISODES = int(os.environ.get("NUM_EPISODES",      "3"))
TIMEOUT      = int(os.environ.get("TIMEOUT_MINUTES",  "20")) * 60

# ---------------------------------------------------------------------------
# System prompt (ICF rules)
# ---------------------------------------------------------------------------

SYSTEM_PROMPT = """\
You are an expert Carrom player following ICF (International Carrom Federation) rules.

Board layout
------------
- 1.0 Γ— 1.0 square centred at (0, 0).  Pockets at the four corners (Β±0.5, Β±0.5).
- Your striker starts on the BOTTOM baseline (y β‰ˆ -0.42).
- You play WHITE coins.  The opponent plays BLACK coins.

Scoring & rules
---------------
- Pocket a WHITE coin  β†’ +1 point, take another turn
- Pocket the QUEEN     β†’ +3 points; you must then pocket a white coin on the
                         same shot OR your next turn to "cover" it
- Pocket a BLACK coin  β†’ DUE: coin returns to board centre, your turn ENDS
- Pocket the STRIKER   β†’ FOUL: one of your pocketed coins returns to board

Action format
-------------
Respond with ONLY a valid JSON object (no markdown, no explanation):
{
  "placement_x": <float, -0.4 to 0.4, 0 = centre>,
  "angle":       <float, radians, 0 = straight ahead toward +y>,
  "force":       <float, 0.0 to 1.0>
}

Strategy tips
-------------
- Prioritise white coins close to pockets for easy points
- Avoid shooting black coins β€” even if they are near a pocket
- Queen near centre: aim to pocket it AND a white coin in the same shot
- Adjust placement_x to get a direct line on your target
"""


# ---------------------------------------------------------------------------
# LLM interaction
# ---------------------------------------------------------------------------

def call_llm(observation_text: str) -> Optional[dict]:
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Content-Type": "application/json",
    }
    payload = {
        "model": MODEL_NAME,
        "messages": [
            {"role": "system",  "content": SYSTEM_PROMPT},
            {"role": "user",    "content": observation_text},
        ],
        # Generous budget to accommodate reasoning models (e.g. MiniMax-M2.5,
        # Nemotron) that emit long CoT before the final JSON answer.
        "max_tokens": int(os.environ.get("MAX_TOKENS", "2048")),
        "temperature": 0.3,
    }
    try:
        resp = requests.post(
            f"{API_BASE_URL}/chat/completions",
            headers=headers,
            json=payload,
            timeout=120,
        )
        resp.raise_for_status()
        msg = resp.json()["choices"][0]["message"]
        # Reasoning models put their final answer in `content` and the trace in
        # `reasoning_content`.  Fall back to reasoning_content if content is
        # null (common when the JSON is inline inside the reasoning).
        text = msg.get("content") or msg.get("reasoning_content") or ""
        return _parse_json_action(text)
    except Exception as e:
        print(f"  [LLM error] {e}")
        return None


def _parse_json_action(text: str) -> Optional[dict]:
    text = text.strip()
    text = re.sub(r"^```(?:json)?\s*", "", text)
    text = re.sub(r"\s*```$",          "", text)
    # Strip <think>…</think> blocks (some reasoning models)
    text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
    match = re.search(r"\{[^}]+\}", text)
    if match:
        try:
            data = json.loads(match.group())
            return {
                "placement_x": float(data.get("placement_x", 0.0)),
                "angle":       float(data.get("angle",       0.0)),
                "force":       float(data.get("force",       0.5)),
            }
        except (json.JSONDecodeError, ValueError, TypeError):
            pass
    return None


# ---------------------------------------------------------------------------
# Policies
# ---------------------------------------------------------------------------

_llm_turn_counter = {"n": 0}


def llm_policy(obs: Observation) -> Action:
    _llm_turn_counter["n"] += 1
    parsed = call_llm(obs.text_summary)
    if parsed:
        action = Action(**parsed)
        print(f"  [shot {_llm_turn_counter['n']:>3}] "
              f"px={action.placement_x:+.2f} "
              f"angle={action.angle:+.2f} "
              f"force={action.force:.2f}   "
              f"(score {obs.agent_score}-{obs.opponent_score}, "
              f"coins left {obs.remaining_coins})", flush=True)
        return action
    import random
    print(f"  [shot {_llm_turn_counter['n']:>3}] PARSE FAIL β†’ random fallback", flush=True)
    return Action(
        placement_x=random.uniform(-0.2, 0.2),
        angle=random.uniform(-0.5, 0.5),
        force=random.uniform(0.3, 0.8),
    )


def random_policy(obs: Observation) -> Action:
    import random
    return Action(
        placement_x=random.uniform(-0.35, 0.35),
        angle=random.uniform(-1.0, 1.0),
        force=random.uniform(0.2, 1.0),
    )


def heuristic_policy(obs: Observation) -> Action:
    """Aim at the nearest WHITE coin to a pocket; avoid black coins."""
    import math
    best_angle    = 0.0
    best_placement = 0.0
    best_score    = float("inf")
    baseline_y    = -0.5 + 0.08

    for coin in obs.coins:
        if coin.pocketed:
            continue
        # Skip black coins β€” pocketing them is a due under ICF rules
        if coin.color == "black":
            continue
        dx    = coin.x - 0.0
        dy    = coin.y - baseline_y
        angle = math.atan2(dx, dy)
        score = coin.pocket_distance
        if coin.color == "queen":
            score *= 0.5
        if score < best_score:
            best_score     = score
            best_angle     = angle
            best_placement = max(-0.35, min(0.35, coin.x * 0.5))

    return Action(placement_x=best_placement, angle=best_angle, force=0.6)


# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------

def build_task_suite(num_episodes: int, max_steps: int) -> list[Task]:
    """Flat task suite used for baseline comparisons: `num_episodes` tasks at
    the given horizon, each with a unique seed.  Keeps every policy evaluated
    on the *same* set of board states for fair comparison.
    """
    return [
        Task(task_id=f"ep_{i}", seed=i * 100, max_turns=max_steps, tier="standard")
        for i in range(num_episodes)
    ]


def run_baseline(
    policy_fn,
    policy_name: str,
    tasks: list[Task],
) -> EvalReport:
    """Run a purple agent (policy_fn) against the shared task suite
    using the green-agent evaluator, and print the scorecard.
    """
    print(f"\n--- Evaluating: {policy_name} ({len(tasks)} tasks) ---")
    evaluator = GreenCarromAgent(tasks=tasks)
    report    = evaluator.evaluate(policy_fn, verbose=True)

    s = report.summary()
    print(f"\n=== {policy_name} ({s['n_tasks']} tasks) ===")
    print(f"  Avg reward     : {s['avg_reward']:+.3f}")
    print(f"  Win rate       : {s['win_rate']:.2f}")
    print(f"  Avg coins      : {s['avg_coins_potted']:.1f}")
    print(f"  Avg dues       : {s['avg_dues']:.2f}   (ICF violations)")
    print(f"  Avg fouls      : {s['avg_fouls']:.2f}")
    print(f"  ICF compliance : {s['icf_compliance']:.3f}")
    print(f"  Sim steps      : {s['total_sim_steps']}")
    print(f"  Efficiency     : {s['efficiency_score']:.4f} coins/1k-steps")
    return report


def launch_web_server(host: str = "0.0.0.0", port: int = 8000) -> None:
    """Start the FastAPI + Gradio server (foreground) and print the watch URL.

    Use this when you want to watch the LLM play on the board and screen-record
    it.  Configure the endpoint/model/key inside the "Auto-play with LLM" panel
    in the browser, then click "Auto-play" to stream animated shots.

    The environment variables ``API_BASE_URL``, ``MODEL_NAME``, and an API key
    (``NEBIUS_API_KEY`` / ``OPENAI_API_KEY`` / ``HF_TOKEN``) are inherited as
    defaults in the web form.
    """
    env = os.environ.copy()
    env["ENABLE_WEB_INTERFACE"] = "true"
    env.setdefault("PYTHONPATH", os.getcwd())

    url = f"http://localhost:{port}/web"
    print("=" * 70)
    print("Carrom server starting with web UI…")
    print(f"  Open:   {url}")
    print(f"  Inside the UI, configure model/endpoint, set number of shots,")
    print(f"  then click the \"πŸ€– Auto-play with LLM\" button to watch it play.")
    print(f"  Press Ctrl+C in this terminal to stop the server.")
    print("=" * 70)

    cmd = [
        sys.executable, "-m", "uvicorn",
        "server.app:app",
        "--host", host,
        "--port", str(port),
        "--ws-ping-interval", "60",
        "--ws-ping-timeout",  "60",
    ]
    # Foreground: user ctrl-c's to stop
    try:
        subprocess.run(cmd, env=env, check=False)
    except KeyboardInterrupt:
        print("\nServer stopped.")


def main():
    parser = argparse.ArgumentParser(
        description="Carrom inference β€” headless baselines or live web view."
    )
    parser.add_argument(
        "--web", action="store_true",
        help="Start the env server + Gradio web UI for auto-play watching "
             "(screen-record friendly). No headless baselines run in this mode.",
    )
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--host", type=str, default="0.0.0.0")
    args = parser.parse_args()

    if args.web:
        launch_web_server(host=args.host, port=args.port)
        return

    print(f"API endpoint : {API_BASE_URL}")
    print(f"Model        : {MODEL_NAME}")
    print(f"API key set  : {'yes' if API_KEY else 'no'}")

    # Shared task suite β€” every policy sees the same boards (deterministic)
    tasks = build_task_suite(NUM_EPISODES, MAX_STEPS)
    print(f"Task suite   : {len(tasks)} Γ— {MAX_STEPS}-turn boards\n")

    start   = time.time()
    reports: dict[str, EvalReport] = {}

    print("=" * 60 + "\nPURPLE AGENT: Random\n" + "=" * 60)
    reports["random"] = run_baseline(random_policy, "Random", tasks)

    print("\n" + "=" * 60 + "\nPURPLE AGENT: Heuristic (ICF-aware)\n" + "=" * 60)
    reports["heuristic"] = run_baseline(heuristic_policy, "Heuristic", tasks)

    if API_KEY:
        elapsed = time.time() - start
        if elapsed < TIMEOUT - 120:
            print(f"\n{'=' * 60}\nPURPLE AGENT: LLM ({MODEL_NAME})\n{'=' * 60}")
            reports["llm"] = run_baseline(llm_policy, f"LLM ({MODEL_NAME})", tasks)
        else:
            print(f"\nSkipping LLM baseline β€” {elapsed:.0f}s elapsed.")
    else:
        print("\nSkipping LLM baseline β€” no API key (set NEBIUS_API_KEY / OPENAI_API_KEY / HF_TOKEN).")

    # ── Leaderboard ────────────────────────────────────────────────
    print("\n" + "=" * 78 + "\nLEADERBOARD\n" + "=" * 78)
    print(f"{'Purple Agent':<25} {'Reward':>8} {'Win%':>6} {'Coins':>6} {'Dues':>6} {'ICF%':>6} {'Eff':>8}")
    print("-" * 78)
    for name, report in reports.items():
        s = report.summary()
        print(f"{name:<25} {s['avg_reward']:>+8.2f} {s['win_rate']*100:>5.0f}% "
              f"{s['avg_coins_potted']:>6.1f} {s['avg_dues']:>6.2f} "
              f"{s['icf_compliance']*100:>5.0f}% {s['efficiency_score']:>8.3f}")

    print(f"\nTotal runtime: {time.time() - start:.1f}s")


if __name__ == "__main__":
    main()