Spaces:
Sleeping
Sleeping
File size: 7,729 Bytes
fe3046d 406fdd9 fe3046d 406fdd9 fe3046d 406fdd9 fe3046d 406fdd9 fe3046d 5a2c6b2 fe3046d 5a2c6b2 fe3046d 5a2c6b2 fe3046d 406fdd9 fe3046d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | import csv
import json
from pathlib import Path
from typing import Any
ROOT = Path(__file__).resolve().parents[2]
SPACE_ROOT = Path(__file__).resolve().parents[1]
SAMPLE_DIR = ROOT / "rca_validation_sample"
SAMPLE_INDEX = SAMPLE_DIR / "sample_index.csv"
SAMPLED_CASES = SAMPLE_DIR / "sampled_cases"
OUT_DIR = SPACE_ROOT / "data"
ROOT_CAUSE_OPTIONS = [
("INPUT ERROR", "", "Missing fact"),
("INPUT ERROR", "", "Incorrect fact"),
("INPUT ERROR", "", "Fabricated fact in case record"),
("MODEL ERROR", "Facts", "Fabricated fact in reasoning"),
("MODEL ERROR", "Facts", "Fact-weighing error"),
("MODEL ERROR", "Facts", "Fact omission in reasoning"),
("MODEL ERROR", "Issues", "Issue omission"),
("MODEL ERROR", "Issues", "Spurious issue"),
("MODEL ERROR", "Issues", "Issue misframing"),
("MODEL ERROR", "Issues", "Issue misprioritisation"),
("MODEL ERROR", "Rules", "Wrong rule source"),
("MODEL ERROR", "Rules", "Using outdated / overruled rule"),
("MODEL ERROR", "Rules", "Rule misinterpretation"),
("MODEL ERROR", "Rules", "Rule misapplication"),
("MODEL ERROR", "Rules", "Wrong legal test / threshold"),
("MODEL ERROR", "Rules", "Missed exception / qualification"),
("MODEL ERROR", "Rules", "Limitation / time-bar error"),
("MODEL ERROR", "Rules", "Precedent context mismatch"),
("MODEL ERROR", "Rules", "Rule interaction error"),
("MODEL ERROR", "Analysis", "Burden misallocation"),
("MODEL ERROR", "Analysis", "Neglected rule, fact, or counter-argument"),
("MODEL ERROR", "Analysis", "Logical fallacy / leap in logic"),
("MODEL ERROR", "Analysis", "Other reasoning error"),
("MODEL ERROR", "Conclusion", "Conclusion does not follow from analysis"),
("MODEL ERROR", "Conclusion", "Discretionary original judgment"),
]
def root_cause_options() -> list[dict[str, str]]:
options: list[dict[str, str]] = []
previous_group = None
for major, subgroup, value in ROOT_CAUSE_OPTIONS:
group = f"{major}: {subgroup}" if subgroup else major
class_name = "root-option root-option-new-group" if group != previous_group else "root-option"
previous_group = group
html = (
f"<div class='{class_name}'>"
f"<span class='root-group'>{group}</span>"
f"<span class='root-leaf'>{value}</span>"
"</div>"
)
options.append({"value": value, "html": html})
return options
def safe_get(obj: dict[str, Any], path: str, default: Any = "") -> Any:
cur: Any = obj
for part in path.split("."):
if not isinstance(cur, dict) or part not in cur:
return default
cur = cur[part]
return cur
def case_title(case: dict[str, Any], fallback: str) -> str:
metadata = case.get("metadata", {})
details = metadata.get("details", [{}])
first_details = details[0] if details and isinstance(details[0], dict) else {}
case_no = ", ".join(first_details.get("case_no", []))
petitioner = ", ".join(first_details.get("petitioner", []))
respondent = ", ".join(first_details.get("respondent", []))
parties = f"{petitioner} v. {respondent}" if petitioner or respondent else ""
bits = [
metadata.get("court", ""),
case_no,
parties,
metadata.get("judgment_date", ""),
]
return " | ".join(str(bit) for bit in bits if bit) or fallback
def short_case_label(case: dict[str, Any], fallback: str) -> str:
metadata = case.get("metadata", {})
details = metadata.get("details", [{}])
first_details = details[0] if details and isinstance(details[0], dict) else {}
case_no = ", ".join(first_details.get("case_no", []))
petitioner = ", ".join(first_details.get("petitioner", []))
respondent = ", ".join(first_details.get("respondent", []))
parties = f"{petitioner} v. {respondent}" if petitioner or respondent else ""
return " | ".join(bit for bit in [case_no, parties] if bit) or fallback
def firac_text(firac: Any) -> str:
if not isinstance(firac, dict):
return ""
parts: list[str] = []
for key in ["facts", "issues", "rules", "analysis", "conclusion"]:
value = firac.get(key, "")
if value:
parts.append(f"{key.upper()}\n{value}")
return "\n\n".join(parts)
def machine_rca(row: dict[str, str]) -> str:
include_facts = row.get("include_facts", "[]")
return "\n".join(
[
f"Stage: {row.get('error_stage', '')}",
f"Major error: {row.get('major_error_category', '')}",
f"Minor error: {row.get('minor_error_category', '')}",
f"Summary: {row.get('one_sentence_summary', '')}",
f"Evidence: {row.get('evidence_snippet', '')}",
f"Recommended fix: {row.get('recommended_fix', '')}",
f"Include facts: {include_facts}",
]
)
def load_rows() -> list[dict[str, str]]:
with SAMPLE_INDEX.open("r", encoding="utf-8", newline="") as f:
return list(csv.DictReader(f))
def build_task(row: dict[str, str]) -> dict[str, Any]:
sample_id = row["sample_id"]
filename = Path(row["copied_json_path"]).name
path = SAMPLED_CASES / filename
with path.open("r", encoding="utf-8") as f:
case = json.load(f)
extracted_conclusion = safe_get(case, "extracted.one_word_conclusion", "UNKNOWN")
rr_conclusion = safe_get(case, "rr_based.one_word_conclusion", "UNKNOWN")
reference_outcome = rr_conclusion if extracted_conclusion == "UNKNOWN" else extracted_conclusion
metadata = case.get("metadata", {})
case_details = f"Case: {short_case_label(case, sample_id)}"
generated_judgment = safe_get(case, "predicted.roles.judge", "")
if not generated_judgment:
generated_judgment = firac_text(safe_get(case, "predicted.firac", {}))
case_record = firac_text(safe_get(case, "extracted.firac", {}))
reference_judgment = case.get("judgment_text") or firac_text(safe_get(case, "rr_based.firac", {}))
return {
"data": {
"case_title": case_title(case, filename),
"case_details": case_details,
"generated_outcome": safe_get(case, "predicted.one_word_conclusion", row.get("predicted_conclusion", "")),
"reference_outcome": reference_outcome,
"case_record": safe_get(case, "extracted.firac.facts", ""),
"generated_judgment": generated_judgment,
"reference_judgment": reference_judgment,
"root_cause_options": root_cause_options(),
},
"meta": {
"sample_id": sample_id,
"case_id": row.get("case_id", ""),
"experiment": row.get("experiment", ""),
"model": row.get("model", ""),
"source_file": filename,
"court": metadata.get("court", ""),
"judgment_date": metadata.get("judgment_date", ""),
"machine_error_stage": row.get("error_stage", ""),
"machine_major_error_category": row.get("major_error_category", ""),
"machine_minor_error_category": row.get("minor_error_category", ""),
},
}
def main() -> None:
OUT_DIR.mkdir(parents=True, exist_ok=True)
tasks = [build_task(row) for row in load_rows()]
(OUT_DIR / "tasks_all.json").write_text(
json.dumps(tasks, ensure_ascii=False, indent=2),
encoding="utf-8",
)
for idx in range(3):
chunk = tasks[idx * 10 : (idx + 1) * 10]
(OUT_DIR / f"tasks_annotator_{idx + 1}.json").write_text(
json.dumps(chunk, ensure_ascii=False, indent=2),
encoding="utf-8",
)
print(f"Wrote {len(tasks)} tasks to {OUT_DIR}")
if __name__ == "__main__":
main()
|