Gin Rummy DREAM Model

This is a DREAM (Deep Regret Minimization) agent trained to play 2-player Gin Rummy via self-play.

Model Description

This model was trained using a novel variant of Deep Counterfactual Regret Minimization (Deep CFR) with a dueling network architecture that separates advantage and value estimation. The agent learns to play Gin Rummy through massively parallel self-play in a fully vectorized JAX environment.

Architecture

The neural network uses a dueling architecture that factorizes Q-values into advantage and value components:

Input Observation [4887 dims]
    โ†“
Shared Trunk: (1024, 1024, 512, 256)
    โ†“        โ†“
Advantage    Value
Head         Head
[107 actions] [scalar]

Key innovation: The advantage head learns relative action quality while the value head learns absolute state value, preventing Q-value drift that plagued earlier versions.

  • Total parameters: 6,738,796
  • Observation space: 4,887 dimensions (sequential event-based representation)
  • Action space: 107 discrete actions (draw, discard, knock variants)

Training Algorithm: DREAM

DREAM combines Deep CFR with several critical stability improvements:

1. Soft Regret Matching Policy

regret = softplus((advantages - max_adv) / temperature)
policy = regret / sum(regret)
  • Smooth policy gradients (unlike hard max in tabular CFR)
  • Temperature annealing: 2.0 โ†’ 0.1 over 400 iterations
  • Uniform warmup for first 20 iterations

2. Target Network Stabilization

  • Problem: Bootstrapping advantages from current network creates feedback loops โ†’ loss divergence
  • Solution: Polyak-averaged target network (ฯ„=0.1) updated every 5 iterations
  • Untaken actions use stable target network Q-values; taken actions use actual returns

3. Reward Shaping for Bootstrap

  • Challenge: Random play produces ~0% knock rate โ†’ no terminal rewards โ†’ no learning gradient
  • Solution: Per-step shaping reward proportional to deadwood reduction
    • Scale: 0.5 raw points per deadwood point reduced (after normalization: ~0.005)
    • Terminal knock rewards: ~25ร— larger (ensures terminal objectives dominate)

Reward Structure:

Gin bonus:      50 points โ†’ 0.50 normalized
Undercut bonus: 50 points โ†’ 0.50 normalized  
Score delta:    winner's margin โ†’ normalized by 100
Shaping:        0.5 ร— deadwood_reduced โ†’ ~0.005 per step
Draw penalty:   -5.0 โ†’ -0.05 normalized

4. Dual Loss Training

advantage_loss = MSE(pred_advantages, target_advantages)  # over legal actions
value_loss = MSE(pred_value, discounted_return)           # over valid states
total_loss = advantage_loss + 0.5 ร— value_loss
  • Advantage clipping: ยฑ10.0 (prevents outlier-driven divergence)
  • Gradient clipping: global norm โ‰ค 1.0
  • L2 regularization: 1e-05

Training Configuration

This checkpoint was trained with:

Environments:      8,192
Batch size:        2,048
Iterations:        800
Learning rate:     0.002
Epochs per iter:   40
Discount (ฮณ):      0.99

Hardware: Trained on NVIDIA H200 GPU
Throughput: ~800-1000 games/second during generation

Environment: Gin Rummy JAX

The environment is a fully vectorized, JIT-compiled Gin Rummy implementation in JAX:

  • โœ… Complete rules: knock, gin, undercut, layoffs, stock exhaustion draws
  • โœ… Pure JAX: No Python control flow in JIT paths (compatible with jax.vmap, jax.lax.scan)
  • โœ… Optimized deadwood computation: Bitmask dynamic programming over card subsets
  • โœ… Batched execution: Trains on 8,192 parallel games simultaneously
  • โœ… Sequential observations: Event log captures public history (picks, discards, knocks)

Key optimization: Per-player deadwood caching eliminates expensive recomputation during observation generation.

Observation Space (4,887 dimensions)

The agent observes:

  1. Static card features (156 dims):

    • My hand (52 binary)
    • Known unplayable cards (52 binary)
    • Belief distribution over unknown cards (52 probabilities)
  2. Sequential event log (4,720 dims = 80 events ร— 59 features):

    • Actor identity (1)
    • Action type one-hot (5: draw stock, draw discard, discard, knock, pass)
    • Card one-hot (52)
    • Turn fraction (1)
  3. Scalar context (11 dims):

    • Stock remaining, phase one-hot, turn indicator, deadwood count, can_knock, can_gin

This representation is designed for transformer-based policy distillation (future work).

Action Space (107 discrete actions)

0:      Draw from stock
1:      Draw from discard pile
2:      Pass (initial draw phase only)
3-54:   Discard card i (where i = action - 3)
55-106: Knock and discard card i (where i = action - 55)

All actions are masked via legal_actions(state) to ensure validity.

Usage

Loading the Model

import pickle
import jax
from huggingface_hub import hf_hub_download

# Download checkpoint
checkpoint_path = hf_hub_download(
    repo_id="GoodStartLabs/gin-rummy-dream",
    filename="checkpoint.pkl"
)

# Load parameters and config
with open(checkpoint_path, "rb") as f:
    checkpoint = pickle.load(f)

params = checkpoint['params']
config = checkpoint['config']

Running Inference

from gin_rummy_jax.env import GinRummyEnv
from gin_rummy_jax.rl.dream import make_network, soft_regret_matching

# Initialize environment
rng_key = jax.random.PRNGKey(0)
state = GinRummyEnv.init(rng_key)

# Create network
network_fn = make_network(config)

# Get observation for current player
obs = GinRummyEnv.observe(state, state.current_player)
legal_mask = GinRummyEnv.legal_actions(state)

# Predict advantages and compute policy
advantages, value = network_fn(params, obs)
policy = soft_regret_matching(advantages, legal_mask, temperature=0.1)

# Sample action
action = jax.random.choice(rng_key, len(policy), p=policy)

# Step environment
state, rewards, done = GinRummyEnv.step(state, action)

Evaluating Against Random

# See examples/evaluate.py for complete evaluation scripts
# Expected win rate vs random: >95% after convergence

Training Progress

Iteration 800 checkpoint exhibits:

  • โœ… Meld building: Agent reduces deadwood systematically (vs random play)
  • โœ… Knock timing: Learns to knock when advantageous (not just when legal)
  • โœ… Discard strategy: Avoids discarding meld-completing cards to opponent
  • ๐Ÿ”„ Gin optimization: Still learning when to hold for gin vs knock early (in progress)

Expected metrics at convergence (iteration ~500):

  • Knock rate: 70-80% (vs 0% for random, 2% at iteration 1)
  • Win rate vs random: >95%
  • Average game length: 15-25 steps (vs 80-step timeout for random)

Technical Innovations

This work demonstrates several novel contributions:

  1. First fully-vectorized Gin Rummy environment in JAX (enables massive parallelism)
  2. Dueling architecture for Deep CFR (separates advantage/value to prevent Q-drift)
  3. Reward shaping for sparse-reward games (bootstraps learning from random init)
  4. Target network stabilization for CFR (adapts DQN techniques to regret-based learning)

See PHASE5.md for detailed design rationale.

Known Limitations

  • Memory requirements: Full checkpoint is ~51MB (network params + target network)
  • Inference speed: JIT compilation adds ~2-5 second warmup on first call
  • Exploration: Converged policy may not cover all edge cases (rare card distributions)

Citation

If you use this model or environment in your research, please cite:

@software{gin_rummy_dream,
  title = {Gin Rummy DREAM: Deep Regret Minimization for Card Games},
  author = {Hopkins, Jack},
  year = {2026},
  url = {https://github.com/learning-environments/gin-rummy},
  note = {Dueling architecture with reward shaping and target network stabilization}
}

Links

  • ๐Ÿ“‚ GitHub Repository: learning-environments/gin-rummy
  • ๐Ÿ“Š Training Logs: Weights & Biases (see training runs)
  • ๐Ÿ“– Documentation: See repository README and PHASE*.md files for implementation details
  • ๐ŸŽฎ Environment: pip install git+https://github.com/learning-environments/gin-rummy.git

License

MIT License - Free for academic and commercial use.

Acknowledgments

This work builds on:

  • Deep CFR (Brown et al., 2019): Regret minimization with function approximation
  • Dueling DQN (Wang et al., 2016): Advantage/value decomposition
  • JAX ecosystem: Enables efficient GPU-accelerated RL at scale

Model card last updated: 800 iterations

Downloads last month
-
Video Preview
loading