Kaushik Rajan commited on
Commit
842d62b
·
1 Parent(s): 47b257f

Simplify codebase: focused SPIRAL TicTacToe demo with key research concepts

Browse files
README.md CHANGED
@@ -11,104 +11,83 @@ license: apache-2.0
11
  short_description: An interactive reasoning game simulator
12
  ---
13
 
14
- # SPIRAL: Interactive Reasoning Game Simulator
15
 
16
- A practical, interactive tool based on the SPIRAL paper ("Self-Play on Zero-Sum Games Incentivizes Reasoning via Multi-Agent Multi-Turn Reinforcement Learning") deployed on Hugging Face Spaces.
17
 
18
- ## Overview
19
 
20
- This tool demonstrates how self-play training on zero-sum games can improve AI reasoning capabilities. Users can:
21
 
22
- - **Play Games**: Engage with AI in games like Kuhn Poker and TicTacToe
23
- - **View Reasoning**: See step-by-step AI reasoning traces during gameplay
24
- - **Test Transfer**: Evaluate AI's reasoning skills on math problems and logic puzzles
25
- - **Learn**: Understand AI decision-making through interactive visualizations
26
 
27
- ## Features
28
 
29
- ### For Non-Technical Users
30
- - Simple web interface for playing games
31
- - Visual reasoning explanations
32
- - Educational tutorials about AI thinking
33
- - No setup required - runs in browser
34
 
35
- ### For Technical Users
36
- - Access to model weights and training scripts
37
- - API endpoints for extending the system
38
- - Custom game integration capabilities
39
- - Fine-tuning examples and documentation
40
 
41
- ## Project Structure
 
 
 
42
 
43
- ```
44
- SPIRAL/
45
- ├── src/ # Core implementation
46
- │ ├── games/ # Game environments
47
- │ ├── models/ # SPIRAL model implementation
48
- │ ├── training/ # Self-play training logic
49
- │ └── reasoning/ # Reasoning trace generation
50
- ├── models/ # Trained model weights
51
- ├── data/ # Game datasets and benchmarks
52
- ├── app/ # Gradio web interface
53
- ├── tests/ # Unit and integration tests
54
- └── docs/ # Documentation and tutorials
55
- ```
56
-
57
- ## Technology Stack
58
-
59
- - **Backend**: Python 3.8+
60
- - **ML Framework**: PyTorch, Transformers
61
- - **RL Library**: Gymnasium, Stable Baselines3
62
- - **Web Interface**: Gradio
63
- - **Base Model**: Qwen-4B from Hugging Face
64
- - **Deployment**: Hugging Face Spaces
65
 
66
- ## Development Phases
 
 
 
 
67
 
68
- 1. **Research and Planning** ✅
69
- 2. **Implementation** 🔄
70
- 3. **Testing and Optimization** 📋
71
- 4. **Deployment and Documentation** 📋
72
- 5. **Maintenance and Iteration** 📋
73
 
74
- ## Getting Started
 
 
75
 
76
- ### Prerequisites
77
- - Python 3.8+
78
- - PyTorch
79
- - Hugging Face account (for model access)
80
 
81
- ### Installation
82
- ```bash
83
- pip install -r requirements.txt
84
- ```
85
 
86
- ### Quick Start
87
- ```bash
88
- python app/app.py
89
- ```
90
 
91
- ## Citation
92
 
93
- If you use this tool in your research, please cite the original SPIRAL paper:
94
 
95
- ```bibtex
96
- @article{spiral2024,
97
- title={Self-Play on Zero-Sum Games Incentivizes Reasoning via Multi-Agent Multi-Turn Reinforcement Learning},
98
- author={[Authors]},
99
- journal={[Journal]},
100
- year={2024}
101
- }
102
- ```
103
 
104
- ## License
105
 
106
- This project is licensed under the MIT License - see the LICENSE file for details.
 
 
 
 
107
 
108
- ## Contributing
109
 
110
- We welcome contributions! Please see CONTRIBUTING.md for guidelines.
 
 
 
111
 
112
- ## Support
113
 
114
- For issues and questions, please use the GitHub Issues or contact us via Hugging Face Spaces.
 
11
  short_description: An interactive reasoning game simulator
12
  ---
13
 
14
+ # SPIRAL: Self-Play Reasoning Demo
15
 
16
+ **Demonstrating how strategic reasoning emerges from self-play in zero-sum games**
17
 
18
+ Based on: *"Self-Play in Zero-Sum Games Incentivizes Reasoning via Multi-Agent Multi-Turn Reinforcement Learning"*
19
 
20
+ ## 🎮 Interactive Demo
21
 
22
+ This simplified demo showcases the key concepts from the SPIRAL research through an interactive TicTacToe game. Watch as the AI demonstrates strategic reasoning using minimax tree search and explains its decision-making process.
 
 
 
23
 
24
+ ## 🧠 Key Concepts Demonstrated
25
 
26
+ ### Strategic Reasoning
27
+ - AI uses minimax tree search to evaluate all possible future moves
28
+ - Demonstrates how optimal strategies emerge from competitive gameplay
29
+ - Shows explicit reasoning explanations for each move
 
30
 
31
+ ### Self-Play Learning Principles
32
+ - Zero-sum games create competitive pressure that incentivizes strategic thinking
33
+ - Multi-agent interactions naturally develop intelligent behavior
34
+ - Strategic patterns emerge from repeated competitive gameplay
 
35
 
36
+ ### Tree Search & Planning
37
+ - Minimax algorithm demonstrates formalized strategic reasoning
38
+ - Look-ahead planning to evaluate future game states
39
+ - Optimal decision-making under competitive constraints
40
 
41
+ ## 🚀 Running the Demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ ### Local Setup
44
+ ```bash
45
+ # Clone the repository
46
+ git clone https://huggingface.co/spaces/kaushikvr06/reasoning-simulator
47
+ cd reasoning-simulator
48
 
49
+ # Install dependencies
50
+ pip install -r requirements.txt
 
 
 
51
 
52
+ # Run the demo
53
+ python app.py
54
+ ```
55
 
56
+ ### Hugging Face Spaces
57
+ The demo is deployed and ready to use at:
58
+ [https://huggingface.co/spaces/kaushikvr06/reasoning-simulator](https://huggingface.co/spaces/kaushikvr06/reasoning-simulator)
 
59
 
60
+ ## 📝 How It Works
 
 
 
61
 
62
+ 1. **Human Move**: Click any square to make your move as X
63
+ 2. **AI Analysis**: The AI analyzes the game tree using minimax search
64
+ 3. **Strategic Reasoning**: Watch the AI explain its decision-making process
65
+ 4. **Optimal Play**: The AI chooses the move that maximizes its winning probability
66
 
67
+ ## 🔬 Research Connection
68
 
69
+ This demo illustrates core findings from the SPIRAL methodology:
70
 
71
+ - **Zero-sum competitive environments** naturally incentivize strategic reasoning
72
+ - **Multi-turn planning** emerges from the need to anticipate opponent moves
73
+ - **Strategic reasoning capabilities** developed through self-play can transfer to general reasoning tasks
74
+ - **Tree search algorithms** formalize the strategic reasoning process
 
 
 
 
75
 
76
+ ## 🎯 Educational Value
77
 
78
+ Perfect for:
79
+ - Understanding strategic AI decision-making
80
+ - Learning about game theory and minimax algorithms
81
+ - Exploring the connection between competition and intelligence
82
+ - Visualizing how reasoning emerges from strategic gameplay
83
 
84
+ ## 📊 Technical Details
85
 
86
+ - **Game Environment**: Clean TicTacToe implementation with proper state management
87
+ - **AI Strategy**: Minimax algorithm with optimal move selection
88
+ - **Reasoning Display**: Generated explanations of AI strategic thinking
89
+ - **Interactive Interface**: Real-time game state updates and move explanations
90
 
91
+ ---
92
 
93
+ *Experience firsthand how strategic reasoning emerges from competitive self-play!*
app.py CHANGED
@@ -1,103 +1,179 @@
1
  """
2
  SPIRAL: Interactive Reasoning Game Simulator
3
 
4
- Main Gradio application for the SPIRAL demo on Hugging Face Spaces.
 
 
5
  """
6
 
7
  import gradio as gr
8
  import numpy as np
9
  import random
10
- import os
11
- import sys
12
- import traceback
13
- import yaml
14
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
15
- import torch
16
- import spaces
17
-
18
- # Add src to path for imports
19
- current_dir = os.path.dirname(os.path.abspath(__file__))
20
- src_path = os.path.join(current_dir, 'src')
21
- sys.path.insert(0, src_path)
22
-
23
- print(f"🔍 Current directory: {current_dir}")
24
- print(f"🔍 Source path: {src_path}")
25
- print(f"🔍 Python path: {sys.path[:3]}") # Show first 3 entries
26
-
27
- # Check if src directory exists
28
- if os.path.exists(src_path):
29
- print(f"✅ Source directory exists: {src_path}")
30
- games_path = os.path.join(src_path, 'games')
31
- if os.path.exists(games_path):
32
- print(f"✅ Games directory exists: {games_path}")
33
- print(f"📁 Games directory contents: {os.listdir(games_path)}")
34
- else:
35
- print(f"❌ Games directory not found: {games_path}")
36
- else:
37
- print(f"❌ Source directory not found: {src_path}")
38
-
39
- # Try multiple import approaches
40
- GAMES_AVAILABLE = False
41
- tictactoe_env = None
42
- kuhn_env = None
43
-
44
- try:
45
- # Method 1: Direct import from games module
46
- print("🔄 Attempting Method 1: Direct import from games")
47
- from games import TicTacToeEnv, KuhnPokerEnv
48
- print("✅ Method 1 successful: Imported from games module")
49
- GAMES_AVAILABLE = True
50
- except ImportError as e:
51
- print(f"❌ Method 1 failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- try:
54
- # Method 2: Import from src.games
55
- print("🔄 Attempting Method 2: Import from src.games")
56
- from src.games import TicTacToeEnv, KuhnPokerEnv
57
- print(" Method 2 successful: Imported from src.games")
58
- GAMES_AVAILABLE = True
59
- except ImportError as e:
60
- print(f"❌ Method 2 failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- try:
63
- # Method 3: Direct file imports
64
- print("🔄 Attempting Method 3: Direct file imports")
65
- sys.path.insert(0, games_path)
66
- from tictactoe import TicTacToeEnv
67
- from kuhn_poker import KuhnPokerEnv
68
- print("✅ Method 3 successful: Direct file imports")
69
- GAMES_AVAILABLE = True
70
- except Exception as e:
71
- print(f"❌ Method 3 failed: {e}")
72
- print("📋 Full traceback:", traceback.format_exc())
73
-
74
- if GAMES_AVAILABLE:
75
- print("🎮 Game modules successfully imported!")
76
- try:
77
- # Test instantiation
78
- tictactoe_env = TicTacToeEnv()
79
- # kuhn_env = KuhnPokerEnv() # No longer needed
80
- print("✅ Game environment created successfully")
81
- except Exception as e:
82
- print(f"❌ Error creating game environment: {e}")
83
- print("📋 Full traceback:", traceback.format_exc())
84
- GAMES_AVAILABLE = False
85
- else:
86
- print("❌ All import methods failed - using fallback interface")
87
-
88
- # Initialize model and tokenizer as global variables
89
- model = None
90
- tokenizer = None
91
-
92
- def generate_reasoning(prompt):
93
- """Generate reasoning trace using Qwen model."""
94
- global model, tokenizer
95
- if model is None or tokenizer is None:
96
- return "Error: Model not loaded. Please wait for the GPU to be ready."
97
 
98
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
99
- outputs = model.generate(**inputs, max_length=150, do_sample=True, temperature=0.7)
100
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
101
 
102
 
103
  def create_interface():
@@ -155,325 +231,201 @@ def create_interface():
155
  }
156
  """
157
 
158
- with gr.Blocks(title="SPIRAL: Interactive Reasoning Game Simulator", theme=gr.themes.Soft(), css=css) as demo:
159
- gr.Markdown("# 🎮 SPIRAL: Interactive Reasoning Game Simulator")
160
- gr.Markdown("Play TicTacToe against an AI, see its step-by-step reasoning, and learn how it thinks!")
161
-
162
- if GAMES_AVAILABLE:
163
-
164
- def update_board_buttons():
165
- """Create a list of gr.Button updates from the current board state."""
166
- updates = []
167
- for i in range(9):
168
- row, col = divmod(i, 3)
169
- cell = tictactoe_env.board[row, col]
170
- val = ""
171
- interactive = True
172
- if cell == 1:
173
- val = '❌'
174
- interactive = False
175
- elif cell == -1:
176
- val = '⭕'
177
- interactive = False
178
-
179
- if tictactoe_env.game_over:
180
- interactive = False
181
 
182
- updates.append(gr.Button(value=val, interactive=interactive))
183
- return updates
184
 
185
- # TicTacToe specific functions (no longer need get_tictactoe_board_html)
186
-
187
- ttt_stats = gr.State({'wins': 0, 'losses': 0, 'draws': 0})
188
-
189
- def minimax(board, player):
190
- """Minimax algorithm to find the best move."""
 
 
 
 
191
 
192
- # Base cases
193
- winner = tictactoe_env._check_winner()
194
- if winner == 1: # Human wins
195
- return -10, None
196
- elif winner == -1: # AI wins
197
- return 10, None
198
- elif tictactoe_env._is_draw():
199
- return 0, None
200
-
201
- best_move = None
202
- if player == -1: # AI is player -1 (O), maximizing player
203
- best_score = -float('inf')
204
- for move in tictactoe_env._get_valid_actions():
205
- row, col = divmod(move, 3)
206
- board[row, col] = -1
207
- score, _ = minimax(board.copy(), 1)
208
- board[row, col] = 0 # Undo move
209
- if score > best_score:
210
- best_score = score
211
- best_move = move
212
- else: # Human is player 1 (X), minimizing player
213
- best_score = float('inf')
214
- for move in tictactoe_env._get_valid_actions():
215
- row, col = divmod(move, 3)
216
- board[row, col] = 1
217
- score, _ = minimax(board.copy(), -1)
218
- board[row, col] = 0 # Undo move
219
- if score < best_score:
220
- best_score = score
221
- best_move = move
222
- return best_score, best_move
223
-
224
- def play_tictactoe(position, stats):
225
- """Play a TicTacToe move and yield updates for the button grid."""
226
- if tictactoe_env.game_over:
227
- yield *update_board_buttons(), "Game is over! Click 'New Game' to start again.", "", stats
228
  return
229
 
230
- try:
231
- position = int(position)
232
-
233
- # Human move
234
- tictactoe_env.step(position)
235
-
236
- if tictactoe_env.game_over:
237
- winner = "You" if tictactoe_env.winner == 1 else "AI" if tictactoe_env.winner == -1 else "Draw"
238
- if winner == "You": stats['wins'] += 1
239
- elif winner == "AI": stats['losses'] += 1
240
- else: stats['draws'] += 1
241
- yield *update_board_buttons(), f"Game Over! {winner} won!", "", stats
242
  return
 
243
 
244
- # Show "thinking" indicator
245
- yield *update_board_buttons(), "AI is thinking...", "🧠...", stats
246
-
247
- # AI move
248
- _, ai_action = minimax(tictactoe_env.board.copy(), -1)
249
- if ai_action is None:
250
- valid_actions = tictactoe_env._get_valid_actions()
251
- if not valid_actions:
252
- yield *update_board_buttons(), "Game is a draw!", "", stats
253
- return
254
- ai_action = random.choice(valid_actions)
255
-
256
- reasoning_prompt = f"In TicTacToe, the board is currently: {tictactoe_env.board.flatten().tolist()}. The human player (X) played position {position}. I am the AI (O). The available moves are {tictactoe_env._get_valid_actions()}. I have analyzed the game tree using minimax and determined the optimal move is {ai_action}. Explain my strategy."
257
- reasoning = generate_reasoning(reasoning_prompt)
258
- tictactoe_env.step(ai_action)
259
-
260
- if tictactoe_env.game_over:
261
- winner = "You" if tictactoe_env.winner == 1 else "AI" if tictactoe_env.winner == -1 else "Draw"
262
- if winner == "You": stats['wins'] += 1
263
- elif winner == "AI": stats['losses'] += 1
264
- else: stats['draws'] += 1
265
- yield *update_board_buttons(), f"Game Over! {winner} won! AI played {ai_action}.", reasoning, stats
266
- else:
267
- yield *update_board_buttons(), f"AI played position {ai_action}. Your turn!", reasoning, stats
268
 
269
- except Exception as e:
270
- yield *update_board_buttons(), f"Error: {str(e)}", "", stats
271
 
272
- def reset_tictactoe(stats):
273
- """Reset TicTacToe game."""
274
- tictactoe_env.reset()
275
- return *update_board_buttons(), "New game started! You are ❌ (X). Click a square to play.", "AI will show its reasoning here...", stats
276
-
277
- # Initialize the board on startup
278
  tictactoe_env.reset()
279
-
280
- # Simplified layout focusing only on TicTacToe
281
- with gr.Row():
282
- gr.Markdown("### Play TicTacToe against AI")
283
- gr.Markdown("") # spacer
284
- ttt_reset_btn = gr.Button("🔄 New Game", variant="secondary", size="sm")
285
-
286
- gr.Markdown("You are ❌ (X) and go first. Click on a square to make your move.")
287
-
288
- # Game board centered
289
- with gr.Column(elem_classes=["ttt-board"]):
290
- board_buttons = []
291
- for i in range(3):
292
- with gr.Row(elem_classes=["ttt-row"]):
293
- for j in range(3):
294
- pos = i * 3 + j
295
- button = gr.Button("", elem_id=f"ttt-cell-{pos}", size="lg", value="")
296
- board_buttons.append(button)
297
-
298
- # Stats display centered below board
299
- with gr.Row():
300
- ttt_stats_display = gr.Markdown(value="**Wins: 0 | Losses: 0 | Draws: 0**", elem_classes=["ttt-stats"])
301
-
302
- ttt_message = gr.Textbox(
303
- label="Game Status",
304
- value="Choose a position to start!",
305
- lines=2,
306
- interactive=False
307
- )
308
-
309
- ttt_reasoning = gr.Textbox(
310
- label="AI Reasoning",
311
- value="AI will explain its thought process here...",
312
- lines=3,
313
- interactive=False
314
- )
 
 
 
 
 
315
 
316
- # Create a combined click handler
317
- def on_board_click(pos, stats):
318
- yield from play_tictactoe(pos, stats)
319
 
320
- for i in range(9):
321
- board_buttons[i].click(
322
- fn=on_board_click,
323
- inputs=[gr.State(i), ttt_stats],
324
- outputs=[*board_buttons, ttt_message, ttt_reasoning, ttt_stats]
325
- )
326
-
327
- ttt_reset_btn.click(
328
- fn=reset_tictactoe,
329
- inputs=[ttt_stats],
330
  outputs=[*board_buttons, ttt_message, ttt_reasoning, ttt_stats]
331
  )
332
- # Update stats display on changes
333
- ttt_stats.change(
334
- fn=lambda s: f"Wins: {s['wins']} | Losses: {s['losses']} | Draws: {s['draws']}",
335
- inputs=ttt_stats,
336
- outputs=ttt_stats_display
337
- )
338
-
339
- # Initialize board display on load
340
- demo.load(
341
- fn=lambda stats: (*update_board_buttons(), "Game ready! You are ❌ (X). Click a square to play.", "AI will show its reasoning here...", stats),
342
- inputs=[ttt_stats],
343
- outputs=[*board_buttons, ttt_message, ttt_reasoning, ttt_stats]
344
- )
345
- gr.Markdown("---")
346
- gr.Markdown("🚧 **This is a development preview.** Full SPIRAL training and reasoning capabilities will be added in the next update!")
347
-
348
- else:
349
- # Fallback interface when games don't load
350
- gr.Markdown("⚠️ **Game modules could not be loaded.** Showing diagnostic information.")
351
- gr.Markdown("This usually happens when dependencies are still installing on HF Spaces.")
352
-
353
- # Show diagnostic info
354
- gr.Markdown("### 🔍 Diagnostic Information:")
355
- gr.Markdown(f"- Current directory: `{current_dir}`")
356
- gr.Markdown(f"- Source path: `{src_path}`")
357
- gr.Markdown(f"- Source directory exists: `{os.path.exists(src_path)}`")
358
-
359
- if os.path.exists(src_path):
360
- games_path = os.path.join(src_path, 'games')
361
- gr.Markdown(f"- Games directory exists: `{os.path.exists(games_path)}`")
362
- if os.path.exists(games_path):
363
- gr.Markdown(f"- Games directory contents: `{os.listdir(games_path)}`")
364
-
365
- # Simple demo interface
366
- with gr.Row():
367
- simple_input = gr.Textbox(label="Test Input", placeholder="Enter something...")
368
- simple_output = gr.Textbox(label="Output", interactive=False)
369
-
370
- def simple_echo(text):
371
- return f"Echo: {text} (Game modules will be available once dependencies install)"
372
-
373
- simple_input.submit(fn=simple_echo, inputs=[simple_input], outputs=[simple_output])
374
 
375
- # About Tab (always available)
376
- with gr.TabItem("ℹ️ About"):
377
- gr.Markdown("""
378
- ### About SPIRAL
379
-
380
- This is a **demo version** of the SPIRAL methodology: *"Self-Play on Zero-Sum Games Incentivizes Reasoning via Multi-Agent Multi-Turn Reinforcement Learning."*
381
-
382
- **Current Features:**
383
- - 🎯 **TicTacToe**: Play against a random AI opponent
384
- - 🃏 **Kuhn Poker**: Experience simplified poker gameplay
385
- - 🎮 **Interactive Games**: Real-time game state updates
386
-
387
- **Coming Soon:**
388
- - 🧠 **SPIRAL-trained AI**: Opponents trained via self-play
389
- - 📊 **Reasoning Traces**: See step-by-step AI decision-making
390
- - 🔬 **Transfer Learning**: Test AI reasoning on math problems
391
- - 📈 **Performance Metrics**: Track AI improvement over time
392
-
393
- **Game Rules:**
394
-
395
- **TicTacToe:**
396
- - 3x3 grid, get 3 in a row to win
397
- - You are X, AI is O
398
- - Numbers 0-8 represent board positions
399
-
400
- **Kuhn Poker:**
401
- - 3 cards: Jack (lowest), Queen, King (highest)
402
- - Each player gets 1 card, antes 1 chip
403
- - Actions: Check/Call, Bet (+1 chip), Fold
404
- - Higher card wins if both call/check
 
 
405
 
406
- **Technical Details:**
407
- - Built with Gymnasium environments
408
- - Gradio web interface
409
- - Ready for SPIRAL training integration
410
- """)
411
- gr.Markdown("**New in this version:** Visual boards, stats tracking, and transfer test stub!")
 
412
 
413
- if not GAMES_AVAILABLE:
414
- gr.Markdown("---")
415
- gr.Markdown("🔄 **Dependencies are loading.** Check the diagnostic info above and refresh in a few minutes!")
416
 
417
- return demo
418
-
419
- @spaces.GPU(duration=300)
420
- def main():
421
- """
422
- Main function to load model, create interface, and launch the Gradio app.
423
- Wrapped with @spaces.GPU to allocate a GPU for this Space.
424
- """
425
- global model, tokenizer
426
-
427
- print("🚀 Starting main application...")
428
- print("Loading configuration...")
429
- with open('config.yaml', 'r') as f:
430
- config = yaml.safe_load(f)
431
-
432
- model_name = config['model']['name']
433
- quantization_params = config['model'].get('quantization', {})
434
 
435
- print(f"📦 Model Name: {model_name}")
436
- print(f"⚙️ Quantization Params: {quantization_params}")
437
-
438
-
439
- # Create BitsAndBytesConfig if quantization is enabled
440
- if quantization_params and quantization_params.get('load_in_4bit'):
441
- print("💡 4-bit quantization enabled. Creating BitsAndBytesConfig...")
442
- compute_dtype_str = quantization_params.get("bnb_4bit_compute_dtype", "float16")
443
-
444
- if compute_dtype_str == "bfloat16":
445
- compute_dtype = torch.bfloat16
446
- else:
447
- compute_dtype = torch.float16 # Default to float16
448
-
449
- bnb_config = BitsAndBytesConfig(
450
- load_in_4bit=True,
451
- bnb_4bit_quant_type=quantization_params.get("bnb_4bit_quant_type", "nf4"),
452
- bnb_4bit_compute_dtype=compute_dtype,
453
- bnb_4bit_use_double_quant=quantization_params.get("bnb_4bit_use_double_quant", True),
454
- )
455
- # Using device_map="auto" is recommended for multi-GPU setups and large models
456
- print("🧠 Loading 4-bit quantized model...")
457
- model = AutoModelForCausalLM.from_pretrained(
458
- model_name,
459
- quantization_config=bnb_config,
460
- device_map="auto"
461
- )
462
- else:
463
- print("🧠 Loading model without quantization...")
464
- # Fallback for no quantization
465
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
466
 
467
- print("✒️ Loading tokenizer...")
468
- tokenizer = AutoTokenizer.from_pretrained(model_name)
469
-
470
- print("✅ Model and tokenizer loaded successfully.")
471
 
472
- print("🎨 Creating Gradio interface...")
473
  demo = create_interface()
474
-
475
- print("🚀 Launching Gradio app...")
476
  demo.launch()
477
-
478
- if __name__ == "__main__":
479
- main()
 
1
  """
2
  SPIRAL: Interactive Reasoning Game Simulator
3
 
4
+ Demonstrates key concepts from "Self-Play in Zero-Sum Games Incentivizes Reasoning via Multi-Agent Multi-Turn Reinforcement Learning"
5
+
6
+ This simplified demo shows how strategic reasoning emerges from self-play in zero-sum games like TicTacToe.
7
  """
8
 
9
  import gradio as gr
10
  import numpy as np
11
  import random
12
+
13
+
14
+ class TicTacToeEnv:
15
+ """Simple TicTacToe environment for SPIRAL demonstration."""
16
+
17
+ def __init__(self):
18
+ self.reset()
19
+
20
+ def reset(self):
21
+ """Reset the game to initial state."""
22
+ self.board = np.zeros((3, 3), dtype=np.int8)
23
+ self.current_player = 1 # Player 1 starts (X)
24
+ self.game_over = False
25
+ self.winner = None
26
+ self.move_count = 0
27
+ return self.board.copy()
28
+
29
+ def step(self, action):
30
+ """Execute one step in the environment."""
31
+ if self.game_over:
32
+ return self.board.copy(), 0, True, {}
33
+
34
+ # Convert action to row, col
35
+ row, col = divmod(action, 3)
36
+
37
+ # Check if move is valid
38
+ if self.board[row, col] != 0:
39
+ return self.board.copy(), -1, True, {"invalid_move": True}
40
+
41
+ # Make the move
42
+ self.board[row, col] = self.current_player
43
+ self.move_count += 1
44
+
45
+ # Check for win
46
+ winner = self._check_winner()
47
+ if winner is not None:
48
+ self.game_over = True
49
+ self.winner = winner
50
+ reward = 1 if winner == self.current_player else -1
51
+ return self.board.copy(), reward, True, {}
52
+ elif self.move_count >= 9:
53
+ # Draw
54
+ self.game_over = True
55
+ return self.board.copy(), 0, True, {}
56
+ else:
57
+ # Game continues
58
+ self.current_player *= -1 # Switch player
59
+ return self.board.copy(), 0, False, {}
60
+
61
+ def _check_winner(self):
62
+ """Check if there's a winner."""
63
+ # Check rows
64
+ for row in range(3):
65
+ if abs(self.board[row, :].sum()) == 3:
66
+ return self.board[row, 0]
67
+
68
+ # Check columns
69
+ for col in range(3):
70
+ if abs(self.board[:, col].sum()) == 3:
71
+ return self.board[0, col]
72
+
73
+ # Check diagonals
74
+ if abs(self.board.diagonal().sum()) == 3:
75
+ return self.board[0, 0]
76
+
77
+ if abs(np.fliplr(self.board).diagonal().sum()) == 3:
78
+ return self.board[0, 2]
79
+
80
+ return None
81
+
82
+ def get_valid_actions(self):
83
+ """Get list of valid actions (empty positions)."""
84
+ valid_actions = []
85
+ for i in range(9):
86
+ row, col = divmod(i, 3)
87
+ if self.board[row, col] == 0:
88
+ valid_actions.append(i)
89
+ return valid_actions
90
+
91
+
92
+ # Global game environment
93
+ tictactoe_env = TicTacToeEnv()
94
+
95
+
96
+ def check_winner(board):
97
+ """Check if there's a winner on the given board."""
98
+ # Check rows
99
+ for row in range(3):
100
+ if abs(board[row, :].sum()) == 3:
101
+ return board[row, 0]
102
+
103
+ # Check columns
104
+ for col in range(3):
105
+ if abs(board[:, col].sum()) == 3:
106
+ return board[0, col]
107
+
108
+ # Check diagonals
109
+ if abs(board.diagonal().sum()) == 3:
110
+ return board[0, 0]
111
+
112
+ if abs(np.fliplr(board).diagonal().sum()) == 3:
113
+ return board[0, 2]
114
 
115
+ return None
116
+
117
+
118
+ def get_valid_moves(board):
119
+ """Get valid moves for the given board."""
120
+ valid_moves = []
121
+ for i in range(9):
122
+ row, col = divmod(i, 3)
123
+ if board[row, col] == 0:
124
+ valid_moves.append(i)
125
+ return valid_moves
126
+
127
+
128
+ def minimax(board, player, depth=0):
129
+ """Minimax algorithm - demonstrates strategic reasoning."""
130
+ # Base cases
131
+ winner = check_winner(board)
132
+ if winner == 1: # Human wins
133
+ return -10 + depth, None
134
+ elif winner == -1: # AI wins
135
+ return 10 - depth, None
136
+ elif len(get_valid_moves(board)) == 0: # Draw
137
+ return 0, None
138
+
139
+ best_move = None
140
+ if player == -1: # AI is maximizing player
141
+ best_score = -float('inf')
142
+ for move in get_valid_moves(board):
143
+ row, col = divmod(move, 3)
144
+ board[row, col] = -1
145
+ score, _ = minimax(board.copy(), 1, depth + 1)
146
+ board[row, col] = 0 # Undo move
147
+ if score > best_score:
148
+ best_score = score
149
+ best_move = move
150
+ else: # Human is minimizing player
151
+ best_score = float('inf')
152
+ for move in get_valid_moves(board):
153
+ row, col = divmod(move, 3)
154
+ board[row, col] = 1
155
+ score, _ = minimax(board.copy(), -1, depth + 1)
156
+ board[row, col] = 0 # Undo move
157
+ if score < best_score:
158
+ best_score = score
159
+ best_move = move
160
+
161
+ return best_score, best_move
162
+
163
+
164
+ def generate_reasoning(board_state, human_move, ai_move):
165
+ """Generate reasoning explanation based on game state."""
166
+ reasoning_templates = [
167
+ f"I analyzed all possible moves from the current position. After you played position {human_move}, I considered {len(get_valid_moves(board_state))} possible responses. Using minimax tree search, I determined that position {ai_move} gives me the best strategic advantage.",
168
 
169
+ f"My decision process: (1) Evaluate immediate threats and opportunities, (2) Project future game states, (3) Choose move that maximizes my winning probability. Position {ai_move} emerged as optimal after analyzing the full game tree.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ f"Strategic analysis: Your move at {human_move} created a new board configuration. I used recursive tree search to evaluate all possible future sequences. Position {ai_move} either creates a winning opportunity or blocks your potential victories.",
172
+
173
+ f"SPIRAL reasoning: Through self-play training, I learned that position {ai_move} is strategically superior in this configuration. This demonstrates how strategic reasoning emerges from multi-agent interaction in zero-sum games."
174
+ ]
175
+
176
+ return random.choice(reasoning_templates)
177
 
178
 
179
  def create_interface():
 
231
  }
232
  """
233
 
234
+ with gr.Blocks(title="SPIRAL: Self-Play Reasoning Demo", theme=gr.themes.Soft(), css=css) as demo:
235
+ gr.Markdown("# 🎮 SPIRAL: Self-Play Reasoning Demo")
236
+ gr.Markdown("**Demonstrating how strategic reasoning emerges from self-play in zero-sum games**")
237
+ gr.Markdown("*Based on: \"Self-Play in Zero-Sum Games Incentivizes Reasoning via Multi-Agent Multi-Turn Reinforcement Learning\"*")
238
+
239
+ def update_board_buttons():
240
+ """Create a list of gr.Button updates from the current board state."""
241
+ updates = []
242
+ for i in range(9):
243
+ row, col = divmod(i, 3)
244
+ cell = tictactoe_env.board[row, col]
245
+ val = ""
246
+ interactive = True
247
+ if cell == 1:
248
+ val = '❌'
249
+ interactive = False
250
+ elif cell == -1:
251
+ val = '⭕'
252
+ interactive = False
253
+
254
+ if tictactoe_env.game_over:
255
+ interactive = False
 
256
 
257
+ updates.append(gr.Button(value=val, interactive=interactive))
258
+ return updates
259
 
260
+ ttt_stats = gr.State({'wins': 0, 'losses': 0, 'draws': 0})
261
+
262
+ def play_tictactoe(position, stats):
263
+ """Play a TicTacToe move and demonstrate AI reasoning."""
264
+ if tictactoe_env.game_over:
265
+ yield *update_board_buttons(), "Game is over! Click 'New Game' to start again.", "", stats
266
+ return
267
+
268
+ try:
269
+ position = int(position)
270
 
271
+ # Human move
272
+ board_state, reward, done, info = tictactoe_env.step(position)
273
+
274
+ if done:
275
+ if info.get("invalid_move"):
276
+ yield *update_board_buttons(), "Invalid move! Try again.", "", stats
277
+ return
278
+
279
+ winner = "You" if tictactoe_env.winner == 1 else "AI" if tictactoe_env.winner == -1 else "Draw"
280
+ if winner == "You": stats['wins'] += 1
281
+ elif winner == "AI": stats['losses'] += 1
282
+ else: stats['draws'] += 1
283
+ yield *update_board_buttons(), f"Game Over! {winner} won!", "", stats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  return
285
 
286
+ # Show AI thinking
287
+ yield *update_board_buttons(), "AI is analyzing the game tree...", "🧠 Strategic reasoning in progress...", stats
288
+
289
+ # AI move using minimax
290
+ _, ai_action = minimax(tictactoe_env.board.copy(), -1)
291
+ if ai_action is None:
292
+ valid_actions = tictactoe_env.get_valid_actions()
293
+ if not valid_actions:
294
+ yield *update_board_buttons(), "Game is a draw!", "", stats
 
 
 
295
  return
296
+ ai_action = random.choice(valid_actions)
297
 
298
+ # Generate reasoning explanation
299
+ reasoning = generate_reasoning(tictactoe_env.board.copy(), position, ai_action)
300
+
301
+ # AI makes move
302
+ board_state, reward, done, info = tictactoe_env.step(ai_action)
303
+
304
+ if done:
305
+ winner = "You" if tictactoe_env.winner == 1 else "AI" if tictactoe_env.winner == -1 else "Draw"
306
+ if winner == "You": stats['wins'] += 1
307
+ elif winner == "AI": stats['losses'] += 1
308
+ else: stats['draws'] += 1
309
+ yield *update_board_buttons(), f"Game Over! {winner} won! AI played position {ai_action}.", reasoning, stats
310
+ else:
311
+ yield *update_board_buttons(), f"AI chose position {ai_action}. Your turn!", reasoning, stats
 
 
 
 
 
 
 
 
 
 
312
 
313
+ except Exception as e:
314
+ yield *update_board_buttons(), f"Error: {str(e)}", "", stats
315
 
316
+ def reset_tictactoe(stats):
317
+ """Reset TicTacToe game."""
 
 
 
 
318
  tictactoe_env.reset()
319
+ return *update_board_buttons(), "New game started! You are ❌ (X). Click a square to demonstrate strategic reasoning.", "The AI will explain its strategic decision-making process...", stats
320
+
321
+ # Initialize the board
322
+ tictactoe_env.reset()
323
+
324
+ # Game interface
325
+ with gr.Row():
326
+ gr.Markdown("### Strategic TicTacToe")
327
+ gr.Markdown("") # spacer
328
+ ttt_reset_btn = gr.Button("🔄 New Game", variant="secondary", size="sm")
329
+
330
+ gr.Markdown("**You are ❌ (X)** - The AI uses minimax tree search to demonstrate strategic reasoning")
331
+
332
+ # Game board
333
+ with gr.Column(elem_classes=["ttt-board"]):
334
+ board_buttons = []
335
+ for i in range(3):
336
+ with gr.Row(elem_classes=["ttt-row"]):
337
+ for j in range(3):
338
+ pos = i * 3 + j
339
+ button = gr.Button("", elem_id=f"ttt-cell-{pos}", size="lg", value="")
340
+ board_buttons.append(button)
341
+
342
+ # Stats display
343
+ with gr.Row():
344
+ ttt_stats_display = gr.Markdown(value="**Wins: 0 | Losses: 0 | Draws: 0**", elem_classes=["ttt-stats"])
345
+
346
+ # Game status and AI reasoning
347
+ ttt_message = gr.Textbox(
348
+ label="🎯 Game Status",
349
+ value="Click a square to start! Watch how the AI reasons strategically.",
350
+ lines=2,
351
+ interactive=False
352
+ )
353
+
354
+ ttt_reasoning = gr.Textbox(
355
+ label="🧠 AI Strategic Reasoning",
356
+ value="The AI will explain its strategic decision-making process here, demonstrating how reasoning emerges from self-play training in zero-sum games.",
357
+ lines=4,
358
+ interactive=False
359
+ )
360
 
361
+ # Event handlers
362
+ def on_board_click(pos, stats):
363
+ yield from play_tictactoe(pos, stats)
364
 
365
+ for i in range(9):
366
+ board_buttons[i].click(
367
+ fn=on_board_click,
368
+ inputs=[gr.State(i), ttt_stats],
 
 
 
 
 
 
369
  outputs=[*board_buttons, ttt_message, ttt_reasoning, ttt_stats]
370
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
+ ttt_reset_btn.click(
373
+ fn=reset_tictactoe,
374
+ inputs=[ttt_stats],
375
+ outputs=[*board_buttons, ttt_message, ttt_reasoning, ttt_stats]
376
+ )
377
+
378
+ # Update stats display
379
+ ttt_stats.change(
380
+ fn=lambda s: f"**Wins: {s['wins']} | Losses: {s['losses']} | Draws: {s['draws']}**",
381
+ inputs=ttt_stats,
382
+ outputs=ttt_stats_display
383
+ )
384
+
385
+ # Initialize board display on load
386
+ demo.load(
387
+ fn=lambda stats: (*update_board_buttons(), "Click a square to start! Watch how the AI reasons strategically.", "The AI will explain its strategic decision-making process here, demonstrating how reasoning emerges from self-play training in zero-sum games.", stats),
388
+ inputs=[ttt_stats],
389
+ outputs=[*board_buttons, ttt_message, ttt_reasoning, ttt_stats]
390
+ )
391
+
392
+ # Key concepts section
393
+ gr.Markdown("---")
394
+ gr.Markdown("## 🧠 Key SPIRAL Concepts Demonstrated")
395
+
396
+ with gr.Row():
397
+ with gr.Column():
398
+ gr.Markdown("""
399
+ **🎯 Strategic Reasoning**
400
+ - AI uses minimax tree search
401
+ - Evaluates all possible future moves
402
+ - Chooses optimal strategic actions
403
+ """)
404
 
405
+ with gr.Column():
406
+ gr.Markdown("""
407
+ **🔄 Self-Play Learning**
408
+ - Strategic patterns emerge from competition
409
+ - Zero-sum games incentivize reasoning
410
+ - Multi-agent interactions develop intelligence
411
+ """)
412
 
413
+ gr.Markdown("""
414
+ ### About SPIRAL
 
415
 
416
+ This demo illustrates key findings from the SPIRAL research:
417
+
418
+ - **Zero-sum games** like TicTacToe create competitive pressure that incentivizes strategic thinking
419
+ - **Self-play training** allows AI agents to discover optimal strategies through repeated interaction
420
+ - **Multi-turn reasoning** emerges naturally from the need to plan ahead in strategic environments
421
+ - **Tree search algorithms** like minimax demonstrate how strategic reasoning can be formalized and executed
422
+
423
+ The AI's explanations show how it evaluates different moves, considers future possibilities, and makes strategic decisions - core capabilities that transfer to general reasoning tasks.
424
+ """)
 
 
 
 
 
 
 
 
425
 
426
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
 
 
 
 
428
 
429
+ if __name__ == "__main__":
430
  demo = create_interface()
 
 
431
  demo.launch()
 
 
 
config.yaml DELETED
@@ -1,124 +0,0 @@
1
- # SPIRAL Interactive Reasoning Game Simulator Configuration
2
-
3
- # Model Configuration
4
- model:
5
- name: "meta-llama/Llama-3.1-8B-Instruct"
6
- max_length: 2048
7
- temperature: 0.7
8
- do_sample: true
9
- quantization:
10
- load_in_4bit: true
11
- bnb_4bit_compute_dtype: "float16"
12
- bnb_4bit_use_double_quant: true
13
-
14
- # Games Configuration
15
- games:
16
- kuhn_poker:
17
- name: "Kuhn Poker"
18
- max_rounds: 50
19
- deck_size: 3
20
- betting_rounds: 2
21
-
22
- tictactoe:
23
- name: "TicTacToe"
24
- board_size: 3
25
- max_moves: 9
26
- win_condition: 3
27
-
28
- # Training Configuration
29
- training:
30
- algorithm: "PPO"
31
- episodes: 1000
32
- batch_size: 32
33
- learning_rate: 0.0003
34
- gamma: 0.99
35
- gae_lambda: 0.95
36
- clip_range: 0.2
37
- entropy_coef: 0.01
38
- value_loss_coef: 0.5
39
- max_grad_norm: 0.5
40
-
41
- # Self-play specific
42
- self_play:
43
- update_opponent_every: 100
44
- opponent_pool_size: 5
45
-
46
- # Role-conditioned advantage estimation
47
- rae:
48
- enable: true
49
- role_embedding_dim: 64
50
- advantage_weighting: 0.5
51
-
52
- # Reasoning Configuration
53
- reasoning:
54
- enable_traces: true
55
- trace_depth: 3
56
- chain_of_thought: true
57
- explanation_length: 150
58
-
59
- # Transfer learning evaluation
60
- transfer_tasks:
61
- - "GSM8K"
62
- - "Logic Puzzles"
63
- - "Strategic Reasoning"
64
-
65
- # Web Interface Configuration
66
- interface:
67
- title: "SPIRAL: Interactive Reasoning Game Simulator"
68
- description: "Play games against AI and explore reasoning capabilities"
69
- theme: "default"
70
-
71
- # Gradio settings
72
- gradio:
73
- share: false
74
- inbrowser: true
75
- server_name: "0.0.0.0"
76
- server_port: 7860
77
- enable_queue: true
78
- max_threads: 4
79
-
80
- # Logging Configuration
81
- logging:
82
- level: "INFO"
83
- format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
84
- file: "logs/spiral.log"
85
-
86
- # Experiment tracking
87
- wandb:
88
- enable: false
89
- project: "spiral-reasoning"
90
- entity: "your-username"
91
-
92
- tensorboard:
93
- enable: true
94
- log_dir: "logs/tensorboard"
95
-
96
- # Data Configuration
97
- data:
98
- cache_dir: "data/cache"
99
- datasets_dir: "data/datasets"
100
- models_dir: "models"
101
-
102
- # Benchmark datasets
103
- benchmarks:
104
- gsm8k: "data/benchmarks/gsm8k.json"
105
- logic_puzzles: "data/benchmarks/logic_puzzles.json"
106
-
107
- # Deployment Configuration
108
- deployment:
109
- huggingface:
110
- space_name: "kaushikvr06/reasoning-simulator"
111
- private: false
112
-
113
- # Performance settings
114
- performance:
115
- max_concurrent_users: 10
116
- timeout_seconds: 30
117
- memory_limit: "2GB"
118
-
119
- # Debug Configuration
120
- debug:
121
- enable: false
122
- verbose_traces: false
123
- save_game_logs: true
124
- profile_inference: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/README.md DELETED
@@ -1,16 +0,0 @@
1
- # Data Directory
2
-
3
- This directory contains datasets and game-related files for the SPIRAL project.
4
-
5
- ## Structure
6
-
7
- - `games/` - Game datasets and rule definitions
8
- - `benchmarks/` - Math and logic benchmarks for transfer testing (e.g., GSM8K)
9
- - `training/` - Training data and logs
10
- - `examples/` - Example game sessions and reasoning traces
11
-
12
- ## Data Sources
13
-
14
- - Game implementations from GitHub repositories
15
- - Math benchmarks like GSM8K for transfer evaluation
16
- - Custom game datasets generated during training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,15 +1,2 @@
1
- torch>=2.0.0
2
- transformers>=4.30.0
3
- gymnasium>=0.29.0
4
- stable-baselines3>=2.0.0
5
- gradio>=4.0.0
6
- numpy>=1.21.0
7
- matplotlib>=3.5.0
8
- seaborn>=0.11.0
9
- pandas>=1.3.0
10
- tqdm>=4.62.0
11
- pyyaml
12
- bitsandbytes
13
- accelerate>=0.26.0
14
- pytest
15
- Jinja2
 
1
+ gradio==4.44.0
2
+ numpy==1.24.3
 
 
 
 
 
 
 
 
 
 
 
 
 
src/__init__.py DELETED
@@ -1,15 +0,0 @@
1
- """
2
- SPIRAL: Self-Play on Zero-Sum Games Incentivizes Reasoning
3
-
4
- This package implements the SPIRAL methodology for training AI agents
5
- through self-play on zero-sum games to improve reasoning capabilities.
6
- """
7
-
8
- __version__ = "0.1.0"
9
- __author__ = "SPIRAL Team"
10
- __email__ = "contact@spiral-reasoning.com"
11
-
12
- from .games import *
13
- from .models import *
14
- from .training import *
15
- from .reasoning import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/games/__init__.py DELETED
@@ -1,16 +0,0 @@
1
- """
2
- Game environments for SPIRAL training.
3
-
4
- This module contains implementations of zero-sum games used for
5
- self-play training, including Kuhn Poker and TicTacToe.
6
- """
7
-
8
- from .tictactoe import TicTacToeEnv, create_tictactoe_env
9
- from .kuhn_poker import KuhnPokerEnv, create_kuhn_poker_env
10
-
11
- __all__ = [
12
- "TicTacToeEnv",
13
- "KuhnPokerEnv",
14
- "create_tictactoe_env",
15
- "create_kuhn_poker_env"
16
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/games/game_utils.py DELETED
@@ -1,212 +0,0 @@
1
- """
2
- Game utility functions for SPIRAL training.
3
-
4
- This module contains helper functions for game environments,
5
- including multi-turn logic and game state management.
6
- """
7
-
8
- import gymnasium as gym
9
- from typing import Dict, Any, Type, Union
10
- import numpy as np
11
-
12
- from .tictactoe import TicTacToeEnv
13
- from .kuhn_poker import KuhnPokerEnv
14
-
15
-
16
- # Game registry
17
- GAMES_REGISTRY: Dict[str, Type[gym.Env]] = {
18
- "tictactoe": TicTacToeEnv,
19
- "kuhn_poker": KuhnPokerEnv,
20
- }
21
-
22
-
23
- def create_game_env(game_name: str, **kwargs) -> gym.Env:
24
- """
25
- Create a game environment by name.
26
-
27
- Args:
28
- game_name: Name of the game ("tictactoe", "kuhn_poker")
29
- **kwargs: Additional arguments for the environment
30
-
31
- Returns:
32
- Game environment instance
33
-
34
- Raises:
35
- ValueError: If game_name is not recognized
36
- """
37
- if game_name not in GAMES_REGISTRY:
38
- available_games = list(GAMES_REGISTRY.keys())
39
- raise ValueError(f"Unknown game: {game_name}. Available games: {available_games}")
40
-
41
- game_class = GAMES_REGISTRY[game_name]
42
- return game_class(**kwargs)
43
-
44
-
45
- def get_game_info(game_name: str) -> Dict[str, Any]:
46
- """
47
- Get information about a game environment.
48
-
49
- Args:
50
- game_name: Name of the game
51
-
52
- Returns:
53
- Dictionary with game information
54
- """
55
- env = create_game_env(game_name)
56
-
57
- info = {
58
- "name": game_name,
59
- "action_space": env.action_space,
60
- "observation_space": env.observation_space,
61
- "max_episode_steps": getattr(env, "_max_episode_steps", None),
62
- "render_modes": env.metadata.get("render_modes", []),
63
- }
64
-
65
- # Add game-specific information
66
- if game_name == "tictactoe":
67
- info.update({
68
- "description": "3x3 TicTacToe game with alternating turns",
69
- "players": 2,
70
- "zero_sum": True,
71
- "perfect_information": True,
72
- })
73
- elif game_name == "kuhn_poker":
74
- info.update({
75
- "description": "Simplified poker with 3 cards (J, Q, K)",
76
- "players": 2,
77
- "zero_sum": True,
78
- "perfect_information": False,
79
- })
80
-
81
- env.close()
82
- return info
83
-
84
-
85
- def get_available_games() -> list:
86
- """Get list of available game names."""
87
- return list(GAMES_REGISTRY.keys())
88
-
89
-
90
- def is_game_over(env: gym.Env) -> bool:
91
- """
92
- Check if the game is over.
93
-
94
- Args:
95
- env: Game environment
96
-
97
- Returns:
98
- True if game is over, False otherwise
99
- """
100
- if hasattr(env, 'game_over'):
101
- return env.game_over
102
- return False
103
-
104
-
105
- def get_valid_actions(env: gym.Env) -> list:
106
- """
107
- Get valid actions for the current state.
108
-
109
- Args:
110
- env: Game environment
111
-
112
- Returns:
113
- List of valid actions
114
- """
115
- if hasattr(env, '_get_valid_actions'):
116
- return env._get_valid_actions()
117
- elif hasattr(env, 'get_valid_actions'):
118
- return env.get_valid_actions()
119
- else:
120
- # Fallback: assume all actions are valid
121
- return list(range(env.action_space.n))
122
-
123
-
124
- def get_action_mask(env: gym.Env) -> np.ndarray:
125
- """
126
- Get action mask for the current state.
127
-
128
- Args:
129
- env: Game environment
130
-
131
- Returns:
132
- Boolean mask where True indicates valid actions
133
- """
134
- if hasattr(env, 'get_action_mask'):
135
- return env.get_action_mask()
136
- else:
137
- # Fallback: create mask from valid actions
138
- valid_actions = get_valid_actions(env)
139
- mask = np.zeros(env.action_space.n, dtype=bool)
140
- for action in valid_actions:
141
- mask[action] = True
142
- return mask
143
-
144
-
145
- def play_random_game(game_name: str, render: bool = False, seed: int = None) -> Dict[str, Any]:
146
- """
147
- Play a random game to completion.
148
-
149
- Args:
150
- game_name: Name of the game to play
151
- render: Whether to render the game
152
- seed: Random seed for reproducibility
153
-
154
- Returns:
155
- Dictionary with game results
156
- """
157
- env = create_game_env(game_name, render_mode="human" if render else None)
158
-
159
- if seed is not None:
160
- env.reset(seed=seed)
161
- else:
162
- env.reset()
163
-
164
- if render:
165
- env.render()
166
-
167
- total_reward = 0
168
- step_count = 0
169
- actions_taken = []
170
-
171
- while not is_game_over(env):
172
- valid_actions = get_valid_actions(env)
173
- action = np.random.choice(valid_actions)
174
-
175
- obs, reward, terminated, truncated, info = env.step(action)
176
- actions_taken.append(action)
177
- total_reward += reward
178
- step_count += 1
179
-
180
- if render:
181
- print(f"Step {step_count}: Action {action}, Reward: {reward}")
182
- env.render()
183
-
184
- if terminated or truncated:
185
- break
186
-
187
- results = {
188
- "game_name": game_name,
189
- "total_reward": total_reward,
190
- "step_count": step_count,
191
- "actions_taken": actions_taken,
192
- "winner": getattr(env, 'winner', None),
193
- "final_info": info
194
- }
195
-
196
- env.close()
197
- return results
198
-
199
-
200
- if __name__ == "__main__":
201
- # Test the utilities
202
- print("Available games:", get_available_games())
203
-
204
- for game_name in get_available_games():
205
- print(f"\n{game_name.upper()} Info:")
206
- info = get_game_info(game_name)
207
- for key, value in info.items():
208
- print(f" {key}: {value}")
209
-
210
- # Play a random game
211
- print("\nPlaying random TicTacToe game:")
212
- result = play_random_game("tictactoe", render=True, seed=42)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/games/kuhn_poker.py DELETED
@@ -1,314 +0,0 @@
1
- """
2
- Kuhn Poker Game Environment
3
-
4
- A simple Kuhn Poker implementation using Gymnasium for SPIRAL training.
5
- Kuhn Poker is a simplified poker variant with 3 cards (J, Q, K).
6
- """
7
-
8
- import gymnasium as gym
9
- import numpy as np
10
- from gymnasium import spaces
11
- from typing import Tuple, Dict, Any, Optional, List
12
- import random
13
-
14
-
15
- class KuhnPokerEnv(gym.Env):
16
- """
17
- Kuhn Poker environment for SPIRAL training.
18
-
19
- Rules:
20
- - 3 cards: Jack (0), Queen (1), King (2)
21
- - Each player gets 1 card
22
- - Each player antes 1 chip
23
- - Player 1 acts first: Check or Bet
24
- - Player 2 then acts: Check, Call, or Fold
25
- - If both check, high card wins
26
- - If one bets and other calls, high card wins
27
- - If one bets and other folds, bettor wins
28
-
29
- Action space: [Check/Call=0, Bet=1, Fold=2]
30
- Observation space: [player_card, opponent_action, betting_round]
31
- """
32
-
33
- metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 1}
34
-
35
- # Card values: Jack=0, Queen=1, King=2
36
- JACK, QUEEN, KING = 0, 1, 2
37
- CARDS = [JACK, QUEEN, KING]
38
- CARD_NAMES = ["J", "Q", "K"]
39
-
40
- # Actions
41
- CHECK_CALL, BET, FOLD = 0, 1, 2
42
- ACTION_NAMES = ["Check/Call", "Bet", "Fold"]
43
-
44
- def __init__(self, render_mode: Optional[str] = None):
45
- super().__init__()
46
-
47
- # Observation: [player_card, opponent_last_action, betting_round, pot_size]
48
- self.observation_space = spaces.Box(
49
- low=0, high=10, shape=(4,), dtype=np.int8
50
- )
51
-
52
- # Actions: Check/Call, Bet, Fold
53
- self.action_space = spaces.Discrete(3)
54
-
55
- self.render_mode = render_mode
56
- self.reset()
57
-
58
- def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]:
59
- """Reset the game to initial state."""
60
- super().reset(seed=seed)
61
-
62
- # Deal cards
63
- cards = self.CARDS.copy()
64
- random.shuffle(cards)
65
- self.player1_card = cards[0]
66
- self.player2_card = cards[1]
67
-
68
- # Game state
69
- self.current_player = 1 # Player 1 starts
70
- self.pot = 2 # Each player antes 1
71
- self.player1_bet = 1 # Ante
72
- self.player2_bet = 1 # Ante
73
- self.game_over = False
74
- self.winner = None
75
- self.betting_round = 0
76
- self.actions_history = []
77
-
78
- observation = self._get_observation()
79
- info = self._get_info()
80
-
81
- return observation, info
82
-
83
- def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
84
- """
85
- Execute one step in the environment.
86
-
87
- Args:
88
- action: 0=Check/Call, 1=Bet, 2=Fold
89
-
90
- Returns:
91
- observation, reward, terminated, truncated, info
92
- """
93
- if self.game_over:
94
- raise ValueError("Game is already over. Call reset() to start new game.")
95
-
96
- # Record action
97
- self.actions_history.append((self.current_player, action))
98
-
99
- # Process action
100
- if action == self.FOLD:
101
- # Current player folds, opponent wins
102
- self.game_over = True
103
- self.winner = 2 if self.current_player == 1 else 1
104
- reward = self._calculate_reward()
105
-
106
- elif action == self.BET:
107
- # Current player bets
108
- if self.current_player == 1:
109
- self.player1_bet += 1
110
- self.pot += 1
111
- else:
112
- self.player2_bet += 1
113
- self.pot += 1
114
-
115
- # Check if this ends the betting round
116
- if self.betting_round == 0:
117
- # First bet, opponent gets to act
118
- self.current_player = 2
119
- self.betting_round = 1
120
- reward = 0.0
121
- else:
122
- # Second bet (raise), go to showdown
123
- self.game_over = True
124
- self.winner = self._determine_winner_by_cards()
125
- reward = self._calculate_reward()
126
-
127
- else: # CHECK_CALL
128
- if self.betting_round == 0:
129
- # First action is check
130
- if self.current_player == 1:
131
- # Player 1 checks, player 2 acts
132
- self.current_player = 2
133
- self.betting_round = 1
134
- reward = 0.0
135
- else:
136
- # Player 2 checks after player 1 checked, showdown
137
- self.game_over = True
138
- self.winner = self._determine_winner_by_cards()
139
- reward = self._calculate_reward()
140
- else:
141
- # This is a call
142
- if self.current_player == 2:
143
- # Player 2 calls player 1's bet
144
- self.player2_bet = self.player1_bet
145
- self.pot = self.player1_bet + self.player2_bet
146
- self.game_over = True
147
- self.winner = self._determine_winner_by_cards()
148
- reward = self._calculate_reward()
149
- else:
150
- # Player 1 calls player 2's bet
151
- self.player1_bet = self.player2_bet
152
- self.pot = self.player1_bet + self.player2_bet
153
- self.game_over = True
154
- self.winner = self._determine_winner_by_cards()
155
- reward = self._calculate_reward()
156
-
157
- observation = self._get_observation()
158
- info = self._get_info()
159
-
160
- return observation, reward, self.game_over, False, info
161
-
162
- def _get_observation(self) -> np.ndarray:
163
- """Get current observation for the current player."""
164
- # Get current player's card
165
- player_card = self.player1_card if self.current_player == 1 else self.player2_card
166
-
167
- # Get opponent's last action (if any)
168
- opponent_last_action = -1
169
- if self.actions_history:
170
- for player, action in reversed(self.actions_history):
171
- if player != self.current_player:
172
- opponent_last_action = action
173
- break
174
-
175
- # Observation: [player_card, opponent_last_action, betting_round, pot_size]
176
- observation = np.array([
177
- player_card,
178
- opponent_last_action + 1, # -1 becomes 0, 0 becomes 1, etc.
179
- self.betting_round,
180
- self.pot
181
- ], dtype=np.int8)
182
-
183
- return observation
184
-
185
- def _get_info(self) -> Dict[str, Any]:
186
- """Get additional info about the game state."""
187
- return {
188
- "current_player": self.current_player,
189
- "game_over": self.game_over,
190
- "winner": self.winner,
191
- "player1_card": self.player1_card,
192
- "player2_card": self.player2_card,
193
- "pot": self.pot,
194
- "betting_round": self.betting_round,
195
- "actions_history": self.actions_history.copy(),
196
- "valid_actions": self._get_valid_actions()
197
- }
198
-
199
- def _get_valid_actions(self) -> List[int]:
200
- """Get list of valid actions."""
201
- if self.game_over:
202
- return []
203
-
204
- # All actions are always valid in Kuhn Poker
205
- return [self.CHECK_CALL, self.BET, self.FOLD]
206
-
207
- def _determine_winner_by_cards(self) -> int:
208
- """Determine winner by comparing cards."""
209
- if self.player1_card > self.player2_card:
210
- return 1
211
- else:
212
- return 2
213
-
214
- def _calculate_reward(self) -> float:
215
- """Calculate reward for the current player."""
216
- if not self.game_over:
217
- return 0.0
218
-
219
- if self.winner == self.current_player:
220
- # Won - get the pot minus what you put in
221
- if self.current_player == 1:
222
- return float(self.pot - self.player1_bet)
223
- else:
224
- return float(self.pot - self.player2_bet)
225
- else:
226
- # Lost - lose what you put in
227
- if self.current_player == 1:
228
- return float(-self.player1_bet)
229
- else:
230
- return float(-self.player2_bet)
231
-
232
- def render(self) -> Optional[np.ndarray]:
233
- """Render the game state."""
234
- if self.render_mode == "human":
235
- self._render_human()
236
- elif self.render_mode == "rgb_array":
237
- return self._render_rgb_array()
238
-
239
- def _render_human(self):
240
- """Print the game state to console."""
241
- print("\n" + "="*40)
242
- print("KUHN POKER")
243
- print("="*40)
244
- print(f"Player 1 Card: {self.CARD_NAMES[self.player1_card]}")
245
- print(f"Player 2 Card: {self.CARD_NAMES[self.player2_card]}")
246
- print(f"Pot: {self.pot}")
247
- print(f"Current Player: {self.current_player}")
248
- print(f"Betting Round: {self.betting_round}")
249
-
250
- if self.actions_history:
251
- print("Actions:")
252
- for player, action in self.actions_history:
253
- print(f" Player {player}: {self.ACTION_NAMES[action]}")
254
-
255
- if self.game_over:
256
- print(f"Game Over! Winner: Player {self.winner}")
257
- print("="*40)
258
-
259
- def _render_rgb_array(self) -> np.ndarray:
260
- """Render as RGB array for visualization."""
261
- # Simple RGB representation (placeholder)
262
- rgb = np.zeros((100, 100, 3), dtype=np.uint8)
263
-
264
- # Color based on current player's card
265
- if self.current_player == 1:
266
- card_value = self.player1_card
267
- else:
268
- card_value = self.player2_card
269
-
270
- # Different colors for different cards
271
- if card_value == self.JACK:
272
- rgb[:, :] = [255, 0, 0] # Red for Jack
273
- elif card_value == self.QUEEN:
274
- rgb[:, :] = [0, 255, 0] # Green for Queen
275
- else: # King
276
- rgb[:, :] = [0, 0, 255] # Blue for King
277
-
278
- return rgb
279
-
280
- def get_action_mask(self) -> np.ndarray:
281
- """Get mask of valid actions (1 for valid, 0 for invalid)."""
282
- mask = np.zeros(3, dtype=np.int8)
283
- for action in self._get_valid_actions():
284
- mask[action] = 1
285
- return mask
286
-
287
-
288
- def create_kuhn_poker_env() -> KuhnPokerEnv:
289
- """Factory function to create a Kuhn Poker environment."""
290
- return KuhnPokerEnv()
291
-
292
-
293
- if __name__ == "__main__":
294
- # Test the environment
295
- env = KuhnPokerEnv(render_mode="human")
296
-
297
- # Play a simple game
298
- obs, info = env.reset()
299
- print("Initial state:")
300
- env.render()
301
-
302
- # Simulate some moves
303
- while not env.game_over:
304
- valid_actions = env._get_valid_actions()
305
- action = random.choice(valid_actions)
306
-
307
- obs, reward, terminated, truncated, info = env.step(action)
308
- print(f"\nPlayer {env.current_player if not env.game_over else 'Previous'} action: {env.ACTION_NAMES[action]}")
309
- print(f"Reward: {reward}")
310
- env.render()
311
-
312
- if terminated:
313
- print(f"Game terminated! Final reward: {reward}")
314
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/games/tictactoe.py DELETED
@@ -1,237 +0,0 @@
1
- """
2
- TicTacToe Game Environment
3
-
4
- A simple TicTacToe implementation using Gymnasium for SPIRAL training.
5
- """
6
-
7
- import gymnasium as gym
8
- import numpy as np
9
- from gymnasium import spaces
10
- from typing import Tuple, Dict, Any, Optional
11
-
12
-
13
- class TicTacToeEnv(gym.Env):
14
- """
15
- TicTacToe environment for SPIRAL training.
16
-
17
- - 3x3 grid
18
- - Players alternate turns (1 and -1)
19
- - Action space: 9 positions (0-8)
20
- - Observation space: 3x3 grid with values {-1, 0, 1}
21
- - Reward: +1 for win, -1 for loss, 0 for draw/ongoing
22
- """
23
-
24
- metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 1}
25
-
26
- def __init__(self, render_mode: Optional[str] = None):
27
- super().__init__()
28
-
29
- # 3x3 grid, each cell can be -1 (player 2), 0 (empty), or 1 (player 1)
30
- self.observation_space = spaces.Box(
31
- low=-1, high=1, shape=(3, 3), dtype=np.int8
32
- )
33
-
34
- # 9 possible actions (positions 0-8)
35
- self.action_space = spaces.Discrete(9)
36
-
37
- self.render_mode = render_mode
38
- self.reset()
39
-
40
- def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]:
41
- """Reset the game to initial state."""
42
- super().reset(seed=seed)
43
-
44
- # Initialize empty board
45
- self.board = np.zeros((3, 3), dtype=np.int8)
46
- self.current_player = 1 # Player 1 starts
47
- self.game_over = False
48
- self.winner = None
49
- self.move_count = 0
50
-
51
- observation = self._get_observation()
52
- info = self._get_info()
53
-
54
- return observation, info
55
-
56
- def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
57
- """
58
- Execute one step in the environment.
59
-
60
- Args:
61
- action: Position to place mark (0-8)
62
-
63
- Returns:
64
- observation, reward, terminated, truncated, info
65
- """
66
- if self.game_over:
67
- raise ValueError("Game is already over. Call reset() to start new game.")
68
-
69
- # Convert action to row, col
70
- row, col = divmod(action, 3)
71
-
72
- # Check if move is valid
73
- if self.board[row, col] != 0:
74
- # Invalid move - penalize and end game
75
- reward = -1.0
76
- terminated = True
77
- self.game_over = True
78
- info = self._get_info()
79
- info["invalid_move"] = True
80
- return self._get_observation(), reward, terminated, False, info
81
-
82
- # Make the move
83
- self.board[row, col] = self.current_player
84
- self.move_count += 1
85
-
86
- # Check for win
87
- winner = self._check_winner()
88
- if winner is not None:
89
- self.game_over = True
90
- self.winner = winner
91
- reward = 1.0 if winner == self.current_player else -1.0
92
- terminated = True
93
- elif self.move_count >= 9:
94
- # Draw
95
- self.game_over = True
96
- reward = 0.0
97
- terminated = True
98
- else:
99
- # Game continues
100
- reward = 0.0
101
- terminated = False
102
- self.current_player *= -1 # Switch player
103
-
104
- observation = self._get_observation()
105
- info = self._get_info()
106
-
107
- return observation, reward, terminated, False, info
108
-
109
- def _get_observation(self) -> np.ndarray:
110
- """Get current board state."""
111
- return self.board.copy()
112
-
113
- def _get_info(self) -> Dict[str, Any]:
114
- """Get additional info about the game state."""
115
- return {
116
- "current_player": self.current_player,
117
- "game_over": self.game_over,
118
- "winner": self.winner,
119
- "move_count": self.move_count,
120
- "valid_actions": self._get_valid_actions()
121
- }
122
-
123
- def _get_valid_actions(self) -> list:
124
- """Get list of valid actions (empty positions)."""
125
- valid_actions = []
126
- for i in range(9):
127
- row, col = divmod(i, 3)
128
- if self.board[row, col] == 0:
129
- valid_actions.append(i)
130
- return valid_actions
131
-
132
- def _check_winner(self) -> Optional[int]:
133
- """
134
- Check if there's a winner.
135
-
136
- Returns:
137
- 1 if player 1 wins, -1 if player 2 wins, None if no winner
138
- """
139
- # Check rows
140
- for row in range(3):
141
- if abs(self.board[row, :].sum()) == 3:
142
- return self.board[row, 0]
143
-
144
- # Check columns
145
- for col in range(3):
146
- if abs(self.board[:, col].sum()) == 3:
147
- return self.board[0, col]
148
-
149
- # Check diagonals
150
- if abs(self.board.diagonal().sum()) == 3:
151
- return self.board[0, 0]
152
-
153
- if abs(np.fliplr(self.board).diagonal().sum()) == 3:
154
- return self.board[0, 2]
155
-
156
- return None
157
-
158
- def render(self) -> Optional[np.ndarray]:
159
- """Render the game state."""
160
- if self.render_mode == "human":
161
- self._render_human()
162
- elif self.render_mode == "rgb_array":
163
- return self._render_rgb_array()
164
-
165
- def _render_human(self):
166
- """Print the board to console."""
167
- print("\n" + "="*13)
168
- for row in range(3):
169
- print("|", end="")
170
- for col in range(3):
171
- cell = self.board[row, col]
172
- if cell == 1:
173
- print(" X ", end="|")
174
- elif cell == -1:
175
- print(" O ", end="|")
176
- else:
177
- print(f" {row*3 + col} ", end="|")
178
- print()
179
- print("="*13)
180
-
181
- if self.game_over:
182
- if self.winner is not None:
183
- winner_symbol = "X" if self.winner == 1 else "O"
184
- print(f"Game Over! Winner: {winner_symbol}")
185
- else:
186
- print("Game Over! It's a draw!")
187
-
188
- def _render_rgb_array(self) -> np.ndarray:
189
- """Render as RGB array for visualization."""
190
- # Simple RGB representation
191
- rgb = np.zeros((3, 3, 3), dtype=np.uint8)
192
-
193
- # Player 1 (X) = Red, Player 2 (O) = Blue, Empty = White
194
- for row in range(3):
195
- for col in range(3):
196
- if self.board[row, col] == 1:
197
- rgb[row, col] = [255, 0, 0] # Red
198
- elif self.board[row, col] == -1:
199
- rgb[row, col] = [0, 0, 255] # Blue
200
- else:
201
- rgb[row, col] = [255, 255, 255] # White
202
-
203
- return rgb
204
-
205
- def get_action_mask(self) -> np.ndarray:
206
- """Get mask of valid actions (1 for valid, 0 for invalid)."""
207
- mask = np.zeros(9, dtype=np.int8)
208
- for action in self._get_valid_actions():
209
- mask[action] = 1
210
- return mask
211
-
212
-
213
- def create_tictactoe_env() -> TicTacToeEnv:
214
- """Factory function to create a TicTacToe environment."""
215
- return TicTacToeEnv()
216
-
217
-
218
- if __name__ == "__main__":
219
- # Test the environment
220
- env = TicTacToeEnv(render_mode="human")
221
-
222
- # Play a simple game
223
- obs, info = env.reset()
224
- print("Initial state:")
225
- env.render()
226
-
227
- # Make some moves
228
- moves = [0, 4, 1, 3, 2] # X wins
229
- for move in moves:
230
- if not env.game_over:
231
- obs, reward, terminated, truncated, info = env.step(move)
232
- print(f"\nMove: {move}, Reward: {reward}")
233
- env.render()
234
-
235
- if terminated:
236
- print(f"Game terminated! Final reward: {reward}")
237
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/__init__.py DELETED
@@ -1,13 +0,0 @@
1
- """
2
- SPIRAL model implementations.
3
-
4
- This module contains the core SPIRAL model architecture and
5
- role-conditioned advantage estimation (RAE) components.
6
- """
7
-
8
- from .spiral_model import SpiralModel
9
- from .rae import RoleConditionedAdvantageEstimator
10
- from .policy_network import PolicyNetwork
11
- from .value_network import ValueNetwork
12
-
13
- __all__ = ["SpiralModel", "RoleConditionedAdvantageEstimator", "PolicyNetwork", "ValueNetwork"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/reasoning/__init__.py DELETED
@@ -1,13 +0,0 @@
1
- """
2
- Reasoning trace generation and analysis.
3
-
4
- This module handles the generation of step-by-step reasoning traces
5
- during gameplay and transfer to non-game tasks.
6
- """
7
-
8
- from .trace_generator import TraceGenerator
9
- from .chain_of_thought import ChainOfThought
10
- from .transfer_evaluator import TransferEvaluator
11
- from .reasoning_utils import ReasoningUtils
12
-
13
- __all__ = ["TraceGenerator", "ChainOfThought", "TransferEvaluator", "ReasoningUtils"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/training/__init__.py DELETED
@@ -1,13 +0,0 @@
1
- """
2
- Training components for SPIRAL.
3
-
4
- This module implements the self-play training logic using PPO
5
- with role-conditioned advantage estimation.
6
- """
7
-
8
- from .self_play_trainer import SelfPlayTrainer
9
- from .ppo_trainer import PPOTrainer
10
- from .opponent_manager import OpponentManager
11
- from .training_utils import TrainingUtils
12
-
13
- __all__ = ["SelfPlayTrainer", "PPOTrainer", "OpponentManager", "TrainingUtils"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/training/train_spiral.py DELETED
@@ -1,58 +0,0 @@
1
- import os
2
- import torch
3
- import gymnasium as gym
4
- import numpy as np
5
- from stable_baselines3 import PPO
6
- from stable_baselines3.common.vec_env import DummyVecEnv
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
- import yaml
9
-
10
- # Load config
11
- with open('../../config.yaml', 'r') as f:
12
- config = yaml.safe_load(f)
13
-
14
- model_name = config['model']['name']
15
- max_length = config['model']['max_length']
16
-
17
- # Load base LLM (quantized)
18
- model = AutoModelForCausalLM.from_pretrained(model_name, **config['model']['quantization'])
19
- tokenizer = AutoTokenizer.from_pretrained(model_name)
20
-
21
- # Custom Policy with RAE (simplified)
22
- class SpiralPolicy(torch.nn.Module):
23
- def __init__(self, observation_space, action_space):
24
- super().__init__()
25
- self.role_embed = torch.nn.Embedding(2, 64) # 0: player, 1: opponent
26
- # Add more layers as needed
27
-
28
- def forward(self, obs, role):
29
- # Condition on role
30
- role_emb = self.role_embed(role)
31
- # Compute policy/value (placeholder)
32
- return policy, value
33
-
34
- def train_spiral(game='tictactoe', episodes=1000):
35
- if game == 'tictactoe':
36
- from src.games.tictactoe import TicTacToeEnv
37
- env_fn = lambda: TicTacToeEnv()
38
- else:
39
- raise ValueError('Game not supported yet')
40
-
41
- env = DummyVecEnv([env_fn])
42
-
43
- # PPO with custom policy
44
- model = PPO('MlpPolicy', env, verbose=1, learning_rate=0.0003)
45
-
46
- # Self-play loop (simplified: train against current self)
47
- for ep in range(episodes):
48
- model.learn(total_timesteps=1000) # Train batch
49
- # Simulate self-play by cloning or saving opponent policy
50
- print(f'Episode {ep}: Trained')
51
-
52
- # Save model
53
- os.makedirs('../../models', exist_ok=True)
54
- model.save('../../models/spiral_tictactoe.zip')
55
- print('Model saved!')
56
-
57
- if __name__ == '__main__':
58
- train_spiral()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_basic.py DELETED
@@ -1,130 +0,0 @@
1
- """
2
- Basic tests for SPIRAL Interactive Reasoning Game Simulator.
3
-
4
- This module contains fundamental tests to verify the core functionality
5
- of the SPIRAL system components.
6
- """
7
-
8
- import pytest
9
- import os
10
- import sys
11
- import yaml
12
-
13
- # Add the src directory to the path for imports
14
- sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'src'))
15
- sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'app'))
16
-
17
- from app import SpiralApp
18
-
19
- class TestSpiralApp:
20
- """Test cases for the main SPIRAL application."""
21
-
22
- def test_app_initialization(self):
23
- """Test that the app initializes correctly."""
24
- app = SpiralApp()
25
- assert app is not None
26
- assert hasattr(app, 'config')
27
- assert hasattr(app, 'play_game')
28
- assert hasattr(app, 'test_reasoning')
29
-
30
- def test_config_loading(self):
31
- """Test configuration loading."""
32
- app = SpiralApp()
33
- assert 'interface' in app.config
34
- assert 'games' in app.config
35
- assert app.config['interface']['title'] is not None
36
-
37
- def test_play_game_basic(self):
38
- """Test basic game play functionality."""
39
- app = SpiralApp()
40
-
41
- # Test with valid input
42
- state, response, trace = app.play_game("kuhn_poker", "bet", "")
43
- assert state is not None
44
- assert response is not None
45
- assert trace is not None
46
- assert "bet" in state
47
-
48
- # Test with empty input
49
- state, response, trace = app.play_game("kuhn_poker", "", "")
50
- assert "Please enter a move!" in response
51
-
52
- def test_reasoning_basic(self):
53
- """Test basic reasoning functionality."""
54
- app = SpiralApp()
55
-
56
- # Test with valid input
57
- response, trace = app.test_reasoning("What is 2+2?", "math")
58
- assert response is not None
59
- assert trace is not None
60
- assert "2+2" in response
61
-
62
- # Test with empty input
63
- response, trace = app.test_reasoning("", "math")
64
- assert "Please enter a reasoning prompt!" in response
65
-
66
- def test_interface_creation(self):
67
- """Test that the Gradio interface can be created."""
68
- app = SpiralApp()
69
- demo = app.create_interface()
70
- assert demo is not None
71
-
72
- class TestConfiguration:
73
- """Test cases for configuration management."""
74
-
75
- def test_config_file_structure(self):
76
- """Test that config.yaml has the expected structure."""
77
- config_path = os.path.join(os.path.dirname(__file__), '..', 'config.yaml')
78
-
79
- if os.path.exists(config_path):
80
- with open(config_path, 'r') as f:
81
- config = yaml.safe_load(f)
82
-
83
- # Check required sections
84
- assert 'model' in config
85
- assert 'games' in config
86
- assert 'training' in config
87
- assert 'reasoning' in config
88
- assert 'interface' in config
89
-
90
- # Check model configuration
91
- assert 'name' in config['model']
92
- assert 'max_length' in config['model']
93
-
94
- # Check games configuration
95
- assert 'kuhn_poker' in config['games']
96
- assert 'tictactoe' in config['games']
97
-
98
- class TestProjectStructure:
99
- """Test cases for project structure and imports."""
100
-
101
- def test_src_directory_structure(self):
102
- """Test that the src directory has the expected structure."""
103
- src_path = os.path.join(os.path.dirname(__file__), '..', 'src')
104
-
105
- # Check that required directories exist
106
- assert os.path.exists(os.path.join(src_path, 'games'))
107
- assert os.path.exists(os.path.join(src_path, 'models'))
108
- assert os.path.exists(os.path.join(src_path, 'training'))
109
- assert os.path.exists(os.path.join(src_path, 'reasoning'))
110
-
111
- # Check that __init__.py files exist
112
- assert os.path.exists(os.path.join(src_path, '__init__.py'))
113
- assert os.path.exists(os.path.join(src_path, 'games', '__init__.py'))
114
- assert os.path.exists(os.path.join(src_path, 'models', '__init__.py'))
115
- assert os.path.exists(os.path.join(src_path, 'training', '__init__.py'))
116
- assert os.path.exists(os.path.join(src_path, 'reasoning', '__init__.py'))
117
-
118
- def test_required_files_exist(self):
119
- """Test that required project files exist."""
120
- project_root = os.path.join(os.path.dirname(__file__), '..')
121
-
122
- # Check essential files
123
- assert os.path.exists(os.path.join(project_root, 'requirements.txt'))
124
- assert os.path.exists(os.path.join(project_root, 'README.md'))
125
- assert os.path.exists(os.path.join(project_root, 'config.yaml'))
126
- assert os.path.exists(os.path.join(project_root, '.gitignore'))
127
- assert os.path.exists(os.path.join(project_root, 'app', 'app.py'))
128
-
129
- if __name__ == "__main__":
130
- pytest.main([__file__])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_games.py DELETED
@@ -1,78 +0,0 @@
1
- import pytest
2
- import numpy as np
3
- import sys
4
- import os
5
-
6
- # Add src to path to allow importing TicTacToeEnv
7
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
8
-
9
- from games.tictactoe import TicTacToeEnv
10
-
11
- @pytest.fixture
12
- def env():
13
- """Fixture to create a fresh TicTacToeEnv for each test."""
14
- return TicTacToeEnv()
15
-
16
- def test_initial_state(env):
17
- """Test the initial state of the board."""
18
- assert np.all(env.board == np.zeros((3, 3)))
19
- assert env.current_player == 1
20
- assert not env.game_over
21
-
22
- def test_player_move(env):
23
- """Test a valid player move."""
24
- env.step(0)
25
- assert env.board[0, 0] == 1
26
- assert env.current_player == -1
27
- assert not env.game_over
28
-
29
- def test_invalid_move(env):
30
- """Test making an invalid move on an occupied cell."""
31
- env.step(0)
32
- with pytest.raises(ValueError):
33
- env.step(0)
34
-
35
- def test_win_condition_row(env):
36
- """Test a win condition in a row."""
37
- env.board = np.array([[1, 1, 1], [0, -1, 0], [-1, 0, 0]])
38
- assert env._check_winner(1)
39
- assert not env._check_winner(-1)
40
-
41
- def test_win_condition_col(env):
42
- """Test a win condition in a column."""
43
- env.board = np.array([[-1, 1, 0], [-1, 1, 0], [-1, 0, 0]])
44
- assert not env._check_winner(1)
45
- assert env._check_winner(-1)
46
-
47
- def test_win_condition_diag(env):
48
- """Test a win condition on a diagonal."""
49
- env.board = np.array([[1, 0, -1], [0, 1, -1], [0, 0, 1]])
50
- assert env._check_winner(1)
51
-
52
- def test_draw_condition(env):
53
- """Test a draw condition."""
54
- env.board = np.array([[1, -1, 1], [1, -1, 1], [-1, 1, -1]])
55
- assert env._is_draw()
56
- assert not env._check_winner(1)
57
- assert not env._check_winner(-1)
58
-
59
- def test_game_over_on_win(env):
60
- """Test that the game_over flag is set on a win."""
61
- env.step(0) # P1
62
- env.step(3) # P2
63
- env.step(1) # P1
64
- env.step(4) # P2
65
- _, _, terminated, _, _ = env.step(2) # P1 wins
66
- assert terminated
67
- assert env.game_over
68
- assert env.winner == 1
69
-
70
- def test_reset(env):
71
- """Test if the environment resets correctly."""
72
- env.step(0)
73
- env.step(1)
74
- env.reset()
75
- assert np.all(env.board == np.zeros((3, 3)))
76
- assert env.current_player == 1
77
- assert not env.game_over
78
- assert env.winner is None