| import gradio as gr |
| import torch |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline |
|
|
| |
| model = AutoModelForSequenceClassification.from_pretrained("calerio-uva/roberta-adr-model") |
| tokenizer = AutoTokenizer.from_pretrained("calerio-uva/roberta-adr-model") |
|
|
| |
| ner = pipeline( |
| "ner", |
| model="d4data/biomedical-ner-all", |
| tokenizer="d4data/biomedical-ner-all", |
| aggregation_strategy="simple" |
| ) |
|
|
| |
| SYMPTOM_TAGS = {"sign_symptom", "symptom"} |
| DISEASE_TAGS = {"disease_disorder"} |
| MED_TAGS = {"medication", "administration", "therapeutic_procedure"} |
|
|
| |
| def dedupe_and_filter(tokens): |
| seen, out = set(), [] |
| for tok in tokens: |
| w = tok.strip() |
| if len(w) < 3: |
| continue |
| lw = w.lower() |
| if lw not in seen: |
| seen.add(lw) |
| out.append(w) |
| return out |
|
|
| def classify_adr(text: str): |
| print("๐ [DEBUG] Running classify_adr", flush=True) |
|
|
| |
| clean = text.strip().replace("nan", "").replace(" ", " ") |
| print("๐ [DEBUG] clean[:50]:", clean[:50], "...", flush=True) |
|
|
| |
| inputs = tokenizer(clean, return_tensors="pt", truncation=True, padding=True, max_length=512) |
| with torch.no_grad(): |
| logits = model(**inputs).logits |
| probs = torch.softmax(logits, dim=1)[0].cpu().numpy() |
|
|
| |
| ents = ner(clean) |
| print("๐ [DEBUG] raw ents:", [(e["entity_group"], e["word"], e["start"], e["end"]) for e in ents], flush=True) |
|
|
| |
| spans = [] |
| for ent in ents: |
| grp, start, end, score = ent["entity_group"].lower(), ent["start"], ent["end"], ent.get("score", 1.0) |
| if spans and spans[-1]["group"] == grp and start <= spans[-1]["end"]: |
| spans[-1]["end"] = max(spans[-1]["end"], end) |
| spans[-1]["score"] = max(spans[-1]["score"], score) |
| else: |
| spans.append({"group": grp, "start": start, "end": end, "score": score}) |
| print("๐ [DEBUG] merged spans:", spans, flush=True) |
|
|
| |
| for s in spans: |
| if s["group"] in MED_TAGS: |
| st, en = s["start"], s["end"] |
| |
| while en < len(clean) and clean[en].isalpha(): |
| en += 1 |
| s["end"] = en |
|
|
| |
| spans = [s for s in spans if s["score"] >= 0.6] |
| print("๐ [DEBUG] postโfilter spans:", spans, flush=True) |
|
|
| |
| tokens = [clean[s["start"]:s["end"]] for s in spans] |
| print("๐ [DEBUG] tokens:", tokens, flush=True) |
|
|
| |
| symptoms = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in SYMPTOM_TAGS]) |
| diseases = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in DISEASE_TAGS]) |
| medications = dedupe_and_filter([t for t, s in zip(tokens, spans) if s["group"] in MED_TAGS]) |
|
|
| |
| if probs[1] > 0.9: |
| comment = "โ High confidence this is a severe ADR." |
| elif probs[1] > 0.5: |
| comment = "โ ๏ธ Borderline case โ may be severe." |
| else: |
| comment = "โ
Likely not severe." |
|
|
| return ( |
| f"Not Severe (0): {probs[0]:.3f}\nSevere (1): {probs[1]:.3f}", |
| "\n".join(symptoms) or "None detected", |
| "\n".join(diseases) or "None detected", |
| "\n".join(medications) or "None detected", |
| comment |
| ) |
|
|
| |
| demo = gr.Interface( |
| fn=classify_adr, |
| inputs=gr.Textbox(lines=4, label="ADR Description"), |
| outputs=[ |
| gr.Textbox(label="Predicted Probabilities"), |
| gr.Textbox(label="Symptoms"), |
| gr.Textbox(label="Diseases or Conditions"), |
| gr.Textbox(label="Medications"), |
| gr.Textbox(label="Interpretation"), |
| ], |
| title="ADR Severity & NER Classifier", |
| description="Paste an ADR description to classify severity and extract symptoms, diseases & medications.", |
| allow_flagging="never" |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |