Spaces:
Running
Running
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import altair as alt | |
| from datasets import load_from_disk | |
| from huggingface_hub import snapshot_download | |
| import colorsys | |
| import html | |
| import os | |
| import streamlit.components.v1 as components | |
| # text utils | |
| LABEL_ORDER = [ | |
| "FA - Factual Argument", | |
| "FA - Factual Question", | |
| "FO - Formal Question", | |
| "FO - Precedent", | |
| "FO - Systematic Interpretation", | |
| "FO - Textual Interpretation", | |
| "SU - Non-Legal Argument", | |
| "SU - Proportionality Analysis", | |
| "SU - Substantive Question", | |
| "SU - Teleological or Purposive Interpretation", | |
| "Negative Frame (ISS-N)", | |
| "Positive Frame (ISS-P)", | |
| "Crime Frame (JUST-C)", | |
| "Health Frame (JUST-H)", | |
| "National Security Frame (JUST-S)", | |
| "Rights Frame (JUST-R)", | |
| ] | |
| def concat_global_text(df, webcast_id, text_col="text"): | |
| rows = df[df["webcast_id"] == webcast_id] | |
| if "segment_id" in rows.columns: | |
| rows = rows.sort_values(["segment_id", "sequence_id"]) | |
| elif "paragraph_id" in rows.columns: | |
| rows = rows.sort_values("paragraph_id") | |
| return " ".join(rows[text_col].fillna("").tolist()) | |
| def sanity_check(df_ann, text_len): | |
| if df_ann.empty: | |
| return True | |
| return df_ann["global_end"].max() <= text_len | |
| def get_hf_dataset_root(): | |
| repo = os.getenv("HF_DATASET_REPO") | |
| if not repo: | |
| return None | |
| cached_root = st.session_state.get("hf_dataset_root") | |
| if cached_root: | |
| return cached_root | |
| token = os.getenv("HF_TOKEN") | |
| if not token: | |
| st.error("HF_TOKEN secret missing for private dataset access.") | |
| st.stop() | |
| cache_dir = os.path.join(os.getcwd(), ".hf_data_cache") | |
| try: | |
| snapshot_path = snapshot_download( | |
| repo_id=repo, | |
| repo_type="dataset", | |
| token=token, | |
| local_dir=cache_dir, | |
| local_dir_use_symlinks=False, | |
| ) | |
| except Exception as exc: | |
| st.error(f"Failed to load dataset repo: {exc}") | |
| st.stop() | |
| st.session_state["hf_dataset_root"] = snapshot_path | |
| return snapshot_path | |
| def resolve_dataset_path(relative_path): | |
| root = get_hf_dataset_root() | |
| return os.path.join(root, relative_path) if root else relative_path | |
| def format_char_count(value): | |
| n = int(value) | |
| if n >= 10000: | |
| return f"{int(round(n / 1000.0))}k" | |
| if n >= 100: | |
| scaled = round(n / 1000.0, 1) | |
| text = f"{scaled:.1f}".rstrip("0").rstrip(".") | |
| return f"{text}k" | |
| return str(n) | |
| def normalize_section_title(text): | |
| return " ".join(str(text).split()) if text else "Unknown" | |
| def compute_hearing_sections(df_text): | |
| if df_text is None or df_text.empty: | |
| return [] | |
| rows = df_text.sort_values(["segment_id", "sequence_id"]) | |
| sections = [] | |
| cursor = 0 | |
| for _, seg_rows in rows.groupby("segment_id"): | |
| speaker = ( | |
| seg_rows.iloc[0].get("speaker_name") | |
| or seg_rows.iloc[0].get("speaker_role") | |
| or "Unknown" | |
| ) | |
| seg_start = cursor | |
| pieces = [] | |
| for _, r in seg_rows.iterrows(): | |
| t = r["text"] or "" | |
| pieces.append(t) | |
| cursor += len(t) + 1 | |
| segment_text = " ".join(pieces) | |
| seg_end = seg_start + len(segment_text) | |
| sections.append({ | |
| "name": normalize_section_title(speaker), | |
| "start": seg_start, | |
| "end": seg_end, | |
| }) | |
| return sections | |
| def compute_judgment_sections(df_text): | |
| if df_text is None or df_text.empty: | |
| return [] | |
| rows = df_text.sort_values("paragraph_id") | |
| paragraphs = [] | |
| cursor = 0 | |
| for _, row in rows.iterrows(): | |
| ptext = row["text"] or "" | |
| start = cursor | |
| end = start + len(ptext) | |
| paragraphs.append({"text": ptext, "start": start, "end": end}) | |
| cursor = end + 1 | |
| if not paragraphs: | |
| return [] | |
| facts_idx = None | |
| for i, p in enumerate(paragraphs): | |
| if "THE FACTS" in p["text"]: | |
| facts_idx = i | |
| if facts_idx is None: | |
| return [] | |
| law_idx = None | |
| for i in range(facts_idx + 1, len(paragraphs)): | |
| if "THE LAW" in paragraphs[i]["text"]: | |
| law_idx = i | |
| if law_idx is None: | |
| return [] | |
| opinion_indices = [ | |
| i for i in range(law_idx + 1, len(paragraphs)) | |
| if "OPINION" in paragraphs[i]["text"] | |
| ] | |
| sections = [] | |
| facts_start = paragraphs[facts_idx]["start"] | |
| law_start = paragraphs[law_idx]["start"] | |
| facts_end = law_start | |
| sections.append({ | |
| "name": normalize_section_title(paragraphs[facts_idx]["text"]), | |
| "start": facts_start, | |
| "end": facts_end, | |
| }) | |
| law_end = ( | |
| paragraphs[opinion_indices[0]]["start"] | |
| if opinion_indices | |
| else paragraphs[-1]["end"] | |
| ) | |
| sections.append({ | |
| "name": normalize_section_title(paragraphs[law_idx]["text"]), | |
| "start": law_start, | |
| "end": law_end, | |
| }) | |
| for idx, op_idx in enumerate(opinion_indices): | |
| start = paragraphs[op_idx]["start"] | |
| end = ( | |
| paragraphs[opinion_indices[idx + 1]]["start"] | |
| if idx + 1 < len(opinion_indices) | |
| else paragraphs[-1]["end"] | |
| ) | |
| sections.append({ | |
| "name": normalize_section_title(paragraphs[op_idx]["text"]), | |
| "start": start, | |
| "end": end, | |
| }) | |
| return sections | |
| def render_section_guide(sections, compact_columns=None): | |
| st.markdown("### Section Guide") | |
| if not sections: | |
| st.info("No section guide could be generated for this document.") | |
| return | |
| if compact_columns: | |
| cells = [] | |
| for section in sections: | |
| name = html.escape(str(section["name"])) | |
| start = format_char_count(section["start"]) | |
| end = format_char_count(section["end"]) | |
| cells.append( | |
| "<div class='section-guide__cell'>" | |
| f"<div class='section-guide__name'>{name}</div>" | |
| f"<div class='section-guide__range'>{start} - {end}</div>" | |
| "</div>" | |
| ) | |
| st.markdown( | |
| "<style>" | |
| ".section-guide{border:1px solid #e6e6e6;border-radius:6px;" | |
| "padding:8px 10px;margin:6px 0 12px;}" | |
| ".section-guide__grid{display:grid;gap:8px;}" | |
| ".section-guide__cell{padding:6px 8px;border:1px dashed #eee;" | |
| "border-radius:6px;background:#fafafa;}" | |
| ".section-guide__name{font-weight:600;font-size:12px;color:#222;}" | |
| ".section-guide__range{font-family:monospace;font-size:11px;" | |
| "color:#333;white-space:nowrap;margin-top:2px;}" | |
| "</style>" | |
| f"<div class='section-guide section-guide__grid' " | |
| f"style='grid-template-columns:repeat({int(compact_columns)}, minmax(0, 1fr));'>" | |
| + "".join(cells) | |
| + "</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| else: | |
| rows = [] | |
| for section in sections: | |
| name = html.escape(str(section["name"])) | |
| start = format_char_count(section["start"]) | |
| end = format_char_count(section["end"]) | |
| rows.append( | |
| "<div class='section-guide__row'>" | |
| f"<span class='section-guide__name'>{name}</span>" | |
| f"<span class='section-guide__range'>{start} - {end}</span>" | |
| "</div>" | |
| ) | |
| st.markdown( | |
| "<style>" | |
| ".section-guide{border:1px solid #e6e6e6;border-radius:6px;" | |
| "padding:8px 10px;margin:6px 0 12px;}" | |
| ".section-guide__row{display:flex;justify-content:space-between;" | |
| "gap:12px;align-items:baseline;padding:4px 0;" | |
| "border-bottom:1px dashed #eee;}" | |
| ".section-guide__row:last-child{border-bottom:none;}" | |
| ".section-guide__name{font-weight:600;font-size:13px;color:#222;" | |
| "flex:1 1 auto;}" | |
| ".section-guide__range{font-family:monospace;font-size:12px;" | |
| "color:#333;white-space:nowrap;}" | |
| "</style>" | |
| "<div class='section-guide'>" | |
| + "".join(rows) | |
| + "</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| # span -> bin coverage | |
| def bin_spans_into_brackets(df_ann, text_len, bin_size): | |
| if df_ann.empty: | |
| return pd.DataFrame() | |
| records = [] | |
| for _, row in df_ann.iterrows(): | |
| s = int(row["global_begin"]) | |
| e = int(min(row["global_end"], text_len)) | |
| if s >= e: | |
| continue | |
| start_bin = s // bin_size | |
| end_bin = e // bin_size | |
| for b in range(start_bin, end_bin + 1): | |
| bin_start = b * bin_size | |
| bin_end = min((b + 1) * bin_size, text_len) | |
| overlap_start = max(s, bin_start) | |
| overlap_end = min(e, bin_end) | |
| if overlap_start < overlap_end: | |
| overlap_len = overlap_end - overlap_start | |
| records.append({ | |
| "label": row["label"], | |
| "bin": b, | |
| "overlap_len": overlap_len, | |
| "bin_size": bin_size, | |
| }) | |
| if not records: | |
| return pd.DataFrame() | |
| df = pd.DataFrame(records) | |
| df = ( | |
| df.groupby(["label", "bin"], as_index=False) | |
| .agg({"overlap_len": "sum", "bin_size": "first"}) | |
| ) | |
| df["coverage_ratio"] = (df["overlap_len"] / df["bin_size"]).clip(0, 1) | |
| return df | |
| # matrix style heatmap | |
| def make_matrix_style_heatmap(df_heat, bin_size, text_len, color="#1f6aff"): | |
| if df_heat.empty: | |
| return alt.Chart(pd.DataFrame({"a": []})).mark_text(text="No annotations") | |
| df = df_heat.copy() | |
| df["bin_start"] = df["bin"] * bin_size | |
| df["bin_end"] = df["bin_start"] + bin_size | |
| heatmap_select = alt.selection_point( | |
| fields=["label", "bin"], | |
| on="click", | |
| clear="dblclick", | |
| name="heatmap_select", | |
| ) | |
| chart = ( | |
| alt.Chart(df) | |
| .mark_rect() | |
| .encode( | |
| x=alt.X("bin_start:Q", | |
| title="Character Bracket", | |
| axis=alt.Axis(format="~s")), | |
| x2="bin_end:Q", | |
| y=alt.Y( | |
| "label:N", | |
| title="Argument Type", | |
| sort=alt.SortArray(LABEL_ORDER), | |
| axis=alt.Axis(labelLimit=0, labelPadding=8), | |
| ), | |
| color=alt.Color( | |
| "coverage_ratio:Q", | |
| title="% of bin covered", | |
| scale=alt.Scale( | |
| domain=[0, 1], | |
| range=["#ffffff", color] | |
| ) | |
| ), | |
| tooltip=[ | |
| alt.Tooltip("label:N", title="Argument"), | |
| alt.Tooltip("bin_start:Q", title="Bin start", format=","), | |
| alt.Tooltip("bin_end:Q", title="Bin end", format=","), | |
| alt.Tooltip("coverage_ratio:Q", title="Coverage", format=".0%") | |
| ], | |
| ) | |
| .add_params(heatmap_select) | |
| .properties( | |
| width=1200, | |
| height=40 * df["label"].nunique(), | |
| ) | |
| ) | |
| return chart | |
| # highlighting utils | |
| def generate_color_palette(n): | |
| colors = [] | |
| for i in range(n): | |
| hue = i / max(1, n) | |
| r, g, b = colorsys.hls_to_rgb(hue, 0.6, 0.8) | |
| colors.append(f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, 0.35)") | |
| return colors | |
| def make_annotator_color_map(annotators): | |
| colors = generate_color_palette(len(annotators)) | |
| return {a: c for a, c in zip(annotators, colors)} | |
| def compute_interval_segments(text_len, spans): | |
| boundaries = {0, text_len} | |
| for s, e, _ in spans: | |
| boundaries.add(int(s)) | |
| boundaries.add(int(e)) | |
| cuts = sorted(b for b in boundaries if 0 <= b <= text_len) | |
| intervals = [] | |
| for i in range(len(cuts) - 1): | |
| s, e = cuts[i], cuts[i+1] | |
| if s >= e: | |
| continue | |
| active = [span for span in spans if span[0] < e and span[1] > s] | |
| intervals.append((s, e, [a[2] for a in active])) | |
| return intervals | |
| def render_highlighted_html(text, spans, color_map, meta_map, focus_ann_id=None): | |
| if not spans: | |
| return f"<pre>{html.escape(text)}</pre>" | |
| intervals = compute_interval_segments(len(text), spans) | |
| out = [] | |
| anchored = False | |
| for s, e, ann_ids in intervals: | |
| chunk = html.escape(text[s:e]) | |
| if ann_ids: | |
| bg_layers = ", ".join( | |
| f"linear-gradient({color_map[a]} 0 0)" for a in ann_ids | |
| ) | |
| tooltip_lines = [] | |
| for a in ann_ids: | |
| m = meta_map[a] | |
| tooltip_lines.append( | |
| f"{m['label']} — {m['annotator']} ({m['curation']})" | |
| ) | |
| title_attr = html.escape("\n".join(tooltip_lines)) | |
| is_focus = focus_ann_id in ann_ids if focus_ann_id else False | |
| anchor_attr = "" | |
| if is_focus and not anchored: | |
| anchor_attr = f' id="ann-{focus_ann_id}"' | |
| anchored = True | |
| focus_style = "box-shadow: inset 0 0 0 2px #111;" if is_focus else "" | |
| chunk = ( | |
| f'<span title="{title_attr}"{anchor_attr} ' | |
| f'style="background:{bg_layers};' | |
| f'background-blend-mode:multiply;{focus_style}">' | |
| f'{chunk}</span>' | |
| ) | |
| out.append(chunk) | |
| return "<pre style='line-height:1.5'>" + "".join(out) + "</pre>" | |
| def extract_heatmap_selection(event): | |
| if event is None: | |
| return None | |
| selection = getattr(event, "selection", None) | |
| if selection is None and isinstance(event, dict): | |
| selection = event.get("selection") | |
| def pull_fields(sel): | |
| if sel is None: | |
| return None | |
| if isinstance(sel, dict): | |
| if "label" in sel and "bin" in sel: | |
| return sel | |
| if "values" in sel: | |
| return pull_fields(sel.get("values")) | |
| for value in sel.values(): | |
| extracted = pull_fields(value) | |
| if extracted: | |
| return extracted | |
| if isinstance(sel, list) and sel: | |
| return pull_fields(sel[0]) | |
| return None | |
| return pull_fields(selection) | |
| def project_spans_to_interval(spans_global, seg_start, seg_end): | |
| projected = [] | |
| for g_start, g_end, ann_id in spans_global: | |
| if g_end <= seg_start or g_start >= seg_end: | |
| continue | |
| local_start = max(g_start, seg_start) - seg_start | |
| local_end = min(g_end, seg_end) - seg_start | |
| if local_start < local_end: | |
| projected.append((local_start, local_end, ann_id)) | |
| return projected | |
| def pick_focus_annotation(df_ann, label, bin_start, bin_end): | |
| if df_ann.empty: | |
| return None | |
| df_sel = df_ann[ | |
| (df_ann["label"] == label) | |
| & (df_ann["global_begin"] < bin_end) | |
| & (df_ann["global_end"] > bin_start) | |
| ] | |
| if df_sel.empty: | |
| return None | |
| overlaps = ( | |
| df_sel.assign( | |
| overlap=lambda d: ( | |
| np.minimum(d["global_end"], bin_end) | |
| - np.maximum(d["global_begin"], bin_start) | |
| ) | |
| ) | |
| .sort_values(["overlap", "global_begin"], ascending=[False, True]) | |
| ) | |
| return overlaps.iloc[0]["annotation_id"] | |
| def scroll_to_annotation(focus_ann_id): | |
| if focus_ann_id is None: | |
| return | |
| components.html( | |
| "<script>" | |
| "const targetId = 'ann-" + str(focus_ann_id) + "';" | |
| "const tryScroll = () => {" | |
| " const el = window.parent.document.getElementById(targetId);" | |
| " if (el) {" | |
| " el.scrollIntoView({behavior: 'smooth', block: 'center'});" | |
| " return true;" | |
| " }" | |
| " return false;" | |
| "};" | |
| "if (!tryScroll()) {" | |
| " setTimeout(tryScroll, 150);" | |
| "}" | |
| "</script>", | |
| height=0, | |
| ) | |
| def scroll_to_heatmap(anchor_id): | |
| if not anchor_id: | |
| return | |
| components.html( | |
| "<script>" | |
| "const targetId = '" + str(anchor_id) + "';" | |
| "const tryScroll = () => {" | |
| " const el = window.parent.document.getElementById(targetId);" | |
| " if (el) {" | |
| " el.scrollIntoView({behavior: 'smooth', block: 'start'});" | |
| " return true;" | |
| " }" | |
| " return false;" | |
| "};" | |
| "if (!tryScroll()) {" | |
| " setTimeout(tryScroll, 150);" | |
| "}" | |
| "</script>", | |
| height=0, | |
| ) | |
| def render_floating_heatmap_button(anchor_id, button_id): | |
| if not anchor_id: | |
| return | |
| components.html( | |
| "<script>" | |
| "const btnId = '" + str(button_id) + "';" | |
| "const anchorId = '" + str(anchor_id) + "';" | |
| "const doc = window.document;" | |
| "let btn = doc.getElementById(btnId);" | |
| "if (!btn) {" | |
| " btn = doc.createElement('button');" | |
| " btn.id = btnId;" | |
| " btn.textContent = 'Back to heatmap';" | |
| " btn.style.position = 'fixed';" | |
| " btn.style.right = '16px';" | |
| " btn.style.bottom = '16px';" | |
| " btn.style.zIndex = '2147483647';" | |
| " btn.style.padding = '8px 12px';" | |
| " btn.style.border = '1px solid #ccc';" | |
| " btn.style.borderRadius = '8px';" | |
| " btn.style.background = '#fff';" | |
| " btn.style.color = '#222';" | |
| " btn.style.boxShadow = '0 2px 6px rgba(0,0,0,0.12)';" | |
| " btn.style.cursor = 'pointer';" | |
| " btn.style.transform = 'none';" | |
| " btn.style.margin = '0';" | |
| " btn.style.pointerEvents = 'auto';" | |
| " doc.body.appendChild(btn);" | |
| "}" | |
| "btn.onclick = () => {" | |
| " let el = doc.getElementById(anchorId);" | |
| " if (!el) {" | |
| " try { el = window.parent.document.getElementById(anchorId); } catch (e) {}" | |
| " }" | |
| " if (el) {" | |
| " el.scrollIntoView({behavior: 'smooth', block: 'start'});" | |
| " }" | |
| "};" | |
| "</script>", | |
| height=0, | |
| ) | |
| # streamlit UI | |
| st.set_page_config(page_title="Argument Heatmap Explorer", layout="wide") | |
| st.title("Argument Saturation Heatmap") | |
| app_password = os.getenv("APP_PASSWORD") | |
| if app_password: | |
| if not st.session_state.get("auth_ok"): | |
| with st.sidebar: | |
| st.markdown("### Access") | |
| pw = st.text_input("Password", type="password") | |
| if pw: | |
| if pw == app_password: | |
| st.session_state["auth_ok"] = True | |
| else: | |
| st.session_state["auth_ok"] = False | |
| st.error("Incorrect password.") | |
| if not st.session_state.get("auth_ok"): | |
| st.stop() | |
| st.caption("Rows = argument types · Columns = character bins · Color = % coverage") | |
| st.markdown( | |
| "<style>" | |
| "[data-testid='stSidebar']{position:fixed;top:0;left:0;height:100vh;}" | |
| "[data-testid='stSidebar'] > div:first-child{height:100vh;overflow:auto;}" | |
| "</style>", | |
| unsafe_allow_html=True, | |
| ) | |
| components.html( | |
| "<script>" | |
| "if (!window._backspaceScrollBound) {" | |
| " window._backspaceScrollBound = true;" | |
| " window.addEventListener('keydown', (e) => {" | |
| " const tag = (e.target && e.target.tagName) || '';" | |
| " const isInput = tag === 'INPUT' || tag === 'TEXTAREA' || e.target.isContentEditable;" | |
| " if (!isInput && e.key === 'Backspace') {" | |
| " e.preventDefault();" | |
| " window.scrollTo({top: 0, behavior: 'smooth'});" | |
| " }" | |
| " });" | |
| "}" | |
| "</script>", | |
| height=0, | |
| ) | |
| components.html( | |
| "<script>" | |
| "const lockSidebar = () => {" | |
| " const doc = window.parent.document;" | |
| " const sidebar = doc.querySelector('[data-testid=\"stSidebar\"], .stSidebar');" | |
| " if (!sidebar) return false;" | |
| " sidebar.style.position = 'fixed';" | |
| " sidebar.style.top = '0';" | |
| " sidebar.style.left = '0';" | |
| " sidebar.style.height = '100vh';" | |
| " sidebar.style.zIndex = '999';" | |
| " const inner = sidebar.querySelector('div');" | |
| " if (inner) {" | |
| " inner.style.height = '100vh';" | |
| " inner.style.overflow = 'auto';" | |
| " }" | |
| " const main = doc.querySelector('[data-testid=\"stAppViewContainer\"], .main');" | |
| " if (main) {" | |
| " const w = sidebar.getBoundingClientRect().width;" | |
| " main.style.marginLeft = `${w}px`;" | |
| " }" | |
| " return true;" | |
| "};" | |
| "if (!lockSidebar()) {" | |
| " setTimeout(lockSidebar, 200);" | |
| " setTimeout(lockSidebar, 800);" | |
| "}" | |
| "</script>", | |
| height=0, | |
| ) | |
| # sidebar | |
| st.sidebar.header("Load Data") | |
| hearings_ds_path = st.sidebar.text_input( | |
| "Hearings dataset path", | |
| "la_cour_dataset_hearings" | |
| ) | |
| judgments_ds_path = st.sidebar.text_input( | |
| "Judgments dataset path", | |
| "la_cour_dataset_judgments" | |
| ) | |
| # default CSV locations | |
| default_hear_csv = resolve_dataset_path("la_cour_hearings_annotations.csv") | |
| default_judg_csv = resolve_dataset_path("la_cour_judgments_annotations.csv") | |
| st.sidebar.markdown("#### Annotation CSVs") | |
| hear_ann_upload = st.sidebar.file_uploader( | |
| "Hearing annotations CSV", | |
| type="csv", | |
| key="hear_csv_upload" | |
| ) | |
| judg_ann_upload = st.sidebar.file_uploader( | |
| "Judgment annotations CSV", | |
| type="csv", | |
| key="judg_csv_upload" | |
| ) | |
| def load_csv_or_default(upload_file, default_path): | |
| if upload_file: | |
| return pd.read_csv(upload_file), f"(uploaded) {upload_file.name}" | |
| if os.path.exists(default_path): | |
| return pd.read_csv(default_path), f"(default) {default_path}" | |
| return None, "(missing)" | |
| df_hear_ann, hear_status = load_csv_or_default(hear_ann_upload, default_hear_csv) | |
| df_judg_ann, judg_status = load_csv_or_default(judg_ann_upload, default_judg_csv) | |
| st.sidebar.caption(f"Hearing CSV: {hear_status}") | |
| st.sidebar.caption(f"Judgment CSV: {judg_status}") | |
| bin_size = st.sidebar.slider( | |
| "Characters per bin", | |
| min_value=50, | |
| max_value=3000, | |
| value=400, | |
| step=50, | |
| ) | |
| heat_color = st.sidebar.color_picker("Heatmap color", value="#1f6aff") | |
| go_heatmap = st.sidebar.button("Back to heatmap") | |
| if df_hear_ann is None and df_judg_ann is None: | |
| st.info("No annotations loaded — upload a CSV or place defaults in working directory.") | |
| st.stop() | |
| # load datasets lazily | |
| ds_hear = ( | |
| load_from_disk(resolve_dataset_path(hearings_ds_path)) | |
| if df_hear_ann is not None | |
| else None | |
| ) | |
| ds_judg = ( | |
| load_from_disk(resolve_dataset_path(judgments_ds_path)) | |
| if df_judg_ann is not None | |
| else None | |
| ) | |
| df_hear_text = ds_hear.to_pandas() if ds_hear else None | |
| df_judg_text = ds_judg.to_pandas() if ds_judg else None | |
| # tab renderer | |
| def render_heatmap_tab(df_ann, df_text, title, key, is_hearing): | |
| if df_ann is None: | |
| st.warning(f"No {title.lower()} annotations loaded.") | |
| return | |
| st.subheader(f"{title} Heatmap") | |
| webcast_ids = sorted(df_ann["webcast_id"].unique()) | |
| webcast = st.selectbox("Select document", webcast_ids, key=f"wc_{key}") | |
| dfA = df_ann[df_ann["webcast_id"] == webcast] | |
| dfT = df_text[df_text["webcast_id"] == webcast] | |
| labels = sorted(dfA["label"].dropna().unique()) | |
| annotators = sorted(dfA["annotator"].dropna().unique()) | |
| c1, c2, c3, c4 = st.columns(4) | |
| sel_labels = c1.multiselect("Argument types", labels, default=labels, key=f"lbl_{key}") | |
| sel_ann = c2.multiselect("Annotators (heatmap)", annotators, default=annotators, key=f"ann_{key}") | |
| valid_only = c3.checkbox("Valid only", value=True, key=f"valid_{key}") | |
| preview_ann = c4.multiselect( | |
| "Annotators (preview)", | |
| annotators, | |
| default=annotators, | |
| key=f"hl_{key}", | |
| ) | |
| dfA = dfA[dfA["label"].isin(sel_labels) & dfA["annotator"].isin(sel_ann)] | |
| if valid_only: | |
| dfA = dfA[dfA["curation"] == "valid"] | |
| full_text = concat_global_text(dfT, webcast) | |
| if not sanity_check(dfA, len(full_text)): | |
| st.error("Annotation spans exceed text length.") | |
| return | |
| df_heat = bin_spans_into_brackets(dfA, len(full_text), bin_size) | |
| st.markdown("### Heatmap") | |
| heatmap_anchor = f"heatmap-anchor-{key}" | |
| st.markdown( | |
| f"<div id='{heatmap_anchor}'></div>", | |
| unsafe_allow_html=True, | |
| ) | |
| render_floating_heatmap_button(heatmap_anchor, f"heatmap-btn-{key}") | |
| heatmap_chart = make_matrix_style_heatmap( | |
| df_heat, bin_size, len(full_text), color=heat_color | |
| ) | |
| try: | |
| heatmap_event = st.altair_chart( | |
| heatmap_chart, | |
| use_container_width=True, | |
| on_select="rerun", | |
| ) | |
| except TypeError: | |
| heatmap_event = None | |
| st.altair_chart(heatmap_chart, use_container_width=True) | |
| selected = extract_heatmap_selection(heatmap_event) | |
| selection_key = f"heat_sel_{key}" | |
| if selected: | |
| try: | |
| st.session_state[selection_key] = { | |
| "label": selected["label"], | |
| "bin": int(selected["bin"]), | |
| } | |
| except (TypeError, ValueError): | |
| pass | |
| elif heatmap_event is not None: | |
| st.session_state.pop(selection_key, None) | |
| sections = ( | |
| compute_hearing_sections(dfT) | |
| if is_hearing | |
| else compute_judgment_sections(dfT) | |
| ) | |
| render_section_guide(sections, compact_columns=10 if is_hearing else None) | |
| # highlighted text preview | |
| st.markdown("### Highlighted Text Preview") | |
| if preview_ann: | |
| annot_color_map = make_annotator_color_map(preview_ann) | |
| legend_rows = [] | |
| for annot in preview_ann: | |
| color = annot_color_map.get(annot, "rgba(0,0,0,0.25)") | |
| legend_rows.append( | |
| f"<div class='annotator-legend__row'>" | |
| f"<span class='annotator-legend__swatch' " | |
| f"style='background:{color}'></span>" | |
| f"<span class='annotator-legend__label'>" | |
| f"{html.escape(str(annot))}</span></div>" | |
| ) | |
| else: | |
| annot_color_map = {} | |
| legend_rows = [] | |
| dfH = df_ann[ | |
| (df_ann["webcast_id"] == webcast) | |
| & (df_ann["annotator"].isin(preview_ann)) | |
| ] | |
| if valid_only: | |
| dfH = dfH[dfH["curation"] == "valid"] | |
| spans_global = [ | |
| (int(r["global_begin"]), int(r["global_end"]), r["annotation_id"]) | |
| for _, r in dfH.iterrows() | |
| ] | |
| ann_id_to_annot = { | |
| r["annotation_id"]: r["annotator"] | |
| for _, r in dfH.iterrows() | |
| } | |
| combo_rows = [] | |
| if preview_ann and spans_global: | |
| intervals = compute_interval_segments(len(full_text), spans_global) | |
| seen_combos = set() | |
| for _, _, ann_ids in intervals: | |
| annotators_in_span = sorted( | |
| {ann_id_to_annot.get(a) for a in ann_ids if a in ann_id_to_annot} | |
| ) | |
| if len(annotators_in_span) <= 1: | |
| continue | |
| combo_key = tuple(annotators_in_span) | |
| if combo_key in seen_combos: | |
| continue | |
| seen_combos.add(combo_key) | |
| bg_layers = ", ".join( | |
| f"linear-gradient({annot_color_map[a]} 0 0)" | |
| for a in annotators_in_span | |
| if a in annot_color_map | |
| ) | |
| label = " + ".join(html.escape(str(a)) for a in combo_key) | |
| combo_rows.append( | |
| f"<div class='annotator-legend__row'>" | |
| f"<span class='annotator-legend__swatch' " | |
| f"style='background:{bg_layers};" | |
| f"background-blend-mode:multiply;'></span>" | |
| f"<span class='annotator-legend__label'>{label}</span></div>" | |
| ) | |
| color_map = { | |
| r["annotation_id"]: annot_color_map.get( | |
| r["annotator"], "rgba(0,0,0,0.25)" | |
| ) | |
| for _, r in dfH.iterrows() | |
| } | |
| if legend_rows or combo_rows: | |
| st.markdown( | |
| "<style>" | |
| ".annotator-legend{position:sticky;top:0;background:white;" | |
| "padding:6px 8px;border:1px solid #e6e6e6;border-radius:6px;" | |
| "z-index:10;margin:6px 0 12px 0;display:inline-block;}" | |
| ".annotator-legend__row{display:flex;align-items:center;" | |
| "gap:8px;margin:2px 0;}" | |
| ".annotator-legend__swatch{width:16px;height:16px;" | |
| "border-radius:3px;display:inline-block;}" | |
| ".annotator-legend__label{font-size:12px;color:#222;}" | |
| ".annotator-legend__section{font-size:11px;margin:4px 0 2px;" | |
| "color:#666;}" | |
| "</style>" | |
| "<div class='annotator-legend'>" | |
| "<div class='annotator-legend__section'>Annotators</div>" | |
| + "".join(legend_rows) | |
| + ( | |
| "<div class='annotator-legend__section'>Combinations</div>" | |
| + "".join(combo_rows) | |
| if combo_rows else "" | |
| ) | |
| + "</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| meta_map = { | |
| r["annotation_id"]: { | |
| "label": r["label"], | |
| "annotator": r["annotator"], | |
| "curation": r["curation"] | |
| } | |
| for _, r in dfH.iterrows() | |
| } | |
| focus_ann_id = None | |
| selection_state = st.session_state.get(selection_key) | |
| if selection_state and not dfH.empty: | |
| sel_label = selection_state.get("label") | |
| sel_bin = selection_state.get("bin") | |
| if sel_label is not None and sel_bin is not None: | |
| bin_start = sel_bin * bin_size | |
| bin_end = bin_start + bin_size | |
| focus_ann_id = pick_focus_annotation( | |
| dfH, sel_label, bin_start, bin_end | |
| ) | |
| # hearing preview | |
| if is_hearing: | |
| rows = dfT.sort_values(["segment_id", "sequence_id"]) | |
| html_blocks = [] | |
| cursor = 0 | |
| for seg_id, seg_rows in rows.groupby("segment_id"): | |
| speaker = ( | |
| seg_rows.iloc[0].get("speaker_name") | |
| or seg_rows.iloc[0].get("speaker_role") | |
| or "Unknown" | |
| ) | |
| pieces = [] | |
| seg_start = cursor | |
| for _, r in seg_rows.iterrows(): | |
| t = r["text"] or "" | |
| pieces.append(t) | |
| cursor += len(t) + 1 | |
| segment_text = " ".join(pieces) | |
| seg_end = seg_start + len(segment_text) | |
| local_spans = project_spans_to_interval(spans_global, seg_start, seg_end) | |
| html_blocks.append(f"<b>{html.escape(str(speaker))}</b><br>") | |
| html_blocks.append( | |
| render_highlighted_html( | |
| segment_text, | |
| local_spans, | |
| color_map, | |
| meta_map, | |
| focus_ann_id=focus_ann_id, | |
| ) | |
| ) | |
| html_blocks.append("<br>") | |
| st.markdown("".join(html_blocks), unsafe_allow_html=True) | |
| # judgment preview | |
| else: | |
| rows = dfT.sort_values("paragraph_id") | |
| html_blocks = [] | |
| cursor = 0 | |
| for _, row in rows.iterrows(): | |
| ptext = row["text"] or "" | |
| seg_start = cursor | |
| seg_end = seg_start + len(ptext) | |
| local_spans = project_spans_to_interval(spans_global, seg_start, seg_end) | |
| cursor = seg_end + 1 | |
| html_blocks.append( | |
| render_highlighted_html( | |
| ptext, | |
| local_spans, | |
| color_map, | |
| meta_map, | |
| focus_ann_id=focus_ann_id, | |
| ) | |
| ) | |
| html_blocks.append("<br>\n") | |
| st.markdown("".join(html_blocks), unsafe_allow_html=True) | |
| scroll_to_annotation(focus_ann_id) | |
| if go_heatmap: | |
| scroll_to_heatmap(heatmap_anchor) | |
| # tabs | |
| tab1, tab2 = st.tabs(["Hearings", "Judgments"]) | |
| with tab1: | |
| render_heatmap_tab(df_hear_ann, df_hear_text, "Hearing", "hear", is_hearing=True) | |
| with tab2: | |
| render_heatmap_tab(df_judg_ann, df_judg_text, "Judgment", "judg", is_hearing=False) | |