# app.py # Streamlit "product-like" Vet De-ID demo (PIPELINE-FREE): # - Loads model from a Hugging Face repo ID (public or private via HF token) # - Runs token-classification via tokenizer+model directly (no HF pipeline kwargs issues) # - Single-note + batch (CSV/TXT) processing # - Highlighted redaction preview + entity table # - Downloads: redacted text, JSON entities, redacted CSV import os import re import json from typing import List, Dict, Any, Optional import streamlit.components.v1 as components import pandas as pd import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForTokenClassification from dotenv import load_dotenv from pathlib import Path load_dotenv() import base64 # Absolute path to this file's directory (e.g., /app/src) HERE = Path(__file__).resolve().parent ASSETS_DIR = HERE / "assets" APP_HOME_URL = "https://www.brundagelab.org/research/apps/" # change to your desired destination def _img_to_data_uri(path: str) -> str: ext = os.path.splitext(path)[1].lower().lstrip(".") mime = "image/png" if ext == "png" else "image/jpeg" if ext in {"jpg","jpeg"} else "image/svg+xml" with open(path, "rb") as f: b64 = base64.b64encode(f.read()).decode("utf-8") return f"data:{mime};base64,{b64}" def brundage_header(): col1, col2 = st.columns([2, 5]) with col1: logo_path = ASSETS_DIR / "brundage_logo.png" logo_uri = _img_to_data_uri(logo_path) st.markdown( f""" Brundage Lab """, unsafe_allow_html=True, ) with col2: st.markdown( """
SpotRemover: Veterinary De-Identification
""", unsafe_allow_html=True ) def inject_brundage_theme(): st.markdown( """ """, unsafe_allow_html=True, ) # ========================= # Core utilities # ========================= def get_group(ent: Dict[str, Any]) -> str: return ent.get("entity_group") or ent.get("entity") or "UNK" def norm_contact(s: str) -> str: s = s.strip().lower() if "@" in s: return s return re.sub(r"\D", "", s) def resolve_overlaps(entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # Keep longest span first, then higher score ents = sorted( entities, key=lambda e: (e["start"], -(e["end"] - e["start"]), -float(e.get("score", 0.0))) ) kept: List[Dict[str, Any]] = [] for e in ents: overlap = False for k in kept: if e["start"] < k["end"] and e["end"] > k["start"]: overlap = True break if not overlap: kept.append(e) return kept def dedup_entities_by_span(ents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: seen = set() out = [] for e in ents: key = (get_group(e), int(e["start"]), int(e["end"])) if key in seen: continue seen.add(key) out.append(e) return out def is_placeholder(word: str) -> bool: w = word.strip() if re.fullmatch(r"[_\s\-\(\)]+", w): return True if w.count("_") >= 2 and len(re.sub(r"[_\s\-\(\)]", "", w)) < 2: return True return False def merge_adjacent_entities(entities: List[Dict[str, Any]], text: str) -> List[Dict[str, Any]]: """ Merge same-label spans separated only by safe punctuation/whitespace. Prevent merges across newlines / field boundaries. """ if not entities: return [] entities = sorted(entities, key=lambda x: x["start"]) merged = [dict(entities[0])] for nxt in entities[1:]: cur = merged[-1] same = (get_group(cur) == get_group(nxt)) gap_text = text[cur["end"]:nxt["start"]] gap = nxt["start"] - cur["end"] if "\n" in gap_text or "\r" in gap_text: merged.append(dict(nxt)) continue safe_gap = bool(re.fullmatch(r"[ \t,./\-()]*", gap_text)) if same and gap <= 3 and safe_gap: new_end = nxt["end"] cur["word"] = text[cur["start"]:new_end] cur["end"] = new_end cur["score"] = max(float(cur.get("score", 0.0)), float(nxt.get("score", 0.0))) else: merged.append(dict(nxt)) return merged def find_structured_pii(text: str) -> List[Dict[str, Any]]: hits = [] # Emails for m in re.finditer(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text): hits.append({"word": m.group(), "entity_group": "CONTACT", "score": 1.0, "start": m.start(), "end": m.end()}) # Phones (US-ish) for m in re.finditer(r"\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}", text): hits.append({"word": m.group(), "entity_group": "CONTACT", "score": 1.0, "start": m.start(), "end": m.end()}) return hits def redact_text(text: str, entities: List[Dict[str, Any]], mode: str = "tags") -> str: """ mode="tags": [NAME], [LOC], etc. mode="char": ***** preserving length """ entities = resolve_overlaps(entities) entities = sorted(entities, key=lambda x: x["start"], reverse=True) redacted = text for ent in entities: start, end = ent["start"], ent["end"] label = get_group(ent) replacement = f"[{label}]" if mode == "tags" else "*" * max(1, (end - start)) redacted = redacted[:start] + replacement + redacted[end:] return redacted def highlight_entities_html(text: str, entities: List[Dict[str, Any]]) -> str: entities = resolve_overlaps(entities) entities = sorted(entities, key=lambda x: x["start"]) # RGBA base colors (R,G,B); alpha is scaled by score palette_rgb = { "NAME": (124, 58, 237), "LOC": (59, 130, 246), "ORG": (16, 185, 129), "DATE": (244, 63, 94), "ID": (234, 179, 8), "CONTACT": (14, 165, 233), "UNK": (107, 114, 128), } def esc(s: str) -> str: return (s.replace("&", "&") .replace("<", "<") .replace(">", ">") .replace('"', """) .replace("'", "'")) css = """ """ out = [] cursor = 0 for e in entities: s, t = e["start"], e["end"] if s < cursor: continue out.append(esc(text[cursor:s])) label = get_group(e) r, g, b = palette_rgb.get(label, palette_rgb["UNK"]) score = float(e.get("score", 0.0)) # background alpha: 0.10 to 0.32 depending on confidence alpha = 0.10 + 0.22 * max(0.0, min(1.0, score)) span_text = esc(text[s:t]) title = f"{label} • {score:.2f}" out.append( f'' f'{span_text}' f'{label}' f"" ) cursor = t out.append(esc(text[cursor:])) return css + "
" + "".join(out) + "
" # ========================= # Model loading from HF (NO PIPELINE) # ========================= @st.cache_resource def load_hf_model( repo_id: str, revision: Optional[str], hf_token: Optional[str], device_str: str, ): device = torch.device(device_str) tok = AutoTokenizer.from_pretrained(repo_id, revision=revision, token=hf_token) mdl = AutoModelForTokenClassification.from_pretrained(repo_id, revision=revision, token=hf_token) mdl.to(device) mdl.eval() return tok, mdl, device # ========================= # NER: model-based inference with offsets (BIO -> spans) # ========================= def ner_call_model(tokenizer, model, text: str, max_len: int, device: torch.device) -> List[Dict[str, Any]]: enc = tokenizer( text, return_offsets_mapping=True, truncation=True, max_length=max_len, return_tensors="pt", padding=False, ) offsets = enc.pop("offset_mapping")[0].tolist() enc = {k: v.to(device) for k, v in enc.items()} with torch.inference_mode(): logits = model(**enc).logits[0] # (seq_len, num_labels) probs = torch.softmax(logits, dim=-1) pred_ids = probs.argmax(dim=-1).tolist() pred_scores = probs.max(dim=-1).values.tolist() id2label = model.config.id2label def id_to_label(i: int) -> str: if i in id2label: return id2label[i] return id2label.get(str(i), "O") labels = [id_to_label(i) for i in pred_ids] entities: List[Dict[str, Any]] = [] i = 0 while i < len(labels): lab = labels[i] s, e = offsets[i] # skip special/empty if s == e: i += 1 continue if lab == "O": i += 1 continue # if I- without B-, treat as B- if lab.startswith("I-"): lab = "B-" + lab[2:] if lab.startswith("B-"): typ = lab[2:] start = s end = e scores = [pred_scores[i]] j = i + 1 while j < len(labels): lab2 = labels[j] s2, e2 = offsets[j] if s2 == e2: j += 1 continue if lab2 == f"I-{typ}": end = e2 scores.append(pred_scores[j]) j += 1 continue break entities.append({ "word": text[start:end], "entity_group": typ, "start": start, "end": end, "score": float(sum(scores) / max(1, len(scores))), # mean token confidence }) i = j else: i += 1 return entities def run_ner_with_windows_model( tokenizer, model, device: torch.device, text: str, pipe_max_len: int, window_chars: int = 2000, overlap_chars: int = 250, ) -> List[Dict[str, Any]]: ents: List[Dict[str, Any]] = [] start = 0 n = len(text) while start < n: end = min(n, start + window_chars) chunk = text[start:end] chunk_ents = ner_call_model(tokenizer, model, chunk, max_len=pipe_max_len, device=device) for e in chunk_ents: e = dict(e) e["start"] += start e["end"] += start e["word"] = text[e["start"]:e["end"]] ents.append(e) if end == n: break start = max(0, end - overlap_chars) return ents def propagate_entities(text: str, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Add additional spans by exact/normalized string matching for selected entity types. Returns a new entity list (original + propagated), resolved/deduped. """ # Which labels to propagate and how PROPAGATE = {"CONTACT", "ID", "NAME"} # consider adding DATE if needed MIN_ID_LEN = 5 # tune: avoid 2-3 digit labs, doses MIN_NAME_LEN = 4 # avoid tiny tokens # Build patterns from existing entities patterns = [] for e in entities: label = get_group(e) if label not in PROPAGATE: continue val = e["word"].strip() if not val: continue if label == "CONTACT": # Exact string match (case-insensitive for emails) patterns.append((label, re.escape(val), re.IGNORECASE)) elif label == "ID": # Only propagate "ID-like" tokens compact = re.sub(r"\D", "", val) if len(compact) < MIN_ID_LEN: continue # Match the same digit sequence allowing separators # e.g. 261808 matches "261808" or "261-808" if present digit_pat = r"\D*".join(list(compact)) patterns.append((label, digit_pat, 0)) elif label == "NAME": # Prefer multi-token names; for single token be conservative # You can tune this: in vet notes, patient single-token names are still PII. is_multi = bool(re.search(r"\s", val)) if (not is_multi) and len(val) < MIN_NAME_LEN: continue # Exact token/phrase match with word boundaries pat = r"\b" + re.escape(val) + r"\b" patterns.append((label, pat, re.IGNORECASE)) # Find additional occurrences added = [] for label, pat, flags in patterns: for m in re.finditer(pat, text, flags=flags): added.append({ "word": text[m.start():m.end()], "entity_group": label, "score": 1.0, # propagated "start": m.start(), "end": m.end(), "source": "propagated", }) all_ents = list(entities) + added all_ents = sorted(all_ents, key=lambda x: x["start"]) all_ents = dedup_entities_by_span(all_ents) all_ents = resolve_overlaps(all_ents) return all_ents def deidentify_note( tokenizer, model, device: torch.device, text: str, pipe_max_len: int, thresh: Dict[str, float], global_stoplist: set, stop_by_label: Dict[str, set], use_windows: bool, window_chars: int, overlap_chars: int, ) -> List[Dict[str, Any]]: def pass_thresh(ent): g = get_group(ent) return float(ent.get("score", 0.0)) >= float(thresh.get(g, thresh.get("_default", 0.45))) def stoplisted(ent): g = get_group(ent) w = ent["word"].strip().lower() if w in global_stoplist: return True return w in stop_by_label.get(g, set()) # BERT if use_windows: bert_results = run_ner_with_windows_model( tokenizer, model, device, text, pipe_max_len=pipe_max_len, window_chars=window_chars, overlap_chars=overlap_chars, ) else: bert_results = ner_call_model(tokenizer, model, text, max_len=pipe_max_len, device=device) # Merge adjacent same-label entities bert_results = merge_adjacent_entities(bert_results, text) # Regex CONTACT regex_results = find_structured_pii(text) final_entities: List[Dict[str, Any]] = [] final_entities.extend(regex_results) for ent in bert_results: word = ent["word"].strip() if not pass_thresh(ent): continue if is_placeholder(word): continue if stoplisted(ent): continue if len(word) < 2 and not word.isdigit(): continue # if overlaps regex CONTACT, skip BERT (regex wins) dup = False for reg in regex_results: if ent["start"] < reg["end"] and ent["end"] > reg["start"]: dup = True break if dup: continue final_entities.append(ent) final_entities = sorted(final_entities, key=lambda x: x["start"]) final_entities = dedup_entities_by_span(final_entities) final_entities = resolve_overlaps(final_entities) return final_entities # ========================= # Streamlit UI # ========================= st.set_page_config(page_title="SpotRemover", layout="wide") st.markdown( """
De-ID NER Veterinary One Health
""", unsafe_allow_html=True ) brundage_header() st.markdown("
", unsafe_allow_html=True) inject_brundage_theme() with st.sidebar: st.header("Model") # 1) Define your private fine-tuned model repo IDs (store actual values in env or hardcode) # Option A (recommended): keep repo IDs in env so you don't commit them MODEL_REGISTRY = { "VetBERT (fine-tuned)": os.environ.get("HF_REPO_VETBERT", ""), "PetBERT (fine-tuned)": os.environ.get("HF_REPO_PETBERT", ""), "ClinicalBERT (fine-tuned)": os.environ.get("HF_REPO_CLINICALBERT", ""), } # 2) Dropdown selector model_label = st.selectbox( "Select model", options=list(MODEL_REGISTRY.keys()), index=0, ) repo_id = MODEL_REGISTRY[model_label] # 3) Optional revision (still OK to keep) revision = (os.environ.get("HF_REVISION", "").strip() or None) # 4) Token comes ONLY from environment hf_token = (os.environ.get("HF_TOKEN", "").strip() or None) if not repo_id: st.error("Model repo_id is not set. Define HF_REPO_VETBERT / HF_REPO_PETBERT / HF_REPO_CLINICALBERT.") st.stop() if hf_token is None: st.error("HF_TOKEN environment variable is not set (required for private models).") st.stop() st.header("Runtime") use_gpu = False device_str = "cuda:0" if (use_gpu and torch.cuda.is_available()) else "cpu" pipe_max_len = st.selectbox("Max token length", options=[256, 512], index=0) use_windows = st.checkbox("Window long notes (recommended)", value=True) window_chars = st.slider("Window size (chars)", 500, 6000, 2000, 100) overlap_chars = st.slider("Window overlap (chars)", 0, 1000, 250, 25) st.header("Thresholds") t_name = st.slider("NAME", 0.0, 1.0, 0.60, 0.01) t_org = st.slider("ORG", 0.0, 1.0, 0.60, 0.01) t_loc = st.slider("LOC", 0.0, 1.0, 0.60, 0.01) t_date = st.slider("DATE", 0.0, 1.0, 0.45, 0.01) t_id = st.slider("ID", 0.0, 1.0, 0.50, 0.01) t_contact = st.slider("CONTACT (model)", 0.0, 1.0, 0.99, 0.01) # regex-first anyway t_default = st.slider("Default", 0.0, 1.0, 0.45, 0.01) redact_mode = st.selectbox("Redaction mode", options=["tags", "char"], index=0) show_highlight = st.checkbox("Show highlighted original", value=True) # Load model/tokenizer try: tokenizer, model, device = load_hf_model(repo_id=repo_id, revision=revision, hf_token=hf_token, device_str=device_str) except Exception as e: st.error(f"Failed to load model/tokenizer from HF.\n\nrepo_id={repo_id}\nrevision={revision}\n\n{e}") st.stop() # Stoplists (can be made editable later) GLOBAL_STOPLIST = {"er", "ve", "w", "dvm", "mph", "sex", "male", "female", "kg", "lb", "patient", "owner", "left", "right"} STOP_BY_LABEL = { "LOC": {"dsh", "feline", "canine", "equine", "bovine", "species", "breed", "color"}, "NAME": {"owner", "patient"}, } THRESH = { "NAME": t_name, "ORG": t_org, "LOC": t_loc, "DATE": t_date, "ID": t_id, "CONTACT": t_contact, "_default": t_default, } tab1, tab2, tab3 = st.tabs(["Single note", "Batch (CSV/TXT)", "About"]) with tab1: st.subheader("Single note") default_text = "Paste a veterinary note here..." text = st.text_area("Input", height=260, value=default_text) colA, colB = st.columns([1, 1]) with colA: run_single = st.button("Run", type="primary") if run_single: with st.spinner("Running de-identification..."): final_ents = deidentify_note( tokenizer=tokenizer, model=model, device=device, text=text, pipe_max_len=pipe_max_len, thresh=THRESH, global_stoplist=GLOBAL_STOPLIST, stop_by_label=STOP_BY_LABEL, use_windows=use_windows, window_chars=window_chars, overlap_chars=overlap_chars, ) enable_propagation = st.checkbox("Propagate exact matches (recommended)", value=True) if enable_propagation: final_ents = propagate_entities(text, final_ents) redacted = redact_text(text, final_ents, mode=redact_mode) left, right = st.columns([1, 1]) with left: st.subheader("Entities") if final_ents: df = pd.DataFrame([{ "type": get_group(e), "text": e["word"], "score": float(e.get("score", 0.0)), "start": int(e["start"]), "end": int(e["end"]), } for e in final_ents]) st.dataframe(df, use_container_width=True) else: st.write("No entities found.") st.download_button( "Download entities (JSON)", data=json.dumps(final_ents, indent=2).encode("utf-8"), file_name="entities.json", mime="application/json", ) with right: st.subheader("Redacted output") st.text_area("Output", height=260, value=redacted) st.download_button( "Download redacted text", data=redacted.encode("utf-8"), file_name="redacted.txt", mime="text/plain", ) if show_highlight: st.subheader("Highlighted original") #st.markdown(highlight_entities_html(text, final_ents), unsafe_allow_html=True) components.html( highlight_entities_html(text, final_ents), height=600, scrolling=True, ) with tab2: st.subheader("Batch processing") st.write("Upload a CSV (one note per row) or a TXT file (single note).") uploaded = st.file_uploader("Upload CSV or TXT", type=["csv", "txt"]) if uploaded is not None: if uploaded.name.lower().endswith(".txt"): raw = uploaded.getvalue().decode("utf-8", errors="replace") st.write("Detected TXT input (single note). Use the Single note tab for best UX.") st.text_area("Preview", value=raw[:5000], height=200) else: df_in = pd.read_csv(uploaded) st.write(f"Loaded CSV with {len(df_in)} rows and columns: {list(df_in.columns)}") text_col = st.selectbox("Text column", options=list(df_in.columns), index=0) max_rows = st.slider("Max rows to process (demo)", 1, min(5000, len(df_in)), min(200, len(df_in)), 1) if st.button("Run batch de-identification", type="primary"): out_rows = [] progress = st.progress(0) for i in range(max_rows): note = str(df_in.loc[i, text_col]) if pd.notna(df_in.loc[i, text_col]) else "" ents = deidentify_note( tokenizer=tokenizer, model=model, device=device, text=note, pipe_max_len=pipe_max_len, thresh=THRESH, global_stoplist=GLOBAL_STOPLIST, stop_by_label=STOP_BY_LABEL, use_windows=use_windows, window_chars=window_chars, overlap_chars=overlap_chars, ) redacted = redact_text(note, ents, mode=redact_mode) out_rows.append({ "row": i, "redacted": redacted, "entities_json": json.dumps(ents, ensure_ascii=False), "n_entities": len(ents), }) if (i + 1) % 5 == 0 or (i + 1) == max_rows: progress.progress((i + 1) / max_rows) out_df = pd.DataFrame(out_rows) st.success(f"Processed {max_rows} rows.") st.subheader("Batch results (preview)") st.dataframe(out_df.head(50), use_container_width=True) csv_bytes = out_df.to_csv(index=False).encode("utf-8") st.download_button( "Download redacted CSV", data=csv_bytes, file_name="redacted_output.csv", mime="text/csv", ) with tab3: st.subheader("About") st.markdown( """ ### About this tool This interactive demo is part of the **Brundage Lab (brundagelab.org)** research program on **AI methods for veterinary clinical text** and privacy-preserving data sharing for veterinary and One Health applications. **What it does** - Performs **veterinary de-identification** on free-text clinical narratives by detecting and redacting identifiers such as **owner/client names**, **addresses/locations**, **dates**, **IDs**, and **contact information**. - Uses a **fine-tuned transformer NER model** (selectable backbone such as VetBERT / PetBERT / ClinicalBERT) loaded from a **private Hugging Face repository**. - Augments model predictions with **high-precision pattern matching** for structured identifiers (e.g., emails and phone numbers). **How to interpret results** - This tool prioritizes **high recall** for sensitive identifiers (reducing false negatives), with thresholds adjustable in the sidebar. - The highlighted view is provided for **demonstration and error analysis**; the redacted output is the intended downstream artifact. **Engineering notes** - **Model source**: loaded directly from Hugging Face (optionally pinned to a specific revision for reproducibility). - **CONTACT**: extracted via regex (emails/phones). If the model also predicts CONTACT, regex is treated as the source of truth on overlaps. - **Long notes**: optional windowing reduces truncation artifacts and improves coverage across multi-page notes. **Privacy and intended use** - This is a **research and demonstration tool**, not a certified de-identification system. - Do **not** paste sensitive/regulated data unless you are running the tool in an approved environment with appropriate controls. - For any public deployment, ensure **access control**, **minimal logging**, and a **privacy/security review** consistent with your institution’s policies. **Citation / attribution** If you use this tool or its outputs in a manuscript, please cite the Brundage Lab and describe the model backbone, training data composition (real vs. synthetic), and evaluation protocol. """ ) st.caption("Tip: Select your model backbone and explore a single document. Modify default thresholds to finetune your performance.")