import gradio as gr from transformers import pipeline import html import torch import numpy as np import warnings warnings.filterwarnings("ignore", message=".*LibreSSL.*") # ------------------------------------------------------------ # Device / ASR model # ------------------------------------------------------------ DEVICE = "mps" if torch.backends.mps.is_available() else ( "cuda" if torch.cuda.is_available() else "cpu" ) ASR_MODEL_NAME = "openai/whisper-large-v3" # most accurate # ------------------------------------------------------------ # App metadata # ------------------------------------------------------------ TITLE = "Clinical NER with Text + Audio (Bonus)" DESCRIPTION = """ This demo extracts biomedical entities (diseases, drugs, anatomy, etc.) from **clinical text** and from **spoken audio**. **Models** - NER → `d4data/biomedical-ner-all` - ASR → `openai/whisper-large-v3` (high accuracy) **Disclaimer:** Educational / research only — not for clinical use. """ FOOTER = "Built with 🤗 Transformers + Gradio | Educational use only" # ------------------------------------------------------------ # Load pipelines # ------------------------------------------------------------ ner = pipeline( "token-classification", model="d4data/biomedical-ner-all", aggregation_strategy="simple", ) # keep dtype float32 on MPS for stability asr = pipeline( "automatic-speech-recognition", model=ASR_MODEL_NAME, device=("mps" if DEVICE == "mps" else 0 if DEVICE == "cuda" else -1), dtype=(torch.float32 if DEVICE == "mps" else None), ) # ------------------------------------------------------------ # Helper functions # ------------------------------------------------------------ def highlight_text(text, ents): """Return HTML string with colored entity spans (spaces not underscores).""" if not text: return "" spans = sorted(ents, key=lambda x: x["start"]) html_out, i = [], 0 # normalize labels to your preferred names (no underscores, friendlier groups) rename_map = { "PROBLEM": "Sign symptom", "SYMPTOM": "Sign symptom", "DISEASE": "Sign symptom", "TEST": "Diagnostic procedure", "TREATMENT": "Diagnostic procedure", "PROCEDURE": "Diagnostic procedure", "ANATOMY": "Biological structure", "BODY_PART": "Biological structure", "MEDICATION": "Medication", "DRUG": "Medication", "CHEMICAL": "Medication", } colors = { "Diagnostic procedure": "#ADD8E6", "Biological structure": "#D8BFD8", "Sign symptom": "#FF7F7F", "Medication": "#FFD700", } for span in spans: s, e = span["start"], span["end"] if s < i: continue html_out.append(html.escape(text[i:s])) raw_label = (span.get("entity_group", "") or "").upper() label = rename_map.get(raw_label, raw_label.title().replace("_", " ")) color = colors.get(label, "#E0E0E0") chunk = html.escape(text[s:e]) html_out.append( f"{chunk} " f"[{label}]" ) i = e html_out.append(html.escape(text[i:])) return "
" + "".join(html_out) + "
" def run_ner(text): if not text or not text.strip(): return "", [] preds = ner(text) rows = [] for p in preds: entity_label = p.get("entity_group", "") entity_label = entity_label.replace("_", " ") # <-- replaces underscores with spaces rows.append([ entity_label, text[p["start"]:p["end"]], p["start"], p["end"], round(float(p.get("score", 0.0)), 4) ]) return highlight_text(text, preds), rows def asr_then_ner(audio): """ Transcribe speech → run NER using Whisper-Large-v3. Fixes the previous error by: - Passing audio as dict {"array": ..., "sampling_rate": ...} (what the pipeline expects) - NOT forwarding unsupported generate kwargs like condition_on_previous_text/sampling_rate - Using only safe decode args (temperature) that Whisper generate() accepts """ if audio is None: return "", "", [] # Normalize input from gr.Audio(type="numpy"): it provides a tuple (sr, data) if isinstance(audio, tuple) and len(audio) == 2: sr, data = audio data = np.asarray(data) if data.ndim == 2: # stereo -> mono data = data.mean(axis=1) data = data.astype("float32", copy=False) asr_input = {"array": data, "sampling_rate": int(sr)} elif isinstance(audio, dict) and "array" in audio: # Already in dict form asr_input = {"array": np.asarray(audio["array"]).astype("float32", copy=False), "sampling_rate": int(audio.get("sampling_rate", 16000))} else: # Let pipeline handle paths or other inputs directly asr_input = audio # Call pipeline with stable settings; only pass args Whisper supports downstream result = asr( asr_input, chunk_length_s=30, # robust windowing return_timestamps=False, generate_kwargs=dict( temperature=0.0, # deterministic decoding task="transcribe", # English transcription (auto-detect language OK) # You can also add: "language": "en" if you want to force English ), ) transcript = (result.get("text", "") if isinstance(result, dict) else str(result)).strip() if not transcript: return "(No speech detected)", "", [] html_out, rows = run_ner(transcript) return transcript, html_out, rows # ------------------------------------------------------------ # Build Gradio interface # ------------------------------------------------------------ with gr.Blocks(title=TITLE) as demo: gr.Markdown(f"# {TITLE}") gr.Markdown(DESCRIPTION) # -------------------------------------------------------- # Clinical NER (Text) — DO NOT TOUCH # -------------------------------------------------------- with gr.Tab("Clinical NER (Text)"): ex_text = "The patient started metformin 500 mg BID for type 2 diabetes and reported mild neuropathy." inp = gr.Textbox(label="Enter clinical text", placeholder=ex_text, lines=5) btn = gr.Button("Extract Entities") out_html = gr.HTML(label="Highlighted text") out_table = gr.Dataframe(headers=["entity","text","start","end","score"], label="Extracted entities", interactive=False) btn.click(fn=run_ner, inputs=inp, outputs=[out_html, out_table]) gr.Examples( examples=[ [ex_text], ["CT scan of the abdomen revealed hepatomegaly. Start acetaminophen PRN and check ALT/AST."], ["Allergic to penicillin; prescribe azithromycin. Monitor CRP and WBC weekly."], ], inputs=[inp] ) # -------------------------------------------------------- # Clinical NER from Audio (Bonus) # -------------------------------------------------------- with gr.Tab("Clinical NER from Audio (Bonus)"): gr.Markdown(""" ### Record or upload ≤30 s audio Whisper transcribes your speech, then the NER model extracts key biomedical entities. **Try saying things like:** - “The patient started metformin five hundred milligrams twice daily for diabetes.” - “Order a CT scan of the abdomen and start acetaminophen as needed.” - “Prescribe azithromycin and monitor white blood cell count weekly.” - “Ultrasound shows hepatomegaly and mild fatty liver.” """) audio = gr.Audio(sources=["microphone", "upload"], type="numpy", label="🎙️ Record or upload audio (≤30 s)") btn2 = gr.Button("Transcribe → Extract Entities") transcript = gr.Textbox(label="Transcript", interactive=False) out_html2 = gr.HTML(label="Highlighted transcript") out_table2 = gr.Dataframe(headers=["entity","text","start","end","score"], label="Extracted entities", interactive=False) btn2.click(fn=asr_then_ner, inputs=audio, outputs=[transcript, out_html2, out_table2]) gr.Markdown(f"---\n{FOOTER}") # ------------------------------------------------------------ # Launch # ------------------------------------------------------------ if __name__ == "__main__": # disable analytics calls that were throwing httpx timeouts; not required demo.launch(share=True)