CanerDedeoglu commited on
Commit
bca3b45
·
verified ·
1 Parent(s): bed7a91

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +130 -111
handler.py CHANGED
@@ -1,5 +1,12 @@
1
-
2
  # -*- coding: utf-8 -*-
 
 
 
 
 
 
 
 
3
  import os, io, sys, subprocess, base64
4
  from typing import Any, Dict, List, Optional
5
 
@@ -8,9 +15,8 @@ from PIL import Image
8
  import requests
9
  import math
10
  import ast
11
- from io import BytesIO
12
  from urllib.parse import urlparse
13
-
14
 
15
  # ===== Kullanılacak HF model id =====
16
  MODEL_ID = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
@@ -42,7 +48,7 @@ try:
42
  except ImportError:
43
  # Fallback: kendi implementasyonumuzu kullan
44
  from llava.constants import IMAGE_TOKEN_INDEX
45
-
46
  def expand2square(pil_img, background_color):
47
  width, height = pil_img.size
48
  if width == height:
@@ -61,13 +67,11 @@ except ImportError:
61
  best_fit = None
62
  max_effective_resolution = 0
63
  min_wasted_resolution = float('inf')
64
-
65
  for width, height in possible_resolutions:
66
  scale = min(width / original_width, height / original_height)
67
  downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
68
  effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
69
  wasted_resolution = (width * height) - effective_resolution
70
-
71
  if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
72
  max_effective_resolution = effective_resolution
73
  min_wasted_resolution = wasted_resolution
@@ -77,17 +81,14 @@ except ImportError:
77
  def resize_and_pad_image(image, target_resolution):
78
  original_width, original_height = image.size
79
  target_width, target_height = target_resolution
80
-
81
  scale_w = target_width / original_width
82
  scale_h = target_height / original_height
83
-
84
  if scale_w < scale_h:
85
  new_width = target_width
86
  new_height = min(math.ceil(original_height * scale_w), target_height)
87
  else:
88
  new_height = target_height
89
  new_width = min(math.ceil(original_width * scale_h), target_width)
90
-
91
  resized_image = image.resize((new_width, new_height))
92
  new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
93
  paste_x = (target_width - new_width) // 2
@@ -181,6 +182,9 @@ from llava.constants import (
181
  from llava.conversation import conv_templates
182
  from llava.utils import disable_torch_init
183
 
 
 
 
184
  # Varsayılanlar
185
  DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v1")
186
  MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "1024"))
@@ -211,11 +215,11 @@ class EndpointHandler:
211
 
212
  # Attention implementation otomatik seç
213
  try:
214
- import flash_attn
215
  attn_impl = "flash_attention_2"
216
  except ImportError:
217
  attn_impl = "sdpa"
218
-
219
  # PULSE, LLaVA tabanlı olduğundan LLaVA loader ile yüklenir
220
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
221
  model_path=model_path,
@@ -227,18 +231,31 @@ class EndpointHandler:
227
  )
228
  self.model.eval()
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  def _patch_forward(obj, label="model"):
231
  try:
232
  if not hasattr(obj, "forward"):
233
  return False
234
  orig_forward = obj.forward
235
-
236
  def patched_forward(*args, **kwargs):
237
- # Sessizce düşürülecek yeni anahtarlar
238
  kwargs.pop("cache_position", None)
239
  kwargs.pop("input_positions", None)
240
  return orig_forward(*args, **kwargs)
241
-
242
  obj.forward = patched_forward
243
  print(f"[hotfix] Patched forward on {label}")
244
  return True
@@ -246,63 +263,88 @@ class EndpointHandler:
246
  print(f"[warn] forward patch failed on {label}: {e}")
247
  return False
248
 
249
- # Ana modelde dene
250
  _patch_forward(self.model, "self.model")
251
-
252
- # Bazı sürümlerde forward zinciri iç modüle de gider
253
  if hasattr(self.model, "model"):
254
  _patch_forward(self.model.model, "self.model.model")
255
  if hasattr(self.model, "base_model"):
256
  _patch_forward(self.model.base_model, "self.model.base_model")
257
- # =======================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  # Model worker'dan: multimodal check
260
  self.is_multimodal = 'llava' in self.model_name.lower() or 'pulse' in self.model_name.lower()
261
-
262
  # Görsel token işaretleri (LLaVA config)
263
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
264
 
265
- # ---- yardımcılar ----
 
 
 
 
 
 
 
266
  def _load_image(self, image_input: str) -> Optional[Image.Image]:
267
  """
268
  URL / base64 / yerel path -> PIL.Image
269
- - URL için: içerik tipi/uzantı kontrolü, UA header, redirect, boyut sınırı
270
- - base64 için: data URL prefix temizliği + padding düzeltme
271
- - path için: doğrudan aç
272
  """
273
  if not image_input:
274
  return None
275
-
276
- # 25MB üstünü reddet (isteğe göre ayarlanabilir)
277
- MAX_IMAGE_BYTES = int(os.getenv("MAX_IMAGE_BYTES", "26214400"))
278
  ALLOWED_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"}
279
-
280
  def _is_valid_image_format(url: str) -> bool:
281
  try:
282
  ext = os.path.splitext(urlparse(url).path.lower())[1]
283
- # Uzantı yoksa da kabul et; Content-Type ile kontrol ederiz
284
  return (ext in ALLOWED_EXTS) or (ext == "")
285
  except Exception:
286
  return True
287
-
288
  try:
289
- # ---- URL input ----
290
  if isinstance(image_input, str) and image_input.startswith(("http://", "https://")):
291
  if not _is_valid_image_format(image_input):
292
  print("[warn] Invalid image extension in URL")
293
  return None
294
-
295
  headers = {"User-Agent": "Mozilla/5.0"}
296
  resp = requests.get(image_input, timeout=20, headers=headers, allow_redirects=True, stream=True)
297
  resp.raise_for_status()
298
-
299
- # Content-Type kontrolü (varsa)
300
  ctype = resp.headers.get("Content-Type", "").lower()
301
  if ctype and not ctype.startswith("image/"):
302
  print(f"[warn] Non-image content-type: {ctype}")
303
  return None
304
-
305
- # Boyut kontrolü (Content-Length varsa)
306
  clen = resp.headers.get("Content-Length")
307
  if clen is not None:
308
  try:
@@ -311,68 +353,54 @@ class EndpointHandler:
311
  return None
312
  except Exception:
313
  pass
314
-
315
- # Stream’den kontrollü oku
316
  data = resp.content
317
  if len(data) > MAX_IMAGE_BYTES:
318
  print(f"[warn] Image too large (actual): {len(data)} bytes")
319
  return None
320
-
321
- img = Image.open(BytesIO(data)).convert("RGB")
322
  print(f"[info] URL image loaded: size={img.size}")
323
  return img
324
-
325
- # ---- Base64 input (data URL dahil) ----
326
  if isinstance(image_input, str):
327
  b64 = image_input.strip()
328
-
329
- # data URL prefix varsa ayıkla
330
  if b64.startswith("data:image"):
331
- # ör: data:image/png;base64,AAAA...
332
  if "base64," in b64:
333
  b64 = b64.split("base64,", 1)[1]
334
  else:
335
- # bazen ;base64 sonrası newline vb olabilir
336
  b64 = b64.split(",", 1)[-1]
337
-
338
- # boşluk/newline temizliği + padding düzeltme
339
  b64 = b64.replace("\n", "").replace("\r", "").replace(" ", "")
340
  missing = (4 - len(b64) % 4) % 4
341
  b64 += "=" * missing
342
-
343
  try:
344
  data = base64.b64decode(b64, validate=False)
345
  if len(data) > MAX_IMAGE_BYTES:
346
  print(f"[warn] Base64 image too large: {len(data)} bytes")
347
  return None
348
- img = Image.open(BytesIO(data)).convert("RGB")
349
  print(f"[info] Base64 image loaded: size={img.size}")
350
  return img
351
  except Exception as e:
352
  print(f"[warn] Base64 decode/open failed: {e}")
353
- # Devam edip path olarak deneyeceğiz
354
-
355
- # ---- Yerel path ----
356
  if isinstance(image_input, str) and os.path.exists(image_input):
357
  img = Image.open(image_input).convert("RGB")
358
  print(f"[info] Local image loaded: size={img.size}")
359
  return img
360
-
361
  except Exception as e:
362
  print(f"[warn] image load failed: {e}")
363
  return None
364
-
365
- return None
366
 
 
367
 
368
  def _build_prompt(self, user_text: str, conv_mode: str) -> str:
369
  """Model worker tarzında prompt oluştur"""
370
  if conv_mode not in conv_templates:
371
  conv_mode = DEFAULT_CONV_MODE
372
  conv = conv_templates[conv_mode].copy()
373
-
374
- # Model worker'da görüntüler sonradan replace edilir
375
- # Şimdilik sadece text ile başlayalım
376
  conv.append_message(conv.roles[0], user_text)
377
  conv.append_message(conv.roles[1], None)
378
  return conv.get_prompt()
@@ -387,95 +415,81 @@ class EndpointHandler:
387
  query_text = inputs.get("query", "") or inputs.get("text", "") or inputs.get("prompt", "")
388
  image_f = inputs.get("image") or inputs.get("image_url") or inputs.get("image_base64")
389
 
390
- # 1) İlk prompt oluştur (görüntü olmadan)
391
  prompt = self._build_prompt(query_text, conv_mode)
392
-
393
- # 2) Görüntü işleme (model worker tarzında)
394
  images = None
395
  image_sizes = None
396
-
397
  if image_f and self.is_multimodal:
398
  try:
399
  pil_image = self._load_image(image_f)
400
- if pil_image is not None:
401
  images_list = [pil_image]
402
  image_sizes = [pil_image.size]
403
-
404
- # Model worker'daki gibi process et
405
  processed_images = process_images(images_list, self.image_processor, self.model.config)
406
-
407
  if isinstance(processed_images, list):
408
  images = [img.to(self.model.device, dtype=torch.float16) for img in processed_images]
409
  else:
410
  images = processed_images.to(self.model.device, dtype=torch.float16)
411
-
412
- # Model worker'daki gibi prompt'u düzenle
413
- # DEFAULT_IMAGE_TOKEN'ı prompt'a ekle
414
  prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
415
-
416
- # Replace token hesapla (model worker'dan)
417
  replace_token = DEFAULT_IMAGE_TOKEN
418
  if self.use_im_start_end:
419
  replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
420
-
421
- # Prompt'taki image token'ları replace et
422
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
423
-
424
  print(f"[info] Image processed successfully")
425
- print(f"[debug] Final prompt: {repr(prompt[:200])}")
426
  else:
427
- print("[warn] Could not load image")
428
  except Exception as e:
429
  print(f"[warn] Image processing failed: {e}")
430
- import traceback
431
- traceback.print_exc()
432
- images = None
433
- image_sizes = None
434
 
435
- # 3) Tokenize (model worker tarzında)
436
  try:
437
  input_ids = tokenizer_image_token(
438
  prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt'
439
  ).unsqueeze(0).to(self.model.device)
440
-
441
- print(f"[debug] input_ids shape: {input_ids.shape}")
442
- print(f"[debug] Has images: {images is not None}")
443
-
444
  except Exception as e:
445
  print(f"[error] Tokenization failed: {e}")
446
  # Fallback to text-only
447
- input_ids = self.tokenizer(query_text, return_tensors="pt").input_ids
448
- input_ids = input_ids.to(self.model.device)
449
  images = None
450
  image_sizes = None
451
 
452
- # 4) Generation parameters (model worker tarzında)
453
  temperature = float(params.get("temperature", 0.0))
454
  top_p = float(params.get("top_p", 1.0))
455
  repetition_penalty = float(params.get("repetition_penalty", 1.0))
456
  max_new_tokens = min(int(params.get("max_new_tokens", MAX_NEW_TOKENS_DEF)), 1024)
457
  do_sample = bool(params.get("do_sample", temperature > 0.001))
458
-
459
- # Context length check
460
- max_context_length = getattr(self.model.config, 'max_position_embeddings', 2048)
461
- max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - 50)
462
-
463
  if max_new_tokens < 1:
464
  return [{"generated_text": "Error: Input too long, exceeds max token length."}]
465
 
466
- # 5) Generation kwargs (model worker tarzında)
467
  gen_kwargs = {
468
- "inputs": input_ids, # model worker 'inputs' kullanır
469
  "max_new_tokens": max_new_tokens,
470
  "temperature": temperature,
471
  "top_p": top_p,
472
  "repetition_penalty": repetition_penalty,
473
  "do_sample": do_sample,
474
  "use_cache": bool(params.get("use_cache", True)),
475
- "pad_token_id": self.tokenizer.eos_token_id,
476
  }
477
 
478
- # Image args (model worker tarzında)
 
 
 
 
 
479
  if images is not None and image_sizes is not None:
480
  gen_kwargs["images"] = images
481
  gen_kwargs["image_sizes"] = image_sizes
@@ -483,21 +497,26 @@ class EndpointHandler:
483
  else:
484
  print("[info] Text-only generation")
485
 
 
486
  try:
487
  with torch.inference_mode():
488
  output_ids = self.model.generate(**gen_kwargs)
489
-
490
- # Output'u input'tan ayır
491
- if output_ids.shape[-1] > input_ids.shape[-1]:
492
- response_ids = output_ids[:, input_ids.shape[-1]:]
493
- text = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)[0].strip()
 
494
  else:
495
- text = "Error: No response generated"
496
-
497
- except Exception as e:
498
- print(f"Generation error: {e}")
499
- import traceback
500
- traceback.print_exc()
501
- text = f"Error during generation: {str(e)}"
502
-
503
- return [{"generated_text": text}]
 
 
 
 
 
1
  # -*- coding: utf-8 -*-
2
+ # handler.py — PULSE-7B / LLaVA robust endpoint
3
+ # - LLaVA kaynak kodunu runtime'da git clone ile getirir
4
+ # - image_processor fallback (AutoProcessor / vision_tower)
5
+ # - anyres -> pad güvenli düşüş
6
+ # - gelişmiş _load_image: UA, redirect, Content-Type kontrolü, base64 padding, boyut sınırı
7
+ # - forward patch (cache_position/input_positions sessizce düşür)
8
+ # - pad_token garanti + conditional attention_mask (+ retry) — HF generate hatalarını önler
9
+
10
  import os, io, sys, subprocess, base64
11
  from typing import Any, Dict, List, Optional
12
 
 
15
  import requests
16
  import math
17
  import ast
 
18
  from urllib.parse import urlparse
19
+ import inspect
20
 
21
  # ===== Kullanılacak HF model id =====
22
  MODEL_ID = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
 
48
  except ImportError:
49
  # Fallback: kendi implementasyonumuzu kullan
50
  from llava.constants import IMAGE_TOKEN_INDEX
51
+
52
  def expand2square(pil_img, background_color):
53
  width, height = pil_img.size
54
  if width == height:
 
67
  best_fit = None
68
  max_effective_resolution = 0
69
  min_wasted_resolution = float('inf')
 
70
  for width, height in possible_resolutions:
71
  scale = min(width / original_width, height / original_height)
72
  downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
73
  effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
74
  wasted_resolution = (width * height) - effective_resolution
 
75
  if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
76
  max_effective_resolution = effective_resolution
77
  min_wasted_resolution = wasted_resolution
 
81
  def resize_and_pad_image(image, target_resolution):
82
  original_width, original_height = image.size
83
  target_width, target_height = target_resolution
 
84
  scale_w = target_width / original_width
85
  scale_h = target_height / original_height
 
86
  if scale_w < scale_h:
87
  new_width = target_width
88
  new_height = min(math.ceil(original_height * scale_w), target_height)
89
  else:
90
  new_height = target_height
91
  new_width = min(math.ceil(original_width * scale_h), target_width)
 
92
  resized_image = image.resize((new_width, new_height))
93
  new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
94
  paste_x = (target_width - new_width) // 2
 
182
  from llava.conversation import conv_templates
183
  from llava.utils import disable_torch_init
184
 
185
+ # HF processor fallback'ları
186
+ from transformers import AutoProcessor, AutoImageProcessor, CLIPImageProcessor
187
+
188
  # Varsayılanlar
189
  DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v1")
190
  MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "1024"))
 
215
 
216
  # Attention implementation otomatik seç
217
  try:
218
+ import flash_attn # noqa: F401
219
  attn_impl = "flash_attention_2"
220
  except ImportError:
221
  attn_impl = "sdpa"
222
+
223
  # PULSE, LLaVA tabanlı olduğundan LLaVA loader ile yüklenir
224
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
225
  model_path=model_path,
 
231
  )
232
  self.model.eval()
233
 
234
+ # --- PAD TOKEN FIX -------------------------------------------------
235
+ if self.tokenizer.pad_token_id is None:
236
+ self.tokenizer.add_special_tokens({'pad_token': '<pad>'})
237
+ try:
238
+ self.model.resize_token_embeddings(len(self.tokenizer))
239
+ except Exception as e:
240
+ print(f"[warn] resize_token_embeddings failed: {e}")
241
+
242
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
243
+ try:
244
+ self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
245
+ except Exception:
246
+ pass
247
+ # -------------------------------------------------------------------
248
+
249
+ # ---- forward patch (HF yeni arg uyumu) ----
250
  def _patch_forward(obj, label="model"):
251
  try:
252
  if not hasattr(obj, "forward"):
253
  return False
254
  orig_forward = obj.forward
 
255
  def patched_forward(*args, **kwargs):
 
256
  kwargs.pop("cache_position", None)
257
  kwargs.pop("input_positions", None)
258
  return orig_forward(*args, **kwargs)
 
259
  obj.forward = patched_forward
260
  print(f"[hotfix] Patched forward on {label}")
261
  return True
 
263
  print(f"[warn] forward patch failed on {label}: {e}")
264
  return False
265
 
 
266
  _patch_forward(self.model, "self.model")
 
 
267
  if hasattr(self.model, "model"):
268
  _patch_forward(self.model.model, "self.model.model")
269
  if hasattr(self.model, "base_model"):
270
  _patch_forward(self.model.base_model, "self.model.base_model")
271
+
272
+ # ---- image_processor fallback ----
273
+ if self.image_processor is None:
274
+ print("[hotfix] image_processor None, AutoProcessor/vision_tower fallback deneniyor...")
275
+ try:
276
+ proc = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
277
+ self.image_processor = getattr(proc, "image_processor", proc)
278
+ except Exception as e:
279
+ print(f"[warn] AutoProcessor başarısız: {e}")
280
+ vt = getattr(self.model.config, "vision_tower", None)
281
+ if vt:
282
+ try:
283
+ self.image_processor = AutoImageProcessor.from_pretrained(vt, trust_remote_code=True)
284
+ except Exception as e2:
285
+ print(f"[warn] AutoImageProcessor failed: {e2}")
286
+ try:
287
+ self.image_processor = CLIPImageProcessor.from_pretrained(vt)
288
+ except Exception as e3:
289
+ print(f"[warn] CLIPImageProcessor failed: {e3}")
290
+
291
+ # anyres -> pad fallback (processor/crop_size yoksa)
292
+ iar = getattr(self.model.config, "mm_image_aspect_ratio", None) or getattr(self.model.config, "image_aspect_ratio", None)
293
+ needs_crop = (self.image_processor is None) or (getattr(self.image_processor, "crop_size", None) is None)
294
+ if iar == "anyres" and needs_crop:
295
+ print("[hotfix] image_aspect_ratio:anyres -> pad (processor/crop_size eksik)")
296
+ if hasattr(self.model.config, "image_aspect_ratio"):
297
+ self.model.config.image_aspect_ratio = "pad"
298
+ if hasattr(self.model.config, "mm_image_aspect_ratio"):
299
+ self.model.config.mm_image_aspect_ratio = "pad"
300
 
301
  # Model worker'dan: multimodal check
302
  self.is_multimodal = 'llava' in self.model_name.lower() or 'pulse' in self.model_name.lower()
303
+
304
  # Görsel token işaretleri (LLaVA config)
305
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
306
 
307
+ # attention_mask desteğini tespit et
308
+ try:
309
+ sig = inspect.signature(self.model.forward)
310
+ self._supports_attention_mask = ("attention_mask" in sig.parameters)
311
+ except Exception:
312
+ self._supports_attention_mask = False
313
+
314
+ # ---- gelişmiş image loader ----
315
  def _load_image(self, image_input: str) -> Optional[Image.Image]:
316
  """
317
  URL / base64 / yerel path -> PIL.Image
318
+ - URL: UA header, redirect, Content-Type kontrolü, boyut sınırı
319
+ - base64: data URL prefix temizliği + padding düzeltme
320
+ - path: doğrudan aç
321
  """
322
  if not image_input:
323
  return None
324
+
325
+ MAX_IMAGE_BYTES = int(os.getenv("MAX_IMAGE_BYTES", "26214400")) # 25MB
 
326
  ALLOWED_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"}
327
+
328
  def _is_valid_image_format(url: str) -> bool:
329
  try:
330
  ext = os.path.splitext(urlparse(url).path.lower())[1]
 
331
  return (ext in ALLOWED_EXTS) or (ext == "")
332
  except Exception:
333
  return True
334
+
335
  try:
336
+ # URL
337
  if isinstance(image_input, str) and image_input.startswith(("http://", "https://")):
338
  if not _is_valid_image_format(image_input):
339
  print("[warn] Invalid image extension in URL")
340
  return None
 
341
  headers = {"User-Agent": "Mozilla/5.0"}
342
  resp = requests.get(image_input, timeout=20, headers=headers, allow_redirects=True, stream=True)
343
  resp.raise_for_status()
 
 
344
  ctype = resp.headers.get("Content-Type", "").lower()
345
  if ctype and not ctype.startswith("image/"):
346
  print(f"[warn] Non-image content-type: {ctype}")
347
  return None
 
 
348
  clen = resp.headers.get("Content-Length")
349
  if clen is not None:
350
  try:
 
353
  return None
354
  except Exception:
355
  pass
 
 
356
  data = resp.content
357
  if len(data) > MAX_IMAGE_BYTES:
358
  print(f"[warn] Image too large (actual): {len(data)} bytes")
359
  return None
360
+ img = Image.open(io.BytesIO(data)).convert("RGB")
 
361
  print(f"[info] URL image loaded: size={img.size}")
362
  return img
363
+
364
+ # Base64 (data URL dahil)
365
  if isinstance(image_input, str):
366
  b64 = image_input.strip()
 
 
367
  if b64.startswith("data:image"):
 
368
  if "base64," in b64:
369
  b64 = b64.split("base64,", 1)[1]
370
  else:
 
371
  b64 = b64.split(",", 1)[-1]
 
 
372
  b64 = b64.replace("\n", "").replace("\r", "").replace(" ", "")
373
  missing = (4 - len(b64) % 4) % 4
374
  b64 += "=" * missing
 
375
  try:
376
  data = base64.b64decode(b64, validate=False)
377
  if len(data) > MAX_IMAGE_BYTES:
378
  print(f"[warn] Base64 image too large: {len(data)} bytes")
379
  return None
380
+ img = Image.open(io.BytesIO(data)).convert("RGB")
381
  print(f"[info] Base64 image loaded: size={img.size}")
382
  return img
383
  except Exception as e:
384
  print(f"[warn] Base64 decode/open failed: {e}")
385
+ # path olarak denemeye devam
386
+
387
+ # Yerel path
388
  if isinstance(image_input, str) and os.path.exists(image_input):
389
  img = Image.open(image_input).convert("RGB")
390
  print(f"[info] Local image loaded: size={img.size}")
391
  return img
392
+
393
  except Exception as e:
394
  print(f"[warn] image load failed: {e}")
395
  return None
 
 
396
 
397
+ return None
398
 
399
  def _build_prompt(self, user_text: str, conv_mode: str) -> str:
400
  """Model worker tarzında prompt oluştur"""
401
  if conv_mode not in conv_templates:
402
  conv_mode = DEFAULT_CONV_MODE
403
  conv = conv_templates[conv_mode].copy()
 
 
 
404
  conv.append_message(conv.roles[0], user_text)
405
  conv.append_message(conv.roles[1], None)
406
  return conv.get_prompt()
 
415
  query_text = inputs.get("query", "") or inputs.get("text", "") or inputs.get("prompt", "")
416
  image_f = inputs.get("image") or inputs.get("image_url") or inputs.get("image_base64")
417
 
418
+ # 1) İlk prompt
419
  prompt = self._build_prompt(query_text, conv_mode)
420
+
421
+ # 2) Görüntü işleme
422
  images = None
423
  image_sizes = None
 
424
  if image_f and self.is_multimodal:
425
  try:
426
  pil_image = self._load_image(image_f)
427
+ if pil_image is not None and self.image_processor is not None:
428
  images_list = [pil_image]
429
  image_sizes = [pil_image.size]
 
 
430
  processed_images = process_images(images_list, self.image_processor, self.model.config)
 
431
  if isinstance(processed_images, list):
432
  images = [img.to(self.model.device, dtype=torch.float16) for img in processed_images]
433
  else:
434
  images = processed_images.to(self.model.device, dtype=torch.float16)
435
+ # Görsel token ekle + im_start/end sarma
 
 
436
  prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
 
 
437
  replace_token = DEFAULT_IMAGE_TOKEN
438
  if self.use_im_start_end:
439
  replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
 
 
440
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
 
441
  print(f"[info] Image processed successfully")
 
442
  else:
443
+ print("[warn] Could not load image or image_processor is None.")
444
  except Exception as e:
445
  print(f"[warn] Image processing failed: {e}")
446
+ import traceback; traceback.print_exc()
447
+ images = None; image_sizes = None
 
 
448
 
449
+ # 3) Tokenize
450
  try:
451
  input_ids = tokenizer_image_token(
452
  prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt'
453
  ).unsqueeze(0).to(self.model.device)
454
+ print(f"[debug] input_ids shape: {input_ids.shape} | has images: {images is not None}")
 
 
 
455
  except Exception as e:
456
  print(f"[error] Tokenization failed: {e}")
457
  # Fallback to text-only
458
+ input_ids = self.tokenizer(query_text, return_tensors="pt").input_ids.to(self.model.device)
 
459
  images = None
460
  image_sizes = None
461
 
462
+ # 4) Generation parameters
463
  temperature = float(params.get("temperature", 0.0))
464
  top_p = float(params.get("top_p", 1.0))
465
  repetition_penalty = float(params.get("repetition_penalty", 1.0))
466
  max_new_tokens = min(int(params.get("max_new_tokens", MAX_NEW_TOKENS_DEF)), 1024)
467
  do_sample = bool(params.get("do_sample", temperature > 0.001))
468
+
469
+ # Context length check (güvenli boşluk)
470
+ max_context_length = getattr(self.model.config, 'max_position_embeddings', 4096)
471
+ max_new_tokens = min(max_new_tokens, max(1, max_context_length - input_ids.shape[-1] - 50))
 
472
  if max_new_tokens < 1:
473
  return [{"generated_text": "Error: Input too long, exceeds max token length."}]
474
 
475
+ # 5) Generation kwargs
476
  gen_kwargs = {
477
+ "inputs": input_ids,
478
  "max_new_tokens": max_new_tokens,
479
  "temperature": temperature,
480
  "top_p": top_p,
481
  "repetition_penalty": repetition_penalty,
482
  "do_sample": do_sample,
483
  "use_cache": bool(params.get("use_cache", True)),
484
+ "pad_token_id": self.tokenizer.pad_token_id, # pad garanti
485
  }
486
 
487
+ # attention_mask'i model destekliyorsa ekle
488
+ if getattr(self, "_supports_attention_mask", False):
489
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
490
+ gen_kwargs["attention_mask"] = attention_mask
491
+
492
+ # Image args
493
  if images is not None and image_sizes is not None:
494
  gen_kwargs["images"] = images
495
  gen_kwargs["image_sizes"] = image_sizes
 
497
  else:
498
  print("[info] Text-only generation")
499
 
500
+ # 6) Generate (+ unused kwargs için retry)
501
  try:
502
  with torch.inference_mode():
503
  output_ids = self.model.generate(**gen_kwargs)
504
+ except ValueError as e:
505
+ if "model_kwargs" in str(e) and "attention_mask" in str(e):
506
+ print("[hotfix] model doesn't accept attention_mask; retrying without it")
507
+ gen_kwargs.pop("attention_mask", None)
508
+ with torch.inference_mode():
509
+ output_ids = self.model.generate(**gen_kwargs)
510
  else:
511
+ print(f"Generation error: {e}")
512
+ import traceback; traceback.print_exc()
513
+ return [{"generated_text": f"Error during generation: {str(e)}"}]
514
+
515
+ # 7) Output'u input'tan ayır
516
+ if output_ids.shape[-1] > input_ids.shape[-1]:
517
+ response_ids = output_ids[:, input_ids.shape[-1]:]
518
+ text = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)[0].strip()
519
+ else:
520
+ text = "Error: No response generated"
521
+
522
+ return [{"generated_text": text}]