Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline | |
| import html | |
| import torch | |
| import numpy as np | |
| import warnings | |
| warnings.filterwarnings("ignore", message=".*LibreSSL.*") | |
| # ------------------------------------------------------------ | |
| # Device / ASR model | |
| # ------------------------------------------------------------ | |
| DEVICE = "mps" if torch.backends.mps.is_available() else ( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| ASR_MODEL_NAME = "openai/whisper-large-v3" # most accurate | |
| # ------------------------------------------------------------ | |
| # App metadata | |
| # ------------------------------------------------------------ | |
| TITLE = "Clinical NER with Text + Audio (Bonus)" | |
| DESCRIPTION = """ | |
| This demo extracts biomedical entities (diseases, drugs, anatomy, etc.) | |
| from **clinical text** and from **spoken audio**. | |
| **Models** | |
| - NER → `d4data/biomedical-ner-all` | |
| - ASR → `openai/whisper-large-v3` (high accuracy) | |
| **Disclaimer:** Educational / research only — not for clinical use. | |
| """ | |
| FOOTER = "Built with 🤗 Transformers + Gradio | Educational use only" | |
| # ------------------------------------------------------------ | |
| # Load pipelines | |
| # ------------------------------------------------------------ | |
| ner = pipeline( | |
| "token-classification", | |
| model="d4data/biomedical-ner-all", | |
| aggregation_strategy="simple", | |
| ) | |
| # keep dtype float32 on MPS for stability | |
| asr = pipeline( | |
| "automatic-speech-recognition", | |
| model=ASR_MODEL_NAME, | |
| device=("mps" if DEVICE == "mps" else 0 if DEVICE == "cuda" else -1), | |
| dtype=(torch.float32 if DEVICE == "mps" else None), | |
| ) | |
| # ------------------------------------------------------------ | |
| # Helper functions | |
| # ------------------------------------------------------------ | |
| def highlight_text(text, ents): | |
| """Return HTML string with colored entity spans (spaces not underscores).""" | |
| if not text: | |
| return "" | |
| spans = sorted(ents, key=lambda x: x["start"]) | |
| html_out, i = [], 0 | |
| # normalize labels to your preferred names (no underscores, friendlier groups) | |
| rename_map = { | |
| "PROBLEM": "Sign symptom", | |
| "SYMPTOM": "Sign symptom", | |
| "DISEASE": "Sign symptom", | |
| "TEST": "Diagnostic procedure", | |
| "TREATMENT": "Diagnostic procedure", | |
| "PROCEDURE": "Diagnostic procedure", | |
| "ANATOMY": "Biological structure", | |
| "BODY_PART": "Biological structure", | |
| "MEDICATION": "Medication", | |
| "DRUG": "Medication", | |
| "CHEMICAL": "Medication", | |
| } | |
| colors = { | |
| "Diagnostic procedure": "#ADD8E6", | |
| "Biological structure": "#D8BFD8", | |
| "Sign symptom": "#FF7F7F", | |
| "Medication": "#FFD700", | |
| } | |
| for span in spans: | |
| s, e = span["start"], span["end"] | |
| if s < i: | |
| continue | |
| html_out.append(html.escape(text[i:s])) | |
| raw_label = (span.get("entity_group", "") or "").upper() | |
| label = rename_map.get(raw_label, raw_label.title().replace("_", " ")) | |
| color = colors.get(label, "#E0E0E0") | |
| chunk = html.escape(text[s:e]) | |
| html_out.append( | |
| f"<span style='background:{color};padding:2px 4px;" | |
| f"border-radius:4px;margin:1px;' title='{label} | " | |
| f"score={span.get('score',0):.2f}'>{chunk} " | |
| f"<small style='opacity:.7'>[{label}]</small></span>" | |
| ) | |
| i = e | |
| html_out.append(html.escape(text[i:])) | |
| return "<div style='line-height:1.9'>" + "".join(html_out) + "</div>" | |
| def run_ner(text): | |
| if not text or not text.strip(): | |
| return "", [] | |
| preds = ner(text) | |
| rows = [] | |
| for p in preds: | |
| entity_label = p.get("entity_group", "") | |
| entity_label = entity_label.replace("_", " ") # <-- replaces underscores with spaces | |
| rows.append([ | |
| entity_label, | |
| text[p["start"]:p["end"]], | |
| p["start"], | |
| p["end"], | |
| round(float(p.get("score", 0.0)), 4) | |
| ]) | |
| return highlight_text(text, preds), rows | |
| def asr_then_ner(audio): | |
| """ | |
| Transcribe speech → run NER using Whisper-Large-v3. | |
| Fixes the previous error by: | |
| - Passing audio as dict {"array": ..., "sampling_rate": ...} (what the pipeline expects) | |
| - NOT forwarding unsupported generate kwargs like condition_on_previous_text/sampling_rate | |
| - Using only safe decode args (temperature) that Whisper generate() accepts | |
| """ | |
| if audio is None: | |
| return "", "", [] | |
| # Normalize input from gr.Audio(type="numpy"): it provides a tuple (sr, data) | |
| if isinstance(audio, tuple) and len(audio) == 2: | |
| sr, data = audio | |
| data = np.asarray(data) | |
| if data.ndim == 2: # stereo -> mono | |
| data = data.mean(axis=1) | |
| data = data.astype("float32", copy=False) | |
| asr_input = {"array": data, "sampling_rate": int(sr)} | |
| elif isinstance(audio, dict) and "array" in audio: | |
| # Already in dict form | |
| asr_input = {"array": np.asarray(audio["array"]).astype("float32", copy=False), | |
| "sampling_rate": int(audio.get("sampling_rate", 16000))} | |
| else: | |
| # Let pipeline handle paths or other inputs directly | |
| asr_input = audio | |
| # Call pipeline with stable settings; only pass args Whisper supports downstream | |
| result = asr( | |
| asr_input, | |
| chunk_length_s=30, # robust windowing | |
| return_timestamps=False, | |
| generate_kwargs=dict( | |
| temperature=0.0, # deterministic decoding | |
| task="transcribe", # English transcription (auto-detect language OK) | |
| # You can also add: "language": "en" if you want to force English | |
| ), | |
| ) | |
| transcript = (result.get("text", "") if isinstance(result, dict) else str(result)).strip() | |
| if not transcript: | |
| return "(No speech detected)", "", [] | |
| html_out, rows = run_ner(transcript) | |
| return transcript, html_out, rows | |
| # ------------------------------------------------------------ | |
| # Build Gradio interface | |
| # ------------------------------------------------------------ | |
| with gr.Blocks(title=TITLE) as demo: | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown(DESCRIPTION) | |
| # -------------------------------------------------------- | |
| # Clinical NER (Text) — DO NOT TOUCH | |
| # -------------------------------------------------------- | |
| with gr.Tab("Clinical NER (Text)"): | |
| ex_text = "The patient started metformin 500 mg BID for type 2 diabetes and reported mild neuropathy." | |
| inp = gr.Textbox(label="Enter clinical text", placeholder=ex_text, lines=5) | |
| btn = gr.Button("Extract Entities") | |
| out_html = gr.HTML(label="Highlighted text") | |
| out_table = gr.Dataframe(headers=["entity","text","start","end","score"], | |
| label="Extracted entities", interactive=False) | |
| btn.click(fn=run_ner, inputs=inp, outputs=[out_html, out_table]) | |
| gr.Examples( | |
| examples=[ | |
| [ex_text], | |
| ["CT scan of the abdomen revealed hepatomegaly. Start acetaminophen PRN and check ALT/AST."], | |
| ["Allergic to penicillin; prescribe azithromycin. Monitor CRP and WBC weekly."], | |
| ], | |
| inputs=[inp] | |
| ) | |
| # -------------------------------------------------------- | |
| # Clinical NER from Audio (Bonus) | |
| # -------------------------------------------------------- | |
| with gr.Tab("Clinical NER from Audio (Bonus)"): | |
| gr.Markdown(""" | |
| ### Record or upload ≤30 s audio | |
| Whisper transcribes your speech, then the NER model extracts key biomedical entities. | |
| **Try saying things like:** | |
| - “The patient started metformin five hundred milligrams twice daily for diabetes.” | |
| - “Order a CT scan of the abdomen and start acetaminophen as needed.” | |
| - “Prescribe azithromycin and monitor white blood cell count weekly.” | |
| - “Ultrasound shows hepatomegaly and mild fatty liver.” | |
| """) | |
| audio = gr.Audio(sources=["microphone", "upload"], | |
| type="numpy", label="🎙️ Record or upload audio (≤30 s)") | |
| btn2 = gr.Button("Transcribe → Extract Entities") | |
| transcript = gr.Textbox(label="Transcript", interactive=False) | |
| out_html2 = gr.HTML(label="Highlighted transcript") | |
| out_table2 = gr.Dataframe(headers=["entity","text","start","end","score"], | |
| label="Extracted entities", interactive=False) | |
| btn2.click(fn=asr_then_ner, inputs=audio, | |
| outputs=[transcript, out_html2, out_table2]) | |
| gr.Markdown(f"---\n{FOOTER}") | |
| # ------------------------------------------------------------ | |
| # Launch | |
| # ------------------------------------------------------------ | |
| if __name__ == "__main__": | |
| # disable analytics calls that were throwing httpx timeouts; not required | |
| demo.launch(share=True) | |