codebook / potato /solo_mode /refinement /validation.py
davidjurgens's picture
Deploy: Potato — Codebook Annotation
aceb1b2 verified
Raw
History Blame Contribute Delete
9.04 kB
"""
Validation infrastructure for refinement strategies.
- ValidationSplit: deterministically splits disagreements into train/val
- CandidateEvaluator: labels val instances with a candidate prompt/ICL
and returns accuracy against human labels
"""
from __future__ import annotations
import logging
import random
from collections import Counter
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
@dataclass
class SplitResult:
train: List[Dict[str, Any]]
val: List[Dict[str, Any]]
seed: int
class ValidationSplit:
"""Split human-LLM comparison records into train/val deterministically.
Split is seeded by prompt_version so it's stable within a version
(a refinement cycle can re-run without the split changing) but
different across versions (preventing val leakage across cycles).
Only disagreements (agrees=False) are split, because the refinement
process works on disagreements. Agreements stay in the train side as
useful context but aren't needed for eval.
"""
def __init__(
self,
val_ratio: float = 0.3,
min_val: int = 5,
min_train: int = 5,
prefer_consistent: bool = False,
):
"""
Args:
val_ratio: fraction of disagreements held out for validation
min_val: minimum val size; if fewer disagreements exist, returns empty val
min_train: minimum train size; if fewer, returns empty train
prefer_consistent: if True, prefer val instances that have disagreed
across ≥2 labeling passes (systematic errors), falling back to
one-off disagreements only when too few qualify.
"""
self.val_ratio = val_ratio
self.min_val = min_val
self.min_train = min_train
self.prefer_consistent = prefer_consistent
def split(
self,
comparisons: List[Dict[str, Any]],
prompt_version: int,
) -> SplitResult:
"""Split comparisons into train/val.
Args:
comparisons: list of {instance_id, human_label, llm_label, agrees, ...}
prompt_version: used to seed the split deterministically
Returns:
SplitResult with train and val lists; either can be empty if not
enough disagreements are available.
"""
# Only disagreements go to val; agreements stay in train
disagreements = [c for c in comparisons if not c.get('agrees')]
agreements = [c for c in comparisons if c.get('agrees')]
if len(disagreements) < (self.min_val + self.min_train):
logger.info(
f"[ValidationSplit] Only {len(disagreements)} disagreements available; "
f"need at least {self.min_val + self.min_train}. Returning empty splits."
)
return SplitResult(train=[], val=[], seed=prompt_version)
rng = random.Random(f"val_split_v{prompt_version}")
# If preferring consistent disagreements: seed val from instances that
# have disagreed ≥2 times. If that pool is too small, top up with
# one-off disagreements so we still meet min_val.
val, train_disagreements = self._partition(
disagreements, prompt_version, rng
)
# Combine train disagreements with agreements (useful context for
# rule generation — but agreements aren't used for scoring)
train = train_disagreements + agreements
logger.info(
f"[ValidationSplit] v{prompt_version}: "
f"{len(train_disagreements)} train disagreements, "
f"{len(val)} val disagreements, "
f"{len(agreements)} agreements in train context"
)
return SplitResult(train=train, val=val, seed=prompt_version)
def _partition(
self,
disagreements: List[Dict[str, Any]],
prompt_version: int,
rng: random.Random,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Pick val set (optionally preferring consistent disagreements) and return (val, train_disagreements)."""
target_val_size = max(self.min_val, int(len(disagreements) * self.val_ratio))
if not self.prefer_consistent:
shuffled = list(disagreements)
rng.shuffle(shuffled)
return shuffled[:target_val_size], shuffled[target_val_size:]
# Count disagreements per instance_id. An instance that shows up
# multiple times in the disagreement list has failed across at least
# that many labeling passes — treat it as a systematic error rather
# than a one-off stochastic flip.
counts = Counter(c['instance_id'] for c in disagreements)
# Keep only the *latest* disagreement record per instance to avoid
# the same instance appearing multiple times in the val set.
latest_by_iid: Dict[str, Dict[str, Any]] = {}
for c in disagreements:
latest_by_iid[c['instance_id']] = c
consistent = [
latest_by_iid[iid] for iid, n in counts.items() if n >= 2
]
oneoff = [
latest_by_iid[iid] for iid, n in counts.items() if n < 2
]
rng.shuffle(consistent)
rng.shuffle(oneoff)
if len(consistent) >= target_val_size:
val = consistent[:target_val_size]
topup_used = 0
else:
# Not enough consistent disagreements — top up with one-offs so
# we still meet min_val. This makes the filter a preference,
# not a hard gate.
need = target_val_size - len(consistent)
val = consistent + oneoff[:need]
topup_used = min(need, len(oneoff))
val_ids = {c['instance_id'] for c in val}
# Train gets all disagreement records whose instance_id isn't in val.
train_disagreements = [c for c in disagreements if c['instance_id'] not in val_ids]
logger.info(
f"[ValidationSplit] prefer_consistent: {len(consistent)} instances "
f"with ≥2 disagreements, {len(oneoff)} one-offs; "
f"val drew {len(val) - topup_used} consistent + {topup_used} one-off"
)
return val, train_disagreements
@dataclass
class EvalResult:
accuracy: float
correct_count: int
total: int
per_instance: List[Dict[str, Any]] # {instance_id, predicted, human, correct}
class CandidateEvaluator:
"""Evaluate a candidate (prompt edit or ICL example) on a validation set.
Uses a single labeling call per instance (no sampling diversity for speed).
Compares predicted label against the human label already recorded.
"""
def __init__(
self,
label_fn: Callable[[str, str, str], Optional[str]],
get_text_fn: Callable[[str], str],
):
"""
Args:
label_fn: callable(instance_id, text, prompt) -> predicted_label or None
get_text_fn: callable(instance_id) -> text string
"""
self.label_fn = label_fn
self.get_text_fn = get_text_fn
def evaluate(
self,
candidate_prompt: str,
val_comparisons: List[Dict[str, Any]],
sample_size: Optional[int] = None,
) -> EvalResult:
"""Label each val instance with the candidate prompt, compute accuracy.
Args:
candidate_prompt: the full prompt text to evaluate
val_comparisons: list of comparison dicts with human_label
sample_size: if set, randomly sample this many from val_comparisons
Returns:
EvalResult with accuracy and per-instance breakdown
"""
if sample_size and len(val_comparisons) > sample_size:
val_comparisons = random.sample(val_comparisons, sample_size)
correct = 0
per_instance = []
for comp in val_comparisons:
iid = comp['instance_id']
human_label = comp.get('human_label')
if human_label is None:
continue
try:
text = self.get_text_fn(iid)
predicted = self.label_fn(iid, text, candidate_prompt)
except Exception as e:
logger.warning(f"[CandidateEval] Failed to label {iid}: {e}")
predicted = None
is_correct = (
predicted is not None
and str(predicted) == str(human_label)
)
if is_correct:
correct += 1
per_instance.append({
'instance_id': iid,
'predicted': predicted,
'human': human_label,
'correct': is_correct,
})
total = len(per_instance)
accuracy = correct / total if total > 0 else 0.0
return EvalResult(
accuracy=accuracy,
correct_count=correct,
total=total,
per_instance=per_instance,
)