DarshanScripts commited on
Commit
8adb2f1
·
verified ·
1 Parent(s): ad372ca

Upload stratego/web/game_controller.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. stratego/web/game_controller.py +141 -0
stratego/web/game_controller.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Game controller - wrapper around StrategoEnv and AI agents"""
2
+
3
+ from typing import Tuple, Optional, List, Dict, Any
4
+ from stratego.env.stratego_env import StrategoEnv
5
+ from stratego.utils.parsing import extract_legal_moves, extract_forbidden, extract_board_block_lines
6
+ from stratego.game_logger import GameLogger
7
+ import os
8
+
9
+
10
+ class GameController:
11
+ """Thin wrapper around StrategoEnv for web UI."""
12
+
13
+ def __init__(self, env_id: str, size: int, ai_agent: object, prompt_name: str = "base"):
14
+ self.env = StrategoEnv(env_id=env_id, size=size)
15
+ self.env_id = env_id
16
+ self.ai_agent = ai_agent
17
+ self.prompt_name = prompt_name
18
+ self.size = size
19
+ self.human_player_id = 0
20
+ self.ai_player_id = 1
21
+ self.move_history = {0: [], 1: []}
22
+ self.game_done = False
23
+ self.game_info = {}
24
+ self.current_player_id = None
25
+ self.current_observation = ""
26
+ self.game_logger = None
27
+ self.logs_dir = os.path.join(os.path.dirname(__file__),"..","..","logs","games")
28
+ os.makedirs(self.logs_dir, exist_ok=True)
29
+
30
+ def reset(self) -> Tuple[int, str]:
31
+ self.env.reset(num_players=2)
32
+ self.move_history = {0: [], 1: []}
33
+ self.game_done = False
34
+ self.game_info = {}
35
+ self.game_logger = GameLogger(out_dir=os.path.join(self.logs_dir, ".."), game_type=self.env_id, board_size=self.size, prompt_name=self.prompt_name)
36
+ return self.get_current_player()
37
+
38
+ def get_current_player(self) -> Tuple[int, str]:
39
+ self.current_player_id, self.current_observation = self.env.get_observation()
40
+ return self.current_player_id, self.current_observation
41
+
42
+ def get_legal_moves(self, observation: Optional[str] = None) -> List[str]:
43
+ obs_to_use = observation or self.current_observation
44
+ if not obs_to_use:
45
+ return []
46
+ legal = extract_legal_moves(obs_to_use)
47
+ forbidden = set(extract_forbidden(obs_to_use))
48
+ result = [m for m in legal if m not in forbidden]
49
+ return result if result else legal
50
+
51
+ def get_board_display(self, observation: Optional[str] = None) -> str:
52
+ obs_to_use = observation or self.current_observation
53
+ if not obs_to_use:
54
+ return "No board state available"
55
+ try:
56
+ board_lines = extract_board_block_lines(obs_to_use, self.size)
57
+ return "\n".join(board_lines)
58
+ except Exception as e:
59
+ return f"Error rendering board: {str(e)}"
60
+
61
+ def execute_move(self, move_str: str) -> Tuple[bool, Dict[str, Any]]:
62
+ if self.game_done:
63
+ return True, {"error": "Game already finished"}
64
+ try:
65
+ done, info = self.env.step(action=move_str)
66
+ move_info = {"player": self.current_player_id, "move": move_str, "done": done, "info": info}
67
+ self.move_history[self.current_player_id].append(move_info)
68
+ if self.game_logger:
69
+ try:
70
+ self.game_logger.log_move(turn=sum(len(v) for v in self.move_history.values()), player=self.current_player_id, move=move_str, model_name=getattr(self.ai_agent,"model_name","human"), prompt_name=self.prompt_name, game_done=done, game_info=info)
71
+ except:
72
+ pass
73
+ if done:
74
+ self.game_done = True
75
+ self.game_info = info
76
+ return done, info
77
+ except Exception as e:
78
+ return True, {"error": f"Move execution error: {str(e)}"}
79
+
80
+ def get_ai_move(self, observation: Optional[str] = None) -> Tuple[str, Optional[str]]:
81
+ obs_to_use = observation or self.current_observation
82
+ if not obs_to_use:
83
+ return "[A0 B0]", "No observation available"
84
+ try:
85
+ move = self.ai_agent(obs_to_use)
86
+ if not move:
87
+ legal_moves = self.get_legal_moves(obs_to_use)
88
+ move = legal_moves[0] if legal_moves else "[A0 B0]"
89
+ return str(move).strip(), None
90
+ except Exception as e:
91
+ legal_moves = self.get_legal_moves(obs_to_use)
92
+ return (legal_moves[0] if legal_moves else "[A0 B0]"), f"AI error: {str(e)}"
93
+
94
+ def close(self) -> Tuple[Dict, Dict]:
95
+ try:
96
+ rewards, info = self.env.close()
97
+ if self.game_logger:
98
+ try:
99
+ winner = info.get("winner") if info else None
100
+ result = info.get("result") if info else None
101
+ self.game_logger.finalize_game(winner=winner, result=result)
102
+ except:
103
+ pass
104
+ return rewards, info
105
+ except Exception as e:
106
+ return {}, {"error": str(e)}
107
+
108
+ def is_human_turn(self) -> bool:
109
+ return self.current_player_id == self.human_player_id
110
+
111
+ def is_ai_turn(self) -> bool:
112
+ return self.current_player_id == self.ai_player_id
113
+
114
+ def get_turn_count(self) -> int:
115
+ return sum(len(v) for v in self.move_history.values())
116
+
117
+ def get_move_history_display(self, limit: int = None) -> List[str]:
118
+ """Get move history in CHRONOLOGICAL order (alternating: You, AI, You, AI...)
119
+
120
+ Args:
121
+ limit: Max moves to return. None = return all moves
122
+ """
123
+ moves = []
124
+ max_moves = max(len(self.move_history[0]), len(self.move_history[1]))
125
+
126
+ # Interleave moves from both players chronologically
127
+ for i in range(max_moves):
128
+ # Your move (Player 0)
129
+ if i < len(self.move_history[0]):
130
+ m = self.move_history[0][i]
131
+ moves.append(f"You: {m['move']}")
132
+
133
+ # AI move (Player 1)
134
+ if i < len(self.move_history[1]):
135
+ m = self.move_history[1][i]
136
+ moves.append(f"AI: {m['move']}")
137
+
138
+ # Return all moves (or last 'limit' if specified)
139
+ if limit:
140
+ return moves[-limit:]
141
+ return moves