Kalaoke commited on
Commit
d4884fa
·
verified ·
1 Parent(s): 7d66d82

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +426 -184
handler.py CHANGED
@@ -1,184 +1,426 @@
1
- from __future__ import annotations
2
-
3
- import base64
4
- from dataclasses import dataclass
5
- from io import BytesIO
6
- from typing import Any, Dict, Optional, List
7
-
8
- import torch
9
- from PIL import Image
10
- from transformers import AutoProcessor, LlavaForConditionalGeneration
11
- from transformers.utils import logging
12
-
13
-
14
- logger = logging.get_logger(__name__)
15
- logging.set_verbosity_info()
16
-
17
-
18
- BASE_MODEL_ID = "mistral-community/pixtral-12b"
19
-
20
-
21
- # Prompt par défaut (tu peux l’ajuster ici)
22
- DEFAULT_PROMPT = (
23
- "Here is a photo showing some food waste. "
24
- "Identify each type of food item and the corresponding weight in grams. "
25
- "Reply like: Milk, 120g; Coffee, 45g. "
26
- "Do not add any explanation, no extra text."
27
- )
28
-
29
-
30
- @dataclass
31
- class GenerationConfig:
32
-
33
- max_new_tokens: int = 64
34
- temperature: float = 0.0
35
- no_repeat_ngram_size: int = 6
36
- repetition_penalty: float = 1.1
37
-
38
-
39
- class EndpointHandler:
40
-
41
-
42
- def __init__(self, path: str = ".") -> None:
43
- """
44
- Initializes the model and processor from the `path` directory,
45
- which contains the merged weights (pixtral-12b-foodwaste-merged).
46
- """
47
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
48
- logger.info("Initializing EndpointHandler on device: %s", self.device)
49
-
50
- self.processor = AutoProcessor.from_pretrained(
51
- BASE_MODEL_ID,
52
- trust_remote_code=True,
53
- )
54
-
55
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
56
- self.model = LlavaForConditionalGeneration.from_pretrained(
57
- BASE_MODEL_ID,
58
- torch_dtype=dtype,
59
- low_cpu_mem_usage=True,
60
- device_map={"": self.device},
61
- trust_remote_code=True,
62
- )
63
- self.model.eval()
64
- logger.info("Model and processor successfully loaded from '%s'.", path)
65
-
66
- # pad token management
67
- tokenizer = getattr(self.processor, "tokenizer", None)
68
- if tokenizer is not None and tokenizer.pad_token_id is None:
69
- tokenizer.pad_token = tokenizer.eos_token
70
- tokenizer.pad_token_id = tokenizer.eos_token_id
71
-
72
- # Preparation of EOS/PAD IDs for generate
73
- eos_candidates: List[int] = []
74
- if self.model.config.eos_token_id is not None:
75
- eos_candidates.append(self.model.config.eos_token_id)
76
- if tokenizer is not None and tokenizer.eos_token_id is not None:
77
- eos_candidates.append(tokenizer.eos_token_id)
78
-
79
- self.eos_token_ids: List[int] = list({i for i in eos_candidates})
80
- if not self.eos_token_ids:
81
- raise ValueError("No EOS token id found on model or tokenizer.")
82
-
83
- pad_id: Optional[int] = getattr(self.model.config, "pad_token_id", None)
84
- if pad_id is None and tokenizer is not None:
85
- pad_id = tokenizer.pad_token_id
86
- if pad_id is None:
87
- pad_id = self.eos_token_ids[0]
88
-
89
- self.pad_token_id: int = pad_id
90
-
91
- self.gen_config = GenerationConfig()
92
- logger.info(
93
- "Generation config: max_new_tokens=%d, temperature=%.3f",
94
- self.gen_config.max_new_tokens,
95
- self.gen_config.temperature,
96
- )
97
-
98
-
99
- @staticmethod
100
- def _decode_image(image_b64: str) -> Image.Image:
101
-
102
- try:
103
- img_bytes = base64.b64decode(image_b64)
104
- img = Image.open(BytesIO(img_bytes)).convert("RGB")
105
- return img
106
- except Exception as exc: # pragma: no cover - log production
107
- raise ValueError(f"Could not decode base64 image: {exc}") from exc
108
-
109
- def _build_chat_text(self, prompt: str) -> str:
110
-
111
- messages = [
112
- {
113
- "role": "user",
114
- "content": [
115
- {"type": "text", "text": prompt},
116
- {"type": "image"},
117
- ],
118
- }
119
- ]
120
-
121
- chat_text = self.processor.apply_chat_template(
122
- messages,
123
- add_generation_prompt=True,
124
- tokenize=False,
125
- )
126
- return chat_text
127
-
128
-
129
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
130
-
131
- inputs = data.get("inputs", data)
132
-
133
- prompt: str = inputs.get("prompt") or DEFAULT_PROMPT
134
-
135
- image_b64: Optional[str] = inputs.get("image")
136
- if image_b64 is None:
137
- raise ValueError("Missing 'image' field (base64-encoded) in 'inputs'.")
138
-
139
- image = self._decode_image(image_b64)
140
-
141
- max_new_tokens = int(inputs.get("max_new_tokens", self.gen_config.max_new_tokens))
142
- temperature = float(inputs.get("temperature", self.gen_config.temperature))
143
-
144
- logger.info(
145
- "Received request: max_new_tokens=%d, temperature=%.3f",
146
- max_new_tokens,
147
- temperature,
148
- )
149
-
150
- chat_text = self._build_chat_text(prompt)
151
-
152
- enc = self.processor(
153
- text=[chat_text],
154
- images=[image],
155
- return_tensors="pt",
156
- truncation=False,
157
- )
158
- enc = {k: v.to(self.device) for k, v in enc.items()}
159
- if "pixel_values" in enc:
160
- enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self.model.dtype)
161
-
162
- gen_kwargs: Dict[str, Any] = {
163
- "max_new_tokens": max_new_tokens,
164
- "do_sample": temperature > 0.0,
165
- "eos_token_id": self.eos_token_ids,
166
- "pad_token_id": self.pad_token_id,
167
- "no_repeat_ngram_size": self.gen_config.no_repeat_ngram_size,
168
- "repetition_penalty": self.gen_config.repetition_penalty,
169
- }
170
- if temperature > 0.0:
171
- gen_kwargs["temperature"] = temperature
172
-
173
- with torch.inference_mode():
174
- output_ids = self.model.generate(**enc, **gen_kwargs)
175
-
176
- generated_only = output_ids[:, enc["input_ids"].shape[1]:]
177
- generated_text = self.processor.batch_decode(
178
- generated_only,
179
- skip_special_tokens=True,
180
- )[0].strip()
181
-
182
- logger.info("Generated text: %s", generated_text)
183
-
184
- return {"generated_text": generated_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ from curses import raw
5
+ import re
6
+ import unicodedata
7
+ from dataclasses import dataclass
8
+ from io import BytesIO
9
+ from typing import Any, Dict, Optional, List, Set
10
+
11
+ import torch
12
+ from PIL import Image
13
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
14
+ from transformers.utils import logging
15
+
16
+
17
+ logger = logging.get_logger(__name__)
18
+ logging.set_verbosity_info()
19
+
20
+ DEFAULT_PROMPT = (
21
+ "Here is a picture showing some food waste.\n\n"
22
+ "Task: Provide a list of the food waste items visible in the picture.\n"
23
+ "For each item, output EXACTLY one CATEGORY chosen from the CATEGORY DICTIONARY below.\n\n"
24
+ "Rules:\n"
25
+ "- Output must be a single line (no extra text, no markdown).\n"
26
+ "- Use the exact spelling of CATEGORY as listed.\n"
27
+ "- If you are unsure, choose the closest broader category (e.g., 'fruit', 'vegetable', 'food_waste').\n\n"
28
+ "CATEGORY DICTIONARY:\n"
29
+ "1. cucumber\n"
30
+ "2. tomato\n"
31
+ "3. cherry tomato\n"
32
+ "4. carrot\n"
33
+ "5. lettuce\n"
34
+ "6. bell_pepper\n"
35
+ "7. cabbage\n"
36
+ "8. cauliflower\n"
37
+ "9. broccoli\n"
38
+ "10. onion\n"
39
+ "11. herbs\n"
40
+ "12. vegetable_mix\n"
41
+ "13. vegetable\n"
42
+ "14. potato\n"
43
+ "15. potato_product\n"
44
+ "16. potato_based\n"
45
+ "17. banana\n"
46
+ "18. apple\n"
47
+ "19. mandarin\n"
48
+ "20. orange\n"
49
+ "21. lemon\n"
50
+ "22. grape\n"
51
+ "23. plum\n"
52
+ "24. nectarine\n"
53
+ "25. dried fruits\n"
54
+ "26. jam\n"
55
+ "27. cherry\n"
56
+ "28. fruit\n"
57
+ "29. blueberry\n"
58
+ "30. strawberry\n"
59
+ "31. raspberry\n"
60
+ "32. currant\n"
61
+ "33. lingonberry\n"
62
+ "34. berries\n"
63
+ "35. rice\n"
64
+ "36. pasta\n"
65
+ "37. rice_or_pasta\n"
66
+ "38. dark_bread\n"
67
+ "39. light_bread\n"
68
+ "40. bread\n"
69
+ "41. oatmeal\n"
70
+ "42. semolina_porridge\n"
71
+ "43. rice_porridge\n"
72
+ "44. whipped_porridge\n"
73
+ "45. porridge\n"
74
+ "46. savory_pastry\n"
75
+ "47. sweet_pastry\n"
76
+ "48. biscuit\n"
77
+ "49. baked_goods\n"
78
+ "50. flour\n"
79
+ "51. flakes\n"
80
+ "52. muesli\n"
81
+ "53. cereal\n"
82
+ "54. grain_products\n"
83
+ "55. milk\n"
84
+ "56. plant_based_milk\n"
85
+ "57. children_milk\n"
86
+ "58. dairy_product_milk\n"
87
+ "59. block_cheese\n"
88
+ "60. sliced_cheese\n"
89
+ "61. fresh_cheese\n"
90
+ "62. soft_cheese\n"
91
+ "63. plant_cheese\n"
92
+ "64. cheese\n"
93
+ "65. yoghurt\n"
94
+ "66. buttermilk\n"
95
+ "67. quark\n"
96
+ "68. fermented_milk\n"
97
+ "69. cream\n"
98
+ "70. ice_cream\n"
99
+ "71. plant_based_yogurt\n"
100
+ "72. plant_based_cream\n"
101
+ "73. plant_based_ice_cream\n"
102
+ "74. dairy_product\n"
103
+ "75. egg\n"
104
+ "76. eggs\n"
105
+ "77. beef_steak\n"
106
+ "78. minced_beef\n"
107
+ "79. cold_cut_beef\n"
108
+ "80. beef_sausage\n"
109
+ "81. beef_frankfurter\n"
110
+ "82. beef_product\n"
111
+ "83. beef\n"
112
+ "84. pork_steak\n"
113
+ "85. minced_pork\n"
114
+ "86. cold_cut_pork\n"
115
+ "87. pork_sausage\n"
116
+ "88. pork_frankfurter\n"
117
+ "89. pork_product\n"
118
+ "90. pork\n"
119
+ "91. chicken\n"
120
+ "92. minced_chicken\n"
121
+ "93. cold_cut_chicken\n"
122
+ "94. chicken_sausage\n"
123
+ "95. chicken_frankfurter\n"
124
+ "96. chicken_product\n"
125
+ "97. steak\n"
126
+ "98. minced_meat\n"
127
+ "99. cold_cut_meat\n"
128
+ "100. sausage\n"
129
+ "101. frankfurter\n"
130
+ "102. meat_product\n"
131
+ "103. plant_protein\n"
132
+ "104. meat\n"
133
+ "105. fish_fillet\n"
134
+ "106. fish_strips\n"
135
+ "107. fish_product\n"
136
+ "108. shellfish\n"
137
+ "109. fish\n"
138
+ "110. meat_soup\n"
139
+ "111. meat_stew\n"
140
+ "112. meat_sauce\n"
141
+ "113. meat_pizza\n"
142
+ "114. meat_hamburger\n"
143
+ "115. meat_salad\n"
144
+ "116. meat_dish\n"
145
+ "117. fish_soup\n"
146
+ "118. fish_stew\n"
147
+ "119. fish_sauce\n"
148
+ "120. fish_pizza\n"
149
+ "121. fish_hamburger\n"
150
+ "122. fish_salad\n"
151
+ "123. fish_dish\n"
152
+ "124. vegetarian_soup\n"
153
+ "125. vegetarian_stew\n"
154
+ "126. vegetarian_sauce\n"
155
+ "127. vegetarian_pizza\n"
156
+ "128. vegetarian_hamburger\n"
157
+ "129. vegetarian_salad\n"
158
+ "130. vegetarian_dish\n"
159
+ "131. pastry\n"
160
+ "132. sweet_soup\n"
161
+ "133. quark\n"
162
+ "134. ice_cream\n"
163
+ "135. crepes\n"
164
+ "136. jam\n"
165
+ "137. dessert\n"
166
+ "138. chocolate\n"
167
+ "139. candy\n"
168
+ "140. sweet\n"
169
+ "141. popcorn\n"
170
+ "142. potato_chip\n"
171
+ "143. snack\n"
172
+ "144. sauce_paste\n"
173
+ "145. spice\n"
174
+ "146. sauce_seasoning\n"
175
+ "147. butter\n"
176
+ "148. oil\n"
177
+ "149. fat_oil\n"
178
+ "150. juice\n"
179
+ "151. soda\n"
180
+ "152. soft_drink\n"
181
+ "153. coffee\n"
182
+ "154. tea\n"
183
+ "155. cocoa\n"
184
+ "156. hot_beverage\n"
185
+ "157. beer\n"
186
+ "158. cider\n"
187
+ "159. alcoholic_long_drink\n"
188
+ "160. wine\n"
189
+ "161. strong_alcoholic_beverage\n"
190
+ "162. alcoholic_beverage\n"
191
+ "163. sugar\n"
192
+ "164. honey\n"
193
+ "165. sirup\n"
194
+ "166. sugar_honey_syrup\n"
195
+ "167. food_waste\n\n"
196
+ "OUTPUT FORMAT (single line):\n"
197
+ "CATEGORY1,CATEGORY2,CATEGORY3\n"
198
+ )
199
+
200
+
201
+ @dataclass
202
+ class GenerationConfig:
203
+
204
+ max_new_tokens: int = 256
205
+ temperature: float = 0.0
206
+ no_repeat_ngram_size: int = 6
207
+ repetition_penalty: float = 1.1
208
+ max_length: int = 4096
209
+ max_side: int = 512
210
+
211
+
212
+ _ALLOWED_RE = re.compile(r"[^a-z0-9_\(\);,\|\/\-\s\?\.\:]")
213
+
214
+ ITEM_RE = re.compile(r"\(\s*(.*?)\s*\)\s*\|\s*(\d+)", flags=re.DOTALL)
215
+
216
+ def _clean_text(s: str) -> str:
217
+ s = unicodedata.normalize("NFKC", s).replace("\u00A0", " ")
218
+ s = re.sub(r"[\u200B-\u200F]", "", s)
219
+ s = re.sub(r"\s+", " ", s).strip()
220
+ return s
221
+
222
+
223
+ _CATEGORY_LINE_RE = re.compile(r"^\s*\d+\.\s*(.+?)\s*$", flags=re.MULTILINE)
224
+
225
+ def _extract_categories_from_prompt(prompt: str) -> Set[str]:
226
+ cats = {m.group(1).strip().lower() for m in _CATEGORY_LINE_RE.finditer(prompt)}
227
+ return {c for c in cats if c}
228
+
229
+ def _clean_model_output(s: str) -> str:
230
+ s = _clean_text(s).lower()
231
+ s = _ALLOWED_RE.sub("", s)
232
+ return s
233
+
234
+ def _parse_and_validate_categories(raw: str, allowed: Set[str]) -> List[str]:
235
+ s = _clean_model_output(raw)
236
+ parts = [p.strip() for p in s.split(",") if p.strip()]
237
+ out: List[str] = []
238
+ seen: Set[str] = set()
239
+ for p in parts:
240
+ if p in allowed and p not in seen:
241
+ out.append(p)
242
+ seen.add(p)
243
+ return out
244
+
245
+
246
+ class EndpointHandler:
247
+
248
+
249
+ def __init__(self, path: str = ".") -> None:
250
+ """
251
+ Initializes the model and processor from the `path` directory,
252
+ which contains the merged weights (pixtral-12b-foodwaste-merged).
253
+ """
254
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
255
+ logger.info("Initializing EndpointHandler on device: %s", self.device)
256
+ self.processor = AutoProcessor.from_pretrained(
257
+ path,
258
+ trust_remote_code=True,
259
+ )
260
+
261
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
262
+ self.model = LlavaForConditionalGeneration.from_pretrained(
263
+ path,
264
+ torch_dtype=dtype,
265
+ low_cpu_mem_usage=True,
266
+ device_map={"": self.device},
267
+ trust_remote_code=True,
268
+ )
269
+ self.model.eval()
270
+ logger.info("Model and processor successfully loaded from '%s'.", path)
271
+
272
+ # pad token management
273
+ tokenizer = getattr(self.processor, "tokenizer", None)
274
+ if tokenizer is not None and tokenizer.pad_token_id is None:
275
+ tokenizer.pad_token = tokenizer.eos_token
276
+ tokenizer.pad_token_id = tokenizer.eos_token_id
277
+
278
+ # Preparation of EOS/PAD IDs for generate
279
+ eos_candidates: List[int] = []
280
+ if self.model.config.eos_token_id is not None:
281
+ eos_candidates.append(self.model.config.eos_token_id)
282
+ if tokenizer is not None and tokenizer.eos_token_id is not None:
283
+ eos_candidates.append(tokenizer.eos_token_id)
284
+
285
+ self.eos_token_ids: List[int] = list({i for i in eos_candidates})
286
+ if not self.eos_token_ids:
287
+ raise ValueError("No EOS token id found on model or tokenizer.")
288
+
289
+ pad_id: Optional[int] = getattr(self.model.config, "pad_token_id", None)
290
+ if pad_id is None and tokenizer is not None:
291
+ pad_id = tokenizer.pad_token_id
292
+ if pad_id is None:
293
+ pad_id = self.eos_token_ids[0]
294
+
295
+ self.pad_token_id: int = pad_id
296
+
297
+ self.gen_config = GenerationConfig()
298
+ logger.info(
299
+ "Generation config: max_new_tokens=%d, temperature=%.3f",
300
+ self.gen_config.max_new_tokens,
301
+ self.gen_config.temperature,
302
+ )
303
+
304
+ self.default_allowed_categories: Set[str] = _extract_categories_from_prompt(DEFAULT_PROMPT)
305
+ logger.info("Extracted %d categories from DEFAULT_PROMPT.", len(self.default_allowed_categories))
306
+
307
+
308
+
309
+ @staticmethod
310
+ def _decode_image(image_b64: str) -> Image.Image:
311
+
312
+ try:
313
+ img_bytes = base64.b64decode(image_b64)
314
+ img = Image.open(BytesIO(img_bytes)).convert("RGB")
315
+ return img
316
+ except Exception as exc: # pragma: no cover - log production
317
+ raise ValueError(f"Could not decode base64 image: {exc}") from exc
318
+
319
+ @staticmethod
320
+ def _resize_max_side(img: Image.Image, max_side: int) -> Image.Image:
321
+ w, h = img.size
322
+ m = max(w, h)
323
+ if m <= max_side:
324
+ return img
325
+ scale = max_side / m
326
+ # LANCZOS = better downscale
327
+ return img.resize((int(w * scale), int(h * scale)), resample=Image.Resampling.LANCZOS)
328
+
329
+ def _build_chat_text(self, prompt: str) -> str:
330
+
331
+ messages = [
332
+ {
333
+ "role": "user",
334
+ "content": [
335
+ {"type": "text", "text": prompt},
336
+ {"type": "image"},
337
+ ],
338
+ }
339
+ ]
340
+
341
+ chat_text = self.processor.apply_chat_template(
342
+ messages,
343
+ add_generation_prompt=True,
344
+ tokenize=False,
345
+ )
346
+ return chat_text
347
+
348
+
349
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
350
+ inputs = data.get("inputs", data)
351
+ debug = bool(inputs.get("debug", False))
352
+
353
+ prompt: str = inputs.get("prompt") or DEFAULT_PROMPT
354
+ allowed_categories = _extract_categories_from_prompt(prompt) or self.default_allowed_categories
355
+ image_b64: Optional[str] = inputs.get("image")
356
+ if not image_b64:
357
+ raise ValueError("Missing 'image' field (base64-encoded) in 'inputs'.")
358
+
359
+ image = self._decode_image(image_b64)
360
+
361
+ max_length = int(inputs.get("max_length", self.gen_config.max_length))
362
+ max_side = int(inputs.get("max_side", self.gen_config.max_side))
363
+ image = self._resize_max_side(image, max_side=max_side)
364
+
365
+ max_new_tokens = int(inputs.get("max_new_tokens", self.gen_config.max_new_tokens))
366
+ temperature = float(inputs.get("temperature", self.gen_config.temperature))
367
+
368
+ chat_text = self._build_chat_text(prompt)
369
+
370
+ enc = self.processor(
371
+ text=[chat_text],
372
+ images=[image],
373
+ return_tensors="pt",
374
+ truncation=True,
375
+ max_length=max_length,
376
+ padding=False, # important for correct prompt_len
377
+ )
378
+
379
+ prompt_len = int(enc["input_ids"].shape[1])
380
+ tokens_left = max(1, max_length - prompt_len)
381
+ max_new_tokens = min(max_new_tokens, tokens_left)
382
+
383
+ if debug:
384
+ logger.info("===== TOKEN BUDGET DEBUG (inference) =====")
385
+ logger.info("img size: %s", getattr(image, "size", None))
386
+ logger.info("max_length: %d", max_length)
387
+ logger.info("prompt_len: %d", prompt_len)
388
+ logger.info("tokens_left_for_answer(approx): %d", tokens_left)
389
+ tok = getattr(self.processor, "tokenizer", None)
390
+ if tok is not None:
391
+ ids = enc["input_ids"][0].tolist()
392
+ logger.info("[prompt head]\n%s", tok.decode(ids[:120], skip_special_tokens=False))
393
+ logger.info("[prompt tail]\n%s", tok.decode(ids[-120:], skip_special_tokens=False))
394
+ logger.info("=========================================")
395
+
396
+ enc = {k: v.to(self.device) for k, v in enc.items()}
397
+ if "pixel_values" in enc:
398
+ enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self.model.dtype)
399
+
400
+ gen_kwargs: Dict[str, Any] = {
401
+ "max_new_tokens": max_new_tokens,
402
+ "do_sample": temperature > 0.0,
403
+ "eos_token_id": self.eos_token_ids,
404
+ "pad_token_id": self.pad_token_id,
405
+ "no_repeat_ngram_size": self.gen_config.no_repeat_ngram_size,
406
+ "repetition_penalty": self.gen_config.repetition_penalty,
407
+ }
408
+ if temperature > 0.0:
409
+ gen_kwargs["temperature"] = temperature
410
+
411
+ with torch.inference_mode():
412
+ output_ids = self.model.generate(**enc, **gen_kwargs)
413
+
414
+ generated_only = output_ids[:, enc["input_ids"].shape[1]:]
415
+ generated_text = self.processor.batch_decode(
416
+ generated_only,
417
+ skip_special_tokens=True,
418
+ )[0].strip()
419
+
420
+ cats = _parse_and_validate_categories(generated_text, allowed_categories)
421
+ if not cats:
422
+ cats = ["food_waste"] if "food_waste" in allowed_categories else []
423
+ generated_text = ",".join(cats)
424
+
425
+ logger.info("Generated text: %s", generated_text)
426
+ return {"generated_text": generated_text, "categories": cats}