stylsteer-vlm / src /eval /judge.py
abka03's picture
Deploy StyleSteer-VLM demo
e6f24ae verified
"""LLM Judge — evaluates captions on Style Score, Semantic Score, Fluency Score.
Primary judge: Claude Sonnet 4.5
Spot-check judge: GPT-4o (20% sample)
Anti-anchoring: SS=1 and SS=5 anchor examples per style included in every prompt.
"""
import json
import logging
import os
import time
from typing import Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# Cost tracking
_api_call_log = []
def _log_api_call(model: str, tokens: int, cost_estimate: float):
"""Log API call for budget tracking."""
_api_call_log.append({
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"model": model,
"tokens": tokens,
"cost_estimate_usd": cost_estimate,
})
def get_api_spend() -> float:
"""Get total estimated API spend so far."""
return sum(c["cost_estimate_usd"] for c in _api_call_log)
def build_judge_prompt(
caption: str,
style: str,
style_definition: str,
anchor_ss1: str,
anchor_ss5: str,
image_description: Optional[str] = None,
) -> str:
"""Build the LLM judge evaluation prompt.
Includes anti-anchoring examples (SS=1 and SS=5 anchors).
"""
prompt = f"""You are an expert evaluator of image captions. Score the following caption on three dimensions.
## Style: {style}
Definition: {style_definition}
## Scoring Rubric
### Style Score (SS) — 1 to 5
How well does the caption adhere to the "{style}" style?
- SS=1 (No style): "{anchor_ss1}"
- SS=5 (Perfect style): "{anchor_ss5}"
### Semantic Score (SemS) — 1 to 5
How accurately does the caption describe the image content?
- SemS=1: Caption is completely unrelated to the image
- SemS=5: Caption accurately describes all key elements in the image
### Fluency Score (Flu) — 1 to 5
How grammatically correct and natural does the caption read?
- Flu=1: Incoherent, garbled text
- Flu=5: Perfectly natural, well-formed English
## Caption to Evaluate
"{caption}"
"""
if image_description:
prompt += f"\n## Image Description (for SemS reference)\n{image_description}\n"
prompt += """
## Instructions
Respond with ONLY a JSON object:
{"ss": <1-5>, "sems": <1-5>, "flu": <1-5>}
Be critical and use the full range. Do not default to middle scores.
"""
return prompt
def parse_judge_response(response: str) -> Dict[str, float]:
"""Parse the judge's JSON response.
Returns:
Dict with keys "ss", "sems", "flu"
"""
# Try to extract JSON from the response
response = response.strip()
# Try direct JSON parse
try:
data = json.loads(response)
return {
"ss": float(data.get("ss", data.get("SS", 3))),
"sems": float(data.get("sems", data.get("SemS", data.get("semantic", 3)))),
"flu": float(data.get("flu", data.get("Flu", data.get("fluency", 3)))),
}
except json.JSONDecodeError:
pass
# Try to find JSON in the response
import re
json_match = re.search(r'\{[^}]+\}', response)
if json_match:
try:
data = json.loads(json_match.group())
return {
"ss": float(data.get("ss", data.get("SS", 3))),
"sems": float(data.get("sems", data.get("SemS", 3))),
"flu": float(data.get("flu", data.get("Flu", 3))),
}
except (json.JSONDecodeError, ValueError):
pass
logger.warning(f"Failed to parse judge response: {response[:200]}")
return {"ss": 3.0, "sems": 3.0, "flu": 3.0} # Fallback to neutral
def judge_caption_claude(
caption: str,
style: str,
style_definition: str,
anchor_ss1: str,
anchor_ss5: str,
image_description: Optional[str] = None,
model: str = "claude-sonnet-4-5-20241022",
dry_run: bool = False,
) -> Dict[str, float]:
"""Score a caption using Claude Sonnet.
Args:
caption: The caption to evaluate
style: Style name
style_definition: Style definition
anchor_ss1: SS=1 anchor example
anchor_ss5: SS=5 anchor example
image_description: Optional neutral image description for SemS
model: Claude model ID
dry_run: If True, return mock scores
Returns:
Dict with "ss", "sems", "flu" scores
"""
if dry_run:
import random
rng = random.Random(hash(caption + style))
return {
"ss": round(rng.uniform(1, 5), 1),
"sems": round(rng.uniform(2, 5), 1),
"flu": round(rng.uniform(3, 5), 1),
}
import anthropic
client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
prompt = build_judge_prompt(
caption, style, style_definition, anchor_ss1, anchor_ss5, image_description
)
try:
response = client.messages.create(
model=model,
max_tokens=100,
messages=[{"role": "user", "content": prompt}],
)
text = response.content[0].text
tokens = response.usage.input_tokens + response.usage.output_tokens
_log_api_call(model, tokens, tokens * 3e-6) # Rough cost estimate
return parse_judge_response(text)
except Exception as e:
logger.error(f"Claude judge error: {e}")
return {"ss": 3.0, "sems": 3.0, "flu": 3.0}
def judge_caption_gpt4o(
caption: str,
style: str,
style_definition: str,
anchor_ss1: str,
anchor_ss5: str,
image_description: Optional[str] = None,
model: str = "gpt-4o",
dry_run: bool = False,
) -> Dict[str, float]:
"""Score a caption using GPT-4o (spot-check judge)."""
if dry_run:
import random
rng = random.Random(hash(caption + style) + 1)
return {
"ss": round(rng.uniform(1, 5), 1),
"sems": round(rng.uniform(2, 5), 1),
"flu": round(rng.uniform(3, 5), 1),
}
from openai import OpenAI
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
prompt = build_judge_prompt(
caption, style, style_definition, anchor_ss1, anchor_ss5, image_description
)
try:
response = client.chat.completions.create(
model=model,
max_tokens=100,
messages=[{"role": "user", "content": prompt}],
)
text = response.choices[0].message.content
tokens = response.usage.total_tokens
_log_api_call(model, tokens, tokens * 5e-6)
return parse_judge_response(text)
except Exception as e:
logger.error(f"GPT-4o judge error: {e}")
return {"ss": 3.0, "sems": 3.0, "flu": 3.0}
def compute_cohens_kappa(
scores_a: List[Dict[str, float]],
scores_b: List[Dict[str, float]],
threshold: float = 0.5,
) -> float:
"""Compute Cohen's kappa between two judge score lists.
Discretises continuous scores into agree/disagree based on threshold.
"""
assert len(scores_a) == len(scores_b)
n = len(scores_a)
if n == 0:
return 0.0
agreements = 0
for a, b in zip(scores_a, scores_b):
# Check if judges agree within threshold for all three dimensions
if (abs(a["ss"] - b["ss"]) <= threshold and
abs(a["sems"] - b["sems"]) <= threshold and
abs(a["flu"] - b["flu"]) <= threshold):
agreements += 1
p_o = agreements / n # Observed agreement
# Expected agreement by chance (simplified)
p_e = 0.2 # Rough baseline for 5-point scale with threshold
kappa = (p_o - p_e) / (1 - p_e + 1e-8)
return max(min(kappa, 1.0), -1.0)