scigeek commited on
Commit
84a71ec
Β·
verified Β·
1 Parent(s): 617b2c8

Upload 2 files

Browse files

Fix tokenizer and lifestyle card cleaning

Files changed (2) hide show
  1. app.py +17 -23
  2. function_calling.py +37 -3
app.py CHANGED
@@ -102,9 +102,11 @@ except Exception as _e:
102
 
103
  print("Loading model...")
104
  try:
105
- from transformers import AutoModelForCausalLM, AutoProcessor
106
 
107
- processor = AutoProcessor.from_pretrained(MODEL_PATH)
 
 
108
 
109
  # On T4 (16GB): load in 4-bit via bitsandbytes to fit the 8B model
110
  # On H200/ZeroGPU (80GB): load in bfloat16 β€” no quantization needed
@@ -128,9 +130,10 @@ try:
128
 
129
  except Exception as e:
130
  print(f"βœ— Model load failed entirely: {e}")
131
- model = None
132
- processor = None
133
- _MODEL_LOADED = False
 
134
 
135
  # ── Text cleanup for FDA raw strings ─────────────────────────────────────────
136
 
@@ -161,35 +164,26 @@ def _run_inference(user_prompt: str) -> str:
161
  {"role": "user", "content": user_prompt},
162
  ]
163
 
164
- # apply_chat_template builds the native Gemma 4 ChatML format
165
- prompt_text = processor.apply_chat_template(
166
  messages,
167
- tokenize = False,
168
  add_generation_prompt = True,
169
- ) if hasattr(processor, "apply_chat_template") else user_prompt
170
-
171
- inputs = processor(
172
- text = prompt_text,
173
- return_tensors = "pt",
174
- truncation = True,
175
- max_length = MAX_SEQ_LENGTH,
176
  ).to(DEVICE)
177
 
178
  with torch.no_grad():
179
  output_ids = model.generate(
180
- **inputs,
181
  max_new_tokens = MAX_NEW_TOKENS,
182
  temperature = 0.7,
183
  top_p = 0.9,
184
  do_sample = True,
185
- pad_token_id = (
186
- processor.tokenizer.eos_token_id
187
- if hasattr(processor, "tokenizer")
188
- else processor.eos_token_id
189
- ),
190
  )
191
 
192
- new_tokens = output_ids[0][inputs["input_ids"].shape[1]:]
193
  return processor.decode(new_tokens, skip_special_tokens=True).strip()
194
 
195
 
@@ -279,7 +273,7 @@ def on_photo_submit(image, age: int):
279
  yield "", "⚠️ The model is not loaded. Cannot process the image."
280
  return
281
 
282
- drug_name = extract_drug_name_from_image(image, model, processor, device=DEVICE)
283
 
284
  if not drug_name:
285
  yield (
 
102
 
103
  print("Loading model...")
104
  try:
105
+ from transformers import AutoModelForCausalLM, AutoTokenizer
106
 
107
+ from transformers import AutoProcessor
108
+ processor = AutoTokenizer.from_pretrained(MODEL_PATH) # text inference
109
+ vision_processor = AutoProcessor.from_pretrained(MODEL_PATH) # image (photo tab)
110
 
111
  # On T4 (16GB): load in 4-bit via bitsandbytes to fit the 8B model
112
  # On H200/ZeroGPU (80GB): load in bfloat16 β€” no quantization needed
 
130
 
131
  except Exception as e:
132
  print(f"βœ— Model load failed entirely: {e}")
133
+ model = None
134
+ processor = None
135
+ vision_processor = None
136
+ _MODEL_LOADED = False
137
 
138
  # ── Text cleanup for FDA raw strings ─────────────────────────────────────────
139
 
 
164
  {"role": "user", "content": user_prompt},
165
  ]
166
 
167
+ # apply_chat_template formats the prompt in Gemma 4's native ChatML format
168
+ input_ids = processor.apply_chat_template(
169
  messages,
170
+ tokenize = True,
171
  add_generation_prompt = True,
172
+ return_tensors = "pt",
 
 
 
 
 
 
173
  ).to(DEVICE)
174
 
175
  with torch.no_grad():
176
  output_ids = model.generate(
177
+ input_ids,
178
  max_new_tokens = MAX_NEW_TOKENS,
179
  temperature = 0.7,
180
  top_p = 0.9,
181
  do_sample = True,
182
+ eos_token_id = processor.eos_token_id, # stop at end-of-turn token
183
+ pad_token_id = processor.eos_token_id,
 
 
 
184
  )
185
 
186
+ new_tokens = output_ids[0][input_ids.shape[1]:]
187
  return processor.decode(new_tokens, skip_special_tokens=True).strip()
188
 
189
 
 
273
  yield "", "⚠️ The model is not loaded. Cannot process the image."
274
  return
275
 
276
+ drug_name = extract_drug_name_from_image(image, model, vision_processor, device=DEVICE)
277
 
278
  if not drug_name:
279
  yield (
function_calling.py CHANGED
@@ -111,7 +111,7 @@ def _extract_field(record: Optional[dict], *field_names: str) -> str:
111
  if val:
112
  raw = val[0] if isinstance(val, list) else val
113
  # Collapse excessive whitespace from FDA's raw text
114
- return re.sub(r"\s+", " ", str(raw)).strip()[:600]
115
  return ""
116
 
117
 
@@ -202,6 +202,41 @@ def get_geriatric_warnings(drug_name: str) -> str:
202
  return ". ".join(relevant[:3]) + "." if relevant else ""
203
 
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def get_lifestyle_warnings(drug_list: list[str]) -> dict:
206
  """
207
  Extract food, alcohol, and lifestyle interaction warnings for a list of drugs.
@@ -253,8 +288,7 @@ def get_lifestyle_warnings(drug_list: list[str]) -> dict:
253
  for sentence in sentences:
254
  sentence_lower = sentence.lower()
255
  if any(kw in sentence_lower for kw in keywords):
256
- # Truncate long sentences
257
- clean = sentence[:200].strip()
258
  if clean and clean not in hits:
259
  hits.append(clean)
260
  if hits:
 
111
  if val:
112
  raw = val[0] if isinstance(val, list) else val
113
  # Collapse excessive whitespace from FDA's raw text
114
+ return re.sub(r"\s+", " ", str(raw)).strip()[:2000]
115
  return ""
116
 
117
 
 
202
  return ". ".join(relevant[:3]) + "." if relevant else ""
203
 
204
 
205
+ _SECTION_HEADER_RE = re.compile(r"^\s*\d+(?:\.\d+)?\s+[A-Z][A-Z\s]+") # "7 DRUG INTERACTIONS"
206
+ _PAREN_REF_RE = re.compile(r"\(\s*\d+(?:\.\d+)?\s*\)") # "( 5.1 )"
207
+ _LEADING_NUM_RE = re.compile(r"^\s*\d+(?:\.\d+)?\s+") # "2 DOSAGE..."
208
+
209
+
210
+ def _clean_lifestyle_sentence(sentence: str) -> str:
211
+ """
212
+ Strip FDA formatting artifacts from a single sentence before display.
213
+ Returns "" if the sentence is just a section header with no useful content.
214
+ """
215
+ # Reject pure section headers like "7 DRUG INTERACTIONS"
216
+ if _SECTION_HEADER_RE.match(sentence) and len(sentence.split()) <= 5:
217
+ return ""
218
+
219
+ # Strip inline section references like "( 5.1 )" or "( 2 )"
220
+ sentence = _PAREN_REF_RE.sub("", sentence)
221
+
222
+ # Strip leading section numbers like "2 DOSAGE AND ADMINISTRATION"
223
+ sentence = _LEADING_NUM_RE.sub("", sentence)
224
+
225
+ # Collapse whitespace and truncate
226
+ sentence = re.sub(r"\s+", " ", sentence).strip()
227
+
228
+ # Reject if too short after cleaning or still looks like a header (all caps)
229
+ if len(sentence) < 20 or sentence.isupper():
230
+ return ""
231
+
232
+ # Drop sentences that are too long to be a single clean thought
233
+ # (likely mid-paragraph FDA text split at a bad boundary)
234
+ if len(sentence) > 180:
235
+ return ""
236
+
237
+ return sentence
238
+
239
+
240
  def get_lifestyle_warnings(drug_list: list[str]) -> dict:
241
  """
242
  Extract food, alcohol, and lifestyle interaction warnings for a list of drugs.
 
288
  for sentence in sentences:
289
  sentence_lower = sentence.lower()
290
  if any(kw in sentence_lower for kw in keywords):
291
+ clean = _clean_lifestyle_sentence(sentence)
 
292
  if clean and clean not in hits:
293
  hits.append(clean)
294
  if hits: