CanerDedeoglu commited on
Commit
83e2bc9
·
verified ·
1 Parent(s): 6d1697f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +296 -248
handler.py CHANGED
@@ -1,24 +1,31 @@
1
-
2
  # -*- coding: utf-8 -*-
 
 
 
 
 
 
 
 
 
3
  import os, io, sys, subprocess, base64
4
- from typing import Any, Dict, List, Optional
5
 
6
  import torch
7
  from PIL import Image
8
  import requests
9
  import math
10
- import ast
11
 
12
  # ===== Kullanılacak HF model id =====
13
  MODEL_ID = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
14
 
15
- # Flash Attention için environment
16
  os.environ.setdefault("FLASH_ATTENTION", "1")
17
  os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
18
 
19
- # ===== LLaVA kaynak kodunu runtime'da getir (pip yok) =====
20
  LLAVA_GIT_URL = os.getenv("LLAVA_GIT_URL", "https://github.com/haotian-liu/LLaVA.git")
21
- LLAVA_GIT_REF = os.getenv("LLAVA_GIT_REF", "v1.2.2.post1") # kanıtlı, stabil
22
  LLAVA_SRC_DIR = os.getenv("LLAVA_SRC_DIR", "/tmp/llava_src/LLaVA")
23
 
24
  def _ensure_llava():
@@ -33,141 +40,7 @@ def _ensure_llava():
33
 
34
  _ensure_llava()
35
 
36
- # ---- mm_utils fonksiyonlarını import etmeye çalış, yoksa kendi implement edelim ----
37
- try:
38
- from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path, load_image_from_base64
39
- except ImportError:
40
- # Fallback: kendi implementasyonumuzu kullan
41
- from llava.constants import IMAGE_TOKEN_INDEX
42
-
43
- def expand2square(pil_img, background_color):
44
- width, height = pil_img.size
45
- if width == height:
46
- return pil_img
47
- elif width > height:
48
- result = Image.new(pil_img.mode, (width, width), background_color)
49
- result.paste(pil_img, (0, (width - height) // 2))
50
- return result
51
- else:
52
- result = Image.new(pil_img.mode, (height, height), background_color)
53
- result.paste(pil_img, ((height - width) // 2, 0))
54
- return result
55
-
56
- def select_best_resolution(original_size, possible_resolutions):
57
- original_width, original_height = original_size
58
- best_fit = None
59
- max_effective_resolution = 0
60
- min_wasted_resolution = float('inf')
61
-
62
- for width, height in possible_resolutions:
63
- scale = min(width / original_width, height / original_height)
64
- downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
65
- effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
66
- wasted_resolution = (width * height) - effective_resolution
67
-
68
- if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
69
- max_effective_resolution = effective_resolution
70
- min_wasted_resolution = wasted_resolution
71
- best_fit = (width, height)
72
- return best_fit
73
-
74
- def resize_and_pad_image(image, target_resolution):
75
- original_width, original_height = image.size
76
- target_width, target_height = target_resolution
77
-
78
- scale_w = target_width / original_width
79
- scale_h = target_height / original_height
80
-
81
- if scale_w < scale_h:
82
- new_width = target_width
83
- new_height = min(math.ceil(original_height * scale_w), target_height)
84
- else:
85
- new_height = target_height
86
- new_width = min(math.ceil(original_width * scale_h), target_width)
87
-
88
- resized_image = image.resize((new_width, new_height))
89
- new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
90
- paste_x = (target_width - new_width) // 2
91
- paste_y = (target_height - new_height) // 2
92
- new_image.paste(resized_image, (paste_x, paste_y))
93
- return new_image
94
-
95
- def divide_to_patches(image, patch_size):
96
- patches = []
97
- width, height = image.size
98
- for i in range(0, height, patch_size):
99
- for j in range(0, width, patch_size):
100
- box = (j, i, j + patch_size, i + patch_size)
101
- patch = image.crop(box)
102
- patches.append(patch)
103
- return patches
104
-
105
- def process_anyres_image(image, processor, grid_pinpoints):
106
- if type(grid_pinpoints) is list:
107
- possible_resolutions = grid_pinpoints
108
- else:
109
- possible_resolutions = ast.literal_eval(grid_pinpoints)
110
- best_resolution = select_best_resolution(image.size, possible_resolutions)
111
- image_padded = resize_and_pad_image(image, best_resolution)
112
- patches = divide_to_patches(image_padded, processor.crop_size['height'])
113
- image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
114
- image_patches = [image_original_resize] + patches
115
- image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
116
- for image_patch in image_patches]
117
- return torch.stack(image_patches, dim=0)
118
-
119
- def process_images(images, image_processor, model_cfg):
120
- """CRITICAL: Tam mm_utils.py implementasyonu"""
121
- image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
122
- new_images = []
123
- if image_aspect_ratio == 'pad':
124
- for image in images:
125
- image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
126
- image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
127
- new_images.append(image)
128
- elif image_aspect_ratio == "anyres":
129
- for image in images:
130
- image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
131
- new_images.append(image)
132
- else:
133
- return image_processor(images, return_tensors='pt')['pixel_values']
134
- if all(x.shape == new_images[0].shape for x in new_images):
135
- new_images = torch.stack(new_images, dim=0)
136
- return new_images
137
-
138
- def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
139
- prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
140
-
141
- def insert_separator(X, sep):
142
- return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
143
-
144
- input_ids = []
145
- offset = 0
146
- if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
147
- offset = 1
148
- input_ids.append(prompt_chunks[0][0])
149
-
150
- for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
151
- input_ids.extend(x[offset:])
152
-
153
- if return_tensors is not None:
154
- if return_tensors == 'pt':
155
- return torch.tensor(input_ids, dtype=torch.long)
156
- raise ValueError(f'Unsupported tensor type: {return_tensors}')
157
- return input_ids
158
-
159
- def get_model_name_from_path(model_path):
160
- model_path = model_path.strip("/")
161
- model_paths = model_path.split("/")
162
- if model_paths[-1].startswith('checkpoint-'):
163
- return model_paths[-2] + "_" + model_paths[-1]
164
- else:
165
- return model_paths[-1]
166
-
167
- def load_image_from_base64(image):
168
- return Image.open(io.BytesIO(base64.b64decode(image)))
169
-
170
- # ---- LLaVA parçaları (model worker'dan alındı) ----
171
  from llava.model.builder import load_pretrained_model
172
  from llava.constants import (
173
  IMAGE_TOKEN_INDEX,
@@ -178,42 +51,206 @@ from llava.constants import (
178
  from llava.conversation import conv_templates
179
  from llava.utils import disable_torch_init
180
 
181
- # Varsayılanlar
182
- DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v1")
183
- MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "1024"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  class EndpointHandler:
186
  """
187
  Girdi:
188
- {
189
- "inputs": { "query": "...", "image": "<url|dataurl|path>" },
190
- "parameters": { "max_new_tokens": 256, "temperature": 0.0, "top_p": 1.0,
191
- "repetition_penalty": 1.0, "do_sample": false, "use_cache": true },
192
- "conv_mode": "llava_v2" # opsiyonel
193
- }
 
 
194
  Çıktı: [ { "generated_text": "..." } ]
195
  """
196
  def __init__(self, path: str = "") -> None:
197
  disable_torch_init()
198
 
199
- # PULSE-7B HF'den/yerelden nereden yükleniyorsa yolu belirle
200
  if os.getenv("HF_MODEL_LOCAL_DIR", "").strip():
201
  model_path = os.getenv("HF_MODEL_LOCAL_DIR").strip()
202
  elif os.getenv("HF_MODEL_ID", "").strip():
203
  model_path = os.getenv("HF_MODEL_ID").strip()
204
  else:
205
- model_path = MODEL_ID # default: HF Hub PULSE-7B
 
 
 
206
 
207
  self.model_name = get_model_name_from_path(model_path)
208
 
209
- # Attention implementation otomatik seç
210
  try:
211
- import flash_attn
212
  attn_impl = "flash_attention_2"
213
- except ImportError:
214
  attn_impl = "sdpa"
215
-
216
- # PULSE, LLaVA tabanlı olduğundan LLaVA loader ile yüklenir
217
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
218
  model_path=model_path,
219
  model_base=None,
@@ -224,18 +261,15 @@ class EndpointHandler:
224
  )
225
  self.model.eval()
226
 
 
227
  def _patch_forward(obj, label="model"):
228
  try:
229
- if not hasattr(obj, "forward"):
230
- return False
231
- orig_forward = obj.forward
232
-
233
  def patched_forward(*args, **kwargs):
234
- # Sessizce düşürülecek yeni anahtarlar
235
  kwargs.pop("cache_position", None)
236
  kwargs.pop("input_positions", None)
237
- return orig_forward(*args, **kwargs)
238
-
239
  obj.forward = patched_forward
240
  print(f"[hotfix] Patched forward on {label}")
241
  return True
@@ -243,34 +277,56 @@ class EndpointHandler:
243
  print(f"[warn] forward patch failed on {label}: {e}")
244
  return False
245
 
246
- # Ana modelde dene
247
  _patch_forward(self.model, "self.model")
 
 
248
 
249
- # Bazı sürümlerde forward zinciri iç modüle de gider
250
- if hasattr(self.model, "model"):
251
- _patch_forward(self.model.model, "self.model.model")
252
- if hasattr(self.model, "base_model"):
253
- _patch_forward(self.model.base_model, "self.model.base_model")
254
- # =======================================================================
255
-
256
- # Model worker'dan: multimodal check
257
- self.is_multimodal = 'llava' in self.model_name.lower() or 'pulse' in self.model_name.lower()
258
-
259
- # Görsel token işaretleri (LLaVA config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
 
 
 
 
 
261
 
262
- # ---- yardımcılar ----
 
 
263
  def _load_image(self, img_field: str) -> Optional[Image.Image]:
264
  """URL / base64 / path -> PIL.Image"""
265
- if not img_field:
266
- return None
267
  try:
268
  if img_field.startswith("data:image"):
269
  _, b64 = img_field.split(",", 1)
270
  return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
271
  if img_field.startswith(("http://", "https://")):
272
- r = requests.get(img_field, timeout=20)
273
- r.raise_for_status()
274
  return Image.open(io.BytesIO(r.content)).convert("RGB")
275
  return Image.open(img_field).convert("RGB")
276
  except Exception as e:
@@ -278,106 +334,90 @@ class EndpointHandler:
278
  return None
279
 
280
  def _build_prompt(self, user_text: str, conv_mode: str) -> str:
281
- """Model worker tarzında prompt oluştur"""
282
  if conv_mode not in conv_templates:
283
- conv_mode = DEFAULT_CONV_MODE
284
  conv = conv_templates[conv_mode].copy()
285
-
286
- # Model worker'da görüntüler sonradan replace edilir
287
- # Şimdilik sadece text ile başlayalım
288
  conv.append_message(conv.roles[0], user_text)
289
  conv.append_message(conv.roles[1], None)
290
  return conv.get_prompt()
291
 
292
- # ---- inference ----
 
 
293
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
294
- inputs = data.get("inputs") or {}
295
- params = data.get("parameters") or {}
296
- conv_mode_req = data.get("conv_mode")
297
 
298
- conv_mode = conv_mode_req if conv_mode_req in conv_templates else DEFAULT_CONV_MODE
299
  query_text = inputs.get("query", "") or inputs.get("text", "") or inputs.get("prompt", "")
300
  image_f = inputs.get("image") or inputs.get("image_url") or inputs.get("image_base64")
301
 
302
- # 1) İlk prompt oluştur (görüntü olmadan)
303
  prompt = self._build_prompt(query_text, conv_mode)
304
-
305
- # 2) Görüntü işleme (model worker tarzında)
306
  images = None
307
  image_sizes = None
308
-
309
  if image_f and self.is_multimodal:
310
  try:
311
  pil_image = self._load_image(image_f)
312
- if pil_image is not None:
313
  images_list = [pil_image]
314
  image_sizes = [pil_image.size]
315
-
316
- # Model worker'daki gibi process et
317
  processed_images = process_images(images_list, self.image_processor, self.model.config)
318
-
319
  if isinstance(processed_images, list):
320
  images = [img.to(self.model.device, dtype=torch.float16) for img in processed_images]
321
  else:
322
  images = processed_images.to(self.model.device, dtype=torch.float16)
323
-
324
- # Model worker'daki gibi prompt'u düzenle
325
- # DEFAULT_IMAGE_TOKEN prompt'a ekle
326
- prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
327
-
328
- # Replace token hesapla (model worker'dan)
329
  replace_token = DEFAULT_IMAGE_TOKEN
330
  if self.use_im_start_end:
331
  replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
332
-
333
- # Prompt'taki image token'ları replace et
334
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
335
-
336
- print(f"[info] Image processed successfully")
337
- print(f"[debug] Final prompt: {repr(prompt[:200])}")
338
  else:
339
- print("[warn] Could not load image")
340
  except Exception as e:
341
  print(f"[warn] Image processing failed: {e}")
342
- import traceback
343
- traceback.print_exc()
344
- images = None
345
- image_sizes = None
346
 
347
- # 3) Tokenize (model worker tarzında)
348
  try:
349
  input_ids = tokenizer_image_token(
350
  prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt'
351
  ).unsqueeze(0).to(self.model.device)
352
-
353
- print(f"[debug] input_ids shape: {input_ids.shape}")
354
- print(f"[debug] Has images: {images is not None}")
355
-
356
  except Exception as e:
357
  print(f"[error] Tokenization failed: {e}")
358
- # Fallback to text-only
359
- input_ids = self.tokenizer(query_text, return_tensors="pt").input_ids
360
- input_ids = input_ids.to(self.model.device)
361
- images = None
362
- image_sizes = None
363
 
364
- # 4) Generation parameters (model worker tarzında)
365
  temperature = float(params.get("temperature", 0.0))
366
  top_p = float(params.get("top_p", 1.0))
367
  repetition_penalty = float(params.get("repetition_penalty", 1.0))
368
- max_new_tokens = min(int(params.get("max_new_tokens", MAX_NEW_TOKENS_DEF)), 1024)
369
  do_sample = bool(params.get("do_sample", temperature > 0.001))
370
-
371
- # Context length check
372
- max_context_length = getattr(self.model.config, 'max_position_embeddings', 2048)
373
- max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - 50)
374
-
375
  if max_new_tokens < 1:
376
  return [{"generated_text": "Error: Input too long, exceeds max token length."}]
377
 
378
- # 5) Generation kwargs (model worker tarzında)
379
- gen_kwargs = {
380
- "inputs": input_ids, # model worker 'inputs' kullanır
 
381
  "max_new_tokens": max_new_tokens,
382
  "temperature": temperature,
383
  "top_p": top_p,
@@ -387,29 +427,37 @@ class EndpointHandler:
387
  "pad_token_id": self.tokenizer.eos_token_id,
388
  }
389
 
390
- # Image args (model worker tarzında)
391
  if images is not None and image_sizes is not None:
392
  gen_kwargs["images"] = images
393
  gen_kwargs["image_sizes"] = image_sizes
394
- print(f"[info] Using images in generation")
395
  else:
396
- print("[info] Text-only generation")
397
-
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  try:
399
  with torch.inference_mode():
400
  output_ids = self.model.generate(**gen_kwargs)
401
-
402
- # Output'u input'tan ayır
403
- if output_ids.shape[-1] > input_ids.shape[-1]:
404
- response_ids = output_ids[:, input_ids.shape[-1]:]
405
  text = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)[0].strip()
406
  else:
407
  text = "Error: No response generated"
408
-
409
  except Exception as e:
410
  print(f"Generation error: {e}")
411
- import traceback
412
- traceback.print_exc()
413
  text = f"Error during generation: {str(e)}"
414
-
415
- return [{"generated_text": text}]
 
 
1
  # -*- coding: utf-8 -*-
2
+ # handler.py — PULSE-7B / LLaVA robust endpoint
3
+ # - LLaVA kaynak kodunu runtime'da git clone ile getirir
4
+ # - image_processor fallback (AutoProcessor / vision_tower)
5
+ # - anyres -> pad güvenli düşüş
6
+ # - preprocess/call farkını soyutlama
7
+ # - attention_mask zorunlu (HF generate NoneType.new_ones fix)
8
+ # - forward patch (cache_position/input_positions sessizce düşür)
9
+ # - robust image pipeline (pad_to_multiple, crop_size/shortest_edge tespiti)
10
+
11
  import os, io, sys, subprocess, base64
12
+ from typing import Any, Dict, List, Optional, Tuple
13
 
14
  import torch
15
  from PIL import Image
16
  import requests
17
  import math
 
18
 
19
  # ===== Kullanılacak HF model id =====
20
  MODEL_ID = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
21
 
22
+ # Flash Attention / attention impl ayarları (müsaitse kullanırız)
23
  os.environ.setdefault("FLASH_ATTENTION", "1")
24
  os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
25
 
26
+ # ===== LLaVA kaynak kodunu runtime'da getir (pip yoksa!) =====
27
  LLAVA_GIT_URL = os.getenv("LLAVA_GIT_URL", "https://github.com/haotian-liu/LLaVA.git")
28
+ LLAVA_GIT_REF = os.getenv("LLAVA_GIT_REF", "v1.2.2.post1") # stabil bir sürüm
29
  LLAVA_SRC_DIR = os.getenv("LLAVA_SRC_DIR", "/tmp/llava_src/LLaVA")
30
 
31
  def _ensure_llava():
 
40
 
41
  _ensure_llava()
42
 
43
+ # ---- LLaVA parçaları ----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  from llava.model.builder import load_pretrained_model
45
  from llava.constants import (
46
  IMAGE_TOKEN_INDEX,
 
51
  from llava.conversation import conv_templates
52
  from llava.utils import disable_torch_init
53
 
54
+ # HF processor fallback'ları
55
+ from transformers import AutoProcessor, AutoImageProcessor, CLIPImageProcessor
56
+
57
+ # ==========================
58
+ # Yardımcı Fonksiyonlar
59
+ # ==========================
60
+
61
+ def get_model_name_from_path(model_path: str) -> str:
62
+ p = model_path.strip("/").split("/")
63
+ return (p[-2] + "_" + p[-1]) if p[-1].startswith("checkpoint-") else p[-1]
64
+
65
+ def load_image_from_base64(image: str) -> Image.Image:
66
+ return Image.open(io.BytesIO(base64.b64decode(image)))
67
+
68
+ def expand2square(pil_img: Image.Image, background_color: Tuple[int,int,int]) -> Image.Image:
69
+ w, h = pil_img.size
70
+ if w == h:
71
+ return pil_img
72
+ if w > h:
73
+ result = Image.new(pil_img.mode, (w, w), background_color); result.paste(pil_img, (0, (w - h)//2)); return result
74
+ result = Image.new(pil_img.mode, (h, h), background_color); result.paste(pil_img, ((h - w)//2, 0)); return result
75
+
76
+ def select_best_resolution(original_size: Tuple[int,int], possible_resolutions: List[Tuple[int,int]]) -> Tuple[int,int]:
77
+ ow, oh = original_size
78
+ best, max_eff, min_waste = None, 0, float("inf")
79
+ for W, H in possible_resolutions:
80
+ s = min(W/ow, H/oh)
81
+ dw, dh = int(ow*s), int(oh*s)
82
+ eff = min(dw*dh, ow*oh)
83
+ waste = (W*H) - eff
84
+ if (eff > max_eff) or (eff == max_eff and waste < min_waste):
85
+ max_eff, min_waste, best = eff, waste, (W, H)
86
+ return best
87
+
88
+ def resize_and_pad_image(image: Image.Image, target_resolution: Tuple[int,int]) -> Image.Image:
89
+ ow, oh = image.size
90
+ W, H = target_resolution
91
+ sw, sh = W/ow, H/oh
92
+ if sw < sh:
93
+ nw, nh = W, min(math.ceil(oh*sw), H)
94
+ else:
95
+ nh, nw = H, min(math.ceil(ow*sh), W)
96
+ resized = image.resize((nw, nh))
97
+ canvas = Image.new("RGB", (W, H), (0,0,0))
98
+ canvas.paste(resized, ((W - nw)//2, (H - nh)//2))
99
+ return canvas
100
+
101
+ def pad_to_multiple(image: Image.Image, multiple: int) -> Image.Image:
102
+ w, h = image.size
103
+ W = math.ceil(w / multiple) * multiple
104
+ H = math.ceil(h / multiple) * multiple
105
+ if (W, H) == (w, h):
106
+ return image
107
+ canvas = Image.new(image.mode, (W, H), (0,0,0))
108
+ canvas.paste(image, (0,0))
109
+ return canvas
110
+
111
+ def divide_to_patches(image: Image.Image, patch_size: int) -> List[Image.Image]:
112
+ patches = []
113
+ W, H = image.size
114
+ for y in range(0, H, patch_size):
115
+ for x in range(0, W, patch_size):
116
+ patches.append(image.crop((x, y, x+patch_size, y+patch_size)))
117
+ return patches
118
+
119
+ def _get_crop_size(processor: Any, default: int = 224) -> int:
120
+ cs = getattr(processor, "crop_size", None)
121
+ if cs is None:
122
+ sz = getattr(processor, "size", None)
123
+ if isinstance(sz, dict): return int(sz.get("shortest_edge", default))
124
+ if isinstance(sz, int): return int(sz)
125
+ return int(default)
126
+ if isinstance(cs, dict):
127
+ if "height" in cs: return int(cs["height"])
128
+ if "shortest_edge" in cs: return int(cs["shortest_edge"])
129
+ for v in cs.values(): return int(v)
130
+ return int(cs)
131
+
132
+ def _get_shortest_edge(processor: Any, fallback: Optional[int] = None) -> int:
133
+ sz = getattr(processor, "size", None)
134
+ if isinstance(sz, dict) and "shortest_edge" in sz: return int(sz["shortest_edge"])
135
+ if isinstance(sz, int): return int(sz)
136
+ return _get_crop_size(processor, default=(fallback or 224))
137
+
138
+ def _preprocess_one(processor: Any, img: Image.Image) -> torch.Tensor:
139
+ if hasattr(processor, "preprocess"):
140
+ out = processor.preprocess(img, return_tensors="pt")
141
+ else:
142
+ out = processor(img, return_tensors="pt")
143
+ return out["pixel_values"][0]
144
+
145
+ def process_anyres_image(image: Image.Image, processor: Any, grid_pinpoints: Any) -> torch.Tensor:
146
+ if isinstance(grid_pinpoints, list):
147
+ poss = grid_pinpoints
148
+ else:
149
+ import ast
150
+ poss = ast.literal_eval(grid_pinpoints)
151
+ patch_size = _get_crop_size(processor, 224)
152
+ shortest = _get_shortest_edge(processor, fallback=patch_size)
153
+ best = select_best_resolution(image.size, poss)
154
+ padded = resize_and_pad_image(image, best)
155
+ padded = pad_to_multiple(padded, patch_size)
156
+ patches = divide_to_patches(padded, patch_size)
157
+ resized_orig = image.resize((shortest, shortest))
158
+ tensors = [_preprocess_one(processor, resized_orig)] + [_preprocess_one(processor, p) for p in patches]
159
+ return torch.stack(tensors, dim=0)
160
+
161
+ def process_images(images: List[Image.Image], image_processor: Any, model_cfg: Any) -> torch.Tensor:
162
+ iar = getattr(model_cfg, "image_aspect_ratio", None) or getattr(model_cfg, "mm_image_aspect_ratio", None)
163
+ new_images: List[torch.Tensor] = []
164
+
165
+ if iar == "pad":
166
+ for img in images:
167
+ img_mean = getattr(image_processor, "image_mean", [0.5,0.5,0.5])
168
+ bg = tuple(int(x*255) for x in img_mean)
169
+ sq = expand2square(img, bg)
170
+ new_images.append(_preprocess_one(image_processor, sq))
171
+
172
+ elif iar == "anyres":
173
+ grid = getattr(model_cfg, "image_grid_pinpoints", "[(336,336)]")
174
+ for img in images:
175
+ new_images.append(process_anyres_image(img, image_processor, grid))
176
+
177
+ else:
178
+ # toplu çağrı başarısız olursa tek tek dene
179
+ try:
180
+ out = image_processor(images, return_tensors="pt")
181
+ return out["pixel_values"]
182
+ except TypeError:
183
+ outs = [image_processor(im, return_tensors="pt") for im in images]
184
+ pix = [o["pixel_values"][0] for o in outs]
185
+ return torch.stack(pix, dim=0)
186
+
187
+ if all(x.shape == new_images[0].shape for x in new_images):
188
+ return torch.stack(new_images, dim=0)
189
+ return new_images
190
+
191
+ def tokenizer_image_token(prompt: str, tokenizer: Any, image_token_index: int = IMAGE_TOKEN_INDEX,
192
+ return_tensors: Optional[str] = None):
193
+ chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
194
+
195
+ def insert_sep(X, sep):
196
+ return [e for sub in zip(X, [sep]*len(X)) for e in sub][:-1]
197
+
198
+ ids: List[int] = []
199
+ offset = 0
200
+ if len(chunks) > 0 and len(chunks[0]) > 0 and chunks[0][0] == tokenizer.bos_token_id:
201
+ offset = 1
202
+ ids.append(chunks[0][0])
203
+
204
+ for x in insert_sep(chunks, [image_token_index]*(offset+1)):
205
+ ids.extend(x[offset:])
206
+
207
+ if return_tensors is not None:
208
+ if return_tensors == "pt":
209
+ return torch.tensor(ids, dtype=torch.long)
210
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
211
+ return ids
212
+
213
+ # ==========================
214
+ # Endpoint Handler
215
+ # ==========================
216
 
217
  class EndpointHandler:
218
  """
219
  Girdi:
220
+ {
221
+ "inputs": { "query": "...", "image": "<url|dataurl|path>" },
222
+ "parameters": {
223
+ "max_new_tokens": 256, "temperature": 0.0, "top_p": 1.0,
224
+ "repetition_penalty": 1.0, "do_sample": false, "use_cache": true
225
+ },
226
+ "conv_mode": "llava_v2" # opsiyonel
227
+ }
228
  Çıktı: [ { "generated_text": "..." } ]
229
  """
230
  def __init__(self, path: str = "") -> None:
231
  disable_torch_init()
232
 
233
+ # Model yolu önceliği: HF_MODEL_LOCAL_DIR > HF_MODEL_ID > MODEL_ID
234
  if os.getenv("HF_MODEL_LOCAL_DIR", "").strip():
235
  model_path = os.getenv("HF_MODEL_LOCAL_DIR").strip()
236
  elif os.getenv("HF_MODEL_ID", "").strip():
237
  model_path = os.getenv("HF_MODEL_ID").strip()
238
  else:
239
+ model_path = MODEL_ID
240
+
241
+ if not model_path:
242
+ raise RuntimeError("Model path belirlenemedi. HF_MODEL_LOCAL_DIR / HF_MODEL_ID / MODEL_ID ayarla.")
243
 
244
  self.model_name = get_model_name_from_path(model_path)
245
 
246
+ # Attention implementation (flash varsa flash, yoksa sdpa)
247
  try:
248
+ import flash_attn # noqa: F401
249
  attn_impl = "flash_attention_2"
250
+ except Exception:
251
  attn_impl = "sdpa"
252
+
253
+ # Model yükle (LLaVA loader)
254
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
255
  model_path=model_path,
256
  model_base=None,
 
261
  )
262
  self.model.eval()
263
 
264
+ # ---- forward patch (HF 4.43+ arg uyumu) ----
265
  def _patch_forward(obj, label="model"):
266
  try:
267
+ if not hasattr(obj, "forward"): return False
268
+ orig = obj.forward
 
 
269
  def patched_forward(*args, **kwargs):
 
270
  kwargs.pop("cache_position", None)
271
  kwargs.pop("input_positions", None)
272
+ return orig(*args, **kwargs)
 
273
  obj.forward = patched_forward
274
  print(f"[hotfix] Patched forward on {label}")
275
  return True
 
277
  print(f"[warn] forward patch failed on {label}: {e}")
278
  return False
279
 
 
280
  _patch_forward(self.model, "self.model")
281
+ if hasattr(self.model, "model"): _patch_forward(self.model.model, "self.model.model")
282
+ if hasattr(self.model, "base_model"): _patch_forward(self.model.base_model, "self.model.base_model")
283
 
284
+ # ---- image_processor fallback ----
285
+ if self.image_processor is None:
286
+ print("[hotfix] image_processor None, AutoProcessor fallback deneniyor...")
287
+ try:
288
+ proc = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
289
+ self.image_processor = getattr(proc, "image_processor", proc)
290
+ except Exception as e:
291
+ print(f"[warn] AutoProcessor başarısız: {e}")
292
+ vt = getattr(self.model.config, "vision_tower", None)
293
+ if vt:
294
+ try:
295
+ self.image_processor = AutoImageProcessor.from_pretrained(vt, trust_remote_code=True)
296
+ except Exception:
297
+ self.image_processor = CLIPImageProcessor.from_pretrained(vt)
298
+
299
+ # anyres -> pad fallback (processor/crop_size yoksa)
300
+ iar = getattr(self.model.config, "mm_image_aspect_ratio", None) or \
301
+ getattr(self.model.config, "image_aspect_ratio", None)
302
+ needs_crop = (self.image_processor is None) or (getattr(self.image_processor, "crop_size", None) is None)
303
+ if iar == "anyres" and needs_crop:
304
+ print("[hotfix] image_aspect_ratio:anyres -> pad (processor/crop_size eksik)")
305
+ if hasattr(self.model.config, "image_aspect_ratio"):
306
+ self.model.config.image_aspect_ratio = "pad"
307
+ if hasattr(self.model.config, "mm_image_aspect_ratio"):
308
+ self.model.config.mm_image_aspect_ratio = "pad"
309
+
310
+ # multimodal bayraklar
311
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
312
+ self.is_multimodal = 'llava' in self.model_name.lower() or 'pulse' in self.model_name.lower()
313
+
314
+ # Varsayılanlar
315
+ self.DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v1")
316
+ self.MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "1024"))
317
 
318
+ # -------------------------
319
+ # İç yardımcılar
320
+ # -------------------------
321
  def _load_image(self, img_field: str) -> Optional[Image.Image]:
322
  """URL / base64 / path -> PIL.Image"""
323
+ if not img_field: return None
 
324
  try:
325
  if img_field.startswith("data:image"):
326
  _, b64 = img_field.split(",", 1)
327
  return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
328
  if img_field.startswith(("http://", "https://")):
329
+ r = requests.get(img_field, timeout=20); r.raise_for_status()
 
330
  return Image.open(io.BytesIO(r.content)).convert("RGB")
331
  return Image.open(img_field).convert("RGB")
332
  except Exception as e:
 
334
  return None
335
 
336
  def _build_prompt(self, user_text: str, conv_mode: str) -> str:
337
+ """LLaVA model worker tarzı prompt oluştur."""
338
  if conv_mode not in conv_templates:
339
+ conv_mode = self.DEFAULT_CONV_MODE
340
  conv = conv_templates[conv_mode].copy()
 
 
 
341
  conv.append_message(conv.roles[0], user_text)
342
  conv.append_message(conv.roles[1], None)
343
  return conv.get_prompt()
344
 
345
+ # -------------------------
346
+ # Inference Entry
347
+ # -------------------------
348
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
349
+ inputs: Dict[str, Any] = data.get("inputs") or {}
350
+ params: Dict[str, Any] = data.get("parameters") or {}
351
+ conv_mode_req: Optional[str] = data.get("conv_mode")
352
 
353
+ conv_mode = conv_mode_req if conv_mode_req in conv_templates else self.DEFAULT_CONV_MODE
354
  query_text = inputs.get("query", "") or inputs.get("text", "") or inputs.get("prompt", "")
355
  image_f = inputs.get("image") or inputs.get("image_url") or inputs.get("image_base64")
356
 
357
+ # 1) Prompt
358
  prompt = self._build_prompt(query_text, conv_mode)
359
+
360
+ # 2) Görsel işleme
361
  images = None
362
  image_sizes = None
 
363
  if image_f and self.is_multimodal:
364
  try:
365
  pil_image = self._load_image(image_f)
366
+ if pil_image is not None and self.image_processor is not None:
367
  images_list = [pil_image]
368
  image_sizes = [pil_image.size]
369
+
 
370
  processed_images = process_images(images_list, self.image_processor, self.model.config)
371
+ # tensor/list to device + dtype
372
  if isinstance(processed_images, list):
373
  images = [img.to(self.model.device, dtype=torch.float16) for img in processed_images]
374
  else:
375
  images = processed_images.to(self.model.device, dtype=torch.float16)
376
+
377
+ # Görsel token ekle + im_start/end sarma
378
+ prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
 
 
 
379
  replace_token = DEFAULT_IMAGE_TOKEN
380
  if self.use_im_start_end:
381
  replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
 
 
382
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
383
+ print("[info] Image processed successfully.")
 
 
384
  else:
385
+ print("[warn] Could not load image or image_processor is None.")
386
  except Exception as e:
387
  print(f"[warn] Image processing failed: {e}")
388
+ import traceback; traceback.print_exc()
389
+ images = None; image_sizes = None
 
 
390
 
391
+ # 3) Tokenization (+ attention_mask)
392
  try:
393
  input_ids = tokenizer_image_token(
394
  prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt'
395
  ).unsqueeze(0).to(self.model.device)
 
 
 
 
396
  except Exception as e:
397
  print(f"[error] Tokenization failed: {e}")
398
+ enc = self.tokenizer(query_text, return_tensors="pt")
399
+ input_ids = enc.input_ids.to(self.model.device)
400
+ images = None; image_sizes = None
401
+
402
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
403
 
404
+ # 4) Generation params
405
  temperature = float(params.get("temperature", 0.0))
406
  top_p = float(params.get("top_p", 1.0))
407
  repetition_penalty = float(params.get("repetition_penalty", 1.0))
408
+ max_new_tokens = min(int(params.get("max_new_tokens", self.MAX_NEW_TOKENS_DEF)), 1024)
409
  do_sample = bool(params.get("do_sample", temperature > 0.001))
410
+
411
+ # Context length sınırı (güvenli boşluk)
412
+ max_context_length = getattr(self.model.config, 'max_position_embeddings', 4096)
413
+ max_new_tokens = min(max_new_tokens, max(1, max_context_length - input_ids.shape[-1] - 50))
 
414
  if max_new_tokens < 1:
415
  return [{"generated_text": "Error: Input too long, exceeds max token length."}]
416
 
417
+ # 5) Gen kwargs
418
+ gen_kwargs: Dict[str, Any] = {
419
+ "inputs": input_ids,
420
+ "attention_mask": attention_mask,
421
  "max_new_tokens": max_new_tokens,
422
  "temperature": temperature,
423
  "top_p": top_p,
 
427
  "pad_token_id": self.tokenizer.eos_token_id,
428
  }
429
 
 
430
  if images is not None and image_sizes is not None:
431
  gen_kwargs["images"] = images
432
  gen_kwargs["image_sizes"] = image_sizes
433
+ print("[info] Using images in generation.")
434
  else:
435
+ # Prompt’ta olası görsel tokenlarını temizle (text-only güvenliği)
436
+ prompt_clean = prompt.replace(DEFAULT_IMAGE_TOKEN, "") \
437
+ .replace(DEFAULT_IM_START_TOKEN, "") \
438
+ .replace(DEFAULT_IM_END_TOKEN, "")
439
+ if prompt_clean != prompt:
440
+ try:
441
+ input_ids = self.tokenizer(prompt_clean, return_tensors="pt").input_ids.to(self.model.device)
442
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
443
+ gen_kwargs["inputs"] = input_ids
444
+ gen_kwargs["attention_mask"] = attention_mask
445
+ except Exception as e:
446
+ print(f"[warn] prompt cleanup failed: {e}")
447
+ print("[info] Text-only generation.")
448
+
449
+ # 6) Generate
450
  try:
451
  with torch.inference_mode():
452
  output_ids = self.model.generate(**gen_kwargs)
453
+ if output_ids.shape[-1] > gen_kwargs["inputs"].shape[-1]:
454
+ response_ids = output_ids[:, gen_kwargs["inputs"].shape[-1]:]
 
 
455
  text = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)[0].strip()
456
  else:
457
  text = "Error: No response generated"
 
458
  except Exception as e:
459
  print(f"Generation error: {e}")
460
+ import traceback; traceback.print_exc()
 
461
  text = f"Error during generation: {str(e)}"
462
+
463
+ return [{"generated_text": text}]