#!/usr/bin/env python3 import argparse import json from collections import defaultdict from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Tuple EXPECTED_LABELS = ( "low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy", ) @dataclass class MergeStats: total_rows: int = 0 total_doc_ids: int = 0 missing_label_rows: int = 0 unexpected_labels: int = 0 doc_ids_missing_some_labels: int = 0 doc_ids_fulltext_mismatch: int = 0 doc_ids_summary_mismatch: int = 0 doc_ids_fulltext_subclaims_mismatch: int = 0 doc_ids_summary_subclaims_mismatch: int = 0 def _pick_first_non_empty(values: List[Optional[str]]) -> Optional[str]: for value in values: if isinstance(value, str) and value.strip(): return value for value in values: if value is not None: return value return None def _normalize_text(value: Any) -> Optional[str]: if value is None: return None if not isinstance(value, str): return str(value) return value def _normalize_string_list(value: Any) -> Optional[Tuple[str, ...]]: if value is None: return None if not isinstance(value, list): return (str(value),) normalized: List[str] = [] for item in value: if item is None: continue if isinstance(item, str): normalized.append(item.strip()) else: normalized.append(str(item).strip()) return tuple(normalized) def combine_by_doc_id(rows: List[Dict[str, Any]], keep_all_fields_per_label: bool = True) -> Tuple[List[Dict[str, Any]], MergeStats]: stats = MergeStats(total_rows=len(rows)) grouped: Dict[int, List[Dict[str, Any]]] = defaultdict(list) for row in rows: if not isinstance(row, dict): continue doc_id = row.get("doc_id") if doc_id is None: continue grouped[int(doc_id)].append(row) stats.total_doc_ids = len(grouped) combined: List[Dict[str, Any]] = [] for doc_id in sorted(grouped.keys()): bucket = grouped[doc_id] labels_map: Dict[str, Dict[str, Any]] = {} fulltexts: List[Optional[str]] = [] summaries: List[Optional[str]] = [] fulltext_subclaims_sets: List[Optional[Tuple[str, ...]]] = [] summary_subclaims_sets: List[Optional[Tuple[str, ...]]] = [] for row in bucket: label = row.get("label") if not label: stats.missing_label_rows += 1 continue if label not in EXPECTED_LABELS: stats.unexpected_labels += 1 fulltexts.append(_normalize_text(row.get("fulltext"))) summaries.append(_normalize_text(row.get("summary"))) fulltext_subclaims_sets.append(_normalize_string_list(row.get("fulltext_subclaims"))) summary_subclaims_sets.append(_normalize_string_list(row.get("summary_subclaims"))) label_payload: Dict[str, Any] if keep_all_fields_per_label: # Shared within a doc_id; keep them only once at top-level label_payload = { k: v for k, v in row.items() if k not in ( "doc_id", "label", "fulltext", "summary", "fulltext_subclaims", "summary_subclaims", ) } else: label_payload = { "diff_label_texts": row.get("diff_label_texts"), "diff_label_subclaims": row.get("diff_label_subclaims"), } labels_map[str(label)] = label_payload chosen_fulltext = _pick_first_non_empty(fulltexts) chosen_summary = _pick_first_non_empty(summaries) chosen_fulltext_subclaims: Optional[List[str]] = None for items in fulltext_subclaims_sets: if items: chosen_fulltext_subclaims = list(items) break chosen_summary_subclaims: Optional[List[str]] = None for items in summary_subclaims_sets: if items: chosen_summary_subclaims = list(items) break distinct_fulltexts = {t.strip() for t in fulltexts if isinstance(t, str) and t.strip()} distinct_summaries = {t.strip() for t in summaries if isinstance(t, str) and t.strip()} if len(distinct_fulltexts) > 1: stats.doc_ids_fulltext_mismatch += 1 if len(distinct_summaries) > 1: stats.doc_ids_summary_mismatch += 1 distinct_fulltext_subclaims = {t for t in fulltext_subclaims_sets if t} distinct_summary_subclaims = {t for t in summary_subclaims_sets if t} if len(distinct_fulltext_subclaims) > 1: stats.doc_ids_fulltext_subclaims_mismatch += 1 if len(distinct_summary_subclaims) > 1: stats.doc_ids_summary_subclaims_mismatch += 1 missing_some = any(lbl not in labels_map for lbl in EXPECTED_LABELS) if missing_some: stats.doc_ids_missing_some_labels += 1 combined.append( { "doc_id": doc_id, "fulltext": chosen_fulltext, "fulltext_subclaims": chosen_fulltext_subclaims, "summary": chosen_summary, "summary_subclaims": chosen_summary_subclaims, "labels": labels_map, } ) return combined, stats def main() -> None: parser = argparse.ArgumentParser( description=( "Combine per-label rows into a single object per doc_id. " "Input is a JSON array with repeated doc_id for different labels." ) ) parser.add_argument( "--input", required=True, help="Path to input JSON file (list of rows)", ) parser.add_argument( "--output", default=None, help="Path to output JSON file. Default: same folder with *_by_docid.json suffix", ) parser.add_argument( "--minimal", action="store_true", help="Only keep diff_label_texts/diff_label_subclaims/fulltext_subclaims/summary_subclaims per label.", ) args = parser.parse_args() input_path = Path(args.input) output_path = Path(args.output) if args.output else input_path.with_name(input_path.stem + "_by_docid.json") rows = json.loads(input_path.read_text(encoding="utf-8")) if not isinstance(rows, list): raise SystemExit("Input JSON must be a list") combined, stats = combine_by_doc_id(rows, keep_all_fields_per_label=not args.minimal) output_path.write_text( json.dumps(combined, ensure_ascii=False, indent=2) + "\n", encoding="utf-8", ) print("Wrote:", str(output_path)) print( "Stats:", json.dumps( { "total_rows": stats.total_rows, "total_doc_ids": stats.total_doc_ids, "missing_label_rows": stats.missing_label_rows, "unexpected_labels": stats.unexpected_labels, "doc_ids_missing_some_labels": stats.doc_ids_missing_some_labels, "doc_ids_fulltext_mismatch": stats.doc_ids_fulltext_mismatch, "doc_ids_summary_mismatch": stats.doc_ids_summary_mismatch, "doc_ids_fulltext_subclaims_mismatch": stats.doc_ids_fulltext_subclaims_mismatch, "doc_ids_summary_subclaims_mismatch": stats.doc_ids_summary_subclaims_mismatch, }, indent=2, ), ) if __name__ == "__main__": main()