Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,42 +14,34 @@ HF_TOKEN = os.getenv("hf_token")
|
|
| 14 |
csv.field_size_limit(sys.maxsize)
|
| 15 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
# 1. Load Models and Tokenizers
|
| 19 |
-
# ==========================================
|
| 20 |
-
# Classification Model
|
| 21 |
tokenizer = AutoTokenizer.from_pretrained("willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN)
|
| 22 |
model = AutoModelForSequenceClassification.from_pretrained("willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN).to(device)
|
| 23 |
|
| 24 |
-
# Build a pipeline object for predictions
|
| 25 |
pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device)
|
| 26 |
|
| 27 |
# SHAP explainer
|
| 28 |
explainer = shap.Explainer(pred)
|
| 29 |
|
| 30 |
-
# NER pipeline
|
| 31 |
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
|
| 32 |
-
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
|
| 33 |
-
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple"
|
| 34 |
|
| 35 |
-
# ==========================================
|
| 36 |
-
# 2. Prediction Function
|
| 37 |
-
# ==========================================
|
| 38 |
def adr_predict(x):
|
| 39 |
text_input = str(x).lower()
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
raw_results = pred(text_input)[0]
|
| 43 |
-
label_output = {item['label']: float(item['score']) for item in raw_results}
|
| 44 |
|
| 45 |
-
# 2. SHAP Logic
|
| 46 |
try:
|
| 47 |
shap_values = explainer([text_input])
|
| 48 |
local_plot = shap.plots.text(shap_values[0], display=False)
|
| 49 |
except Exception as e:
|
| 50 |
local_plot = f"<p style='color:red;'>SHAP explanation error: {e}</p>"
|
| 51 |
|
| 52 |
-
# 3. NER Logic
|
| 53 |
try:
|
| 54 |
res = ner_pipe(text_input)
|
| 55 |
entity_colors = {
|
|
@@ -62,32 +54,27 @@ def adr_predict(x):
|
|
| 62 |
'Biological_structure':'#d9d9d9'
|
| 63 |
}
|
| 64 |
|
| 65 |
-
#
|
| 66 |
-
htext = "<div style='line-height: 2.0; font-size: 1.1em; color: black;
|
| 67 |
prev_end = 0
|
| 68 |
res = sorted(res, key=lambda x: x['start'])
|
| 69 |
-
|
| 70 |
for entity in res:
|
| 71 |
start, end = entity['start'], entity['end']
|
| 72 |
word = text_input[start:end]
|
| 73 |
color = entity_colors.get(entity['entity_group'], '#f3f3f3')
|
| 74 |
|
| 75 |
-
htext += f"
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
f"<small style='opacity: 0.7; font-size: 0.7em;'>[{entity['entity_group']}]</small></mark>")
|
| 79 |
prev_end = end
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
htext = f"<p style='color: black;'>NER processing error: {e}</p>"
|
| 84 |
|
|
|
|
| 85 |
return label_output, local_plot, htext
|
| 86 |
|
| 87 |
-
#
|
| 88 |
-
# 3. Gradio Interface
|
| 89 |
-
# ==========================================
|
| 90 |
-
# CSS ensures Gradio's dark mode doesn't override the white background and black text
|
| 91 |
custom_css = """
|
| 92 |
.gradio-container { font-family: 'Inter', system-ui, sans-serif; }
|
| 93 |
.main-header { text-align: center; margin-bottom: 2rem; }
|
|
@@ -95,7 +82,7 @@ custom_css = """
|
|
| 95 |
footer { visibility: hidden; }
|
| 96 |
"""
|
| 97 |
|
| 98 |
-
with gr.Blocks(title="ADR Detector"
|
| 99 |
with gr.Column(elem_classes="main-header"):
|
| 100 |
gr.Markdown("# Adverse Drug Reaction (ADR) Detector")
|
| 101 |
gr.Markdown("Analyze clinical text for potential medication-related severity and key medical entities.")
|
|
@@ -115,9 +102,7 @@ with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as
|
|
| 115 |
gr.Examples(
|
| 116 |
examples=[
|
| 117 |
["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
|
| 118 |
-
["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]
|
| 119 |
-
["A 62-year-old female presented with shortness of breath and anaphylaxis minutes after intravenous Penicillin administration."],
|
| 120 |
-
["Patient felt slight drowsiness and dry mouth after taking 10mg of Cetirizine, but no other symptoms were noted."]
|
| 121 |
],
|
| 122 |
inputs=[prob1]
|
| 123 |
)
|
|
@@ -141,5 +126,4 @@ with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as
|
|
| 141 |
outputs=[label, local_plot, htext]
|
| 142 |
)
|
| 143 |
|
| 144 |
-
|
| 145 |
-
demo.launch()
|
|
|
|
| 14 |
csv.field_size_limit(sys.maxsize)
|
| 15 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 16 |
|
| 17 |
+
# Load models and tokenizer
|
|
|
|
|
|
|
|
|
|
| 18 |
tokenizer = AutoTokenizer.from_pretrained("willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN)
|
| 19 |
model = AutoModelForSequenceClassification.from_pretrained("willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN).to(device)
|
| 20 |
|
| 21 |
+
# Build a pipeline object for predictions
|
| 22 |
pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device)
|
| 23 |
|
| 24 |
# SHAP explainer
|
| 25 |
explainer = shap.Explainer(pred)
|
| 26 |
|
| 27 |
+
# NER pipeline
|
| 28 |
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
|
| 29 |
+
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
|
| 30 |
+
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
|
| 31 |
|
|
|
|
|
|
|
|
|
|
| 32 |
def adr_predict(x):
|
| 33 |
text_input = str(x).lower()
|
| 34 |
+
encoded_input = tokenizer(text_input, return_tensors='pt').to(device)
|
| 35 |
+
output = model(**encoded_input)
|
| 36 |
|
| 37 |
+
scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
|
|
|
|
|
|
|
| 38 |
|
|
|
|
| 39 |
try:
|
| 40 |
shap_values = explainer([text_input])
|
| 41 |
local_plot = shap.plots.text(shap_values[0], display=False)
|
| 42 |
except Exception as e:
|
| 43 |
local_plot = f"<p style='color:red;'>SHAP explanation error: {e}</p>"
|
| 44 |
|
|
|
|
| 45 |
try:
|
| 46 |
res = ner_pipe(text_input)
|
| 47 |
entity_colors = {
|
|
|
|
| 54 |
'Biological_structure':'#d9d9d9'
|
| 55 |
}
|
| 56 |
|
| 57 |
+
# FIX: Added inline "color: black;" to force all un-highlighted text to be black
|
| 58 |
+
htext = "<div style='line-height: 2.0; font-size: 1.1em; color: black;'>"
|
| 59 |
prev_end = 0
|
| 60 |
res = sorted(res, key=lambda x: x['start'])
|
|
|
|
| 61 |
for entity in res:
|
| 62 |
start, end = entity['start'], entity['end']
|
| 63 |
word = text_input[start:end]
|
| 64 |
color = entity_colors.get(entity['entity_group'], '#f3f3f3')
|
| 65 |
|
| 66 |
+
htext += f"{text_input[prev_end:start]}"
|
| 67 |
+
# Highlighted text is also explicitly set to black
|
| 68 |
+
htext += f"<mark style='background-color:{color}; color: black; padding: 2px 4px; border-radius: 4px; font-weight: 500;'>{word} <small style='opacity: 0.7;'>[{entity['entity_group']}]</small></mark>"
|
|
|
|
| 69 |
prev_end = end
|
| 70 |
+
htext += text_input[prev_end:] + "</div>"
|
| 71 |
+
except:
|
| 72 |
+
htext = "<p style='color: black;'>NER processing error.</p>"
|
|
|
|
| 73 |
|
| 74 |
+
label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])}
|
| 75 |
return label_output, local_plot, htext
|
| 76 |
|
| 77 |
+
# FIX: Added !important tags to ensure Gradio's dark mode doesn't override the white background and black text
|
|
|
|
|
|
|
|
|
|
| 78 |
custom_css = """
|
| 79 |
.gradio-container { font-family: 'Inter', system-ui, sans-serif; }
|
| 80 |
.main-header { text-align: center; margin-bottom: 2rem; }
|
|
|
|
| 82 |
footer { visibility: hidden; }
|
| 83 |
"""
|
| 84 |
|
| 85 |
+
with gr.Blocks(title="ADR Detector") as demo:
|
| 86 |
with gr.Column(elem_classes="main-header"):
|
| 87 |
gr.Markdown("# Adverse Drug Reaction (ADR) Detector")
|
| 88 |
gr.Markdown("Analyze clinical text for potential medication-related severity and key medical entities.")
|
|
|
|
| 102 |
gr.Examples(
|
| 103 |
examples=[
|
| 104 |
["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
|
| 105 |
+
["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]
|
|
|
|
|
|
|
| 106 |
],
|
| 107 |
inputs=[prob1]
|
| 108 |
)
|
|
|
|
| 126 |
outputs=[label, local_plot, htext]
|
| 127 |
)
|
| 128 |
|
| 129 |
+
demo.launch(css=custom_css, theme=gr.themes.Soft())
|
|
|