MedScribe-AI / app.py
walid0795's picture
Create app.py
3608aca verified
import gradio as gr
from transformers import pipeline
import html
import torch
import numpy as np
import warnings
warnings.filterwarnings("ignore", message=".*LibreSSL.*")
# ------------------------------------------------------------
# Device / ASR model
# ------------------------------------------------------------
DEVICE = "mps" if torch.backends.mps.is_available() else (
"cuda" if torch.cuda.is_available() else "cpu"
)
ASR_MODEL_NAME = "openai/whisper-large-v3" # most accurate
# ------------------------------------------------------------
# App metadata
# ------------------------------------------------------------
TITLE = "Clinical NER with Text + Audio (Bonus)"
DESCRIPTION = """
This demo extracts biomedical entities (diseases, drugs, anatomy, etc.)
from **clinical text** and from **spoken audio**.
**Models**
- NER → `d4data/biomedical-ner-all`
- ASR → `openai/whisper-large-v3` (high accuracy)
**Disclaimer:** Educational / research only — not for clinical use.
"""
FOOTER = "Built with 🤗 Transformers + Gradio | Educational use only"
# ------------------------------------------------------------
# Load pipelines
# ------------------------------------------------------------
ner = pipeline(
"token-classification",
model="d4data/biomedical-ner-all",
aggregation_strategy="simple",
)
# keep dtype float32 on MPS for stability
asr = pipeline(
"automatic-speech-recognition",
model=ASR_MODEL_NAME,
device=("mps" if DEVICE == "mps" else 0 if DEVICE == "cuda" else -1),
dtype=(torch.float32 if DEVICE == "mps" else None),
)
# ------------------------------------------------------------
# Helper functions
# ------------------------------------------------------------
def highlight_text(text, ents):
"""Return HTML string with colored entity spans (spaces not underscores)."""
if not text:
return ""
spans = sorted(ents, key=lambda x: x["start"])
html_out, i = [], 0
# normalize labels to your preferred names (no underscores, friendlier groups)
rename_map = {
"PROBLEM": "Sign symptom",
"SYMPTOM": "Sign symptom",
"DISEASE": "Sign symptom",
"TEST": "Diagnostic procedure",
"TREATMENT": "Diagnostic procedure",
"PROCEDURE": "Diagnostic procedure",
"ANATOMY": "Biological structure",
"BODY_PART": "Biological structure",
"MEDICATION": "Medication",
"DRUG": "Medication",
"CHEMICAL": "Medication",
}
colors = {
"Diagnostic procedure": "#ADD8E6",
"Biological structure": "#D8BFD8",
"Sign symptom": "#FF7F7F",
"Medication": "#FFD700",
}
for span in spans:
s, e = span["start"], span["end"]
if s < i:
continue
html_out.append(html.escape(text[i:s]))
raw_label = (span.get("entity_group", "") or "").upper()
label = rename_map.get(raw_label, raw_label.title().replace("_", " "))
color = colors.get(label, "#E0E0E0")
chunk = html.escape(text[s:e])
html_out.append(
f"<span style='background:{color};padding:2px 4px;"
f"border-radius:4px;margin:1px;' title='{label} | "
f"score={span.get('score',0):.2f}'>{chunk} "
f"<small style='opacity:.7'>[{label}]</small></span>"
)
i = e
html_out.append(html.escape(text[i:]))
return "<div style='line-height:1.9'>" + "".join(html_out) + "</div>"
def run_ner(text):
if not text or not text.strip():
return "", []
preds = ner(text)
rows = []
for p in preds:
entity_label = p.get("entity_group", "")
entity_label = entity_label.replace("_", " ") # <-- replaces underscores with spaces
rows.append([
entity_label,
text[p["start"]:p["end"]],
p["start"],
p["end"],
round(float(p.get("score", 0.0)), 4)
])
return highlight_text(text, preds), rows
def asr_then_ner(audio):
"""
Transcribe speech → run NER using Whisper-Large-v3.
Fixes the previous error by:
- Passing audio as dict {"array": ..., "sampling_rate": ...} (what the pipeline expects)
- NOT forwarding unsupported generate kwargs like condition_on_previous_text/sampling_rate
- Using only safe decode args (temperature) that Whisper generate() accepts
"""
if audio is None:
return "", "", []
# Normalize input from gr.Audio(type="numpy"): it provides a tuple (sr, data)
if isinstance(audio, tuple) and len(audio) == 2:
sr, data = audio
data = np.asarray(data)
if data.ndim == 2: # stereo -> mono
data = data.mean(axis=1)
data = data.astype("float32", copy=False)
asr_input = {"array": data, "sampling_rate": int(sr)}
elif isinstance(audio, dict) and "array" in audio:
# Already in dict form
asr_input = {"array": np.asarray(audio["array"]).astype("float32", copy=False),
"sampling_rate": int(audio.get("sampling_rate", 16000))}
else:
# Let pipeline handle paths or other inputs directly
asr_input = audio
# Call pipeline with stable settings; only pass args Whisper supports downstream
result = asr(
asr_input,
chunk_length_s=30, # robust windowing
return_timestamps=False,
generate_kwargs=dict(
temperature=0.0, # deterministic decoding
task="transcribe", # English transcription (auto-detect language OK)
# You can also add: "language": "en" if you want to force English
),
)
transcript = (result.get("text", "") if isinstance(result, dict) else str(result)).strip()
if not transcript:
return "(No speech detected)", "", []
html_out, rows = run_ner(transcript)
return transcript, html_out, rows
# ------------------------------------------------------------
# Build Gradio interface
# ------------------------------------------------------------
with gr.Blocks(title=TITLE) as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESCRIPTION)
# --------------------------------------------------------
# Clinical NER (Text) — DO NOT TOUCH
# --------------------------------------------------------
with gr.Tab("Clinical NER (Text)"):
ex_text = "The patient started metformin 500 mg BID for type 2 diabetes and reported mild neuropathy."
inp = gr.Textbox(label="Enter clinical text", placeholder=ex_text, lines=5)
btn = gr.Button("Extract Entities")
out_html = gr.HTML(label="Highlighted text")
out_table = gr.Dataframe(headers=["entity","text","start","end","score"],
label="Extracted entities", interactive=False)
btn.click(fn=run_ner, inputs=inp, outputs=[out_html, out_table])
gr.Examples(
examples=[
[ex_text],
["CT scan of the abdomen revealed hepatomegaly. Start acetaminophen PRN and check ALT/AST."],
["Allergic to penicillin; prescribe azithromycin. Monitor CRP and WBC weekly."],
],
inputs=[inp]
)
# --------------------------------------------------------
# Clinical NER from Audio (Bonus)
# --------------------------------------------------------
with gr.Tab("Clinical NER from Audio (Bonus)"):
gr.Markdown("""
### Record or upload ≤30 s audio
Whisper transcribes your speech, then the NER model extracts key biomedical entities.
**Try saying things like:**
- “The patient started metformin five hundred milligrams twice daily for diabetes.”
- “Order a CT scan of the abdomen and start acetaminophen as needed.”
- “Prescribe azithromycin and monitor white blood cell count weekly.”
- “Ultrasound shows hepatomegaly and mild fatty liver.”
""")
audio = gr.Audio(sources=["microphone", "upload"],
type="numpy", label="🎙️ Record or upload audio (≤30 s)")
btn2 = gr.Button("Transcribe → Extract Entities")
transcript = gr.Textbox(label="Transcript", interactive=False)
out_html2 = gr.HTML(label="Highlighted transcript")
out_table2 = gr.Dataframe(headers=["entity","text","start","end","score"],
label="Extracted entities", interactive=False)
btn2.click(fn=asr_then_ner, inputs=audio,
outputs=[transcript, out_html2, out_table2])
gr.Markdown(f"---\n{FOOTER}")
# ------------------------------------------------------------
# Launch
# ------------------------------------------------------------
if __name__ == "__main__":
# disable analytics calls that were throwing httpx timeouts; not required
demo.launch(share=True)