Kalaoke's picture
Create handler.py
54c22a0 verified
from __future__ import annotations
import base64
import re
import unicodedata
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers.utils import logging
logger = logging.get_logger(__name__)
logging.set_verbosity_info()
CATEGORIES_REPO = "HCKLab/pixtral-12b-foodwaste-classification-merged-2"
WEIGHT_REPO = "HCKLab/pixtral-12b-foodwaste-weight-merged"
def _clean_text(s: str) -> str:
s = unicodedata.normalize("NFKC", s).replace("\u00A0", " ")
s = re.sub(r"[\u200B-\u200F]", "", s)
s = re.sub(r"\s+", " ", s).strip()
return s
def _decode_image(image_b64: str) -> Image.Image:
try:
img_bytes = base64.b64decode(image_b64)
img = Image.open(BytesIO(img_bytes)).convert("RGB")
return img
except Exception as exc: # pragma: no cover - log production
raise ValueError(f"Could not decode base64 image: {exc}") from exc
def _resize_max_side(img: Image.Image, max_side: int) -> Image.Image:
w, h = img.size
m = max(w, h)
if m <= max_side:
return img
scale = max_side / m
return img.resize((int(w * scale), int(h * scale)), resample=Image.Resampling.LANCZOS)
def _build_chat_text(prompt: str, processor: AutoProcessor) -> str:
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
}
]
chat_text = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
)
return chat_text
# ----------------------------
# Categories task
# ----------------------------
@dataclass
class CategoriesGenConfig:
max_new_tokens: int = 128
temperature: float = 0.0
no_repeat_ngram_size: int = 6
repetition_penalty: float = 1.1
max_length: int = 4096
max_side: int = 768
CATEGORIES_DEFAULT_PROMPT = (
"Here is a picture showing some food waste.\n\n"
"Task: Provide a list of the food waste items visible in the picture.\n"
"For each item, output EXACTLY one CATEGORY chosen from the CATEGORY DICTIONARY below.\n\n"
"Rules:\n"
"- Output must be a single line (no extra text, no markdown).\n"
"- Use the exact spelling of CATEGORY as listed.\n"
"- If you are unsure, choose the closest broader category (e.g., 'fruit', 'vegetable', 'food_waste').\n\n"
"CATEGORY DICTIONARY:\n"
"1. beer\n"
"2. cabbage\n"
"3. whipped_porridge\n"
"4. cider\n"
"5. orange\n"
"6. dairy_product_milk\n"
"7. yoghurt\n"
"8. herbs\n"
"9. potato_product\n"
"10. strawberry\n"
"11. minced_chicken\n"
"12. soda\n"
"13. sauce_paste\n"
"14. rice_porridge\n"
"15. blueberry\n"
"16. light_bread\n"
"17. semolina_porridge\n"
"18. egg\n"
"19. crepes\n"
"20. sliced_cheese\n"
"21. chicken\n"
"22. food_waste\n"
"23. candy\n"
"24. vegetarian_sauce\n"
"25. cheese\n"
"26. cereal\n"
"27. pork_product\n"
"28. vegetable\n"
"29. honey\n"
"30. plant_based_cream\n"
"31. vegetarian_pizza\n"
"32. dried fruits\n"
"33. rice_or_pasta\n"
"34. pork_steak\n"
"35. wine\n"
"36. soft_drink\n"
"37. minced_beef\n"
"38. steak\n"
"39. frankfurter\n"
"40. vegetarian_soup\n"
"41. beef_frankfurter\n"
"42. alcoholic_long_drink\n"
"43. raspberry\n"
"44. rice\n"
"45. mandarin\n"
"46. juice\n"
"47. porridge\n"
"48. fish_soup\n"
"49. fish_hamburger\n"
"50. fruit\n"
"51. coffee\n"
"52. plant_based_ice_cream\n"
"53. beef_sausage\n"
"54. minced_pork\n"
"55. meat\n"
"56. sweet_pastry\n"
"57. ice_cream\n"
"58. fish_salad\n"
"59. flour\n"
"60. snack\n"
"61. fish_product\n"
"62. cherry\n"
"63. shellfish\n"
"64. cream\n"
"65. dessert\n"
"66. cold_cut_chicken\n"
"67. onion\n"
"68. dark_bread\n"
"69. plant_based_milk\n"
"70. fermented_milk\n"
"71. cocoa\n"
"72. bell_pepper\n"
"73. beef\n"
"74. sausage\n"
"75. dairy_product\n"
"76. pork\n"
"77. fish_fillet\n"
"78. strong_alcoholic_beverage\n"
"79. biscuit\n"
"80. currant\n"
"81. sweet_soup\n"
"82. popcorn\n"
"83. meat_sauce\n"
"84. meat_pizza\n"
"85. lemon\n"
"86. plum\n"
"87. grain_products\n"
"88. sugar_honey_syrup\n"
"89. vegetarian_hamburger\n"
"90. jam\n"
"91. vegetarian_stew\n"
"92. beef_steak\n"
"93. tomato\n"
"94. meat_soup\n"
"95. block_cheese\n"
"96. potato_based\n"
"97. potato\n"
"98. flakes\n"
"99. soft_cheese\n"
"100. beef_product\n"
"101. chocolate\n"
"102. cauliflower\n"
"103. nectarine\n"
"104. banana\n"
"105. apple\n"
"106. alcoholic_beverage\n"
"107. meat_salad\n"
"108. pork_sausage\n"
"109. lingonberry\n"
"110. chicken_sausage\n"
"111. pork_frankfurter\n"
"112. minced_meat\n"
"113. baked_goods\n"
"114. broccoli\n"
"115. cold_cut_beef\n"
"116. meat_dish\n"
"117. plant_cheese\n"
"118. butter\n"
"119. children_milk\n"
"120. fish_sauce\n"
"121. buttermilk\n"
"122. cucumber\n"
"123. meat_hamburger\n"
"124. pasta\n"
"125. tea\n"
"126. berries\n"
"127. milk\n"
"128. fish_pizza\n"
"129. fresh_cheese\n"
"130. fish_stew\n"
"131. sirup\n"
"132. vegetarian_salad\n"
"133. spice\n"
"134. quark\n"
"135. cold_cut_pork\n"
"136. sauce_seasoning\n"
"137. oatmeal\n"
"138. sugar\n"
"139. meat_product\n"
"140. vegetarian_dish\n"
"141. savory_pastry\n"
"142. chicken_product\n"
"143. vegetable_mix\n"
"144. potato_chip\n"
"145. muesli\n"
"146. carrot\n"
"147. plant_based_yogurt\n"
"148. sweet\n"
"149. eggs\n"
"150. chicken_frankfurter\n"
"151. bread\n"
"152. fish_strips\n"
"153. lettuce\n"
"154. fish\n"
"155. meat_stew\n"
"156. hot_beverage\n"
"157. cherry tomato\n"
"158. fish_dish\n"
"159. fat_oil\n"
"160. grape\n"
"161. plant_protein\n"
"162. pastry\n"
"163. oil\n"
"164. cold_cut_meat\n\n"
"OUTPUT FORMAT (single line):\n"
"CATEGORY1,CATEGORY2,CATEGORY3\n"
)
_ALLOWED_RE = re.compile(r"[^a-z0-9_\(\);,\|\/\-\s\?\.\:]")
_CATEGORY_LINE_RE = re.compile(r"^\s*\d+\.\s*(.+?)\s*$", flags=re.MULTILINE)
ITEM_RE = re.compile(r"\(\s*(.*?)\s*\)\s*\|\s*(\d+)", flags=re.DOTALL)
def _extract_categories_from_prompt(prompt: str) -> Set[str]:
cats = {m.group(1).strip().lower() for m in _CATEGORY_LINE_RE.finditer(prompt)}
return {c for c in cats if c}
def _clean_model_output(s: str) -> str:
s = _clean_text(s).lower()
s = _ALLOWED_RE.sub("", s)
return s
def _parse_and_validate_categories(raw: str, allowed: Set[str]) -> List[str]:
s = _clean_model_output(raw)
parts = [p.strip() for p in s.split(",") if p.strip()]
out: List[str] = []
seen: Set[str] = set()
for p in parts:
if p in allowed and p not in seen:
out.append(p)
seen.add(p)
return out
# ----------------------------
# Weight task
# ----------------------------
ANS_START = "<answers>"
@dataclass
class WeightGenConfig:
max_new_tokens: int = 8
temperature: float = 0.0
no_repeat_ngram_size: int = 0
repetition_penalty: float = 1.0
max_length: int = 3000
max_side: int = 768
max_w: int = 5000
WEIGHT_DEFAULT_PROMPT = """
Here is a photo of some food waste.
Your task is to provide a weight estimate (in grams) for the following food waste:
{SUB_CATEGORY}
While providing your weight estimate, please keep in mind the following:
*** there might be other food items on the photo, I want you to provide the weight estimate for {SUB_CATEGORY} only!
*** make sure to provide your best guess, never say 0!
OUTPUT (exactly one line, no extra text):
WEIGHT
""".strip()
_INT_RE = re.compile(r"[0-9]+")
def _extract_int_digits_only(text: str, default: int = 1) -> int:
t = _clean_text(text)
digits = "".join(ch for ch in t if "0" <= ch <= "9")
if digits:
return int(digits)
#fallback
m= _INT_RE.search(t)
if not m:
return max(1, default)
return max(1, int(m.group(0)))
def _resolve_sub_category(inputs: Dict[str, Any]) -> str:
subcat = inputs.get("sub_category")
if not subcat or not str(subcat).strip():
raise ValueError("Missing 'subcat' (or 'sub_category') in 'inputs'.")
return str(subcat).strip()
def _clamp_w(w: int, lo: int, hi: int) -> int:
return max(lo, min(hi, int(w)))
_TRIE_END = "_end_"
def _build_trie(token_seqs):
trie = {}
for seq in token_seqs:
node = trie
for tok in seq:
node = node.setdefault(tok, {})
node[_TRIE_END] = True
return trie
def _trie_node_for_prefix(trie, prefix):
node = trie
for tok in prefix:
node = node.get(tok)
if node is None:
return None
return node
class RangeConstraint:
def __init__(self, tokenizer, min_w=1, max_w=1200):
self.tok = tokenizer
seqs = []
for w in range(min_w, max_w + 1):
seqs.append(tokenizer.encode(str(w), add_special_tokens=False))
self.trie = _build_trie(seqs)
self.prompt_len = None
def prefix_allowed_tokens_fn(self, batch_id, input_ids):
ids = input_ids[batch_id] if getattr(input_ids, "dim", lambda: 1)() == 2 else input_ids
gen_prefix = ids[self.prompt_len :].tolist()
node = _trie_node_for_prefix(self.trie, gen_prefix)
if node is None:
return [self.tok.eos_token_id]
allowed = [k for k in node.keys() if k != _TRIE_END]
if node.get(_TRIE_END) and self.tok.eos_token_id is not None:
allowed.append(self.tok.eos_token_id)
return allowed
# ----------------------------
# Router EndpointHandler
# ----------------------------
class EndpointHandler:
def _load_processor_and_model(self, repo: str):
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
processor = AutoProcessor.from_pretrained(
repo,
trust_remote_code=True,
)
model = LlavaForConditionalGeneration.from_pretrained(
repo,
torch_dtype=dtype,
low_cpu_mem_usage=True,
device_map={"": self.device},
trust_remote_code=True,
)
model.eval()
logger.info("Model and processor successfully loaded from '%s'.", repo)
return processor, model
@staticmethod
def eos_and_pad_token_id(processor, model) -> int:
tokenizer = getattr(processor, "tokenizer", None)
if tokenizer is not None and tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
eos_candidates: List[int] = []
if model.config.eos_token_id is not None:
eos_candidates.append(model.config.eos_token_id)
if tokenizer is not None and tokenizer.eos_token_id is not None:
eos_candidates.append(tokenizer.eos_token_id)
eos_token_ids: List[int] = list({i for i in eos_candidates})
if not eos_token_ids:
raise ValueError("No EOS token id found on model or tokenizer.")
pad_id: Optional[int] = getattr(model.config, "pad_token_id", None)
if pad_id is None and tokenizer is not None:
pad_id = tokenizer.pad_token_id
if pad_id is None:
pad_id = eos_token_ids[0]
return eos_token_ids, pad_id
def __init__(self, path: str = ".") -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Initializing EndpointHandler on device: %s", self.device)
self.cat_processor, self.cat_model = self._load_processor_and_model(CATEGORIES_REPO)
self.cat_eos_token_ids, self.cat_pad_token_id = self.eos_and_pad_token_id(
self.cat_processor,
self.cat_model,
)
self.cat_gen_config = CategoriesGenConfig()
logger.info(
"Generation config: max_new_tokens=%d, temperature=%.3f",
self.cat_gen_config.max_new_tokens,
self.cat_gen_config.temperature,
)
self.weight_processor, self.weight_model = self._load_processor_and_model(WEIGHT_REPO)
self.weight_eos_token_ids, self.weight_pad_token_id = self.eos_and_pad_token_id(
self.weight_processor,
self.weight_model,
)
self.weight_gen_config = WeightGenConfig()
logger.info(
"Generation config: max_new_tokens=%d, temperature=%.3f",
self.weight_gen_config.max_new_tokens,
self.weight_gen_config.temperature,
)
self.range_constraint = RangeConstraint(self.weight_processor.tokenizer, min_w=1, max_w=self.weight_gen_config.max_w)
self.default_allowed_categories: Set[str] = _extract_categories_from_prompt(CATEGORIES_DEFAULT_PROMPT)
logger.info(
"Extracted %d categories from CATEGORIES_DEFAULT_PROMPT.",
len(self.default_allowed_categories),
)
def _predict_categories(self, image: Image.Image, debug: bool) -> List[str]:
max_length = int(self.cat_gen_config.max_length)
max_side = int(self.cat_gen_config.max_side)
image = _resize_max_side(image, max_side=max_side)
max_new_tokens = int(self.cat_gen_config.max_new_tokens)
temperature = float(self.cat_gen_config.temperature)
chat_text = _build_chat_text(CATEGORIES_DEFAULT_PROMPT, self.cat_processor)
enc = self.cat_processor(
text=[chat_text],
images=[image],
return_tensors="pt",
truncation=False,
padding=False,
)
prompt_len = int(enc["input_ids"].shape[1])
if prompt_len >= max_length:
raise ValueError(
f"Prompt too long ({prompt_len} tokens) for max_length={max_length}. "
"Increase max_length or shorten the prompt."
)
tokens_left = max(1, max_length - prompt_len)
max_new_tokens = min(max_new_tokens, tokens_left)
if debug:
logger.info("===== TOKEN BUDGET DEBUG (inference) =====")
logger.info("img size: %s", getattr(image, "size", None))
logger.info("max_length: %d", max_length)
logger.info("prompt_len: %d", prompt_len)
logger.info("tokens_left_for_answer(approx): %d", tokens_left)
tok = getattr(self.cat_processor, "tokenizer", None)
if tok is not None:
ids = enc["input_ids"][0].tolist()
logger.info("[prompt head]\n%s", tok.decode(ids[:120], skip_special_tokens=False))
logger.info("[prompt tail]\n%s", tok.decode(ids[-120:], skip_special_tokens=False))
logger.info("=========================================")
enc = {k: v.to(self.device) for k, v in enc.items()}
if "pixel_values" in enc:
enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self.cat_model.dtype)
gen_kwargs: Dict[str, Any] = {
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0.0,
"eos_token_id": self.cat_eos_token_ids,
"pad_token_id": self.cat_pad_token_id,
"no_repeat_ngram_size": self.cat_gen_config.no_repeat_ngram_size,
"repetition_penalty": self.cat_gen_config.repetition_penalty,
}
if temperature > 0.0:
gen_kwargs["temperature"] = temperature
with torch.inference_mode():
output_ids = self.cat_model.generate(**enc, **gen_kwargs)
generated_only = output_ids[:, enc["input_ids"].shape[1] :]
generated_text = self.cat_processor.batch_decode(
generated_only,
skip_special_tokens=True,
)[0].strip()
cats = _parse_and_validate_categories(generated_text, self.default_allowed_categories)
if not cats:
cats = ["food_waste"] if "food_waste" in self.default_allowed_categories else []
return cats
def _predict_weight(self, image: Image.Image, sub_category: str, debug: bool) -> int:
max_w = int(self.weight_gen_config.max_w)
prompt = WEIGHT_DEFAULT_PROMPT.format(SUB_CATEGORY=sub_category).rstrip()
prompt += (
"\n\nRules:\n"
"- output digits only (no units, no text)\n"
f"- output an integer between 1 and {self.weight_gen_config.max_w}\n"
"- never output 0\n"
f"\n{ANS_START}"
)
max_length = int(self.weight_gen_config.max_length)
max_side = int(self.weight_gen_config.max_side)
image = _resize_max_side(image, max_side=max_side)
max_new_tokens = int(self.weight_gen_config.max_new_tokens)
temperature = float(self.weight_gen_config.temperature)
chat_text = _build_chat_text(prompt, self.weight_processor)
enc = self.weight_processor(
text=[chat_text],
images=[image],
return_tensors="pt",
truncation=True,
max_length=max_length,
padding=False,
)
prompt_len = int(enc["input_ids"].shape[1])
self.range_constraint.prompt_len = prompt_len
tokens_left = max(1, max_length - prompt_len)
max_new_tokens = min(max_new_tokens, tokens_left)
if debug:
logger.info("===== TOKEN BUDGET DEBUG (inference) =====")
logger.info("sub_category: %s", sub_category)
logger.info("img size: %s", getattr(image, "size", None))
logger.info("max_length: %d", max_length)
logger.info("prompt_len: %d", prompt_len)
logger.info("tokens_left_for_answer(approx): %d", tokens_left)
logger.info("=========================================")
enc = {k: v.to(self.device) for k, v in enc.items()}
if "pixel_values" in enc:
enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self.weight_model.dtype)
gen_kwargs: Dict[str, Any] = {
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0.0,
"eos_token_id": self.weight_eos_token_ids,
"pad_token_id": self.weight_pad_token_id,
"no_repeat_ngram_size": self.weight_gen_config.no_repeat_ngram_size,
"repetition_penalty": self.weight_gen_config.repetition_penalty,
"prefix_allowed_tokens_fn": self.range_constraint.prefix_allowed_tokens_fn
}
if temperature > 0.0:
gen_kwargs["temperature"] = temperature
with torch.inference_mode():
output_ids = self.weight_model.generate(**enc, **gen_kwargs)
generated_only = output_ids[:, enc["input_ids"].shape[1] :]
generated_text_raw = self.weight_processor.batch_decode(
generated_only,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0].strip()
weight_g = _extract_int_digits_only(generated_text_raw, default=1)
weight_g = _clamp_w(weight_g, lo=1, hi=max_w)
if debug:
logger.info("Generated raw: %s", generated_text_raw)
logger.info("Final weight_g: %d", weight_g)
return weight_g
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data.get("inputs", data)
debug = bool(inputs.get("debug", False))
image_b64: Optional[str] = inputs.get("image")
if not image_b64:
raise ValueError("Missing 'image' field (base64-encoded) in 'inputs'.")
image = _decode_image(image_b64)
cats = self._predict_categories(image, debug=debug)
items: List[Dict[str, Any]] = []
weights_by_category: Dict[str, int] = {}
total_weight_g = 0
for cat in cats:
w_pred = int(self._predict_weight(image, sub_category=cat, debug=debug))
items.append({"category": cat, "weight_g": w_pred})
weights_by_category[cat] = w_pred
total_weight_g += w_pred
if debug:
logger.info("Predicted categories: %s", cats)
logger.info("Weights by category: %s", weights_by_category)
logger.info("Total weight (g): %d", total_weight_g)
return {
"categories": cats,
"items": items,
"weights": weights_by_category,
"total_weight_g": total_weight_g,
}