File size: 16,863 Bytes
7aceaa5
 
 
 
 
3f6201e
7aceaa5
 
 
 
 
 
 
3f6201e
 
7aceaa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f6201e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7aceaa5
 
 
 
 
 
 
 
 
3f6201e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7aceaa5
 
 
 
 
3f6201e
 
 
 
7aceaa5
 
 
3f6201e
7aceaa5
 
3f6201e
 
 
 
 
7aceaa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f6201e
7aceaa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f6201e
 
 
7aceaa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f6201e
 
 
7aceaa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
#!/usr/bin/env python3

from __future__ import annotations

import copy
import io
import json
import os
import sys
from datetime import datetime, timezone
from pathlib import Path

import streamlit as st
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError

SCRIPT_DIR = Path(__file__).resolve().parent
if str(SCRIPT_DIR) not in sys.path:
    sys.path.insert(0, str(SCRIPT_DIR))

import validate_compliance_prm as validator

ROOT = Path(__file__).resolve().parents[1]
SOURCE_BUNDLES_PATH = ROOT / "data" / "bundles" / "pilot_bundles_v1.jsonl"
GUIDELINE_PATH = ROOT / "data" / "docs" / "pilot_annotation_guideline_v1.md"
ANNOTATIONS_DIR = ROOT / "data" / "annotations"

TARGET_BUNDLE_IDS = [
    "A17_CN_BUNDLE",
    "A17_US_BUNDLE",
    "A17_ISLAMIC_BUNDLE",
    "M29_CN_BUNDLE",
    "M29_US_BUNDLE",
    "M29_ISLAMIC_BUNDLE",
]
TRACE_LABELS = ["compliant", "deadline_missed", "hard_violation"]
STATUSES = ["in_progress", "final"]


def load_jsonl(path: Path) -> list[dict]:
    with path.open("r", encoding="utf-8") as handle:
        return [json.loads(line) for line in handle if line.strip()]


def load_source_bundles() -> dict[str, dict]:
    bundles = {
        bundle["bundle_id"]: bundle
        for bundle in load_jsonl(SOURCE_BUNDLES_PATH)
        if bundle["bundle_id"] in TARGET_BUNDLE_IDS
    }
    return {bundle_id: bundles[bundle_id] for bundle_id in TARGET_BUNDLE_IDS}


def annotation_path(annotator_id: str, bundle_id: str) -> Path:
    return ANNOTATIONS_DIR / annotator_id / f"{bundle_id}.json"


def dataset_repo_id() -> str:
    return os.getenv("HF_DATASET_REPO", "").strip()


def dataset_repo_subdir() -> str:
    return os.getenv("HF_DATASET_SUBDIR", "annotations").strip().strip("/") or "annotations"


def hf_token() -> str:
    for key in ("HF_TOKEN", "HUGGINGFACEHUB_API_TOKEN"):
        value = os.getenv(key, "").strip()
        if value:
            return value
    return ""


def storage_backend() -> str:
    if dataset_repo_id() and hf_token():
        return "hf_dataset"
    return "local"


def dataset_repo_path(annotator_id: str, bundle_id: str) -> str:
    return f"{dataset_repo_subdir()}/{annotator_id}/{bundle_id}.json"


def build_initial_annotation(bundle: dict, annotator_id: str) -> dict:
    annotation = copy.deepcopy(bundle)
    annotation["annotator_id"] = annotator_id
    annotation["status"] = "in_progress"
    annotation["updated_at"] = None
    annotation["change_notes"] = ""
    return annotation


def save_local_annotation(payload: dict) -> Path:
    path = annotation_path(payload["annotator_id"], payload["bundle_id"])
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as handle:
        json.dump(payload, handle, indent=2, ensure_ascii=False)
    return path


def load_remote_annotation(bundle: dict, annotator_id: str) -> dict | None:
    try:
        downloaded_path = hf_hub_download(
            repo_id=dataset_repo_id(),
            filename=dataset_repo_path(annotator_id, bundle["bundle_id"]),
            repo_type="dataset",
            token=hf_token(),
        )
    except EntryNotFoundError:
        return None
    except Exception:
        return None

    with Path(downloaded_path).open("r", encoding="utf-8") as handle:
        return json.load(handle)


def save_remote_annotation(payload: dict) -> str:
    repo_id = dataset_repo_id()
    api = HfApi(token=hf_token())
    api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=True)

    repo_path = dataset_repo_path(payload["annotator_id"], payload["bundle_id"])
    payload_bytes = json.dumps(payload, indent=2, ensure_ascii=False).encode("utf-8")
    api.upload_file(
        path_or_fileobj=io.BytesIO(payload_bytes),
        path_in_repo=repo_path,
        repo_id=repo_id,
        repo_type="dataset",
        commit_message=f"Update annotation: {payload['bundle_id']} ({payload['annotator_id']})",
    )
    return f"hf://datasets/{repo_id}/{repo_path}"


def load_annotation(bundle: dict, annotator_id: str) -> dict:
    path = annotation_path(annotator_id, bundle["bundle_id"])
    if path.exists():
        with path.open("r", encoding="utf-8") as handle:
            return json.load(handle)
    if storage_backend() == "hf_dataset":
        remote = load_remote_annotation(bundle, annotator_id)
        if remote is not None:
            return remote
    return build_initial_annotation(bundle, annotator_id)


def save_annotation(annotation: dict) -> str:
    payload = copy.deepcopy(annotation)
    payload["updated_at"] = datetime.now(timezone.utc).isoformat()
    local_path = save_local_annotation(payload)
    if storage_backend() == "hf_dataset":
        remote_path = save_remote_annotation(payload)
        return f"{remote_path} (local mirror: {local_path})"
    return str(local_path)


def require_password() -> None:
    expected_password = os.getenv("ANNOTATION_APP_PASSWORD", "").strip()
    if not expected_password:
        return

    if st.session_state.get("authenticated"):
        return

    st.title("CPRM Annotation App")
    st.caption("This instance is password-protected.")
    typed_password = st.text_input("Shared Password", type="password")
    if st.button("Unlock"):
        if typed_password == expected_password:
            st.session_state["authenticated"] = True
            st.rerun()
        st.error("Incorrect password.")
    st.stop()


def read_guideline() -> str:
    if GUIDELINE_PATH.exists():
        return GUIDELINE_PATH.read_text(encoding="utf-8")
    return "Guideline file not found. Generate `data/docs/pilot_annotation_guideline_v1.md` first."


def reset_guideline_gate() -> None:
    st.session_state["guideline_acknowledged"] = False
    st.session_state["guideline_confirmed_for"] = None


def render_guideline_gate() -> None:
    st.title("CPRM Pilot Annotation App")
    st.caption("Step 1 of 2: read the guideline, confirm it, then enter the annotation workspace.")

    annotator_id = st.text_input(
        "Annotator ID",
        value=st.session_state.get("annotator_id", "solo_annotator"),
        help="Use a stable ID so saved files go to a consistent annotation folder.",
    ).strip()
    st.session_state["annotator_id"] = annotator_id

    if not annotator_id:
        st.info("Enter an annotator ID before continuing.")
        st.stop()

    st.subheader("Guideline")
    st.markdown(read_guideline())

    acknowledged = st.checkbox(
        "I have read the guideline and I understand that round 1 only edits existing fields and does not change step count.",
        value=False,
        key="guideline_ack_checkbox",
    )

    if st.button("Enter Annotation Workspace", type="primary", disabled=not acknowledged):
        st.session_state["guideline_acknowledged"] = True
        st.session_state["guideline_confirmed_for"] = annotator_id
        st.rerun()

    st.stop()


def ensure_working_annotation(source_bundle: dict, annotator_id: str, bundle_id: str) -> dict:
    state_key = "working_bundle_key"
    target_key = f"{annotator_id}:{bundle_id}"
    if st.session_state.get(state_key) != target_key:
        st.session_state[state_key] = target_key
        st.session_state["working_bundle"] = load_annotation(source_bundle, annotator_id)
    return copy.deepcopy(st.session_state["working_bundle"])


def step_key(bundle_id: str, trace_id: str, step_id: int, field: str, suffix: str = "") -> str:
    extra = f":{suffix}" if suffix else ""
    return f"{bundle_id}:{trace_id}:{step_id}:{field}{extra}"


def trace_key(bundle_id: str, trace_id: str, field: str) -> str:
    return f"{bundle_id}:{trace_id}:{field}"


def get_rule_options(bundle: dict) -> list[str]:
    seen: set[str] = set()
    ordered_rule_ids: list[str] = []

    def add_rule(rule_id: str | None) -> None:
        if not rule_id or rule_id in seen:
            return
        seen.add(rule_id)
        ordered_rule_ids.append(rule_id)

    for rule_id in bundle["rulebook"]:
        add_rule(rule_id)

    for candidate in bundle["candidates"]:
        for step in candidate["steps"]:
            for rule_id in step["active_rule_ids"]:
                add_rule(rule_id)
            add_rule(step["violated_rule_id"])
            for rule_id in step["soft_coverage_delta"]:
                add_rule(rule_id)

    return ordered_rule_ids


def render_metadata(bundle: dict, annotator_id: str) -> tuple[str, str]:
    with st.sidebar:
        if st.button("Back To Guideline"):
            reset_guideline_gate()
            st.rerun()

        st.header("Bundle")
        st.write(f"`{bundle['bundle_id']}`")
        st.write(f"Annotator: `{annotator_id}`")
        st.write(f"Jurisdiction: `{bundle['jurisdiction']}`")
        st.write(f"Mode: `{bundle['mode']}`")
        st.write(f"Storage backend: `{storage_backend()}`")
        if storage_backend() == "hf_dataset":
            st.write(f"Dataset repo: `{dataset_repo_id()}`")
        st.write("Rulebook:")
        for rule_id in bundle["rulebook"]:
            st.code(rule_id)

        status = st.selectbox(
            "Bundle Status",
            options=STATUSES,
            index=STATUSES.index(bundle.get("status", "in_progress")),
            key=f"{bundle['bundle_id']}:status",
        )
        change_notes = st.text_area(
            "Change Notes",
            value=bundle.get("change_notes", ""),
            height=160,
            key=f"{bundle['bundle_id']}:change_notes",
            help="Short note on what changed from the machine-generated version.",
        )

        with st.expander("Guideline", expanded=False):
            st.markdown(read_guideline())

    return status, change_notes


def render_step_editor(bundle: dict, trace: dict, step: dict, rule_options: list[str]) -> dict:
    bundle_id = bundle["bundle_id"]
    trace_id = trace["trace_id"]
    step_id = step["step_id"]

    st.markdown(f"**Step {step_id}:** `{step['text']}`")
    action_type = st.selectbox(
        f"Action Type ({step_id})",
        options=sorted(validator.ALLOWED_ACTION_TYPES),
        index=sorted(validator.ALLOWED_ACTION_TYPES).index(step["action_type"]),
        key=step_key(bundle_id, trace_id, step_id, "action_type"),
    )
    active_rule_ids = st.multiselect(
        f"Active Rule IDs ({step_id})",
        options=rule_options,
        default=step["active_rule_ids"],
        key=step_key(bundle_id, trace_id, step_id, "active_rule_ids"),
    )
    hard_violation = st.checkbox(
        f"Hard Violation ({step_id})",
        value=bool(step["hard_violation"]),
        key=step_key(bundle_id, trace_id, step_id, "hard_violation"),
    )
    violated_rule_id = st.selectbox(
        f"Violated Rule ID ({step_id})",
        options=[None] + rule_options,
        index=([None] + rule_options).index(step["violated_rule_id"]),
        key=step_key(bundle_id, trace_id, step_id, "violated_rule_id"),
        format_func=lambda value: "None" if value is None else value,
    )

    st.caption("Soft Coverage Delta")
    soft_coverage_delta: dict[str, float] = {}
    columns = st.columns(len(rule_options) or 1)
    for index, rule_id in enumerate(rule_options):
        default_value = float(step["soft_coverage_delta"].get(rule_id, 0.0))
        with columns[index]:
            value = st.number_input(
                rule_id,
                min_value=0.0,
                max_value=1.0,
                value=default_value,
                step=0.05,
                key=step_key(bundle_id, trace_id, step_id, "soft_delta", rule_id),
            )
            if value > 0:
                soft_coverage_delta[rule_id] = round(float(value), 2)

    return {
        "step_id": step_id,
        "action_type": action_type,
        "text": step["text"],
        "active_rule_ids": active_rule_ids,
        "hard_violation": int(hard_violation),
        "violated_rule_id": violated_rule_id,
        "soft_coverage_delta": soft_coverage_delta,
    }


def render_trace_editor(bundle: dict, trace: dict) -> dict:
    bundle_id = bundle["bundle_id"]
    trace_id = trace["trace_id"]
    rule_options = get_rule_options(bundle)

    label = st.selectbox(
        "Trace Label",
        options=TRACE_LABELS,
        index=TRACE_LABELS.index(trace["label"]),
        key=trace_key(bundle_id, trace_id, "label"),
    )
    overall_compliant = st.checkbox(
        "Overall Compliant",
        value=bool(trace["overall_compliant"]),
        key=trace_key(bundle_id, trace_id, "overall_compliant"),
    )
    step_ids = [step["step_id"] for step in trace["steps"]]
    first_violation_step = st.selectbox(
        "First Violation Step",
        options=[None] + step_ids,
        index=([None] + step_ids).index(trace["first_violation_step"]),
        key=trace_key(bundle_id, trace_id, "first_violation_step"),
        format_func=lambda value: "None" if value is None else f"Step {value}",
    )

    edited_steps = []
    for step in trace["steps"]:
        with st.container(border=True):
            edited_steps.append(render_step_editor(bundle, trace, step, rule_options))

    edited_trace = copy.deepcopy(trace)
    edited_trace["label"] = label
    edited_trace["overall_compliant"] = overall_compliant
    edited_trace["first_violation_step"] = first_violation_step
    edited_trace["steps"] = edited_steps
    return edited_trace


def render_bundle_editor(bundle: dict) -> dict:
    tabs = st.tabs([candidate["trace_id"] for candidate in bundle["candidates"]])
    edited_candidates = []
    for tab, candidate in zip(tabs, bundle["candidates"]):
        with tab:
            edited_candidates.append(render_trace_editor(bundle, candidate))

    edited_bundle = copy.deepcopy(bundle)
    edited_bundle["candidates"] = edited_candidates
    return edited_bundle


def render_validation_panel(bundle: dict, valid_rule_ids: set[str]) -> None:
    result = validator.validate_single_bundle(bundle, valid_rule_ids)
    with st.expander("Validation", expanded=True):
        st.write(
            {
                "ok": result["ok"],
                "errors": len(result["errors"]),
                "warnings": len(result["warnings"]),
            }
        )
        if result["errors"]:
            st.error("\n".join(result["errors"]))
        if result["warnings"]:
            st.warning("\n".join(result["warnings"]))
        if not result["errors"] and not result["warnings"]:
            st.success("No validation issues detected.")


def main() -> None:
    st.set_page_config(page_title="CPRM Annotation App", layout="wide")
    require_password()

    current_annotator = st.session_state.get("annotator_id", "").strip()
    if (
        not st.session_state.get("guideline_acknowledged")
        or st.session_state.get("guideline_confirmed_for") != current_annotator
    ):
        render_guideline_gate()

    source_bundles = load_source_bundles()
    valid_rule_ids = validator.load_rule_ids(validator.RULE_CARDS_PATH)

    st.title("CPRM Pilot Annotation App")
    st.caption(
        "Step 2 of 2: annotate one of the 6 calibration bundles. Existing steps are editable, but step count is fixed."
    )

    bundle_id = st.selectbox("Bundle", options=TARGET_BUNDLE_IDS)
    source_bundle = source_bundles[bundle_id]
    working_bundle = ensure_working_annotation(source_bundle, current_annotator, bundle_id)
    status, change_notes = render_metadata(working_bundle, current_annotator)

    left, right = st.columns([3, 2])
    with left:
        edited_bundle = render_bundle_editor(working_bundle)
    with right:
        st.subheader("Scenario")
        st.json(
            {
                "bundle_id": source_bundle["bundle_id"],
                "scenario_id": source_bundle["scenario_id"],
                "intent_id": source_bundle["intent_id"],
                "jurisdiction": source_bundle["jurisdiction"],
                "mode": source_bundle["mode"],
                "rulebook": source_bundle["rulebook"],
            },
            expanded=False,
        )

    edited_bundle["annotator_id"] = current_annotator
    edited_bundle["status"] = status
    edited_bundle["change_notes"] = change_notes
    edited_bundle["updated_at"] = working_bundle.get("updated_at")

    render_validation_panel(edited_bundle, valid_rule_ids)

    col1, col2 = st.columns(2)
    with col1:
        if st.button("Save Annotation", type="primary"):
            saved_path = save_annotation(edited_bundle)
            st.session_state["working_bundle"] = copy.deepcopy(edited_bundle)
            st.success(f"Saved to {saved_path}")
    with col2:
        st.download_button(
            "Download JSON",
            data=json.dumps(edited_bundle, indent=2, ensure_ascii=False),
            file_name=f"{edited_bundle['bundle_id']}.json",
            mime="application/json",
        )


if __name__ == "__main__":
    main()