CanerDedeoglu commited on
Commit
b1d5c20
·
verified ·
1 Parent(s): 5ea9d94

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +264 -558
handler.py CHANGED
@@ -1,41 +1,33 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- PULSE ECG Handler - Deterministic ECG Analysis Model (app.py uyumlu)
4
- - Deterministic (do_sample=False, sabit seed)
5
- - Tek görüntü, LLaVA conv_template + <image> token akışı
6
- - Model dtype/device ile uyumlu görüntü tensörü (3D/4D/5D destekli)
7
- - Sağlam URL/base64 işleme, güvenli logging, opsiyonel HF upload
8
- - Zorunlu başlık şablonu + min_new_tokens ile tam Step 1–9 çıktısı
9
- - Tekrarları engelleme (no_repeat_ngram_size) + post-format dedup
10
  """
11
 
12
  import os
13
- import re
14
- import datetime
15
- import torch
16
- import hashlib
17
  import json
18
  import base64
19
- import requests
20
- from PIL import Image
21
  from io import BytesIO
22
 
23
- # --- Opsiyonel bağımlılıklar ---
24
- try:
25
- import numpy as np # isteğe bağlı
26
- except Exception:
27
- np = None
28
 
 
29
  try:
30
  import cv2
31
  CV2_AVAILABLE = True
32
  except Exception:
33
  CV2_AVAILABLE = False
34
- print("Warning: cv2 (OpenCV) not available. Video processing will be disabled.")
35
 
36
- # LLaVA
37
  try:
38
- from llava import conversation as conversation_lib
39
  from llava.constants import (
40
  IMAGE_TOKEN_INDEX,
41
  DEFAULT_IMAGE_TOKEN,
@@ -44,55 +36,45 @@ try:
44
  )
45
  from llava.conversation import conv_templates, SeparatorStyle
46
  from llava.model.builder import load_pretrained_model
47
- from llava.utils import disable_torch_init
48
  from llava.mm_utils import (
49
  tokenizer_image_token,
50
  process_images,
51
  get_model_name_from_path,
52
  KeywordsStoppingCriteria,
53
  )
 
54
  LLAVA_AVAILABLE = True
55
  except Exception as e:
56
  LLAVA_AVAILABLE = False
57
  print(f"Warning: LLaVA modules not available: {e}")
58
 
59
- # Transformers
60
- try:
61
- from transformers import TextIteratorStreamer # mevcutsa sorun değil
62
- TRANSFORMERS_AVAILABLE = True
63
- except Exception:
64
- TRANSFORMERS_AVAILABLE = False
65
- print("Warning: Transformers not available")
66
-
67
- # HF Hub (opsiyonel)
68
  try:
69
  from huggingface_hub import HfApi, login
70
  HF_HUB_AVAILABLE = True
71
  except Exception:
72
  HF_HUB_AVAILABLE = False
73
- print("Warning: Hugging Face Hub not available")
74
 
75
- # --- HF Hub init (opsiyonel) ---
 
 
76
  if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
77
  try:
78
  login(token=os.environ["HF_TOKEN"], write_permission=True)
79
  api = HfApi()
80
  repo_name = os.environ.get("LOG_REPO", "")
81
  except Exception as e:
82
- print(f"Failed to initialize HF API: {e}")
83
  api = None
84
  repo_name = ""
85
- else:
86
- api = None
87
- repo_name = ""
88
 
89
- # --- Sabitler / Dizinyapısı ---
90
  LOGDIR = "./logs"
91
  VOTEDIR = "./votes"
92
  os.makedirs(LOGDIR, exist_ok=True)
93
  os.makedirs(VOTEDIR, exist_ok=True)
94
 
95
- # --- Global model durumları ---
96
  tokenizer = None
97
  model = None
98
  image_processor = None
@@ -100,30 +82,9 @@ context_len = None
100
  args = None
101
  model_initialized = False
102
 
103
- # --- Tutarlılık ayarları ---
104
- PROMPT_NORMALIZATION = True
105
- DEFAULT_ECG_PROMPT = (
106
- "Perform a detailed ECG interpretation of the provided image. Analyze step by step the rhythm, heart rate, "
107
- "cardiac axis, P waves, PR interval, QRS complex morphology and duration, ST segments, T waves, and QT/QTc interval. "
108
- "OUTPUT FORMAT (use these exact headings, and include every section even if normal):\n"
109
- "Step 1: Rhythm Analysis\n"
110
- "Step 2: Heart Rate Analysis\n"
111
- "Step 3: Cardiac Axis Analysis\n"
112
- "Step 4: P Wave Analysis\n"
113
- "Step 5: PR Interval Analysis\n"
114
- "Step 6: QRS Complex Analysis\n"
115
- "Step 7: ST Segment Analysis\n"
116
- "Step 8: T Wave Analysis\n"
117
- "Step 9: QT/QTc Interval Analysis\n"
118
- "Structured Clinical Impression:\n"
119
- "If a section is normal, write 'Normal' and give a brief justification. "
120
- "Each section must be 1–3 concise sentences. Do not repeat identical statements. "
121
- "Write the final diagnostic impression only once in 'Structured Clinical Impression' and do not restate it elsewhere."
122
- )
123
-
124
- # ---------- Yardımcılar ----------
125
-
126
- def _safe_upload(path):
127
  if api and repo_name and os.path.isfile(path):
128
  try:
129
  api.upload_file(
@@ -135,290 +96,62 @@ def _safe_upload(path):
135
  except Exception as e:
136
  print(f"[upload] failed for {path}: {e}")
137
 
138
- def get_conv_log_filename():
139
  t = datetime.datetime.now()
140
- name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
141
- os.makedirs(os.path.dirname(name), exist_ok=True
142
- )
143
- return name
144
-
145
- def get_conv_vote_filename():
146
- t = datetime.datetime.now()
147
- name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json")
148
- os.makedirs(os.path.dirname(name), exist_ok=True)
149
- return name
150
-
151
- def vote_last_response(state, vote_type, model_selector):
152
- try:
153
- with open(get_conv_vote_filename(), "a") as fout:
154
- data = {"type": vote_type, "model": model_selector, "state": state}
155
- fout.write(json.dumps(data) + "\n")
156
- _safe_upload(get_conv_vote_filename())
157
- except Exception as e:
158
- print(f"Failed to record vote: {e}")
159
-
160
- # Yalın uzantı listeleri
161
- IMAGE_EXTS = {"jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "jfif"}
162
- try:
163
- import pillow_heif # noqa: F401
164
- IMAGE_EXTS.update({"heic", "heif"})
165
- except Exception:
166
- pass
167
-
168
- VIDEO_EXTS = {"avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"}
169
-
170
- def is_valid_video_filename(name: str) -> bool:
171
- if not CV2_AVAILABLE or not name:
172
- return False
173
- ext = name.split(".")[-1].lower()
174
- return ext in VIDEO_EXTS
175
-
176
- def is_valid_image_filename(name: str) -> bool:
177
- if not name:
178
- return False
179
- ext = name.split(".")[-1].lower()
180
- return ext in IMAGE_EXTS
181
-
182
- def sample_frames(video_file, num_frames):
183
- if not CV2_AVAILABLE:
184
- raise ImportError("cv2 (OpenCV) not available. Video processing is disabled.")
185
- cap = cv2.VideoCapture(video_file)
186
- total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
187
- if total <= 0 or num_frames <= 0:
188
- cap.release()
189
- return []
190
- step = max(1, total // num_frames)
191
- idxs = list(range(0, total, step))[:num_frames]
192
- frames = []
193
- for i in idxs:
194
- cap.set(cv2.CAP_PROP_POS_FRAMES, i)
195
- ret, frame = cap.read()
196
- if not ret or frame is None:
197
- continue
198
- pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
199
- frames.append(pil_img)
200
- cap.release()
201
- return frames
202
-
203
- def load_image(image_file):
204
- if image_file.startswith(("http://", "https://")):
205
- try:
206
- r = requests.get(image_file, timeout=(5, 15))
207
- r.raise_for_status()
208
- return Image.open(BytesIO(r.content)).convert("RGB")
209
- except Exception as e:
210
- raise ValueError(f"Failed to load image URL: {e}")
211
- else:
212
- return Image.open(image_file).convert("RGB")
213
-
214
- def process_base64_image(base64_string: str) -> Image.Image:
215
- try:
216
- if base64_string.startswith("data:image"):
217
- base64_string = base64_string.split(",", 1)[1]
218
- image_data = base64.b64decode(base64_string)
219
- image = Image.open(BytesIO(image_data)).convert("RGB")
220
- return image
221
- except Exception as e:
222
- raise ValueError(f"Failed to process base64 image: {e}")
223
 
224
- def process_image_input(image_input):
225
- """Desteklenen formatlar: yerel yol, URL, base64 string veya {'image': base64} sözlüğü."""
226
- if isinstance(image_input, str):
227
- if image_input.startswith(("http://", "https://")):
228
- return load_image(image_input)
229
- if os.path.exists(image_input):
230
- return load_image(image_input)
231
- return process_base64_image(image_input)
232
- if isinstance(image_input, dict) and "image" in image_input:
233
- return process_base64_image(image_input["image"])
234
- raise ValueError("Unsupported image input format")
235
-
236
- # ---------- Şablon dayatma (post-format) ----------
237
-
238
- SECTION_ORDER = [
239
- "Step 1: Rhythm Analysis",
240
- "Step 2: Heart Rate Analysis",
241
- "Step 3: Cardiac Axis Analysis",
242
- "Step 4: P Wave Analysis",
243
- "Step 5: PR Interval Analysis",
244
- "Step 6: QRS Complex Analysis",
245
- "Step 7: ST Segment Analysis",
246
- "Step 8: T Wave Analysis",
247
- "Step 9: QT/QTc Interval Analysis",
248
- "Structured Clinical Impression:",
249
- ]
250
-
251
- _SECTION_RE = re.compile(
252
- r"(Step\s*1:\s*Rhythm Analysis|"
253
- r"Step\s*2:\s*Heart Rate Analysis|"
254
- r"Step\s*3:\s*Cardiac Axis Analysis|"
255
- r"Step\s*4:\s*P Wave Analysis|"
256
- r"Step\s*5:\s*PR Interval Analysis|"
257
- r"Step\s*6:\s*QRS Complex Analysis|"
258
- r"Step\s*7:\s*ST Segment Analysis|"
259
- r"Step\s*8:\s*T Wave Analysis|"
260
- r"Step\s*9:\s*QT/QTc Interval Analysis|"
261
- r"Structured Clinical Impression:)",
262
- flags=re.IGNORECASE
263
- )
264
-
265
- def _enforce_section_template(text: str) -> str:
266
  """
267
- Model çıktısını yakalayıp Step 1–9 + Structured başlıklarını sırayla ve eksiksiz
268
- döndürecek şekilde biçimler. Eksik bölümler 'Normal...' notuyla doldurulur.
 
 
269
  """
270
- pieces = _SECTION_RE.split(text)
271
- found = {}
272
- prefix = None
273
-
274
- if pieces:
275
- if not _SECTION_RE.match(pieces[0] or ""):
276
- prefix = (pieces[0] or "").strip()
277
-
278
- i = 1
279
- while i + 1 < len(pieces):
280
- heading = pieces[i].strip()
281
- content = pieces[i + 1].strip()
282
- for canonical in SECTION_ORDER:
283
- if heading.lower().startswith(canonical.lower().rstrip(":")):
284
- found[canonical] = content
285
- break
286
- i += 2
287
-
288
- filled = []
289
- for sec in SECTION_ORDER:
290
- val = (found.get(sec, "") or "").strip()
291
- if not val:
292
- if sec.startswith("Step"):
293
- val = "Normal. No definite abnormality detected in this section based on the provided ECG image."
294
- else:
295
- val = "Overall impression: No acute life-threatening abnormality identified. Correlate clinically."
296
- filled.append(f"{sec}\n{val}")
297
-
298
- if prefix:
299
- filled[0] = filled[0] + f"\n\n(Additional notes captured before Step 1): {prefix}"
300
-
301
- return "\n\n".join(filled)
302
-
303
- def _sent_split(s: str):
304
- return [x.strip() for x in re.split(r'(?<=[.!?])\s+', s.strip()) if x.strip()]
305
-
306
- def _norm_key(s: str):
307
- return re.sub(r'\W+', ' ', s.lower()).strip()
308
-
309
- def _dedupe_and_clip_sections(text: str) -> str:
310
- """
311
- Şablon oluşmuş metni alır, her bölümde tekrar eden cümleleri siler,
312
- uzunluğu kısaltır (Steps: ≤3 cümle, Impression: ≤6 cümle) ve birleştirir.
313
- """
314
- pieces = _SECTION_RE.split(text)
315
- found = {}
316
- i = 1
317
- while i + 1 < len(pieces):
318
- heading = pieces[i].strip()
319
- content = pieces[i + 1].strip()
320
- for canonical in SECTION_ORDER:
321
- if heading.lower().startswith(canonical.lower().rstrip(":")):
322
- found[canonical] = content
323
- break
324
- i += 2
325
-
326
- out_sections = []
327
- for sec in SECTION_ORDER:
328
- body = (found.get(sec, "") or "").strip()
329
- sents = _sent_split(body)
330
-
331
- seen = set()
332
- deduped = []
333
- for s in sents:
334
- k = _norm_key(s)
335
- if k not in seen:
336
- seen.add(k)
337
- deduped.append(s)
338
-
339
- limit = 3 if sec.startswith("Step") else 6
340
- limited = deduped[:limit] if deduped else []
341
- out_body = " ".join(limited) if limited else body
342
- out_sections.append(f"{sec}\n{out_body}" if out_body else f"{sec}\n")
343
-
344
- return "\n\n".join(out_sections)
345
-
346
- # ---------- Oturum / Konuşma ----------
347
-
348
- class InferenceDemo(object):
349
- def __init__(self, args, model_path, tokenizer, model, image_processor, context_len) -> None:
350
- if not LLAVA_AVAILABLE:
351
- raise ImportError("LLaVA modules not available")
352
- disable_torch_init()
353
- self.tokenizer, self.model, self.image_processor, self.context_len = (
354
- tokenizer, model, image_processor, context_len
355
- )
356
- model_name = get_model_name_from_path(model_path)
357
- low = model_name.lower()
358
- if "llama-2" in low:
359
- conv_mode = "llava_llama_2"
360
- elif "v1" in low or "pulse" in low:
361
- conv_mode = "llava_v1"
362
- elif "mpt" in low:
363
- conv_mode = "mpt"
364
- elif "qwen" in low:
365
- conv_mode = "qwen_1_5"
366
- else:
367
- conv_mode = "llava_v0"
368
-
369
- if args.conv_mode is not None and conv_mode != args.conv_mode:
370
- print(f"[WARNING] auto conv={conv_mode}, using --conv-mode={args.conv_mode}")
371
- else:
372
- args.conv_mode = conv_mode
373
- self.conv_mode = args.conv_mode
374
- self.conversation = conv_templates[self.conv_mode].copy()
375
- self.num_frames = args.num_frames
376
-
377
- class ChatSessionManager:
378
- def __init__(self):
379
- self.chatbot_instance = None
380
- self.args = None
381
- self.model_path = None
382
-
383
- def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
384
- self.args = args
385
- self.model_path = model_path
386
- self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
387
- print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
388
-
389
- def reset_chatbot(self):
390
- self.chatbot_instance = None
391
-
392
- def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
393
- if self.chatbot_instance is None:
394
- self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
395
- return self.chatbot_instance
396
-
397
- chat_manager = ChatSessionManager()
398
-
399
- def clear_history():
400
- if not LLAVA_AVAILABLE:
401
- return {"error": "LLaVA modules not available"}
402
- try:
403
- chatbot = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B",
404
- tokenizer, model, image_processor, context_len)
405
  try:
406
- chatbot.conversation = conv_templates[chatbot.conv_mode].copy()
 
407
  except Exception as e:
408
- print(f"[DEBUG] Failed to reset conversation: {e}")
409
- return {"status": "success", "message": "Conversation history cleared"}
410
- except Exception as e:
411
- return {"error": f"Failed to clear history: {str(e)}"}
412
-
413
- # ---------- Prompt inşası ----------
414
-
415
- def _build_prompt(chatbot, user_text: str) -> str:
416
- # mm_use_im_start_end konfigürasyonuna göre <image> tokenını sarmala
 
 
 
 
 
 
 
 
 
 
417
  try:
418
- use_wrap = bool(getattr(chatbot.model.config, "mm_use_im_start_end", False))
419
  except Exception:
420
- use_wrap = False
 
 
421
 
 
 
422
  if use_wrap:
423
  inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
424
  else:
@@ -426,233 +159,207 @@ def _build_prompt(chatbot, user_text: str) -> str:
426
 
427
  chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
428
  chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
429
- return chatbot.conversation.get_prompt()
 
 
 
 
 
430
 
431
- def _stop_criteria_from_conv(chatbot, input_ids):
432
  conv = chatbot.conversation
433
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
434
  return KeywordsStoppingCriteria([stop_str], chatbot.tokenizer, input_ids)
435
 
436
- # ---------- Cevap üretimi ----------
437
-
438
- def generate_response(message_text,
439
- image_input,
440
- max_output_tokens=4096,
441
- repetition_penalty=1.0,
442
- conv_mode_override=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  if not LLAVA_AVAILABLE:
444
  return {"error": "LLaVA modules not available"}
445
 
446
- if not message_text or not image_input:
447
- return {"error": "Both message text and image are required"}
448
-
449
- chatbot = chat_manager.get_chatbot(
450
- args, args.model_path if args else "PULSE-ECG/PULSE-7B",
451
- tokenizer, model, image_processor, context_len
452
- )
453
 
 
 
454
  if conv_mode_override and conv_mode_override in conv_templates:
455
  chatbot.conversation = conv_templates[conv_mode_override].copy()
456
  else:
457
  chatbot.conversation = conv_templates[chatbot.conv_mode].copy()
458
 
459
- # Görüntüyü al/işle
460
  try:
461
- image = process_image_input(image_input)
462
  except Exception as e:
463
- return {"error": f"Failed to process image: {str(e)}"}
464
 
465
  # Log için kaydet
 
 
466
  try:
467
- img_byte_arr = BytesIO()
468
- image.save(img_byte_arr, format="JPEG")
469
- image_hash = hashlib.md5(img_byte_arr.getvalue()).hexdigest()
 
470
  t = datetime.datetime.now()
471
- out_path = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{image_hash}.jpg")
472
- os.makedirs(os.path.dirname(out_path), exist_ok=True)
473
- if not os.path.isfile(out_path):
474
- image.save(out_path)
475
  except Exception as e:
476
- print(f"[WARN] Failed to save image: {e}")
477
- out_path = None
478
- image_hash = "NA"
479
-
480
- # Model dtype/device
481
- model_device = next(chatbot.model.parameters()).device
482
- model_dtype = next(chatbot.model.parameters()).dtype
483
 
484
- # Görüntü tensörü (Tensor/list/tuple + 3D/4D/5D destekli)
 
 
485
  try:
486
- processed = process_images([image], chatbot.image_processor, chatbot.model.config)
487
-
488
  if isinstance(processed, torch.Tensor):
489
- if processed.ndim == 3:
490
- image_tensor = processed.unsqueeze(0) # (1,C,H,W)
491
- elif processed.ndim == 4:
492
- image_tensor = processed # (B,C,H,W)
493
- elif processed.ndim == 5:
494
- b, t, c, h, w = processed.shape
495
- image_tensor = processed.reshape(b * t, c, h, w) # (B*T,C,H,W)
496
  else:
497
  return {"error": f"Unexpected image tensor shape: {tuple(processed.shape)}"}
498
- elif isinstance(processed, (list, tuple)):
499
- if len(processed) == 0:
500
- return {"error": "Image processing returned empty list"}
501
  first = processed[0]
502
- if not isinstance(first, torch.Tensor):
503
- return {"error": f"Processed image type not tensor: {type(first)}"}
504
- image_tensor = first.unsqueeze(0) if first.ndim == 3 else first
505
  else:
506
- return {"error": f"Unsupported processed type: {type(processed)}"}
507
-
508
- image_tensor = image_tensor.to(device=model_device, dtype=model_dtype)
509
 
 
 
510
  except Exception as e:
511
- return {"error": f"Image processing failed: {str(e)}"}
512
 
513
  # Prompt & tokenizasyon
514
- prompt = _build_prompt(chatbot, message_text)
515
- input_ids = tokenizer_image_token(
516
- prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
517
- ).unsqueeze(0).to(model_device)
518
-
519
- # Stop kriteri
520
- stopping_criteria = _stop_criteria_from_conv(chatbot, input_ids)
521
 
522
- # Deterministik üretim
523
- torch.manual_seed(42)
524
- if torch.cuda.is_available():
525
- torch.cuda.manual_seed(42)
526
- torch.cuda.manual_seed_all(42)
527
 
528
- # EOS/PAD güvenli al
529
- eos_id = chatbot.tokenizer.eos_token_id
530
- if eos_id is None:
531
- try:
532
- eos_id = chatbot.tokenizer.convert_tokens_to_ids("</s>")
533
- except Exception:
534
- eos_id = 0 # son çare
535
 
536
  try:
537
  with torch.no_grad():
538
  outputs = chatbot.model.generate(
539
  inputs=input_ids,
540
  images=image_tensor,
541
- do_sample=False, # deterministik
542
- max_new_tokens=int(max_output_tokens),
543
- min_new_tokens=350, # 800 -> 350 (tekrar riskini azalt)
544
- no_repeat_ngram_size=5, # tekrar bloklarını engelle
545
  repetition_penalty=float(repetition_penalty),
 
546
  use_cache=False,
547
- pad_token_id=eos_id,
548
- eos_token_id=eos_id,
549
  length_penalty=1.0,
550
  early_stopping=False,
551
- stopping_criteria=[stopping_criteria],
552
  )
553
- # Sadece yeni üretilen kısmı çöz
554
  gen = outputs[0][input_ids.shape[1]:]
555
- response = chatbot.tokenizer.decode(gen, skip_special_tokens=True)
556
-
557
- # ŞABLON ZORLAMA + tekrar kırpma
558
- response = _enforce_section_template(response)
559
- response = _dedupe_and_clip_sections(response)
560
-
561
- # Konuşmaya yerleştir
562
- if chatbot.conversation.messages and isinstance(chatbot.conversation.messages[-1], list):
563
- chatbot.conversation.messages[-1][-1] = response
564
- else:
565
- chatbot.conversation.append_message(chatbot.conversation.roles[1], response)
566
 
 
 
567
  except Exception as e:
568
- return {"error": f"Generation failed: {str(e)}"}
569
 
570
- # Log
571
  try:
572
- history = [(message_text, response)]
573
- with open(get_conv_log_filename(), "a") as fout:
574
- data = {
575
- "type": "chat",
576
- "model": "PULSE-7B",
577
- "state": history,
578
- "images": [image_hash],
579
- "images_path": [out_path] if out_path else []
580
- }
581
- fout.write(json.dumps(data) + "\n")
582
- _safe_upload(get_conv_log_filename())
583
- if out_path:
584
- _safe_upload(out_path)
585
  except Exception as e:
586
- print(f"[WARN] Failed to log/upload: {e}")
587
 
588
- return {
589
- "status": "success",
590
- "response": response,
591
- "conversation_id": id(chatbot.conversation)
592
- }
593
 
594
- # ---------- API yüzeyi ----------
595
 
596
- def query(payload):
597
- """HF Endpoint ana giriş noktası"""
598
  global model_initialized, tokenizer, model, image_processor, context_len, args
599
 
600
- # Lazy init
601
  if not model_initialized:
602
- ok = initialize_model()
603
- if not ok:
604
  return {"error": "Model initialization failed"}
605
  model_initialized = True
606
 
607
  try:
608
- # Metin
609
- message_text = (
610
- payload.get("message")
611
- or payload.get("query")
612
- or payload.get("prompt")
613
- or payload.get("istem")
614
- or ""
615
- )
616
-
617
- # Prompt normalization (ECG içeren tüm isteklerde ayrıntılı şablonu zorla)
618
- if PROMPT_NORMALIZATION and "ecg" in message_text.lower():
619
- if "concise" in message_text.lower():
620
- message_text = (
621
- "Provide a short, concise clinical summary of the ECG. "
622
- "Still cover rhythm, rate, axis, PR, QRS, ST-T, QT/QTc in brief."
623
- )
624
- else:
625
- message_text = DEFAULT_ECG_PROMPT
626
-
627
- # Görüntü
628
- image_input = (
629
- payload.get("image")
630
- or payload.get("image_url")
631
- or payload.get("img")
632
- or None
633
- )
634
-
635
- # Parametreler
636
- max_output_tokens = int(payload.get("max_output_tokens",
637
- payload.get("max_new_tokens",
638
- payload.get("max_tokens", 4096))))
639
- repetition_penalty = float(payload.get("repetition_penalty", 1.0))
640
- conv_mode_override = payload.get("conv_mode", None)
641
-
642
- if not message_text.strip():
643
- return {"error": "Missing prompt text. Use 'message', 'query', 'prompt', or 'istem' key"}
644
- if image_input is None:
645
- return {"error": "Missing image. Use 'image', 'image_url', or 'img' key"}
646
 
647
  return generate_response(
648
- message_text=message_text,
649
- image_input=image_input,
650
- max_output_tokens=max_output_tokens,
 
 
651
  repetition_penalty=repetition_penalty,
652
- conv_mode_override=conv_mode_override
 
653
  )
654
  except Exception as e:
655
- return {"error": f"Query failed: {str(e)}"}
656
 
657
  def health_check():
658
  return {
@@ -660,109 +367,108 @@ def health_check():
660
  "model_initialized": model_initialized,
661
  "cuda_available": torch.cuda.is_available(),
662
  "llava_available": LLAVA_AVAILABLE,
663
- "transformers_available": TRANSFORMERS_AVAILABLE,
664
  "cv2_available": CV2_AVAILABLE,
665
- "lazy_loading": True
666
  }
667
 
668
  def get_model_info():
669
  if not model_initialized:
670
- return {"error": "Model not initialized yet", "lazy_loading": True}
671
  return {
672
  "model_path": args.model_path if args else "Unknown",
673
- "model_type": "PULSE-7B",
674
- "cuda_available": torch.cuda.is_available(),
675
- "device": str(model.device) if model else "Unknown"
676
  }
677
 
678
- def upvote_last_response(conversation_id):
679
- try:
680
- vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B")
681
- return {"status": "success", "message": "Thank you for your voting!"}
682
- except Exception as e:
683
- return {"error": f"Failed to upvote: {str(e)}"}
684
 
685
- def downvote_last_response(conversation_id):
686
- try:
687
- vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B")
688
- return {"status": "success", "message": "Thank you for your voting!"}
689
- except Exception as e:
690
- return {"error": f"Failed to downvote: {str(e)}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691
 
692
- def flag_response(conversation_id):
693
- try:
694
- vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B")
695
- return {"status": "success", "message": "Response flagged successfully"}
696
- except Exception as e:
697
- return {"error": f"Failed to flag response: {str(e)}"}
 
 
 
 
 
 
 
698
 
699
- # ---------- Model init ----------
700
 
701
  def initialize_model():
702
  """Modeli yükle (lazy)"""
703
  global tokenizer, model, image_processor, context_len, args
704
-
705
  if not LLAVA_AVAILABLE:
706
- print("LLaVA modules not available, skipping model initialization")
707
  return False
708
-
709
  try:
710
- class Args:
711
- def __init__(self):
712
- self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
713
- self.model_base = None
714
- self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
715
- self.conv_mode = None
716
- self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "4096"))
717
- self.num_frames = 16
718
- self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
719
- self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0")))
720
- self.debug = bool(int(os.getenv("DEBUG", "0")))
721
-
722
- globals()["args"] = Args()
723
-
724
  model_name = get_model_name_from_path(args.model_path)
725
- loaded = load_pretrained_model(
726
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
727
  )
728
- globals()["tokenizer"], globals()["model"], globals()["image_processor"], globals()["context_len"] = loaded
729
-
730
- # Device: accelerate devicemap varsa ek .to('cuda') gerekmeyebilir
731
  try:
732
  _ = next(model.parameters()).device
733
  except Exception:
734
  if torch.cuda.is_available():
735
  model = model.to(torch.device("cuda"))
736
-
737
- # Deterministik için dropout vb. kapansın
738
  model.eval()
739
-
740
- print("[init] tokenizer/image_processor/context_len ready")
 
741
  return True
742
-
743
  except Exception as e:
744
- print(f"Failed to initialize model: {e}")
745
  return False
746
 
747
- # ---------- HF EndpointHandler ----------
748
 
749
  class EndpointHandler:
750
- """Hugging Face endpoint handler class"""
751
-
752
  def __init__(self, model_dir):
753
  self.model_dir = model_dir
754
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
755
-
756
  def __call__(self, payload):
757
  if "inputs" in payload:
758
  return query(payload["inputs"])
759
  return query(payload)
760
-
761
  def health_check(self):
762
  return health_check()
763
-
764
  def get_model_info(self):
765
  return get_model_info()
766
 
767
  if __name__ == "__main__":
768
- print("Handler loaded. Use `query` or `EndpointHandler` in HF Inference Endpoints.")
 
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ PULSE ECG Handler - Demo-like (sampling) LLaVA endpoint
4
+ - Demo davranışı: do_sample=True, temperature/top_p payload'dan alınır
5
+ - max_new_tokens: payload/slider değeri; bağlam limitine göre güvenli kırpma
6
+ - Tek görsel işleme; IM_START/END otomatik; 3D/4D/5D tensör uyumlu
7
+ - Çıktıya post-format/deduplicate UYGULANMAZ (demo ile bire bir)
 
 
8
  """
9
 
10
  import os
 
 
 
 
11
  import json
12
  import base64
13
+ import hashlib
14
+ import datetime
15
  from io import BytesIO
16
 
17
+ import torch
18
+ from PIL import Image
19
+ import requests
 
 
20
 
21
+ # --- Opsiyonel bağımlılıklar ---
22
  try:
23
  import cv2
24
  CV2_AVAILABLE = True
25
  except Exception:
26
  CV2_AVAILABLE = False
27
+ print("Warning: OpenCV (cv2) not available; video is disabled.")
28
 
29
+ # --- LLaVA / Transformers ---
30
  try:
 
31
  from llava.constants import (
32
  IMAGE_TOKEN_INDEX,
33
  DEFAULT_IMAGE_TOKEN,
 
36
  )
37
  from llava.conversation import conv_templates, SeparatorStyle
38
  from llava.model.builder import load_pretrained_model
 
39
  from llava.mm_utils import (
40
  tokenizer_image_token,
41
  process_images,
42
  get_model_name_from_path,
43
  KeywordsStoppingCriteria,
44
  )
45
+ from llava.utils import disable_torch_init
46
  LLAVA_AVAILABLE = True
47
  except Exception as e:
48
  LLAVA_AVAILABLE = False
49
  print(f"Warning: LLaVA modules not available: {e}")
50
 
51
+ # --- HF Hub (opsiyonel logging) ---
 
 
 
 
 
 
 
 
52
  try:
53
  from huggingface_hub import HfApi, login
54
  HF_HUB_AVAILABLE = True
55
  except Exception:
56
  HF_HUB_AVAILABLE = False
 
57
 
58
+ # ------------- HF Hub init (opsiyonel) -------------
59
+ api = None
60
+ repo_name = ""
61
  if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
62
  try:
63
  login(token=os.environ["HF_TOKEN"], write_permission=True)
64
  api = HfApi()
65
  repo_name = os.environ.get("LOG_REPO", "")
66
  except Exception as e:
67
+ print(f"[HF Hub] init failed: {e}")
68
  api = None
69
  repo_name = ""
 
 
 
70
 
71
+ # ------------- Klasörler -------------
72
  LOGDIR = "./logs"
73
  VOTEDIR = "./votes"
74
  os.makedirs(LOGDIR, exist_ok=True)
75
  os.makedirs(VOTEDIR, exist_ok=True)
76
 
77
+ # ------------- Global durum -------------
78
  tokenizer = None
79
  model = None
80
  image_processor = None
 
82
  args = None
83
  model_initialized = False
84
 
85
+ # ------------- Yardımcılar -------------
86
+
87
+ def _safe_upload(path: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  if api and repo_name and os.path.isfile(path):
89
  try:
90
  api.upload_file(
 
96
  except Exception as e:
97
  print(f"[upload] failed for {path}: {e}")
98
 
99
+ def _conv_log_path():
100
  t = datetime.datetime.now()
101
+ p = os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json")
102
+ os.makedirs(os.path.dirname(p), exist_ok=True)
103
+ return p
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ def load_image_any(image_input):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  """
107
+ Desteklenen formatlar:
108
+ - URL (http/https)
109
+ - Yerel dosya yolu
110
+ - base64 (opsiyonel data URL prefix ile)
111
  """
112
+ if isinstance(image_input, str):
113
+ s = image_input.strip()
114
+ if s.startswith(("http://", "https://")):
115
+ r = requests.get(s, timeout=(5, 15))
116
+ r.raise_for_status()
117
+ return Image.open(BytesIO(r.content)).convert("RGB")
118
+ if os.path.exists(s):
119
+ return Image.open(s).convert("RGB")
120
+ # base64
121
+ if s.startswith("data:image"):
122
+ s = s.split(",", 1)[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  try:
124
+ raw = base64.b64decode(s)
125
+ return Image.open(BytesIO(raw)).convert("RGB")
126
  except Exception as e:
127
+ raise ValueError(f"Invalid image string (not URL/path/base64): {e}")
128
+ elif isinstance(image_input, dict) and "image" in image_input:
129
+ return load_image_any(image_input["image"])
130
+ else:
131
+ raise ValueError("Unsupported image input format")
132
+
133
+ def _guess_conv_mode(model_path: str) -> str:
134
+ name = get_model_name_from_path(model_path).lower()
135
+ if "llama-2" in name:
136
+ return "llava_llama_2"
137
+ if "v1" in name or "pulse" in name:
138
+ return "llava_v1"
139
+ if "mpt" in name:
140
+ return "mpt"
141
+ if "qwen" in name:
142
+ return "qwen_1_5"
143
+ return "llava_v0"
144
+
145
+ def _wrap_image_token_if_needed(model_cfg) -> bool:
146
  try:
147
+ return bool(getattr(model_cfg, "mm_use_im_start_end", False))
148
  except Exception:
149
+ return False
150
+
151
+ # ------------- Çekirdek üretim -------------
152
 
153
+ def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
154
+ use_wrap = _wrap_image_token_if_needed(chatbot.model.config)
155
  if use_wrap:
156
  inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
157
  else:
 
159
 
160
  chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
161
  chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
162
+ prompt = chatbot.conversation.get_prompt()
163
+
164
+ input_ids = tokenizer_image_token(
165
+ prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
166
+ ).unsqueeze(0).to(device)
167
+ return prompt, input_ids
168
 
169
+ def _stopping(chatbot, input_ids):
170
  conv = chatbot.conversation
171
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
172
  return KeywordsStoppingCriteria([stop_str], chatbot.tokenizer, input_ids)
173
 
174
+ def _safe_max_new_tokens(requested: int, input_len: int, ctx_limit: int) -> int:
175
+ """
176
+ Demo'da slider değeri doğrudan kullanılıyor; burada ek güvenlik:
177
+ toplam (input + new + rezerv) <= ctx_limit olacak şekilde kırp.
178
+ """
179
+ requested = max(1, min(int(requested), 8192))
180
+ reserve = 16
181
+ available = max(32, ctx_limit - input_len - reserve)
182
+ return max(1, min(requested, available))
183
+
184
+ def generate_response(
185
+ message_text: str,
186
+ image_input,
187
+ *,
188
+ max_new_tokens: int = 4096,
189
+ temperature: float = 0.05,
190
+ top_p: float = 1.0,
191
+ repetition_penalty: float = 1.0,
192
+ conv_mode_override: str | None = None,
193
+ det_seed: int | None = None,
194
+ ):
195
  if not LLAVA_AVAILABLE:
196
  return {"error": "LLaVA modules not available"}
197
 
198
+ if not message_text or image_input is None:
199
+ return {"error": "Both 'message' and 'image' are required"}
 
 
 
 
 
200
 
201
+ # Chatbot/konuşma hazırla (her çağrıda sıfırdan, demo gibi)
202
+ chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
203
  if conv_mode_override and conv_mode_override in conv_templates:
204
  chatbot.conversation = conv_templates[conv_mode_override].copy()
205
  else:
206
  chatbot.conversation = conv_templates[chatbot.conv_mode].copy()
207
 
208
+ # Görüntüyü yükle
209
  try:
210
+ pil_img = load_image_any(image_input)
211
  except Exception as e:
212
+ return {"error": f"Failed to load image: {e}"}
213
 
214
  # Log için kaydet
215
+ img_hash = "NA"
216
+ img_path = None
217
  try:
218
+ buf = BytesIO()
219
+ pil_img.save(buf, format="JPEG")
220
+ img_bytes = buf.getvalue()
221
+ img_hash = hashlib.md5(img_bytes).hexdigest()
222
  t = datetime.datetime.now()
223
+ img_path = os.path.join(LOGDIR, "serve_images", f"{t.year:04d}-{t.month:02d}-{t.day:02d}", f"{img_hash}.jpg")
224
+ os.makedirs(os.path.dirname(img_path), exist_ok=True)
225
+ if not os.path.isfile(img_path):
226
+ pil_img.save(img_path)
227
  except Exception as e:
228
+ print(f"[log] saving image failed: {e}")
 
 
 
 
 
 
229
 
230
+ # Görüntüyü tensöre çevir
231
+ device = next(chatbot.model.parameters()).device
232
+ dtype = next(chatbot.model.parameters()).dtype
233
  try:
234
+ processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
 
235
  if isinstance(processed, torch.Tensor):
236
+ if processed.ndim == 3: # (C,H,W)
237
+ image_tensor = processed.unsqueeze(0)
238
+ elif processed.ndim == 4: # (B,C,H,W)
239
+ image_tensor = processed
240
+ elif processed.ndim == 5: # (B,T,C,H,W) -> (B*T,C,H,W)
241
+ b,t,c,h,w = processed.shape
242
+ image_tensor = processed.reshape(b*t, c, h, w)
243
  else:
244
  return {"error": f"Unexpected image tensor shape: {tuple(processed.shape)}"}
245
+ elif isinstance(processed, (list, tuple)) and len(processed) > 0:
 
 
246
  first = processed[0]
247
+ image_tensor = first.unsqueeze(0) if isinstance(first, torch.Tensor) and first.ndim == 3 else first
 
 
248
  else:
249
+ return {"error": "Image processing returned empty"}
 
 
250
 
251
+ # Demo'da çoğunlukla half + to(device) kullanılıyor
252
+ image_tensor = image_tensor.to(device=device, dtype=getattr(torch, "float16", torch.float16))
253
  except Exception as e:
254
+ return {"error": f"Image processing failed: {e}"}
255
 
256
  # Prompt & tokenizasyon
257
+ prompt, input_ids = _build_prompt_and_ids(chatbot, message_text, device)
258
+ stopping = _stopping(chatbot, input_ids)
 
 
 
 
 
259
 
260
+ # max_new_tokens'ı güvenle kırp (demo slider + bağlam tavanı)
261
+ ctx_limit = context_len or getattr(chatbot.model.config, "max_position_embeddings", 8192)
262
+ max_new_tokens = _safe_max_new_tokens(max_new_tokens, input_ids.shape[1], ctx_limit)
 
 
263
 
264
+ # Demo: sampling açık; istenirse deterministik sample için seed verilebilir
265
+ if det_seed is not None:
266
+ torch.manual_seed(det_seed)
267
+ if torch.cuda.is_available():
268
+ torch.cuda.manual_seed(det_seed)
269
+ torch.cuda.manual_seed_all(det_seed)
 
270
 
271
  try:
272
  with torch.no_grad():
273
  outputs = chatbot.model.generate(
274
  inputs=input_ids,
275
  images=image_tensor,
276
+ do_sample=True,
277
+ temperature=float(temperature),
278
+ top_p=float(top_p),
 
279
  repetition_penalty=float(repetition_penalty),
280
+ max_new_tokens=int(max_new_tokens),
281
  use_cache=False,
282
+ pad_token_id=chatbot.tokenizer.eos_token_id,
283
+ eos_token_id=chatbot.tokenizer.eos_token_id,
284
  length_penalty=1.0,
285
  early_stopping=False,
286
+ stopping_criteria=[stopping],
287
  )
 
288
  gen = outputs[0][input_ids.shape[1]:]
289
+ text = chatbot.tokenizer.decode(gen, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
290
 
291
+ # Konuşmaya yerleştir (demo gibi)
292
+ chatbot.conversation.messages[-1][-1] = text
293
  except Exception as e:
294
+ return {"error": f"Generation failed: {e}"}
295
 
296
+ # Log yaz
297
  try:
298
+ row = {
299
+ "time": datetime.datetime.now().isoformat(),
300
+ "type": "chat",
301
+ "model": "PULSE-7B",
302
+ "state": [(message_text, text)],
303
+ "image_hash": img_hash,
304
+ "image_path": img_path or "",
305
+ }
306
+ with open(_conv_log_path(), "a") as f:
307
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
308
+ _safe_upload(_conv_log_path())
309
+ if img_path:
310
+ _safe_upload(img_path)
311
  except Exception as e:
312
+ print(f"[log] failed: {e}")
313
 
314
+ return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
 
 
 
 
315
 
316
+ # ------------- API Yüzeyi -------------
317
 
318
+ def query(payload: dict):
319
+ """HF Endpoint ana giriş noktası (demo uyumlu)"""
320
  global model_initialized, tokenizer, model, image_processor, context_len, args
321
 
 
322
  if not model_initialized:
323
+ if not initialize_model():
 
324
  return {"error": "Model initialization failed"}
325
  model_initialized = True
326
 
327
  try:
328
+ message = payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or ""
329
+ image = payload.get("image") or payload.get("image_url") or payload.get("img") or None
330
+
331
+ if not message.strip():
332
+ return {"error": "Missing 'message' text"}
333
+ if image is None:
334
+ return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
335
+
336
+ # Demo: slider benzeri parametreler
337
+ max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096))))
338
+ temperature = float(payload.get("temperature", 0.05))
339
+ top_p = float(payload.get("top_p", 1.0))
340
+ repetition_penalty = float(payload.get("repetition_penalty", 1.0))
341
+ conv_mode_override = payload.get("conv_mode", None)
342
+
343
+ # (Opsiyonel) deterministik sample için seed (demo defaultu: None)
344
+ det_seed = payload.get("det_seed", None)
345
+ if det_seed is not None:
346
+ try:
347
+ det_seed = int(det_seed)
348
+ except Exception:
349
+ det_seed = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
  return generate_response(
352
+ message_text=message,
353
+ image_input=image,
354
+ max_new_tokens=max_new_tokens,
355
+ temperature=temperature,
356
+ top_p=top_p,
357
  repetition_penalty=repetition_penalty,
358
+ conv_mode_override=conv_mode_override,
359
+ det_seed=det_seed,
360
  )
361
  except Exception as e:
362
+ return {"error": f"Query failed: {e}"}
363
 
364
  def health_check():
365
  return {
 
367
  "model_initialized": model_initialized,
368
  "cuda_available": torch.cuda.is_available(),
369
  "llava_available": LLAVA_AVAILABLE,
 
370
  "cv2_available": CV2_AVAILABLE,
 
371
  }
372
 
373
  def get_model_info():
374
  if not model_initialized:
375
+ return {"error": "Model not initialized"}
376
  return {
377
  "model_path": args.model_path if args else "Unknown",
378
+ "context_len": context_len,
379
+ "device": str(next(model.parameters()).device) if model else "Unknown",
 
380
  }
381
 
382
+ # ------------- Model init -------------
 
 
 
 
 
383
 
384
+ class _Args:
385
+ def __init__(self):
386
+ self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
387
+ self.model_base = None
388
+ self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
389
+ self.conv_mode = None
390
+ self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "4096"))
391
+ self.num_frames = 16
392
+ self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
393
+ self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0")))
394
+ self.debug = bool(int(os.getenv("DEBUG", "0")))
395
+
396
+ class InferenceDemo:
397
+ def __init__(self, args, model_path, tokenizer, model, image_processor, context_len):
398
+ if not LLAVA_AVAILABLE:
399
+ raise ImportError("LLaVA modules not available")
400
+ disable_torch_init()
401
+ self.tokenizer, self.model, self.image_processor, self.context_len = (
402
+ tokenizer, model, image_processor, context_len
403
+ )
404
+ conv_mode_auto = _guess_conv_mode(model_path)
405
+ if args.conv_mode and args.conv_mode != conv_mode_auto:
406
+ self.conv_mode = args.conv_mode
407
+ else:
408
+ self.conv_mode = conv_mode_auto
409
+ args.conv_mode = conv_mode_auto
410
+ self.conversation = conv_templates[self.conv_mode].copy()
411
+ self.num_frames = args.num_frames
412
 
413
+ class ChatSessionManager:
414
+ def __init__(self):
415
+ self.chatbot = None
416
+ self.args = None
417
+ self.model_path = None
418
+ def init_if_needed(self, args, model_path, tokenizer, model, image_processor, context_len):
419
+ if self.chatbot is None:
420
+ self.args = args
421
+ self.model_path = model_path
422
+ self.chatbot = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
423
+ def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
424
+ self.init_if_needed(args, model_path, tokenizer, model, image_processor, context_len)
425
+ return self.chatbot
426
 
427
+ chat_manager = ChatSessionManager()
428
 
429
  def initialize_model():
430
  """Modeli yükle (lazy)"""
431
  global tokenizer, model, image_processor, context_len, args
 
432
  if not LLAVA_AVAILABLE:
433
+ print("LLaVA not available; cannot init.")
434
  return False
 
435
  try:
436
+ args = _Args()
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  model_name = get_model_name_from_path(args.model_path)
438
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
439
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
440
  )
441
+ # Cihaz
 
 
442
  try:
443
  _ = next(model.parameters()).device
444
  except Exception:
445
  if torch.cuda.is_available():
446
  model = model.to(torch.device("cuda"))
 
 
447
  model.eval()
448
+ # Chatbot init
449
+ chat_manager.init_if_needed(args, args.model_path, tokenizer, model, image_processor, context_len)
450
+ print("[init] model/tokenizer/image_processor loaded.")
451
  return True
 
452
  except Exception as e:
453
+ print(f"[init] failed: {e}")
454
  return False
455
 
456
+ # ------------- HF EndpointHandler -------------
457
 
458
  class EndpointHandler:
459
+ """Hugging Face Endpoint uyumlu sınıf"""
 
460
  def __init__(self, model_dir):
461
  self.model_dir = model_dir
462
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
 
463
  def __call__(self, payload):
464
  if "inputs" in payload:
465
  return query(payload["inputs"])
466
  return query(payload)
 
467
  def health_check(self):
468
  return health_check()
 
469
  def get_model_info(self):
470
  return get_model_info()
471
 
472
  if __name__ == "__main__":
473
+ print("Handler ready. Use `EndpointHandler` or `query` for HF Inference Endpoints.")
474
+