import streamlit as st import pandas as pd import numpy as np import altair as alt from datasets import load_from_disk from huggingface_hub import snapshot_download import colorsys import html import os import streamlit.components.v1 as components # text utils LABEL_ORDER = [ "FA - Factual Argument", "FA - Factual Question", "FO - Formal Question", "FO - Precedent", "FO - Systematic Interpretation", "FO - Textual Interpretation", "SU - Non-Legal Argument", "SU - Proportionality Analysis", "SU - Substantive Question", "SU - Teleological or Purposive Interpretation", "Negative Frame (ISS-N)", "Positive Frame (ISS-P)", "Crime Frame (JUST-C)", "Health Frame (JUST-H)", "National Security Frame (JUST-S)", "Rights Frame (JUST-R)", ] def concat_global_text(df, webcast_id, text_col="text"): rows = df[df["webcast_id"] == webcast_id] if "segment_id" in rows.columns: rows = rows.sort_values(["segment_id", "sequence_id"]) elif "paragraph_id" in rows.columns: rows = rows.sort_values("paragraph_id") return " ".join(rows[text_col].fillna("").tolist()) def sanity_check(df_ann, text_len): if df_ann.empty: return True return df_ann["global_end"].max() <= text_len def get_hf_dataset_root(): repo = os.getenv("HF_DATASET_REPO") if not repo: return None cached_root = st.session_state.get("hf_dataset_root") if cached_root: return cached_root token = os.getenv("HF_TOKEN") if not token: st.error("HF_TOKEN secret missing for private dataset access.") st.stop() cache_dir = os.path.join(os.getcwd(), ".hf_data_cache") try: snapshot_path = snapshot_download( repo_id=repo, repo_type="dataset", token=token, local_dir=cache_dir, local_dir_use_symlinks=False, ) except Exception as exc: st.error(f"Failed to load dataset repo: {exc}") st.stop() st.session_state["hf_dataset_root"] = snapshot_path return snapshot_path def resolve_dataset_path(relative_path): root = get_hf_dataset_root() return os.path.join(root, relative_path) if root else relative_path def format_char_count(value): n = int(value) if n >= 10000: return f"{int(round(n / 1000.0))}k" if n >= 100: scaled = round(n / 1000.0, 1) text = f"{scaled:.1f}".rstrip("0").rstrip(".") return f"{text}k" return str(n) def normalize_section_title(text): return " ".join(str(text).split()) if text else "Unknown" def compute_hearing_sections(df_text): if df_text is None or df_text.empty: return [] rows = df_text.sort_values(["segment_id", "sequence_id"]) sections = [] cursor = 0 for _, seg_rows in rows.groupby("segment_id"): speaker = ( seg_rows.iloc[0].get("speaker_name") or seg_rows.iloc[0].get("speaker_role") or "Unknown" ) seg_start = cursor pieces = [] for _, r in seg_rows.iterrows(): t = r["text"] or "" pieces.append(t) cursor += len(t) + 1 segment_text = " ".join(pieces) seg_end = seg_start + len(segment_text) sections.append({ "name": normalize_section_title(speaker), "start": seg_start, "end": seg_end, }) return sections def compute_judgment_sections(df_text): if df_text is None or df_text.empty: return [] rows = df_text.sort_values("paragraph_id") paragraphs = [] cursor = 0 for _, row in rows.iterrows(): ptext = row["text"] or "" start = cursor end = start + len(ptext) paragraphs.append({"text": ptext, "start": start, "end": end}) cursor = end + 1 if not paragraphs: return [] facts_idx = None for i, p in enumerate(paragraphs): if "THE FACTS" in p["text"]: facts_idx = i if facts_idx is None: return [] law_idx = None for i in range(facts_idx + 1, len(paragraphs)): if "THE LAW" in paragraphs[i]["text"]: law_idx = i if law_idx is None: return [] opinion_indices = [ i for i in range(law_idx + 1, len(paragraphs)) if "OPINION" in paragraphs[i]["text"] ] sections = [] facts_start = paragraphs[facts_idx]["start"] law_start = paragraphs[law_idx]["start"] facts_end = law_start sections.append({ "name": normalize_section_title(paragraphs[facts_idx]["text"]), "start": facts_start, "end": facts_end, }) law_end = ( paragraphs[opinion_indices[0]]["start"] if opinion_indices else paragraphs[-1]["end"] ) sections.append({ "name": normalize_section_title(paragraphs[law_idx]["text"]), "start": law_start, "end": law_end, }) for idx, op_idx in enumerate(opinion_indices): start = paragraphs[op_idx]["start"] end = ( paragraphs[opinion_indices[idx + 1]]["start"] if idx + 1 < len(opinion_indices) else paragraphs[-1]["end"] ) sections.append({ "name": normalize_section_title(paragraphs[op_idx]["text"]), "start": start, "end": end, }) return sections def render_section_guide(sections, compact_columns=None): st.markdown("### Section Guide") if not sections: st.info("No section guide could be generated for this document.") return if compact_columns: cells = [] for section in sections: name = html.escape(str(section["name"])) start = format_char_count(section["start"]) end = format_char_count(section["end"]) cells.append( "
{html.escape(text)}"
intervals = compute_interval_segments(len(text), spans)
out = []
anchored = False
for s, e, ann_ids in intervals:
chunk = html.escape(text[s:e])
if ann_ids:
bg_layers = ", ".join(
f"linear-gradient({color_map[a]} 0 0)" for a in ann_ids
)
tooltip_lines = []
for a in ann_ids:
m = meta_map[a]
tooltip_lines.append(
f"{m['label']} — {m['annotator']} ({m['curation']})"
)
title_attr = html.escape("\n".join(tooltip_lines))
is_focus = focus_ann_id in ann_ids if focus_ann_id else False
anchor_attr = ""
if is_focus and not anchored:
anchor_attr = f' id="ann-{focus_ann_id}"'
anchored = True
focus_style = "box-shadow: inset 0 0 0 2px #111;" if is_focus else ""
chunk = (
f''
f'{chunk}'
)
out.append(chunk)
return "" + "".join(out) + "" def extract_heatmap_selection(event): if event is None: return None selection = getattr(event, "selection", None) if selection is None and isinstance(event, dict): selection = event.get("selection") def pull_fields(sel): if sel is None: return None if isinstance(sel, dict): if "label" in sel and "bin" in sel: return sel if "values" in sel: return pull_fields(sel.get("values")) for value in sel.values(): extracted = pull_fields(value) if extracted: return extracted if isinstance(sel, list) and sel: return pull_fields(sel[0]) return None return pull_fields(selection) def project_spans_to_interval(spans_global, seg_start, seg_end): projected = [] for g_start, g_end, ann_id in spans_global: if g_end <= seg_start or g_start >= seg_end: continue local_start = max(g_start, seg_start) - seg_start local_end = min(g_end, seg_end) - seg_start if local_start < local_end: projected.append((local_start, local_end, ann_id)) return projected def pick_focus_annotation(df_ann, label, bin_start, bin_end): if df_ann.empty: return None df_sel = df_ann[ (df_ann["label"] == label) & (df_ann["global_begin"] < bin_end) & (df_ann["global_end"] > bin_start) ] if df_sel.empty: return None overlaps = ( df_sel.assign( overlap=lambda d: ( np.minimum(d["global_end"], bin_end) - np.maximum(d["global_begin"], bin_start) ) ) .sort_values(["overlap", "global_begin"], ascending=[False, True]) ) return overlaps.iloc[0]["annotation_id"] def scroll_to_annotation(focus_ann_id): if focus_ann_id is None: return components.html( "", height=0, ) def scroll_to_heatmap(anchor_id): if not anchor_id: return components.html( "", height=0, ) def render_floating_heatmap_button(anchor_id, button_id): if not anchor_id: return components.html( "", height=0, ) # streamlit UI st.set_page_config(page_title="Argument Heatmap Explorer", layout="wide") st.title("Argument Saturation Heatmap") app_password = os.getenv("APP_PASSWORD") if app_password: if not st.session_state.get("auth_ok"): with st.sidebar: st.markdown("### Access") pw = st.text_input("Password", type="password") if pw: if pw == app_password: st.session_state["auth_ok"] = True else: st.session_state["auth_ok"] = False st.error("Incorrect password.") if not st.session_state.get("auth_ok"): st.stop() st.caption("Rows = argument types · Columns = character bins · Color = % coverage") st.markdown( "", unsafe_allow_html=True, ) components.html( "", height=0, ) components.html( "", height=0, ) # sidebar st.sidebar.header("Load Data") hearings_ds_path = st.sidebar.text_input( "Hearings dataset path", "la_cour_dataset_hearings" ) judgments_ds_path = st.sidebar.text_input( "Judgments dataset path", "la_cour_dataset_judgments" ) # default CSV locations default_hear_csv = resolve_dataset_path("la_cour_hearings_annotations.csv") default_judg_csv = resolve_dataset_path("la_cour_judgments_annotations.csv") st.sidebar.markdown("#### Annotation CSVs") hear_ann_upload = st.sidebar.file_uploader( "Hearing annotations CSV", type="csv", key="hear_csv_upload" ) judg_ann_upload = st.sidebar.file_uploader( "Judgment annotations CSV", type="csv", key="judg_csv_upload" ) def load_csv_or_default(upload_file, default_path): if upload_file: return pd.read_csv(upload_file), f"(uploaded) {upload_file.name}" if os.path.exists(default_path): return pd.read_csv(default_path), f"(default) {default_path}" return None, "(missing)" df_hear_ann, hear_status = load_csv_or_default(hear_ann_upload, default_hear_csv) df_judg_ann, judg_status = load_csv_or_default(judg_ann_upload, default_judg_csv) st.sidebar.caption(f"Hearing CSV: {hear_status}") st.sidebar.caption(f"Judgment CSV: {judg_status}") bin_size = st.sidebar.slider( "Characters per bin", min_value=50, max_value=3000, value=400, step=50, ) heat_color = st.sidebar.color_picker("Heatmap color", value="#1f6aff") go_heatmap = st.sidebar.button("Back to heatmap") if df_hear_ann is None and df_judg_ann is None: st.info("No annotations loaded — upload a CSV or place defaults in working directory.") st.stop() # load datasets lazily ds_hear = ( load_from_disk(resolve_dataset_path(hearings_ds_path)) if df_hear_ann is not None else None ) ds_judg = ( load_from_disk(resolve_dataset_path(judgments_ds_path)) if df_judg_ann is not None else None ) df_hear_text = ds_hear.to_pandas() if ds_hear else None df_judg_text = ds_judg.to_pandas() if ds_judg else None # tab renderer def render_heatmap_tab(df_ann, df_text, title, key, is_hearing): if df_ann is None: st.warning(f"No {title.lower()} annotations loaded.") return st.subheader(f"{title} Heatmap") webcast_ids = sorted(df_ann["webcast_id"].unique()) webcast = st.selectbox("Select document", webcast_ids, key=f"wc_{key}") dfA = df_ann[df_ann["webcast_id"] == webcast] dfT = df_text[df_text["webcast_id"] == webcast] labels = sorted(dfA["label"].dropna().unique()) annotators = sorted(dfA["annotator"].dropna().unique()) c1, c2, c3, c4 = st.columns(4) sel_labels = c1.multiselect("Argument types", labels, default=labels, key=f"lbl_{key}") sel_ann = c2.multiselect("Annotators (heatmap)", annotators, default=annotators, key=f"ann_{key}") valid_only = c3.checkbox("Valid only", value=True, key=f"valid_{key}") preview_ann = c4.multiselect( "Annotators (preview)", annotators, default=annotators, key=f"hl_{key}", ) dfA = dfA[dfA["label"].isin(sel_labels) & dfA["annotator"].isin(sel_ann)] if valid_only: dfA = dfA[dfA["curation"] == "valid"] full_text = concat_global_text(dfT, webcast) if not sanity_check(dfA, len(full_text)): st.error("Annotation spans exceed text length.") return df_heat = bin_spans_into_brackets(dfA, len(full_text), bin_size) st.markdown("### Heatmap") heatmap_anchor = f"heatmap-anchor-{key}" st.markdown( f"", unsafe_allow_html=True, ) render_floating_heatmap_button(heatmap_anchor, f"heatmap-btn-{key}") heatmap_chart = make_matrix_style_heatmap( df_heat, bin_size, len(full_text), color=heat_color ) try: heatmap_event = st.altair_chart( heatmap_chart, use_container_width=True, on_select="rerun", ) except TypeError: heatmap_event = None st.altair_chart(heatmap_chart, use_container_width=True) selected = extract_heatmap_selection(heatmap_event) selection_key = f"heat_sel_{key}" if selected: try: st.session_state[selection_key] = { "label": selected["label"], "bin": int(selected["bin"]), } except (TypeError, ValueError): pass elif heatmap_event is not None: st.session_state.pop(selection_key, None) sections = ( compute_hearing_sections(dfT) if is_hearing else compute_judgment_sections(dfT) ) render_section_guide(sections, compact_columns=10 if is_hearing else None) # highlighted text preview st.markdown("### Highlighted Text Preview") if preview_ann: annot_color_map = make_annotator_color_map(preview_ann) legend_rows = [] for annot in preview_ann: color = annot_color_map.get(annot, "rgba(0,0,0,0.25)") legend_rows.append( f"