| | |
| |
|
| | import argparse |
| | import json |
| | from collections import defaultdict |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Optional, Tuple |
| |
|
| |
|
| | EXPECTED_LABELS = ( |
| | "low_health_literacy", |
| | "intermediate_health_literacy", |
| | "proficient_health_literacy", |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class MergeStats: |
| | total_rows: int = 0 |
| | total_doc_ids: int = 0 |
| | missing_label_rows: int = 0 |
| | unexpected_labels: int = 0 |
| | doc_ids_missing_some_labels: int = 0 |
| | doc_ids_fulltext_mismatch: int = 0 |
| | doc_ids_summary_mismatch: int = 0 |
| | doc_ids_fulltext_subclaims_mismatch: int = 0 |
| | doc_ids_summary_subclaims_mismatch: int = 0 |
| |
|
| |
|
| | def _pick_first_non_empty(values: List[Optional[str]]) -> Optional[str]: |
| | for value in values: |
| | if isinstance(value, str) and value.strip(): |
| | return value |
| | for value in values: |
| | if value is not None: |
| | return value |
| | return None |
| |
|
| |
|
| | def _normalize_text(value: Any) -> Optional[str]: |
| | if value is None: |
| | return None |
| | if not isinstance(value, str): |
| | return str(value) |
| | return value |
| |
|
| |
|
| | def _normalize_string_list(value: Any) -> Optional[Tuple[str, ...]]: |
| | if value is None: |
| | return None |
| | if not isinstance(value, list): |
| | return (str(value),) |
| | normalized: List[str] = [] |
| | for item in value: |
| | if item is None: |
| | continue |
| | if isinstance(item, str): |
| | normalized.append(item.strip()) |
| | else: |
| | normalized.append(str(item).strip()) |
| | return tuple(normalized) |
| |
|
| |
|
| | def combine_by_doc_id(rows: List[Dict[str, Any]], keep_all_fields_per_label: bool = True) -> Tuple[List[Dict[str, Any]], MergeStats]: |
| | stats = MergeStats(total_rows=len(rows)) |
| |
|
| | grouped: Dict[int, List[Dict[str, Any]]] = defaultdict(list) |
| | for row in rows: |
| | if not isinstance(row, dict): |
| | continue |
| | doc_id = row.get("doc_id") |
| | if doc_id is None: |
| | continue |
| | grouped[int(doc_id)].append(row) |
| |
|
| | stats.total_doc_ids = len(grouped) |
| |
|
| | combined: List[Dict[str, Any]] = [] |
| |
|
| | for doc_id in sorted(grouped.keys()): |
| | bucket = grouped[doc_id] |
| |
|
| | labels_map: Dict[str, Dict[str, Any]] = {} |
| | fulltexts: List[Optional[str]] = [] |
| | summaries: List[Optional[str]] = [] |
| | fulltext_subclaims_sets: List[Optional[Tuple[str, ...]]] = [] |
| | summary_subclaims_sets: List[Optional[Tuple[str, ...]]] = [] |
| |
|
| | for row in bucket: |
| | label = row.get("label") |
| | if not label: |
| | stats.missing_label_rows += 1 |
| | continue |
| | if label not in EXPECTED_LABELS: |
| | stats.unexpected_labels += 1 |
| |
|
| | fulltexts.append(_normalize_text(row.get("fulltext"))) |
| | summaries.append(_normalize_text(row.get("summary"))) |
| | fulltext_subclaims_sets.append(_normalize_string_list(row.get("fulltext_subclaims"))) |
| | summary_subclaims_sets.append(_normalize_string_list(row.get("summary_subclaims"))) |
| |
|
| | label_payload: Dict[str, Any] |
| | if keep_all_fields_per_label: |
| | |
| | label_payload = { |
| | k: v |
| | for k, v in row.items() |
| | if k |
| | not in ( |
| | "doc_id", |
| | "label", |
| | "fulltext", |
| | "summary", |
| | "fulltext_subclaims", |
| | "summary_subclaims", |
| | ) |
| | } |
| | else: |
| | label_payload = { |
| | "diff_label_texts": row.get("diff_label_texts"), |
| | "diff_label_subclaims": row.get("diff_label_subclaims"), |
| | } |
| |
|
| | labels_map[str(label)] = label_payload |
| |
|
| | chosen_fulltext = _pick_first_non_empty(fulltexts) |
| | chosen_summary = _pick_first_non_empty(summaries) |
| |
|
| | chosen_fulltext_subclaims: Optional[List[str]] = None |
| | for items in fulltext_subclaims_sets: |
| | if items: |
| | chosen_fulltext_subclaims = list(items) |
| | break |
| | chosen_summary_subclaims: Optional[List[str]] = None |
| | for items in summary_subclaims_sets: |
| | if items: |
| | chosen_summary_subclaims = list(items) |
| | break |
| |
|
| | distinct_fulltexts = {t.strip() for t in fulltexts if isinstance(t, str) and t.strip()} |
| | distinct_summaries = {t.strip() for t in summaries if isinstance(t, str) and t.strip()} |
| | if len(distinct_fulltexts) > 1: |
| | stats.doc_ids_fulltext_mismatch += 1 |
| | if len(distinct_summaries) > 1: |
| | stats.doc_ids_summary_mismatch += 1 |
| |
|
| | distinct_fulltext_subclaims = {t for t in fulltext_subclaims_sets if t} |
| | distinct_summary_subclaims = {t for t in summary_subclaims_sets if t} |
| | if len(distinct_fulltext_subclaims) > 1: |
| | stats.doc_ids_fulltext_subclaims_mismatch += 1 |
| | if len(distinct_summary_subclaims) > 1: |
| | stats.doc_ids_summary_subclaims_mismatch += 1 |
| |
|
| | missing_some = any(lbl not in labels_map for lbl in EXPECTED_LABELS) |
| | if missing_some: |
| | stats.doc_ids_missing_some_labels += 1 |
| |
|
| | combined.append( |
| | { |
| | "doc_id": doc_id, |
| | "fulltext": chosen_fulltext, |
| | "fulltext_subclaims": chosen_fulltext_subclaims, |
| | "summary": chosen_summary, |
| | "summary_subclaims": chosen_summary_subclaims, |
| | "labels": labels_map, |
| | } |
| | ) |
| |
|
| | return combined, stats |
| |
|
| |
|
| | def main() -> None: |
| | parser = argparse.ArgumentParser( |
| | description=( |
| | "Combine per-label rows into a single object per doc_id. " |
| | "Input is a JSON array with repeated doc_id for different labels." |
| | ) |
| | ) |
| | parser.add_argument( |
| | "--input", |
| | required=True, |
| | help="Path to input JSON file (list of rows)", |
| | ) |
| | parser.add_argument( |
| | "--output", |
| | default=None, |
| | help="Path to output JSON file. Default: same folder with *_by_docid.json suffix", |
| | ) |
| | parser.add_argument( |
| | "--minimal", |
| | action="store_true", |
| | help="Only keep diff_label_texts/diff_label_subclaims/fulltext_subclaims/summary_subclaims per label.", |
| | ) |
| |
|
| | args = parser.parse_args() |
| | input_path = Path(args.input) |
| | output_path = Path(args.output) if args.output else input_path.with_name(input_path.stem + "_by_docid.json") |
| |
|
| | rows = json.loads(input_path.read_text(encoding="utf-8")) |
| | if not isinstance(rows, list): |
| | raise SystemExit("Input JSON must be a list") |
| |
|
| | combined, stats = combine_by_doc_id(rows, keep_all_fields_per_label=not args.minimal) |
| |
|
| | output_path.write_text( |
| | json.dumps(combined, ensure_ascii=False, indent=2) + "\n", |
| | encoding="utf-8", |
| | ) |
| |
|
| | print("Wrote:", str(output_path)) |
| | print( |
| | "Stats:", |
| | json.dumps( |
| | { |
| | "total_rows": stats.total_rows, |
| | "total_doc_ids": stats.total_doc_ids, |
| | "missing_label_rows": stats.missing_label_rows, |
| | "unexpected_labels": stats.unexpected_labels, |
| | "doc_ids_missing_some_labels": stats.doc_ids_missing_some_labels, |
| | "doc_ids_fulltext_mismatch": stats.doc_ids_fulltext_mismatch, |
| | "doc_ids_summary_mismatch": stats.doc_ids_summary_mismatch, |
| | "doc_ids_fulltext_subclaims_mismatch": stats.doc_ids_fulltext_subclaims_mismatch, |
| | "doc_ids_summary_subclaims_mismatch": stats.doc_ids_summary_subclaims_mismatch, |
| | }, |
| | indent=2, |
| | ), |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|