{ "cells": [ { "cell_type": "markdown", "id": "a1b2c3d4", "metadata": {}, "source": [ "# Training LLMs to Play Carrom (ICF Rules) with GRPO + Unsloth\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bp-high/carrom_rl_env/blob/main/examples/grpo_carrom_tutorial.ipynb)\n", "\n", "This notebook trains a language model to play Carrom using **GRPO** (Group Relative Policy\n", "Optimization) via TRL + Unsloth on the **OpenEnv Carrom** environment with full\n", "International Carrom Federation (ICF) rule compliance.\n", "\n", "## What you'll learn\n", "1. Explore the physics-based Carrom environment (Coulomb friction, ICF rules)\n", "2. Benchmark random and heuristic baselines\n", "3. Fine-tune **Gemma-3-4B** with GRPO (4-bit, Unsloth)\n", "4. Run inference via any OpenAI-compatible endpoint (Nebius, HF, vLLM)\n", "5. Scale up on **Modal** for under $25\n", "\n", "## Quick-start cost guide\n", "| Setting | GPU | Est. time | Est. cost |\n", "|---------|-----|-----------|----------|\n", "| Colab T4 (free) | 16 GB | ~20 min | $0 |\n", "| Modal A10G — smoke test | 24 GB | ~15 min | ~$0.30 |\n", "| Modal A10G — blog run | 24 GB | ~2.5 hr | ~$2.75 |\n", "| Modal A100 — best results | 40 GB | ~1.5 hr | ~$5.60 |" ] }, { "cell_type": "markdown", "id": "b2c3d4e5", "metadata": {}, "source": [ "## 1. Installation" ] }, { "cell_type": "code", "execution_count": null, "id": "c3d4e5f6", "metadata": {}, "outputs": [], "source": [ "# Detect Colab vs local and install accordingly\n", "import subprocess, sys\n", "\n", "IN_COLAB = 'google.colab' in sys.modules\n", "\n", "if IN_COLAB:\n", " # Unsloth Colab install (handles CUDA version automatically)\n", " subprocess.run([sys.executable, '-m', 'pip', 'install',\n", " 'unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git',\n", " '--quiet'], check=True)\n", " subprocess.run([sys.executable, '-m', 'pip', 'install',\n", " '--no-deps', 'trl>=0.15', 'peft', 'accelerate', 'bitsandbytes',\n", " '--quiet'], check=True)\n", " subprocess.run([sys.executable, '-m', 'pip', 'install',\n", " 'openenv-core', 'pymunk>=6.5', 'numpy', 'pydantic>=2.0',\n", " '--quiet'], check=True)\n", " # Install carrom env from the repo root\n", " subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', '-e', '..'], check=True)\n", "else:\n", " # Local: assume the repo is already installed\n", " print('Running locally — skipping pip installs.')\n", "\n", "print('Setup complete!')" ] }, { "cell_type": "code", "execution_count": null, "id": "d4e5f6a7", "metadata": {}, "outputs": [], "source": [ "import sys, os\n", "# Add repo root to path when running from examples/\n", "sys.path.insert(0, os.path.join(os.getcwd(), '..'))\n", "\n", "import math, random, json, re\n", "import torch\n", "\n", "from carrom_env.env import CarromEnv\n", "from carrom_env.models import Action, Observation\n", "from carrom_env.green_agent import GreenCarromAgent\n", "from examples.grpo_utils import (\n", " format_chat_prompt, parse_response, carrom_reward_for_trl,\n", " CARROM_SYSTEM_PROMPT,\n", ")\n", "\n", "print(f'PyTorch : {torch.__version__}')\n", "print(f'CUDA : {torch.cuda.is_available()} ({torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"none\"})')" ] }, { "cell_type": "markdown", "id": "e5f6a7b8", "metadata": {}, "source": [ "## 2. Explore the Environment\n", "\n", "Carrom is a physics-based board game with **Coulomb board friction** (constant\n", "deceleration, not viscous drag) and full ICF rule compliance:\n", "- **You play WHITE** coins (+1 pt each)\n", "- **Due rule**: pocketing a black coin returns it to the board, no score, turn ends\n", "- **Queen cover**: pocket queen → must cover with own coin next shot" ] }, { "cell_type": "code", "execution_count": null, "id": "f6a7b8c9", "metadata": {}, "outputs": [], "source": [ "env = CarromEnv(seed=42)\n", "obs = env.reset()\n", "print(obs.text_summary)" ] }, { "cell_type": "code", "execution_count": null, "id": "a7b8c9d0", "metadata": {}, "outputs": [], "source": [ "# Take a shot and inspect the result\n", "action = Action(placement_x=0.0, angle=0.0, force=0.7)\n", "obs, reward, terminated, truncated, info = env.step(action)\n", "\n", "print(f'Reward : {reward:.3f}')\n", "print(f'Own coins potted : {int(info[\"coin_potted\"])}')\n", "print(f'Due coins (opponent pocketed) : {int(info[\"due_coins\"])}')\n", "print(f'Foul : {bool(info[\"foul\"])}')\n", "print(f'Sim steps : {int(info[\"sim_steps\"])}')\n", "print()\n", "print(obs.text_summary)" ] }, { "cell_type": "markdown", "id": "b8c9d0e1", "metadata": {}, "source": [ "## 3. Baseline Benchmarks\n", "\n", "Establish baselines with random and ICF-aware heuristic policies." ] }, { "cell_type": "code", "execution_count": null, "id": "c9d0e1f2", "metadata": {}, "outputs": [], "source": [ "def random_policy(obs: Observation) -> Action:\n", " return Action(\n", " placement_x=random.uniform(-0.35, 0.35),\n", " angle=random.uniform(-1.0, 1.0),\n", " force=random.uniform(0.2, 1.0),\n", " )\n", "\n", "def heuristic_policy(obs: Observation) -> Action:\n", " \"\"\"ICF-aware heuristic: aim at the nearest WHITE coin to a pocket.\"\"\"\n", " best_angle, best_x, best_score = 0.0, 0.0, float('inf')\n", " baseline_y = -0.5 + 0.08\n", " for coin in obs.coins:\n", " if coin.pocketed or coin.color == 'black': # skip black (due rule)\n", " continue\n", " dx, dy = coin.x, coin.y - baseline_y\n", " score = coin.pocket_distance * (0.5 if coin.color == 'queen' else 1.0)\n", " if score < best_score:\n", " best_score = score\n", " best_angle = math.atan2(dx, dy)\n", " best_x = max(-0.35, min(0.35, coin.x * 0.5))\n", " return Action(placement_x=best_x, angle=best_angle, force=0.6)\n", "\n", "print(f'{'Policy':<20} {'Avg Reward':>12} {'Avg Coins':>12} {'Avg Fouls':>10} {'Efficiency':>12}')\n", "print('-' * 70)\n", "for name, policy in [('Random', random_policy), ('Heuristic (ICF)', heuristic_policy)]:\n", " all_m = []\n", " for ep in range(5):\n", " agent = GreenCarromAgent(policy=policy, seed=ep * 42)\n", " all_m.append(agent.run_episode(max_steps=30, seed=ep * 42))\n", " avg_r = sum(sum(m.rewards) for m in all_m) / 5\n", " avg_c = sum(m.coins_potted for m in all_m) / 5\n", " avg_f = sum(m.total_fouls for m in all_m) / 5\n", " avg_e = sum(m.efficiency_score for m in all_m) / 5\n", " print(f'{name:<20} {avg_r:>12.3f} {avg_c:>12.1f} {avg_f:>10.1f} {avg_e:>12.4f}')" ] }, { "cell_type": "markdown", "id": "d0e1f2a3", "metadata": {}, "source": [ "## 4. Generate Training Data\n", "\n", "We create diverse board states by stepping random actions from various seeds." ] }, { "cell_type": "code", "execution_count": null, "id": "e1f2a3b4", "metadata": {}, "outputs": [], "source": [ "from datasets import Dataset\n", "\n", "def make_dataset(n: int = 300, max_random_steps: int = 6) -> Dataset:\n", " samples = []\n", " for seed in range(n * 3):\n", " if len(samples) >= n:\n", " break\n", " env = CarromEnv(seed=seed)\n", " obs = env.reset()\n", " for _ in range(random.randint(0, max_random_steps)):\n", " a = Action(\n", " placement_x=random.uniform(-0.3, 0.3),\n", " angle=random.uniform(-0.9, 0.9),\n", " force=random.uniform(0.25, 0.85),\n", " )\n", " obs, _, done, trunc, _ = env.step(a)\n", " if done or trunc:\n", " break\n", " if obs.remaining_coins > 2:\n", " samples.append({'prompt': format_chat_prompt(obs)})\n", " return Dataset.from_list(samples)\n", "\n", "train_ds = make_dataset(300)\n", "print(f'Training samples : {len(train_ds)}')\n", "print(f'Example prompt (truncated):')\n", "print(train_ds[0]['prompt'][1]['content'][:400])" ] }, { "cell_type": "markdown", "id": "f2a3b4c5", "metadata": {}, "source": [ "## 5. Load Gemma-3-4B with Unsloth (4-bit)\n", "\n", "Unsloth's optimised 4-bit loading fits **Gemma-3-4B** in ~10 GB VRAM —\n", "well within a Colab T4 (16 GB) or Modal A10G (24 GB).\n", "\n", "**Alternative models** (change `MODEL_ID` below):\n", "| Model | VRAM | Quality |\n", "|-------|------|---------|\n", "| `unsloth/gemma-3-1b-it` | 6 GB | Fast demo |\n", "| `unsloth/gemma-3-4b-it` | 10 GB | **Recommended** |\n", "| `unsloth/Qwen2.5-1.5B-Instruct` | 5 GB | Very fast, good JSON |\n", "| `unsloth/Qwen2.5-3B-Instruct` | 8 GB | Good balance |\n", "| `unsloth/Qwen2.5-7B-Instruct` | 18 GB | Best (needs A10G) |" ] }, { "cell_type": "code", "execution_count": null, "id": "a3b4c5d6", "metadata": {}, "outputs": [], "source": [ "from unsloth import FastLanguageModel\n", "\n", "# Change this to try different models (see table above)\n", "MODEL_ID = 'unsloth/gemma-3-4b-it'\n", "\n", "model, tokenizer = FastLanguageModel.from_pretrained(\n", " model_name=MODEL_ID,\n", " max_seq_length=1024,\n", " load_in_4bit=True,\n", " dtype=torch.bfloat16,\n", ")\n", "\n", "model = FastLanguageModel.get_peft_model(\n", " model,\n", " r=8,\n", " target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj',\n", " 'gate_proj', 'up_proj', 'down_proj'],\n", " lora_alpha=8,\n", " lora_dropout=0,\n", " bias='none',\n", " use_gradient_checkpointing='unsloth', # saves ~30% VRAM vs standard checkpointing\n", " random_state=42,\n", ")\n", "\n", "trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "total = sum(p.numel() for p in model.parameters())\n", "print(f'Trainable params : {trainable:,} ({100*trainable/total:.2f}% of {total:,})')" ] }, { "cell_type": "markdown", "id": "b4c5d6e7", "metadata": {}, "source": [ "## 6. Reward Function\n", "\n", "GRPO needs a reward for each completion. Ours combines:\n", "1. **Format** — valid JSON (+0.3, or -0.5 if unparseable)\n", "2. **Range** — reasonable placement / force / angle (+0.1 each)\n", "3. **Game** — actual environment reward from executing the action\n", "4. **ICF-aware** — extra +0.5 for scoring own coins; -0.3 for each due violation" ] }, { "cell_type": "code", "execution_count": null, "id": "c5d6e7f8", "metadata": {}, "outputs": [], "source": [ "# Test the reward function on a few example completions\n", "test_cases = [\n", " '{\"placement_x\": 0.0, \"angle\": 0.15, \"force\": 0.65}', # good white-targeting shot\n", " 'I think we should shoot left', # unparseable\n", " '{\"placement_x\": 0.2, \"angle\": -0.3, \"force\": 0.85}', # valid, aggressive\n", " '```json\\n{\"placement_x\": -0.1, \"angle\": 0.05, \"force\": 0.5}\\n```', # markdown-fenced\n", "]\n", "rewards = carrom_reward_for_trl(test_cases)\n", "for case, r in zip(test_cases, rewards):\n", " print(f' {case[:55]:55s} → reward = {r:+.3f}')" ] }, { "cell_type": "markdown", "id": "d6e7f8a9", "metadata": {}, "source": [ "## 7. GRPO Training\n", "\n", "We use TRL's `GRPOTrainer` — no value network needed, just the reward function.\n", "\n", "> **Colab T4 users**: `max_steps=100` runs in ~10 min and gives a feel for the training loop.\n", "> For a real result, use `max_steps=500+` or run [the Modal script](#8-scale-up-with-modal)." ] }, { "cell_type": "code", "execution_count": null, "id": "e7f8a9b0", "metadata": {}, "outputs": [], "source": [ "from trl import GRPOConfig, GRPOTrainer\n", "\n", "training_args = GRPOConfig(\n", " output_dir='./carrom-grpo-output',\n", " num_train_epochs=1,\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=8,\n", " num_generations=4, # G in GRPO — completions sampled per prompt\n", " max_prompt_length=512,\n", " max_completion_length=256,\n", " max_steps=100, # increase to 500+ for real training\n", " learning_rate=5e-6,\n", " warmup_ratio=0.05,\n", " lr_scheduler_type='cosine',\n", " optim='adamw_8bit',\n", " bf16=torch.cuda.is_available(),\n", " logging_steps=5,\n", " save_steps=50,\n", " report_to='none',\n", ")\n", "\n", "trainer = GRPOTrainer(\n", " model=model,\n", " processing_class=tokenizer,\n", " reward_funcs=carrom_reward_for_trl,\n", " args=training_args,\n", " train_dataset=train_ds,\n", ")\n", "\n", "print('Trainer ready. Starting training...')\n", "trainer.train()\n", "print('Training complete!')" ] }, { "cell_type": "code", "execution_count": null, "id": "f8a9b0c1", "metadata": {}, "outputs": [], "source": [ "# Save locally (optional: push to HF Hub)\n", "trainer.save_model('./carrom-grpo-output/final')\n", "tokenizer.save_pretrained('./carrom-grpo-output/final')\n", "print('Saved to ./carrom-grpo-output/final')\n", "\n", "# Uncomment to push to HF Hub:\n", "# from huggingface_hub import login\n", "# login()\n", "# trainer.model.push_to_hub('your-username/carrom-grpo-gemma')\n", "# tokenizer.push_to_hub('your-username/carrom-grpo-gemma')" ] }, { "cell_type": "markdown", "id": "a9b0c1d2", "metadata": {}, "source": [ "## 8. Scale Up with Modal\n", "\n", "For longer runs (500–2000 steps) or bigger models (7B), use the Modal training script.\n", "Everything runs on cloud GPUs; pay only for what you use.\n", "\n", "```bash\n", "# Install Modal\n", "pip install modal && modal setup\n", "\n", "# Create HF token secret in Modal\n", "modal secret create hf-token HF_TOKEN=hf_...\n", "\n", "# Quick smoke-test (200 steps, ~$0.30)\n", "modal run examples/train_modal.py --dry-run # preview cost\n", "modal run examples/train_modal.py # run it\n", "\n", "# Blog-quality training (2000 steps, ~$2.75)\n", "modal run examples/train_modal.py --steps 2000 --push --repo your-user/carrom-grpo-gemma\n", "\n", "# Larger model\n", "modal run examples/train_modal.py --model qwen2.5-7b --gpu a100 --steps 2000\n", "```\n", "\n", "All options under $25. See `examples/train_modal.py` for the full model menu." ] }, { "cell_type": "markdown", "id": "b0c1d2e3", "metadata": {}, "source": [ "## 9. Evaluate the Trained Model" ] }, { "cell_type": "code", "execution_count": null, "id": "c1d2e3f4", "metadata": {}, "outputs": [], "source": [ "from transformers import pipeline\n", "\n", "# Enable native (fast) inference with Unsloth\n", "FastLanguageModel.for_inference(model)\n", "\n", "def make_trained_policy(model, tokenizer):\n", " pipe = pipeline(\n", " 'text-generation',\n", " model=model,\n", " tokenizer=tokenizer,\n", " max_new_tokens=128,\n", " temperature=0.3,\n", " do_sample=True,\n", " )\n", " def _policy(obs: Observation) -> Action:\n", " messages = format_chat_prompt(obs)\n", " result = pipe(messages)\n", " text = result[0]['generated_text'][-1]['content']\n", " action = parse_response(text)\n", " return action if action else Action(placement_x=0.0, angle=0.0, force=0.5)\n", " return _policy\n", "\n", "trained_policy = make_trained_policy(model, tokenizer)\n", "\n", "# Compare all policies\n", "N_EVAL = 5\n", "print(f'{'Policy':<22} {'Avg Reward':>12} {'Avg Coins':>10} {'Avg Fouls':>10} {'Efficiency':>12}')\n", "print('-' * 70)\n", "for name, policy in [\n", " ('Random', random_policy),\n", " ('Heuristic (ICF)', heuristic_policy),\n", " ('GRPO-Trained', trained_policy),\n", "]:\n", " ms = [GreenCarromAgent(policy=policy, seed=ep*7).run_episode(max_steps=30, seed=ep*7)\n", " for ep in range(N_EVAL)]\n", " print(f'{name:<22}'\n", " f' {sum(sum(m.rewards) for m in ms)/N_EVAL:>12.3f}'\n", " f' {sum(m.coins_potted for m in ms)/N_EVAL:>10.1f}'\n", " f' {sum(m.total_fouls for m in ms)/N_EVAL:>10.1f}'\n", " f' {sum(m.efficiency_score for m in ms)/N_EVAL:>12.4f}')" ] }, { "cell_type": "markdown", "id": "d2e3f4a5", "metadata": {}, "source": [ "## 10. Inference via OpenAI-Compatible APIs\n", "\n", "Once you've pushed a model (or want to use a hosted one), you can plug any\n", "OpenAI-compatible endpoint into the Carrom environment.\n", "\n", "### Supported providers (same code, different env vars)\n", "\n", "| Provider | `API_BASE_URL` | Key env var |\n", "|----------|---------------|-------------|\n", "| HuggingFace Inference | `https://api-inference.huggingface.co/v1` | `HF_TOKEN` |\n", "| Nebius | `https://api.studio.nebius.com/v1` | `NEBIUS_API_KEY` |\n", "| OpenAI | `https://api.openai.com/v1` | `OPENAI_API_KEY` |\n", "| Local vLLM | `http://localhost:8000/v1` | — |\n", "| Local Ollama | `http://localhost:11434/v1` | — |" ] }, { "cell_type": "code", "execution_count": null, "id": "e3f4a5b6", "metadata": {}, "outputs": [], "source": [ "import os, requests\n", "\n", "# ── Configure your endpoint here ──────────────────────────────────────────\n", "API_BASE_URL = os.environ.get('API_BASE_URL', 'https://api-inference.huggingface.co/v1')\n", "MODEL_NAME = os.environ.get('MODEL_NAME', 'Qwen/Qwen3-4B')\n", "API_KEY = (\n", " os.environ.get('NEBIUS_API_KEY')\n", " or os.environ.get('OPENAI_API_KEY')\n", " or os.environ.get('HF_TOKEN', '')\n", ")\n", "# ──────────────────────────────────────────────────────────────────────────\n", "\n", "SYSTEM_PROMPT = \"\"\"You are a Carrom expert (ICF rules). You play WHITE coins.\n", "Pocket WHITE coins (+1 pt). Pocketing BLACK coins is a DUE — coin returns to board.\n", "Respond with ONLY JSON: {\"placement_x\": <-0.4..0.4>, \"angle\": , \"force\": <0..1>}\"\"\"\n", "\n", "def api_call(obs_text: str) -> dict | None:\n", " try:\n", " resp = requests.post(\n", " f'{API_BASE_URL}/chat/completions',\n", " headers={'Authorization': f'Bearer {API_KEY}', 'Content-Type': 'application/json'},\n", " json={\n", " 'model': MODEL_NAME,\n", " 'messages': [\n", " {'role': 'system', 'content': SYSTEM_PROMPT},\n", " {'role': 'user', 'content': obs_text},\n", " ],\n", " 'max_tokens': 128,\n", " 'temperature': 0.3,\n", " },\n", " timeout=30,\n", " )\n", " resp.raise_for_status()\n", " text = resp.json()['choices'][0]['message']['content']\n", " return parse_response(text)\n", " except Exception as e:\n", " print(f'API error: {e}')\n", " return None\n", "\n", "def api_policy(obs: Observation) -> Action:\n", " result = api_call(obs.text_summary)\n", " if result:\n", " return Action(**{k: v for k, v in result.__dict__.items()\n", " if k in ('placement_x', 'angle', 'force')})\n", " return Action(placement_x=0.0, angle=0.0, force=0.5) # fallback\n", "\n", "if API_KEY:\n", " print(f'Running 3 episodes with {MODEL_NAME} ...')\n", " ms = [GreenCarromAgent(policy=api_policy, seed=ep*5).run_episode(max_steps=15, seed=ep*5)\n", " for ep in range(3)]\n", " avg_r = sum(sum(m.rewards) for m in ms) / 3\n", " avg_c = sum(m.coins_potted for m in ms) / 3\n", " print(f'Avg reward: {avg_r:.3f} | Avg coins: {avg_c:.1f}')\n", "else:\n", " print('No API key set. Set NEBIUS_API_KEY / OPENAI_API_KEY / HF_TOKEN to test inference.')" ] }, { "cell_type": "markdown", "id": "f4a5b6c7", "metadata": {}, "source": [ "## Key Takeaways\n", "\n", "1. **Carrom has real ICF physics** — Coulomb friction (constant decel ≠ viscous drag),\n", " due rule, queen cover, foul handling\n", "2. **Text actions bridge LLMs to physics** — model outputs JSON, env executes physics\n", "3. **GRPO + Unsloth is efficient** — 4-bit + LoRA trains 4B models on a T4 in <20 min\n", "4. **Due rule shapes learning** — model learns to avoid black coins; `-0.3` per due\n", " in the reward is a surprisingly strong signal\n", "5. **Modal scales cheaply** — 2000-step production run under $3\n", "\n", "### Next steps\n", "- Use **torchforge** for distributed multi-GPU rollout collection\n", "- Try **vision input** (rendered board state) for multimodal training\n", "- Tune `num_generations` (G) — higher G = better gradient signal but more memory\n", "- Push environment to the OpenEnv Hub: `openenv push --repo-id openenv/carrom`" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 5 }