rca-judgment-validation / scripts /prepare_tasks.py
mborcin's picture
Make root cause options leaf-only
406fdd9 verified
raw
history blame contribute delete
7.73 kB
import csv
import json
from pathlib import Path
from typing import Any
ROOT = Path(__file__).resolve().parents[2]
SPACE_ROOT = Path(__file__).resolve().parents[1]
SAMPLE_DIR = ROOT / "rca_validation_sample"
SAMPLE_INDEX = SAMPLE_DIR / "sample_index.csv"
SAMPLED_CASES = SAMPLE_DIR / "sampled_cases"
OUT_DIR = SPACE_ROOT / "data"
ROOT_CAUSE_OPTIONS = [
("INPUT ERROR", "", "Missing fact"),
("INPUT ERROR", "", "Incorrect fact"),
("INPUT ERROR", "", "Fabricated fact in case record"),
("MODEL ERROR", "Facts", "Fabricated fact in reasoning"),
("MODEL ERROR", "Facts", "Fact-weighing error"),
("MODEL ERROR", "Facts", "Fact omission in reasoning"),
("MODEL ERROR", "Issues", "Issue omission"),
("MODEL ERROR", "Issues", "Spurious issue"),
("MODEL ERROR", "Issues", "Issue misframing"),
("MODEL ERROR", "Issues", "Issue misprioritisation"),
("MODEL ERROR", "Rules", "Wrong rule source"),
("MODEL ERROR", "Rules", "Using outdated / overruled rule"),
("MODEL ERROR", "Rules", "Rule misinterpretation"),
("MODEL ERROR", "Rules", "Rule misapplication"),
("MODEL ERROR", "Rules", "Wrong legal test / threshold"),
("MODEL ERROR", "Rules", "Missed exception / qualification"),
("MODEL ERROR", "Rules", "Limitation / time-bar error"),
("MODEL ERROR", "Rules", "Precedent context mismatch"),
("MODEL ERROR", "Rules", "Rule interaction error"),
("MODEL ERROR", "Analysis", "Burden misallocation"),
("MODEL ERROR", "Analysis", "Neglected rule, fact, or counter-argument"),
("MODEL ERROR", "Analysis", "Logical fallacy / leap in logic"),
("MODEL ERROR", "Analysis", "Other reasoning error"),
("MODEL ERROR", "Conclusion", "Conclusion does not follow from analysis"),
("MODEL ERROR", "Conclusion", "Discretionary original judgment"),
]
def root_cause_options() -> list[dict[str, str]]:
options: list[dict[str, str]] = []
previous_group = None
for major, subgroup, value in ROOT_CAUSE_OPTIONS:
group = f"{major}: {subgroup}" if subgroup else major
class_name = "root-option root-option-new-group" if group != previous_group else "root-option"
previous_group = group
html = (
f"<div class='{class_name}'>"
f"<span class='root-group'>{group}</span>"
f"<span class='root-leaf'>{value}</span>"
"</div>"
)
options.append({"value": value, "html": html})
return options
def safe_get(obj: dict[str, Any], path: str, default: Any = "") -> Any:
cur: Any = obj
for part in path.split("."):
if not isinstance(cur, dict) or part not in cur:
return default
cur = cur[part]
return cur
def case_title(case: dict[str, Any], fallback: str) -> str:
metadata = case.get("metadata", {})
details = metadata.get("details", [{}])
first_details = details[0] if details and isinstance(details[0], dict) else {}
case_no = ", ".join(first_details.get("case_no", []))
petitioner = ", ".join(first_details.get("petitioner", []))
respondent = ", ".join(first_details.get("respondent", []))
parties = f"{petitioner} v. {respondent}" if petitioner or respondent else ""
bits = [
metadata.get("court", ""),
case_no,
parties,
metadata.get("judgment_date", ""),
]
return " | ".join(str(bit) for bit in bits if bit) or fallback
def short_case_label(case: dict[str, Any], fallback: str) -> str:
metadata = case.get("metadata", {})
details = metadata.get("details", [{}])
first_details = details[0] if details and isinstance(details[0], dict) else {}
case_no = ", ".join(first_details.get("case_no", []))
petitioner = ", ".join(first_details.get("petitioner", []))
respondent = ", ".join(first_details.get("respondent", []))
parties = f"{petitioner} v. {respondent}" if petitioner or respondent else ""
return " | ".join(bit for bit in [case_no, parties] if bit) or fallback
def firac_text(firac: Any) -> str:
if not isinstance(firac, dict):
return ""
parts: list[str] = []
for key in ["facts", "issues", "rules", "analysis", "conclusion"]:
value = firac.get(key, "")
if value:
parts.append(f"{key.upper()}\n{value}")
return "\n\n".join(parts)
def machine_rca(row: dict[str, str]) -> str:
include_facts = row.get("include_facts", "[]")
return "\n".join(
[
f"Stage: {row.get('error_stage', '')}",
f"Major error: {row.get('major_error_category', '')}",
f"Minor error: {row.get('minor_error_category', '')}",
f"Summary: {row.get('one_sentence_summary', '')}",
f"Evidence: {row.get('evidence_snippet', '')}",
f"Recommended fix: {row.get('recommended_fix', '')}",
f"Include facts: {include_facts}",
]
)
def load_rows() -> list[dict[str, str]]:
with SAMPLE_INDEX.open("r", encoding="utf-8", newline="") as f:
return list(csv.DictReader(f))
def build_task(row: dict[str, str]) -> dict[str, Any]:
sample_id = row["sample_id"]
filename = Path(row["copied_json_path"]).name
path = SAMPLED_CASES / filename
with path.open("r", encoding="utf-8") as f:
case = json.load(f)
extracted_conclusion = safe_get(case, "extracted.one_word_conclusion", "UNKNOWN")
rr_conclusion = safe_get(case, "rr_based.one_word_conclusion", "UNKNOWN")
reference_outcome = rr_conclusion if extracted_conclusion == "UNKNOWN" else extracted_conclusion
metadata = case.get("metadata", {})
case_details = f"Case: {short_case_label(case, sample_id)}"
generated_judgment = safe_get(case, "predicted.roles.judge", "")
if not generated_judgment:
generated_judgment = firac_text(safe_get(case, "predicted.firac", {}))
case_record = firac_text(safe_get(case, "extracted.firac", {}))
reference_judgment = case.get("judgment_text") or firac_text(safe_get(case, "rr_based.firac", {}))
return {
"data": {
"case_title": case_title(case, filename),
"case_details": case_details,
"generated_outcome": safe_get(case, "predicted.one_word_conclusion", row.get("predicted_conclusion", "")),
"reference_outcome": reference_outcome,
"case_record": safe_get(case, "extracted.firac.facts", ""),
"generated_judgment": generated_judgment,
"reference_judgment": reference_judgment,
"root_cause_options": root_cause_options(),
},
"meta": {
"sample_id": sample_id,
"case_id": row.get("case_id", ""),
"experiment": row.get("experiment", ""),
"model": row.get("model", ""),
"source_file": filename,
"court": metadata.get("court", ""),
"judgment_date": metadata.get("judgment_date", ""),
"machine_error_stage": row.get("error_stage", ""),
"machine_major_error_category": row.get("major_error_category", ""),
"machine_minor_error_category": row.get("minor_error_category", ""),
},
}
def main() -> None:
OUT_DIR.mkdir(parents=True, exist_ok=True)
tasks = [build_task(row) for row in load_rows()]
(OUT_DIR / "tasks_all.json").write_text(
json.dumps(tasks, ensure_ascii=False, indent=2),
encoding="utf-8",
)
for idx in range(3):
chunk = tasks[idx * 10 : (idx + 1) * 10]
(OUT_DIR / f"tasks_annotator_{idx + 1}.json").write_text(
json.dumps(chunk, ensure_ascii=False, indent=2),
encoding="utf-8",
)
print(f"Wrote {len(tasks)} tasks to {OUT_DIR}")
if __name__ == "__main__":
main()