TyHamil commited on
Commit
10a77d3
·
verified ·
1 Parent(s): e2076ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -11
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import gradio as gr
2
  import shap
3
  import numpy as np
@@ -12,6 +15,7 @@ import csv
12
  import io
13
  import base64
14
 
 
15
  # Increase CSV field size limit
16
  csv.field_size_limit(sys.maxsize)
17
 
@@ -26,12 +30,41 @@ model = AutoModelForSequenceClassification.from_pretrained("TyHamil/ADRv2025").t
26
  pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device)
27
 
28
  # SHAP explainer
29
- explainer = shap.Explainer(pred)
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # NER pipeline
32
- ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
33
- ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
34
- ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # SHAP Plotting Function
37
  def generate_shap_plot(shap_values):
@@ -59,13 +92,13 @@ def adr_predict(x):
59
  local_plot = "<p>SHAP explanation not available.</p>"
60
 
61
  # NER Processing
62
- try:
63
- res = ner_pipe(text_input)
64
- entity_colors = {
65
- 'Severity': '#a3e635', 'Sign_symptom': '#1e3a8a', 'Medication': '#c0c0c0',
66
- 'Age': '#a3e635', 'Sex': '#a3e635', 'Diagnostic_procedure': '#c0c0c0',
67
- 'Biological_structure': '#c0c0c0'
68
- }
69
  htext = "<div style='line-height: 1.5; font-family: Poppins;'>"
70
  prev_end = 0
71
  res = sorted(res, key=lambda x: x['start'])
 
1
+ pip install scispacy
2
+ pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_core_sci_sm-0.5.1.tar.gz
3
+
4
  import gradio as gr
5
  import shap
6
  import numpy as np
 
15
  import io
16
  import base64
17
 
18
+
19
  # Increase CSV field size limit
20
  csv.field_size_limit(sys.maxsize)
21
 
 
30
  pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device)
31
 
32
  # SHAP explainer
33
+ #explainer = shap.Explainer(pred)
34
+ import shap
35
+
36
+ def predict_prob(texts):
37
+ encoded = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).to(device)
38
+ with torch.no_grad():
39
+ outputs = model(**encoded)
40
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
41
+ return probs.cpu().numpy()
42
+
43
+ explainer = shap.Explainer(predict_prob, tokenizer)
44
+
45
 
46
  # NER pipeline
47
+ #ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
48
+ #ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
49
+ #ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
50
+
51
+ import spacy
52
+ import scispacy
53
+ nlp = spacy.load("en_core_sci_sm") # Use small SciSpacy model
54
+
55
+ def scispacy_ner(text_input):
56
+ doc = nlp(text_input)
57
+ highlighted = text_input
58
+ offset = 0
59
+ for ent in doc.ents:
60
+ start = ent.start_char + offset
61
+ end = ent.end_char + offset
62
+ label = ent.label_
63
+ color = "#a3e635" if "DISEASE" in label else "#1e3a8a"
64
+ replacement = f"<mark style='background-color:{color}; border-radius: 4px;'>{ent.text} ({label})</mark>"
65
+ highlighted = highlighted[:start] + replacement + highlighted[end:]
66
+ offset += len(replacement) - (end - start)
67
+ return highlighted
68
 
69
  # SHAP Plotting Function
70
  def generate_shap_plot(shap_values):
 
92
  local_plot = "<p>SHAP explanation not available.</p>"
93
 
94
  # NER Processing
95
+ try:
96
+ htext = scispacy_ner(text_input)
97
+ except Exception as e:
98
+ print(f"NER processing failed: {e}")
99
+ htext = "<p>NER processing not available.</p>"
100
+
101
+
102
  htext = "<div style='line-height: 1.5; font-family: Poppins;'>"
103
  prev_end = 0
104
  res = sorted(res, key=lambda x: x['start'])