Spaces:
Paused
Paused
| """ | |
| Judge ↔ human alignment: persistence + agreement computation. | |
| Pairs each LLM-judge verdict (from ``potato/ai/judge.py``) with the human gold | |
| label for the same instance/schema and computes Cohen's κ, a confusion matrix, | |
| agreement rate, and the list of disagreements. Judge predictions are persisted | |
| per *prompt version* so the admin report can track κ as the rubric is calibrated. | |
| Layout under ``{task_dir}/judge_alignment/``: | |
| predictions.json -> {prompt_version: {"<instance>::<schema>": JudgePrediction}} | |
| comparisons.json -> [{instance_id, schema, human_label, judge_label, agrees, prompt_version}] | |
| (running log written by the inline capture path) | |
| The κ computation reuses ``potato/agreement.py`` (judge vs. human gold as two | |
| "annotators"). The pure ``compute_alignment_from_pairs`` is the unit-testable core. | |
| """ | |
| import json | |
| import logging | |
| import os | |
| from collections import Counter, defaultdict | |
| from typing import Any, Dict, List, Optional, Tuple | |
| logger = logging.getLogger(__name__) | |
| # ----- paths / persistence ---------------------------------------------- | |
| def _dir(config: Dict[str, Any]) -> str: | |
| base = config.get("output_annotation_dir") or config.get("task_dir") or "." | |
| return os.path.join(base, "judge_alignment") | |
| def _load_json(path: str, default): | |
| try: | |
| with open(path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except (FileNotFoundError, ValueError): | |
| return default | |
| def _save_json(path: str, data) -> None: | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| with open(path, "w", encoding="utf-8") as f: | |
| json.dump(data, f, indent=2, ensure_ascii=False) | |
| def predictions_path(config: Dict[str, Any]) -> str: | |
| return os.path.join(_dir(config), "predictions.json") | |
| def comparisons_path(config: Dict[str, Any]) -> str: | |
| return os.path.join(_dir(config), "comparisons.json") | |
| def load_predictions(config: Dict[str, Any]) -> Dict[str, Dict[str, dict]]: | |
| return _load_json(predictions_path(config), {}) | |
| def save_prediction(config: Dict[str, Any], pred) -> None: | |
| """Persist one JudgePrediction (keyed by prompt_version → instance::schema).""" | |
| data = load_predictions(config) | |
| version = pred.prompt_version or "default" | |
| data.setdefault(version, {})[f"{pred.instance_id}::{pred.schema_name}"] = pred.to_dict() | |
| _save_json(predictions_path(config), data) | |
| def latest_prompt_version(config: Dict[str, Any]) -> Optional[str]: | |
| data = load_predictions(config) | |
| if not data: | |
| return None | |
| # Most-populated version is the "current" working set. | |
| return max(data.keys(), key=lambda v: len(data[v])) | |
| def record_comparison(config: Dict[str, Any], instance_id: str, schema: str, | |
| human_label: Any, judge_label: Any, prompt_version: str) -> None: | |
| """Append a human↔judge comparison to the running log (inline capture).""" | |
| log = _load_json(comparisons_path(config), []) | |
| log.append({ | |
| "instance_id": instance_id, | |
| "schema": schema, | |
| "human_label": str(human_label), | |
| "judge_label": str(judge_label), | |
| "agrees": str(human_label) == str(judge_label), | |
| "prompt_version": prompt_version, | |
| }) | |
| _save_json(comparisons_path(config), log) | |
| def running_agreement(config: Dict[str, Any], schema: Optional[str] = None) -> Dict[str, Any]: | |
| """Quick running agreement from the comparison log (for the inline badge).""" | |
| log = _load_json(comparisons_path(config), []) | |
| if schema: | |
| log = [c for c in log if c.get("schema") == schema] | |
| n = len(log) | |
| agree = sum(1 for c in log if c.get("agrees")) | |
| pairs = {s: [] for s in {c["schema"] for c in log}} | |
| for c in log: | |
| pairs[c["schema"]].append((c["instance_id"], c["human_label"], c["judge_label"], None, "")) | |
| kappa = None | |
| if schema and pairs.get(schema): | |
| res = compute_alignment_from_pairs({schema: pairs[schema]}).get(schema, {}) | |
| kappa = res.get("kappa") | |
| return {"n": n, "agreements": agree, | |
| "agreement_rate": round(agree / n, 3) if n else 0.0, "kappa": kappa} | |
| # ----- human label extraction -------------------------------------------- | |
| def human_label_for(instance_id: str, schema_name: str, username: str) -> Optional[str]: | |
| """The single categorical label a user assigned for a schema, or None.""" | |
| from potato.flask_server import get_annotations_for_user_on | |
| anns = get_annotations_for_user_on(username, instance_id) or {} | |
| chosen = anns.get(schema_name) | |
| if not chosen: | |
| return None | |
| # Single-choice: the (first) selected label name. | |
| keys = [k for k in chosen.keys()] | |
| return keys[0] if keys else None | |
| def majority_human_label(instance_id: str, schema_name: str, users: List[str]) -> Optional[str]: | |
| votes = [] | |
| for u in users: | |
| lab = human_label_for(instance_id, schema_name, u) | |
| if lab is not None: | |
| votes.append(lab) | |
| if not votes: | |
| return None | |
| return Counter(votes).most_common(1)[0][0] | |
| # ----- agreement computation (pure core) ---------------------------------- | |
| def compute_alignment_from_pairs( | |
| pairs_by_schema: Dict[str, List[Tuple[str, Any, Any, Optional[float], str]]], | |
| ) -> Dict[str, Any]: | |
| """Compute per-schema judge↔human alignment from resolved pairs. | |
| pairs_by_schema: {schema: [(instance_id, human_label, judge_label, | |
| judge_confidence|None, reasoning), ...]} | |
| Returns {schema: {kappa, interpretation, agreement_rate, n, confusion, | |
| disagreements[]}}. | |
| """ | |
| import pandas as pd | |
| from potato.agreement import cohen_kappa_pairwise, interpret_kappa | |
| out: Dict[str, Any] = {} | |
| for schema, pairs in pairs_by_schema.items(): | |
| pairs = [p for p in pairs if p[1] is not None and p[2] is not None] | |
| n = len(pairs) | |
| if n == 0: | |
| out[schema] = {"kappa": None, "interpretation": "no overlap", | |
| "agreement_rate": 0.0, "n": 0, "confusion": {}, | |
| "disagreements": []} | |
| continue | |
| agree = sum(1 for _, h, j, *_ in pairs if str(h) == str(j)) | |
| confusion: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) | |
| disagreements = [] | |
| rows = [] | |
| for inst, h, j, conf, reason in pairs: | |
| confusion[str(h)][str(j)] += 1 | |
| rows.append({"unit": inst, "annotator": "human", "annotation": str(h)}) | |
| rows.append({"unit": inst, "annotator": "judge", "annotation": str(j)}) | |
| if str(h) != str(j): | |
| disagreements.append({ | |
| "instance_id": inst, "human_label": str(h), "judge_label": str(j), | |
| "judge_confidence": conf, "reasoning": reason, | |
| }) | |
| kappa = None | |
| interp = "n/a" | |
| try: | |
| res = cohen_kappa_pairwise(pd.DataFrame(rows)) | |
| kappa = res.get("mean_kappa") | |
| if kappa is not None: | |
| interp = interpret_kappa(kappa) | |
| except Exception as e: | |
| logger.warning(f"Judge alignment: kappa failed for {schema}: {e}") | |
| out[schema] = { | |
| "kappa": round(kappa, 3) if isinstance(kappa, (int, float)) else None, | |
| "interpretation": interp, | |
| "agreement_rate": round(agree / n, 3), | |
| "n": n, | |
| "confusion": {h: dict(js) for h, js in confusion.items()}, | |
| "disagreements": disagreements, | |
| } | |
| return out | |
| # ----- gathering from persisted predictions + live human labels ----------- | |
| def judge_scoped_schemas(config: Dict[str, Any]) -> List[dict]: | |
| """Annotation schemes the judge should evaluate (categorical only). | |
| Honors ``judge_alignment.schemas`` allow-list if present; otherwise all | |
| radio/select/likert schemes. | |
| """ | |
| schemes = config.get("annotation_schemes", []) or [] | |
| allow = set((config.get("judge_alignment", {}) or {}).get("schemas", {}).keys()) | |
| cats = {"radio", "select", "likert"} | |
| out = [] | |
| for s in schemes: | |
| if s.get("annotation_type") not in cats: | |
| continue | |
| if allow and s.get("name") not in allow: | |
| continue | |
| out.append(s) | |
| return out | |
| def gather_pairs(config: Dict[str, Any], users: List[str], schema_names: List[str], | |
| prompt_version: Optional[str]) -> Dict[str, List[Tuple]]: | |
| """Build (instance, human_gold, judge_label, conf, reasoning) pairs.""" | |
| preds = load_predictions(config) | |
| version = prompt_version or latest_prompt_version(config) | |
| version_preds = preds.get(version, {}) if version else {} | |
| pairs_by_schema: Dict[str, List[Tuple]] = {s: [] for s in schema_names} | |
| for key, pred in version_preds.items(): | |
| instance_id, _, schema = key.partition("::") | |
| if schema not in pairs_by_schema: | |
| continue | |
| gold = majority_human_label(instance_id, schema, users) | |
| if gold is None: | |
| continue | |
| pairs_by_schema[schema].append(( | |
| instance_id, gold, pred.get("predicted_label"), | |
| pred.get("confidence"), pred.get("reasoning", ""), | |
| )) | |
| return pairs_by_schema | |
| def annotated_instance_ids(users: List[str], schema_name: str) -> List[str]: | |
| """Instance ids that at least one user has labeled for this schema.""" | |
| from potato.flask_server import get_user_state | |
| ids = set() | |
| for u in users: | |
| st = get_user_state(u) | |
| if not st: | |
| continue | |
| for iid in st.get_annotated_instance_ids(): | |
| if human_label_for(iid, schema_name, u) is not None: | |
| ids.add(iid) | |
| return sorted(ids) | |
| def run_judge_batch(config: Dict[str, Any], users: List[str], | |
| rubric_overrides: Optional[Dict[str, str]] = None, | |
| max_per_schema: Optional[int] = None) -> Dict[str, Any]: | |
| """Run the judge over human-annotated instances and persist predictions. | |
| rubric_overrides: {schema_name: rubric} to calibrate + create a new prompt | |
| version. Few-shot examples (when enabled) are drawn from high-agreement | |
| human labels, excluding the instance being judged. | |
| """ | |
| from potato.ai.judge import JudgeService, compute_prompt_version | |
| from potato.item_state_management import get_item_state_manager | |
| # Apply rubric overrides into a working config copy. | |
| cfg = dict(config) | |
| ja = dict(cfg.get("judge_alignment", {}) or {}) | |
| if rubric_overrides: | |
| schemas_cfg = dict(ja.get("schemas", {}) or {}) | |
| for name, rubric in rubric_overrides.items(): | |
| sc = dict(schemas_cfg.get(name, {}) or {}) | |
| sc["rubric"] = rubric | |
| schemas_cfg[name] = sc | |
| ja["schemas"] = schemas_cfg | |
| cfg["judge_alignment"] = ja | |
| service = JudgeService(cfg) | |
| ism = get_item_state_manager() | |
| few_shot_cfg = (ja.get("few_shot") or {}) | |
| use_few_shot = bool(few_shot_cfg.get("enabled", False)) | |
| n_judged, n_failed, version_seen = 0, 0, None | |
| for schema in judge_scoped_schemas(cfg): | |
| schema_name = schema.get("name") | |
| ids = annotated_instance_ids(users, schema_name) | |
| if max_per_schema: | |
| ids = ids[:max_per_schema] | |
| examples = _few_shot_examples(schema_name, use_few_shot, few_shot_cfg) | |
| for iid in ids: | |
| try: | |
| item = ism.get_item(iid) | |
| text = item.get_text() if item else "" | |
| except Exception: | |
| text = "" | |
| shots = [e for e in examples if e.get("id") != iid] or None | |
| pred = service.judge_instance(iid, schema, text, few_shot_examples=shots) | |
| if pred is None: | |
| n_failed += 1 | |
| continue | |
| save_prediction(cfg, pred) | |
| version_seen = pred.prompt_version | |
| n_judged += 1 | |
| return {"judged": n_judged, "failed": n_failed, "prompt_version": version_seen} | |
| def _few_shot_examples(schema_name: str, enabled: bool, cfg: Dict[str, Any]) -> List[dict]: | |
| """Gold few-shot examples from high-agreement human labels (or []).""" | |
| if not enabled: | |
| return [] | |
| try: | |
| from potato.ai.icl_labeler import get_icl_labeler | |
| labeler = get_icl_labeler() | |
| if labeler is None: | |
| return [] | |
| by_schema = labeler.refresh_high_confidence_examples() | |
| examples = by_schema.get(schema_name, [])[: int(cfg.get("max_examples", 5))] | |
| return [{"id": getattr(e, "instance_id", ""), | |
| "text": getattr(e, "instance_text", getattr(e, "text", "")), | |
| "label": getattr(e, "label", getattr(e, "agreed_label", ""))} | |
| for e in examples] | |
| except Exception as e: | |
| logger.warning(f"Judge few-shot example gathering failed: {e}") | |
| return [] | |
| def compute_judge_alignment(config: Dict[str, Any], users: List[str], | |
| prompt_version: Optional[str] = None) -> Dict[str, Any]: | |
| """Full report: per-schema alignment for a prompt version + version list.""" | |
| schemas = [s.get("name") for s in judge_scoped_schemas(config)] | |
| version = prompt_version or latest_prompt_version(config) | |
| pairs = gather_pairs(config, users, schemas, version) | |
| per_schema = compute_alignment_from_pairs(pairs) | |
| preds = load_predictions(config) | |
| versions = [] | |
| for v in preds.keys(): | |
| v_pairs = gather_pairs(config, users, schemas, v) | |
| v_report = compute_alignment_from_pairs(v_pairs) | |
| kappas = [r["kappa"] for r in v_report.values() if r.get("kappa") is not None] | |
| versions.append({ | |
| "prompt_version": v, | |
| "n_predictions": len(preds[v]), | |
| "mean_kappa": round(sum(kappas) / len(kappas), 3) if kappas else None, | |
| }) | |
| return { | |
| "prompt_version": version, | |
| "per_schema": per_schema, | |
| "prompt_versions": sorted(versions, key=lambda x: x["prompt_version"]), | |
| } | |