""" Best-Worst Scaling Score Estimation Computes item scores from BWS annotations using three methods: 1. Counting: score = (best_count - worst_count) / appearances (no dependencies) 2. Bradley-Terry: pairwise comparison model via choix (requires choix) 3. Plackett-Luce: partial ranking model via choix (requires choix) Usage as library: from potato.bws_scoring import BwsScorer scorer = BwsScorer(annotations, pool_items, id_key) scores = scorer.counting() Usage as CLI: python -m potato.bws_scoring --config config.yaml --method counting """ import argparse import csv import json import logging import os import sys from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) class BwsScorer: """Compute BWS scores from annotations.""" def __init__( self, annotations: List[Dict[str, Any]], pool_items: List[Dict[str, Any]], id_key: str, text_key: str = "text", ): """ Args: annotations: List of annotation dicts, each with: - "instance_id": tuple instance ID (e.g. "bws_tuple_0001") - "bws_items": list of {source_id, text, position} - "best": position label (e.g. "B") - "worst": position label (e.g. "D") - "annotator": username pool_items: Original pool items id_key: Key for item IDs in pool_items text_key: Key for item text in pool_items """ self.annotations = annotations self.pool_items = pool_items self.id_key = id_key self.text_key = text_key # Build item index self.item_ids = [str(item[id_key]) for item in pool_items] self.item_texts = { str(item[id_key]): str(item.get(text_key, "")) for item in pool_items } self.item_id_to_idx = {iid: idx for idx, iid in enumerate(self.item_ids)} def _resolve_annotation( self, ann: Dict[str, Any] ) -> Optional[Tuple[str, str, List[str]]]: """Resolve an annotation to (best_source_id, worst_source_id, all_source_ids). Returns None if annotation is incomplete. """ best_pos = ann.get("best") worst_pos = ann.get("worst") bws_items = ann.get("bws_items", []) if not best_pos or not worst_pos or not bws_items: return None pos_to_id = {item["position"]: item["source_id"] for item in bws_items} best_id = pos_to_id.get(best_pos) worst_id = pos_to_id.get(worst_pos) if not best_id or not worst_id: return None all_ids = [item["source_id"] for item in bws_items] return best_id, worst_id, all_ids def counting(self) -> Dict[str, Dict[str, Any]]: """Counting method: score = (best_count - worst_count) / appearances. Returns dict mapping item_id to {score, best_count, worst_count, appearances, text}. """ best_counts = {iid: 0 for iid in self.item_ids} worst_counts = {iid: 0 for iid in self.item_ids} appearances = {iid: 0 for iid in self.item_ids} for ann in self.annotations: resolved = self._resolve_annotation(ann) if not resolved: continue best_id, worst_id, all_ids = resolved for iid in all_ids: if iid in appearances: appearances[iid] += 1 if best_id in best_counts: best_counts[best_id] += 1 if worst_id in worst_counts: worst_counts[worst_id] += 1 scores = {} for iid in self.item_ids: app = appearances[iid] if app > 0: score = (best_counts[iid] - worst_counts[iid]) / app else: score = 0.0 scores[iid] = { "score": score, "best_count": best_counts[iid], "worst_count": worst_counts[iid], "appearances": app, "text": self.item_texts.get(iid, ""), } return scores def bradley_terry(self) -> Dict[str, Dict[str, Any]]: """Bradley-Terry model via choix. Converts each BWS annotation to pairwise comparisons: - Best item beats every other item (K-1 comparisons) - Every item beats the worst item (K-1 comparisons) """ try: import choix except ImportError: raise ImportError( "Bradley-Terry scoring requires the 'choix' package. " "Install it with: pip install choix" ) n_items = len(self.item_ids) comparisons = [] for ann in self.annotations: resolved = self._resolve_annotation(ann) if not resolved: continue best_id, worst_id, all_ids = resolved best_idx = self.item_id_to_idx.get(best_id) worst_idx = self.item_id_to_idx.get(worst_id) if best_idx is None or worst_idx is None: continue # Best beats all others for iid in all_ids: idx = self.item_id_to_idx.get(iid) if idx is not None and idx != best_idx: comparisons.append((best_idx, idx)) # All others beat worst for iid in all_ids: idx = self.item_id_to_idx.get(iid) if idx is not None and idx != worst_idx: comparisons.append((idx, worst_idx)) if not comparisons: return { iid: {"score": 0.0, "text": self.item_texts.get(iid, "")} for iid in self.item_ids } params = choix.ilsr_pairwise(n_items, comparisons, alpha=0.01) scores = {} for iid in self.item_ids: idx = self.item_id_to_idx[iid] scores[iid] = { "score": float(params[idx]), "text": self.item_texts.get(iid, ""), } return scores def plackett_luce(self) -> Dict[str, Dict[str, Any]]: """Plackett-Luce model via choix. Converts BWS to partial rankings: Each annotation yields top-1 (best) selections, processed via ilsr_top1. """ try: import choix except ImportError: raise ImportError( "Plackett-Luce scoring requires the 'choix' package. " "Install it with: pip install choix" ) n_items = len(self.item_ids) # Use pairwise comparisons to approximate partial rankings # Best > middle items, middle items > worst comparisons = [] for ann in self.annotations: resolved = self._resolve_annotation(ann) if not resolved: continue best_id, worst_id, all_ids = resolved best_idx = self.item_id_to_idx.get(best_id) worst_idx = self.item_id_to_idx.get(worst_id) if best_idx is None or worst_idx is None: continue middle_ids = [ iid for iid in all_ids if iid != best_id and iid != worst_id ] # Best beats all middle items for iid in middle_ids: idx = self.item_id_to_idx.get(iid) if idx is not None: comparisons.append((best_idx, idx)) # All middle items beat worst for iid in middle_ids: idx = self.item_id_to_idx.get(iid) if idx is not None: comparisons.append((idx, worst_idx)) # Best beats worst comparisons.append((best_idx, worst_idx)) if not comparisons: return { iid: {"score": 0.0, "text": self.item_texts.get(iid, "")} for iid in self.item_ids } params = choix.ilsr_pairwise(n_items, comparisons, alpha=0.01) scores = {} for iid in self.item_ids: idx = self.item_id_to_idx[iid] scores[iid] = { "score": float(params[idx]), "text": self.item_texts.get(iid, ""), } return scores def score(self, method: str = "counting") -> Dict[str, Dict[str, Any]]: """Compute scores using the specified method.""" if method == "counting": return self.counting() elif method == "bradley_terry": return self.bradley_terry() elif method == "plackett_luce": return self.plackett_luce() else: raise ValueError( f"Unknown scoring method: {method}. " "Use 'counting', 'bradley_terry', or 'plackett_luce'." ) def write_scores( scores: Dict[str, Dict[str, Any]], output_path: str, ) -> None: """Write scores to a TSV file. Output columns: item_id, text, score, best_count, worst_count, appearances, rank """ # Sort by score descending sorted_items = sorted(scores.items(), key=lambda x: x[1]["score"], reverse=True) os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True) with open(output_path, "w", newline="") as f: writer = csv.writer(f, delimiter="\t") writer.writerow( ["item_id", "text", "score", "best_count", "worst_count", "appearances", "rank"] ) for rank, (item_id, data) in enumerate(sorted_items, 1): writer.writerow([ item_id, data.get("text", ""), f"{data['score']:.6f}", data.get("best_count", ""), data.get("worst_count", ""), data.get("appearances", ""), rank, ]) logger.info(f"Wrote BWS scores to {output_path}") def collect_annotations_from_output( output_dir: str, bws_schema_name: str, config: dict ) -> List[Dict[str, Any]]: """Collect BWS annotations from Potato's output directory. Reads annotation files and reconstructs BWS annotation records. """ annotations = [] pool_items_by_tuple = {} # Get pool items from config bws_pool = config.get("_bws_pool_items", []) id_key = config["item_properties"]["id_key"] # We need to read the saved annotations from the output dir # Potato saves annotations as {output_dir}/{annotator}.jsonl if not os.path.isdir(output_dir): logger.warning(f"Output directory not found: {output_dir}") return annotations for fname in os.listdir(output_dir): if not fname.endswith(".jsonl"): continue annotator = fname.replace(".jsonl", "") fpath = os.path.join(output_dir, fname) with open(fpath, "r") as f: for line in f: line = line.strip() if not line: continue try: record = json.loads(line) except json.JSONDecodeError: continue instance_id = record.get("id") ann_data = record.get("annotation", {}) # Look for BWS schema annotations best_val = None worst_val = None for schema_name, schema_ann in ann_data.items(): if schema_name == bws_schema_name: best_val = schema_ann.get("best") worst_val = schema_ann.get("worst") break if not best_val or not worst_val: continue # Get BWS items from the instance data bws_items = record.get("_bws_items", []) annotations.append({ "instance_id": instance_id, "bws_items": bws_items, "best": best_val, "worst": worst_val, "annotator": annotator, }) return annotations def main(): """CLI entry point for BWS scoring.""" parser = argparse.ArgumentParser( description="Compute BWS scores from Potato annotation output" ) parser.add_argument( "--config", required=True, help="Path to Potato config YAML file" ) parser.add_argument( "--method", default="counting", choices=["counting", "bradley_terry", "plackett_luce"], help="Scoring method (default: counting)", ) parser.add_argument( "--output", default=None, help="Output TSV file path (default: {output_dir}/bws_scores.tsv)", ) args = parser.parse_args() logging.basicConfig(level=logging.INFO) # Load config import yaml with open(args.config, "r") as f: config = yaml.safe_load(f) output_dir = config.get("output_annotation_dir", "annotation_output") id_key = config["item_properties"]["id_key"] text_key = config["item_properties"]["text_key"] # Find BWS schema name bws_schema_name = None for scheme in config.get("annotation_schemes", []): if scheme.get("annotation_type") == "bws": bws_schema_name = scheme["name"] break if not bws_schema_name: print("Error: No BWS annotation scheme found in config", file=sys.stderr) sys.exit(1) # Load pool items from data files pool_items = [] for data_file in config.get("data_files", []): if isinstance(data_file, dict): data_file = data_file.get("path") if not data_file: continue with open(data_file, "r") as f: if data_file.endswith(".json"): pool_items.extend(json.load(f)) else: for line in f: line = line.strip() if line: pool_items.append(json.loads(line)) # Collect annotations annotations = collect_annotations_from_output(output_dir, bws_schema_name, config) if not annotations: print("No BWS annotations found in output directory", file=sys.stderr) sys.exit(1) print(f"Found {len(annotations)} BWS annotations for {len(pool_items)} pool items") # Score scorer = BwsScorer(annotations, pool_items, id_key, text_key) scores = scorer.score(args.method) # Write output output_path = args.output or os.path.join(output_dir, "bws_scores.tsv") write_scores(scores, output_path) print(f"Scores written to {output_path}") if __name__ == "__main__": main()