DrugNER / app.py
MMuzamilAI's picture
Create app.py
fac4991 verified
import torch
import gradio as gr
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
# Load Model & Tokenizer
model_name = "mmuzamilai/distilbert-drug-ner"
model = AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# NER Pipeline (CPU only)
ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, device=-1)
# Label Map
label_map = {
0: "DOSAGE",
1: "DRUG_NAME",
2: "EVENT",
3: "LOCATION",
4: "OTHER",
5: "ROA",
6: "SYMPTOM",
7: "TEMPORAL",
}
# Color Map for Pills
color_map = {
"DOSAGE": "#fbcfe8", # pink
"DRUG_NAME": "#a5f3fc", # sky blue
"EVENT": "#fde68a", # yellow
"LOCATION": "#ddd6fe", # violet
"ROA": "#fecaca", # red
"SYMPTOM": "#fcd34d", # amber
"TEMPORAL": "#c7d2fe", # indigo
}
def merge_subwords_and_decode(entities):
merged = []
current_word = ""
current_label = None
for entity in entities:
word = entity["word"]
label_id = int(entity["entity"].replace("LABEL_", ""))
label = label_map.get(label_id, "O")
if word.startswith("##"):
current_word += word[2:]
else:
if current_word:
merged.append({"word": current_word, "label": current_label})
current_word = word
current_label = label
if current_word:
merged.append({"word": current_word, "label": current_label})
return merged
def format_entities_html(entities):
"""Return HTML with color-coded entity spans."""
html = ""
for ent in entities:
word = ent["word"]
label = ent["label"]
if label != "O":
color = color_map.get(label, "#e5e7eb")
html += f'<span style="background-color:{color}; padding:2px 6px; margin:2px; border-radius:10px; display:inline-block;">{word} <small style="opacity:0.7">({label})</small></span> '
else:
html += f"{word} "
return html.strip()
def predict_and_format(text):
raw_entities = ner_pipeline(text)
cleaned_entities = merge_subwords_and_decode(raw_entities)
return format_entities_html(cleaned_entities)
# Gradio Interface
demo = gr.Interface(
fn=predict_and_format,
inputs=gr.Textbox(lines=4, placeholder="Enter clinical or drug-related text here..."),
outputs=gr.HTML(label="Named Entity Recognition"),
title="πŸ’Š Drug NER Highlighter",
description="A custom NER model that highlights drug-related named entities with colorful pills."
)
demo.launch()