CanerDedeoglu commited on
Commit
18475c7
·
verified ·
1 Parent(s): 27bc9ca

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +254 -151
handler.py CHANGED
@@ -1,25 +1,24 @@
1
- # -*- coding: utf-8 -*-
2
- # handler.py — PULSE-7B / LLaVA endpoint (mm_utils_local ile)
3
- # - LLaVA kaynaklarını runtime'da git clone ile getirir (model builder, conv, constants)
4
- # - Görsel işleme: mm_utils_local.process_images / tokenizer_image_token
5
- # - image_processor fallback (AutoProcessor / vision_tower)
6
- # - anyres -> pad güvenli düşüş (mm_utils_local zaten robust)
7
- # - forward patch (cache_position/input_positions sessizce düşür)
8
- # - attention_mask: model destekliyorsa gönder (unused kwargs hatasını önlemek için koşullu)
9
 
10
- import os, io, sys, subprocess, base64, inspect
11
- from typing import Any, Dict, List, Optional, Tuple
 
12
 
13
  import torch
14
  from PIL import Image
15
  import requests
 
 
16
 
17
- # ===== Model ID =====
18
  MODEL_ID = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
19
 
20
- # ===== LLaVA kaynaklarını runtime'da çek =====
 
 
 
 
21
  LLAVA_GIT_URL = os.getenv("LLAVA_GIT_URL", "https://github.com/haotian-liu/LLaVA.git")
22
- LLAVA_GIT_REF = os.getenv("LLAVA_GIT_REF", "v1.2.2.post1")
23
  LLAVA_SRC_DIR = os.getenv("LLAVA_SRC_DIR", "/tmp/llava_src/LLaVA")
24
 
25
  def _ensure_llava():
@@ -34,7 +33,141 @@ def _ensure_llava():
34
 
35
  _ensure_llava()
36
 
37
- # ---- LLaVA parçaları ----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  from llava.model.builder import load_pretrained_model
39
  from llava.constants import (
40
  IMAGE_TOKEN_INDEX,
@@ -45,53 +178,42 @@ from llava.constants import (
45
  from llava.conversation import conv_templates
46
  from llava.utils import disable_torch_init
47
 
48
- # ---- mm_utils_local (senin dosyan) ----
49
- from mm_utils_local import (
50
- tokenizer_image_token,
51
- process_images,
52
- get_model_name_from_path,
53
- )
54
-
55
- # HF processor fallback'ları
56
- from transformers import AutoProcessor, AutoImageProcessor, CLIPImageProcessor
57
-
58
 
59
  class EndpointHandler:
60
  """
61
  Girdi:
62
- {
63
- "inputs": { "query": "...", "image": "<url|dataurl|path>" },
64
- "parameters": {
65
- "max_new_tokens": 256, "temperature": 0.0, "top_p": 1.0,
66
- "repetition_penalty": 1.0, "do_sample": false, "use_cache": true
67
- },
68
- "conv_mode": "llava_v2" # opsiyonel
69
- }
70
  Çıktı: [ { "generated_text": "..." } ]
71
  """
72
  def __init__(self, path: str = "") -> None:
73
  disable_torch_init()
74
 
75
- # Model yolu önceliği: HF_MODEL_LOCAL_DIR > HF_MODEL_ID > MODEL_ID
76
  if os.getenv("HF_MODEL_LOCAL_DIR", "").strip():
77
  model_path = os.getenv("HF_MODEL_LOCAL_DIR").strip()
78
  elif os.getenv("HF_MODEL_ID", "").strip():
79
  model_path = os.getenv("HF_MODEL_ID").strip()
80
  else:
81
- model_path = MODEL_ID
82
- if not model_path:
83
- raise RuntimeError("Model path belirlenemedi. HF_MODEL_LOCAL_DIR / HF_MODEL_ID / MODEL_ID ayarla.")
84
 
85
  self.model_name = get_model_name_from_path(model_path)
86
 
87
- # Attention implementation seçimi
88
  try:
89
- import flash_attn # noqa: F401
90
  attn_impl = "flash_attention_2"
91
- except Exception:
92
  attn_impl = "sdpa"
93
-
94
- # Modeli yükle
95
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
96
  model_path=model_path,
97
  model_base=None,
@@ -102,15 +224,18 @@ class EndpointHandler:
102
  )
103
  self.model.eval()
104
 
105
- # ---- forward patch: yeni HF arg uyumu ----
106
  def _patch_forward(obj, label="model"):
107
  try:
108
- if not hasattr(obj, "forward"): return False
109
- orig = obj.forward
 
 
110
  def patched_forward(*args, **kwargs):
 
111
  kwargs.pop("cache_position", None)
112
  kwargs.pop("input_positions", None)
113
- return orig(*args, **kwargs)
 
114
  obj.forward = patched_forward
115
  print(f"[hotfix] Patched forward on {label}")
116
  return True
@@ -118,64 +243,34 @@ class EndpointHandler:
118
  print(f"[warn] forward patch failed on {label}: {e}")
119
  return False
120
 
 
121
  _patch_forward(self.model, "self.model")
122
- if hasattr(self.model, "model"): _patch_forward(self.model.model, "self.model.model")
123
- if hasattr(self.model, "base_model"): _patch_forward(self.model.base_model, "self.model.base_model")
124
 
125
- # ---- image_processor fallback ----
126
- if self.image_processor is None:
127
- print("[hotfix] image_processor None, AutoProcessor fallback deneniyor...")
128
- try:
129
- proc = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
130
- self.image_processor = getattr(proc, "image_processor", proc)
131
- except Exception as e:
132
- print(f"[warn] AutoProcessor başarısız: {e}")
133
- vt = getattr(self.model.config, "vision_tower", None)
134
- if vt:
135
- try:
136
- self.image_processor = AutoImageProcessor.from_pretrained(vt, trust_remote_code=True)
137
- except Exception:
138
- self.image_processor = CLIPImageProcessor.from_pretrained(vt)
139
-
140
- # anyres -> pad fallback (processor/crop_size yoksa)
141
- iar = getattr(self.model.config, "mm_image_aspect_ratio", None) or \
142
- getattr(self.model.config, "image_aspect_ratio", None)
143
- needs_crop = (self.image_processor is None) or (getattr(self.image_processor, "crop_size", None) is None)
144
- if iar == "anyres" and needs_crop:
145
- print("[hotfix] image_aspect_ratio:anyres -> pad (processor/crop_size eksik)")
146
- if hasattr(self.model.config, "image_aspect_ratio"):
147
- self.model.config.image_aspect_ratio = "pad"
148
- if hasattr(self.model.config, "mm_image_aspect_ratio"):
149
- self.model.config.mm_image_aspect_ratio = "pad"
150
-
151
- # multimodal bayraklar
152
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
153
- self.is_multimodal = ('llava' in self.model_name.lower()) or ('pulse' in self.model_name.lower())
154
 
155
- # Varsayılanlar
156
- self.DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v1")
157
- self.MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "1024"))
158
-
159
- # attention_mask desteğini bir kez tespit et
160
- self._supports_attention_mask = False
161
- try:
162
- sig = inspect.signature(self.model.forward)
163
- self._supports_attention_mask = ("attention_mask" in sig.parameters)
164
- except Exception:
165
- self._supports_attention_mask = False
166
-
167
- # -------------------------
168
- # İç yardımcılar
169
- # -------------------------
170
  def _load_image(self, img_field: str) -> Optional[Image.Image]:
171
  """URL / base64 / path -> PIL.Image"""
172
- if not img_field: return None
 
173
  try:
174
  if img_field.startswith("data:image"):
175
  _, b64 = img_field.split(",", 1)
176
  return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
177
  if img_field.startswith(("http://", "https://")):
178
- r = requests.get(img_field, timeout=20); r.raise_for_status()
 
179
  return Image.open(io.BytesIO(r.content)).convert("RGB")
180
  return Image.open(img_field).convert("RGB")
181
  except Exception as e:
@@ -183,88 +278,106 @@ class EndpointHandler:
183
  return None
184
 
185
  def _build_prompt(self, user_text: str, conv_mode: str) -> str:
186
- """LLaVA model worker tarzı prompt oluştur."""
187
  if conv_mode not in conv_templates:
188
- conv_mode = self.DEFAULT_CONV_MODE
189
  conv = conv_templates[conv_mode].copy()
 
 
 
190
  conv.append_message(conv.roles[0], user_text)
191
  conv.append_message(conv.roles[1], None)
192
  return conv.get_prompt()
193
 
194
- # -------------------------
195
- # Inference Entry
196
- # -------------------------
197
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
198
- inputs: Dict[str, Any] = data.get("inputs") or {}
199
- params: Dict[str, Any] = data.get("parameters") or {}
200
- conv_mode_req: Optional[str] = data.get("conv_mode")
201
 
202
- conv_mode = conv_mode_req if conv_mode_req in conv_templates else self.DEFAULT_CONV_MODE
203
  query_text = inputs.get("query", "") or inputs.get("text", "") or inputs.get("prompt", "")
204
  image_f = inputs.get("image") or inputs.get("image_url") or inputs.get("image_base64")
205
 
206
- # 1) Prompt
207
  prompt = self._build_prompt(query_text, conv_mode)
208
-
209
- # 2) Görsel işleme
210
  images = None
211
  image_sizes = None
 
212
  if image_f and self.is_multimodal:
213
  try:
214
  pil_image = self._load_image(image_f)
215
- if pil_image is not None and self.image_processor is not None:
216
  images_list = [pil_image]
217
  image_sizes = [pil_image.size]
218
-
 
219
  processed_images = process_images(images_list, self.image_processor, self.model.config)
 
220
  if isinstance(processed_images, list):
221
  images = [img.to(self.model.device, dtype=torch.float16) for img in processed_images]
222
  else:
223
  images = processed_images.to(self.model.device, dtype=torch.float16)
224
-
225
- # Görsel token ekle + im_start/end sarma
226
- prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
 
 
 
227
  replace_token = DEFAULT_IMAGE_TOKEN
228
  if self.use_im_start_end:
229
  replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
 
 
230
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
231
- print("[info] Image processed successfully.")
 
 
232
  else:
233
- print("[warn] Could not load image or image_processor is None.")
234
  except Exception as e:
235
  print(f"[warn] Image processing failed: {e}")
236
- import traceback; traceback.print_exc()
237
- images = None; image_sizes = None
 
 
238
 
239
- # 3) Tokenization
240
  try:
241
  input_ids = tokenizer_image_token(
242
  prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt'
243
  ).unsqueeze(0).to(self.model.device)
 
 
 
 
244
  except Exception as e:
245
  print(f"[error] Tokenization failed: {e}")
246
- enc = self.tokenizer(query_text, return_tensors="pt")
247
- input_ids = enc.input_ids.to(self.model.device)
248
- images = None; image_sizes = None
249
-
250
- # attention_mask: model destekliyorsa üret ve ekleyeceğiz
251
- attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
252
 
253
- # 4) Generation params
254
  temperature = float(params.get("temperature", 0.0))
255
  top_p = float(params.get("top_p", 1.0))
256
  repetition_penalty = float(params.get("repetition_penalty", 1.0))
257
- max_new_tokens = min(int(params.get("max_new_tokens", self.MAX_NEW_TOKENS_DEF)), 1024)
258
  do_sample = bool(params.get("do_sample", temperature > 0.001))
259
-
260
- max_context_length = getattr(self.model.config, 'max_position_embeddings', 4096)
261
- max_new_tokens = min(max_new_tokens, max(1, max_context_length - input_ids.shape[-1] - 50))
 
 
262
  if max_new_tokens < 1:
263
  return [{"generated_text": "Error: Input too long, exceeds max token length."}]
264
 
265
- # 5) Gen kwargs
266
- gen_kwargs: Dict[str, Any] = {
267
- "inputs": input_ids,
268
  "max_new_tokens": max_new_tokens,
269
  "temperature": temperature,
270
  "top_p": top_p,
@@ -273,40 +386,30 @@ class EndpointHandler:
273
  "use_cache": bool(params.get("use_cache", True)),
274
  "pad_token_id": self.tokenizer.eos_token_id,
275
  }
276
- if self._supports_attention_mask:
277
- gen_kwargs["attention_mask"] = attention_mask
278
 
 
279
  if images is not None and image_sizes is not None:
280
  gen_kwargs["images"] = images
281
  gen_kwargs["image_sizes"] = image_sizes
282
- print("[info] Using images in generation.")
283
  else:
284
- # Prompt’ta olası görsel tokenlarını temizle (text-only güvenliği)
285
- prompt_clean = prompt.replace(DEFAULT_IMAGE_TOKEN, "") \
286
- .replace(DEFAULT_IM_START_TOKEN, "") \
287
- .replace(DEFAULT_IM_END_TOKEN, "")
288
- if prompt_clean != prompt:
289
- try:
290
- input_ids = self.tokenizer(prompt_clean, return_tensors="pt").input_ids.to(self.model.device)
291
- gen_kwargs["inputs"] = input_ids
292
- if self._supports_attention_mask:
293
- gen_kwargs["attention_mask"] = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
294
- except Exception as e:
295
- print(f"[warn] prompt cleanup failed: {e}")
296
- print("[info] Text-only generation.")
297
-
298
- # 6) Generate
299
  try:
300
  with torch.inference_mode():
301
  output_ids = self.model.generate(**gen_kwargs)
302
- if output_ids.shape[-1] > gen_kwargs["inputs"].shape[-1]:
303
- response_ids = output_ids[:, gen_kwargs["inputs"].shape[-1]:]
 
 
304
  text = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)[0].strip()
305
  else:
306
  text = "Error: No response generated"
 
307
  except Exception as e:
308
  print(f"Generation error: {e}")
309
- import traceback; traceback.print_exc()
 
310
  text = f"Error during generation: {str(e)}"
311
-
312
- return [{"generated_text": text}]
 
 
 
 
 
 
 
 
 
1
 
2
+ # -*- coding: utf-8 -*-
3
+ import os, io, sys, subprocess, base64
4
+ from typing import Any, Dict, List, Optional
5
 
6
  import torch
7
  from PIL import Image
8
  import requests
9
+ import math
10
+ import ast
11
 
12
+ # ===== Kullanılacak HF model id =====
13
  MODEL_ID = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
14
 
15
+ # Flash Attention için environment
16
+ os.environ.setdefault("FLASH_ATTENTION", "1")
17
+ os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
18
+
19
+ # ===== LLaVA kaynak kodunu runtime'da getir (pip yok) =====
20
  LLAVA_GIT_URL = os.getenv("LLAVA_GIT_URL", "https://github.com/haotian-liu/LLaVA.git")
21
+ LLAVA_GIT_REF = os.getenv("LLAVA_GIT_REF", "v1.2.2.post1") # kanıtlı, stabil
22
  LLAVA_SRC_DIR = os.getenv("LLAVA_SRC_DIR", "/tmp/llava_src/LLaVA")
23
 
24
  def _ensure_llava():
 
33
 
34
  _ensure_llava()
35
 
36
+ # ---- mm_utils fonksiyonlarını import etmeye çalış, yoksa kendi implement edelim ----
37
+ try:
38
+ from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path, load_image_from_base64
39
+ except ImportError:
40
+ # Fallback: kendi implementasyonumuzu kullan
41
+ from llava.constants import IMAGE_TOKEN_INDEX
42
+
43
+ def expand2square(pil_img, background_color):
44
+ width, height = pil_img.size
45
+ if width == height:
46
+ return pil_img
47
+ elif width > height:
48
+ result = Image.new(pil_img.mode, (width, width), background_color)
49
+ result.paste(pil_img, (0, (width - height) // 2))
50
+ return result
51
+ else:
52
+ result = Image.new(pil_img.mode, (height, height), background_color)
53
+ result.paste(pil_img, ((height - width) // 2, 0))
54
+ return result
55
+
56
+ def select_best_resolution(original_size, possible_resolutions):
57
+ original_width, original_height = original_size
58
+ best_fit = None
59
+ max_effective_resolution = 0
60
+ min_wasted_resolution = float('inf')
61
+
62
+ for width, height in possible_resolutions:
63
+ scale = min(width / original_width, height / original_height)
64
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
65
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
66
+ wasted_resolution = (width * height) - effective_resolution
67
+
68
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
69
+ max_effective_resolution = effective_resolution
70
+ min_wasted_resolution = wasted_resolution
71
+ best_fit = (width, height)
72
+ return best_fit
73
+
74
+ def resize_and_pad_image(image, target_resolution):
75
+ original_width, original_height = image.size
76
+ target_width, target_height = target_resolution
77
+
78
+ scale_w = target_width / original_width
79
+ scale_h = target_height / original_height
80
+
81
+ if scale_w < scale_h:
82
+ new_width = target_width
83
+ new_height = min(math.ceil(original_height * scale_w), target_height)
84
+ else:
85
+ new_height = target_height
86
+ new_width = min(math.ceil(original_width * scale_h), target_width)
87
+
88
+ resized_image = image.resize((new_width, new_height))
89
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
90
+ paste_x = (target_width - new_width) // 2
91
+ paste_y = (target_height - new_height) // 2
92
+ new_image.paste(resized_image, (paste_x, paste_y))
93
+ return new_image
94
+
95
+ def divide_to_patches(image, patch_size):
96
+ patches = []
97
+ width, height = image.size
98
+ for i in range(0, height, patch_size):
99
+ for j in range(0, width, patch_size):
100
+ box = (j, i, j + patch_size, i + patch_size)
101
+ patch = image.crop(box)
102
+ patches.append(patch)
103
+ return patches
104
+
105
+ def process_anyres_image(image, processor, grid_pinpoints):
106
+ if type(grid_pinpoints) is list:
107
+ possible_resolutions = grid_pinpoints
108
+ else:
109
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
110
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
111
+ image_padded = resize_and_pad_image(image, best_resolution)
112
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
113
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
114
+ image_patches = [image_original_resize] + patches
115
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
116
+ for image_patch in image_patches]
117
+ return torch.stack(image_patches, dim=0)
118
+
119
+ def process_images(images, image_processor, model_cfg):
120
+ """CRITICAL: Tam mm_utils.py implementasyonu"""
121
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
122
+ new_images = []
123
+ if image_aspect_ratio == 'pad':
124
+ for image in images:
125
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
126
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
127
+ new_images.append(image)
128
+ elif image_aspect_ratio == "anyres":
129
+ for image in images:
130
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
131
+ new_images.append(image)
132
+ else:
133
+ return image_processor(images, return_tensors='pt')['pixel_values']
134
+ if all(x.shape == new_images[0].shape for x in new_images):
135
+ new_images = torch.stack(new_images, dim=0)
136
+ return new_images
137
+
138
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
139
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
140
+
141
+ def insert_separator(X, sep):
142
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
143
+
144
+ input_ids = []
145
+ offset = 0
146
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
147
+ offset = 1
148
+ input_ids.append(prompt_chunks[0][0])
149
+
150
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
151
+ input_ids.extend(x[offset:])
152
+
153
+ if return_tensors is not None:
154
+ if return_tensors == 'pt':
155
+ return torch.tensor(input_ids, dtype=torch.long)
156
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
157
+ return input_ids
158
+
159
+ def get_model_name_from_path(model_path):
160
+ model_path = model_path.strip("/")
161
+ model_paths = model_path.split("/")
162
+ if model_paths[-1].startswith('checkpoint-'):
163
+ return model_paths[-2] + "_" + model_paths[-1]
164
+ else:
165
+ return model_paths[-1]
166
+
167
+ def load_image_from_base64(image):
168
+ return Image.open(io.BytesIO(base64.b64decode(image)))
169
+
170
+ # ---- LLaVA parçaları (model worker'dan alındı) ----
171
  from llava.model.builder import load_pretrained_model
172
  from llava.constants import (
173
  IMAGE_TOKEN_INDEX,
 
178
  from llava.conversation import conv_templates
179
  from llava.utils import disable_torch_init
180
 
181
+ # Varsayılanlar
182
+ DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v1")
183
+ MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "1024"))
 
 
 
 
 
 
 
184
 
185
  class EndpointHandler:
186
  """
187
  Girdi:
188
+ {
189
+ "inputs": { "query": "...", "image": "<url|dataurl|path>" },
190
+ "parameters": { "max_new_tokens": 256, "temperature": 0.0, "top_p": 1.0,
191
+ "repetition_penalty": 1.0, "do_sample": false, "use_cache": true },
192
+ "conv_mode": "llava_v2" # opsiyonel
193
+ }
 
 
194
  Çıktı: [ { "generated_text": "..." } ]
195
  """
196
  def __init__(self, path: str = "") -> None:
197
  disable_torch_init()
198
 
199
+ # PULSE-7B HF'den/yerelden nereden yükleniyorsa yolu belirle
200
  if os.getenv("HF_MODEL_LOCAL_DIR", "").strip():
201
  model_path = os.getenv("HF_MODEL_LOCAL_DIR").strip()
202
  elif os.getenv("HF_MODEL_ID", "").strip():
203
  model_path = os.getenv("HF_MODEL_ID").strip()
204
  else:
205
+ model_path = MODEL_ID # default: HF Hub PULSE-7B
 
 
206
 
207
  self.model_name = get_model_name_from_path(model_path)
208
 
209
+ # Attention implementation otomatik seç
210
  try:
211
+ import flash_attn
212
  attn_impl = "flash_attention_2"
213
+ except ImportError:
214
  attn_impl = "sdpa"
215
+
216
+ # PULSE, LLaVA tabanlı olduğundan LLaVA loader ile yüklenir
217
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
218
  model_path=model_path,
219
  model_base=None,
 
224
  )
225
  self.model.eval()
226
 
 
227
  def _patch_forward(obj, label="model"):
228
  try:
229
+ if not hasattr(obj, "forward"):
230
+ return False
231
+ orig_forward = obj.forward
232
+
233
  def patched_forward(*args, **kwargs):
234
+ # Sessizce düşürülecek yeni anahtarlar
235
  kwargs.pop("cache_position", None)
236
  kwargs.pop("input_positions", None)
237
+ return orig_forward(*args, **kwargs)
238
+
239
  obj.forward = patched_forward
240
  print(f"[hotfix] Patched forward on {label}")
241
  return True
 
243
  print(f"[warn] forward patch failed on {label}: {e}")
244
  return False
245
 
246
+ # Ana modelde dene
247
  _patch_forward(self.model, "self.model")
 
 
248
 
249
+ # Bazı sürümlerde forward zinciri iç modüle de gider
250
+ if hasattr(self.model, "model"):
251
+ _patch_forward(self.model.model, "self.model.model")
252
+ if hasattr(self.model, "base_model"):
253
+ _patch_forward(self.model.base_model, "self.model.base_model")
254
+ # =======================================================================
255
+
256
+ # Model worker'dan: multimodal check
257
+ self.is_multimodal = 'llava' in self.model_name.lower() or 'pulse' in self.model_name.lower()
258
+
259
+ # Görsel token işaretleri (LLaVA config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
 
261
 
262
+ # ---- yardımcılar ----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  def _load_image(self, img_field: str) -> Optional[Image.Image]:
264
  """URL / base64 / path -> PIL.Image"""
265
+ if not img_field:
266
+ return None
267
  try:
268
  if img_field.startswith("data:image"):
269
  _, b64 = img_field.split(",", 1)
270
  return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
271
  if img_field.startswith(("http://", "https://")):
272
+ r = requests.get(img_field, timeout=20)
273
+ r.raise_for_status()
274
  return Image.open(io.BytesIO(r.content)).convert("RGB")
275
  return Image.open(img_field).convert("RGB")
276
  except Exception as e:
 
278
  return None
279
 
280
  def _build_prompt(self, user_text: str, conv_mode: str) -> str:
281
+ """Model worker tarzında prompt oluştur"""
282
  if conv_mode not in conv_templates:
283
+ conv_mode = DEFAULT_CONV_MODE
284
  conv = conv_templates[conv_mode].copy()
285
+
286
+ # Model worker'da görüntüler sonradan replace edilir
287
+ # Şimdilik sadece text ile başlayalım
288
  conv.append_message(conv.roles[0], user_text)
289
  conv.append_message(conv.roles[1], None)
290
  return conv.get_prompt()
291
 
292
+ # ---- inference ----
 
 
293
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
294
+ inputs = data.get("inputs") or {}
295
+ params = data.get("parameters") or {}
296
+ conv_mode_req = data.get("conv_mode")
297
 
298
+ conv_mode = conv_mode_req if conv_mode_req in conv_templates else DEFAULT_CONV_MODE
299
  query_text = inputs.get("query", "") or inputs.get("text", "") or inputs.get("prompt", "")
300
  image_f = inputs.get("image") or inputs.get("image_url") or inputs.get("image_base64")
301
 
302
+ # 1) İlk prompt oluştur (görüntü olmadan)
303
  prompt = self._build_prompt(query_text, conv_mode)
304
+
305
+ # 2) Görüntü işleme (model worker tarzında)
306
  images = None
307
  image_sizes = None
308
+
309
  if image_f and self.is_multimodal:
310
  try:
311
  pil_image = self._load_image(image_f)
312
+ if pil_image is not None:
313
  images_list = [pil_image]
314
  image_sizes = [pil_image.size]
315
+
316
+ # Model worker'daki gibi process et
317
  processed_images = process_images(images_list, self.image_processor, self.model.config)
318
+
319
  if isinstance(processed_images, list):
320
  images = [img.to(self.model.device, dtype=torch.float16) for img in processed_images]
321
  else:
322
  images = processed_images.to(self.model.device, dtype=torch.float16)
323
+
324
+ # Model worker'daki gibi prompt'u düzenle
325
+ # DEFAULT_IMAGE_TOKEN prompt'a ekle
326
+ prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
327
+
328
+ # Replace token hesapla (model worker'dan)
329
  replace_token = DEFAULT_IMAGE_TOKEN
330
  if self.use_im_start_end:
331
  replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
332
+
333
+ # Prompt'taki image token'ları replace et
334
  prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
335
+
336
+ print(f"[info] Image processed successfully")
337
+ print(f"[debug] Final prompt: {repr(prompt[:200])}")
338
  else:
339
+ print("[warn] Could not load image")
340
  except Exception as e:
341
  print(f"[warn] Image processing failed: {e}")
342
+ import traceback
343
+ traceback.print_exc()
344
+ images = None
345
+ image_sizes = None
346
 
347
+ # 3) Tokenize (model worker tarzında)
348
  try:
349
  input_ids = tokenizer_image_token(
350
  prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt'
351
  ).unsqueeze(0).to(self.model.device)
352
+
353
+ print(f"[debug] input_ids shape: {input_ids.shape}")
354
+ print(f"[debug] Has images: {images is not None}")
355
+
356
  except Exception as e:
357
  print(f"[error] Tokenization failed: {e}")
358
+ # Fallback to text-only
359
+ input_ids = self.tokenizer(query_text, return_tensors="pt").input_ids
360
+ input_ids = input_ids.to(self.model.device)
361
+ images = None
362
+ image_sizes = None
 
363
 
364
+ # 4) Generation parameters (model worker tarzında)
365
  temperature = float(params.get("temperature", 0.0))
366
  top_p = float(params.get("top_p", 1.0))
367
  repetition_penalty = float(params.get("repetition_penalty", 1.0))
368
+ max_new_tokens = min(int(params.get("max_new_tokens", MAX_NEW_TOKENS_DEF)), 1024)
369
  do_sample = bool(params.get("do_sample", temperature > 0.001))
370
+
371
+ # Context length check
372
+ max_context_length = getattr(self.model.config, 'max_position_embeddings', 2048)
373
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - 50)
374
+
375
  if max_new_tokens < 1:
376
  return [{"generated_text": "Error: Input too long, exceeds max token length."}]
377
 
378
+ # 5) Generation kwargs (model worker tarzında)
379
+ gen_kwargs = {
380
+ "inputs": input_ids, # model worker 'inputs' kullanır
381
  "max_new_tokens": max_new_tokens,
382
  "temperature": temperature,
383
  "top_p": top_p,
 
386
  "use_cache": bool(params.get("use_cache", True)),
387
  "pad_token_id": self.tokenizer.eos_token_id,
388
  }
 
 
389
 
390
+ # Image args (model worker tarzında)
391
  if images is not None and image_sizes is not None:
392
  gen_kwargs["images"] = images
393
  gen_kwargs["image_sizes"] = image_sizes
394
+ print(f"[info] Using images in generation")
395
  else:
396
+ print("[info] Text-only generation")
397
+
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  try:
399
  with torch.inference_mode():
400
  output_ids = self.model.generate(**gen_kwargs)
401
+
402
+ # Output'u input'tan ayır
403
+ if output_ids.shape[-1] > input_ids.shape[-1]:
404
+ response_ids = output_ids[:, input_ids.shape[-1]:]
405
  text = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)[0].strip()
406
  else:
407
  text = "Error: No response generated"
408
+
409
  except Exception as e:
410
  print(f"Generation error: {e}")
411
+ import traceback
412
+ traceback.print_exc()
413
  text = f"Error during generation: {str(e)}"
414
+
415
+ return [{"generated_text": text}]