Spaces:
Configuration error
Configuration error
Upload stratego\prompts\evaluate_prompts_multiturn.py with huggingface_hub
Browse files
stratego//prompts//evaluate_prompts_multiturn.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ollama
|
| 2 |
+
import re
|
| 3 |
+
import random
|
| 4 |
+
import statistics
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# --- Définition du PromptPack ---
|
| 9 |
+
@dataclass
|
| 10 |
+
class PromptPack:
|
| 11 |
+
name: str
|
| 12 |
+
system: str
|
| 13 |
+
guidance_template: str
|
| 14 |
+
|
| 15 |
+
def build_prompt(self, board_slice: str) -> str:
|
| 16 |
+
return f"{self.system}\n\n{self.guidance_template.format(board_slice=board_slice)}"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# --- Variantes de prompts Stratégo ---
|
| 20 |
+
PROMPTS = [
|
| 21 |
+
PromptPack(
|
| 22 |
+
"base",
|
| 23 |
+
"You are a Stratego-playing AI. Output exactly one move [SRC DST].",
|
| 24 |
+
"{board_slice}\nPick one move from 'Available Moves:' and avoid 'FORBIDDEN:'.",
|
| 25 |
+
),
|
| 26 |
+
PromptPack(
|
| 27 |
+
"defensive",
|
| 28 |
+
"You are a defensive Stratego AI. Prefer safe and backward moves.",
|
| 29 |
+
"{board_slice}\nPick one safe move and avoid 'FORBIDDEN:'.",
|
| 30 |
+
),
|
| 31 |
+
PromptPack(
|
| 32 |
+
"adaptive",
|
| 33 |
+
"You are an expert Stratego AI. Balance offense and defense smartly.",
|
| 34 |
+
"{board_slice}\nChoose one optimal move considering both safety and progress. Avoid 'FORBIDDEN:'.",
|
| 35 |
+
),
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# --- Fonctions utilitaires ---
|
| 40 |
+
def extract_moves(board_slice: str):
|
| 41 |
+
available_line = next((l for l in board_slice.splitlines() if "Available Moves:" in l), "")
|
| 42 |
+
forbidden_line = next((l for l in board_slice.splitlines() if "FORBIDDEN:" in l), "")
|
| 43 |
+
available = re.findall(r"\[[A-Z][0-9] [A-Z][0-9]\]", available_line)
|
| 44 |
+
forbidden = re.findall(r"\[[A-Z][0-9] [A-Z][0-9]\]", forbidden_line)
|
| 45 |
+
return available, forbidden
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def is_valid_move(move: str, available: list, forbidden: list):
|
| 49 |
+
return move in available and move not in forbidden
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def query_ollama(model: str, prompt: str) -> str:
|
| 53 |
+
try:
|
| 54 |
+
response = ollama.chat(model=model, messages=[{"role": "user", "content": prompt}])
|
| 55 |
+
text = response["message"]["content"]
|
| 56 |
+
match = re.search(r"\[[A-Z][0-9] [A-Z][0-9]\]", text)
|
| 57 |
+
return match.group(0) if match else "INVALID"
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"⚠️ Ollama error: {e}")
|
| 60 |
+
return "INVALID"
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# --- Simulation de plusieurs tours ---
|
| 64 |
+
def generate_board_slices(num_rounds=5):
|
| 65 |
+
letters = ["A", "B", "C", "D", "E", "F"]
|
| 66 |
+
boards = []
|
| 67 |
+
for _ in range(num_rounds):
|
| 68 |
+
available = [f"[{random.choice(letters)}{random.randint(1,6)} {random.choice(letters)}{random.randint(1,6)}]" for _ in range(4)]
|
| 69 |
+
forbidden = random.sample(available, k=random.randint(0, 1))
|
| 70 |
+
board_slice = f"Available Moves: {', '.join(available)}\nFORBIDDEN: {', '.join(forbidden)}"
|
| 71 |
+
boards.append(board_slice)
|
| 72 |
+
return boards
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# --- Évaluation multi-turn ---
|
| 76 |
+
def evaluate_prompts_multiturn(model: str, num_rounds=5):
|
| 77 |
+
boards = generate_board_slices(num_rounds)
|
| 78 |
+
scores = {p.name: [] for p in PROMPTS}
|
| 79 |
+
|
| 80 |
+
print(f"\n Starting evaluation on {num_rounds} rounds with model: {model}\n")
|
| 81 |
+
|
| 82 |
+
for round_idx, board_slice in enumerate(boards, start=1):
|
| 83 |
+
available, forbidden = extract_moves(board_slice)
|
| 84 |
+
print(f"\n===== ROUND {round_idx} =====")
|
| 85 |
+
print(board_slice)
|
| 86 |
+
|
| 87 |
+
for pack in PROMPTS:
|
| 88 |
+
prompt_text = pack.build_prompt(board_slice)
|
| 89 |
+
move = query_ollama(model, prompt_text)
|
| 90 |
+
valid = is_valid_move(move, available, forbidden)
|
| 91 |
+
scores[pack.name].append(1 if valid else 0)
|
| 92 |
+
print(f"→ {pack.name.upper():<10} | Move: {move:<10} | Valid: {valid}")
|
| 93 |
+
|
| 94 |
+
# --- Résumé global ---
|
| 95 |
+
print("\n === FINAL RESULTS ===")
|
| 96 |
+
for name, result_list in scores.items():
|
| 97 |
+
avg = statistics.mean(result_list)
|
| 98 |
+
print(f"{name.capitalize():<10}: {sum(result_list)}/{len(result_list)} valid moves ({avg*100:.1f}%)")
|
| 99 |
+
|
| 100 |
+
best_prompt = max(scores.items(), key=lambda x: statistics.mean(x[1]))[0]
|
| 101 |
+
print(f"\n Best performing prompt: {best_prompt.upper()}")
|
| 102 |
+
return scores
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# --- Lancer le test ---
|
| 106 |
+
#if __name__ == "__main__":
|
| 107 |
+
# evaluate_prompts_multiturn("gemma:2b", num_rounds=5)
|