CanerDedeoglu commited on
Commit
f9e643a
·
verified ·
1 Parent(s): ab44485

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +70 -39
handler.py CHANGED
@@ -3,8 +3,9 @@
3
  PULSE ECG Handler - Deterministic ECG Analysis Model (app.py uyumlu)
4
  - Deterministic (do_sample=False, sabit seed)
5
  - Tek görüntü, LLaVA conv_template + <image> token akışı
6
- - Model dtype/device ile uyumlu görüntü tensörü
7
  - Sağlam URL/base64 işleme, güvenli logging, opsiyonel HF upload
 
8
  """
9
 
10
  import os
@@ -19,7 +20,7 @@ from io import BytesIO
19
 
20
  # --- Opsiyonel bağımlılıklar ---
21
  try:
22
- import numpy as np # isteğe bağlı, kullanılabilir
23
  except Exception:
24
  np = None
25
 
@@ -55,7 +56,7 @@ except Exception as e:
55
 
56
  # Transformers
57
  try:
58
- from transformers import TextIteratorStreamer # kullanılmıyor ama mevcutsa sorun değil
59
  TRANSFORMERS_AVAILABLE = True
60
  except Exception:
61
  TRANSFORMERS_AVAILABLE = False
@@ -98,10 +99,23 @@ args = None
98
  model_initialized = False
99
 
100
  # --- Tutarlılık ayarları ---
101
- # Tutarlılık ayarları
102
  PROMPT_NORMALIZATION = True
103
- DEFAULT_ECG_PROMPT = "Perform a detailed ECG interpretation of the provided image. Analyze step by step the rhythm, heart rate, cardiac axis, P waves, PR interval, QRS complex morphology and duration, ST segments, T waves, and QT/QTc interval. Highlight any abnormalities, conduction disturbances, or ischemic changes you detect. Conclude with a structured clinical impression of the overall ECG."
104
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  # ---------- Yardımcılar ----------
107
 
@@ -287,22 +301,32 @@ def clear_history():
287
  except Exception as e:
288
  return {"error": f"Failed to clear history: {str(e)}"}
289
 
290
- # ---------- Cevap üretimi ----------
291
 
292
  def _build_prompt(chatbot, user_text: str) -> str:
293
- # App.py ile aynı: <image> token + kullanıcı metni
294
- image_token = DEFAULT_IMAGE_TOKEN
295
- inp = image_token + "\n" + user_text
 
 
 
 
 
 
 
 
 
296
  chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
297
  chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
298
- prompt = chatbot.conversation.get_prompt()
299
- return prompt
300
 
301
  def _stop_criteria_from_conv(chatbot, input_ids):
302
  conv = chatbot.conversation
303
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
304
  return KeywordsStoppingCriteria([stop_str], chatbot.tokenizer, input_ids)
305
 
 
 
306
  def generate_response(message_text,
307
  image_input,
308
  max_output_tokens=4096,
@@ -351,45 +375,37 @@ def generate_response(message_text,
351
  # Model dtype/device
352
  model_device = next(chatbot.model.parameters()).device
353
  model_dtype = next(chatbot.model.parameters()).dtype
354
-
355
  # Görüntü tensörü (Tensor/list/tuple + 3D/4D/5D destekli)
356
  try:
357
  processed = process_images([image], chatbot.image_processor, chatbot.model.config)
358
-
359
  if isinstance(processed, torch.Tensor):
360
- # Olası şekiller: (C,H,W), (B,C,H,W), (B,T,C,H,W)
361
  if processed.ndim == 3:
362
- # (C,H,W) -> (1,C,H,W)
363
- image_tensor = processed.unsqueeze(0)
364
  elif processed.ndim == 4:
365
- # (B,C,H,W)
366
- image_tensor = processed
367
  elif processed.ndim == 5:
368
- # (B,T,C,H,W) -> (B*T,C,H,W)
369
  b, t, c, h, w = processed.shape
370
- image_tensor = processed.reshape(b * t, c, h, w)
371
  else:
372
  return {"error": f"Unexpected image tensor shape: {tuple(processed.shape)}"}
373
-
374
  elif isinstance(processed, (list, tuple)):
375
  if len(processed) == 0:
376
  return {"error": "Image processing returned empty list"}
377
  first = processed[0]
378
  if not isinstance(first, torch.Tensor):
379
  return {"error": f"Processed image type not tensor: {type(first)}"}
380
- # first: (C,H,W) veya (B,C,H,W)
381
  image_tensor = first.unsqueeze(0) if first.ndim == 3 else first
382
  else:
383
  return {"error": f"Unsupported processed type: {type(processed)}"}
384
-
385
- # Cihaz ve dtype eşle
386
  image_tensor = image_tensor.to(device=model_device, dtype=model_dtype)
387
-
388
  except Exception as e:
389
  return {"error": f"Image processing failed: {str(e)}"}
390
 
391
-
392
-
393
  # Prompt & tokenizasyon
394
  prompt = _build_prompt(chatbot, message_text)
395
  input_ids = tokenizer_image_token(
@@ -405,6 +421,14 @@ def generate_response(message_text,
405
  torch.cuda.manual_seed(42)
406
  torch.cuda.manual_seed_all(42)
407
 
 
 
 
 
 
 
 
 
408
  try:
409
  with torch.no_grad():
410
  outputs = chatbot.model.generate(
@@ -412,10 +436,11 @@ def generate_response(message_text,
412
  images=image_tensor,
413
  do_sample=False, # deterministik
414
  max_new_tokens=int(max_output_tokens),
 
415
  repetition_penalty=float(repetition_penalty),
416
  use_cache=False,
417
- pad_token_id=chatbot.tokenizer.eos_token_id,
418
- eos_token_id=chatbot.tokenizer.eos_token_id,
419
  length_penalty=1.0,
420
  early_stopping=False,
421
  stopping_criteria=[stopping_criteria],
@@ -480,10 +505,13 @@ def query(payload):
480
  or ""
481
  )
482
 
483
- # Prompt normalization (ECG + diagnosis içeriyorsa)
484
  if PROMPT_NORMALIZATION and "ecg" in message_text.lower():
485
  if "concise" in message_text.lower():
486
- message_text = "Provide a short, concise clinical summary of the ECG."
 
 
 
487
  else:
488
  message_text = DEFAULT_ECG_PROMPT
489
 
@@ -498,7 +526,7 @@ def query(payload):
498
  # Parametreler
499
  max_output_tokens = int(payload.get("max_output_tokens",
500
  payload.get("max_new_tokens",
501
- payload.get("max_tokens", 2048))))
502
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
503
  conv_mode_override = payload.get("conv_mode", None)
504
 
@@ -576,27 +604,30 @@ def initialize_model():
576
  self.model_base = None
577
  self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
578
  self.conv_mode = None
579
- self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "2048"))
580
  self.num_frames = 16
581
  self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
582
  self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0")))
583
  self.debug = bool(int(os.getenv("DEBUG", "0")))
584
-
585
- args = Args()
586
 
587
  model_name = get_model_name_from_path(args.model_path)
588
- tokenizer, model, image_processor, context_len = load_pretrained_model(
589
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
590
  )
 
591
 
592
- # Device: accelerate devicemap kullanıyorsa ek .to('cuda') gerekmez
593
  try:
594
  _ = next(model.parameters()).device
595
  except Exception:
596
- # güvenli taşıma
597
  if torch.cuda.is_available():
598
  model = model.to(torch.device("cuda"))
599
 
 
 
 
600
  print("[init] tokenizer/image_processor/context_len ready")
601
  return True
602
 
 
3
  PULSE ECG Handler - Deterministic ECG Analysis Model (app.py uyumlu)
4
  - Deterministic (do_sample=False, sabit seed)
5
  - Tek görüntü, LLaVA conv_template + <image> token akışı
6
+ - Model dtype/device ile uyumlu görüntü tensörü (3D/4D/5D destekli)
7
  - Sağlam URL/base64 işleme, güvenli logging, opsiyonel HF upload
8
+ - Zorunlu başlık şablonu + min_new_tokens ile tam Step 1–9 çıktısı
9
  """
10
 
11
  import os
 
20
 
21
  # --- Opsiyonel bağımlılıklar ---
22
  try:
23
+ import numpy as np # isteğe bağlı
24
  except Exception:
25
  np = None
26
 
 
56
 
57
  # Transformers
58
  try:
59
+ from transformers import TextIteratorStreamer # mevcutsa sorun değil
60
  TRANSFORMERS_AVAILABLE = True
61
  except Exception:
62
  TRANSFORMERS_AVAILABLE = False
 
99
  model_initialized = False
100
 
101
  # --- Tutarlılık ayarları ---
 
102
  PROMPT_NORMALIZATION = True
103
+ DEFAULT_ECG_PROMPT = (
104
+ "Perform a detailed ECG interpretation of the provided image. Analyze step by step the rhythm, heart rate, "
105
+ "cardiac axis, P waves, PR interval, QRS complex morphology and duration, ST segments, T waves, and QT/QTc interval. "
106
+ "OUTPUT FORMAT (use these exact headings, and include every section even if normal):\n"
107
+ "Step 1: Rhythm Analysis\n"
108
+ "Step 2: Heart Rate Analysis\n"
109
+ "Step 3: Cardiac Axis Analysis\n"
110
+ "Step 4: P Wave Analysis\n"
111
+ "Step 5: PR Interval Analysis\n"
112
+ "Step 6: QRS Complex Analysis\n"
113
+ "Step 7: ST Segment Analysis\n"
114
+ "Step 8: T Wave Analysis\n"
115
+ "Step 9: QT/QTc Interval Analysis\n"
116
+ "Structured Clinical Impression:\n"
117
+ "If a section is normal, write 'Normal' and give a brief justification."
118
+ )
119
 
120
  # ---------- Yardımcılar ----------
121
 
 
301
  except Exception as e:
302
  return {"error": f"Failed to clear history: {str(e)}"}
303
 
304
+ # ---------- Prompt inşası ----------
305
 
306
  def _build_prompt(chatbot, user_text: str) -> str:
307
+ # mm_use_im_start_end konfigürasyonuna göre <image> tokenını sarmala
308
+ try:
309
+ use_wrap = bool(getattr(chatbot.model.config, "mm_use_im_start_end", False))
310
+ except Exception:
311
+ use_wrap = False
312
+
313
+ if use_wrap:
314
+ # <im_start><image></im_end>\n + metin
315
+ inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
316
+ else:
317
+ inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
318
+
319
  chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
320
  chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
321
+ return chatbot.conversation.get_prompt()
 
322
 
323
  def _stop_criteria_from_conv(chatbot, input_ids):
324
  conv = chatbot.conversation
325
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
326
  return KeywordsStoppingCriteria([stop_str], chatbot.tokenizer, input_ids)
327
 
328
+ # ---------- Cevap üretimi ----------
329
+
330
  def generate_response(message_text,
331
  image_input,
332
  max_output_tokens=4096,
 
375
  # Model dtype/device
376
  model_device = next(chatbot.model.parameters()).device
377
  model_dtype = next(chatbot.model.parameters()).dtype
378
+
379
  # Görüntü tensörü (Tensor/list/tuple + 3D/4D/5D destekli)
380
  try:
381
  processed = process_images([image], chatbot.image_processor, chatbot.model.config)
382
+
383
  if isinstance(processed, torch.Tensor):
384
+ # (C,H,W) / (B,C,H,W) / (B,T,C,H,W)
385
  if processed.ndim == 3:
386
+ image_tensor = processed.unsqueeze(0) # (1,C,H,W)
 
387
  elif processed.ndim == 4:
388
+ image_tensor = processed # (B,C,H,W)
 
389
  elif processed.ndim == 5:
 
390
  b, t, c, h, w = processed.shape
391
+ image_tensor = processed.reshape(b * t, c, h, w) # (B*T,C,H,W)
392
  else:
393
  return {"error": f"Unexpected image tensor shape: {tuple(processed.shape)}"}
 
394
  elif isinstance(processed, (list, tuple)):
395
  if len(processed) == 0:
396
  return {"error": "Image processing returned empty list"}
397
  first = processed[0]
398
  if not isinstance(first, torch.Tensor):
399
  return {"error": f"Processed image type not tensor: {type(first)}"}
 
400
  image_tensor = first.unsqueeze(0) if first.ndim == 3 else first
401
  else:
402
  return {"error": f"Unsupported processed type: {type(processed)}"}
403
+
 
404
  image_tensor = image_tensor.to(device=model_device, dtype=model_dtype)
405
+
406
  except Exception as e:
407
  return {"error": f"Image processing failed: {str(e)}"}
408
 
 
 
409
  # Prompt & tokenizasyon
410
  prompt = _build_prompt(chatbot, message_text)
411
  input_ids = tokenizer_image_token(
 
421
  torch.cuda.manual_seed(42)
422
  torch.cuda.manual_seed_all(42)
423
 
424
+ # EOS/PAD güvenli al
425
+ eos_id = chatbot.tokenizer.eos_token_id
426
+ if eos_id is None:
427
+ try:
428
+ eos_id = chatbot.tokenizer.convert_tokens_to_ids("</s>")
429
+ except Exception:
430
+ eos_id = 0 # son çare
431
+
432
  try:
433
  with torch.no_grad():
434
  outputs = chatbot.model.generate(
 
436
  images=image_tensor,
437
  do_sample=False, # deterministik
438
  max_new_tokens=int(max_output_tokens),
439
+ min_new_tokens=600, # en az bu kadar üret (step başlıkları garanti)
440
  repetition_penalty=float(repetition_penalty),
441
  use_cache=False,
442
+ pad_token_id=eos_id,
443
+ eos_token_id=eos_id,
444
  length_penalty=1.0,
445
  early_stopping=False,
446
  stopping_criteria=[stopping_criteria],
 
505
  or ""
506
  )
507
 
508
+ # Prompt normalization (ECG içeren tüm isteklerde ayrıntılı şablonu zorla)
509
  if PROMPT_NORMALIZATION and "ecg" in message_text.lower():
510
  if "concise" in message_text.lower():
511
+ message_text = (
512
+ "Provide a short, concise clinical summary of the ECG. "
513
+ "Still cover rhythm, rate, axis, PR, QRS, ST-T, QT/QTc in brief."
514
+ )
515
  else:
516
  message_text = DEFAULT_ECG_PROMPT
517
 
 
526
  # Parametreler
527
  max_output_tokens = int(payload.get("max_output_tokens",
528
  payload.get("max_new_tokens",
529
+ payload.get("max_tokens", 4096))))
530
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
531
  conv_mode_override = payload.get("conv_mode", None)
532
 
 
604
  self.model_base = None
605
  self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
606
  self.conv_mode = None
607
+ self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "4096"))
608
  self.num_frames = 16
609
  self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
610
  self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0")))
611
  self.debug = bool(int(os.getenv("DEBUG", "0")))
612
+ # args globaline ata
613
+ globals()["args"] = Args()
614
 
615
  model_name = get_model_name_from_path(args.model_path)
616
+ loaded = load_pretrained_model(
617
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
618
  )
619
+ globals()["tokenizer"], globals()["model"], globals()["image_processor"], globals()["context_len"] = loaded
620
 
621
+ # Device: accelerate devicemap varsa ek .to('cuda') gerekmeyebilir
622
  try:
623
  _ = next(model.parameters()).device
624
  except Exception:
 
625
  if torch.cuda.is_available():
626
  model = model.to(torch.device("cuda"))
627
 
628
+ # Deterministik için dropout vb. kapansın
629
+ model.eval()
630
+
631
  print("[init] tokenizer/image_processor/context_len ready")
632
  return True
633