import torch import gradio as gr from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer # Load Model & Tokenizer model_name = "mmuzamilai/distilbert-drug-ner" model = AutoModelForTokenClassification.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # NER Pipeline (CPU only) ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, device=-1) # Label Map label_map = { 0: "DOSAGE", 1: "DRUG_NAME", 2: "EVENT", 3: "LOCATION", 4: "OTHER", 5: "ROA", 6: "SYMPTOM", 7: "TEMPORAL", } # Color Map for Pills color_map = { "DOSAGE": "#fbcfe8", # pink "DRUG_NAME": "#a5f3fc", # sky blue "EVENT": "#fde68a", # yellow "LOCATION": "#ddd6fe", # violet "ROA": "#fecaca", # red "SYMPTOM": "#fcd34d", # amber "TEMPORAL": "#c7d2fe", # indigo } def merge_subwords_and_decode(entities): merged = [] current_word = "" current_label = None for entity in entities: word = entity["word"] label_id = int(entity["entity"].replace("LABEL_", "")) label = label_map.get(label_id, "O") if word.startswith("##"): current_word += word[2:] else: if current_word: merged.append({"word": current_word, "label": current_label}) current_word = word current_label = label if current_word: merged.append({"word": current_word, "label": current_label}) return merged def format_entities_html(entities): """Return HTML with color-coded entity spans.""" html = "" for ent in entities: word = ent["word"] label = ent["label"] if label != "O": color = color_map.get(label, "#e5e7eb") html += f'{word} ({label}) ' else: html += f"{word} " return html.strip() def predict_and_format(text): raw_entities = ner_pipeline(text) cleaned_entities = merge_subwords_and_decode(raw_entities) return format_entities_html(cleaned_entities) # Gradio Interface demo = gr.Interface( fn=predict_and_format, inputs=gr.Textbox(lines=4, placeholder="Enter clinical or drug-related text here..."), outputs=gr.HTML(label="Named Entity Recognition"), title="💊 Drug NER Highlighter", description="A custom NER model that highlights drug-related named entities with colorful pills." ) demo.launch()