File size: 18,898 Bytes
6aff2d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
import argparse
import os
import re
import time
import random
# from stratego.prompt_optimizer import improve_prompt
from stratego.env.stratego_env import StrategoEnv
from stratego.prompts import get_prompt_pack
from stratego.utils.parsing import extract_board_block_lines, extract_legal_moves, extract_forbidden
from stratego.utils.game_move_tracker import GameMoveTracker as MoveTrackerClass
from stratego.utils.move_processor import process_move
from stratego.game_logger import GameLogger
from stratego.game_analyzer import analyze_and_update_prompt
from stratego.datasets import auto_push_after_game


#Revised to set temperature(13 Nov 2025)
def build_agent(spec: str, prompt_name: str):
    """ 

    Creates and configures an AI agent based on the input string.

    Example spec: 'ollama:phi3:3.8b'

    """
    kind, name = spec.split(":", 1)  # Split string to get model type and name
    
    if kind == "ollama":
        from stratego.models.ollama_model import OllamaAgent
        # Define the temperature value explicitly
        AGENT_TEMPERATURE = 0.2
        
        # Create the Ollama agent
        agent = OllamaAgent(
            model_name=name, 
            temperature=AGENT_TEMPERATURE, 
            num_predict=128,  # Allow enough tokens for a complete move response
            prompt_pack=get_prompt_pack(prompt_name) # Load strategy prompt
        )
        
        # Store temperature for logging
        agent.temperature = AGENT_TEMPERATURE
        
        return agent
    if kind == "hf":
        from stratego.models.hf_model import HFLocalAgent
        return HFLocalAgent(model_id=name, prompt_pack=prompt_name)
    raise ValueError(f"Unknown agent spec: {spec}")

def print_board(observation: str, size: int = 10):
    block = extract_board_block_lines(observation, size)
    if block:
        print("\n".join(block))

# --- Main Command Line Interface (CLI) ---
def cli():
    DEFAULT_ENV = "Stratego-v0"
    DUEL_ENV = "Stratego-duel"
    CUSTOM_ENV = "Stratego-custom"
    tracker = MoveTrackerClass()
    p = argparse.ArgumentParser()
    p.add_argument("--p0", default="ollama:deepseek-r1:32b")
    p.add_argument("--p1", default="ollama:gemma3:1b")
    # UPDATED HELP TEXT to explain how this parameter relates to VRAM utilization
    # For large models (120B, 70B), you MUST set this value based on available VRAM(13 Nov 2025)
    # UPDATED GPU arguments for VRAM control (now defaults to CPU-only)
    p.add_argument("--p0-num-gpu", type=int, default=0,
                    help="Number of GPU layers to offload for Player 0. Default is 0 (CPU-only mode). Use a positive number (e.g., 50) to offload layers to GPU/VRAM, or 999 for maximum GPU use.")
    p.add_argument("--p1-num-gpu", type=int, default=0,
                    help="Number of GPU layers to offload for Player 1. Default is 0 (CPU-only mode). Use a positive number (e.g., 40) to offload layers to GPU/VRAM, or 999 for maximum GPU use.")    
    #(13 Nov 2025) NOTE: Default env_id is used as a flag to trigger the interactive menu
    p.add_argument("--prompt", default="base", help="Prompt preset name (e.g. base|concise|adaptive)")
    p.add_argument("--env_id", default=DEFAULT_ENV, help="TextArena environment id")
    p.add_argument("--log-dir", default="logs", help="Directory for per-game CSV logs")
    p.add_argument("--game-id", default=None, help="Optional custom game id in CSV filename")
    p.add_argument("--size", type=int, default=10, help="Board size NxN")
    p.add_argument("--max-turns", type=int, default=None, help="Maximum turns before stopping (for testing). E.g., --max-turns 10")

    args = p.parse_args()

    #(13 Nov 2025) --- INTERACTIVE ENVIRONMENT SELECTION ---
    if args.env_id == DEFAULT_ENV:
        print("\n--- Stratego Version Selection ---")
        print(f"1. Standard Game ({DEFAULT_ENV})")
        print(f"2. Duel Mode ({DUEL_ENV})")
        print(f"3. Custom Mode ({CUSTOM_ENV})")
        
        while True:
            choice = input("Enter your choice (1, 2, or 3): ").strip()
            if not choice or choice == '1':
                print(f"Selected: {DEFAULT_ENV}")
                break
            elif choice == '2':
                args.env_id = DUEL_ENV
                args.size = 6
                print(f"Selected: {DUEL_ENV}")
                break
            elif choice == '3':
                # [CHANGE] Updated prompt range description
                board = input("Please enter your custom board size in range of 4~9: ").strip()
                # [CHANGE] Added '4' and '5' to valid options
                if board in ['4', '5', '6', '7', '8', '9']:
                    args.env_id = CUSTOM_ENV
                    args.size = int(board)
                    print(f"Selected: {CUSTOM_ENV} with size {args.size}x{args.size}")
                    break
                else:
                    print("Invalid choice.")
            else:
                print("Invalid choice.")

    # --- Setup Game ---
    agents = {
        0: build_agent(args.p0, args.prompt),
        1: build_agent(args.p1, args.prompt),
    }
    # Check if it is really normal Stratego version
    if (args.env_id == CUSTOM_ENV): 
        env = StrategoEnv(env_id=CUSTOM_ENV, size=args.size)
        game_type = "custom"
    elif (args.env_id == DUEL_ENV):
        env = StrategoEnv(env_id=DUEL_ENV)
        game_type = "duel"
        args.size = 6  # Duel mode uses 6x6 board
    else:
        env = StrategoEnv()
        game_type = "standard"
    env.reset(num_players=2)
    
    # Track game start time
    game_start_time = time.time()
    
    # Simple move history tracker (separate for each player)
    move_history = {0: [], 1: []}

    with GameLogger(out_dir=args.log_dir, game_id=args.game_id, prompt_name=args.prompt, game_type=game_type, board_size=args.size) as logger:
        for pid in (0, 1):
            if hasattr(agents[pid], "logger"):
                agents[pid].logger = logger
                agents[pid].player_id = pid

        done = False
        turn = 0
        print("\n--- Stratego LLM Match Started ---")
        print(f"Player 1 Agent: {agents[0].model_name}")
        print(f"Player 2 Agent: {agents[1].model_name}")
        if args.max_turns:
            print(f"⏱️  Max turns limit: {args.max_turns} (testing mode)")
        print()
        while not done:
            # Check max turns limit
            if args.max_turns and turn >= args.max_turns:
                print(f"\n⏱️  Reached max turns limit ({args.max_turns}). Stopping game early.")
                break
            
            player_id, observation = env.get_observation()
            current_agent = agents[player_id]
            player_display = f"Player {player_id+1}"
            model_name = current_agent.model_name
            
            # --- NEW LOGGING FOR TURN, PLAYER, AND MODEL ---
            print(f"\n>>>> TURN {turn}: {player_display} ({model_name}) is moving...")

            if (args.size == 10):
                print_board(observation)
            else:
                print_board(observation, args.size)
            # Pass recent move history to agent
            current_agent.set_move_history(move_history[player_id][-10:])
            history_str = tracker.to_prompt_string(player_id)
            
            # --- [CHANGE] INJECT AGGRESSION WARNING ---
            # If the game drags on (e.g. > 20 turns), force them to wake up
            if turn > 20:
                observation += "\n\n[SYSTEM MESSAGE]: The game is stalling. You MUST ATTACK or ADVANCE immediately. Passive play is forbidden."
            
            if turn > 50:
                 observation += "\n[CRITICAL]: STOP MOVING BACK AND FORTH. Pick a piece and move it FORWARD now."
            # ------------------------------------------

            observation = observation + history_str
            # print(tracker.to_prompt_string(player_id))
            lines = history_str.strip().splitlines()
            if len(lines) <= 1:
                print(history_str)
            else:
                header = lines[0:1]
                body = lines[1:]
                tail = body[-5:]  # Show only last 5 moves
                print("\n".join(header + tail))
            
            # The agent (LLM) generates the action, retry a few times; fallback to available moves
            action = ""
            max_agent_attempts = 3
            for attempt in range(max_agent_attempts):
                action = current_agent(observation)
                if action:
                    break
                print(f"[TURN {turn}] {model_name} failed to produce a move (attempt {attempt+1}/{max_agent_attempts}). Retrying...")

            if not action:
                legal = extract_legal_moves(observation)
                forbidden = set(extract_forbidden(observation))
                legal_filtered = [m for m in legal if m not in forbidden] or legal
                if legal_filtered:
                    action = random.choice(legal_filtered)
                    print(f"[TURN {turn}] Fallback to random available move: {action}")
                else:
                    print(f"[TURN {turn}] No legal moves available for fallback; ending game loop.")
                    break
                
            # --- NEW LOGGING FOR STRATEGY/MODEL DECISION ---
            print(f"  > AGENT DECISION: {model_name} -> {action}")
            print(f"  > Strategy/Model: Ollama Agent (T={current_agent.temperature}, Prompt='{args.prompt}')")

            # Extract move details for logging
            move_pattern = r'\[([A-J]\d+)\s+([A-J]\d+)\]'
            match = re.search(move_pattern, action)
            # src_pos = match.group(1) if match else ""
            # dst_pos = match.group(2) if match else ""
            
            # # Get piece type from board (simplified extraction)
            # piece_type = ""
            # if src_pos and hasattr(env, 'game_state') and hasattr(env.game_state, 'board'):
            #     try:
            #         # Parse position like "D4" -> row=3, col=3
            #         col = ord(src_pos[0]) - ord('A')
            #         row = int(src_pos[1:]) - 1
            #         piece = env.game_state.board[row][col]
            #         if piece and hasattr(piece, 'rank_name'):
            #             piece_type = piece.rank_name
            #     except:
            #         piece_type = "Unknown"
            
            # # Check if this is a repeated move (last 3 moves)
            # was_repeated = False
            # recent_moves = [m["move"] for m in move_history[player_id][-3:]]
            # if action in recent_moves:
            #     was_repeated = True
            
            # Record this move in history
            move_history[player_id].append({
                "turn": turn,
                "move": action,
                "text": f"Turn {turn}: You played {action}"
            })

            # Process move details for logging BEFORE making the environment step
            move_details = process_move(
                action=action,
                board=env.env.board,
                observation=observation,
                player_id=player_id
            )

            # Execute the action exactly once in the environment
            done, info = env.step(action=action)
            
            # Determine battle outcome by checking if target piece was there
            battle_outcome = ""
            if move_details.target_piece:
                # There was a piece at destination, so battle occurred
                # Check what's at destination now to determine outcome
                dst_row = ord(move_details.dst_pos[0]) - ord('A')
                dst_col = int(move_details.dst_pos[1:])
                cell_after = env.env.board[dst_row][dst_col]
                
                if cell_after is None:
                    # Both pieces removed = draw
                    battle_outcome = "draw"
                elif isinstance(cell_after, dict):
                    if cell_after.get('player') == player_id:
                        battle_outcome = "won"
                    else:
                        battle_outcome = "lost"
            
            # Extract outcome from environment observation
            outcome = "move"
            # captured = ""
            obs_text = ""
            # if isinstance(info, (list, tuple)) and len(info) > 1:
            #     obs_text = str(info[1])
            # else:
            #     obs_text = str(info)
            if isinstance(info, (list, tuple)):
                if 0 <= player_id < len(info):
                    obs_text = str(info[player_id])
                else:
                    obs_text = " ".join(str(x) for x in info)
            else:
                obs_text = str(info)

            low = obs_text.lower()
            if "invalid" in low or "illegal" in low:
                outcome = "invalid"
            elif "captured" in low or "won the battle" in low:
                outcome = "won_battle"
            elif "lost the battle" in low or "defeated" in low:
                outcome = "lost_battle"
            elif "draw" in low or "tie" in low:
                outcome = "draw"
                    
            event = info.get("event") if isinstance(info, dict) else None
            extra = info.get("detail") if isinstance(info, dict) else None
            
            if outcome != "invalid":
                # Record this move in history
                move_history[player_id].append({
                    "turn": turn,
                    "move": action,
                    "text": f"Turn {turn}: You played {action}"
                })

                tracker.record(
                    player=player_id,
                    move=action,
                    event=event,
                    extra=extra
                )
            else:
                move_history[player_id].append({
                    "turn": turn,
                    "move": action,
                    "text": f"Turn {turn}: INVALID move {action}"
                })
                tracker.record(
                    player=player_id,
                    move=action,
                    event="invalid_move",
                    extra=extra
                )
                print(f"[HISTORY] Skipping invalid move from history: {action}")

            logger.log_move(turn=turn,
                                player=player_id,
                                model_name=getattr(current_agent, "model_name", "unknown"),
                                move=action,
                                src=move_details.src_pos,
                                dst=move_details.dst_pos,
                                piece_type=move_details.piece_type,
                                board_state=move_details.board_state,
                                available_moves=move_details.available_moves,
                                move_direction=move_details.move_direction,
                                target_piece=move_details.target_piece,
                                battle_outcome=battle_outcome,
                            )
            turn += 1


        # --- Game Over & Winner Announcement ---
        rewards, game_info = env.close()
        print("\n" + "="*50)
        print("--- GAME OVER ---")
        game_duration = time.time() - game_start_time
        # Print summary
        print(f"\nGame finished. Duration: {int(game_duration // 60)}m {int(game_duration % 60)}s")
        print(f"Result: {rewards} | {game_info}")
        
        # Logic to declare the specific winner based on rewards
        # Rewards are usually {0: 1, 1: -1} (P0 Wins) or {0: -1, 1: 1} (P1 Wins)
        p0_score = rewards.get(0, 0)
        p1_score = rewards.get(1, 0)
        winner = None
        game_result = ""

        if p0_score > p1_score:
            winner = 0
            game_result = "player0"
            print(f"\n🏆 * * * PLAYER 0 WINS! * * * 🏆")
            print(f"Agent: {agents[0].model_name}")
        elif p1_score > p0_score:
            winner = 1
            game_result = "player1"
            print(f"\n🏆 * * * PLAYER 1 WINS! * * * 🏆")
            print(f"Agent: {agents[1].model_name}")
        else:
            game_result = "draw"
            print(f"\n🤝 * * * IT'S A DRAW! * * * 🤝")

        print("\nDetails:")
        print(f"Final Rewards: {rewards}")
        print(f"Game Info: {game_info}")
    
    try:
        invalid_players = [
            pid for pid, info_dict in (game_info or {}).items()
            if isinstance(info_dict, dict) and info_dict.get("invalid_move")
        ]
        if invalid_players:
            import csv
            csv_path = logger.path
            rows = []
            fieldnames = None

            with open(csv_path, "r", encoding="utf-8", newline="") as f:
                reader = csv.DictReader(f)
                fieldnames = reader.fieldnames
                for r in reader:
                    rows.append(r)

            if rows and fieldnames and "outcome" in fieldnames:
                rows[-1]["outcome"] = "invalid"
                with open(csv_path, "w", encoding="utf-8", newline="") as f:
                    writer = csv.DictWriter(f, fieldnames=fieldnames)
                    writer.writeheader()
                    writer.writerows(rows)

                print("\n[LOG PATCH] Last move outcome patched to 'invalid' "
                      f"(player {invalid_players[0]} made an invalid move).")
                
    except Exception as e:
        print(f"[LOG PATCH] Failed to patch CSV outcome: {e}")
            
        # Finalize the game log with winner info in every row
        logger.finalize_game(winner=winner, game_result=game_result)
    
    # LLM analyzes the game CSV and updates prompt
    analyze_and_update_prompt(
        csv_path=logger.path,
        prompts_dir="stratego/prompts",
        logs_dir=args.log_dir,
        model_name="mistral:7b",  # Analysis model
        models_used=[agents[0].model_name, agents[1].model_name],
        game_duration_seconds=game_duration,
        winner=winner,
        total_turns=turn - 1
    )
    
    # Auto-push game data to Hugging Face Hub
    print("\nSyncing game data to Hugging Face...")
    auto_push_after_game(
        logs_dir=os.path.join(args.log_dir, "games"),
        repo_id="STRATEGO-LLM-TRAINING/stratego",
    )