CanerDedeoglu commited on
Commit
f04ae31
·
verified ·
1 Parent(s): 7133210

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +55 -52
handler.py CHANGED
@@ -1,16 +1,14 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- PULSE ECG Handler — Demo Parity + Style Hint
4
- - Demo app.py ile aynı üretim ayarları:
5
- do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096
6
- - Stopping: konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter
7
  - Görsel tensörü: .half() ve model cihazında
8
  - Streamer: TextIteratorStreamer (demo gibi), thread ile generate
9
- - Seed/deterministic KAPALI (göndermezseniz); demo gibi stokastik
10
- - STYLE_HINT: demo üslubuna (narratif + sonda tek satır structured impression) yaklaşmak için
11
- - Post-process: YALNIZCA whitespace/biçim normalizasyonu (yönetim/öneri cümleleri korunur)
12
  """
13
-
14
  import os
15
  import re
16
  import json
@@ -20,7 +18,6 @@ import datetime
20
  from io import BytesIO
21
  from threading import Thread
22
  from typing import Optional, Union
23
-
24
  import torch
25
  from PIL import Image
26
  import requests
@@ -129,10 +126,10 @@ def load_image_any(image_input: Union[str, dict]) -> Image.Image:
129
  s = s.split(",", 1)[1]
130
  raw = base64.b64decode(s)
131
  return Image.open(BytesIO(raw)).convert("RGB")
132
-
133
  if isinstance(image_input, dict) and "image" in image_input:
134
  return load_image_any(image_input["image"])
135
-
136
  raise ValueError("Unsupported image input format")
137
 
138
  def _normalize_whitespace(text: str) -> str:
@@ -161,7 +158,7 @@ class SafeKeywordsStoppingCriteria(StoppingCriteria):
161
  self.tokenizer = tokenizer
162
  tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0]
163
  self.kw_ids = tok # shape: (n,)
164
-
165
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
166
  if input_ids is None or input_ids.shape[0] == 0:
167
  return False
@@ -193,11 +190,13 @@ class ChatSessionManager:
193
  self.chatbot = None
194
  self.args = None
195
  self.model_path = None
 
196
  def init_if_needed(self, args, model_path, tokenizer, model, image_processor, context_len):
197
  if self.chatbot is None:
198
  self.args = args
199
  self.model_path = model_path
200
  self.chatbot = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
 
201
  def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
202
  self.init_if_needed(args, model_path, tokenizer, model, image_processor, context_len)
203
  # Her çağrıda taze template (demo gibi yeni tur)
@@ -212,7 +211,6 @@ def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
212
  chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
213
  chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
214
  prompt = chatbot.conversation.get_prompt()
215
-
216
  input_ids = tokenizer_image_token(
217
  prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
218
  ).unsqueeze(0).to(device)
@@ -222,35 +220,33 @@ def generate_response(
222
  message_text: str,
223
  image_input,
224
  *,
225
- temperature: Optional[float] = None,
226
- top_p: Optional[float] = None,
227
  max_new_tokens: Optional[int] = None,
228
  conv_mode_override: Optional[str] = None,
229
  repetition_penalty: Optional[float] = None,
230
- det_seed: Optional[int] = None, # None stokastik (demo gibi)
231
  ):
232
  if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
233
  return {"error": "Required libraries not available (llava/transformers)"}
234
  if not message_text or image_input is None:
235
  return {"error": "Both 'message' and 'image' are required"}
236
-
237
- # Varsayılanlar → demo
238
- if temperature is None: temperature = 0.05
239
- if top_p is None: top_p = 1.0
240
  if max_new_tokens is None: max_new_tokens = 4096
241
  if repetition_penalty is None: repetition_penalty = 1.0 # etkisiz
242
-
243
  # Chat session
244
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
245
  if conv_mode_override and conv_mode_override in conv_templates:
246
  chatbot.conversation = conv_templates[conv_mode_override].copy()
247
-
248
  # Görüntü yükle
249
  try:
250
  pil_img = load_image_any(image_input)
251
  except Exception as e:
252
  return {"error": f"Failed to load image: {e}"}
253
-
254
  # Log için hash+path
255
  img_hash, img_path = "NA", None
256
  try:
@@ -263,11 +259,11 @@ def generate_response(
263
  pil_img.save(img_path)
264
  except Exception as e:
265
  print(f"[log] save image failed: {e}")
266
-
267
  # Cihaz/dtype
268
  device = next(chatbot.model.parameters()).device
269
  dtype = torch.float16 # demo: half
270
-
271
  # Görüntü ön-işleme → tensör
272
  try:
273
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
@@ -277,22 +273,23 @@ def generate_response(
277
  image_tensor = processed[0] if processed.ndim == 4 else processed
278
  else:
279
  return {"error": "Image processing returned empty"}
 
280
  if image_tensor.ndim == 3:
281
  image_tensor = image_tensor.unsqueeze(0) # (1,C,H,W)
282
  image_tensor = image_tensor.to(device=device, dtype=dtype) # demo: half + device
283
  except Exception as e:
284
  return {"error": f"Image processing failed: {e}"}
285
-
286
  # STYLE_HINT ekle ve prompt hazırla
287
  msg = (message_text or "").strip()
288
  msg = f"{msg}\n\n{STYLE_HINT}"
289
  _, input_ids = _build_prompt_and_ids(chatbot, msg, device)
290
-
291
  # Stop string (conv separator) → güvenli kriter
292
  stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
293
  stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer)
294
-
295
- # Seed (gönderilmediyse stokastik demo gibi)
296
  if det_seed is not None:
297
  try:
298
  s = int(det_seed)
@@ -302,26 +299,31 @@ def generate_response(
302
  torch.cuda.manual_seed_all(s)
303
  except Exception:
304
  pass
305
-
306
  # Streamer (demo gibi)
307
  streamer = TextIteratorStreamer(
308
  chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
309
  )
310
-
311
- # Generate kwargs — demo ayarları
312
  gen_kwargs = dict(
313
  inputs=input_ids,
314
  images=image_tensor,
315
  streamer=streamer,
316
- do_sample=True, # DEMO
317
- temperature=float(temperature), # DEMO default 0.05
318
- top_p=float(top_p), # DEMO default 1.0
319
- max_new_tokens=int(max_new_tokens), # DEMO slider
320
- repetition_penalty=float(repetition_penalty), # default 1.0 etkisiz
 
 
 
 
 
321
  use_cache=False,
322
- stopping_criteria=[stopping], # DEMO-benzeri durdurma
323
  )
324
-
325
  # Üretim (arka thread) + akışı topla
326
  try:
327
  t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
@@ -334,7 +336,7 @@ def generate_response(
334
  chatbot.conversation.messages[-1][-1] = text
335
  except Exception as e:
336
  return {"error": f"Generation failed: {e}"}
337
-
338
  # Log
339
  try:
340
  row = {
@@ -350,7 +352,7 @@ def generate_response(
350
  _safe_upload(_conv_log_path()); _safe_upload(img_path or "")
351
  except Exception as e:
352
  print(f"[log] failed: {e}")
353
-
354
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
355
 
356
  # ===================== Public API =====================
@@ -362,25 +364,24 @@ def query(payload: dict):
362
  if not initialize_model():
363
  return {"error": "Model initialization failed"}
364
  model_initialized = True
365
-
366
  try:
367
  message = payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or ""
368
  image = payload.get("image") or payload.get("image_url") or payload.get("img") or None
 
369
  if not message.strip(): return {"error": "Missing 'message' text"}
370
  if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
371
-
372
- # Demo varsayılanları payload override edebilir
373
- temperature = float(payload.get("temperature", 0.05))
374
  top_p = float(payload.get("top_p", 1.0))
375
  max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096))))
376
- repetition_penalty = float(payload.get("repetition_penalty", 1.0)) # etkisiz default
377
-
378
  conv_mode_override = payload.get("conv_mode", None)
379
  det_seed = payload.get("det_seed", None)
380
  if det_seed is not None:
381
  try: det_seed = int(det_seed)
382
  except Exception: det_seed = None
383
-
384
  return generate_response(
385
  message_text=message,
386
  image_input=image,
@@ -437,19 +438,18 @@ def initialize_model():
437
  tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model(
438
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
439
  )
440
- # demo: model'ı genelde cuda’da çalıştırır
441
  try:
442
  _ = next(model_.parameters()).device
443
  except Exception:
444
  if torch.cuda.is_available():
445
  model_ = model_.to(torch.device("cuda"))
 
446
  model_.eval()
447
-
448
  globals()["tokenizer"] = tokenizer_
449
  globals()["model"] = model_
450
  globals()["image_processor"] = image_processor_
451
  globals()["context_len"] = context_len_
452
-
453
  chat_manager.init_if_needed(args, args.model_path, tokenizer_, model_, image_processor_, context_len_)
454
  print("[init] model/tokenizer/image_processor loaded.")
455
  return True
@@ -464,14 +464,17 @@ class EndpointHandler:
464
  def __init__(self, model_dir):
465
  self.model_dir = model_dir
466
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
 
467
  def __call__(self, payload):
468
  if "inputs" in payload:
469
  return query(payload["inputs"])
470
  return query(payload)
 
471
  def health_check(self):
472
  return health_check()
 
473
  def get_model_info(self):
474
  return get_model_info()
475
 
476
  if __name__ == "__main__":
477
- print("Handler ready (Demo Parity + Style Hint + whitespace post-process). Use `EndpointHandler` or `query`.")
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ PULSE ECG Handler — Deterministik Versiyon
4
+ - Üretim ayarları: do_sample=False (Tutarlı çıktı), temperature/top_p etkisiz
5
+ - Stopping: Konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter
 
6
  - Görsel tensörü: .half() ve model cihazında
7
  - Streamer: TextIteratorStreamer (demo gibi), thread ile generate
8
+ - Seed/deterministic KAPALI (do_sample=False ile determinizm sağlanır)
9
+ - STYLE_HINT: demo üslubuna yaklaşmak için
10
+ - Post-process: YALNIZCA whitespace/biçim normalizasyonu
11
  """
 
12
  import os
13
  import re
14
  import json
 
18
  from io import BytesIO
19
  from threading import Thread
20
  from typing import Optional, Union
 
21
  import torch
22
  from PIL import Image
23
  import requests
 
126
  s = s.split(",", 1)[1]
127
  raw = base64.b64decode(s)
128
  return Image.open(BytesIO(raw)).convert("RGB")
129
+
130
  if isinstance(image_input, dict) and "image" in image_input:
131
  return load_image_any(image_input["image"])
132
+
133
  raise ValueError("Unsupported image input format")
134
 
135
  def _normalize_whitespace(text: str) -> str:
 
158
  self.tokenizer = tokenizer
159
  tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0]
160
  self.kw_ids = tok # shape: (n,)
161
+
162
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
163
  if input_ids is None or input_ids.shape[0] == 0:
164
  return False
 
190
  self.chatbot = None
191
  self.args = None
192
  self.model_path = None
193
+
194
  def init_if_needed(self, args, model_path, tokenizer, model, image_processor, context_len):
195
  if self.chatbot is None:
196
  self.args = args
197
  self.model_path = model_path
198
  self.chatbot = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
199
+
200
  def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
201
  self.init_if_needed(args, model_path, tokenizer, model, image_processor, context_len)
202
  # Her çağrıda taze template (demo gibi yeni tur)
 
211
  chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
212
  chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
213
  prompt = chatbot.conversation.get_prompt()
 
214
  input_ids = tokenizer_image_token(
215
  prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
216
  ).unsqueeze(0).to(device)
 
220
  message_text: str,
221
  image_input,
222
  *,
223
+ temperature: Optional[float] = None, # Deterministik modda yoksayılır
224
+ top_p: Optional[float] = None, # Deterministik modda yoksayılır
225
  max_new_tokens: Optional[int] = None,
226
  conv_mode_override: Optional[str] = None,
227
  repetition_penalty: Optional[float] = None,
228
+ det_seed: Optional[int] = None, # Deterministik modda yoksayılır
229
  ):
230
  if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
231
  return {"error": "Required libraries not available (llava/transformers)"}
232
  if not message_text or image_input is None:
233
  return {"error": "Both 'message' and 'image' are required"}
234
+
235
+ # Varsayılanlar
 
 
236
  if max_new_tokens is None: max_new_tokens = 4096
237
  if repetition_penalty is None: repetition_penalty = 1.0 # etkisiz
238
+
239
  # Chat session
240
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
241
  if conv_mode_override and conv_mode_override in conv_templates:
242
  chatbot.conversation = conv_templates[conv_mode_override].copy()
243
+
244
  # Görüntü yükle
245
  try:
246
  pil_img = load_image_any(image_input)
247
  except Exception as e:
248
  return {"error": f"Failed to load image: {e}"}
249
+
250
  # Log için hash+path
251
  img_hash, img_path = "NA", None
252
  try:
 
259
  pil_img.save(img_path)
260
  except Exception as e:
261
  print(f"[log] save image failed: {e}")
262
+
263
  # Cihaz/dtype
264
  device = next(chatbot.model.parameters()).device
265
  dtype = torch.float16 # demo: half
266
+
267
  # Görüntü ön-işleme → tensör
268
  try:
269
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
 
273
  image_tensor = processed[0] if processed.ndim == 4 else processed
274
  else:
275
  return {"error": "Image processing returned empty"}
276
+
277
  if image_tensor.ndim == 3:
278
  image_tensor = image_tensor.unsqueeze(0) # (1,C,H,W)
279
  image_tensor = image_tensor.to(device=device, dtype=dtype) # demo: half + device
280
  except Exception as e:
281
  return {"error": f"Image processing failed: {e}"}
282
+
283
  # STYLE_HINT ekle ve prompt hazırla
284
  msg = (message_text or "").strip()
285
  msg = f"{msg}\n\n{STYLE_HINT}"
286
  _, input_ids = _build_prompt_and_ids(chatbot, msg, device)
287
+
288
  # Stop string (conv separator) → güvenli kriter
289
  stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
290
  stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer)
291
+
292
+ # Seed (do_sample=False olduğu için önemsiz, ancak kodda bırakılabilir)
293
  if det_seed is not None:
294
  try:
295
  s = int(det_seed)
 
299
  torch.cuda.manual_seed_all(s)
300
  except Exception:
301
  pass
302
+
303
  # Streamer (demo gibi)
304
  streamer = TextIteratorStreamer(
305
  chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
306
  )
307
+
308
+ # Generate kwargs — Deterministik Ayarlar
309
  gen_kwargs = dict(
310
  inputs=input_ids,
311
  images=image_tensor,
312
  streamer=streamer,
313
+
314
+ # 🟢 ÖNEMLİ DEĞİŞİKLİK: Deterministiği (Tutarlılığı) Aç
315
+ do_sample=False,
316
+
317
+ # temperature ve top_p ayarları artık yoksayılır
318
+ # temperature=float(temperature),
319
+ # top_p=float(top_p),
320
+
321
+ max_new_tokens=int(max_new_tokens),
322
+ repetition_penalty=float(repetition_penalty),
323
  use_cache=False,
324
+ stopping_criteria=[stopping],
325
  )
326
+
327
  # Üretim (arka thread) + akışı topla
328
  try:
329
  t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
 
336
  chatbot.conversation.messages[-1][-1] = text
337
  except Exception as e:
338
  return {"error": f"Generation failed: {e}"}
339
+
340
  # Log
341
  try:
342
  row = {
 
352
  _safe_upload(_conv_log_path()); _safe_upload(img_path or "")
353
  except Exception as e:
354
  print(f"[log] failed: {e}")
355
+
356
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
357
 
358
  # ===================== Public API =====================
 
364
  if not initialize_model():
365
  return {"error": "Model initialization failed"}
366
  model_initialized = True
 
367
  try:
368
  message = payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or ""
369
  image = payload.get("image") or payload.get("image_url") or payload.get("img") or None
370
+
371
  if not message.strip(): return {"error": "Missing 'message' text"}
372
  if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
373
+
374
+ # Deterministik modda temperature/top_p yoksayılır, ancak API uyumluluğu için tutulur
375
+ temperature = float(payload.get("temperature", 0.0)) # Default 0.0
376
  top_p = float(payload.get("top_p", 1.0))
377
  max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096))))
378
+ repetition_penalty = float(payload.get("repetition_penalty", 1.0))
 
379
  conv_mode_override = payload.get("conv_mode", None)
380
  det_seed = payload.get("det_seed", None)
381
  if det_seed is not None:
382
  try: det_seed = int(det_seed)
383
  except Exception: det_seed = None
384
+
385
  return generate_response(
386
  message_text=message,
387
  image_input=image,
 
438
  tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model(
439
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
440
  )
441
+ # model'ı cuda’ya taşı
442
  try:
443
  _ = next(model_.parameters()).device
444
  except Exception:
445
  if torch.cuda.is_available():
446
  model_ = model_.to(torch.device("cuda"))
447
+
448
  model_.eval()
 
449
  globals()["tokenizer"] = tokenizer_
450
  globals()["model"] = model_
451
  globals()["image_processor"] = image_processor_
452
  globals()["context_len"] = context_len_
 
453
  chat_manager.init_if_needed(args, args.model_path, tokenizer_, model_, image_processor_, context_len_)
454
  print("[init] model/tokenizer/image_processor loaded.")
455
  return True
 
464
  def __init__(self, model_dir):
465
  self.model_dir = model_dir
466
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
467
+
468
  def __call__(self, payload):
469
  if "inputs" in payload:
470
  return query(payload["inputs"])
471
  return query(payload)
472
+
473
  def health_check(self):
474
  return health_check()
475
+
476
  def get_model_info(self):
477
  return get_model_info()
478
 
479
  if __name__ == "__main__":
480
+ print("Handler ready (Deterministik Mode: do_sample=False). Use `EndpointHandler` or `query`.")