willwim commited on
Commit
86232e6
Β·
verified Β·
1 Parent(s): 19bf2e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -88
app.py CHANGED
@@ -10,6 +10,8 @@ import matplotlib.pyplot as plt
10
  import matplotlib.patches as mpatches
11
  import io
12
  import base64
 
 
13
  import sys
14
  import csv
15
  import os
@@ -42,27 +44,92 @@ ner_pipe = pipeline(
42
  )
43
 
44
 
45
- # ── Custom SHAP bar-chart renderer ─────────────────────────────────────────────
46
- def render_shap_bar_chart(shap_values, class_idx: int = 1) -> str:
47
  """
48
- Builds a horizontal bar chart (red = pushes toward ADR, teal = pushes away)
49
- that mirrors the style shown in the reference screenshot.
50
- Returns an <img> HTML tag with an embedded base64 PNG.
 
 
 
 
51
  """
52
- # shap_values is a shap.Explanation object for a single sample
53
- # .values shape: (n_tokens, n_classes) or (n_tokens,) when binary
54
- values = shap_values.values # (n_tokens, n_classes)
55
- tokens = shap_values.data # list/array of token strings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  if values.ndim == 2:
58
- sv = values[:, class_idx] # SHAP values for "Severe Reaction"
59
  else:
60
  sv = values
61
 
62
- # ── Filter out noise tokens ──────────────────────────────────────────────
63
- import string
64
- import re
65
-
66
  STOPWORDS = {
67
  "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for",
68
  "of", "with", "by", "from", "is", "was", "are", "were", "be", "been",
@@ -78,19 +145,14 @@ def render_shap_bar_chart(shap_values, class_idx: int = 1) -> str:
78
  t = tok.strip().lower()
79
  if not t:
80
  return False
81
- # Pure punctuation or whitespace
82
  if all(c in string.punctuation + " \t\n" for c in t):
83
  return False
84
- # Single characters (letters or digits alone)
85
  if len(t) <= 1:
86
  return False
87
- # Common stopwords
88
  if t in STOPWORDS:
89
  return False
90
- # Subword fragments starting with ## (BERT-style)
91
  if t.startswith("##"):
92
  return False
93
- # Tokens that are only digits
94
  if re.fullmatch(r"\d+", t):
95
  return False
96
  return True
@@ -100,30 +162,26 @@ def render_shap_bar_chart(shap_values, class_idx: int = 1) -> str:
100
  sv_f = sv[mask]
101
  tok_f = tok_arr[mask]
102
 
103
- # ── Deduplicate: for repeated tokens keep the one with the largest |SHAP| ─
104
  seen = {}
105
  for i, tok in enumerate(tok_f):
106
  key = tok.strip().lower()
107
  if key not in seen or abs(sv_f[i]) > abs(sv_f[seen[key]]):
108
  seen[key] = i
109
- keep = sorted(seen.values())
110
  sv_f = sv_f[keep]
111
  tok_f = tok_f[keep]
112
 
113
- # ── Top-N by absolute magnitude, then sort for chart display ─────────────
114
- TOP_N = 20
115
- order = np.argsort(np.abs(sv_f))[::-1][:TOP_N]
116
- sv_top = sv_f[order]
117
- tok_top = tok_f[order]
118
-
119
- # Re-sort so chart reads largest positive at top, largest negative at bottom
120
  plot_order = np.argsort(sv_top)
121
- sv_plot = sv_top[plot_order]
122
- tok_plot = tok_top[plot_order]
123
-
124
- COLOR_POSITIVE = "#cc1111" # bold red β€” increases severe ADR probability
125
- COLOR_NEGATIVE = "#1a6fcc" # strong blue β€” decreases severe ADR probability
126
 
 
 
127
  colors = [COLOR_POSITIVE if v > 0 else COLOR_NEGATIVE for v in sv_plot]
128
 
129
  fig_height = max(4, len(sv_plot) * 0.38)
@@ -131,11 +189,8 @@ def render_shap_bar_chart(shap_values, class_idx: int = 1) -> str:
131
  ax.set_facecolor("white")
132
 
133
  y_pos = np.arange(len(sv_plot))
134
- bars = ax.barh(y_pos, sv_plot, color=colors, height=0.6, edgecolor="none")
135
-
136
- # Zero line
137
  ax.axvline(0, color="#333333", linewidth=0.9, zorder=3)
138
-
139
  ax.set_yticks(y_pos)
140
  ax.set_yticklabels(tok_plot, fontsize=10, color="#222222")
141
  ax.set_xlabel("SHAP Value β€” impact on ADR prediction", fontsize=10, color="#444444")
@@ -144,34 +199,83 @@ def render_shap_bar_chart(shap_values, class_idx: int = 1) -> str:
144
  fontsize=12, fontweight="bold", color="#222222", pad=12
145
  )
146
 
147
- # Legend patches β€” colors match the bars exactly
148
- red_patch = mpatches.Patch(color=COLOR_POSITIVE,
149
- label="Increases severe ADR probability")
150
- blue_patch = mpatches.Patch(color=COLOR_NEGATIVE,
151
- label="Decreases severe ADR probability")
152
- ax.legend(handles=[red_patch, blue_patch], fontsize=9,
153
- loc="lower right", framealpha=0.7)
154
 
155
  ax.spines["top"].set_visible(False)
156
  ax.spines["right"].set_visible(False)
157
  ax.spines["left"].set_visible(False)
158
  ax.tick_params(axis="y", length=0)
159
  ax.tick_params(axis="x", colors="#555555")
160
- ax.xaxis.label.set_color("#555555")
161
 
162
  plt.tight_layout()
163
-
164
  buf = io.BytesIO()
165
- fig.savefig(buf, format="png", dpi=130, bbox_inches="tight",
166
- facecolor="white")
167
  plt.close(fig)
168
  buf.seek(0)
169
  b64 = base64.b64encode(buf.read()).decode("utf-8")
170
  return (
171
- f"<div style='background:white; padding:12px; border-radius:8px;'>"
172
- f"<img src='data:image/png;base64,{b64}' "
173
- f"style='width:100%; max-width:760px; display:block; margin:auto;' />"
174
- f"</div>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  )
176
 
177
 
@@ -182,18 +286,16 @@ def adr_predict(x):
182
  output = model(**encoded_input)
183
  scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
184
 
185
- # ── SHAP (bar chart) ──────────────────────────────────────────────────────
186
  try:
187
  shap_values = explainer([text_input])
188
  shap_html = render_shap_bar_chart(shap_values[0], class_idx=1)
189
  except Exception as e:
190
- shap_html = f"<p style='color:red;'>SHAP explanation error: {e}</p>"
191
 
192
- # ── NER ───────────────────────────────────────────────────────────────────
193
  try:
194
  res = ner_pipe(text_input)
195
-
196
- # Richer config: background, border accent, and human-readable display name
197
  entity_config = {
198
  "Severity": {"bg": "#ffe0de", "border": "#e07070", "label": "Severity"},
199
  "Sign_symptom": {"bg": "#d4f5d4", "border": "#5aaa5a", "label": "Symptom"},
@@ -205,30 +307,27 @@ def adr_predict(x):
205
  }
206
  default_cfg = {"bg": "#f0f0f0", "border": "#aaaaaa", "label": "Other"}
207
 
208
- # Legend β€” only show entity types actually found in this result
209
- seen_groups = list(dict.fromkeys(e["entity_group"] for e in sorted(res, key=lambda e: e["start"])))
 
210
  legend_html = (
211
  "<div style='display:flex; flex-wrap:wrap; gap:10px; "
212
- "margin-bottom:16px; padding-bottom:12px; "
213
- "border-bottom:1px solid #e0e0e0;'>"
214
  )
215
  for grp in seen_groups:
216
- cfg = entity_config.get(grp, default_cfg)
217
  bg = cfg["bg"]
218
  border = cfg["border"]
219
  lbl = cfg["label"]
220
  legend_html += (
221
- f"<span style='display:inline-flex; align-items:center; gap:6px; "
222
- f"font-size:0.8em; font-weight:600; color:#444; "
223
- f"font-family:system-ui, sans-serif;'>"
224
- f"<span style='display:inline-block; width:13px; height:13px; "
225
- f"border-radius:3px; background:{bg}; "
226
- f"border:2px solid {border};'></span>"
227
- f"{lbl}</span>"
228
  )
229
  legend_html += "</div>"
230
 
231
- # Annotated text β€” each entity as a pill with a small label badge above it
232
  text_html = (
233
  "<div style='line-height:3.4; font-size:1.05em; color:#111; "
234
  "font-family:Georgia, serif; letter-spacing:0.01em;'>"
@@ -243,36 +342,35 @@ def adr_predict(x):
243
  border = cfg["border"]
244
  lbl = cfg["label"]
245
 
246
- # Plain text before this entity
247
- text_html += f"<span style='color:#111;'>{text_input[prev_end:start]}</span>"
248
-
249
- # Entity pill: small uppercase label above, coloured word below
250
  text_html += (
251
- f"<span style='display:inline-block; position:relative; "
252
- f"vertical-align:middle; margin:0 2px; text-align:center;'>"
253
- f"<span style='display:block; font-size:0.6em; font-weight:800; "
254
- f"letter-spacing:0.08em; text-transform:uppercase; "
255
- f"color:{border}; font-family:system-ui, sans-serif; "
256
- f"line-height:1.1; margin-bottom:2px;'>{lbl}</span>"
257
- f"<span style='background:{bg}; border:1.5px solid {border}; "
258
- f"color:#111; padding:3px 8px; border-radius:6px; "
259
- f"font-weight:600; white-space:nowrap;'>{word}</span>"
260
- f"</span>"
261
  )
262
  prev_end = end
263
 
264
- text_html += f"<span style='color:#111;'>{text_input[prev_end:]}</span></div>"
265
  htext = legend_html + text_html
266
 
267
  except Exception as ex:
268
- htext = f"<p style='color:#c00;'>NER processing error: {ex}</p>"
269
 
270
  label_output = {
271
  "Severe Reaction": float(scores[1]),
272
  "Non-severe Reaction": float(scores[0]),
273
  }
274
 
275
- return label_output, shap_html, htext
 
 
276
 
277
 
278
  # ── UI ─────────────────────────────────────────────────────────────────────────
@@ -330,6 +428,9 @@ with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as
330
  gr.Markdown("### Classification")
331
  label = gr.Label(label="Severity Probability")
332
 
 
 
 
333
  gr.Markdown("### Medical Entities")
334
  htext_out = gr.HTML(label="NER Mapping", elem_classes="output-box")
335
 
@@ -345,7 +446,7 @@ with gr.Blocks(title="ADR Detector", css=custom_css, theme=gr.themes.Soft()) as
345
  submit_btn.click(
346
  fn=adr_predict,
347
  inputs=[prob1],
348
- outputs=[label, shap_out, htext_out],
349
  )
350
 
351
  demo.launch()
 
10
  import matplotlib.patches as mpatches
11
  import io
12
  import base64
13
+ import string
14
+ import re
15
  import sys
16
  import csv
17
  import os
 
44
  )
45
 
46
 
47
+ # ── Severity rating renderer ──────────────────────────────────────────────────
48
+ def build_severity_html(severe_prob: float) -> str:
49
  """
50
+ Converts the Severe Reaction probability into a Low / Medium / High badge
51
+ with a colour-coded progress bar and plain-language description.
52
+
53
+ Thresholds (tunable):
54
+ < 0.35 β†’ Low
55
+ 0.35 – 0.65 β†’ Medium
56
+ >= 0.65 β†’ High
57
  """
58
+ if severe_prob < 0.35:
59
+ level = "Low"
60
+ bar_color = "#2eaa5a" # green
61
+ bg_color = "#eafaf1"
62
+ border_col = "#2eaa5a"
63
+ description = (
64
+ "The clinical text shows limited indicators of a serious adverse drug reaction. "
65
+ "Routine monitoring is advised."
66
+ )
67
+ elif severe_prob < 0.65:
68
+ level = "Medium"
69
+ bar_color = "#e8a020" # amber
70
+ bg_color = "#fffbea"
71
+ border_col = "#e8a020"
72
+ description = (
73
+ "The clinical text contains some features associated with a significant adverse "
74
+ "drug reaction. Further clinical assessment is recommended."
75
+ )
76
+ else:
77
+ level = "High"
78
+ bar_color = "#cc1111" # red
79
+ bg_color = "#fff0f0"
80
+ border_col = "#cc1111"
81
+ description = (
82
+ "The clinical text shows strong indicators of a severe adverse drug reaction. "
83
+ "Prompt clinical review is strongly recommended."
84
+ )
85
+
86
+ pct = round(severe_prob * 100, 1)
87
+ bar_width = round(severe_prob * 100, 1)
88
+
89
+ html = (
90
+ "<div style='"
91
+ "background:" + bg_color + "; "
92
+ "border:1.5px solid " + border_col + "; "
93
+ "border-radius:10px; padding:18px 20px; font-family:system-ui, sans-serif;'>"
94
+
95
+ # Rating badge + percentage on same row
96
+ "<div style='display:flex; align-items:center; justify-content:space-between; "
97
+ "margin-bottom:12px;'>"
98
+ "<span style='font-size:1.5em; font-weight:800; color:" + bar_color + ";'>"
99
+ + level +
100
+ "</span>"
101
+ "<span style='font-size:1.0em; font-weight:600; color:#555;'>"
102
+ + str(pct) + "% severe probability"
103
+ "</span>"
104
+ "</div>"
105
+
106
+ # Progress bar track
107
+ "<div style='background:#e0e0e0; border-radius:6px; height:14px; "
108
+ "overflow:hidden; margin-bottom:14px;'>"
109
+ "<div style='background:" + bar_color + "; width:" + str(bar_width) + "%; "
110
+ "height:100%; border-radius:6px; "
111
+ "transition:width 0.4s ease;'></div>"
112
+ "</div>"
113
+
114
+ # Description
115
+ "<p style='margin:0; font-size:0.9em; color:#444; line-height:1.6;'>"
116
+ + description +
117
+ "</p>"
118
+ "</div>"
119
+ )
120
+ return html
121
+
122
+
123
+ # ── Custom SHAP bar-chart renderer ─────────────────────────────────────────────
124
+ def render_shap_bar_chart(shap_values, class_idx=1):
125
+ values = shap_values.values
126
+ tokens = shap_values.data
127
 
128
  if values.ndim == 2:
129
+ sv = values[:, class_idx]
130
  else:
131
  sv = values
132
 
 
 
 
 
133
  STOPWORDS = {
134
  "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for",
135
  "of", "with", "by", "from", "is", "was", "are", "were", "be", "been",
 
145
  t = tok.strip().lower()
146
  if not t:
147
  return False
 
148
  if all(c in string.punctuation + " \t\n" for c in t):
149
  return False
 
150
  if len(t) <= 1:
151
  return False
 
152
  if t in STOPWORDS:
153
  return False
 
154
  if t.startswith("##"):
155
  return False
 
156
  if re.fullmatch(r"\d+", t):
157
  return False
158
  return True
 
162
  sv_f = sv[mask]
163
  tok_f = tok_arr[mask]
164
 
165
+ # Deduplicate: keep highest |SHAP| per unique token
166
  seen = {}
167
  for i, tok in enumerate(tok_f):
168
  key = tok.strip().lower()
169
  if key not in seen or abs(sv_f[i]) > abs(sv_f[seen[key]]):
170
  seen[key] = i
171
+ keep = sorted(seen.values())
172
  sv_f = sv_f[keep]
173
  tok_f = tok_f[keep]
174
 
175
+ TOP_N = 20
176
+ order = np.argsort(np.abs(sv_f))[::-1][:TOP_N]
177
+ sv_top = sv_f[order]
178
+ tok_top = tok_f[order]
 
 
 
179
  plot_order = np.argsort(sv_top)
180
+ sv_plot = sv_top[plot_order]
181
+ tok_plot = tok_top[plot_order]
 
 
 
182
 
183
+ COLOR_POSITIVE = "#cc1111"
184
+ COLOR_NEGATIVE = "#1a6fcc"
185
  colors = [COLOR_POSITIVE if v > 0 else COLOR_NEGATIVE for v in sv_plot]
186
 
187
  fig_height = max(4, len(sv_plot) * 0.38)
 
189
  ax.set_facecolor("white")
190
 
191
  y_pos = np.arange(len(sv_plot))
192
+ ax.barh(y_pos, sv_plot, color=colors, height=0.6, edgecolor="none")
 
 
193
  ax.axvline(0, color="#333333", linewidth=0.9, zorder=3)
 
194
  ax.set_yticks(y_pos)
195
  ax.set_yticklabels(tok_plot, fontsize=10, color="#222222")
196
  ax.set_xlabel("SHAP Value β€” impact on ADR prediction", fontsize=10, color="#444444")
 
199
  fontsize=12, fontweight="bold", color="#222222", pad=12
200
  )
201
 
202
+ red_patch = mpatches.Patch(color=COLOR_POSITIVE, label="Increases severe ADR probability")
203
+ blue_patch = mpatches.Patch(color=COLOR_NEGATIVE, label="Decreases severe ADR probability")
204
+ ax.legend(handles=[red_patch, blue_patch], fontsize=9, loc="lower right", framealpha=0.7)
 
 
 
 
205
 
206
  ax.spines["top"].set_visible(False)
207
  ax.spines["right"].set_visible(False)
208
  ax.spines["left"].set_visible(False)
209
  ax.tick_params(axis="y", length=0)
210
  ax.tick_params(axis="x", colors="#555555")
 
211
 
212
  plt.tight_layout()
 
213
  buf = io.BytesIO()
214
+ fig.savefig(buf, format="png", dpi=130, bbox_inches="tight", facecolor="white")
 
215
  plt.close(fig)
216
  buf.seek(0)
217
  b64 = base64.b64encode(buf.read()).decode("utf-8")
218
  return (
219
+ "<div style='background:white; padding:12px; border-radius:8px;'>"
220
+ "<img src='data:image/png;base64," + b64 + "' "
221
+ "style='width:100%; max-width:760px; display:block; margin:auto;' /></div>"
222
+ )
223
+
224
+
225
+ # ── Severity rating widget ─────────────────────────────────────────────────────
226
+ def build_severity_html(severe_prob):
227
+ if severe_prob >= 0.70:
228
+ rating = "HIGH"
229
+ rating_color = "#cc1111"
230
+ rating_bg = "#fff0f0"
231
+ rating_border = "#e07070"
232
+ rating_desc = "Strong indicators of a serious adverse drug reaction are present."
233
+ rating_icon = "\u26a0\ufe0f"
234
+ elif severe_prob >= 0.40:
235
+ rating = "MEDIUM"
236
+ rating_color = "#b07000"
237
+ rating_bg = "#fffbe6"
238
+ rating_border = "#d4a800"
239
+ rating_desc = "Some indicators of an adverse drug reaction detected. Clinical review recommended."
240
+ rating_icon = "\U0001f536"
241
+ else:
242
+ rating = "LOW"
243
+ rating_color = "#2a7a2a"
244
+ rating_bg = "#f0fff0"
245
+ rating_border = "#5aaa5a"
246
+ rating_desc = "Few or no indicators of a severe adverse drug reaction detected."
247
+ rating_icon = "\u2705"
248
+
249
+ bar_pct = int(severe_prob * 100)
250
+
251
+ return (
252
+ "<div style='background:" + rating_bg + "; border:2px solid " + rating_border + "; "
253
+ "border-radius:10px; padding:16px 20px; font-family:system-ui, sans-serif;'>"
254
+
255
+ "<div style='display:flex; align-items:center; gap:12px; margin-bottom:10px;'>"
256
+ "<span style='font-size:1.8em;'>" + rating_icon + "</span>"
257
+
258
+ "<div>"
259
+ "<div style='font-size:0.75em; font-weight:700; letter-spacing:0.1em; "
260
+ "text-transform:uppercase; color:#666;'>ADR Severity Rating</div>"
261
+ "<div style='font-size:1.6em; font-weight:900; color:" + rating_color + "; "
262
+ "letter-spacing:0.04em;'>" + rating + "</div>"
263
+ "</div>"
264
+
265
+ "<div style='margin-left:auto; text-align:right;'>"
266
+ "<div style='font-size:0.72em; color:#888; font-weight:600;'>Severe reaction probability</div>"
267
+ "<div style='font-size:1.4em; font-weight:800; color:" + rating_color + ";'>"
268
+ + str(bar_pct) + "%</div>"
269
+ "</div>"
270
+ "</div>"
271
+
272
+ "<div style='background:#e0e0e0; border-radius:999px; height:10px; margin-bottom:10px;'>"
273
+ "<div style='background:" + rating_color + "; width:" + str(bar_pct) + "%; "
274
+ "height:10px; border-radius:999px;'></div>"
275
+ "</div>"
276
+
277
+ "<div style='font-size:0.88em; color:#555; margin-top:4px;'>" + rating_desc + "</div>"
278
+ "</div>"
279
  )
280
 
281
 
 
286
  output = model(**encoded_input)
287
  scores = torch.softmax(output.logits, dim=-1)[0].detach().cpu().numpy()
288
 
289
+ # SHAP
290
  try:
291
  shap_values = explainer([text_input])
292
  shap_html = render_shap_bar_chart(shap_values[0], class_idx=1)
293
  except Exception as e:
294
+ shap_html = "<p style='color:red;'>SHAP explanation error: " + str(e) + "</p>"
295
 
296
+ # NER
297
  try:
298
  res = ner_pipe(text_input)
 
 
299
  entity_config = {
300
  "Severity": {"bg": "#ffe0de", "border": "#e07070", "label": "Severity"},
301
  "Sign_symptom": {"bg": "#d4f5d4", "border": "#5aaa5a", "label": "Symptom"},
 
307
  }
308
  default_cfg = {"bg": "#f0f0f0", "border": "#aaaaaa", "label": "Other"}
309
 
310
+ seen_groups = list(dict.fromkeys(
311
+ e["entity_group"] for e in sorted(res, key=lambda e: e["start"])
312
+ ))
313
  legend_html = (
314
  "<div style='display:flex; flex-wrap:wrap; gap:10px; "
315
+ "margin-bottom:16px; padding-bottom:12px; border-bottom:1px solid #e0e0e0;'>"
 
316
  )
317
  for grp in seen_groups:
318
+ cfg = entity_config.get(grp, default_cfg)
319
  bg = cfg["bg"]
320
  border = cfg["border"]
321
  lbl = cfg["label"]
322
  legend_html += (
323
+ "<span style='display:inline-flex; align-items:center; gap:6px; "
324
+ "font-size:0.8em; font-weight:600; color:#444; font-family:system-ui, sans-serif;'>"
325
+ "<span style='display:inline-block; width:13px; height:13px; border-radius:3px; "
326
+ "background:" + bg + "; border:2px solid " + border + ";'></span>"
327
+ + lbl + "</span>"
 
 
328
  )
329
  legend_html += "</div>"
330
 
 
331
  text_html = (
332
  "<div style='line-height:3.4; font-size:1.05em; color:#111; "
333
  "font-family:Georgia, serif; letter-spacing:0.01em;'>"
 
342
  border = cfg["border"]
343
  lbl = cfg["label"]
344
 
345
+ text_html += "<span style='color:#111;'>" + text_input[prev_end:start] + "</span>"
 
 
 
346
  text_html += (
347
+ "<span style='display:inline-block; position:relative; "
348
+ "vertical-align:middle; margin:0 2px; text-align:center;'>"
349
+ "<span style='display:block; font-size:0.6em; font-weight:800; "
350
+ "letter-spacing:0.08em; text-transform:uppercase; color:" + border + "; "
351
+ "font-family:system-ui, sans-serif; line-height:1.1; margin-bottom:2px;'>"
352
+ + lbl + "</span>"
353
+ "<span style='background:" + bg + "; border:1.5px solid " + border + "; "
354
+ "color:#111; padding:3px 8px; border-radius:6px; "
355
+ "font-weight:600; white-space:nowrap;'>" + word + "</span>"
356
+ "</span>"
357
  )
358
  prev_end = end
359
 
360
+ text_html += "<span style='color:#111;'>" + text_input[prev_end:] + "</span></div>"
361
  htext = legend_html + text_html
362
 
363
  except Exception as ex:
364
+ htext = "<p style='color:#c00;'>NER processing error: " + str(ex) + "</p>"
365
 
366
  label_output = {
367
  "Severe Reaction": float(scores[1]),
368
  "Non-severe Reaction": float(scores[0]),
369
  }
370
 
371
+ severity_html = build_severity_html(float(scores[1]))
372
+
373
+ return label_output, severity_html, shap_html, htext
374
 
375
 
376
  # ── UI ─────────────────────────────────────────────────────────────────────────
 
428
  gr.Markdown("### Classification")
429
  label = gr.Label(label="Severity Probability")
430
 
431
+ gr.Markdown("### Severity Rating")
432
+ severity_out = gr.HTML(label="Severity Rating", elem_classes="output-box")
433
+
434
  gr.Markdown("### Medical Entities")
435
  htext_out = gr.HTML(label="NER Mapping", elem_classes="output-box")
436
 
 
446
  submit_btn.click(
447
  fn=adr_predict,
448
  inputs=[prob1],
449
+ outputs=[label, severity_out, shap_out, htext_out],
450
  )
451
 
452
  demo.launch()