Gin Rummy DREAM Model
This is a DREAM (Deep Regret Minimization) agent trained to play 2-player Gin Rummy via self-play.
Model Description
This model was trained using a novel variant of Deep Counterfactual Regret Minimization (Deep CFR) with a dueling network architecture that separates advantage and value estimation. The agent learns to play Gin Rummy through massively parallel self-play in a fully vectorized JAX environment.
Architecture
The neural network uses a dueling architecture that factorizes Q-values into advantage and value components:
Input Observation [4887 dims]
โ
Shared Trunk: (1024, 1024, 512, 256)
โ โ
Advantage Value
Head Head
[107 actions] [scalar]
Key innovation: The advantage head learns relative action quality while the value head learns absolute state value, preventing Q-value drift that plagued earlier versions.
- Total parameters: 6,738,796
- Observation space: 4,887 dimensions (sequential event-based representation)
- Action space: 107 discrete actions (draw, discard, knock variants)
Training Algorithm: DREAM
DREAM combines Deep CFR with several critical stability improvements:
1. Soft Regret Matching Policy
regret = softplus((advantages - max_adv) / temperature)
policy = regret / sum(regret)
- Smooth policy gradients (unlike hard max in tabular CFR)
- Temperature annealing: 2.0 โ 0.1 over 400 iterations
- Uniform warmup for first 20 iterations
2. Target Network Stabilization
- Problem: Bootstrapping advantages from current network creates feedback loops โ loss divergence
- Solution: Polyak-averaged target network (ฯ=0.1) updated every 5 iterations
- Untaken actions use stable target network Q-values; taken actions use actual returns
3. Reward Shaping for Bootstrap
- Challenge: Random play produces ~0% knock rate โ no terminal rewards โ no learning gradient
- Solution: Per-step shaping reward proportional to deadwood reduction
- Scale: 0.5 raw points per deadwood point reduced (after normalization: ~0.005)
- Terminal knock rewards: ~25ร larger (ensures terminal objectives dominate)
Reward Structure:
Gin bonus: 50 points โ 0.50 normalized
Undercut bonus: 50 points โ 0.50 normalized
Score delta: winner's margin โ normalized by 100
Shaping: 0.5 ร deadwood_reduced โ ~0.005 per step
Draw penalty: -5.0 โ -0.05 normalized
4. Dual Loss Training
advantage_loss = MSE(pred_advantages, target_advantages) # over legal actions
value_loss = MSE(pred_value, discounted_return) # over valid states
total_loss = advantage_loss + 0.5 ร value_loss
- Advantage clipping: ยฑ10.0 (prevents outlier-driven divergence)
- Gradient clipping: global norm โค 1.0
- L2 regularization: 1e-05
Training Configuration
This checkpoint was trained with:
Environments: 8,192
Batch size: 2,048
Iterations: 800
Learning rate: 0.002
Epochs per iter: 40
Discount (ฮณ): 0.99
Hardware: Trained on NVIDIA H200 GPU
Throughput: ~800-1000 games/second during generation
Environment: Gin Rummy JAX
The environment is a fully vectorized, JIT-compiled Gin Rummy implementation in JAX:
- โ Complete rules: knock, gin, undercut, layoffs, stock exhaustion draws
- โ
Pure JAX: No Python control flow in JIT paths (compatible with
jax.vmap,jax.lax.scan) - โ Optimized deadwood computation: Bitmask dynamic programming over card subsets
- โ Batched execution: Trains on 8,192 parallel games simultaneously
- โ Sequential observations: Event log captures public history (picks, discards, knocks)
Key optimization: Per-player deadwood caching eliminates expensive recomputation during observation generation.
Observation Space (4,887 dimensions)
The agent observes:
Static card features (156 dims):
- My hand (52 binary)
- Known unplayable cards (52 binary)
- Belief distribution over unknown cards (52 probabilities)
Sequential event log (4,720 dims = 80 events ร 59 features):
- Actor identity (1)
- Action type one-hot (5: draw stock, draw discard, discard, knock, pass)
- Card one-hot (52)
- Turn fraction (1)
Scalar context (11 dims):
- Stock remaining, phase one-hot, turn indicator, deadwood count, can_knock, can_gin
This representation is designed for transformer-based policy distillation (future work).
Action Space (107 discrete actions)
0: Draw from stock
1: Draw from discard pile
2: Pass (initial draw phase only)
3-54: Discard card i (where i = action - 3)
55-106: Knock and discard card i (where i = action - 55)
All actions are masked via legal_actions(state) to ensure validity.
Usage
Loading the Model
import pickle
import jax
from huggingface_hub import hf_hub_download
# Download checkpoint
checkpoint_path = hf_hub_download(
repo_id="GoodStartLabs/gin-rummy-dream",
filename="checkpoint.pkl"
)
# Load parameters and config
with open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
params = checkpoint['params']
config = checkpoint['config']
Running Inference
from gin_rummy_jax.env import GinRummyEnv
from gin_rummy_jax.rl.dream import make_network, soft_regret_matching
# Initialize environment
rng_key = jax.random.PRNGKey(0)
state = GinRummyEnv.init(rng_key)
# Create network
network_fn = make_network(config)
# Get observation for current player
obs = GinRummyEnv.observe(state, state.current_player)
legal_mask = GinRummyEnv.legal_actions(state)
# Predict advantages and compute policy
advantages, value = network_fn(params, obs)
policy = soft_regret_matching(advantages, legal_mask, temperature=0.1)
# Sample action
action = jax.random.choice(rng_key, len(policy), p=policy)
# Step environment
state, rewards, done = GinRummyEnv.step(state, action)
Evaluating Against Random
# See examples/evaluate.py for complete evaluation scripts
# Expected win rate vs random: >95% after convergence
Training Progress
Iteration 800 checkpoint exhibits:
- โ Meld building: Agent reduces deadwood systematically (vs random play)
- โ Knock timing: Learns to knock when advantageous (not just when legal)
- โ Discard strategy: Avoids discarding meld-completing cards to opponent
- ๐ Gin optimization: Still learning when to hold for gin vs knock early (in progress)
Expected metrics at convergence (iteration ~500):
- Knock rate: 70-80% (vs 0% for random, 2% at iteration 1)
- Win rate vs random: >95%
- Average game length: 15-25 steps (vs 80-step timeout for random)
Technical Innovations
This work demonstrates several novel contributions:
- First fully-vectorized Gin Rummy environment in JAX (enables massive parallelism)
- Dueling architecture for Deep CFR (separates advantage/value to prevent Q-drift)
- Reward shaping for sparse-reward games (bootstraps learning from random init)
- Target network stabilization for CFR (adapts DQN techniques to regret-based learning)
See PHASE5.md for detailed design rationale.
Known Limitations
- Memory requirements: Full checkpoint is ~51MB (network params + target network)
- Inference speed: JIT compilation adds ~2-5 second warmup on first call
- Exploration: Converged policy may not cover all edge cases (rare card distributions)
Citation
If you use this model or environment in your research, please cite:
@software{gin_rummy_dream,
title = {Gin Rummy DREAM: Deep Regret Minimization for Card Games},
author = {Hopkins, Jack},
year = {2026},
url = {https://github.com/learning-environments/gin-rummy},
note = {Dueling architecture with reward shaping and target network stabilization}
}
Links
- ๐ GitHub Repository: learning-environments/gin-rummy
- ๐ Training Logs: Weights & Biases (see training runs)
- ๐ Documentation: See repository README and PHASE*.md files for implementation details
- ๐ฎ Environment:
pip install git+https://github.com/learning-environments/gin-rummy.git
License
MIT License - Free for academic and commercial use.
Acknowledgments
This work builds on:
- Deep CFR (Brown et al., 2019): Regret minimization with function approximation
- Dueling DQN (Wang et al., 2016): Advantage/value decomposition
- JAX ecosystem: Enables efficient GPU-accelerated RL at scale
Model card last updated: 800 iterations
- Downloads last month
- -