File size: 7,841 Bytes
1db7196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
#!/usr/bin/env python3

import argparse
import json
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple


EXPECTED_LABELS = (
    "low_health_literacy",
    "intermediate_health_literacy",
    "proficient_health_literacy",
)


@dataclass
class MergeStats:
    total_rows: int = 0
    total_doc_ids: int = 0
    missing_label_rows: int = 0
    unexpected_labels: int = 0
    doc_ids_missing_some_labels: int = 0
    doc_ids_fulltext_mismatch: int = 0
    doc_ids_summary_mismatch: int = 0
    doc_ids_fulltext_subclaims_mismatch: int = 0
    doc_ids_summary_subclaims_mismatch: int = 0


def _pick_first_non_empty(values: List[Optional[str]]) -> Optional[str]:
    for value in values:
        if isinstance(value, str) and value.strip():
            return value
    for value in values:
        if value is not None:
            return value
    return None


def _normalize_text(value: Any) -> Optional[str]:
    if value is None:
        return None
    if not isinstance(value, str):
        return str(value)
    return value


def _normalize_string_list(value: Any) -> Optional[Tuple[str, ...]]:
    if value is None:
        return None
    if not isinstance(value, list):
        return (str(value),)
    normalized: List[str] = []
    for item in value:
        if item is None:
            continue
        if isinstance(item, str):
            normalized.append(item.strip())
        else:
            normalized.append(str(item).strip())
    return tuple(normalized)


def combine_by_doc_id(rows: List[Dict[str, Any]], keep_all_fields_per_label: bool = True) -> Tuple[List[Dict[str, Any]], MergeStats]:
    stats = MergeStats(total_rows=len(rows))

    grouped: Dict[int, List[Dict[str, Any]]] = defaultdict(list)
    for row in rows:
        if not isinstance(row, dict):
            continue
        doc_id = row.get("doc_id")
        if doc_id is None:
            continue
        grouped[int(doc_id)].append(row)

    stats.total_doc_ids = len(grouped)

    combined: List[Dict[str, Any]] = []

    for doc_id in sorted(grouped.keys()):
        bucket = grouped[doc_id]

        labels_map: Dict[str, Dict[str, Any]] = {}
        fulltexts: List[Optional[str]] = []
        summaries: List[Optional[str]] = []
        fulltext_subclaims_sets: List[Optional[Tuple[str, ...]]] = []
        summary_subclaims_sets: List[Optional[Tuple[str, ...]]] = []

        for row in bucket:
            label = row.get("label")
            if not label:
                stats.missing_label_rows += 1
                continue
            if label not in EXPECTED_LABELS:
                stats.unexpected_labels += 1

            fulltexts.append(_normalize_text(row.get("fulltext")))
            summaries.append(_normalize_text(row.get("summary")))
            fulltext_subclaims_sets.append(_normalize_string_list(row.get("fulltext_subclaims")))
            summary_subclaims_sets.append(_normalize_string_list(row.get("summary_subclaims")))

            label_payload: Dict[str, Any]
            if keep_all_fields_per_label:
                # Shared within a doc_id; keep them only once at top-level
                label_payload = {
                    k: v
                    for k, v in row.items()
                    if k
                    not in (
                        "doc_id",
                        "label",
                        "fulltext",
                        "summary",
                        "fulltext_subclaims",
                        "summary_subclaims",
                    )
                }
            else:
                label_payload = {
                    "diff_label_texts": row.get("diff_label_texts"),
                    "diff_label_subclaims": row.get("diff_label_subclaims"),
                }

            labels_map[str(label)] = label_payload

        chosen_fulltext = _pick_first_non_empty(fulltexts)
        chosen_summary = _pick_first_non_empty(summaries)

        chosen_fulltext_subclaims: Optional[List[str]] = None
        for items in fulltext_subclaims_sets:
            if items:
                chosen_fulltext_subclaims = list(items)
                break
        chosen_summary_subclaims: Optional[List[str]] = None
        for items in summary_subclaims_sets:
            if items:
                chosen_summary_subclaims = list(items)
                break

        distinct_fulltexts = {t.strip() for t in fulltexts if isinstance(t, str) and t.strip()}
        distinct_summaries = {t.strip() for t in summaries if isinstance(t, str) and t.strip()}
        if len(distinct_fulltexts) > 1:
            stats.doc_ids_fulltext_mismatch += 1
        if len(distinct_summaries) > 1:
            stats.doc_ids_summary_mismatch += 1

        distinct_fulltext_subclaims = {t for t in fulltext_subclaims_sets if t}
        distinct_summary_subclaims = {t for t in summary_subclaims_sets if t}
        if len(distinct_fulltext_subclaims) > 1:
            stats.doc_ids_fulltext_subclaims_mismatch += 1
        if len(distinct_summary_subclaims) > 1:
            stats.doc_ids_summary_subclaims_mismatch += 1

        missing_some = any(lbl not in labels_map for lbl in EXPECTED_LABELS)
        if missing_some:
            stats.doc_ids_missing_some_labels += 1

        combined.append(
            {
                "doc_id": doc_id,
                "fulltext": chosen_fulltext,
                "fulltext_subclaims": chosen_fulltext_subclaims,
                "summary": chosen_summary,
                "summary_subclaims": chosen_summary_subclaims,
                "labels": labels_map,
            }
        )

    return combined, stats


def main() -> None:
    parser = argparse.ArgumentParser(
        description=(
            "Combine per-label rows into a single object per doc_id. "
            "Input is a JSON array with repeated doc_id for different labels."
        )
    )
    parser.add_argument(
        "--input",
        required=True,
        help="Path to input JSON file (list of rows)",
    )
    parser.add_argument(
        "--output",
        default=None,
        help="Path to output JSON file. Default: same folder with *_by_docid.json suffix",
    )
    parser.add_argument(
        "--minimal",
        action="store_true",
        help="Only keep diff_label_texts/diff_label_subclaims/fulltext_subclaims/summary_subclaims per label.",
    )

    args = parser.parse_args()
    input_path = Path(args.input)
    output_path = Path(args.output) if args.output else input_path.with_name(input_path.stem + "_by_docid.json")

    rows = json.loads(input_path.read_text(encoding="utf-8"))
    if not isinstance(rows, list):
        raise SystemExit("Input JSON must be a list")

    combined, stats = combine_by_doc_id(rows, keep_all_fields_per_label=not args.minimal)

    output_path.write_text(
        json.dumps(combined, ensure_ascii=False, indent=2) + "\n",
        encoding="utf-8",
    )

    print("Wrote:", str(output_path))
    print(
        "Stats:",
        json.dumps(
            {
                "total_rows": stats.total_rows,
                "total_doc_ids": stats.total_doc_ids,
                "missing_label_rows": stats.missing_label_rows,
                "unexpected_labels": stats.unexpected_labels,
                "doc_ids_missing_some_labels": stats.doc_ids_missing_some_labels,
                "doc_ids_fulltext_mismatch": stats.doc_ids_fulltext_mismatch,
                "doc_ids_summary_mismatch": stats.doc_ids_summary_mismatch,
                "doc_ids_fulltext_subclaims_mismatch": stats.doc_ids_fulltext_subclaims_mismatch,
                "doc_ids_summary_subclaims_mismatch": stats.doc_ids_summary_subclaims_mismatch,
            },
            indent=2,
        ),
    )


if __name__ == "__main__":
    main()