Janushi commited on
Commit
ef8cea9
·
1 Parent(s): 9ee2757

Update app with casual examples, CPU mode, state persistence

Browse files
Files changed (1) hide show
  1. app.py +151 -42
app.py CHANGED
@@ -9,13 +9,24 @@ MODEL_ID = "Janushi/ClinicalDistill-Gemma-1B"
9
  tokenizer = None
10
  model = None
11
 
12
- INSTRUCTION = """Extract symptoms from the clinical note and return JSON with this exact format:
 
 
13
  {
14
  "symptoms": ["symptom1", "symptom2"],
15
  "duration": ["duration1", "duration2"],
16
  "severity": ["severity1", "severity2"],
17
  "urgent": true/false
18
- }"""
 
 
 
 
 
 
 
 
 
19
 
20
  EXAMPLES = [
21
  ["been feeling off for a few days, chest feels weird and i get tired just walking to the kitchen"],
@@ -59,6 +70,28 @@ CSS = """
59
  footer { display: none !important; }
60
  """
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def load_model():
63
  global tokenizer, model
64
  if model is not None:
@@ -67,10 +100,11 @@ def load_model():
67
  model = AutoModelForCausalLM.from_pretrained(
68
  MODEL_ID,
69
  torch_dtype=torch.float32, # float32 for CPU
70
- device_map="cpu"
71
  )
72
  model.eval()
73
 
 
74
  def build_prompt(clinical_note: str) -> str:
75
  return (
76
  f"<instruction>\n{INSTRUCTION}\n</instruction>\n\n"
@@ -78,36 +112,82 @@ def build_prompt(clinical_note: str) -> str:
78
  f"<output>\n"
79
  )
80
 
 
81
  def parse_output(raw: str) -> dict:
82
  raw = raw.split("</output>")[0].strip()
83
  match = re.search(r"\{.*\}", raw, re.DOTALL)
84
  if match:
85
  raw = match.group(0)
86
  result = json.loads(raw)
 
87
  for key in ("symptoms", "duration", "severity"):
88
  if key not in result or not isinstance(result[key], list):
89
  result[key] = []
90
  if "urgent" not in result:
91
  result["urgent"] = False
92
- n = max(len(result["symptoms"]), len(result["duration"]), len(result["severity"]), 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  for key in ("symptoms", "duration", "severity"):
94
  while len(result[key]) < n:
95
  result[key].append("unspecified")
 
96
  return result
97
 
 
98
  def severity_badge(s: str) -> str:
99
  s = (s or "").lower().strip()
100
- if not s or s in ("unspecified", "unknown", "n/a", "none", "not mentioned", "—"):
101
  return '<span style="color:#d1d5db;font-style:italic">—</span>'
102
- if any(w in s for w in ("severe", "critical", "extreme", "crushing", "sudden", "acute", "high")):
103
  return f'<span style="background:#fef2f2;color:#dc2626;font-weight:600;padding:2px 8px;border-radius:4px;font-size:0.85rem">▲ {s}</span>'
104
- if any(w in s for w in ("moderate", "significant", "worsening", "progressive")):
105
  return f'<span style="background:#fffbeb;color:#d97706;font-weight:600;padding:2px 8px;border-radius:4px;font-size:0.85rem">● {s}</span>'
106
- if any(w in s for w in ("mild", "slight", "minor", "low", "minimal")):
107
  return f'<span style="background:#f0fdf4;color:#16a34a;font-weight:600;padding:2px 8px;border-radius:4px;font-size:0.85rem">▼ {s}</span>'
108
  return f'<span style="background:#f1f5f9;color:#475569;padding:2px 8px;border-radius:4px;font-size:0.85rem">{s}</span>'
109
 
110
- def format_results(result: dict):
 
 
 
 
 
 
 
 
111
  symptoms = result["symptoms"]
112
  durations = result["duration"]
113
  severities = result["severity"]
@@ -120,17 +200,11 @@ def format_results(result: dict):
120
  sev = severities[i] if i < len(severities) else "unspecified"
121
  bg = "#fff7f7" if urgent else ("#f8faff" if i % 2 == 0 else "white")
122
 
123
- dur_html = (
124
- '<span style="color:#d1d5db;font-style:italic">—</span>'
125
- if dur.lower() in ("unspecified", "unknown", "", "n/a")
126
- else f'<span style="color:#4b5563">{dur}</span>'
127
- )
128
-
129
  rows += f"""
130
  <tr style="background:{bg}">
131
  <td style="padding:10px 14px;font-weight:600;color:#111827;
132
  border-left:3px solid {accent}">{sym}</td>
133
- <td style="padding:10px 14px">{dur_html}</td>
134
  <td style="padding:10px 14px">{severity_badge(sev)}</td>
135
  </tr>"""
136
 
@@ -157,9 +231,15 @@ def format_results(result: dict):
157
  </div>
158
  """
159
 
 
160
  def extract(clinical_note: str, state: dict):
161
  if not clinical_note.strip():
162
- return "<p style='color:#9ca3af'>Enter a clinical note to see results.</p>", "{}", state
 
 
 
 
 
163
 
164
  load_model()
165
  prompt = build_prompt(clinical_note)
@@ -183,10 +263,10 @@ def extract(clinical_note: str, state: dict):
183
  table_html = format_results(result)
184
  json_out = json.dumps(result, indent=2)
185
  new_state = {"table": table_html, "json": json_out}
186
- return table_html, json_out, new_state
187
  except (json.JSONDecodeError, KeyError, IndexError):
188
- err = f"<pre style='color:#dc2626'>Parse error:\n{generated}</pre>"
189
- return err, "{}", state
190
 
191
 
192
  with gr.Blocks(css=CSS, title="ClinicalDistill") as demo:
@@ -194,27 +274,49 @@ with gr.Blocks(css=CSS, title="ClinicalDistill") as demo:
194
  result_state = gr.State(value={"table": "", "json": "{}"})
195
 
196
  gr.HTML("""
197
- <h1 id="title" style="font-size:2rem;font-weight:800;background:linear-gradient(135deg,#667eea,#764ba2);
198
- -webkit-background-clip:text;-webkit-text-fill-color:transparent;margin-top:1rem">
 
 
199
  🏥 ClinicalDistill
200
  </h1>
201
- <p id="subtitle">Structured symptom extraction from clinical notes · Gemma-3-1B fine-tuned with LoRA</p>
 
 
 
202
  <div id="stats-row">
203
- <div class="stat-card"><div class="stat-val">0.781</div><div class="stat-lbl">F1 Score</div></div>
204
- <div class="stat-card"><div class="stat-val">85.7%</div><div class="stat-lbl">Urgent Accuracy</div></div>
205
- <div class="stat-card"><div class="stat-val">100%</div><div class="stat-lbl">Valid JSON</div></div>
206
- <div class="stat-card"><div class="stat-val">1B</div><div class="stat-lbl">Parameters</div></div>
 
 
 
 
 
 
 
 
 
 
 
 
207
  </div>
208
  """)
209
 
210
- gr.HTML("""
211
- <div style="text-align:center;margin-bottom:1rem;padding:0.6rem 1rem;
212
- background:#fffbeb;border:1px solid #fde68a;border-radius:8px;
213
- color:#92400e;font-size:0.85rem;max-width:600px;margin-left:auto;margin-right:auto">
214
- Running on CPU — inference takes ~60 seconds.
215
- Results persist after completion.
216
- </div>
217
- """)
 
 
 
 
 
218
 
219
  with gr.Row():
220
  with gr.Column(scale=1):
@@ -226,13 +328,18 @@ with gr.Blocks(css=CSS, title="ClinicalDistill") as demo:
226
  submit_btn = gr.Button(
227
  "⚡ Extract Symptoms",
228
  variant="primary",
229
- elem_id="submit-btn"
 
 
 
 
 
230
  )
231
- gr.Examples(examples=EXAMPLES, inputs=note_input, label="Try an Example")
232
 
233
  with gr.Column(scale=1):
234
  table_output = gr.HTML(
235
- value="<p style='color:#9ca3af;text-align:center;margin-top:2rem'>Results will appear here.</p>",
 
236
  )
237
  with gr.Accordion("Raw JSON Output", open=False):
238
  json_output = gr.Code(language="json", label="")
@@ -240,19 +347,21 @@ with gr.Blocks(css=CSS, title="ClinicalDistill") as demo:
240
  submit_btn.click(
241
  fn=extract,
242
  inputs=[note_input, result_state],
243
- outputs=[table_output, json_output, result_state]
244
  )
245
  note_input.submit(
246
  fn=extract,
247
  inputs=[note_input, result_state],
248
- outputs=[table_output, json_output, result_state]
249
  )
250
 
251
  gr.HTML("""
252
  <div style="text-align:center;margin-top:2rem;padding-top:1rem;
253
  border-top:1px solid #e5e7eb;color:#9ca3af;font-size:0.85rem">
254
- ClinicalDistill · Fine-tuned on 145 synthetic clinical examples ·
255
- <a href="https://github.com/JanushiShastri/ClinicalDistill" style="color:#667eea">GitHub</a>
 
 
256
  </div>
257
  """)
258
 
 
9
  tokenizer = None
10
  model = None
11
 
12
+ INSTRUCTION = """You are a clinical NLP model. Extract ONLY medical symptoms from the clinical note.
13
+
14
+ Return JSON in this exact format:
15
  {
16
  "symptoms": ["symptom1", "symptom2"],
17
  "duration": ["duration1", "duration2"],
18
  "severity": ["severity1", "severity2"],
19
  "urgent": true/false
20
+ }
21
+
22
+ Rules:
23
+ - symptoms: ONLY medical symptoms (fever, back pain, headache, nausea, cough, dizziness). NOT observations, context, or descriptions like "seems okay", "a little cranky", "not sure"
24
+ - duration: how long each symptom has lasted. Use "unspecified" if not mentioned
25
+ - severity: how severe each symptom is. Use "unspecified" if not clearly stated — do NOT guess severity
26
+ - urgent=true ONLY for: chest pain, difficulty breathing, stroke symptoms (slurred speech, facial drooping, arm weakness), severe bleeding, loss of consciousness
27
+ - urgent=false for: back pain, headache, nausea, fever, diarrhea, fatigue, sneezing, runny nose, cough, irritability, dizziness, stomach ache
28
+ - Never duplicate symptoms
29
+ - All arrays must be the same length"""
30
 
31
  EXAMPLES = [
32
  ["been feeling off for a few days, chest feels weird and i get tired just walking to the kitchen"],
 
70
  footer { display: none !important; }
71
  """
72
 
73
+ # Phrases that are NOT valid medical symptoms
74
+ NON_SYMPTOM_PHRASES = [
75
+ "seems okay", "seems fine", "otherwise fine", "no fever", "a little",
76
+ "otherwise", "seems", "appears", "looks", "none", "normal", "okay",
77
+ "fine", "not sure", "cranky", "irritable", "fussy", "acting up",
78
+ "feeling off", "feeling drained", "feeling tired", "feeling weak"
79
+ ]
80
+
81
+ def is_valid_symptom(s: str) -> bool:
82
+ s_lower = s.lower().strip()
83
+ # Too long to be a real symptom (more than 5 words)
84
+ if len(s_lower.split()) > 5:
85
+ return False
86
+ # Contains non-symptom phrases
87
+ if any(phrase in s_lower for phrase in NON_SYMPTOM_PHRASES):
88
+ return False
89
+ # Too short to be meaningful
90
+ if len(s_lower) < 3:
91
+ return False
92
+ return True
93
+
94
+
95
  def load_model():
96
  global tokenizer, model
97
  if model is not None:
 
100
  model = AutoModelForCausalLM.from_pretrained(
101
  MODEL_ID,
102
  torch_dtype=torch.float32, # float32 for CPU
103
+ device_map="cpu",
104
  )
105
  model.eval()
106
 
107
+
108
  def build_prompt(clinical_note: str) -> str:
109
  return (
110
  f"<instruction>\n{INSTRUCTION}\n</instruction>\n\n"
 
112
  f"<output>\n"
113
  )
114
 
115
+
116
  def parse_output(raw: str) -> dict:
117
  raw = raw.split("</output>")[0].strip()
118
  match = re.search(r"\{.*\}", raw, re.DOTALL)
119
  if match:
120
  raw = match.group(0)
121
  result = json.loads(raw)
122
+
123
  for key in ("symptoms", "duration", "severity"):
124
  if key not in result or not isinstance(result[key], list):
125
  result[key] = []
126
  if "urgent" not in result:
127
  result["urgent"] = False
128
+
129
+ # Deduplicate symptoms preserving order
130
+ seen = set()
131
+ unique_indices = []
132
+ for i, sym in enumerate(result["symptoms"]):
133
+ sym_lower = sym.lower().strip()
134
+ if sym_lower not in seen:
135
+ seen.add(sym_lower)
136
+ unique_indices.append(i)
137
+
138
+ result["symptoms"] = [result["symptoms"][i] for i in unique_indices]
139
+ result["duration"] = [
140
+ result["duration"][i] if i < len(result["duration"]) else "unspecified"
141
+ for i in unique_indices
142
+ ]
143
+ result["severity"] = [
144
+ result["severity"][i] if i < len(result["severity"]) else "unspecified"
145
+ for i in unique_indices
146
+ ]
147
+
148
+ # Filter out non-medical symptom descriptions
149
+ valid_indices = [
150
+ i for i, sym in enumerate(result["symptoms"])
151
+ if is_valid_symptom(sym)
152
+ ]
153
+ # Keep at least one symptom even if filter removes everything
154
+ if not valid_indices and result["symptoms"]:
155
+ valid_indices = [0]
156
+
157
+ result["symptoms"] = [result["symptoms"][i] for i in valid_indices]
158
+ result["duration"] = [result["duration"][i] for i in valid_indices]
159
+ result["severity"] = [result["severity"][i] for i in valid_indices]
160
+
161
+ # Pad arrays to same length
162
+ n = len(result["symptoms"]) or 1
163
  for key in ("symptoms", "duration", "severity"):
164
  while len(result[key]) < n:
165
  result[key].append("unspecified")
166
+
167
  return result
168
 
169
+
170
  def severity_badge(s: str) -> str:
171
  s = (s or "").lower().strip()
172
+ if not s or s in ("unspecified", "unknown", "n/a", "none", "not mentioned", "—", "not stated"):
173
  return '<span style="color:#d1d5db;font-style:italic">—</span>'
174
+ if any(w in s for w in ("severe", "critical", "extreme", "crushing", "sudden", "acute", "high", "sharp")):
175
  return f'<span style="background:#fef2f2;color:#dc2626;font-weight:600;padding:2px 8px;border-radius:4px;font-size:0.85rem">▲ {s}</span>'
176
+ if any(w in s for w in ("moderate", "significant", "worsening", "progressive", "persistent")):
177
  return f'<span style="background:#fffbeb;color:#d97706;font-weight:600;padding:2px 8px;border-radius:4px;font-size:0.85rem">● {s}</span>'
178
+ if any(w in s for w in ("mild", "slight", "minor", "low", "minimal", "light")):
179
  return f'<span style="background:#f0fdf4;color:#16a34a;font-weight:600;padding:2px 8px;border-radius:4px;font-size:0.85rem">▼ {s}</span>'
180
  return f'<span style="background:#f1f5f9;color:#475569;padding:2px 8px;border-radius:4px;font-size:0.85rem">{s}</span>'
181
 
182
+
183
+ def format_duration(d: str) -> str:
184
+ d = (d or "").lower().strip()
185
+ if not d or d in ("unspecified", "unknown", "n/a", "none", "not mentioned", "not stated"):
186
+ return '<span style="color:#d1d5db;font-style:italic">—</span>'
187
+ return f'<span style="color:#4b5563">{d}</span>'
188
+
189
+
190
+ def format_results(result: dict) -> str:
191
  symptoms = result["symptoms"]
192
  durations = result["duration"]
193
  severities = result["severity"]
 
200
  sev = severities[i] if i < len(severities) else "unspecified"
201
  bg = "#fff7f7" if urgent else ("#f8faff" if i % 2 == 0 else "white")
202
 
 
 
 
 
 
 
203
  rows += f"""
204
  <tr style="background:{bg}">
205
  <td style="padding:10px 14px;font-weight:600;color:#111827;
206
  border-left:3px solid {accent}">{sym}</td>
207
+ <td style="padding:10px 14px">{format_duration(dur)}</td>
208
  <td style="padding:10px 14px">{severity_badge(sev)}</td>
209
  </tr>"""
210
 
 
231
  </div>
232
  """
233
 
234
+
235
  def extract(clinical_note: str, state: dict):
236
  if not clinical_note.strip():
237
+ return (
238
+ "<p style='color:#9ca3af'>Enter a clinical note to see results.</p>",
239
+ "{}",
240
+ state,
241
+ gr.update(visible=True), # keep warning visible
242
+ )
243
 
244
  load_model()
245
  prompt = build_prompt(clinical_note)
 
263
  table_html = format_results(result)
264
  json_out = json.dumps(result, indent=2)
265
  new_state = {"table": table_html, "json": json_out}
266
+ return table_html, json_out, new_state, gr.update(visible=False) # hide warning
267
  except (json.JSONDecodeError, KeyError, IndexError):
268
+ err = f"<pre style='color:#dc2626'>Parse error. Raw output:\n{generated}</pre>"
269
+ return err, "{}", state, gr.update(visible=False)
270
 
271
 
272
  with gr.Blocks(css=CSS, title="ClinicalDistill") as demo:
 
274
  result_state = gr.State(value={"table": "", "json": "{}"})
275
 
276
  gr.HTML("""
277
+ <h1 id="title" style="font-size:2rem;font-weight:800;
278
+ background:linear-gradient(135deg,#667eea,#764ba2);
279
+ -webkit-background-clip:text;-webkit-text-fill-color:transparent;
280
+ margin-top:1rem">
281
  🏥 ClinicalDistill
282
  </h1>
283
+ <p id="subtitle">
284
+ Structured symptom extraction from clinical notes ·
285
+ Gemma-3-1B fine-tuned with LoRA
286
+ </p>
287
  <div id="stats-row">
288
+ <div class="stat-card">
289
+ <div class="stat-val">0.781</div>
290
+ <div class="stat-lbl">F1 Score</div>
291
+ </div>
292
+ <div class="stat-card">
293
+ <div class="stat-val">85.7%</div>
294
+ <div class="stat-lbl">Urgent Accuracy</div>
295
+ </div>
296
+ <div class="stat-card">
297
+ <div class="stat-val">100%</div>
298
+ <div class="stat-lbl">Valid JSON</div>
299
+ </div>
300
+ <div class="stat-card">
301
+ <div class="stat-val">1B</div>
302
+ <div class="stat-lbl">Parameters</div>
303
+ </div>
304
  </div>
305
  """)
306
 
307
+ # Warning banner — hidden after first inference
308
+ cpu_warning = gr.HTML(
309
+ value="""
310
+ <div style="text-align:center;margin-bottom:1rem;padding:0.6rem 1rem;
311
+ background:#fffbeb;border:1px solid #fde68a;border-radius:8px;
312
+ color:#92400e;font-size:0.85rem;max-width:600px;
313
+ margin-left:auto;margin-right:auto">
314
+ ⏳ Running on CPU — inference takes ~60 seconds.
315
+ Results persist after completion.
316
+ </div>
317
+ """,
318
+ visible=True,
319
+ )
320
 
321
  with gr.Row():
322
  with gr.Column(scale=1):
 
328
  submit_btn = gr.Button(
329
  "⚡ Extract Symptoms",
330
  variant="primary",
331
+ elem_id="submit-btn",
332
+ )
333
+ gr.Examples(
334
+ examples=EXAMPLES,
335
+ inputs=note_input,
336
+ label="Try an Example",
337
  )
 
338
 
339
  with gr.Column(scale=1):
340
  table_output = gr.HTML(
341
+ value="<p style='color:#9ca3af;text-align:center;margin-top:2rem'>"
342
+ "Results will appear here.</p>",
343
  )
344
  with gr.Accordion("Raw JSON Output", open=False):
345
  json_output = gr.Code(language="json", label="")
 
347
  submit_btn.click(
348
  fn=extract,
349
  inputs=[note_input, result_state],
350
+ outputs=[table_output, json_output, result_state, cpu_warning],
351
  )
352
  note_input.submit(
353
  fn=extract,
354
  inputs=[note_input, result_state],
355
+ outputs=[table_output, json_output, result_state, cpu_warning],
356
  )
357
 
358
  gr.HTML("""
359
  <div style="text-align:center;margin-top:2rem;padding-top:1rem;
360
  border-top:1px solid #e5e7eb;color:#9ca3af;font-size:0.85rem">
361
+ ClinicalDistill · Fine-tuned on 145 synthetic clinical examples
362
+ (cardiac, respiratory, neurological, GI) ·
363
+ <a href="https://github.com/JanushiShastri/ClinicalDistill"
364
+ style="color:#667eea">GitHub</a>
365
  </div>
366
  """)
367