DarshanScripts commited on
Commit
aedcb93
·
verified ·
1 Parent(s): 49f6957

Upload stratego\models\ollama_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. stratego//models//ollama_model.py +383 -0
stratego//models//ollama_model.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import re
5
+ from typing import Optional, Tuple
6
+
7
+ from langchain_ollama import ChatOllama
8
+ import requests
9
+
10
+ from .base import AgentLike
11
+ from ..utils.parsing import (
12
+ extract_legal_moves, slice_board_and_moves, strip_think, MOVE_RE, extract_forbidden
13
+ )
14
+
15
+
16
+ # I seperated Prompts from the code
17
+ from ..prompts import PromptPack, get_prompt_pack
18
+
19
+ # 🧩 Import strategies
20
+ from ..strategies.base import Strategy
21
+ from ..strategies.aggressive_strategy import AggressiveStrategy
22
+ from ..strategies.defensive_strategy import DefensiveStrategy
23
+ from ..strategies.random_move import RandomStrategy
24
+
25
+
26
+ class OllamaAgent(AgentLike):
27
+ def __init__(
28
+ self,
29
+ model_name: str,
30
+ system_prompt: Optional[str] = None,
31
+ host: Optional[str] = None,
32
+ prompt_pack: Optional[PromptPack | str] = None,
33
+ strategy: Optional[Strategy] = None,
34
+ **kwargs,
35
+ ):
36
+ self.model_name = model_name
37
+
38
+ self.STRATEGIC_GUIDANCE = """
39
+ You are a skilled Stratego player.
40
+ You must choose the SINGLE best legal move from the given board, legal moves, forbidden moves, and move history.
41
+
42
+ GENERAL RULES:
43
+ 1. Output EXACTLY ONE MOVE in the form [A0 B0].
44
+ 2. NEVER output explanations, commentary, or reasoning.
45
+ 3. Try to choose a move that would be legal in Stratego rules.
46
+ 4. NEVER repeat a previous move unless it creates a tactical advantage (capture, reveal, escape).
47
+ 5. AVOID back-and-forth oscillations (e.g., A5->A6 then A6->A5).
48
+ 6. It would be considered a SERIOUS MISTAKE, which leads you to lose the game, to attempt illegal moves such as moving a Flag or Bomb, moving in an impossible way, moving upon its own pieces, or trying to move opponent's pieces.
49
+
50
+ STRATEGIC PRINCIPLES:
51
+ 1. Avoid random or pointless shuffling of pieces.
52
+ 2. Prefer moves that improve board position, uncover information, or apply pressure.
53
+ 3. Avoid moving high-value officers (Marshal, General, Colonel) blindly into unknown pieces.
54
+ 4. Prefer advancing Scouts for reconnaissance.
55
+ 5. Avoid moving bombs unless revealed and forced.
56
+ 6. Do NOT walk pieces next to the same unknown piece repeatedly without purpose.
57
+ 7. Do NOT afraid to sacrifice low-rank pieces for information gain.
58
+
59
+ CAPTURE & SAFETY RULES:
60
+ 1. If you can capture a known weaker enemy piece safely, prefer that move.
61
+ 2. NEVER attack a higher-ranked or unknown piece with a valuable piece unless strategically justified.
62
+ 3. If the enemy piece is revealed as weaker, press the advantage.
63
+ 4. If your piece is threatened, retreat or reposition instead of repeating the last move.
64
+
65
+ USE OF HISTORY:
66
+ 1. Avoid repeating cycles recognized in the history (e.g., A->B->A->B).
67
+ 2. Track revealed enemy pieces from history and use rank knowledge:
68
+ - If they moved, they are not Bombs or Flags.
69
+ - If they captured, infer their rank and avoid attacking with weaker pieces.
70
+ 3. If an enemy repeatedly retreats from your piece, continue safe pressure.
71
+
72
+ POSITIONING RULES:
73
+ 1. Advance pieces that have strategic value while keeping your formation stable.
74
+ 2. Keep bombs guarding high-value territory; avoid unnecessary bomb movement.
75
+ 3. Push on flanks where the opponent retreats often.
76
+ 4. Maintain escape squares for your high-ranking leaders.
77
+
78
+ ENDGAME LOGIC:
79
+ 1. Prioritize discovering and attacking the opponent's flag location.
80
+ 2. Secure safe paths for Miners to remove bombs.
81
+ 3. In endgame, prioritize mobility and avoid blockades caused by your own pieces.
82
+
83
+ CHOOSE THE BEST MOVE:
84
+ Evaluate all legal moves and pick the one that:
85
+ - improves position, OR
86
+ - pressures an opponent safely, OR
87
+ - increases information, OR
88
+ - avoids known traps or loops, OR
89
+ - ensures safety of valuable pieces.
90
+
91
+ Output ONLY one legal move in the exact format [A0 B0]. Nothing else.
92
+ """
93
+ # self.VALIDATION_GUIDANCE = """
94
+ # You are validating a Stratego move. Decide if the move obeys Stratego rules given the board and history.
95
+ # Rules to enforce:
96
+ # - Pieces cannot move into lakes or off-board.
97
+ # - Immovable pieces (Bomb, Flag) cannot move.
98
+ # - A piece cannot capture its own piece.
99
+ # - Only Scouts can move more than one square in straight lines; others move exactly one square orthogonally.
100
+ # - No diagonal movement.
101
+ # - Respect revealed information from history (if it moved before, it is not a Bomb/Flag).
102
+ # - If an 'Available Moves:' list is present, moves not in that list are almost always invalid.
103
+ # - If a 'FORBIDDEN' list is present, those moves are invalid.
104
+ # - On small custom boards (size <= 5), there are NO lakes unless the board explicitly shows '~'. If you do not see '~', assume no lakes exist.
105
+
106
+ # Respond with either:
107
+ # - VALID
108
+ # - INVALID: <short reason>
109
+ # """
110
+ if isinstance(prompt_pack, str) or prompt_pack is None:
111
+ self.prompt_pack: PromptPack = get_prompt_pack(prompt_pack)
112
+ else:
113
+ self.prompt_pack = prompt_pack
114
+
115
+
116
+
117
+ if system_prompt is not None:
118
+ self.system_prompt = system_prompt
119
+ else:
120
+ # if there is already an existing updated prompt, we use that one
121
+ prompt_path = os.path.join(os.path.dirname(__file__), "..", "prompts", "current_prompt.txt")
122
+ if os.path.exists(prompt_path):
123
+ with open(prompt_path, "r", encoding="utf-8") as f:
124
+ self.system_prompt = f.read()
125
+ else:
126
+ self.system_prompt = self.prompt_pack.system
127
+
128
+
129
+ self.initial_prompt = self.system_prompt
130
+ # Setup Ollama client
131
+ base_url = host or os.getenv("OLLAMA_HOST", "http://localhost:11434")
132
+ model_kwargs = {
133
+ "temperature": kwargs.pop("temperature", 0.1),
134
+ "top_p": kwargs.pop("top_p", 0.9),
135
+ "repeat_penalty": kwargs.pop("repeat_penalty", 1.05),
136
+ "num_predict": kwargs.pop("num_predict", 24),
137
+ **kwargs,
138
+ }
139
+
140
+ # Only print connection message if explicitly enabled (for CLI use, not web UI)
141
+ # print("🚀 Connecting to Ollama at:", base_url)
142
+ self.client = ChatOllama(model=model_name, base_url=base_url, model_kwargs=model_kwargs)
143
+
144
+ # Simple move history tracking
145
+ self.move_history = []
146
+ self.player_id = None
147
+
148
+ def set_move_history(self, history):
149
+ """Set the recent move history for this agent."""
150
+ self.move_history = history
151
+
152
+ # def _validate_move(self, context: str, move: str) -> Tuple[bool, str]:
153
+ # """Ask the LLM to self-check legality based on board + history."""
154
+ # prompt = (
155
+ # self.VALIDATION_GUIDANCE
156
+ # + "\n\nBOARD + HISTORY CONTEXT:\n"
157
+ # + context
158
+ # + f"\n\nCANDIDATE MOVE: {move}\nRespond strictly with VALID or INVALID and a reason."
159
+ # )
160
+ # verdict = self._llm_once(prompt)
161
+ # if not verdict:
162
+ # return False, "empty validation response"
163
+ # verdict_upper = verdict.strip().upper()
164
+ # if verdict_upper.startswith("VALID"):
165
+ # return True, ""
166
+ # if verdict_upper.startswith("INVALID"):
167
+ # reason = verdict.split(":", 1)[1].strip() if ":" in verdict else "marked invalid"
168
+ # return False, reason
169
+ # return False, f"unrecognized verdict: {verdict[:60]}"
170
+
171
+ # Run one LLM call
172
+ def _llm_once(self, prompt: str) -> str:
173
+
174
+ """Send request directly to Ollama REST API (fixes Windows LangChain bug)."""
175
+ try:
176
+ response = requests.post(
177
+ "http://localhost:11434/api/generate",
178
+ json={
179
+ "model": self.model_name,
180
+ "prompt": prompt,
181
+ "stream": False
182
+ },
183
+ timeout=300
184
+ )
185
+ if response.status_code == 200:
186
+ data = response.json()
187
+ return (data.get("response") or "").strip()
188
+ else:
189
+ print(f"Ollama returned HTTP {response.status_code}: {response.text}")
190
+ return ""
191
+ except Exception as e:
192
+ print(f"Ollama request failed: {e}")
193
+ return ""
194
+
195
+ def __call__(self, observation: str) -> str:
196
+ # Build context
197
+ slim = slice_board_and_moves(observation)
198
+ available_moves = set(extract_legal_moves(observation))
199
+ forbidden_moves = set(extract_forbidden(observation))
200
+
201
+ prompt_history_lines = []
202
+ for line in observation.splitlines():
203
+ if line.startswith("Turn ") or "played[" in line:
204
+ prompt_history_lines.append(line)
205
+ history = "\n".join(prompt_history_lines)
206
+ full_context = slim + ("\n\nMOVE HISTORY:\n" + history if history else "")
207
+
208
+ def _detect_board_size(obs: str) -> Optional[int]:
209
+ """Infer board size from numeric header (e.g., '0 1 2 3')."""
210
+ header_re = re.compile(r"^\s*0(\s+\d+)+\s*$")
211
+ lines = obs.splitlines()
212
+ for i in range(len(lines) - 1, -1, -1):
213
+ if header_re.match(lines[i].strip()):
214
+ nums = [int(n) for n in lines[i].split() if n.isdigit()]
215
+ if nums:
216
+ return max(nums) + 1
217
+ return None
218
+
219
+ def _build_board_map(obs: str) -> dict[str, str]:
220
+ size_local = _detect_board_size(obs)
221
+ if not size_local:
222
+ return {}
223
+ lines = obs.splitlines()
224
+ header_idx = None
225
+ header_re = re.compile(r"^\s*0(\s+\d+)+\s*$")
226
+ for i in range(len(lines)):
227
+ if header_re.match(lines[i].strip()):
228
+ header_idx = i
229
+ break
230
+ if header_idx is None:
231
+ return {}
232
+ board_map: dict[str, str] = {}
233
+ # Expect size_local lines after header
234
+ for r in range(size_local):
235
+ line_idx = header_idx + 1 + r
236
+ if line_idx >= len(lines):
237
+ break
238
+ parts = lines[line_idx].split()
239
+ if not parts:
240
+ continue
241
+ row_label = parts[0]
242
+ cells = parts[1:]
243
+ if len(cells) < size_local:
244
+ continue
245
+ for c in range(size_local):
246
+ pos = f"{row_label.upper()}{c}"
247
+ board_map[pos] = cells[c]
248
+ return board_map
249
+
250
+ board_map = _build_board_map(observation)
251
+
252
+ # >>> THE CRITICAL FIX <<<
253
+ guidance = (
254
+ self.STRATEGIC_GUIDANCE
255
+ + "\n\n"
256
+ + self.prompt_pack.guidance(full_context)
257
+ )
258
+
259
+ recent_moves = set()
260
+ if len(self.move_history) >= 2:
261
+ recent_moves = {m["move"] for m in self.move_history[-2:]}
262
+
263
+ last_error = None
264
+ last_raw: str = ""
265
+ invalid_memory = []
266
+ BARE_MOVE_RE = re.compile(r"\b([A-Z]\d+)\s+([A-Z]\d+)\b")
267
+
268
+ def _extract_move(raw: str):
269
+ m = MOVE_RE.search(raw or "")
270
+
271
+ if m:
272
+ return m.group(0)
273
+ m2 = BARE_MOVE_RE.search(raw or "")
274
+ if m2:
275
+ return f"[{m2.group(1)} {m2.group(2)}]"
276
+ return None
277
+
278
+ # generation + self-validation loop (4 attempts max)
279
+ for attempt in range(4):
280
+ decorated_guidance = guidance
281
+ if invalid_memory:
282
+ decorated_guidance += "\n\nPreviously invalid moves (avoid these):\n" + "\n".join(invalid_memory)
283
+
284
+ raw = self._llm_once(decorated_guidance)
285
+ last_raw = raw or last_raw
286
+ if not raw:
287
+ last_error = "empty response (timeout or HTTP error)"
288
+ continue
289
+
290
+ mv = _extract_move(raw)
291
+ if not mv:
292
+ last_error = f"no move found in response: {raw[:80]!r}"
293
+ continue
294
+
295
+ # Geometric sanity check: block diagonals and multi-step moves from non-Scout pieces
296
+ try:
297
+ src, dst = mv.strip("[]").split()
298
+ sr, sc = ord(src[0]) - 65, int(src[1:])
299
+ dr, dc = ord(dst[0]) - 65, int(dst[1:])
300
+ drow = abs(dr - sr)
301
+ dcol = abs(dc - sc)
302
+ src_token = board_map.get(src, "")
303
+ # Block moving empty/unknown/lake squares
304
+ if src_token in {"", ".", "?", "~"}:
305
+ invalid_memory.append(f"{mv} (source not movable)")
306
+ last_error = "source not movable"
307
+ continue
308
+ # Diagonal
309
+ if drow > 0 and dcol > 0:
310
+ invalid_memory.append(f"{mv} (diagonal not allowed)")
311
+ last_error = "diagonal"
312
+ continue
313
+ # Multi-step non-Scout
314
+ if drow + dcol > 1:
315
+ is_scout = src_token.upper() in {"SC", "SCOUT"}
316
+ if not is_scout:
317
+ invalid_memory.append(f"{mv} (non-Scout multi-step)")
318
+ last_error = "non-Scout multi-step"
319
+ continue
320
+ except Exception:
321
+ pass
322
+
323
+ # quick deterministic veto using env-provided lists
324
+ if available_moves and mv not in available_moves:
325
+ invalid_memory.append(f"{mv} (not in Available Moves)")
326
+ last_error = f"{mv} not in Available Moves"
327
+ print(f" LLM proposed move not in Available Moves: {mv}")
328
+ continue
329
+ if mv in forbidden_moves:
330
+ invalid_memory.append(f"{mv} (in FORBIDDEN)")
331
+ last_error = f"{mv} in FORBIDDEN"
332
+ print(f" LLM proposed forbidden move {mv}")
333
+ continue
334
+
335
+ if mv in recent_moves and len(recent_moves) > 0:
336
+ last_error = f"repeated move {mv}"
337
+ print(f" LLM proposed recent move {mv}, trying alternatives...")
338
+ continue
339
+
340
+ if available_moves:
341
+ return mv
342
+
343
+ def _first_valid_from_list(candidates):
344
+ for mv in candidates:
345
+ if available_moves and mv not in available_moves:
346
+ # print(f" Fallback move not in Available Moves: {mv}")
347
+ continue
348
+ if mv in forbidden_moves:
349
+ # print(f" Fallback forbidden move: {mv}")
350
+ continue
351
+ if mv in recent_moves and len(recent_moves) > 0:
352
+ continue
353
+ if available_moves:
354
+ return mv
355
+ # is_valid, reason = self._validate_move(full_context, mv)
356
+ # if is_valid:
357
+ # return mv
358
+ # print(f" Fallback invalid move {mv}: {reason}")
359
+ return None
360
+
361
+ if last_raw:
362
+ candidates = MOVE_RE.findall(last_raw or "")
363
+ if candidates:
364
+ mv = _first_valid_from_list(candidates)
365
+ if mv:
366
+ return mv
367
+
368
+ # Try to pick a random valid move from available moves
369
+ obs_moves = MOVE_RE.findall(observation)
370
+ if obs_moves:
371
+ mv = _first_valid_from_list(obs_moves)
372
+ if mv:
373
+ return mv
374
+ non_recent = [mv for mv in obs_moves if mv not in recent_moves]
375
+ if non_recent:
376
+ return random.choice(non_recent)
377
+ return random.choice(obs_moves)
378
+
379
+ print(f"[AGENT] {self.model_name} failed to produce valid move after retries.")
380
+ if last_error:
381
+ print(f" Last error: {last_error}")
382
+
383
+ return ""