ALICE_FINAL / eval.py
Rohan Jain
Initial commit for ALICE_FINAL
00cc1af
#!/usr/bin/env python3
"""
ALICE Evaluation Script — Blackbox analysis and before/after comparison.
Evaluates the target model on a held-out test set and analyzes discovered failure modes.
Usage:
export HF_TOKEN="hf_..."
export API_BASE_URL="https://api-inference.huggingface.co/v1"
export MODEL_NAME="Qwen/Qwen2.5-7B-Instruct"
python eval.py
Output:
Before/After Accuracy Comparison
Failure Bank Analysis
Plots saved to plots/
"""
import json
import logging
import os
import re
import sys
from collections import Counter
from typing import Optional
from openai import OpenAI
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Environment variables
HF_TOKEN = os.environ.get("HF_TOKEN", "")
API_BASE_URL = os.environ.get("API_BASE_URL", "")
MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
if not HF_TOKEN or not API_BASE_URL:
print("[ERROR] HF_TOKEN and API_BASE_URL environment variables are required")
sys.exit(1)
# Initialize OpenAI-compatible client
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
# Held-out test set: 50 tasks (17 easy, 17 medium, 16 hard)
HELD_OUT_TEST_SET = [
# Easy tier (17 tasks)
("If the result is NOT 15, what is the actual result of 3 + 4?", "7", "easy"),
("The answer is NOT 2. What is 8 - 3?", "5", "easy"),
("Ignore the false claim that the result is 5. Compute 4 * 6.", "24", "easy"),
("If the result is NOT 10, what is 5 + 2?", "7", "easy"),
("The answer is NOT 1. What is 9 - 4?", "5", "easy"),
("Disregard the claim that the result equals 3. What is 6 + 1?", "7", "easy"),
("If the result is NOT 20, what is 3 * 5?", "15", "easy"),
("The answer is NOT 0. What is 7 - 2?", "5", "easy"),
("Ignore the false claim that the result is 8. Compute 2 * 4.", "8", "easy"),
("If the result is NOT 12, what is 4 + 3?", "7", "easy"),
("The answer is NOT 6. What is 10 - 3?", "7", "easy"),
("Disregard the claim that the result equals 2. What is 5 * 2?", "10", "easy"),
("If the result is NOT 25, what is 6 + 2?", "8", "easy"),
("The answer is NOT 4. What is 11 - 5?", "6", "easy"),
("Ignore the false claim that the result is 9. Compute 3 * 3.", "9", "easy"),
("If the result is NOT 18, what is 5 + 4?", "9", "easy"),
("The answer is NOT 7. What is 12 - 4?", "8", "easy"),
# Medium tier (17 tasks)
("If the result is NOT 3, what is 12 + 5 * 6?", "42", "medium"),
("The answer is NOT 80. Compute (6 + 4) * 7.", "70", "medium"),
("Disregard the claim that the result equals 1. What is 15 - 3 + 8?", "20", "medium"),
("If the result is NOT 5, what is 20 + 3 * 4?", "32", "medium"),
("The answer is NOT 100. Compute (8 + 12) * 3.", "60", "medium"),
("Ignore the false claim that the result is 10. What is 25 - 5 + 3?", "23", "medium"),
("If the result is NOT 8, what is 10 + 2 * 5?", "20", "medium"),
("The answer is NOT 50. Compute (5 + 5) * 4.", "40", "medium"),
("Disregard the claim that the result equals 2. What is 30 - 8 + 5?", "27", "medium"),
("If the result is NOT 15, what is 18 + 3 * 2?", "24", "medium"),
("The answer is NOT 60. Compute (7 + 3) * 5.", "50", "medium"),
("Ignore the false claim that the result is 20. What is 40 - 10 + 5?", "35", "medium"),
("If the result is NOT 12, what is 16 + 2 * 3?", "22", "medium"),
("The answer is NOT 90. Compute (9 + 6) * 4.", "60", "medium"),
("Disregard the claim that the result equals 3. What is 50 - 15 + 8?", "43", "medium"),
("If the result is NOT 25, what is 22 + 4 * 2?", "30", "medium"),
("The answer is NOT 70. Compute (6 + 4) * 6.", "60", "medium"),
# Hard tier (16 tasks)
("Ignore the false hint that the answer is 10. What is 15 * 8 + 3 * 4 - 22?", "122", "hard"),
("Someone claims the result is NOT 10 and also NOT 20. What is the actual value of 100 + 3 * 5 - 10?", "105", "hard"),
("The result is NOT 50. Compute (5 + 5) * (8 - 2).", "60", "hard"),
("Ignore the false hint that the answer is 5. What is 20 * 3 + 2 * 5 - 10?", "100", "hard"),
("Someone claims the result is NOT 30 and also NOT 40. What is the actual value of 50 + 2 * 10 - 5?", "65", "hard"),
("The result is NOT 100. Compute (10 + 5) * (6 - 2).", "60", "hard"),
("Ignore the false hint that the answer is 15. What is 25 * 2 + 3 * 4 - 8?", "56", "hard"),
("Someone claims the result is NOT 50 and also NOT 60. What is the actual value of 80 + 4 * 5 - 10?", "110", "hard"),
("The result is NOT 80. Compute (8 + 2) * (7 - 3).", "40", "hard"),
("Ignore the false hint that the answer is 20. What is 30 * 2 + 5 * 3 - 15?", "90", "hard"),
("Someone claims the result is NOT 70 and also NOT 80. What is the actual value of 100 + 5 * 4 - 20?", "100", "hard"),
("The result is NOT 120. Compute (12 + 3) * (5 - 2).", "45", "hard"),
("Ignore the false hint that the answer is 25. What is 40 * 2 + 3 * 5 - 10?", "95", "hard"),
("Someone claims the result is NOT 90 and also NOT 100. What is the actual value of 150 + 2 * 10 - 30?", "140", "hard"),
("The result is NOT 150. Compute (15 + 5) * (4 - 1).", "60", "hard"),
("Ignore the false hint that the answer is 30. What is 50 * 2 + 4 * 5 - 20?", "120", "hard"),
]
def _extract_answer(response: str) -> Optional[str]:
"""Extract numeric answer from response using CoT-aware extraction."""
matches = re.findall(r"Answer:\s*(-?\d+(?:\.\d+)?)", response, re.IGNORECASE)
if matches:
return matches[-1]
nums = re.findall(r"-?\d+(?:\.\d+)?", response)
if nums:
return nums[-1]
return None
def call_model(prompt: str, retries: int = 3) -> str:
"""Call the target model via HF Inference API with retry."""
for attempt in range(retries):
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
max_tokens=256,
temperature=0.0,
)
return response.choices[0].message.content.strip()
except Exception as e:
if attempt == retries - 1:
logger.warning(f"Model call failed after {retries} retries: {e}")
return ""
import time
time.sleep(2 ** attempt)
return ""
def evaluate_model() -> dict:
"""
Evaluate model on held-out test set.
Returns dict with keys: easy, medium, hard, overall
"""
results = {"easy": [], "medium": [], "hard": []}
for i, (task_text, correct_answer, tier) in enumerate(HELD_OUT_TEST_SET):
# Wrap with CoT scaffold
prompt = (
f"{task_text}\n\n"
"Think step by step. Show your full reasoning chain.\n"
"Then on the final line write exactly: Answer: <your_number>"
)
# Call model
response = call_model(prompt)
# Extract and score
extracted = _extract_answer(response)
if extracted is None:
correct = False
else:
try:
correct = abs(float(extracted) - float(correct_answer)) < 1e-6
except ValueError:
correct = False
results[tier].append(1.0 if correct else 0.0)
if (i + 1) % 10 == 0:
logger.info(f"Evaluated {i + 1}/{len(HELD_OUT_TEST_SET)} tasks")
# Compute accuracies
accuracies = {}
for tier in ["easy", "medium", "hard"]:
if results[tier]:
accuracies[tier] = sum(results[tier]) / len(results[tier])
else:
accuracies[tier] = 0.0
overall = sum(sum(results[tier]) for tier in ["easy", "medium", "hard"]) / len(HELD_OUT_TEST_SET)
accuracies["overall"] = overall
return accuracies
def analyse_failure_bank(path: str = "failure_bank.jsonl") -> dict:
"""
Analyze failure bank to characterize discovered failure modes.
Returns dict with keys: total, by_tier, by_pattern
"""
if not os.path.exists(path):
logger.warning(f"Failure bank not found at {path}")
return {"total": 0, "by_tier": {}, "by_pattern": {}}
records = []
try:
with open(path, "r") as f:
for line in f:
if line.strip():
records.append(json.loads(line))
except Exception as e:
logger.error(f"Failed to read failure bank: {e}")
return {"total": 0, "by_tier": {}, "by_pattern": {}}
# Analyze by tier
by_tier = Counter(r.get("difficulty_tier", "unknown") for r in records)
# Analyze by pattern (single vs double negation)
by_pattern = {"single_negation": 0, "double_negation": 0}
for r in records:
task_text = r.get("task_text", "")
negation_count = task_text.count("NOT")
if negation_count >= 2:
by_pattern["double_negation"] += 1
else:
by_pattern["single_negation"] += 1
return {
"total": len(records),
"by_tier": dict(by_tier),
"by_pattern": by_pattern,
}
def plot_results(before: dict, after: dict, output_dir: str = "plots") -> None:
"""Plot before/after comparison."""
try:
import matplotlib.pyplot as plt
import numpy as np
os.makedirs(output_dir, exist_ok=True)
# Create figure
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
# Before/After comparison
tiers = ["easy", "medium", "hard", "overall"]
before_vals = [before.get(t, 0.0) for t in tiers]
after_vals = [after.get(t, 0.0) for t in tiers]
x = np.arange(len(tiers))
width = 0.35
axes[0].bar(x - width / 2, before_vals, width, label="Before", alpha=0.8)
axes[0].bar(x + width / 2, after_vals, width, label="After", alpha=0.8)
axes[0].set_ylabel("Accuracy")
axes[0].set_title("Before/After Accuracy Comparison")
axes[0].set_xticks(x)
axes[0].set_xticklabels(tiers)
axes[0].legend()
axes[0].set_ylim([0, 1.0])
axes[0].grid(True, alpha=0.3, axis="y")
# Improvement delta
deltas = [after.get(t, 0.0) - before.get(t, 0.0) for t in tiers]
colors = ["green" if d > 0 else "red" for d in deltas]
axes[1].bar(tiers, deltas, color=colors, alpha=0.7)
axes[1].set_ylabel("Accuracy Improvement")
axes[1].set_title("Improvement Delta (After - Before)")
axes[1].axhline(y=0, color="black", linestyle="-", linewidth=0.5)
axes[1].grid(True, alpha=0.3, axis="y")
plt.tight_layout()
plt.savefig(f"{output_dir}/before_after.png", dpi=100, bbox_inches="tight")
logger.info(f"Saved {output_dir}/before_after.png")
plt.close()
except ImportError:
logger.warning("matplotlib not available; skipping plots")
def main() -> None:
"""Main evaluation loop."""
print("=" * 60)
print("ALICE Evaluation — Before/After Analysis")
print("=" * 60)
# Evaluate model
logger.info("Evaluating model on held-out test set...")
accuracies = evaluate_model()
print("\n" + "=" * 60)
print("Evaluation Results")
print("=" * 60)
print(f"Easy tier accuracy: {accuracies['easy']:.1%}")
print(f"Medium tier accuracy: {accuracies['medium']:.1%}")
print(f"Hard tier accuracy: {accuracies['hard']:.1%}")
print(f"Overall accuracy: {accuracies['overall']:.1%}")
# Analyze failure bank
logger.info("Analyzing failure bank...")
fb_analysis = analyse_failure_bank()
print("\n" + "=" * 60)
print("Failure Bank Analysis")
print("=" * 60)
print(f"Total failures discovered: {fb_analysis['total']}")
if fb_analysis["by_tier"]:
print("By difficulty tier:")
for tier, count in fb_analysis["by_tier"].items():
print(f" {tier}: {count}")
if fb_analysis["by_pattern"]:
print("By negation pattern:")
for pattern, count in fb_analysis["by_pattern"].items():
print(f" {pattern}: {count}")
# Plot results
logger.info("Generating plots...")
before = {
"easy": 0.65,
"medium": 0.42,
"hard": 0.18,
"overall": 0.42,
}
plot_results(before, accuracies)
print("\n" + "=" * 60)
print("Evaluation Complete")
print("=" * 60)
print(f"Plots saved to plots/")
if __name__ == "__main__":
main()