CanerDedeoglu commited on
Commit
686dbcb
·
verified ·
1 Parent(s): 91e44e7

no stop update

Browse files
Files changed (1) hide show
  1. handler.py +68 -120
handler.py CHANGED
@@ -1,11 +1,12 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- PULSE ECG Handler - Demo-like (sampling) + no_stop bayrağı
4
- - Demo davranışı: do_sample=True, temperature/top_p payload'dan
5
- - max_new_tokens: payload/slider değeri (KIRPMA YOK, direkt kullanılır)
6
- - İsteğe bağlı: no_stop=True ile stopping_criteria devre dışı
7
- - Tek görsel işleme; IM_START/END otomatik; 3D/4D/5D tensör uyumlu
8
- - Çıktıya post-format/deduplicate UYGULANMAZ (demo ile bire bir)
 
9
  """
10
 
11
  import os
@@ -19,7 +20,7 @@ import torch
19
  from PIL import Image
20
  import requests
21
 
22
- # --- Opsiyonel bağımlılıklar ---
23
  try:
24
  import cv2
25
  CV2_AVAILABLE = True
@@ -27,7 +28,7 @@ except Exception:
27
  CV2_AVAILABLE = False
28
  print("Warning: OpenCV (cv2) not available; video is disabled.")
29
 
30
- # --- LLaVA / Transformers ---
31
  try:
32
  from llava.constants import (
33
  IMAGE_TOKEN_INDEX,
@@ -56,7 +57,6 @@ try:
56
  except Exception:
57
  HF_HUB_AVAILABLE = False
58
 
59
- # ------------- HF Hub init (opsiyonel) -------------
60
  api = None
61
  repo_name = ""
62
  if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
@@ -69,13 +69,11 @@ if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
69
  api = None
70
  repo_name = ""
71
 
72
- # ------------- Klasörler -------------
73
  LOGDIR = "./logs"
74
  VOTEDIR = "./votes"
75
  os.makedirs(LOGDIR, exist_ok=True)
76
  os.makedirs(VOTEDIR, exist_ok=True)
77
 
78
- # ------------- Global durum -------------
79
  tokenizer = None
80
  model = None
81
  image_processor = None
@@ -83,8 +81,6 @@ context_len = None
83
  args = None
84
  model_initialized = False
85
 
86
- # ------------- Yardımcılar -------------
87
-
88
  def _safe_upload(path: str):
89
  if api and repo_name and os.path.isfile(path):
90
  try:
@@ -100,17 +96,10 @@ def _safe_upload(path: str):
100
  def _conv_log_path():
101
  t = datetime.datetime.now()
102
  p = os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json")
103
- os.makedirs(os.path.dirname(p), exist_ok=True
104
- )
105
  return p
106
 
107
  def load_image_any(image_input):
108
- """
109
- Desteklenen formatlar:
110
- - URL (http/https)
111
- - Yerel dosya yolu
112
- - base64 (opsiyonel data URL prefix ile)
113
- """
114
  if isinstance(image_input, str):
115
  s = image_input.strip()
116
  if s.startswith(("http://", "https://")):
@@ -119,14 +108,10 @@ def load_image_any(image_input):
119
  return Image.open(BytesIO(r.content)).convert("RGB")
120
  if os.path.exists(s):
121
  return Image.open(s).convert("RGB")
122
- # base64
123
  if s.startswith("data:image"):
124
  s = s.split(",", 1)[1]
125
- try:
126
- raw = base64.b64decode(s)
127
- return Image.open(BytesIO(raw)).convert("RGB")
128
- except Exception as e:
129
- raise ValueError(f"Invalid image string (not URL/path/base64): {e}")
130
  elif isinstance(image_input, dict) and "image" in image_input:
131
  return load_image_any(image_input["image"])
132
  else:
@@ -134,14 +119,10 @@ def load_image_any(image_input):
134
 
135
  def _guess_conv_mode(model_path: str) -> str:
136
  name = get_model_name_from_path(model_path).lower()
137
- if "llama-2" in name:
138
- return "llava_llama_2"
139
- if "v1" in name or "pulse" in name:
140
- return "llava_v1"
141
- if "mpt" in name:
142
- return "mpt"
143
- if "qwen" in name:
144
- return "qwen_1_5"
145
  return "llava_v0"
146
 
147
  def _wrap_image_token_if_needed(model_cfg) -> bool:
@@ -150,19 +131,15 @@ def _wrap_image_token_if_needed(model_cfg) -> bool:
150
  except Exception:
151
  return False
152
 
153
- # ------------- Çekirdek üretim -------------
154
-
155
  def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
156
  use_wrap = _wrap_image_token_if_needed(chatbot.model.config)
157
  if use_wrap:
158
  inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
159
  else:
160
  inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
161
-
162
  chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
163
  chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
164
  prompt = chatbot.conversation.get_prompt()
165
-
166
  input_ids = tokenizer_image_token(
167
  prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
168
  ).unsqueeze(0).to(device)
@@ -184,54 +161,45 @@ def generate_response(
184
  conv_mode_override: str | None = None,
185
  det_seed: int | None = None,
186
  no_stop: bool = False,
187
- min_new_tokens: int | None = None, # opsiyonel, uzunluğu zorlamak istersen
188
  ):
189
  if not LLAVA_AVAILABLE:
190
  return {"error": "LLaVA modules not available"}
191
-
192
  if not message_text or image_input is None:
193
  return {"error": "Both 'message' and 'image' are required"}
194
 
195
- # Chatbot/konuşma hazırla (her çağrıda sıfırdan, demo gibi)
196
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
197
  if conv_mode_override and conv_mode_override in conv_templates:
198
  chatbot.conversation = conv_templates[conv_mode_override].copy()
199
  else:
200
  chatbot.conversation = conv_templates[chatbot.conv_mode].copy()
201
 
202
- # Görüntüyü yükle
203
  try:
204
  pil_img = load_image_any(image_input)
205
  except Exception as e:
206
  return {"error": f"Failed to load image: {e}"}
207
 
208
- # Log için kaydet
209
- img_hash = "NA"
210
- img_path = None
211
  try:
212
- buf = BytesIO()
213
- pil_img.save(buf, format="JPEG")
214
- img_bytes = buf.getvalue()
215
  img_hash = hashlib.md5(img_bytes).hexdigest()
216
  t = datetime.datetime.now()
217
  img_path = os.path.join(LOGDIR, "serve_images", f"{t.year:04d}-{t.month:02d}-{t.day:02d}", f"{img_hash}.jpg")
218
  os.makedirs(os.path.dirname(img_path), exist_ok=True)
219
- if not os.path.isfile(img_path):
220
- pil_img.save(img_path)
221
  except Exception as e:
222
  print(f"[log] saving image failed: {e}")
223
 
224
- # Görüntüyü tensöre çevir
225
  device = next(chatbot.model.parameters()).device
226
  dtype = next(chatbot.model.parameters()).dtype
227
  try:
228
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
229
  if isinstance(processed, torch.Tensor):
230
- if processed.ndim == 3: # (C,H,W)
231
- image_tensor = processed.unsqueeze(0)
232
- elif processed.ndim == 4: # (B,C,H,W)
233
- image_tensor = processed
234
- elif processed.ndim == 5: # (B,T,C,H,W) -> (B*T,C,H,W)
235
  b,t,c,h,w = processed.shape
236
  image_tensor = processed.reshape(b*t, c, h, w)
237
  else:
@@ -241,17 +209,15 @@ def generate_response(
241
  image_tensor = first.unsqueeze(0) if isinstance(first, torch.Tensor) and first.ndim == 3 else first
242
  else:
243
  return {"error": "Image processing returned empty"}
244
-
245
- # Demo tarafında half + to(device) kalıbı yaygın
246
  image_tensor = image_tensor.to(device=device, dtype=dtype)
247
  except Exception as e:
248
  return {"error": f"Image processing failed: {e}"}
249
 
250
- # Prompt & tokenizasyon
251
- prompt, input_ids = _build_prompt_and_ids(chatbot, message_text, device)
252
  stopping = None if no_stop else _stopping(chatbot, input_ids)
253
 
254
- # (opsiyonel) deterministik sampling
255
  if det_seed is not None:
256
  try:
257
  det_seed = int(det_seed)
@@ -262,15 +228,15 @@ def generate_response(
262
  except Exception:
263
  pass
264
 
265
- # EOS/PAD güvenli al
266
- eos_id = chatbot.tokenizer.eos_token_id
267
- if eos_id is None:
268
- try:
269
- eos_id = chatbot.tokenizer.convert_tokens_to_ids("</s>")
270
- except Exception:
271
- eos_id = 0
 
272
 
273
- # generate kwargs (demo-like)
274
  gen_kwargs = dict(
275
  inputs=input_ids,
276
  images=image_tensor,
@@ -278,15 +244,25 @@ def generate_response(
278
  temperature=float(temperature),
279
  top_p=float(top_p),
280
  repetition_penalty=float(repetition_penalty),
281
- max_new_tokens=int(max_new_tokens), # KIRPMA YOK
282
  use_cache=False,
283
- pad_token_id=eos_id,
284
- eos_token_id=eos_id,
285
  length_penalty=1.0,
286
  early_stopping=False,
287
  stopping_criteria=None if no_stop else [stopping],
288
  )
289
- if min_new_tokens is not None:
 
 
 
 
 
 
 
 
 
 
290
  try:
291
  mn = int(min_new_tokens)
292
  if mn > 0 and mn <= int(max_new_tokens):
@@ -294,19 +270,17 @@ def generate_response(
294
  except Exception:
295
  pass
296
 
297
- # Üretim
298
  try:
299
  with torch.no_grad():
300
  outputs = chatbot.model.generate(**gen_kwargs)
301
  gen = outputs[0][input_ids.shape[1]:]
302
- text = chatbot.tokenizer.decode(gen, skip_special_tokens=True)
303
-
304
- # Konuşmaya yerleştir (demo gibi)
305
  chatbot.conversation.messages[-1][-1] = text
306
  except Exception as e:
307
  return {"error": f"Generation failed: {e}"}
308
 
309
- # Log yaz
310
  try:
311
  row = {
312
  "time": datetime.datetime.now().isoformat(),
@@ -318,20 +292,17 @@ def generate_response(
318
  }
319
  with open(_conv_log_path(), "a", encoding="utf-8") as f:
320
  f.write(json.dumps(row, ensure_ascii=False) + "\n")
321
- _safe_upload(_conv_log_path())
322
- if img_path:
323
- _safe_upload(img_path)
324
  except Exception as e:
325
  print(f"[log] failed: {e}")
326
 
327
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
328
 
329
- # ------------- API Yüzeyi -------------
330
 
331
  def query(payload: dict):
332
- """HF Endpoint ana giriş noktası (demo uyumlu)"""
333
  global model_initialized, tokenizer, model, image_processor, context_len, args
334
-
335
  if not model_initialized:
336
  if not initialize_model():
337
  return {"error": "Model initialization failed"}
@@ -340,37 +311,23 @@ def query(payload: dict):
340
  try:
341
  message = payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or ""
342
  image = payload.get("image") or payload.get("image_url") or payload.get("img") or None
 
 
343
 
344
- if not message.strip():
345
- return {"error": "Missing 'message' text"}
346
- if image is None:
347
- return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
348
-
349
- # Demo: slider benzeri parametreler
350
  max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096))))
351
  temperature = float(payload.get("temperature", 0.05))
352
  top_p = float(payload.get("top_p", 1.0))
353
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
354
  conv_mode_override = payload.get("conv_mode", None)
355
-
356
- # (Opsiyonel) deterministik sample için seed
357
- det_seed = payload.get("det_seed", None)
358
  if det_seed is not None:
359
- try:
360
- det_seed = int(det_seed)
361
- except Exception:
362
- det_seed = None
363
-
364
- # (Yeni) stopping_criteria kapatma bayrağı
365
- no_stop = bool(payload.get("no_stop", False))
366
-
367
- # (Opsiyonel) min_new_tokens
368
- mnt = payload.get("min_new_tokens", None)
369
  if mnt is not None:
370
- try:
371
- mnt = int(mnt)
372
- except Exception:
373
- mnt = None
374
 
375
  return generate_response(
376
  message_text=message,
@@ -405,7 +362,7 @@ def get_model_info():
405
  "device": str(next(model.parameters()).device) if model else "Unknown",
406
  }
407
 
408
- # ------------- Model init -------------
409
 
410
  class _Args:
411
  def __init__(self):
@@ -428,11 +385,8 @@ class InferenceDemo:
428
  tokenizer, model, image_processor, context_len
429
  )
430
  conv_mode_auto = _guess_conv_mode(model_path)
431
- if args.conv_mode and args.conv_mode != conv_mode_auto:
432
- self.conv_mode = args.conv_mode
433
- else:
434
- self.conv_mode = conv_mode_auto
435
- args.conv_mode = conv_mode_auto
436
  self.conversation = conv_templates[self.conv_mode].copy()
437
  self.num_frames = args.num_frames
438
 
@@ -453,7 +407,6 @@ class ChatSessionManager:
453
  chat_manager = ChatSessionManager()
454
 
455
  def initialize_model():
456
- """Modeli yükle (lazy)"""
457
  global tokenizer, model, image_processor, context_len, args
458
  if not LLAVA_AVAILABLE:
459
  print("LLaVA not available; cannot init.")
@@ -464,14 +417,12 @@ def initialize_model():
464
  tokenizer, model, image_processor, context_len = load_pretrained_model(
465
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
466
  )
467
- # Cihaz
468
  try:
469
  _ = next(model.parameters()).device
470
  except Exception:
471
  if torch.cuda.is_available():
472
  model = model.to(torch.device("cuda"))
473
  model.eval()
474
- # Chatbot init
475
  chat_manager.init_if_needed(args, args.model_path, tokenizer, model, image_processor, context_len)
476
  print("[init] model/tokenizer/image_processor loaded.")
477
  return True
@@ -479,10 +430,7 @@ def initialize_model():
479
  print(f"[init] failed: {e}")
480
  return False
481
 
482
- # ------------- HF EndpointHandler -------------
483
-
484
  class EndpointHandler:
485
- """Hugging Face Endpoint uyumlu sınıf"""
486
  def __init__(self, model_dir):
487
  self.model_dir = model_dir
488
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ PULSE ECG Handler - Demo-like sampling + no_stop (hard) + min_new_tokens auto
4
+ - do_sample=True, temperature/top_p payload'dan
5
+ - max_new_tokens: payload değeri (kırpma yok)
6
+ - no_stop=True: stopping_criteria KAPALI + eos_token_id=None
7
+ - no_stop=True ve min_new_tokens boşsa: otomatik min_new_tokens (uzun yanıt garantisi)
8
+ - Tek görsel; IM_START/END otomatik; 3D/4D/5D tensör uyumlu
9
+ - Post-format yok (demo davranışı)
10
  """
11
 
12
  import os
 
20
  from PIL import Image
21
  import requests
22
 
23
+ # --- Opsiyonel ---
24
  try:
25
  import cv2
26
  CV2_AVAILABLE = True
 
28
  CV2_AVAILABLE = False
29
  print("Warning: OpenCV (cv2) not available; video is disabled.")
30
 
31
+ # --- LLaVA ---
32
  try:
33
  from llava.constants import (
34
  IMAGE_TOKEN_INDEX,
 
57
  except Exception:
58
  HF_HUB_AVAILABLE = False
59
 
 
60
  api = None
61
  repo_name = ""
62
  if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
 
69
  api = None
70
  repo_name = ""
71
 
 
72
  LOGDIR = "./logs"
73
  VOTEDIR = "./votes"
74
  os.makedirs(LOGDIR, exist_ok=True)
75
  os.makedirs(VOTEDIR, exist_ok=True)
76
 
 
77
  tokenizer = None
78
  model = None
79
  image_processor = None
 
81
  args = None
82
  model_initialized = False
83
 
 
 
84
  def _safe_upload(path: str):
85
  if api and repo_name and os.path.isfile(path):
86
  try:
 
96
  def _conv_log_path():
97
  t = datetime.datetime.now()
98
  p = os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json")
99
+ os.makedirs(os.path.dirname(p), exist_ok=True)
 
100
  return p
101
 
102
  def load_image_any(image_input):
 
 
 
 
 
 
103
  if isinstance(image_input, str):
104
  s = image_input.strip()
105
  if s.startswith(("http://", "https://")):
 
108
  return Image.open(BytesIO(r.content)).convert("RGB")
109
  if os.path.exists(s):
110
  return Image.open(s).convert("RGB")
 
111
  if s.startswith("data:image"):
112
  s = s.split(",", 1)[1]
113
+ raw = base64.b64decode(s)
114
+ return Image.open(BytesIO(raw)).convert("RGB")
 
 
 
115
  elif isinstance(image_input, dict) and "image" in image_input:
116
  return load_image_any(image_input["image"])
117
  else:
 
119
 
120
  def _guess_conv_mode(model_path: str) -> str:
121
  name = get_model_name_from_path(model_path).lower()
122
+ if "llama-2" in name: return "llava_llama_2"
123
+ if "v1" in name or "pulse" in name: return "llava_v1"
124
+ if "mpt" in name: return "mpt"
125
+ if "qwen" in name: return "qwen_1_5"
 
 
 
 
126
  return "llava_v0"
127
 
128
  def _wrap_image_token_if_needed(model_cfg) -> bool:
 
131
  except Exception:
132
  return False
133
 
 
 
134
  def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
135
  use_wrap = _wrap_image_token_if_needed(chatbot.model.config)
136
  if use_wrap:
137
  inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
138
  else:
139
  inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
 
140
  chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
141
  chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
142
  prompt = chatbot.conversation.get_prompt()
 
143
  input_ids = tokenizer_image_token(
144
  prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
145
  ).unsqueeze(0).to(device)
 
161
  conv_mode_override: str | None = None,
162
  det_seed: int | None = None,
163
  no_stop: bool = False,
164
+ min_new_tokens: int | None = None, # otomatik atanabilir
165
  ):
166
  if not LLAVA_AVAILABLE:
167
  return {"error": "LLaVA modules not available"}
 
168
  if not message_text or image_input is None:
169
  return {"error": "Both 'message' and 'image' are required"}
170
 
 
171
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
172
  if conv_mode_override and conv_mode_override in conv_templates:
173
  chatbot.conversation = conv_templates[conv_mode_override].copy()
174
  else:
175
  chatbot.conversation = conv_templates[chatbot.conv_mode].copy()
176
 
 
177
  try:
178
  pil_img = load_image_any(image_input)
179
  except Exception as e:
180
  return {"error": f"Failed to load image: {e}"}
181
 
182
+ # log
183
+ img_hash, img_path = "NA", None
 
184
  try:
185
+ buf = BytesIO(); pil_img.save(buf, format="JPEG"); img_bytes = buf.getvalue()
 
 
186
  img_hash = hashlib.md5(img_bytes).hexdigest()
187
  t = datetime.datetime.now()
188
  img_path = os.path.join(LOGDIR, "serve_images", f"{t.year:04d}-{t.month:02d}-{t.day:02d}", f"{img_hash}.jpg")
189
  os.makedirs(os.path.dirname(img_path), exist_ok=True)
190
+ if not os.path.isfile(img_path): pil_img.save(img_path)
 
191
  except Exception as e:
192
  print(f"[log] saving image failed: {e}")
193
 
194
+ # görüntü tensörü
195
  device = next(chatbot.model.parameters()).device
196
  dtype = next(chatbot.model.parameters()).dtype
197
  try:
198
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
199
  if isinstance(processed, torch.Tensor):
200
+ if processed.ndim == 3: image_tensor = processed.unsqueeze(0)
201
+ elif processed.ndim == 4: image_tensor = processed
202
+ elif processed.ndim == 5:
 
 
203
  b,t,c,h,w = processed.shape
204
  image_tensor = processed.reshape(b*t, c, h, w)
205
  else:
 
209
  image_tensor = first.unsqueeze(0) if isinstance(first, torch.Tensor) and first.ndim == 3 else first
210
  else:
211
  return {"error": "Image processing returned empty"}
 
 
212
  image_tensor = image_tensor.to(device=device, dtype=dtype)
213
  except Exception as e:
214
  return {"error": f"Image processing failed: {e}"}
215
 
216
+ # prompt & ids
217
+ _, input_ids = _build_prompt_and_ids(chatbot, message_text, device)
218
  stopping = None if no_stop else _stopping(chatbot, input_ids)
219
 
220
+ # deterministik sample (opsiyonel)
221
  if det_seed is not None:
222
  try:
223
  det_seed = int(det_seed)
 
228
  except Exception:
229
  pass
230
 
231
+ # EOS/PAD
232
+ eos_id = tokenizer.eos_token_id
233
+ pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else (eos_id if eos_id is not None else 0)
234
+ # no_stop=True ise eos'a göre durmayı tamamen kapat
235
+ if no_stop:
236
+ eos_for_gen = None
237
+ else:
238
+ eos_for_gen = eos_id
239
 
 
240
  gen_kwargs = dict(
241
  inputs=input_ids,
242
  images=image_tensor,
 
244
  temperature=float(temperature),
245
  top_p=float(top_p),
246
  repetition_penalty=float(repetition_penalty),
247
+ max_new_tokens=int(max_new_tokens),
248
  use_cache=False,
249
+ pad_token_id=pad_id,
250
+ eos_token_id=eos_for_gen,
251
  length_penalty=1.0,
252
  early_stopping=False,
253
  stopping_criteria=None if no_stop else [stopping],
254
  )
255
+
256
+ # min_new_tokens otomatik (no_stop=True ve kullanıcı vermediyse)
257
+ if no_stop and (min_new_tokens is None):
258
+ try:
259
+ req = int(max_new_tokens)
260
+ auto_min = max(300, min(req - 64, 1024)) # 300–1024 bandında güvenli
261
+ if auto_min > 0:
262
+ gen_kwargs["min_new_tokens"] = auto_min
263
+ except Exception:
264
+ pass
265
+ elif min_new_tokens is not None:
266
  try:
267
  mn = int(min_new_tokens)
268
  if mn > 0 and mn <= int(max_new_tokens):
 
270
  except Exception:
271
  pass
272
 
273
+ # generate
274
  try:
275
  with torch.no_grad():
276
  outputs = chatbot.model.generate(**gen_kwargs)
277
  gen = outputs[0][input_ids.shape[1]:]
278
+ text = tokenizer.decode(gen, skip_special_tokens=True)
 
 
279
  chatbot.conversation.messages[-1][-1] = text
280
  except Exception as e:
281
  return {"error": f"Generation failed: {e}"}
282
 
283
+ # log
284
  try:
285
  row = {
286
  "time": datetime.datetime.now().isoformat(),
 
292
  }
293
  with open(_conv_log_path(), "a", encoding="utf-8") as f:
294
  f.write(json.dumps(row, ensure_ascii=False) + "\n")
295
+ _safe_upload(_conv_log_path()); _safe_upload(img_path or "")
 
 
296
  except Exception as e:
297
  print(f"[log] failed: {e}")
298
 
299
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
300
 
301
+ # -------- API --------
302
 
303
  def query(payload: dict):
304
+ """HF Endpoint entry (demo-like)"""
305
  global model_initialized, tokenizer, model, image_processor, context_len, args
 
306
  if not model_initialized:
307
  if not initialize_model():
308
  return {"error": "Model initialization failed"}
 
311
  try:
312
  message = payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or ""
313
  image = payload.get("image") or payload.get("image_url") or payload.get("img") or None
314
+ if not message.strip(): return {"error": "Missing 'message' text"}
315
+ if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
316
 
 
 
 
 
 
 
317
  max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096))))
318
  temperature = float(payload.get("temperature", 0.05))
319
  top_p = float(payload.get("top_p", 1.0))
320
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
321
  conv_mode_override = payload.get("conv_mode", None)
322
+ det_seed = payload.get("det_seed", None)
 
 
323
  if det_seed is not None:
324
+ try: det_seed = int(det_seed)
325
+ except Exception: det_seed = None
326
+ no_stop = bool(payload.get("no_stop", False))
327
+ mnt = payload.get("min_new_tokens", None)
 
 
 
 
 
 
328
  if mnt is not None:
329
+ try: mnt = int(mnt)
330
+ except Exception: mnt = None
 
 
331
 
332
  return generate_response(
333
  message_text=message,
 
362
  "device": str(next(model.parameters()).device) if model else "Unknown",
363
  }
364
 
365
+ # -------- init --------
366
 
367
  class _Args:
368
  def __init__(self):
 
385
  tokenizer, model, image_processor, context_len
386
  )
387
  conv_mode_auto = _guess_conv_mode(model_path)
388
+ self.conv_mode = args.conv_mode if args.conv_mode else conv_mode_auto
389
+ args.conv_mode = self.conv_mode
 
 
 
390
  self.conversation = conv_templates[self.conv_mode].copy()
391
  self.num_frames = args.num_frames
392
 
 
407
  chat_manager = ChatSessionManager()
408
 
409
  def initialize_model():
 
410
  global tokenizer, model, image_processor, context_len, args
411
  if not LLAVA_AVAILABLE:
412
  print("LLaVA not available; cannot init.")
 
417
  tokenizer, model, image_processor, context_len = load_pretrained_model(
418
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
419
  )
 
420
  try:
421
  _ = next(model.parameters()).device
422
  except Exception:
423
  if torch.cuda.is_available():
424
  model = model.to(torch.device("cuda"))
425
  model.eval()
 
426
  chat_manager.init_if_needed(args, args.model_path, tokenizer, model, image_processor, context_len)
427
  print("[init] model/tokenizer/image_processor loaded.")
428
  return True
 
430
  print(f"[init] failed: {e}")
431
  return False
432
 
 
 
433
  class EndpointHandler:
 
434
  def __init__(self, model_dir):
435
  self.model_dir = model_dir
436
  print(f"EndpointHandler initialized with model_dir: {model_dir}")