hsienberg's picture
Create app.py
941b965 verified
Raw
History Blame Contribute Delete
13.5 kB
import gradio as gr
from transformers import pipeline
# ── Model ─────────────────────────────────────────────────────────────────────
ner = pipeline(
"token-classification",
model="d4data/biomedical-ner-all",
aggregation_strategy="simple",
)
# ── Entity color map ──────────────────────────────────────────────────────────
COLOR_MAP = {
"Disease_disorder": "#FF6B6B", # coral-red
"Sign_symptom": "#FFB347", # amber
"Chemical": "#4ECDC4", # teal
"Medication": "#4ECDC4", # alias → teal
"Body_part": "#A78BFA", # violet
"Biological_structure":"#A78BFA", # alias → violet
"Diagnostic_procedure":"#60A5FA", # sky-blue
"Therapeutic_procedure":"#34D399", # emerald
"Lab_value": "#F9A8D4", # pink
"Clinical_event": "#FCD34D", # yellow
"Date": "#94A3B8", # slate
"Age": "#94A3B8", # slate
"Severity": "#FB923C", # orange
"Biological_attribute":"#C084FC", # purple
}
EXAMPLES = [
"Patient presents with acute onset of substernal chest pain radiating to the left arm. Started on aspirin and heparin drip.",
"The patient has a headache, fever, and sore throat. She was prescribed ibuprofen and amoxicillin.",
"The warrior stepped into the dungeon and drew his sword.",
]
# ── Core function ─────────────────────────────────────────────────────────────
def extract_entities(text: str):
if not text.strip():
return [], []
results = ner(text)
# ── Build HighlightedText spans ──
highlighted = []
cursor = 0
for ent in results:
start, end = ent["start"], ent["end"]
label = ent.get("entity_group", ent.get("entity", "ENTITY"))
score = ent["score"]
# plain text before this entity
if start > cursor:
highlighted.append((text[cursor:start], None))
highlighted.append((text[start:end], label))
cursor = end
# trailing plain text
if cursor < len(text):
highlighted.append((text[cursor:], None))
# ── Build Dataframe rows ──
rows = [
{
"Entity": ent["word"],
"Label": ent.get("entity_group", ent.get("entity", "ENTITY")),
"Score": round(float(ent["score"]), 2),
}
for ent in results
]
return highlighted, rows
# ── CSS ───────────────────────────────────────────────────────────────────────
css = """
@import url('https://fonts.googleapis.com/css2?family=Syne:wght@400;500;700;800&family=IBM+Plex+Mono:wght@300;400;500&family=IBM+Plex+Sans:wght@300;400;500&display=swap');
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
:root {
--bg: #0D1117;
--surface: #161B22;
--border: #30363D;
--text: #E6EDF3;
--muted: #8B949E;
--accent: #58A6FF;
--accent2: #3FB950;
--danger: #FF7B72;
--mono: 'IBM Plex Mono', monospace;
--display: 'Syne', sans-serif;
--body: 'IBM Plex Sans', sans-serif;
--radius: 8px;
}
body, .gradio-container {
background: var(--bg) !important;
font-family: var(--body) !important;
color: var(--text) !important;
}
/* ── Header ── */
.app-header {
padding: 32px 0 24px;
border-bottom: 1px solid var(--border);
margin-bottom: 28px;
display: flex;
align-items: flex-end;
justify-content: space-between;
flex-wrap: wrap;
gap: 12px;
}
.app-header h1 {
font-family: var(--display);
font-size: 2rem;
font-weight: 800;
letter-spacing: -0.03em;
color: var(--text);
line-height: 1;
}
.app-header h1 .accent { color: var(--accent); }
.header-right {
font-family: var(--mono);
font-size: 0.68rem;
letter-spacing: 0.1em;
text-transform: uppercase;
color: var(--muted);
text-align: right;
line-height: 1.8;
}
.tag {
display: inline-block;
background: #1F2937;
border: 1px solid var(--border);
color: var(--accent);
font-family: var(--mono);
font-size: 0.62rem;
letter-spacing: 0.08em;
padding: 2px 8px;
border-radius: 4px;
text-transform: uppercase;
}
/* ── Section labels ── */
.section-label {
font-family: var(--mono);
font-size: 0.65rem;
letter-spacing: 0.16em;
text-transform: uppercase;
color: var(--muted);
border-bottom: 1px solid var(--border);
padding-bottom: 8px;
margin-bottom: 12px;
}
/* ── Inputs ── */
textarea, input[type="text"] {
font-family: var(--mono) !important;
font-size: 0.84rem !important;
background: var(--surface) !important;
border: 1px solid var(--border) !important;
border-radius: var(--radius) !important;
color: var(--text) !important;
line-height: 1.7 !important;
transition: border-color 0.15s;
}
textarea:focus, input:focus {
border-color: var(--accent) !important;
outline: none !important;
box-shadow: 0 0 0 3px rgba(88,166,255,0.15) !important;
}
/* ── HighlightedText output ── */
.highlighted-text-output {
background: var(--surface) !important;
border: 1px solid var(--border) !important;
border-radius: var(--radius) !important;
padding: 16px !important;
font-family: var(--body) !important;
font-size: 0.92rem !important;
line-height: 1.9 !important;
color: var(--text) !important;
min-height: 80px;
}
/* entity chips */
.highlighted-text-output mark,
.highlighted-text-output span[data-label] {
border-radius: 3px !important;
padding: 1px 4px !important;
font-weight: 500 !important;
}
/* ── Dataframe ── */
.dataframe-output table {
font-family: var(--mono) !important;
font-size: 0.78rem !important;
background: var(--surface) !important;
border: 1px solid var(--border) !important;
border-radius: var(--radius) !important;
width: 100% !important;
border-collapse: collapse !important;
}
.dataframe-output th {
background: #1C2130 !important;
color: var(--accent) !important;
font-size: 0.65rem !important;
letter-spacing: 0.12em !important;
text-transform: uppercase !important;
padding: 10px 14px !important;
border-bottom: 1px solid var(--border) !important;
text-align: left !important;
}
.dataframe-output td {
color: var(--text) !important;
padding: 8px 14px !important;
border-bottom: 1px solid #1E242C !important;
}
.dataframe-output tr:hover td { background: #1A2030 !important; }
/* ── Buttons ── */
.analyze-btn {
background: var(--accent) !important;
color: #0D1117 !important;
font-family: var(--mono) !important;
font-size: 0.78rem !important;
font-weight: 500 !important;
letter-spacing: 0.1em !important;
text-transform: uppercase !important;
border: none !important;
border-radius: var(--radius) !important;
padding: 12px 28px !important;
cursor: pointer !important;
transition: opacity 0.15s, transform 0.1s !important;
}
.analyze-btn:hover { opacity: 0.85 !important; transform: translateY(-1px) !important; }
.analyze-btn:active { transform: translateY(0) !important; }
.clear-btn {
background: transparent !important;
color: var(--muted) !important;
font-family: var(--mono) !important;
font-size: 0.75rem !important;
letter-spacing: 0.08em !important;
border: 1px solid var(--border) !important;
border-radius: var(--radius) !important;
padding: 10px 20px !important;
cursor: pointer !important;
transition: border-color 0.15s, color 0.15s !important;
}
.clear-btn:hover {
border-color: var(--danger) !important;
color: var(--danger) !important;
}
/* ── Legend ── */
.legend {
display: flex;
flex-wrap: wrap;
gap: 8px;
margin: 16px 0 4px;
}
.legend-chip {
font-family: var(--mono);
font-size: 0.62rem;
letter-spacing: 0.06em;
padding: 3px 9px;
border-radius: 3px;
font-weight: 500;
color: #0D1117;
}
/* ── Examples ── */
.examples-section table { border-collapse: collapse; width: 100%; }
.examples-section td {
font-family: var(--mono) !important;
font-size: 0.75rem !important;
color: var(--muted) !important;
padding: 7px 12px !important;
border-bottom: 1px solid #1E242C !important;
cursor: pointer;
}
.examples-section tr:hover td { color: var(--text) !important; background: #1A2030 !important; }
/* ── Disclaimer ── */
.disclaimer {
margin-top: 32px;
padding: 12px 16px;
border: 1px solid var(--border);
border-left: 3px solid var(--muted);
border-radius: var(--radius);
background: var(--surface);
font-family: var(--mono);
font-size: 0.72rem;
color: var(--muted);
letter-spacing: 0.03em;
line-height: 1.6;
}
/* ── Gradio chrome ── */
.gradio-container > .main > .wrap { padding: 20px 32px 40px !important; }
footer { display: none !important; }
label span { font-family: var(--mono) !important; font-size: 0.7rem !important;
letter-spacing: 0.1em !important; text-transform: uppercase !important;
color: var(--muted) !important; }
"""
# ── UI ────────────────────────────────────────────────────────────────────────
with gr.Blocks(css=css, title="Medical Entity Extractor") as demo:
gr.HTML("""
<div class="app-header">
<div>
<div style="margin-bottom:8px;">
<span class="tag">NER</span>
<span class="tag" style="margin-left:4px;">Biomedical</span>
<span class="tag" style="margin-left:4px;">d4data</span>
</div>
<h1>Medical <span class="accent">Entity</span> Extractor</h1>
</div>
<div class="header-right">
d4data/biomedical-ner-all<br>
Diseases · Drugs · Symptoms · Body Parts · Procedures
</div>
</div>
<div class="legend">
<span class="legend-chip" style="background:#FF6B6B;">Disease / Disorder</span>
<span class="legend-chip" style="background:#FFB347;">Sign / Symptom</span>
<span class="legend-chip" style="background:#4ECDC4;">Chemical / Drug</span>
<span class="legend-chip" style="background:#A78BFA; color:#fff;">Body Part</span>
<span class="legend-chip" style="background:#60A5FA;">Diagnostic Procedure</span>
<span class="legend-chip" style="background:#34D399;">Therapeutic Procedure</span>
<span class="legend-chip" style="background:#F9A8D4;">Lab Value</span>
<span class="legend-chip" style="background:#FB923C;">Severity</span>
</div>
""")
with gr.Row(equal_height=False):
with gr.Column(scale=5):
gr.HTML('<div class="section-label">Input Text</div>')
text_input = gr.Textbox(
label="",
placeholder="Paste clinical notes, symptom descriptions, or any medical text…",
lines=8,
max_lines=20,
)
with gr.Row():
analyze_btn = gr.Button("Analyze →", elem_classes="analyze-btn")
clear_btn = gr.ClearButton(
components=[text_input],
value="Clear",
elem_classes="clear-btn",
)
with gr.Column(scale=6):
gr.HTML('<div class="section-label">Highlighted Entities</div>')
highlighted_out = gr.HighlightedText(
label="",
color_map=COLOR_MAP,
show_legend=False,
elem_classes="highlighted-text-output",
)
gr.HTML('<div class="section-label" style="margin-top:24px;">Entity Table</div>')
table_out = gr.Dataframe(
headers=["Entity", "Label", "Score"],
datatype=["str", "str", "number"],
label="",
elem_classes="dataframe-output",
wrap=True,
)
gr.HTML('<div class="section-label" style="margin-top:24px;">Example Texts</div>')
gr.Examples(
examples=EXAMPLES,
inputs=text_input,
outputs=[highlighted_out, table_out],
fn=extract_entities,
cache_examples=False,
elem_id="examples-section",
)
gr.HTML("""
<div class="disclaimer">
ℹ️ &nbsp;This model identifies biomedical terms in text. It does not provide medical advice.
Results are for research and educational purposes only.
</div>
""")
# ── Wire up ──
analyze_btn.click(
fn=extract_entities,
inputs=text_input,
outputs=[highlighted_out, table_out],
)
text_input.submit(
fn=extract_entities,
inputs=text_input,
outputs=[highlighted_out, table_out],
)
if __name__ == "__main__":
demo.launch()