Kaushik Rajan commited on
Commit
ee800d8
·
1 Parent(s): 4420646

Implemented Phase 2 improvements, training script, reasoning integration, tests, and marked execution-plan.md

Browse files
Files changed (3) hide show
  1. app.py +149 -100
  2. src/training/train_spiral.py +58 -0
  3. tests/test_games.py +58 -0
app.py CHANGED
@@ -10,6 +10,8 @@ import random
10
  import os
11
  import sys
12
  import traceback
 
 
13
 
14
  # Add src to path for imports
15
  current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -82,6 +84,19 @@ else:
82
  print("❌ All import methods failed - using fallback interface")
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def create_interface():
86
  """Create the main Gradio interface."""
87
 
@@ -89,95 +104,107 @@ def create_interface():
89
  gr.Markdown("# 🎮 SPIRAL: Interactive Reasoning Game Simulator")
90
 
91
  if GAMES_AVAILABLE:
92
- gr.Markdown("**Demo Version** - Experience zero-sum games with AI! Full reasoning capabilities coming soon.")
93
 
94
- def get_tictactoe_board():
95
- """Get current TicTacToe board as string."""
 
96
  board = tictactoe_env.board
97
- display = ""
98
  for row in range(3):
 
99
  for col in range(3):
100
  cell = board[row, col]
101
  if cell == 1:
102
- display += " X "
103
  elif cell == -1:
104
- display += " O "
105
  else:
106
- display += f" {row*3 + col} "
107
- if col < 2:
108
- display += "|"
109
- display += "\n"
110
- if row < 2:
111
- display += "-----------\n"
112
- return display
113
 
114
- def play_tictactoe(position):
 
 
 
 
 
 
115
  """Play a TicTacToe move."""
116
  if tictactoe_env.game_over:
117
- return get_tictactoe_board(), "Game is over! Click 'New Game' to start again.", ""
118
 
119
  try:
120
  position = int(position)
121
  if position < 0 or position > 8:
122
- return get_tictactoe_board(), "Invalid position! Choose 0-8.", ""
123
 
124
  # Human move
125
  obs, reward, terminated, truncated, info = tictactoe_env.step(position)
126
 
127
  if terminated:
128
- winner = "You" if tictactoe_env.winner == 1 else "AI" if tictactoe_env.winner == -1 else "No one"
129
- return get_tictactoe_board(), f"Game Over! {winner} won!", f"Final reward: {reward}"
 
 
 
130
 
131
- # AI move (random for now)
132
- if not tictactoe_env.game_over:
133
- valid_actions = tictactoe_env._get_valid_actions()
134
- if valid_actions:
135
- ai_action = random.choice(valid_actions)
136
- obs, reward, terminated, truncated, info = tictactoe_env.step(ai_action)
137
-
138
- if terminated:
139
- winner = "You" if tictactoe_env.winner == 1 else "AI" if tictactoe_env.winner == -1 else "No one"
140
- return get_tictactoe_board(), f"Game Over! {winner} won!", f"AI played position {ai_action}. Final reward: {reward}"
141
- else:
142
- return get_tictactoe_board(), f"AI played position {ai_action}. Your turn!", f"AI reasoning: Chose position {ai_action} randomly"
143
 
144
- return get_tictactoe_board(), "Your turn!", ""
 
 
 
 
 
 
 
145
 
146
- except ValueError:
147
- return get_tictactoe_board(), "Please enter a valid number (0-8).", ""
148
  except Exception as e:
149
- return get_tictactoe_board(), f"Error: {str(e)}", ""
150
 
151
- def reset_tictactoe():
152
  """Reset TicTacToe game."""
153
  tictactoe_env.reset()
154
- return get_tictactoe_board(), "New game started! You are X. Choose a position (0-8).", ""
155
 
156
- def get_kuhn_poker_state():
157
- """Get current Kuhn Poker state as string."""
158
- state = f"🃏 Your Card: {['J', 'Q', 'K'][kuhn_env.player1_card]}\n"
159
- state += f"💰 Pot: {kuhn_env.pot}\n"
160
- state += f"🎯 Current Player: {kuhn_env.current_player}\n"
161
- state += f"🔄 Betting Round: {kuhn_env.betting_round}\n"
 
162
 
163
  if kuhn_env.actions_history:
164
- state += "\n📋 Actions:\n"
165
  for player, action in kuhn_env.actions_history:
166
  action_name = ["Check/Call", "Bet", "Fold"][action]
167
- state += f" Player {player}: {action_name}\n"
 
168
 
169
- return state
 
 
 
170
 
171
- def play_kuhn_poker(action_name):
172
  """Play a Kuhn Poker move."""
173
  if kuhn_env.game_over:
174
- return get_kuhn_poker_state(), "Game is over! Click 'New Game' to start again.", ""
175
 
176
  try:
177
- # Map action name to action number
178
  action_map = {"Check/Call": 0, "Bet": 1, "Fold": 2}
179
  if action_name not in action_map:
180
- return get_kuhn_poker_state(), "Invalid action!", ""
181
 
182
  action = action_map[action_name]
183
 
@@ -185,97 +212,100 @@ def create_interface():
185
  obs, reward, terminated, truncated, info = kuhn_env.step(action)
186
 
187
  if terminated:
188
- winner = "You" if kuhn_env.winner == 1 else "AI"
189
- return get_kuhn_poker_state(), f"Game Over! {winner} won! Pot: {kuhn_env.pot}", f"Your final reward: {reward}"
 
 
 
190
 
191
- # AI move (random for now)
192
- if not kuhn_env.game_over:
193
- valid_actions = kuhn_env._get_valid_actions()
194
- ai_action = random.choice(valid_actions)
195
- ai_action_name = ["Check/Call", "Bet", "Fold"][ai_action]
196
-
197
- obs, reward, terminated, truncated, info = kuhn_env.step(ai_action)
198
-
199
- if terminated:
200
- winner = "You" if kuhn_env.winner == 1 else "AI"
201
- return get_kuhn_poker_state(), f"AI chose {ai_action_name}. Game Over! {winner} won! Pot: {kuhn_env.pot}", f"AI reasoning: Chose {ai_action_name} randomly. Your final reward: {reward}"
202
- else:
203
- return get_kuhn_poker_state(), f"AI chose {ai_action_name}. Your turn!", f"AI reasoning: Chose {ai_action_name} randomly"
204
 
205
- return get_kuhn_poker_state(), "Your turn!", ""
 
 
 
 
 
 
 
206
 
207
  except Exception as e:
208
- return get_kuhn_poker_state(), f"Error: {str(e)}", ""
209
 
210
- def reset_kuhn_poker():
211
  """Reset Kuhn Poker game."""
212
  kuhn_env.reset()
213
- return get_kuhn_poker_state(), "New game started! You are Player 1. Choose your action.", f"Your card: {['J', 'Q', 'K'][kuhn_env.player1_card]}"
 
214
 
215
  with gr.Tabs():
216
  # TicTacToe Tab
217
  with gr.TabItem("🎯 TicTacToe"):
218
- gr.Markdown("### Play TicTacToe against AI")
219
- gr.Markdown("You are **X** and go first. Enter a position (0-8) to make your move.")
220
 
221
  with gr.Row():
222
  with gr.Column(scale=2):
223
- ttt_board = gr.Textbox(
224
  label="Game Board",
225
- value=get_tictactoe_board(),
226
- lines=6,
227
- interactive=False,
228
- elem_id="ttt-board"
229
  )
230
 
231
  with gr.Column(scale=1):
232
- ttt_position = gr.Textbox(
233
- label="Your Move (0-8)",
234
- placeholder="Enter position number",
235
- lines=1
236
  )
237
-
238
  with gr.Row():
239
  ttt_play_btn = gr.Button("Play Move", variant="primary")
240
  ttt_reset_btn = gr.Button("New Game", variant="secondary")
 
241
 
242
  ttt_message = gr.Textbox(
243
  label="Game Status",
244
- value="Choose a position (0-8) to start!",
245
  lines=2,
246
  interactive=False
247
  )
248
 
249
  ttt_reasoning = gr.Textbox(
250
  label="AI Reasoning",
251
- value="AI will show its reasoning here...",
252
- lines=2,
253
  interactive=False
254
  )
255
 
256
  ttt_play_btn.click(
257
  fn=play_tictactoe,
258
- inputs=[ttt_position],
259
- outputs=[ttt_board, ttt_message, ttt_reasoning]
260
  )
261
-
262
  ttt_reset_btn.click(
263
  fn=reset_tictactoe,
264
- outputs=[ttt_board, ttt_message, ttt_reasoning]
 
 
 
 
 
 
 
265
  )
266
 
267
  # Kuhn Poker Tab
268
  with gr.TabItem("🃏 Kuhn Poker"):
269
- gr.Markdown("### Play Kuhn Poker against AI")
270
- gr.Markdown("Simple poker with 3 cards (J, Q, K). You are Player 1.")
271
 
272
  with gr.Row():
273
  with gr.Column(scale=2):
274
- kuhn_state = gr.Textbox(
275
  label="Game State",
276
- value=get_kuhn_poker_state(),
277
- lines=8,
278
- interactive=False
279
  )
280
 
281
  with gr.Column(scale=1):
@@ -284,10 +314,10 @@ def create_interface():
284
  choices=["Check/Call", "Bet", "Fold"],
285
  value="Check/Call"
286
  )
287
-
288
  with gr.Row():
289
  kuhn_play_btn = gr.Button("Play Action", variant="primary")
290
  kuhn_reset_btn = gr.Button("New Game", variant="secondary")
 
291
 
292
  kuhn_message = gr.Textbox(
293
  label="Game Status",
@@ -298,22 +328,40 @@ def create_interface():
298
 
299
  kuhn_reasoning = gr.Textbox(
300
  label="AI Reasoning",
301
- value="AI will show its reasoning here...",
302
- lines=2,
303
  interactive=False
304
  )
305
 
306
  kuhn_play_btn.click(
307
  fn=play_kuhn_poker,
308
- inputs=[kuhn_action],
309
- outputs=[kuhn_state, kuhn_message, kuhn_reasoning]
310
  )
311
-
312
  kuhn_reset_btn.click(
313
  fn=reset_kuhn_poker,
314
- outputs=[kuhn_state, kuhn_message, kuhn_reasoning]
 
 
 
 
 
 
315
  )
316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  else:
318
  # Fallback interface when games don't load
319
  gr.Markdown("⚠️ **Game modules could not be loaded.** Showing diagnostic information.")
@@ -377,6 +425,7 @@ def create_interface():
377
  - Gradio web interface
378
  - Ready for SPIRAL training integration
379
  """)
 
380
 
381
  if GAMES_AVAILABLE:
382
  gr.Markdown("---")
 
10
  import os
11
  import sys
12
  import traceback
13
+ import yaml
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
 
16
  # Add src to path for imports
17
  current_dir = os.path.dirname(os.path.abspath(__file__))
 
84
  print("❌ All import methods failed - using fallback interface")
85
 
86
 
87
+ with open('config.yaml', 'r') as f:
88
+ config = yaml.safe_load(f)
89
+ model_name = config['model']['name']
90
+ model = AutoModelForCausalLM.from_pretrained(model_name, **config['model']['quantization'])
91
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
92
+
93
+ def generate_reasoning(prompt):
94
+ """Generate reasoning trace using Qwen model."""
95
+ inputs = tokenizer(prompt, return_tensors="pt")
96
+ outputs = model.generate(**inputs, max_length=150, do_sample=True, temperature=0.7)
97
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
98
+
99
+
100
  def create_interface():
101
  """Create the main Gradio interface."""
102
 
 
104
  gr.Markdown("# 🎮 SPIRAL: Interactive Reasoning Game Simulator")
105
 
106
  if GAMES_AVAILABLE:
107
+ gr.Markdown("**Demo Version** - Experience zero-sum games with AI! Full reasoning capabilities coming soon. Learn how AI makes decisions in competitive scenarios.")
108
 
109
+ # TicTacToe specific functions
110
+ def get_tictactoe_board_html():
111
+ """Get current TicTacToe board as HTML with emojis."""
112
  board = tictactoe_env.board
113
+ html = '<table style="border: 1px solid black; text-align: center; font-size: 24px;">'
114
  for row in range(3):
115
+ html += '<tr>'
116
  for col in range(3):
117
  cell = board[row, col]
118
  if cell == 1:
119
+ content = '❌'
120
  elif cell == -1:
121
+ content = '⭕'
122
  else:
123
+ content = f'{row*3 + col}'
124
+ html += f'<td style="border: 1px solid black; width: 50px; height: 50px;">{content}</td>'
125
+ html += '</tr>'
126
+ html += '</table>'
127
+ return html
 
 
128
 
129
+ def get_valid_tictactoe_positions():
130
+ """Get list of valid position strings."""
131
+ return [str(i) for i in tictactoe_env._get_valid_actions()]
132
+
133
+ ttt_stats = gr.State({'wins': 0, 'losses': 0, 'draws': 0})
134
+
135
+ def play_tictactoe(position, stats):
136
  """Play a TicTacToe move."""
137
  if tictactoe_env.game_over:
138
+ return get_tictactoe_board_html(), "Game is over! Click 'New Game' to start again.", "", stats, get_valid_tictactoe_positions()
139
 
140
  try:
141
  position = int(position)
142
  if position < 0 or position > 8:
143
+ raise ValueError("Invalid position")
144
 
145
  # Human move
146
  obs, reward, terminated, truncated, info = tictactoe_env.step(position)
147
 
148
  if terminated:
149
+ winner = "You" if tictactoe_env.winner == 1 else "AI" if tictactoe_env.winner == -1 else "Draw"
150
+ if winner == "You": stats['wins'] += 1
151
+ elif winner == "AI": stats['losses'] += 1
152
+ else: stats['draws'] += 1
153
+ return get_tictactoe_board_html(), f"Game Over! {winner} won!", f"Final reward: {reward}", stats, []
154
 
155
+ # AI move
156
+ valid_actions = tictactoe_env._get_valid_actions()
157
+ ai_action = random.choice(valid_actions) # Still random for now; integrate policy later
158
+ reasoning_prompt = f"In TicTacToe, board state: {tictactoe_env.board.flatten().tolist()}. Valid moves: {valid_actions}. Explain why to choose one randomly as placeholder."
159
+ reasoning = generate_reasoning(reasoning_prompt)
160
+ obs, reward, terminated, truncated, info = tictactoe_env.step(ai_action)
 
 
 
 
 
 
161
 
162
+ if terminated:
163
+ winner = "You" if tictactoe_env.winner == 1 else "AI" if tictactoe_env.winner == -1 else "Draw"
164
+ if winner == "You": stats['wins'] += 1
165
+ elif winner == "AI": stats['losses'] += 1
166
+ else: stats['draws'] += 1
167
+ return get_tictactoe_board_html(), f"Game Over! {winner} won! AI played {ai_action}.", reasoning, stats, []
168
+ else:
169
+ return get_tictactoe_board_html(), f"AI played position {ai_action}. Your turn!", reasoning, stats, get_valid_tictactoe_positions()
170
 
 
 
171
  except Exception as e:
172
+ return get_tictactoe_board_html(), f"Error: {str(e)}", "", stats, get_valid_tictactoe_positions()
173
 
174
+ def reset_tictactoe(stats):
175
  """Reset TicTacToe game."""
176
  tictactoe_env.reset()
177
+ return get_tictactoe_board_html(), "New game started! You are ❌ (X). Choose a position from the dropdown.", "AI will show its reasoning here...", stats, get_valid_tictactoe_positions()
178
 
179
+ def get_kuhn_poker_state_html():
180
+ """Get current Kuhn Poker state as HTML."""
181
+ card = ['J', 'Q', 'K'][kuhn_env.player1_card]
182
+ html = f"<div style='font-size: 18px;'><p>🃏 Your Card: <strong>{card}</strong></p>"
183
+ html += f"<p>💰 Pot: <strong>{kuhn_env.pot}</strong></p>"
184
+ html += f"<p>🎯 Current Player: <strong>{kuhn_env.current_player}</strong></p>"
185
+ html += f"<p>🔄 Betting Round: <strong>{kuhn_env.betting_round}</strong></p>"
186
 
187
  if kuhn_env.actions_history:
188
+ html += "<p>📋 Actions:</p><ul>"
189
  for player, action in kuhn_env.actions_history:
190
  action_name = ["Check/Call", "Bet", "Fold"][action]
191
+ html += f"<li>Player {player}: {action_name}</li>"
192
+ html += "</ul>"
193
 
194
+ html += "</div>"
195
+ return html
196
+
197
+ kuhn_stats = gr.State({'wins': 0, 'losses': 0, 'draws': 0})
198
 
199
+ def play_kuhn_poker(action_name, stats):
200
  """Play a Kuhn Poker move."""
201
  if kuhn_env.game_over:
202
+ return get_kuhn_poker_state_html(), "Game is over! Click 'New Game' to start again.", "", stats
203
 
204
  try:
 
205
  action_map = {"Check/Call": 0, "Bet": 1, "Fold": 2}
206
  if action_name not in action_map:
207
+ raise ValueError("Invalid action")
208
 
209
  action = action_map[action_name]
210
 
 
212
  obs, reward, terminated, truncated, info = kuhn_env.step(action)
213
 
214
  if terminated:
215
+ winner = "You" if kuhn_env.winner == 1 else "AI" if kuhn_env.winner == -1 else "Draw"
216
+ if winner == "You": stats['wins'] += 1
217
+ elif winner == "AI": stats['losses'] += 1
218
+ else: stats['draws'] += 1
219
+ return get_kuhn_poker_state_html(), f"Game Over! {winner} won! Pot: {kuhn_env.pot}", f"Your final reward: {reward}", stats
220
 
221
+ # AI move
222
+ valid_actions = kuhn_env._get_valid_actions()
223
+ ai_action = random.choice(valid_actions)
224
+ ai_action_name = ["Check/Call", "Bet", "Fold"][ai_action]
225
+ reasoning_prompt = f"In Kuhn Poker, my card: {kuhn_env.player2_card}, history: {kuhn_env.actions_history}. Valid actions: {valid_actions}. Explain choice."
226
+ reasoning = generate_reasoning(reasoning_prompt)
227
+ obs, reward, terminated, truncated, info = kuhn_env.step(ai_action)
 
 
 
 
 
 
228
 
229
+ if terminated:
230
+ winner = "You" if kuhn_env.winner == 1 else "AI" if kuhn_env.winner == -1 else "Draw"
231
+ if winner == "You": stats['wins'] += 1
232
+ elif winner == "AI": stats['losses'] += 1
233
+ else: stats['draws'] += 1
234
+ return get_kuhn_poker_state_html(), f"AI chose {ai_action_name}. Game Over! {winner} won! Pot: {kuhn_env.pot}", reasoning, stats
235
+ else:
236
+ return get_kuhn_poker_state_html(), f"AI chose {ai_action_name}. Your turn!", reasoning, stats
237
 
238
  except Exception as e:
239
+ return get_kuhn_poker_state_html(), f"Error: {str(e)}", "", stats
240
 
241
+ def reset_kuhn_poker(stats):
242
  """Reset Kuhn Poker game."""
243
  kuhn_env.reset()
244
+ card = ['J', 'Q', 'K'][kuhn_env.player1_card]
245
+ return get_kuhn_poker_state_html(), "New game started! You are Player 1. Choose your action.", f"Your card: {card}", stats
246
 
247
  with gr.Tabs():
248
  # TicTacToe Tab
249
  with gr.TabItem("🎯 TicTacToe"):
250
+ gr.Markdown("### Play TicTacToe against AI\nYou are ❌ (X) and go first. Get 3 in a row to win! **How AI Thinks**: AI will analyze the board and explain its moves (random for now; full reasoning soon).\nPositions: Top-left=0, bottom-right=8.")
 
251
 
252
  with gr.Row():
253
  with gr.Column(scale=2):
254
+ ttt_board = gr.HTML(
255
  label="Game Board",
256
+ value=get_tictactoe_board_html()
 
 
 
257
  )
258
 
259
  with gr.Column(scale=1):
260
+ ttt_position = gr.Dropdown(
261
+ label="Your Move (Valid Positions)",
262
+ choices=get_valid_tictactoe_positions()
 
263
  )
 
264
  with gr.Row():
265
  ttt_play_btn = gr.Button("Play Move", variant="primary")
266
  ttt_reset_btn = gr.Button("New Game", variant="secondary")
267
+ ttt_stats_display = gr.Markdown(value="Wins: 0 | Losses: 0 | Draws: 0")
268
 
269
  ttt_message = gr.Textbox(
270
  label="Game Status",
271
+ value="Choose a position to start!",
272
  lines=2,
273
  interactive=False
274
  )
275
 
276
  ttt_reasoning = gr.Textbox(
277
  label="AI Reasoning",
278
+ value="AI will explain its thought process here...",
279
+ lines=3,
280
  interactive=False
281
  )
282
 
283
  ttt_play_btn.click(
284
  fn=play_tictactoe,
285
+ inputs=[ttt_position, ttt_stats],
286
+ outputs=[ttt_board, ttt_message, ttt_reasoning, ttt_stats, ttt_position]
287
  )
 
288
  ttt_reset_btn.click(
289
  fn=reset_tictactoe,
290
+ inputs=[ttt_stats],
291
+ outputs=[ttt_board, ttt_message, ttt_reasoning, ttt_stats, ttt_position]
292
+ )
293
+ # Update stats display on changes
294
+ ttt_stats.change(
295
+ fn=lambda s: f"Wins: {s['wins']} | Losses: {s['losses']} | Draws: {s['draws']}",
296
+ inputs=ttt_stats,
297
+ outputs=ttt_stats_display
298
  )
299
 
300
  # Kuhn Poker Tab
301
  with gr.TabItem("🃏 Kuhn Poker"):
302
+ gr.Markdown("### Play Kuhn Poker against AI\nSimplified poker with J/Q/K cards. You ante 1 chip each. Higher card wins if no fold. **How AI Thinks**: AI evaluates card strength and bets (random now; strategic soon).")
 
303
 
304
  with gr.Row():
305
  with gr.Column(scale=2):
306
+ kuhn_state = gr.HTML(
307
  label="Game State",
308
+ value=get_kuhn_poker_state_html()
 
 
309
  )
310
 
311
  with gr.Column(scale=1):
 
314
  choices=["Check/Call", "Bet", "Fold"],
315
  value="Check/Call"
316
  )
 
317
  with gr.Row():
318
  kuhn_play_btn = gr.Button("Play Action", variant="primary")
319
  kuhn_reset_btn = gr.Button("New Game", variant="secondary")
320
+ kuhn_stats_display = gr.Markdown(value="Wins: 0 | Losses: 0 | Draws: 0")
321
 
322
  kuhn_message = gr.Textbox(
323
  label="Game Status",
 
328
 
329
  kuhn_reasoning = gr.Textbox(
330
  label="AI Reasoning",
331
+ value="AI will explain its thought process here...",
332
+ lines=3,
333
  interactive=False
334
  )
335
 
336
  kuhn_play_btn.click(
337
  fn=play_kuhn_poker,
338
+ inputs=[kuhn_action, kuhn_stats],
339
+ outputs=[kuhn_state, kuhn_message, kuhn_reasoning, kuhn_stats]
340
  )
 
341
  kuhn_reset_btn.click(
342
  fn=reset_kuhn_poker,
343
+ inputs=[kuhn_stats],
344
+ outputs=[kuhn_state, kuhn_message, kuhn_reasoning, kuhn_stats]
345
+ )
346
+ kuhn_stats.change(
347
+ fn=lambda s: f"Wins: {s['wins']} | Losses: {s['losses']} | Draws: {s['draws']}",
348
+ inputs=kuhn_stats,
349
+ outputs=kuhn_stats_display
350
  )
351
 
352
+ # New Transfer Test Tab (stub)
353
+ with gr.TabItem("🔬 Transfer Test"):
354
+ gr.Markdown("### Test AI Reasoning on Non-Game Tasks\n(Coming Soon) Enter a math problem or logic puzzle to see transferred reasoning from game training.")
355
+ transfer_input = gr.Textbox(label="Input Prompt", placeholder="E.g., 'Solve: 2x + 3 = 7'")
356
+ transfer_output = gr.Textbox(label="AI Response", interactive=False)
357
+ transfer_btn = gr.Button("Test")
358
+
359
+ def transfer_test(input):
360
+ cot_prompt = f"Solve step-by-step: {input}"
361
+ return generate_reasoning(cot_prompt)
362
+
363
+ transfer_btn.click(fn=transfer_test, inputs=transfer_input, outputs=transfer_output)
364
+
365
  else:
366
  # Fallback interface when games don't load
367
  gr.Markdown("⚠️ **Game modules could not be loaded.** Showing diagnostic information.")
 
425
  - Gradio web interface
426
  - Ready for SPIRAL training integration
427
  """)
428
+ gr.Markdown("**New in this version:** Visual boards, stats tracking, and transfer test stub!")
429
 
430
  if GAMES_AVAILABLE:
431
  gr.Markdown("---")
src/training/train_spiral.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gymnasium as gym
4
+ import numpy as np
5
+ from stable_baselines3 import PPO
6
+ from stable_baselines3.common.vec_env import DummyVecEnv
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ import yaml
9
+
10
+ # Load config
11
+ with open('../../config.yaml', 'r') as f:
12
+ config = yaml.safe_load(f)
13
+
14
+ model_name = config['model']['name']
15
+ max_length = config['model']['max_length']
16
+
17
+ # Load base LLM (quantized)
18
+ model = AutoModelForCausalLM.from_pretrained(model_name, **config['model']['quantization'])
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+
21
+ # Custom Policy with RAE (simplified)
22
+ class SpiralPolicy(torch.nn.Module):
23
+ def __init__(self, observation_space, action_space):
24
+ super().__init__()
25
+ self.role_embed = torch.nn.Embedding(2, 64) # 0: player, 1: opponent
26
+ # Add more layers as needed
27
+
28
+ def forward(self, obs, role):
29
+ # Condition on role
30
+ role_emb = self.role_embed(role)
31
+ # Compute policy/value (placeholder)
32
+ return policy, value
33
+
34
+ def train_spiral(game='tictactoe', episodes=1000):
35
+ if game == 'tictactoe':
36
+ from src.games.tictactoe import TicTacToeEnv
37
+ env_fn = lambda: TicTacToeEnv()
38
+ else:
39
+ raise ValueError('Game not supported yet')
40
+
41
+ env = DummyVecEnv([env_fn])
42
+
43
+ # PPO with custom policy
44
+ model = PPO('MlpPolicy', env, verbose=1, learning_rate=0.0003)
45
+
46
+ # Self-play loop (simplified: train against current self)
47
+ for ep in range(episodes):
48
+ model.learn(total_timesteps=1000) # Train batch
49
+ # Simulate self-play by cloning or saving opponent policy
50
+ print(f'Episode {ep}: Trained')
51
+
52
+ # Save model
53
+ os.makedirs('../../models', exist_ok=True)
54
+ model.save('../../models/spiral_tictactoe.zip')
55
+ print('Model saved!')
56
+
57
+ if __name__ == '__main__':
58
+ train_spiral()
tests/test_games.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import numpy as np
3
+ from src.games.tictactoe import TicTacToeEnv
4
+ from src.games.kuhn_poker import KuhnPokerEnv
5
+
6
+ @pytest.fixture
7
+ def ttt_env():
8
+ return TicTacToeEnv()
9
+
10
+ @pytest.fixture
11
+ def kuhn_env():
12
+ return KuhnPokerEnv()
13
+
14
+ def test_tictactoe_reset(ttt_env):
15
+ obs, info = ttt_env.reset()
16
+ assert np.all(obs == 0)
17
+ assert ttt_env.current_player == 1
18
+ assert not ttt_env.game_over
19
+
20
+ def test_tictactoe_win(ttt_env):
21
+ # Simulate win for player 1
22
+ ttt_env.step(0) # X
23
+ ttt_env.step(3) # O (invalid sim, but test step)
24
+ ttt_env.step(1) # X
25
+ ttt_env.step(4) # O
26
+ _, reward, terminated, _, _ = ttt_env.step(2) # X wins
27
+ assert terminated
28
+ assert reward == 1 # From player 1 perspective
29
+ assert ttt_env.winner == 1
30
+
31
+ def test_tictactoe_invalid_move(ttt_env):
32
+ ttt_env.step(0)
33
+ _, reward, terminated, _, info = ttt_env.step(0) # Same spot
34
+ assert 'invalid_move' in info
35
+ assert terminated
36
+ assert reward == -1
37
+
38
+ def test_kuhn_reset(kuhn_env):
39
+ obs, info = kuhn_env.reset()
40
+ assert kuhn_env.pot == 2 # Antes
41
+ assert kuhn_env.current_player == 1
42
+ assert not kuhn_env.game_over
43
+
44
+ def test_kuhn_fold(kuhn_env):
45
+ _, reward, terminated, _, _ = kuhn_env.step(2) # Player 1 folds
46
+ assert terminated
47
+ assert reward == -1 # Lost ante
48
+ assert kuhn_env.winner == -1
49
+
50
+ def test_kuhn_win(kuhn_env):
51
+ kuhn_env.player1_card = 2 # K
52
+ kuhn_env.player2_card = 0 # J
53
+ kuhn_env.step(1) # Bet
54
+ kuhn_env.step(0) # Call
55
+ _, reward, terminated, _, _ = kuhn_env.step(0) # Call (if needed)
56
+ assert terminated
57
+ assert reward > 0 # Win with higher card
58
+ assert kuhn_env.winner == 1