willwim commited on
Commit
f780603
·
verified ·
1 Parent(s): 4b5eb9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -37
app.py CHANGED
@@ -14,42 +14,34 @@ HF_TOKEN = os.getenv("hf_token")
14
  csv.field_size_limit(sys.maxsize)
15
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
16
 
17
- # ==========================================
18
- # 1. Load Models and Tokenizers
19
- # ==========================================
20
- # Classification Model
21
  tokenizer = AutoTokenizer.from_pretrained("willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN)
22
  model = AutoModelForSequenceClassification.from_pretrained("willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN).to(device)
23
 
24
- # Build a pipeline object for predictions (used for both label and SHAP)
25
  pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device)
26
 
27
  # SHAP explainer
28
  explainer = shap.Explainer(pred)
29
 
30
- # NER pipeline (Added device mapping for faster inference)
31
  ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
32
- ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all").to(device)
33
- ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple", device=device)
34
 
35
- # ==========================================
36
- # 2. Prediction Function
37
- # ==========================================
38
  def adr_predict(x):
39
  text_input = str(x).lower()
 
 
40
 
41
- # 1. Classification via the pipeline (reusing the pred object)
42
- raw_results = pred(text_input)[0]
43
- label_output = {item['label']: float(item['score']) for item in raw_results}
44
 
45
- # 2. SHAP Logic
46
  try:
47
  shap_values = explainer([text_input])
48
  local_plot = shap.plots.text(shap_values[0], display=False)
49
  except Exception as e:
50
  local_plot = f"<p style='color:red;'>SHAP explanation error: {e}</p>"
51
 
52
- # 3. NER Logic
53
  try:
54
  res = ner_pipe(text_input)
55
  entity_colors = {
@@ -62,32 +54,27 @@ def adr_predict(x):
62
  'Biological_structure':'#d9d9d9'
63
  }
64
 
65
- # Inline styles to force black text and white background for readability
66
- htext = "<div style='line-height: 2.0; font-size: 1.1em; color: black; background-color: white; padding: 10px;'>"
67
  prev_end = 0
68
  res = sorted(res, key=lambda x: x['start'])
69
-
70
  for entity in res:
71
  start, end = entity['start'], entity['end']
72
  word = text_input[start:end]
73
  color = entity_colors.get(entity['entity_group'], '#f3f3f3')
74
 
75
- htext += f"<span>{text_input[prev_end:start]}</span>"
76
- htext += (f"<mark style='background-color:{color}; color: black; padding: 2px 4px; "
77
- f"border-radius: 4px; font-weight: 500;'>{word} "
78
- f"<small style='opacity: 0.7; font-size: 0.7em;'>[{entity['entity_group']}]</small></mark>")
79
  prev_end = end
80
-
81
- htext += f"<span>{text_input[prev_end:]}</span></div>"
82
- except Exception as e:
83
- htext = f"<p style='color: black;'>NER processing error: {e}</p>"
84
 
 
85
  return label_output, local_plot, htext
86
 
87
- # ==========================================
88
- # 3. Gradio Interface
89
- # ==========================================
90
- # CSS ensures Gradio's dark mode doesn't override the white background and black text
91
  custom_css = """
92
  .gradio-container { font-family: 'Inter', system-ui, sans-serif; }
93
  .main-header { text-align: center; margin-bottom: 2rem; }
@@ -95,7 +82,7 @@ custom_css = """
95
  footer { visibility: hidden; }
96
  """
97
 
98
- with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as demo:
99
  with gr.Column(elem_classes="main-header"):
100
  gr.Markdown("# Adverse Drug Reaction (ADR) Detector")
101
  gr.Markdown("Analyze clinical text for potential medication-related severity and key medical entities.")
@@ -115,9 +102,7 @@ with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as
115
  gr.Examples(
116
  examples=[
117
  ["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
118
- ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."],
119
- ["A 62-year-old female presented with shortness of breath and anaphylaxis minutes after intravenous Penicillin administration."],
120
- ["Patient felt slight drowsiness and dry mouth after taking 10mg of Cetirizine, but no other symptoms were noted."]
121
  ],
122
  inputs=[prob1]
123
  )
@@ -141,5 +126,4 @@ with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as
141
  outputs=[label, local_plot, htext]
142
  )
143
 
144
- if __name__ == "__main__":
145
- demo.launch()
 
14
  csv.field_size_limit(sys.maxsize)
15
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
16
 
17
+ # Load models and tokenizer
 
 
 
18
  tokenizer = AutoTokenizer.from_pretrained("willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN)
19
  model = AutoModelForSequenceClassification.from_pretrained("willwim/adr_SJM_Notebook-Copy_for_T3", token=HF_TOKEN).to(device)
20
 
21
+ # Build a pipeline object for predictions
22
  pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None, device=device)
23
 
24
  # SHAP explainer
25
  explainer = shap.Explainer(pred)
26
 
27
+ # NER pipeline
28
  ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
29
+ ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
30
+ ner_pipe = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
31
 
 
 
 
32
  def adr_predict(x):
33
  text_input = str(x).lower()
34
+ encoded_input = tokenizer(text_input, return_tensors='pt').to(device)
35
+ output = model(**encoded_input)
36
 
37
+ scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
 
 
38
 
 
39
  try:
40
  shap_values = explainer([text_input])
41
  local_plot = shap.plots.text(shap_values[0], display=False)
42
  except Exception as e:
43
  local_plot = f"<p style='color:red;'>SHAP explanation error: {e}</p>"
44
 
 
45
  try:
46
  res = ner_pipe(text_input)
47
  entity_colors = {
 
54
  'Biological_structure':'#d9d9d9'
55
  }
56
 
57
+ # FIX: Added inline "color: black;" to force all un-highlighted text to be black
58
+ htext = "<div style='line-height: 2.0; font-size: 1.1em; color: black;'>"
59
  prev_end = 0
60
  res = sorted(res, key=lambda x: x['start'])
 
61
  for entity in res:
62
  start, end = entity['start'], entity['end']
63
  word = text_input[start:end]
64
  color = entity_colors.get(entity['entity_group'], '#f3f3f3')
65
 
66
+ htext += f"{text_input[prev_end:start]}"
67
+ # Highlighted text is also explicitly set to black
68
+ htext += f"<mark style='background-color:{color}; color: black; padding: 2px 4px; border-radius: 4px; font-weight: 500;'>{word} <small style='opacity: 0.7;'>[{entity['entity_group']}]</small></mark>"
 
69
  prev_end = end
70
+ htext += text_input[prev_end:] + "</div>"
71
+ except:
72
+ htext = "<p style='color: black;'>NER processing error.</p>"
 
73
 
74
+ label_output = {"Severe Reaction": float(scores[1]), "Non-severe Reaction": float(scores[0])}
75
  return label_output, local_plot, htext
76
 
77
+ # FIX: Added !important tags to ensure Gradio's dark mode doesn't override the white background and black text
 
 
 
78
  custom_css = """
79
  .gradio-container { font-family: 'Inter', system-ui, sans-serif; }
80
  .main-header { text-align: center; margin-bottom: 2rem; }
 
82
  footer { visibility: hidden; }
83
  """
84
 
85
+ with gr.Blocks(title="ADR Detector") as demo:
86
  with gr.Column(elem_classes="main-header"):
87
  gr.Markdown("# Adverse Drug Reaction (ADR) Detector")
88
  gr.Markdown("Analyze clinical text for potential medication-related severity and key medical entities.")
 
102
  gr.Examples(
103
  examples=[
104
  ["A 35 year-old male had severe headache after taking Aspirin. The lab results were normal."],
105
+ ["A 35 year-old female had minor pain in upper abdomen after taking Acetaminophen."]
 
 
106
  ],
107
  inputs=[prob1]
108
  )
 
126
  outputs=[label, local_plot, htext]
127
  )
128
 
129
+ demo.launch(css=custom_css, theme=gr.themes.Soft())