File size: 6,097 Bytes
8adb2f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Game controller - wrapper around StrategoEnv and AI agents"""

from typing import Tuple, Optional, List, Dict, Any
from stratego.env.stratego_env import StrategoEnv
from stratego.utils.parsing import extract_legal_moves, extract_forbidden, extract_board_block_lines
from stratego.game_logger import GameLogger
import os


class GameController:
    """Thin wrapper around StrategoEnv for web UI."""
    
    def __init__(self, env_id: str, size: int, ai_agent: object, prompt_name: str = "base"):
        self.env = StrategoEnv(env_id=env_id, size=size)
        self.env_id = env_id
        self.ai_agent = ai_agent
        self.prompt_name = prompt_name
        self.size = size
        self.human_player_id = 0
        self.ai_player_id = 1
        self.move_history = {0: [], 1: []}
        self.game_done = False
        self.game_info = {}
        self.current_player_id = None
        self.current_observation = ""
        self.game_logger = None
        self.logs_dir = os.path.join(os.path.dirname(__file__),"..","..","logs","games")
        os.makedirs(self.logs_dir, exist_ok=True)
    
    def reset(self) -> Tuple[int, str]:
        self.env.reset(num_players=2)
        self.move_history = {0: [], 1: []}
        self.game_done = False
        self.game_info = {}
        self.game_logger = GameLogger(out_dir=os.path.join(self.logs_dir, ".."), game_type=self.env_id, board_size=self.size, prompt_name=self.prompt_name)
        return self.get_current_player()
    
    def get_current_player(self) -> Tuple[int, str]:
        self.current_player_id, self.current_observation = self.env.get_observation()
        return self.current_player_id, self.current_observation
    
    def get_legal_moves(self, observation: Optional[str] = None) -> List[str]:
        obs_to_use = observation or self.current_observation
        if not obs_to_use:
            return []
        legal = extract_legal_moves(obs_to_use)
        forbidden = set(extract_forbidden(obs_to_use))
        result = [m for m in legal if m not in forbidden]
        return result if result else legal
    
    def get_board_display(self, observation: Optional[str] = None) -> str:
        obs_to_use = observation or self.current_observation
        if not obs_to_use:
            return "No board state available"
        try:
            board_lines = extract_board_block_lines(obs_to_use, self.size)
            return "\n".join(board_lines)
        except Exception as e:
            return f"Error rendering board: {str(e)}"
    
    def execute_move(self, move_str: str) -> Tuple[bool, Dict[str, Any]]:
        if self.game_done:
            return True, {"error": "Game already finished"}
        try:
            done, info = self.env.step(action=move_str)
            move_info = {"player": self.current_player_id, "move": move_str, "done": done, "info": info}
            self.move_history[self.current_player_id].append(move_info)
            if self.game_logger:
                try:
                    self.game_logger.log_move(turn=sum(len(v) for v in self.move_history.values()), player=self.current_player_id, move=move_str, model_name=getattr(self.ai_agent,"model_name","human"), prompt_name=self.prompt_name, game_done=done, game_info=info)
                except:
                    pass
            if done:
                self.game_done = True
                self.game_info = info
            return done, info
        except Exception as e:
            return True, {"error": f"Move execution error: {str(e)}"}
    
    def get_ai_move(self, observation: Optional[str] = None) -> Tuple[str, Optional[str]]:
        obs_to_use = observation or self.current_observation
        if not obs_to_use:
            return "[A0 B0]", "No observation available"
        try:
            move = self.ai_agent(obs_to_use)
            if not move:
                legal_moves = self.get_legal_moves(obs_to_use)
                move = legal_moves[0] if legal_moves else "[A0 B0]"
            return str(move).strip(), None
        except Exception as e:
            legal_moves = self.get_legal_moves(obs_to_use)
            return (legal_moves[0] if legal_moves else "[A0 B0]"), f"AI error: {str(e)}"
    
    def close(self) -> Tuple[Dict, Dict]:
        try:
            rewards, info = self.env.close()
            if self.game_logger:
                try:
                    winner = info.get("winner") if info else None
                    result = info.get("result") if info else None
                    self.game_logger.finalize_game(winner=winner, result=result)
                except:
                    pass
            return rewards, info
        except Exception as e:
            return {}, {"error": str(e)}
    
    def is_human_turn(self) -> bool:
        return self.current_player_id == self.human_player_id
    
    def is_ai_turn(self) -> bool:
        return self.current_player_id == self.ai_player_id
    
    def get_turn_count(self) -> int:
        return sum(len(v) for v in self.move_history.values())
    
    def get_move_history_display(self, limit: int = None) -> List[str]:
        """Get move history in CHRONOLOGICAL order (alternating: You, AI, You, AI...)



        Args:

            limit: Max moves to return. None = return all moves

        """
        moves = []
        max_moves = max(len(self.move_history[0]), len(self.move_history[1]))

        # Interleave moves from both players chronologically
        for i in range(max_moves):
            # Your move (Player 0)
            if i < len(self.move_history[0]):
                m = self.move_history[0][i]
                moves.append(f"You: {m['move']}")

            # AI move (Player 1)
            if i < len(self.move_history[1]):
                m = self.move_history[1][i]
                moves.append(f"AI: {m['move']}")

        # Return all moves (or last 'limit' if specified)
        if limit:
            return moves[-limit:]
        return moves