CanerDedeoglu commited on
Commit
637c5d5
·
verified ·
1 Parent(s): 70e1ff6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +66 -29
handler.py CHANGED
@@ -45,13 +45,9 @@ from llava.constants import (
45
  from llava.conversation import conv_templates
46
  from llava.utils import disable_torch_init
47
 
48
- # Eksik fonksiyonu kaldır - artık mm_utils'ten import ediyoruz
49
- # def get_model_name_from_path() artık gerekli değil
50
-
51
  # Varsayılanlar
52
  DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v1")
53
  MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "1024"))
54
- # ATTN_IMPLEMENTATION artık otomatik seçiliyor, bu satırı kaldırıyoruz
55
 
56
  class EndpointHandler:
57
  """
@@ -97,8 +93,6 @@ class EndpointHandler:
97
 
98
  # Görsel token işaretleri (LLaVA config)
99
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
100
- # Constants'tan direkt kullan
101
- # self.image_token, self.im_start, self.im_end artık gerekli değil
102
 
103
  # ---- yardımcılar ----
104
  def _load_image(self, img_field: str) -> Optional[Image.Image]:
@@ -115,20 +109,24 @@ class EndpointHandler:
115
  return Image.open(io.BytesIO(r.content)).convert("RGB")
116
  return Image.open(img_field).convert("RGB")
117
  except Exception as e:
118
- # Görsel opsiyoneldir; okunamazsa kullanıcıya hata dönmek yerine None bırakabiliriz.
119
  print(f"[warn] image load failed: {e}")
120
  return None
121
 
122
- def _build_prompt(self, user_text: str, conv_mode: str) -> str:
 
123
  if conv_mode not in conv_templates:
124
  conv_mode = DEFAULT_CONV_MODE
125
  conv = conv_templates[conv_mode].copy()
126
 
127
- # Image token'ları doğru yerleştir
128
- if self.use_im_start_end:
129
- content = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
 
 
 
130
  else:
131
- content = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
 
132
 
133
  conv.append_message(conv.roles[0], content)
134
  conv.append_message(conv.roles[1], None)
@@ -144,42 +142,71 @@ class EndpointHandler:
144
  query_text = inputs.get("query", "") or inputs.get("text", "") or inputs.get("prompt", "")
145
  image_f = inputs.get("image") or inputs.get("image_url") or inputs.get("image_base64")
146
 
147
- # 1) prompt
148
- prompt = self._build_prompt(query_text, conv_mode)
149
-
150
- # 2) image -> tensor (opsiyonel)
151
  image_tensors = None
 
 
152
  if image_f:
153
  pil = self._load_image(image_f)
154
  if pil is not None:
155
  try:
156
- # LLaVA'nın gelişmiş process_images fonksiyonunu kullan
157
- # Bu fonksiyon anyres, pad gibi farklı aspect ratio modlarını destekler
158
- image_tensors = process_images([pil], self.image_processor, self.model.config)
159
- if image_tensors is not None and len(image_tensors) > 0:
160
- image_tensors = image_tensors.to(self.model.device, dtype=torch.float16, non_blocking=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  except Exception as e:
162
  print(f"[warn] image processing failed: {e}")
163
  image_tensors = None
164
-
165
- # 3) tokenize (image token'ı gömülü)
166
- input_ids = tokenizer_image_token(
167
- prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
168
- ).unsqueeze(0).to(self.model.device, non_blocking=True) # unsqueeze ekledik
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  # Input uzunluk kontrolü
171
  if input_ids.shape[-1] > self.context_len - 100:
172
- # Prompt'u kısalt
173
  input_ids = input_ids[:, -(self.context_len - 200):]
174
 
175
- # 4) güvenli max_new_tokens
176
  requested = int(params.get("max_new_tokens", MAX_NEW_TOKENS_DEF))
177
  avail = max(16, int(self.context_len) - int(input_ids.shape[-1]) - 8)
178
  max_new_tokens = max(1, min(requested, avail))
179
 
180
  gen_kwargs = {
181
  "input_ids": input_ids,
182
- "images": image_tensors,
183
  "max_new_tokens": max_new_tokens,
184
  "temperature": float(params.get("temperature", 0.0)),
185
  "top_p": float(params.get("top_p", 1.0)),
@@ -189,6 +216,13 @@ class EndpointHandler:
189
  "pad_token_id": self.tokenizer.eos_token_id,
190
  }
191
 
 
 
 
 
 
 
 
192
  try:
193
  with torch.inference_mode():
194
  output_ids = self.model.generate(**gen_kwargs)
@@ -202,5 +236,8 @@ class EndpointHandler:
202
 
203
  except Exception as e:
204
  print(f"Generation error: {e}")
 
 
205
  text = f"Error during generation: {str(e)}"
 
206
  return [{"generated_text": text}]
 
45
  from llava.conversation import conv_templates
46
  from llava.utils import disable_torch_init
47
 
 
 
 
48
  # Varsayılanlar
49
  DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v1")
50
  MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "1024"))
 
51
 
52
  class EndpointHandler:
53
  """
 
93
 
94
  # Görsel token işaretleri (LLaVA config)
95
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
 
 
96
 
97
  # ---- yardımcılar ----
98
  def _load_image(self, img_field: str) -> Optional[Image.Image]:
 
109
  return Image.open(io.BytesIO(r.content)).convert("RGB")
110
  return Image.open(img_field).convert("RGB")
111
  except Exception as e:
 
112
  print(f"[warn] image load failed: {e}")
113
  return None
114
 
115
+ def _build_prompt(self, user_text: str, conv_mode: str, has_image: bool = False) -> str:
116
+ """Prompt oluştur - görüntü olup olmadığına göre"""
117
  if conv_mode not in conv_templates:
118
  conv_mode = DEFAULT_CONV_MODE
119
  conv = conv_templates[conv_mode].copy()
120
 
121
+ # Sadece görüntü varsa image token'ları ekle
122
+ if has_image:
123
+ if self.use_im_start_end:
124
+ content = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
125
+ else:
126
+ content = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
127
  else:
128
+ # Görüntü yoksa sadece text
129
+ content = user_text
130
 
131
  conv.append_message(conv.roles[0], content)
132
  conv.append_message(conv.roles[1], None)
 
142
  query_text = inputs.get("query", "") or inputs.get("text", "") or inputs.get("prompt", "")
143
  image_f = inputs.get("image") or inputs.get("image_url") or inputs.get("image_base64")
144
 
145
+ # 1) Görüntü işleme (önce)
 
 
 
146
  image_tensors = None
147
+ has_image = False
148
+
149
  if image_f:
150
  pil = self._load_image(image_f)
151
  if pil is not None:
152
  try:
153
+ # LLaVA'nın process_images fonksiyonunu kullan
154
+ processed_images = process_images([pil], self.image_processor, self.model.config)
155
+
156
+ if processed_images is not None:
157
+ # Tensor formatını kontrol et ve düzelt
158
+ if isinstance(processed_images, list):
159
+ if len(processed_images) > 0:
160
+ image_tensors = torch.stack(processed_images, dim=0)
161
+ else:
162
+ image_tensors = None
163
+ else:
164
+ image_tensors = processed_images
165
+
166
+ if image_tensors is not None:
167
+ image_tensors = image_tensors.to(
168
+ self.model.device,
169
+ dtype=torch.float16,
170
+ non_blocking=True
171
+ )
172
+ has_image = True
173
+ print(f"[info] Image processed successfully, shape: {image_tensors.shape}")
174
+ else:
175
+ print("[warn] Image processing returned None")
176
+
177
  except Exception as e:
178
  print(f"[warn] image processing failed: {e}")
179
  image_tensors = None
180
+ has_image = False
181
+
182
+ # 2) Prompt oluştur (görüntü durumuna göre)
183
+ prompt = self._build_prompt(query_text, conv_mode, has_image)
184
+ print(f"[debug] Generated prompt: {repr(prompt[:200])}")
185
+
186
+ # 3) Tokenize
187
+ if has_image:
188
+ # Görüntü varsa IMAGE_TOKEN_INDEX ile tokenize et
189
+ input_ids = tokenizer_image_token(
190
+ prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
191
+ )
192
+ else:
193
+ # Görüntü yoksa normal tokenize
194
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids[0]
195
+
196
+ # Batch dimension ekle
197
+ input_ids = input_ids.unsqueeze(0).to(self.model.device, non_blocking=True)
198
 
199
  # Input uzunluk kontrolü
200
  if input_ids.shape[-1] > self.context_len - 100:
 
201
  input_ids = input_ids[:, -(self.context_len - 200):]
202
 
203
+ # 4) Generation parameters
204
  requested = int(params.get("max_new_tokens", MAX_NEW_TOKENS_DEF))
205
  avail = max(16, int(self.context_len) - int(input_ids.shape[-1]) - 8)
206
  max_new_tokens = max(1, min(requested, avail))
207
 
208
  gen_kwargs = {
209
  "input_ids": input_ids,
 
210
  "max_new_tokens": max_new_tokens,
211
  "temperature": float(params.get("temperature", 0.0)),
212
  "top_p": float(params.get("top_p", 1.0)),
 
216
  "pad_token_id": self.tokenizer.eos_token_id,
217
  }
218
 
219
+ # Görüntü varsa images parametresini ekle
220
+ if has_image and image_tensors is not None:
221
+ gen_kwargs["images"] = image_tensors
222
+ print(f"[info] Using images in generation, shape: {image_tensors.shape}")
223
+ else:
224
+ print("[info] No images in generation")
225
+
226
  try:
227
  with torch.inference_mode():
228
  output_ids = self.model.generate(**gen_kwargs)
 
236
 
237
  except Exception as e:
238
  print(f"Generation error: {e}")
239
+ import traceback
240
+ traceback.print_exc()
241
  text = f"Error during generation: {str(e)}"
242
+
243
  return [{"generated_text": text}]