CanerDedeoglu commited on
Commit
2ad0be8
·
verified ·
1 Parent(s): 1d76b32

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +58 -29
handler.py CHANGED
@@ -9,6 +9,10 @@ import requests
9
  # ===== Kullanılacak HF model id =====
10
  MODEL_ID = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
11
 
 
 
 
 
12
  # ===== LLaVA kaynak kodunu runtime'da getir (pip yok) =====
13
  LLAVA_GIT_URL = os.getenv("LLAVA_GIT_URL", "https://github.com/haotian-liu/LLaVA.git")
14
  LLAVA_GIT_REF = os.getenv("LLAVA_GIT_REF", "v1.2.2.post1") # kanıtlı, stabil
@@ -28,35 +32,26 @@ _ensure_llava()
28
 
29
  # ---- LLaVA parçaları (demo akışı) ----
30
  from llava.model.builder import load_pretrained_model
31
- from llava.mm_utils import tokenizer_image_token, process_images
32
  from llava.constants import (
33
  IMAGE_TOKEN_INDEX,
34
  DEFAULT_IMAGE_TOKEN,
35
  DEFAULT_IM_START_TOKEN,
36
  DEFAULT_IM_END_TOKEN,
 
 
 
37
  )
38
  from llava.conversation import conv_templates
39
  from llava.utils import disable_torch_init
40
 
41
- # Eksik fonksiyonu kendimiz tanımlıyoruz
42
- def get_model_name_from_path(model_path):
43
-
44
- model_path = model_path.strip("/")
45
-
46
- model_paths = model_path.split("/")
47
-
48
- if model_paths[-1].startswith('checkpoint-'):
49
-
50
- return model_paths[-2] + "_" + model_paths[-1]
51
-
52
- else:
53
-
54
- return model_paths[-1]
55
 
56
  # Varsayılanlar
57
  DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_llama_2")
58
  MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
59
- os.environ.setdefault("ATTN_IMPLEMENTATION", os.getenv("ATTN_IMPLEMENTATION", "sdpa"))
60
 
61
  class EndpointHandler:
62
  """
@@ -82,22 +77,28 @@ class EndpointHandler:
82
 
83
  self.model_name = get_model_name_from_path(model_path)
84
 
 
 
 
 
 
 
 
85
  # PULSE, LLaVA tabanlı olduğundan LLaVA loader ile yüklenir
86
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
87
  model_path=model_path,
88
  model_base=None,
89
  model_name=self.model_name,
90
  torch_dtype="auto",
91
- attn_implementation=os.getenv("ATTN_IMPLEMENTATION", "sdpa"),
92
  device_map="auto",
93
  )
94
  self.model.eval()
95
 
96
  # Görsel token işaretleri (LLaVA config)
97
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
98
- self.image_token = DEFAULT_IMAGE_TOKEN
99
- self.im_start = DEFAULT_IM_START_TOKEN
100
- self.im_end = DEFAULT_IM_END_TOKEN
101
 
102
  # ---- yardımcılar ----
103
  def _load_image(self, img_field: str) -> Optional[Image.Image]:
@@ -122,10 +123,13 @@ class EndpointHandler:
122
  if conv_mode not in conv_templates:
123
  conv_mode = DEFAULT_CONV_MODE
124
  conv = conv_templates[conv_mode].copy()
 
 
125
  if self.use_im_start_end:
126
- content = f"{self.im_start}{self.image_token}{self.im_end}\n{user_text}"
127
  else:
128
- content = f"{self.image_token}\n{user_text}"
 
129
  conv.append_message(conv.roles[0], content)
130
  conv.append_message(conv.roles[1], None)
131
  return conv.get_prompt()
@@ -148,13 +152,25 @@ class EndpointHandler:
148
  if image_f:
149
  pil = self._load_image(image_f)
150
  if pil is not None:
151
- image_tensors = process_images([pil], self.image_processor, self.model.config)
152
- image_tensors = image_tensors.to(self.model.device, dtype=self.model.dtype, non_blocking=True)
 
 
 
 
 
 
 
153
 
154
  # 3) tokenize (image token'ı gömülü)
155
  input_ids = tokenizer_image_token(
156
  prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
157
- ).to(self.model.device, non_blocking=True)
 
 
 
 
 
158
 
159
  # 4) güvenli max_new_tokens
160
  requested = int(params.get("max_new_tokens", MAX_NEW_TOKENS_DEF))
@@ -162,16 +178,29 @@ class EndpointHandler:
162
  max_new_tokens = max(1, min(requested, avail))
163
 
164
  gen_kwargs = {
 
 
165
  "max_new_tokens": max_new_tokens,
166
  "temperature": float(params.get("temperature", 0.0)),
167
  "top_p": float(params.get("top_p", 1.0)),
168
  "repetition_penalty": float(params.get("repetition_penalty", 1.0)),
169
  "do_sample": bool(params.get("do_sample", float(params.get("temperature", 0.0)) > 0)),
170
  "use_cache": bool(params.get("use_cache", True)),
 
171
  }
172
 
173
- with torch.inference_mode():
174
- output_ids = self.model.generate(input_ids, images=image_tensors, **gen_kwargs)
175
-
176
- text = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
 
 
 
 
 
 
 
 
 
 
177
  return [{"generated_text": text}]
 
9
  # ===== Kullanılacak HF model id =====
10
  MODEL_ID = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
11
 
12
+ # Flash Attention için environment
13
+ os.environ.setdefault("FLASH_ATTENTION", "1")
14
+ os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
15
+
16
  # ===== LLaVA kaynak kodunu runtime'da getir (pip yok) =====
17
  LLAVA_GIT_URL = os.getenv("LLAVA_GIT_URL", "https://github.com/haotian-liu/LLaVA.git")
18
  LLAVA_GIT_REF = os.getenv("LLAVA_GIT_REF", "v1.2.2.post1") # kanıtlı, stabil
 
32
 
33
  # ---- LLaVA parçaları (demo akışı) ----
34
  from llava.model.builder import load_pretrained_model
35
+ from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
36
  from llava.constants import (
37
  IMAGE_TOKEN_INDEX,
38
  DEFAULT_IMAGE_TOKEN,
39
  DEFAULT_IM_START_TOKEN,
40
  DEFAULT_IM_END_TOKEN,
41
+ DEFAULT_IMAGE_PATCH_TOKEN,
42
+ IMAGE_PLACEHOLDER,
43
+ IGNORE_INDEX,
44
  )
45
  from llava.conversation import conv_templates
46
  from llava.utils import disable_torch_init
47
 
48
+ # Eksik fonksiyonu kaldır - artık mm_utils'ten import ediyoruz
49
+ # def get_model_name_from_path() artık gerekli değil
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Varsayılanlar
52
  DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_llama_2")
53
  MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
54
+ # ATTN_IMPLEMENTATION artık otomatik seçiliyor, bu satırı kaldırıyoruz
55
 
56
  class EndpointHandler:
57
  """
 
77
 
78
  self.model_name = get_model_name_from_path(model_path)
79
 
80
+ # Attention implementation otomatik seç
81
+ try:
82
+ import flash_attn
83
+ attn_impl = "flash_attention_2"
84
+ except ImportError:
85
+ attn_impl = "sdpa"
86
+
87
  # PULSE, LLaVA tabanlı olduğundan LLaVA loader ile yüklenir
88
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
89
  model_path=model_path,
90
  model_base=None,
91
  model_name=self.model_name,
92
  torch_dtype="auto",
93
+ attn_implementation=attn_impl,
94
  device_map="auto",
95
  )
96
  self.model.eval()
97
 
98
  # Görsel token işaretleri (LLaVA config)
99
  self.use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
100
+ # Constants'tan direkt kullan
101
+ # self.image_token, self.im_start, self.im_end artık gerekli değil
 
102
 
103
  # ---- yardımcılar ----
104
  def _load_image(self, img_field: str) -> Optional[Image.Image]:
 
123
  if conv_mode not in conv_templates:
124
  conv_mode = DEFAULT_CONV_MODE
125
  conv = conv_templates[conv_mode].copy()
126
+
127
+ # Image token'ları doğru yerleştir
128
  if self.use_im_start_end:
129
+ content = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
130
  else:
131
+ content = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
132
+
133
  conv.append_message(conv.roles[0], content)
134
  conv.append_message(conv.roles[1], None)
135
  return conv.get_prompt()
 
152
  if image_f:
153
  pil = self._load_image(image_f)
154
  if pil is not None:
155
+ try:
156
+ # LLaVA'nın gelişmiş process_images fonksiyonunu kullan
157
+ # Bu fonksiyon anyres, pad gibi farklı aspect ratio modlarını destekler
158
+ image_tensors = process_images([pil], self.image_processor, self.model.config)
159
+ if image_tensors is not None and len(image_tensors) > 0:
160
+ image_tensors = image_tensors.to(self.model.device, dtype=torch.float16, non_blocking=True)
161
+ except Exception as e:
162
+ print(f"[warn] image processing failed: {e}")
163
+ image_tensors = None
164
 
165
  # 3) tokenize (image token'ı gömülü)
166
  input_ids = tokenizer_image_token(
167
  prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
168
+ ).unsqueeze(0).to(self.model.device, non_blocking=True) # unsqueeze ekledik
169
+
170
+ # Input uzunluk kontrolü
171
+ if input_ids.shape[-1] > self.context_len - 100:
172
+ # Prompt'u kısalt
173
+ input_ids = input_ids[:, -(self.context_len - 200):]
174
 
175
  # 4) güvenli max_new_tokens
176
  requested = int(params.get("max_new_tokens", MAX_NEW_TOKENS_DEF))
 
178
  max_new_tokens = max(1, min(requested, avail))
179
 
180
  gen_kwargs = {
181
+ "input_ids": input_ids,
182
+ "images": image_tensors,
183
  "max_new_tokens": max_new_tokens,
184
  "temperature": float(params.get("temperature", 0.0)),
185
  "top_p": float(params.get("top_p", 1.0)),
186
  "repetition_penalty": float(params.get("repetition_penalty", 1.0)),
187
  "do_sample": bool(params.get("do_sample", float(params.get("temperature", 0.0)) > 0)),
188
  "use_cache": bool(params.get("use_cache", True)),
189
+ "pad_token_id": self.tokenizer.eos_token_id,
190
  }
191
 
192
+ try:
193
+ with torch.inference_mode():
194
+ output_ids = self.model.generate(**gen_kwargs)
195
+
196
+ # Output'u input'tan ayır
197
+ if output_ids.shape[-1] > input_ids.shape[-1]:
198
+ response_ids = output_ids[:, input_ids.shape[-1]:]
199
+ text = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)[0].strip()
200
+ else:
201
+ text = "Error: No response generated"
202
+
203
+ except Exception as e:
204
+ print(f"Generation error: {e}")
205
+ text = f"Error during generation: {str(e)}"
206
  return [{"generated_text": text}]