Spaces:
Sleeping
Sleeping
Upload 2 files
Browse filesFix tokenizer and lifestyle card cleaning
- app.py +17 -23
- function_calling.py +37 -3
app.py
CHANGED
|
@@ -102,9 +102,11 @@ except Exception as _e:
|
|
| 102 |
|
| 103 |
print("Loading model...")
|
| 104 |
try:
|
| 105 |
-
from transformers import AutoModelForCausalLM,
|
| 106 |
|
| 107 |
-
|
|
|
|
|
|
|
| 108 |
|
| 109 |
# On T4 (16GB): load in 4-bit via bitsandbytes to fit the 8B model
|
| 110 |
# On H200/ZeroGPU (80GB): load in bfloat16 β no quantization needed
|
|
@@ -128,9 +130,10 @@ try:
|
|
| 128 |
|
| 129 |
except Exception as e:
|
| 130 |
print(f"β Model load failed entirely: {e}")
|
| 131 |
-
model
|
| 132 |
-
processor
|
| 133 |
-
|
|
|
|
| 134 |
|
| 135 |
# ββ Text cleanup for FDA raw strings βββββββββββββββββββββββββββββββββββββββββ
|
| 136 |
|
|
@@ -161,35 +164,26 @@ def _run_inference(user_prompt: str) -> str:
|
|
| 161 |
{"role": "user", "content": user_prompt},
|
| 162 |
]
|
| 163 |
|
| 164 |
-
# apply_chat_template
|
| 165 |
-
|
| 166 |
messages,
|
| 167 |
-
tokenize
|
| 168 |
add_generation_prompt = True,
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
inputs = processor(
|
| 172 |
-
text = prompt_text,
|
| 173 |
-
return_tensors = "pt",
|
| 174 |
-
truncation = True,
|
| 175 |
-
max_length = MAX_SEQ_LENGTH,
|
| 176 |
).to(DEVICE)
|
| 177 |
|
| 178 |
with torch.no_grad():
|
| 179 |
output_ids = model.generate(
|
| 180 |
-
|
| 181 |
max_new_tokens = MAX_NEW_TOKENS,
|
| 182 |
temperature = 0.7,
|
| 183 |
top_p = 0.9,
|
| 184 |
do_sample = True,
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
if hasattr(processor, "tokenizer")
|
| 188 |
-
else processor.eos_token_id
|
| 189 |
-
),
|
| 190 |
)
|
| 191 |
|
| 192 |
-
new_tokens = output_ids[0][
|
| 193 |
return processor.decode(new_tokens, skip_special_tokens=True).strip()
|
| 194 |
|
| 195 |
|
|
@@ -279,7 +273,7 @@ def on_photo_submit(image, age: int):
|
|
| 279 |
yield "", "β οΈ The model is not loaded. Cannot process the image."
|
| 280 |
return
|
| 281 |
|
| 282 |
-
drug_name = extract_drug_name_from_image(image, model,
|
| 283 |
|
| 284 |
if not drug_name:
|
| 285 |
yield (
|
|
|
|
| 102 |
|
| 103 |
print("Loading model...")
|
| 104 |
try:
|
| 105 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 106 |
|
| 107 |
+
from transformers import AutoProcessor
|
| 108 |
+
processor = AutoTokenizer.from_pretrained(MODEL_PATH) # text inference
|
| 109 |
+
vision_processor = AutoProcessor.from_pretrained(MODEL_PATH) # image (photo tab)
|
| 110 |
|
| 111 |
# On T4 (16GB): load in 4-bit via bitsandbytes to fit the 8B model
|
| 112 |
# On H200/ZeroGPU (80GB): load in bfloat16 β no quantization needed
|
|
|
|
| 130 |
|
| 131 |
except Exception as e:
|
| 132 |
print(f"β Model load failed entirely: {e}")
|
| 133 |
+
model = None
|
| 134 |
+
processor = None
|
| 135 |
+
vision_processor = None
|
| 136 |
+
_MODEL_LOADED = False
|
| 137 |
|
| 138 |
# ββ Text cleanup for FDA raw strings βββββββββββββββββββββββββββββββββββββββββ
|
| 139 |
|
|
|
|
| 164 |
{"role": "user", "content": user_prompt},
|
| 165 |
]
|
| 166 |
|
| 167 |
+
# apply_chat_template formats the prompt in Gemma 4's native ChatML format
|
| 168 |
+
input_ids = processor.apply_chat_template(
|
| 169 |
messages,
|
| 170 |
+
tokenize = True,
|
| 171 |
add_generation_prompt = True,
|
| 172 |
+
return_tensors = "pt",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
).to(DEVICE)
|
| 174 |
|
| 175 |
with torch.no_grad():
|
| 176 |
output_ids = model.generate(
|
| 177 |
+
input_ids,
|
| 178 |
max_new_tokens = MAX_NEW_TOKENS,
|
| 179 |
temperature = 0.7,
|
| 180 |
top_p = 0.9,
|
| 181 |
do_sample = True,
|
| 182 |
+
eos_token_id = processor.eos_token_id, # stop at end-of-turn token
|
| 183 |
+
pad_token_id = processor.eos_token_id,
|
|
|
|
|
|
|
|
|
|
| 184 |
)
|
| 185 |
|
| 186 |
+
new_tokens = output_ids[0][input_ids.shape[1]:]
|
| 187 |
return processor.decode(new_tokens, skip_special_tokens=True).strip()
|
| 188 |
|
| 189 |
|
|
|
|
| 273 |
yield "", "β οΈ The model is not loaded. Cannot process the image."
|
| 274 |
return
|
| 275 |
|
| 276 |
+
drug_name = extract_drug_name_from_image(image, model, vision_processor, device=DEVICE)
|
| 277 |
|
| 278 |
if not drug_name:
|
| 279 |
yield (
|
function_calling.py
CHANGED
|
@@ -111,7 +111,7 @@ def _extract_field(record: Optional[dict], *field_names: str) -> str:
|
|
| 111 |
if val:
|
| 112 |
raw = val[0] if isinstance(val, list) else val
|
| 113 |
# Collapse excessive whitespace from FDA's raw text
|
| 114 |
-
return re.sub(r"\s+", " ", str(raw)).strip()[:
|
| 115 |
return ""
|
| 116 |
|
| 117 |
|
|
@@ -202,6 +202,41 @@ def get_geriatric_warnings(drug_name: str) -> str:
|
|
| 202 |
return ". ".join(relevant[:3]) + "." if relevant else ""
|
| 203 |
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
def get_lifestyle_warnings(drug_list: list[str]) -> dict:
|
| 206 |
"""
|
| 207 |
Extract food, alcohol, and lifestyle interaction warnings for a list of drugs.
|
|
@@ -253,8 +288,7 @@ def get_lifestyle_warnings(drug_list: list[str]) -> dict:
|
|
| 253 |
for sentence in sentences:
|
| 254 |
sentence_lower = sentence.lower()
|
| 255 |
if any(kw in sentence_lower for kw in keywords):
|
| 256 |
-
|
| 257 |
-
clean = sentence[:200].strip()
|
| 258 |
if clean and clean not in hits:
|
| 259 |
hits.append(clean)
|
| 260 |
if hits:
|
|
|
|
| 111 |
if val:
|
| 112 |
raw = val[0] if isinstance(val, list) else val
|
| 113 |
# Collapse excessive whitespace from FDA's raw text
|
| 114 |
+
return re.sub(r"\s+", " ", str(raw)).strip()[:2000]
|
| 115 |
return ""
|
| 116 |
|
| 117 |
|
|
|
|
| 202 |
return ". ".join(relevant[:3]) + "." if relevant else ""
|
| 203 |
|
| 204 |
|
| 205 |
+
_SECTION_HEADER_RE = re.compile(r"^\s*\d+(?:\.\d+)?\s+[A-Z][A-Z\s]+") # "7 DRUG INTERACTIONS"
|
| 206 |
+
_PAREN_REF_RE = re.compile(r"\(\s*\d+(?:\.\d+)?\s*\)") # "( 5.1 )"
|
| 207 |
+
_LEADING_NUM_RE = re.compile(r"^\s*\d+(?:\.\d+)?\s+") # "2 DOSAGE..."
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def _clean_lifestyle_sentence(sentence: str) -> str:
|
| 211 |
+
"""
|
| 212 |
+
Strip FDA formatting artifacts from a single sentence before display.
|
| 213 |
+
Returns "" if the sentence is just a section header with no useful content.
|
| 214 |
+
"""
|
| 215 |
+
# Reject pure section headers like "7 DRUG INTERACTIONS"
|
| 216 |
+
if _SECTION_HEADER_RE.match(sentence) and len(sentence.split()) <= 5:
|
| 217 |
+
return ""
|
| 218 |
+
|
| 219 |
+
# Strip inline section references like "( 5.1 )" or "( 2 )"
|
| 220 |
+
sentence = _PAREN_REF_RE.sub("", sentence)
|
| 221 |
+
|
| 222 |
+
# Strip leading section numbers like "2 DOSAGE AND ADMINISTRATION"
|
| 223 |
+
sentence = _LEADING_NUM_RE.sub("", sentence)
|
| 224 |
+
|
| 225 |
+
# Collapse whitespace and truncate
|
| 226 |
+
sentence = re.sub(r"\s+", " ", sentence).strip()
|
| 227 |
+
|
| 228 |
+
# Reject if too short after cleaning or still looks like a header (all caps)
|
| 229 |
+
if len(sentence) < 20 or sentence.isupper():
|
| 230 |
+
return ""
|
| 231 |
+
|
| 232 |
+
# Drop sentences that are too long to be a single clean thought
|
| 233 |
+
# (likely mid-paragraph FDA text split at a bad boundary)
|
| 234 |
+
if len(sentence) > 180:
|
| 235 |
+
return ""
|
| 236 |
+
|
| 237 |
+
return sentence
|
| 238 |
+
|
| 239 |
+
|
| 240 |
def get_lifestyle_warnings(drug_list: list[str]) -> dict:
|
| 241 |
"""
|
| 242 |
Extract food, alcohol, and lifestyle interaction warnings for a list of drugs.
|
|
|
|
| 288 |
for sentence in sentences:
|
| 289 |
sentence_lower = sentence.lower()
|
| 290 |
if any(kw in sentence_lower for kw in keywords):
|
| 291 |
+
clean = _clean_lifestyle_sentence(sentence)
|
|
|
|
| 292 |
if clean and clean not in hits:
|
| 293 |
hits.append(clean)
|
| 294 |
if hits:
|