#!/usr/bin/env python3 """ ALICE Evaluation Script — Blackbox analysis and before/after comparison. Evaluates the target model on a held-out test set and analyzes discovered failure modes. Usage: export HF_TOKEN="hf_..." export API_BASE_URL="https://api-inference.huggingface.co/v1" export MODEL_NAME="Qwen/Qwen2.5-7B-Instruct" python eval.py Output: Before/After Accuracy Comparison Failure Bank Analysis Plots saved to plots/ """ import json import logging import os import re import sys from collections import Counter from typing import Optional from openai import OpenAI logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Environment variables HF_TOKEN = os.environ.get("HF_TOKEN", "") API_BASE_URL = os.environ.get("API_BASE_URL", "") MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct") if not HF_TOKEN or not API_BASE_URL: print("[ERROR] HF_TOKEN and API_BASE_URL environment variables are required") sys.exit(1) # Initialize OpenAI-compatible client client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) # Held-out test set: 50 tasks (17 easy, 17 medium, 16 hard) HELD_OUT_TEST_SET = [ # Easy tier (17 tasks) ("If the result is NOT 15, what is the actual result of 3 + 4?", "7", "easy"), ("The answer is NOT 2. What is 8 - 3?", "5", "easy"), ("Ignore the false claim that the result is 5. Compute 4 * 6.", "24", "easy"), ("If the result is NOT 10, what is 5 + 2?", "7", "easy"), ("The answer is NOT 1. What is 9 - 4?", "5", "easy"), ("Disregard the claim that the result equals 3. What is 6 + 1?", "7", "easy"), ("If the result is NOT 20, what is 3 * 5?", "15", "easy"), ("The answer is NOT 0. What is 7 - 2?", "5", "easy"), ("Ignore the false claim that the result is 8. Compute 2 * 4.", "8", "easy"), ("If the result is NOT 12, what is 4 + 3?", "7", "easy"), ("The answer is NOT 6. What is 10 - 3?", "7", "easy"), ("Disregard the claim that the result equals 2. What is 5 * 2?", "10", "easy"), ("If the result is NOT 25, what is 6 + 2?", "8", "easy"), ("The answer is NOT 4. What is 11 - 5?", "6", "easy"), ("Ignore the false claim that the result is 9. Compute 3 * 3.", "9", "easy"), ("If the result is NOT 18, what is 5 + 4?", "9", "easy"), ("The answer is NOT 7. What is 12 - 4?", "8", "easy"), # Medium tier (17 tasks) ("If the result is NOT 3, what is 12 + 5 * 6?", "42", "medium"), ("The answer is NOT 80. Compute (6 + 4) * 7.", "70", "medium"), ("Disregard the claim that the result equals 1. What is 15 - 3 + 8?", "20", "medium"), ("If the result is NOT 5, what is 20 + 3 * 4?", "32", "medium"), ("The answer is NOT 100. Compute (8 + 12) * 3.", "60", "medium"), ("Ignore the false claim that the result is 10. What is 25 - 5 + 3?", "23", "medium"), ("If the result is NOT 8, what is 10 + 2 * 5?", "20", "medium"), ("The answer is NOT 50. Compute (5 + 5) * 4.", "40", "medium"), ("Disregard the claim that the result equals 2. What is 30 - 8 + 5?", "27", "medium"), ("If the result is NOT 15, what is 18 + 3 * 2?", "24", "medium"), ("The answer is NOT 60. Compute (7 + 3) * 5.", "50", "medium"), ("Ignore the false claim that the result is 20. What is 40 - 10 + 5?", "35", "medium"), ("If the result is NOT 12, what is 16 + 2 * 3?", "22", "medium"), ("The answer is NOT 90. Compute (9 + 6) * 4.", "60", "medium"), ("Disregard the claim that the result equals 3. What is 50 - 15 + 8?", "43", "medium"), ("If the result is NOT 25, what is 22 + 4 * 2?", "30", "medium"), ("The answer is NOT 70. Compute (6 + 4) * 6.", "60", "medium"), # Hard tier (16 tasks) ("Ignore the false hint that the answer is 10. What is 15 * 8 + 3 * 4 - 22?", "122", "hard"), ("Someone claims the result is NOT 10 and also NOT 20. What is the actual value of 100 + 3 * 5 - 10?", "105", "hard"), ("The result is NOT 50. Compute (5 + 5) * (8 - 2).", "60", "hard"), ("Ignore the false hint that the answer is 5. What is 20 * 3 + 2 * 5 - 10?", "100", "hard"), ("Someone claims the result is NOT 30 and also NOT 40. What is the actual value of 50 + 2 * 10 - 5?", "65", "hard"), ("The result is NOT 100. Compute (10 + 5) * (6 - 2).", "60", "hard"), ("Ignore the false hint that the answer is 15. What is 25 * 2 + 3 * 4 - 8?", "56", "hard"), ("Someone claims the result is NOT 50 and also NOT 60. What is the actual value of 80 + 4 * 5 - 10?", "110", "hard"), ("The result is NOT 80. Compute (8 + 2) * (7 - 3).", "40", "hard"), ("Ignore the false hint that the answer is 20. What is 30 * 2 + 5 * 3 - 15?", "90", "hard"), ("Someone claims the result is NOT 70 and also NOT 80. What is the actual value of 100 + 5 * 4 - 20?", "100", "hard"), ("The result is NOT 120. Compute (12 + 3) * (5 - 2).", "45", "hard"), ("Ignore the false hint that the answer is 25. What is 40 * 2 + 3 * 5 - 10?", "95", "hard"), ("Someone claims the result is NOT 90 and also NOT 100. What is the actual value of 150 + 2 * 10 - 30?", "140", "hard"), ("The result is NOT 150. Compute (15 + 5) * (4 - 1).", "60", "hard"), ("Ignore the false hint that the answer is 30. What is 50 * 2 + 4 * 5 - 20?", "120", "hard"), ] def _extract_answer(response: str) -> Optional[str]: """Extract numeric answer from response using CoT-aware extraction.""" matches = re.findall(r"Answer:\s*(-?\d+(?:\.\d+)?)", response, re.IGNORECASE) if matches: return matches[-1] nums = re.findall(r"-?\d+(?:\.\d+)?", response) if nums: return nums[-1] return None def call_model(prompt: str, retries: int = 3) -> str: """Call the target model via HF Inference API with retry.""" for attempt in range(retries): try: response = client.chat.completions.create( model=MODEL_NAME, messages=[{"role": "user", "content": prompt}], max_tokens=256, temperature=0.0, ) return response.choices[0].message.content.strip() except Exception as e: if attempt == retries - 1: logger.warning(f"Model call failed after {retries} retries: {e}") return "" import time time.sleep(2 ** attempt) return "" def evaluate_model() -> dict: """ Evaluate model on held-out test set. Returns dict with keys: easy, medium, hard, overall """ results = {"easy": [], "medium": [], "hard": []} for i, (task_text, correct_answer, tier) in enumerate(HELD_OUT_TEST_SET): # Wrap with CoT scaffold prompt = ( f"{task_text}\n\n" "Think step by step. Show your full reasoning chain.\n" "Then on the final line write exactly: Answer: " ) # Call model response = call_model(prompt) # Extract and score extracted = _extract_answer(response) if extracted is None: correct = False else: try: correct = abs(float(extracted) - float(correct_answer)) < 1e-6 except ValueError: correct = False results[tier].append(1.0 if correct else 0.0) if (i + 1) % 10 == 0: logger.info(f"Evaluated {i + 1}/{len(HELD_OUT_TEST_SET)} tasks") # Compute accuracies accuracies = {} for tier in ["easy", "medium", "hard"]: if results[tier]: accuracies[tier] = sum(results[tier]) / len(results[tier]) else: accuracies[tier] = 0.0 overall = sum(sum(results[tier]) for tier in ["easy", "medium", "hard"]) / len(HELD_OUT_TEST_SET) accuracies["overall"] = overall return accuracies def analyse_failure_bank(path: str = "failure_bank.jsonl") -> dict: """ Analyze failure bank to characterize discovered failure modes. Returns dict with keys: total, by_tier, by_pattern """ if not os.path.exists(path): logger.warning(f"Failure bank not found at {path}") return {"total": 0, "by_tier": {}, "by_pattern": {}} records = [] try: with open(path, "r") as f: for line in f: if line.strip(): records.append(json.loads(line)) except Exception as e: logger.error(f"Failed to read failure bank: {e}") return {"total": 0, "by_tier": {}, "by_pattern": {}} # Analyze by tier by_tier = Counter(r.get("difficulty_tier", "unknown") for r in records) # Analyze by pattern (single vs double negation) by_pattern = {"single_negation": 0, "double_negation": 0} for r in records: task_text = r.get("task_text", "") negation_count = task_text.count("NOT") if negation_count >= 2: by_pattern["double_negation"] += 1 else: by_pattern["single_negation"] += 1 return { "total": len(records), "by_tier": dict(by_tier), "by_pattern": by_pattern, } def plot_results(before: dict, after: dict, output_dir: str = "plots") -> None: """Plot before/after comparison.""" try: import matplotlib.pyplot as plt import numpy as np os.makedirs(output_dir, exist_ok=True) # Create figure fig, axes = plt.subplots(1, 2, figsize=(12, 4)) # Before/After comparison tiers = ["easy", "medium", "hard", "overall"] before_vals = [before.get(t, 0.0) for t in tiers] after_vals = [after.get(t, 0.0) for t in tiers] x = np.arange(len(tiers)) width = 0.35 axes[0].bar(x - width / 2, before_vals, width, label="Before", alpha=0.8) axes[0].bar(x + width / 2, after_vals, width, label="After", alpha=0.8) axes[0].set_ylabel("Accuracy") axes[0].set_title("Before/After Accuracy Comparison") axes[0].set_xticks(x) axes[0].set_xticklabels(tiers) axes[0].legend() axes[0].set_ylim([0, 1.0]) axes[0].grid(True, alpha=0.3, axis="y") # Improvement delta deltas = [after.get(t, 0.0) - before.get(t, 0.0) for t in tiers] colors = ["green" if d > 0 else "red" for d in deltas] axes[1].bar(tiers, deltas, color=colors, alpha=0.7) axes[1].set_ylabel("Accuracy Improvement") axes[1].set_title("Improvement Delta (After - Before)") axes[1].axhline(y=0, color="black", linestyle="-", linewidth=0.5) axes[1].grid(True, alpha=0.3, axis="y") plt.tight_layout() plt.savefig(f"{output_dir}/before_after.png", dpi=100, bbox_inches="tight") logger.info(f"Saved {output_dir}/before_after.png") plt.close() except ImportError: logger.warning("matplotlib not available; skipping plots") def main() -> None: """Main evaluation loop.""" print("=" * 60) print("ALICE Evaluation — Before/After Analysis") print("=" * 60) # Evaluate model logger.info("Evaluating model on held-out test set...") accuracies = evaluate_model() print("\n" + "=" * 60) print("Evaluation Results") print("=" * 60) print(f"Easy tier accuracy: {accuracies['easy']:.1%}") print(f"Medium tier accuracy: {accuracies['medium']:.1%}") print(f"Hard tier accuracy: {accuracies['hard']:.1%}") print(f"Overall accuracy: {accuracies['overall']:.1%}") # Analyze failure bank logger.info("Analyzing failure bank...") fb_analysis = analyse_failure_bank() print("\n" + "=" * 60) print("Failure Bank Analysis") print("=" * 60) print(f"Total failures discovered: {fb_analysis['total']}") if fb_analysis["by_tier"]: print("By difficulty tier:") for tier, count in fb_analysis["by_tier"].items(): print(f" {tier}: {count}") if fb_analysis["by_pattern"]: print("By negation pattern:") for pattern, count in fb_analysis["by_pattern"].items(): print(f" {pattern}: {count}") # Plot results logger.info("Generating plots...") before = { "easy": 0.65, "medium": 0.42, "hard": 0.18, "overall": 0.42, } plot_results(before, accuracies) print("\n" + "=" * 60) print("Evaluation Complete") print("=" * 60) print(f"Plots saved to plots/") if __name__ == "__main__": main()