| | from __future__ import annotations |
| |
|
| | import base64 |
| | import re |
| | import unicodedata |
| | from dataclasses import dataclass |
| | from io import BytesIO |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Optional, Set |
| |
|
| | import torch |
| | from PIL import Image |
| | from transformers import AutoProcessor, LlavaForConditionalGeneration |
| | from transformers.utils import logging |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| | logging.set_verbosity_info() |
| |
|
| | CATEGORIES_REPO = "HCKLab/pixtral-12b-foodwaste-classification-merged-2" |
| | WEIGHT_REPO = "HCKLab/pixtral-12b-foodwaste-weight-merged" |
| |
|
| | def _clean_text(s: str) -> str: |
| | s = unicodedata.normalize("NFKC", s).replace("\u00A0", " ") |
| | s = re.sub(r"[\u200B-\u200F]", "", s) |
| | s = re.sub(r"\s+", " ", s).strip() |
| | return s |
| |
|
| |
|
| | def _decode_image(image_b64: str) -> Image.Image: |
| | try: |
| | img_bytes = base64.b64decode(image_b64) |
| | img = Image.open(BytesIO(img_bytes)).convert("RGB") |
| | return img |
| | except Exception as exc: |
| | raise ValueError(f"Could not decode base64 image: {exc}") from exc |
| |
|
| | def _resize_max_side(img: Image.Image, max_side: int) -> Image.Image: |
| | w, h = img.size |
| | m = max(w, h) |
| | if m <= max_side: |
| | return img |
| | scale = max_side / m |
| | return img.resize((int(w * scale), int(h * scale)), resample=Image.Resampling.LANCZOS) |
| |
|
| | def _build_chat_text(prompt: str, processor: AutoProcessor) -> str: |
| | messages = [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "image"}, |
| | {"type": "text", "text": prompt}, |
| | ], |
| | } |
| | ] |
| |
|
| | chat_text = processor.apply_chat_template( |
| | messages, |
| | add_generation_prompt=True, |
| | tokenize=False, |
| | ) |
| | return chat_text |
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class CategoriesGenConfig: |
| |
|
| | max_new_tokens: int = 128 |
| | temperature: float = 0.0 |
| | no_repeat_ngram_size: int = 6 |
| | repetition_penalty: float = 1.1 |
| | max_length: int = 4096 |
| | max_side: int = 768 |
| |
|
| | CATEGORIES_DEFAULT_PROMPT = ( |
| | "Here is a picture showing some food waste.\n\n" |
| | "Task: Provide a list of the food waste items visible in the picture.\n" |
| | "For each item, output EXACTLY one CATEGORY chosen from the CATEGORY DICTIONARY below.\n\n" |
| | "Rules:\n" |
| | "- Output must be a single line (no extra text, no markdown).\n" |
| | "- Use the exact spelling of CATEGORY as listed.\n" |
| | "- If you are unsure, choose the closest broader category (e.g., 'fruit', 'vegetable', 'food_waste').\n\n" |
| | "CATEGORY DICTIONARY:\n" |
| | "1. beer\n" |
| | "2. cabbage\n" |
| | "3. whipped_porridge\n" |
| | "4. cider\n" |
| | "5. orange\n" |
| | "6. dairy_product_milk\n" |
| | "7. yoghurt\n" |
| | "8. herbs\n" |
| | "9. potato_product\n" |
| | "10. strawberry\n" |
| | "11. minced_chicken\n" |
| | "12. soda\n" |
| | "13. sauce_paste\n" |
| | "14. rice_porridge\n" |
| | "15. blueberry\n" |
| | "16. light_bread\n" |
| | "17. semolina_porridge\n" |
| | "18. egg\n" |
| | "19. crepes\n" |
| | "20. sliced_cheese\n" |
| | "21. chicken\n" |
| | "22. food_waste\n" |
| | "23. candy\n" |
| | "24. vegetarian_sauce\n" |
| | "25. cheese\n" |
| | "26. cereal\n" |
| | "27. pork_product\n" |
| | "28. vegetable\n" |
| | "29. honey\n" |
| | "30. plant_based_cream\n" |
| | "31. vegetarian_pizza\n" |
| | "32. dried fruits\n" |
| | "33. rice_or_pasta\n" |
| | "34. pork_steak\n" |
| | "35. wine\n" |
| | "36. soft_drink\n" |
| | "37. minced_beef\n" |
| | "38. steak\n" |
| | "39. frankfurter\n" |
| | "40. vegetarian_soup\n" |
| | "41. beef_frankfurter\n" |
| | "42. alcoholic_long_drink\n" |
| | "43. raspberry\n" |
| | "44. rice\n" |
| | "45. mandarin\n" |
| | "46. juice\n" |
| | "47. porridge\n" |
| | "48. fish_soup\n" |
| | "49. fish_hamburger\n" |
| | "50. fruit\n" |
| | "51. coffee\n" |
| | "52. plant_based_ice_cream\n" |
| | "53. beef_sausage\n" |
| | "54. minced_pork\n" |
| | "55. meat\n" |
| | "56. sweet_pastry\n" |
| | "57. ice_cream\n" |
| | "58. fish_salad\n" |
| | "59. flour\n" |
| | "60. snack\n" |
| | "61. fish_product\n" |
| | "62. cherry\n" |
| | "63. shellfish\n" |
| | "64. cream\n" |
| | "65. dessert\n" |
| | "66. cold_cut_chicken\n" |
| | "67. onion\n" |
| | "68. dark_bread\n" |
| | "69. plant_based_milk\n" |
| | "70. fermented_milk\n" |
| | "71. cocoa\n" |
| | "72. bell_pepper\n" |
| | "73. beef\n" |
| | "74. sausage\n" |
| | "75. dairy_product\n" |
| | "76. pork\n" |
| | "77. fish_fillet\n" |
| | "78. strong_alcoholic_beverage\n" |
| | "79. biscuit\n" |
| | "80. currant\n" |
| | "81. sweet_soup\n" |
| | "82. popcorn\n" |
| | "83. meat_sauce\n" |
| | "84. meat_pizza\n" |
| | "85. lemon\n" |
| | "86. plum\n" |
| | "87. grain_products\n" |
| | "88. sugar_honey_syrup\n" |
| | "89. vegetarian_hamburger\n" |
| | "90. jam\n" |
| | "91. vegetarian_stew\n" |
| | "92. beef_steak\n" |
| | "93. tomato\n" |
| | "94. meat_soup\n" |
| | "95. block_cheese\n" |
| | "96. potato_based\n" |
| | "97. potato\n" |
| | "98. flakes\n" |
| | "99. soft_cheese\n" |
| | "100. beef_product\n" |
| | "101. chocolate\n" |
| | "102. cauliflower\n" |
| | "103. nectarine\n" |
| | "104. banana\n" |
| | "105. apple\n" |
| | "106. alcoholic_beverage\n" |
| | "107. meat_salad\n" |
| | "108. pork_sausage\n" |
| | "109. lingonberry\n" |
| | "110. chicken_sausage\n" |
| | "111. pork_frankfurter\n" |
| | "112. minced_meat\n" |
| | "113. baked_goods\n" |
| | "114. broccoli\n" |
| | "115. cold_cut_beef\n" |
| | "116. meat_dish\n" |
| | "117. plant_cheese\n" |
| | "118. butter\n" |
| | "119. children_milk\n" |
| | "120. fish_sauce\n" |
| | "121. buttermilk\n" |
| | "122. cucumber\n" |
| | "123. meat_hamburger\n" |
| | "124. pasta\n" |
| | "125. tea\n" |
| | "126. berries\n" |
| | "127. milk\n" |
| | "128. fish_pizza\n" |
| | "129. fresh_cheese\n" |
| | "130. fish_stew\n" |
| | "131. sirup\n" |
| | "132. vegetarian_salad\n" |
| | "133. spice\n" |
| | "134. quark\n" |
| | "135. cold_cut_pork\n" |
| | "136. sauce_seasoning\n" |
| | "137. oatmeal\n" |
| | "138. sugar\n" |
| | "139. meat_product\n" |
| | "140. vegetarian_dish\n" |
| | "141. savory_pastry\n" |
| | "142. chicken_product\n" |
| | "143. vegetable_mix\n" |
| | "144. potato_chip\n" |
| | "145. muesli\n" |
| | "146. carrot\n" |
| | "147. plant_based_yogurt\n" |
| | "148. sweet\n" |
| | "149. eggs\n" |
| | "150. chicken_frankfurter\n" |
| | "151. bread\n" |
| | "152. fish_strips\n" |
| | "153. lettuce\n" |
| | "154. fish\n" |
| | "155. meat_stew\n" |
| | "156. hot_beverage\n" |
| | "157. cherry tomato\n" |
| | "158. fish_dish\n" |
| | "159. fat_oil\n" |
| | "160. grape\n" |
| | "161. plant_protein\n" |
| | "162. pastry\n" |
| | "163. oil\n" |
| | "164. cold_cut_meat\n\n" |
| | "OUTPUT FORMAT (single line):\n" |
| | "CATEGORY1,CATEGORY2,CATEGORY3\n" |
| | ) |
| |
|
| | _ALLOWED_RE = re.compile(r"[^a-z0-9_\(\);,\|\/\-\s\?\.\:]") |
| | _CATEGORY_LINE_RE = re.compile(r"^\s*\d+\.\s*(.+?)\s*$", flags=re.MULTILINE) |
| |
|
| | ITEM_RE = re.compile(r"\(\s*(.*?)\s*\)\s*\|\s*(\d+)", flags=re.DOTALL) |
| |
|
| | def _extract_categories_from_prompt(prompt: str) -> Set[str]: |
| | cats = {m.group(1).strip().lower() for m in _CATEGORY_LINE_RE.finditer(prompt)} |
| | return {c for c in cats if c} |
| |
|
| | def _clean_model_output(s: str) -> str: |
| | s = _clean_text(s).lower() |
| | s = _ALLOWED_RE.sub("", s) |
| | return s |
| |
|
| | def _parse_and_validate_categories(raw: str, allowed: Set[str]) -> List[str]: |
| | s = _clean_model_output(raw) |
| | parts = [p.strip() for p in s.split(",") if p.strip()] |
| | out: List[str] = [] |
| | seen: Set[str] = set() |
| | for p in parts: |
| | if p in allowed and p not in seen: |
| | out.append(p) |
| | seen.add(p) |
| | return out |
| |
|
| | |
| | |
| | |
| |
|
| | ANS_START = "<answers>" |
| |
|
| | @dataclass |
| | class WeightGenConfig: |
| | max_new_tokens: int = 8 |
| | temperature: float = 0.0 |
| | no_repeat_ngram_size: int = 0 |
| | repetition_penalty: float = 1.0 |
| | max_length: int = 3000 |
| | max_side: int = 768 |
| | max_w: int = 5000 |
| |
|
| | WEIGHT_DEFAULT_PROMPT = """ |
| | Here is a photo of some food waste. |
| | |
| | Your task is to provide a weight estimate (in grams) for the following food waste: |
| | {SUB_CATEGORY} |
| | |
| | While providing your weight estimate, please keep in mind the following: |
| | *** there might be other food items on the photo, I want you to provide the weight estimate for {SUB_CATEGORY} only! |
| | *** make sure to provide your best guess, never say 0! |
| | |
| | OUTPUT (exactly one line, no extra text): |
| | WEIGHT |
| | """.strip() |
| |
|
| | _INT_RE = re.compile(r"[0-9]+") |
| |
|
| | def _extract_int_digits_only(text: str, default: int = 1) -> int: |
| | t = _clean_text(text) |
| | digits = "".join(ch for ch in t if "0" <= ch <= "9") |
| | if digits: |
| | return int(digits) |
| | |
| | m= _INT_RE.search(t) |
| | if not m: |
| | return max(1, default) |
| | return max(1, int(m.group(0))) |
| |
|
| | def _resolve_sub_category(inputs: Dict[str, Any]) -> str: |
| | subcat = inputs.get("sub_category") |
| | if not subcat or not str(subcat).strip(): |
| | raise ValueError("Missing 'subcat' (or 'sub_category') in 'inputs'.") |
| | return str(subcat).strip() |
| |
|
| | def _clamp_w(w: int, lo: int, hi: int) -> int: |
| | return max(lo, min(hi, int(w))) |
| |
|
| | _TRIE_END = "_end_" |
| |
|
| | def _build_trie(token_seqs): |
| | trie = {} |
| | for seq in token_seqs: |
| | node = trie |
| | for tok in seq: |
| | node = node.setdefault(tok, {}) |
| | node[_TRIE_END] = True |
| | return trie |
| |
|
| | def _trie_node_for_prefix(trie, prefix): |
| | node = trie |
| | for tok in prefix: |
| | node = node.get(tok) |
| | if node is None: |
| | return None |
| | return node |
| |
|
| | class RangeConstraint: |
| | def __init__(self, tokenizer, min_w=1, max_w=1200): |
| | self.tok = tokenizer |
| | seqs = [] |
| | for w in range(min_w, max_w + 1): |
| | seqs.append(tokenizer.encode(str(w), add_special_tokens=False)) |
| | self.trie = _build_trie(seqs) |
| | self.prompt_len = None |
| |
|
| | def prefix_allowed_tokens_fn(self, batch_id, input_ids): |
| | ids = input_ids[batch_id] if getattr(input_ids, "dim", lambda: 1)() == 2 else input_ids |
| | gen_prefix = ids[self.prompt_len :].tolist() |
| | node = _trie_node_for_prefix(self.trie, gen_prefix) |
| | if node is None: |
| | return [self.tok.eos_token_id] |
| | allowed = [k for k in node.keys() if k != _TRIE_END] |
| | if node.get(_TRIE_END) and self.tok.eos_token_id is not None: |
| | allowed.append(self.tok.eos_token_id) |
| | return allowed |
| |
|
| | |
| | |
| | |
| |
|
| | class EndpointHandler: |
| |
|
| | def _load_processor_and_model(self, repo: str): |
| | dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| | processor = AutoProcessor.from_pretrained( |
| | repo, |
| | trust_remote_code=True, |
| | ) |
| | model = LlavaForConditionalGeneration.from_pretrained( |
| | repo, |
| | torch_dtype=dtype, |
| | low_cpu_mem_usage=True, |
| | device_map={"": self.device}, |
| | trust_remote_code=True, |
| | ) |
| | model.eval() |
| | logger.info("Model and processor successfully loaded from '%s'.", repo) |
| | return processor, model |
| | |
| | @staticmethod |
| | def eos_and_pad_token_id(processor, model) -> int: |
| | tokenizer = getattr(processor, "tokenizer", None) |
| | if tokenizer is not None and tokenizer.pad_token_id is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | tokenizer.pad_token_id = tokenizer.eos_token_id |
| |
|
| | eos_candidates: List[int] = [] |
| | if model.config.eos_token_id is not None: |
| | eos_candidates.append(model.config.eos_token_id) |
| | if tokenizer is not None and tokenizer.eos_token_id is not None: |
| | eos_candidates.append(tokenizer.eos_token_id) |
| |
|
| | eos_token_ids: List[int] = list({i for i in eos_candidates}) |
| | if not eos_token_ids: |
| | raise ValueError("No EOS token id found on model or tokenizer.") |
| |
|
| | pad_id: Optional[int] = getattr(model.config, "pad_token_id", None) |
| | if pad_id is None and tokenizer is not None: |
| | pad_id = tokenizer.pad_token_id |
| | if pad_id is None: |
| | pad_id = eos_token_ids[0] |
| |
|
| | return eos_token_ids, pad_id |
| |
|
| | |
| |
|
| | def __init__(self, path: str = ".") -> None: |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | logger.info("Initializing EndpointHandler on device: %s", self.device) |
| |
|
| | self.cat_processor, self.cat_model = self._load_processor_and_model(CATEGORIES_REPO) |
| |
|
| | self.cat_eos_token_ids, self.cat_pad_token_id = self.eos_and_pad_token_id( |
| | self.cat_processor, |
| | self.cat_model, |
| | ) |
| | self.cat_gen_config = CategoriesGenConfig() |
| | logger.info( |
| | "Generation config: max_new_tokens=%d, temperature=%.3f", |
| | self.cat_gen_config.max_new_tokens, |
| | self.cat_gen_config.temperature, |
| | ) |
| |
|
| | self.weight_processor, self.weight_model = self._load_processor_and_model(WEIGHT_REPO) |
| | self.weight_eos_token_ids, self.weight_pad_token_id = self.eos_and_pad_token_id( |
| | self.weight_processor, |
| | self.weight_model, |
| | ) |
| | self.weight_gen_config = WeightGenConfig() |
| | logger.info( |
| | "Generation config: max_new_tokens=%d, temperature=%.3f", |
| | self.weight_gen_config.max_new_tokens, |
| | self.weight_gen_config.temperature, |
| | ) |
| | self.range_constraint = RangeConstraint(self.weight_processor.tokenizer, min_w=1, max_w=self.weight_gen_config.max_w) |
| |
|
| | self.default_allowed_categories: Set[str] = _extract_categories_from_prompt(CATEGORIES_DEFAULT_PROMPT) |
| | logger.info( |
| | "Extracted %d categories from CATEGORIES_DEFAULT_PROMPT.", |
| | len(self.default_allowed_categories), |
| | ) |
| |
|
| | def _predict_categories(self, image: Image.Image, debug: bool) -> List[str]: |
| | max_length = int(self.cat_gen_config.max_length) |
| | max_side = int(self.cat_gen_config.max_side) |
| | image = _resize_max_side(image, max_side=max_side) |
| |
|
| | max_new_tokens = int(self.cat_gen_config.max_new_tokens) |
| | temperature = float(self.cat_gen_config.temperature) |
| |
|
| | chat_text = _build_chat_text(CATEGORIES_DEFAULT_PROMPT, self.cat_processor) |
| |
|
| | enc = self.cat_processor( |
| | text=[chat_text], |
| | images=[image], |
| | return_tensors="pt", |
| | truncation=False, |
| | padding=False, |
| | ) |
| |
|
| | prompt_len = int(enc["input_ids"].shape[1]) |
| | if prompt_len >= max_length: |
| | raise ValueError( |
| | f"Prompt too long ({prompt_len} tokens) for max_length={max_length}. " |
| | "Increase max_length or shorten the prompt." |
| | ) |
| |
|
| | tokens_left = max(1, max_length - prompt_len) |
| | max_new_tokens = min(max_new_tokens, tokens_left) |
| |
|
| | if debug: |
| | logger.info("===== TOKEN BUDGET DEBUG (inference) =====") |
| | logger.info("img size: %s", getattr(image, "size", None)) |
| | logger.info("max_length: %d", max_length) |
| | logger.info("prompt_len: %d", prompt_len) |
| | logger.info("tokens_left_for_answer(approx): %d", tokens_left) |
| | tok = getattr(self.cat_processor, "tokenizer", None) |
| | if tok is not None: |
| | ids = enc["input_ids"][0].tolist() |
| | logger.info("[prompt head]\n%s", tok.decode(ids[:120], skip_special_tokens=False)) |
| | logger.info("[prompt tail]\n%s", tok.decode(ids[-120:], skip_special_tokens=False)) |
| | logger.info("=========================================") |
| |
|
| | enc = {k: v.to(self.device) for k, v in enc.items()} |
| | if "pixel_values" in enc: |
| | enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self.cat_model.dtype) |
| |
|
| | gen_kwargs: Dict[str, Any] = { |
| | "max_new_tokens": max_new_tokens, |
| | "do_sample": temperature > 0.0, |
| | "eos_token_id": self.cat_eos_token_ids, |
| | "pad_token_id": self.cat_pad_token_id, |
| | "no_repeat_ngram_size": self.cat_gen_config.no_repeat_ngram_size, |
| | "repetition_penalty": self.cat_gen_config.repetition_penalty, |
| | } |
| | if temperature > 0.0: |
| | gen_kwargs["temperature"] = temperature |
| |
|
| | with torch.inference_mode(): |
| | output_ids = self.cat_model.generate(**enc, **gen_kwargs) |
| |
|
| | generated_only = output_ids[:, enc["input_ids"].shape[1] :] |
| | generated_text = self.cat_processor.batch_decode( |
| | generated_only, |
| | skip_special_tokens=True, |
| | )[0].strip() |
| |
|
| | cats = _parse_and_validate_categories(generated_text, self.default_allowed_categories) |
| | if not cats: |
| | cats = ["food_waste"] if "food_waste" in self.default_allowed_categories else [] |
| | return cats |
| |
|
| | def _predict_weight(self, image: Image.Image, sub_category: str, debug: bool) -> int: |
| | |
| | max_w = int(self.weight_gen_config.max_w) |
| | prompt = WEIGHT_DEFAULT_PROMPT.format(SUB_CATEGORY=sub_category).rstrip() |
| | prompt += ( |
| | "\n\nRules:\n" |
| | "- output digits only (no units, no text)\n" |
| | f"- output an integer between 1 and {self.weight_gen_config.max_w}\n" |
| | "- never output 0\n" |
| | f"\n{ANS_START}" |
| | ) |
| |
|
| | max_length = int(self.weight_gen_config.max_length) |
| | max_side = int(self.weight_gen_config.max_side) |
| | image = _resize_max_side(image, max_side=max_side) |
| |
|
| | max_new_tokens = int(self.weight_gen_config.max_new_tokens) |
| | temperature = float(self.weight_gen_config.temperature) |
| |
|
| | chat_text = _build_chat_text(prompt, self.weight_processor) |
| |
|
| | enc = self.weight_processor( |
| | text=[chat_text], |
| | images=[image], |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=max_length, |
| | padding=False, |
| | ) |
| |
|
| | prompt_len = int(enc["input_ids"].shape[1]) |
| | self.range_constraint.prompt_len = prompt_len |
| | tokens_left = max(1, max_length - prompt_len) |
| | max_new_tokens = min(max_new_tokens, tokens_left) |
| |
|
| | if debug: |
| | logger.info("===== TOKEN BUDGET DEBUG (inference) =====") |
| | logger.info("sub_category: %s", sub_category) |
| | logger.info("img size: %s", getattr(image, "size", None)) |
| | logger.info("max_length: %d", max_length) |
| | logger.info("prompt_len: %d", prompt_len) |
| | logger.info("tokens_left_for_answer(approx): %d", tokens_left) |
| | logger.info("=========================================") |
| |
|
| | enc = {k: v.to(self.device) for k, v in enc.items()} |
| | if "pixel_values" in enc: |
| | enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self.weight_model.dtype) |
| |
|
| | gen_kwargs: Dict[str, Any] = { |
| | "max_new_tokens": max_new_tokens, |
| | "do_sample": temperature > 0.0, |
| | "eos_token_id": self.weight_eos_token_ids, |
| | "pad_token_id": self.weight_pad_token_id, |
| | "no_repeat_ngram_size": self.weight_gen_config.no_repeat_ngram_size, |
| | "repetition_penalty": self.weight_gen_config.repetition_penalty, |
| | "prefix_allowed_tokens_fn": self.range_constraint.prefix_allowed_tokens_fn |
| | } |
| | if temperature > 0.0: |
| | gen_kwargs["temperature"] = temperature |
| |
|
| | with torch.inference_mode(): |
| | output_ids = self.weight_model.generate(**enc, **gen_kwargs) |
| |
|
| | generated_only = output_ids[:, enc["input_ids"].shape[1] :] |
| | generated_text_raw = self.weight_processor.batch_decode( |
| | generated_only, |
| | skip_special_tokens=True, |
| | clean_up_tokenization_spaces=False, |
| | )[0].strip() |
| |
|
| | weight_g = _extract_int_digits_only(generated_text_raw, default=1) |
| | weight_g = _clamp_w(weight_g, lo=1, hi=max_w) |
| |
|
| | if debug: |
| | logger.info("Generated raw: %s", generated_text_raw) |
| | logger.info("Final weight_g: %d", weight_g) |
| |
|
| | return weight_g |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | inputs = data.get("inputs", data) |
| | debug = bool(inputs.get("debug", False)) |
| |
|
| | image_b64: Optional[str] = inputs.get("image") |
| | if not image_b64: |
| | raise ValueError("Missing 'image' field (base64-encoded) in 'inputs'.") |
| |
|
| | image = _decode_image(image_b64) |
| |
|
| | cats = self._predict_categories(image, debug=debug) |
| | items: List[Dict[str, Any]] = [] |
| | weights_by_category: Dict[str, int] = {} |
| | total_weight_g = 0 |
| | for cat in cats: |
| | w_pred = int(self._predict_weight(image, sub_category=cat, debug=debug)) |
| | items.append({"category": cat, "weight_g": w_pred}) |
| | weights_by_category[cat] = w_pred |
| | total_weight_g += w_pred |
| |
|
| | if debug: |
| | logger.info("Predicted categories: %s", cats) |
| | logger.info("Weights by category: %s", weights_by_category) |
| | logger.info("Total weight (g): %d", total_weight_g) |
| |
|
| | return { |
| | "categories": cats, |
| | "items": items, |
| | "weights": weights_by_category, |
| | "total_weight_g": total_weight_g, |
| | } |
| | |
| |
|