File size: 8,919 Bytes
3608aca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import gradio as gr
from transformers import pipeline
import html
import torch
import numpy as np
import warnings
warnings.filterwarnings("ignore", message=".*LibreSSL.*")


# ------------------------------------------------------------
# Device / ASR model
# ------------------------------------------------------------
DEVICE = "mps" if torch.backends.mps.is_available() else (
    "cuda" if torch.cuda.is_available() else "cpu"
)
ASR_MODEL_NAME = "openai/whisper-large-v3"  # most accurate

# ------------------------------------------------------------
# App metadata
# ------------------------------------------------------------
TITLE = "Clinical NER with Text + Audio (Bonus)"
DESCRIPTION = """
This demo extracts biomedical entities (diseases, drugs, anatomy, etc.)
from **clinical text** and from **spoken audio**.

**Models**
- NER → `d4data/biomedical-ner-all`
- ASR → `openai/whisper-large-v3` (high accuracy)

**Disclaimer:** Educational / research only — not for clinical use.
"""
FOOTER = "Built with 🤗 Transformers + Gradio | Educational use only"

# ------------------------------------------------------------
# Load pipelines
# ------------------------------------------------------------
ner = pipeline(
    "token-classification",
    model="d4data/biomedical-ner-all",
    aggregation_strategy="simple",
)

# keep dtype float32 on MPS for stability
asr = pipeline(
    "automatic-speech-recognition",
    model=ASR_MODEL_NAME,
    device=("mps" if DEVICE == "mps" else 0 if DEVICE == "cuda" else -1),
    dtype=(torch.float32 if DEVICE == "mps" else None),
)

# ------------------------------------------------------------
# Helper functions
# ------------------------------------------------------------
def highlight_text(text, ents):
    """Return HTML string with colored entity spans (spaces not underscores)."""
    if not text:
        return ""
    spans = sorted(ents, key=lambda x: x["start"])
    html_out, i = [], 0

    # normalize labels to your preferred names (no underscores, friendlier groups)
    rename_map = {
        "PROBLEM": "Sign symptom",
        "SYMPTOM": "Sign symptom",
        "DISEASE": "Sign symptom",
        "TEST": "Diagnostic procedure",
        "TREATMENT": "Diagnostic procedure",
        "PROCEDURE": "Diagnostic procedure",
        "ANATOMY": "Biological structure",
        "BODY_PART": "Biological structure",
        "MEDICATION": "Medication",
        "DRUG": "Medication",
        "CHEMICAL": "Medication",
    }
    colors = {
        "Diagnostic procedure": "#ADD8E6",
        "Biological structure": "#D8BFD8",
        "Sign symptom": "#FF7F7F",
        "Medication": "#FFD700",
    }

    for span in spans:
        s, e = span["start"], span["end"]
        if s < i:
            continue
        html_out.append(html.escape(text[i:s]))
        raw_label = (span.get("entity_group", "") or "").upper()
        label = rename_map.get(raw_label, raw_label.title().replace("_", " "))
        color = colors.get(label, "#E0E0E0")
        chunk = html.escape(text[s:e])
        html_out.append(
            f"<span style='background:{color};padding:2px 4px;"
            f"border-radius:4px;margin:1px;' title='{label} | "
            f"score={span.get('score',0):.2f}'>{chunk} "
            f"<small style='opacity:.7'>[{label}]</small></span>"
        )
        i = e
    html_out.append(html.escape(text[i:]))
    return "<div style='line-height:1.9'>" + "".join(html_out) + "</div>"


def run_ner(text):
    if not text or not text.strip():
        return "", []
    preds = ner(text)

    rows = []
    for p in preds:
        entity_label = p.get("entity_group", "")
        entity_label = entity_label.replace("_", " ")  # <-- replaces underscores with spaces
        rows.append([
            entity_label,
            text[p["start"]:p["end"]],
            p["start"],
            p["end"],
            round(float(p.get("score", 0.0)), 4)
        ])
    return highlight_text(text, preds), rows


def asr_then_ner(audio):
    """
    Transcribe speech → run NER using Whisper-Large-v3.

    Fixes the previous error by:
      - Passing audio as dict {"array": ..., "sampling_rate": ...} (what the pipeline expects)
      - NOT forwarding unsupported generate kwargs like condition_on_previous_text/sampling_rate
      - Using only safe decode args (temperature) that Whisper generate() accepts
    """
    if audio is None:
        return "", "", []

    # Normalize input from gr.Audio(type="numpy"): it provides a tuple (sr, data)
    if isinstance(audio, tuple) and len(audio) == 2:
        sr, data = audio
        data = np.asarray(data)
        if data.ndim == 2:  # stereo -> mono
            data = data.mean(axis=1)
        data = data.astype("float32", copy=False)
        asr_input = {"array": data, "sampling_rate": int(sr)}
    elif isinstance(audio, dict) and "array" in audio:
        # Already in dict form
        asr_input = {"array": np.asarray(audio["array"]).astype("float32", copy=False),
                     "sampling_rate": int(audio.get("sampling_rate", 16000))}
    else:
        # Let pipeline handle paths or other inputs directly
        asr_input = audio

    # Call pipeline with stable settings; only pass args Whisper supports downstream
    result = asr(
        asr_input,
        chunk_length_s=30,              # robust windowing
        return_timestamps=False,
        generate_kwargs=dict(
            temperature=0.0,            # deterministic decoding
            task="transcribe",          # English transcription (auto-detect language OK)
            # You can also add: "language": "en" if you want to force English
        ),
    )

    transcript = (result.get("text", "") if isinstance(result, dict) else str(result)).strip()
    if not transcript:
        return "(No speech detected)", "", []

    html_out, rows = run_ner(transcript)
    return transcript, html_out, rows

# ------------------------------------------------------------
# Build Gradio interface
# ------------------------------------------------------------
with gr.Blocks(title=TITLE) as demo:
    gr.Markdown(f"# {TITLE}")
    gr.Markdown(DESCRIPTION)

    # --------------------------------------------------------
    # Clinical NER (Text) — DO NOT TOUCH
    # --------------------------------------------------------
    with gr.Tab("Clinical NER (Text)"):
        ex_text = "The patient started metformin 500 mg BID for type 2 diabetes and reported mild neuropathy."
        inp = gr.Textbox(label="Enter clinical text", placeholder=ex_text, lines=5)
        btn = gr.Button("Extract Entities")
        out_html = gr.HTML(label="Highlighted text")
        out_table = gr.Dataframe(headers=["entity","text","start","end","score"],
                                 label="Extracted entities", interactive=False)
        btn.click(fn=run_ner, inputs=inp, outputs=[out_html, out_table])
        gr.Examples(
            examples=[
                [ex_text],
                ["CT scan of the abdomen revealed hepatomegaly. Start acetaminophen PRN and check ALT/AST."],
                ["Allergic to penicillin; prescribe azithromycin. Monitor CRP and WBC weekly."],
            ],
            inputs=[inp]
        )

    # --------------------------------------------------------
    # Clinical NER from Audio (Bonus)
    # --------------------------------------------------------
    with gr.Tab("Clinical NER from Audio (Bonus)"):
        gr.Markdown("""
        ### Record or upload ≤30 s audio  
        Whisper transcribes your speech, then the NER model extracts key biomedical entities.  

        **Try saying things like:**
        - “The patient started metformin five hundred milligrams twice daily for diabetes.”  
        - “Order a CT scan of the abdomen and start acetaminophen as needed.”  
        - “Prescribe azithromycin and monitor white blood cell count weekly.”  
        - “Ultrasound shows hepatomegaly and mild fatty liver.”
        """)

        audio = gr.Audio(sources=["microphone", "upload"],
                         type="numpy", label="🎙️ Record or upload audio (≤30 s)")

        btn2 = gr.Button("Transcribe → Extract Entities")
        transcript = gr.Textbox(label="Transcript", interactive=False)
        out_html2 = gr.HTML(label="Highlighted transcript")
        out_table2 = gr.Dataframe(headers=["entity","text","start","end","score"],
                                  label="Extracted entities", interactive=False)

        btn2.click(fn=asr_then_ner, inputs=audio,
                   outputs=[transcript, out_html2, out_table2])

    gr.Markdown(f"---\n{FOOTER}")

# ------------------------------------------------------------
# Launch
# ------------------------------------------------------------
if __name__ == "__main__":
    # disable analytics calls that were throwing httpx timeouts; not required
    demo.launch(share=True)