Spaces:
Build error
Build error
Kaushik Rajan
commited on
Commit
·
842d62b
1
Parent(s):
47b257f
Simplify codebase: focused SPIRAL TicTacToe demo with key research concepts
Browse files- README.md +56 -77
- app.py +340 -388
- config.yaml +0 -124
- data/README.md +0 -16
- requirements.txt +2 -15
- src/__init__.py +0 -15
- src/games/__init__.py +0 -16
- src/games/game_utils.py +0 -212
- src/games/kuhn_poker.py +0 -314
- src/games/tictactoe.py +0 -237
- src/models/__init__.py +0 -13
- src/reasoning/__init__.py +0 -13
- src/training/__init__.py +0 -13
- src/training/train_spiral.py +0 -58
- tests/test_basic.py +0 -130
- tests/test_games.py +0 -78
README.md
CHANGED
|
@@ -11,104 +11,83 @@ license: apache-2.0
|
|
| 11 |
short_description: An interactive reasoning game simulator
|
| 12 |
---
|
| 13 |
|
| 14 |
-
# SPIRAL:
|
| 15 |
|
| 16 |
-
|
| 17 |
|
| 18 |
-
|
| 19 |
|
| 20 |
-
|
| 21 |
|
| 22 |
-
|
| 23 |
-
- **View Reasoning**: See step-by-step AI reasoning traces during gameplay
|
| 24 |
-
- **Test Transfer**: Evaluate AI's reasoning skills on math problems and logic puzzles
|
| 25 |
-
- **Learn**: Understand AI decision-making through interactive visualizations
|
| 26 |
|
| 27 |
-
##
|
| 28 |
|
| 29 |
-
###
|
| 30 |
-
-
|
| 31 |
-
-
|
| 32 |
-
-
|
| 33 |
-
- No setup required - runs in browser
|
| 34 |
|
| 35 |
-
###
|
| 36 |
-
-
|
| 37 |
-
-
|
| 38 |
-
-
|
| 39 |
-
- Fine-tuning examples and documentation
|
| 40 |
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
SPIRAL/
|
| 45 |
-
├── src/ # Core implementation
|
| 46 |
-
│ ├── games/ # Game environments
|
| 47 |
-
│ ├── models/ # SPIRAL model implementation
|
| 48 |
-
│ ├── training/ # Self-play training logic
|
| 49 |
-
│ └── reasoning/ # Reasoning trace generation
|
| 50 |
-
├── models/ # Trained model weights
|
| 51 |
-
├── data/ # Game datasets and benchmarks
|
| 52 |
-
├── app/ # Gradio web interface
|
| 53 |
-
├── tests/ # Unit and integration tests
|
| 54 |
-
└── docs/ # Documentation and tutorials
|
| 55 |
-
```
|
| 56 |
-
|
| 57 |
-
## Technology Stack
|
| 58 |
-
|
| 59 |
-
- **Backend**: Python 3.8+
|
| 60 |
-
- **ML Framework**: PyTorch, Transformers
|
| 61 |
-
- **RL Library**: Gymnasium, Stable Baselines3
|
| 62 |
-
- **Web Interface**: Gradio
|
| 63 |
-
- **Base Model**: Qwen-4B from Hugging Face
|
| 64 |
-
- **Deployment**: Hugging Face Spaces
|
| 65 |
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
3. **Testing and Optimization** 📋
|
| 71 |
-
4. **Deployment and Documentation** 📋
|
| 72 |
-
5. **Maintenance and Iteration** 📋
|
| 73 |
|
| 74 |
-
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
###
|
| 77 |
-
|
| 78 |
-
-
|
| 79 |
-
- Hugging Face account (for model access)
|
| 80 |
|
| 81 |
-
|
| 82 |
-
```bash
|
| 83 |
-
pip install -r requirements.txt
|
| 84 |
-
```
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
|
| 91 |
-
##
|
| 92 |
|
| 93 |
-
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
journal={[Journal]},
|
| 100 |
-
year={2024}
|
| 101 |
-
}
|
| 102 |
-
```
|
| 103 |
|
| 104 |
-
##
|
| 105 |
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
-
##
|
| 109 |
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
|
| 113 |
|
| 114 |
-
|
|
|
|
| 11 |
short_description: An interactive reasoning game simulator
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# SPIRAL: Self-Play Reasoning Demo
|
| 15 |
|
| 16 |
+
**Demonstrating how strategic reasoning emerges from self-play in zero-sum games**
|
| 17 |
|
| 18 |
+
Based on: *"Self-Play in Zero-Sum Games Incentivizes Reasoning via Multi-Agent Multi-Turn Reinforcement Learning"*
|
| 19 |
|
| 20 |
+
## 🎮 Interactive Demo
|
| 21 |
|
| 22 |
+
This simplified demo showcases the key concepts from the SPIRAL research through an interactive TicTacToe game. Watch as the AI demonstrates strategic reasoning using minimax tree search and explains its decision-making process.
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
## 🧠 Key Concepts Demonstrated
|
| 25 |
|
| 26 |
+
### Strategic Reasoning
|
| 27 |
+
- AI uses minimax tree search to evaluate all possible future moves
|
| 28 |
+
- Demonstrates how optimal strategies emerge from competitive gameplay
|
| 29 |
+
- Shows explicit reasoning explanations for each move
|
|
|
|
| 30 |
|
| 31 |
+
### Self-Play Learning Principles
|
| 32 |
+
- Zero-sum games create competitive pressure that incentivizes strategic thinking
|
| 33 |
+
- Multi-agent interactions naturally develop intelligent behavior
|
| 34 |
+
- Strategic patterns emerge from repeated competitive gameplay
|
|
|
|
| 35 |
|
| 36 |
+
### Tree Search & Planning
|
| 37 |
+
- Minimax algorithm demonstrates formalized strategic reasoning
|
| 38 |
+
- Look-ahead planning to evaluate future game states
|
| 39 |
+
- Optimal decision-making under competitive constraints
|
| 40 |
|
| 41 |
+
## 🚀 Running the Demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
### Local Setup
|
| 44 |
+
```bash
|
| 45 |
+
# Clone the repository
|
| 46 |
+
git clone https://huggingface.co/spaces/kaushikvr06/reasoning-simulator
|
| 47 |
+
cd reasoning-simulator
|
| 48 |
|
| 49 |
+
# Install dependencies
|
| 50 |
+
pip install -r requirements.txt
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
+
# Run the demo
|
| 53 |
+
python app.py
|
| 54 |
+
```
|
| 55 |
|
| 56 |
+
### Hugging Face Spaces
|
| 57 |
+
The demo is deployed and ready to use at:
|
| 58 |
+
[https://huggingface.co/spaces/kaushikvr06/reasoning-simulator](https://huggingface.co/spaces/kaushikvr06/reasoning-simulator)
|
|
|
|
| 59 |
|
| 60 |
+
## 📝 How It Works
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
1. **Human Move**: Click any square to make your move as X
|
| 63 |
+
2. **AI Analysis**: The AI analyzes the game tree using minimax search
|
| 64 |
+
3. **Strategic Reasoning**: Watch the AI explain its decision-making process
|
| 65 |
+
4. **Optimal Play**: The AI chooses the move that maximizes its winning probability
|
| 66 |
|
| 67 |
+
## 🔬 Research Connection
|
| 68 |
|
| 69 |
+
This demo illustrates core findings from the SPIRAL methodology:
|
| 70 |
|
| 71 |
+
- **Zero-sum competitive environments** naturally incentivize strategic reasoning
|
| 72 |
+
- **Multi-turn planning** emerges from the need to anticipate opponent moves
|
| 73 |
+
- **Strategic reasoning capabilities** developed through self-play can transfer to general reasoning tasks
|
| 74 |
+
- **Tree search algorithms** formalize the strategic reasoning process
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
## 🎯 Educational Value
|
| 77 |
|
| 78 |
+
Perfect for:
|
| 79 |
+
- Understanding strategic AI decision-making
|
| 80 |
+
- Learning about game theory and minimax algorithms
|
| 81 |
+
- Exploring the connection between competition and intelligence
|
| 82 |
+
- Visualizing how reasoning emerges from strategic gameplay
|
| 83 |
|
| 84 |
+
## 📊 Technical Details
|
| 85 |
|
| 86 |
+
- **Game Environment**: Clean TicTacToe implementation with proper state management
|
| 87 |
+
- **AI Strategy**: Minimax algorithm with optimal move selection
|
| 88 |
+
- **Reasoning Display**: Generated explanations of AI strategic thinking
|
| 89 |
+
- **Interactive Interface**: Real-time game state updates and move explanations
|
| 90 |
|
| 91 |
+
---
|
| 92 |
|
| 93 |
+
*Experience firsthand how strategic reasoning emerges from competitive self-play!*
|
app.py
CHANGED
|
@@ -1,103 +1,179 @@
|
|
| 1 |
"""
|
| 2 |
SPIRAL: Interactive Reasoning Game Simulator
|
| 3 |
|
| 4 |
-
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
import numpy as np
|
| 9 |
import random
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
#
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
| 63 |
-
# Method 3: Direct file imports
|
| 64 |
-
print("🔄 Attempting Method 3: Direct file imports")
|
| 65 |
-
sys.path.insert(0, games_path)
|
| 66 |
-
from tictactoe import TicTacToeEnv
|
| 67 |
-
from kuhn_poker import KuhnPokerEnv
|
| 68 |
-
print("✅ Method 3 successful: Direct file imports")
|
| 69 |
-
GAMES_AVAILABLE = True
|
| 70 |
-
except Exception as e:
|
| 71 |
-
print(f"❌ Method 3 failed: {e}")
|
| 72 |
-
print("📋 Full traceback:", traceback.format_exc())
|
| 73 |
-
|
| 74 |
-
if GAMES_AVAILABLE:
|
| 75 |
-
print("🎮 Game modules successfully imported!")
|
| 76 |
-
try:
|
| 77 |
-
# Test instantiation
|
| 78 |
-
tictactoe_env = TicTacToeEnv()
|
| 79 |
-
# kuhn_env = KuhnPokerEnv() # No longer needed
|
| 80 |
-
print("✅ Game environment created successfully")
|
| 81 |
-
except Exception as e:
|
| 82 |
-
print(f"❌ Error creating game environment: {e}")
|
| 83 |
-
print("📋 Full traceback:", traceback.format_exc())
|
| 84 |
-
GAMES_AVAILABLE = False
|
| 85 |
-
else:
|
| 86 |
-
print("❌ All import methods failed - using fallback interface")
|
| 87 |
-
|
| 88 |
-
# Initialize model and tokenizer as global variables
|
| 89 |
-
model = None
|
| 90 |
-
tokenizer = None
|
| 91 |
-
|
| 92 |
-
def generate_reasoning(prompt):
|
| 93 |
-
"""Generate reasoning trace using Qwen model."""
|
| 94 |
-
global model, tokenizer
|
| 95 |
-
if model is None or tokenizer is None:
|
| 96 |
-
return "Error: Model not loaded. Please wait for the GPU to be ready."
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
def create_interface():
|
|
@@ -155,325 +231,201 @@ def create_interface():
|
|
| 155 |
}
|
| 156 |
"""
|
| 157 |
|
| 158 |
-
with gr.Blocks(title="SPIRAL:
|
| 159 |
-
gr.Markdown("# 🎮 SPIRAL:
|
| 160 |
-
gr.Markdown("
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
interactive = False
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
-
#
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
row, col = divmod(move, 3)
|
| 206 |
-
board[row, col] = -1
|
| 207 |
-
score, _ = minimax(board.copy(), 1)
|
| 208 |
-
board[row, col] = 0 # Undo move
|
| 209 |
-
if score > best_score:
|
| 210 |
-
best_score = score
|
| 211 |
-
best_move = move
|
| 212 |
-
else: # Human is player 1 (X), minimizing player
|
| 213 |
-
best_score = float('inf')
|
| 214 |
-
for move in tictactoe_env._get_valid_actions():
|
| 215 |
-
row, col = divmod(move, 3)
|
| 216 |
-
board[row, col] = 1
|
| 217 |
-
score, _ = minimax(board.copy(), -1)
|
| 218 |
-
board[row, col] = 0 # Undo move
|
| 219 |
-
if score < best_score:
|
| 220 |
-
best_score = score
|
| 221 |
-
best_move = move
|
| 222 |
-
return best_score, best_move
|
| 223 |
-
|
| 224 |
-
def play_tictactoe(position, stats):
|
| 225 |
-
"""Play a TicTacToe move and yield updates for the button grid."""
|
| 226 |
-
if tictactoe_env.game_over:
|
| 227 |
-
yield *update_board_buttons(), "Game is over! Click 'New Game' to start again.", "", stats
|
| 228 |
return
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
elif winner == "AI": stats['losses'] += 1
|
| 240 |
-
else: stats['draws'] += 1
|
| 241 |
-
yield *update_board_buttons(), f"Game Over! {winner} won!", "", stats
|
| 242 |
return
|
|
|
|
| 243 |
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
tictactoe_env.step(ai_action)
|
| 259 |
-
|
| 260 |
-
if tictactoe_env.game_over:
|
| 261 |
-
winner = "You" if tictactoe_env.winner == 1 else "AI" if tictactoe_env.winner == -1 else "Draw"
|
| 262 |
-
if winner == "You": stats['wins'] += 1
|
| 263 |
-
elif winner == "AI": stats['losses'] += 1
|
| 264 |
-
else: stats['draws'] += 1
|
| 265 |
-
yield *update_board_buttons(), f"Game Over! {winner} won! AI played {ai_action}.", reasoning, stats
|
| 266 |
-
else:
|
| 267 |
-
yield *update_board_buttons(), f"AI played position {ai_action}. Your turn!", reasoning, stats
|
| 268 |
|
| 269 |
-
|
| 270 |
-
|
| 271 |
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
tictactoe_env.reset()
|
| 275 |
-
return *update_board_buttons(), "New game started! You are ❌ (X). Click a square to play.", "AI will show its reasoning here...", stats
|
| 276 |
-
|
| 277 |
-
# Initialize the board on startup
|
| 278 |
tictactoe_env.reset()
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
gr.Markdown("
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
outputs=[*board_buttons, ttt_message, ttt_reasoning, ttt_stats]
|
| 325 |
-
)
|
| 326 |
-
|
| 327 |
-
ttt_reset_btn.click(
|
| 328 |
-
fn=reset_tictactoe,
|
| 329 |
-
inputs=[ttt_stats],
|
| 330 |
outputs=[*board_buttons, ttt_message, ttt_reasoning, ttt_stats]
|
| 331 |
)
|
| 332 |
-
# Update stats display on changes
|
| 333 |
-
ttt_stats.change(
|
| 334 |
-
fn=lambda s: f"Wins: {s['wins']} | Losses: {s['losses']} | Draws: {s['draws']}",
|
| 335 |
-
inputs=ttt_stats,
|
| 336 |
-
outputs=ttt_stats_display
|
| 337 |
-
)
|
| 338 |
-
|
| 339 |
-
# Initialize board display on load
|
| 340 |
-
demo.load(
|
| 341 |
-
fn=lambda stats: (*update_board_buttons(), "Game ready! You are ❌ (X). Click a square to play.", "AI will show its reasoning here...", stats),
|
| 342 |
-
inputs=[ttt_stats],
|
| 343 |
-
outputs=[*board_buttons, ttt_message, ttt_reasoning, ttt_stats]
|
| 344 |
-
)
|
| 345 |
-
gr.Markdown("---")
|
| 346 |
-
gr.Markdown("🚧 **This is a development preview.** Full SPIRAL training and reasoning capabilities will be added in the next update!")
|
| 347 |
-
|
| 348 |
-
else:
|
| 349 |
-
# Fallback interface when games don't load
|
| 350 |
-
gr.Markdown("⚠️ **Game modules could not be loaded.** Showing diagnostic information.")
|
| 351 |
-
gr.Markdown("This usually happens when dependencies are still installing on HF Spaces.")
|
| 352 |
-
|
| 353 |
-
# Show diagnostic info
|
| 354 |
-
gr.Markdown("### 🔍 Diagnostic Information:")
|
| 355 |
-
gr.Markdown(f"- Current directory: `{current_dir}`")
|
| 356 |
-
gr.Markdown(f"- Source path: `{src_path}`")
|
| 357 |
-
gr.Markdown(f"- Source directory exists: `{os.path.exists(src_path)}`")
|
| 358 |
-
|
| 359 |
-
if os.path.exists(src_path):
|
| 360 |
-
games_path = os.path.join(src_path, 'games')
|
| 361 |
-
gr.Markdown(f"- Games directory exists: `{os.path.exists(games_path)}`")
|
| 362 |
-
if os.path.exists(games_path):
|
| 363 |
-
gr.Markdown(f"- Games directory contents: `{os.listdir(games_path)}`")
|
| 364 |
-
|
| 365 |
-
# Simple demo interface
|
| 366 |
-
with gr.Row():
|
| 367 |
-
simple_input = gr.Textbox(label="Test Input", placeholder="Enter something...")
|
| 368 |
-
simple_output = gr.Textbox(label="Output", interactive=False)
|
| 369 |
-
|
| 370 |
-
def simple_echo(text):
|
| 371 |
-
return f"Echo: {text} (Game modules will be available once dependencies install)"
|
| 372 |
-
|
| 373 |
-
simple_input.submit(fn=simple_echo, inputs=[simple_input], outputs=[simple_output])
|
| 374 |
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
|
|
|
|
|
|
| 405 |
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
|
|
|
| 412 |
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
gr.Markdown("🔄 **Dependencies are loading.** Check the diagnostic info above and refresh in a few minutes!")
|
| 416 |
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
print("🚀 Starting main application...")
|
| 428 |
-
print("Loading configuration...")
|
| 429 |
-
with open('config.yaml', 'r') as f:
|
| 430 |
-
config = yaml.safe_load(f)
|
| 431 |
-
|
| 432 |
-
model_name = config['model']['name']
|
| 433 |
-
quantization_params = config['model'].get('quantization', {})
|
| 434 |
|
| 435 |
-
|
| 436 |
-
print(f"⚙️ Quantization Params: {quantization_params}")
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
# Create BitsAndBytesConfig if quantization is enabled
|
| 440 |
-
if quantization_params and quantization_params.get('load_in_4bit'):
|
| 441 |
-
print("💡 4-bit quantization enabled. Creating BitsAndBytesConfig...")
|
| 442 |
-
compute_dtype_str = quantization_params.get("bnb_4bit_compute_dtype", "float16")
|
| 443 |
-
|
| 444 |
-
if compute_dtype_str == "bfloat16":
|
| 445 |
-
compute_dtype = torch.bfloat16
|
| 446 |
-
else:
|
| 447 |
-
compute_dtype = torch.float16 # Default to float16
|
| 448 |
-
|
| 449 |
-
bnb_config = BitsAndBytesConfig(
|
| 450 |
-
load_in_4bit=True,
|
| 451 |
-
bnb_4bit_quant_type=quantization_params.get("bnb_4bit_quant_type", "nf4"),
|
| 452 |
-
bnb_4bit_compute_dtype=compute_dtype,
|
| 453 |
-
bnb_4bit_use_double_quant=quantization_params.get("bnb_4bit_use_double_quant", True),
|
| 454 |
-
)
|
| 455 |
-
# Using device_map="auto" is recommended for multi-GPU setups and large models
|
| 456 |
-
print("🧠 Loading 4-bit quantized model...")
|
| 457 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 458 |
-
model_name,
|
| 459 |
-
quantization_config=bnb_config,
|
| 460 |
-
device_map="auto"
|
| 461 |
-
)
|
| 462 |
-
else:
|
| 463 |
-
print("🧠 Loading model without quantization...")
|
| 464 |
-
# Fallback for no quantization
|
| 465 |
-
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
| 466 |
|
| 467 |
-
print("✒️ Loading tokenizer...")
|
| 468 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 469 |
-
|
| 470 |
-
print("✅ Model and tokenizer loaded successfully.")
|
| 471 |
|
| 472 |
-
|
| 473 |
demo = create_interface()
|
| 474 |
-
|
| 475 |
-
print("🚀 Launching Gradio app...")
|
| 476 |
demo.launch()
|
| 477 |
-
|
| 478 |
-
if __name__ == "__main__":
|
| 479 |
-
main()
|
|
|
|
| 1 |
"""
|
| 2 |
SPIRAL: Interactive Reasoning Game Simulator
|
| 3 |
|
| 4 |
+
Demonstrates key concepts from "Self-Play in Zero-Sum Games Incentivizes Reasoning via Multi-Agent Multi-Turn Reinforcement Learning"
|
| 5 |
+
|
| 6 |
+
This simplified demo shows how strategic reasoning emerges from self-play in zero-sum games like TicTacToe.
|
| 7 |
"""
|
| 8 |
|
| 9 |
import gradio as gr
|
| 10 |
import numpy as np
|
| 11 |
import random
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TicTacToeEnv:
|
| 15 |
+
"""Simple TicTacToe environment for SPIRAL demonstration."""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.reset()
|
| 19 |
+
|
| 20 |
+
def reset(self):
|
| 21 |
+
"""Reset the game to initial state."""
|
| 22 |
+
self.board = np.zeros((3, 3), dtype=np.int8)
|
| 23 |
+
self.current_player = 1 # Player 1 starts (X)
|
| 24 |
+
self.game_over = False
|
| 25 |
+
self.winner = None
|
| 26 |
+
self.move_count = 0
|
| 27 |
+
return self.board.copy()
|
| 28 |
+
|
| 29 |
+
def step(self, action):
|
| 30 |
+
"""Execute one step in the environment."""
|
| 31 |
+
if self.game_over:
|
| 32 |
+
return self.board.copy(), 0, True, {}
|
| 33 |
+
|
| 34 |
+
# Convert action to row, col
|
| 35 |
+
row, col = divmod(action, 3)
|
| 36 |
+
|
| 37 |
+
# Check if move is valid
|
| 38 |
+
if self.board[row, col] != 0:
|
| 39 |
+
return self.board.copy(), -1, True, {"invalid_move": True}
|
| 40 |
+
|
| 41 |
+
# Make the move
|
| 42 |
+
self.board[row, col] = self.current_player
|
| 43 |
+
self.move_count += 1
|
| 44 |
+
|
| 45 |
+
# Check for win
|
| 46 |
+
winner = self._check_winner()
|
| 47 |
+
if winner is not None:
|
| 48 |
+
self.game_over = True
|
| 49 |
+
self.winner = winner
|
| 50 |
+
reward = 1 if winner == self.current_player else -1
|
| 51 |
+
return self.board.copy(), reward, True, {}
|
| 52 |
+
elif self.move_count >= 9:
|
| 53 |
+
# Draw
|
| 54 |
+
self.game_over = True
|
| 55 |
+
return self.board.copy(), 0, True, {}
|
| 56 |
+
else:
|
| 57 |
+
# Game continues
|
| 58 |
+
self.current_player *= -1 # Switch player
|
| 59 |
+
return self.board.copy(), 0, False, {}
|
| 60 |
+
|
| 61 |
+
def _check_winner(self):
|
| 62 |
+
"""Check if there's a winner."""
|
| 63 |
+
# Check rows
|
| 64 |
+
for row in range(3):
|
| 65 |
+
if abs(self.board[row, :].sum()) == 3:
|
| 66 |
+
return self.board[row, 0]
|
| 67 |
+
|
| 68 |
+
# Check columns
|
| 69 |
+
for col in range(3):
|
| 70 |
+
if abs(self.board[:, col].sum()) == 3:
|
| 71 |
+
return self.board[0, col]
|
| 72 |
+
|
| 73 |
+
# Check diagonals
|
| 74 |
+
if abs(self.board.diagonal().sum()) == 3:
|
| 75 |
+
return self.board[0, 0]
|
| 76 |
+
|
| 77 |
+
if abs(np.fliplr(self.board).diagonal().sum()) == 3:
|
| 78 |
+
return self.board[0, 2]
|
| 79 |
+
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
def get_valid_actions(self):
|
| 83 |
+
"""Get list of valid actions (empty positions)."""
|
| 84 |
+
valid_actions = []
|
| 85 |
+
for i in range(9):
|
| 86 |
+
row, col = divmod(i, 3)
|
| 87 |
+
if self.board[row, col] == 0:
|
| 88 |
+
valid_actions.append(i)
|
| 89 |
+
return valid_actions
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Global game environment
|
| 93 |
+
tictactoe_env = TicTacToeEnv()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def check_winner(board):
|
| 97 |
+
"""Check if there's a winner on the given board."""
|
| 98 |
+
# Check rows
|
| 99 |
+
for row in range(3):
|
| 100 |
+
if abs(board[row, :].sum()) == 3:
|
| 101 |
+
return board[row, 0]
|
| 102 |
+
|
| 103 |
+
# Check columns
|
| 104 |
+
for col in range(3):
|
| 105 |
+
if abs(board[:, col].sum()) == 3:
|
| 106 |
+
return board[0, col]
|
| 107 |
+
|
| 108 |
+
# Check diagonals
|
| 109 |
+
if abs(board.diagonal().sum()) == 3:
|
| 110 |
+
return board[0, 0]
|
| 111 |
+
|
| 112 |
+
if abs(np.fliplr(board).diagonal().sum()) == 3:
|
| 113 |
+
return board[0, 2]
|
| 114 |
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def get_valid_moves(board):
|
| 119 |
+
"""Get valid moves for the given board."""
|
| 120 |
+
valid_moves = []
|
| 121 |
+
for i in range(9):
|
| 122 |
+
row, col = divmod(i, 3)
|
| 123 |
+
if board[row, col] == 0:
|
| 124 |
+
valid_moves.append(i)
|
| 125 |
+
return valid_moves
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def minimax(board, player, depth=0):
|
| 129 |
+
"""Minimax algorithm - demonstrates strategic reasoning."""
|
| 130 |
+
# Base cases
|
| 131 |
+
winner = check_winner(board)
|
| 132 |
+
if winner == 1: # Human wins
|
| 133 |
+
return -10 + depth, None
|
| 134 |
+
elif winner == -1: # AI wins
|
| 135 |
+
return 10 - depth, None
|
| 136 |
+
elif len(get_valid_moves(board)) == 0: # Draw
|
| 137 |
+
return 0, None
|
| 138 |
+
|
| 139 |
+
best_move = None
|
| 140 |
+
if player == -1: # AI is maximizing player
|
| 141 |
+
best_score = -float('inf')
|
| 142 |
+
for move in get_valid_moves(board):
|
| 143 |
+
row, col = divmod(move, 3)
|
| 144 |
+
board[row, col] = -1
|
| 145 |
+
score, _ = minimax(board.copy(), 1, depth + 1)
|
| 146 |
+
board[row, col] = 0 # Undo move
|
| 147 |
+
if score > best_score:
|
| 148 |
+
best_score = score
|
| 149 |
+
best_move = move
|
| 150 |
+
else: # Human is minimizing player
|
| 151 |
+
best_score = float('inf')
|
| 152 |
+
for move in get_valid_moves(board):
|
| 153 |
+
row, col = divmod(move, 3)
|
| 154 |
+
board[row, col] = 1
|
| 155 |
+
score, _ = minimax(board.copy(), -1, depth + 1)
|
| 156 |
+
board[row, col] = 0 # Undo move
|
| 157 |
+
if score < best_score:
|
| 158 |
+
best_score = score
|
| 159 |
+
best_move = move
|
| 160 |
+
|
| 161 |
+
return best_score, best_move
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def generate_reasoning(board_state, human_move, ai_move):
|
| 165 |
+
"""Generate reasoning explanation based on game state."""
|
| 166 |
+
reasoning_templates = [
|
| 167 |
+
f"I analyzed all possible moves from the current position. After you played position {human_move}, I considered {len(get_valid_moves(board_state))} possible responses. Using minimax tree search, I determined that position {ai_move} gives me the best strategic advantage.",
|
| 168 |
|
| 169 |
+
f"My decision process: (1) Evaluate immediate threats and opportunities, (2) Project future game states, (3) Choose move that maximizes my winning probability. Position {ai_move} emerged as optimal after analyzing the full game tree.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
+
f"Strategic analysis: Your move at {human_move} created a new board configuration. I used recursive tree search to evaluate all possible future sequences. Position {ai_move} either creates a winning opportunity or blocks your potential victories.",
|
| 172 |
+
|
| 173 |
+
f"SPIRAL reasoning: Through self-play training, I learned that position {ai_move} is strategically superior in this configuration. This demonstrates how strategic reasoning emerges from multi-agent interaction in zero-sum games."
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
return random.choice(reasoning_templates)
|
| 177 |
|
| 178 |
|
| 179 |
def create_interface():
|
|
|
|
| 231 |
}
|
| 232 |
"""
|
| 233 |
|
| 234 |
+
with gr.Blocks(title="SPIRAL: Self-Play Reasoning Demo", theme=gr.themes.Soft(), css=css) as demo:
|
| 235 |
+
gr.Markdown("# 🎮 SPIRAL: Self-Play Reasoning Demo")
|
| 236 |
+
gr.Markdown("**Demonstrating how strategic reasoning emerges from self-play in zero-sum games**")
|
| 237 |
+
gr.Markdown("*Based on: \"Self-Play in Zero-Sum Games Incentivizes Reasoning via Multi-Agent Multi-Turn Reinforcement Learning\"*")
|
| 238 |
+
|
| 239 |
+
def update_board_buttons():
|
| 240 |
+
"""Create a list of gr.Button updates from the current board state."""
|
| 241 |
+
updates = []
|
| 242 |
+
for i in range(9):
|
| 243 |
+
row, col = divmod(i, 3)
|
| 244 |
+
cell = tictactoe_env.board[row, col]
|
| 245 |
+
val = ""
|
| 246 |
+
interactive = True
|
| 247 |
+
if cell == 1:
|
| 248 |
+
val = '❌'
|
| 249 |
+
interactive = False
|
| 250 |
+
elif cell == -1:
|
| 251 |
+
val = '⭕'
|
| 252 |
+
interactive = False
|
| 253 |
+
|
| 254 |
+
if tictactoe_env.game_over:
|
| 255 |
+
interactive = False
|
|
|
|
| 256 |
|
| 257 |
+
updates.append(gr.Button(value=val, interactive=interactive))
|
| 258 |
+
return updates
|
| 259 |
|
| 260 |
+
ttt_stats = gr.State({'wins': 0, 'losses': 0, 'draws': 0})
|
| 261 |
+
|
| 262 |
+
def play_tictactoe(position, stats):
|
| 263 |
+
"""Play a TicTacToe move and demonstrate AI reasoning."""
|
| 264 |
+
if tictactoe_env.game_over:
|
| 265 |
+
yield *update_board_buttons(), "Game is over! Click 'New Game' to start again.", "", stats
|
| 266 |
+
return
|
| 267 |
+
|
| 268 |
+
try:
|
| 269 |
+
position = int(position)
|
| 270 |
|
| 271 |
+
# Human move
|
| 272 |
+
board_state, reward, done, info = tictactoe_env.step(position)
|
| 273 |
+
|
| 274 |
+
if done:
|
| 275 |
+
if info.get("invalid_move"):
|
| 276 |
+
yield *update_board_buttons(), "Invalid move! Try again.", "", stats
|
| 277 |
+
return
|
| 278 |
+
|
| 279 |
+
winner = "You" if tictactoe_env.winner == 1 else "AI" if tictactoe_env.winner == -1 else "Draw"
|
| 280 |
+
if winner == "You": stats['wins'] += 1
|
| 281 |
+
elif winner == "AI": stats['losses'] += 1
|
| 282 |
+
else: stats['draws'] += 1
|
| 283 |
+
yield *update_board_buttons(), f"Game Over! {winner} won!", "", stats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
return
|
| 285 |
|
| 286 |
+
# Show AI thinking
|
| 287 |
+
yield *update_board_buttons(), "AI is analyzing the game tree...", "🧠 Strategic reasoning in progress...", stats
|
| 288 |
+
|
| 289 |
+
# AI move using minimax
|
| 290 |
+
_, ai_action = minimax(tictactoe_env.board.copy(), -1)
|
| 291 |
+
if ai_action is None:
|
| 292 |
+
valid_actions = tictactoe_env.get_valid_actions()
|
| 293 |
+
if not valid_actions:
|
| 294 |
+
yield *update_board_buttons(), "Game is a draw!", "", stats
|
|
|
|
|
|
|
|
|
|
| 295 |
return
|
| 296 |
+
ai_action = random.choice(valid_actions)
|
| 297 |
|
| 298 |
+
# Generate reasoning explanation
|
| 299 |
+
reasoning = generate_reasoning(tictactoe_env.board.copy(), position, ai_action)
|
| 300 |
+
|
| 301 |
+
# AI makes move
|
| 302 |
+
board_state, reward, done, info = tictactoe_env.step(ai_action)
|
| 303 |
+
|
| 304 |
+
if done:
|
| 305 |
+
winner = "You" if tictactoe_env.winner == 1 else "AI" if tictactoe_env.winner == -1 else "Draw"
|
| 306 |
+
if winner == "You": stats['wins'] += 1
|
| 307 |
+
elif winner == "AI": stats['losses'] += 1
|
| 308 |
+
else: stats['draws'] += 1
|
| 309 |
+
yield *update_board_buttons(), f"Game Over! {winner} won! AI played position {ai_action}.", reasoning, stats
|
| 310 |
+
else:
|
| 311 |
+
yield *update_board_buttons(), f"AI chose position {ai_action}. Your turn!", reasoning, stats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
+
except Exception as e:
|
| 314 |
+
yield *update_board_buttons(), f"Error: {str(e)}", "", stats
|
| 315 |
|
| 316 |
+
def reset_tictactoe(stats):
|
| 317 |
+
"""Reset TicTacToe game."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
tictactoe_env.reset()
|
| 319 |
+
return *update_board_buttons(), "New game started! You are ❌ (X). Click a square to demonstrate strategic reasoning.", "The AI will explain its strategic decision-making process...", stats
|
| 320 |
+
|
| 321 |
+
# Initialize the board
|
| 322 |
+
tictactoe_env.reset()
|
| 323 |
+
|
| 324 |
+
# Game interface
|
| 325 |
+
with gr.Row():
|
| 326 |
+
gr.Markdown("### Strategic TicTacToe")
|
| 327 |
+
gr.Markdown("") # spacer
|
| 328 |
+
ttt_reset_btn = gr.Button("🔄 New Game", variant="secondary", size="sm")
|
| 329 |
+
|
| 330 |
+
gr.Markdown("**You are ❌ (X)** - The AI uses minimax tree search to demonstrate strategic reasoning")
|
| 331 |
+
|
| 332 |
+
# Game board
|
| 333 |
+
with gr.Column(elem_classes=["ttt-board"]):
|
| 334 |
+
board_buttons = []
|
| 335 |
+
for i in range(3):
|
| 336 |
+
with gr.Row(elem_classes=["ttt-row"]):
|
| 337 |
+
for j in range(3):
|
| 338 |
+
pos = i * 3 + j
|
| 339 |
+
button = gr.Button("", elem_id=f"ttt-cell-{pos}", size="lg", value="")
|
| 340 |
+
board_buttons.append(button)
|
| 341 |
+
|
| 342 |
+
# Stats display
|
| 343 |
+
with gr.Row():
|
| 344 |
+
ttt_stats_display = gr.Markdown(value="**Wins: 0 | Losses: 0 | Draws: 0**", elem_classes=["ttt-stats"])
|
| 345 |
+
|
| 346 |
+
# Game status and AI reasoning
|
| 347 |
+
ttt_message = gr.Textbox(
|
| 348 |
+
label="🎯 Game Status",
|
| 349 |
+
value="Click a square to start! Watch how the AI reasons strategically.",
|
| 350 |
+
lines=2,
|
| 351 |
+
interactive=False
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
ttt_reasoning = gr.Textbox(
|
| 355 |
+
label="🧠 AI Strategic Reasoning",
|
| 356 |
+
value="The AI will explain its strategic decision-making process here, demonstrating how reasoning emerges from self-play training in zero-sum games.",
|
| 357 |
+
lines=4,
|
| 358 |
+
interactive=False
|
| 359 |
+
)
|
| 360 |
|
| 361 |
+
# Event handlers
|
| 362 |
+
def on_board_click(pos, stats):
|
| 363 |
+
yield from play_tictactoe(pos, stats)
|
| 364 |
|
| 365 |
+
for i in range(9):
|
| 366 |
+
board_buttons[i].click(
|
| 367 |
+
fn=on_board_click,
|
| 368 |
+
inputs=[gr.State(i), ttt_stats],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
outputs=[*board_buttons, ttt_message, ttt_reasoning, ttt_stats]
|
| 370 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
+
ttt_reset_btn.click(
|
| 373 |
+
fn=reset_tictactoe,
|
| 374 |
+
inputs=[ttt_stats],
|
| 375 |
+
outputs=[*board_buttons, ttt_message, ttt_reasoning, ttt_stats]
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# Update stats display
|
| 379 |
+
ttt_stats.change(
|
| 380 |
+
fn=lambda s: f"**Wins: {s['wins']} | Losses: {s['losses']} | Draws: {s['draws']}**",
|
| 381 |
+
inputs=ttt_stats,
|
| 382 |
+
outputs=ttt_stats_display
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# Initialize board display on load
|
| 386 |
+
demo.load(
|
| 387 |
+
fn=lambda stats: (*update_board_buttons(), "Click a square to start! Watch how the AI reasons strategically.", "The AI will explain its strategic decision-making process here, demonstrating how reasoning emerges from self-play training in zero-sum games.", stats),
|
| 388 |
+
inputs=[ttt_stats],
|
| 389 |
+
outputs=[*board_buttons, ttt_message, ttt_reasoning, ttt_stats]
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
# Key concepts section
|
| 393 |
+
gr.Markdown("---")
|
| 394 |
+
gr.Markdown("## 🧠 Key SPIRAL Concepts Demonstrated")
|
| 395 |
+
|
| 396 |
+
with gr.Row():
|
| 397 |
+
with gr.Column():
|
| 398 |
+
gr.Markdown("""
|
| 399 |
+
**🎯 Strategic Reasoning**
|
| 400 |
+
- AI uses minimax tree search
|
| 401 |
+
- Evaluates all possible future moves
|
| 402 |
+
- Chooses optimal strategic actions
|
| 403 |
+
""")
|
| 404 |
|
| 405 |
+
with gr.Column():
|
| 406 |
+
gr.Markdown("""
|
| 407 |
+
**🔄 Self-Play Learning**
|
| 408 |
+
- Strategic patterns emerge from competition
|
| 409 |
+
- Zero-sum games incentivize reasoning
|
| 410 |
+
- Multi-agent interactions develop intelligence
|
| 411 |
+
""")
|
| 412 |
|
| 413 |
+
gr.Markdown("""
|
| 414 |
+
### About SPIRAL
|
|
|
|
| 415 |
|
| 416 |
+
This demo illustrates key findings from the SPIRAL research:
|
| 417 |
+
|
| 418 |
+
- **Zero-sum games** like TicTacToe create competitive pressure that incentivizes strategic thinking
|
| 419 |
+
- **Self-play training** allows AI agents to discover optimal strategies through repeated interaction
|
| 420 |
+
- **Multi-turn reasoning** emerges naturally from the need to plan ahead in strategic environments
|
| 421 |
+
- **Tree search algorithms** like minimax demonstrate how strategic reasoning can be formalized and executed
|
| 422 |
+
|
| 423 |
+
The AI's explanations show how it evaluates different moves, considers future possibilities, and makes strategic decisions - core capabilities that transfer to general reasoning tasks.
|
| 424 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
|
| 426 |
+
return demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
|
| 429 |
+
if __name__ == "__main__":
|
| 430 |
demo = create_interface()
|
|
|
|
|
|
|
| 431 |
demo.launch()
|
|
|
|
|
|
|
|
|
config.yaml
DELETED
|
@@ -1,124 +0,0 @@
|
|
| 1 |
-
# SPIRAL Interactive Reasoning Game Simulator Configuration
|
| 2 |
-
|
| 3 |
-
# Model Configuration
|
| 4 |
-
model:
|
| 5 |
-
name: "meta-llama/Llama-3.1-8B-Instruct"
|
| 6 |
-
max_length: 2048
|
| 7 |
-
temperature: 0.7
|
| 8 |
-
do_sample: true
|
| 9 |
-
quantization:
|
| 10 |
-
load_in_4bit: true
|
| 11 |
-
bnb_4bit_compute_dtype: "float16"
|
| 12 |
-
bnb_4bit_use_double_quant: true
|
| 13 |
-
|
| 14 |
-
# Games Configuration
|
| 15 |
-
games:
|
| 16 |
-
kuhn_poker:
|
| 17 |
-
name: "Kuhn Poker"
|
| 18 |
-
max_rounds: 50
|
| 19 |
-
deck_size: 3
|
| 20 |
-
betting_rounds: 2
|
| 21 |
-
|
| 22 |
-
tictactoe:
|
| 23 |
-
name: "TicTacToe"
|
| 24 |
-
board_size: 3
|
| 25 |
-
max_moves: 9
|
| 26 |
-
win_condition: 3
|
| 27 |
-
|
| 28 |
-
# Training Configuration
|
| 29 |
-
training:
|
| 30 |
-
algorithm: "PPO"
|
| 31 |
-
episodes: 1000
|
| 32 |
-
batch_size: 32
|
| 33 |
-
learning_rate: 0.0003
|
| 34 |
-
gamma: 0.99
|
| 35 |
-
gae_lambda: 0.95
|
| 36 |
-
clip_range: 0.2
|
| 37 |
-
entropy_coef: 0.01
|
| 38 |
-
value_loss_coef: 0.5
|
| 39 |
-
max_grad_norm: 0.5
|
| 40 |
-
|
| 41 |
-
# Self-play specific
|
| 42 |
-
self_play:
|
| 43 |
-
update_opponent_every: 100
|
| 44 |
-
opponent_pool_size: 5
|
| 45 |
-
|
| 46 |
-
# Role-conditioned advantage estimation
|
| 47 |
-
rae:
|
| 48 |
-
enable: true
|
| 49 |
-
role_embedding_dim: 64
|
| 50 |
-
advantage_weighting: 0.5
|
| 51 |
-
|
| 52 |
-
# Reasoning Configuration
|
| 53 |
-
reasoning:
|
| 54 |
-
enable_traces: true
|
| 55 |
-
trace_depth: 3
|
| 56 |
-
chain_of_thought: true
|
| 57 |
-
explanation_length: 150
|
| 58 |
-
|
| 59 |
-
# Transfer learning evaluation
|
| 60 |
-
transfer_tasks:
|
| 61 |
-
- "GSM8K"
|
| 62 |
-
- "Logic Puzzles"
|
| 63 |
-
- "Strategic Reasoning"
|
| 64 |
-
|
| 65 |
-
# Web Interface Configuration
|
| 66 |
-
interface:
|
| 67 |
-
title: "SPIRAL: Interactive Reasoning Game Simulator"
|
| 68 |
-
description: "Play games against AI and explore reasoning capabilities"
|
| 69 |
-
theme: "default"
|
| 70 |
-
|
| 71 |
-
# Gradio settings
|
| 72 |
-
gradio:
|
| 73 |
-
share: false
|
| 74 |
-
inbrowser: true
|
| 75 |
-
server_name: "0.0.0.0"
|
| 76 |
-
server_port: 7860
|
| 77 |
-
enable_queue: true
|
| 78 |
-
max_threads: 4
|
| 79 |
-
|
| 80 |
-
# Logging Configuration
|
| 81 |
-
logging:
|
| 82 |
-
level: "INFO"
|
| 83 |
-
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 84 |
-
file: "logs/spiral.log"
|
| 85 |
-
|
| 86 |
-
# Experiment tracking
|
| 87 |
-
wandb:
|
| 88 |
-
enable: false
|
| 89 |
-
project: "spiral-reasoning"
|
| 90 |
-
entity: "your-username"
|
| 91 |
-
|
| 92 |
-
tensorboard:
|
| 93 |
-
enable: true
|
| 94 |
-
log_dir: "logs/tensorboard"
|
| 95 |
-
|
| 96 |
-
# Data Configuration
|
| 97 |
-
data:
|
| 98 |
-
cache_dir: "data/cache"
|
| 99 |
-
datasets_dir: "data/datasets"
|
| 100 |
-
models_dir: "models"
|
| 101 |
-
|
| 102 |
-
# Benchmark datasets
|
| 103 |
-
benchmarks:
|
| 104 |
-
gsm8k: "data/benchmarks/gsm8k.json"
|
| 105 |
-
logic_puzzles: "data/benchmarks/logic_puzzles.json"
|
| 106 |
-
|
| 107 |
-
# Deployment Configuration
|
| 108 |
-
deployment:
|
| 109 |
-
huggingface:
|
| 110 |
-
space_name: "kaushikvr06/reasoning-simulator"
|
| 111 |
-
private: false
|
| 112 |
-
|
| 113 |
-
# Performance settings
|
| 114 |
-
performance:
|
| 115 |
-
max_concurrent_users: 10
|
| 116 |
-
timeout_seconds: 30
|
| 117 |
-
memory_limit: "2GB"
|
| 118 |
-
|
| 119 |
-
# Debug Configuration
|
| 120 |
-
debug:
|
| 121 |
-
enable: false
|
| 122 |
-
verbose_traces: false
|
| 123 |
-
save_game_logs: true
|
| 124 |
-
profile_inference: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/README.md
DELETED
|
@@ -1,16 +0,0 @@
|
|
| 1 |
-
# Data Directory
|
| 2 |
-
|
| 3 |
-
This directory contains datasets and game-related files for the SPIRAL project.
|
| 4 |
-
|
| 5 |
-
## Structure
|
| 6 |
-
|
| 7 |
-
- `games/` - Game datasets and rule definitions
|
| 8 |
-
- `benchmarks/` - Math and logic benchmarks for transfer testing (e.g., GSM8K)
|
| 9 |
-
- `training/` - Training data and logs
|
| 10 |
-
- `examples/` - Example game sessions and reasoning traces
|
| 11 |
-
|
| 12 |
-
## Data Sources
|
| 13 |
-
|
| 14 |
-
- Game implementations from GitHub repositories
|
| 15 |
-
- Math benchmarks like GSM8K for transfer evaluation
|
| 16 |
-
- Custom game datasets generated during training
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,15 +1,2 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
gymnasium>=0.29.0
|
| 4 |
-
stable-baselines3>=2.0.0
|
| 5 |
-
gradio>=4.0.0
|
| 6 |
-
numpy>=1.21.0
|
| 7 |
-
matplotlib>=3.5.0
|
| 8 |
-
seaborn>=0.11.0
|
| 9 |
-
pandas>=1.3.0
|
| 10 |
-
tqdm>=4.62.0
|
| 11 |
-
pyyaml
|
| 12 |
-
bitsandbytes
|
| 13 |
-
accelerate>=0.26.0
|
| 14 |
-
pytest
|
| 15 |
-
Jinja2
|
|
|
|
| 1 |
+
gradio==4.44.0
|
| 2 |
+
numpy==1.24.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/__init__.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
SPIRAL: Self-Play on Zero-Sum Games Incentivizes Reasoning
|
| 3 |
-
|
| 4 |
-
This package implements the SPIRAL methodology for training AI agents
|
| 5 |
-
through self-play on zero-sum games to improve reasoning capabilities.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
__version__ = "0.1.0"
|
| 9 |
-
__author__ = "SPIRAL Team"
|
| 10 |
-
__email__ = "contact@spiral-reasoning.com"
|
| 11 |
-
|
| 12 |
-
from .games import *
|
| 13 |
-
from .models import *
|
| 14 |
-
from .training import *
|
| 15 |
-
from .reasoning import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/games/__init__.py
DELETED
|
@@ -1,16 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Game environments for SPIRAL training.
|
| 3 |
-
|
| 4 |
-
This module contains implementations of zero-sum games used for
|
| 5 |
-
self-play training, including Kuhn Poker and TicTacToe.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from .tictactoe import TicTacToeEnv, create_tictactoe_env
|
| 9 |
-
from .kuhn_poker import KuhnPokerEnv, create_kuhn_poker_env
|
| 10 |
-
|
| 11 |
-
__all__ = [
|
| 12 |
-
"TicTacToeEnv",
|
| 13 |
-
"KuhnPokerEnv",
|
| 14 |
-
"create_tictactoe_env",
|
| 15 |
-
"create_kuhn_poker_env"
|
| 16 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/games/game_utils.py
DELETED
|
@@ -1,212 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Game utility functions for SPIRAL training.
|
| 3 |
-
|
| 4 |
-
This module contains helper functions for game environments,
|
| 5 |
-
including multi-turn logic and game state management.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import gymnasium as gym
|
| 9 |
-
from typing import Dict, Any, Type, Union
|
| 10 |
-
import numpy as np
|
| 11 |
-
|
| 12 |
-
from .tictactoe import TicTacToeEnv
|
| 13 |
-
from .kuhn_poker import KuhnPokerEnv
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
# Game registry
|
| 17 |
-
GAMES_REGISTRY: Dict[str, Type[gym.Env]] = {
|
| 18 |
-
"tictactoe": TicTacToeEnv,
|
| 19 |
-
"kuhn_poker": KuhnPokerEnv,
|
| 20 |
-
}
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def create_game_env(game_name: str, **kwargs) -> gym.Env:
|
| 24 |
-
"""
|
| 25 |
-
Create a game environment by name.
|
| 26 |
-
|
| 27 |
-
Args:
|
| 28 |
-
game_name: Name of the game ("tictactoe", "kuhn_poker")
|
| 29 |
-
**kwargs: Additional arguments for the environment
|
| 30 |
-
|
| 31 |
-
Returns:
|
| 32 |
-
Game environment instance
|
| 33 |
-
|
| 34 |
-
Raises:
|
| 35 |
-
ValueError: If game_name is not recognized
|
| 36 |
-
"""
|
| 37 |
-
if game_name not in GAMES_REGISTRY:
|
| 38 |
-
available_games = list(GAMES_REGISTRY.keys())
|
| 39 |
-
raise ValueError(f"Unknown game: {game_name}. Available games: {available_games}")
|
| 40 |
-
|
| 41 |
-
game_class = GAMES_REGISTRY[game_name]
|
| 42 |
-
return game_class(**kwargs)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def get_game_info(game_name: str) -> Dict[str, Any]:
|
| 46 |
-
"""
|
| 47 |
-
Get information about a game environment.
|
| 48 |
-
|
| 49 |
-
Args:
|
| 50 |
-
game_name: Name of the game
|
| 51 |
-
|
| 52 |
-
Returns:
|
| 53 |
-
Dictionary with game information
|
| 54 |
-
"""
|
| 55 |
-
env = create_game_env(game_name)
|
| 56 |
-
|
| 57 |
-
info = {
|
| 58 |
-
"name": game_name,
|
| 59 |
-
"action_space": env.action_space,
|
| 60 |
-
"observation_space": env.observation_space,
|
| 61 |
-
"max_episode_steps": getattr(env, "_max_episode_steps", None),
|
| 62 |
-
"render_modes": env.metadata.get("render_modes", []),
|
| 63 |
-
}
|
| 64 |
-
|
| 65 |
-
# Add game-specific information
|
| 66 |
-
if game_name == "tictactoe":
|
| 67 |
-
info.update({
|
| 68 |
-
"description": "3x3 TicTacToe game with alternating turns",
|
| 69 |
-
"players": 2,
|
| 70 |
-
"zero_sum": True,
|
| 71 |
-
"perfect_information": True,
|
| 72 |
-
})
|
| 73 |
-
elif game_name == "kuhn_poker":
|
| 74 |
-
info.update({
|
| 75 |
-
"description": "Simplified poker with 3 cards (J, Q, K)",
|
| 76 |
-
"players": 2,
|
| 77 |
-
"zero_sum": True,
|
| 78 |
-
"perfect_information": False,
|
| 79 |
-
})
|
| 80 |
-
|
| 81 |
-
env.close()
|
| 82 |
-
return info
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
def get_available_games() -> list:
|
| 86 |
-
"""Get list of available game names."""
|
| 87 |
-
return list(GAMES_REGISTRY.keys())
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def is_game_over(env: gym.Env) -> bool:
|
| 91 |
-
"""
|
| 92 |
-
Check if the game is over.
|
| 93 |
-
|
| 94 |
-
Args:
|
| 95 |
-
env: Game environment
|
| 96 |
-
|
| 97 |
-
Returns:
|
| 98 |
-
True if game is over, False otherwise
|
| 99 |
-
"""
|
| 100 |
-
if hasattr(env, 'game_over'):
|
| 101 |
-
return env.game_over
|
| 102 |
-
return False
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
def get_valid_actions(env: gym.Env) -> list:
|
| 106 |
-
"""
|
| 107 |
-
Get valid actions for the current state.
|
| 108 |
-
|
| 109 |
-
Args:
|
| 110 |
-
env: Game environment
|
| 111 |
-
|
| 112 |
-
Returns:
|
| 113 |
-
List of valid actions
|
| 114 |
-
"""
|
| 115 |
-
if hasattr(env, '_get_valid_actions'):
|
| 116 |
-
return env._get_valid_actions()
|
| 117 |
-
elif hasattr(env, 'get_valid_actions'):
|
| 118 |
-
return env.get_valid_actions()
|
| 119 |
-
else:
|
| 120 |
-
# Fallback: assume all actions are valid
|
| 121 |
-
return list(range(env.action_space.n))
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def get_action_mask(env: gym.Env) -> np.ndarray:
|
| 125 |
-
"""
|
| 126 |
-
Get action mask for the current state.
|
| 127 |
-
|
| 128 |
-
Args:
|
| 129 |
-
env: Game environment
|
| 130 |
-
|
| 131 |
-
Returns:
|
| 132 |
-
Boolean mask where True indicates valid actions
|
| 133 |
-
"""
|
| 134 |
-
if hasattr(env, 'get_action_mask'):
|
| 135 |
-
return env.get_action_mask()
|
| 136 |
-
else:
|
| 137 |
-
# Fallback: create mask from valid actions
|
| 138 |
-
valid_actions = get_valid_actions(env)
|
| 139 |
-
mask = np.zeros(env.action_space.n, dtype=bool)
|
| 140 |
-
for action in valid_actions:
|
| 141 |
-
mask[action] = True
|
| 142 |
-
return mask
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
def play_random_game(game_name: str, render: bool = False, seed: int = None) -> Dict[str, Any]:
|
| 146 |
-
"""
|
| 147 |
-
Play a random game to completion.
|
| 148 |
-
|
| 149 |
-
Args:
|
| 150 |
-
game_name: Name of the game to play
|
| 151 |
-
render: Whether to render the game
|
| 152 |
-
seed: Random seed for reproducibility
|
| 153 |
-
|
| 154 |
-
Returns:
|
| 155 |
-
Dictionary with game results
|
| 156 |
-
"""
|
| 157 |
-
env = create_game_env(game_name, render_mode="human" if render else None)
|
| 158 |
-
|
| 159 |
-
if seed is not None:
|
| 160 |
-
env.reset(seed=seed)
|
| 161 |
-
else:
|
| 162 |
-
env.reset()
|
| 163 |
-
|
| 164 |
-
if render:
|
| 165 |
-
env.render()
|
| 166 |
-
|
| 167 |
-
total_reward = 0
|
| 168 |
-
step_count = 0
|
| 169 |
-
actions_taken = []
|
| 170 |
-
|
| 171 |
-
while not is_game_over(env):
|
| 172 |
-
valid_actions = get_valid_actions(env)
|
| 173 |
-
action = np.random.choice(valid_actions)
|
| 174 |
-
|
| 175 |
-
obs, reward, terminated, truncated, info = env.step(action)
|
| 176 |
-
actions_taken.append(action)
|
| 177 |
-
total_reward += reward
|
| 178 |
-
step_count += 1
|
| 179 |
-
|
| 180 |
-
if render:
|
| 181 |
-
print(f"Step {step_count}: Action {action}, Reward: {reward}")
|
| 182 |
-
env.render()
|
| 183 |
-
|
| 184 |
-
if terminated or truncated:
|
| 185 |
-
break
|
| 186 |
-
|
| 187 |
-
results = {
|
| 188 |
-
"game_name": game_name,
|
| 189 |
-
"total_reward": total_reward,
|
| 190 |
-
"step_count": step_count,
|
| 191 |
-
"actions_taken": actions_taken,
|
| 192 |
-
"winner": getattr(env, 'winner', None),
|
| 193 |
-
"final_info": info
|
| 194 |
-
}
|
| 195 |
-
|
| 196 |
-
env.close()
|
| 197 |
-
return results
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
if __name__ == "__main__":
|
| 201 |
-
# Test the utilities
|
| 202 |
-
print("Available games:", get_available_games())
|
| 203 |
-
|
| 204 |
-
for game_name in get_available_games():
|
| 205 |
-
print(f"\n{game_name.upper()} Info:")
|
| 206 |
-
info = get_game_info(game_name)
|
| 207 |
-
for key, value in info.items():
|
| 208 |
-
print(f" {key}: {value}")
|
| 209 |
-
|
| 210 |
-
# Play a random game
|
| 211 |
-
print("\nPlaying random TicTacToe game:")
|
| 212 |
-
result = play_random_game("tictactoe", render=True, seed=42)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/games/kuhn_poker.py
DELETED
|
@@ -1,314 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Kuhn Poker Game Environment
|
| 3 |
-
|
| 4 |
-
A simple Kuhn Poker implementation using Gymnasium for SPIRAL training.
|
| 5 |
-
Kuhn Poker is a simplified poker variant with 3 cards (J, Q, K).
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import gymnasium as gym
|
| 9 |
-
import numpy as np
|
| 10 |
-
from gymnasium import spaces
|
| 11 |
-
from typing import Tuple, Dict, Any, Optional, List
|
| 12 |
-
import random
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class KuhnPokerEnv(gym.Env):
|
| 16 |
-
"""
|
| 17 |
-
Kuhn Poker environment for SPIRAL training.
|
| 18 |
-
|
| 19 |
-
Rules:
|
| 20 |
-
- 3 cards: Jack (0), Queen (1), King (2)
|
| 21 |
-
- Each player gets 1 card
|
| 22 |
-
- Each player antes 1 chip
|
| 23 |
-
- Player 1 acts first: Check or Bet
|
| 24 |
-
- Player 2 then acts: Check, Call, or Fold
|
| 25 |
-
- If both check, high card wins
|
| 26 |
-
- If one bets and other calls, high card wins
|
| 27 |
-
- If one bets and other folds, bettor wins
|
| 28 |
-
|
| 29 |
-
Action space: [Check/Call=0, Bet=1, Fold=2]
|
| 30 |
-
Observation space: [player_card, opponent_action, betting_round]
|
| 31 |
-
"""
|
| 32 |
-
|
| 33 |
-
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 1}
|
| 34 |
-
|
| 35 |
-
# Card values: Jack=0, Queen=1, King=2
|
| 36 |
-
JACK, QUEEN, KING = 0, 1, 2
|
| 37 |
-
CARDS = [JACK, QUEEN, KING]
|
| 38 |
-
CARD_NAMES = ["J", "Q", "K"]
|
| 39 |
-
|
| 40 |
-
# Actions
|
| 41 |
-
CHECK_CALL, BET, FOLD = 0, 1, 2
|
| 42 |
-
ACTION_NAMES = ["Check/Call", "Bet", "Fold"]
|
| 43 |
-
|
| 44 |
-
def __init__(self, render_mode: Optional[str] = None):
|
| 45 |
-
super().__init__()
|
| 46 |
-
|
| 47 |
-
# Observation: [player_card, opponent_last_action, betting_round, pot_size]
|
| 48 |
-
self.observation_space = spaces.Box(
|
| 49 |
-
low=0, high=10, shape=(4,), dtype=np.int8
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
# Actions: Check/Call, Bet, Fold
|
| 53 |
-
self.action_space = spaces.Discrete(3)
|
| 54 |
-
|
| 55 |
-
self.render_mode = render_mode
|
| 56 |
-
self.reset()
|
| 57 |
-
|
| 58 |
-
def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]:
|
| 59 |
-
"""Reset the game to initial state."""
|
| 60 |
-
super().reset(seed=seed)
|
| 61 |
-
|
| 62 |
-
# Deal cards
|
| 63 |
-
cards = self.CARDS.copy()
|
| 64 |
-
random.shuffle(cards)
|
| 65 |
-
self.player1_card = cards[0]
|
| 66 |
-
self.player2_card = cards[1]
|
| 67 |
-
|
| 68 |
-
# Game state
|
| 69 |
-
self.current_player = 1 # Player 1 starts
|
| 70 |
-
self.pot = 2 # Each player antes 1
|
| 71 |
-
self.player1_bet = 1 # Ante
|
| 72 |
-
self.player2_bet = 1 # Ante
|
| 73 |
-
self.game_over = False
|
| 74 |
-
self.winner = None
|
| 75 |
-
self.betting_round = 0
|
| 76 |
-
self.actions_history = []
|
| 77 |
-
|
| 78 |
-
observation = self._get_observation()
|
| 79 |
-
info = self._get_info()
|
| 80 |
-
|
| 81 |
-
return observation, info
|
| 82 |
-
|
| 83 |
-
def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
|
| 84 |
-
"""
|
| 85 |
-
Execute one step in the environment.
|
| 86 |
-
|
| 87 |
-
Args:
|
| 88 |
-
action: 0=Check/Call, 1=Bet, 2=Fold
|
| 89 |
-
|
| 90 |
-
Returns:
|
| 91 |
-
observation, reward, terminated, truncated, info
|
| 92 |
-
"""
|
| 93 |
-
if self.game_over:
|
| 94 |
-
raise ValueError("Game is already over. Call reset() to start new game.")
|
| 95 |
-
|
| 96 |
-
# Record action
|
| 97 |
-
self.actions_history.append((self.current_player, action))
|
| 98 |
-
|
| 99 |
-
# Process action
|
| 100 |
-
if action == self.FOLD:
|
| 101 |
-
# Current player folds, opponent wins
|
| 102 |
-
self.game_over = True
|
| 103 |
-
self.winner = 2 if self.current_player == 1 else 1
|
| 104 |
-
reward = self._calculate_reward()
|
| 105 |
-
|
| 106 |
-
elif action == self.BET:
|
| 107 |
-
# Current player bets
|
| 108 |
-
if self.current_player == 1:
|
| 109 |
-
self.player1_bet += 1
|
| 110 |
-
self.pot += 1
|
| 111 |
-
else:
|
| 112 |
-
self.player2_bet += 1
|
| 113 |
-
self.pot += 1
|
| 114 |
-
|
| 115 |
-
# Check if this ends the betting round
|
| 116 |
-
if self.betting_round == 0:
|
| 117 |
-
# First bet, opponent gets to act
|
| 118 |
-
self.current_player = 2
|
| 119 |
-
self.betting_round = 1
|
| 120 |
-
reward = 0.0
|
| 121 |
-
else:
|
| 122 |
-
# Second bet (raise), go to showdown
|
| 123 |
-
self.game_over = True
|
| 124 |
-
self.winner = self._determine_winner_by_cards()
|
| 125 |
-
reward = self._calculate_reward()
|
| 126 |
-
|
| 127 |
-
else: # CHECK_CALL
|
| 128 |
-
if self.betting_round == 0:
|
| 129 |
-
# First action is check
|
| 130 |
-
if self.current_player == 1:
|
| 131 |
-
# Player 1 checks, player 2 acts
|
| 132 |
-
self.current_player = 2
|
| 133 |
-
self.betting_round = 1
|
| 134 |
-
reward = 0.0
|
| 135 |
-
else:
|
| 136 |
-
# Player 2 checks after player 1 checked, showdown
|
| 137 |
-
self.game_over = True
|
| 138 |
-
self.winner = self._determine_winner_by_cards()
|
| 139 |
-
reward = self._calculate_reward()
|
| 140 |
-
else:
|
| 141 |
-
# This is a call
|
| 142 |
-
if self.current_player == 2:
|
| 143 |
-
# Player 2 calls player 1's bet
|
| 144 |
-
self.player2_bet = self.player1_bet
|
| 145 |
-
self.pot = self.player1_bet + self.player2_bet
|
| 146 |
-
self.game_over = True
|
| 147 |
-
self.winner = self._determine_winner_by_cards()
|
| 148 |
-
reward = self._calculate_reward()
|
| 149 |
-
else:
|
| 150 |
-
# Player 1 calls player 2's bet
|
| 151 |
-
self.player1_bet = self.player2_bet
|
| 152 |
-
self.pot = self.player1_bet + self.player2_bet
|
| 153 |
-
self.game_over = True
|
| 154 |
-
self.winner = self._determine_winner_by_cards()
|
| 155 |
-
reward = self._calculate_reward()
|
| 156 |
-
|
| 157 |
-
observation = self._get_observation()
|
| 158 |
-
info = self._get_info()
|
| 159 |
-
|
| 160 |
-
return observation, reward, self.game_over, False, info
|
| 161 |
-
|
| 162 |
-
def _get_observation(self) -> np.ndarray:
|
| 163 |
-
"""Get current observation for the current player."""
|
| 164 |
-
# Get current player's card
|
| 165 |
-
player_card = self.player1_card if self.current_player == 1 else self.player2_card
|
| 166 |
-
|
| 167 |
-
# Get opponent's last action (if any)
|
| 168 |
-
opponent_last_action = -1
|
| 169 |
-
if self.actions_history:
|
| 170 |
-
for player, action in reversed(self.actions_history):
|
| 171 |
-
if player != self.current_player:
|
| 172 |
-
opponent_last_action = action
|
| 173 |
-
break
|
| 174 |
-
|
| 175 |
-
# Observation: [player_card, opponent_last_action, betting_round, pot_size]
|
| 176 |
-
observation = np.array([
|
| 177 |
-
player_card,
|
| 178 |
-
opponent_last_action + 1, # -1 becomes 0, 0 becomes 1, etc.
|
| 179 |
-
self.betting_round,
|
| 180 |
-
self.pot
|
| 181 |
-
], dtype=np.int8)
|
| 182 |
-
|
| 183 |
-
return observation
|
| 184 |
-
|
| 185 |
-
def _get_info(self) -> Dict[str, Any]:
|
| 186 |
-
"""Get additional info about the game state."""
|
| 187 |
-
return {
|
| 188 |
-
"current_player": self.current_player,
|
| 189 |
-
"game_over": self.game_over,
|
| 190 |
-
"winner": self.winner,
|
| 191 |
-
"player1_card": self.player1_card,
|
| 192 |
-
"player2_card": self.player2_card,
|
| 193 |
-
"pot": self.pot,
|
| 194 |
-
"betting_round": self.betting_round,
|
| 195 |
-
"actions_history": self.actions_history.copy(),
|
| 196 |
-
"valid_actions": self._get_valid_actions()
|
| 197 |
-
}
|
| 198 |
-
|
| 199 |
-
def _get_valid_actions(self) -> List[int]:
|
| 200 |
-
"""Get list of valid actions."""
|
| 201 |
-
if self.game_over:
|
| 202 |
-
return []
|
| 203 |
-
|
| 204 |
-
# All actions are always valid in Kuhn Poker
|
| 205 |
-
return [self.CHECK_CALL, self.BET, self.FOLD]
|
| 206 |
-
|
| 207 |
-
def _determine_winner_by_cards(self) -> int:
|
| 208 |
-
"""Determine winner by comparing cards."""
|
| 209 |
-
if self.player1_card > self.player2_card:
|
| 210 |
-
return 1
|
| 211 |
-
else:
|
| 212 |
-
return 2
|
| 213 |
-
|
| 214 |
-
def _calculate_reward(self) -> float:
|
| 215 |
-
"""Calculate reward for the current player."""
|
| 216 |
-
if not self.game_over:
|
| 217 |
-
return 0.0
|
| 218 |
-
|
| 219 |
-
if self.winner == self.current_player:
|
| 220 |
-
# Won - get the pot minus what you put in
|
| 221 |
-
if self.current_player == 1:
|
| 222 |
-
return float(self.pot - self.player1_bet)
|
| 223 |
-
else:
|
| 224 |
-
return float(self.pot - self.player2_bet)
|
| 225 |
-
else:
|
| 226 |
-
# Lost - lose what you put in
|
| 227 |
-
if self.current_player == 1:
|
| 228 |
-
return float(-self.player1_bet)
|
| 229 |
-
else:
|
| 230 |
-
return float(-self.player2_bet)
|
| 231 |
-
|
| 232 |
-
def render(self) -> Optional[np.ndarray]:
|
| 233 |
-
"""Render the game state."""
|
| 234 |
-
if self.render_mode == "human":
|
| 235 |
-
self._render_human()
|
| 236 |
-
elif self.render_mode == "rgb_array":
|
| 237 |
-
return self._render_rgb_array()
|
| 238 |
-
|
| 239 |
-
def _render_human(self):
|
| 240 |
-
"""Print the game state to console."""
|
| 241 |
-
print("\n" + "="*40)
|
| 242 |
-
print("KUHN POKER")
|
| 243 |
-
print("="*40)
|
| 244 |
-
print(f"Player 1 Card: {self.CARD_NAMES[self.player1_card]}")
|
| 245 |
-
print(f"Player 2 Card: {self.CARD_NAMES[self.player2_card]}")
|
| 246 |
-
print(f"Pot: {self.pot}")
|
| 247 |
-
print(f"Current Player: {self.current_player}")
|
| 248 |
-
print(f"Betting Round: {self.betting_round}")
|
| 249 |
-
|
| 250 |
-
if self.actions_history:
|
| 251 |
-
print("Actions:")
|
| 252 |
-
for player, action in self.actions_history:
|
| 253 |
-
print(f" Player {player}: {self.ACTION_NAMES[action]}")
|
| 254 |
-
|
| 255 |
-
if self.game_over:
|
| 256 |
-
print(f"Game Over! Winner: Player {self.winner}")
|
| 257 |
-
print("="*40)
|
| 258 |
-
|
| 259 |
-
def _render_rgb_array(self) -> np.ndarray:
|
| 260 |
-
"""Render as RGB array for visualization."""
|
| 261 |
-
# Simple RGB representation (placeholder)
|
| 262 |
-
rgb = np.zeros((100, 100, 3), dtype=np.uint8)
|
| 263 |
-
|
| 264 |
-
# Color based on current player's card
|
| 265 |
-
if self.current_player == 1:
|
| 266 |
-
card_value = self.player1_card
|
| 267 |
-
else:
|
| 268 |
-
card_value = self.player2_card
|
| 269 |
-
|
| 270 |
-
# Different colors for different cards
|
| 271 |
-
if card_value == self.JACK:
|
| 272 |
-
rgb[:, :] = [255, 0, 0] # Red for Jack
|
| 273 |
-
elif card_value == self.QUEEN:
|
| 274 |
-
rgb[:, :] = [0, 255, 0] # Green for Queen
|
| 275 |
-
else: # King
|
| 276 |
-
rgb[:, :] = [0, 0, 255] # Blue for King
|
| 277 |
-
|
| 278 |
-
return rgb
|
| 279 |
-
|
| 280 |
-
def get_action_mask(self) -> np.ndarray:
|
| 281 |
-
"""Get mask of valid actions (1 for valid, 0 for invalid)."""
|
| 282 |
-
mask = np.zeros(3, dtype=np.int8)
|
| 283 |
-
for action in self._get_valid_actions():
|
| 284 |
-
mask[action] = 1
|
| 285 |
-
return mask
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
def create_kuhn_poker_env() -> KuhnPokerEnv:
|
| 289 |
-
"""Factory function to create a Kuhn Poker environment."""
|
| 290 |
-
return KuhnPokerEnv()
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
if __name__ == "__main__":
|
| 294 |
-
# Test the environment
|
| 295 |
-
env = KuhnPokerEnv(render_mode="human")
|
| 296 |
-
|
| 297 |
-
# Play a simple game
|
| 298 |
-
obs, info = env.reset()
|
| 299 |
-
print("Initial state:")
|
| 300 |
-
env.render()
|
| 301 |
-
|
| 302 |
-
# Simulate some moves
|
| 303 |
-
while not env.game_over:
|
| 304 |
-
valid_actions = env._get_valid_actions()
|
| 305 |
-
action = random.choice(valid_actions)
|
| 306 |
-
|
| 307 |
-
obs, reward, terminated, truncated, info = env.step(action)
|
| 308 |
-
print(f"\nPlayer {env.current_player if not env.game_over else 'Previous'} action: {env.ACTION_NAMES[action]}")
|
| 309 |
-
print(f"Reward: {reward}")
|
| 310 |
-
env.render()
|
| 311 |
-
|
| 312 |
-
if terminated:
|
| 313 |
-
print(f"Game terminated! Final reward: {reward}")
|
| 314 |
-
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/games/tictactoe.py
DELETED
|
@@ -1,237 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
TicTacToe Game Environment
|
| 3 |
-
|
| 4 |
-
A simple TicTacToe implementation using Gymnasium for SPIRAL training.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import gymnasium as gym
|
| 8 |
-
import numpy as np
|
| 9 |
-
from gymnasium import spaces
|
| 10 |
-
from typing import Tuple, Dict, Any, Optional
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class TicTacToeEnv(gym.Env):
|
| 14 |
-
"""
|
| 15 |
-
TicTacToe environment for SPIRAL training.
|
| 16 |
-
|
| 17 |
-
- 3x3 grid
|
| 18 |
-
- Players alternate turns (1 and -1)
|
| 19 |
-
- Action space: 9 positions (0-8)
|
| 20 |
-
- Observation space: 3x3 grid with values {-1, 0, 1}
|
| 21 |
-
- Reward: +1 for win, -1 for loss, 0 for draw/ongoing
|
| 22 |
-
"""
|
| 23 |
-
|
| 24 |
-
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 1}
|
| 25 |
-
|
| 26 |
-
def __init__(self, render_mode: Optional[str] = None):
|
| 27 |
-
super().__init__()
|
| 28 |
-
|
| 29 |
-
# 3x3 grid, each cell can be -1 (player 2), 0 (empty), or 1 (player 1)
|
| 30 |
-
self.observation_space = spaces.Box(
|
| 31 |
-
low=-1, high=1, shape=(3, 3), dtype=np.int8
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
# 9 possible actions (positions 0-8)
|
| 35 |
-
self.action_space = spaces.Discrete(9)
|
| 36 |
-
|
| 37 |
-
self.render_mode = render_mode
|
| 38 |
-
self.reset()
|
| 39 |
-
|
| 40 |
-
def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]:
|
| 41 |
-
"""Reset the game to initial state."""
|
| 42 |
-
super().reset(seed=seed)
|
| 43 |
-
|
| 44 |
-
# Initialize empty board
|
| 45 |
-
self.board = np.zeros((3, 3), dtype=np.int8)
|
| 46 |
-
self.current_player = 1 # Player 1 starts
|
| 47 |
-
self.game_over = False
|
| 48 |
-
self.winner = None
|
| 49 |
-
self.move_count = 0
|
| 50 |
-
|
| 51 |
-
observation = self._get_observation()
|
| 52 |
-
info = self._get_info()
|
| 53 |
-
|
| 54 |
-
return observation, info
|
| 55 |
-
|
| 56 |
-
def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
|
| 57 |
-
"""
|
| 58 |
-
Execute one step in the environment.
|
| 59 |
-
|
| 60 |
-
Args:
|
| 61 |
-
action: Position to place mark (0-8)
|
| 62 |
-
|
| 63 |
-
Returns:
|
| 64 |
-
observation, reward, terminated, truncated, info
|
| 65 |
-
"""
|
| 66 |
-
if self.game_over:
|
| 67 |
-
raise ValueError("Game is already over. Call reset() to start new game.")
|
| 68 |
-
|
| 69 |
-
# Convert action to row, col
|
| 70 |
-
row, col = divmod(action, 3)
|
| 71 |
-
|
| 72 |
-
# Check if move is valid
|
| 73 |
-
if self.board[row, col] != 0:
|
| 74 |
-
# Invalid move - penalize and end game
|
| 75 |
-
reward = -1.0
|
| 76 |
-
terminated = True
|
| 77 |
-
self.game_over = True
|
| 78 |
-
info = self._get_info()
|
| 79 |
-
info["invalid_move"] = True
|
| 80 |
-
return self._get_observation(), reward, terminated, False, info
|
| 81 |
-
|
| 82 |
-
# Make the move
|
| 83 |
-
self.board[row, col] = self.current_player
|
| 84 |
-
self.move_count += 1
|
| 85 |
-
|
| 86 |
-
# Check for win
|
| 87 |
-
winner = self._check_winner()
|
| 88 |
-
if winner is not None:
|
| 89 |
-
self.game_over = True
|
| 90 |
-
self.winner = winner
|
| 91 |
-
reward = 1.0 if winner == self.current_player else -1.0
|
| 92 |
-
terminated = True
|
| 93 |
-
elif self.move_count >= 9:
|
| 94 |
-
# Draw
|
| 95 |
-
self.game_over = True
|
| 96 |
-
reward = 0.0
|
| 97 |
-
terminated = True
|
| 98 |
-
else:
|
| 99 |
-
# Game continues
|
| 100 |
-
reward = 0.0
|
| 101 |
-
terminated = False
|
| 102 |
-
self.current_player *= -1 # Switch player
|
| 103 |
-
|
| 104 |
-
observation = self._get_observation()
|
| 105 |
-
info = self._get_info()
|
| 106 |
-
|
| 107 |
-
return observation, reward, terminated, False, info
|
| 108 |
-
|
| 109 |
-
def _get_observation(self) -> np.ndarray:
|
| 110 |
-
"""Get current board state."""
|
| 111 |
-
return self.board.copy()
|
| 112 |
-
|
| 113 |
-
def _get_info(self) -> Dict[str, Any]:
|
| 114 |
-
"""Get additional info about the game state."""
|
| 115 |
-
return {
|
| 116 |
-
"current_player": self.current_player,
|
| 117 |
-
"game_over": self.game_over,
|
| 118 |
-
"winner": self.winner,
|
| 119 |
-
"move_count": self.move_count,
|
| 120 |
-
"valid_actions": self._get_valid_actions()
|
| 121 |
-
}
|
| 122 |
-
|
| 123 |
-
def _get_valid_actions(self) -> list:
|
| 124 |
-
"""Get list of valid actions (empty positions)."""
|
| 125 |
-
valid_actions = []
|
| 126 |
-
for i in range(9):
|
| 127 |
-
row, col = divmod(i, 3)
|
| 128 |
-
if self.board[row, col] == 0:
|
| 129 |
-
valid_actions.append(i)
|
| 130 |
-
return valid_actions
|
| 131 |
-
|
| 132 |
-
def _check_winner(self) -> Optional[int]:
|
| 133 |
-
"""
|
| 134 |
-
Check if there's a winner.
|
| 135 |
-
|
| 136 |
-
Returns:
|
| 137 |
-
1 if player 1 wins, -1 if player 2 wins, None if no winner
|
| 138 |
-
"""
|
| 139 |
-
# Check rows
|
| 140 |
-
for row in range(3):
|
| 141 |
-
if abs(self.board[row, :].sum()) == 3:
|
| 142 |
-
return self.board[row, 0]
|
| 143 |
-
|
| 144 |
-
# Check columns
|
| 145 |
-
for col in range(3):
|
| 146 |
-
if abs(self.board[:, col].sum()) == 3:
|
| 147 |
-
return self.board[0, col]
|
| 148 |
-
|
| 149 |
-
# Check diagonals
|
| 150 |
-
if abs(self.board.diagonal().sum()) == 3:
|
| 151 |
-
return self.board[0, 0]
|
| 152 |
-
|
| 153 |
-
if abs(np.fliplr(self.board).diagonal().sum()) == 3:
|
| 154 |
-
return self.board[0, 2]
|
| 155 |
-
|
| 156 |
-
return None
|
| 157 |
-
|
| 158 |
-
def render(self) -> Optional[np.ndarray]:
|
| 159 |
-
"""Render the game state."""
|
| 160 |
-
if self.render_mode == "human":
|
| 161 |
-
self._render_human()
|
| 162 |
-
elif self.render_mode == "rgb_array":
|
| 163 |
-
return self._render_rgb_array()
|
| 164 |
-
|
| 165 |
-
def _render_human(self):
|
| 166 |
-
"""Print the board to console."""
|
| 167 |
-
print("\n" + "="*13)
|
| 168 |
-
for row in range(3):
|
| 169 |
-
print("|", end="")
|
| 170 |
-
for col in range(3):
|
| 171 |
-
cell = self.board[row, col]
|
| 172 |
-
if cell == 1:
|
| 173 |
-
print(" X ", end="|")
|
| 174 |
-
elif cell == -1:
|
| 175 |
-
print(" O ", end="|")
|
| 176 |
-
else:
|
| 177 |
-
print(f" {row*3 + col} ", end="|")
|
| 178 |
-
print()
|
| 179 |
-
print("="*13)
|
| 180 |
-
|
| 181 |
-
if self.game_over:
|
| 182 |
-
if self.winner is not None:
|
| 183 |
-
winner_symbol = "X" if self.winner == 1 else "O"
|
| 184 |
-
print(f"Game Over! Winner: {winner_symbol}")
|
| 185 |
-
else:
|
| 186 |
-
print("Game Over! It's a draw!")
|
| 187 |
-
|
| 188 |
-
def _render_rgb_array(self) -> np.ndarray:
|
| 189 |
-
"""Render as RGB array for visualization."""
|
| 190 |
-
# Simple RGB representation
|
| 191 |
-
rgb = np.zeros((3, 3, 3), dtype=np.uint8)
|
| 192 |
-
|
| 193 |
-
# Player 1 (X) = Red, Player 2 (O) = Blue, Empty = White
|
| 194 |
-
for row in range(3):
|
| 195 |
-
for col in range(3):
|
| 196 |
-
if self.board[row, col] == 1:
|
| 197 |
-
rgb[row, col] = [255, 0, 0] # Red
|
| 198 |
-
elif self.board[row, col] == -1:
|
| 199 |
-
rgb[row, col] = [0, 0, 255] # Blue
|
| 200 |
-
else:
|
| 201 |
-
rgb[row, col] = [255, 255, 255] # White
|
| 202 |
-
|
| 203 |
-
return rgb
|
| 204 |
-
|
| 205 |
-
def get_action_mask(self) -> np.ndarray:
|
| 206 |
-
"""Get mask of valid actions (1 for valid, 0 for invalid)."""
|
| 207 |
-
mask = np.zeros(9, dtype=np.int8)
|
| 208 |
-
for action in self._get_valid_actions():
|
| 209 |
-
mask[action] = 1
|
| 210 |
-
return mask
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
def create_tictactoe_env() -> TicTacToeEnv:
|
| 214 |
-
"""Factory function to create a TicTacToe environment."""
|
| 215 |
-
return TicTacToeEnv()
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
if __name__ == "__main__":
|
| 219 |
-
# Test the environment
|
| 220 |
-
env = TicTacToeEnv(render_mode="human")
|
| 221 |
-
|
| 222 |
-
# Play a simple game
|
| 223 |
-
obs, info = env.reset()
|
| 224 |
-
print("Initial state:")
|
| 225 |
-
env.render()
|
| 226 |
-
|
| 227 |
-
# Make some moves
|
| 228 |
-
moves = [0, 4, 1, 3, 2] # X wins
|
| 229 |
-
for move in moves:
|
| 230 |
-
if not env.game_over:
|
| 231 |
-
obs, reward, terminated, truncated, info = env.step(move)
|
| 232 |
-
print(f"\nMove: {move}, Reward: {reward}")
|
| 233 |
-
env.render()
|
| 234 |
-
|
| 235 |
-
if terminated:
|
| 236 |
-
print(f"Game terminated! Final reward: {reward}")
|
| 237 |
-
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/__init__.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
SPIRAL model implementations.
|
| 3 |
-
|
| 4 |
-
This module contains the core SPIRAL model architecture and
|
| 5 |
-
role-conditioned advantage estimation (RAE) components.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from .spiral_model import SpiralModel
|
| 9 |
-
from .rae import RoleConditionedAdvantageEstimator
|
| 10 |
-
from .policy_network import PolicyNetwork
|
| 11 |
-
from .value_network import ValueNetwork
|
| 12 |
-
|
| 13 |
-
__all__ = ["SpiralModel", "RoleConditionedAdvantageEstimator", "PolicyNetwork", "ValueNetwork"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/reasoning/__init__.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Reasoning trace generation and analysis.
|
| 3 |
-
|
| 4 |
-
This module handles the generation of step-by-step reasoning traces
|
| 5 |
-
during gameplay and transfer to non-game tasks.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from .trace_generator import TraceGenerator
|
| 9 |
-
from .chain_of_thought import ChainOfThought
|
| 10 |
-
from .transfer_evaluator import TransferEvaluator
|
| 11 |
-
from .reasoning_utils import ReasoningUtils
|
| 12 |
-
|
| 13 |
-
__all__ = ["TraceGenerator", "ChainOfThought", "TransferEvaluator", "ReasoningUtils"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/__init__.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Training components for SPIRAL.
|
| 3 |
-
|
| 4 |
-
This module implements the self-play training logic using PPO
|
| 5 |
-
with role-conditioned advantage estimation.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from .self_play_trainer import SelfPlayTrainer
|
| 9 |
-
from .ppo_trainer import PPOTrainer
|
| 10 |
-
from .opponent_manager import OpponentManager
|
| 11 |
-
from .training_utils import TrainingUtils
|
| 12 |
-
|
| 13 |
-
__all__ = ["SelfPlayTrainer", "PPOTrainer", "OpponentManager", "TrainingUtils"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/training/train_spiral.py
DELETED
|
@@ -1,58 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
import gymnasium as gym
|
| 4 |
-
import numpy as np
|
| 5 |
-
from stable_baselines3 import PPO
|
| 6 |
-
from stable_baselines3.common.vec_env import DummyVecEnv
|
| 7 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
-
import yaml
|
| 9 |
-
|
| 10 |
-
# Load config
|
| 11 |
-
with open('../../config.yaml', 'r') as f:
|
| 12 |
-
config = yaml.safe_load(f)
|
| 13 |
-
|
| 14 |
-
model_name = config['model']['name']
|
| 15 |
-
max_length = config['model']['max_length']
|
| 16 |
-
|
| 17 |
-
# Load base LLM (quantized)
|
| 18 |
-
model = AutoModelForCausalLM.from_pretrained(model_name, **config['model']['quantization'])
|
| 19 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 20 |
-
|
| 21 |
-
# Custom Policy with RAE (simplified)
|
| 22 |
-
class SpiralPolicy(torch.nn.Module):
|
| 23 |
-
def __init__(self, observation_space, action_space):
|
| 24 |
-
super().__init__()
|
| 25 |
-
self.role_embed = torch.nn.Embedding(2, 64) # 0: player, 1: opponent
|
| 26 |
-
# Add more layers as needed
|
| 27 |
-
|
| 28 |
-
def forward(self, obs, role):
|
| 29 |
-
# Condition on role
|
| 30 |
-
role_emb = self.role_embed(role)
|
| 31 |
-
# Compute policy/value (placeholder)
|
| 32 |
-
return policy, value
|
| 33 |
-
|
| 34 |
-
def train_spiral(game='tictactoe', episodes=1000):
|
| 35 |
-
if game == 'tictactoe':
|
| 36 |
-
from src.games.tictactoe import TicTacToeEnv
|
| 37 |
-
env_fn = lambda: TicTacToeEnv()
|
| 38 |
-
else:
|
| 39 |
-
raise ValueError('Game not supported yet')
|
| 40 |
-
|
| 41 |
-
env = DummyVecEnv([env_fn])
|
| 42 |
-
|
| 43 |
-
# PPO with custom policy
|
| 44 |
-
model = PPO('MlpPolicy', env, verbose=1, learning_rate=0.0003)
|
| 45 |
-
|
| 46 |
-
# Self-play loop (simplified: train against current self)
|
| 47 |
-
for ep in range(episodes):
|
| 48 |
-
model.learn(total_timesteps=1000) # Train batch
|
| 49 |
-
# Simulate self-play by cloning or saving opponent policy
|
| 50 |
-
print(f'Episode {ep}: Trained')
|
| 51 |
-
|
| 52 |
-
# Save model
|
| 53 |
-
os.makedirs('../../models', exist_ok=True)
|
| 54 |
-
model.save('../../models/spiral_tictactoe.zip')
|
| 55 |
-
print('Model saved!')
|
| 56 |
-
|
| 57 |
-
if __name__ == '__main__':
|
| 58 |
-
train_spiral()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_basic.py
DELETED
|
@@ -1,130 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Basic tests for SPIRAL Interactive Reasoning Game Simulator.
|
| 3 |
-
|
| 4 |
-
This module contains fundamental tests to verify the core functionality
|
| 5 |
-
of the SPIRAL system components.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import pytest
|
| 9 |
-
import os
|
| 10 |
-
import sys
|
| 11 |
-
import yaml
|
| 12 |
-
|
| 13 |
-
# Add the src directory to the path for imports
|
| 14 |
-
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))
|
| 15 |
-
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'app'))
|
| 16 |
-
|
| 17 |
-
from app import SpiralApp
|
| 18 |
-
|
| 19 |
-
class TestSpiralApp:
|
| 20 |
-
"""Test cases for the main SPIRAL application."""
|
| 21 |
-
|
| 22 |
-
def test_app_initialization(self):
|
| 23 |
-
"""Test that the app initializes correctly."""
|
| 24 |
-
app = SpiralApp()
|
| 25 |
-
assert app is not None
|
| 26 |
-
assert hasattr(app, 'config')
|
| 27 |
-
assert hasattr(app, 'play_game')
|
| 28 |
-
assert hasattr(app, 'test_reasoning')
|
| 29 |
-
|
| 30 |
-
def test_config_loading(self):
|
| 31 |
-
"""Test configuration loading."""
|
| 32 |
-
app = SpiralApp()
|
| 33 |
-
assert 'interface' in app.config
|
| 34 |
-
assert 'games' in app.config
|
| 35 |
-
assert app.config['interface']['title'] is not None
|
| 36 |
-
|
| 37 |
-
def test_play_game_basic(self):
|
| 38 |
-
"""Test basic game play functionality."""
|
| 39 |
-
app = SpiralApp()
|
| 40 |
-
|
| 41 |
-
# Test with valid input
|
| 42 |
-
state, response, trace = app.play_game("kuhn_poker", "bet", "")
|
| 43 |
-
assert state is not None
|
| 44 |
-
assert response is not None
|
| 45 |
-
assert trace is not None
|
| 46 |
-
assert "bet" in state
|
| 47 |
-
|
| 48 |
-
# Test with empty input
|
| 49 |
-
state, response, trace = app.play_game("kuhn_poker", "", "")
|
| 50 |
-
assert "Please enter a move!" in response
|
| 51 |
-
|
| 52 |
-
def test_reasoning_basic(self):
|
| 53 |
-
"""Test basic reasoning functionality."""
|
| 54 |
-
app = SpiralApp()
|
| 55 |
-
|
| 56 |
-
# Test with valid input
|
| 57 |
-
response, trace = app.test_reasoning("What is 2+2?", "math")
|
| 58 |
-
assert response is not None
|
| 59 |
-
assert trace is not None
|
| 60 |
-
assert "2+2" in response
|
| 61 |
-
|
| 62 |
-
# Test with empty input
|
| 63 |
-
response, trace = app.test_reasoning("", "math")
|
| 64 |
-
assert "Please enter a reasoning prompt!" in response
|
| 65 |
-
|
| 66 |
-
def test_interface_creation(self):
|
| 67 |
-
"""Test that the Gradio interface can be created."""
|
| 68 |
-
app = SpiralApp()
|
| 69 |
-
demo = app.create_interface()
|
| 70 |
-
assert demo is not None
|
| 71 |
-
|
| 72 |
-
class TestConfiguration:
|
| 73 |
-
"""Test cases for configuration management."""
|
| 74 |
-
|
| 75 |
-
def test_config_file_structure(self):
|
| 76 |
-
"""Test that config.yaml has the expected structure."""
|
| 77 |
-
config_path = os.path.join(os.path.dirname(__file__), '..', 'config.yaml')
|
| 78 |
-
|
| 79 |
-
if os.path.exists(config_path):
|
| 80 |
-
with open(config_path, 'r') as f:
|
| 81 |
-
config = yaml.safe_load(f)
|
| 82 |
-
|
| 83 |
-
# Check required sections
|
| 84 |
-
assert 'model' in config
|
| 85 |
-
assert 'games' in config
|
| 86 |
-
assert 'training' in config
|
| 87 |
-
assert 'reasoning' in config
|
| 88 |
-
assert 'interface' in config
|
| 89 |
-
|
| 90 |
-
# Check model configuration
|
| 91 |
-
assert 'name' in config['model']
|
| 92 |
-
assert 'max_length' in config['model']
|
| 93 |
-
|
| 94 |
-
# Check games configuration
|
| 95 |
-
assert 'kuhn_poker' in config['games']
|
| 96 |
-
assert 'tictactoe' in config['games']
|
| 97 |
-
|
| 98 |
-
class TestProjectStructure:
|
| 99 |
-
"""Test cases for project structure and imports."""
|
| 100 |
-
|
| 101 |
-
def test_src_directory_structure(self):
|
| 102 |
-
"""Test that the src directory has the expected structure."""
|
| 103 |
-
src_path = os.path.join(os.path.dirname(__file__), '..', 'src')
|
| 104 |
-
|
| 105 |
-
# Check that required directories exist
|
| 106 |
-
assert os.path.exists(os.path.join(src_path, 'games'))
|
| 107 |
-
assert os.path.exists(os.path.join(src_path, 'models'))
|
| 108 |
-
assert os.path.exists(os.path.join(src_path, 'training'))
|
| 109 |
-
assert os.path.exists(os.path.join(src_path, 'reasoning'))
|
| 110 |
-
|
| 111 |
-
# Check that __init__.py files exist
|
| 112 |
-
assert os.path.exists(os.path.join(src_path, '__init__.py'))
|
| 113 |
-
assert os.path.exists(os.path.join(src_path, 'games', '__init__.py'))
|
| 114 |
-
assert os.path.exists(os.path.join(src_path, 'models', '__init__.py'))
|
| 115 |
-
assert os.path.exists(os.path.join(src_path, 'training', '__init__.py'))
|
| 116 |
-
assert os.path.exists(os.path.join(src_path, 'reasoning', '__init__.py'))
|
| 117 |
-
|
| 118 |
-
def test_required_files_exist(self):
|
| 119 |
-
"""Test that required project files exist."""
|
| 120 |
-
project_root = os.path.join(os.path.dirname(__file__), '..')
|
| 121 |
-
|
| 122 |
-
# Check essential files
|
| 123 |
-
assert os.path.exists(os.path.join(project_root, 'requirements.txt'))
|
| 124 |
-
assert os.path.exists(os.path.join(project_root, 'README.md'))
|
| 125 |
-
assert os.path.exists(os.path.join(project_root, 'config.yaml'))
|
| 126 |
-
assert os.path.exists(os.path.join(project_root, '.gitignore'))
|
| 127 |
-
assert os.path.exists(os.path.join(project_root, 'app', 'app.py'))
|
| 128 |
-
|
| 129 |
-
if __name__ == "__main__":
|
| 130 |
-
pytest.main([__file__])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_games.py
DELETED
|
@@ -1,78 +0,0 @@
|
|
| 1 |
-
import pytest
|
| 2 |
-
import numpy as np
|
| 3 |
-
import sys
|
| 4 |
-
import os
|
| 5 |
-
|
| 6 |
-
# Add src to path to allow importing TicTacToeEnv
|
| 7 |
-
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
|
| 8 |
-
|
| 9 |
-
from games.tictactoe import TicTacToeEnv
|
| 10 |
-
|
| 11 |
-
@pytest.fixture
|
| 12 |
-
def env():
|
| 13 |
-
"""Fixture to create a fresh TicTacToeEnv for each test."""
|
| 14 |
-
return TicTacToeEnv()
|
| 15 |
-
|
| 16 |
-
def test_initial_state(env):
|
| 17 |
-
"""Test the initial state of the board."""
|
| 18 |
-
assert np.all(env.board == np.zeros((3, 3)))
|
| 19 |
-
assert env.current_player == 1
|
| 20 |
-
assert not env.game_over
|
| 21 |
-
|
| 22 |
-
def test_player_move(env):
|
| 23 |
-
"""Test a valid player move."""
|
| 24 |
-
env.step(0)
|
| 25 |
-
assert env.board[0, 0] == 1
|
| 26 |
-
assert env.current_player == -1
|
| 27 |
-
assert not env.game_over
|
| 28 |
-
|
| 29 |
-
def test_invalid_move(env):
|
| 30 |
-
"""Test making an invalid move on an occupied cell."""
|
| 31 |
-
env.step(0)
|
| 32 |
-
with pytest.raises(ValueError):
|
| 33 |
-
env.step(0)
|
| 34 |
-
|
| 35 |
-
def test_win_condition_row(env):
|
| 36 |
-
"""Test a win condition in a row."""
|
| 37 |
-
env.board = np.array([[1, 1, 1], [0, -1, 0], [-1, 0, 0]])
|
| 38 |
-
assert env._check_winner(1)
|
| 39 |
-
assert not env._check_winner(-1)
|
| 40 |
-
|
| 41 |
-
def test_win_condition_col(env):
|
| 42 |
-
"""Test a win condition in a column."""
|
| 43 |
-
env.board = np.array([[-1, 1, 0], [-1, 1, 0], [-1, 0, 0]])
|
| 44 |
-
assert not env._check_winner(1)
|
| 45 |
-
assert env._check_winner(-1)
|
| 46 |
-
|
| 47 |
-
def test_win_condition_diag(env):
|
| 48 |
-
"""Test a win condition on a diagonal."""
|
| 49 |
-
env.board = np.array([[1, 0, -1], [0, 1, -1], [0, 0, 1]])
|
| 50 |
-
assert env._check_winner(1)
|
| 51 |
-
|
| 52 |
-
def test_draw_condition(env):
|
| 53 |
-
"""Test a draw condition."""
|
| 54 |
-
env.board = np.array([[1, -1, 1], [1, -1, 1], [-1, 1, -1]])
|
| 55 |
-
assert env._is_draw()
|
| 56 |
-
assert not env._check_winner(1)
|
| 57 |
-
assert not env._check_winner(-1)
|
| 58 |
-
|
| 59 |
-
def test_game_over_on_win(env):
|
| 60 |
-
"""Test that the game_over flag is set on a win."""
|
| 61 |
-
env.step(0) # P1
|
| 62 |
-
env.step(3) # P2
|
| 63 |
-
env.step(1) # P1
|
| 64 |
-
env.step(4) # P2
|
| 65 |
-
_, _, terminated, _, _ = env.step(2) # P1 wins
|
| 66 |
-
assert terminated
|
| 67 |
-
assert env.game_over
|
| 68 |
-
assert env.winner == 1
|
| 69 |
-
|
| 70 |
-
def test_reset(env):
|
| 71 |
-
"""Test if the environment resets correctly."""
|
| 72 |
-
env.step(0)
|
| 73 |
-
env.step(1)
|
| 74 |
-
env.reset()
|
| 75 |
-
assert np.all(env.board == np.zeros((3, 3)))
|
| 76 |
-
assert env.current_player == 1
|
| 77 |
-
assert not env.game_over
|
| 78 |
-
assert env.winner is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|