import streamlit as st import pandas as pd import numpy as np import altair as alt from datasets import load_from_disk from huggingface_hub import snapshot_download import colorsys import html import os import streamlit.components.v1 as components # text utils LABEL_ORDER = [ "FA - Factual Argument", "FA - Factual Question", "FO - Formal Question", "FO - Precedent", "FO - Systematic Interpretation", "FO - Textual Interpretation", "SU - Non-Legal Argument", "SU - Proportionality Analysis", "SU - Substantive Question", "SU - Teleological or Purposive Interpretation", "Negative Frame (ISS-N)", "Positive Frame (ISS-P)", "Crime Frame (JUST-C)", "Health Frame (JUST-H)", "National Security Frame (JUST-S)", "Rights Frame (JUST-R)", ] def concat_global_text(df, webcast_id, text_col="text"): rows = df[df["webcast_id"] == webcast_id] if "segment_id" in rows.columns: rows = rows.sort_values(["segment_id", "sequence_id"]) elif "paragraph_id" in rows.columns: rows = rows.sort_values("paragraph_id") return " ".join(rows[text_col].fillna("").tolist()) def sanity_check(df_ann, text_len): if df_ann.empty: return True return df_ann["global_end"].max() <= text_len def get_hf_dataset_root(): repo = os.getenv("HF_DATASET_REPO") if not repo: return None cached_root = st.session_state.get("hf_dataset_root") if cached_root: return cached_root token = os.getenv("HF_TOKEN") if not token: st.error("HF_TOKEN secret missing for private dataset access.") st.stop() cache_dir = os.path.join(os.getcwd(), ".hf_data_cache") try: snapshot_path = snapshot_download( repo_id=repo, repo_type="dataset", token=token, local_dir=cache_dir, local_dir_use_symlinks=False, ) except Exception as exc: st.error(f"Failed to load dataset repo: {exc}") st.stop() st.session_state["hf_dataset_root"] = snapshot_path return snapshot_path def resolve_dataset_path(relative_path): root = get_hf_dataset_root() return os.path.join(root, relative_path) if root else relative_path def format_char_count(value): n = int(value) if n >= 10000: return f"{int(round(n / 1000.0))}k" if n >= 100: scaled = round(n / 1000.0, 1) text = f"{scaled:.1f}".rstrip("0").rstrip(".") return f"{text}k" return str(n) def normalize_section_title(text): return " ".join(str(text).split()) if text else "Unknown" def compute_hearing_sections(df_text): if df_text is None or df_text.empty: return [] rows = df_text.sort_values(["segment_id", "sequence_id"]) sections = [] cursor = 0 for _, seg_rows in rows.groupby("segment_id"): speaker = ( seg_rows.iloc[0].get("speaker_name") or seg_rows.iloc[0].get("speaker_role") or "Unknown" ) seg_start = cursor pieces = [] for _, r in seg_rows.iterrows(): t = r["text"] or "" pieces.append(t) cursor += len(t) + 1 segment_text = " ".join(pieces) seg_end = seg_start + len(segment_text) sections.append({ "name": normalize_section_title(speaker), "start": seg_start, "end": seg_end, }) return sections def compute_judgment_sections(df_text): if df_text is None or df_text.empty: return [] rows = df_text.sort_values("paragraph_id") paragraphs = [] cursor = 0 for _, row in rows.iterrows(): ptext = row["text"] or "" start = cursor end = start + len(ptext) paragraphs.append({"text": ptext, "start": start, "end": end}) cursor = end + 1 if not paragraphs: return [] facts_idx = None for i, p in enumerate(paragraphs): if "THE FACTS" in p["text"]: facts_idx = i if facts_idx is None: return [] law_idx = None for i in range(facts_idx + 1, len(paragraphs)): if "THE LAW" in paragraphs[i]["text"]: law_idx = i if law_idx is None: return [] opinion_indices = [ i for i in range(law_idx + 1, len(paragraphs)) if "OPINION" in paragraphs[i]["text"] ] sections = [] facts_start = paragraphs[facts_idx]["start"] law_start = paragraphs[law_idx]["start"] facts_end = law_start sections.append({ "name": normalize_section_title(paragraphs[facts_idx]["text"]), "start": facts_start, "end": facts_end, }) law_end = ( paragraphs[opinion_indices[0]]["start"] if opinion_indices else paragraphs[-1]["end"] ) sections.append({ "name": normalize_section_title(paragraphs[law_idx]["text"]), "start": law_start, "end": law_end, }) for idx, op_idx in enumerate(opinion_indices): start = paragraphs[op_idx]["start"] end = ( paragraphs[opinion_indices[idx + 1]]["start"] if idx + 1 < len(opinion_indices) else paragraphs[-1]["end"] ) sections.append({ "name": normalize_section_title(paragraphs[op_idx]["text"]), "start": start, "end": end, }) return sections def render_section_guide(sections, compact_columns=None): st.markdown("### Section Guide") if not sections: st.info("No section guide could be generated for this document.") return if compact_columns: cells = [] for section in sections: name = html.escape(str(section["name"])) start = format_char_count(section["start"]) end = format_char_count(section["end"]) cells.append( "
" f"
{name}
" f"
{start} - {end}
" "
" ) st.markdown( "" f"
" + "".join(cells) + "
", unsafe_allow_html=True, ) else: rows = [] for section in sections: name = html.escape(str(section["name"])) start = format_char_count(section["start"]) end = format_char_count(section["end"]) rows.append( "
" f"{name}" f"{start} - {end}" "
" ) st.markdown( "" "
" + "".join(rows) + "
", unsafe_allow_html=True, ) # span -> bin coverage def bin_spans_into_brackets(df_ann, text_len, bin_size): if df_ann.empty: return pd.DataFrame() records = [] for _, row in df_ann.iterrows(): s = int(row["global_begin"]) e = int(min(row["global_end"], text_len)) if s >= e: continue start_bin = s // bin_size end_bin = e // bin_size for b in range(start_bin, end_bin + 1): bin_start = b * bin_size bin_end = min((b + 1) * bin_size, text_len) overlap_start = max(s, bin_start) overlap_end = min(e, bin_end) if overlap_start < overlap_end: overlap_len = overlap_end - overlap_start records.append({ "label": row["label"], "bin": b, "overlap_len": overlap_len, "bin_size": bin_size, }) if not records: return pd.DataFrame() df = pd.DataFrame(records) df = ( df.groupby(["label", "bin"], as_index=False) .agg({"overlap_len": "sum", "bin_size": "first"}) ) df["coverage_ratio"] = (df["overlap_len"] / df["bin_size"]).clip(0, 1) return df # matrix style heatmap def make_matrix_style_heatmap(df_heat, bin_size, text_len, color="#1f6aff"): if df_heat.empty: return alt.Chart(pd.DataFrame({"a": []})).mark_text(text="No annotations") df = df_heat.copy() df["bin_start"] = df["bin"] * bin_size df["bin_end"] = df["bin_start"] + bin_size heatmap_select = alt.selection_point( fields=["label", "bin"], on="click", clear="dblclick", name="heatmap_select", ) chart = ( alt.Chart(df) .mark_rect() .encode( x=alt.X("bin_start:Q", title="Character Bracket", axis=alt.Axis(format="~s")), x2="bin_end:Q", y=alt.Y( "label:N", title="Argument Type", sort=alt.SortArray(LABEL_ORDER), axis=alt.Axis(labelLimit=0, labelPadding=8), ), color=alt.Color( "coverage_ratio:Q", title="% of bin covered", scale=alt.Scale( domain=[0, 1], range=["#ffffff", color] ) ), tooltip=[ alt.Tooltip("label:N", title="Argument"), alt.Tooltip("bin_start:Q", title="Bin start", format=","), alt.Tooltip("bin_end:Q", title="Bin end", format=","), alt.Tooltip("coverage_ratio:Q", title="Coverage", format=".0%") ], ) .add_params(heatmap_select) .properties( width=1200, height=40 * df["label"].nunique(), ) ) return chart # highlighting utils def generate_color_palette(n): colors = [] for i in range(n): hue = i / max(1, n) r, g, b = colorsys.hls_to_rgb(hue, 0.6, 0.8) colors.append(f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, 0.35)") return colors def make_annotator_color_map(annotators): colors = generate_color_palette(len(annotators)) return {a: c for a, c in zip(annotators, colors)} def compute_interval_segments(text_len, spans): boundaries = {0, text_len} for s, e, _ in spans: boundaries.add(int(s)) boundaries.add(int(e)) cuts = sorted(b for b in boundaries if 0 <= b <= text_len) intervals = [] for i in range(len(cuts) - 1): s, e = cuts[i], cuts[i+1] if s >= e: continue active = [span for span in spans if span[0] < e and span[1] > s] intervals.append((s, e, [a[2] for a in active])) return intervals def render_highlighted_html(text, spans, color_map, meta_map, focus_ann_id=None): if not spans: return f"
{html.escape(text)}
" intervals = compute_interval_segments(len(text), spans) out = [] anchored = False for s, e, ann_ids in intervals: chunk = html.escape(text[s:e]) if ann_ids: bg_layers = ", ".join( f"linear-gradient({color_map[a]} 0 0)" for a in ann_ids ) tooltip_lines = [] for a in ann_ids: m = meta_map[a] tooltip_lines.append( f"{m['label']} — {m['annotator']} ({m['curation']})" ) title_attr = html.escape("\n".join(tooltip_lines)) is_focus = focus_ann_id in ann_ids if focus_ann_id else False anchor_attr = "" if is_focus and not anchored: anchor_attr = f' id="ann-{focus_ann_id}"' anchored = True focus_style = "box-shadow: inset 0 0 0 2px #111;" if is_focus else "" chunk = ( f'' f'{chunk}' ) out.append(chunk) return "
" + "".join(out) + "
" def extract_heatmap_selection(event): if event is None: return None selection = getattr(event, "selection", None) if selection is None and isinstance(event, dict): selection = event.get("selection") def pull_fields(sel): if sel is None: return None if isinstance(sel, dict): if "label" in sel and "bin" in sel: return sel if "values" in sel: return pull_fields(sel.get("values")) for value in sel.values(): extracted = pull_fields(value) if extracted: return extracted if isinstance(sel, list) and sel: return pull_fields(sel[0]) return None return pull_fields(selection) def project_spans_to_interval(spans_global, seg_start, seg_end): projected = [] for g_start, g_end, ann_id in spans_global: if g_end <= seg_start or g_start >= seg_end: continue local_start = max(g_start, seg_start) - seg_start local_end = min(g_end, seg_end) - seg_start if local_start < local_end: projected.append((local_start, local_end, ann_id)) return projected def pick_focus_annotation(df_ann, label, bin_start, bin_end): if df_ann.empty: return None df_sel = df_ann[ (df_ann["label"] == label) & (df_ann["global_begin"] < bin_end) & (df_ann["global_end"] > bin_start) ] if df_sel.empty: return None overlaps = ( df_sel.assign( overlap=lambda d: ( np.minimum(d["global_end"], bin_end) - np.maximum(d["global_begin"], bin_start) ) ) .sort_values(["overlap", "global_begin"], ascending=[False, True]) ) return overlaps.iloc[0]["annotation_id"] def scroll_to_annotation(focus_ann_id): if focus_ann_id is None: return components.html( "", height=0, ) def scroll_to_heatmap(anchor_id): if not anchor_id: return components.html( "", height=0, ) def render_floating_heatmap_button(anchor_id, button_id): if not anchor_id: return components.html( "", height=0, ) # streamlit UI st.set_page_config(page_title="Argument Heatmap Explorer", layout="wide") st.title("Argument Saturation Heatmap") app_password = os.getenv("APP_PASSWORD") if app_password: if not st.session_state.get("auth_ok"): with st.sidebar: st.markdown("### Access") pw = st.text_input("Password", type="password") if pw: if pw == app_password: st.session_state["auth_ok"] = True else: st.session_state["auth_ok"] = False st.error("Incorrect password.") if not st.session_state.get("auth_ok"): st.stop() st.caption("Rows = argument types · Columns = character bins · Color = % coverage") st.markdown( "", unsafe_allow_html=True, ) components.html( "", height=0, ) components.html( "", height=0, ) # sidebar st.sidebar.header("Load Data") hearings_ds_path = st.sidebar.text_input( "Hearings dataset path", "la_cour_dataset_hearings" ) judgments_ds_path = st.sidebar.text_input( "Judgments dataset path", "la_cour_dataset_judgments" ) # default CSV locations default_hear_csv = resolve_dataset_path("la_cour_hearings_annotations.csv") default_judg_csv = resolve_dataset_path("la_cour_judgments_annotations.csv") st.sidebar.markdown("#### Annotation CSVs") hear_ann_upload = st.sidebar.file_uploader( "Hearing annotations CSV", type="csv", key="hear_csv_upload" ) judg_ann_upload = st.sidebar.file_uploader( "Judgment annotations CSV", type="csv", key="judg_csv_upload" ) def load_csv_or_default(upload_file, default_path): if upload_file: return pd.read_csv(upload_file), f"(uploaded) {upload_file.name}" if os.path.exists(default_path): return pd.read_csv(default_path), f"(default) {default_path}" return None, "(missing)" df_hear_ann, hear_status = load_csv_or_default(hear_ann_upload, default_hear_csv) df_judg_ann, judg_status = load_csv_or_default(judg_ann_upload, default_judg_csv) st.sidebar.caption(f"Hearing CSV: {hear_status}") st.sidebar.caption(f"Judgment CSV: {judg_status}") bin_size = st.sidebar.slider( "Characters per bin", min_value=50, max_value=3000, value=400, step=50, ) heat_color = st.sidebar.color_picker("Heatmap color", value="#1f6aff") go_heatmap = st.sidebar.button("Back to heatmap") if df_hear_ann is None and df_judg_ann is None: st.info("No annotations loaded — upload a CSV or place defaults in working directory.") st.stop() # load datasets lazily ds_hear = ( load_from_disk(resolve_dataset_path(hearings_ds_path)) if df_hear_ann is not None else None ) ds_judg = ( load_from_disk(resolve_dataset_path(judgments_ds_path)) if df_judg_ann is not None else None ) df_hear_text = ds_hear.to_pandas() if ds_hear else None df_judg_text = ds_judg.to_pandas() if ds_judg else None # tab renderer def render_heatmap_tab(df_ann, df_text, title, key, is_hearing): if df_ann is None: st.warning(f"No {title.lower()} annotations loaded.") return st.subheader(f"{title} Heatmap") webcast_ids = sorted(df_ann["webcast_id"].unique()) webcast = st.selectbox("Select document", webcast_ids, key=f"wc_{key}") dfA = df_ann[df_ann["webcast_id"] == webcast] dfT = df_text[df_text["webcast_id"] == webcast] labels = sorted(dfA["label"].dropna().unique()) annotators = sorted(dfA["annotator"].dropna().unique()) c1, c2, c3, c4 = st.columns(4) sel_labels = c1.multiselect("Argument types", labels, default=labels, key=f"lbl_{key}") sel_ann = c2.multiselect("Annotators (heatmap)", annotators, default=annotators, key=f"ann_{key}") valid_only = c3.checkbox("Valid only", value=True, key=f"valid_{key}") preview_ann = c4.multiselect( "Annotators (preview)", annotators, default=annotators, key=f"hl_{key}", ) dfA = dfA[dfA["label"].isin(sel_labels) & dfA["annotator"].isin(sel_ann)] if valid_only: dfA = dfA[dfA["curation"] == "valid"] full_text = concat_global_text(dfT, webcast) if not sanity_check(dfA, len(full_text)): st.error("Annotation spans exceed text length.") return df_heat = bin_spans_into_brackets(dfA, len(full_text), bin_size) st.markdown("### Heatmap") heatmap_anchor = f"heatmap-anchor-{key}" st.markdown( f"
", unsafe_allow_html=True, ) render_floating_heatmap_button(heatmap_anchor, f"heatmap-btn-{key}") heatmap_chart = make_matrix_style_heatmap( df_heat, bin_size, len(full_text), color=heat_color ) try: heatmap_event = st.altair_chart( heatmap_chart, use_container_width=True, on_select="rerun", ) except TypeError: heatmap_event = None st.altair_chart(heatmap_chart, use_container_width=True) selected = extract_heatmap_selection(heatmap_event) selection_key = f"heat_sel_{key}" if selected: try: st.session_state[selection_key] = { "label": selected["label"], "bin": int(selected["bin"]), } except (TypeError, ValueError): pass elif heatmap_event is not None: st.session_state.pop(selection_key, None) sections = ( compute_hearing_sections(dfT) if is_hearing else compute_judgment_sections(dfT) ) render_section_guide(sections, compact_columns=10 if is_hearing else None) # highlighted text preview st.markdown("### Highlighted Text Preview") if preview_ann: annot_color_map = make_annotator_color_map(preview_ann) legend_rows = [] for annot in preview_ann: color = annot_color_map.get(annot, "rgba(0,0,0,0.25)") legend_rows.append( f"
" f"" f"" f"{html.escape(str(annot))}
" ) else: annot_color_map = {} legend_rows = [] dfH = df_ann[ (df_ann["webcast_id"] == webcast) & (df_ann["annotator"].isin(preview_ann)) ] if valid_only: dfH = dfH[dfH["curation"] == "valid"] spans_global = [ (int(r["global_begin"]), int(r["global_end"]), r["annotation_id"]) for _, r in dfH.iterrows() ] ann_id_to_annot = { r["annotation_id"]: r["annotator"] for _, r in dfH.iterrows() } combo_rows = [] if preview_ann and spans_global: intervals = compute_interval_segments(len(full_text), spans_global) seen_combos = set() for _, _, ann_ids in intervals: annotators_in_span = sorted( {ann_id_to_annot.get(a) for a in ann_ids if a in ann_id_to_annot} ) if len(annotators_in_span) <= 1: continue combo_key = tuple(annotators_in_span) if combo_key in seen_combos: continue seen_combos.add(combo_key) bg_layers = ", ".join( f"linear-gradient({annot_color_map[a]} 0 0)" for a in annotators_in_span if a in annot_color_map ) label = " + ".join(html.escape(str(a)) for a in combo_key) combo_rows.append( f"
" f"" f"{label}
" ) color_map = { r["annotation_id"]: annot_color_map.get( r["annotator"], "rgba(0,0,0,0.25)" ) for _, r in dfH.iterrows() } if legend_rows or combo_rows: st.markdown( "" "
" "
Annotators
" + "".join(legend_rows) + ( "
Combinations
" + "".join(combo_rows) if combo_rows else "" ) + "
", unsafe_allow_html=True, ) meta_map = { r["annotation_id"]: { "label": r["label"], "annotator": r["annotator"], "curation": r["curation"] } for _, r in dfH.iterrows() } focus_ann_id = None selection_state = st.session_state.get(selection_key) if selection_state and not dfH.empty: sel_label = selection_state.get("label") sel_bin = selection_state.get("bin") if sel_label is not None and sel_bin is not None: bin_start = sel_bin * bin_size bin_end = bin_start + bin_size focus_ann_id = pick_focus_annotation( dfH, sel_label, bin_start, bin_end ) # hearing preview if is_hearing: rows = dfT.sort_values(["segment_id", "sequence_id"]) html_blocks = [] cursor = 0 for seg_id, seg_rows in rows.groupby("segment_id"): speaker = ( seg_rows.iloc[0].get("speaker_name") or seg_rows.iloc[0].get("speaker_role") or "Unknown" ) pieces = [] seg_start = cursor for _, r in seg_rows.iterrows(): t = r["text"] or "" pieces.append(t) cursor += len(t) + 1 segment_text = " ".join(pieces) seg_end = seg_start + len(segment_text) local_spans = project_spans_to_interval(spans_global, seg_start, seg_end) html_blocks.append(f"{html.escape(str(speaker))}
") html_blocks.append( render_highlighted_html( segment_text, local_spans, color_map, meta_map, focus_ann_id=focus_ann_id, ) ) html_blocks.append("
") st.markdown("".join(html_blocks), unsafe_allow_html=True) # judgment preview else: rows = dfT.sort_values("paragraph_id") html_blocks = [] cursor = 0 for _, row in rows.iterrows(): ptext = row["text"] or "" seg_start = cursor seg_end = seg_start + len(ptext) local_spans = project_spans_to_interval(spans_global, seg_start, seg_end) cursor = seg_end + 1 html_blocks.append( render_highlighted_html( ptext, local_spans, color_map, meta_map, focus_ann_id=focus_ann_id, ) ) html_blocks.append("
\n") st.markdown("".join(html_blocks), unsafe_allow_html=True) scroll_to_annotation(focus_ann_id) if go_heatmap: scroll_to_heatmap(heatmap_anchor) # tabs tab1, tab2 = st.tabs(["Hearings", "Judgments"]) with tab1: render_heatmap_tab(df_hear_ann, df_hear_text, "Hearing", "hear", is_hearing=True) with tab2: render_heatmap_tab(df_judg_ann, df_judg_text, "Judgment", "judg", is_hearing=False)