CanerDedeoglu commited on
Commit
05ae2ff
·
verified ·
1 Parent(s): 686dbcb

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +120 -81
handler.py CHANGED
@@ -1,12 +1,10 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- PULSE ECG Handler - Demo-like sampling + no_stop (hard) + min_new_tokens auto
4
- - do_sample=True, temperature/top_p payload'dan
5
- - max_new_tokens: payload değeri (kırpma yok)
6
- - no_stop=True: stopping_criteria KAPALI + eos_token_id=None
7
- - no_stop=True ve min_new_tokens boşsa: otomatik min_new_tokens (uzun yanıt garantisi)
8
- - Tek görsel; IM_START/END otomatik; 3D/4D/5D tensör uyumlu
9
- - Post-format yok (demo davranışı)
10
  """
11
 
12
  import os
@@ -15,20 +13,14 @@ import base64
15
  import hashlib
16
  import datetime
17
  from io import BytesIO
 
 
18
 
19
  import torch
20
  from PIL import Image
21
  import requests
22
 
23
- # --- Opsiyonel ---
24
- try:
25
- import cv2
26
- CV2_AVAILABLE = True
27
- except Exception:
28
- CV2_AVAILABLE = False
29
- print("Warning: OpenCV (cv2) not available; video is disabled.")
30
-
31
- # --- LLaVA ---
32
  try:
33
  from llava.constants import (
34
  IMAGE_TOKEN_INDEX,
@@ -48,7 +40,14 @@ try:
48
  LLAVA_AVAILABLE = True
49
  except Exception as e:
50
  LLAVA_AVAILABLE = False
51
- print(f"Warning: LLaVA modules not available: {e}")
 
 
 
 
 
 
 
52
 
53
  # --- HF Hub (opsiyonel logging) ---
54
  try:
@@ -70,10 +69,9 @@ if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
70
  repo_name = ""
71
 
72
  LOGDIR = "./logs"
73
- VOTEDIR = "./votes"
74
  os.makedirs(LOGDIR, exist_ok=True)
75
- os.makedirs(VOTEDIR, exist_ok=True)
76
 
 
77
  tokenizer = None
78
  model = None
79
  image_processor = None
@@ -81,8 +79,10 @@ context_len = None
81
  args = None
82
  model_initialized = False
83
 
 
 
84
  def _safe_upload(path: str):
85
- if api and repo_name and os.path.isfile(path):
86
  try:
87
  api.upload_file(
88
  path_or_fileobj=path,
@@ -100,6 +100,12 @@ def _conv_log_path():
100
  return p
101
 
102
  def load_image_any(image_input):
 
 
 
 
 
 
103
  if isinstance(image_input, str):
104
  s = image_input.strip()
105
  if s.startswith(("http://", "https://")):
@@ -108,6 +114,7 @@ def load_image_any(image_input):
108
  return Image.open(BytesIO(r.content)).convert("RGB")
109
  if os.path.exists(s):
110
  return Image.open(s).convert("RGB")
 
111
  if s.startswith("data:image"):
112
  s = s.split(",", 1)[1]
113
  raw = base64.b64decode(s)
@@ -134,72 +141,87 @@ def _wrap_image_token_if_needed(model_cfg) -> bool:
134
  def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
135
  use_wrap = _wrap_image_token_if_needed(chatbot.model.config)
136
  if use_wrap:
 
137
  inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
138
  else:
139
  inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
 
140
  chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
141
  chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
142
  prompt = chatbot.conversation.get_prompt()
 
143
  input_ids = tokenizer_image_token(
144
  prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
145
  ).unsqueeze(0).to(device)
146
  return prompt, input_ids
147
 
148
- def _stopping(chatbot, input_ids):
149
  conv = chatbot.conversation
150
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
151
- return KeywordsStoppingCriteria([stop_str], chatbot.tokenizer, input_ids)
 
 
 
 
 
152
 
153
  def generate_response(
154
  message_text: str,
155
  image_input,
156
  *,
157
- max_new_tokens: int = 4096,
158
- temperature: float = 0.05,
159
- top_p: float = 1.0,
160
- repetition_penalty: float = 1.0,
161
- conv_mode_override: str | None = None,
162
- det_seed: int | None = None,
 
 
163
  no_stop: bool = False,
164
- min_new_tokens: int | None = None, # otomatik atanabilir
165
  ):
166
- if not LLAVA_AVAILABLE:
167
- return {"error": "LLaVA modules not available"}
168
  if not message_text or image_input is None:
169
  return {"error": "Both 'message' and 'image' are required"}
170
 
 
171
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
172
  if conv_mode_override and conv_mode_override in conv_templates:
173
  chatbot.conversation = conv_templates[conv_mode_override].copy()
174
  else:
175
  chatbot.conversation = conv_templates[chatbot.conv_mode].copy()
176
 
 
177
  try:
178
  pil_img = load_image_any(image_input)
179
  except Exception as e:
180
  return {"error": f"Failed to load image: {e}"}
181
 
182
- # log
183
  img_hash, img_path = "NA", None
184
  try:
185
- buf = BytesIO(); pil_img.save(buf, format="JPEG"); img_bytes = buf.getvalue()
186
- img_hash = hashlib.md5(img_bytes).hexdigest()
187
  t = datetime.datetime.now()
188
  img_path = os.path.join(LOGDIR, "serve_images", f"{t.year:04d}-{t.month:02d}-{t.day:02d}", f"{img_hash}.jpg")
189
  os.makedirs(os.path.dirname(img_path), exist_ok=True)
190
- if not os.path.isfile(img_path): pil_img.save(img_path)
 
191
  except Exception as e:
192
  print(f"[log] saving image failed: {e}")
193
 
194
- # görüntü tensörü
195
  device = next(chatbot.model.parameters()).device
196
  dtype = next(chatbot.model.parameters()).dtype
 
 
197
  try:
198
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
199
  if isinstance(processed, torch.Tensor):
200
  if processed.ndim == 3: image_tensor = processed.unsqueeze(0)
201
  elif processed.ndim == 4: image_tensor = processed
202
- elif processed.ndim == 5:
203
  b,t,c,h,w = processed.shape
204
  image_tensor = processed.reshape(b*t, c, h, w)
205
  else:
@@ -213,11 +235,16 @@ def generate_response(
213
  except Exception as e:
214
  return {"error": f"Image processing failed: {e}"}
215
 
216
- # prompt & ids
217
  _, input_ids = _build_prompt_and_ids(chatbot, message_text, device)
218
- stopping = None if no_stop else _stopping(chatbot, input_ids)
219
 
220
- # deterministik sample (opsiyonel)
 
 
 
 
 
 
221
  if det_seed is not None:
222
  try:
223
  det_seed = int(det_seed)
@@ -228,18 +255,15 @@ def generate_response(
228
  except Exception:
229
  pass
230
 
231
- # EOS/PAD
232
- eos_id = tokenizer.eos_token_id
233
- pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else (eos_id if eos_id is not None else 0)
234
- # no_stop=True ise eos'a göre durmayı tamamen kapat
235
- if no_stop:
236
- eos_for_gen = None
237
- else:
238
- eos_for_gen = eos_id
239
 
240
  gen_kwargs = dict(
241
  inputs=input_ids,
242
  images=image_tensor,
 
243
  do_sample=True,
244
  temperature=float(temperature),
245
  top_p=float(top_p),
@@ -250,37 +274,38 @@ def generate_response(
250
  eos_token_id=eos_for_gen,
251
  length_penalty=1.0,
252
  early_stopping=False,
253
- stopping_criteria=None if no_stop else [stopping],
254
  )
255
 
256
- # min_new_tokens otomatik (no_stop=True ve kullanıcı vermediyse)
257
- if no_stop and (min_new_tokens is None):
258
  try:
259
- req = int(max_new_tokens)
260
- auto_min = max(300, min(req - 64, 1024)) # 300–1024 bandında güvenli
261
- if auto_min > 0:
262
- gen_kwargs["min_new_tokens"] = auto_min
263
  except Exception:
264
  pass
265
- elif min_new_tokens is not None:
 
266
  try:
267
  mn = int(min_new_tokens)
268
- if mn > 0 and mn <= int(max_new_tokens):
269
  gen_kwargs["min_new_tokens"] = mn
270
  except Exception:
271
  pass
272
 
273
- # generate
274
  try:
275
- with torch.no_grad():
276
- outputs = chatbot.model.generate(**gen_kwargs)
277
- gen = outputs[0][input_ids.shape[1]:]
278
- text = tokenizer.decode(gen, skip_special_tokens=True)
 
 
279
  chatbot.conversation.messages[-1][-1] = text
280
  except Exception as e:
281
  return {"error": f"Generation failed: {e}"}
282
 
283
- # log
284
  try:
285
  row = {
286
  "time": datetime.datetime.now().isoformat(),
@@ -292,16 +317,16 @@ def generate_response(
292
  }
293
  with open(_conv_log_path(), "a", encoding="utf-8") as f:
294
  f.write(json.dumps(row, ensure_ascii=False) + "\n")
295
- _safe_upload(_conv_log_path()); _safe_upload(img_path or "")
296
  except Exception as e:
297
  print(f"[log] failed: {e}")
298
 
299
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
300
 
301
- # -------- API --------
302
 
303
  def query(payload: dict):
304
- """HF Endpoint entry (demo-like)"""
305
  global model_initialized, tokenizer, model, image_processor, context_len, args
306
  if not model_initialized:
307
  if not initialize_model():
@@ -314,32 +339,43 @@ def query(payload: dict):
314
  if not message.strip(): return {"error": "Missing 'message' text"}
315
  if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
316
 
317
- max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096))))
318
- temperature = float(payload.get("temperature", 0.05))
319
- top_p = float(payload.get("top_p", 1.0))
320
- repetition_penalty = float(payload.get("repetition_penalty", 1.0))
 
 
 
 
 
 
 
 
 
 
 
 
321
  conv_mode_override = payload.get("conv_mode", None)
322
  det_seed = payload.get("det_seed", None)
323
  if det_seed is not None:
324
  try: det_seed = int(det_seed)
325
  except Exception: det_seed = None
326
  no_stop = bool(payload.get("no_stop", False))
327
- mnt = payload.get("min_new_tokens", None)
328
- if mnt is not None:
329
- try: mnt = int(mnt)
330
- except Exception: mnt = None
331
 
332
  return generate_response(
333
  message_text=message,
334
  image_input=image,
335
  max_new_tokens=max_new_tokens,
 
336
  temperature=temperature,
337
  top_p=top_p,
338
  repetition_penalty=repetition_penalty,
 
339
  conv_mode_override=conv_mode_override,
340
  det_seed=det_seed,
341
  no_stop=no_stop,
342
- min_new_tokens=mnt,
343
  )
344
  except Exception as e:
345
  return {"error": f"Query failed: {e}"}
@@ -350,7 +386,7 @@ def health_check():
350
  "model_initialized": model_initialized,
351
  "cuda_available": torch.cuda.is_available(),
352
  "llava_available": LLAVA_AVAILABLE,
353
- "cv2_available": CV2_AVAILABLE,
354
  }
355
 
356
  def get_model_info():
@@ -362,7 +398,7 @@ def get_model_info():
362
  "device": str(next(model.parameters()).device) if model else "Unknown",
363
  }
364
 
365
- # -------- init --------
366
 
367
  class _Args:
368
  def __init__(self):
@@ -370,19 +406,19 @@ class _Args:
370
  self.model_base = None
371
  self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
372
  self.conv_mode = None
373
- self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "4096"))
374
  self.num_frames = 16
375
  self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
376
  self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0")))
377
  self.debug = bool(int(os.getenv("DEBUG", "0")))
378
 
379
  class InferenceDemo:
380
- def __init__(self, args, model_path, tokenizer, model, image_processor, context_len):
381
  if not LLAVA_AVAILABLE:
382
  raise ImportError("LLaVA modules not available")
383
  disable_torch_init()
384
  self.tokenizer, self.model, self.image_processor, self.context_len = (
385
- tokenizer, model, image_processor, context_len
386
  )
387
  conv_mode_auto = _guess_conv_mode(model_path)
388
  self.conv_mode = args.conv_mode if args.conv_mode else conv_mode_auto
@@ -409,7 +445,7 @@ chat_manager = ChatSessionManager()
409
  def initialize_model():
410
  global tokenizer, model, image_processor, context_len, args
411
  if not LLAVA_AVAILABLE:
412
- print("LLaVA not available; cannot init.")
413
  return False
414
  try:
415
  args = _Args()
@@ -430,7 +466,10 @@ def initialize_model():
430
  print(f"[init] failed: {e}")
431
  return False
432
 
 
 
433
  class EndpointHandler:
 
434
  def __init__(self, model_dir):
435
  self.model_dir = model_dir
436
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ PULSE ECG Handler (demo-like streaming)
4
+ - TextIteratorStreamer + skip_prompt=True (dilimleme yok; Step 1 korunur)
5
+ - do_sample=True (demo davranışı), temperature/top_p payload'dan
6
+ - Opsiyonel: no_stop, custom_stop, no_repeat_ngram_size, min_new_tokens
7
+ - IM_START/END otomatik; 3D/4D/5D görüntü tensörü uyumlu; device/dtype eşleştirme
 
 
8
  """
9
 
10
  import os
 
13
  import hashlib
14
  import datetime
15
  from io import BytesIO
16
+ from threading import Thread
17
+ from typing import Optional, List
18
 
19
  import torch
20
  from PIL import Image
21
  import requests
22
 
23
+ # --- LLaVA / Transformers ---
 
 
 
 
 
 
 
 
24
  try:
25
  from llava.constants import (
26
  IMAGE_TOKEN_INDEX,
 
40
  LLAVA_AVAILABLE = True
41
  except Exception as e:
42
  LLAVA_AVAILABLE = False
43
+ print(f"[WARN] LLaVA modules not available: {e}")
44
+
45
+ try:
46
+ from transformers import TextIteratorStreamer
47
+ TRANSFORMERS_AVAILABLE = True
48
+ except Exception as e:
49
+ TRANSFORMERS_AVAILABLE = False
50
+ print(f"[WARN] transformers not available: {e}")
51
 
52
  # --- HF Hub (opsiyonel logging) ---
53
  try:
 
69
  repo_name = ""
70
 
71
  LOGDIR = "./logs"
 
72
  os.makedirs(LOGDIR, exist_ok=True)
 
73
 
74
+ # --- Global Model State ---
75
  tokenizer = None
76
  model = None
77
  image_processor = None
 
79
  args = None
80
  model_initialized = False
81
 
82
+ # ----------------- Utilities -----------------
83
+
84
  def _safe_upload(path: str):
85
+ if api and repo_name and path and os.path.isfile(path):
86
  try:
87
  api.upload_file(
88
  path_or_fileobj=path,
 
100
  return p
101
 
102
  def load_image_any(image_input):
103
+ """
104
+ Desteklenen:
105
+ - URL (http/https)
106
+ - Yerel dosya yolu
107
+ - base64 (opsiyonel data URL prefix ile)
108
+ """
109
  if isinstance(image_input, str):
110
  s = image_input.strip()
111
  if s.startswith(("http://", "https://")):
 
114
  return Image.open(BytesIO(r.content)).convert("RGB")
115
  if os.path.exists(s):
116
  return Image.open(s).convert("RGB")
117
+ # base64
118
  if s.startswith("data:image"):
119
  s = s.split(",", 1)[1]
120
  raw = base64.b64decode(s)
 
141
  def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
142
  use_wrap = _wrap_image_token_if_needed(chatbot.model.config)
143
  if use_wrap:
144
+ # <im_start><image><im_end>\n + user text
145
  inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
146
  else:
147
  inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
148
+
149
  chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
150
  chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
151
  prompt = chatbot.conversation.get_prompt()
152
+
153
  input_ids = tokenizer_image_token(
154
  prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
155
  ).unsqueeze(0).to(device)
156
  return prompt, input_ids
157
 
158
+ def _stopping_keywords(chatbot, input_ids, extra: Optional[List[str]] = None):
159
  conv = chatbot.conversation
160
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
161
+ keys = [stop_str]
162
+ if extra:
163
+ keys.extend([k for k in extra if isinstance(k, str) and k.strip()])
164
+ return KeywordsStoppingCriteria(keys, chatbot.tokenizer, input_ids)
165
+
166
+ # ----------------- Core Generation -----------------
167
 
168
  def generate_response(
169
  message_text: str,
170
  image_input,
171
  *,
172
+ max_new_tokens: int = 1800,
173
+ min_new_tokens: Optional[int] = None,
174
+ temperature: float = 0.20,
175
+ top_p: float = 0.95,
176
+ repetition_penalty: float = 1.20,
177
+ no_repeat_ngram_size: Optional[int] = 6,
178
+ conv_mode_override: Optional[str] = None,
179
+ det_seed: Optional[int] = None,
180
  no_stop: bool = False,
181
+ custom_stop: Optional[List[str]] = None,
182
  ):
183
+ if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
184
+ return {"error": "Required libraries not available (llava/transformers)"}
185
  if not message_text or image_input is None:
186
  return {"error": "Both 'message' and 'image' are required"}
187
 
188
+ # Chat session (fresh conv each call, demo-like)
189
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
190
  if conv_mode_override and conv_mode_override in conv_templates:
191
  chatbot.conversation = conv_templates[conv_mode_override].copy()
192
  else:
193
  chatbot.conversation = conv_templates[chatbot.conv_mode].copy()
194
 
195
+ # Load image
196
  try:
197
  pil_img = load_image_any(image_input)
198
  except Exception as e:
199
  return {"error": f"Failed to load image: {e}"}
200
 
201
+ # Save image to logs (optional)
202
  img_hash, img_path = "NA", None
203
  try:
204
+ buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue()
205
+ img_hash = hashlib.md5(raw).hexdigest()
206
  t = datetime.datetime.now()
207
  img_path = os.path.join(LOGDIR, "serve_images", f"{t.year:04d}-{t.month:02d}-{t.day:02d}", f"{img_hash}.jpg")
208
  os.makedirs(os.path.dirname(img_path), exist_ok=True)
209
+ if not os.path.isfile(img_path):
210
+ pil_img.save(img_path)
211
  except Exception as e:
212
  print(f"[log] saving image failed: {e}")
213
 
214
+ # To device/dtype
215
  device = next(chatbot.model.parameters()).device
216
  dtype = next(chatbot.model.parameters()).dtype
217
+
218
+ # Preprocess image -> tensor (support 3D/4D/5D)
219
  try:
220
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
221
  if isinstance(processed, torch.Tensor):
222
  if processed.ndim == 3: image_tensor = processed.unsqueeze(0)
223
  elif processed.ndim == 4: image_tensor = processed
224
+ elif processed.ndim == 5: # (B,T,C,H,W) -> (B*T,C,H,W)
225
  b,t,c,h,w = processed.shape
226
  image_tensor = processed.reshape(b*t, c, h, w)
227
  else:
 
235
  except Exception as e:
236
  return {"error": f"Image processing failed: {e}"}
237
 
238
+ # Prompt & ids
239
  _, input_ids = _build_prompt_and_ids(chatbot, message_text, device)
 
240
 
241
+ # Stopping criteria
242
+ stopping = None if no_stop else _stopping_keywords(chatbot, input_ids, custom_stop)
243
+ eos_id = chatbot.tokenizer.eos_token_id
244
+ pad_id = chatbot.tokenizer.pad_token_id if chatbot.tokenizer.pad_token_id is not None else (eos_id if eos_id is not None else 0)
245
+ eos_for_gen = None if no_stop else eos_id
246
+
247
+ # Deterministic sampling (optional)
248
  if det_seed is not None:
249
  try:
250
  det_seed = int(det_seed)
 
255
  except Exception:
256
  pass
257
 
258
+ # Streamer (demo-like, avoids manual slicing)
259
+ streamer = TextIteratorStreamer(
260
+ chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
261
+ )
 
 
 
 
262
 
263
  gen_kwargs = dict(
264
  inputs=input_ids,
265
  images=image_tensor,
266
+ streamer=streamer,
267
  do_sample=True,
268
  temperature=float(temperature),
269
  top_p=float(top_p),
 
274
  eos_token_id=eos_for_gen,
275
  length_penalty=1.0,
276
  early_stopping=False,
277
+ stopping_criteria=None if no_stop else ([stopping] if stopping else None),
278
  )
279
 
280
+ if no_repeat_ngram_size:
 
281
  try:
282
+ n = int(no_repeat_ngram_size)
283
+ if n > 0:
284
+ gen_kwargs["no_repeat_ngram_size"] = n
 
285
  except Exception:
286
  pass
287
+
288
+ if min_new_tokens is not None:
289
  try:
290
  mn = int(min_new_tokens)
291
+ if 1 <= mn <= int(max_new_tokens):
292
  gen_kwargs["min_new_tokens"] = mn
293
  except Exception:
294
  pass
295
 
296
+ # Generate in a background thread; collect streamed tokens
297
  try:
298
+ t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
299
+ t.start()
300
+ chunks = []
301
+ for piece in streamer:
302
+ chunks.append(piece)
303
+ text = "".join(chunks)
304
  chatbot.conversation.messages[-1][-1] = text
305
  except Exception as e:
306
  return {"error": f"Generation failed: {e}"}
307
 
308
+ # Log
309
  try:
310
  row = {
311
  "time": datetime.datetime.now().isoformat(),
 
317
  }
318
  with open(_conv_log_path(), "a", encoding="utf-8") as f:
319
  f.write(json.dumps(row, ensure_ascii=False) + "\n")
320
+ _safe_upload(_conv_log_path()); _safe_upload(img_path or "")
321
  except Exception as e:
322
  print(f"[log] failed: {e}")
323
 
324
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
325
 
326
+ # ----------------- Public API -----------------
327
 
328
  def query(payload: dict):
329
+ """HF Endpoint entry (demo-like streaming)"""
330
  global model_initialized, tokenizer, model, image_processor, context_len, args
331
  if not model_initialized:
332
  if not initialize_model():
 
339
  if not message.strip(): return {"error": "Missing 'message' text"}
340
  if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
341
 
342
+ # Demo-like knobs
343
+ max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 1800))))
344
+ min_new_tokens = payload.get("min_new_tokens", None)
345
+ if min_new_tokens is not None:
346
+ try: min_new_tokens = int(min_new_tokens)
347
+ except Exception: min_new_tokens = None
348
+
349
+ temperature = float(payload.get("temperature", 0.20))
350
+ top_p = float(payload.get("top_p", 0.95))
351
+ repetition_penalty = float(payload.get("repetition_penalty", 1.20))
352
+ no_repeat_ngram = payload.get("no_repeat_ngram_size", 6)
353
+ try:
354
+ no_repeat_ngram = int(no_repeat_ngram) if no_repeat_ngram is not None else None
355
+ except Exception:
356
+ no_repeat_ngram = None
357
+
358
  conv_mode_override = payload.get("conv_mode", None)
359
  det_seed = payload.get("det_seed", None)
360
  if det_seed is not None:
361
  try: det_seed = int(det_seed)
362
  except Exception: det_seed = None
363
  no_stop = bool(payload.get("no_stop", False))
364
+ custom_stop = payload.get("custom_stop", None)
 
 
 
365
 
366
  return generate_response(
367
  message_text=message,
368
  image_input=image,
369
  max_new_tokens=max_new_tokens,
370
+ min_new_tokens=min_new_tokens,
371
  temperature=temperature,
372
  top_p=top_p,
373
  repetition_penalty=repetition_penalty,
374
+ no_repeat_ngram_size=no_repeat_ngram,
375
  conv_mode_override=conv_mode_override,
376
  det_seed=det_seed,
377
  no_stop=no_stop,
378
+ custom_stop=custom_stop,
379
  )
380
  except Exception as e:
381
  return {"error": f"Query failed: {e}"}
 
386
  "model_initialized": model_initialized,
387
  "cuda_available": torch.cuda.is_available(),
388
  "llava_available": LLAVA_AVAILABLE,
389
+ "transformers_available": TRANSFORMERS_AVAILABLE,
390
  }
391
 
392
  def get_model_info():
 
398
  "device": str(next(model.parameters()).device) if model else "Unknown",
399
  }
400
 
401
+ # ----------------- Init & Session -----------------
402
 
403
  class _Args:
404
  def __init__(self):
 
406
  self.model_base = None
407
  self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
408
  self.conv_mode = None
409
+ self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "1800"))
410
  self.num_frames = 16
411
  self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
412
  self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0")))
413
  self.debug = bool(int(os.getenv("DEBUG", "0")))
414
 
415
  class InferenceDemo:
416
+ def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_):
417
  if not LLAVA_AVAILABLE:
418
  raise ImportError("LLaVA modules not available")
419
  disable_torch_init()
420
  self.tokenizer, self.model, self.image_processor, self.context_len = (
421
+ tokenizer_, model_, image_processor_, context_len_
422
  )
423
  conv_mode_auto = _guess_conv_mode(model_path)
424
  self.conv_mode = args.conv_mode if args.conv_mode else conv_mode_auto
 
445
  def initialize_model():
446
  global tokenizer, model, image_processor, context_len, args
447
  if not LLAVA_AVAILABLE:
448
+ print("[init] LLaVA not available; cannot init.")
449
  return False
450
  try:
451
  args = _Args()
 
466
  print(f"[init] failed: {e}")
467
  return False
468
 
469
+ # ----------------- HF EndpointHandler -----------------
470
+
471
  class EndpointHandler:
472
+ """Hugging Face Endpoint uyumlu sınıf"""
473
  def __init__(self, model_dir):
474
  self.model_dir = model_dir
475
  print(f"EndpointHandler initialized with model_dir: {model_dir}")