CanerDedeoglu commited on
Commit
57494c3
·
verified ·
1 Parent(s): ae66524

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +76 -59
handler.py CHANGED
@@ -1,11 +1,10 @@
1
  # -*- coding: utf-8 -*-
2
- # handler.py — PULSE-7B / LLaVA robust endpoint (minimal & stable)
3
- # - PULSE fork (AIMedLab/PULSE:dev) üzerinden LLaVA yükleme
4
- # - Güvenli image loader + processor normalizasyonu
5
- # - ANYRES->PAD fallback
6
- # - Forward patch: cache_position/input_positions sessizce at
7
- # - KRİTİK FIX: generate çağrısına hem `inputs` hem de `input_ids` ver (NoneType.new_ones biter)
8
- # - attention_mask gönderme (LLaVA kendi içinde hallediyor)
9
 
10
  import os, io, sys, subprocess, base64
11
  from typing import Any, Dict, List, Optional, Tuple
@@ -18,15 +17,15 @@ import ast
18
  import inspect
19
  from urllib.parse import urlparse
20
 
21
- # ===== Model/config =====
22
  MODEL_ID = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
23
  DEFAULT_VISION_TOWER_ID = os.getenv("HF_VISION_TOWER_ID", "openai/clip-vit-large-patch14-336")
24
 
25
- # ===== Flash Attention/env =====
26
  os.environ.setdefault("FLASH_ATTENTION", "1")
27
  os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
28
 
29
- # ===== Pull LLaVA from PULSE repo =====
30
  LLAVA_GIT_URL = os.getenv("LLAVA_GIT_URL", "https://github.com/AIMedLab/PULSE.git")
31
  LLAVA_GIT_REF = os.getenv("LLAVA_GIT_REF", "dev")
32
  LLAVA_SRC_DIR = os.getenv("LLAVA_SRC_DIR", "/tmp/llava_src/PULSE/LLaVA")
@@ -48,6 +47,7 @@ _ensure_llava()
48
  try:
49
  from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path, load_image_from_base64
50
  except Exception:
 
51
  from llava.constants import IMAGE_TOKEN_INDEX
52
 
53
  def expand2square(pil_img: Image.Image, background_color: Tuple[int,int,int]) -> Image.Image:
@@ -129,15 +129,12 @@ except Exception:
129
  chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
130
  def insert_sep(X, sep):
131
  return [e for sub in zip(X, [sep]*len(X)) for e in sub][:-1]
132
- ids = []
133
- offset = 0
134
  if len(chunks) > 0 and len(chunks[0]) > 0 and chunks[0][0] == tokenizer.bos_token_id:
135
- offset = 1
136
- ids.append(chunks[0][0])
137
- for x in insert_sep(chunks, [IMAGE_TOKEN_INDEX]*(offset+1)):
138
  ids.extend(x[offset:])
139
- if return_tensors == 'pt':
140
- return torch.tensor(ids, dtype=torch.long)
141
  return ids
142
 
143
  def get_model_name_from_path(model_path):
@@ -147,7 +144,7 @@ except Exception:
147
  def load_image_from_base64(image):
148
  return Image.open(io.BytesIO(base64.b64decode(image)))
149
 
150
- # ---- LLaVA (PULSE fork) ----
151
  from llava.model.builder import load_pretrained_model
152
  from llava.constants import (
153
  IMAGE_TOKEN_INDEX,
@@ -157,7 +154,10 @@ from llava.constants import (
157
  )
158
  from llava.conversation import conv_templates
159
  from llava.utils import disable_torch_init
 
160
  from transformers import AutoProcessor, AutoImageProcessor, CLIPImageProcessor
 
 
161
 
162
  DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v1")
163
  MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "1024"))
@@ -175,12 +175,14 @@ class EndpointHandler:
175
 
176
  self.model_name = get_model_name_from_path(model_path)
177
 
 
178
  try:
179
  import flash_attn # noqa
180
  attn_impl = "flash_attention_2"
181
  except Exception:
182
  attn_impl = "sdpa"
183
 
 
184
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
185
  model_path=model_path,
186
  model_base=None,
@@ -204,7 +206,7 @@ class EndpointHandler:
204
  except Exception:
205
  pass
206
 
207
- # forward patch: gereksiz arg'ları sil
208
  def _patch_forward(obj, label="model"):
209
  try:
210
  if not hasattr(obj, "forward"): return False
@@ -234,7 +236,7 @@ class EndpointHandler:
234
  except Exception as e:
235
  print(f"[warn] AutoProcessor başarısız: {e}")
236
  vt_id = self._resolve_vision_tower_id(self.model.config)
237
- print(f"[hotfix] trying to load image_processor from vision_tower: {vt_id}")
238
  try:
239
  self.image_processor = AutoImageProcessor.from_pretrained(vt_id, trust_remote_code=True)
240
  print("[info] image_processor loaded via AutoImageProcessor(vision_tower)")
@@ -263,14 +265,14 @@ class EndpointHandler:
263
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
264
  self.is_multimodal = ('llava' in self.model_name.lower()) or ('pulse' in self.model_name.lower())
265
 
266
- # ------------- helpers -------------
267
  def _resolve_vision_tower_id(self, config: Any) -> str:
268
  for key in ("mm_vision_tower", "vision_tower", "mm_vision_tower_name", "image_tower", "visual_encoder"):
269
  v = getattr(config, key, None)
270
  if isinstance(v, str) and v.strip(): return v.strip()
271
  try:
272
- v = getattr(config, "vision_tower", None)
273
- name = getattr(getattr(v, "config", None), "_name_or_path", None)
274
  if isinstance(name, str) and name.strip(): return name.strip()
275
  except Exception:
276
  pass
@@ -318,6 +320,7 @@ class EndpointHandler:
318
  return True
319
 
320
  try:
 
321
  if isinstance(image_input, str) and image_input.startswith(("http://", "https://")):
322
  if not _is_valid_image_format(image_input):
323
  print("[warn] Invalid image extension in URL"); return None
@@ -333,6 +336,7 @@ class EndpointHandler:
333
  img = Image.open(io.BytesIO(data)).convert("RGB")
334
  print(f"[info] URL image loaded: size={img.size}"); return img
335
 
 
336
  if isinstance(image_input, str):
337
  b64 = image_input.strip()
338
  if b64.startswith("data:image"):
@@ -346,6 +350,7 @@ class EndpointHandler:
346
  img = Image.open(io.BytesIO(data)).convert("RGB")
347
  print(f"[info] Base64 image loaded: size={img.size}"); return img
348
 
 
349
  if isinstance(image_input, str) and os.path.exists(image_input):
350
  img = Image.open(image_input).convert("RGB")
351
  print(f"[info] Local image loaded: size={img.size}"); return img
@@ -362,7 +367,13 @@ class EndpointHandler:
362
  conv.append_message(conv.roles[1], None)
363
  return conv.get_prompt()
364
 
365
- # ------------- inference -------------
 
 
 
 
 
 
366
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
367
  inputs = data.get("inputs") or {}
368
  params = data.get("parameters") or {}
@@ -382,7 +393,7 @@ class EndpointHandler:
382
  try:
383
  pil_image = self._load_image(image_f)
384
  if pil_image is not None and self.image_processor is not None:
385
- processed_images = process_images([pil_image], self.image_processor, self.model.config)
386
  # model device/dtype
387
  try:
388
  mdev = next(self.model.parameters()).device
@@ -390,12 +401,13 @@ class EndpointHandler:
390
  except Exception:
391
  mdev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
392
  mdtype = torch.float16 if mdev.type == "cuda" else torch.float32
393
- if isinstance(processed_images, list):
394
- images = [img.to(mdev, dtype=mdtype) for img in processed_images]
395
  else:
396
- images = processed_images.to(mdev, dtype=mdtype)
397
  image_sizes = [pil_image.size]
398
- # image tokens
 
399
  prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
400
  rep = DEFAULT_IMAGE_TOKEN
401
  if self.use_im_start_end:
@@ -412,20 +424,19 @@ class EndpointHandler:
412
  # 3) tokenize
413
  try:
414
  mdev = next(self.model.parameters()).device
415
- input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') \
416
- .unsqueeze(0).to(mdev)
 
417
  print(f"[debug] input_ids shape: {input_ids.shape} | has images: {images is not None}")
418
  except Exception as e:
419
  print(f"[error] Tokenization failed: {e}")
420
- try:
421
- input_ids = self.tokenizer(query_text, return_tensors="pt").input_ids.to(next(self.model.parameters()).device)
422
- images = None; image_sizes = None
423
- print("[warn] Fallback to basic tokenization without image tokens")
424
- except Exception as e2:
425
- print(f"[error] Even basic tokenization failed: {e2}")
426
- return [{"generated_text": f"Error: Tokenization failed: {str(e)}"}]
427
 
428
- # 4) gen params (attention_mask YOK)
 
 
 
429
  temperature = float(params.get("temperature", 0.0))
430
  top_p = float(params.get("top_p", 1.0))
431
  repetition_penalty = float(params.get("repetition_penalty", 1.0))
@@ -437,50 +448,56 @@ class EndpointHandler:
437
  if max_new_tokens < 1:
438
  return [{"generated_text": "Error: Input too long, exceeds max token length."}]
439
 
440
- gen_kwargs = {
441
- # KRİTİK: Hem `inputs` hem de `input_ids` veriyoruz
442
- "inputs": input_ids,
443
- "input_ids": input_ids,
444
  "max_new_tokens": max_new_tokens,
445
  "temperature": temperature,
446
  "top_p": top_p,
447
  "repetition_penalty": repetition_penalty,
448
  "do_sample": do_sample,
449
- # attention_mask verme!
450
  "use_cache": bool(params.get("use_cache", True)),
451
  "pad_token_id": self.tokenizer.pad_token_id,
452
  "eos_token_id": getattr(self.tokenizer, "eos_token_id", None),
453
  "bos_token_id": getattr(self.tokenizer, "bos_token_id", None),
454
  }
 
 
 
 
 
 
455
  if images is not None and image_sizes is not None:
456
  gen_kwargs["images"] = images
457
  gen_kwargs["image_sizes"] = image_sizes
458
 
459
- # 5) generate
460
  try:
461
  with torch.inference_mode():
462
- output = self.model.generate(**gen_kwargs)
463
  except Exception as e:
464
- # Son çare: cache kapalı tekrar dene
465
- print(f"[warn] First generate failed: {e} | retry with use_cache=False")
466
- gen_kwargs["use_cache"] = False
467
  try:
 
 
 
 
 
 
 
468
  with torch.inference_mode():
469
- output = self.model.generate(**gen_kwargs)
470
  except Exception as e2:
471
- print(f"[error] Generation failed: {e2}")
472
- import traceback; traceback.print_exc()
473
  return [{"generated_text": f"Error during generation: {str(e2)}"}]
474
 
475
- # 6) decode
476
  try:
477
  sequences = output.sequences if hasattr(output, "sequences") else output
478
- input_len = input_ids.shape[1]
479
- response_ids = sequences[:, input_len:] if sequences.shape[-1] > input_len else sequences
480
- text = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)[0].strip()
481
  if not text:
482
- text = "Error: Empty response generated"
483
  return [{"generated_text": text}]
484
  except Exception as e:
485
- print(f"[error] Response decoding failed: {e}")
486
- return [{"generated_text": f"Error: Response decoding failed: {str(e)}"}]
 
1
  # -*- coding: utf-8 -*-
2
+ # handler.py — PULSE-7B / LLaVA robust endpoint (final fix)
3
+ # - Kaynak: AIMedLab/PULSE (dev) LLaVA fork
4
+ # - Güvenli image load + processor normalize
5
+ # - DOLU attention_mask oluşturma
6
+ # - Üretimi HF GenerationMixin ile çağır (LLaVA generate override'ını bypass)
7
+ # - forward() patch: cache_position/input_positions düşür
 
8
 
9
  import os, io, sys, subprocess, base64
10
  from typing import Any, Dict, List, Optional, Tuple
 
17
  import inspect
18
  from urllib.parse import urlparse
19
 
20
+ # ===== Model / Config =====
21
  MODEL_ID = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
22
  DEFAULT_VISION_TOWER_ID = os.getenv("HF_VISION_TOWER_ID", "openai/clip-vit-large-patch14-336")
23
 
24
+ # Flash Attention
25
  os.environ.setdefault("FLASH_ATTENTION", "1")
26
  os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
27
 
28
+ # ===== LLaVA (AIMedLab/PULSE dev) kaynak kodunu getir =====
29
  LLAVA_GIT_URL = os.getenv("LLAVA_GIT_URL", "https://github.com/AIMedLab/PULSE.git")
30
  LLAVA_GIT_REF = os.getenv("LLAVA_GIT_REF", "dev")
31
  LLAVA_SRC_DIR = os.getenv("LLAVA_SRC_DIR", "/tmp/llava_src/PULSE/LLaVA")
 
47
  try:
48
  from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path, load_image_from_base64
49
  except Exception:
50
+ # Minimal fallback'lar
51
  from llava.constants import IMAGE_TOKEN_INDEX
52
 
53
  def expand2square(pil_img: Image.Image, background_color: Tuple[int,int,int]) -> Image.Image:
 
129
  chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
130
  def insert_sep(X, sep):
131
  return [e for sub in zip(X, [sep]*len(X)) for e in sub][:-1]
132
+ ids = []; offset = 0
 
133
  if len(chunks) > 0 and len(chunks[0]) > 0 and chunks[0][0] == tokenizer.bos_token_id:
134
+ offset = 1; ids.append(chunks[0][0])
135
+ for x in insert_sep(chunks, [image_token_index]*(offset+1)):
 
136
  ids.extend(x[offset:])
137
+ if return_tensors == 'pt': return torch.tensor(ids, dtype=torch.long)
 
138
  return ids
139
 
140
  def get_model_name_from_path(model_path):
 
144
  def load_image_from_base64(image):
145
  return Image.open(io.BytesIO(base64.b64decode(image)))
146
 
147
+ # ---- LLaVA parçaları ----
148
  from llava.model.builder import load_pretrained_model
149
  from llava.constants import (
150
  IMAGE_TOKEN_INDEX,
 
154
  )
155
  from llava.conversation import conv_templates
156
  from llava.utils import disable_torch_init
157
+
158
  from transformers import AutoProcessor, AutoImageProcessor, CLIPImageProcessor
159
+ # ÖNEMLİ: HF GenerationMixin'i doğrudan çağıracağız (LLaVA override'ını bypass)
160
+ from transformers.generation.utils import GenerationMixin as HFGenerationMixin
161
 
162
  DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v1")
163
  MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "1024"))
 
175
 
176
  self.model_name = get_model_name_from_path(model_path)
177
 
178
+ # attention impl
179
  try:
180
  import flash_attn # noqa
181
  attn_impl = "flash_attention_2"
182
  except Exception:
183
  attn_impl = "sdpa"
184
 
185
+ # LLaVA/PULSE modeli yükle
186
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
187
  model_path=model_path,
188
  model_base=None,
 
206
  except Exception:
207
  pass
208
 
209
+ # forward patch: bilinmeyen kwargs'ları sessiz düşür
210
  def _patch_forward(obj, label="model"):
211
  try:
212
  if not hasattr(obj, "forward"): return False
 
236
  except Exception as e:
237
  print(f"[warn] AutoProcessor başarısız: {e}")
238
  vt_id = self._resolve_vision_tower_id(self.model.config)
239
+ print(f"[hotfix] trying vision_tower: {vt_id}")
240
  try:
241
  self.image_processor = AutoImageProcessor.from_pretrained(vt_id, trust_remote_code=True)
242
  print("[info] image_processor loaded via AutoImageProcessor(vision_tower)")
 
265
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
266
  self.is_multimodal = ('llava' in self.model_name.lower()) or ('pulse' in self.model_name.lower())
267
 
268
+ # ---------- helpers ----------
269
  def _resolve_vision_tower_id(self, config: Any) -> str:
270
  for key in ("mm_vision_tower", "vision_tower", "mm_vision_tower_name", "image_tower", "visual_encoder"):
271
  v = getattr(config, key, None)
272
  if isinstance(v, str) and v.strip(): return v.strip()
273
  try:
274
+ vt = getattr(config, "vision_tower", None)
275
+ name = getattr(getattr(vt, "config", None), "_name_or_path", None)
276
  if isinstance(name, str) and name.strip(): return name.strip()
277
  except Exception:
278
  pass
 
320
  return True
321
 
322
  try:
323
+ # URL
324
  if isinstance(image_input, str) and image_input.startswith(("http://", "https://")):
325
  if not _is_valid_image_format(image_input):
326
  print("[warn] Invalid image extension in URL"); return None
 
336
  img = Image.open(io.BytesIO(data)).convert("RGB")
337
  print(f"[info] URL image loaded: size={img.size}"); return img
338
 
339
+ # Base64 (data URL dahil)
340
  if isinstance(image_input, str):
341
  b64 = image_input.strip()
342
  if b64.startswith("data:image"):
 
350
  img = Image.open(io.BytesIO(data)).convert("RGB")
351
  print(f"[info] Base64 image loaded: size={img.size}"); return img
352
 
353
+ # Yerel path
354
  if isinstance(image_input, str) and os.path.exists(image_input):
355
  img = Image.open(image_input).convert("RGB")
356
  print(f"[info] Local image loaded: size={img.size}"); return img
 
367
  conv.append_message(conv.roles[1], None)
368
  return conv.get_prompt()
369
 
370
+ def _create_attention_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
371
+ attn = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
372
+ if self.tokenizer.pad_token_id is not None:
373
+ attn = attn.masked_fill(input_ids == self.tokenizer.pad_token_id, 0)
374
+ return attn
375
+
376
+ # ---------- inference ----------
377
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
378
  inputs = data.get("inputs") or {}
379
  params = data.get("parameters") or {}
 
393
  try:
394
  pil_image = self._load_image(image_f)
395
  if pil_image is not None and self.image_processor is not None:
396
+ processed = process_images([pil_image], self.image_processor, self.model.config)
397
  # model device/dtype
398
  try:
399
  mdev = next(self.model.parameters()).device
 
401
  except Exception:
402
  mdev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
403
  mdtype = torch.float16 if mdev.type == "cuda" else torch.float32
404
+ if isinstance(processed, list):
405
+ images = [img.to(mdev, dtype=mdtype) for img in processed]
406
  else:
407
+ images = processed.to(mdev, dtype=mdtype)
408
  image_sizes = [pil_image.size]
409
+
410
+ # image token(ları)
411
  prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
412
  rep = DEFAULT_IMAGE_TOKEN
413
  if self.use_im_start_end:
 
424
  # 3) tokenize
425
  try:
426
  mdev = next(self.model.parameters()).device
427
+ input_ids = tokenizer_image_token(
428
+ prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt'
429
+ ).unsqueeze(0).to(mdev)
430
  print(f"[debug] input_ids shape: {input_ids.shape} | has images: {images is not None}")
431
  except Exception as e:
432
  print(f"[error] Tokenization failed: {e}")
433
+ input_ids = self.tokenizer(query_text, return_tensors="pt").input_ids.to(next(self.model.parameters()).device)
434
+ images = None; image_sizes = None
 
 
 
 
 
435
 
436
+ # 4) attention mask
437
+ attention_mask = self._create_attention_mask(input_ids)
438
+
439
+ # 5) generation params
440
  temperature = float(params.get("temperature", 0.0))
441
  top_p = float(params.get("top_p", 1.0))
442
  repetition_penalty = float(params.get("repetition_penalty", 1.0))
 
448
  if max_new_tokens < 1:
449
  return [{"generated_text": "Error: Input too long, exceeds max token length."}]
450
 
451
+ # 6) HF GenerationMixin ile üret (LLaVA generate override BYPASS)
452
+ common_params = {
 
 
453
  "max_new_tokens": max_new_tokens,
454
  "temperature": temperature,
455
  "top_p": top_p,
456
  "repetition_penalty": repetition_penalty,
457
  "do_sample": do_sample,
 
458
  "use_cache": bool(params.get("use_cache", True)),
459
  "pad_token_id": self.tokenizer.pad_token_id,
460
  "eos_token_id": getattr(self.tokenizer, "eos_token_id", None),
461
  "bos_token_id": getattr(self.tokenizer, "bos_token_id", None),
462
  }
463
+
464
+ gen_kwargs = {
465
+ "inputs": input_ids, # DİKKAT: 'inputs'
466
+ "attention_mask": attention_mask, # Maske burada
467
+ **common_params
468
+ }
469
  if images is not None and image_sizes is not None:
470
  gen_kwargs["images"] = images
471
  gen_kwargs["image_sizes"] = image_sizes
472
 
 
473
  try:
474
  with torch.inference_mode():
475
+ output = HFGenerationMixin.generate(self.model, **gen_kwargs)
476
  except Exception as e:
477
+ # son çare: masksiz minimal
478
+ print(f"[warn] HF mixin generate failed: {e} | retry minimal no-mask")
 
479
  try:
480
+ minimal = {
481
+ "max_new_tokens": max_new_tokens,
482
+ "do_sample": False,
483
+ "temperature": 0.0,
484
+ "use_cache": False,
485
+ "pad_token_id": self.tokenizer.pad_token_id,
486
+ }
487
  with torch.inference_mode():
488
+ output = HFGenerationMixin.generate(self.model, inputs=input_ids, **minimal)
489
  except Exception as e2:
 
 
490
  return [{"generated_text": f"Error during generation: {str(e2)}"}]
491
 
492
+ # 7) decode
493
  try:
494
  sequences = output.sequences if hasattr(output, "sequences") else output
495
+ in_len = input_ids.shape[1]
496
+ resp_ids = sequences[:, in_len:] if sequences.shape[-1] > in_len else sequences
497
+ text = self.tokenizer.batch_decode(resp_ids, skip_special_tokens=True)[0].strip()
498
  if not text:
499
+ text = "Error: Empty response"
500
  return [{"generated_text": text}]
501
  except Exception as e:
502
+ print(f"[error] Decoding failed: {e}")
503
+ return [{"generated_text": f"Error during decoding: {str(e)}"}]