CanerDedeoglu commited on
Commit
91e44e7
·
verified ·
1 Parent(s): b1d5c20

no stop added

Browse files
Files changed (1) hide show
  1. handler.py +69 -44
handler.py CHANGED
@@ -1,8 +1,9 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- PULSE ECG Handler - Demo-like (sampling) LLaVA endpoint
4
- - Demo davranışı: do_sample=True, temperature/top_p payload'dan alınır
5
- - max_new_tokens: payload/slider değeri; bağlam limitine göre güvenli kırpma
 
6
  - Tek görsel işleme; IM_START/END otomatik; 3D/4D/5D tensör uyumlu
7
  - Çıktıya post-format/deduplicate UYGULANMAZ (demo ile bire bir)
8
  """
@@ -99,7 +100,8 @@ def _safe_upload(path: str):
99
  def _conv_log_path():
100
  t = datetime.datetime.now()
101
  p = os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json")
102
- os.makedirs(os.path.dirname(p), exist_ok=True)
 
103
  return p
104
 
105
  def load_image_any(image_input):
@@ -171,16 +173,6 @@ def _stopping(chatbot, input_ids):
171
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
172
  return KeywordsStoppingCriteria([stop_str], chatbot.tokenizer, input_ids)
173
 
174
- def _safe_max_new_tokens(requested: int, input_len: int, ctx_limit: int) -> int:
175
- """
176
- Demo'da slider değeri doğrudan kullanılıyor; burada ek güvenlik:
177
- toplam (input + new + rezerv) <= ctx_limit olacak şekilde kırp.
178
- """
179
- requested = max(1, min(int(requested), 8192))
180
- reserve = 16
181
- available = max(32, ctx_limit - input_len - reserve)
182
- return max(1, min(requested, available))
183
-
184
  def generate_response(
185
  message_text: str,
186
  image_input,
@@ -191,6 +183,8 @@ def generate_response(
191
  repetition_penalty: float = 1.0,
192
  conv_mode_override: str | None = None,
193
  det_seed: int | None = None,
 
 
194
  ):
195
  if not LLAVA_AVAILABLE:
196
  return {"error": "LLaVA modules not available"}
@@ -248,43 +242,62 @@ def generate_response(
248
  else:
249
  return {"error": "Image processing returned empty"}
250
 
251
- # Demo'da çoğunlukla half + to(device) kullanılıyor
252
- image_tensor = image_tensor.to(device=device, dtype=getattr(torch, "float16", torch.float16))
253
  except Exception as e:
254
  return {"error": f"Image processing failed: {e}"}
255
 
256
  # Prompt & tokenizasyon
257
  prompt, input_ids = _build_prompt_and_ids(chatbot, message_text, device)
258
- stopping = _stopping(chatbot, input_ids)
259
-
260
- # max_new_tokens'ı güvenle kırp (demo slider + bağlam tavanı)
261
- ctx_limit = context_len or getattr(chatbot.model.config, "max_position_embeddings", 8192)
262
- max_new_tokens = _safe_max_new_tokens(max_new_tokens, input_ids.shape[1], ctx_limit)
263
 
264
- # Demo: sampling açık; istenirse deterministik sample için seed verilebilir
265
  if det_seed is not None:
266
- torch.manual_seed(det_seed)
267
- if torch.cuda.is_available():
268
- torch.cuda.manual_seed(det_seed)
269
- torch.cuda.manual_seed_all(det_seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
 
271
  try:
272
  with torch.no_grad():
273
- outputs = chatbot.model.generate(
274
- inputs=input_ids,
275
- images=image_tensor,
276
- do_sample=True,
277
- temperature=float(temperature),
278
- top_p=float(top_p),
279
- repetition_penalty=float(repetition_penalty),
280
- max_new_tokens=int(max_new_tokens),
281
- use_cache=False,
282
- pad_token_id=chatbot.tokenizer.eos_token_id,
283
- eos_token_id=chatbot.tokenizer.eos_token_id,
284
- length_penalty=1.0,
285
- early_stopping=False,
286
- stopping_criteria=[stopping],
287
- )
288
  gen = outputs[0][input_ids.shape[1]:]
289
  text = chatbot.tokenizer.decode(gen, skip_special_tokens=True)
290
 
@@ -303,7 +316,7 @@ def generate_response(
303
  "image_hash": img_hash,
304
  "image_path": img_path or "",
305
  }
306
- with open(_conv_log_path(), "a") as f:
307
  f.write(json.dumps(row, ensure_ascii=False) + "\n")
308
  _safe_upload(_conv_log_path())
309
  if img_path:
@@ -340,7 +353,7 @@ def query(payload: dict):
340
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
341
  conv_mode_override = payload.get("conv_mode", None)
342
 
343
- # (Opsiyonel) deterministik sample için seed (demo defaultu: None)
344
  det_seed = payload.get("det_seed", None)
345
  if det_seed is not None:
346
  try:
@@ -348,6 +361,17 @@ def query(payload: dict):
348
  except Exception:
349
  det_seed = None
350
 
 
 
 
 
 
 
 
 
 
 
 
351
  return generate_response(
352
  message_text=message,
353
  image_input=image,
@@ -357,6 +381,8 @@ def query(payload: dict):
357
  repetition_penalty=repetition_penalty,
358
  conv_mode_override=conv_mode_override,
359
  det_seed=det_seed,
 
 
360
  )
361
  except Exception as e:
362
  return {"error": f"Query failed: {e}"}
@@ -471,4 +497,3 @@ class EndpointHandler:
471
 
472
  if __name__ == "__main__":
473
  print("Handler ready. Use `EndpointHandler` or `query` for HF Inference Endpoints.")
474
-
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ PULSE ECG Handler - Demo-like (sampling) + no_stop bayrağı
4
+ - Demo davranışı: do_sample=True, temperature/top_p payload'dan
5
+ - max_new_tokens: payload/slider değeri (KIRPMA YOK, direkt kullanılır)
6
+ - İsteğe bağlı: no_stop=True ile stopping_criteria devre dışı
7
  - Tek görsel işleme; IM_START/END otomatik; 3D/4D/5D tensör uyumlu
8
  - Çıktıya post-format/deduplicate UYGULANMAZ (demo ile bire bir)
9
  """
 
100
  def _conv_log_path():
101
  t = datetime.datetime.now()
102
  p = os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json")
103
+ os.makedirs(os.path.dirname(p), exist_ok=True
104
+ )
105
  return p
106
 
107
  def load_image_any(image_input):
 
173
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
174
  return KeywordsStoppingCriteria([stop_str], chatbot.tokenizer, input_ids)
175
 
 
 
 
 
 
 
 
 
 
 
176
  def generate_response(
177
  message_text: str,
178
  image_input,
 
183
  repetition_penalty: float = 1.0,
184
  conv_mode_override: str | None = None,
185
  det_seed: int | None = None,
186
+ no_stop: bool = False,
187
+ min_new_tokens: int | None = None, # opsiyonel, uzunluğu zorlamak istersen
188
  ):
189
  if not LLAVA_AVAILABLE:
190
  return {"error": "LLaVA modules not available"}
 
242
  else:
243
  return {"error": "Image processing returned empty"}
244
 
245
+ # Demo tarafında half + to(device) kalıbı yaygın
246
+ image_tensor = image_tensor.to(device=device, dtype=dtype)
247
  except Exception as e:
248
  return {"error": f"Image processing failed: {e}"}
249
 
250
  # Prompt & tokenizasyon
251
  prompt, input_ids = _build_prompt_and_ids(chatbot, message_text, device)
252
+ stopping = None if no_stop else _stopping(chatbot, input_ids)
 
 
 
 
253
 
254
+ # (opsiyonel) deterministik sampling
255
  if det_seed is not None:
256
+ try:
257
+ det_seed = int(det_seed)
258
+ torch.manual_seed(det_seed)
259
+ if torch.cuda.is_available():
260
+ torch.cuda.manual_seed(det_seed)
261
+ torch.cuda.manual_seed_all(det_seed)
262
+ except Exception:
263
+ pass
264
+
265
+ # EOS/PAD güvenli al
266
+ eos_id = chatbot.tokenizer.eos_token_id
267
+ if eos_id is None:
268
+ try:
269
+ eos_id = chatbot.tokenizer.convert_tokens_to_ids("</s>")
270
+ except Exception:
271
+ eos_id = 0
272
+
273
+ # generate kwargs (demo-like)
274
+ gen_kwargs = dict(
275
+ inputs=input_ids,
276
+ images=image_tensor,
277
+ do_sample=True,
278
+ temperature=float(temperature),
279
+ top_p=float(top_p),
280
+ repetition_penalty=float(repetition_penalty),
281
+ max_new_tokens=int(max_new_tokens), # KIRPMA YOK
282
+ use_cache=False,
283
+ pad_token_id=eos_id,
284
+ eos_token_id=eos_id,
285
+ length_penalty=1.0,
286
+ early_stopping=False,
287
+ stopping_criteria=None if no_stop else [stopping],
288
+ )
289
+ if min_new_tokens is not None:
290
+ try:
291
+ mn = int(min_new_tokens)
292
+ if mn > 0 and mn <= int(max_new_tokens):
293
+ gen_kwargs["min_new_tokens"] = mn
294
+ except Exception:
295
+ pass
296
 
297
+ # Üretim
298
  try:
299
  with torch.no_grad():
300
+ outputs = chatbot.model.generate(**gen_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  gen = outputs[0][input_ids.shape[1]:]
302
  text = chatbot.tokenizer.decode(gen, skip_special_tokens=True)
303
 
 
316
  "image_hash": img_hash,
317
  "image_path": img_path or "",
318
  }
319
+ with open(_conv_log_path(), "a", encoding="utf-8") as f:
320
  f.write(json.dumps(row, ensure_ascii=False) + "\n")
321
  _safe_upload(_conv_log_path())
322
  if img_path:
 
353
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
354
  conv_mode_override = payload.get("conv_mode", None)
355
 
356
+ # (Opsiyonel) deterministik sample için seed
357
  det_seed = payload.get("det_seed", None)
358
  if det_seed is not None:
359
  try:
 
361
  except Exception:
362
  det_seed = None
363
 
364
+ # (Yeni) stopping_criteria kapatma bayrağı
365
+ no_stop = bool(payload.get("no_stop", False))
366
+
367
+ # (Opsiyonel) min_new_tokens
368
+ mnt = payload.get("min_new_tokens", None)
369
+ if mnt is not None:
370
+ try:
371
+ mnt = int(mnt)
372
+ except Exception:
373
+ mnt = None
374
+
375
  return generate_response(
376
  message_text=message,
377
  image_input=image,
 
381
  repetition_penalty=repetition_penalty,
382
  conv_mode_override=conv_mode_override,
383
  det_seed=det_seed,
384
+ no_stop=no_stop,
385
+ min_new_tokens=mnt,
386
  )
387
  except Exception as e:
388
  return {"error": f"Query failed: {e}"}
 
497
 
498
  if __name__ == "__main__":
499
  print("Handler ready. Use `EndpointHandler` or `query` for HF Inference Endpoints.")