from __future__ import annotations import base64 from curses import raw import re import unicodedata from dataclasses import dataclass from io import BytesIO from typing import Any, Dict, Optional, List, Set import torch from PIL import Image from transformers import AutoProcessor, LlavaForConditionalGeneration from transformers.utils import logging logger = logging.get_logger(__name__) logging.set_verbosity_info() BASE_MODEL_ID = "mistral-community/pixtral-12b" DEFAULT_PROMPT = ( "Here is a picture showing some food waste.\n\n" "Task: Provide a list of the food waste items visible in the picture.\n" "For each item, output EXACTLY one CATEGORY chosen from the CATEGORY DICTIONARY below.\n\n" "Rules:\n" "- Output must be a single line (no extra text, no markdown).\n" "- Use the exact spelling of CATEGORY as listed.\n" "- If you are unsure, choose the closest broader category (e.g., 'fruit', 'vegetable', 'food_waste').\n\n" "CATEGORY DICTIONARY:\n" "1. beer\n" "2. cabbage\n" "3. whipped_porridge\n" "4. cider\n" "5. orange\n" "6. dairy_product_milk\n" "7. yoghurt\n" "8. herbs\n" "9. potato_product\n" "10. strawberry\n" "11. minced_chicken\n" "12. soda\n" "13. sauce_paste\n" "14. rice_porridge\n" "15. blueberry\n" "16. light_bread\n" "17. semolina_porridge\n" "18. egg\n" "19. crepes\n" "20. sliced_cheese\n" "21. chicken\n" "22. food_waste\n" "23. candy\n" "24. vegetarian_sauce\n" "25. cheese\n" "26. cereal\n" "27. pork_product\n" "28. vegetable\n" "29. honey\n" "30. plant_based_cream\n" "31. vegetarian_pizza\n" "32. dried fruits\n" "33. rice_or_pasta\n" "34. pork_steak\n" "35. wine\n" "36. soft_drink\n" "37. minced_beef\n" "38. steak\n" "39. frankfurter\n" "40. vegetarian_soup\n" "41. beef_frankfurter\n" "42. alcoholic_long_drink\n" "43. raspberry\n" "44. rice\n" "45. mandarin\n" "46. juice\n" "47. porridge\n" "48. fish_soup\n" "49. fish_hamburger\n" "50. fruit\n" "51. coffee\n" "52. plant_based_ice_cream\n" "53. beef_sausage\n" "54. minced_pork\n" "55. meat\n" "56. sweet_pastry\n" "57. ice_cream\n" "58. fish_salad\n" "59. flour\n" "60. snack\n" "61. fish_product\n" "62. cherry\n" "63. shellfish\n" "64. cream\n" "65. dessert\n" "66. cold_cut_chicken\n" "67. onion\n" "68. dark_bread\n" "69. plant_based_milk\n" "70. fermented_milk\n" "71. cocoa\n" "72. bell_pepper\n" "73. beef\n" "74. sausage\n" "75. dairy_product\n" "76. pork\n" "77. fish_fillet\n" "78. strong_alcoholic_beverage\n" "79. biscuit\n" "80. currant\n" "81. sweet_soup\n" "82. popcorn\n" "83. meat_sauce\n" "84. meat_pizza\n" "85. lemon\n" "86. plum\n" "87. grain_products\n" "88. sugar_honey_syrup\n" "89. vegetarian_hamburger\n" "90. jam\n" "91. vegetarian_stew\n" "92. beef_steak\n" "93. tomato\n" "94. meat_soup\n" "95. block_cheese\n" "96. potato_based\n" "97. potato\n" "98. flakes\n" "99. soft_cheese\n" "100. beef_product\n" "101. chocolate\n" "102. cauliflower\n" "103. nectarine\n" "104. banana\n" "105. apple\n" "106. alcoholic_beverage\n" "107. meat_salad\n" "108. pork_sausage\n" "109. lingonberry\n" "110. chicken_sausage\n" "111. pork_frankfurter\n" "112. minced_meat\n" "113. baked_goods\n" "114. broccoli\n" "115. cold_cut_beef\n" "116. meat_dish\n" "117. plant_cheese\n" "118. butter\n" "119. children_milk\n" "120. fish_sauce\n" "121. buttermilk\n" "122. cucumber\n" "123. meat_hamburger\n" "124. pasta\n" "125. tea\n" "126. berries\n" "127. milk\n" "128. fish_pizza\n" "129. fresh_cheese\n" "130. fish_stew\n" "131. sirup\n" "132. vegetarian_salad\n" "133. spice\n" "134. quark\n" "135. cold_cut_pork\n" "136. sauce_seasoning\n" "137. oatmeal\n" "138. sugar\n" "139. meat_product\n" "140. vegetarian_dish\n" "141. savory_pastry\n" "142. chicken_product\n" "143. vegetable_mix\n" "144. potato_chip\n" "145. muesli\n" "146. carrot\n" "147. plant_based_yogurt\n" "148. sweet\n" "149. eggs\n" "150. chicken_frankfurter\n" "151. bread\n" "152. fish_strips\n" "153. lettuce\n" "154. fish\n" "155. meat_stew\n" "156. hot_beverage\n" "157. cherry tomato\n" "158. fish_dish\n" "159. fat_oil\n" "160. grape\n" "161. plant_protein\n" "162. pastry\n" "163. oil\n" "164. cold_cut_meat\n\n" "OUTPUT FORMAT (single line):\n" "CATEGORY1,CATEGORY2,CATEGORY3\n" ) @dataclass class GenerationConfig: max_new_tokens: int = 256 temperature: float = 0.0 no_repeat_ngram_size: int = 6 repetition_penalty: float = 1.1 max_length: int = 4096 max_side: int = 512 _ALLOWED_RE = re.compile(r"[^a-z0-9_\(\);,\|\/\-\s\?\.\:]") ITEM_RE = re.compile(r"\(\s*(.*?)\s*\)\s*\|\s*(\d+)", flags=re.DOTALL) def _clean_text(s: str) -> str: s = unicodedata.normalize("NFKC", s).replace("\u00A0", " ") s = re.sub(r"[\u200B-\u200F]", "", s) s = re.sub(r"\s+", " ", s).strip() return s _CATEGORY_LINE_RE = re.compile(r"^\s*\d+\.\s*(.+?)\s*$", flags=re.MULTILINE) def _extract_categories_from_prompt(prompt: str) -> Set[str]: cats = {m.group(1).strip().lower() for m in _CATEGORY_LINE_RE.finditer(prompt)} return {c for c in cats if c} def _clean_model_output(s: str) -> str: s = _clean_text(s).lower() s = _ALLOWED_RE.sub("", s) return s def _parse_and_validate_categories(raw: str, allowed: Set[str]) -> List[str]: s = _clean_model_output(raw) parts = [p.strip() for p in s.split(",") if p.strip()] out: List[str] = [] seen: Set[str] = set() for p in parts: if p in allowed and p not in seen: out.append(p) seen.add(p) return out class EndpointHandler: def __init__(self, path: str = ".") -> None: """ Initializes the model and processor from the `path` directory, which contains the merged weights (pixtral-12b-foodwaste-merged). """ self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info("Initializing EndpointHandler on device: %s", self.device) self.processor = AutoProcessor.from_pretrained( BASE_MODEL_ID, trust_remote_code=True, ) dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 self.model = LlavaForConditionalGeneration.from_pretrained( BASE_MODEL_ID, torch_dtype=dtype, low_cpu_mem_usage=True, device_map={"": self.device}, trust_remote_code=True, ) self.model.eval() logger.info("Model and processor successfully loaded from '%s'.", path) # pad token management tokenizer = getattr(self.processor, "tokenizer", None) if tokenizer is not None and tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id # Preparation of EOS/PAD IDs for generate eos_candidates: List[int] = [] if self.model.config.eos_token_id is not None: eos_candidates.append(self.model.config.eos_token_id) if tokenizer is not None and tokenizer.eos_token_id is not None: eos_candidates.append(tokenizer.eos_token_id) self.eos_token_ids: List[int] = list({i for i in eos_candidates}) if not self.eos_token_ids: raise ValueError("No EOS token id found on model or tokenizer.") pad_id: Optional[int] = getattr(self.model.config, "pad_token_id", None) if pad_id is None and tokenizer is not None: pad_id = tokenizer.pad_token_id if pad_id is None: pad_id = self.eos_token_ids[0] self.pad_token_id: int = pad_id self.gen_config = GenerationConfig() logger.info( "Generation config: max_new_tokens=%d, temperature=%.3f", self.gen_config.max_new_tokens, self.gen_config.temperature, ) self.default_allowed_categories: Set[str] = _extract_categories_from_prompt(DEFAULT_PROMPT) logger.info("Extracted %d categories from DEFAULT_PROMPT.", len(self.default_allowed_categories)) @staticmethod def _decode_image(image_b64: str) -> Image.Image: try: img_bytes = base64.b64decode(image_b64) img = Image.open(BytesIO(img_bytes)).convert("RGB") return img except Exception as exc: # pragma: no cover - log production raise ValueError(f"Could not decode base64 image: {exc}") from exc @staticmethod def _resize_max_side(img: Image.Image, max_side: int) -> Image.Image: w, h = img.size m = max(w, h) if m <= max_side: return img scale = max_side / m # LANCZOS = better downscale return img.resize((int(w * scale), int(h * scale)), resample=Image.Resampling.LANCZOS) def _build_chat_text(self, prompt: str) -> str: messages = [ { "role": "user", "content": [ {"type": "text", "text": prompt}, {"type": "image"}, ], } ] chat_text = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=False, ) return chat_text def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: inputs = data.get("inputs", data) debug = bool(inputs.get("debug", False)) prompt: str = inputs.get("prompt") or DEFAULT_PROMPT allowed_categories = _extract_categories_from_prompt(prompt) or self.default_allowed_categories image_b64: Optional[str] = inputs.get("image") if not image_b64: raise ValueError("Missing 'image' field (base64-encoded) in 'inputs'.") image = self._decode_image(image_b64) max_length = int(inputs.get("max_length", self.gen_config.max_length)) max_side = int(inputs.get("max_side", self.gen_config.max_side)) image = self._resize_max_side(image, max_side=max_side) max_new_tokens = int(inputs.get("max_new_tokens", self.gen_config.max_new_tokens)) temperature = float(inputs.get("temperature", self.gen_config.temperature)) chat_text = self._build_chat_text(prompt) enc = self.processor( text=[chat_text], images=[image], return_tensors="pt", truncation=True, max_length=max_length, padding=False, # important for correct prompt_len ) prompt_len = int(enc["input_ids"].shape[1]) tokens_left = max(1, max_length - prompt_len) max_new_tokens = min(max_new_tokens, tokens_left) if debug: logger.info("===== TOKEN BUDGET DEBUG (inference) =====") logger.info("img size: %s", getattr(image, "size", None)) logger.info("max_length: %d", max_length) logger.info("prompt_len: %d", prompt_len) logger.info("tokens_left_for_answer(approx): %d", tokens_left) tok = getattr(self.processor, "tokenizer", None) if tok is not None: ids = enc["input_ids"][0].tolist() logger.info("[prompt head]\n%s", tok.decode(ids[:120], skip_special_tokens=False)) logger.info("[prompt tail]\n%s", tok.decode(ids[-120:], skip_special_tokens=False)) logger.info("=========================================") enc = {k: v.to(self.device) for k, v in enc.items()} if "pixel_values" in enc: enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self.model.dtype) gen_kwargs: Dict[str, Any] = { "max_new_tokens": max_new_tokens, "do_sample": temperature > 0.0, "eos_token_id": self.eos_token_ids, "pad_token_id": self.pad_token_id, "no_repeat_ngram_size": self.gen_config.no_repeat_ngram_size, "repetition_penalty": self.gen_config.repetition_penalty, } if temperature > 0.0: gen_kwargs["temperature"] = temperature with torch.inference_mode(): output_ids = self.model.generate(**enc, **gen_kwargs) generated_only = output_ids[:, enc["input_ids"].shape[1]:] generated_text = self.processor.batch_decode( generated_only, skip_special_tokens=True, )[0].strip() cats = _parse_and_validate_categories(generated_text, allowed_categories) if not cats: cats = ["food_waste"] if "food_waste" in allowed_categories else [] generated_text = ",".join(cats) logger.info("Generated text: %s", generated_text) return {"generated_text": generated_text, "categories": cats}