CanerDedeoglu commited on
Commit
05bcc3b
·
verified ·
1 Parent(s): dbf6dc8

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +19 -20
handler.py CHANGED
@@ -1,12 +1,13 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- PULSE ECG Handler — Demo Parity Mode
4
  - Demo app.py ile aynı üretim ayarları:
5
  do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096
6
  - Stopping: konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter
7
  - Görsel tensörü: .half() ve model cihazında
8
  - Streamer: TextIteratorStreamer (demo gibi), thread ile generate
9
  - Seed/deterministic KAPALI (göndermezseniz); demo gibi stokastik
 
10
  """
11
 
12
  import os
@@ -80,6 +81,13 @@ context_len = None
80
  args = None
81
  model_initialized = False
82
 
 
 
 
 
 
 
 
83
 
84
  # ===================== Utilities =====================
85
 
@@ -140,7 +148,6 @@ def _wrap_image_token_if_needed(model_cfg) -> bool:
140
  except Exception:
141
  return False
142
 
143
-
144
  # ====== Güvenli Stop Kriteri (demo eşleniği) ======
145
  class SafeKeywordsStoppingCriteria(StoppingCriteria):
146
  """
@@ -153,7 +160,6 @@ class SafeKeywordsStoppingCriteria(StoppingCriteria):
153
  self.kw_ids = tok # shape: (n,)
154
 
155
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
156
- # input_ids: (bsz, seq_len)
157
  if input_ids is None or input_ids.shape[0] == 0:
158
  return False
159
  out = input_ids[0] # assume bsz=1
@@ -161,18 +167,15 @@ class SafeKeywordsStoppingCriteria(StoppingCriteria):
161
  if out.shape[0] < n:
162
  return False
163
  tail = out[-n:]
164
- # cihaz hizası
165
  kw = self.kw_ids.to(tail.device)
166
  return torch.equal(tail, kw)
167
 
168
-
169
  # ===================== Core Generation =====================
170
 
171
  def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
172
  # demo gibi: <image> + text (IM_START/END gerekiyorsa sar)
173
  use_wrap = _wrap_image_token_if_needed(chatbot.model.config)
174
  if use_wrap:
175
- # <im_start><image><im_end>\n + user text
176
  inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
177
  else:
178
  inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
@@ -195,7 +198,6 @@ def generate_response(
195
  max_new_tokens: Optional[int] = None,
196
  conv_mode_override: Optional[str] = None,
197
  repetition_penalty: Optional[float] = None, # demo'da yok; verilirse 1.0 yaparız
198
- # NOT: no_repeat_ngram_size / min_new_tokens / custom_stop KULLANMIYORUZ → demo-parite
199
  det_seed: Optional[int] = None, # seed gönderilmezse stokastik (demo gibi)
200
  ):
201
  if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
@@ -237,30 +239,31 @@ def generate_response(
237
 
238
  # Cihaz/dtype
239
  device = next(chatbot.model.parameters()).device
240
- # demo half: .half() kullanacağız
241
- dtype = torch.float16
242
 
243
  # Görüntü ön-işleme → tensör
244
  try:
245
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
246
- # LLaVA genelde list döndürür
247
  if isinstance(processed, (list, tuple)) and len(processed) > 0:
248
  image_tensor = processed[0]
249
  elif isinstance(processed, torch.Tensor):
250
- image_tensor = processed[0] if processed.ndim == 4 else processed # güvenlik
251
  else:
252
  return {"error": "Image processing returned empty"}
253
  if image_tensor.ndim == 3:
254
  image_tensor = image_tensor.unsqueeze(0) # (1,C,H,W)
255
- # demo: half + device
256
- image_tensor = image_tensor.to(device=device, dtype=dtype)
257
  except Exception as e:
258
  return {"error": f"Image processing failed: {e}"}
259
 
 
 
 
 
260
  # Prompt & input ids
261
  _, input_ids = _build_prompt_and_ids(chatbot, message_text, device)
262
 
263
- # Stop string from conv
264
  stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
265
  stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer)
266
 
@@ -324,11 +327,10 @@ def generate_response(
324
 
325
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
326
 
327
-
328
  # ===================== Public API =====================
329
 
330
  def query(payload: dict):
331
- """HF Endpoint entry (demo parity)."""
332
  global model_initialized, tokenizer, model, image_processor, context_len, args
333
  if not model_initialized:
334
  if not initialize_model():
@@ -384,7 +386,6 @@ def get_model_info():
384
  "device": str(next(model.parameters()).device) if model else "Unknown",
385
  }
386
 
387
-
388
  # ===================== Init & Session =====================
389
 
390
  class _Args:
@@ -448,7 +449,6 @@ def initialize_model():
448
  model_ = model_.to(torch.device("cuda"))
449
  model_.eval()
450
 
451
- # assign globals
452
  globals()["tokenizer"] = tokenizer_
453
  globals()["model"] = model_
454
  globals()["image_processor"] = image_processor_
@@ -461,7 +461,6 @@ def initialize_model():
461
  print(f"[init] failed: {e}")
462
  return False
463
 
464
-
465
  # ===================== HF EndpointHandler =====================
466
 
467
  class EndpointHandler:
@@ -479,4 +478,4 @@ class EndpointHandler:
479
  return get_model_info()
480
 
481
  if __name__ == "__main__":
482
- print("Handler ready (Demo Parity Mode). Use `EndpointHandler` or `query`.")
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ PULSE ECG Handler — Demo Parity + Style Hint
4
  - Demo app.py ile aynı üretim ayarları:
5
  do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096
6
  - Stopping: konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter
7
  - Görsel tensörü: .half() ve model cihazında
8
  - Streamer: TextIteratorStreamer (demo gibi), thread ile generate
9
  - Seed/deterministic KAPALI (göndermezseniz); demo gibi stokastik
10
+ - STYLE_HINT: demo üslubuna (narratif + sonda tek satır structured impression) yaklaşmak için
11
  """
12
 
13
  import os
 
81
  args = None
82
  model_initialized = False
83
 
84
+ # ====== Demo üslubuna yönlendiren stil ipucu ======
85
+ STYLE_HINT = (
86
+ "Write a concise diagnostic narrative as in a cardiology read: "
87
+ "use 2–3 short paragraphs describing rhythm, rate, axis, chamber enlargement, conduction, QRS, ST–T, QT; "
88
+ "then finish with a single final line starting exactly with 'Structured clinical impression:'. "
89
+ "Do not include recommendations, prognosis, follow-up, or risk counseling. No emojis or bullet points."
90
+ )
91
 
92
  # ===================== Utilities =====================
93
 
 
148
  except Exception:
149
  return False
150
 
 
151
  # ====== Güvenli Stop Kriteri (demo eşleniği) ======
152
  class SafeKeywordsStoppingCriteria(StoppingCriteria):
153
  """
 
160
  self.kw_ids = tok # shape: (n,)
161
 
162
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
163
  if input_ids is None or input_ids.shape[0] == 0:
164
  return False
165
  out = input_ids[0] # assume bsz=1
 
167
  if out.shape[0] < n:
168
  return False
169
  tail = out[-n:]
 
170
  kw = self.kw_ids.to(tail.device)
171
  return torch.equal(tail, kw)
172
 
 
173
  # ===================== Core Generation =====================
174
 
175
  def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
176
  # demo gibi: <image> + text (IM_START/END gerekiyorsa sar)
177
  use_wrap = _wrap_image_token_if_needed(chatbot.model.config)
178
  if use_wrap:
 
179
  inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
180
  else:
181
  inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
 
198
  max_new_tokens: Optional[int] = None,
199
  conv_mode_override: Optional[str] = None,
200
  repetition_penalty: Optional[float] = None, # demo'da yok; verilirse 1.0 yaparız
 
201
  det_seed: Optional[int] = None, # seed gönderilmezse stokastik (demo gibi)
202
  ):
203
  if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
 
239
 
240
  # Cihaz/dtype
241
  device = next(chatbot.model.parameters()).device
242
+ dtype = torch.float16 # demo: half
 
243
 
244
  # Görüntü ön-işleme → tensör
245
  try:
246
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
 
247
  if isinstance(processed, (list, tuple)) and len(processed) > 0:
248
  image_tensor = processed[0]
249
  elif isinstance(processed, torch.Tensor):
250
+ image_tensor = processed[0] if processed.ndim == 4 else processed
251
  else:
252
  return {"error": "Image processing returned empty"}
253
  if image_tensor.ndim == 3:
254
  image_tensor = image_tensor.unsqueeze(0) # (1,C,H,W)
255
+ image_tensor = image_tensor.to(device=device, dtype=dtype) # demo: half + device
 
256
  except Exception as e:
257
  return {"error": f"Image processing failed: {e}"}
258
 
259
+ # --------- STIL İPUCU EKLEME ---------
260
+ message_text = (message_text or "").strip() + "\n\n" + STYLE_HINT
261
+ # -------------------------------------
262
+
263
  # Prompt & input ids
264
  _, input_ids = _build_prompt_and_ids(chatbot, message_text, device)
265
 
266
+ # Stop string (conv separator) → güvenli kriter
267
  stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
268
  stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer)
269
 
 
327
 
328
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
329
 
 
330
  # ===================== Public API =====================
331
 
332
  def query(payload: dict):
333
+ """HF Endpoint entry (demo parity + style hint)."""
334
  global model_initialized, tokenizer, model, image_processor, context_len, args
335
  if not model_initialized:
336
  if not initialize_model():
 
386
  "device": str(next(model.parameters()).device) if model else "Unknown",
387
  }
388
 
 
389
  # ===================== Init & Session =====================
390
 
391
  class _Args:
 
449
  model_ = model_.to(torch.device("cuda"))
450
  model_.eval()
451
 
 
452
  globals()["tokenizer"] = tokenizer_
453
  globals()["model"] = model_
454
  globals()["image_processor"] = image_processor_
 
461
  print(f"[init] failed: {e}")
462
  return False
463
 
 
464
  # ===================== HF EndpointHandler =====================
465
 
466
  class EndpointHandler:
 
478
  return get_model_info()
479
 
480
  if __name__ == "__main__":
481
+ print("Handler ready (Demo Parity + Style Hint). Use `EndpointHandler` or `query`.")