from __future__ import annotations import base64 import re import unicodedata from dataclasses import dataclass from io import BytesIO from pathlib import Path from typing import Any, Dict, List, Optional, Set import torch from PIL import Image from transformers import AutoProcessor, LlavaForConditionalGeneration from transformers.utils import logging logger = logging.get_logger(__name__) logging.set_verbosity_info() CATEGORIES_REPO = "HCKLab/pixtral-12b-foodwaste-classification-merged-2" WEIGHT_REPO = "HCKLab/pixtral-12b-foodwaste-weight-merged" def _clean_text(s: str) -> str: s = unicodedata.normalize("NFKC", s).replace("\u00A0", " ") s = re.sub(r"[\u200B-\u200F]", "", s) s = re.sub(r"\s+", " ", s).strip() return s def _decode_image(image_b64: str) -> Image.Image: try: img_bytes = base64.b64decode(image_b64) img = Image.open(BytesIO(img_bytes)).convert("RGB") return img except Exception as exc: # pragma: no cover - log production raise ValueError(f"Could not decode base64 image: {exc}") from exc def _resize_max_side(img: Image.Image, max_side: int) -> Image.Image: w, h = img.size m = max(w, h) if m <= max_side: return img scale = max_side / m return img.resize((int(w * scale), int(h * scale)), resample=Image.Resampling.LANCZOS) def _build_chat_text(prompt: str, processor: AutoProcessor) -> str: messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ], } ] chat_text = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=False, ) return chat_text # ---------------------------- # Categories task # ---------------------------- @dataclass class CategoriesGenConfig: max_new_tokens: int = 128 temperature: float = 0.0 no_repeat_ngram_size: int = 6 repetition_penalty: float = 1.1 max_length: int = 4096 max_side: int = 768 CATEGORIES_DEFAULT_PROMPT = ( "Here is a picture showing some food waste.\n\n" "Task: Provide a list of the food waste items visible in the picture.\n" "For each item, output EXACTLY one CATEGORY chosen from the CATEGORY DICTIONARY below.\n\n" "Rules:\n" "- Output must be a single line (no extra text, no markdown).\n" "- Use the exact spelling of CATEGORY as listed.\n" "- If you are unsure, choose the closest broader category (e.g., 'fruit', 'vegetable', 'food_waste').\n\n" "CATEGORY DICTIONARY:\n" "1. beer\n" "2. cabbage\n" "3. whipped_porridge\n" "4. cider\n" "5. orange\n" "6. dairy_product_milk\n" "7. yoghurt\n" "8. herbs\n" "9. potato_product\n" "10. strawberry\n" "11. minced_chicken\n" "12. soda\n" "13. sauce_paste\n" "14. rice_porridge\n" "15. blueberry\n" "16. light_bread\n" "17. semolina_porridge\n" "18. egg\n" "19. crepes\n" "20. sliced_cheese\n" "21. chicken\n" "22. food_waste\n" "23. candy\n" "24. vegetarian_sauce\n" "25. cheese\n" "26. cereal\n" "27. pork_product\n" "28. vegetable\n" "29. honey\n" "30. plant_based_cream\n" "31. vegetarian_pizza\n" "32. dried fruits\n" "33. rice_or_pasta\n" "34. pork_steak\n" "35. wine\n" "36. soft_drink\n" "37. minced_beef\n" "38. steak\n" "39. frankfurter\n" "40. vegetarian_soup\n" "41. beef_frankfurter\n" "42. alcoholic_long_drink\n" "43. raspberry\n" "44. rice\n" "45. mandarin\n" "46. juice\n" "47. porridge\n" "48. fish_soup\n" "49. fish_hamburger\n" "50. fruit\n" "51. coffee\n" "52. plant_based_ice_cream\n" "53. beef_sausage\n" "54. minced_pork\n" "55. meat\n" "56. sweet_pastry\n" "57. ice_cream\n" "58. fish_salad\n" "59. flour\n" "60. snack\n" "61. fish_product\n" "62. cherry\n" "63. shellfish\n" "64. cream\n" "65. dessert\n" "66. cold_cut_chicken\n" "67. onion\n" "68. dark_bread\n" "69. plant_based_milk\n" "70. fermented_milk\n" "71. cocoa\n" "72. bell_pepper\n" "73. beef\n" "74. sausage\n" "75. dairy_product\n" "76. pork\n" "77. fish_fillet\n" "78. strong_alcoholic_beverage\n" "79. biscuit\n" "80. currant\n" "81. sweet_soup\n" "82. popcorn\n" "83. meat_sauce\n" "84. meat_pizza\n" "85. lemon\n" "86. plum\n" "87. grain_products\n" "88. sugar_honey_syrup\n" "89. vegetarian_hamburger\n" "90. jam\n" "91. vegetarian_stew\n" "92. beef_steak\n" "93. tomato\n" "94. meat_soup\n" "95. block_cheese\n" "96. potato_based\n" "97. potato\n" "98. flakes\n" "99. soft_cheese\n" "100. beef_product\n" "101. chocolate\n" "102. cauliflower\n" "103. nectarine\n" "104. banana\n" "105. apple\n" "106. alcoholic_beverage\n" "107. meat_salad\n" "108. pork_sausage\n" "109. lingonberry\n" "110. chicken_sausage\n" "111. pork_frankfurter\n" "112. minced_meat\n" "113. baked_goods\n" "114. broccoli\n" "115. cold_cut_beef\n" "116. meat_dish\n" "117. plant_cheese\n" "118. butter\n" "119. children_milk\n" "120. fish_sauce\n" "121. buttermilk\n" "122. cucumber\n" "123. meat_hamburger\n" "124. pasta\n" "125. tea\n" "126. berries\n" "127. milk\n" "128. fish_pizza\n" "129. fresh_cheese\n" "130. fish_stew\n" "131. sirup\n" "132. vegetarian_salad\n" "133. spice\n" "134. quark\n" "135. cold_cut_pork\n" "136. sauce_seasoning\n" "137. oatmeal\n" "138. sugar\n" "139. meat_product\n" "140. vegetarian_dish\n" "141. savory_pastry\n" "142. chicken_product\n" "143. vegetable_mix\n" "144. potato_chip\n" "145. muesli\n" "146. carrot\n" "147. plant_based_yogurt\n" "148. sweet\n" "149. eggs\n" "150. chicken_frankfurter\n" "151. bread\n" "152. fish_strips\n" "153. lettuce\n" "154. fish\n" "155. meat_stew\n" "156. hot_beverage\n" "157. cherry tomato\n" "158. fish_dish\n" "159. fat_oil\n" "160. grape\n" "161. plant_protein\n" "162. pastry\n" "163. oil\n" "164. cold_cut_meat\n\n" "OUTPUT FORMAT (single line):\n" "CATEGORY1,CATEGORY2,CATEGORY3\n" ) _ALLOWED_RE = re.compile(r"[^a-z0-9_\(\);,\|\/\-\s\?\.\:]") _CATEGORY_LINE_RE = re.compile(r"^\s*\d+\.\s*(.+?)\s*$", flags=re.MULTILINE) ITEM_RE = re.compile(r"\(\s*(.*?)\s*\)\s*\|\s*(\d+)", flags=re.DOTALL) def _extract_categories_from_prompt(prompt: str) -> Set[str]: cats = {m.group(1).strip().lower() for m in _CATEGORY_LINE_RE.finditer(prompt)} return {c for c in cats if c} def _clean_model_output(s: str) -> str: s = _clean_text(s).lower() s = _ALLOWED_RE.sub("", s) return s def _parse_and_validate_categories(raw: str, allowed: Set[str]) -> List[str]: s = _clean_model_output(raw) parts = [p.strip() for p in s.split(",") if p.strip()] out: List[str] = [] seen: Set[str] = set() for p in parts: if p in allowed and p not in seen: out.append(p) seen.add(p) return out # ---------------------------- # Weight task # ---------------------------- ANS_START = "" @dataclass class WeightGenConfig: max_new_tokens: int = 8 temperature: float = 0.0 no_repeat_ngram_size: int = 0 repetition_penalty: float = 1.0 max_length: int = 3000 max_side: int = 768 max_w: int = 5000 WEIGHT_DEFAULT_PROMPT = """ Here is a photo of some food waste. Your task is to provide a weight estimate (in grams) for the following food waste: {SUB_CATEGORY} While providing your weight estimate, please keep in mind the following: *** there might be other food items on the photo, I want you to provide the weight estimate for {SUB_CATEGORY} only! *** make sure to provide your best guess, never say 0! OUTPUT (exactly one line, no extra text): WEIGHT """.strip() _INT_RE = re.compile(r"[0-9]+") def _extract_int_digits_only(text: str, default: int = 1) -> int: t = _clean_text(text) digits = "".join(ch for ch in t if "0" <= ch <= "9") if digits: return int(digits) #fallback m= _INT_RE.search(t) if not m: return max(1, default) return max(1, int(m.group(0))) def _resolve_sub_category(inputs: Dict[str, Any]) -> str: subcat = inputs.get("sub_category") if not subcat or not str(subcat).strip(): raise ValueError("Missing 'subcat' (or 'sub_category') in 'inputs'.") return str(subcat).strip() def _clamp_w(w: int, lo: int, hi: int) -> int: return max(lo, min(hi, int(w))) _TRIE_END = "_end_" def _build_trie(token_seqs): trie = {} for seq in token_seqs: node = trie for tok in seq: node = node.setdefault(tok, {}) node[_TRIE_END] = True return trie def _trie_node_for_prefix(trie, prefix): node = trie for tok in prefix: node = node.get(tok) if node is None: return None return node class RangeConstraint: def __init__(self, tokenizer, min_w=1, max_w=1200): self.tok = tokenizer seqs = [] for w in range(min_w, max_w + 1): seqs.append(tokenizer.encode(str(w), add_special_tokens=False)) self.trie = _build_trie(seqs) self.prompt_len = None def prefix_allowed_tokens_fn(self, batch_id, input_ids): ids = input_ids[batch_id] if getattr(input_ids, "dim", lambda: 1)() == 2 else input_ids gen_prefix = ids[self.prompt_len :].tolist() node = _trie_node_for_prefix(self.trie, gen_prefix) if node is None: return [self.tok.eos_token_id] allowed = [k for k in node.keys() if k != _TRIE_END] if node.get(_TRIE_END) and self.tok.eos_token_id is not None: allowed.append(self.tok.eos_token_id) return allowed # ---------------------------- # Router EndpointHandler # ---------------------------- class EndpointHandler: def _load_processor_and_model(self, repo: str): dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 processor = AutoProcessor.from_pretrained( repo, trust_remote_code=True, ) model = LlavaForConditionalGeneration.from_pretrained( repo, torch_dtype=dtype, low_cpu_mem_usage=True, device_map={"": self.device}, trust_remote_code=True, ) model.eval() logger.info("Model and processor successfully loaded from '%s'.", repo) return processor, model @staticmethod def eos_and_pad_token_id(processor, model) -> int: tokenizer = getattr(processor, "tokenizer", None) if tokenizer is not None and tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id eos_candidates: List[int] = [] if model.config.eos_token_id is not None: eos_candidates.append(model.config.eos_token_id) if tokenizer is not None and tokenizer.eos_token_id is not None: eos_candidates.append(tokenizer.eos_token_id) eos_token_ids: List[int] = list({i for i in eos_candidates}) if not eos_token_ids: raise ValueError("No EOS token id found on model or tokenizer.") pad_id: Optional[int] = getattr(model.config, "pad_token_id", None) if pad_id is None and tokenizer is not None: pad_id = tokenizer.pad_token_id if pad_id is None: pad_id = eos_token_ids[0] return eos_token_ids, pad_id def __init__(self, path: str = ".") -> None: self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info("Initializing EndpointHandler on device: %s", self.device) self.cat_processor, self.cat_model = self._load_processor_and_model(CATEGORIES_REPO) self.cat_eos_token_ids, self.cat_pad_token_id = self.eos_and_pad_token_id( self.cat_processor, self.cat_model, ) self.cat_gen_config = CategoriesGenConfig() logger.info( "Generation config: max_new_tokens=%d, temperature=%.3f", self.cat_gen_config.max_new_tokens, self.cat_gen_config.temperature, ) self.weight_processor, self.weight_model = self._load_processor_and_model(WEIGHT_REPO) self.weight_eos_token_ids, self.weight_pad_token_id = self.eos_and_pad_token_id( self.weight_processor, self.weight_model, ) self.weight_gen_config = WeightGenConfig() logger.info( "Generation config: max_new_tokens=%d, temperature=%.3f", self.weight_gen_config.max_new_tokens, self.weight_gen_config.temperature, ) self.range_constraint = RangeConstraint(self.weight_processor.tokenizer, min_w=1, max_w=self.weight_gen_config.max_w) self.default_allowed_categories: Set[str] = _extract_categories_from_prompt(CATEGORIES_DEFAULT_PROMPT) logger.info( "Extracted %d categories from CATEGORIES_DEFAULT_PROMPT.", len(self.default_allowed_categories), ) def _predict_categories(self, image: Image.Image, debug: bool) -> List[str]: max_length = int(self.cat_gen_config.max_length) max_side = int(self.cat_gen_config.max_side) image = _resize_max_side(image, max_side=max_side) max_new_tokens = int(self.cat_gen_config.max_new_tokens) temperature = float(self.cat_gen_config.temperature) chat_text = _build_chat_text(CATEGORIES_DEFAULT_PROMPT, self.cat_processor) enc = self.cat_processor( text=[chat_text], images=[image], return_tensors="pt", truncation=False, padding=False, ) prompt_len = int(enc["input_ids"].shape[1]) if prompt_len >= max_length: raise ValueError( f"Prompt too long ({prompt_len} tokens) for max_length={max_length}. " "Increase max_length or shorten the prompt." ) tokens_left = max(1, max_length - prompt_len) max_new_tokens = min(max_new_tokens, tokens_left) if debug: logger.info("===== TOKEN BUDGET DEBUG (inference) =====") logger.info("img size: %s", getattr(image, "size", None)) logger.info("max_length: %d", max_length) logger.info("prompt_len: %d", prompt_len) logger.info("tokens_left_for_answer(approx): %d", tokens_left) tok = getattr(self.cat_processor, "tokenizer", None) if tok is not None: ids = enc["input_ids"][0].tolist() logger.info("[prompt head]\n%s", tok.decode(ids[:120], skip_special_tokens=False)) logger.info("[prompt tail]\n%s", tok.decode(ids[-120:], skip_special_tokens=False)) logger.info("=========================================") enc = {k: v.to(self.device) for k, v in enc.items()} if "pixel_values" in enc: enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self.cat_model.dtype) gen_kwargs: Dict[str, Any] = { "max_new_tokens": max_new_tokens, "do_sample": temperature > 0.0, "eos_token_id": self.cat_eos_token_ids, "pad_token_id": self.cat_pad_token_id, "no_repeat_ngram_size": self.cat_gen_config.no_repeat_ngram_size, "repetition_penalty": self.cat_gen_config.repetition_penalty, } if temperature > 0.0: gen_kwargs["temperature"] = temperature with torch.inference_mode(): output_ids = self.cat_model.generate(**enc, **gen_kwargs) generated_only = output_ids[:, enc["input_ids"].shape[1] :] generated_text = self.cat_processor.batch_decode( generated_only, skip_special_tokens=True, )[0].strip() cats = _parse_and_validate_categories(generated_text, self.default_allowed_categories) if not cats: cats = ["food_waste"] if "food_waste" in self.default_allowed_categories else [] return cats def _predict_weight(self, image: Image.Image, sub_category: str, debug: bool) -> int: max_w = int(self.weight_gen_config.max_w) prompt = WEIGHT_DEFAULT_PROMPT.format(SUB_CATEGORY=sub_category).rstrip() prompt += ( "\n\nRules:\n" "- output digits only (no units, no text)\n" f"- output an integer between 1 and {self.weight_gen_config.max_w}\n" "- never output 0\n" f"\n{ANS_START}" ) max_length = int(self.weight_gen_config.max_length) max_side = int(self.weight_gen_config.max_side) image = _resize_max_side(image, max_side=max_side) max_new_tokens = int(self.weight_gen_config.max_new_tokens) temperature = float(self.weight_gen_config.temperature) chat_text = _build_chat_text(prompt, self.weight_processor) enc = self.weight_processor( text=[chat_text], images=[image], return_tensors="pt", truncation=True, max_length=max_length, padding=False, ) prompt_len = int(enc["input_ids"].shape[1]) self.range_constraint.prompt_len = prompt_len tokens_left = max(1, max_length - prompt_len) max_new_tokens = min(max_new_tokens, tokens_left) if debug: logger.info("===== TOKEN BUDGET DEBUG (inference) =====") logger.info("sub_category: %s", sub_category) logger.info("img size: %s", getattr(image, "size", None)) logger.info("max_length: %d", max_length) logger.info("prompt_len: %d", prompt_len) logger.info("tokens_left_for_answer(approx): %d", tokens_left) logger.info("=========================================") enc = {k: v.to(self.device) for k, v in enc.items()} if "pixel_values" in enc: enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self.weight_model.dtype) gen_kwargs: Dict[str, Any] = { "max_new_tokens": max_new_tokens, "do_sample": temperature > 0.0, "eos_token_id": self.weight_eos_token_ids, "pad_token_id": self.weight_pad_token_id, "no_repeat_ngram_size": self.weight_gen_config.no_repeat_ngram_size, "repetition_penalty": self.weight_gen_config.repetition_penalty, "prefix_allowed_tokens_fn": self.range_constraint.prefix_allowed_tokens_fn } if temperature > 0.0: gen_kwargs["temperature"] = temperature with torch.inference_mode(): output_ids = self.weight_model.generate(**enc, **gen_kwargs) generated_only = output_ids[:, enc["input_ids"].shape[1] :] generated_text_raw = self.weight_processor.batch_decode( generated_only, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0].strip() weight_g = _extract_int_digits_only(generated_text_raw, default=1) weight_g = _clamp_w(weight_g, lo=1, hi=max_w) if debug: logger.info("Generated raw: %s", generated_text_raw) logger.info("Final weight_g: %d", weight_g) return weight_g def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: inputs = data.get("inputs", data) debug = bool(inputs.get("debug", False)) image_b64: Optional[str] = inputs.get("image") if not image_b64: raise ValueError("Missing 'image' field (base64-encoded) in 'inputs'.") image = _decode_image(image_b64) cats = self._predict_categories(image, debug=debug) items: List[Dict[str, Any]] = [] weights_by_category: Dict[str, int] = {} total_weight_g = 0 for cat in cats: w_pred = int(self._predict_weight(image, sub_category=cat, debug=debug)) items.append({"category": cat, "weight_g": w_pred}) weights_by_category[cat] = w_pred total_weight_g += w_pred if debug: logger.info("Predicted categories: %s", cats) logger.info("Weights by category: %s", weights_by_category) logger.info("Total weight (g): %d", total_weight_g) return { "categories": cats, "items": items, "weights": weights_by_category, "total_weight_g": total_weight_g, }