| |
|
|
| import argparse |
| import json |
| from collections import defaultdict |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
| EXPECTED_LABELS = ( |
| "low_health_literacy", |
| "intermediate_health_literacy", |
| "proficient_health_literacy", |
| ) |
|
|
|
|
| @dataclass |
| class MergeStats: |
| total_rows: int = 0 |
| total_doc_ids: int = 0 |
| missing_label_rows: int = 0 |
| unexpected_labels: int = 0 |
| doc_ids_missing_some_labels: int = 0 |
| doc_ids_fulltext_mismatch: int = 0 |
| doc_ids_summary_mismatch: int = 0 |
| doc_ids_fulltext_subclaims_mismatch: int = 0 |
| doc_ids_summary_subclaims_mismatch: int = 0 |
|
|
|
|
| def _pick_first_non_empty(values: List[Optional[str]]) -> Optional[str]: |
| for value in values: |
| if isinstance(value, str) and value.strip(): |
| return value |
| for value in values: |
| if value is not None: |
| return value |
| return None |
|
|
|
|
| def _normalize_text(value: Any) -> Optional[str]: |
| if value is None: |
| return None |
| if not isinstance(value, str): |
| return str(value) |
| return value |
|
|
|
|
| def _normalize_string_list(value: Any) -> Optional[Tuple[str, ...]]: |
| if value is None: |
| return None |
| if not isinstance(value, list): |
| return (str(value),) |
| normalized: List[str] = [] |
| for item in value: |
| if item is None: |
| continue |
| if isinstance(item, str): |
| normalized.append(item.strip()) |
| else: |
| normalized.append(str(item).strip()) |
| return tuple(normalized) |
|
|
|
|
| def combine_by_doc_id(rows: List[Dict[str, Any]], keep_all_fields_per_label: bool = True) -> Tuple[List[Dict[str, Any]], MergeStats]: |
| stats = MergeStats(total_rows=len(rows)) |
|
|
| grouped: Dict[int, List[Dict[str, Any]]] = defaultdict(list) |
| for row in rows: |
| if not isinstance(row, dict): |
| continue |
| doc_id = row.get("doc_id") |
| if doc_id is None: |
| continue |
| grouped[int(doc_id)].append(row) |
|
|
| stats.total_doc_ids = len(grouped) |
|
|
| combined: List[Dict[str, Any]] = [] |
|
|
| for doc_id in sorted(grouped.keys()): |
| bucket = grouped[doc_id] |
|
|
| labels_map: Dict[str, Dict[str, Any]] = {} |
| fulltexts: List[Optional[str]] = [] |
| summaries: List[Optional[str]] = [] |
| fulltext_subclaims_sets: List[Optional[Tuple[str, ...]]] = [] |
| summary_subclaims_sets: List[Optional[Tuple[str, ...]]] = [] |
|
|
| for row in bucket: |
| label = row.get("label") |
| if not label: |
| stats.missing_label_rows += 1 |
| continue |
| if label not in EXPECTED_LABELS: |
| stats.unexpected_labels += 1 |
|
|
| fulltexts.append(_normalize_text(row.get("fulltext"))) |
| summaries.append(_normalize_text(row.get("summary"))) |
| fulltext_subclaims_sets.append(_normalize_string_list(row.get("fulltext_subclaims"))) |
| summary_subclaims_sets.append(_normalize_string_list(row.get("summary_subclaims"))) |
|
|
| label_payload: Dict[str, Any] |
| if keep_all_fields_per_label: |
| |
| label_payload = { |
| k: v |
| for k, v in row.items() |
| if k |
| not in ( |
| "doc_id", |
| "label", |
| "fulltext", |
| "summary", |
| "fulltext_subclaims", |
| "summary_subclaims", |
| ) |
| } |
| else: |
| label_payload = { |
| "diff_label_texts": row.get("diff_label_texts"), |
| "diff_label_subclaims": row.get("diff_label_subclaims"), |
| } |
|
|
| labels_map[str(label)] = label_payload |
|
|
| chosen_fulltext = _pick_first_non_empty(fulltexts) |
| chosen_summary = _pick_first_non_empty(summaries) |
|
|
| chosen_fulltext_subclaims: Optional[List[str]] = None |
| for items in fulltext_subclaims_sets: |
| if items: |
| chosen_fulltext_subclaims = list(items) |
| break |
| chosen_summary_subclaims: Optional[List[str]] = None |
| for items in summary_subclaims_sets: |
| if items: |
| chosen_summary_subclaims = list(items) |
| break |
|
|
| distinct_fulltexts = {t.strip() for t in fulltexts if isinstance(t, str) and t.strip()} |
| distinct_summaries = {t.strip() for t in summaries if isinstance(t, str) and t.strip()} |
| if len(distinct_fulltexts) > 1: |
| stats.doc_ids_fulltext_mismatch += 1 |
| if len(distinct_summaries) > 1: |
| stats.doc_ids_summary_mismatch += 1 |
|
|
| distinct_fulltext_subclaims = {t for t in fulltext_subclaims_sets if t} |
| distinct_summary_subclaims = {t for t in summary_subclaims_sets if t} |
| if len(distinct_fulltext_subclaims) > 1: |
| stats.doc_ids_fulltext_subclaims_mismatch += 1 |
| if len(distinct_summary_subclaims) > 1: |
| stats.doc_ids_summary_subclaims_mismatch += 1 |
|
|
| missing_some = any(lbl not in labels_map for lbl in EXPECTED_LABELS) |
| if missing_some: |
| stats.doc_ids_missing_some_labels += 1 |
|
|
| combined.append( |
| { |
| "doc_id": doc_id, |
| "fulltext": chosen_fulltext, |
| "fulltext_subclaims": chosen_fulltext_subclaims, |
| "summary": chosen_summary, |
| "summary_subclaims": chosen_summary_subclaims, |
| "labels": labels_map, |
| } |
| ) |
|
|
| return combined, stats |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description=( |
| "Combine per-label rows into a single object per doc_id. " |
| "Input is a JSON array with repeated doc_id for different labels." |
| ) |
| ) |
| parser.add_argument( |
| "--input", |
| required=True, |
| help="Path to input JSON file (list of rows)", |
| ) |
| parser.add_argument( |
| "--output", |
| default=None, |
| help="Path to output JSON file. Default: same folder with *_by_docid.json suffix", |
| ) |
| parser.add_argument( |
| "--minimal", |
| action="store_true", |
| help="Only keep diff_label_texts/diff_label_subclaims/fulltext_subclaims/summary_subclaims per label.", |
| ) |
|
|
| args = parser.parse_args() |
| input_path = Path(args.input) |
| output_path = Path(args.output) if args.output else input_path.with_name(input_path.stem + "_by_docid.json") |
|
|
| rows = json.loads(input_path.read_text(encoding="utf-8")) |
| if not isinstance(rows, list): |
| raise SystemExit("Input JSON must be a list") |
|
|
| combined, stats = combine_by_doc_id(rows, keep_all_fields_per_label=not args.minimal) |
|
|
| output_path.write_text( |
| json.dumps(combined, ensure_ascii=False, indent=2) + "\n", |
| encoding="utf-8", |
| ) |
|
|
| print("Wrote:", str(output_path)) |
| print( |
| "Stats:", |
| json.dumps( |
| { |
| "total_rows": stats.total_rows, |
| "total_doc_ids": stats.total_doc_ids, |
| "missing_label_rows": stats.missing_label_rows, |
| "unexpected_labels": stats.unexpected_labels, |
| "doc_ids_missing_some_labels": stats.doc_ids_missing_some_labels, |
| "doc_ids_fulltext_mismatch": stats.doc_ids_fulltext_mismatch, |
| "doc_ids_summary_mismatch": stats.doc_ids_summary_mismatch, |
| "doc_ids_fulltext_subclaims_mismatch": stats.doc_ids_fulltext_subclaims_mismatch, |
| "doc_ids_summary_subclaims_mismatch": stats.doc_ids_summary_subclaims_mismatch, |
| }, |
| indent=2, |
| ), |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|