Kaushik Rajan commited on
Commit
b1670f3
ยท
1 Parent(s): 6be63cd

feat(tictactoe): Refine UI, implement Minimax AI, and add tests

Browse files
Files changed (4) hide show
  1. .gitignore +9 -1
  2. app.py +109 -200
  3. requirements.txt +3 -1
  4. tests/test_games.py +68 -48
.gitignore CHANGED
@@ -230,4 +230,12 @@ gradio_cached_examples/
230
  execution-plan.md
231
 
232
  # Research paper images - not needed in repo
233
- research-paper-snips/
 
 
 
 
 
 
 
 
 
230
  execution-plan.md
231
 
232
  # Research paper images - not needed in repo
233
+ research-paper-snips/
234
+
235
+ # huggingface
236
+ *.sagemaker_notebook.ipynb
237
+
238
+ # virtualenv
239
+ .venv/
240
+ venv/
241
+ ENV/
app.py CHANGED
@@ -76,10 +76,10 @@ if GAMES_AVAILABLE:
76
  try:
77
  # Test instantiation
78
  tictactoe_env = TicTacToeEnv()
79
- kuhn_env = KuhnPokerEnv()
80
- print("โœ… Game environments created successfully")
81
  except Exception as e:
82
- print(f"โŒ Error creating game environments: {e}")
83
  print("๐Ÿ“‹ Full traceback:", traceback.format_exc())
84
  GAMES_AVAILABLE = False
85
  else:
@@ -105,10 +105,9 @@ def create_interface():
105
 
106
  with gr.Blocks(title="SPIRAL: Interactive Reasoning Game Simulator", theme=gr.themes.Soft()) as demo:
107
  gr.Markdown("# ๐ŸŽฎ SPIRAL: Interactive Reasoning Game Simulator")
108
-
 
109
  if GAMES_AVAILABLE:
110
- gr.Markdown("**Demo Version** - Experience zero-sum games with AI! Full reasoning capabilities coming soon. Learn how AI makes decisions in competitive scenarios.")
111
-
112
  # TicTacToe specific functions
113
  def get_tictactoe_board_html():
114
  """Get current TicTacToe board as HTML with emojis."""
@@ -135,11 +134,46 @@ def create_interface():
135
 
136
  ttt_stats = gr.State({'wins': 0, 'losses': 0, 'draws': 0})
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def play_tictactoe(position, stats):
139
  """Play a TicTacToe move."""
140
  if tictactoe_env.game_over:
141
- return get_tictactoe_board_html(), "Game is over! Click 'New Game' to start again.", "", stats, get_valid_tictactoe_positions()
142
-
 
143
  try:
144
  position = int(position)
145
  if position < 0 or position > 8:
@@ -153,12 +187,23 @@ def create_interface():
153
  if winner == "You": stats['wins'] += 1
154
  elif winner == "AI": stats['losses'] += 1
155
  else: stats['draws'] += 1
156
- return get_tictactoe_board_html(), f"Game Over! {winner} won!", f"Final reward: {reward}", stats, []
157
-
 
 
 
 
158
  # AI move
159
- valid_actions = tictactoe_env._get_valid_actions()
160
- ai_action = random.choice(valid_actions) # Still random for now; integrate policy later
161
- reasoning_prompt = f"In TicTacToe, board state: {tictactoe_env.board.flatten().tolist()}. Valid moves: {valid_actions}. Explain why to choose one randomly as placeholder."
 
 
 
 
 
 
 
162
  reasoning = generate_reasoning(reasoning_prompt)
163
  obs, reward, terminated, truncated, info = tictactoe_env.step(ai_action)
164
 
@@ -167,204 +212,71 @@ def create_interface():
167
  if winner == "You": stats['wins'] += 1
168
  elif winner == "AI": stats['losses'] += 1
169
  else: stats['draws'] += 1
170
- return get_tictactoe_board_html(), f"Game Over! {winner} won! AI played {ai_action}.", reasoning, stats, []
171
  else:
172
- return get_tictactoe_board_html(), f"AI played position {ai_action}. Your turn!", reasoning, stats, get_valid_tictactoe_positions()
173
 
174
  except Exception as e:
175
- return get_tictactoe_board_html(), f"Error: {str(e)}", "", stats, get_valid_tictactoe_positions()
176
 
177
  def reset_tictactoe(stats):
178
  """Reset TicTacToe game."""
179
  tictactoe_env.reset()
180
  return get_tictactoe_board_html(), "New game started! You are โŒ (X). Choose a position from the dropdown.", "AI will show its reasoning here...", stats, get_valid_tictactoe_positions()
181
 
182
- def get_kuhn_poker_state_html():
183
- """Get current Kuhn Poker state as HTML."""
184
- card = ['J', 'Q', 'K'][kuhn_env.player1_card]
185
- html = f"<div style='font-size: 18px;'><p>๐Ÿƒ Your Card: <strong>{card}</strong></p>"
186
- html += f"<p>๐Ÿ’ฐ Pot: <strong>{kuhn_env.pot}</strong></p>"
187
- html += f"<p>๐ŸŽฏ Current Player: <strong>{kuhn_env.current_player}</strong></p>"
188
- html += f"<p>๐Ÿ”„ Betting Round: <strong>{kuhn_env.betting_round}</strong></p>"
189
-
190
- if kuhn_env.actions_history:
191
- html += "<p>๐Ÿ“‹ Actions:</p><ul>"
192
- for player, action in kuhn_env.actions_history:
193
- action_name = ["Check/Call", "Bet", "Fold"][action]
194
- html += f"<li>Player {player}: {action_name}</li>"
195
- html += "</ul>"
196
-
197
- html += "</div>"
198
- return html
199
-
200
- kuhn_stats = gr.State({'wins': 0, 'losses': 0, 'draws': 0})
201
-
202
- def play_kuhn_poker(action_name, stats):
203
- """Play a Kuhn Poker move."""
204
- if kuhn_env.game_over:
205
- return get_kuhn_poker_state_html(), "Game is over! Click 'New Game' to start again.", "", stats
206
-
207
- try:
208
- action_map = {"Check/Call": 0, "Bet": 1, "Fold": 2}
209
- if action_name not in action_map:
210
- raise ValueError("Invalid action")
211
-
212
- action = action_map[action_name]
213
-
214
- # Human move
215
- obs, reward, terminated, truncated, info = kuhn_env.step(action)
216
-
217
- if terminated:
218
- winner = "You" if kuhn_env.winner == 1 else "AI" if kuhn_env.winner == -1 else "Draw"
219
- if winner == "You": stats['wins'] += 1
220
- elif winner == "AI": stats['losses'] += 1
221
- else: stats['draws'] += 1
222
- return get_kuhn_poker_state_html(), f"Game Over! {winner} won! Pot: {kuhn_env.pot}", f"Your final reward: {reward}", stats
223
-
224
- # AI move
225
- valid_actions = kuhn_env._get_valid_actions()
226
- ai_action = random.choice(valid_actions)
227
- ai_action_name = ["Check/Call", "Bet", "Fold"][ai_action]
228
- reasoning_prompt = f"In Kuhn Poker, my card: {kuhn_env.player2_card}, history: {kuhn_env.actions_history}. Valid actions: {valid_actions}. Explain choice."
229
- reasoning = generate_reasoning(reasoning_prompt)
230
- obs, reward, terminated, truncated, info = kuhn_env.step(ai_action)
231
-
232
- if terminated:
233
- winner = "You" if kuhn_env.winner == 1 else "AI" if kuhn_env.winner == -1 else "Draw"
234
- if winner == "You": stats['wins'] += 1
235
- elif winner == "AI": stats['losses'] += 1
236
- else: stats['draws'] += 1
237
- return get_kuhn_poker_state_html(), f"AI chose {ai_action_name}. Game Over! {winner} won! Pot: {kuhn_env.pot}", reasoning, stats
238
- else:
239
- return get_kuhn_poker_state_html(), f"AI chose {ai_action_name}. Your turn!", reasoning, stats
240
-
241
- except Exception as e:
242
- return get_kuhn_poker_state_html(), f"Error: {str(e)}", "", stats
243
-
244
- def reset_kuhn_poker(stats):
245
- """Reset Kuhn Poker game."""
246
- kuhn_env.reset()
247
- card = ['J', 'Q', 'K'][kuhn_env.player1_card]
248
- return get_kuhn_poker_state_html(), "New game started! You are Player 1. Choose your action.", f"Your card: {card}", stats
249
 
250
- with gr.Tabs():
251
- # TicTacToe Tab
252
- with gr.TabItem("๐ŸŽฏ TicTacToe"):
253
- gr.Markdown("### Play TicTacToe against AI\nYou are โŒ (X) and go first. Get 3 in a row to win! **How AI Thinks**: AI will analyze the board and explain its moves (random for now; full reasoning soon).\nPositions: Top-left=0, bottom-right=8.")
254
-
255
- with gr.Row():
256
- with gr.Column(scale=2):
257
- ttt_board = gr.HTML(
258
- label="Game Board",
259
- value=get_tictactoe_board_html()
260
- )
261
-
262
- with gr.Column(scale=1):
263
- ttt_position = gr.Dropdown(
264
- label="Your Move (Valid Positions)",
265
- choices=get_valid_tictactoe_positions()
266
- )
267
- with gr.Row():
268
- ttt_play_btn = gr.Button("Play Move", variant="primary")
269
- ttt_reset_btn = gr.Button("New Game", variant="secondary")
270
- ttt_stats_display = gr.Markdown(value="Wins: 0 | Losses: 0 | Draws: 0")
271
-
272
- ttt_message = gr.Textbox(
273
- label="Game Status",
274
- value="Choose a position to start!",
275
- lines=2,
276
- interactive=False
277
- )
278
-
279
- ttt_reasoning = gr.Textbox(
280
- label="AI Reasoning",
281
- value="AI will explain its thought process here...",
282
- lines=3,
283
- interactive=False
284
  )
285
 
286
- ttt_play_btn.click(
287
- fn=play_tictactoe,
288
- inputs=[ttt_position, ttt_stats],
289
- outputs=[ttt_board, ttt_message, ttt_reasoning, ttt_stats, ttt_position]
290
- )
291
- ttt_reset_btn.click(
292
- fn=reset_tictactoe,
293
- inputs=[ttt_stats],
294
- outputs=[ttt_board, ttt_message, ttt_reasoning, ttt_stats, ttt_position]
295
  )
296
- # Update stats display on changes
297
- ttt_stats.change(
298
- fn=lambda s: f"Wins: {s['wins']} | Losses: {s['losses']} | Draws: {s['draws']}",
299
- inputs=ttt_stats,
300
- outputs=ttt_stats_display
301
- )
302
-
303
- # Kuhn Poker Tab
304
- with gr.TabItem("๐Ÿƒ Kuhn Poker"):
305
- gr.Markdown("### Play Kuhn Poker against AI\nSimplified poker with J/Q/K cards. You ante 1 chip each. Higher card wins if no fold. **How AI Thinks**: AI evaluates card strength and bets (random now; strategic soon).")
306
-
307
  with gr.Row():
308
- with gr.Column(scale=2):
309
- kuhn_state = gr.HTML(
310
- label="Game State",
311
- value=get_kuhn_poker_state_html()
312
- )
313
-
314
- with gr.Column(scale=1):
315
- kuhn_action = gr.Dropdown(
316
- label="Your Action",
317
- choices=["Check/Call", "Bet", "Fold"],
318
- value="Check/Call"
319
- )
320
- with gr.Row():
321
- kuhn_play_btn = gr.Button("Play Action", variant="primary")
322
- kuhn_reset_btn = gr.Button("New Game", variant="secondary")
323
- kuhn_stats_display = gr.Markdown(value="Wins: 0 | Losses: 0 | Draws: 0")
324
-
325
- kuhn_message = gr.Textbox(
326
- label="Game Status",
327
- value="Choose your action!",
328
- lines=2,
329
- interactive=False
330
- )
331
-
332
- kuhn_reasoning = gr.Textbox(
333
- label="AI Reasoning",
334
- value="AI will explain its thought process here...",
335
- lines=3,
336
- interactive=False
337
- )
338
-
339
- kuhn_play_btn.click(
340
- fn=play_kuhn_poker,
341
- inputs=[kuhn_action, kuhn_stats],
342
- outputs=[kuhn_state, kuhn_message, kuhn_reasoning, kuhn_stats]
343
- )
344
- kuhn_reset_btn.click(
345
- fn=reset_kuhn_poker,
346
- inputs=[kuhn_stats],
347
- outputs=[kuhn_state, kuhn_message, kuhn_reasoning, kuhn_stats]
348
- )
349
- kuhn_stats.change(
350
- fn=lambda s: f"Wins: {s['wins']} | Losses: {s['losses']} | Draws: {s['draws']}",
351
- inputs=kuhn_stats,
352
- outputs=kuhn_stats_display
353
- )
354
-
355
- # New Transfer Test Tab (stub)
356
- with gr.TabItem("๐Ÿ”ฌ Transfer Test"):
357
- gr.Markdown("### Test AI Reasoning on Non-Game Tasks\n(Coming Soon) Enter a math problem or logic puzzle to see transferred reasoning from game training.")
358
- transfer_input = gr.Textbox(label="Input Prompt", placeholder="E.g., 'Solve: 2x + 3 = 7'")
359
- transfer_output = gr.Textbox(label="AI Response", interactive=False)
360
- transfer_btn = gr.Button("Test")
361
-
362
- def transfer_test(input):
363
- cot_prompt = f"Solve step-by-step: {input}"
364
- return generate_reasoning(cot_prompt)
365
-
366
- transfer_btn.click(fn=transfer_test, inputs=transfer_input, outputs=transfer_output)
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  else:
369
  # Fallback interface when games don't load
370
  gr.Markdown("โš ๏ธ **Game modules could not be loaded.** Showing diagnostic information.")
@@ -430,10 +342,7 @@ def create_interface():
430
  """)
431
  gr.Markdown("**New in this version:** Visual boards, stats tracking, and transfer test stub!")
432
 
433
- if GAMES_AVAILABLE:
434
- gr.Markdown("---")
435
- gr.Markdown("๐Ÿšง **This is a development preview.** Full SPIRAL training and reasoning capabilities will be added in the next update!")
436
- else:
437
  gr.Markdown("---")
438
  gr.Markdown("๐Ÿ”„ **Dependencies are loading.** Check the diagnostic info above and refresh in a few minutes!")
439
 
 
76
  try:
77
  # Test instantiation
78
  tictactoe_env = TicTacToeEnv()
79
+ # kuhn_env = KuhnPokerEnv() # No longer needed
80
+ print("โœ… Game environment created successfully")
81
  except Exception as e:
82
+ print(f"โŒ Error creating game environment: {e}")
83
  print("๐Ÿ“‹ Full traceback:", traceback.format_exc())
84
  GAMES_AVAILABLE = False
85
  else:
 
105
 
106
  with gr.Blocks(title="SPIRAL: Interactive Reasoning Game Simulator", theme=gr.themes.Soft()) as demo:
107
  gr.Markdown("# ๐ŸŽฎ SPIRAL: Interactive Reasoning Game Simulator")
108
+ gr.Markdown("Play TicTacToe against an AI, see its step-by-step reasoning, and learn how it thinks!")
109
+
110
  if GAMES_AVAILABLE:
 
 
111
  # TicTacToe specific functions
112
  def get_tictactoe_board_html():
113
  """Get current TicTacToe board as HTML with emojis."""
 
134
 
135
  ttt_stats = gr.State({'wins': 0, 'losses': 0, 'draws': 0})
136
 
137
+ def minimax(board, player):
138
+ """Minimax algorithm to find the best move."""
139
+
140
+ # Base cases
141
+ if tictactoe_env._check_winner(1):
142
+ return -10, None
143
+ elif tictactoe_env._check_winner(-1):
144
+ return 10, None
145
+ elif tictactoe_env._is_draw():
146
+ return 0, None
147
+
148
+ best_move = None
149
+ if player == -1: # AI is player -1 (O), maximizing player
150
+ best_score = -float('inf')
151
+ for move in tictactoe_env._get_valid_actions():
152
+ row, col = divmod(move, 3)
153
+ board[row, col] = -1
154
+ score, _ = minimax(board.copy(), 1)
155
+ board[row, col] = 0 # Undo move
156
+ if score > best_score:
157
+ best_score = score
158
+ best_move = move
159
+ else: # Human is player 1 (X), minimizing player
160
+ best_score = float('inf')
161
+ for move in tictactoe_env._get_valid_actions():
162
+ row, col = divmod(move, 3)
163
+ board[row, col] = 1
164
+ score, _ = minimax(board.copy(), -1)
165
+ board[row, col] = 0 # Undo move
166
+ if score < best_score:
167
+ best_score = score
168
+ best_move = move
169
+ return best_score, best_move
170
+
171
  def play_tictactoe(position, stats):
172
  """Play a TicTacToe move."""
173
  if tictactoe_env.game_over:
174
+ yield get_tictactoe_board_html(), "Game is over! Click 'New Game' to start again.", "", stats, get_valid_tictactoe_positions()
175
+ return
176
+
177
  try:
178
  position = int(position)
179
  if position < 0 or position > 8:
 
187
  if winner == "You": stats['wins'] += 1
188
  elif winner == "AI": stats['losses'] += 1
189
  else: stats['draws'] += 1
190
+ yield get_tictactoe_board_html(), f"Game Over! {winner} won!", f"Final reward: {reward}", stats, []
191
+ return
192
+
193
+ # Show "thinking" indicator
194
+ yield get_tictactoe_board_html(), "AI is thinking...", "๐Ÿง ...", stats, []
195
+
196
  # AI move
197
+ _, ai_action = minimax(tictactoe_env.board.copy(), -1)
198
+ if ai_action is None: # Handle case where minimax returns no move (e.g., game over)
199
+ valid_actions = tictactoe_env._get_valid_actions()
200
+ if not valid_actions: # No actions left
201
+ yield get_tictactoe_board_html(), "Game is a draw!", "", stats, []
202
+ return
203
+ ai_action = random.choice(valid_actions)
204
+
205
+
206
+ reasoning_prompt = f"In TicTacToe, the board is currently: {tictactoe_env.board.flatten().tolist()}. The human player (X) played position {position}. I am the AI (O). The available moves are {tictactoe_env._get_valid_actions()}. I have analyzed the game tree using minimax and determined the optimal move is {ai_action}. Explain my strategy."
207
  reasoning = generate_reasoning(reasoning_prompt)
208
  obs, reward, terminated, truncated, info = tictactoe_env.step(ai_action)
209
 
 
212
  if winner == "You": stats['wins'] += 1
213
  elif winner == "AI": stats['losses'] += 1
214
  else: stats['draws'] += 1
215
+ yield get_tictactoe_board_html(), f"Game Over! {winner} won! AI played {ai_action}.", reasoning, stats, []
216
  else:
217
+ yield get_tictactoe_board_html(), f"AI played position {ai_action}. Your turn!", reasoning, stats, get_valid_tictactoe_positions()
218
 
219
  except Exception as e:
220
+ yield get_tictactoe_board_html(), f"Error: {str(e)}", "", stats, get_valid_tictactoe_positions()
221
 
222
  def reset_tictactoe(stats):
223
  """Reset TicTacToe game."""
224
  tictactoe_env.reset()
225
  return get_tictactoe_board_html(), "New game started! You are โŒ (X). Choose a position from the dropdown.", "AI will show its reasoning here...", stats, get_valid_tictactoe_positions()
226
 
227
+ # Simplified layout focusing only on TicTacToe
228
+ gr.Markdown("### Play TicTacToe against AI\nYou are โŒ (X) and go first. Get 3 in a row to win! **How AI Thinks**: AI will analyze the board and explain its moves.\nPositions: Top-left=0, bottom-right=8.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ with gr.Row():
231
+ with gr.Column(scale=2):
232
+ ttt_board = gr.HTML(
233
+ label="Game Board",
234
+ value=get_tictactoe_board_html()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  )
236
 
237
+ with gr.Column(scale=1):
238
+ ttt_position = gr.Dropdown(
239
+ label="Your Move (Valid Positions)",
240
+ choices=get_valid_tictactoe_positions()
 
 
 
 
 
241
  )
 
 
 
 
 
 
 
 
 
 
 
242
  with gr.Row():
243
+ ttt_play_btn = gr.Button("Play Move", variant="primary")
244
+ ttt_reset_btn = gr.Button("New Game", variant="secondary")
245
+ ttt_stats_display = gr.Markdown(value="Wins: 0 | Losses: 0 | Draws: 0")
246
+
247
+ ttt_message = gr.Textbox(
248
+ label="Game Status",
249
+ value="Choose a position to start!",
250
+ lines=2,
251
+ interactive=False
252
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
+ ttt_reasoning = gr.Textbox(
255
+ label="AI Reasoning",
256
+ value="AI will explain its thought process here...",
257
+ lines=3,
258
+ interactive=False
259
+ )
260
+
261
+ ttt_play_btn.click(
262
+ fn=play_tictactoe,
263
+ inputs=[ttt_position, ttt_stats],
264
+ outputs=[ttt_board, ttt_message, ttt_reasoning, ttt_stats, ttt_position]
265
+ )
266
+ ttt_reset_btn.click(
267
+ fn=reset_tictactoe,
268
+ inputs=[ttt_stats],
269
+ outputs=[ttt_board, ttt_message, ttt_reasoning, ttt_stats, ttt_position]
270
+ )
271
+ # Update stats display on changes
272
+ ttt_stats.change(
273
+ fn=lambda s: f"Wins: {s['wins']} | Losses: {s['losses']} | Draws: {s['draws']}",
274
+ inputs=ttt_stats,
275
+ outputs=ttt_stats_display
276
+ )
277
+ gr.Markdown("---")
278
+ gr.Markdown("๐Ÿšง **This is a development preview.** Full SPIRAL training and reasoning capabilities will be added in the next update!")
279
+
280
  else:
281
  # Fallback interface when games don't load
282
  gr.Markdown("โš ๏ธ **Game modules could not be loaded.** Showing diagnostic information.")
 
342
  """)
343
  gr.Markdown("**New in this version:** Visual boards, stats tracking, and transfer test stub!")
344
 
345
+ if not GAMES_AVAILABLE:
 
 
 
346
  gr.Markdown("---")
347
  gr.Markdown("๐Ÿ”„ **Dependencies are loading.** Check the diagnostic info above and refresh in a few minutes!")
348
 
requirements.txt CHANGED
@@ -10,4 +10,6 @@ pandas>=1.3.0
10
  tqdm>=4.62.0
11
  pyyaml
12
  bitsandbytes
13
- accelerate>=0.26.0
 
 
 
10
  tqdm>=4.62.0
11
  pyyaml
12
  bitsandbytes
13
+ accelerate>=0.26.0
14
+ pytest
15
+ Jinja2
tests/test_games.py CHANGED
@@ -1,58 +1,78 @@
1
  import pytest
2
  import numpy as np
3
- from src.games.tictactoe import TicTacToeEnv
4
- from src.games.kuhn_poker import KuhnPokerEnv
 
 
 
 
 
5
 
6
  @pytest.fixture
7
- def ttt_env():
 
8
  return TicTacToeEnv()
9
 
10
- @pytest.fixture
11
- def kuhn_env():
12
- return KuhnPokerEnv()
13
-
14
- def test_tictactoe_reset(ttt_env):
15
- obs, info = ttt_env.reset()
16
- assert np.all(obs == 0)
17
- assert ttt_env.current_player == 1
18
- assert not ttt_env.game_over
19
-
20
- def test_tictactoe_win(ttt_env):
21
- # Simulate win for player 1
22
- ttt_env.step(0) # X
23
- ttt_env.step(3) # O (invalid sim, but test step)
24
- ttt_env.step(1) # X
25
- ttt_env.step(4) # O
26
- _, reward, terminated, _, _ = ttt_env.step(2) # X wins
27
- assert terminated
28
- assert reward == 1 # From player 1 perspective
29
- assert ttt_env.winner == 1
30
 
31
- def test_tictactoe_invalid_move(ttt_env):
32
- ttt_env.step(0)
33
- _, reward, terminated, _, info = ttt_env.step(0) # Same spot
34
- assert 'invalid_move' in info
35
- assert terminated
36
- assert reward == -1
37
 
38
- def test_kuhn_reset(kuhn_env):
39
- obs, info = kuhn_env.reset()
40
- assert kuhn_env.pot == 2 # Antes
41
- assert kuhn_env.current_player == 1
42
- assert not kuhn_env.game_over
43
 
44
- def test_kuhn_fold(kuhn_env):
45
- _, reward, terminated, _, _ = kuhn_env.step(2) # Player 1 folds
46
- assert terminated
47
- assert reward == -1 # Lost ante
48
- assert kuhn_env.winner == -1
49
-
50
- def test_kuhn_win(kuhn_env):
51
- kuhn_env.player1_card = 2 # K
52
- kuhn_env.player2_card = 0 # J
53
- kuhn_env.step(1) # Bet
54
- kuhn_env.step(0) # Call
55
- _, reward, terminated, _, _ = kuhn_env.step(0) # Call (if needed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  assert terminated
57
- assert reward > 0 # Win with higher card
58
- assert kuhn_env.winner == 1
 
 
 
 
 
 
 
 
 
 
 
1
  import pytest
2
  import numpy as np
3
+ import sys
4
+ import os
5
+
6
+ # Add src to path to allow importing TicTacToeEnv
7
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
8
+
9
+ from games.tictactoe import TicTacToeEnv
10
 
11
  @pytest.fixture
12
+ def env():
13
+ """Fixture to create a fresh TicTacToeEnv for each test."""
14
  return TicTacToeEnv()
15
 
16
+ def test_initial_state(env):
17
+ """Test the initial state of the board."""
18
+ assert np.all(env.board == np.zeros((3, 3)))
19
+ assert env.current_player == 1
20
+ assert not env.game_over
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def test_player_move(env):
23
+ """Test a valid player move."""
24
+ env.step(0)
25
+ assert env.board[0, 0] == 1
26
+ assert env.current_player == -1
27
+ assert not env.game_over
28
 
29
+ def test_invalid_move(env):
30
+ """Test making an invalid move on an occupied cell."""
31
+ env.step(0)
32
+ with pytest.raises(ValueError):
33
+ env.step(0)
34
 
35
+ def test_win_condition_row(env):
36
+ """Test a win condition in a row."""
37
+ env.board = np.array([[1, 1, 1], [0, -1, 0], [-1, 0, 0]])
38
+ assert env._check_winner(1)
39
+ assert not env._check_winner(-1)
40
+
41
+ def test_win_condition_col(env):
42
+ """Test a win condition in a column."""
43
+ env.board = np.array([[-1, 1, 0], [-1, 1, 0], [-1, 0, 0]])
44
+ assert not env._check_winner(1)
45
+ assert env._check_winner(-1)
46
+
47
+ def test_win_condition_diag(env):
48
+ """Test a win condition on a diagonal."""
49
+ env.board = np.array([[1, 0, -1], [0, 1, -1], [0, 0, 1]])
50
+ assert env._check_winner(1)
51
+
52
+ def test_draw_condition(env):
53
+ """Test a draw condition."""
54
+ env.board = np.array([[1, -1, 1], [1, -1, 1], [-1, 1, -1]])
55
+ assert env._is_draw()
56
+ assert not env._check_winner(1)
57
+ assert not env._check_winner(-1)
58
+
59
+ def test_game_over_on_win(env):
60
+ """Test that the game_over flag is set on a win."""
61
+ env.step(0) # P1
62
+ env.step(3) # P2
63
+ env.step(1) # P1
64
+ env.step(4) # P2
65
+ _, _, terminated, _, _ = env.step(2) # P1 wins
66
  assert terminated
67
+ assert env.game_over
68
+ assert env.winner == 1
69
+
70
+ def test_reset(env):
71
+ """Test if the environment resets correctly."""
72
+ env.step(0)
73
+ env.step(1)
74
+ env.reset()
75
+ assert np.all(env.board == np.zeros((3, 3)))
76
+ assert env.current_player == 1
77
+ assert not env.game_over
78
+ assert env.winner is None