Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer | |
| # Load Model & Tokenizer | |
| model_name = "mmuzamilai/distilbert-drug-ner" | |
| model = AutoModelForTokenClassification.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # NER Pipeline (CPU only) | |
| ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, device=-1) | |
| # Label Map | |
| label_map = { | |
| 0: "DOSAGE", | |
| 1: "DRUG_NAME", | |
| 2: "EVENT", | |
| 3: "LOCATION", | |
| 4: "OTHER", | |
| 5: "ROA", | |
| 6: "SYMPTOM", | |
| 7: "TEMPORAL", | |
| } | |
| # Color Map for Pills | |
| color_map = { | |
| "DOSAGE": "#fbcfe8", # pink | |
| "DRUG_NAME": "#a5f3fc", # sky blue | |
| "EVENT": "#fde68a", # yellow | |
| "LOCATION": "#ddd6fe", # violet | |
| "ROA": "#fecaca", # red | |
| "SYMPTOM": "#fcd34d", # amber | |
| "TEMPORAL": "#c7d2fe", # indigo | |
| } | |
| def merge_subwords_and_decode(entities): | |
| merged = [] | |
| current_word = "" | |
| current_label = None | |
| for entity in entities: | |
| word = entity["word"] | |
| label_id = int(entity["entity"].replace("LABEL_", "")) | |
| label = label_map.get(label_id, "O") | |
| if word.startswith("##"): | |
| current_word += word[2:] | |
| else: | |
| if current_word: | |
| merged.append({"word": current_word, "label": current_label}) | |
| current_word = word | |
| current_label = label | |
| if current_word: | |
| merged.append({"word": current_word, "label": current_label}) | |
| return merged | |
| def format_entities_html(entities): | |
| """Return HTML with color-coded entity spans.""" | |
| html = "" | |
| for ent in entities: | |
| word = ent["word"] | |
| label = ent["label"] | |
| if label != "O": | |
| color = color_map.get(label, "#e5e7eb") | |
| html += f'<span style="background-color:{color}; padding:2px 6px; margin:2px; border-radius:10px; display:inline-block;">{word} <small style="opacity:0.7">({label})</small></span> ' | |
| else: | |
| html += f"{word} " | |
| return html.strip() | |
| def predict_and_format(text): | |
| raw_entities = ner_pipeline(text) | |
| cleaned_entities = merge_subwords_and_decode(raw_entities) | |
| return format_entities_html(cleaned_entities) | |
| # Gradio Interface | |
| demo = gr.Interface( | |
| fn=predict_and_format, | |
| inputs=gr.Textbox(lines=4, placeholder="Enter clinical or drug-related text here..."), | |
| outputs=gr.HTML(label="Named Entity Recognition"), | |
| title="π Drug NER Highlighter", | |
| description="A custom NER model that highlights drug-related named entities with colorful pills." | |
| ) | |
| demo.launch() |