#!/usr/bin/env python3 from __future__ import annotations import copy import io import json import os import sys from datetime import datetime, timezone from pathlib import Path import streamlit as st from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.utils import EntryNotFoundError SCRIPT_DIR = Path(__file__).resolve().parent if str(SCRIPT_DIR) not in sys.path: sys.path.insert(0, str(SCRIPT_DIR)) import validate_compliance_prm as validator ROOT = Path(__file__).resolve().parents[1] SOURCE_BUNDLES_PATH = ROOT / "data" / "bundles" / "pilot_bundles_v1.jsonl" GUIDELINE_PATH = ROOT / "data" / "docs" / "pilot_annotation_guideline_v1.md" ANNOTATIONS_DIR = ROOT / "data" / "annotations" TARGET_BUNDLE_IDS = [ "A17_CN_BUNDLE", "A17_US_BUNDLE", "A17_ISLAMIC_BUNDLE", "M29_CN_BUNDLE", "M29_US_BUNDLE", "M29_ISLAMIC_BUNDLE", ] TRACE_LABELS = ["compliant", "deadline_missed", "hard_violation"] STATUSES = ["in_progress", "final"] def load_jsonl(path: Path) -> list[dict]: with path.open("r", encoding="utf-8") as handle: return [json.loads(line) for line in handle if line.strip()] def load_source_bundles() -> dict[str, dict]: bundles = { bundle["bundle_id"]: bundle for bundle in load_jsonl(SOURCE_BUNDLES_PATH) if bundle["bundle_id"] in TARGET_BUNDLE_IDS } return {bundle_id: bundles[bundle_id] for bundle_id in TARGET_BUNDLE_IDS} def annotation_path(annotator_id: str, bundle_id: str) -> Path: return ANNOTATIONS_DIR / annotator_id / f"{bundle_id}.json" def dataset_repo_id() -> str: return os.getenv("HF_DATASET_REPO", "").strip() def dataset_repo_subdir() -> str: return os.getenv("HF_DATASET_SUBDIR", "annotations").strip().strip("/") or "annotations" def hf_token() -> str: for key in ("HF_TOKEN", "HUGGINGFACEHUB_API_TOKEN"): value = os.getenv(key, "").strip() if value: return value return "" def storage_backend() -> str: if dataset_repo_id() and hf_token(): return "hf_dataset" return "local" def dataset_repo_path(annotator_id: str, bundle_id: str) -> str: return f"{dataset_repo_subdir()}/{annotator_id}/{bundle_id}.json" def build_initial_annotation(bundle: dict, annotator_id: str) -> dict: annotation = copy.deepcopy(bundle) annotation["annotator_id"] = annotator_id annotation["status"] = "in_progress" annotation["updated_at"] = None annotation["change_notes"] = "" return annotation def save_local_annotation(payload: dict) -> Path: path = annotation_path(payload["annotator_id"], payload["bundle_id"]) path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8") as handle: json.dump(payload, handle, indent=2, ensure_ascii=False) return path def load_remote_annotation(bundle: dict, annotator_id: str) -> dict | None: try: downloaded_path = hf_hub_download( repo_id=dataset_repo_id(), filename=dataset_repo_path(annotator_id, bundle["bundle_id"]), repo_type="dataset", token=hf_token(), ) except EntryNotFoundError: return None except Exception: return None with Path(downloaded_path).open("r", encoding="utf-8") as handle: return json.load(handle) def save_remote_annotation(payload: dict) -> str: repo_id = dataset_repo_id() api = HfApi(token=hf_token()) api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=True) repo_path = dataset_repo_path(payload["annotator_id"], payload["bundle_id"]) payload_bytes = json.dumps(payload, indent=2, ensure_ascii=False).encode("utf-8") api.upload_file( path_or_fileobj=io.BytesIO(payload_bytes), path_in_repo=repo_path, repo_id=repo_id, repo_type="dataset", commit_message=f"Update annotation: {payload['bundle_id']} ({payload['annotator_id']})", ) return f"hf://datasets/{repo_id}/{repo_path}" def load_annotation(bundle: dict, annotator_id: str) -> dict: path = annotation_path(annotator_id, bundle["bundle_id"]) if path.exists(): with path.open("r", encoding="utf-8") as handle: return json.load(handle) if storage_backend() == "hf_dataset": remote = load_remote_annotation(bundle, annotator_id) if remote is not None: return remote return build_initial_annotation(bundle, annotator_id) def save_annotation(annotation: dict) -> str: payload = copy.deepcopy(annotation) payload["updated_at"] = datetime.now(timezone.utc).isoformat() local_path = save_local_annotation(payload) if storage_backend() == "hf_dataset": remote_path = save_remote_annotation(payload) return f"{remote_path} (local mirror: {local_path})" return str(local_path) def require_password() -> None: expected_password = os.getenv("ANNOTATION_APP_PASSWORD", "").strip() if not expected_password: return if st.session_state.get("authenticated"): return st.title("CPRM Annotation App") st.caption("This instance is password-protected.") typed_password = st.text_input("Shared Password", type="password") if st.button("Unlock"): if typed_password == expected_password: st.session_state["authenticated"] = True st.rerun() st.error("Incorrect password.") st.stop() def read_guideline() -> str: if GUIDELINE_PATH.exists(): return GUIDELINE_PATH.read_text(encoding="utf-8") return "Guideline file not found. Generate `data/docs/pilot_annotation_guideline_v1.md` first." def reset_guideline_gate() -> None: st.session_state["guideline_acknowledged"] = False st.session_state["guideline_confirmed_for"] = None def render_guideline_gate() -> None: st.title("CPRM Pilot Annotation App") st.caption("Step 1 of 2: read the guideline, confirm it, then enter the annotation workspace.") annotator_id = st.text_input( "Annotator ID", value=st.session_state.get("annotator_id", "solo_annotator"), help="Use a stable ID so saved files go to a consistent annotation folder.", ).strip() st.session_state["annotator_id"] = annotator_id if not annotator_id: st.info("Enter an annotator ID before continuing.") st.stop() st.subheader("Guideline") st.markdown(read_guideline()) acknowledged = st.checkbox( "I have read the guideline and I understand that round 1 only edits existing fields and does not change step count.", value=False, key="guideline_ack_checkbox", ) if st.button("Enter Annotation Workspace", type="primary", disabled=not acknowledged): st.session_state["guideline_acknowledged"] = True st.session_state["guideline_confirmed_for"] = annotator_id st.rerun() st.stop() def ensure_working_annotation(source_bundle: dict, annotator_id: str, bundle_id: str) -> dict: state_key = "working_bundle_key" target_key = f"{annotator_id}:{bundle_id}" if st.session_state.get(state_key) != target_key: st.session_state[state_key] = target_key st.session_state["working_bundle"] = load_annotation(source_bundle, annotator_id) return copy.deepcopy(st.session_state["working_bundle"]) def step_key(bundle_id: str, trace_id: str, step_id: int, field: str, suffix: str = "") -> str: extra = f":{suffix}" if suffix else "" return f"{bundle_id}:{trace_id}:{step_id}:{field}{extra}" def trace_key(bundle_id: str, trace_id: str, field: str) -> str: return f"{bundle_id}:{trace_id}:{field}" def get_rule_options(bundle: dict) -> list[str]: seen: set[str] = set() ordered_rule_ids: list[str] = [] def add_rule(rule_id: str | None) -> None: if not rule_id or rule_id in seen: return seen.add(rule_id) ordered_rule_ids.append(rule_id) for rule_id in bundle["rulebook"]: add_rule(rule_id) for candidate in bundle["candidates"]: for step in candidate["steps"]: for rule_id in step["active_rule_ids"]: add_rule(rule_id) add_rule(step["violated_rule_id"]) for rule_id in step["soft_coverage_delta"]: add_rule(rule_id) return ordered_rule_ids def render_metadata(bundle: dict, annotator_id: str) -> tuple[str, str]: with st.sidebar: if st.button("Back To Guideline"): reset_guideline_gate() st.rerun() st.header("Bundle") st.write(f"`{bundle['bundle_id']}`") st.write(f"Annotator: `{annotator_id}`") st.write(f"Jurisdiction: `{bundle['jurisdiction']}`") st.write(f"Mode: `{bundle['mode']}`") st.write(f"Storage backend: `{storage_backend()}`") if storage_backend() == "hf_dataset": st.write(f"Dataset repo: `{dataset_repo_id()}`") st.write("Rulebook:") for rule_id in bundle["rulebook"]: st.code(rule_id) status = st.selectbox( "Bundle Status", options=STATUSES, index=STATUSES.index(bundle.get("status", "in_progress")), key=f"{bundle['bundle_id']}:status", ) change_notes = st.text_area( "Change Notes", value=bundle.get("change_notes", ""), height=160, key=f"{bundle['bundle_id']}:change_notes", help="Short note on what changed from the machine-generated version.", ) with st.expander("Guideline", expanded=False): st.markdown(read_guideline()) return status, change_notes def render_step_editor(bundle: dict, trace: dict, step: dict, rule_options: list[str]) -> dict: bundle_id = bundle["bundle_id"] trace_id = trace["trace_id"] step_id = step["step_id"] st.markdown(f"**Step {step_id}:** `{step['text']}`") action_type = st.selectbox( f"Action Type ({step_id})", options=sorted(validator.ALLOWED_ACTION_TYPES), index=sorted(validator.ALLOWED_ACTION_TYPES).index(step["action_type"]), key=step_key(bundle_id, trace_id, step_id, "action_type"), ) active_rule_ids = st.multiselect( f"Active Rule IDs ({step_id})", options=rule_options, default=step["active_rule_ids"], key=step_key(bundle_id, trace_id, step_id, "active_rule_ids"), ) hard_violation = st.checkbox( f"Hard Violation ({step_id})", value=bool(step["hard_violation"]), key=step_key(bundle_id, trace_id, step_id, "hard_violation"), ) violated_rule_id = st.selectbox( f"Violated Rule ID ({step_id})", options=[None] + rule_options, index=([None] + rule_options).index(step["violated_rule_id"]), key=step_key(bundle_id, trace_id, step_id, "violated_rule_id"), format_func=lambda value: "None" if value is None else value, ) st.caption("Soft Coverage Delta") soft_coverage_delta: dict[str, float] = {} columns = st.columns(len(rule_options) or 1) for index, rule_id in enumerate(rule_options): default_value = float(step["soft_coverage_delta"].get(rule_id, 0.0)) with columns[index]: value = st.number_input( rule_id, min_value=0.0, max_value=1.0, value=default_value, step=0.05, key=step_key(bundle_id, trace_id, step_id, "soft_delta", rule_id), ) if value > 0: soft_coverage_delta[rule_id] = round(float(value), 2) return { "step_id": step_id, "action_type": action_type, "text": step["text"], "active_rule_ids": active_rule_ids, "hard_violation": int(hard_violation), "violated_rule_id": violated_rule_id, "soft_coverage_delta": soft_coverage_delta, } def render_trace_editor(bundle: dict, trace: dict) -> dict: bundle_id = bundle["bundle_id"] trace_id = trace["trace_id"] rule_options = get_rule_options(bundle) label = st.selectbox( "Trace Label", options=TRACE_LABELS, index=TRACE_LABELS.index(trace["label"]), key=trace_key(bundle_id, trace_id, "label"), ) overall_compliant = st.checkbox( "Overall Compliant", value=bool(trace["overall_compliant"]), key=trace_key(bundle_id, trace_id, "overall_compliant"), ) step_ids = [step["step_id"] for step in trace["steps"]] first_violation_step = st.selectbox( "First Violation Step", options=[None] + step_ids, index=([None] + step_ids).index(trace["first_violation_step"]), key=trace_key(bundle_id, trace_id, "first_violation_step"), format_func=lambda value: "None" if value is None else f"Step {value}", ) edited_steps = [] for step in trace["steps"]: with st.container(border=True): edited_steps.append(render_step_editor(bundle, trace, step, rule_options)) edited_trace = copy.deepcopy(trace) edited_trace["label"] = label edited_trace["overall_compliant"] = overall_compliant edited_trace["first_violation_step"] = first_violation_step edited_trace["steps"] = edited_steps return edited_trace def render_bundle_editor(bundle: dict) -> dict: tabs = st.tabs([candidate["trace_id"] for candidate in bundle["candidates"]]) edited_candidates = [] for tab, candidate in zip(tabs, bundle["candidates"]): with tab: edited_candidates.append(render_trace_editor(bundle, candidate)) edited_bundle = copy.deepcopy(bundle) edited_bundle["candidates"] = edited_candidates return edited_bundle def render_validation_panel(bundle: dict, valid_rule_ids: set[str]) -> None: result = validator.validate_single_bundle(bundle, valid_rule_ids) with st.expander("Validation", expanded=True): st.write( { "ok": result["ok"], "errors": len(result["errors"]), "warnings": len(result["warnings"]), } ) if result["errors"]: st.error("\n".join(result["errors"])) if result["warnings"]: st.warning("\n".join(result["warnings"])) if not result["errors"] and not result["warnings"]: st.success("No validation issues detected.") def main() -> None: st.set_page_config(page_title="CPRM Annotation App", layout="wide") require_password() current_annotator = st.session_state.get("annotator_id", "").strip() if ( not st.session_state.get("guideline_acknowledged") or st.session_state.get("guideline_confirmed_for") != current_annotator ): render_guideline_gate() source_bundles = load_source_bundles() valid_rule_ids = validator.load_rule_ids(validator.RULE_CARDS_PATH) st.title("CPRM Pilot Annotation App") st.caption( "Step 2 of 2: annotate one of the 6 calibration bundles. Existing steps are editable, but step count is fixed." ) bundle_id = st.selectbox("Bundle", options=TARGET_BUNDLE_IDS) source_bundle = source_bundles[bundle_id] working_bundle = ensure_working_annotation(source_bundle, current_annotator, bundle_id) status, change_notes = render_metadata(working_bundle, current_annotator) left, right = st.columns([3, 2]) with left: edited_bundle = render_bundle_editor(working_bundle) with right: st.subheader("Scenario") st.json( { "bundle_id": source_bundle["bundle_id"], "scenario_id": source_bundle["scenario_id"], "intent_id": source_bundle["intent_id"], "jurisdiction": source_bundle["jurisdiction"], "mode": source_bundle["mode"], "rulebook": source_bundle["rulebook"], }, expanded=False, ) edited_bundle["annotator_id"] = current_annotator edited_bundle["status"] = status edited_bundle["change_notes"] = change_notes edited_bundle["updated_at"] = working_bundle.get("updated_at") render_validation_panel(edited_bundle, valid_rule_ids) col1, col2 = st.columns(2) with col1: if st.button("Save Annotation", type="primary"): saved_path = save_annotation(edited_bundle) st.session_state["working_bundle"] = copy.deepcopy(edited_bundle) st.success(f"Saved to {saved_path}") with col2: st.download_button( "Download JSON", data=json.dumps(edited_bundle, indent=2, ensure_ascii=False), file_name=f"{edited_bundle['bundle_id']}.json", mime="application/json", ) if __name__ == "__main__": main()