BrundageLab commited on
Commit
1adf975
·
verified ·
1 Parent(s): 6f0968f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +699 -32
src/streamlit_app.py CHANGED
@@ -1,40 +1,707 @@
1
- import altair as alt
2
- import numpy as np
 
 
 
 
 
 
 
 
 
 
 
3
  import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
- # Welcome to Streamlit!
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
 
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ # app.py
2
+ # Streamlit "product-like" Vet De-ID demo (PIPELINE-FREE):
3
+ # - Loads model from a Hugging Face repo ID (public or private via HF token)
4
+ # - Runs token-classification via tokenizer+model directly (no HF pipeline kwargs issues)
5
+ # - Single-note + batch (CSV/TXT) processing
6
+ # - Highlighted redaction preview + entity table
7
+ # - Downloads: redacted text, JSON entities, redacted CSV
8
+
9
+ import os
10
+ import re
11
+ import json
12
+ from typing import List, Dict, Any, Optional
13
+ import streamlit.components.v1 as components
14
  import pandas as pd
15
  import streamlit as st
16
+ import torch
17
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
18
+
19
+
20
+ # =========================
21
+ # Core utilities
22
+ # =========================
23
+ def get_group(ent: Dict[str, Any]) -> str:
24
+ return ent.get("entity_group") or ent.get("entity") or "UNK"
25
+
26
+ def norm_contact(s: str) -> str:
27
+ s = s.strip().lower()
28
+ if "@" in s:
29
+ return s
30
+ return re.sub(r"\D", "", s)
31
+
32
+ def resolve_overlaps(entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
33
+ # Keep longest span first, then higher score
34
+ ents = sorted(
35
+ entities,
36
+ key=lambda e: (e["start"], -(e["end"] - e["start"]), -float(e.get("score", 0.0)))
37
+ )
38
+ kept: List[Dict[str, Any]] = []
39
+ for e in ents:
40
+ overlap = False
41
+ for k in kept:
42
+ if e["start"] < k["end"] and e["end"] > k["start"]:
43
+ overlap = True
44
+ break
45
+ if not overlap:
46
+ kept.append(e)
47
+ return kept
48
+
49
+ def dedup_entities_by_span(ents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
50
+ seen = set()
51
+ out = []
52
+ for e in ents:
53
+ key = (get_group(e), int(e["start"]), int(e["end"]))
54
+ if key in seen:
55
+ continue
56
+ seen.add(key)
57
+ out.append(e)
58
+ return out
59
+
60
+ def is_placeholder(word: str) -> bool:
61
+ w = word.strip()
62
+ if re.fullmatch(r"[_\s\-\(\)]+", w):
63
+ return True
64
+ if w.count("_") >= 2 and len(re.sub(r"[_\s\-\(\)]", "", w)) < 2:
65
+ return True
66
+ return False
67
+
68
+ def merge_adjacent_entities(entities: List[Dict[str, Any]], text: str) -> List[Dict[str, Any]]:
69
+ """
70
+ Merge same-label spans separated only by safe punctuation/whitespace.
71
+ Prevent merges across newlines / field boundaries.
72
+ """
73
+ if not entities:
74
+ return []
75
+ entities = sorted(entities, key=lambda x: x["start"])
76
+ merged = [dict(entities[0])]
77
+
78
+ for nxt in entities[1:]:
79
+ cur = merged[-1]
80
+ same = (get_group(cur) == get_group(nxt))
81
+
82
+ gap_text = text[cur["end"]:nxt["start"]]
83
+ gap = nxt["start"] - cur["end"]
84
+
85
+ if "\n" in gap_text or "\r" in gap_text:
86
+ merged.append(dict(nxt))
87
+ continue
88
+
89
+ safe_gap = bool(re.fullmatch(r"[ \t,./\-()]*", gap_text))
90
+ if same and gap <= 3 and safe_gap:
91
+ new_end = nxt["end"]
92
+ cur["word"] = text[cur["start"]:new_end]
93
+ cur["end"] = new_end
94
+ cur["score"] = max(float(cur.get("score", 0.0)), float(nxt.get("score", 0.0)))
95
+ else:
96
+ merged.append(dict(nxt))
97
+
98
+ return merged
99
 
100
+ def find_structured_pii(text: str) -> List[Dict[str, Any]]:
101
+ hits = []
102
+ # Emails
103
+ for m in re.finditer(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text):
104
+ hits.append({"word": m.group(), "entity_group": "CONTACT", "score": 1.0, "start": m.start(), "end": m.end()})
105
+ # Phones (US-ish)
106
+ for m in re.finditer(r"\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}", text):
107
+ hits.append({"word": m.group(), "entity_group": "CONTACT", "score": 1.0, "start": m.start(), "end": m.end()})
108
+ return hits
109
+
110
+ def redact_text(text: str, entities: List[Dict[str, Any]], mode: str = "tags") -> str:
111
+ """
112
+ mode="tags": [NAME], [LOC], etc.
113
+ mode="char": ***** preserving length
114
+ """
115
+ entities = resolve_overlaps(entities)
116
+ entities = sorted(entities, key=lambda x: x["start"], reverse=True)
117
+
118
+ redacted = text
119
+ for ent in entities:
120
+ start, end = ent["start"], ent["end"]
121
+ label = get_group(ent)
122
+ replacement = f"[{label}]" if mode == "tags" else "*" * max(1, (end - start))
123
+ redacted = redacted[:start] + replacement + redacted[end:]
124
+ return redacted
125
+
126
+ def highlight_entities_html(text: str, entities: List[Dict[str, Any]]) -> str:
127
+ entities = resolve_overlaps(entities)
128
+ entities = sorted(entities, key=lambda x: x["start"])
129
+
130
+ # RGBA base colors (R,G,B); alpha is scaled by score
131
+ palette_rgb = {
132
+ "NAME": (255, 200, 87),
133
+ "LOC": (120, 180, 255),
134
+ "ORG": (140, 220, 160),
135
+ "DATE": (255, 140, 140),
136
+ "ID": (200, 160, 255),
137
+ "CONTACT": (120, 220, 220),
138
+ "UNK": (200, 200, 200),
139
+ }
140
+
141
+ def esc(s: str) -> str:
142
+ return (s.replace("&", "&amp;")
143
+ .replace("<", "&lt;")
144
+ .replace(">", "&gt;")
145
+ .replace('"', "&quot;")
146
+ .replace("'", "&#39;"))
147
+
148
+ css = """
149
+ <style>
150
+ .note {
151
+ white-space: pre-wrap;
152
+ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
153
+ font-size: 13px;
154
+ line-height: 1.45;
155
+
156
+ /* add these */
157
+ color: #e8eaed;
158
+ background: #0e1117;
159
+ padding: 12px 14px;
160
+ border-radius: 10px;
161
+ }
162
+
163
+ .ent {
164
+ position: relative;
165
+ border-radius: 4px;
166
+ padding: 0px 2px;
167
+ margin: 0px 1px;
168
+ box-decoration-break: clone;
169
+ -webkit-box-decoration-break: clone;
170
+ transition: filter 120ms ease;
171
+ }
172
+ .ent:hover { filter: brightness(1.05); }
173
+
174
+ .ent::after {
175
+ content: "";
176
+ position: absolute;
177
+ left: 0; right: 0; bottom: -1px;
178
+ height: 2px;
179
+ border-radius: 2px;
180
+ background: rgba(var(--rgb), 0.85);
181
+ }
182
+
183
+ .pill {
184
+ display: none;
185
+ position: absolute;
186
+ top: -14px;
187
+ left: 0px;
188
+ font-size: 10px;
189
+ line-height: 1;
190
+ padding: 2px 6px;
191
+ border-radius: 999px;
192
+ background: rgba(var(--rgb), 0.95);
193
+ color: #111;
194
+ box-shadow: 0 2px 8px rgba(0,0,0,0.25);
195
+ white-space: nowrap;
196
+ z-index: 5;
197
+ }
198
+ .ent:hover .pill { display: inline-block; }
199
+ </style>
200
  """
 
201
 
 
 
 
202
 
203
+ out = []
204
+ cursor = 0
205
+ for e in entities:
206
+ s, t = e["start"], e["end"]
207
+ if s < cursor:
208
+ continue
209
+
210
+ out.append(esc(text[cursor:s]))
211
+
212
+ label = get_group(e)
213
+ r, g, b = palette_rgb.get(label, palette_rgb["UNK"])
214
+ score = float(e.get("score", 0.0))
215
+ # background alpha: 0.10 to 0.32 depending on confidence
216
+ alpha = 0.10 + 0.22 * max(0.0, min(1.0, score))
217
+
218
+ span_text = esc(text[s:t])
219
+ title = f"{label} • {score:.2f}"
220
+
221
+ out.append(
222
+ f'<span class="ent" title="{esc(title)}" style="--rgb:{r},{g},{b}; background: rgba({r},{g},{b},{alpha});">'
223
+ f'{span_text}'
224
+ f'<span class="pill">{label}</span>'
225
+ f"</span>"
226
+ )
227
+ cursor = t
228
+
229
+ out.append(esc(text[cursor:]))
230
+
231
+ return css + "<div class='note'>" + "".join(out) + "</div>"
232
+
233
+
234
+
235
+ # =========================
236
+ # Model loading from HF (NO PIPELINE)
237
+ # =========================
238
+ @st.cache_resource
239
+ def load_hf_model(
240
+ repo_id: str,
241
+ revision: Optional[str],
242
+ hf_token: Optional[str],
243
+ device_str: str,
244
+ ):
245
+ device = torch.device(device_str)
246
+ tok = AutoTokenizer.from_pretrained(repo_id, revision=revision, token=hf_token)
247
+ mdl = AutoModelForTokenClassification.from_pretrained(repo_id, revision=revision, token=hf_token)
248
+ mdl.to(device)
249
+ mdl.eval()
250
+ return tok, mdl, device
251
+
252
+
253
+ # =========================
254
+ # NER: model-based inference with offsets (BIO -> spans)
255
+ # =========================
256
+ def ner_call_model(tokenizer, model, text: str, max_len: int, device: torch.device) -> List[Dict[str, Any]]:
257
+ enc = tokenizer(
258
+ text,
259
+ return_offsets_mapping=True,
260
+ truncation=True,
261
+ max_length=max_len,
262
+ return_tensors="pt",
263
+ padding=False,
264
+ )
265
+ offsets = enc.pop("offset_mapping")[0].tolist()
266
+ enc = {k: v.to(device) for k, v in enc.items()}
267
+
268
+ with torch.inference_mode():
269
+ logits = model(**enc).logits[0] # (seq_len, num_labels)
270
+
271
+ probs = torch.softmax(logits, dim=-1)
272
+ pred_ids = probs.argmax(dim=-1).tolist()
273
+ pred_scores = probs.max(dim=-1).values.tolist()
274
+
275
+ id2label = model.config.id2label
276
+
277
+ def id_to_label(i: int) -> str:
278
+ if i in id2label:
279
+ return id2label[i]
280
+ return id2label.get(str(i), "O")
281
+
282
+ labels = [id_to_label(i) for i in pred_ids]
283
+
284
+ entities: List[Dict[str, Any]] = []
285
+ i = 0
286
+ while i < len(labels):
287
+ lab = labels[i]
288
+ s, e = offsets[i]
289
+
290
+ # skip special/empty
291
+ if s == e:
292
+ i += 1
293
+ continue
294
+ if lab == "O":
295
+ i += 1
296
+ continue
297
+
298
+ # if I- without B-, treat as B-
299
+ if lab.startswith("I-"):
300
+ lab = "B-" + lab[2:]
301
+
302
+ if lab.startswith("B-"):
303
+ typ = lab[2:]
304
+ start = s
305
+ end = e
306
+ scores = [pred_scores[i]]
307
+
308
+ j = i + 1
309
+ while j < len(labels):
310
+ lab2 = labels[j]
311
+ s2, e2 = offsets[j]
312
+ if s2 == e2:
313
+ j += 1
314
+ continue
315
+ if lab2 == f"I-{typ}":
316
+ end = e2
317
+ scores.append(pred_scores[j])
318
+ j += 1
319
+ continue
320
+ break
321
+
322
+ entities.append({
323
+ "word": text[start:end],
324
+ "entity_group": typ,
325
+ "start": start,
326
+ "end": end,
327
+ "score": float(sum(scores) / max(1, len(scores))), # mean token confidence
328
+ })
329
+ i = j
330
+ else:
331
+ i += 1
332
+
333
+ return entities
334
+
335
+
336
+ def run_ner_with_windows_model(
337
+ tokenizer,
338
+ model,
339
+ device: torch.device,
340
+ text: str,
341
+ pipe_max_len: int,
342
+ window_chars: int = 2000,
343
+ overlap_chars: int = 250,
344
+ ) -> List[Dict[str, Any]]:
345
+ ents: List[Dict[str, Any]] = []
346
+ start = 0
347
+ n = len(text)
348
+
349
+ while start < n:
350
+ end = min(n, start + window_chars)
351
+ chunk = text[start:end]
352
+ chunk_ents = ner_call_model(tokenizer, model, chunk, max_len=pipe_max_len, device=device)
353
+
354
+ for e in chunk_ents:
355
+ e = dict(e)
356
+ e["start"] += start
357
+ e["end"] += start
358
+ e["word"] = text[e["start"]:e["end"]]
359
+ ents.append(e)
360
+
361
+ if end == n:
362
+ break
363
+ start = max(0, end - overlap_chars)
364
+
365
+ return ents
366
+ def propagate_entities(text: str, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
367
+ """
368
+ Add additional spans by exact/normalized string matching for selected entity types.
369
+ Returns a new entity list (original + propagated), resolved/deduped.
370
+ """
371
+ # Which labels to propagate and how
372
+ PROPAGATE = {"CONTACT", "ID", "NAME"} # consider adding DATE if needed
373
+ MIN_ID_LEN = 5 # tune: avoid 2-3 digit labs, doses
374
+ MIN_NAME_LEN = 4 # avoid tiny tokens
375
+
376
+ # Build patterns from existing entities
377
+ patterns = []
378
+ for e in entities:
379
+ label = get_group(e)
380
+ if label not in PROPAGATE:
381
+ continue
382
+
383
+ val = e["word"].strip()
384
+ if not val:
385
+ continue
386
+
387
+ if label == "CONTACT":
388
+ # Exact string match (case-insensitive for emails)
389
+ patterns.append((label, re.escape(val), re.IGNORECASE))
390
+
391
+ elif label == "ID":
392
+ # Only propagate "ID-like" tokens
393
+ compact = re.sub(r"\D", "", val)
394
+ if len(compact) < MIN_ID_LEN:
395
+ continue
396
+ # Match the same digit sequence allowing separators
397
+ # e.g. 261808 matches "261808" or "261-808" if present
398
+ digit_pat = r"\D*".join(list(compact))
399
+ patterns.append((label, digit_pat, 0))
400
+
401
+ elif label == "NAME":
402
+ # Prefer multi-token names; for single token be conservative
403
+ # You can tune this: in vet notes, patient single-token names are still PII.
404
+ is_multi = bool(re.search(r"\s", val))
405
+ if (not is_multi) and len(val) < MIN_NAME_LEN:
406
+ continue
407
+ # Exact token/phrase match with word boundaries
408
+ pat = r"\b" + re.escape(val) + r"\b"
409
+ patterns.append((label, pat, re.IGNORECASE))
410
+
411
+ # Find additional occurrences
412
+ added = []
413
+ for label, pat, flags in patterns:
414
+ for m in re.finditer(pat, text, flags=flags):
415
+ added.append({
416
+ "word": text[m.start():m.end()],
417
+ "entity_group": label,
418
+ "score": 1.0, # propagated
419
+ "start": m.start(),
420
+ "end": m.end(),
421
+ "source": "propagated",
422
+ })
423
+
424
+ all_ents = list(entities) + added
425
+ all_ents = sorted(all_ents, key=lambda x: x["start"])
426
+ all_ents = dedup_entities_by_span(all_ents)
427
+ all_ents = resolve_overlaps(all_ents)
428
+ return all_ents
429
+
430
+ def deidentify_note(
431
+ tokenizer,
432
+ model,
433
+ device: torch.device,
434
+ text: str,
435
+ pipe_max_len: int,
436
+ thresh: Dict[str, float],
437
+ global_stoplist: set,
438
+ stop_by_label: Dict[str, set],
439
+ use_windows: bool,
440
+ window_chars: int,
441
+ overlap_chars: int,
442
+ ) -> List[Dict[str, Any]]:
443
+ def pass_thresh(ent):
444
+ g = get_group(ent)
445
+ return float(ent.get("score", 0.0)) >= float(thresh.get(g, thresh.get("_default", 0.45)))
446
+
447
+ def stoplisted(ent):
448
+ g = get_group(ent)
449
+ w = ent["word"].strip().lower()
450
+ if w in global_stoplist:
451
+ return True
452
+ return w in stop_by_label.get(g, set())
453
+
454
+ # BERT
455
+ if use_windows:
456
+ bert_results = run_ner_with_windows_model(
457
+ tokenizer, model, device, text,
458
+ pipe_max_len=pipe_max_len,
459
+ window_chars=window_chars,
460
+ overlap_chars=overlap_chars,
461
+ )
462
+ else:
463
+ bert_results = ner_call_model(tokenizer, model, text, max_len=pipe_max_len, device=device)
464
+
465
+ # Merge adjacent same-label entities
466
+ bert_results = merge_adjacent_entities(bert_results, text)
467
+
468
+ # Regex CONTACT
469
+ regex_results = find_structured_pii(text)
470
+
471
+ final_entities: List[Dict[str, Any]] = []
472
+ final_entities.extend(regex_results)
473
+
474
+ for ent in bert_results:
475
+ word = ent["word"].strip()
476
+
477
+ if not pass_thresh(ent):
478
+ continue
479
+ if is_placeholder(word):
480
+ continue
481
+ if stoplisted(ent):
482
+ continue
483
+ if len(word) < 2 and not word.isdigit():
484
+ continue
485
+
486
+ # if overlaps regex CONTACT, skip BERT (regex wins)
487
+ dup = False
488
+ for reg in regex_results:
489
+ if ent["start"] < reg["end"] and ent["end"] > reg["start"]:
490
+ dup = True
491
+ break
492
+ if dup:
493
+ continue
494
+
495
+ final_entities.append(ent)
496
+
497
+ final_entities = sorted(final_entities, key=lambda x: x["start"])
498
+ final_entities = dedup_entities_by_span(final_entities)
499
+ final_entities = resolve_overlaps(final_entities)
500
+ return final_entities
501
+
502
+
503
+ # =========================
504
+ # Streamlit UI
505
+ # =========================
506
+ st.set_page_config(page_title="Vet De-ID Demo", layout="wide")
507
+ st.title("Veterinary De-identification Demo (HF model + NER + Regex)")
508
+
509
+ with st.sidebar:
510
+ st.header("Model (Hugging Face)")
511
+ repo_id = st.text_input("HF repo_id", value=os.environ.get("HF_REPO_ID", "YOUR_ORG/YOUR_VET_DEID_MODEL"))
512
+ revision = st.text_input("Revision (optional)", value=os.environ.get("HF_REVISION", "")).strip() or None
513
+ hf_token = st.text_input("HF token (optional for private repos)", value=os.environ.get("HF_TOKEN", ""), type="password").strip() or None
514
+
515
+ st.header("Runtime")
516
+ use_gpu = st.checkbox("Use GPU (CUDA)", value=torch.cuda.is_available())
517
+ device_str = "cuda:0" if (use_gpu and torch.cuda.is_available()) else "cpu"
518
+
519
+ pipe_max_len = st.selectbox("Max token length", options=[256, 512], index=0)
520
+ use_windows = st.checkbox("Window long notes (recommended)", value=True)
521
+ window_chars = st.slider("Window size (chars)", 500, 6000, 2000, 100)
522
+ overlap_chars = st.slider("Window overlap (chars)", 0, 1000, 250, 25)
523
+
524
+ st.header("Thresholds")
525
+ t_name = st.slider("NAME", 0.0, 1.0, 0.60, 0.01)
526
+ t_org = st.slider("ORG", 0.0, 1.0, 0.60, 0.01)
527
+ t_loc = st.slider("LOC", 0.0, 1.0, 0.60, 0.01)
528
+ t_date = st.slider("DATE", 0.0, 1.0, 0.45, 0.01)
529
+ t_id = st.slider("ID", 0.0, 1.0, 0.50, 0.01)
530
+ t_contact = st.slider("CONTACT (model)", 0.0, 1.0, 0.99, 0.01) # regex-first anyway
531
+ t_default = st.slider("Default", 0.0, 1.0, 0.45, 0.01)
532
+
533
+ redact_mode = st.selectbox("Redaction mode", options=["tags", "char"], index=0)
534
+ show_highlight = st.checkbox("Show highlighted original", value=True)
535
+
536
+ # Load model/tokenizer
537
+ try:
538
+ tokenizer, model, device = load_hf_model(repo_id=repo_id, revision=revision, hf_token=hf_token, device_str=device_str)
539
+ except Exception as e:
540
+ st.error(f"Failed to load model/tokenizer from HF.\n\nrepo_id={repo_id}\nrevision={revision}\n\n{e}")
541
+ st.stop()
542
+
543
+ # Stoplists (can be made editable later)
544
+ GLOBAL_STOPLIST = {"er", "ve", "w", "dvm", "mph", "sex", "male", "female", "kg", "lb", "patient", "owner", "left", "right"}
545
+ STOP_BY_LABEL = {
546
+ "LOC": {"dsh", "feline", "canine", "equine", "bovine", "species", "breed", "color"},
547
+ "NAME": {"owner", "patient"},
548
+ }
549
+
550
+ THRESH = {
551
+ "NAME": t_name,
552
+ "ORG": t_org,
553
+ "LOC": t_loc,
554
+ "DATE": t_date,
555
+ "ID": t_id,
556
+ "CONTACT": t_contact,
557
+ "_default": t_default,
558
+ }
559
+
560
+ tab1, tab2, tab3 = st.tabs(["Single note", "Batch (CSV/TXT)", "About"])
561
+
562
+ with tab1:
563
+ st.subheader("Single note")
564
+ default_text = "Paste a veterinary note here..."
565
+ text = st.text_area("Input", height=260, value=default_text)
566
+
567
+ colA, colB = st.columns([1, 1])
568
+ with colA:
569
+ run_single = st.button("Run", type="primary")
570
+ with colB:
571
+ st.caption("CONTACT is extracted via regex (emails/phones). Model CONTACT output is effectively ignored by default.")
572
+
573
+ if run_single:
574
+ with st.spinner("Running de-identification..."):
575
+ final_ents = deidentify_note(
576
+ tokenizer=tokenizer,
577
+ model=model,
578
+ device=device,
579
+ text=text,
580
+ pipe_max_len=pipe_max_len,
581
+ thresh=THRESH,
582
+ global_stoplist=GLOBAL_STOPLIST,
583
+ stop_by_label=STOP_BY_LABEL,
584
+ use_windows=use_windows,
585
+ window_chars=window_chars,
586
+ overlap_chars=overlap_chars,
587
+ )
588
+ enable_propagation = st.checkbox("Propagate exact matches (recommended)", value=True)
589
+ if enable_propagation:
590
+ final_ents = propagate_entities(text, final_ents)
591
+
592
+ redacted = redact_text(text, final_ents, mode=redact_mode)
593
+
594
+ left, right = st.columns([1, 1])
595
+
596
+ with left:
597
+ st.subheader("Entities")
598
+ if final_ents:
599
+ df = pd.DataFrame([{
600
+ "type": get_group(e),
601
+ "text": e["word"],
602
+ "score": float(e.get("score", 0.0)),
603
+ "start": int(e["start"]),
604
+ "end": int(e["end"]),
605
+ } for e in final_ents])
606
+ st.dataframe(df, use_container_width=True)
607
+ else:
608
+ st.write("No entities found.")
609
+
610
+ st.download_button(
611
+ "Download entities (JSON)",
612
+ data=json.dumps(final_ents, indent=2).encode("utf-8"),
613
+ file_name="entities.json",
614
+ mime="application/json",
615
+ )
616
+
617
+ with right:
618
+ st.subheader("Redacted output")
619
+ st.text_area("Output", height=260, value=redacted)
620
+
621
+ st.download_button(
622
+ "Download redacted text",
623
+ data=redacted.encode("utf-8"),
624
+ file_name="redacted.txt",
625
+ mime="text/plain",
626
+ )
627
+
628
+ if show_highlight:
629
+ st.subheader("Highlighted original (for demo)")
630
+ #st.markdown(highlight_entities_html(text, final_ents), unsafe_allow_html=True)
631
+ components.html(
632
+ highlight_entities_html(text, final_ents),
633
+ height=600,
634
+ scrolling=True,
635
+ )
636
+
637
+ with tab2:
638
+ st.subheader("Batch processing")
639
+ st.write("Upload a CSV (one note per row) or a TXT file (single note).")
640
+ uploaded = st.file_uploader("Upload CSV or TXT", type=["csv", "txt"])
641
+
642
+ if uploaded is not None:
643
+ if uploaded.name.lower().endswith(".txt"):
644
+ raw = uploaded.getvalue().decode("utf-8", errors="replace")
645
+ st.write("Detected TXT input (single note). Use the Single note tab for best UX.")
646
+ st.text_area("Preview", value=raw[:5000], height=200)
647
+
648
+ else:
649
+ df_in = pd.read_csv(uploaded)
650
+ st.write(f"Loaded CSV with {len(df_in)} rows and columns: {list(df_in.columns)}")
651
+ text_col = st.selectbox("Text column", options=list(df_in.columns), index=0)
652
+ max_rows = st.slider("Max rows to process (demo)", 1, min(5000, len(df_in)), min(200, len(df_in)), 1)
653
+
654
+ if st.button("Run batch de-identification", type="primary"):
655
+ out_rows = []
656
+ progress = st.progress(0)
657
+ for i in range(max_rows):
658
+ note = str(df_in.loc[i, text_col]) if pd.notna(df_in.loc[i, text_col]) else ""
659
+ ents = deidentify_note(
660
+ tokenizer=tokenizer,
661
+ model=model,
662
+ device=device,
663
+ text=note,
664
+ pipe_max_len=pipe_max_len,
665
+ thresh=THRESH,
666
+ global_stoplist=GLOBAL_STOPLIST,
667
+ stop_by_label=STOP_BY_LABEL,
668
+ use_windows=use_windows,
669
+ window_chars=window_chars,
670
+ overlap_chars=overlap_chars,
671
+ )
672
+ redacted = redact_text(note, ents, mode=redact_mode)
673
+ out_rows.append({
674
+ "row": i,
675
+ "redacted": redacted,
676
+ "entities_json": json.dumps(ents, ensure_ascii=False),
677
+ "n_entities": len(ents),
678
+ })
679
+ if (i + 1) % 5 == 0 or (i + 1) == max_rows:
680
+ progress.progress((i + 1) / max_rows)
681
+
682
+ out_df = pd.DataFrame(out_rows)
683
+ st.success(f"Processed {max_rows} rows.")
684
+
685
+ st.subheader("Batch results (preview)")
686
+ st.dataframe(out_df.head(50), use_container_width=True)
687
+
688
+ csv_bytes = out_df.to_csv(index=False).encode("utf-8")
689
+ st.download_button(
690
+ "Download redacted CSV",
691
+ data=csv_bytes,
692
+ file_name="redacted_output.csv",
693
+ mime="text/csv",
694
+ )
695
+
696
+ with tab3:
697
+ st.subheader("About / demo notes")
698
+ st.markdown(
699
+ """
700
+ - **Model source**: loaded directly from a Hugging Face `repo_id` (optionally pinned to a `revision`).
701
+ - **CONTACT**: extracted via regex (emails/phones). Model CONTACT output is typically redundant; regex wins on overlaps.
702
+ - **Long notes**: enable windowing to avoid truncation artifacts.
703
+ - **Security**: run locally for PHI. Do not deploy publicly without access control, logging controls, and a privacy review.
704
  """
705
+ )
706
 
707
+ st.caption("Tip: set env vars HF_REPO_ID, HF_REVISION, HF_TOKEN for smoother demos.")