yassine-thlija commited on
Commit
63e7b1f
·
1 Parent(s): 80a6e06

init public frames

Browse files
Files changed (2) hide show
  1. requirements.txt +4 -1
  2. src/streamlit_app.py +1076 -38
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
  altair
2
  pandas
3
- streamlit
 
 
 
 
1
  altair
2
  pandas
3
+ streamlit
4
+ numpy
5
+ pandas
6
+ datasets
src/streamlit_app.py CHANGED
@@ -1,40 +1,1078 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import altair as alt
5
+ from datasets import load_from_disk
6
+ from huggingface_hub import snapshot_download
7
+ import colorsys
8
+ import html
9
+ import os
10
+ import streamlit.components.v1 as components
11
+
12
+ # text utils
13
+
14
+ LABEL_ORDER = [
15
+ "FA - Factual Argument",
16
+ "FA - Factual Question",
17
+ "FO - Formal Question",
18
+ "FO - Precedent",
19
+ "FO - Systematic Interpretation",
20
+ "FO - Textual Interpretation",
21
+ "SU - Non-Legal Argument",
22
+ "SU - Proportionality Analysis",
23
+ "SU - Substantive Question",
24
+ "SU - Teleological or Purposive Interpretation",
25
+ "Negative Frame (ISS-N)",
26
+ "Positive Frame (ISS-P)",
27
+ "Crime Frame (JUST-C)",
28
+ "Health Frame (JUST-H)",
29
+ "National Security Frame (JUST-S)",
30
+ "Rights Frame (JUST-R)",
31
+ ]
32
+
33
+ def concat_global_text(df, webcast_id, text_col="text"):
34
+ rows = df[df["webcast_id"] == webcast_id]
35
+
36
+ if "segment_id" in rows.columns:
37
+ rows = rows.sort_values(["segment_id", "sequence_id"])
38
+ elif "paragraph_id" in rows.columns:
39
+ rows = rows.sort_values("paragraph_id")
40
+
41
+ return " ".join(rows[text_col].fillna("").tolist())
42
+
43
+
44
+ def sanity_check(df_ann, text_len):
45
+ if df_ann.empty:
46
+ return True
47
+ return df_ann["global_end"].max() <= text_len
48
+
49
+
50
+ def get_hf_dataset_root():
51
+ repo = os.getenv("HF_DATASET_REPO")
52
+ if not repo:
53
+ return None
54
+ cached_root = st.session_state.get("hf_dataset_root")
55
+ if cached_root:
56
+ return cached_root
57
+ token = os.getenv("HF_TOKEN")
58
+ if not token:
59
+ st.error("HF_TOKEN secret missing for private dataset access.")
60
+ st.stop()
61
+ cache_dir = os.path.join(os.getcwd(), ".hf_data_cache")
62
+ try:
63
+ snapshot_path = snapshot_download(
64
+ repo_id=repo,
65
+ repo_type="dataset",
66
+ token=token,
67
+ local_dir=cache_dir,
68
+ local_dir_use_symlinks=False,
69
+ )
70
+ except Exception as exc:
71
+ st.error(f"Failed to load dataset repo: {exc}")
72
+ st.stop()
73
+ st.session_state["hf_dataset_root"] = snapshot_path
74
+ return snapshot_path
75
+
76
+
77
+ def resolve_dataset_path(relative_path):
78
+ root = get_hf_dataset_root()
79
+ return os.path.join(root, relative_path) if root else relative_path
80
+
81
+
82
+ def format_char_count(value):
83
+ n = int(value)
84
+ if n >= 10000:
85
+ return f"{int(round(n / 1000.0))}k"
86
+ if n >= 100:
87
+ scaled = round(n / 1000.0, 1)
88
+ text = f"{scaled:.1f}".rstrip("0").rstrip(".")
89
+ return f"{text}k"
90
+ return str(n)
91
+
92
+
93
+ def normalize_section_title(text):
94
+ return " ".join(str(text).split()) if text else "Unknown"
95
+
96
+
97
+ def compute_hearing_sections(df_text):
98
+ if df_text is None or df_text.empty:
99
+ return []
100
+
101
+ rows = df_text.sort_values(["segment_id", "sequence_id"])
102
+ sections = []
103
+ cursor = 0
104
+
105
+ for _, seg_rows in rows.groupby("segment_id"):
106
+ speaker = (
107
+ seg_rows.iloc[0].get("speaker_name")
108
+ or seg_rows.iloc[0].get("speaker_role")
109
+ or "Unknown"
110
+ )
111
+ seg_start = cursor
112
+ pieces = []
113
+
114
+ for _, r in seg_rows.iterrows():
115
+ t = r["text"] or ""
116
+ pieces.append(t)
117
+ cursor += len(t) + 1
118
+
119
+ segment_text = " ".join(pieces)
120
+ seg_end = seg_start + len(segment_text)
121
+
122
+ sections.append({
123
+ "name": normalize_section_title(speaker),
124
+ "start": seg_start,
125
+ "end": seg_end,
126
+ })
127
+
128
+ return sections
129
+
130
+
131
+ def compute_judgment_sections(df_text):
132
+ if df_text is None or df_text.empty:
133
+ return []
134
+
135
+ rows = df_text.sort_values("paragraph_id")
136
+ paragraphs = []
137
+ cursor = 0
138
+
139
+ for _, row in rows.iterrows():
140
+ ptext = row["text"] or ""
141
+ start = cursor
142
+ end = start + len(ptext)
143
+ paragraphs.append({"text": ptext, "start": start, "end": end})
144
+ cursor = end + 1
145
+
146
+ if not paragraphs:
147
+ return []
148
+
149
+ facts_idx = None
150
+ for i, p in enumerate(paragraphs):
151
+ if "THE FACTS" in p["text"]:
152
+ facts_idx = i
153
+
154
+ if facts_idx is None:
155
+ return []
156
+
157
+ law_idx = None
158
+ for i in range(facts_idx + 1, len(paragraphs)):
159
+ if "THE LAW" in paragraphs[i]["text"]:
160
+ law_idx = i
161
+
162
+ if law_idx is None:
163
+ return []
164
+
165
+ opinion_indices = [
166
+ i for i in range(law_idx + 1, len(paragraphs))
167
+ if "OPINION" in paragraphs[i]["text"]
168
+ ]
169
+
170
+ sections = []
171
+ facts_start = paragraphs[facts_idx]["start"]
172
+ law_start = paragraphs[law_idx]["start"]
173
+ facts_end = law_start
174
+ sections.append({
175
+ "name": normalize_section_title(paragraphs[facts_idx]["text"]),
176
+ "start": facts_start,
177
+ "end": facts_end,
178
+ })
179
+
180
+ law_end = (
181
+ paragraphs[opinion_indices[0]]["start"]
182
+ if opinion_indices
183
+ else paragraphs[-1]["end"]
184
+ )
185
+ sections.append({
186
+ "name": normalize_section_title(paragraphs[law_idx]["text"]),
187
+ "start": law_start,
188
+ "end": law_end,
189
+ })
190
+
191
+ for idx, op_idx in enumerate(opinion_indices):
192
+ start = paragraphs[op_idx]["start"]
193
+ end = (
194
+ paragraphs[opinion_indices[idx + 1]]["start"]
195
+ if idx + 1 < len(opinion_indices)
196
+ else paragraphs[-1]["end"]
197
+ )
198
+ sections.append({
199
+ "name": normalize_section_title(paragraphs[op_idx]["text"]),
200
+ "start": start,
201
+ "end": end,
202
+ })
203
+
204
+ return sections
205
+
206
+
207
+ def render_section_guide(sections, compact_columns=None):
208
+ st.markdown("### Section Guide")
209
+ if not sections:
210
+ st.info("No section guide could be generated for this document.")
211
+ return
212
+
213
+ if compact_columns:
214
+ cells = []
215
+ for section in sections:
216
+ name = html.escape(str(section["name"]))
217
+ start = format_char_count(section["start"])
218
+ end = format_char_count(section["end"])
219
+ cells.append(
220
+ "<div class='section-guide__cell'>"
221
+ f"<div class='section-guide__name'>{name}</div>"
222
+ f"<div class='section-guide__range'>{start} - {end}</div>"
223
+ "</div>"
224
+ )
225
+ st.markdown(
226
+ "<style>"
227
+ ".section-guide{border:1px solid #e6e6e6;border-radius:6px;"
228
+ "padding:8px 10px;margin:6px 0 12px;}"
229
+ ".section-guide__grid{display:grid;gap:8px;}"
230
+ ".section-guide__cell{padding:6px 8px;border:1px dashed #eee;"
231
+ "border-radius:6px;background:#fafafa;}"
232
+ ".section-guide__name{font-weight:600;font-size:12px;color:#222;}"
233
+ ".section-guide__range{font-family:monospace;font-size:11px;"
234
+ "color:#333;white-space:nowrap;margin-top:2px;}"
235
+ "</style>"
236
+ f"<div class='section-guide section-guide__grid' "
237
+ f"style='grid-template-columns:repeat({int(compact_columns)}, minmax(0, 1fr));'>"
238
+ + "".join(cells)
239
+ + "</div>",
240
+ unsafe_allow_html=True,
241
+ )
242
+ else:
243
+ rows = []
244
+ for section in sections:
245
+ name = html.escape(str(section["name"]))
246
+ start = format_char_count(section["start"])
247
+ end = format_char_count(section["end"])
248
+ rows.append(
249
+ "<div class='section-guide__row'>"
250
+ f"<span class='section-guide__name'>{name}</span>"
251
+ f"<span class='section-guide__range'>{start} - {end}</span>"
252
+ "</div>"
253
+ )
254
+ st.markdown(
255
+ "<style>"
256
+ ".section-guide{border:1px solid #e6e6e6;border-radius:6px;"
257
+ "padding:8px 10px;margin:6px 0 12px;}"
258
+ ".section-guide__row{display:flex;justify-content:space-between;"
259
+ "gap:12px;align-items:baseline;padding:4px 0;"
260
+ "border-bottom:1px dashed #eee;}"
261
+ ".section-guide__row:last-child{border-bottom:none;}"
262
+ ".section-guide__name{font-weight:600;font-size:13px;color:#222;"
263
+ "flex:1 1 auto;}"
264
+ ".section-guide__range{font-family:monospace;font-size:12px;"
265
+ "color:#333;white-space:nowrap;}"
266
+ "</style>"
267
+ "<div class='section-guide'>"
268
+ + "".join(rows)
269
+ + "</div>",
270
+ unsafe_allow_html=True,
271
+ )
272
+
273
+
274
+ # span -> bin coverage
275
+
276
+ def bin_spans_into_brackets(df_ann, text_len, bin_size):
277
+
278
+ if df_ann.empty:
279
+ return pd.DataFrame()
280
+
281
+ records = []
282
+
283
+ for _, row in df_ann.iterrows():
284
+ s = int(row["global_begin"])
285
+ e = int(min(row["global_end"], text_len))
286
+ if s >= e:
287
+ continue
288
+
289
+ start_bin = s // bin_size
290
+ end_bin = e // bin_size
291
+
292
+ for b in range(start_bin, end_bin + 1):
293
+ bin_start = b * bin_size
294
+ bin_end = min((b + 1) * bin_size, text_len)
295
+
296
+ overlap_start = max(s, bin_start)
297
+ overlap_end = min(e, bin_end)
298
+
299
+ if overlap_start < overlap_end:
300
+ overlap_len = overlap_end - overlap_start
301
+
302
+ records.append({
303
+ "label": row["label"],
304
+ "bin": b,
305
+ "overlap_len": overlap_len,
306
+ "bin_size": bin_size,
307
+ })
308
+
309
+ if not records:
310
+ return pd.DataFrame()
311
+
312
+ df = pd.DataFrame(records)
313
+
314
+ df = (
315
+ df.groupby(["label", "bin"], as_index=False)
316
+ .agg({"overlap_len": "sum", "bin_size": "first"})
317
+ )
318
+
319
+ df["coverage_ratio"] = (df["overlap_len"] / df["bin_size"]).clip(0, 1)
320
+
321
+ return df
322
+
323
+
324
+ # matrix style heatmap
325
+
326
+ def make_matrix_style_heatmap(df_heat, bin_size, text_len, color="#1f6aff"):
327
+
328
+ if df_heat.empty:
329
+ return alt.Chart(pd.DataFrame({"a": []})).mark_text(text="No annotations")
330
+
331
+ df = df_heat.copy()
332
+
333
+ df["bin_start"] = df["bin"] * bin_size
334
+ df["bin_end"] = df["bin_start"] + bin_size
335
+
336
+ heatmap_select = alt.selection_point(
337
+ fields=["label", "bin"],
338
+ on="click",
339
+ clear="dblclick",
340
+ name="heatmap_select",
341
+ )
342
+
343
+ chart = (
344
+ alt.Chart(df)
345
+ .mark_rect()
346
+ .encode(
347
+ x=alt.X("bin_start:Q",
348
+ title="Character Bracket",
349
+ axis=alt.Axis(format="~s")),
350
+ x2="bin_end:Q",
351
+
352
+ y=alt.Y(
353
+ "label:N",
354
+ title="Argument Type",
355
+ sort=alt.SortArray(LABEL_ORDER),
356
+ axis=alt.Axis(labelLimit=0, labelPadding=8),
357
+ ),
358
+
359
+ color=alt.Color(
360
+ "coverage_ratio:Q",
361
+ title="% of bin covered",
362
+ scale=alt.Scale(
363
+ domain=[0, 1],
364
+ range=["#ffffff", color]
365
+ )
366
+ ),
367
+
368
+ tooltip=[
369
+ alt.Tooltip("label:N", title="Argument"),
370
+ alt.Tooltip("bin_start:Q", title="Bin start", format=","),
371
+ alt.Tooltip("bin_end:Q", title="Bin end", format=","),
372
+ alt.Tooltip("coverage_ratio:Q", title="Coverage", format=".0%")
373
+ ],
374
+ )
375
+ .add_params(heatmap_select)
376
+ .properties(
377
+ width=1200,
378
+ height=40 * df["label"].nunique(),
379
+ )
380
+ )
381
+
382
+ return chart
383
+
384
+
385
+ # highlighting utils
386
+
387
+ def generate_color_palette(n):
388
+ colors = []
389
+ for i in range(n):
390
+ hue = i / max(1, n)
391
+ r, g, b = colorsys.hls_to_rgb(hue, 0.6, 0.8)
392
+ colors.append(f"rgba({int(r*255)}, {int(g*255)}, {int(b*255)}, 0.35)")
393
+ return colors
394
+
395
+
396
+ def make_annotator_color_map(annotators):
397
+ colors = generate_color_palette(len(annotators))
398
+ return {a: c for a, c in zip(annotators, colors)}
399
+
400
+
401
+ def compute_interval_segments(text_len, spans):
402
+
403
+ boundaries = {0, text_len}
404
+
405
+ for s, e, _ in spans:
406
+ boundaries.add(int(s))
407
+ boundaries.add(int(e))
408
+
409
+ cuts = sorted(b for b in boundaries if 0 <= b <= text_len)
410
+
411
+ intervals = []
412
+
413
+ for i in range(len(cuts) - 1):
414
+ s, e = cuts[i], cuts[i+1]
415
+ if s >= e:
416
+ continue
417
+
418
+ active = [span for span in spans if span[0] < e and span[1] > s]
419
+
420
+ intervals.append((s, e, [a[2] for a in active]))
421
+
422
+ return intervals
423
+
424
+
425
+ def render_highlighted_html(text, spans, color_map, meta_map, focus_ann_id=None):
426
+
427
+ if not spans:
428
+ return f"<pre>{html.escape(text)}</pre>"
429
+
430
+ intervals = compute_interval_segments(len(text), spans)
431
+
432
+ out = []
433
+
434
+ anchored = False
435
+
436
+ for s, e, ann_ids in intervals:
437
+ chunk = html.escape(text[s:e])
438
+
439
+ if ann_ids:
440
+
441
+ bg_layers = ", ".join(
442
+ f"linear-gradient({color_map[a]} 0 0)" for a in ann_ids
443
+ )
444
+
445
+ tooltip_lines = []
446
+ for a in ann_ids:
447
+ m = meta_map[a]
448
+ tooltip_lines.append(
449
+ f"{m['label']} — {m['annotator']} ({m['curation']})"
450
+ )
451
+
452
+ title_attr = html.escape("\n".join(tooltip_lines))
453
+
454
+ is_focus = focus_ann_id in ann_ids if focus_ann_id else False
455
+ anchor_attr = ""
456
+ if is_focus and not anchored:
457
+ anchor_attr = f' id="ann-{focus_ann_id}"'
458
+ anchored = True
459
+
460
+ focus_style = "box-shadow: inset 0 0 0 2px #111;" if is_focus else ""
461
+
462
+ chunk = (
463
+ f'<span title="{title_attr}"{anchor_attr} '
464
+ f'style="background:{bg_layers};'
465
+ f'background-blend-mode:multiply;{focus_style}">'
466
+ f'{chunk}</span>'
467
+ )
468
+
469
+ out.append(chunk)
470
+
471
+ return "<pre style='line-height:1.5'>" + "".join(out) + "</pre>"
472
+
473
+
474
+ def extract_heatmap_selection(event):
475
+ if event is None:
476
+ return None
477
+
478
+ selection = getattr(event, "selection", None)
479
+ if selection is None and isinstance(event, dict):
480
+ selection = event.get("selection")
481
+
482
+ def pull_fields(sel):
483
+ if sel is None:
484
+ return None
485
+ if isinstance(sel, dict):
486
+ if "label" in sel and "bin" in sel:
487
+ return sel
488
+ if "values" in sel:
489
+ return pull_fields(sel.get("values"))
490
+ for value in sel.values():
491
+ extracted = pull_fields(value)
492
+ if extracted:
493
+ return extracted
494
+ if isinstance(sel, list) and sel:
495
+ return pull_fields(sel[0])
496
+ return None
497
+
498
+ return pull_fields(selection)
499
+
500
+
501
+ def project_spans_to_interval(spans_global, seg_start, seg_end):
502
+ projected = []
503
+
504
+ for g_start, g_end, ann_id in spans_global:
505
+ if g_end <= seg_start or g_start >= seg_end:
506
+ continue
507
+
508
+ local_start = max(g_start, seg_start) - seg_start
509
+ local_end = min(g_end, seg_end) - seg_start
510
+
511
+ if local_start < local_end:
512
+ projected.append((local_start, local_end, ann_id))
513
+
514
+ return projected
515
+
516
+
517
+ def pick_focus_annotation(df_ann, label, bin_start, bin_end):
518
+ if df_ann.empty:
519
+ return None
520
+
521
+ df_sel = df_ann[
522
+ (df_ann["label"] == label)
523
+ & (df_ann["global_begin"] < bin_end)
524
+ & (df_ann["global_end"] > bin_start)
525
+ ]
526
+
527
+ if df_sel.empty:
528
+ return None
529
+
530
+ overlaps = (
531
+ df_sel.assign(
532
+ overlap=lambda d: (
533
+ np.minimum(d["global_end"], bin_end)
534
+ - np.maximum(d["global_begin"], bin_start)
535
+ )
536
+ )
537
+ .sort_values(["overlap", "global_begin"], ascending=[False, True])
538
+ )
539
+
540
+ return overlaps.iloc[0]["annotation_id"]
541
+
542
+
543
+ def scroll_to_annotation(focus_ann_id):
544
+ if focus_ann_id is None:
545
+ return
546
+
547
+ components.html(
548
+ "<script>"
549
+ "const targetId = 'ann-" + str(focus_ann_id) + "';"
550
+ "const tryScroll = () => {"
551
+ " const el = window.parent.document.getElementById(targetId);"
552
+ " if (el) {"
553
+ " el.scrollIntoView({behavior: 'smooth', block: 'center'});"
554
+ " return true;"
555
+ " }"
556
+ " return false;"
557
+ "};"
558
+ "if (!tryScroll()) {"
559
+ " setTimeout(tryScroll, 150);"
560
+ "}"
561
+ "</script>",
562
+ height=0,
563
+ )
564
+
565
+
566
+ def scroll_to_heatmap(anchor_id):
567
+ if not anchor_id:
568
+ return
569
+
570
+ components.html(
571
+ "<script>"
572
+ "const targetId = '" + str(anchor_id) + "';"
573
+ "const tryScroll = () => {"
574
+ " const el = window.parent.document.getElementById(targetId);"
575
+ " if (el) {"
576
+ " el.scrollIntoView({behavior: 'smooth', block: 'start'});"
577
+ " return true;"
578
+ " }"
579
+ " return false;"
580
+ "};"
581
+ "if (!tryScroll()) {"
582
+ " setTimeout(tryScroll, 150);"
583
+ "}"
584
+ "</script>",
585
+ height=0,
586
+ )
587
+
588
+
589
+ def render_floating_heatmap_button(anchor_id, button_id):
590
+ if not anchor_id:
591
+ return
592
+
593
+ components.html(
594
+ "<script>"
595
+ "const btnId = '" + str(button_id) + "';"
596
+ "const anchorId = '" + str(anchor_id) + "';"
597
+ "const doc = window.document;"
598
+ "let btn = doc.getElementById(btnId);"
599
+ "if (!btn) {"
600
+ " btn = doc.createElement('button');"
601
+ " btn.id = btnId;"
602
+ " btn.textContent = 'Back to heatmap';"
603
+ " btn.style.position = 'fixed';"
604
+ " btn.style.right = '16px';"
605
+ " btn.style.bottom = '16px';"
606
+ " btn.style.zIndex = '2147483647';"
607
+ " btn.style.padding = '8px 12px';"
608
+ " btn.style.border = '1px solid #ccc';"
609
+ " btn.style.borderRadius = '8px';"
610
+ " btn.style.background = '#fff';"
611
+ " btn.style.color = '#222';"
612
+ " btn.style.boxShadow = '0 2px 6px rgba(0,0,0,0.12)';"
613
+ " btn.style.cursor = 'pointer';"
614
+ " btn.style.transform = 'none';"
615
+ " btn.style.margin = '0';"
616
+ " btn.style.pointerEvents = 'auto';"
617
+ " doc.body.appendChild(btn);"
618
+ "}"
619
+ "btn.onclick = () => {"
620
+ " let el = doc.getElementById(anchorId);"
621
+ " if (!el) {"
622
+ " try { el = window.parent.document.getElementById(anchorId); } catch (e) {}"
623
+ " }"
624
+ " if (el) {"
625
+ " el.scrollIntoView({behavior: 'smooth', block: 'start'});"
626
+ " }"
627
+ "};"
628
+ "</script>",
629
+ height=0,
630
+ )
631
+
632
+ # streamlit UI
633
+
634
+ st.set_page_config(page_title="Argument Heatmap Explorer", layout="wide")
635
+
636
+ st.title("Argument Saturation Heatmap")
637
+
638
+ app_password = os.getenv("APP_PASSWORD")
639
+ if app_password:
640
+ if not st.session_state.get("auth_ok"):
641
+ with st.sidebar:
642
+ st.markdown("### Access")
643
+ pw = st.text_input("Password", type="password")
644
+ if pw:
645
+ if pw == app_password:
646
+ st.session_state["auth_ok"] = True
647
+ else:
648
+ st.session_state["auth_ok"] = False
649
+ st.error("Incorrect password.")
650
+ if not st.session_state.get("auth_ok"):
651
+ st.stop()
652
+ st.caption("Rows = argument types · Columns = character bins · Color = % coverage")
653
+ st.markdown(
654
+ "<style>"
655
+ "[data-testid='stSidebar']{position:fixed;top:0;left:0;height:100vh;}"
656
+ "[data-testid='stSidebar'] > div:first-child{height:100vh;overflow:auto;}"
657
+ "</style>",
658
+ unsafe_allow_html=True,
659
+ )
660
+ components.html(
661
+ "<script>"
662
+ "if (!window._backspaceScrollBound) {"
663
+ " window._backspaceScrollBound = true;"
664
+ " window.addEventListener('keydown', (e) => {"
665
+ " const tag = (e.target && e.target.tagName) || '';"
666
+ " const isInput = tag === 'INPUT' || tag === 'TEXTAREA' || e.target.isContentEditable;"
667
+ " if (!isInput && e.key === 'Backspace') {"
668
+ " e.preventDefault();"
669
+ " window.scrollTo({top: 0, behavior: 'smooth'});"
670
+ " }"
671
+ " });"
672
+ "}"
673
+ "</script>",
674
+ height=0,
675
+ )
676
+ components.html(
677
+ "<script>"
678
+ "const lockSidebar = () => {"
679
+ " const doc = window.parent.document;"
680
+ " const sidebar = doc.querySelector('[data-testid=\"stSidebar\"], .stSidebar');"
681
+ " if (!sidebar) return false;"
682
+ " sidebar.style.position = 'fixed';"
683
+ " sidebar.style.top = '0';"
684
+ " sidebar.style.left = '0';"
685
+ " sidebar.style.height = '100vh';"
686
+ " sidebar.style.zIndex = '999';"
687
+ " const inner = sidebar.querySelector('div');"
688
+ " if (inner) {"
689
+ " inner.style.height = '100vh';"
690
+ " inner.style.overflow = 'auto';"
691
+ " }"
692
+ " const main = doc.querySelector('[data-testid=\"stAppViewContainer\"], .main');"
693
+ " if (main) {"
694
+ " const w = sidebar.getBoundingClientRect().width;"
695
+ " main.style.marginLeft = `${w}px`;"
696
+ " }"
697
+ " return true;"
698
+ "};"
699
+ "if (!lockSidebar()) {"
700
+ " setTimeout(lockSidebar, 200);"
701
+ " setTimeout(lockSidebar, 800);"
702
+ "}"
703
+ "</script>",
704
+ height=0,
705
+ )
706
+
707
+
708
+ # sidebar
709
+
710
+ st.sidebar.header("Load Data")
711
+
712
+ hearings_ds_path = st.sidebar.text_input(
713
+ "Hearings dataset path",
714
+ "la_cour_dataset_hearings"
715
+ )
716
+
717
+ judgments_ds_path = st.sidebar.text_input(
718
+ "Judgments dataset path",
719
+ "la_cour_dataset_judgments"
720
+ )
721
+
722
+ # default CSV locations
723
+ default_hear_csv = resolve_dataset_path("la_cour_hearings_annotations.csv")
724
+ default_judg_csv = resolve_dataset_path("la_cour_judgments_annotations.csv")
725
+
726
+ st.sidebar.markdown("#### Annotation CSVs")
727
+
728
+ hear_ann_upload = st.sidebar.file_uploader(
729
+ "Hearing annotations CSV",
730
+ type="csv",
731
+ key="hear_csv_upload"
732
+ )
733
+
734
+ judg_ann_upload = st.sidebar.file_uploader(
735
+ "Judgment annotations CSV",
736
+ type="csv",
737
+ key="judg_csv_upload"
738
+ )
739
+
740
+
741
+ def load_csv_or_default(upload_file, default_path):
742
+ if upload_file:
743
+ return pd.read_csv(upload_file), f"(uploaded) {upload_file.name}"
744
+
745
+ if os.path.exists(default_path):
746
+ return pd.read_csv(default_path), f"(default) {default_path}"
747
+
748
+ return None, "(missing)"
749
+
750
+
751
+ df_hear_ann, hear_status = load_csv_or_default(hear_ann_upload, default_hear_csv)
752
+ df_judg_ann, judg_status = load_csv_or_default(judg_ann_upload, default_judg_csv)
753
+
754
+ st.sidebar.caption(f"Hearing CSV: {hear_status}")
755
+ st.sidebar.caption(f"Judgment CSV: {judg_status}")
756
+
757
+
758
+ bin_size = st.sidebar.slider(
759
+ "Characters per bin",
760
+ min_value=50,
761
+ max_value=3000,
762
+ value=400,
763
+ step=50,
764
+ )
765
+
766
+ heat_color = st.sidebar.color_picker("Heatmap color", value="#1f6aff")
767
+ go_heatmap = st.sidebar.button("Back to heatmap")
768
+
769
+
770
+ if df_hear_ann is None and df_judg_ann is None:
771
+ st.info("No annotations loaded — upload a CSV or place defaults in working directory.")
772
+ st.stop()
773
+
774
+
775
+ # load datasets lazily
776
+
777
+ ds_hear = (
778
+ load_from_disk(resolve_dataset_path(hearings_ds_path))
779
+ if df_hear_ann is not None
780
+ else None
781
+ )
782
+ ds_judg = (
783
+ load_from_disk(resolve_dataset_path(judgments_ds_path))
784
+ if df_judg_ann is not None
785
+ else None
786
+ )
787
+
788
+ df_hear_text = ds_hear.to_pandas() if ds_hear else None
789
+ df_judg_text = ds_judg.to_pandas() if ds_judg else None
790
+
791
+
792
+ # tab renderer
793
+
794
+ def render_heatmap_tab(df_ann, df_text, title, key, is_hearing):
795
+
796
+ if df_ann is None:
797
+ st.warning(f"No {title.lower()} annotations loaded.")
798
+ return
799
+
800
+ st.subheader(f"{title} Heatmap")
801
+
802
+ webcast_ids = sorted(df_ann["webcast_id"].unique())
803
+ webcast = st.selectbox("Select document", webcast_ids, key=f"wc_{key}")
804
+
805
+ dfA = df_ann[df_ann["webcast_id"] == webcast]
806
+ dfT = df_text[df_text["webcast_id"] == webcast]
807
+
808
+ labels = sorted(dfA["label"].dropna().unique())
809
+ annotators = sorted(dfA["annotator"].dropna().unique())
810
+
811
+ c1, c2, c3, c4 = st.columns(4)
812
+
813
+ sel_labels = c1.multiselect("Argument types", labels, default=labels, key=f"lbl_{key}")
814
+ sel_ann = c2.multiselect("Annotators (heatmap)", annotators, default=annotators, key=f"ann_{key}")
815
+ valid_only = c3.checkbox("Valid only", value=True, key=f"valid_{key}")
816
+ preview_ann = c4.multiselect(
817
+ "Annotators (preview)",
818
+ annotators,
819
+ default=annotators,
820
+ key=f"hl_{key}",
821
+ )
822
+
823
+ dfA = dfA[dfA["label"].isin(sel_labels) & dfA["annotator"].isin(sel_ann)]
824
+ if valid_only:
825
+ dfA = dfA[dfA["curation"] == "valid"]
826
+
827
+ full_text = concat_global_text(dfT, webcast)
828
+
829
+ if not sanity_check(dfA, len(full_text)):
830
+ st.error("Annotation spans exceed text length.")
831
+ return
832
+
833
+ df_heat = bin_spans_into_brackets(dfA, len(full_text), bin_size)
834
+
835
+ st.markdown("### Heatmap")
836
+ heatmap_anchor = f"heatmap-anchor-{key}"
837
+ st.markdown(
838
+ f"<div id='{heatmap_anchor}'></div>",
839
+ unsafe_allow_html=True,
840
+ )
841
+ render_floating_heatmap_button(heatmap_anchor, f"heatmap-btn-{key}")
842
+ heatmap_chart = make_matrix_style_heatmap(
843
+ df_heat, bin_size, len(full_text), color=heat_color
844
+ )
845
+ try:
846
+ heatmap_event = st.altair_chart(
847
+ heatmap_chart,
848
+ use_container_width=True,
849
+ on_select="rerun",
850
+ )
851
+ except TypeError:
852
+ heatmap_event = None
853
+ st.altair_chart(heatmap_chart, use_container_width=True)
854
+
855
+ selected = extract_heatmap_selection(heatmap_event)
856
+ selection_key = f"heat_sel_{key}"
857
+ if selected:
858
+ try:
859
+ st.session_state[selection_key] = {
860
+ "label": selected["label"],
861
+ "bin": int(selected["bin"]),
862
+ }
863
+ except (TypeError, ValueError):
864
+ pass
865
+ elif heatmap_event is not None:
866
+ st.session_state.pop(selection_key, None)
867
+
868
+ sections = (
869
+ compute_hearing_sections(dfT)
870
+ if is_hearing
871
+ else compute_judgment_sections(dfT)
872
+ )
873
+ render_section_guide(sections, compact_columns=10 if is_hearing else None)
874
+
875
+ # highlighted text preview
876
+
877
+ st.markdown("### Highlighted Text Preview")
878
+ if preview_ann:
879
+ annot_color_map = make_annotator_color_map(preview_ann)
880
+ legend_rows = []
881
+ for annot in preview_ann:
882
+ color = annot_color_map.get(annot, "rgba(0,0,0,0.25)")
883
+ legend_rows.append(
884
+ f"<div class='annotator-legend__row'>"
885
+ f"<span class='annotator-legend__swatch' "
886
+ f"style='background:{color}'></span>"
887
+ f"<span class='annotator-legend__label'>"
888
+ f"{html.escape(str(annot))}</span></div>"
889
+ )
890
+ else:
891
+ annot_color_map = {}
892
+ legend_rows = []
893
+
894
+ dfH = df_ann[
895
+ (df_ann["webcast_id"] == webcast)
896
+ & (df_ann["annotator"].isin(preview_ann))
897
+ ]
898
+ if valid_only:
899
+ dfH = dfH[dfH["curation"] == "valid"]
900
+
901
+ spans_global = [
902
+ (int(r["global_begin"]), int(r["global_end"]), r["annotation_id"])
903
+ for _, r in dfH.iterrows()
904
+ ]
905
+ ann_id_to_annot = {
906
+ r["annotation_id"]: r["annotator"]
907
+ for _, r in dfH.iterrows()
908
+ }
909
+
910
+ combo_rows = []
911
+ if preview_ann and spans_global:
912
+ intervals = compute_interval_segments(len(full_text), spans_global)
913
+ seen_combos = set()
914
+ for _, _, ann_ids in intervals:
915
+ annotators_in_span = sorted(
916
+ {ann_id_to_annot.get(a) for a in ann_ids if a in ann_id_to_annot}
917
+ )
918
+ if len(annotators_in_span) <= 1:
919
+ continue
920
+ combo_key = tuple(annotators_in_span)
921
+ if combo_key in seen_combos:
922
+ continue
923
+ seen_combos.add(combo_key)
924
+ bg_layers = ", ".join(
925
+ f"linear-gradient({annot_color_map[a]} 0 0)"
926
+ for a in annotators_in_span
927
+ if a in annot_color_map
928
+ )
929
+ label = " + ".join(html.escape(str(a)) for a in combo_key)
930
+ combo_rows.append(
931
+ f"<div class='annotator-legend__row'>"
932
+ f"<span class='annotator-legend__swatch' "
933
+ f"style='background:{bg_layers};"
934
+ f"background-blend-mode:multiply;'></span>"
935
+ f"<span class='annotator-legend__label'>{label}</span></div>"
936
+ )
937
+
938
+ color_map = {
939
+ r["annotation_id"]: annot_color_map.get(
940
+ r["annotator"], "rgba(0,0,0,0.25)"
941
+ )
942
+ for _, r in dfH.iterrows()
943
+ }
944
+
945
+ if legend_rows or combo_rows:
946
+ st.markdown(
947
+ "<style>"
948
+ ".annotator-legend{position:sticky;top:0;background:white;"
949
+ "padding:6px 8px;border:1px solid #e6e6e6;border-radius:6px;"
950
+ "z-index:10;margin:6px 0 12px 0;display:inline-block;}"
951
+ ".annotator-legend__row{display:flex;align-items:center;"
952
+ "gap:8px;margin:2px 0;}"
953
+ ".annotator-legend__swatch{width:16px;height:16px;"
954
+ "border-radius:3px;display:inline-block;}"
955
+ ".annotator-legend__label{font-size:12px;color:#222;}"
956
+ ".annotator-legend__section{font-size:11px;margin:4px 0 2px;"
957
+ "color:#666;}"
958
+ "</style>"
959
+ "<div class='annotator-legend'>"
960
+ "<div class='annotator-legend__section'>Annotators</div>"
961
+ + "".join(legend_rows)
962
+ + (
963
+ "<div class='annotator-legend__section'>Combinations</div>"
964
+ + "".join(combo_rows)
965
+ if combo_rows else ""
966
+ )
967
+ + "</div>",
968
+ unsafe_allow_html=True,
969
+ )
970
+
971
+ meta_map = {
972
+ r["annotation_id"]: {
973
+ "label": r["label"],
974
+ "annotator": r["annotator"],
975
+ "curation": r["curation"]
976
+ }
977
+ for _, r in dfH.iterrows()
978
+ }
979
+
980
+ focus_ann_id = None
981
+ selection_state = st.session_state.get(selection_key)
982
+ if selection_state and not dfH.empty:
983
+ sel_label = selection_state.get("label")
984
+ sel_bin = selection_state.get("bin")
985
+ if sel_label is not None and sel_bin is not None:
986
+ bin_start = sel_bin * bin_size
987
+ bin_end = bin_start + bin_size
988
+ focus_ann_id = pick_focus_annotation(
989
+ dfH, sel_label, bin_start, bin_end
990
+ )
991
+
992
+ # hearing preview
993
+ if is_hearing:
994
+
995
+ rows = dfT.sort_values(["segment_id", "sequence_id"])
996
+ html_blocks = []
997
+
998
+ cursor = 0
999
+ for seg_id, seg_rows in rows.groupby("segment_id"):
1000
+
1001
+ speaker = (
1002
+ seg_rows.iloc[0].get("speaker_name")
1003
+ or seg_rows.iloc[0].get("speaker_role")
1004
+ or "Unknown"
1005
+ )
1006
+
1007
+ pieces = []
1008
+ seg_start = cursor
1009
+
1010
+ for _, r in seg_rows.iterrows():
1011
+ t = r["text"] or ""
1012
+ pieces.append(t)
1013
+ cursor += len(t) + 1
1014
+
1015
+ segment_text = " ".join(pieces)
1016
+ seg_end = seg_start + len(segment_text)
1017
+
1018
+ local_spans = project_spans_to_interval(spans_global, seg_start, seg_end)
1019
+
1020
+ html_blocks.append(f"<b>{html.escape(str(speaker))}</b><br>")
1021
+ html_blocks.append(
1022
+ render_highlighted_html(
1023
+ segment_text,
1024
+ local_spans,
1025
+ color_map,
1026
+ meta_map,
1027
+ focus_ann_id=focus_ann_id,
1028
+ )
1029
+ )
1030
+ html_blocks.append("<br>")
1031
+
1032
+ st.markdown("".join(html_blocks), unsafe_allow_html=True)
1033
+
1034
+ # judgment preview
1035
+ else:
1036
+
1037
+ rows = dfT.sort_values("paragraph_id")
1038
+ html_blocks = []
1039
+
1040
+ cursor = 0
1041
+
1042
+ for _, row in rows.iterrows():
1043
+ ptext = row["text"] or ""
1044
+
1045
+ seg_start = cursor
1046
+ seg_end = seg_start + len(ptext)
1047
+
1048
+ local_spans = project_spans_to_interval(spans_global, seg_start, seg_end)
1049
+
1050
+ cursor = seg_end + 1
1051
+
1052
+ html_blocks.append(
1053
+ render_highlighted_html(
1054
+ ptext,
1055
+ local_spans,
1056
+ color_map,
1057
+ meta_map,
1058
+ focus_ann_id=focus_ann_id,
1059
+ )
1060
+ )
1061
+ html_blocks.append("<br>\n")
1062
+
1063
+ st.markdown("".join(html_blocks), unsafe_allow_html=True)
1064
+
1065
+ scroll_to_annotation(focus_ann_id)
1066
+
1067
+ if go_heatmap:
1068
+ scroll_to_heatmap(heatmap_anchor)
1069
+
1070
+ # tabs
1071
+
1072
+ tab1, tab2 = st.tabs(["Hearings", "Judgments"])
1073
+
1074
+ with tab1:
1075
+ render_heatmap_tab(df_hear_ann, df_hear_text, "Hearing", "hear", is_hearing=True)
1076
 
1077
+ with tab2:
1078
+ render_heatmap_tab(df_judg_ann, df_judg_text, "Judgment", "judg", is_hearing=False)