CanerDedeoglu commited on
Commit
dbf6dc8
·
verified ·
1 Parent(s): 56c38a0

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +114 -158
handler.py CHANGED
@@ -1,12 +1,12 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- PULSE ECG Handler (demo-like streaming, stable & clean)
4
- - TextIteratorStreamer + skip_prompt=True → baş kesilmesi yok (Step 1 korunur)
5
- - do_sample=True (demo davranışı), temperature/top_p payload’dan
6
- - Anti-tekrar: no_repeat_ngram_size + repetition_penalty
7
- - Opsiyonel: custom_stop (örn. "END OF REPORT") çıktı sonunda trim
8
- - Deterministik mod: aynı görüntü+mesaj için aynı seed (deterministic=True)
9
- - Görsel tensörü 3D/4D/5D uyumlu; device/dtype eşleştirme
10
  """
11
 
12
  import os
@@ -16,13 +16,13 @@ import hashlib
16
  import datetime
17
  from io import BytesIO
18
  from threading import Thread
19
- from typing import Optional, List, Union
20
 
21
  import torch
22
  from PIL import Image
23
  import requests
24
 
25
- # ---------- LLaVA & Transformers ----------
26
  try:
27
  from llava.constants import (
28
  IMAGE_TOKEN_INDEX,
@@ -41,16 +41,16 @@ try:
41
  LLAVA_AVAILABLE = True
42
  except Exception as e:
43
  LLAVA_AVAILABLE = False
44
- print(f"[WARN] LLaVA modules not available: {e}")
45
 
46
  try:
47
- from transformers import TextIteratorStreamer
48
  TRANSFORMERS_AVAILABLE = True
49
  except Exception as e:
50
  TRANSFORMERS_AVAILABLE = False
51
  print(f"[WARN] transformers not available: {e}")
52
 
53
- # ---------- HF Hub (opsiyonel logging) ----------
54
  try:
55
  from huggingface_hub import HfApi, login
56
  HF_HUB_AVAILABLE = True
@@ -72,7 +72,7 @@ if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
72
  LOGDIR = "./logs"
73
  os.makedirs(LOGDIR, exist_ok=True)
74
 
75
- # ---------- Global Model State ----------
76
  tokenizer = None
77
  model = None
78
  image_processor = None
@@ -81,7 +81,7 @@ args = None
81
  model_initialized = False
82
 
83
 
84
- # ======================== Utilities ========================
85
 
86
  def _safe_upload(path: str):
87
  if api and repo_name and path and os.path.isfile(path):
@@ -97,22 +97,20 @@ def _safe_upload(path: str):
97
 
98
  def _conv_log_path() -> str:
99
  t = datetime.datetime.now()
100
- p = os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json")
101
- os.makedirs(os.path.dirname(p), exist_ok=True)
102
- return p
103
 
104
  def load_image_any(image_input: Union[str, dict]) -> Image.Image:
105
  """
106
  Desteklenen:
107
  - URL (http/https)
108
- - Yerel dosya yolu
109
  - base64 (opsiyonel data URL prefix ile)
110
  - {"image": <base64|dataurl>}
111
  """
112
  if isinstance(image_input, str):
113
  s = image_input.strip()
114
  if s.startswith(("http://", "https://")):
115
- r = requests.get(s, timeout=(5, 15))
116
  r.raise_for_status()
117
  return Image.open(BytesIO(r.content)).convert("RGB")
118
  if os.path.exists(s):
@@ -142,8 +140,36 @@ def _wrap_image_token_if_needed(model_cfg) -> bool:
142
  except Exception:
143
  return False
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
146
- # Demo gibi: <image> token + text (IM_START/END gerekiyorsa sar)
147
  use_wrap = _wrap_image_token_if_needed(chatbot.model.config)
148
  if use_wrap:
149
  # <im_start><image><im_end>\n + user text
@@ -160,50 +186,43 @@ def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
160
  ).unsqueeze(0).to(device)
161
  return prompt, input_ids
162
 
163
- def _stable_seed_from(image_hash: str, message_text: str) -> int:
164
- """Aynı resim+mesaj için aynı seed (deterministik örnekleme)"""
165
- h = hashlib.md5((image_hash + "||" + message_text).encode("utf-8")).digest()
166
- # 32-bit pozitif int
167
- return int.from_bytes(h[:4], "big", signed=False)
168
-
169
-
170
- # ======================== Core Generation ========================
171
-
172
  def generate_response(
173
  message_text: str,
174
  image_input,
175
  *,
176
- max_new_tokens: int = 1800,
177
- min_new_tokens: Optional[int] = 700,
178
- temperature: float = 0.20,
179
- top_p: float = 0.95,
180
- repetition_penalty: float = 1.20,
181
- no_repeat_ngram_size: Optional[int] = 6,
182
  conv_mode_override: Optional[str] = None,
183
- deterministic: bool = False, # True do_sample=False (tam deterministik)
184
- det_seed: Optional[int] = None, # verilirse sabit seed
185
- custom_stop: Optional[List[str]] = None, # ["END OF REPORT"] gibi
186
- no_stop: bool = False, # True → eos/stop yok (önerilmez)
187
  ):
188
  if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
189
  return {"error": "Required libraries not available (llava/transformers)"}
190
  if not message_text or image_input is None:
191
  return {"error": "Both 'message' and 'image' are required"}
192
 
193
- # Chat oturumu (her çağrıda taze template; demo benzeri)
 
 
 
 
 
 
194
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
195
  if conv_mode_override and conv_mode_override in conv_templates:
196
  chatbot.conversation = conv_templates[conv_mode_override].copy()
197
  else:
198
  chatbot.conversation = conv_templates[chatbot.conv_mode].copy()
199
 
200
- # Görseli yükle
201
  try:
202
  pil_img = load_image_any(image_input)
203
  except Exception as e:
204
  return {"error": f"Failed to load image: {e}"}
205
 
206
- # Log için kaydet (hash + path)
207
  img_hash, img_path = "NA", None
208
  try:
209
  buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue()
@@ -214,117 +233,75 @@ def generate_response(
214
  if not os.path.isfile(img_path):
215
  pil_img.save(img_path)
216
  except Exception as e:
217
- print(f"[log] saving image failed: {e}")
218
 
219
- # Cihaza/dtype’a taşı
220
  device = next(chatbot.model.parameters()).device
221
- dtype = next(chatbot.model.parameters()).dtype
 
222
 
223
- # Görüntü ön-işleme → tensör (3D/4D/5D destek)
224
  try:
225
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
226
- if isinstance(processed, torch.Tensor):
227
- if processed.ndim == 3: image_tensor = processed.unsqueeze(0) # (1,C,H,W)
228
- elif processed.ndim == 4: image_tensor = processed # (B,C,H,W)
229
- elif processed.ndim == 5: # (B,T,C,H,W) → (B*T,C,H,W)
230
- b,t,c,h,w = processed.shape
231
- image_tensor = processed.reshape(b*t, c, h, w)
232
- else:
233
- return {"error": f"Unexpected image tensor shape: {tuple(processed.shape)}"}
234
- elif isinstance(processed, (list, tuple)) and len(processed) > 0:
235
- first = processed[0]
236
- image_tensor = first.unsqueeze(0) if isinstance(first, torch.Tensor) and first.ndim == 3 else first
237
  else:
238
  return {"error": "Image processing returned empty"}
 
 
 
239
  image_tensor = image_tensor.to(device=device, dtype=dtype)
240
  except Exception as e:
241
  return {"error": f"Image processing failed: {e}"}
242
 
243
- # Prompt & ids
244
  _, input_ids = _build_prompt_and_ids(chatbot, message_text, device)
245
 
246
- # Seed ayarı
 
 
 
 
247
  if det_seed is not None:
248
  try:
249
  s = int(det_seed)
 
 
 
 
250
  except Exception:
251
- s = None
252
- elif deterministic:
253
- s = _stable_seed_from(img_hash, message_text)
254
- else:
255
- # Deterministik örnekleme istiyorsan; aynı girdide aynı sonuç için stabil seed de kullanabiliriz
256
- s = _stable_seed_from(img_hash, message_text)
257
-
258
- if s is not None:
259
- torch.manual_seed(s)
260
- if torch.cuda.is_available():
261
- torch.cuda.manual_seed(s)
262
- torch.cuda.manual_seed_all(s)
263
-
264
- # Stopping / EOS
265
- eos_id = chatbot.tokenizer.eos_token_id
266
- pad_id = chatbot.tokenizer.pad_token_id if chatbot.tokenizer.pad_token_id is not None else (eos_id if eos_id is not None else 0)
267
- eos_for_gen = None if no_stop else eos_id
268
 
269
- # Streamer (demo gibi; manuel dilimleme yok → Step 1 korunur)
270
  streamer = TextIteratorStreamer(
271
  chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
272
  )
273
 
274
- # do_sample: demo gibi (True). deterministic=True ise greedy’ye geç
275
- do_sample = not deterministic
276
-
277
  gen_kwargs = dict(
278
  inputs=input_ids,
279
  images=image_tensor,
280
  streamer=streamer,
281
- do_sample=do_sample,
282
- temperature=float(temperature),
283
- top_p=float(top_p),
284
- repetition_penalty=float(repetition_penalty),
285
- max_new_tokens=int(max_new_tokens),
286
  use_cache=False,
287
- pad_token_id=pad_id,
288
- eos_token_id=eos_for_gen,
289
- length_penalty=1.0,
290
- early_stopping=False,
291
- # stopping_criteria vermiyoruz → LLaVA'daki KeywordsStoppingCriteria hatalarından kaçınmak için
292
  )
293
 
294
- if no_repeat_ngram_size:
295
- try:
296
- n = int(no_repeat_ngram_size)
297
- if n > 0:
298
- gen_kwargs["no_repeat_ngram_size"] = n
299
- except Exception:
300
- pass
301
-
302
- if min_new_tokens is not None:
303
- try:
304
- mn = int(min_new_tokens)
305
- if 1 <= mn <= int(max_new_tokens):
306
- gen_kwargs["min_new_tokens"] = mn
307
- except Exception:
308
- pass
309
-
310
- # Üretim (arka thread) + stream toplama
311
  try:
312
  t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
313
  t.start()
314
- chunks: List[str] = []
315
  for piece in streamer:
316
  chunks.append(piece)
317
  text = "".join(chunks)
318
- # custom_stop varsa çıktıdan itibaren kırp
319
- if custom_stop:
320
- if isinstance(custom_stop, str):
321
- custom_stop = [custom_stop]
322
- for tag in custom_stop:
323
- if isinstance(tag, str) and tag:
324
- idx = text.find(tag)
325
- if idx != -1:
326
- text = text[:idx].rstrip()
327
- break
328
  chatbot.conversation.messages[-1][-1] = text
329
  except Exception as e:
330
  return {"error": f"Generation failed: {e}"}
@@ -348,10 +325,10 @@ def generate_response(
348
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
349
 
350
 
351
- # ======================== Public API ========================
352
 
353
  def query(payload: dict):
354
- """HF Endpoint entry (demo-like streaming)"""
355
  global model_initialized, tokenizer, model, image_processor, context_len, args
356
  if not model_initialized:
357
  if not initialize_model():
@@ -364,47 +341,27 @@ def query(payload: dict):
364
  if not message.strip(): return {"error": "Missing 'message' text"}
365
  if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
366
 
367
- # Demo-like varsayılanlar
368
- max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 1800))))
369
- min_new_tokens = payload.get("min_new_tokens", 700)
370
- try:
371
- min_new_tokens = int(min_new_tokens) if min_new_tokens is not None else None
372
- except Exception:
373
- min_new_tokens = None
374
-
375
- temperature = float(payload.get("temperature", 0.20))
376
- top_p = float(payload.get("top_p", 0.95))
377
- repetition_penalty = float(payload.get("repetition_penalty", 1.20))
378
- no_repeat_ngram = payload.get("no_repeat_ngram_size", 6)
379
- try:
380
- no_repeat_ngram = int(no_repeat_ngram) if no_repeat_ngram is not None else None
381
- except Exception:
382
- no_repeat_ngram = None
383
 
384
- conv_mode_override = payload.get("conv_mode", None)
385
- deterministic = bool(payload.get("deterministic", False))
386
- det_seed = payload.get("det_seed", None)
387
  if det_seed is not None:
388
  try: det_seed = int(det_seed)
389
  except Exception: det_seed = None
390
 
391
- custom_stop = payload.get("custom_stop", None)
392
- no_stop = bool(payload.get("no_stop", False)) # genelde False kalsın
393
-
394
  return generate_response(
395
  message_text=message,
396
  image_input=image,
397
- max_new_tokens=max_new_tokens,
398
- min_new_tokens=min_new_tokens,
399
  temperature=temperature,
400
  top_p=top_p,
401
- repetition_penalty=repetition_penalty,
402
- no_repeat_ngram_size=no_repeat_ngram,
403
  conv_mode_override=conv_mode_override,
404
- deterministic=deterministic,
405
  det_seed=det_seed,
406
- custom_stop=custom_stop,
407
- no_stop=no_stop,
408
  )
409
  except Exception as e:
410
  return {"error": f"Query failed: {e}"}
@@ -428,7 +385,7 @@ def get_model_info():
428
  }
429
 
430
 
431
- # ======================== Init & Session ========================
432
 
433
  class _Args:
434
  def __init__(self):
@@ -436,23 +393,22 @@ class _Args:
436
  self.model_base = None
437
  self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
438
  self.conv_mode = None
439
- self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "1800"))
440
  self.num_frames = 16
441
  self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
442
- # 4bit/8bit hız için açık bırakılabilir; accelerate devicemap kullanıyorsanız .to(cuda) gerekmez
443
  self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0")))
444
  self.debug = bool(int(os.getenv("DEBUG", "0")))
445
 
446
  class InferenceDemo:
447
  def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_):
448
  if not LLAVA_AVAILABLE:
449
- raise ImportError("LLaVA modules not available")
450
  disable_torch_init()
451
  self.tokenizer, self.model, self.image_processor, self.context_len = (
452
  tokenizer_, model_, image_processor_, context_len_
453
  )
454
- conv_mode_auto = _guess_conv_mode(model_path)
455
- self.conv_mode = args.conv_mode if args.conv_mode else conv_mode_auto
456
  args.conv_mode = self.conv_mode
457
  self.conversation = conv_templates[self.conv_mode].copy()
458
  self.num_frames = args.num_frames
@@ -484,7 +440,7 @@ def initialize_model():
484
  tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model(
485
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
486
  )
487
- # Device
488
  try:
489
  _ = next(model_.parameters()).device
490
  except Exception:
@@ -506,7 +462,7 @@ def initialize_model():
506
  return False
507
 
508
 
509
- # ======================== HF EndpointHandler ========================
510
 
511
  class EndpointHandler:
512
  """Hugging Face Endpoint uyumlu sınıf"""
@@ -523,4 +479,4 @@ class EndpointHandler:
523
  return get_model_info()
524
 
525
  if __name__ == "__main__":
526
- print("Handler ready. Use `EndpointHandler` or `query` for HF Inference Endpoints.")
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ PULSE ECG Handler Demo Parity Mode
4
+ - Demo app.py ile aynı üretim ayarları:
5
+ do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096
6
+ - Stopping: konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter
7
+ - Görsel tensörü: .half() ve model cihazında
8
+ - Streamer: TextIteratorStreamer (demo gibi), thread ile generate
9
+ - Seed/deterministic KAPALI (göndermezseniz); demo gibi stokastik
10
  """
11
 
12
  import os
 
16
  import datetime
17
  from io import BytesIO
18
  from threading import Thread
19
+ from typing import Optional, Union
20
 
21
  import torch
22
  from PIL import Image
23
  import requests
24
 
25
+ # ====== LLaVA & Transformers ======
26
  try:
27
  from llava.constants import (
28
  IMAGE_TOKEN_INDEX,
 
41
  LLAVA_AVAILABLE = True
42
  except Exception as e:
43
  LLAVA_AVAILABLE = False
44
+ print(f"[WARN] LLaVA not available: {e}")
45
 
46
  try:
47
+ from transformers import TextIteratorStreamer, StoppingCriteria
48
  TRANSFORMERS_AVAILABLE = True
49
  except Exception as e:
50
  TRANSFORMERS_AVAILABLE = False
51
  print(f"[WARN] transformers not available: {e}")
52
 
53
+ # ====== HF Hub logging (opsiyonel) ======
54
  try:
55
  from huggingface_hub import HfApi, login
56
  HF_HUB_AVAILABLE = True
 
72
  LOGDIR = "./logs"
73
  os.makedirs(LOGDIR, exist_ok=True)
74
 
75
+ # ====== Global State ======
76
  tokenizer = None
77
  model = None
78
  image_processor = None
 
81
  model_initialized = False
82
 
83
 
84
+ # ===================== Utilities =====================
85
 
86
  def _safe_upload(path: str):
87
  if api and repo_name and path and os.path.isfile(path):
 
97
 
98
  def _conv_log_path() -> str:
99
  t = datetime.datetime.now()
100
+ return os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json")
 
 
101
 
102
  def load_image_any(image_input: Union[str, dict]) -> Image.Image:
103
  """
104
  Desteklenen:
105
  - URL (http/https)
106
+ - yerel dosya yolu
107
  - base64 (opsiyonel data URL prefix ile)
108
  - {"image": <base64|dataurl>}
109
  """
110
  if isinstance(image_input, str):
111
  s = image_input.strip()
112
  if s.startswith(("http://", "https://")):
113
+ r = requests.get(s, timeout=(5, 20))
114
  r.raise_for_status()
115
  return Image.open(BytesIO(r.content)).convert("RGB")
116
  if os.path.exists(s):
 
140
  except Exception:
141
  return False
142
 
143
+
144
+ # ====== Güvenli Stop Kriteri (demo eşleniği) ======
145
+ class SafeKeywordsStoppingCriteria(StoppingCriteria):
146
+ """
147
+ LLaVA'nın KeywordsStoppingCriteria'sına karşılık, token bazlı
148
+ anahtar dizi (separator) eşleşmesi; tensör → bool hatası yok.
149
+ """
150
+ def __init__(self, keyword: str, tokenizer):
151
+ self.tokenizer = tokenizer
152
+ tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0]
153
+ self.kw_ids = tok # shape: (n,)
154
+
155
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
156
+ # input_ids: (bsz, seq_len)
157
+ if input_ids is None or input_ids.shape[0] == 0:
158
+ return False
159
+ out = input_ids[0] # assume bsz=1
160
+ n = self.kw_ids.shape[0]
161
+ if out.shape[0] < n:
162
+ return False
163
+ tail = out[-n:]
164
+ # cihaz hizası
165
+ kw = self.kw_ids.to(tail.device)
166
+ return torch.equal(tail, kw)
167
+
168
+
169
+ # ===================== Core Generation =====================
170
+
171
  def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
172
+ # demo gibi: <image> + text (IM_START/END gerekiyorsa sar)
173
  use_wrap = _wrap_image_token_if_needed(chatbot.model.config)
174
  if use_wrap:
175
  # <im_start><image><im_end>\n + user text
 
186
  ).unsqueeze(0).to(device)
187
  return prompt, input_ids
188
 
 
 
 
 
 
 
 
 
 
189
  def generate_response(
190
  message_text: str,
191
  image_input,
192
  *,
193
+ temperature: Optional[float] = None,
194
+ top_p: Optional[float] = None,
195
+ max_new_tokens: Optional[int] = None,
 
 
 
196
  conv_mode_override: Optional[str] = None,
197
+ repetition_penalty: Optional[float] = None, # demo'da yok; verilirse 1.0 yaparız
198
+ # NOT: no_repeat_ngram_size / min_new_tokens / custom_stop KULLANMIYORUZ → demo-parite
199
+ det_seed: Optional[int] = None, # seed gönderilmezse stokastik (demo gibi)
 
200
  ):
201
  if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
202
  return {"error": "Required libraries not available (llava/transformers)"}
203
  if not message_text or image_input is None:
204
  return {"error": "Both 'message' and 'image' are required"}
205
 
206
+ # Varsayılanlar demo
207
+ if temperature is None: temperature = 0.05
208
+ if top_p is None: top_p = 1.0
209
+ if max_new_tokens is None: max_new_tokens = 4096
210
+ if repetition_penalty is None: repetition_penalty = 1.0 # etkisiz
211
+
212
+ # Chat session: her çağrıda taze template
213
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
214
  if conv_mode_override and conv_mode_override in conv_templates:
215
  chatbot.conversation = conv_templates[conv_mode_override].copy()
216
  else:
217
  chatbot.conversation = conv_templates[chatbot.conv_mode].copy()
218
 
219
+ # Görüntü yükle
220
  try:
221
  pil_img = load_image_any(image_input)
222
  except Exception as e:
223
  return {"error": f"Failed to load image: {e}"}
224
 
225
+ # Log için hash+path
226
  img_hash, img_path = "NA", None
227
  try:
228
  buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue()
 
233
  if not os.path.isfile(img_path):
234
  pil_img.save(img_path)
235
  except Exception as e:
236
+ print(f"[log] save image failed: {e}")
237
 
238
+ # Cihaz/dtype
239
  device = next(chatbot.model.parameters()).device
240
+ # demo half: .half() kullanacağız
241
+ dtype = torch.float16
242
 
243
+ # Görüntü ön-işleme → tensör
244
  try:
245
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
246
+ # LLaVA genelde list döndürür
247
+ if isinstance(processed, (list, tuple)) and len(processed) > 0:
248
+ image_tensor = processed[0]
249
+ elif isinstance(processed, torch.Tensor):
250
+ image_tensor = processed[0] if processed.ndim == 4 else processed # güvenlik
 
 
 
 
 
 
251
  else:
252
  return {"error": "Image processing returned empty"}
253
+ if image_tensor.ndim == 3:
254
+ image_tensor = image_tensor.unsqueeze(0) # (1,C,H,W)
255
+ # demo: half + device
256
  image_tensor = image_tensor.to(device=device, dtype=dtype)
257
  except Exception as e:
258
  return {"error": f"Image processing failed: {e}"}
259
 
260
+ # Prompt & input ids
261
  _, input_ids = _build_prompt_and_ids(chatbot, message_text, device)
262
 
263
+ # Stop string from conv
264
+ stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
265
+ stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer)
266
+
267
+ # Seed (gönderilmediyse stokastik → demo gibi)
268
  if det_seed is not None:
269
  try:
270
  s = int(det_seed)
271
+ torch.manual_seed(s)
272
+ if torch.cuda.is_available():
273
+ torch.cuda.manual_seed(s)
274
+ torch.cuda.manual_seed_all(s)
275
  except Exception:
276
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
+ # Streamer (demo gibi)
279
  streamer = TextIteratorStreamer(
280
  chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
281
  )
282
 
283
+ # Generate kwargs demo ayarları
 
 
284
  gen_kwargs = dict(
285
  inputs=input_ids,
286
  images=image_tensor,
287
  streamer=streamer,
288
+ do_sample=True, # DEMO
289
+ temperature=float(temperature), # DEMO default 0.05
290
+ top_p=float(top_p), # DEMO default 1.0
291
+ max_new_tokens=int(max_new_tokens), # DEMO slider
292
+ repetition_penalty=float(repetition_penalty), # default 1.0 → etkisiz
293
  use_cache=False,
294
+ stopping_criteria=[stopping], # DEMO-benzeri durdurma
 
 
 
 
295
  )
296
 
297
+ # Üretim (arka thread) + akışı topla
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  try:
299
  t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
300
  t.start()
301
+ chunks = []
302
  for piece in streamer:
303
  chunks.append(piece)
304
  text = "".join(chunks)
 
 
 
 
 
 
 
 
 
 
305
  chatbot.conversation.messages[-1][-1] = text
306
  except Exception as e:
307
  return {"error": f"Generation failed: {e}"}
 
325
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
326
 
327
 
328
+ # ===================== Public API =====================
329
 
330
  def query(payload: dict):
331
+ """HF Endpoint entry (demo parity)."""
332
  global model_initialized, tokenizer, model, image_processor, context_len, args
333
  if not model_initialized:
334
  if not initialize_model():
 
341
  if not message.strip(): return {"error": "Missing 'message' text"}
342
  if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
343
 
344
+ # Demo varsayılanları — payload override edebilir
345
+ temperature = float(payload.get("temperature", 0.05))
346
+ top_p = float(payload.get("top_p", 1.0))
347
+ max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096))))
348
+ repetition_penalty = float(payload.get("repetition_penalty", 1.0)) # etkisiz default
 
 
 
 
 
 
 
 
 
 
 
349
 
350
+ conv_mode_override = payload.get("conv_mode", None)
351
+ det_seed = payload.get("det_seed", None)
 
352
  if det_seed is not None:
353
  try: det_seed = int(det_seed)
354
  except Exception: det_seed = None
355
 
 
 
 
356
  return generate_response(
357
  message_text=message,
358
  image_input=image,
 
 
359
  temperature=temperature,
360
  top_p=top_p,
361
+ max_new_tokens=max_new_tokens,
 
362
  conv_mode_override=conv_mode_override,
363
+ repetition_penalty=repetition_penalty,
364
  det_seed=det_seed,
 
 
365
  )
366
  except Exception as e:
367
  return {"error": f"Query failed: {e}"}
 
385
  }
386
 
387
 
388
+ # ===================== Init & Session =====================
389
 
390
  class _Args:
391
  def __init__(self):
 
393
  self.model_base = None
394
  self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
395
  self.conv_mode = None
396
+ self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "4096"))
397
  self.num_frames = 16
398
  self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
 
399
  self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0")))
400
  self.debug = bool(int(os.getenv("DEBUG", "0")))
401
 
402
  class InferenceDemo:
403
  def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_):
404
  if not LLAVA_AVAILABLE:
405
+ raise ImportError("LLaVA not available")
406
  disable_torch_init()
407
  self.tokenizer, self.model, self.image_processor, self.context_len = (
408
  tokenizer_, model_, image_processor_, context_len_
409
  )
410
+ auto = _guess_conv_mode(model_path)
411
+ self.conv_mode = args.conv_mode if args.conv_mode else auto
412
  args.conv_mode = self.conv_mode
413
  self.conversation = conv_templates[self.conv_mode].copy()
414
  self.num_frames = args.num_frames
 
440
  tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model(
441
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
442
  )
443
+ # demo: model'ı genelde cuda’da çalıştırır
444
  try:
445
  _ = next(model_.parameters()).device
446
  except Exception:
 
462
  return False
463
 
464
 
465
+ # ===================== HF EndpointHandler =====================
466
 
467
  class EndpointHandler:
468
  """Hugging Face Endpoint uyumlu sınıf"""
 
479
  return get_model_info()
480
 
481
  if __name__ == "__main__":
482
+ print("Handler ready (Demo Parity Mode). Use `EndpointHandler` or `query`.")