SpotRemover / src /streamlit_app.py
BrundageLab's picture
Update src/streamlit_app.py
50c7ced verified
# app.py
# Streamlit "product-like" Vet De-ID demo (PIPELINE-FREE):
# - Loads model from a Hugging Face repo ID (public or private via HF token)
# - Runs token-classification via tokenizer+model directly (no HF pipeline kwargs issues)
# - Single-note + batch (CSV/TXT) processing
# - Highlighted redaction preview + entity table
# - Downloads: redacted text, JSON entities, redacted CSV
import os
import re
import json
from typing import List, Dict, Any, Optional
import streamlit.components.v1 as components
import pandas as pd
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from dotenv import load_dotenv
from pathlib import Path
load_dotenv()
import base64
# Absolute path to this file's directory (e.g., /app/src)
HERE = Path(__file__).resolve().parent
ASSETS_DIR = HERE / "assets"
APP_HOME_URL = "https://www.brundagelab.org/research/apps/" # change to your desired destination
def _img_to_data_uri(path: str) -> str:
ext = os.path.splitext(path)[1].lower().lstrip(".")
mime = "image/png" if ext == "png" else "image/jpeg" if ext in {"jpg","jpeg"} else "image/svg+xml"
with open(path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("utf-8")
return f"data:{mime};base64,{b64}"
def brundage_header():
col1, col2 = st.columns([2, 5])
with col1:
logo_path = ASSETS_DIR / "brundage_logo.png"
logo_uri = _img_to_data_uri(logo_path)
st.markdown(
f"""
<a href="{APP_HOME_URL}" target="_self" style="display:inline-block;">
<img src="{logo_uri}" alt="Brundage Lab" style="width:256px; height:auto; cursor:pointer;" />
</a>
""",
unsafe_allow_html=True,
)
with col2:
st.markdown(
"""
<div style="padding-top:24px;">
<div style="font-size:34px; font-weight:850; letter-spacing:-0.02em; color:#111827;">
SpotRemover: Veterinary De-Identification
</div>
</div>
""",
unsafe_allow_html=True
)
def inject_brundage_theme():
st.markdown(
"""
<style>
/* ---- App canvas ---- */
html, body, [class*="css"] {
font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, "Apple Color Emoji","Segoe UI Emoji";
}
/* Remove Streamlit default top padding a bit */
.block-container {
padding-top: 2.2rem;
padding-bottom: 2.5rem;
max-width: 1200px;
}
/* Hide Streamlit chrome */
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
/* Headings: closer to your site (bold, clean) */
h1, h2, h3, h4 {
letter-spacing: -0.02em;
color: #111827;
}
h1 { font-weight: 800; }
h2 { font-weight: 800; }
h3 { font-weight: 750; }
/* Sidebar polish */
section[data-testid="stSidebar"] {
background: #FFFFFF;
border-right: 1px solid #EEF2F7;
}
section[data-testid="stSidebar"] .block-container {
padding-top: 1.5rem;
}
/* Buttons: purple primary like your site */
.stButton > button {
border-radius: 12px;
padding: 0.55rem 0.95rem;
border: 1px solid #E5E7EB;
background: #FFFFFF;
color: #111827;
font-weight: 650;
}
.stButton > button:hover {
border-color: #C7B7FF;
background: #FBFAFF;
}
.stButton > button[kind="primary"] {
background: #6D28D9;
color: #FFFFFF;
border: 1px solid #6D28D9;
box-shadow: 0 6px 18px rgba(109, 40, 217, 0.18);
}
.stButton > button[kind="primary"]:hover {
background: #5B21B6;
border-color: #5B21B6;
}
/* Inputs: rounded + subtle border */
div[data-baseweb="input"] > div,
div[data-baseweb="textarea"] > div,
div[data-baseweb="select"] > div {
border-radius: 12px !important;
border-color: #E5E7EB !important;
box-shadow: none !important;
}
div[data-baseweb="textarea"] textarea {
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;
font-size: 13px;
}
/* Tabs: reduce Streamlit “blocky” feel */
button[data-baseweb="tab"] {
border-radius: 12px 12px 0 0;
font-weight: 650;
}
/* Dataframe / tables: softer container */
div[data-testid="stDataFrame"] {
border: 1px solid #EEF2F7;
border-radius: 14px;
overflow: hidden;
}
/* “Card” helper class you can use via st.markdown */
.card {
border: 1px solid #EEF2F7;
border-radius: 16px;
padding: 14px 16px;
background: #FFFFFF;
box-shadow: 0 10px 24px rgba(17, 24, 39, 0.06);
}
/* Highlight rendering container (so text stays readable on light bg) */
.note {
white-space: pre-wrap;
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
font-size: 13px;
line-height: 1.45;
color: #111827;
}
/* Entity pill styling (used in highlight) */
.ent {
border-radius: 10px;
padding: 2px 6px;
border: 1px solid rgba(17, 24, 39, 0.08);
box-shadow: 0 6px 14px rgba(17, 24, 39, 0.06);
}
.ent sup {
font-size: 10px;
margin-left: 6px;
opacity: 0.75;
}
/* Make captions less “default streamlit” */
.stCaption, small {
color: #6B7280;
}
</style>
""",
unsafe_allow_html=True,
)
# =========================
# Core utilities
# =========================
def get_group(ent: Dict[str, Any]) -> str:
return ent.get("entity_group") or ent.get("entity") or "UNK"
def norm_contact(s: str) -> str:
s = s.strip().lower()
if "@" in s:
return s
return re.sub(r"\D", "", s)
def resolve_overlaps(entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# Keep longest span first, then higher score
ents = sorted(
entities,
key=lambda e: (e["start"], -(e["end"] - e["start"]), -float(e.get("score", 0.0)))
)
kept: List[Dict[str, Any]] = []
for e in ents:
overlap = False
for k in kept:
if e["start"] < k["end"] and e["end"] > k["start"]:
overlap = True
break
if not overlap:
kept.append(e)
return kept
def dedup_entities_by_span(ents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
seen = set()
out = []
for e in ents:
key = (get_group(e), int(e["start"]), int(e["end"]))
if key in seen:
continue
seen.add(key)
out.append(e)
return out
def is_placeholder(word: str) -> bool:
w = word.strip()
if re.fullmatch(r"[_\s\-\(\)]+", w):
return True
if w.count("_") >= 2 and len(re.sub(r"[_\s\-\(\)]", "", w)) < 2:
return True
return False
def merge_adjacent_entities(entities: List[Dict[str, Any]], text: str) -> List[Dict[str, Any]]:
"""
Merge same-label spans separated only by safe punctuation/whitespace.
Prevent merges across newlines / field boundaries.
"""
if not entities:
return []
entities = sorted(entities, key=lambda x: x["start"])
merged = [dict(entities[0])]
for nxt in entities[1:]:
cur = merged[-1]
same = (get_group(cur) == get_group(nxt))
gap_text = text[cur["end"]:nxt["start"]]
gap = nxt["start"] - cur["end"]
if "\n" in gap_text or "\r" in gap_text:
merged.append(dict(nxt))
continue
safe_gap = bool(re.fullmatch(r"[ \t,./\-()]*", gap_text))
if same and gap <= 3 and safe_gap:
new_end = nxt["end"]
cur["word"] = text[cur["start"]:new_end]
cur["end"] = new_end
cur["score"] = max(float(cur.get("score", 0.0)), float(nxt.get("score", 0.0)))
else:
merged.append(dict(nxt))
return merged
def find_structured_pii(text: str) -> List[Dict[str, Any]]:
hits = []
# Emails
for m in re.finditer(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text):
hits.append({"word": m.group(), "entity_group": "CONTACT", "score": 1.0, "start": m.start(), "end": m.end()})
# Phones (US-ish)
for m in re.finditer(r"\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}", text):
hits.append({"word": m.group(), "entity_group": "CONTACT", "score": 1.0, "start": m.start(), "end": m.end()})
return hits
def redact_text(text: str, entities: List[Dict[str, Any]], mode: str = "tags") -> str:
"""
mode="tags": [NAME], [LOC], etc.
mode="char": ***** preserving length
"""
entities = resolve_overlaps(entities)
entities = sorted(entities, key=lambda x: x["start"], reverse=True)
redacted = text
for ent in entities:
start, end = ent["start"], ent["end"]
label = get_group(ent)
replacement = f"[{label}]" if mode == "tags" else "*" * max(1, (end - start))
redacted = redacted[:start] + replacement + redacted[end:]
return redacted
def highlight_entities_html(text: str, entities: List[Dict[str, Any]]) -> str:
entities = resolve_overlaps(entities)
entities = sorted(entities, key=lambda x: x["start"])
# RGBA base colors (R,G,B); alpha is scaled by score
palette_rgb = {
"NAME": (124, 58, 237),
"LOC": (59, 130, 246),
"ORG": (16, 185, 129),
"DATE": (244, 63, 94),
"ID": (234, 179, 8),
"CONTACT": (14, 165, 233),
"UNK": (107, 114, 128),
}
def esc(s: str) -> str:
return (s.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&#39;"))
css = """
<style>
.note {
white-space: pre-wrap;
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
font-size: 13px;
line-height: 1.45;
/* Light card styling */
color: #111827;
background: #FFFFFF;
border: 1px solid #EEF2F7;
padding: 12px 14px;
border-radius: 14px;
box-shadow: 0 10px 24px rgba(17, 24, 39, 0.06);
}
.ent {
position: relative;
border-radius: 8px;
padding: 1px 4px;
margin: 0px 1px;
box-decoration-break: clone;
-webkit-box-decoration-break: clone;
transition: filter 120ms ease;
border: 1px solid rgba(17, 24, 39, 0.08);
}
.ent:hover { filter: brightness(1.03); }
.ent::after {
content: "";
position: absolute;
left: 6px; right: 6px; bottom: -2px;
height: 2px;
border-radius: 2px;
background: rgba(var(--rgb), 0.85);
}
.pill {
display: none;
position: absolute;
top: -18px;
left: 0px;
font-size: 10px;
line-height: 1;
padding: 3px 8px;
border-radius: 999px;
background: rgba(var(--rgb), 0.95);
color: #111827;
box-shadow: 0 6px 16px rgba(17, 24, 39, 0.12);
white-space: nowrap;
z-index: 5;
}
.ent:hover .pill { display: inline-block; }
</style>
"""
out = []
cursor = 0
for e in entities:
s, t = e["start"], e["end"]
if s < cursor:
continue
out.append(esc(text[cursor:s]))
label = get_group(e)
r, g, b = palette_rgb.get(label, palette_rgb["UNK"])
score = float(e.get("score", 0.0))
# background alpha: 0.10 to 0.32 depending on confidence
alpha = 0.10 + 0.22 * max(0.0, min(1.0, score))
span_text = esc(text[s:t])
title = f"{label}{score:.2f}"
out.append(
f'<span class="ent" title="{esc(title)}" style="--rgb:{r},{g},{b}; background: rgba({r},{g},{b},{alpha});">'
f'{span_text}'
f'<span class="pill">{label}</span>'
f"</span>"
)
cursor = t
out.append(esc(text[cursor:]))
return css + "<div class='note'>" + "".join(out) + "</div>"
# =========================
# Model loading from HF (NO PIPELINE)
# =========================
@st.cache_resource
def load_hf_model(
repo_id: str,
revision: Optional[str],
hf_token: Optional[str],
device_str: str,
):
device = torch.device(device_str)
tok = AutoTokenizer.from_pretrained(repo_id, revision=revision, token=hf_token)
mdl = AutoModelForTokenClassification.from_pretrained(repo_id, revision=revision, token=hf_token)
mdl.to(device)
mdl.eval()
return tok, mdl, device
# =========================
# NER: model-based inference with offsets (BIO -> spans)
# =========================
def ner_call_model(tokenizer, model, text: str, max_len: int, device: torch.device) -> List[Dict[str, Any]]:
enc = tokenizer(
text,
return_offsets_mapping=True,
truncation=True,
max_length=max_len,
return_tensors="pt",
padding=False,
)
offsets = enc.pop("offset_mapping")[0].tolist()
enc = {k: v.to(device) for k, v in enc.items()}
with torch.inference_mode():
logits = model(**enc).logits[0] # (seq_len, num_labels)
probs = torch.softmax(logits, dim=-1)
pred_ids = probs.argmax(dim=-1).tolist()
pred_scores = probs.max(dim=-1).values.tolist()
id2label = model.config.id2label
def id_to_label(i: int) -> str:
if i in id2label:
return id2label[i]
return id2label.get(str(i), "O")
labels = [id_to_label(i) for i in pred_ids]
entities: List[Dict[str, Any]] = []
i = 0
while i < len(labels):
lab = labels[i]
s, e = offsets[i]
# skip special/empty
if s == e:
i += 1
continue
if lab == "O":
i += 1
continue
# if I- without B-, treat as B-
if lab.startswith("I-"):
lab = "B-" + lab[2:]
if lab.startswith("B-"):
typ = lab[2:]
start = s
end = e
scores = [pred_scores[i]]
j = i + 1
while j < len(labels):
lab2 = labels[j]
s2, e2 = offsets[j]
if s2 == e2:
j += 1
continue
if lab2 == f"I-{typ}":
end = e2
scores.append(pred_scores[j])
j += 1
continue
break
entities.append({
"word": text[start:end],
"entity_group": typ,
"start": start,
"end": end,
"score": float(sum(scores) / max(1, len(scores))), # mean token confidence
})
i = j
else:
i += 1
return entities
def run_ner_with_windows_model(
tokenizer,
model,
device: torch.device,
text: str,
pipe_max_len: int,
window_chars: int = 2000,
overlap_chars: int = 250,
) -> List[Dict[str, Any]]:
ents: List[Dict[str, Any]] = []
start = 0
n = len(text)
while start < n:
end = min(n, start + window_chars)
chunk = text[start:end]
chunk_ents = ner_call_model(tokenizer, model, chunk, max_len=pipe_max_len, device=device)
for e in chunk_ents:
e = dict(e)
e["start"] += start
e["end"] += start
e["word"] = text[e["start"]:e["end"]]
ents.append(e)
if end == n:
break
start = max(0, end - overlap_chars)
return ents
def propagate_entities(text: str, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Add additional spans by exact/normalized string matching for selected entity types.
Returns a new entity list (original + propagated), resolved/deduped.
"""
# Which labels to propagate and how
PROPAGATE = {"CONTACT", "ID", "NAME"} # consider adding DATE if needed
MIN_ID_LEN = 5 # tune: avoid 2-3 digit labs, doses
MIN_NAME_LEN = 4 # avoid tiny tokens
# Build patterns from existing entities
patterns = []
for e in entities:
label = get_group(e)
if label not in PROPAGATE:
continue
val = e["word"].strip()
if not val:
continue
if label == "CONTACT":
# Exact string match (case-insensitive for emails)
patterns.append((label, re.escape(val), re.IGNORECASE))
elif label == "ID":
# Only propagate "ID-like" tokens
compact = re.sub(r"\D", "", val)
if len(compact) < MIN_ID_LEN:
continue
# Match the same digit sequence allowing separators
# e.g. 261808 matches "261808" or "261-808" if present
digit_pat = r"\D*".join(list(compact))
patterns.append((label, digit_pat, 0))
elif label == "NAME":
# Prefer multi-token names; for single token be conservative
# You can tune this: in vet notes, patient single-token names are still PII.
is_multi = bool(re.search(r"\s", val))
if (not is_multi) and len(val) < MIN_NAME_LEN:
continue
# Exact token/phrase match with word boundaries
pat = r"\b" + re.escape(val) + r"\b"
patterns.append((label, pat, re.IGNORECASE))
# Find additional occurrences
added = []
for label, pat, flags in patterns:
for m in re.finditer(pat, text, flags=flags):
added.append({
"word": text[m.start():m.end()],
"entity_group": label,
"score": 1.0, # propagated
"start": m.start(),
"end": m.end(),
"source": "propagated",
})
all_ents = list(entities) + added
all_ents = sorted(all_ents, key=lambda x: x["start"])
all_ents = dedup_entities_by_span(all_ents)
all_ents = resolve_overlaps(all_ents)
return all_ents
def deidentify_note(
tokenizer,
model,
device: torch.device,
text: str,
pipe_max_len: int,
thresh: Dict[str, float],
global_stoplist: set,
stop_by_label: Dict[str, set],
use_windows: bool,
window_chars: int,
overlap_chars: int,
) -> List[Dict[str, Any]]:
def pass_thresh(ent):
g = get_group(ent)
return float(ent.get("score", 0.0)) >= float(thresh.get(g, thresh.get("_default", 0.45)))
def stoplisted(ent):
g = get_group(ent)
w = ent["word"].strip().lower()
if w in global_stoplist:
return True
return w in stop_by_label.get(g, set())
# BERT
if use_windows:
bert_results = run_ner_with_windows_model(
tokenizer, model, device, text,
pipe_max_len=pipe_max_len,
window_chars=window_chars,
overlap_chars=overlap_chars,
)
else:
bert_results = ner_call_model(tokenizer, model, text, max_len=pipe_max_len, device=device)
# Merge adjacent same-label entities
bert_results = merge_adjacent_entities(bert_results, text)
# Regex CONTACT
regex_results = find_structured_pii(text)
final_entities: List[Dict[str, Any]] = []
final_entities.extend(regex_results)
for ent in bert_results:
word = ent["word"].strip()
if not pass_thresh(ent):
continue
if is_placeholder(word):
continue
if stoplisted(ent):
continue
if len(word) < 2 and not word.isdigit():
continue
# if overlaps regex CONTACT, skip BERT (regex wins)
dup = False
for reg in regex_results:
if ent["start"] < reg["end"] and ent["end"] > reg["start"]:
dup = True
break
if dup:
continue
final_entities.append(ent)
final_entities = sorted(final_entities, key=lambda x: x["start"])
final_entities = dedup_entities_by_span(final_entities)
final_entities = resolve_overlaps(final_entities)
return final_entities
# =========================
# Streamlit UI
# =========================
st.set_page_config(page_title="SpotRemover", layout="wide")
st.markdown(
"""
<div style="display:flex; gap:8px; flex-wrap:wrap; margin: 8px 0 16px 0;">
<span style="border:1px solid #E5E7EB; padding:4px 10px; border-radius:999px; font-weight:600; font-size:13px;">De-ID</span>
<span style="border:1px solid #E5E7EB; padding:4px 10px; border-radius:999px; font-weight:600; font-size:13px;">NER</span>
<span style="border:1px solid #E5E7EB; padding:4px 10px; border-radius:999px; font-weight:600; font-size:13px;">Veterinary</span>
<span style="border:1px solid #E5E7EB; padding:4px 10px; border-radius:999px; font-weight:600; font-size:13px;">One Health</span>
</div>
""",
unsafe_allow_html=True
)
brundage_header()
st.markdown("<div style='height:14px;'></div>", unsafe_allow_html=True)
inject_brundage_theme()
with st.sidebar:
st.header("Model")
# 1) Define your private fine-tuned model repo IDs (store actual values in env or hardcode)
# Option A (recommended): keep repo IDs in env so you don't commit them
MODEL_REGISTRY = {
"VetBERT (fine-tuned)": os.environ.get("HF_REPO_VETBERT", ""),
"PetBERT (fine-tuned)": os.environ.get("HF_REPO_PETBERT", ""),
"ClinicalBERT (fine-tuned)": os.environ.get("HF_REPO_CLINICALBERT", ""),
}
# 2) Dropdown selector
model_label = st.selectbox(
"Select model",
options=list(MODEL_REGISTRY.keys()),
index=0,
)
repo_id = MODEL_REGISTRY[model_label]
# 3) Optional revision (still OK to keep)
revision = (os.environ.get("HF_REVISION", "").strip() or None)
# 4) Token comes ONLY from environment
hf_token = (os.environ.get("HF_TOKEN", "").strip() or None)
if not repo_id:
st.error("Model repo_id is not set. Define HF_REPO_VETBERT / HF_REPO_PETBERT / HF_REPO_CLINICALBERT.")
st.stop()
if hf_token is None:
st.error("HF_TOKEN environment variable is not set (required for private models).")
st.stop()
st.header("Runtime")
use_gpu = False
device_str = "cuda:0" if (use_gpu and torch.cuda.is_available()) else "cpu"
pipe_max_len = st.selectbox("Max token length", options=[256, 512], index=0)
use_windows = st.checkbox("Window long notes (recommended)", value=True)
window_chars = st.slider("Window size (chars)", 500, 6000, 2000, 100)
overlap_chars = st.slider("Window overlap (chars)", 0, 1000, 250, 25)
st.header("Thresholds")
t_name = st.slider("NAME", 0.0, 1.0, 0.60, 0.01)
t_org = st.slider("ORG", 0.0, 1.0, 0.60, 0.01)
t_loc = st.slider("LOC", 0.0, 1.0, 0.60, 0.01)
t_date = st.slider("DATE", 0.0, 1.0, 0.45, 0.01)
t_id = st.slider("ID", 0.0, 1.0, 0.50, 0.01)
t_contact = st.slider("CONTACT (model)", 0.0, 1.0, 0.99, 0.01) # regex-first anyway
t_default = st.slider("Default", 0.0, 1.0, 0.45, 0.01)
redact_mode = st.selectbox("Redaction mode", options=["tags", "char"], index=0)
show_highlight = st.checkbox("Show highlighted original", value=True)
# Load model/tokenizer
try:
tokenizer, model, device = load_hf_model(repo_id=repo_id, revision=revision, hf_token=hf_token, device_str=device_str)
except Exception as e:
st.error(f"Failed to load model/tokenizer from HF.\n\nrepo_id={repo_id}\nrevision={revision}\n\n{e}")
st.stop()
# Stoplists (can be made editable later)
GLOBAL_STOPLIST = {"er", "ve", "w", "dvm", "mph", "sex", "male", "female", "kg", "lb", "patient", "owner", "left", "right"}
STOP_BY_LABEL = {
"LOC": {"dsh", "feline", "canine", "equine", "bovine", "species", "breed", "color"},
"NAME": {"owner", "patient"},
}
THRESH = {
"NAME": t_name,
"ORG": t_org,
"LOC": t_loc,
"DATE": t_date,
"ID": t_id,
"CONTACT": t_contact,
"_default": t_default,
}
tab1, tab2, tab3 = st.tabs(["Single note", "Batch (CSV/TXT)", "About"])
with tab1:
st.subheader("Single note")
default_text = "Paste a veterinary note here..."
text = st.text_area("Input", height=260, value=default_text)
colA, colB = st.columns([1, 1])
with colA:
run_single = st.button("Run", type="primary")
if run_single:
with st.spinner("Running de-identification..."):
final_ents = deidentify_note(
tokenizer=tokenizer,
model=model,
device=device,
text=text,
pipe_max_len=pipe_max_len,
thresh=THRESH,
global_stoplist=GLOBAL_STOPLIST,
stop_by_label=STOP_BY_LABEL,
use_windows=use_windows,
window_chars=window_chars,
overlap_chars=overlap_chars,
)
enable_propagation = st.checkbox("Propagate exact matches (recommended)", value=True)
if enable_propagation:
final_ents = propagate_entities(text, final_ents)
redacted = redact_text(text, final_ents, mode=redact_mode)
left, right = st.columns([1, 1])
with left:
st.subheader("Entities")
if final_ents:
df = pd.DataFrame([{
"type": get_group(e),
"text": e["word"],
"score": float(e.get("score", 0.0)),
"start": int(e["start"]),
"end": int(e["end"]),
} for e in final_ents])
st.dataframe(df, use_container_width=True)
else:
st.write("No entities found.")
st.download_button(
"Download entities (JSON)",
data=json.dumps(final_ents, indent=2).encode("utf-8"),
file_name="entities.json",
mime="application/json",
)
with right:
st.subheader("Redacted output")
st.text_area("Output", height=260, value=redacted)
st.download_button(
"Download redacted text",
data=redacted.encode("utf-8"),
file_name="redacted.txt",
mime="text/plain",
)
if show_highlight:
st.subheader("Highlighted original")
#st.markdown(highlight_entities_html(text, final_ents), unsafe_allow_html=True)
components.html(
highlight_entities_html(text, final_ents),
height=600,
scrolling=True,
)
with tab2:
st.subheader("Batch processing")
st.write("Upload a CSV (one note per row) or a TXT file (single note).")
uploaded = st.file_uploader("Upload CSV or TXT", type=["csv", "txt"])
if uploaded is not None:
if uploaded.name.lower().endswith(".txt"):
raw = uploaded.getvalue().decode("utf-8", errors="replace")
st.write("Detected TXT input (single note). Use the Single note tab for best UX.")
st.text_area("Preview", value=raw[:5000], height=200)
else:
df_in = pd.read_csv(uploaded)
st.write(f"Loaded CSV with {len(df_in)} rows and columns: {list(df_in.columns)}")
text_col = st.selectbox("Text column", options=list(df_in.columns), index=0)
max_rows = st.slider("Max rows to process (demo)", 1, min(5000, len(df_in)), min(200, len(df_in)), 1)
if st.button("Run batch de-identification", type="primary"):
out_rows = []
progress = st.progress(0)
for i in range(max_rows):
note = str(df_in.loc[i, text_col]) if pd.notna(df_in.loc[i, text_col]) else ""
ents = deidentify_note(
tokenizer=tokenizer,
model=model,
device=device,
text=note,
pipe_max_len=pipe_max_len,
thresh=THRESH,
global_stoplist=GLOBAL_STOPLIST,
stop_by_label=STOP_BY_LABEL,
use_windows=use_windows,
window_chars=window_chars,
overlap_chars=overlap_chars,
)
redacted = redact_text(note, ents, mode=redact_mode)
out_rows.append({
"row": i,
"redacted": redacted,
"entities_json": json.dumps(ents, ensure_ascii=False),
"n_entities": len(ents),
})
if (i + 1) % 5 == 0 or (i + 1) == max_rows:
progress.progress((i + 1) / max_rows)
out_df = pd.DataFrame(out_rows)
st.success(f"Processed {max_rows} rows.")
st.subheader("Batch results (preview)")
st.dataframe(out_df.head(50), use_container_width=True)
csv_bytes = out_df.to_csv(index=False).encode("utf-8")
st.download_button(
"Download redacted CSV",
data=csv_bytes,
file_name="redacted_output.csv",
mime="text/csv",
)
with tab3:
st.subheader("About")
st.markdown(
"""
### About this tool
This interactive demo is part of the **Brundage Lab (brundagelab.org)** research program on **AI methods for veterinary clinical text** and privacy-preserving data sharing for veterinary and One Health applications.
**What it does**
- Performs **veterinary de-identification** on free-text clinical narratives by detecting and redacting identifiers such as **owner/client names**, **addresses/locations**, **dates**, **IDs**, and **contact information**.
- Uses a **fine-tuned transformer NER model** (selectable backbone such as VetBERT / PetBERT / ClinicalBERT) loaded from a **private Hugging Face repository**.
- Augments model predictions with **high-precision pattern matching** for structured identifiers (e.g., emails and phone numbers).
**How to interpret results**
- This tool prioritizes **high recall** for sensitive identifiers (reducing false negatives), with thresholds adjustable in the sidebar.
- The highlighted view is provided for **demonstration and error analysis**; the redacted output is the intended downstream artifact.
**Engineering notes**
- **Model source**: loaded directly from Hugging Face (optionally pinned to a specific revision for reproducibility).
- **CONTACT**: extracted via regex (emails/phones). If the model also predicts CONTACT, regex is treated as the source of truth on overlaps.
- **Long notes**: optional windowing reduces truncation artifacts and improves coverage across multi-page notes.
**Privacy and intended use**
- This is a **research and demonstration tool**, not a certified de-identification system.
- Do **not** paste sensitive/regulated data unless you are running the tool in an approved environment with appropriate controls.
- For any public deployment, ensure **access control**, **minimal logging**, and a **privacy/security review** consistent with your institution’s policies.
**Citation / attribution**
If you use this tool or its outputs in a manuscript, please cite the Brundage Lab and describe the model backbone, training data composition (real vs. synthetic), and evaluation protocol.
"""
)
st.caption("Tip: Select your model backbone and explore a single document. Modify default thresholds to finetune your performance.")