CanerDedeoglu commited on
Commit
ae66524
·
verified ·
1 Parent(s): 4d6ae93

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +47 -204
handler.py CHANGED
@@ -1,11 +1,11 @@
1
  # -*- coding: utf-8 -*-
2
- # handler.py — PULSE-7B / LLaVA robust endpoint (fixed: 'inputs' & NoneType.new_ones)
3
- # - Uses PULSE fork of LLaVA (AIMedLab/PULSE:dev)
4
- # - Safe image loading + processor normalization
5
- # - Attention-mask creation + "mask injection" fallback
6
- # - Fix for duplicate 'inputs' kwarg during fallback
7
- # - Small forward() patch to drop unknown kwargs
8
- # - FIXED: NoneType.new_ones error in mask injection
9
 
10
  import os, io, sys, subprocess, base64
11
  from typing import Any, Dict, List, Optional, Tuple
@@ -134,7 +134,7 @@ except Exception:
134
  if len(chunks) > 0 and len(chunks[0]) > 0 and chunks[0][0] == tokenizer.bos_token_id:
135
  offset = 1
136
  ids.append(chunks[0][0])
137
- for x in insert_sep(chunks, [image_token_index]*(offset+1)):
138
  ids.extend(x[offset:])
139
  if return_tensors == 'pt':
140
  return torch.tensor(ids, dtype=torch.long)
@@ -157,7 +157,6 @@ from llava.constants import (
157
  )
158
  from llava.conversation import conv_templates
159
  from llava.utils import disable_torch_init
160
-
161
  from transformers import AutoProcessor, AutoImageProcessor, CLIPImageProcessor
162
 
163
  DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v1")
@@ -174,7 +173,7 @@ class EndpointHandler:
174
  else:
175
  model_path = MODEL_ID
176
 
177
- self.model_name = self._get_model_name_from_path(model_path)
178
 
179
  try:
180
  import flash_attn # noqa
@@ -234,7 +233,7 @@ class EndpointHandler:
234
  print("[info] image_processor loaded via AutoProcessor(model_path)")
235
  except Exception as e:
236
  print(f"[warn] AutoProcessor başarısız: {e}")
237
- vt_id = self._resolve_vision_tower_id(self.model.config, model_path)
238
  print(f"[hotfix] trying to load image_processor from vision_tower: {vt_id}")
239
  try:
240
  self.image_processor = AutoImageProcessor.from_pretrained(vt_id, trust_remote_code=True)
@@ -264,131 +263,17 @@ class EndpointHandler:
264
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
265
  self.is_multimodal = ('llava' in self.model_name.lower()) or ('pulse' in self.model_name.lower())
266
 
267
- # ---- FIXED: mask injection helper ----
268
- def _generate_with_injected_mask(self, input_ids, images, image_sizes, attention_mask, base_kwargs):
269
- """
270
- Inject attention_mask inside prepare_inputs_for_generation so HF generate uses it,
271
- while avoiding duplicate kwargs like 'inputs' or 'attention_mask'.
272
- FIXED: Better handling of None values and tensor validation.
273
- """
274
- orig_prepare = getattr(self.model, "prepare_inputs_for_generation", None)
275
- if orig_prepare is None:
276
- print("[error] Model has no prepare_inputs_for_generation method")
277
- raise RuntimeError("Model doesn't support mask injection fallback")
278
-
279
- def patched_prepare(input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
280
- try:
281
- # Call original prepare method
282
- model_inputs = orig_prepare(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
283
-
284
- # Validate model_inputs is not None and is a dict
285
- if model_inputs is None:
286
- print("[error] prepare_inputs_for_generation returned None")
287
- model_inputs = {}
288
- elif not isinstance(model_inputs, dict):
289
- print(f"[error] prepare_inputs_for_generation returned non-dict: {type(model_inputs)}")
290
- model_inputs = {}
291
-
292
- # Only inject attention_mask if it's not already present and we have a valid mask
293
- if model_inputs.get("attention_mask", None) is None and attention_mask is not None:
294
- # Validate attention_mask
295
- if isinstance(attention_mask, torch.Tensor) and attention_mask.numel() > 0:
296
- model_inputs["attention_mask"] = attention_mask
297
- print(f"[debug] Injected attention_mask with shape: {attention_mask.shape}")
298
- else:
299
- print("[warn] Invalid attention_mask, skipping injection")
300
-
301
- # Ensure input_ids is present
302
- if "input_ids" not in model_inputs and input_ids is not None:
303
- model_inputs["input_ids"] = input_ids
304
-
305
- return model_inputs
306
-
307
- except Exception as e:
308
- print(f"[error] Error in patched_prepare: {e}")
309
- # Return minimal valid dict to avoid None errors
310
- return {"input_ids": input_ids}
311
-
312
- # Apply the patch
313
- self.model.prepare_inputs_for_generation = patched_prepare
314
-
315
- try:
316
- # IMPORTANT: Remove 'attention_mask' and 'inputs' from kwargs to avoid conflicts
317
- patched_kwargs = {k: v for k, v in base_kwargs.items() if k not in ("attention_mask", "inputs")}
318
-
319
- # Add images and image_sizes if they exist
320
- if images is not None:
321
- patched_kwargs["images"] = images
322
- if image_sizes is not None:
323
- patched_kwargs["image_sizes"] = image_sizes
324
-
325
- # Validate input_ids before generation
326
- if input_ids is None or not isinstance(input_ids, torch.Tensor) or input_ids.numel() == 0:
327
- raise ValueError("Invalid input_ids for generation")
328
-
329
- print(f"[debug] Starting generation with input_ids shape: {input_ids.shape}")
330
- with torch.inference_mode():
331
- output = self.model.generate(inputs=input_ids, **patched_kwargs)
332
-
333
- return output
334
-
335
- except Exception as e:
336
- print(f"[error] Generation failed in mask injection: {e}")
337
- raise e
338
- finally:
339
- # Always restore original method
340
- self.model.prepare_inputs_for_generation = orig_prepare
341
-
342
- # ---- ADDED: Simplified fallback without mask injection ----
343
- def _generate_without_mask(self, input_ids, images, image_sizes, base_kwargs):
344
- """
345
- Fallback generation without attention_mask for models that don't support it well.
346
- """
347
- try:
348
- # Remove problematic arguments
349
- clean_kwargs = {k: v for k, v in base_kwargs.items()
350
- if k not in ("attention_mask", "inputs")}
351
-
352
- # Add multimodal inputs if present
353
- if images is not None:
354
- clean_kwargs["images"] = images
355
- if image_sizes is not None:
356
- clean_kwargs["image_sizes"] = image_sizes
357
-
358
- # Force use_cache=False for stability
359
- clean_kwargs["use_cache"] = False
360
-
361
- # Ensure we have basic required parameters
362
- clean_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
363
- clean_kwargs.setdefault("eos_token_id", self.tokenizer.eos_token_id)
364
-
365
- print(f"[debug] Fallback generation without mask, kwargs: {list(clean_kwargs.keys())}")
366
-
367
- with torch.inference_mode():
368
- output = self.model.generate(inputs=input_ids, **clean_kwargs)
369
-
370
- return output
371
-
372
- except Exception as e:
373
- print(f"[error] Fallback generation failed: {e}")
374
- raise e
375
-
376
  # ------------- helpers -------------
377
- def _get_model_name_from_path(self, model_path: str) -> str:
378
- p = model_path.strip("/").split("/")
379
- return (p[-2] + "_" + p[-1]) if p[-1].startswith("checkpoint-") else p[-1]
380
-
381
- def _resolve_vision_tower_id(self, config: Any, model_path: str) -> str:
382
  for key in ("mm_vision_tower", "vision_tower", "mm_vision_tower_name", "image_tower", "visual_encoder"):
383
  v = getattr(config, key, None)
384
  if isinstance(v, str) and v.strip(): return v.strip()
385
- for key in ("mm_vision_tower", "vision_tower"):
386
- v = getattr(config, key, None)
387
- try:
388
- name = getattr(getattr(v, "config", None), "_name_or_path", None)
389
- if isinstance(name, str) and name.strip(): return name.strip()
390
- except Exception:
391
- pass
392
  return DEFAULT_VISION_TOWER_ID
393
 
394
  def _normalize_image_processor(self) -> bool:
@@ -477,26 +362,7 @@ class EndpointHandler:
477
  conv.append_message(conv.roles[1], None)
478
  return conv.get_prompt()
479
 
480
- def _create_robust_attention_mask(self, input_ids: torch.Tensor) -> Optional[torch.Tensor]:
481
- try:
482
- if input_ids is None or not isinstance(input_ids, torch.Tensor):
483
- print("[warn] Invalid input_ids for attention mask creation")
484
- return None
485
-
486
- device = input_ids.device
487
- attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
488
-
489
- if self.tokenizer.pad_token_id is not None:
490
- attention_mask = attention_mask.masked_fill(input_ids == self.tokenizer.pad_token_id, 0)
491
-
492
- print(f"[debug] Created attention_mask: shape={attention_mask.shape}, device={device}")
493
- return attention_mask
494
-
495
- except Exception as e:
496
- print(f"[error] Failed to create attention_mask: {e}")
497
- return None
498
-
499
- # ------------- IMPROVED: inference with better error handling -------------
500
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
501
  inputs = data.get("inputs") or {}
502
  params = data.get("parameters") or {}
@@ -543,7 +409,7 @@ class EndpointHandler:
543
  import traceback; traceback.print_exc()
544
  images = None; image_sizes = None
545
 
546
- # 3) tokenize + robust mask
547
  try:
548
  mdev = next(self.model.parameters()).device
549
  input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') \
@@ -551,12 +417,15 @@ class EndpointHandler:
551
  print(f"[debug] input_ids shape: {input_ids.shape} | has images: {images is not None}")
552
  except Exception as e:
553
  print(f"[error] Tokenization failed: {e}")
554
- input_ids = self.tokenizer(query_text, return_tensors="pt").input_ids.to(next(self.model.parameters()).device)
555
- images = None; image_sizes = None
556
-
557
- attention_mask = self._create_robust_attention_mask(input_ids)
 
 
 
558
 
559
- # 4) gen params
560
  temperature = float(params.get("temperature", 0.0))
561
  top_p = float(params.get("top_p", 1.0))
562
  repetition_penalty = float(params.get("repetition_penalty", 1.0))
@@ -568,76 +437,50 @@ class EndpointHandler:
568
  if max_new_tokens < 1:
569
  return [{"generated_text": "Error: Input too long, exceeds max token length."}]
570
 
571
- # --- Strategy 1: Normal path with attention_mask ---
572
  gen_kwargs = {
 
573
  "inputs": input_ids,
574
- "attention_mask": attention_mask,
575
- "use_cache": bool(params.get("use_cache", True)),
576
  "max_new_tokens": max_new_tokens,
577
  "temperature": temperature,
578
  "top_p": top_p,
579
  "repetition_penalty": repetition_penalty,
580
  "do_sample": do_sample,
 
 
581
  "pad_token_id": self.tokenizer.pad_token_id,
582
  "eos_token_id": getattr(self.tokenizer, "eos_token_id", None),
583
  "bos_token_id": getattr(self.tokenizer, "bos_token_id", None),
584
  }
585
- if images is not None:
586
  gen_kwargs["images"] = images
587
  gen_kwargs["image_sizes"] = image_sizes
588
 
 
589
  try:
590
- print("[debug] Trying generation: normal path (with attention_mask)")
591
  with torch.inference_mode():
592
  output = self.model.generate(**gen_kwargs)
593
- except ValueError as e:
594
- msg = str(e)
595
- if "model_kwargs" in msg and "attention_mask" in msg and "not used" in msg:
596
- print("[hotfix] model rejected attention_mask; retrying via mask injection (no kwargs mask) + use_cache=False")
597
- gen_kwargs_no_mask = {k: v for k, v in gen_kwargs.items() if k not in ("attention_mask", "inputs")}
598
- gen_kwargs_no_mask["use_cache"] = False
599
- output = self._generate_with_injected_mask(
600
- input_ids=input_ids,
601
- images=images,
602
- image_sizes=image_sizes,
603
- attention_mask=attention_mask,
604
- base_kwargs=gen_kwargs_no_mask
605
- )
606
- else:
607
- print(f"Generation error: {e}")
608
- import traceback; traceback.print_exc()
609
- return [{"generated_text": f"Error during generation: {msg}"}]
610
  except Exception as e:
611
- emsg = str(e)
612
- print(f"[warn] Normal path failed: {emsg}")
613
- print("[hotfix] retry via NO-MASK fallback")
614
  try:
615
- output = self._generate_without_mask(
616
- input_ids=input_ids,
617
- images=images,
618
- image_sizes=image_sizes,
619
- base_kwargs=gen_kwargs
620
- )
621
  except Exception as e2:
622
- print(f"[error] No-mask fallback failed: {e2}")
623
  import traceback; traceback.print_exc()
624
- return [{"generated_text": f"Error during generation: {emsg} | fallback: {e2}"}]
625
 
626
- # 5) Decode response
627
  try:
628
  sequences = output.sequences if hasattr(output, "sequences") else output
629
  input_len = input_ids.shape[1]
630
- if sequences.shape[-1] > input_len:
631
- response_ids = sequences[:, input_len:]
632
- else:
633
- response_ids = sequences
634
  text = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)[0].strip()
635
- return [{
636
- "generated_text": text,
637
- "input_tokens": int(input_len),
638
- "output_tokens": int(response_ids.shape[-1]),
639
- "strategy_used": "normal_or_injected_mask_or_nomask"
640
- }]
641
  except Exception as e:
642
- print(f"[error] Decoding failed: {e}")
643
- return [{"generated_text": f"Error during decoding: {str(e)}"}]
 
1
  # -*- coding: utf-8 -*-
2
+ # handler.py — PULSE-7B / LLaVA robust endpoint (minimal & stable)
3
+ # - PULSE fork (AIMedLab/PULSE:dev) üzerinden LLaVA yükleme
4
+ # - Güvenli image loader + processor normalizasyonu
5
+ # - ANYRES->PAD fallback
6
+ # - Forward patch: cache_position/input_positions sessizce at
7
+ # - KRİTİK FIX: generate çağrısına hem `inputs` hem de `input_ids` ver (NoneType.new_ones biter)
8
+ # - attention_mask gönderme (LLaVA kendi içinde hallediyor)
9
 
10
  import os, io, sys, subprocess, base64
11
  from typing import Any, Dict, List, Optional, Tuple
 
134
  if len(chunks) > 0 and len(chunks[0]) > 0 and chunks[0][0] == tokenizer.bos_token_id:
135
  offset = 1
136
  ids.append(chunks[0][0])
137
+ for x in insert_sep(chunks, [IMAGE_TOKEN_INDEX]*(offset+1)):
138
  ids.extend(x[offset:])
139
  if return_tensors == 'pt':
140
  return torch.tensor(ids, dtype=torch.long)
 
157
  )
158
  from llava.conversation import conv_templates
159
  from llava.utils import disable_torch_init
 
160
  from transformers import AutoProcessor, AutoImageProcessor, CLIPImageProcessor
161
 
162
  DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v1")
 
173
  else:
174
  model_path = MODEL_ID
175
 
176
+ self.model_name = get_model_name_from_path(model_path)
177
 
178
  try:
179
  import flash_attn # noqa
 
233
  print("[info] image_processor loaded via AutoProcessor(model_path)")
234
  except Exception as e:
235
  print(f"[warn] AutoProcessor başarısız: {e}")
236
+ vt_id = self._resolve_vision_tower_id(self.model.config)
237
  print(f"[hotfix] trying to load image_processor from vision_tower: {vt_id}")
238
  try:
239
  self.image_processor = AutoImageProcessor.from_pretrained(vt_id, trust_remote_code=True)
 
263
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
264
  self.is_multimodal = ('llava' in self.model_name.lower()) or ('pulse' in self.model_name.lower())
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  # ------------- helpers -------------
267
+ def _resolve_vision_tower_id(self, config: Any) -> str:
 
 
 
 
268
  for key in ("mm_vision_tower", "vision_tower", "mm_vision_tower_name", "image_tower", "visual_encoder"):
269
  v = getattr(config, key, None)
270
  if isinstance(v, str) and v.strip(): return v.strip()
271
+ try:
272
+ v = getattr(config, "vision_tower", None)
273
+ name = getattr(getattr(v, "config", None), "_name_or_path", None)
274
+ if isinstance(name, str) and name.strip(): return name.strip()
275
+ except Exception:
276
+ pass
 
277
  return DEFAULT_VISION_TOWER_ID
278
 
279
  def _normalize_image_processor(self) -> bool:
 
362
  conv.append_message(conv.roles[1], None)
363
  return conv.get_prompt()
364
 
365
+ # ------------- inference -------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
367
  inputs = data.get("inputs") or {}
368
  params = data.get("parameters") or {}
 
409
  import traceback; traceback.print_exc()
410
  images = None; image_sizes = None
411
 
412
+ # 3) tokenize
413
  try:
414
  mdev = next(self.model.parameters()).device
415
  input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') \
 
417
  print(f"[debug] input_ids shape: {input_ids.shape} | has images: {images is not None}")
418
  except Exception as e:
419
  print(f"[error] Tokenization failed: {e}")
420
+ try:
421
+ input_ids = self.tokenizer(query_text, return_tensors="pt").input_ids.to(next(self.model.parameters()).device)
422
+ images = None; image_sizes = None
423
+ print("[warn] Fallback to basic tokenization without image tokens")
424
+ except Exception as e2:
425
+ print(f"[error] Even basic tokenization failed: {e2}")
426
+ return [{"generated_text": f"Error: Tokenization failed: {str(e)}"}]
427
 
428
+ # 4) gen params (attention_mask YOK)
429
  temperature = float(params.get("temperature", 0.0))
430
  top_p = float(params.get("top_p", 1.0))
431
  repetition_penalty = float(params.get("repetition_penalty", 1.0))
 
437
  if max_new_tokens < 1:
438
  return [{"generated_text": "Error: Input too long, exceeds max token length."}]
439
 
 
440
  gen_kwargs = {
441
+ # KRİTİK: Hem `inputs` hem de `input_ids` veriyoruz
442
  "inputs": input_ids,
443
+ "input_ids": input_ids,
 
444
  "max_new_tokens": max_new_tokens,
445
  "temperature": temperature,
446
  "top_p": top_p,
447
  "repetition_penalty": repetition_penalty,
448
  "do_sample": do_sample,
449
+ # attention_mask verme!
450
+ "use_cache": bool(params.get("use_cache", True)),
451
  "pad_token_id": self.tokenizer.pad_token_id,
452
  "eos_token_id": getattr(self.tokenizer, "eos_token_id", None),
453
  "bos_token_id": getattr(self.tokenizer, "bos_token_id", None),
454
  }
455
+ if images is not None and image_sizes is not None:
456
  gen_kwargs["images"] = images
457
  gen_kwargs["image_sizes"] = image_sizes
458
 
459
+ # 5) generate
460
  try:
 
461
  with torch.inference_mode():
462
  output = self.model.generate(**gen_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  except Exception as e:
464
+ # Son çare: cache kapalı tekrar dene
465
+ print(f"[warn] First generate failed: {e} | retry with use_cache=False")
466
+ gen_kwargs["use_cache"] = False
467
  try:
468
+ with torch.inference_mode():
469
+ output = self.model.generate(**gen_kwargs)
 
 
 
 
470
  except Exception as e2:
471
+ print(f"[error] Generation failed: {e2}")
472
  import traceback; traceback.print_exc()
473
+ return [{"generated_text": f"Error during generation: {str(e2)}"}]
474
 
475
+ # 6) decode
476
  try:
477
  sequences = output.sequences if hasattr(output, "sequences") else output
478
  input_len = input_ids.shape[1]
479
+ response_ids = sequences[:, input_len:] if sequences.shape[-1] > input_len else sequences
 
 
 
480
  text = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)[0].strip()
481
+ if not text:
482
+ text = "Error: Empty response generated"
483
+ return [{"generated_text": text}]
 
 
 
484
  except Exception as e:
485
+ print(f"[error] Response decoding failed: {e}")
486
+ return [{"generated_text": f"Error: Response decoding failed: {str(e)}"}]