DarshanScripts commited on
Commit
93ac131
·
verified ·
1 Parent(s): baa0817

Upload stratego\prompts\evaluate_prompts_multiturn.py with huggingface_hub

Browse files
stratego//prompts//evaluate_prompts_multiturn.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ollama
2
+ import re
3
+ import random
4
+ import statistics
5
+ from dataclasses import dataclass
6
+
7
+
8
+ # --- Définition du PromptPack ---
9
+ @dataclass
10
+ class PromptPack:
11
+ name: str
12
+ system: str
13
+ guidance_template: str
14
+
15
+ def build_prompt(self, board_slice: str) -> str:
16
+ return f"{self.system}\n\n{self.guidance_template.format(board_slice=board_slice)}"
17
+
18
+
19
+ # --- Variantes de prompts Stratégo ---
20
+ PROMPTS = [
21
+ PromptPack(
22
+ "base",
23
+ "You are a Stratego-playing AI. Output exactly one move [SRC DST].",
24
+ "{board_slice}\nPick one move from 'Available Moves:' and avoid 'FORBIDDEN:'.",
25
+ ),
26
+ PromptPack(
27
+ "defensive",
28
+ "You are a defensive Stratego AI. Prefer safe and backward moves.",
29
+ "{board_slice}\nPick one safe move and avoid 'FORBIDDEN:'.",
30
+ ),
31
+ PromptPack(
32
+ "adaptive",
33
+ "You are an expert Stratego AI. Balance offense and defense smartly.",
34
+ "{board_slice}\nChoose one optimal move considering both safety and progress. Avoid 'FORBIDDEN:'.",
35
+ ),
36
+ ]
37
+
38
+
39
+ # --- Fonctions utilitaires ---
40
+ def extract_moves(board_slice: str):
41
+ available_line = next((l for l in board_slice.splitlines() if "Available Moves:" in l), "")
42
+ forbidden_line = next((l for l in board_slice.splitlines() if "FORBIDDEN:" in l), "")
43
+ available = re.findall(r"\[[A-Z][0-9] [A-Z][0-9]\]", available_line)
44
+ forbidden = re.findall(r"\[[A-Z][0-9] [A-Z][0-9]\]", forbidden_line)
45
+ return available, forbidden
46
+
47
+
48
+ def is_valid_move(move: str, available: list, forbidden: list):
49
+ return move in available and move not in forbidden
50
+
51
+
52
+ def query_ollama(model: str, prompt: str) -> str:
53
+ try:
54
+ response = ollama.chat(model=model, messages=[{"role": "user", "content": prompt}])
55
+ text = response["message"]["content"]
56
+ match = re.search(r"\[[A-Z][0-9] [A-Z][0-9]\]", text)
57
+ return match.group(0) if match else "INVALID"
58
+ except Exception as e:
59
+ print(f"⚠️ Ollama error: {e}")
60
+ return "INVALID"
61
+
62
+
63
+ # --- Simulation de plusieurs tours ---
64
+ def generate_board_slices(num_rounds=5):
65
+ letters = ["A", "B", "C", "D", "E", "F"]
66
+ boards = []
67
+ for _ in range(num_rounds):
68
+ available = [f"[{random.choice(letters)}{random.randint(1,6)} {random.choice(letters)}{random.randint(1,6)}]" for _ in range(4)]
69
+ forbidden = random.sample(available, k=random.randint(0, 1))
70
+ board_slice = f"Available Moves: {', '.join(available)}\nFORBIDDEN: {', '.join(forbidden)}"
71
+ boards.append(board_slice)
72
+ return boards
73
+
74
+
75
+ # --- Évaluation multi-turn ---
76
+ def evaluate_prompts_multiturn(model: str, num_rounds=5):
77
+ boards = generate_board_slices(num_rounds)
78
+ scores = {p.name: [] for p in PROMPTS}
79
+
80
+ print(f"\n Starting evaluation on {num_rounds} rounds with model: {model}\n")
81
+
82
+ for round_idx, board_slice in enumerate(boards, start=1):
83
+ available, forbidden = extract_moves(board_slice)
84
+ print(f"\n===== ROUND {round_idx} =====")
85
+ print(board_slice)
86
+
87
+ for pack in PROMPTS:
88
+ prompt_text = pack.build_prompt(board_slice)
89
+ move = query_ollama(model, prompt_text)
90
+ valid = is_valid_move(move, available, forbidden)
91
+ scores[pack.name].append(1 if valid else 0)
92
+ print(f"→ {pack.name.upper():<10} | Move: {move:<10} | Valid: {valid}")
93
+
94
+ # --- Résumé global ---
95
+ print("\n === FINAL RESULTS ===")
96
+ for name, result_list in scores.items():
97
+ avg = statistics.mean(result_list)
98
+ print(f"{name.capitalize():<10}: {sum(result_list)}/{len(result_list)} valid moves ({avg*100:.1f}%)")
99
+
100
+ best_prompt = max(scores.items(), key=lambda x: statistics.mean(x[1]))[0]
101
+ print(f"\n Best performing prompt: {best_prompt.upper()}")
102
+ return scores
103
+
104
+
105
+ # --- Lancer le test ---
106
+ #if __name__ == "__main__":
107
+ # evaluate_prompts_multiturn("gemma:2b", num_rounds=5)