Kalaoke commited on
Commit
7d66d82
·
verified ·
1 Parent(s): 7a3a24a

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +184 -0
handler.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ from dataclasses import dataclass
5
+ from io import BytesIO
6
+ from typing import Any, Dict, Optional, List
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
11
+ from transformers.utils import logging
12
+
13
+
14
+ logger = logging.get_logger(__name__)
15
+ logging.set_verbosity_info()
16
+
17
+
18
+ BASE_MODEL_ID = "mistral-community/pixtral-12b"
19
+
20
+
21
+ # Prompt par défaut (tu peux l’ajuster ici)
22
+ DEFAULT_PROMPT = (
23
+ "Here is a photo showing some food waste. "
24
+ "Identify each type of food item and the corresponding weight in grams. "
25
+ "Reply like: Milk, 120g; Coffee, 45g. "
26
+ "Do not add any explanation, no extra text."
27
+ )
28
+
29
+
30
+ @dataclass
31
+ class GenerationConfig:
32
+
33
+ max_new_tokens: int = 64
34
+ temperature: float = 0.0
35
+ no_repeat_ngram_size: int = 6
36
+ repetition_penalty: float = 1.1
37
+
38
+
39
+ class EndpointHandler:
40
+
41
+
42
+ def __init__(self, path: str = ".") -> None:
43
+ """
44
+ Initializes the model and processor from the `path` directory,
45
+ which contains the merged weights (pixtral-12b-foodwaste-merged).
46
+ """
47
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ logger.info("Initializing EndpointHandler on device: %s", self.device)
49
+
50
+ self.processor = AutoProcessor.from_pretrained(
51
+ BASE_MODEL_ID,
52
+ trust_remote_code=True,
53
+ )
54
+
55
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
56
+ self.model = LlavaForConditionalGeneration.from_pretrained(
57
+ BASE_MODEL_ID,
58
+ torch_dtype=dtype,
59
+ low_cpu_mem_usage=True,
60
+ device_map={"": self.device},
61
+ trust_remote_code=True,
62
+ )
63
+ self.model.eval()
64
+ logger.info("Model and processor successfully loaded from '%s'.", path)
65
+
66
+ # pad token management
67
+ tokenizer = getattr(self.processor, "tokenizer", None)
68
+ if tokenizer is not None and tokenizer.pad_token_id is None:
69
+ tokenizer.pad_token = tokenizer.eos_token
70
+ tokenizer.pad_token_id = tokenizer.eos_token_id
71
+
72
+ # Preparation of EOS/PAD IDs for generate
73
+ eos_candidates: List[int] = []
74
+ if self.model.config.eos_token_id is not None:
75
+ eos_candidates.append(self.model.config.eos_token_id)
76
+ if tokenizer is not None and tokenizer.eos_token_id is not None:
77
+ eos_candidates.append(tokenizer.eos_token_id)
78
+
79
+ self.eos_token_ids: List[int] = list({i for i in eos_candidates})
80
+ if not self.eos_token_ids:
81
+ raise ValueError("No EOS token id found on model or tokenizer.")
82
+
83
+ pad_id: Optional[int] = getattr(self.model.config, "pad_token_id", None)
84
+ if pad_id is None and tokenizer is not None:
85
+ pad_id = tokenizer.pad_token_id
86
+ if pad_id is None:
87
+ pad_id = self.eos_token_ids[0]
88
+
89
+ self.pad_token_id: int = pad_id
90
+
91
+ self.gen_config = GenerationConfig()
92
+ logger.info(
93
+ "Generation config: max_new_tokens=%d, temperature=%.3f",
94
+ self.gen_config.max_new_tokens,
95
+ self.gen_config.temperature,
96
+ )
97
+
98
+
99
+ @staticmethod
100
+ def _decode_image(image_b64: str) -> Image.Image:
101
+
102
+ try:
103
+ img_bytes = base64.b64decode(image_b64)
104
+ img = Image.open(BytesIO(img_bytes)).convert("RGB")
105
+ return img
106
+ except Exception as exc: # pragma: no cover - log production
107
+ raise ValueError(f"Could not decode base64 image: {exc}") from exc
108
+
109
+ def _build_chat_text(self, prompt: str) -> str:
110
+
111
+ messages = [
112
+ {
113
+ "role": "user",
114
+ "content": [
115
+ {"type": "text", "text": prompt},
116
+ {"type": "image"},
117
+ ],
118
+ }
119
+ ]
120
+
121
+ chat_text = self.processor.apply_chat_template(
122
+ messages,
123
+ add_generation_prompt=True,
124
+ tokenize=False,
125
+ )
126
+ return chat_text
127
+
128
+
129
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
130
+
131
+ inputs = data.get("inputs", data)
132
+
133
+ prompt: str = inputs.get("prompt") or DEFAULT_PROMPT
134
+
135
+ image_b64: Optional[str] = inputs.get("image")
136
+ if image_b64 is None:
137
+ raise ValueError("Missing 'image' field (base64-encoded) in 'inputs'.")
138
+
139
+ image = self._decode_image(image_b64)
140
+
141
+ max_new_tokens = int(inputs.get("max_new_tokens", self.gen_config.max_new_tokens))
142
+ temperature = float(inputs.get("temperature", self.gen_config.temperature))
143
+
144
+ logger.info(
145
+ "Received request: max_new_tokens=%d, temperature=%.3f",
146
+ max_new_tokens,
147
+ temperature,
148
+ )
149
+
150
+ chat_text = self._build_chat_text(prompt)
151
+
152
+ enc = self.processor(
153
+ text=[chat_text],
154
+ images=[image],
155
+ return_tensors="pt",
156
+ truncation=False,
157
+ )
158
+ enc = {k: v.to(self.device) for k, v in enc.items()}
159
+ if "pixel_values" in enc:
160
+ enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self.model.dtype)
161
+
162
+ gen_kwargs: Dict[str, Any] = {
163
+ "max_new_tokens": max_new_tokens,
164
+ "do_sample": temperature > 0.0,
165
+ "eos_token_id": self.eos_token_ids,
166
+ "pad_token_id": self.pad_token_id,
167
+ "no_repeat_ngram_size": self.gen_config.no_repeat_ngram_size,
168
+ "repetition_penalty": self.gen_config.repetition_penalty,
169
+ }
170
+ if temperature > 0.0:
171
+ gen_kwargs["temperature"] = temperature
172
+
173
+ with torch.inference_mode():
174
+ output_ids = self.model.generate(**enc, **gen_kwargs)
175
+
176
+ generated_only = output_ids[:, enc["input_ids"].shape[1]:]
177
+ generated_text = self.processor.batch_decode(
178
+ generated_only,
179
+ skip_special_tokens=True,
180
+ )[0].strip()
181
+
182
+ logger.info("Generated text: %s", generated_text)
183
+
184
+ return {"generated_text": generated_text}