Update handler.py
Browse files- handler.py +426 -184
handler.py
CHANGED
|
@@ -1,184 +1,426 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import base64
|
| 4 |
-
from
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
import
|
| 9 |
-
from
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
"
|
| 24 |
-
"
|
| 25 |
-
"
|
| 26 |
-
"
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
from curses import raw
|
| 5 |
+
import re
|
| 6 |
+
import unicodedata
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
from typing import Any, Dict, Optional, List, Set
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
| 14 |
+
from transformers.utils import logging
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.get_logger(__name__)
|
| 18 |
+
logging.set_verbosity_info()
|
| 19 |
+
|
| 20 |
+
DEFAULT_PROMPT = (
|
| 21 |
+
"Here is a picture showing some food waste.\n\n"
|
| 22 |
+
"Task: Provide a list of the food waste items visible in the picture.\n"
|
| 23 |
+
"For each item, output EXACTLY one CATEGORY chosen from the CATEGORY DICTIONARY below.\n\n"
|
| 24 |
+
"Rules:\n"
|
| 25 |
+
"- Output must be a single line (no extra text, no markdown).\n"
|
| 26 |
+
"- Use the exact spelling of CATEGORY as listed.\n"
|
| 27 |
+
"- If you are unsure, choose the closest broader category (e.g., 'fruit', 'vegetable', 'food_waste').\n\n"
|
| 28 |
+
"CATEGORY DICTIONARY:\n"
|
| 29 |
+
"1. cucumber\n"
|
| 30 |
+
"2. tomato\n"
|
| 31 |
+
"3. cherry tomato\n"
|
| 32 |
+
"4. carrot\n"
|
| 33 |
+
"5. lettuce\n"
|
| 34 |
+
"6. bell_pepper\n"
|
| 35 |
+
"7. cabbage\n"
|
| 36 |
+
"8. cauliflower\n"
|
| 37 |
+
"9. broccoli\n"
|
| 38 |
+
"10. onion\n"
|
| 39 |
+
"11. herbs\n"
|
| 40 |
+
"12. vegetable_mix\n"
|
| 41 |
+
"13. vegetable\n"
|
| 42 |
+
"14. potato\n"
|
| 43 |
+
"15. potato_product\n"
|
| 44 |
+
"16. potato_based\n"
|
| 45 |
+
"17. banana\n"
|
| 46 |
+
"18. apple\n"
|
| 47 |
+
"19. mandarin\n"
|
| 48 |
+
"20. orange\n"
|
| 49 |
+
"21. lemon\n"
|
| 50 |
+
"22. grape\n"
|
| 51 |
+
"23. plum\n"
|
| 52 |
+
"24. nectarine\n"
|
| 53 |
+
"25. dried fruits\n"
|
| 54 |
+
"26. jam\n"
|
| 55 |
+
"27. cherry\n"
|
| 56 |
+
"28. fruit\n"
|
| 57 |
+
"29. blueberry\n"
|
| 58 |
+
"30. strawberry\n"
|
| 59 |
+
"31. raspberry\n"
|
| 60 |
+
"32. currant\n"
|
| 61 |
+
"33. lingonberry\n"
|
| 62 |
+
"34. berries\n"
|
| 63 |
+
"35. rice\n"
|
| 64 |
+
"36. pasta\n"
|
| 65 |
+
"37. rice_or_pasta\n"
|
| 66 |
+
"38. dark_bread\n"
|
| 67 |
+
"39. light_bread\n"
|
| 68 |
+
"40. bread\n"
|
| 69 |
+
"41. oatmeal\n"
|
| 70 |
+
"42. semolina_porridge\n"
|
| 71 |
+
"43. rice_porridge\n"
|
| 72 |
+
"44. whipped_porridge\n"
|
| 73 |
+
"45. porridge\n"
|
| 74 |
+
"46. savory_pastry\n"
|
| 75 |
+
"47. sweet_pastry\n"
|
| 76 |
+
"48. biscuit\n"
|
| 77 |
+
"49. baked_goods\n"
|
| 78 |
+
"50. flour\n"
|
| 79 |
+
"51. flakes\n"
|
| 80 |
+
"52. muesli\n"
|
| 81 |
+
"53. cereal\n"
|
| 82 |
+
"54. grain_products\n"
|
| 83 |
+
"55. milk\n"
|
| 84 |
+
"56. plant_based_milk\n"
|
| 85 |
+
"57. children_milk\n"
|
| 86 |
+
"58. dairy_product_milk\n"
|
| 87 |
+
"59. block_cheese\n"
|
| 88 |
+
"60. sliced_cheese\n"
|
| 89 |
+
"61. fresh_cheese\n"
|
| 90 |
+
"62. soft_cheese\n"
|
| 91 |
+
"63. plant_cheese\n"
|
| 92 |
+
"64. cheese\n"
|
| 93 |
+
"65. yoghurt\n"
|
| 94 |
+
"66. buttermilk\n"
|
| 95 |
+
"67. quark\n"
|
| 96 |
+
"68. fermented_milk\n"
|
| 97 |
+
"69. cream\n"
|
| 98 |
+
"70. ice_cream\n"
|
| 99 |
+
"71. plant_based_yogurt\n"
|
| 100 |
+
"72. plant_based_cream\n"
|
| 101 |
+
"73. plant_based_ice_cream\n"
|
| 102 |
+
"74. dairy_product\n"
|
| 103 |
+
"75. egg\n"
|
| 104 |
+
"76. eggs\n"
|
| 105 |
+
"77. beef_steak\n"
|
| 106 |
+
"78. minced_beef\n"
|
| 107 |
+
"79. cold_cut_beef\n"
|
| 108 |
+
"80. beef_sausage\n"
|
| 109 |
+
"81. beef_frankfurter\n"
|
| 110 |
+
"82. beef_product\n"
|
| 111 |
+
"83. beef\n"
|
| 112 |
+
"84. pork_steak\n"
|
| 113 |
+
"85. minced_pork\n"
|
| 114 |
+
"86. cold_cut_pork\n"
|
| 115 |
+
"87. pork_sausage\n"
|
| 116 |
+
"88. pork_frankfurter\n"
|
| 117 |
+
"89. pork_product\n"
|
| 118 |
+
"90. pork\n"
|
| 119 |
+
"91. chicken\n"
|
| 120 |
+
"92. minced_chicken\n"
|
| 121 |
+
"93. cold_cut_chicken\n"
|
| 122 |
+
"94. chicken_sausage\n"
|
| 123 |
+
"95. chicken_frankfurter\n"
|
| 124 |
+
"96. chicken_product\n"
|
| 125 |
+
"97. steak\n"
|
| 126 |
+
"98. minced_meat\n"
|
| 127 |
+
"99. cold_cut_meat\n"
|
| 128 |
+
"100. sausage\n"
|
| 129 |
+
"101. frankfurter\n"
|
| 130 |
+
"102. meat_product\n"
|
| 131 |
+
"103. plant_protein\n"
|
| 132 |
+
"104. meat\n"
|
| 133 |
+
"105. fish_fillet\n"
|
| 134 |
+
"106. fish_strips\n"
|
| 135 |
+
"107. fish_product\n"
|
| 136 |
+
"108. shellfish\n"
|
| 137 |
+
"109. fish\n"
|
| 138 |
+
"110. meat_soup\n"
|
| 139 |
+
"111. meat_stew\n"
|
| 140 |
+
"112. meat_sauce\n"
|
| 141 |
+
"113. meat_pizza\n"
|
| 142 |
+
"114. meat_hamburger\n"
|
| 143 |
+
"115. meat_salad\n"
|
| 144 |
+
"116. meat_dish\n"
|
| 145 |
+
"117. fish_soup\n"
|
| 146 |
+
"118. fish_stew\n"
|
| 147 |
+
"119. fish_sauce\n"
|
| 148 |
+
"120. fish_pizza\n"
|
| 149 |
+
"121. fish_hamburger\n"
|
| 150 |
+
"122. fish_salad\n"
|
| 151 |
+
"123. fish_dish\n"
|
| 152 |
+
"124. vegetarian_soup\n"
|
| 153 |
+
"125. vegetarian_stew\n"
|
| 154 |
+
"126. vegetarian_sauce\n"
|
| 155 |
+
"127. vegetarian_pizza\n"
|
| 156 |
+
"128. vegetarian_hamburger\n"
|
| 157 |
+
"129. vegetarian_salad\n"
|
| 158 |
+
"130. vegetarian_dish\n"
|
| 159 |
+
"131. pastry\n"
|
| 160 |
+
"132. sweet_soup\n"
|
| 161 |
+
"133. quark\n"
|
| 162 |
+
"134. ice_cream\n"
|
| 163 |
+
"135. crepes\n"
|
| 164 |
+
"136. jam\n"
|
| 165 |
+
"137. dessert\n"
|
| 166 |
+
"138. chocolate\n"
|
| 167 |
+
"139. candy\n"
|
| 168 |
+
"140. sweet\n"
|
| 169 |
+
"141. popcorn\n"
|
| 170 |
+
"142. potato_chip\n"
|
| 171 |
+
"143. snack\n"
|
| 172 |
+
"144. sauce_paste\n"
|
| 173 |
+
"145. spice\n"
|
| 174 |
+
"146. sauce_seasoning\n"
|
| 175 |
+
"147. butter\n"
|
| 176 |
+
"148. oil\n"
|
| 177 |
+
"149. fat_oil\n"
|
| 178 |
+
"150. juice\n"
|
| 179 |
+
"151. soda\n"
|
| 180 |
+
"152. soft_drink\n"
|
| 181 |
+
"153. coffee\n"
|
| 182 |
+
"154. tea\n"
|
| 183 |
+
"155. cocoa\n"
|
| 184 |
+
"156. hot_beverage\n"
|
| 185 |
+
"157. beer\n"
|
| 186 |
+
"158. cider\n"
|
| 187 |
+
"159. alcoholic_long_drink\n"
|
| 188 |
+
"160. wine\n"
|
| 189 |
+
"161. strong_alcoholic_beverage\n"
|
| 190 |
+
"162. alcoholic_beverage\n"
|
| 191 |
+
"163. sugar\n"
|
| 192 |
+
"164. honey\n"
|
| 193 |
+
"165. sirup\n"
|
| 194 |
+
"166. sugar_honey_syrup\n"
|
| 195 |
+
"167. food_waste\n\n"
|
| 196 |
+
"OUTPUT FORMAT (single line):\n"
|
| 197 |
+
"CATEGORY1,CATEGORY2,CATEGORY3\n"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@dataclass
|
| 202 |
+
class GenerationConfig:
|
| 203 |
+
|
| 204 |
+
max_new_tokens: int = 256
|
| 205 |
+
temperature: float = 0.0
|
| 206 |
+
no_repeat_ngram_size: int = 6
|
| 207 |
+
repetition_penalty: float = 1.1
|
| 208 |
+
max_length: int = 4096
|
| 209 |
+
max_side: int = 512
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
_ALLOWED_RE = re.compile(r"[^a-z0-9_\(\);,\|\/\-\s\?\.\:]")
|
| 213 |
+
|
| 214 |
+
ITEM_RE = re.compile(r"\(\s*(.*?)\s*\)\s*\|\s*(\d+)", flags=re.DOTALL)
|
| 215 |
+
|
| 216 |
+
def _clean_text(s: str) -> str:
|
| 217 |
+
s = unicodedata.normalize("NFKC", s).replace("\u00A0", " ")
|
| 218 |
+
s = re.sub(r"[\u200B-\u200F]", "", s)
|
| 219 |
+
s = re.sub(r"\s+", " ", s).strip()
|
| 220 |
+
return s
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
_CATEGORY_LINE_RE = re.compile(r"^\s*\d+\.\s*(.+?)\s*$", flags=re.MULTILINE)
|
| 224 |
+
|
| 225 |
+
def _extract_categories_from_prompt(prompt: str) -> Set[str]:
|
| 226 |
+
cats = {m.group(1).strip().lower() for m in _CATEGORY_LINE_RE.finditer(prompt)}
|
| 227 |
+
return {c for c in cats if c}
|
| 228 |
+
|
| 229 |
+
def _clean_model_output(s: str) -> str:
|
| 230 |
+
s = _clean_text(s).lower()
|
| 231 |
+
s = _ALLOWED_RE.sub("", s)
|
| 232 |
+
return s
|
| 233 |
+
|
| 234 |
+
def _parse_and_validate_categories(raw: str, allowed: Set[str]) -> List[str]:
|
| 235 |
+
s = _clean_model_output(raw)
|
| 236 |
+
parts = [p.strip() for p in s.split(",") if p.strip()]
|
| 237 |
+
out: List[str] = []
|
| 238 |
+
seen: Set[str] = set()
|
| 239 |
+
for p in parts:
|
| 240 |
+
if p in allowed and p not in seen:
|
| 241 |
+
out.append(p)
|
| 242 |
+
seen.add(p)
|
| 243 |
+
return out
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class EndpointHandler:
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def __init__(self, path: str = ".") -> None:
|
| 250 |
+
"""
|
| 251 |
+
Initializes the model and processor from the `path` directory,
|
| 252 |
+
which contains the merged weights (pixtral-12b-foodwaste-merged).
|
| 253 |
+
"""
|
| 254 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 255 |
+
logger.info("Initializing EndpointHandler on device: %s", self.device)
|
| 256 |
+
self.processor = AutoProcessor.from_pretrained(
|
| 257 |
+
path,
|
| 258 |
+
trust_remote_code=True,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 262 |
+
self.model = LlavaForConditionalGeneration.from_pretrained(
|
| 263 |
+
path,
|
| 264 |
+
torch_dtype=dtype,
|
| 265 |
+
low_cpu_mem_usage=True,
|
| 266 |
+
device_map={"": self.device},
|
| 267 |
+
trust_remote_code=True,
|
| 268 |
+
)
|
| 269 |
+
self.model.eval()
|
| 270 |
+
logger.info("Model and processor successfully loaded from '%s'.", path)
|
| 271 |
+
|
| 272 |
+
# pad token management
|
| 273 |
+
tokenizer = getattr(self.processor, "tokenizer", None)
|
| 274 |
+
if tokenizer is not None and tokenizer.pad_token_id is None:
|
| 275 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 276 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 277 |
+
|
| 278 |
+
# Preparation of EOS/PAD IDs for generate
|
| 279 |
+
eos_candidates: List[int] = []
|
| 280 |
+
if self.model.config.eos_token_id is not None:
|
| 281 |
+
eos_candidates.append(self.model.config.eos_token_id)
|
| 282 |
+
if tokenizer is not None and tokenizer.eos_token_id is not None:
|
| 283 |
+
eos_candidates.append(tokenizer.eos_token_id)
|
| 284 |
+
|
| 285 |
+
self.eos_token_ids: List[int] = list({i for i in eos_candidates})
|
| 286 |
+
if not self.eos_token_ids:
|
| 287 |
+
raise ValueError("No EOS token id found on model or tokenizer.")
|
| 288 |
+
|
| 289 |
+
pad_id: Optional[int] = getattr(self.model.config, "pad_token_id", None)
|
| 290 |
+
if pad_id is None and tokenizer is not None:
|
| 291 |
+
pad_id = tokenizer.pad_token_id
|
| 292 |
+
if pad_id is None:
|
| 293 |
+
pad_id = self.eos_token_ids[0]
|
| 294 |
+
|
| 295 |
+
self.pad_token_id: int = pad_id
|
| 296 |
+
|
| 297 |
+
self.gen_config = GenerationConfig()
|
| 298 |
+
logger.info(
|
| 299 |
+
"Generation config: max_new_tokens=%d, temperature=%.3f",
|
| 300 |
+
self.gen_config.max_new_tokens,
|
| 301 |
+
self.gen_config.temperature,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
self.default_allowed_categories: Set[str] = _extract_categories_from_prompt(DEFAULT_PROMPT)
|
| 305 |
+
logger.info("Extracted %d categories from DEFAULT_PROMPT.", len(self.default_allowed_categories))
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
@staticmethod
|
| 310 |
+
def _decode_image(image_b64: str) -> Image.Image:
|
| 311 |
+
|
| 312 |
+
try:
|
| 313 |
+
img_bytes = base64.b64decode(image_b64)
|
| 314 |
+
img = Image.open(BytesIO(img_bytes)).convert("RGB")
|
| 315 |
+
return img
|
| 316 |
+
except Exception as exc: # pragma: no cover - log production
|
| 317 |
+
raise ValueError(f"Could not decode base64 image: {exc}") from exc
|
| 318 |
+
|
| 319 |
+
@staticmethod
|
| 320 |
+
def _resize_max_side(img: Image.Image, max_side: int) -> Image.Image:
|
| 321 |
+
w, h = img.size
|
| 322 |
+
m = max(w, h)
|
| 323 |
+
if m <= max_side:
|
| 324 |
+
return img
|
| 325 |
+
scale = max_side / m
|
| 326 |
+
# LANCZOS = better downscale
|
| 327 |
+
return img.resize((int(w * scale), int(h * scale)), resample=Image.Resampling.LANCZOS)
|
| 328 |
+
|
| 329 |
+
def _build_chat_text(self, prompt: str) -> str:
|
| 330 |
+
|
| 331 |
+
messages = [
|
| 332 |
+
{
|
| 333 |
+
"role": "user",
|
| 334 |
+
"content": [
|
| 335 |
+
{"type": "text", "text": prompt},
|
| 336 |
+
{"type": "image"},
|
| 337 |
+
],
|
| 338 |
+
}
|
| 339 |
+
]
|
| 340 |
+
|
| 341 |
+
chat_text = self.processor.apply_chat_template(
|
| 342 |
+
messages,
|
| 343 |
+
add_generation_prompt=True,
|
| 344 |
+
tokenize=False,
|
| 345 |
+
)
|
| 346 |
+
return chat_text
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 350 |
+
inputs = data.get("inputs", data)
|
| 351 |
+
debug = bool(inputs.get("debug", False))
|
| 352 |
+
|
| 353 |
+
prompt: str = inputs.get("prompt") or DEFAULT_PROMPT
|
| 354 |
+
allowed_categories = _extract_categories_from_prompt(prompt) or self.default_allowed_categories
|
| 355 |
+
image_b64: Optional[str] = inputs.get("image")
|
| 356 |
+
if not image_b64:
|
| 357 |
+
raise ValueError("Missing 'image' field (base64-encoded) in 'inputs'.")
|
| 358 |
+
|
| 359 |
+
image = self._decode_image(image_b64)
|
| 360 |
+
|
| 361 |
+
max_length = int(inputs.get("max_length", self.gen_config.max_length))
|
| 362 |
+
max_side = int(inputs.get("max_side", self.gen_config.max_side))
|
| 363 |
+
image = self._resize_max_side(image, max_side=max_side)
|
| 364 |
+
|
| 365 |
+
max_new_tokens = int(inputs.get("max_new_tokens", self.gen_config.max_new_tokens))
|
| 366 |
+
temperature = float(inputs.get("temperature", self.gen_config.temperature))
|
| 367 |
+
|
| 368 |
+
chat_text = self._build_chat_text(prompt)
|
| 369 |
+
|
| 370 |
+
enc = self.processor(
|
| 371 |
+
text=[chat_text],
|
| 372 |
+
images=[image],
|
| 373 |
+
return_tensors="pt",
|
| 374 |
+
truncation=True,
|
| 375 |
+
max_length=max_length,
|
| 376 |
+
padding=False, # important for correct prompt_len
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
prompt_len = int(enc["input_ids"].shape[1])
|
| 380 |
+
tokens_left = max(1, max_length - prompt_len)
|
| 381 |
+
max_new_tokens = min(max_new_tokens, tokens_left)
|
| 382 |
+
|
| 383 |
+
if debug:
|
| 384 |
+
logger.info("===== TOKEN BUDGET DEBUG (inference) =====")
|
| 385 |
+
logger.info("img size: %s", getattr(image, "size", None))
|
| 386 |
+
logger.info("max_length: %d", max_length)
|
| 387 |
+
logger.info("prompt_len: %d", prompt_len)
|
| 388 |
+
logger.info("tokens_left_for_answer(approx): %d", tokens_left)
|
| 389 |
+
tok = getattr(self.processor, "tokenizer", None)
|
| 390 |
+
if tok is not None:
|
| 391 |
+
ids = enc["input_ids"][0].tolist()
|
| 392 |
+
logger.info("[prompt head]\n%s", tok.decode(ids[:120], skip_special_tokens=False))
|
| 393 |
+
logger.info("[prompt tail]\n%s", tok.decode(ids[-120:], skip_special_tokens=False))
|
| 394 |
+
logger.info("=========================================")
|
| 395 |
+
|
| 396 |
+
enc = {k: v.to(self.device) for k, v in enc.items()}
|
| 397 |
+
if "pixel_values" in enc:
|
| 398 |
+
enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self.model.dtype)
|
| 399 |
+
|
| 400 |
+
gen_kwargs: Dict[str, Any] = {
|
| 401 |
+
"max_new_tokens": max_new_tokens,
|
| 402 |
+
"do_sample": temperature > 0.0,
|
| 403 |
+
"eos_token_id": self.eos_token_ids,
|
| 404 |
+
"pad_token_id": self.pad_token_id,
|
| 405 |
+
"no_repeat_ngram_size": self.gen_config.no_repeat_ngram_size,
|
| 406 |
+
"repetition_penalty": self.gen_config.repetition_penalty,
|
| 407 |
+
}
|
| 408 |
+
if temperature > 0.0:
|
| 409 |
+
gen_kwargs["temperature"] = temperature
|
| 410 |
+
|
| 411 |
+
with torch.inference_mode():
|
| 412 |
+
output_ids = self.model.generate(**enc, **gen_kwargs)
|
| 413 |
+
|
| 414 |
+
generated_only = output_ids[:, enc["input_ids"].shape[1]:]
|
| 415 |
+
generated_text = self.processor.batch_decode(
|
| 416 |
+
generated_only,
|
| 417 |
+
skip_special_tokens=True,
|
| 418 |
+
)[0].strip()
|
| 419 |
+
|
| 420 |
+
cats = _parse_and_validate_categories(generated_text, allowed_categories)
|
| 421 |
+
if not cats:
|
| 422 |
+
cats = ["food_waste"] if "food_waste" in allowed_categories else []
|
| 423 |
+
generated_text = ",".join(cats)
|
| 424 |
+
|
| 425 |
+
logger.info("Generated text: %s", generated_text)
|
| 426 |
+
return {"generated_text": generated_text, "categories": cats}
|