Spaces:
Sleeping
Sleeping
| """Calibration module — 100-sample calibration + Cohen's kappa check.""" | |
| import logging | |
| import random | |
| from typing import Dict, List, Tuple | |
| from src.eval.judge import ( | |
| compute_cohens_kappa, | |
| judge_caption_claude, | |
| judge_caption_gpt4o, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def run_calibration( | |
| captions: List[str], | |
| styles: List[str], | |
| style_definitions: Dict[str, str], | |
| style_anchors: Dict[str, Dict[str, str]], | |
| n_calibration: int = 100, | |
| kappa_threshold: float = 0.65, | |
| seed: int = 42, | |
| dry_run: bool = False, | |
| ) -> Tuple[float, bool]: | |
| """Run calibration with both judges on a sample. | |
| Args: | |
| captions: All generated captions | |
| styles: Corresponding style labels | |
| style_definitions: Style name → definition | |
| style_anchors: Style name → {"anchor_ss1": ..., "anchor_ss5": ...} | |
| n_calibration: Number of samples to calibrate on | |
| kappa_threshold: Minimum acceptable kappa | |
| seed: Random seed for sampling | |
| dry_run: Use mock judge responses | |
| Returns: | |
| (kappa, passed) tuple | |
| """ | |
| rng = random.Random(seed) | |
| n = min(n_calibration, len(captions)) | |
| indices = rng.sample(range(len(captions)), n) | |
| claude_scores = [] | |
| gpt4o_scores = [] | |
| for idx in indices: | |
| caption = captions[idx] | |
| style = styles[idx] | |
| defn = style_definitions.get(style, "") | |
| anchors = style_anchors.get(style, {}) | |
| s_claude = judge_caption_claude( | |
| caption, style, defn, | |
| anchors.get("anchor_ss1", ""), | |
| anchors.get("anchor_ss5", ""), | |
| dry_run=dry_run, | |
| ) | |
| s_gpt4o = judge_caption_gpt4o( | |
| caption, style, defn, | |
| anchors.get("anchor_ss1", ""), | |
| anchors.get("anchor_ss5", ""), | |
| dry_run=dry_run, | |
| ) | |
| claude_scores.append(s_claude) | |
| gpt4o_scores.append(s_gpt4o) | |
| kappa = compute_cohens_kappa(claude_scores, gpt4o_scores) | |
| passed = kappa >= kappa_threshold | |
| logger.info(f"Calibration: kappa={kappa:.3f}, threshold={kappa_threshold}, " | |
| f"passed={passed}, n={n}") | |
| if not passed: | |
| logger.warning( | |
| f"CALIBRATION FAILED: kappa={kappa:.3f} < {kappa_threshold}. " | |
| "Revise style definitions / judge prompt before proceeding." | |
| ) | |
| return kappa, passed | |