Spaces:
Running
Running
| # app.py | |
| # Streamlit "product-like" Vet De-ID demo (PIPELINE-FREE): | |
| # - Loads model from a Hugging Face repo ID (public or private via HF token) | |
| # - Runs token-classification via tokenizer+model directly (no HF pipeline kwargs issues) | |
| # - Single-note + batch (CSV/TXT) processing | |
| # - Highlighted redaction preview + entity table | |
| # - Downloads: redacted text, JSON entities, redacted CSV | |
| import os | |
| import re | |
| import json | |
| from typing import List, Dict, Any, Optional | |
| import streamlit.components.v1 as components | |
| import pandas as pd | |
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification | |
| from dotenv import load_dotenv | |
| from pathlib import Path | |
| load_dotenv() | |
| import base64 | |
| # Absolute path to this file's directory (e.g., /app/src) | |
| HERE = Path(__file__).resolve().parent | |
| ASSETS_DIR = HERE / "assets" | |
| APP_HOME_URL = "https://www.brundagelab.org/research/apps/" # change to your desired destination | |
| def _img_to_data_uri(path: str) -> str: | |
| ext = os.path.splitext(path)[1].lower().lstrip(".") | |
| mime = "image/png" if ext == "png" else "image/jpeg" if ext in {"jpg","jpeg"} else "image/svg+xml" | |
| with open(path, "rb") as f: | |
| b64 = base64.b64encode(f.read()).decode("utf-8") | |
| return f"data:{mime};base64,{b64}" | |
| def brundage_header(): | |
| col1, col2 = st.columns([2, 5]) | |
| with col1: | |
| logo_path = ASSETS_DIR / "brundage_logo.png" | |
| logo_uri = _img_to_data_uri(logo_path) | |
| st.markdown( | |
| f""" | |
| <a href="{APP_HOME_URL}" target="_self" style="display:inline-block;"> | |
| <img src="{logo_uri}" alt="Brundage Lab" style="width:256px; height:auto; cursor:pointer;" /> | |
| </a> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| with col2: | |
| st.markdown( | |
| """ | |
| <div style="padding-top:24px;"> | |
| <div style="font-size:34px; font-weight:850; letter-spacing:-0.02em; color:#111827;"> | |
| SpotRemover: Veterinary De-Identification | |
| </div> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| def inject_brundage_theme(): | |
| st.markdown( | |
| """ | |
| <style> | |
| /* ---- App canvas ---- */ | |
| html, body, [class*="css"] { | |
| font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, "Apple Color Emoji","Segoe UI Emoji"; | |
| } | |
| /* Remove Streamlit default top padding a bit */ | |
| .block-container { | |
| padding-top: 2.2rem; | |
| padding-bottom: 2.5rem; | |
| max-width: 1200px; | |
| } | |
| /* Hide Streamlit chrome */ | |
| #MainMenu {visibility: hidden;} | |
| footer {visibility: hidden;} | |
| header {visibility: hidden;} | |
| /* Headings: closer to your site (bold, clean) */ | |
| h1, h2, h3, h4 { | |
| letter-spacing: -0.02em; | |
| color: #111827; | |
| } | |
| h1 { font-weight: 800; } | |
| h2 { font-weight: 800; } | |
| h3 { font-weight: 750; } | |
| /* Sidebar polish */ | |
| section[data-testid="stSidebar"] { | |
| background: #FFFFFF; | |
| border-right: 1px solid #EEF2F7; | |
| } | |
| section[data-testid="stSidebar"] .block-container { | |
| padding-top: 1.5rem; | |
| } | |
| /* Buttons: purple primary like your site */ | |
| .stButton > button { | |
| border-radius: 12px; | |
| padding: 0.55rem 0.95rem; | |
| border: 1px solid #E5E7EB; | |
| background: #FFFFFF; | |
| color: #111827; | |
| font-weight: 650; | |
| } | |
| .stButton > button:hover { | |
| border-color: #C7B7FF; | |
| background: #FBFAFF; | |
| } | |
| .stButton > button[kind="primary"] { | |
| background: #6D28D9; | |
| color: #FFFFFF; | |
| border: 1px solid #6D28D9; | |
| box-shadow: 0 6px 18px rgba(109, 40, 217, 0.18); | |
| } | |
| .stButton > button[kind="primary"]:hover { | |
| background: #5B21B6; | |
| border-color: #5B21B6; | |
| } | |
| /* Inputs: rounded + subtle border */ | |
| div[data-baseweb="input"] > div, | |
| div[data-baseweb="textarea"] > div, | |
| div[data-baseweb="select"] > div { | |
| border-radius: 12px !important; | |
| border-color: #E5E7EB !important; | |
| box-shadow: none !important; | |
| } | |
| div[data-baseweb="textarea"] textarea { | |
| font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; | |
| font-size: 13px; | |
| } | |
| /* Tabs: reduce Streamlit “blocky” feel */ | |
| button[data-baseweb="tab"] { | |
| border-radius: 12px 12px 0 0; | |
| font-weight: 650; | |
| } | |
| /* Dataframe / tables: softer container */ | |
| div[data-testid="stDataFrame"] { | |
| border: 1px solid #EEF2F7; | |
| border-radius: 14px; | |
| overflow: hidden; | |
| } | |
| /* “Card” helper class you can use via st.markdown */ | |
| .card { | |
| border: 1px solid #EEF2F7; | |
| border-radius: 16px; | |
| padding: 14px 16px; | |
| background: #FFFFFF; | |
| box-shadow: 0 10px 24px rgba(17, 24, 39, 0.06); | |
| } | |
| /* Highlight rendering container (so text stays readable on light bg) */ | |
| .note { | |
| white-space: pre-wrap; | |
| font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace; | |
| font-size: 13px; | |
| line-height: 1.45; | |
| color: #111827; | |
| } | |
| /* Entity pill styling (used in highlight) */ | |
| .ent { | |
| border-radius: 10px; | |
| padding: 2px 6px; | |
| border: 1px solid rgba(17, 24, 39, 0.08); | |
| box-shadow: 0 6px 14px rgba(17, 24, 39, 0.06); | |
| } | |
| .ent sup { | |
| font-size: 10px; | |
| margin-left: 6px; | |
| opacity: 0.75; | |
| } | |
| /* Make captions less “default streamlit” */ | |
| .stCaption, small { | |
| color: #6B7280; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # ========================= | |
| # Core utilities | |
| # ========================= | |
| def get_group(ent: Dict[str, Any]) -> str: | |
| return ent.get("entity_group") or ent.get("entity") or "UNK" | |
| def norm_contact(s: str) -> str: | |
| s = s.strip().lower() | |
| if "@" in s: | |
| return s | |
| return re.sub(r"\D", "", s) | |
| def resolve_overlaps(entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| # Keep longest span first, then higher score | |
| ents = sorted( | |
| entities, | |
| key=lambda e: (e["start"], -(e["end"] - e["start"]), -float(e.get("score", 0.0))) | |
| ) | |
| kept: List[Dict[str, Any]] = [] | |
| for e in ents: | |
| overlap = False | |
| for k in kept: | |
| if e["start"] < k["end"] and e["end"] > k["start"]: | |
| overlap = True | |
| break | |
| if not overlap: | |
| kept.append(e) | |
| return kept | |
| def dedup_entities_by_span(ents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| seen = set() | |
| out = [] | |
| for e in ents: | |
| key = (get_group(e), int(e["start"]), int(e["end"])) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| out.append(e) | |
| return out | |
| def is_placeholder(word: str) -> bool: | |
| w = word.strip() | |
| if re.fullmatch(r"[_\s\-\(\)]+", w): | |
| return True | |
| if w.count("_") >= 2 and len(re.sub(r"[_\s\-\(\)]", "", w)) < 2: | |
| return True | |
| return False | |
| def merge_adjacent_entities(entities: List[Dict[str, Any]], text: str) -> List[Dict[str, Any]]: | |
| """ | |
| Merge same-label spans separated only by safe punctuation/whitespace. | |
| Prevent merges across newlines / field boundaries. | |
| """ | |
| if not entities: | |
| return [] | |
| entities = sorted(entities, key=lambda x: x["start"]) | |
| merged = [dict(entities[0])] | |
| for nxt in entities[1:]: | |
| cur = merged[-1] | |
| same = (get_group(cur) == get_group(nxt)) | |
| gap_text = text[cur["end"]:nxt["start"]] | |
| gap = nxt["start"] - cur["end"] | |
| if "\n" in gap_text or "\r" in gap_text: | |
| merged.append(dict(nxt)) | |
| continue | |
| safe_gap = bool(re.fullmatch(r"[ \t,./\-()]*", gap_text)) | |
| if same and gap <= 3 and safe_gap: | |
| new_end = nxt["end"] | |
| cur["word"] = text[cur["start"]:new_end] | |
| cur["end"] = new_end | |
| cur["score"] = max(float(cur.get("score", 0.0)), float(nxt.get("score", 0.0))) | |
| else: | |
| merged.append(dict(nxt)) | |
| return merged | |
| def find_structured_pii(text: str) -> List[Dict[str, Any]]: | |
| hits = [] | |
| # Emails | |
| for m in re.finditer(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", text): | |
| hits.append({"word": m.group(), "entity_group": "CONTACT", "score": 1.0, "start": m.start(), "end": m.end()}) | |
| # Phones (US-ish) | |
| for m in re.finditer(r"\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}", text): | |
| hits.append({"word": m.group(), "entity_group": "CONTACT", "score": 1.0, "start": m.start(), "end": m.end()}) | |
| return hits | |
| def redact_text(text: str, entities: List[Dict[str, Any]], mode: str = "tags") -> str: | |
| """ | |
| mode="tags": [NAME], [LOC], etc. | |
| mode="char": ***** preserving length | |
| """ | |
| entities = resolve_overlaps(entities) | |
| entities = sorted(entities, key=lambda x: x["start"], reverse=True) | |
| redacted = text | |
| for ent in entities: | |
| start, end = ent["start"], ent["end"] | |
| label = get_group(ent) | |
| replacement = f"[{label}]" if mode == "tags" else "*" * max(1, (end - start)) | |
| redacted = redacted[:start] + replacement + redacted[end:] | |
| return redacted | |
| def highlight_entities_html(text: str, entities: List[Dict[str, Any]]) -> str: | |
| entities = resolve_overlaps(entities) | |
| entities = sorted(entities, key=lambda x: x["start"]) | |
| # RGBA base colors (R,G,B); alpha is scaled by score | |
| palette_rgb = { | |
| "NAME": (124, 58, 237), | |
| "LOC": (59, 130, 246), | |
| "ORG": (16, 185, 129), | |
| "DATE": (244, 63, 94), | |
| "ID": (234, 179, 8), | |
| "CONTACT": (14, 165, 233), | |
| "UNK": (107, 114, 128), | |
| } | |
| def esc(s: str) -> str: | |
| return (s.replace("&", "&") | |
| .replace("<", "<") | |
| .replace(">", ">") | |
| .replace('"', """) | |
| .replace("'", "'")) | |
| css = """ | |
| <style> | |
| .note { | |
| white-space: pre-wrap; | |
| font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace; | |
| font-size: 13px; | |
| line-height: 1.45; | |
| /* Light card styling */ | |
| color: #111827; | |
| background: #FFFFFF; | |
| border: 1px solid #EEF2F7; | |
| padding: 12px 14px; | |
| border-radius: 14px; | |
| box-shadow: 0 10px 24px rgba(17, 24, 39, 0.06); | |
| } | |
| .ent { | |
| position: relative; | |
| border-radius: 8px; | |
| padding: 1px 4px; | |
| margin: 0px 1px; | |
| box-decoration-break: clone; | |
| -webkit-box-decoration-break: clone; | |
| transition: filter 120ms ease; | |
| border: 1px solid rgba(17, 24, 39, 0.08); | |
| } | |
| .ent:hover { filter: brightness(1.03); } | |
| .ent::after { | |
| content: ""; | |
| position: absolute; | |
| left: 6px; right: 6px; bottom: -2px; | |
| height: 2px; | |
| border-radius: 2px; | |
| background: rgba(var(--rgb), 0.85); | |
| } | |
| .pill { | |
| display: none; | |
| position: absolute; | |
| top: -18px; | |
| left: 0px; | |
| font-size: 10px; | |
| line-height: 1; | |
| padding: 3px 8px; | |
| border-radius: 999px; | |
| background: rgba(var(--rgb), 0.95); | |
| color: #111827; | |
| box-shadow: 0 6px 16px rgba(17, 24, 39, 0.12); | |
| white-space: nowrap; | |
| z-index: 5; | |
| } | |
| .ent:hover .pill { display: inline-block; } | |
| </style> | |
| """ | |
| out = [] | |
| cursor = 0 | |
| for e in entities: | |
| s, t = e["start"], e["end"] | |
| if s < cursor: | |
| continue | |
| out.append(esc(text[cursor:s])) | |
| label = get_group(e) | |
| r, g, b = palette_rgb.get(label, palette_rgb["UNK"]) | |
| score = float(e.get("score", 0.0)) | |
| # background alpha: 0.10 to 0.32 depending on confidence | |
| alpha = 0.10 + 0.22 * max(0.0, min(1.0, score)) | |
| span_text = esc(text[s:t]) | |
| title = f"{label} • {score:.2f}" | |
| out.append( | |
| f'<span class="ent" title="{esc(title)}" style="--rgb:{r},{g},{b}; background: rgba({r},{g},{b},{alpha});">' | |
| f'{span_text}' | |
| f'<span class="pill">{label}</span>' | |
| f"</span>" | |
| ) | |
| cursor = t | |
| out.append(esc(text[cursor:])) | |
| return css + "<div class='note'>" + "".join(out) + "</div>" | |
| # ========================= | |
| # Model loading from HF (NO PIPELINE) | |
| # ========================= | |
| def load_hf_model( | |
| repo_id: str, | |
| revision: Optional[str], | |
| hf_token: Optional[str], | |
| device_str: str, | |
| ): | |
| device = torch.device(device_str) | |
| tok = AutoTokenizer.from_pretrained(repo_id, revision=revision, token=hf_token) | |
| mdl = AutoModelForTokenClassification.from_pretrained(repo_id, revision=revision, token=hf_token) | |
| mdl.to(device) | |
| mdl.eval() | |
| return tok, mdl, device | |
| # ========================= | |
| # NER: model-based inference with offsets (BIO -> spans) | |
| # ========================= | |
| def ner_call_model(tokenizer, model, text: str, max_len: int, device: torch.device) -> List[Dict[str, Any]]: | |
| enc = tokenizer( | |
| text, | |
| return_offsets_mapping=True, | |
| truncation=True, | |
| max_length=max_len, | |
| return_tensors="pt", | |
| padding=False, | |
| ) | |
| offsets = enc.pop("offset_mapping")[0].tolist() | |
| enc = {k: v.to(device) for k, v in enc.items()} | |
| with torch.inference_mode(): | |
| logits = model(**enc).logits[0] # (seq_len, num_labels) | |
| probs = torch.softmax(logits, dim=-1) | |
| pred_ids = probs.argmax(dim=-1).tolist() | |
| pred_scores = probs.max(dim=-1).values.tolist() | |
| id2label = model.config.id2label | |
| def id_to_label(i: int) -> str: | |
| if i in id2label: | |
| return id2label[i] | |
| return id2label.get(str(i), "O") | |
| labels = [id_to_label(i) for i in pred_ids] | |
| entities: List[Dict[str, Any]] = [] | |
| i = 0 | |
| while i < len(labels): | |
| lab = labels[i] | |
| s, e = offsets[i] | |
| # skip special/empty | |
| if s == e: | |
| i += 1 | |
| continue | |
| if lab == "O": | |
| i += 1 | |
| continue | |
| # if I- without B-, treat as B- | |
| if lab.startswith("I-"): | |
| lab = "B-" + lab[2:] | |
| if lab.startswith("B-"): | |
| typ = lab[2:] | |
| start = s | |
| end = e | |
| scores = [pred_scores[i]] | |
| j = i + 1 | |
| while j < len(labels): | |
| lab2 = labels[j] | |
| s2, e2 = offsets[j] | |
| if s2 == e2: | |
| j += 1 | |
| continue | |
| if lab2 == f"I-{typ}": | |
| end = e2 | |
| scores.append(pred_scores[j]) | |
| j += 1 | |
| continue | |
| break | |
| entities.append({ | |
| "word": text[start:end], | |
| "entity_group": typ, | |
| "start": start, | |
| "end": end, | |
| "score": float(sum(scores) / max(1, len(scores))), # mean token confidence | |
| }) | |
| i = j | |
| else: | |
| i += 1 | |
| return entities | |
| def run_ner_with_windows_model( | |
| tokenizer, | |
| model, | |
| device: torch.device, | |
| text: str, | |
| pipe_max_len: int, | |
| window_chars: int = 2000, | |
| overlap_chars: int = 250, | |
| ) -> List[Dict[str, Any]]: | |
| ents: List[Dict[str, Any]] = [] | |
| start = 0 | |
| n = len(text) | |
| while start < n: | |
| end = min(n, start + window_chars) | |
| chunk = text[start:end] | |
| chunk_ents = ner_call_model(tokenizer, model, chunk, max_len=pipe_max_len, device=device) | |
| for e in chunk_ents: | |
| e = dict(e) | |
| e["start"] += start | |
| e["end"] += start | |
| e["word"] = text[e["start"]:e["end"]] | |
| ents.append(e) | |
| if end == n: | |
| break | |
| start = max(0, end - overlap_chars) | |
| return ents | |
| def propagate_entities(text: str, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """ | |
| Add additional spans by exact/normalized string matching for selected entity types. | |
| Returns a new entity list (original + propagated), resolved/deduped. | |
| """ | |
| # Which labels to propagate and how | |
| PROPAGATE = {"CONTACT", "ID", "NAME"} # consider adding DATE if needed | |
| MIN_ID_LEN = 5 # tune: avoid 2-3 digit labs, doses | |
| MIN_NAME_LEN = 4 # avoid tiny tokens | |
| # Build patterns from existing entities | |
| patterns = [] | |
| for e in entities: | |
| label = get_group(e) | |
| if label not in PROPAGATE: | |
| continue | |
| val = e["word"].strip() | |
| if not val: | |
| continue | |
| if label == "CONTACT": | |
| # Exact string match (case-insensitive for emails) | |
| patterns.append((label, re.escape(val), re.IGNORECASE)) | |
| elif label == "ID": | |
| # Only propagate "ID-like" tokens | |
| compact = re.sub(r"\D", "", val) | |
| if len(compact) < MIN_ID_LEN: | |
| continue | |
| # Match the same digit sequence allowing separators | |
| # e.g. 261808 matches "261808" or "261-808" if present | |
| digit_pat = r"\D*".join(list(compact)) | |
| patterns.append((label, digit_pat, 0)) | |
| elif label == "NAME": | |
| # Prefer multi-token names; for single token be conservative | |
| # You can tune this: in vet notes, patient single-token names are still PII. | |
| is_multi = bool(re.search(r"\s", val)) | |
| if (not is_multi) and len(val) < MIN_NAME_LEN: | |
| continue | |
| # Exact token/phrase match with word boundaries | |
| pat = r"\b" + re.escape(val) + r"\b" | |
| patterns.append((label, pat, re.IGNORECASE)) | |
| # Find additional occurrences | |
| added = [] | |
| for label, pat, flags in patterns: | |
| for m in re.finditer(pat, text, flags=flags): | |
| added.append({ | |
| "word": text[m.start():m.end()], | |
| "entity_group": label, | |
| "score": 1.0, # propagated | |
| "start": m.start(), | |
| "end": m.end(), | |
| "source": "propagated", | |
| }) | |
| all_ents = list(entities) + added | |
| all_ents = sorted(all_ents, key=lambda x: x["start"]) | |
| all_ents = dedup_entities_by_span(all_ents) | |
| all_ents = resolve_overlaps(all_ents) | |
| return all_ents | |
| def deidentify_note( | |
| tokenizer, | |
| model, | |
| device: torch.device, | |
| text: str, | |
| pipe_max_len: int, | |
| thresh: Dict[str, float], | |
| global_stoplist: set, | |
| stop_by_label: Dict[str, set], | |
| use_windows: bool, | |
| window_chars: int, | |
| overlap_chars: int, | |
| ) -> List[Dict[str, Any]]: | |
| def pass_thresh(ent): | |
| g = get_group(ent) | |
| return float(ent.get("score", 0.0)) >= float(thresh.get(g, thresh.get("_default", 0.45))) | |
| def stoplisted(ent): | |
| g = get_group(ent) | |
| w = ent["word"].strip().lower() | |
| if w in global_stoplist: | |
| return True | |
| return w in stop_by_label.get(g, set()) | |
| # BERT | |
| if use_windows: | |
| bert_results = run_ner_with_windows_model( | |
| tokenizer, model, device, text, | |
| pipe_max_len=pipe_max_len, | |
| window_chars=window_chars, | |
| overlap_chars=overlap_chars, | |
| ) | |
| else: | |
| bert_results = ner_call_model(tokenizer, model, text, max_len=pipe_max_len, device=device) | |
| # Merge adjacent same-label entities | |
| bert_results = merge_adjacent_entities(bert_results, text) | |
| # Regex CONTACT | |
| regex_results = find_structured_pii(text) | |
| final_entities: List[Dict[str, Any]] = [] | |
| final_entities.extend(regex_results) | |
| for ent in bert_results: | |
| word = ent["word"].strip() | |
| if not pass_thresh(ent): | |
| continue | |
| if is_placeholder(word): | |
| continue | |
| if stoplisted(ent): | |
| continue | |
| if len(word) < 2 and not word.isdigit(): | |
| continue | |
| # if overlaps regex CONTACT, skip BERT (regex wins) | |
| dup = False | |
| for reg in regex_results: | |
| if ent["start"] < reg["end"] and ent["end"] > reg["start"]: | |
| dup = True | |
| break | |
| if dup: | |
| continue | |
| final_entities.append(ent) | |
| final_entities = sorted(final_entities, key=lambda x: x["start"]) | |
| final_entities = dedup_entities_by_span(final_entities) | |
| final_entities = resolve_overlaps(final_entities) | |
| return final_entities | |
| # ========================= | |
| # Streamlit UI | |
| # ========================= | |
| st.set_page_config(page_title="SpotRemover", layout="wide") | |
| st.markdown( | |
| """ | |
| <div style="display:flex; gap:8px; flex-wrap:wrap; margin: 8px 0 16px 0;"> | |
| <span style="border:1px solid #E5E7EB; padding:4px 10px; border-radius:999px; font-weight:600; font-size:13px;">De-ID</span> | |
| <span style="border:1px solid #E5E7EB; padding:4px 10px; border-radius:999px; font-weight:600; font-size:13px;">NER</span> | |
| <span style="border:1px solid #E5E7EB; padding:4px 10px; border-radius:999px; font-weight:600; font-size:13px;">Veterinary</span> | |
| <span style="border:1px solid #E5E7EB; padding:4px 10px; border-radius:999px; font-weight:600; font-size:13px;">One Health</span> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| brundage_header() | |
| st.markdown("<div style='height:14px;'></div>", unsafe_allow_html=True) | |
| inject_brundage_theme() | |
| with st.sidebar: | |
| st.header("Model") | |
| # 1) Define your private fine-tuned model repo IDs (store actual values in env or hardcode) | |
| # Option A (recommended): keep repo IDs in env so you don't commit them | |
| MODEL_REGISTRY = { | |
| "VetBERT (fine-tuned)": os.environ.get("HF_REPO_VETBERT", ""), | |
| "PetBERT (fine-tuned)": os.environ.get("HF_REPO_PETBERT", ""), | |
| "ClinicalBERT (fine-tuned)": os.environ.get("HF_REPO_CLINICALBERT", ""), | |
| } | |
| # 2) Dropdown selector | |
| model_label = st.selectbox( | |
| "Select model", | |
| options=list(MODEL_REGISTRY.keys()), | |
| index=0, | |
| ) | |
| repo_id = MODEL_REGISTRY[model_label] | |
| # 3) Optional revision (still OK to keep) | |
| revision = (os.environ.get("HF_REVISION", "").strip() or None) | |
| # 4) Token comes ONLY from environment | |
| hf_token = (os.environ.get("HF_TOKEN", "").strip() or None) | |
| if not repo_id: | |
| st.error("Model repo_id is not set. Define HF_REPO_VETBERT / HF_REPO_PETBERT / HF_REPO_CLINICALBERT.") | |
| st.stop() | |
| if hf_token is None: | |
| st.error("HF_TOKEN environment variable is not set (required for private models).") | |
| st.stop() | |
| st.header("Runtime") | |
| use_gpu = False | |
| device_str = "cuda:0" if (use_gpu and torch.cuda.is_available()) else "cpu" | |
| pipe_max_len = st.selectbox("Max token length", options=[256, 512], index=0) | |
| use_windows = st.checkbox("Window long notes (recommended)", value=True) | |
| window_chars = st.slider("Window size (chars)", 500, 6000, 2000, 100) | |
| overlap_chars = st.slider("Window overlap (chars)", 0, 1000, 250, 25) | |
| st.header("Thresholds") | |
| t_name = st.slider("NAME", 0.0, 1.0, 0.60, 0.01) | |
| t_org = st.slider("ORG", 0.0, 1.0, 0.60, 0.01) | |
| t_loc = st.slider("LOC", 0.0, 1.0, 0.60, 0.01) | |
| t_date = st.slider("DATE", 0.0, 1.0, 0.45, 0.01) | |
| t_id = st.slider("ID", 0.0, 1.0, 0.50, 0.01) | |
| t_contact = st.slider("CONTACT (model)", 0.0, 1.0, 0.99, 0.01) # regex-first anyway | |
| t_default = st.slider("Default", 0.0, 1.0, 0.45, 0.01) | |
| redact_mode = st.selectbox("Redaction mode", options=["tags", "char"], index=0) | |
| show_highlight = st.checkbox("Show highlighted original", value=True) | |
| # Load model/tokenizer | |
| try: | |
| tokenizer, model, device = load_hf_model(repo_id=repo_id, revision=revision, hf_token=hf_token, device_str=device_str) | |
| except Exception as e: | |
| st.error(f"Failed to load model/tokenizer from HF.\n\nrepo_id={repo_id}\nrevision={revision}\n\n{e}") | |
| st.stop() | |
| # Stoplists (can be made editable later) | |
| GLOBAL_STOPLIST = {"er", "ve", "w", "dvm", "mph", "sex", "male", "female", "kg", "lb", "patient", "owner", "left", "right"} | |
| STOP_BY_LABEL = { | |
| "LOC": {"dsh", "feline", "canine", "equine", "bovine", "species", "breed", "color"}, | |
| "NAME": {"owner", "patient"}, | |
| } | |
| THRESH = { | |
| "NAME": t_name, | |
| "ORG": t_org, | |
| "LOC": t_loc, | |
| "DATE": t_date, | |
| "ID": t_id, | |
| "CONTACT": t_contact, | |
| "_default": t_default, | |
| } | |
| tab1, tab2, tab3 = st.tabs(["Single note", "Batch (CSV/TXT)", "About"]) | |
| with tab1: | |
| st.subheader("Single note") | |
| default_text = "Paste a veterinary note here..." | |
| text = st.text_area("Input", height=260, value=default_text) | |
| colA, colB = st.columns([1, 1]) | |
| with colA: | |
| run_single = st.button("Run", type="primary") | |
| if run_single: | |
| with st.spinner("Running de-identification..."): | |
| final_ents = deidentify_note( | |
| tokenizer=tokenizer, | |
| model=model, | |
| device=device, | |
| text=text, | |
| pipe_max_len=pipe_max_len, | |
| thresh=THRESH, | |
| global_stoplist=GLOBAL_STOPLIST, | |
| stop_by_label=STOP_BY_LABEL, | |
| use_windows=use_windows, | |
| window_chars=window_chars, | |
| overlap_chars=overlap_chars, | |
| ) | |
| enable_propagation = st.checkbox("Propagate exact matches (recommended)", value=True) | |
| if enable_propagation: | |
| final_ents = propagate_entities(text, final_ents) | |
| redacted = redact_text(text, final_ents, mode=redact_mode) | |
| left, right = st.columns([1, 1]) | |
| with left: | |
| st.subheader("Entities") | |
| if final_ents: | |
| df = pd.DataFrame([{ | |
| "type": get_group(e), | |
| "text": e["word"], | |
| "score": float(e.get("score", 0.0)), | |
| "start": int(e["start"]), | |
| "end": int(e["end"]), | |
| } for e in final_ents]) | |
| st.dataframe(df, use_container_width=True) | |
| else: | |
| st.write("No entities found.") | |
| st.download_button( | |
| "Download entities (JSON)", | |
| data=json.dumps(final_ents, indent=2).encode("utf-8"), | |
| file_name="entities.json", | |
| mime="application/json", | |
| ) | |
| with right: | |
| st.subheader("Redacted output") | |
| st.text_area("Output", height=260, value=redacted) | |
| st.download_button( | |
| "Download redacted text", | |
| data=redacted.encode("utf-8"), | |
| file_name="redacted.txt", | |
| mime="text/plain", | |
| ) | |
| if show_highlight: | |
| st.subheader("Highlighted original") | |
| #st.markdown(highlight_entities_html(text, final_ents), unsafe_allow_html=True) | |
| components.html( | |
| highlight_entities_html(text, final_ents), | |
| height=600, | |
| scrolling=True, | |
| ) | |
| with tab2: | |
| st.subheader("Batch processing") | |
| st.write("Upload a CSV (one note per row) or a TXT file (single note).") | |
| uploaded = st.file_uploader("Upload CSV or TXT", type=["csv", "txt"]) | |
| if uploaded is not None: | |
| if uploaded.name.lower().endswith(".txt"): | |
| raw = uploaded.getvalue().decode("utf-8", errors="replace") | |
| st.write("Detected TXT input (single note). Use the Single note tab for best UX.") | |
| st.text_area("Preview", value=raw[:5000], height=200) | |
| else: | |
| df_in = pd.read_csv(uploaded) | |
| st.write(f"Loaded CSV with {len(df_in)} rows and columns: {list(df_in.columns)}") | |
| text_col = st.selectbox("Text column", options=list(df_in.columns), index=0) | |
| max_rows = st.slider("Max rows to process (demo)", 1, min(5000, len(df_in)), min(200, len(df_in)), 1) | |
| if st.button("Run batch de-identification", type="primary"): | |
| out_rows = [] | |
| progress = st.progress(0) | |
| for i in range(max_rows): | |
| note = str(df_in.loc[i, text_col]) if pd.notna(df_in.loc[i, text_col]) else "" | |
| ents = deidentify_note( | |
| tokenizer=tokenizer, | |
| model=model, | |
| device=device, | |
| text=note, | |
| pipe_max_len=pipe_max_len, | |
| thresh=THRESH, | |
| global_stoplist=GLOBAL_STOPLIST, | |
| stop_by_label=STOP_BY_LABEL, | |
| use_windows=use_windows, | |
| window_chars=window_chars, | |
| overlap_chars=overlap_chars, | |
| ) | |
| redacted = redact_text(note, ents, mode=redact_mode) | |
| out_rows.append({ | |
| "row": i, | |
| "redacted": redacted, | |
| "entities_json": json.dumps(ents, ensure_ascii=False), | |
| "n_entities": len(ents), | |
| }) | |
| if (i + 1) % 5 == 0 or (i + 1) == max_rows: | |
| progress.progress((i + 1) / max_rows) | |
| out_df = pd.DataFrame(out_rows) | |
| st.success(f"Processed {max_rows} rows.") | |
| st.subheader("Batch results (preview)") | |
| st.dataframe(out_df.head(50), use_container_width=True) | |
| csv_bytes = out_df.to_csv(index=False).encode("utf-8") | |
| st.download_button( | |
| "Download redacted CSV", | |
| data=csv_bytes, | |
| file_name="redacted_output.csv", | |
| mime="text/csv", | |
| ) | |
| with tab3: | |
| st.subheader("About") | |
| st.markdown( | |
| """ | |
| ### About this tool | |
| This interactive demo is part of the **Brundage Lab (brundagelab.org)** research program on **AI methods for veterinary clinical text** and privacy-preserving data sharing for veterinary and One Health applications. | |
| **What it does** | |
| - Performs **veterinary de-identification** on free-text clinical narratives by detecting and redacting identifiers such as **owner/client names**, **addresses/locations**, **dates**, **IDs**, and **contact information**. | |
| - Uses a **fine-tuned transformer NER model** (selectable backbone such as VetBERT / PetBERT / ClinicalBERT) loaded from a **private Hugging Face repository**. | |
| - Augments model predictions with **high-precision pattern matching** for structured identifiers (e.g., emails and phone numbers). | |
| **How to interpret results** | |
| - This tool prioritizes **high recall** for sensitive identifiers (reducing false negatives), with thresholds adjustable in the sidebar. | |
| - The highlighted view is provided for **demonstration and error analysis**; the redacted output is the intended downstream artifact. | |
| **Engineering notes** | |
| - **Model source**: loaded directly from Hugging Face (optionally pinned to a specific revision for reproducibility). | |
| - **CONTACT**: extracted via regex (emails/phones). If the model also predicts CONTACT, regex is treated as the source of truth on overlaps. | |
| - **Long notes**: optional windowing reduces truncation artifacts and improves coverage across multi-page notes. | |
| **Privacy and intended use** | |
| - This is a **research and demonstration tool**, not a certified de-identification system. | |
| - Do **not** paste sensitive/regulated data unless you are running the tool in an approved environment with appropriate controls. | |
| - For any public deployment, ensure **access control**, **minimal logging**, and a **privacy/security review** consistent with your institution’s policies. | |
| **Citation / attribution** | |
| If you use this tool or its outputs in a manuscript, please cite the Brundage Lab and describe the model backbone, training data composition (real vs. synthetic), and evaluation protocol. | |
| """ | |
| ) | |
| st.caption("Tip: Select your model backbone and explore a single document. Modify default thresholds to finetune your performance.") | |