CanerDedeoglu commited on
Commit
12fd62f
·
verified ·
1 Parent(s): 1b1a09c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +317 -112
handler.py CHANGED
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- # handler.py — PULSE-7B / LLaVA robust endpoint
3
- # - Safe decode (empty output fix)
4
- # - PAD/EOS safety
5
- # - Hugging Face endpoint compatible
6
-
7
  import os
8
  import datetime
9
  import torch
@@ -49,7 +43,7 @@ except ImportError as e:
49
 
50
  # Try to import transformers
51
  try:
52
- from transformers import TextStreamer, TextIteratorStreamer, GenerationConfig
53
  TRANSFORMERS_AVAILABLE = True
54
  except ImportError:
55
  TRANSFORMERS_AVAILABLE = False
@@ -81,7 +75,7 @@ external_log_dir = "./logs"
81
  LOGDIR = external_log_dir
82
  VOTEDIR = "./votes"
83
 
84
- # Global variables
85
  tokenizer = None
86
  model = None
87
  image_processor = None
@@ -121,7 +115,7 @@ def vote_last_response(state, vote_type, model_selector):
121
 
122
  def is_valid_video_filename(name):
123
  if not CV2_AVAILABLE:
124
- return False
125
  video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
126
  ext = name.split(".")[-1].lower()
127
  return ext in video_extensions
@@ -133,7 +127,8 @@ def is_valid_image_filename(name):
133
 
134
  def sample_frames(video_file, num_frames):
135
  if not CV2_AVAILABLE:
136
- raise ImportError("cv2 not available")
 
137
  video = cv2.VideoCapture(video_file)
138
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
139
  interval = total_frames // num_frames
@@ -149,32 +144,46 @@ def sample_frames(video_file, num_frames):
149
  return frames
150
 
151
  def load_image(image_file):
152
- if image_file.startswith("http"):
153
  response = requests.get(image_file)
154
  if response.status_code == 200:
155
  image = Image.open(BytesIO(response.content)).convert("RGB")
156
  else:
157
  raise ValueError("Failed to load image from URL")
158
  else:
 
 
159
  image = Image.open(image_file).convert("RGB")
160
  return image
161
 
162
  def process_base64_image(base64_string):
163
- if base64_string.startswith('data:image'):
164
- base64_string = base64_string.split(',')[1]
165
- image_data = base64.b64decode(base64_string)
166
- image = Image.open(BytesIO(image_data)).convert("RGB")
167
- return image
 
 
 
 
 
 
 
 
 
168
 
169
  def process_image_input(image_input):
 
170
  if isinstance(image_input, str):
171
  if image_input.startswith("http"):
172
  return load_image(image_input)
173
  elif os.path.exists(image_input):
174
  return load_image(image_input)
175
  else:
 
176
  return process_base64_image(image_input)
177
  elif isinstance(image_input, dict) and "image" in image_input:
 
178
  return process_base64_image(image_input["image"])
179
  else:
180
  raise ValueError("Unsupported image input format")
@@ -185,9 +194,14 @@ class InferenceDemo(object):
185
  raise ImportError("LLaVA modules not available")
186
 
187
  disable_torch_init()
 
188
  self.tokenizer, self.model, self.image_processor, self.context_len = (
189
- tokenizer, model, image_processor, context_len
 
 
 
190
  )
 
191
  model_name = get_model_name_from_path(model_path)
192
  if "llama-2" in model_name.lower():
193
  conv_mode = "llava_llama_2"
@@ -199,8 +213,13 @@ class InferenceDemo(object):
199
  conv_mode = "qwen_1_5"
200
  else:
201
  conv_mode = "llava_v0"
 
202
  if args.conv_mode is not None and conv_mode != args.conv_mode:
203
- print(f"[WARNING] auto inferred conv_mode={conv_mode}, using {args.conv_mode}")
 
 
 
 
204
  else:
205
  args.conv_mode = conv_mode
206
  self.conv_mode = conv_mode
@@ -210,11 +229,14 @@ class InferenceDemo(object):
210
  class ChatSessionManager:
211
  def __init__(self):
212
  self.chatbot_instance = None
 
213
  def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
214
  self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
215
  print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
 
216
  def reset_chatbot(self):
217
  self.chatbot_instance = None
 
218
  def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
219
  if self.chatbot_instance is None:
220
  self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
@@ -223,139 +245,242 @@ class ChatSessionManager:
223
  chat_manager = ChatSessionManager()
224
 
225
  def clear_history():
 
226
  if not LLAVA_AVAILABLE:
227
- return {"error": "LLaVA not available"}
 
228
  try:
229
  chatbot_instance = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len)
230
- mode = getattr(chatbot_instance, 'conv_mode', None)
231
- if mode and LLAVA_AVAILABLE and mode in conv_templates:
232
- chatbot_instance.conversation = conv_templates[mode].copy()
233
- else:
234
- chatbot_instance.conversation = chatbot_instance.conversation.__class__()
235
- return {"status": "success", "message": "Conversation cleared"}
 
 
 
236
  except Exception as e:
237
  return {"error": f"Failed to clear history: {str(e)}"}
238
 
239
- def _strip_prefix_relaxed(text: str, prefix: str) -> str:
240
- try:
241
- if text.startswith(prefix):
242
- return text[len(prefix):]
243
- t_norm = " ".join(text.split())
244
- p_norm = " ".join(prefix.split())
245
- if t_norm.startswith(p_norm):
246
- idx = text.find(prefix.splitlines()[0]) if prefix.splitlines() else -1
247
- if idx >= 0:
248
- return text[idx + len(prefix.splitlines()[0]):]
249
- except Exception:
250
- pass
251
- return text
252
 
253
  def generate_response(message_text, image_input, temperature=0.05, top_p=1.0, max_output_tokens=4096, repetition_penalty=1.0, conv_mode_override=None):
 
254
  if not LLAVA_AVAILABLE:
255
- return {"error": "LLaVA not available"}
 
256
  try:
257
  if not message_text or not image_input:
258
- return {"error": "Both message and image required"}
259
 
260
  our_chatbot = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len)
261
- image = process_image_input(image_input)
262
 
 
 
 
 
 
 
 
 
 
 
 
263
  img_byte_arr = BytesIO()
264
  image.save(img_byte_arr, format='JPEG')
265
- image_hash = hashlib.md5(img_byte_arr.getvalue()).hexdigest()
 
 
 
 
266
  t = datetime.datetime.now()
267
- filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{image_hash}.jpg")
268
- os.makedirs(os.path.dirname(filename), exist_ok=True)
269
- image.save(filename)
 
 
 
 
 
 
 
 
270
 
271
- processed_images = process_images([image], our_chatbot.image_processor, our_chatbot.model.config)
272
- image_tensor = processed_images[0].half().to(our_chatbot.model.device).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
- if conv_mode_override:
275
- our_chatbot.conversation = conv_templates[conv_mode_override].copy()
276
- else:
277
- our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
 
 
 
 
 
 
278
 
279
  inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text
280
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
281
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
282
  prompt = our_chatbot.conversation.get_prompt()
283
 
284
- input_ids = tokenizer_image_token(prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(our_chatbot.model.device)
 
 
 
285
 
286
- stop_str = our_chatbot.conversation.sep if our_chatbot.conversation.sep_style != SeparatorStyle.TWO else our_chatbot.conversation.sep2
287
- stopping_criteria = KeywordsStoppingCriteria([stop_str], our_chatbot.tokenizer, input_ids)
288
-
289
- pad_id = our_chatbot.tokenizer.pad_token_id
290
- eos_id = our_chatbot.tokenizer.eos_token_id if our_chatbot.tokenizer.eos_token_id is not None else pad_id
291
- gen_cfg = GenerationConfig(
292
- do_sample=True, temperature=float(temperature), top_p=float(top_p),
293
- max_new_tokens=int(max_output_tokens), repetition_penalty=float(repetition_penalty),
294
- pad_token_id=pad_id, eos_token_id=eos_id
295
  )
296
 
 
297
  with torch.no_grad():
298
  outputs = our_chatbot.model.generate(
299
  inputs=input_ids,
300
  images=image_tensor,
301
- generation_config=gen_cfg,
302
- use_cache=True,
 
 
 
 
303
  stopping_criteria=[stopping_criteria],
304
- return_dict_in_generate=True
305
  )
306
 
307
- sequences = outputs.sequences
308
- gen_ids = sequences[0]
309
- full_text = our_chatbot.tokenizer.decode(gen_ids, skip_special_tokens=True)
310
- prompt_text = our_chatbot.tokenizer.decode(input_ids[0], skip_special_tokens=True)
311
-
312
- if gen_ids.shape[0] > input_ids.shape[1]:
313
- response = our_chatbot.tokenizer.decode(gen_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
314
- else:
315
- response = _strip_prefix_relaxed(full_text, prompt_text).strip()
316
- if not response:
317
- response = full_text.replace(stop_str, "").strip()
318
-
319
- our_chatbot.conversation.messages[-1][-1] = response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
 
321
  history = [(message_text, response)]
322
  with open(get_conv_log_filename(), "a") as fout:
323
- fout.write(json.dumps({
324
- "type": "chat", "model": "PULSE-7b", "state": history,
325
- "images": [image_hash], "images_path": [filename]
326
- }) + "\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
- return {"status": "success", "response": response, "conversation_id": id(our_chatbot.conversation)}
329
  except Exception as e:
330
  return {"error": f"Generation failed: {str(e)}"}
331
 
332
  def upvote_last_response(conversation_id):
 
333
  try:
334
  vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B")
335
- return {"status": "success", "message": "Upvoted"}
336
  except Exception as e:
337
- return {"error": str(e)}
338
 
339
  def downvote_last_response(conversation_id):
 
340
  try:
341
  vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B")
342
- return {"status": "success", "message": "Downvoted"}
343
  except Exception as e:
344
- return {"error": str(e)}
345
 
346
  def flag_response(conversation_id):
 
347
  try:
348
  vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B")
349
- return {"status": "success", "message": "Flagged"}
350
  except Exception as e:
351
- return {"error": str(e)}
352
 
 
353
  def initialize_model():
 
354
  global tokenizer, model, image_processor, context_len, args
 
355
  if not LLAVA_AVAILABLE:
356
- print("LLaVA not available")
357
  return False
 
358
  try:
 
359
  class Args:
360
  def __init__(self):
361
  self.model_path = "PULSE-ECG/PULSE-7B"
@@ -368,45 +493,95 @@ def initialize_model():
368
  self.load_8bit = False
369
  self.load_4bit = False
370
  self.debug = False
 
371
  args = Args()
372
- tok, mdl, img_proc, ctx_len = load_pretrained_model(args.model_path, args.model_base, get_model_name_from_path(args.model_path), args.load_8bit, args.load_4bit)
373
- if tok.eos_token_id is None:
374
- tok.add_special_tokens({"eos_token": "</s>"})
375
- if tok.pad_token_id is None:
376
- tok.pad_token = tok.eos_token
377
- tokenizer, model, image_processor, context_len = tok, mdl, img_proc, ctx_len
 
 
 
 
 
 
378
  if torch.cuda.is_available():
379
  model = model.to(torch.device('cuda'))
 
 
 
 
380
  return True
 
381
  except Exception as e:
382
- print(f"Init model fail: {e}")
383
  return False
384
 
 
385
  model_initialized = False
386
 
 
387
  def query(payload):
 
388
  global model_initialized
 
 
389
  if not model_initialized:
 
390
  model_initialized = initialize_model()
391
  if not model_initialized:
392
- return {"error": "Model init failed"}
 
393
  try:
394
- message_text = payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or ""
395
- image_input = payload.get("image") or payload.get("image_url") or payload.get("img") or None
 
 
 
 
 
 
 
 
 
 
 
 
396
  temperature = float(payload.get("temperature", 0.05))
397
  top_p = float(payload.get("top_p", 1.0))
398
- max_output_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096))))
 
 
399
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
400
  conv_mode_override = payload.get("conv_mode", None)
401
- if not message_text.strip():
402
- return {"error": "Missing prompt text"}
 
 
403
  if not image_input:
404
- return {"error": "Missing image"}
405
- return generate_response(message_text, image_input, temperature, top_p, max_output_tokens, repetition_penalty, conv_mode_override)
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  except Exception as e:
407
- return {"error": str(e)}
408
 
 
409
  def health_check():
 
410
  return {
411
  "status": "healthy",
412
  "model_initialized": model_initialized,
@@ -414,26 +589,56 @@ def health_check():
414
  "llava_available": LLAVA_AVAILABLE,
415
  "transformers_available": TRANSFORMERS_AVAILABLE,
416
  "cv2_available": CV2_AVAILABLE,
417
- "lazy_loading": True
418
  }
419
 
420
  def get_model_info():
 
421
  if not model_initialized:
422
- return {"error": "Not initialized", "lazy_loading": True}
423
- return {"model_path": args.model_path if args else "Unknown", "model_type": "PULSE-7B", "cuda_available": torch.cuda.is_available(), "device": str(model.device) if model else "Unknown"}
 
 
 
 
 
 
 
 
 
 
424
 
 
425
  class EndpointHandler:
 
 
426
  def __init__(self, model_dir):
 
427
  self.model_dir = model_dir
428
- print(f"Handler init with model_dir={model_dir}")
 
429
  def __call__(self, payload):
 
 
430
  if "inputs" in payload:
431
- return query(payload["inputs"])
432
- return query(payload)
 
 
 
 
 
433
  def health_check(self):
 
434
  return health_check()
 
435
  def get_model_info(self):
 
436
  return get_model_info()
437
 
 
438
  if __name__ == "__main__":
439
- print("Handler loaded and ready.")
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import datetime
3
  import torch
 
43
 
44
  # Try to import transformers
45
  try:
46
+ from transformers import TextStreamer, TextIteratorStreamer
47
  TRANSFORMERS_AVAILABLE = True
48
  except ImportError:
49
  TRANSFORMERS_AVAILABLE = False
 
75
  LOGDIR = external_log_dir
76
  VOTEDIR = "./votes"
77
 
78
+ # Global variables for model and tokenizer
79
  tokenizer = None
80
  model = None
81
  image_processor = None
 
115
 
116
  def is_valid_video_filename(name):
117
  if not CV2_AVAILABLE:
118
+ return False # Video processing disabled
119
  video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
120
  ext = name.split(".")[-1].lower()
121
  return ext in video_extensions
 
127
 
128
  def sample_frames(video_file, num_frames):
129
  if not CV2_AVAILABLE:
130
+ raise ImportError("cv2 (OpenCV) not available. Video processing is disabled.")
131
+
132
  video = cv2.VideoCapture(video_file)
133
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
134
  interval = total_frames // num_frames
 
144
  return frames
145
 
146
  def load_image(image_file):
147
+ if image_file.startswith("http") or image_file.startswith("https"):
148
  response = requests.get(image_file)
149
  if response.status_code == 200:
150
  image = Image.open(BytesIO(response.content)).convert("RGB")
151
  else:
152
  raise ValueError("Failed to load image from URL")
153
  else:
154
+ print("Load image from local file")
155
+ print(image_file)
156
  image = Image.open(image_file).convert("RGB")
157
  return image
158
 
159
  def process_base64_image(base64_string):
160
+ """Process base64 encoded image string"""
161
+ try:
162
+ # Remove data URL prefix if present
163
+ if base64_string.startswith('data:image'):
164
+ base64_string = base64_string.split(',')[1]
165
+
166
+ # Decode base64 to bytes
167
+ image_data = base64.b64decode(base64_string)
168
+
169
+ # Convert to PIL Image
170
+ image = Image.open(BytesIO(image_data)).convert("RGB")
171
+ return image
172
+ except Exception as e:
173
+ raise ValueError(f"Failed to process base64 image: {e}")
174
 
175
  def process_image_input(image_input):
176
+ """Process different types of image input (file path, URL, or base64)"""
177
  if isinstance(image_input, str):
178
  if image_input.startswith("http"):
179
  return load_image(image_input)
180
  elif os.path.exists(image_input):
181
  return load_image(image_input)
182
  else:
183
+ # Try to process as base64
184
  return process_base64_image(image_input)
185
  elif isinstance(image_input, dict) and "image" in image_input:
186
+ # Handle base64 image from dict
187
  return process_base64_image(image_input["image"])
188
  else:
189
  raise ValueError("Unsupported image input format")
 
194
  raise ImportError("LLaVA modules not available")
195
 
196
  disable_torch_init()
197
+
198
  self.tokenizer, self.model, self.image_processor, self.context_len = (
199
+ tokenizer,
200
+ model,
201
+ image_processor,
202
+ context_len,
203
  )
204
+
205
  model_name = get_model_name_from_path(model_path)
206
  if "llama-2" in model_name.lower():
207
  conv_mode = "llava_llama_2"
 
213
  conv_mode = "qwen_1_5"
214
  else:
215
  conv_mode = "llava_v0"
216
+
217
  if args.conv_mode is not None and conv_mode != args.conv_mode:
218
+ print(
219
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
220
+ conv_mode, args.conv_mode, args.conv_mode
221
+ )
222
+ )
223
  else:
224
  args.conv_mode = conv_mode
225
  self.conv_mode = conv_mode
 
229
  class ChatSessionManager:
230
  def __init__(self):
231
  self.chatbot_instance = None
232
+
233
  def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
234
  self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
235
  print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
236
+
237
  def reset_chatbot(self):
238
  self.chatbot_instance = None
239
+
240
  def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
241
  if self.chatbot_instance is None:
242
  self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
 
245
  chat_manager = ChatSessionManager()
246
 
247
  def clear_history():
248
+ """Clear conversation history"""
249
  if not LLAVA_AVAILABLE:
250
+ return {"error": "LLaVA modules not available"}
251
+
252
  try:
253
  chatbot_instance = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len)
254
+ try:
255
+ if hasattr(chatbot_instance, 'conv_mode') and chatbot_instance.conv_mode and LLAVA_AVAILABLE:
256
+ chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy()
257
+ else:
258
+ # Use default conversation template
259
+ chatbot_instance.conversation = chatbot_instance.conversation.__class__()
260
+ except Exception as e:
261
+ print(f"[DEBUG] Failed to reset conversation in clear_history: {e}")
262
+ return {"status": "success", "message": "Conversation history cleared"}
263
  except Exception as e:
264
  return {"error": f"Failed to clear history: {str(e)}"}
265
 
266
+ def add_message(message_text, image_input=None):
267
+ """Add a message to the conversation"""
268
+ return {"status": "success", "message": "Message added"}
 
 
 
 
 
 
 
 
 
 
269
 
270
  def generate_response(message_text, image_input, temperature=0.05, top_p=1.0, max_output_tokens=4096, repetition_penalty=1.0, conv_mode_override=None):
271
+ """Generate response for the given message and image"""
272
  if not LLAVA_AVAILABLE:
273
+ return {"error": "LLaVA modules not available"}
274
+
275
  try:
276
  if not message_text or not image_input:
277
+ return {"error": "Both message text and image are required"}
278
 
279
  our_chatbot = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len)
 
280
 
281
+ # Process image input
282
+ try:
283
+ image = process_image_input(image_input)
284
+ except Exception as e:
285
+ return {"error": f"Failed to process image: {str(e)}"}
286
+
287
+ # Save image for logging
288
+ all_image_hash = []
289
+ all_image_path = []
290
+
291
+ # Generate hash for the image
292
  img_byte_arr = BytesIO()
293
  image.save(img_byte_arr, format='JPEG')
294
+ img_byte_arr = img_byte_arr.getvalue()
295
+ image_hash = hashlib.md5(img_byte_arr).hexdigest()
296
+ all_image_hash.append(image_hash)
297
+
298
+ # Save image to logs
299
  t = datetime.datetime.now()
300
+ filename = os.path.join(
301
+ LOGDIR,
302
+ "serve_images",
303
+ f"{t.year}-{t.month:02d}-{t.day:02d}",
304
+ f"{image_hash}.jpg",
305
+ )
306
+ all_image_path.append(filename)
307
+ if not os.path.isfile(filename):
308
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
309
+ print("image save to", filename)
310
+ image.save(filename)
311
 
312
+ # Process image for model
313
+ try:
314
+ print(f"[DEBUG] Processing image for model...")
315
+ processed_images = process_images([image], our_chatbot.image_processor, our_chatbot.model.config)
316
+ print(f"[DEBUG] Processed images length: {len(processed_images)}")
317
+
318
+ if len(processed_images) == 0:
319
+ return {"error": "Image processing returned empty list"}
320
+
321
+ image_tensor = processed_images[0]
322
+ image_tensor = image_tensor.half().to(our_chatbot.model.device)
323
+ image_tensor = image_tensor.unsqueeze(0)
324
+ print(f"[DEBUG] Image tensor shape: {image_tensor.shape}")
325
+ except Exception as e:
326
+ print(f"[DEBUG] Image processing error: {str(e)}")
327
+ return {"error": f"Image processing failed: {str(e)}"}
328
 
329
+ # Prepare conversation - reset for each request to avoid history issues
330
+ try:
331
+ if hasattr(our_chatbot, 'conv_mode') and our_chatbot.conv_mode and LLAVA_AVAILABLE:
332
+ our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
333
+ else:
334
+ # Use default conversation template
335
+ our_chatbot.conversation = our_chatbot.conversation.__class__()
336
+ except Exception as e:
337
+ print(f"[DEBUG] Failed to reset conversation: {e}")
338
+ # Continue with existing conversation
339
 
340
  inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text
341
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
342
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
343
  prompt = our_chatbot.conversation.get_prompt()
344
 
345
+ # Tokenize input
346
+ input_ids = tokenizer_image_token(
347
+ prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
348
+ ).unsqueeze(0).to(our_chatbot.model.device)
349
 
350
+ # Set up stopping criteria
351
+ stop_str = (
352
+ our_chatbot.conversation.sep
353
+ if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
354
+ else our_chatbot.conversation.sep2
355
+ )
356
+ keywords = [stop_str]
357
+ stopping_criteria = KeywordsStoppingCriteria(
358
+ keywords, our_chatbot.tokenizer, input_ids
359
  )
360
 
361
+ # Generate response
362
  with torch.no_grad():
363
  outputs = our_chatbot.model.generate(
364
  inputs=input_ids,
365
  images=image_tensor,
366
+ do_sample=True,
367
+ temperature=temperature,
368
+ top_p=top_p,
369
+ max_new_tokens=max_output_tokens,
370
+ repetition_penalty=repetition_penalty,
371
+ use_cache=False,
372
  stopping_criteria=[stopping_criteria],
 
373
  )
374
 
375
+ # Decode response
376
+ try:
377
+ print(f"[DEBUG] Outputs shape: {outputs.shape if hasattr(outputs, 'shape') else 'No shape attr'}")
378
+ print(f"[DEBUG] Outputs length: {len(outputs) if hasattr(outputs, '__len__') else 'No length'}")
379
+ print(f"[DEBUG] Input IDs shape: {input_ids.shape}")
380
+
381
+ if len(outputs) == 0:
382
+ return {"error": "Model generated empty output"}
383
+
384
+ response = our_chatbot.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
385
+
386
+ print(f"[DEBUG] Conversation messages length: {len(our_chatbot.conversation.messages)}")
387
+ if len(our_chatbot.conversation.messages) > 0:
388
+ last_message = our_chatbot.conversation.messages[-1]
389
+ print(f"[DEBUG] Last message: {last_message}")
390
+ if isinstance(last_message, list) and len(last_message) > 1:
391
+ our_chatbot.conversation.messages[-1][-1] = response
392
+ print(f"[DEBUG] Response added to conversation")
393
+ else:
394
+ print(f"[DEBUG] Last message format unexpected: {last_message}")
395
+ # Add response as new message if format is wrong
396
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], response)
397
+ else:
398
+ print("[DEBUG] No conversation messages found")
399
+ # Add response as new message
400
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], response)
401
+
402
+ print(f"[DEBUG] Generated response length: {len(response)}")
403
+ except Exception as e:
404
+ print(f"[DEBUG] Response decoding error: {str(e)}")
405
+ return {"error": f"Response decoding failed: {str(e)}"}
406
 
407
+ # Log conversation
408
  history = [(message_text, response)]
409
  with open(get_conv_log_filename(), "a") as fout:
410
+ data = {
411
+ "type": "chat",
412
+ "model": "PULSE-7b",
413
+ "state": history,
414
+ "images": all_image_hash,
415
+ "images_path": all_image_path
416
+ }
417
+ print("#### conv log", data)
418
+ fout.write(json.dumps(data) + "\n")
419
+
420
+ # Upload files to Hugging Face if configured
421
+ if api and repo_name:
422
+ try:
423
+ for upload_img in all_image_path:
424
+ api.upload_file(
425
+ path_or_fileobj=upload_img,
426
+ path_in_repo=upload_img.replace("./logs/", ""),
427
+ repo_id=repo_name,
428
+ repo_type="dataset",
429
+ )
430
+
431
+ # Upload conversation log
432
+ api.upload_file(
433
+ path_or_fileobj=get_conv_log_filename(),
434
+ path_in_repo=get_conv_log_filename().replace("./logs/", ""),
435
+ repo_id=repo_name,
436
+ repo_type="dataset")
437
+ except Exception as e:
438
+ print(f"Failed to upload files: {e}")
439
+
440
+ return {
441
+ "status": "success",
442
+ "response": response,
443
+ "conversation_id": id(our_chatbot.conversation)
444
+ }
445
 
 
446
  except Exception as e:
447
  return {"error": f"Generation failed: {str(e)}"}
448
 
449
  def upvote_last_response(conversation_id):
450
+ """Upvote the last response"""
451
  try:
452
  vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B")
453
+ return {"status": "success", "message": "Thank you for your voting!"}
454
  except Exception as e:
455
+ return {"error": f"Failed to upvote: {str(e)}"}
456
 
457
  def downvote_last_response(conversation_id):
458
+ """Downvote the last response"""
459
  try:
460
  vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B")
461
+ return {"status": "success", "message": "Thank you for your voting!"}
462
  except Exception as e:
463
+ return {"error": f"Failed to downvote: {str(e)}"}
464
 
465
  def flag_response(conversation_id):
466
+ """Flag the last response"""
467
  try:
468
  vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B")
469
+ return {"status": "success", "message": "Response flagged successfully"}
470
  except Exception as e:
471
+ return {"error": f"Failed to flag response: {str(e)}"}
472
 
473
+ # Initialize model when module is imported
474
  def initialize_model():
475
+ """Initialize the model and tokenizer"""
476
  global tokenizer, model, image_processor, context_len, args
477
+
478
  if not LLAVA_AVAILABLE:
479
+ print("LLaVA modules not available, skipping model initialization")
480
  return False
481
+
482
  try:
483
+ # Set default arguments
484
  class Args:
485
  def __init__(self):
486
  self.model_path = "PULSE-ECG/PULSE-7B"
 
493
  self.load_8bit = False
494
  self.load_4bit = False
495
  self.debug = False
496
+
497
  args = Args()
498
+
499
+ # Load model
500
+ model_path = args.model_path
501
+ model_name = get_model_name_from_path(args.model_path)
502
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
503
+ args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
504
+ )
505
+
506
+ print("### image_processor", image_processor)
507
+ print("### tokenizer", tokenizer)
508
+
509
+ # Move model to GPU if available
510
  if torch.cuda.is_available():
511
  model = model.to(torch.device('cuda'))
512
+ print("Model moved to CUDA")
513
+ else:
514
+ print("CUDA not available, using CPU")
515
+
516
  return True
517
+
518
  except Exception as e:
519
+ print(f"Failed to initialize model: {e}")
520
  return False
521
 
522
+ # Don't initialize model on import - do it lazily
523
  model_initialized = False
524
 
525
+ # Main endpoint function for Hugging Face
526
  def query(payload):
527
+ """Main endpoint function for Hugging Face inference API"""
528
  global model_initialized
529
+
530
+ # Lazy initialization - initialize model on first call
531
  if not model_initialized:
532
+ print("Initializing model on first query...")
533
  model_initialized = initialize_model()
534
  if not model_initialized:
535
+ return {"error": "Model initialization failed"}
536
+
537
  try:
538
+ print(f"[DEBUG] query payload keys={list(payload.keys()) if hasattr(payload,'keys') else 'N/A'}")
539
+
540
+ # Extract prompt with multiple possible keys
541
+ message_text = (payload.get("message") or
542
+ payload.get("query") or
543
+ payload.get("prompt") or
544
+ payload.get("istem") or "")
545
+
546
+ # Extract image with multiple possible keys
547
+ image_input = (payload.get("image") or
548
+ payload.get("image_url") or
549
+ payload.get("img") or None)
550
+
551
+ # Extract generation parameters with fallbacks
552
  temperature = float(payload.get("temperature", 0.05))
553
  top_p = float(payload.get("top_p", 1.0))
554
+ max_output_tokens = int(payload.get("max_output_tokens",
555
+ payload.get("max_new_tokens",
556
+ payload.get("max_tokens", 4096))))
557
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
558
  conv_mode_override = payload.get("conv_mode", None)
559
+
560
+ if not message_text or not message_text.strip():
561
+ return {"error": "Missing prompt text. Use 'message', 'query', 'prompt', or 'istem' key"}
562
+
563
  if not image_input:
564
+ return {"error": "Missing image. Use 'image', 'image_url', or 'img' key"}
565
+
566
+ # Generate response with all parameters
567
+ result = generate_response(
568
+ message_text=message_text,
569
+ image_input=image_input,
570
+ temperature=temperature,
571
+ top_p=top_p,
572
+ max_output_tokens=max_output_tokens,
573
+ repetition_penalty=repetition_penalty,
574
+ conv_mode_override=conv_mode_override
575
+ )
576
+
577
+ return result
578
+
579
  except Exception as e:
580
+ return {"error": f"Query failed: {str(e)}"}
581
 
582
+ # Additional utility endpoints
583
  def health_check():
584
+ """Health check endpoint"""
585
  return {
586
  "status": "healthy",
587
  "model_initialized": model_initialized,
 
589
  "llava_available": LLAVA_AVAILABLE,
590
  "transformers_available": TRANSFORMERS_AVAILABLE,
591
  "cv2_available": CV2_AVAILABLE,
592
+ "lazy_loading": True # Model will be loaded on first query
593
  }
594
 
595
  def get_model_info():
596
+ """Get model information"""
597
  if not model_initialized:
598
+ return {
599
+ "error": "Model not initialized yet",
600
+ "lazy_loading": True,
601
+ "note": "Model will be loaded on first query"
602
+ }
603
+
604
+ return {
605
+ "model_path": args.model_path if args else "Unknown",
606
+ "model_type": "PULSE-7B",
607
+ "cuda_available": torch.cuda.is_available(),
608
+ "device": str(model.device) if model else "Unknown"
609
+ }
610
 
611
+ # Hugging Face EndpointHandler class
612
  class EndpointHandler:
613
+ """Hugging Face endpoint handler class"""
614
+
615
  def __init__(self, model_dir):
616
+ """Initialize the endpoint handler"""
617
  self.model_dir = model_dir
618
+ print(f"EndpointHandler initialized with model_dir: {model_dir}")
619
+
620
  def __call__(self, payload):
621
+ """Main endpoint function - handles Hugging Face payload format"""
622
+ # Hugging Face sends payload in "inputs" wrapper
623
  if "inputs" in payload:
624
+ # Extract the actual payload from inputs wrapper
625
+ actual_payload = payload["inputs"]
626
+ return query(actual_payload)
627
+ else:
628
+ # Direct payload (for backward compatibility)
629
+ return query(payload)
630
+
631
  def health_check(self):
632
+ """Health check endpoint"""
633
  return health_check()
634
+
635
  def get_model_info(self):
636
+ """Get model information"""
637
  return get_model_info()
638
 
639
+ # For backward compatibility and testing
640
  if __name__ == "__main__":
641
+ print("Handler module loaded successfully!")
642
+ print("This handler is now ready for Hugging Face endpoints.")
643
+ print("Use the 'query' function as the main endpoint.")
644
+ print("Or use EndpointHandler class for Hugging Face compatibility.")