CanerDedeoglu commited on
Commit
1b1a09c
·
verified ·
1 Parent(s): 4e102e5

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +112 -317
handler.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  import os
2
  import datetime
3
  import torch
@@ -43,7 +49,7 @@ except ImportError as e:
43
 
44
  # Try to import transformers
45
  try:
46
- from transformers import TextStreamer, TextIteratorStreamer
47
  TRANSFORMERS_AVAILABLE = True
48
  except ImportError:
49
  TRANSFORMERS_AVAILABLE = False
@@ -75,7 +81,7 @@ external_log_dir = "./logs"
75
  LOGDIR = external_log_dir
76
  VOTEDIR = "./votes"
77
 
78
- # Global variables for model and tokenizer
79
  tokenizer = None
80
  model = None
81
  image_processor = None
@@ -115,7 +121,7 @@ def vote_last_response(state, vote_type, model_selector):
115
 
116
  def is_valid_video_filename(name):
117
  if not CV2_AVAILABLE:
118
- return False # Video processing disabled
119
  video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
120
  ext = name.split(".")[-1].lower()
121
  return ext in video_extensions
@@ -127,8 +133,7 @@ def is_valid_image_filename(name):
127
 
128
  def sample_frames(video_file, num_frames):
129
  if not CV2_AVAILABLE:
130
- raise ImportError("cv2 (OpenCV) not available. Video processing is disabled.")
131
-
132
  video = cv2.VideoCapture(video_file)
133
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
134
  interval = total_frames // num_frames
@@ -144,46 +149,32 @@ def sample_frames(video_file, num_frames):
144
  return frames
145
 
146
  def load_image(image_file):
147
- if image_file.startswith("http") or image_file.startswith("https"):
148
  response = requests.get(image_file)
149
  if response.status_code == 200:
150
  image = Image.open(BytesIO(response.content)).convert("RGB")
151
  else:
152
  raise ValueError("Failed to load image from URL")
153
  else:
154
- print("Load image from local file")
155
- print(image_file)
156
  image = Image.open(image_file).convert("RGB")
157
  return image
158
 
159
  def process_base64_image(base64_string):
160
- """Process base64 encoded image string"""
161
- try:
162
- # Remove data URL prefix if present
163
- if base64_string.startswith('data:image'):
164
- base64_string = base64_string.split(',')[1]
165
-
166
- # Decode base64 to bytes
167
- image_data = base64.b64decode(base64_string)
168
-
169
- # Convert to PIL Image
170
- image = Image.open(BytesIO(image_data)).convert("RGB")
171
- return image
172
- except Exception as e:
173
- raise ValueError(f"Failed to process base64 image: {e}")
174
 
175
  def process_image_input(image_input):
176
- """Process different types of image input (file path, URL, or base64)"""
177
  if isinstance(image_input, str):
178
  if image_input.startswith("http"):
179
  return load_image(image_input)
180
  elif os.path.exists(image_input):
181
  return load_image(image_input)
182
  else:
183
- # Try to process as base64
184
  return process_base64_image(image_input)
185
  elif isinstance(image_input, dict) and "image" in image_input:
186
- # Handle base64 image from dict
187
  return process_base64_image(image_input["image"])
188
  else:
189
  raise ValueError("Unsupported image input format")
@@ -194,14 +185,9 @@ class InferenceDemo(object):
194
  raise ImportError("LLaVA modules not available")
195
 
196
  disable_torch_init()
197
-
198
  self.tokenizer, self.model, self.image_processor, self.context_len = (
199
- tokenizer,
200
- model,
201
- image_processor,
202
- context_len,
203
  )
204
-
205
  model_name = get_model_name_from_path(model_path)
206
  if "llama-2" in model_name.lower():
207
  conv_mode = "llava_llama_2"
@@ -213,13 +199,8 @@ class InferenceDemo(object):
213
  conv_mode = "qwen_1_5"
214
  else:
215
  conv_mode = "llava_v0"
216
-
217
  if args.conv_mode is not None and conv_mode != args.conv_mode:
218
- print(
219
- "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
220
- conv_mode, args.conv_mode, args.conv_mode
221
- )
222
- )
223
  else:
224
  args.conv_mode = conv_mode
225
  self.conv_mode = conv_mode
@@ -229,14 +210,11 @@ class InferenceDemo(object):
229
  class ChatSessionManager:
230
  def __init__(self):
231
  self.chatbot_instance = None
232
-
233
  def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
234
  self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
235
  print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
236
-
237
  def reset_chatbot(self):
238
  self.chatbot_instance = None
239
-
240
  def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
241
  if self.chatbot_instance is None:
242
  self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
@@ -245,242 +223,139 @@ class ChatSessionManager:
245
  chat_manager = ChatSessionManager()
246
 
247
  def clear_history():
248
- """Clear conversation history"""
249
  if not LLAVA_AVAILABLE:
250
- return {"error": "LLaVA modules not available"}
251
-
252
  try:
253
  chatbot_instance = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len)
254
- try:
255
- if hasattr(chatbot_instance, 'conv_mode') and chatbot_instance.conv_mode and LLAVA_AVAILABLE:
256
- chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy()
257
- else:
258
- # Use default conversation template
259
- chatbot_instance.conversation = chatbot_instance.conversation.__class__()
260
- except Exception as e:
261
- print(f"[DEBUG] Failed to reset conversation in clear_history: {e}")
262
- return {"status": "success", "message": "Conversation history cleared"}
263
  except Exception as e:
264
  return {"error": f"Failed to clear history: {str(e)}"}
265
 
266
- def add_message(message_text, image_input=None):
267
- """Add a message to the conversation"""
268
- return {"status": "success", "message": "Message added"}
 
 
 
 
 
 
 
 
 
 
269
 
270
  def generate_response(message_text, image_input, temperature=0.05, top_p=1.0, max_output_tokens=4096, repetition_penalty=1.0, conv_mode_override=None):
271
- """Generate response for the given message and image"""
272
  if not LLAVA_AVAILABLE:
273
- return {"error": "LLaVA modules not available"}
274
-
275
  try:
276
  if not message_text or not image_input:
277
- return {"error": "Both message text and image are required"}
278
 
279
  our_chatbot = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len)
 
280
 
281
- # Process image input
282
- try:
283
- image = process_image_input(image_input)
284
- except Exception as e:
285
- return {"error": f"Failed to process image: {str(e)}"}
286
-
287
- # Save image for logging
288
- all_image_hash = []
289
- all_image_path = []
290
-
291
- # Generate hash for the image
292
  img_byte_arr = BytesIO()
293
  image.save(img_byte_arr, format='JPEG')
294
- img_byte_arr = img_byte_arr.getvalue()
295
- image_hash = hashlib.md5(img_byte_arr).hexdigest()
296
- all_image_hash.append(image_hash)
297
-
298
- # Save image to logs
299
  t = datetime.datetime.now()
300
- filename = os.path.join(
301
- LOGDIR,
302
- "serve_images",
303
- f"{t.year}-{t.month:02d}-{t.day:02d}",
304
- f"{image_hash}.jpg",
305
- )
306
- all_image_path.append(filename)
307
- if not os.path.isfile(filename):
308
- os.makedirs(os.path.dirname(filename), exist_ok=True)
309
- print("image save to", filename)
310
- image.save(filename)
311
 
312
- # Process image for model
313
- try:
314
- print(f"[DEBUG] Processing image for model...")
315
- processed_images = process_images([image], our_chatbot.image_processor, our_chatbot.model.config)
316
- print(f"[DEBUG] Processed images length: {len(processed_images)}")
317
-
318
- if len(processed_images) == 0:
319
- return {"error": "Image processing returned empty list"}
320
-
321
- image_tensor = processed_images[0]
322
- image_tensor = image_tensor.half().to(our_chatbot.model.device)
323
- image_tensor = image_tensor.unsqueeze(0)
324
- print(f"[DEBUG] Image tensor shape: {image_tensor.shape}")
325
- except Exception as e:
326
- print(f"[DEBUG] Image processing error: {str(e)}")
327
- return {"error": f"Image processing failed: {str(e)}"}
328
 
329
- # Prepare conversation - reset for each request to avoid history issues
330
- try:
331
- if hasattr(our_chatbot, 'conv_mode') and our_chatbot.conv_mode and LLAVA_AVAILABLE:
332
- our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
333
- else:
334
- # Use default conversation template
335
- our_chatbot.conversation = our_chatbot.conversation.__class__()
336
- except Exception as e:
337
- print(f"[DEBUG] Failed to reset conversation: {e}")
338
- # Continue with existing conversation
339
 
340
  inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text
341
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
342
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
343
  prompt = our_chatbot.conversation.get_prompt()
344
 
345
- # Tokenize input
346
- input_ids = tokenizer_image_token(
347
- prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
348
- ).unsqueeze(0).to(our_chatbot.model.device)
349
 
350
- # Set up stopping criteria
351
- stop_str = (
352
- our_chatbot.conversation.sep
353
- if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
354
- else our_chatbot.conversation.sep2
355
- )
356
- keywords = [stop_str]
357
- stopping_criteria = KeywordsStoppingCriteria(
358
- keywords, our_chatbot.tokenizer, input_ids
359
  )
360
 
361
- # Generate response
362
  with torch.no_grad():
363
  outputs = our_chatbot.model.generate(
364
  inputs=input_ids,
365
  images=image_tensor,
366
- do_sample=True,
367
- temperature=temperature,
368
- top_p=top_p,
369
- max_new_tokens=max_output_tokens,
370
- repetition_penalty=repetition_penalty,
371
- use_cache=False,
372
  stopping_criteria=[stopping_criteria],
 
373
  )
374
 
375
- # Decode response
376
- try:
377
- print(f"[DEBUG] Outputs shape: {outputs.shape if hasattr(outputs, 'shape') else 'No shape attr'}")
378
- print(f"[DEBUG] Outputs length: {len(outputs) if hasattr(outputs, '__len__') else 'No length'}")
379
- print(f"[DEBUG] Input IDs shape: {input_ids.shape}")
380
-
381
- if len(outputs) == 0:
382
- return {"error": "Model generated empty output"}
383
-
384
- response = our_chatbot.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
385
-
386
- print(f"[DEBUG] Conversation messages length: {len(our_chatbot.conversation.messages)}")
387
- if len(our_chatbot.conversation.messages) > 0:
388
- last_message = our_chatbot.conversation.messages[-1]
389
- print(f"[DEBUG] Last message: {last_message}")
390
- if isinstance(last_message, list) and len(last_message) > 1:
391
- our_chatbot.conversation.messages[-1][-1] = response
392
- print(f"[DEBUG] Response added to conversation")
393
- else:
394
- print(f"[DEBUG] Last message format unexpected: {last_message}")
395
- # Add response as new message if format is wrong
396
- our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], response)
397
- else:
398
- print("[DEBUG] No conversation messages found")
399
- # Add response as new message
400
- our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], response)
401
-
402
- print(f"[DEBUG] Generated response length: {len(response)}")
403
- except Exception as e:
404
- print(f"[DEBUG] Response decoding error: {str(e)}")
405
- return {"error": f"Response decoding failed: {str(e)}"}
406
 
407
- # Log conversation
408
- history = [(message_text, response)]
409
- with open(get_conv_log_filename(), "a") as fout:
410
- data = {
411
- "type": "chat",
412
- "model": "PULSE-7b",
413
- "state": history,
414
- "images": all_image_hash,
415
- "images_path": all_image_path
416
- }
417
- print("#### conv log", data)
418
- fout.write(json.dumps(data) + "\n")
419
 
420
- # Upload files to Hugging Face if configured
421
- if api and repo_name:
422
- try:
423
- for upload_img in all_image_path:
424
- api.upload_file(
425
- path_or_fileobj=upload_img,
426
- path_in_repo=upload_img.replace("./logs/", ""),
427
- repo_id=repo_name,
428
- repo_type="dataset",
429
- )
430
-
431
- # Upload conversation log
432
- api.upload_file(
433
- path_or_fileobj=get_conv_log_filename(),
434
- path_in_repo=get_conv_log_filename().replace("./logs/", ""),
435
- repo_id=repo_name,
436
- repo_type="dataset")
437
- except Exception as e:
438
- print(f"Failed to upload files: {e}")
439
 
440
- return {
441
- "status": "success",
442
- "response": response,
443
- "conversation_id": id(our_chatbot.conversation)
444
- }
 
445
 
 
446
  except Exception as e:
447
  return {"error": f"Generation failed: {str(e)}"}
448
 
449
  def upvote_last_response(conversation_id):
450
- """Upvote the last response"""
451
  try:
452
  vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B")
453
- return {"status": "success", "message": "Thank you for your voting!"}
454
  except Exception as e:
455
- return {"error": f"Failed to upvote: {str(e)}"}
456
 
457
  def downvote_last_response(conversation_id):
458
- """Downvote the last response"""
459
  try:
460
  vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B")
461
- return {"status": "success", "message": "Thank you for your voting!"}
462
  except Exception as e:
463
- return {"error": f"Failed to downvote: {str(e)}"}
464
 
465
  def flag_response(conversation_id):
466
- """Flag the last response"""
467
  try:
468
  vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B")
469
- return {"status": "success", "message": "Response flagged successfully"}
470
  except Exception as e:
471
- return {"error": f"Failed to flag response: {str(e)}"}
472
 
473
- # Initialize model when module is imported
474
  def initialize_model():
475
- """Initialize the model and tokenizer"""
476
  global tokenizer, model, image_processor, context_len, args
477
-
478
  if not LLAVA_AVAILABLE:
479
- print("LLaVA modules not available, skipping model initialization")
480
  return False
481
-
482
  try:
483
- # Set default arguments
484
  class Args:
485
  def __init__(self):
486
  self.model_path = "PULSE-ECG/PULSE-7B"
@@ -493,95 +368,45 @@ def initialize_model():
493
  self.load_8bit = False
494
  self.load_4bit = False
495
  self.debug = False
496
-
497
  args = Args()
498
-
499
- # Load model
500
- model_path = args.model_path
501
- model_name = get_model_name_from_path(args.model_path)
502
- tokenizer, model, image_processor, context_len = load_pretrained_model(
503
- args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
504
- )
505
-
506
- print("### image_processor", image_processor)
507
- print("### tokenizer", tokenizer)
508
-
509
- # Move model to GPU if available
510
  if torch.cuda.is_available():
511
  model = model.to(torch.device('cuda'))
512
- print("Model moved to CUDA")
513
- else:
514
- print("CUDA not available, using CPU")
515
-
516
  return True
517
-
518
  except Exception as e:
519
- print(f"Failed to initialize model: {e}")
520
  return False
521
 
522
- # Don't initialize model on import - do it lazily
523
  model_initialized = False
524
 
525
- # Main endpoint function for Hugging Face
526
  def query(payload):
527
- """Main endpoint function for Hugging Face inference API"""
528
  global model_initialized
529
-
530
- # Lazy initialization - initialize model on first call
531
  if not model_initialized:
532
- print("Initializing model on first query...")
533
  model_initialized = initialize_model()
534
  if not model_initialized:
535
- return {"error": "Model initialization failed"}
536
-
537
  try:
538
- print(f"[DEBUG] query payload keys={list(payload.keys()) if hasattr(payload,'keys') else 'N/A'}")
539
-
540
- # Extract prompt with multiple possible keys
541
- message_text = (payload.get("message") or
542
- payload.get("query") or
543
- payload.get("prompt") or
544
- payload.get("istem") or "")
545
-
546
- # Extract image with multiple possible keys
547
- image_input = (payload.get("image") or
548
- payload.get("image_url") or
549
- payload.get("img") or None)
550
-
551
- # Extract generation parameters with fallbacks
552
  temperature = float(payload.get("temperature", 0.05))
553
  top_p = float(payload.get("top_p", 1.0))
554
- max_output_tokens = int(payload.get("max_output_tokens",
555
- payload.get("max_new_tokens",
556
- payload.get("max_tokens", 4096))))
557
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
558
  conv_mode_override = payload.get("conv_mode", None)
559
-
560
- if not message_text or not message_text.strip():
561
- return {"error": "Missing prompt text. Use 'message', 'query', 'prompt', or 'istem' key"}
562
-
563
  if not image_input:
564
- return {"error": "Missing image. Use 'image', 'image_url', or 'img' key"}
565
-
566
- # Generate response with all parameters
567
- result = generate_response(
568
- message_text=message_text,
569
- image_input=image_input,
570
- temperature=temperature,
571
- top_p=top_p,
572
- max_output_tokens=max_output_tokens,
573
- repetition_penalty=repetition_penalty,
574
- conv_mode_override=conv_mode_override
575
- )
576
-
577
- return result
578
-
579
  except Exception as e:
580
- return {"error": f"Query failed: {str(e)}"}
581
 
582
- # Additional utility endpoints
583
  def health_check():
584
- """Health check endpoint"""
585
  return {
586
  "status": "healthy",
587
  "model_initialized": model_initialized,
@@ -589,56 +414,26 @@ def health_check():
589
  "llava_available": LLAVA_AVAILABLE,
590
  "transformers_available": TRANSFORMERS_AVAILABLE,
591
  "cv2_available": CV2_AVAILABLE,
592
- "lazy_loading": True # Model will be loaded on first query
593
  }
594
 
595
  def get_model_info():
596
- """Get model information"""
597
  if not model_initialized:
598
- return {
599
- "error": "Model not initialized yet",
600
- "lazy_loading": True,
601
- "note": "Model will be loaded on first query"
602
- }
603
-
604
- return {
605
- "model_path": args.model_path if args else "Unknown",
606
- "model_type": "PULSE-7B",
607
- "cuda_available": torch.cuda.is_available(),
608
- "device": str(model.device) if model else "Unknown"
609
- }
610
 
611
- # Hugging Face EndpointHandler class
612
  class EndpointHandler:
613
- """Hugging Face endpoint handler class"""
614
-
615
  def __init__(self, model_dir):
616
- """Initialize the endpoint handler"""
617
  self.model_dir = model_dir
618
- print(f"EndpointHandler initialized with model_dir: {model_dir}")
619
-
620
  def __call__(self, payload):
621
- """Main endpoint function - handles Hugging Face payload format"""
622
- # Hugging Face sends payload in "inputs" wrapper
623
  if "inputs" in payload:
624
- # Extract the actual payload from inputs wrapper
625
- actual_payload = payload["inputs"]
626
- return query(actual_payload)
627
- else:
628
- # Direct payload (for backward compatibility)
629
- return query(payload)
630
-
631
  def health_check(self):
632
- """Health check endpoint"""
633
  return health_check()
634
-
635
  def get_model_info(self):
636
- """Get model information"""
637
  return get_model_info()
638
 
639
- # For backward compatibility and testing
640
  if __name__ == "__main__":
641
- print("Handler module loaded successfully!")
642
- print("This handler is now ready for Hugging Face endpoints.")
643
- print("Use the 'query' function as the main endpoint.")
644
- print("Or use EndpointHandler class for Hugging Face compatibility.")
 
1
+ # -*- coding: utf-8 -*-
2
+ # handler.py — PULSE-7B / LLaVA robust endpoint
3
+ # - Safe decode (empty output fix)
4
+ # - PAD/EOS safety
5
+ # - Hugging Face endpoint compatible
6
+
7
  import os
8
  import datetime
9
  import torch
 
49
 
50
  # Try to import transformers
51
  try:
52
+ from transformers import TextStreamer, TextIteratorStreamer, GenerationConfig
53
  TRANSFORMERS_AVAILABLE = True
54
  except ImportError:
55
  TRANSFORMERS_AVAILABLE = False
 
81
  LOGDIR = external_log_dir
82
  VOTEDIR = "./votes"
83
 
84
+ # Global variables
85
  tokenizer = None
86
  model = None
87
  image_processor = None
 
121
 
122
  def is_valid_video_filename(name):
123
  if not CV2_AVAILABLE:
124
+ return False
125
  video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
126
  ext = name.split(".")[-1].lower()
127
  return ext in video_extensions
 
133
 
134
  def sample_frames(video_file, num_frames):
135
  if not CV2_AVAILABLE:
136
+ raise ImportError("cv2 not available")
 
137
  video = cv2.VideoCapture(video_file)
138
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
139
  interval = total_frames // num_frames
 
149
  return frames
150
 
151
  def load_image(image_file):
152
+ if image_file.startswith("http"):
153
  response = requests.get(image_file)
154
  if response.status_code == 200:
155
  image = Image.open(BytesIO(response.content)).convert("RGB")
156
  else:
157
  raise ValueError("Failed to load image from URL")
158
  else:
 
 
159
  image = Image.open(image_file).convert("RGB")
160
  return image
161
 
162
  def process_base64_image(base64_string):
163
+ if base64_string.startswith('data:image'):
164
+ base64_string = base64_string.split(',')[1]
165
+ image_data = base64.b64decode(base64_string)
166
+ image = Image.open(BytesIO(image_data)).convert("RGB")
167
+ return image
 
 
 
 
 
 
 
 
 
168
 
169
  def process_image_input(image_input):
 
170
  if isinstance(image_input, str):
171
  if image_input.startswith("http"):
172
  return load_image(image_input)
173
  elif os.path.exists(image_input):
174
  return load_image(image_input)
175
  else:
 
176
  return process_base64_image(image_input)
177
  elif isinstance(image_input, dict) and "image" in image_input:
 
178
  return process_base64_image(image_input["image"])
179
  else:
180
  raise ValueError("Unsupported image input format")
 
185
  raise ImportError("LLaVA modules not available")
186
 
187
  disable_torch_init()
 
188
  self.tokenizer, self.model, self.image_processor, self.context_len = (
189
+ tokenizer, model, image_processor, context_len
 
 
 
190
  )
 
191
  model_name = get_model_name_from_path(model_path)
192
  if "llama-2" in model_name.lower():
193
  conv_mode = "llava_llama_2"
 
199
  conv_mode = "qwen_1_5"
200
  else:
201
  conv_mode = "llava_v0"
 
202
  if args.conv_mode is not None and conv_mode != args.conv_mode:
203
+ print(f"[WARNING] auto inferred conv_mode={conv_mode}, using {args.conv_mode}")
 
 
 
 
204
  else:
205
  args.conv_mode = conv_mode
206
  self.conv_mode = conv_mode
 
210
  class ChatSessionManager:
211
  def __init__(self):
212
  self.chatbot_instance = None
 
213
  def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
214
  self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
215
  print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
 
216
  def reset_chatbot(self):
217
  self.chatbot_instance = None
 
218
  def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
219
  if self.chatbot_instance is None:
220
  self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
 
223
  chat_manager = ChatSessionManager()
224
 
225
  def clear_history():
 
226
  if not LLAVA_AVAILABLE:
227
+ return {"error": "LLaVA not available"}
 
228
  try:
229
  chatbot_instance = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len)
230
+ mode = getattr(chatbot_instance, 'conv_mode', None)
231
+ if mode and LLAVA_AVAILABLE and mode in conv_templates:
232
+ chatbot_instance.conversation = conv_templates[mode].copy()
233
+ else:
234
+ chatbot_instance.conversation = chatbot_instance.conversation.__class__()
235
+ return {"status": "success", "message": "Conversation cleared"}
 
 
 
236
  except Exception as e:
237
  return {"error": f"Failed to clear history: {str(e)}"}
238
 
239
+ def _strip_prefix_relaxed(text: str, prefix: str) -> str:
240
+ try:
241
+ if text.startswith(prefix):
242
+ return text[len(prefix):]
243
+ t_norm = " ".join(text.split())
244
+ p_norm = " ".join(prefix.split())
245
+ if t_norm.startswith(p_norm):
246
+ idx = text.find(prefix.splitlines()[0]) if prefix.splitlines() else -1
247
+ if idx >= 0:
248
+ return text[idx + len(prefix.splitlines()[0]):]
249
+ except Exception:
250
+ pass
251
+ return text
252
 
253
  def generate_response(message_text, image_input, temperature=0.05, top_p=1.0, max_output_tokens=4096, repetition_penalty=1.0, conv_mode_override=None):
 
254
  if not LLAVA_AVAILABLE:
255
+ return {"error": "LLaVA not available"}
 
256
  try:
257
  if not message_text or not image_input:
258
+ return {"error": "Both message and image required"}
259
 
260
  our_chatbot = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len)
261
+ image = process_image_input(image_input)
262
 
 
 
 
 
 
 
 
 
 
 
 
263
  img_byte_arr = BytesIO()
264
  image.save(img_byte_arr, format='JPEG')
265
+ image_hash = hashlib.md5(img_byte_arr.getvalue()).hexdigest()
 
 
 
 
266
  t = datetime.datetime.now()
267
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{image_hash}.jpg")
268
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
269
+ image.save(filename)
 
 
 
 
 
 
 
 
270
 
271
+ processed_images = process_images([image], our_chatbot.image_processor, our_chatbot.model.config)
272
+ image_tensor = processed_images[0].half().to(our_chatbot.model.device).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
+ if conv_mode_override:
275
+ our_chatbot.conversation = conv_templates[conv_mode_override].copy()
276
+ else:
277
+ our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
 
 
 
 
 
 
278
 
279
  inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text
280
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
281
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
282
  prompt = our_chatbot.conversation.get_prompt()
283
 
284
+ input_ids = tokenizer_image_token(prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(our_chatbot.model.device)
 
 
 
285
 
286
+ stop_str = our_chatbot.conversation.sep if our_chatbot.conversation.sep_style != SeparatorStyle.TWO else our_chatbot.conversation.sep2
287
+ stopping_criteria = KeywordsStoppingCriteria([stop_str], our_chatbot.tokenizer, input_ids)
288
+
289
+ pad_id = our_chatbot.tokenizer.pad_token_id
290
+ eos_id = our_chatbot.tokenizer.eos_token_id if our_chatbot.tokenizer.eos_token_id is not None else pad_id
291
+ gen_cfg = GenerationConfig(
292
+ do_sample=True, temperature=float(temperature), top_p=float(top_p),
293
+ max_new_tokens=int(max_output_tokens), repetition_penalty=float(repetition_penalty),
294
+ pad_token_id=pad_id, eos_token_id=eos_id
295
  )
296
 
 
297
  with torch.no_grad():
298
  outputs = our_chatbot.model.generate(
299
  inputs=input_ids,
300
  images=image_tensor,
301
+ generation_config=gen_cfg,
302
+ use_cache=True,
 
 
 
 
303
  stopping_criteria=[stopping_criteria],
304
+ return_dict_in_generate=True
305
  )
306
 
307
+ sequences = outputs.sequences
308
+ gen_ids = sequences[0]
309
+ full_text = our_chatbot.tokenizer.decode(gen_ids, skip_special_tokens=True)
310
+ prompt_text = our_chatbot.tokenizer.decode(input_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
+ if gen_ids.shape[0] > input_ids.shape[1]:
313
+ response = our_chatbot.tokenizer.decode(gen_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
314
+ else:
315
+ response = _strip_prefix_relaxed(full_text, prompt_text).strip()
316
+ if not response:
317
+ response = full_text.replace(stop_str, "").strip()
 
 
 
 
 
 
318
 
319
+ our_chatbot.conversation.messages[-1][-1] = response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
+ history = [(message_text, response)]
322
+ with open(get_conv_log_filename(), "a") as fout:
323
+ fout.write(json.dumps({
324
+ "type": "chat", "model": "PULSE-7b", "state": history,
325
+ "images": [image_hash], "images_path": [filename]
326
+ }) + "\n")
327
 
328
+ return {"status": "success", "response": response, "conversation_id": id(our_chatbot.conversation)}
329
  except Exception as e:
330
  return {"error": f"Generation failed: {str(e)}"}
331
 
332
  def upvote_last_response(conversation_id):
 
333
  try:
334
  vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B")
335
+ return {"status": "success", "message": "Upvoted"}
336
  except Exception as e:
337
+ return {"error": str(e)}
338
 
339
  def downvote_last_response(conversation_id):
 
340
  try:
341
  vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B")
342
+ return {"status": "success", "message": "Downvoted"}
343
  except Exception as e:
344
+ return {"error": str(e)}
345
 
346
  def flag_response(conversation_id):
 
347
  try:
348
  vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B")
349
+ return {"status": "success", "message": "Flagged"}
350
  except Exception as e:
351
+ return {"error": str(e)}
352
 
 
353
  def initialize_model():
 
354
  global tokenizer, model, image_processor, context_len, args
 
355
  if not LLAVA_AVAILABLE:
356
+ print("LLaVA not available")
357
  return False
 
358
  try:
 
359
  class Args:
360
  def __init__(self):
361
  self.model_path = "PULSE-ECG/PULSE-7B"
 
368
  self.load_8bit = False
369
  self.load_4bit = False
370
  self.debug = False
 
371
  args = Args()
372
+ tok, mdl, img_proc, ctx_len = load_pretrained_model(args.model_path, args.model_base, get_model_name_from_path(args.model_path), args.load_8bit, args.load_4bit)
373
+ if tok.eos_token_id is None:
374
+ tok.add_special_tokens({"eos_token": "</s>"})
375
+ if tok.pad_token_id is None:
376
+ tok.pad_token = tok.eos_token
377
+ tokenizer, model, image_processor, context_len = tok, mdl, img_proc, ctx_len
 
 
 
 
 
 
378
  if torch.cuda.is_available():
379
  model = model.to(torch.device('cuda'))
 
 
 
 
380
  return True
 
381
  except Exception as e:
382
+ print(f"Init model fail: {e}")
383
  return False
384
 
 
385
  model_initialized = False
386
 
 
387
  def query(payload):
 
388
  global model_initialized
 
 
389
  if not model_initialized:
 
390
  model_initialized = initialize_model()
391
  if not model_initialized:
392
+ return {"error": "Model init failed"}
 
393
  try:
394
+ message_text = payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or ""
395
+ image_input = payload.get("image") or payload.get("image_url") or payload.get("img") or None
 
 
 
 
 
 
 
 
 
 
 
 
396
  temperature = float(payload.get("temperature", 0.05))
397
  top_p = float(payload.get("top_p", 1.0))
398
+ max_output_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096))))
 
 
399
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
400
  conv_mode_override = payload.get("conv_mode", None)
401
+ if not message_text.strip():
402
+ return {"error": "Missing prompt text"}
 
 
403
  if not image_input:
404
+ return {"error": "Missing image"}
405
+ return generate_response(message_text, image_input, temperature, top_p, max_output_tokens, repetition_penalty, conv_mode_override)
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  except Exception as e:
407
+ return {"error": str(e)}
408
 
 
409
  def health_check():
 
410
  return {
411
  "status": "healthy",
412
  "model_initialized": model_initialized,
 
414
  "llava_available": LLAVA_AVAILABLE,
415
  "transformers_available": TRANSFORMERS_AVAILABLE,
416
  "cv2_available": CV2_AVAILABLE,
417
+ "lazy_loading": True
418
  }
419
 
420
  def get_model_info():
 
421
  if not model_initialized:
422
+ return {"error": "Not initialized", "lazy_loading": True}
423
+ return {"model_path": args.model_path if args else "Unknown", "model_type": "PULSE-7B", "cuda_available": torch.cuda.is_available(), "device": str(model.device) if model else "Unknown"}
 
 
 
 
 
 
 
 
 
 
424
 
 
425
  class EndpointHandler:
 
 
426
  def __init__(self, model_dir):
 
427
  self.model_dir = model_dir
428
+ print(f"Handler init with model_dir={model_dir}")
 
429
  def __call__(self, payload):
 
 
430
  if "inputs" in payload:
431
+ return query(payload["inputs"])
432
+ return query(payload)
 
 
 
 
 
433
  def health_check(self):
 
434
  return health_check()
 
435
  def get_model_info(self):
 
436
  return get_model_info()
437
 
 
438
  if __name__ == "__main__":
439
+ print("Handler loaded and ready.")