readctrl / code /combine_docid_labels.py
shahidul034's picture
Add files using upload-large-folder tool
1db7196 verified
#!/usr/bin/env python3
import argparse
import json
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
EXPECTED_LABELS = (
"low_health_literacy",
"intermediate_health_literacy",
"proficient_health_literacy",
)
@dataclass
class MergeStats:
total_rows: int = 0
total_doc_ids: int = 0
missing_label_rows: int = 0
unexpected_labels: int = 0
doc_ids_missing_some_labels: int = 0
doc_ids_fulltext_mismatch: int = 0
doc_ids_summary_mismatch: int = 0
doc_ids_fulltext_subclaims_mismatch: int = 0
doc_ids_summary_subclaims_mismatch: int = 0
def _pick_first_non_empty(values: List[Optional[str]]) -> Optional[str]:
for value in values:
if isinstance(value, str) and value.strip():
return value
for value in values:
if value is not None:
return value
return None
def _normalize_text(value: Any) -> Optional[str]:
if value is None:
return None
if not isinstance(value, str):
return str(value)
return value
def _normalize_string_list(value: Any) -> Optional[Tuple[str, ...]]:
if value is None:
return None
if not isinstance(value, list):
return (str(value),)
normalized: List[str] = []
for item in value:
if item is None:
continue
if isinstance(item, str):
normalized.append(item.strip())
else:
normalized.append(str(item).strip())
return tuple(normalized)
def combine_by_doc_id(rows: List[Dict[str, Any]], keep_all_fields_per_label: bool = True) -> Tuple[List[Dict[str, Any]], MergeStats]:
stats = MergeStats(total_rows=len(rows))
grouped: Dict[int, List[Dict[str, Any]]] = defaultdict(list)
for row in rows:
if not isinstance(row, dict):
continue
doc_id = row.get("doc_id")
if doc_id is None:
continue
grouped[int(doc_id)].append(row)
stats.total_doc_ids = len(grouped)
combined: List[Dict[str, Any]] = []
for doc_id in sorted(grouped.keys()):
bucket = grouped[doc_id]
labels_map: Dict[str, Dict[str, Any]] = {}
fulltexts: List[Optional[str]] = []
summaries: List[Optional[str]] = []
fulltext_subclaims_sets: List[Optional[Tuple[str, ...]]] = []
summary_subclaims_sets: List[Optional[Tuple[str, ...]]] = []
for row in bucket:
label = row.get("label")
if not label:
stats.missing_label_rows += 1
continue
if label not in EXPECTED_LABELS:
stats.unexpected_labels += 1
fulltexts.append(_normalize_text(row.get("fulltext")))
summaries.append(_normalize_text(row.get("summary")))
fulltext_subclaims_sets.append(_normalize_string_list(row.get("fulltext_subclaims")))
summary_subclaims_sets.append(_normalize_string_list(row.get("summary_subclaims")))
label_payload: Dict[str, Any]
if keep_all_fields_per_label:
# Shared within a doc_id; keep them only once at top-level
label_payload = {
k: v
for k, v in row.items()
if k
not in (
"doc_id",
"label",
"fulltext",
"summary",
"fulltext_subclaims",
"summary_subclaims",
)
}
else:
label_payload = {
"diff_label_texts": row.get("diff_label_texts"),
"diff_label_subclaims": row.get("diff_label_subclaims"),
}
labels_map[str(label)] = label_payload
chosen_fulltext = _pick_first_non_empty(fulltexts)
chosen_summary = _pick_first_non_empty(summaries)
chosen_fulltext_subclaims: Optional[List[str]] = None
for items in fulltext_subclaims_sets:
if items:
chosen_fulltext_subclaims = list(items)
break
chosen_summary_subclaims: Optional[List[str]] = None
for items in summary_subclaims_sets:
if items:
chosen_summary_subclaims = list(items)
break
distinct_fulltexts = {t.strip() for t in fulltexts if isinstance(t, str) and t.strip()}
distinct_summaries = {t.strip() for t in summaries if isinstance(t, str) and t.strip()}
if len(distinct_fulltexts) > 1:
stats.doc_ids_fulltext_mismatch += 1
if len(distinct_summaries) > 1:
stats.doc_ids_summary_mismatch += 1
distinct_fulltext_subclaims = {t for t in fulltext_subclaims_sets if t}
distinct_summary_subclaims = {t for t in summary_subclaims_sets if t}
if len(distinct_fulltext_subclaims) > 1:
stats.doc_ids_fulltext_subclaims_mismatch += 1
if len(distinct_summary_subclaims) > 1:
stats.doc_ids_summary_subclaims_mismatch += 1
missing_some = any(lbl not in labels_map for lbl in EXPECTED_LABELS)
if missing_some:
stats.doc_ids_missing_some_labels += 1
combined.append(
{
"doc_id": doc_id,
"fulltext": chosen_fulltext,
"fulltext_subclaims": chosen_fulltext_subclaims,
"summary": chosen_summary,
"summary_subclaims": chosen_summary_subclaims,
"labels": labels_map,
}
)
return combined, stats
def main() -> None:
parser = argparse.ArgumentParser(
description=(
"Combine per-label rows into a single object per doc_id. "
"Input is a JSON array with repeated doc_id for different labels."
)
)
parser.add_argument(
"--input",
required=True,
help="Path to input JSON file (list of rows)",
)
parser.add_argument(
"--output",
default=None,
help="Path to output JSON file. Default: same folder with *_by_docid.json suffix",
)
parser.add_argument(
"--minimal",
action="store_true",
help="Only keep diff_label_texts/diff_label_subclaims/fulltext_subclaims/summary_subclaims per label.",
)
args = parser.parse_args()
input_path = Path(args.input)
output_path = Path(args.output) if args.output else input_path.with_name(input_path.stem + "_by_docid.json")
rows = json.loads(input_path.read_text(encoding="utf-8"))
if not isinstance(rows, list):
raise SystemExit("Input JSON must be a list")
combined, stats = combine_by_doc_id(rows, keep_all_fields_per_label=not args.minimal)
output_path.write_text(
json.dumps(combined, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
print("Wrote:", str(output_path))
print(
"Stats:",
json.dumps(
{
"total_rows": stats.total_rows,
"total_doc_ids": stats.total_doc_ids,
"missing_label_rows": stats.missing_label_rows,
"unexpected_labels": stats.unexpected_labels,
"doc_ids_missing_some_labels": stats.doc_ids_missing_some_labels,
"doc_ids_fulltext_mismatch": stats.doc_ids_fulltext_mismatch,
"doc_ids_summary_mismatch": stats.doc_ids_summary_mismatch,
"doc_ids_fulltext_subclaims_mismatch": stats.doc_ids_fulltext_subclaims_mismatch,
"doc_ids_summary_subclaims_mismatch": stats.doc_ids_summary_subclaims_mismatch,
},
indent=2,
),
)
if __name__ == "__main__":
main()