public-frames / src /streamlit_app.py
yassine-thlija's picture
init public frames
63e7b1f
import streamlit as st
import pandas as pd
import numpy as np
import altair as alt
from datasets import load_from_disk
from huggingface_hub import snapshot_download
import colorsys
import html
import os
import streamlit.components.v1 as components
# text utils
LABEL_ORDER = [
"FA - Factual Argument",
"FA - Factual Question",
"FO - Formal Question",
"FO - Precedent",
"FO - Systematic Interpretation",
"FO - Textual Interpretation",
"SU - Non-Legal Argument",
"SU - Proportionality Analysis",
"SU - Substantive Question",
"SU - Teleological or Purposive Interpretation",
"Negative Frame (ISS-N)",
"Positive Frame (ISS-P)",
"Crime Frame (JUST-C)",
"Health Frame (JUST-H)",
"National Security Frame (JUST-S)",
"Rights Frame (JUST-R)",
]
def concat_global_text(df, webcast_id, text_col="text"):
rows = df[df["webcast_id"] == webcast_id]
if "segment_id" in rows.columns:
rows = rows.sort_values(["segment_id", "sequence_id"])
elif "paragraph_id" in rows.columns:
rows = rows.sort_values("paragraph_id")
return " ".join(rows[text_col].fillna("").tolist())
def sanity_check(df_ann, text_len):
if df_ann.empty:
return True
return df_ann["global_end"].max() <= text_len
def get_hf_dataset_root():
repo = os.getenv("HF_DATASET_REPO")
if not repo:
return None
cached_root = st.session_state.get("hf_dataset_root")
if cached_root:
return cached_root
token = os.getenv("HF_TOKEN")
if not token:
st.error("HF_TOKEN secret missing for private dataset access.")
st.stop()
cache_dir = os.path.join(os.getcwd(), ".hf_data_cache")
try:
snapshot_path = snapshot_download(
repo_id=repo,
repo_type="dataset",
token=token,
local_dir=cache_dir,
local_dir_use_symlinks=False,
)
except Exception as exc:
st.error(f"Failed to load dataset repo: {exc}")
st.stop()
st.session_state["hf_dataset_root"] = snapshot_path
return snapshot_path
def resolve_dataset_path(relative_path):
root = get_hf_dataset_root()
return os.path.join(root, relative_path) if root else relative_path
def format_char_count(value):
n = int(value)
if n >= 10000:
return f"{int(round(n / 1000.0))}k"
if n >= 100:
scaled = round(n / 1000.0, 1)
text = f"{scaled:.1f}".rstrip("0").rstrip(".")
return f"{text}k"
return str(n)
def normalize_section_title(text):
return " ".join(str(text).split()) if text else "Unknown"
def compute_hearing_sections(df_text):
if df_text is None or df_text.empty:
return []
rows = df_text.sort_values(["segment_id", "sequence_id"])
sections = []
cursor = 0
for _, seg_rows in rows.groupby("segment_id"):
speaker = (
seg_rows.iloc[0].get("speaker_name")
or seg_rows.iloc[0].get("speaker_role")
or "Unknown"
)
seg_start = cursor
pieces = []
for _, r in seg_rows.iterrows():
t = r["text"] or ""
pieces.append(t)
cursor += len(t) + 1
segment_text = " ".join(pieces)
seg_end = seg_start + len(segment_text)
sections.append({
"name": normalize_section_title(speaker),
"start": seg_start,
"end": seg_end,
})
return sections
def compute_judgment_sections(df_text):
if df_text is None or df_text.empty:
return []
rows = df_text.sort_values("paragraph_id")
paragraphs = []
cursor = 0
for _, row in rows.iterrows():
ptext = row["text"] or ""
start = cursor
end = start + len(ptext)
paragraphs.append({"text": ptext, "start": start, "end": end})
cursor = end + 1
if not paragraphs:
return []
facts_idx = None
for i, p in enumerate(paragraphs):
if "THE FACTS" in p["text"]:
facts_idx = i
if facts_idx is None:
return []
law_idx = None
for i in range(facts_idx + 1, len(paragraphs)):
if "THE LAW" in paragraphs[i]["text"]:
law_idx = i
if law_idx is None:
return []
opinion_indices = [
i for i in range(law_idx + 1, len(paragraphs))
if "OPINION" in paragraphs[i]["text"]
]
sections = []
facts_start = paragraphs[facts_idx]["start"]
law_start = paragraphs[law_idx]["start"]
facts_end = law_start
sections.append({
"name": normalize_section_title(paragraphs[facts_idx]["text"]),
"start": facts_start,
"end": facts_end,
})
law_end = (
paragraphs[opinion_indices[0]]["start"]
if opinion_indices
else paragraphs[-1]["end"]
)
sections.append({
"name": normalize_section_title(paragraphs[law_idx]["text"]),
"start": law_start,
"end": law_end,
})
for idx, op_idx in enumerate(opinion_indices):
start = paragraphs[op_idx]["start"]
end = (
paragraphs[opinion_indices[idx + 1]]["start"]
if idx + 1 < len(opinion_indices)
else paragraphs[-1]["end"]
)
sections.append({
"name": normalize_section_title(paragraphs[op_idx]["text"]),
"start": start,
"end": end,
})
return sections
def render_section_guide(sections, compact_columns=None):
st.markdown("### Section Guide")
if not sections:
st.info("No section guide could be generated for this document.")
return
if compact_columns:
cells = []
for section in sections:
name = html.escape(str(section["name"]))
start = format_char_count(section["start"])
end = format_char_count(section["end"])
cells.append(
"<div class='section-guide__cell'>"
f"<div class='section-guide__name'>{name}</div>"
f"<div class='section-guide__range'>{start} - {end}</div>"
"</div>"
)
st.markdown(
"<style>"
".section-guide{border:1px solid #e6e6e6;border-radius:6px;"
"padding:8px 10px;margin:6px 0 12px;}"
".section-guide__grid{display:grid;gap:8px;}"
".section-guide__cell{padding:6px 8px;border:1px dashed #eee;"
"border-radius:6px;background:#fafafa;}"
".section-guide__name{font-weight:600;font-size:12px;color:#222;}"
".section-guide__range{font-family:monospace;font-size:11px;"
"color:#333;white-space:nowrap;margin-top:2px;}"
"</style>"
f"<div class='section-guide section-guide__grid' "
f"style='grid-template-columns:repeat({int(compact_columns)}, minmax(0, 1fr));'>"
+ "".join(cells)
+ "</div>",
unsafe_allow_html=True,
)
else:
rows = []
for section in sections:
name = html.escape(str(section["name"]))
start = format_char_count(section["start"])
end = format_char_count(section["end"])
rows.append(
"<div class='section-guide__row'>"
f"<span class='section-guide__name'>{name}</span>"
f"<span class='section-guide__range'>{start} - {end}</span>"
"</div>"
)
st.markdown(
"<style>"
".section-guide{border:1px solid #e6e6e6;border-radius:6px;"
"padding:8px 10px;margin:6px 0 12px;}"
".section-guide__row{display:flex;justify-content:space-between;"
"gap:12px;align-items:baseline;padding:4px 0;"
"border-bottom:1px dashed #eee;}"
".section-guide__row:last-child{border-bottom:none;}"
".section-guide__name{font-weight:600;font-size:13px;color:#222;"
"flex:1 1 auto;}"
".section-guide__range{font-family:monospace;font-size:12px;"
"color:#333;white-space:nowrap;}"
"</style>"
"<div class='section-guide'>"
+ "".join(rows)
+ "</div>",
unsafe_allow_html=True,
)
# span -> bin coverage
def bin_spans_into_brackets(df_ann, text_len, bin_size):
if df_ann.empty:
return pd.DataFrame()
records = []
for _, row in df_ann.iterrows():
s = int(row["global_begin"])
e = int(min(row["global_end"], text_len))
if s >= e:
continue
start_bin = s // bin_size
end_bin = e // bin_size
for b in range(start_bin, end_bin + 1):
bin_start = b * bin_size
bin_end = min((b + 1) * bin_size, text_len)
overlap_start = max(s, bin_start)
overlap_end = min(e, bin_end)
if overlap_start < overlap_end:
overlap_len = overlap_end - overlap_start
records.append({
"label": row["label"],
"bin": b,
"overlap_len": overlap_len,
"bin_size": bin_size,
})
if not records:
return pd.DataFrame()
df = pd.DataFrame(records)
df = (
df.groupby(["label", "bin"], as_index=False)
.agg({"overlap_len": "sum", "bin_size": "first"})
)
df["coverage_ratio"] = (df["overlap_len"] / df["bin_size"]).clip(0, 1)
return df
# matrix style heatmap
def make_matrix_style_heatmap(df_heat, bin_size, text_len, color="#1f6aff"):
if df_heat.empty:
return alt.Chart(pd.DataFrame({"a": []})).mark_text(text="No annotations")
df = df_heat.copy()
df["bin_start"] = df["bin"] * bin_size
df["bin_end"] = df["bin_start"] + bin_size
heatmap_select = alt.selection_point(
fields=["label", "bin"],
on="click",
clear="dblclick",
name="heatmap_select",
)
chart = (
alt.Chart(df)
.mark_rect()
.encode(
x=alt.X("bin_start:Q",
title="Character Bracket",
axis=alt.Axis(format="~s")),
x2="bin_end:Q",
y=alt.Y(
"label:N",
title="Argument Type",
sort=alt.SortArray(LABEL_ORDER),
axis=alt.Axis(labelLimit=0, labelPadding=8),
),
color=alt.Color(
"coverage_ratio:Q",
title="% of bin covered",
scale=alt.Scale(
domain=[0, 1],
range=["#ffffff", color]
)
),
tooltip=[
alt.Tooltip("label:N", title="Argument"),
alt.Tooltip("bin_start:Q", title="Bin start", format=","),
alt.Tooltip("bin_end:Q", title="Bin end", format=","),
alt.Tooltip("coverage_ratio:Q", title="Coverage", format=".0%")
],
)
.add_params(heatmap_select)
.properties(
width=1200,
height=40 * df["label"].nunique(),
)
)
return chart
# highlighting utils
def generate_color_palette(n):
colors = []
for i in range(n):
hue = i / max(1, n)
r, g, b = colorsys.hls_to_rgb(hue, 0.6, 0.8)
colors.append(f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, 0.35)")
return colors
def make_annotator_color_map(annotators):
colors = generate_color_palette(len(annotators))
return {a: c for a, c in zip(annotators, colors)}
def compute_interval_segments(text_len, spans):
boundaries = {0, text_len}
for s, e, _ in spans:
boundaries.add(int(s))
boundaries.add(int(e))
cuts = sorted(b for b in boundaries if 0 <= b <= text_len)
intervals = []
for i in range(len(cuts) - 1):
s, e = cuts[i], cuts[i+1]
if s >= e:
continue
active = [span for span in spans if span[0] < e and span[1] > s]
intervals.append((s, e, [a[2] for a in active]))
return intervals
def render_highlighted_html(text, spans, color_map, meta_map, focus_ann_id=None):
if not spans:
return f"<pre>{html.escape(text)}</pre>"
intervals = compute_interval_segments(len(text), spans)
out = []
anchored = False
for s, e, ann_ids in intervals:
chunk = html.escape(text[s:e])
if ann_ids:
bg_layers = ", ".join(
f"linear-gradient({color_map[a]} 0 0)" for a in ann_ids
)
tooltip_lines = []
for a in ann_ids:
m = meta_map[a]
tooltip_lines.append(
f"{m['label']}{m['annotator']} ({m['curation']})"
)
title_attr = html.escape("\n".join(tooltip_lines))
is_focus = focus_ann_id in ann_ids if focus_ann_id else False
anchor_attr = ""
if is_focus and not anchored:
anchor_attr = f' id="ann-{focus_ann_id}"'
anchored = True
focus_style = "box-shadow: inset 0 0 0 2px #111;" if is_focus else ""
chunk = (
f'<span title="{title_attr}"{anchor_attr} '
f'style="background:{bg_layers};'
f'background-blend-mode:multiply;{focus_style}">'
f'{chunk}</span>'
)
out.append(chunk)
return "<pre style='line-height:1.5'>" + "".join(out) + "</pre>"
def extract_heatmap_selection(event):
if event is None:
return None
selection = getattr(event, "selection", None)
if selection is None and isinstance(event, dict):
selection = event.get("selection")
def pull_fields(sel):
if sel is None:
return None
if isinstance(sel, dict):
if "label" in sel and "bin" in sel:
return sel
if "values" in sel:
return pull_fields(sel.get("values"))
for value in sel.values():
extracted = pull_fields(value)
if extracted:
return extracted
if isinstance(sel, list) and sel:
return pull_fields(sel[0])
return None
return pull_fields(selection)
def project_spans_to_interval(spans_global, seg_start, seg_end):
projected = []
for g_start, g_end, ann_id in spans_global:
if g_end <= seg_start or g_start >= seg_end:
continue
local_start = max(g_start, seg_start) - seg_start
local_end = min(g_end, seg_end) - seg_start
if local_start < local_end:
projected.append((local_start, local_end, ann_id))
return projected
def pick_focus_annotation(df_ann, label, bin_start, bin_end):
if df_ann.empty:
return None
df_sel = df_ann[
(df_ann["label"] == label)
& (df_ann["global_begin"] < bin_end)
& (df_ann["global_end"] > bin_start)
]
if df_sel.empty:
return None
overlaps = (
df_sel.assign(
overlap=lambda d: (
np.minimum(d["global_end"], bin_end)
- np.maximum(d["global_begin"], bin_start)
)
)
.sort_values(["overlap", "global_begin"], ascending=[False, True])
)
return overlaps.iloc[0]["annotation_id"]
def scroll_to_annotation(focus_ann_id):
if focus_ann_id is None:
return
components.html(
"<script>"
"const targetId = 'ann-" + str(focus_ann_id) + "';"
"const tryScroll = () => {"
" const el = window.parent.document.getElementById(targetId);"
" if (el) {"
" el.scrollIntoView({behavior: 'smooth', block: 'center'});"
" return true;"
" }"
" return false;"
"};"
"if (!tryScroll()) {"
" setTimeout(tryScroll, 150);"
"}"
"</script>",
height=0,
)
def scroll_to_heatmap(anchor_id):
if not anchor_id:
return
components.html(
"<script>"
"const targetId = '" + str(anchor_id) + "';"
"const tryScroll = () => {"
" const el = window.parent.document.getElementById(targetId);"
" if (el) {"
" el.scrollIntoView({behavior: 'smooth', block: 'start'});"
" return true;"
" }"
" return false;"
"};"
"if (!tryScroll()) {"
" setTimeout(tryScroll, 150);"
"}"
"</script>",
height=0,
)
def render_floating_heatmap_button(anchor_id, button_id):
if not anchor_id:
return
components.html(
"<script>"
"const btnId = '" + str(button_id) + "';"
"const anchorId = '" + str(anchor_id) + "';"
"const doc = window.document;"
"let btn = doc.getElementById(btnId);"
"if (!btn) {"
" btn = doc.createElement('button');"
" btn.id = btnId;"
" btn.textContent = 'Back to heatmap';"
" btn.style.position = 'fixed';"
" btn.style.right = '16px';"
" btn.style.bottom = '16px';"
" btn.style.zIndex = '2147483647';"
" btn.style.padding = '8px 12px';"
" btn.style.border = '1px solid #ccc';"
" btn.style.borderRadius = '8px';"
" btn.style.background = '#fff';"
" btn.style.color = '#222';"
" btn.style.boxShadow = '0 2px 6px rgba(0,0,0,0.12)';"
" btn.style.cursor = 'pointer';"
" btn.style.transform = 'none';"
" btn.style.margin = '0';"
" btn.style.pointerEvents = 'auto';"
" doc.body.appendChild(btn);"
"}"
"btn.onclick = () => {"
" let el = doc.getElementById(anchorId);"
" if (!el) {"
" try { el = window.parent.document.getElementById(anchorId); } catch (e) {}"
" }"
" if (el) {"
" el.scrollIntoView({behavior: 'smooth', block: 'start'});"
" }"
"};"
"</script>",
height=0,
)
# streamlit UI
st.set_page_config(page_title="Argument Heatmap Explorer", layout="wide")
st.title("Argument Saturation Heatmap")
app_password = os.getenv("APP_PASSWORD")
if app_password:
if not st.session_state.get("auth_ok"):
with st.sidebar:
st.markdown("### Access")
pw = st.text_input("Password", type="password")
if pw:
if pw == app_password:
st.session_state["auth_ok"] = True
else:
st.session_state["auth_ok"] = False
st.error("Incorrect password.")
if not st.session_state.get("auth_ok"):
st.stop()
st.caption("Rows = argument types · Columns = character bins · Color = % coverage")
st.markdown(
"<style>"
"[data-testid='stSidebar']{position:fixed;top:0;left:0;height:100vh;}"
"[data-testid='stSidebar'] > div:first-child{height:100vh;overflow:auto;}"
"</style>",
unsafe_allow_html=True,
)
components.html(
"<script>"
"if (!window._backspaceScrollBound) {"
" window._backspaceScrollBound = true;"
" window.addEventListener('keydown', (e) => {"
" const tag = (e.target && e.target.tagName) || '';"
" const isInput = tag === 'INPUT' || tag === 'TEXTAREA' || e.target.isContentEditable;"
" if (!isInput && e.key === 'Backspace') {"
" e.preventDefault();"
" window.scrollTo({top: 0, behavior: 'smooth'});"
" }"
" });"
"}"
"</script>",
height=0,
)
components.html(
"<script>"
"const lockSidebar = () => {"
" const doc = window.parent.document;"
" const sidebar = doc.querySelector('[data-testid=\"stSidebar\"], .stSidebar');"
" if (!sidebar) return false;"
" sidebar.style.position = 'fixed';"
" sidebar.style.top = '0';"
" sidebar.style.left = '0';"
" sidebar.style.height = '100vh';"
" sidebar.style.zIndex = '999';"
" const inner = sidebar.querySelector('div');"
" if (inner) {"
" inner.style.height = '100vh';"
" inner.style.overflow = 'auto';"
" }"
" const main = doc.querySelector('[data-testid=\"stAppViewContainer\"], .main');"
" if (main) {"
" const w = sidebar.getBoundingClientRect().width;"
" main.style.marginLeft = `${w}px`;"
" }"
" return true;"
"};"
"if (!lockSidebar()) {"
" setTimeout(lockSidebar, 200);"
" setTimeout(lockSidebar, 800);"
"}"
"</script>",
height=0,
)
# sidebar
st.sidebar.header("Load Data")
hearings_ds_path = st.sidebar.text_input(
"Hearings dataset path",
"la_cour_dataset_hearings"
)
judgments_ds_path = st.sidebar.text_input(
"Judgments dataset path",
"la_cour_dataset_judgments"
)
# default CSV locations
default_hear_csv = resolve_dataset_path("la_cour_hearings_annotations.csv")
default_judg_csv = resolve_dataset_path("la_cour_judgments_annotations.csv")
st.sidebar.markdown("#### Annotation CSVs")
hear_ann_upload = st.sidebar.file_uploader(
"Hearing annotations CSV",
type="csv",
key="hear_csv_upload"
)
judg_ann_upload = st.sidebar.file_uploader(
"Judgment annotations CSV",
type="csv",
key="judg_csv_upload"
)
def load_csv_or_default(upload_file, default_path):
if upload_file:
return pd.read_csv(upload_file), f"(uploaded) {upload_file.name}"
if os.path.exists(default_path):
return pd.read_csv(default_path), f"(default) {default_path}"
return None, "(missing)"
df_hear_ann, hear_status = load_csv_or_default(hear_ann_upload, default_hear_csv)
df_judg_ann, judg_status = load_csv_or_default(judg_ann_upload, default_judg_csv)
st.sidebar.caption(f"Hearing CSV: {hear_status}")
st.sidebar.caption(f"Judgment CSV: {judg_status}")
bin_size = st.sidebar.slider(
"Characters per bin",
min_value=50,
max_value=3000,
value=400,
step=50,
)
heat_color = st.sidebar.color_picker("Heatmap color", value="#1f6aff")
go_heatmap = st.sidebar.button("Back to heatmap")
if df_hear_ann is None and df_judg_ann is None:
st.info("No annotations loaded — upload a CSV or place defaults in working directory.")
st.stop()
# load datasets lazily
ds_hear = (
load_from_disk(resolve_dataset_path(hearings_ds_path))
if df_hear_ann is not None
else None
)
ds_judg = (
load_from_disk(resolve_dataset_path(judgments_ds_path))
if df_judg_ann is not None
else None
)
df_hear_text = ds_hear.to_pandas() if ds_hear else None
df_judg_text = ds_judg.to_pandas() if ds_judg else None
# tab renderer
def render_heatmap_tab(df_ann, df_text, title, key, is_hearing):
if df_ann is None:
st.warning(f"No {title.lower()} annotations loaded.")
return
st.subheader(f"{title} Heatmap")
webcast_ids = sorted(df_ann["webcast_id"].unique())
webcast = st.selectbox("Select document", webcast_ids, key=f"wc_{key}")
dfA = df_ann[df_ann["webcast_id"] == webcast]
dfT = df_text[df_text["webcast_id"] == webcast]
labels = sorted(dfA["label"].dropna().unique())
annotators = sorted(dfA["annotator"].dropna().unique())
c1, c2, c3, c4 = st.columns(4)
sel_labels = c1.multiselect("Argument types", labels, default=labels, key=f"lbl_{key}")
sel_ann = c2.multiselect("Annotators (heatmap)", annotators, default=annotators, key=f"ann_{key}")
valid_only = c3.checkbox("Valid only", value=True, key=f"valid_{key}")
preview_ann = c4.multiselect(
"Annotators (preview)",
annotators,
default=annotators,
key=f"hl_{key}",
)
dfA = dfA[dfA["label"].isin(sel_labels) & dfA["annotator"].isin(sel_ann)]
if valid_only:
dfA = dfA[dfA["curation"] == "valid"]
full_text = concat_global_text(dfT, webcast)
if not sanity_check(dfA, len(full_text)):
st.error("Annotation spans exceed text length.")
return
df_heat = bin_spans_into_brackets(dfA, len(full_text), bin_size)
st.markdown("### Heatmap")
heatmap_anchor = f"heatmap-anchor-{key}"
st.markdown(
f"<div id='{heatmap_anchor}'></div>",
unsafe_allow_html=True,
)
render_floating_heatmap_button(heatmap_anchor, f"heatmap-btn-{key}")
heatmap_chart = make_matrix_style_heatmap(
df_heat, bin_size, len(full_text), color=heat_color
)
try:
heatmap_event = st.altair_chart(
heatmap_chart,
use_container_width=True,
on_select="rerun",
)
except TypeError:
heatmap_event = None
st.altair_chart(heatmap_chart, use_container_width=True)
selected = extract_heatmap_selection(heatmap_event)
selection_key = f"heat_sel_{key}"
if selected:
try:
st.session_state[selection_key] = {
"label": selected["label"],
"bin": int(selected["bin"]),
}
except (TypeError, ValueError):
pass
elif heatmap_event is not None:
st.session_state.pop(selection_key, None)
sections = (
compute_hearing_sections(dfT)
if is_hearing
else compute_judgment_sections(dfT)
)
render_section_guide(sections, compact_columns=10 if is_hearing else None)
# highlighted text preview
st.markdown("### Highlighted Text Preview")
if preview_ann:
annot_color_map = make_annotator_color_map(preview_ann)
legend_rows = []
for annot in preview_ann:
color = annot_color_map.get(annot, "rgba(0,0,0,0.25)")
legend_rows.append(
f"<div class='annotator-legend__row'>"
f"<span class='annotator-legend__swatch' "
f"style='background:{color}'></span>"
f"<span class='annotator-legend__label'>"
f"{html.escape(str(annot))}</span></div>"
)
else:
annot_color_map = {}
legend_rows = []
dfH = df_ann[
(df_ann["webcast_id"] == webcast)
& (df_ann["annotator"].isin(preview_ann))
]
if valid_only:
dfH = dfH[dfH["curation"] == "valid"]
spans_global = [
(int(r["global_begin"]), int(r["global_end"]), r["annotation_id"])
for _, r in dfH.iterrows()
]
ann_id_to_annot = {
r["annotation_id"]: r["annotator"]
for _, r in dfH.iterrows()
}
combo_rows = []
if preview_ann and spans_global:
intervals = compute_interval_segments(len(full_text), spans_global)
seen_combos = set()
for _, _, ann_ids in intervals:
annotators_in_span = sorted(
{ann_id_to_annot.get(a) for a in ann_ids if a in ann_id_to_annot}
)
if len(annotators_in_span) <= 1:
continue
combo_key = tuple(annotators_in_span)
if combo_key in seen_combos:
continue
seen_combos.add(combo_key)
bg_layers = ", ".join(
f"linear-gradient({annot_color_map[a]} 0 0)"
for a in annotators_in_span
if a in annot_color_map
)
label = " + ".join(html.escape(str(a)) for a in combo_key)
combo_rows.append(
f"<div class='annotator-legend__row'>"
f"<span class='annotator-legend__swatch' "
f"style='background:{bg_layers};"
f"background-blend-mode:multiply;'></span>"
f"<span class='annotator-legend__label'>{label}</span></div>"
)
color_map = {
r["annotation_id"]: annot_color_map.get(
r["annotator"], "rgba(0,0,0,0.25)"
)
for _, r in dfH.iterrows()
}
if legend_rows or combo_rows:
st.markdown(
"<style>"
".annotator-legend{position:sticky;top:0;background:white;"
"padding:6px 8px;border:1px solid #e6e6e6;border-radius:6px;"
"z-index:10;margin:6px 0 12px 0;display:inline-block;}"
".annotator-legend__row{display:flex;align-items:center;"
"gap:8px;margin:2px 0;}"
".annotator-legend__swatch{width:16px;height:16px;"
"border-radius:3px;display:inline-block;}"
".annotator-legend__label{font-size:12px;color:#222;}"
".annotator-legend__section{font-size:11px;margin:4px 0 2px;"
"color:#666;}"
"</style>"
"<div class='annotator-legend'>"
"<div class='annotator-legend__section'>Annotators</div>"
+ "".join(legend_rows)
+ (
"<div class='annotator-legend__section'>Combinations</div>"
+ "".join(combo_rows)
if combo_rows else ""
)
+ "</div>",
unsafe_allow_html=True,
)
meta_map = {
r["annotation_id"]: {
"label": r["label"],
"annotator": r["annotator"],
"curation": r["curation"]
}
for _, r in dfH.iterrows()
}
focus_ann_id = None
selection_state = st.session_state.get(selection_key)
if selection_state and not dfH.empty:
sel_label = selection_state.get("label")
sel_bin = selection_state.get("bin")
if sel_label is not None and sel_bin is not None:
bin_start = sel_bin * bin_size
bin_end = bin_start + bin_size
focus_ann_id = pick_focus_annotation(
dfH, sel_label, bin_start, bin_end
)
# hearing preview
if is_hearing:
rows = dfT.sort_values(["segment_id", "sequence_id"])
html_blocks = []
cursor = 0
for seg_id, seg_rows in rows.groupby("segment_id"):
speaker = (
seg_rows.iloc[0].get("speaker_name")
or seg_rows.iloc[0].get("speaker_role")
or "Unknown"
)
pieces = []
seg_start = cursor
for _, r in seg_rows.iterrows():
t = r["text"] or ""
pieces.append(t)
cursor += len(t) + 1
segment_text = " ".join(pieces)
seg_end = seg_start + len(segment_text)
local_spans = project_spans_to_interval(spans_global, seg_start, seg_end)
html_blocks.append(f"<b>{html.escape(str(speaker))}</b><br>")
html_blocks.append(
render_highlighted_html(
segment_text,
local_spans,
color_map,
meta_map,
focus_ann_id=focus_ann_id,
)
)
html_blocks.append("<br>")
st.markdown("".join(html_blocks), unsafe_allow_html=True)
# judgment preview
else:
rows = dfT.sort_values("paragraph_id")
html_blocks = []
cursor = 0
for _, row in rows.iterrows():
ptext = row["text"] or ""
seg_start = cursor
seg_end = seg_start + len(ptext)
local_spans = project_spans_to_interval(spans_global, seg_start, seg_end)
cursor = seg_end + 1
html_blocks.append(
render_highlighted_html(
ptext,
local_spans,
color_map,
meta_map,
focus_ann_id=focus_ann_id,
)
)
html_blocks.append("<br>\n")
st.markdown("".join(html_blocks), unsafe_allow_html=True)
scroll_to_annotation(focus_ann_id)
if go_heatmap:
scroll_to_heatmap(heatmap_anchor)
# tabs
tab1, tab2 = st.tabs(["Hearings", "Judgments"])
with tab1:
render_heatmap_tab(df_hear_ann, df_hear_text, "Hearing", "hear", is_hearing=True)
with tab2:
render_heatmap_tab(df_judg_ann, df_judg_text, "Judgment", "judg", is_hearing=False)