cprm / scripts /annotation_app.py
Zhuohan's picture
Add HF Spaces deployment and dataset-backed annotation storage
3f6201e
#!/usr/bin/env python3
from __future__ import annotations
import copy
import io
import json
import os
import sys
from datetime import datetime, timezone
from pathlib import Path
import streamlit as st
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
SCRIPT_DIR = Path(__file__).resolve().parent
if str(SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(SCRIPT_DIR))
import validate_compliance_prm as validator
ROOT = Path(__file__).resolve().parents[1]
SOURCE_BUNDLES_PATH = ROOT / "data" / "bundles" / "pilot_bundles_v1.jsonl"
GUIDELINE_PATH = ROOT / "data" / "docs" / "pilot_annotation_guideline_v1.md"
ANNOTATIONS_DIR = ROOT / "data" / "annotations"
TARGET_BUNDLE_IDS = [
"A17_CN_BUNDLE",
"A17_US_BUNDLE",
"A17_ISLAMIC_BUNDLE",
"M29_CN_BUNDLE",
"M29_US_BUNDLE",
"M29_ISLAMIC_BUNDLE",
]
TRACE_LABELS = ["compliant", "deadline_missed", "hard_violation"]
STATUSES = ["in_progress", "final"]
def load_jsonl(path: Path) -> list[dict]:
with path.open("r", encoding="utf-8") as handle:
return [json.loads(line) for line in handle if line.strip()]
def load_source_bundles() -> dict[str, dict]:
bundles = {
bundle["bundle_id"]: bundle
for bundle in load_jsonl(SOURCE_BUNDLES_PATH)
if bundle["bundle_id"] in TARGET_BUNDLE_IDS
}
return {bundle_id: bundles[bundle_id] for bundle_id in TARGET_BUNDLE_IDS}
def annotation_path(annotator_id: str, bundle_id: str) -> Path:
return ANNOTATIONS_DIR / annotator_id / f"{bundle_id}.json"
def dataset_repo_id() -> str:
return os.getenv("HF_DATASET_REPO", "").strip()
def dataset_repo_subdir() -> str:
return os.getenv("HF_DATASET_SUBDIR", "annotations").strip().strip("/") or "annotations"
def hf_token() -> str:
for key in ("HF_TOKEN", "HUGGINGFACEHUB_API_TOKEN"):
value = os.getenv(key, "").strip()
if value:
return value
return ""
def storage_backend() -> str:
if dataset_repo_id() and hf_token():
return "hf_dataset"
return "local"
def dataset_repo_path(annotator_id: str, bundle_id: str) -> str:
return f"{dataset_repo_subdir()}/{annotator_id}/{bundle_id}.json"
def build_initial_annotation(bundle: dict, annotator_id: str) -> dict:
annotation = copy.deepcopy(bundle)
annotation["annotator_id"] = annotator_id
annotation["status"] = "in_progress"
annotation["updated_at"] = None
annotation["change_notes"] = ""
return annotation
def save_local_annotation(payload: dict) -> Path:
path = annotation_path(payload["annotator_id"], payload["bundle_id"])
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as handle:
json.dump(payload, handle, indent=2, ensure_ascii=False)
return path
def load_remote_annotation(bundle: dict, annotator_id: str) -> dict | None:
try:
downloaded_path = hf_hub_download(
repo_id=dataset_repo_id(),
filename=dataset_repo_path(annotator_id, bundle["bundle_id"]),
repo_type="dataset",
token=hf_token(),
)
except EntryNotFoundError:
return None
except Exception:
return None
with Path(downloaded_path).open("r", encoding="utf-8") as handle:
return json.load(handle)
def save_remote_annotation(payload: dict) -> str:
repo_id = dataset_repo_id()
api = HfApi(token=hf_token())
api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=True)
repo_path = dataset_repo_path(payload["annotator_id"], payload["bundle_id"])
payload_bytes = json.dumps(payload, indent=2, ensure_ascii=False).encode("utf-8")
api.upload_file(
path_or_fileobj=io.BytesIO(payload_bytes),
path_in_repo=repo_path,
repo_id=repo_id,
repo_type="dataset",
commit_message=f"Update annotation: {payload['bundle_id']} ({payload['annotator_id']})",
)
return f"hf://datasets/{repo_id}/{repo_path}"
def load_annotation(bundle: dict, annotator_id: str) -> dict:
path = annotation_path(annotator_id, bundle["bundle_id"])
if path.exists():
with path.open("r", encoding="utf-8") as handle:
return json.load(handle)
if storage_backend() == "hf_dataset":
remote = load_remote_annotation(bundle, annotator_id)
if remote is not None:
return remote
return build_initial_annotation(bundle, annotator_id)
def save_annotation(annotation: dict) -> str:
payload = copy.deepcopy(annotation)
payload["updated_at"] = datetime.now(timezone.utc).isoformat()
local_path = save_local_annotation(payload)
if storage_backend() == "hf_dataset":
remote_path = save_remote_annotation(payload)
return f"{remote_path} (local mirror: {local_path})"
return str(local_path)
def require_password() -> None:
expected_password = os.getenv("ANNOTATION_APP_PASSWORD", "").strip()
if not expected_password:
return
if st.session_state.get("authenticated"):
return
st.title("CPRM Annotation App")
st.caption("This instance is password-protected.")
typed_password = st.text_input("Shared Password", type="password")
if st.button("Unlock"):
if typed_password == expected_password:
st.session_state["authenticated"] = True
st.rerun()
st.error("Incorrect password.")
st.stop()
def read_guideline() -> str:
if GUIDELINE_PATH.exists():
return GUIDELINE_PATH.read_text(encoding="utf-8")
return "Guideline file not found. Generate `data/docs/pilot_annotation_guideline_v1.md` first."
def reset_guideline_gate() -> None:
st.session_state["guideline_acknowledged"] = False
st.session_state["guideline_confirmed_for"] = None
def render_guideline_gate() -> None:
st.title("CPRM Pilot Annotation App")
st.caption("Step 1 of 2: read the guideline, confirm it, then enter the annotation workspace.")
annotator_id = st.text_input(
"Annotator ID",
value=st.session_state.get("annotator_id", "solo_annotator"),
help="Use a stable ID so saved files go to a consistent annotation folder.",
).strip()
st.session_state["annotator_id"] = annotator_id
if not annotator_id:
st.info("Enter an annotator ID before continuing.")
st.stop()
st.subheader("Guideline")
st.markdown(read_guideline())
acknowledged = st.checkbox(
"I have read the guideline and I understand that round 1 only edits existing fields and does not change step count.",
value=False,
key="guideline_ack_checkbox",
)
if st.button("Enter Annotation Workspace", type="primary", disabled=not acknowledged):
st.session_state["guideline_acknowledged"] = True
st.session_state["guideline_confirmed_for"] = annotator_id
st.rerun()
st.stop()
def ensure_working_annotation(source_bundle: dict, annotator_id: str, bundle_id: str) -> dict:
state_key = "working_bundle_key"
target_key = f"{annotator_id}:{bundle_id}"
if st.session_state.get(state_key) != target_key:
st.session_state[state_key] = target_key
st.session_state["working_bundle"] = load_annotation(source_bundle, annotator_id)
return copy.deepcopy(st.session_state["working_bundle"])
def step_key(bundle_id: str, trace_id: str, step_id: int, field: str, suffix: str = "") -> str:
extra = f":{suffix}" if suffix else ""
return f"{bundle_id}:{trace_id}:{step_id}:{field}{extra}"
def trace_key(bundle_id: str, trace_id: str, field: str) -> str:
return f"{bundle_id}:{trace_id}:{field}"
def get_rule_options(bundle: dict) -> list[str]:
seen: set[str] = set()
ordered_rule_ids: list[str] = []
def add_rule(rule_id: str | None) -> None:
if not rule_id or rule_id in seen:
return
seen.add(rule_id)
ordered_rule_ids.append(rule_id)
for rule_id in bundle["rulebook"]:
add_rule(rule_id)
for candidate in bundle["candidates"]:
for step in candidate["steps"]:
for rule_id in step["active_rule_ids"]:
add_rule(rule_id)
add_rule(step["violated_rule_id"])
for rule_id in step["soft_coverage_delta"]:
add_rule(rule_id)
return ordered_rule_ids
def render_metadata(bundle: dict, annotator_id: str) -> tuple[str, str]:
with st.sidebar:
if st.button("Back To Guideline"):
reset_guideline_gate()
st.rerun()
st.header("Bundle")
st.write(f"`{bundle['bundle_id']}`")
st.write(f"Annotator: `{annotator_id}`")
st.write(f"Jurisdiction: `{bundle['jurisdiction']}`")
st.write(f"Mode: `{bundle['mode']}`")
st.write(f"Storage backend: `{storage_backend()}`")
if storage_backend() == "hf_dataset":
st.write(f"Dataset repo: `{dataset_repo_id()}`")
st.write("Rulebook:")
for rule_id in bundle["rulebook"]:
st.code(rule_id)
status = st.selectbox(
"Bundle Status",
options=STATUSES,
index=STATUSES.index(bundle.get("status", "in_progress")),
key=f"{bundle['bundle_id']}:status",
)
change_notes = st.text_area(
"Change Notes",
value=bundle.get("change_notes", ""),
height=160,
key=f"{bundle['bundle_id']}:change_notes",
help="Short note on what changed from the machine-generated version.",
)
with st.expander("Guideline", expanded=False):
st.markdown(read_guideline())
return status, change_notes
def render_step_editor(bundle: dict, trace: dict, step: dict, rule_options: list[str]) -> dict:
bundle_id = bundle["bundle_id"]
trace_id = trace["trace_id"]
step_id = step["step_id"]
st.markdown(f"**Step {step_id}:** `{step['text']}`")
action_type = st.selectbox(
f"Action Type ({step_id})",
options=sorted(validator.ALLOWED_ACTION_TYPES),
index=sorted(validator.ALLOWED_ACTION_TYPES).index(step["action_type"]),
key=step_key(bundle_id, trace_id, step_id, "action_type"),
)
active_rule_ids = st.multiselect(
f"Active Rule IDs ({step_id})",
options=rule_options,
default=step["active_rule_ids"],
key=step_key(bundle_id, trace_id, step_id, "active_rule_ids"),
)
hard_violation = st.checkbox(
f"Hard Violation ({step_id})",
value=bool(step["hard_violation"]),
key=step_key(bundle_id, trace_id, step_id, "hard_violation"),
)
violated_rule_id = st.selectbox(
f"Violated Rule ID ({step_id})",
options=[None] + rule_options,
index=([None] + rule_options).index(step["violated_rule_id"]),
key=step_key(bundle_id, trace_id, step_id, "violated_rule_id"),
format_func=lambda value: "None" if value is None else value,
)
st.caption("Soft Coverage Delta")
soft_coverage_delta: dict[str, float] = {}
columns = st.columns(len(rule_options) or 1)
for index, rule_id in enumerate(rule_options):
default_value = float(step["soft_coverage_delta"].get(rule_id, 0.0))
with columns[index]:
value = st.number_input(
rule_id,
min_value=0.0,
max_value=1.0,
value=default_value,
step=0.05,
key=step_key(bundle_id, trace_id, step_id, "soft_delta", rule_id),
)
if value > 0:
soft_coverage_delta[rule_id] = round(float(value), 2)
return {
"step_id": step_id,
"action_type": action_type,
"text": step["text"],
"active_rule_ids": active_rule_ids,
"hard_violation": int(hard_violation),
"violated_rule_id": violated_rule_id,
"soft_coverage_delta": soft_coverage_delta,
}
def render_trace_editor(bundle: dict, trace: dict) -> dict:
bundle_id = bundle["bundle_id"]
trace_id = trace["trace_id"]
rule_options = get_rule_options(bundle)
label = st.selectbox(
"Trace Label",
options=TRACE_LABELS,
index=TRACE_LABELS.index(trace["label"]),
key=trace_key(bundle_id, trace_id, "label"),
)
overall_compliant = st.checkbox(
"Overall Compliant",
value=bool(trace["overall_compliant"]),
key=trace_key(bundle_id, trace_id, "overall_compliant"),
)
step_ids = [step["step_id"] for step in trace["steps"]]
first_violation_step = st.selectbox(
"First Violation Step",
options=[None] + step_ids,
index=([None] + step_ids).index(trace["first_violation_step"]),
key=trace_key(bundle_id, trace_id, "first_violation_step"),
format_func=lambda value: "None" if value is None else f"Step {value}",
)
edited_steps = []
for step in trace["steps"]:
with st.container(border=True):
edited_steps.append(render_step_editor(bundle, trace, step, rule_options))
edited_trace = copy.deepcopy(trace)
edited_trace["label"] = label
edited_trace["overall_compliant"] = overall_compliant
edited_trace["first_violation_step"] = first_violation_step
edited_trace["steps"] = edited_steps
return edited_trace
def render_bundle_editor(bundle: dict) -> dict:
tabs = st.tabs([candidate["trace_id"] for candidate in bundle["candidates"]])
edited_candidates = []
for tab, candidate in zip(tabs, bundle["candidates"]):
with tab:
edited_candidates.append(render_trace_editor(bundle, candidate))
edited_bundle = copy.deepcopy(bundle)
edited_bundle["candidates"] = edited_candidates
return edited_bundle
def render_validation_panel(bundle: dict, valid_rule_ids: set[str]) -> None:
result = validator.validate_single_bundle(bundle, valid_rule_ids)
with st.expander("Validation", expanded=True):
st.write(
{
"ok": result["ok"],
"errors": len(result["errors"]),
"warnings": len(result["warnings"]),
}
)
if result["errors"]:
st.error("\n".join(result["errors"]))
if result["warnings"]:
st.warning("\n".join(result["warnings"]))
if not result["errors"] and not result["warnings"]:
st.success("No validation issues detected.")
def main() -> None:
st.set_page_config(page_title="CPRM Annotation App", layout="wide")
require_password()
current_annotator = st.session_state.get("annotator_id", "").strip()
if (
not st.session_state.get("guideline_acknowledged")
or st.session_state.get("guideline_confirmed_for") != current_annotator
):
render_guideline_gate()
source_bundles = load_source_bundles()
valid_rule_ids = validator.load_rule_ids(validator.RULE_CARDS_PATH)
st.title("CPRM Pilot Annotation App")
st.caption(
"Step 2 of 2: annotate one of the 6 calibration bundles. Existing steps are editable, but step count is fixed."
)
bundle_id = st.selectbox("Bundle", options=TARGET_BUNDLE_IDS)
source_bundle = source_bundles[bundle_id]
working_bundle = ensure_working_annotation(source_bundle, current_annotator, bundle_id)
status, change_notes = render_metadata(working_bundle, current_annotator)
left, right = st.columns([3, 2])
with left:
edited_bundle = render_bundle_editor(working_bundle)
with right:
st.subheader("Scenario")
st.json(
{
"bundle_id": source_bundle["bundle_id"],
"scenario_id": source_bundle["scenario_id"],
"intent_id": source_bundle["intent_id"],
"jurisdiction": source_bundle["jurisdiction"],
"mode": source_bundle["mode"],
"rulebook": source_bundle["rulebook"],
},
expanded=False,
)
edited_bundle["annotator_id"] = current_annotator
edited_bundle["status"] = status
edited_bundle["change_notes"] = change_notes
edited_bundle["updated_at"] = working_bundle.get("updated_at")
render_validation_panel(edited_bundle, valid_rule_ids)
col1, col2 = st.columns(2)
with col1:
if st.button("Save Annotation", type="primary"):
saved_path = save_annotation(edited_bundle)
st.session_state["working_bundle"] = copy.deepcopy(edited_bundle)
st.success(f"Saved to {saved_path}")
with col2:
st.download_button(
"Download JSON",
data=json.dumps(edited_bundle, indent=2, ensure_ascii=False),
file_name=f"{edited_bundle['bundle_id']}.json",
mime="application/json",
)
if __name__ == "__main__":
main()