File size: 7,729 Bytes
fe3046d
 
 
 
 
 
 
 
 
 
 
 
 
406fdd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe3046d
 
 
406fdd9
fe3046d
406fdd9
 
 
 
 
fe3046d
406fdd9
 
 
fe3046d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a2c6b2
 
 
 
 
 
 
 
 
 
 
fe3046d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a2c6b2
fe3046d
 
 
 
 
 
 
 
 
 
 
 
 
 
5a2c6b2
fe3046d
 
406fdd9
fe3046d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import csv
import json
from pathlib import Path
from typing import Any


ROOT = Path(__file__).resolve().parents[2]
SPACE_ROOT = Path(__file__).resolve().parents[1]
SAMPLE_DIR = ROOT / "rca_validation_sample"
SAMPLE_INDEX = SAMPLE_DIR / "sample_index.csv"
SAMPLED_CASES = SAMPLE_DIR / "sampled_cases"
OUT_DIR = SPACE_ROOT / "data"

ROOT_CAUSE_OPTIONS = [
    ("INPUT ERROR", "", "Missing fact"),
    ("INPUT ERROR", "", "Incorrect fact"),
    ("INPUT ERROR", "", "Fabricated fact in case record"),
    ("MODEL ERROR", "Facts", "Fabricated fact in reasoning"),
    ("MODEL ERROR", "Facts", "Fact-weighing error"),
    ("MODEL ERROR", "Facts", "Fact omission in reasoning"),
    ("MODEL ERROR", "Issues", "Issue omission"),
    ("MODEL ERROR", "Issues", "Spurious issue"),
    ("MODEL ERROR", "Issues", "Issue misframing"),
    ("MODEL ERROR", "Issues", "Issue misprioritisation"),
    ("MODEL ERROR", "Rules", "Wrong rule source"),
    ("MODEL ERROR", "Rules", "Using outdated / overruled rule"),
    ("MODEL ERROR", "Rules", "Rule misinterpretation"),
    ("MODEL ERROR", "Rules", "Rule misapplication"),
    ("MODEL ERROR", "Rules", "Wrong legal test / threshold"),
    ("MODEL ERROR", "Rules", "Missed exception / qualification"),
    ("MODEL ERROR", "Rules", "Limitation / time-bar error"),
    ("MODEL ERROR", "Rules", "Precedent context mismatch"),
    ("MODEL ERROR", "Rules", "Rule interaction error"),
    ("MODEL ERROR", "Analysis", "Burden misallocation"),
    ("MODEL ERROR", "Analysis", "Neglected rule, fact, or counter-argument"),
    ("MODEL ERROR", "Analysis", "Logical fallacy / leap in logic"),
    ("MODEL ERROR", "Analysis", "Other reasoning error"),
    ("MODEL ERROR", "Conclusion", "Conclusion does not follow from analysis"),
    ("MODEL ERROR", "Conclusion", "Discretionary original judgment"),
]


def root_cause_options() -> list[dict[str, str]]:
    options: list[dict[str, str]] = []
    previous_group = None
    for major, subgroup, value in ROOT_CAUSE_OPTIONS:
        group = f"{major}: {subgroup}" if subgroup else major
        class_name = "root-option root-option-new-group" if group != previous_group else "root-option"
        previous_group = group
        html = (
            f"<div class='{class_name}'>"
            f"<span class='root-group'>{group}</span>"
            f"<span class='root-leaf'>{value}</span>"
            "</div>"
        )
        options.append({"value": value, "html": html})
    return options


def safe_get(obj: dict[str, Any], path: str, default: Any = "") -> Any:
    cur: Any = obj
    for part in path.split("."):
        if not isinstance(cur, dict) or part not in cur:
            return default
        cur = cur[part]
    return cur


def case_title(case: dict[str, Any], fallback: str) -> str:
    metadata = case.get("metadata", {})
    details = metadata.get("details", [{}])
    first_details = details[0] if details and isinstance(details[0], dict) else {}

    case_no = ", ".join(first_details.get("case_no", []))
    petitioner = ", ".join(first_details.get("petitioner", []))
    respondent = ", ".join(first_details.get("respondent", []))
    parties = f"{petitioner} v. {respondent}" if petitioner or respondent else ""

    bits = [
        metadata.get("court", ""),
        case_no,
        parties,
        metadata.get("judgment_date", ""),
    ]
    return " | ".join(str(bit) for bit in bits if bit) or fallback


def short_case_label(case: dict[str, Any], fallback: str) -> str:
    metadata = case.get("metadata", {})
    details = metadata.get("details", [{}])
    first_details = details[0] if details and isinstance(details[0], dict) else {}
    case_no = ", ".join(first_details.get("case_no", []))
    petitioner = ", ".join(first_details.get("petitioner", []))
    respondent = ", ".join(first_details.get("respondent", []))
    parties = f"{petitioner} v. {respondent}" if petitioner or respondent else ""
    return " | ".join(bit for bit in [case_no, parties] if bit) or fallback


def firac_text(firac: Any) -> str:
    if not isinstance(firac, dict):
        return ""
    parts: list[str] = []
    for key in ["facts", "issues", "rules", "analysis", "conclusion"]:
        value = firac.get(key, "")
        if value:
            parts.append(f"{key.upper()}\n{value}")
    return "\n\n".join(parts)


def machine_rca(row: dict[str, str]) -> str:
    include_facts = row.get("include_facts", "[]")
    return "\n".join(
        [
            f"Stage: {row.get('error_stage', '')}",
            f"Major error: {row.get('major_error_category', '')}",
            f"Minor error: {row.get('minor_error_category', '')}",
            f"Summary: {row.get('one_sentence_summary', '')}",
            f"Evidence: {row.get('evidence_snippet', '')}",
            f"Recommended fix: {row.get('recommended_fix', '')}",
            f"Include facts: {include_facts}",
        ]
    )


def load_rows() -> list[dict[str, str]]:
    with SAMPLE_INDEX.open("r", encoding="utf-8", newline="") as f:
        return list(csv.DictReader(f))


def build_task(row: dict[str, str]) -> dict[str, Any]:
    sample_id = row["sample_id"]
    filename = Path(row["copied_json_path"]).name
    path = SAMPLED_CASES / filename
    with path.open("r", encoding="utf-8") as f:
        case = json.load(f)

    extracted_conclusion = safe_get(case, "extracted.one_word_conclusion", "UNKNOWN")
    rr_conclusion = safe_get(case, "rr_based.one_word_conclusion", "UNKNOWN")
    reference_outcome = rr_conclusion if extracted_conclusion == "UNKNOWN" else extracted_conclusion

    metadata = case.get("metadata", {})
    case_details = f"Case: {short_case_label(case, sample_id)}"

    generated_judgment = safe_get(case, "predicted.roles.judge", "")
    if not generated_judgment:
        generated_judgment = firac_text(safe_get(case, "predicted.firac", {}))

    case_record = firac_text(safe_get(case, "extracted.firac", {}))
    reference_judgment = case.get("judgment_text") or firac_text(safe_get(case, "rr_based.firac", {}))

    return {
        "data": {
            "case_title": case_title(case, filename),
            "case_details": case_details,
            "generated_outcome": safe_get(case, "predicted.one_word_conclusion", row.get("predicted_conclusion", "")),
            "reference_outcome": reference_outcome,
            "case_record": safe_get(case, "extracted.firac.facts", ""),
            "generated_judgment": generated_judgment,
            "reference_judgment": reference_judgment,
            "root_cause_options": root_cause_options(),
        },
        "meta": {
            "sample_id": sample_id,
            "case_id": row.get("case_id", ""),
            "experiment": row.get("experiment", ""),
            "model": row.get("model", ""),
            "source_file": filename,
            "court": metadata.get("court", ""),
            "judgment_date": metadata.get("judgment_date", ""),
            "machine_error_stage": row.get("error_stage", ""),
            "machine_major_error_category": row.get("major_error_category", ""),
            "machine_minor_error_category": row.get("minor_error_category", ""),
        },
    }


def main() -> None:
    OUT_DIR.mkdir(parents=True, exist_ok=True)
    tasks = [build_task(row) for row in load_rows()]

    (OUT_DIR / "tasks_all.json").write_text(
        json.dumps(tasks, ensure_ascii=False, indent=2),
        encoding="utf-8",
    )

    for idx in range(3):
        chunk = tasks[idx * 10 : (idx + 1) * 10]
        (OUT_DIR / f"tasks_annotator_{idx + 1}.json").write_text(
            json.dumps(chunk, ensure_ascii=False, indent=2),
            encoding="utf-8",
        )

    print(f"Wrote {len(tasks)} tasks to {OUT_DIR}")


if __name__ == "__main__":
    main()