Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import shap
|
| 3 |
import numpy as np
|
|
@@ -12,6 +15,7 @@ import csv
|
|
| 12 |
import io
|
| 13 |
import base64
|
| 14 |
|
|
|
|
| 15 |
# Increase CSV field size limit
|
| 16 |
csv.field_size_limit(sys.maxsize)
|
| 17 |
|
|
@@ -26,12 +30,41 @@ model = AutoModelForSequenceClassification.from_pretrained("TyHamil/ADRv2025").t
|
|
| 26 |
pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device)
|
| 27 |
|
| 28 |
# SHAP explainer
|
| 29 |
-
explainer = shap.Explainer(pred)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# NER pipeline
|
| 32 |
-
ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
|
| 33 |
-
ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
|
| 34 |
-
ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
# SHAP Plotting Function
|
| 37 |
def generate_shap_plot(shap_values):
|
|
@@ -59,13 +92,13 @@ def adr_predict(x):
|
|
| 59 |
local_plot = "<p>SHAP explanation not available.</p>"
|
| 60 |
|
| 61 |
# NER Processing
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
htext = "<div style='line-height: 1.5; font-family: Poppins;'>"
|
| 70 |
prev_end = 0
|
| 71 |
res = sorted(res, key=lambda x: x['start'])
|
|
|
|
| 1 |
+
pip install scispacy
|
| 2 |
+
pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_core_sci_sm-0.5.1.tar.gz
|
| 3 |
+
|
| 4 |
import gradio as gr
|
| 5 |
import shap
|
| 6 |
import numpy as np
|
|
|
|
| 15 |
import io
|
| 16 |
import base64
|
| 17 |
|
| 18 |
+
|
| 19 |
# Increase CSV field size limit
|
| 20 |
csv.field_size_limit(sys.maxsize)
|
| 21 |
|
|
|
|
| 30 |
pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device)
|
| 31 |
|
| 32 |
# SHAP explainer
|
| 33 |
+
#explainer = shap.Explainer(pred)
|
| 34 |
+
import shap
|
| 35 |
+
|
| 36 |
+
def predict_prob(texts):
|
| 37 |
+
encoded = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).to(device)
|
| 38 |
+
with torch.no_grad():
|
| 39 |
+
outputs = model(**encoded)
|
| 40 |
+
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
|
| 41 |
+
return probs.cpu().numpy()
|
| 42 |
+
|
| 43 |
+
explainer = shap.Explainer(predict_prob, tokenizer)
|
| 44 |
+
|
| 45 |
|
| 46 |
# NER pipeline
|
| 47 |
+
#ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
|
| 48 |
+
#ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
|
| 49 |
+
#ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
|
| 50 |
+
|
| 51 |
+
import spacy
|
| 52 |
+
import scispacy
|
| 53 |
+
nlp = spacy.load("en_core_sci_sm") # Use small SciSpacy model
|
| 54 |
+
|
| 55 |
+
def scispacy_ner(text_input):
|
| 56 |
+
doc = nlp(text_input)
|
| 57 |
+
highlighted = text_input
|
| 58 |
+
offset = 0
|
| 59 |
+
for ent in doc.ents:
|
| 60 |
+
start = ent.start_char + offset
|
| 61 |
+
end = ent.end_char + offset
|
| 62 |
+
label = ent.label_
|
| 63 |
+
color = "#a3e635" if "DISEASE" in label else "#1e3a8a"
|
| 64 |
+
replacement = f"<mark style='background-color:{color}; border-radius: 4px;'>{ent.text} ({label})</mark>"
|
| 65 |
+
highlighted = highlighted[:start] + replacement + highlighted[end:]
|
| 66 |
+
offset += len(replacement) - (end - start)
|
| 67 |
+
return highlighted
|
| 68 |
|
| 69 |
# SHAP Plotting Function
|
| 70 |
def generate_shap_plot(shap_values):
|
|
|
|
| 92 |
local_plot = "<p>SHAP explanation not available.</p>"
|
| 93 |
|
| 94 |
# NER Processing
|
| 95 |
+
try:
|
| 96 |
+
htext = scispacy_ner(text_input)
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print(f"NER processing failed: {e}")
|
| 99 |
+
htext = "<p>NER processing not available.</p>"
|
| 100 |
+
|
| 101 |
+
|
| 102 |
htext = "<div style='line-height: 1.5; font-family: Poppins;'>"
|
| 103 |
prev_end = 0
|
| 104 |
res = sorted(res, key=lambda x: x['start'])
|