Spaces:
Sleeping
Sleeping
File size: 8,919 Bytes
3608aca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
import gradio as gr
from transformers import pipeline
import html
import torch
import numpy as np
import warnings
warnings.filterwarnings("ignore", message=".*LibreSSL.*")
# ------------------------------------------------------------
# Device / ASR model
# ------------------------------------------------------------
DEVICE = "mps" if torch.backends.mps.is_available() else (
"cuda" if torch.cuda.is_available() else "cpu"
)
ASR_MODEL_NAME = "openai/whisper-large-v3" # most accurate
# ------------------------------------------------------------
# App metadata
# ------------------------------------------------------------
TITLE = "Clinical NER with Text + Audio (Bonus)"
DESCRIPTION = """
This demo extracts biomedical entities (diseases, drugs, anatomy, etc.)
from **clinical text** and from **spoken audio**.
**Models**
- NER → `d4data/biomedical-ner-all`
- ASR → `openai/whisper-large-v3` (high accuracy)
**Disclaimer:** Educational / research only — not for clinical use.
"""
FOOTER = "Built with 🤗 Transformers + Gradio | Educational use only"
# ------------------------------------------------------------
# Load pipelines
# ------------------------------------------------------------
ner = pipeline(
"token-classification",
model="d4data/biomedical-ner-all",
aggregation_strategy="simple",
)
# keep dtype float32 on MPS for stability
asr = pipeline(
"automatic-speech-recognition",
model=ASR_MODEL_NAME,
device=("mps" if DEVICE == "mps" else 0 if DEVICE == "cuda" else -1),
dtype=(torch.float32 if DEVICE == "mps" else None),
)
# ------------------------------------------------------------
# Helper functions
# ------------------------------------------------------------
def highlight_text(text, ents):
"""Return HTML string with colored entity spans (spaces not underscores)."""
if not text:
return ""
spans = sorted(ents, key=lambda x: x["start"])
html_out, i = [], 0
# normalize labels to your preferred names (no underscores, friendlier groups)
rename_map = {
"PROBLEM": "Sign symptom",
"SYMPTOM": "Sign symptom",
"DISEASE": "Sign symptom",
"TEST": "Diagnostic procedure",
"TREATMENT": "Diagnostic procedure",
"PROCEDURE": "Diagnostic procedure",
"ANATOMY": "Biological structure",
"BODY_PART": "Biological structure",
"MEDICATION": "Medication",
"DRUG": "Medication",
"CHEMICAL": "Medication",
}
colors = {
"Diagnostic procedure": "#ADD8E6",
"Biological structure": "#D8BFD8",
"Sign symptom": "#FF7F7F",
"Medication": "#FFD700",
}
for span in spans:
s, e = span["start"], span["end"]
if s < i:
continue
html_out.append(html.escape(text[i:s]))
raw_label = (span.get("entity_group", "") or "").upper()
label = rename_map.get(raw_label, raw_label.title().replace("_", " "))
color = colors.get(label, "#E0E0E0")
chunk = html.escape(text[s:e])
html_out.append(
f"<span style='background:{color};padding:2px 4px;"
f"border-radius:4px;margin:1px;' title='{label} | "
f"score={span.get('score',0):.2f}'>{chunk} "
f"<small style='opacity:.7'>[{label}]</small></span>"
)
i = e
html_out.append(html.escape(text[i:]))
return "<div style='line-height:1.9'>" + "".join(html_out) + "</div>"
def run_ner(text):
if not text or not text.strip():
return "", []
preds = ner(text)
rows = []
for p in preds:
entity_label = p.get("entity_group", "")
entity_label = entity_label.replace("_", " ") # <-- replaces underscores with spaces
rows.append([
entity_label,
text[p["start"]:p["end"]],
p["start"],
p["end"],
round(float(p.get("score", 0.0)), 4)
])
return highlight_text(text, preds), rows
def asr_then_ner(audio):
"""
Transcribe speech → run NER using Whisper-Large-v3.
Fixes the previous error by:
- Passing audio as dict {"array": ..., "sampling_rate": ...} (what the pipeline expects)
- NOT forwarding unsupported generate kwargs like condition_on_previous_text/sampling_rate
- Using only safe decode args (temperature) that Whisper generate() accepts
"""
if audio is None:
return "", "", []
# Normalize input from gr.Audio(type="numpy"): it provides a tuple (sr, data)
if isinstance(audio, tuple) and len(audio) == 2:
sr, data = audio
data = np.asarray(data)
if data.ndim == 2: # stereo -> mono
data = data.mean(axis=1)
data = data.astype("float32", copy=False)
asr_input = {"array": data, "sampling_rate": int(sr)}
elif isinstance(audio, dict) and "array" in audio:
# Already in dict form
asr_input = {"array": np.asarray(audio["array"]).astype("float32", copy=False),
"sampling_rate": int(audio.get("sampling_rate", 16000))}
else:
# Let pipeline handle paths or other inputs directly
asr_input = audio
# Call pipeline with stable settings; only pass args Whisper supports downstream
result = asr(
asr_input,
chunk_length_s=30, # robust windowing
return_timestamps=False,
generate_kwargs=dict(
temperature=0.0, # deterministic decoding
task="transcribe", # English transcription (auto-detect language OK)
# You can also add: "language": "en" if you want to force English
),
)
transcript = (result.get("text", "") if isinstance(result, dict) else str(result)).strip()
if not transcript:
return "(No speech detected)", "", []
html_out, rows = run_ner(transcript)
return transcript, html_out, rows
# ------------------------------------------------------------
# Build Gradio interface
# ------------------------------------------------------------
with gr.Blocks(title=TITLE) as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESCRIPTION)
# --------------------------------------------------------
# Clinical NER (Text) — DO NOT TOUCH
# --------------------------------------------------------
with gr.Tab("Clinical NER (Text)"):
ex_text = "The patient started metformin 500 mg BID for type 2 diabetes and reported mild neuropathy."
inp = gr.Textbox(label="Enter clinical text", placeholder=ex_text, lines=5)
btn = gr.Button("Extract Entities")
out_html = gr.HTML(label="Highlighted text")
out_table = gr.Dataframe(headers=["entity","text","start","end","score"],
label="Extracted entities", interactive=False)
btn.click(fn=run_ner, inputs=inp, outputs=[out_html, out_table])
gr.Examples(
examples=[
[ex_text],
["CT scan of the abdomen revealed hepatomegaly. Start acetaminophen PRN and check ALT/AST."],
["Allergic to penicillin; prescribe azithromycin. Monitor CRP and WBC weekly."],
],
inputs=[inp]
)
# --------------------------------------------------------
# Clinical NER from Audio (Bonus)
# --------------------------------------------------------
with gr.Tab("Clinical NER from Audio (Bonus)"):
gr.Markdown("""
### Record or upload ≤30 s audio
Whisper transcribes your speech, then the NER model extracts key biomedical entities.
**Try saying things like:**
- “The patient started metformin five hundred milligrams twice daily for diabetes.”
- “Order a CT scan of the abdomen and start acetaminophen as needed.”
- “Prescribe azithromycin and monitor white blood cell count weekly.”
- “Ultrasound shows hepatomegaly and mild fatty liver.”
""")
audio = gr.Audio(sources=["microphone", "upload"],
type="numpy", label="🎙️ Record or upload audio (≤30 s)")
btn2 = gr.Button("Transcribe → Extract Entities")
transcript = gr.Textbox(label="Transcript", interactive=False)
out_html2 = gr.HTML(label="Highlighted transcript")
out_table2 = gr.Dataframe(headers=["entity","text","start","end","score"],
label="Extracted entities", interactive=False)
btn2.click(fn=asr_then_ner, inputs=audio,
outputs=[transcript, out_html2, out_table2])
gr.Markdown(f"---\n{FOOTER}")
# ------------------------------------------------------------
# Launch
# ------------------------------------------------------------
if __name__ == "__main__":
# disable analytics calls that were throwing httpx timeouts; not required
demo.launch(share=True)
|