CanerDedeoglu commited on
Commit
27bc9ca
·
verified ·
1 Parent(s): 0b38f8d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +34 -185
handler.py CHANGED
@@ -1,31 +1,25 @@
1
  # -*- coding: utf-8 -*-
2
- # handler.py — PULSE-7B / LLaVA robust endpoint
3
- # - LLaVA kaynak kodunu runtime'da git clone ile getirir
 
4
  # - image_processor fallback (AutoProcessor / vision_tower)
5
- # - anyres -> pad güvenli düşüş
6
- # - preprocess/call farkını soyutlama
7
- # - attention_mask zorunlu (HF generate NoneType.new_ones fix)
8
  # - forward patch (cache_position/input_positions sessizce düşür)
9
- # - robust image pipeline (pad_to_multiple, crop_size/shortest_edge tespiti)
10
 
11
- import os, io, sys, subprocess, base64
12
  from typing import Any, Dict, List, Optional, Tuple
13
 
14
  import torch
15
  from PIL import Image
16
  import requests
17
- import math
18
 
19
- # ===== Kullanılacak HF model id =====
20
  MODEL_ID = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
21
 
22
- # Flash Attention / attention impl ayarları (müsaitse kullanırız)
23
- os.environ.setdefault("FLASH_ATTENTION", "1")
24
- os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
25
-
26
- # ===== LLaVA kaynak kodunu runtime'da getir (pip yoksa!) =====
27
  LLAVA_GIT_URL = os.getenv("LLAVA_GIT_URL", "https://github.com/haotian-liu/LLaVA.git")
28
- LLAVA_GIT_REF = os.getenv("LLAVA_GIT_REF", "v1.2.2.post1") # stabil bir sürüm
29
  LLAVA_SRC_DIR = os.getenv("LLAVA_SRC_DIR", "/tmp/llava_src/LLaVA")
30
 
31
  def _ensure_llava():
@@ -51,168 +45,16 @@ from llava.constants import (
51
  from llava.conversation import conv_templates
52
  from llava.utils import disable_torch_init
53
 
 
 
 
 
 
 
 
54
  # HF processor fallback'ları
55
  from transformers import AutoProcessor, AutoImageProcessor, CLIPImageProcessor
56
 
57
- # ==========================
58
- # Yardımcı Fonksiyonlar
59
- # ==========================
60
-
61
- def get_model_name_from_path(model_path: str) -> str:
62
- p = model_path.strip("/").split("/")
63
- return (p[-2] + "_" + p[-1]) if p[-1].startswith("checkpoint-") else p[-1]
64
-
65
- def load_image_from_base64(image: str) -> Image.Image:
66
- return Image.open(io.BytesIO(base64.b64decode(image)))
67
-
68
- def expand2square(pil_img: Image.Image, background_color: Tuple[int,int,int]) -> Image.Image:
69
- w, h = pil_img.size
70
- if w == h:
71
- return pil_img
72
- if w > h:
73
- result = Image.new(pil_img.mode, (w, w), background_color); result.paste(pil_img, (0, (w - h)//2)); return result
74
- result = Image.new(pil_img.mode, (h, h), background_color); result.paste(pil_img, ((h - w)//2, 0)); return result
75
-
76
- def select_best_resolution(original_size: Tuple[int,int], possible_resolutions: List[Tuple[int,int]]) -> Tuple[int,int]:
77
- ow, oh = original_size
78
- best, max_eff, min_waste = None, 0, float("inf")
79
- for W, H in possible_resolutions:
80
- s = min(W/ow, H/oh)
81
- dw, dh = int(ow*s), int(oh*s)
82
- eff = min(dw*dh, ow*oh)
83
- waste = (W*H) - eff
84
- if (eff > max_eff) or (eff == max_eff and waste < min_waste):
85
- max_eff, min_waste, best = eff, waste, (W, H)
86
- return best
87
-
88
- def resize_and_pad_image(image: Image.Image, target_resolution: Tuple[int,int]) -> Image.Image:
89
- ow, oh = image.size
90
- W, H = target_resolution
91
- sw, sh = W/ow, H/oh
92
- if sw < sh:
93
- nw, nh = W, min(math.ceil(oh*sw), H)
94
- else:
95
- nh, nw = H, min(math.ceil(ow*sh), W)
96
- resized = image.resize((nw, nh))
97
- canvas = Image.new("RGB", (W, H), (0,0,0))
98
- canvas.paste(resized, ((W - nw)//2, (H - nh)//2))
99
- return canvas
100
-
101
- def pad_to_multiple(image: Image.Image, multiple: int) -> Image.Image:
102
- w, h = image.size
103
- W = math.ceil(w / multiple) * multiple
104
- H = math.ceil(h / multiple) * multiple
105
- if (W, H) == (w, h):
106
- return image
107
- canvas = Image.new(image.mode, (W, H), (0,0,0))
108
- canvas.paste(image, (0,0))
109
- return canvas
110
-
111
- def divide_to_patches(image: Image.Image, patch_size: int) -> List[Image.Image]:
112
- patches = []
113
- W, H = image.size
114
- for y in range(0, H, patch_size):
115
- for x in range(0, W, patch_size):
116
- patches.append(image.crop((x, y, x+patch_size, y+patch_size)))
117
- return patches
118
-
119
- def _get_crop_size(processor: Any, default: int = 224) -> int:
120
- cs = getattr(processor, "crop_size", None)
121
- if cs is None:
122
- sz = getattr(processor, "size", None)
123
- if isinstance(sz, dict): return int(sz.get("shortest_edge", default))
124
- if isinstance(sz, int): return int(sz)
125
- return int(default)
126
- if isinstance(cs, dict):
127
- if "height" in cs: return int(cs["height"])
128
- if "shortest_edge" in cs: return int(cs["shortest_edge"])
129
- for v in cs.values(): return int(v)
130
- return int(cs)
131
-
132
- def _get_shortest_edge(processor: Any, fallback: Optional[int] = None) -> int:
133
- sz = getattr(processor, "size", None)
134
- if isinstance(sz, dict) and "shortest_edge" in sz: return int(sz["shortest_edge"])
135
- if isinstance(sz, int): return int(sz)
136
- return _get_crop_size(processor, default=(fallback or 224))
137
-
138
- def _preprocess_one(processor: Any, img: Image.Image) -> torch.Tensor:
139
- if hasattr(processor, "preprocess"):
140
- out = processor.preprocess(img, return_tensors="pt")
141
- else:
142
- out = processor(img, return_tensors="pt")
143
- return out["pixel_values"][0]
144
-
145
- def process_anyres_image(image: Image.Image, processor: Any, grid_pinpoints: Any) -> torch.Tensor:
146
- if isinstance(grid_pinpoints, list):
147
- poss = grid_pinpoints
148
- else:
149
- import ast
150
- poss = ast.literal_eval(grid_pinpoints)
151
- patch_size = _get_crop_size(processor, 224)
152
- shortest = _get_shortest_edge(processor, fallback=patch_size)
153
- best = select_best_resolution(image.size, poss)
154
- padded = resize_and_pad_image(image, best)
155
- padded = pad_to_multiple(padded, patch_size)
156
- patches = divide_to_patches(padded, patch_size)
157
- resized_orig = image.resize((shortest, shortest))
158
- tensors = [_preprocess_one(processor, resized_orig)] + [_preprocess_one(processor, p) for p in patches]
159
- return torch.stack(tensors, dim=0)
160
-
161
- def process_images(images: List[Image.Image], image_processor: Any, model_cfg: Any) -> torch.Tensor:
162
- iar = getattr(model_cfg, "image_aspect_ratio", None) or getattr(model_cfg, "mm_image_aspect_ratio", None)
163
- new_images: List[torch.Tensor] = []
164
-
165
- if iar == "pad":
166
- for img in images:
167
- img_mean = getattr(image_processor, "image_mean", [0.5,0.5,0.5])
168
- bg = tuple(int(x*255) for x in img_mean)
169
- sq = expand2square(img, bg)
170
- new_images.append(_preprocess_one(image_processor, sq))
171
-
172
- elif iar == "anyres":
173
- grid = getattr(model_cfg, "image_grid_pinpoints", "[(336,336)]")
174
- for img in images:
175
- new_images.append(process_anyres_image(img, image_processor, grid))
176
-
177
- else:
178
- # toplu çağrı başarısız olursa tek tek dene
179
- try:
180
- out = image_processor(images, return_tensors="pt")
181
- return out["pixel_values"]
182
- except TypeError:
183
- outs = [image_processor(im, return_tensors="pt") for im in images]
184
- pix = [o["pixel_values"][0] for o in outs]
185
- return torch.stack(pix, dim=0)
186
-
187
- if all(x.shape == new_images[0].shape for x in new_images):
188
- return torch.stack(new_images, dim=0)
189
- return new_images
190
-
191
- def tokenizer_image_token(prompt: str, tokenizer: Any, image_token_index: int = IMAGE_TOKEN_INDEX,
192
- return_tensors: Optional[str] = None):
193
- chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
194
-
195
- def insert_sep(X, sep):
196
- return [e for sub in zip(X, [sep]*len(X)) for e in sub][:-1]
197
-
198
- ids: List[int] = []
199
- offset = 0
200
- if len(chunks) > 0 and len(chunks[0]) > 0 and chunks[0][0] == tokenizer.bos_token_id:
201
- offset = 1
202
- ids.append(chunks[0][0])
203
-
204
- for x in insert_sep(chunks, [image_token_index]*(offset+1)):
205
- ids.extend(x[offset:])
206
-
207
- if return_tensors is not None:
208
- if return_tensors == "pt":
209
- return torch.tensor(ids, dtype=torch.long)
210
- raise ValueError(f"Unsupported tensor type: {return_tensors}")
211
- return ids
212
-
213
- # ==========================
214
- # Endpoint Handler
215
- # ==========================
216
 
217
  class EndpointHandler:
218
  """
@@ -237,20 +79,19 @@ class EndpointHandler:
237
  model_path = os.getenv("HF_MODEL_ID").strip()
238
  else:
239
  model_path = MODEL_ID
240
-
241
  if not model_path:
242
  raise RuntimeError("Model path belirlenemedi. HF_MODEL_LOCAL_DIR / HF_MODEL_ID / MODEL_ID ayarla.")
243
 
244
  self.model_name = get_model_name_from_path(model_path)
245
 
246
- # Attention implementation (flash varsa flash, yoksa sdpa)
247
  try:
248
  import flash_attn # noqa: F401
249
  attn_impl = "flash_attention_2"
250
  except Exception:
251
  attn_impl = "sdpa"
252
 
253
- # Model yükle (LLaVA loader)
254
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
255
  model_path=model_path,
256
  model_base=None,
@@ -261,7 +102,7 @@ class EndpointHandler:
261
  )
262
  self.model.eval()
263
 
264
- # ---- forward patch (HF 4.43+ arg uyumu) ----
265
  def _patch_forward(obj, label="model"):
266
  try:
267
  if not hasattr(obj, "forward"): return False
@@ -309,12 +150,20 @@ class EndpointHandler:
309
 
310
  # multimodal bayraklar
311
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
312
- self.is_multimodal = 'llava' in self.model_name.lower() or 'pulse' in self.model_name.lower()
313
 
314
  # Varsayılanlar
315
  self.DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v1")
316
  self.MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "1024"))
317
 
 
 
 
 
 
 
 
 
318
  # -------------------------
319
  # İç yardımcılar
320
  # -------------------------
@@ -368,7 +217,6 @@ class EndpointHandler:
368
  image_sizes = [pil_image.size]
369
 
370
  processed_images = process_images(images_list, self.image_processor, self.model.config)
371
- # tensor/list to device + dtype
372
  if isinstance(processed_images, list):
373
  images = [img.to(self.model.device, dtype=torch.float16) for img in processed_images]
374
  else:
@@ -388,7 +236,7 @@ class EndpointHandler:
388
  import traceback; traceback.print_exc()
389
  images = None; image_sizes = None
390
 
391
- # 3) Tokenization (+ attention_mask)
392
  try:
393
  input_ids = tokenizer_image_token(
394
  prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt'
@@ -399,6 +247,7 @@ class EndpointHandler:
399
  input_ids = enc.input_ids.to(self.model.device)
400
  images = None; image_sizes = None
401
 
 
402
  attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
403
 
404
  # 4) Generation params
@@ -408,7 +257,6 @@ class EndpointHandler:
408
  max_new_tokens = min(int(params.get("max_new_tokens", self.MAX_NEW_TOKENS_DEF)), 1024)
409
  do_sample = bool(params.get("do_sample", temperature > 0.001))
410
 
411
- # Context length sınırı (güvenli boşluk)
412
  max_context_length = getattr(self.model.config, 'max_position_embeddings', 4096)
413
  max_new_tokens = min(max_new_tokens, max(1, max_context_length - input_ids.shape[-1] - 50))
414
  if max_new_tokens < 1:
@@ -417,7 +265,6 @@ class EndpointHandler:
417
  # 5) Gen kwargs
418
  gen_kwargs: Dict[str, Any] = {
419
  "inputs": input_ids,
420
- "attention_mask": attention_mask,
421
  "max_new_tokens": max_new_tokens,
422
  "temperature": temperature,
423
  "top_p": top_p,
@@ -426,6 +273,8 @@ class EndpointHandler:
426
  "use_cache": bool(params.get("use_cache", True)),
427
  "pad_token_id": self.tokenizer.eos_token_id,
428
  }
 
 
429
 
430
  if images is not None and image_sizes is not None:
431
  gen_kwargs["images"] = images
@@ -439,9 +288,9 @@ class EndpointHandler:
439
  if prompt_clean != prompt:
440
  try:
441
  input_ids = self.tokenizer(prompt_clean, return_tensors="pt").input_ids.to(self.model.device)
442
- attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
443
  gen_kwargs["inputs"] = input_ids
444
- gen_kwargs["attention_mask"] = attention_mask
 
445
  except Exception as e:
446
  print(f"[warn] prompt cleanup failed: {e}")
447
  print("[info] Text-only generation.")
 
1
  # -*- coding: utf-8 -*-
2
+ # handler.py — PULSE-7B / LLaVA endpoint (mm_utils_local ile)
3
+ # - LLaVA kaynaklarını runtime'da git clone ile getirir (model builder, conv, constants)
4
+ # - Görsel işleme: mm_utils_local.process_images / tokenizer_image_token
5
  # - image_processor fallback (AutoProcessor / vision_tower)
6
+ # - anyres -> pad güvenli düşüş (mm_utils_local zaten robust)
 
 
7
  # - forward patch (cache_position/input_positions sessizce düşür)
8
+ # - attention_mask: model destekliyorsa gönder (unused kwargs hatasını önlemek için koşullu)
9
 
10
+ import os, io, sys, subprocess, base64, inspect
11
  from typing import Any, Dict, List, Optional, Tuple
12
 
13
  import torch
14
  from PIL import Image
15
  import requests
 
16
 
17
+ # ===== Model ID =====
18
  MODEL_ID = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
19
 
20
+ # ===== LLaVA kaynaklarını runtime'da çek =====
 
 
 
 
21
  LLAVA_GIT_URL = os.getenv("LLAVA_GIT_URL", "https://github.com/haotian-liu/LLaVA.git")
22
+ LLAVA_GIT_REF = os.getenv("LLAVA_GIT_REF", "v1.2.2.post1")
23
  LLAVA_SRC_DIR = os.getenv("LLAVA_SRC_DIR", "/tmp/llava_src/LLaVA")
24
 
25
  def _ensure_llava():
 
45
  from llava.conversation import conv_templates
46
  from llava.utils import disable_torch_init
47
 
48
+ # ---- mm_utils_local (senin dosyan) ----
49
+ from mm_utils_local import (
50
+ tokenizer_image_token,
51
+ process_images,
52
+ get_model_name_from_path,
53
+ )
54
+
55
  # HF processor fallback'ları
56
  from transformers import AutoProcessor, AutoImageProcessor, CLIPImageProcessor
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  class EndpointHandler:
60
  """
 
79
  model_path = os.getenv("HF_MODEL_ID").strip()
80
  else:
81
  model_path = MODEL_ID
 
82
  if not model_path:
83
  raise RuntimeError("Model path belirlenemedi. HF_MODEL_LOCAL_DIR / HF_MODEL_ID / MODEL_ID ayarla.")
84
 
85
  self.model_name = get_model_name_from_path(model_path)
86
 
87
+ # Attention implementation seçimi
88
  try:
89
  import flash_attn # noqa: F401
90
  attn_impl = "flash_attention_2"
91
  except Exception:
92
  attn_impl = "sdpa"
93
 
94
+ # Modeli yükle
95
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
96
  model_path=model_path,
97
  model_base=None,
 
102
  )
103
  self.model.eval()
104
 
105
+ # ---- forward patch: yeni HF arg uyumu ----
106
  def _patch_forward(obj, label="model"):
107
  try:
108
  if not hasattr(obj, "forward"): return False
 
150
 
151
  # multimodal bayraklar
152
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
153
+ self.is_multimodal = ('llava' in self.model_name.lower()) or ('pulse' in self.model_name.lower())
154
 
155
  # Varsayılanlar
156
  self.DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v1")
157
  self.MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "1024"))
158
 
159
+ # attention_mask desteğini bir kez tespit et
160
+ self._supports_attention_mask = False
161
+ try:
162
+ sig = inspect.signature(self.model.forward)
163
+ self._supports_attention_mask = ("attention_mask" in sig.parameters)
164
+ except Exception:
165
+ self._supports_attention_mask = False
166
+
167
  # -------------------------
168
  # İç yardımcılar
169
  # -------------------------
 
217
  image_sizes = [pil_image.size]
218
 
219
  processed_images = process_images(images_list, self.image_processor, self.model.config)
 
220
  if isinstance(processed_images, list):
221
  images = [img.to(self.model.device, dtype=torch.float16) for img in processed_images]
222
  else:
 
236
  import traceback; traceback.print_exc()
237
  images = None; image_sizes = None
238
 
239
+ # 3) Tokenization
240
  try:
241
  input_ids = tokenizer_image_token(
242
  prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt'
 
247
  input_ids = enc.input_ids.to(self.model.device)
248
  images = None; image_sizes = None
249
 
250
+ # attention_mask: model destekliyorsa üret ve ekleyeceğiz
251
  attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
252
 
253
  # 4) Generation params
 
257
  max_new_tokens = min(int(params.get("max_new_tokens", self.MAX_NEW_TOKENS_DEF)), 1024)
258
  do_sample = bool(params.get("do_sample", temperature > 0.001))
259
 
 
260
  max_context_length = getattr(self.model.config, 'max_position_embeddings', 4096)
261
  max_new_tokens = min(max_new_tokens, max(1, max_context_length - input_ids.shape[-1] - 50))
262
  if max_new_tokens < 1:
 
265
  # 5) Gen kwargs
266
  gen_kwargs: Dict[str, Any] = {
267
  "inputs": input_ids,
 
268
  "max_new_tokens": max_new_tokens,
269
  "temperature": temperature,
270
  "top_p": top_p,
 
273
  "use_cache": bool(params.get("use_cache", True)),
274
  "pad_token_id": self.tokenizer.eos_token_id,
275
  }
276
+ if self._supports_attention_mask:
277
+ gen_kwargs["attention_mask"] = attention_mask
278
 
279
  if images is not None and image_sizes is not None:
280
  gen_kwargs["images"] = images
 
288
  if prompt_clean != prompt:
289
  try:
290
  input_ids = self.tokenizer(prompt_clean, return_tensors="pt").input_ids.to(self.model.device)
 
291
  gen_kwargs["inputs"] = input_ids
292
+ if self._supports_attention_mask:
293
+ gen_kwargs["attention_mask"] = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
294
  except Exception as e:
295
  print(f"[warn] prompt cleanup failed: {e}")
296
  print("[info] Text-only generation.")