Files changed (1) hide show
  1. handler.py +228 -349
handler.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import os
2
  import datetime
3
  import torch
@@ -9,7 +12,7 @@ import requests
9
  from PIL import Image
10
  from io import BytesIO
11
 
12
- # Try to import cv2, but make it optional
13
  try:
14
  import cv2
15
  CV2_AVAILABLE = True
@@ -17,7 +20,7 @@ except ImportError:
17
  CV2_AVAILABLE = False
18
  print("Warning: cv2 (OpenCV) not available. Video processing will be disabled.")
19
 
20
- # Try to import llava modules, but make them optional
21
  try:
22
  from llava import conversation as conversation_lib
23
  from llava.constants import DEFAULT_IMAGE_TOKEN
@@ -41,15 +44,15 @@ except ImportError as e:
41
  LLAVA_AVAILABLE = False
42
  print(f"Warning: LLaVA modules not available: {e}")
43
 
44
- # Try to import transformers
45
  try:
46
- from transformers import TextStreamer, TextIteratorStreamer
47
  TRANSFORMERS_AVAILABLE = True
48
  except ImportError:
49
  TRANSFORMERS_AVAILABLE = False
50
  print("Warning: Transformers not available")
51
 
52
- # Try to import huggingface_hub
53
  try:
54
  from huggingface_hub import HfApi, login
55
  HF_HUB_AVAILABLE = True
@@ -57,7 +60,7 @@ except ImportError:
57
  HF_HUB_AVAILABLE = False
58
  print("Warning: Hugging Face Hub not available")
59
 
60
- # Initialize Hugging Face API
61
  if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
62
  try:
63
  login(token=os.environ["HF_TOKEN"], write_permission=True)
@@ -71,21 +74,23 @@ else:
71
  api = None
72
  repo_name = ""
73
 
 
74
  external_log_dir = "./logs"
75
  LOGDIR = external_log_dir
76
  VOTEDIR = "./votes"
77
 
78
- # Global variables for model and tokenizer
79
  tokenizer = None
80
  model = None
81
  image_processor = None
82
  context_len = None
83
  args = None
 
84
 
 
85
  def get_conv_log_filename():
86
  t = datetime.datetime.now()
87
- name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
88
- return name
89
 
90
  def get_conv_vote_filename():
91
  t = datetime.datetime.now()
@@ -98,13 +103,7 @@ def vote_last_response(state, vote_type, model_selector):
98
  if api and repo_name:
99
  try:
100
  with open(get_conv_vote_filename(), "a") as fout:
101
- data = {
102
- "type": vote_type,
103
- "model": model_selector,
104
- "state": state,
105
- }
106
- fout.write(json.dumps(data) + "\n")
107
-
108
  api.upload_file(
109
  path_or_fileobj=get_conv_vote_filename(),
110
  path_in_repo=get_conv_vote_filename().replace("./votes/", ""),
@@ -115,93 +114,48 @@ def vote_last_response(state, vote_type, model_selector):
115
 
116
  def is_valid_video_filename(name):
117
  if not CV2_AVAILABLE:
118
- return False # Video processing disabled
119
- video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
120
- ext = name.split(".")[-1].lower()
121
- return ext in video_extensions
122
 
123
  def is_valid_image_filename(name):
124
- image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
125
- ext = name.split(".")[-1].lower()
126
- return ext in image_extensions
127
-
128
- def sample_frames(video_file, num_frames):
129
- if not CV2_AVAILABLE:
130
- raise ImportError("cv2 (OpenCV) not available. Video processing is disabled.")
131
-
132
- video = cv2.VideoCapture(video_file)
133
- total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
134
- interval = total_frames // num_frames
135
- frames = []
136
- for i in range(total_frames):
137
- ret, frame = video.read()
138
- if not ret:
139
- continue
140
- if i % interval == 0:
141
- pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
142
- frames.append(pil_img)
143
- video.release()
144
- return frames
145
 
146
  def load_image(image_file):
147
- if image_file.startswith("http") or image_file.startswith("https"):
148
- response = requests.get(image_file)
149
- if response.status_code == 200:
150
- image = Image.open(BytesIO(response.content)).convert("RGB")
151
- else:
152
- raise ValueError("Failed to load image from URL")
153
- else:
154
- print("Load image from local file")
155
- print(image_file)
156
- image = Image.open(image_file).convert("RGB")
157
- return image
158
 
159
  def process_base64_image(base64_string):
160
- """Process base64 encoded image string"""
161
- try:
162
- # Remove data URL prefix if present
163
- if base64_string.startswith('data:image'):
164
- base64_string = base64_string.split(',')[1]
165
-
166
- # Decode base64 to bytes
167
- image_data = base64.b64decode(base64_string)
168
-
169
- # Convert to PIL Image
170
- image = Image.open(BytesIO(image_data)).convert("RGB")
171
- return image
172
- except Exception as e:
173
- raise ValueError(f"Failed to process base64 image: {e}")
174
 
175
  def process_image_input(image_input):
176
- """Process different types of image input (file path, URL, or base64)"""
177
  if isinstance(image_input, str):
178
  if image_input.startswith("http"):
179
  return load_image(image_input)
180
  elif os.path.exists(image_input):
181
  return load_image(image_input)
182
  else:
183
- # Try to process as base64
184
  return process_base64_image(image_input)
185
  elif isinstance(image_input, dict) and "image" in image_input:
186
- # Handle base64 image from dict
187
  return process_base64_image(image_input["image"])
188
  else:
189
  raise ValueError("Unsupported image input format")
190
 
 
191
  class InferenceDemo(object):
192
  def __init__(self, args, model_path, tokenizer, model, image_processor, context_len) -> None:
193
  if not LLAVA_AVAILABLE:
194
  raise ImportError("LLaVA modules not available")
195
-
196
  disable_torch_init()
197
-
198
  self.tokenizer, self.model, self.image_processor, self.context_len = (
199
- tokenizer,
200
- model,
201
- image_processor,
202
- context_len,
203
  )
204
-
205
  model_name = get_model_name_from_path(model_path)
206
  if "llama-2" in model_name.lower():
207
  conv_mode = "llava_llama_2"
@@ -213,30 +167,22 @@ class InferenceDemo(object):
213
  conv_mode = "qwen_1_5"
214
  else:
215
  conv_mode = "llava_v0"
216
-
217
  if args.conv_mode is not None and conv_mode != args.conv_mode:
218
- print(
219
- "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
220
- conv_mode, args.conv_mode, args.conv_mode
221
- )
222
- )
223
  else:
224
  args.conv_mode = conv_mode
225
- self.conv_mode = conv_mode
226
- self.conversation = conv_templates[args.conv_mode].copy()
227
  self.num_frames = args.num_frames
228
 
229
  class ChatSessionManager:
230
  def __init__(self):
231
  self.chatbot_instance = None
232
-
233
  def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
234
  self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
235
  print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
236
-
237
  def reset_chatbot(self):
238
  self.chatbot_instance = None
239
-
240
  def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
241
  if self.chatbot_instance is None:
242
  self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
@@ -245,237 +191,196 @@ class ChatSessionManager:
245
  chat_manager = ChatSessionManager()
246
 
247
  def clear_history():
248
- """Clear conversation history"""
249
  if not LLAVA_AVAILABLE:
250
  return {"error": "LLaVA modules not available"}
251
-
252
  try:
253
- chatbot_instance = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len)
254
- try:
255
- if hasattr(chatbot_instance, 'conv_mode') and chatbot_instance.conv_mode and LLAVA_AVAILABLE:
256
- chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy()
257
- else:
258
- # Use default conversation template
259
- chatbot_instance.conversation = chatbot_instance.conversation.__class__()
260
- except Exception as e:
261
- print(f"[DEBUG] Failed to reset conversation in clear_history: {e}")
262
  return {"status": "success", "message": "Conversation history cleared"}
263
  except Exception as e:
264
  return {"error": f"Failed to clear history: {str(e)}"}
265
 
266
- def add_message(message_text, image_input=None):
267
- """Add a message to the conversation"""
268
- return {"status": "success", "message": "Message added"}
269
-
270
- def generate_response(message_text, image_input, temperature=0.05, top_p=1.0, max_output_tokens=4096, repetition_penalty=1.0, conv_mode_override=None):
271
- """Generate response for the given message and image"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  if not LLAVA_AVAILABLE:
273
  return {"error": "LLaVA modules not available"}
274
-
275
  try:
276
  if not message_text or not image_input:
277
  return {"error": "Both message text and image are required"}
278
-
279
- our_chatbot = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len)
280
-
281
- # Process image input
282
- try:
283
- image = process_image_input(image_input)
284
- except Exception as e:
285
- return {"error": f"Failed to process image: {str(e)}"}
286
-
287
- # Save image for logging
288
- all_image_hash = []
289
- all_image_path = []
290
-
291
- # Generate hash for the image
 
292
  img_byte_arr = BytesIO()
293
  image.save(img_byte_arr, format='JPEG')
294
- img_byte_arr = img_byte_arr.getvalue()
295
- image_hash = hashlib.md5(img_byte_arr).hexdigest()
296
- all_image_hash.append(image_hash)
297
-
298
  # Save image to logs
299
  t = datetime.datetime.now()
300
- filename = os.path.join(
301
- LOGDIR,
302
- "serve_images",
303
- f"{t.year}-{t.month:02d}-{t.day:02d}",
304
- f"{image_hash}.jpg",
305
- )
306
- all_image_path.append(filename)
307
- if not os.path.isfile(filename):
308
- os.makedirs(os.path.dirname(filename), exist_ok=True)
309
- print("image save to", filename)
310
- image.save(filename)
311
-
312
- # Process image for model
313
- try:
314
- print(f"[DEBUG] Processing image for model...")
315
- processed_images = process_images([image], our_chatbot.image_processor, our_chatbot.model.config)
316
- print(f"[DEBUG] Processed images length: {len(processed_images)}")
317
-
318
- if len(processed_images) == 0:
319
- return {"error": "Image processing returned empty list"}
320
-
321
- image_tensor = processed_images[0]
322
- image_tensor = image_tensor.half().to(our_chatbot.model.device)
323
- image_tensor = image_tensor.unsqueeze(0)
324
- print(f"[DEBUG] Image tensor shape: {image_tensor.shape}")
325
- except Exception as e:
326
- print(f"[DEBUG] Image processing error: {str(e)}")
327
- return {"error": f"Image processing failed: {str(e)}"}
328
-
329
- # Prepare conversation - reset for each request to avoid history issues
330
- try:
331
- if hasattr(our_chatbot, 'conv_mode') and our_chatbot.conv_mode and LLAVA_AVAILABLE:
332
- our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
333
- else:
334
- # Use default conversation template
335
- our_chatbot.conversation = our_chatbot.conversation.__class__()
336
- except Exception as e:
337
- print(f"[DEBUG] Failed to reset conversation: {e}")
338
- # Continue with existing conversation
339
-
340
  inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text
341
- our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
342
- our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
343
- prompt = our_chatbot.conversation.get_prompt()
344
-
345
- # Tokenize input
346
- input_ids = tokenizer_image_token(
347
- prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
348
- ).unsqueeze(0).to(our_chatbot.model.device)
349
-
350
- # No stopping criteria - let model generate freely up to max_new_tokens
351
- print(f"[DEBUG] No stopping criteria - free generation up to {max_output_tokens} tokens")
352
  stopping_criteria = None
353
-
354
- # Generate response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  with torch.no_grad():
356
- outputs = our_chatbot.model.generate(
357
  inputs=input_ids,
358
  images=image_tensor,
359
- do_sample=False,
360
- temperature=temperature,
361
- top_p=top_p,
362
- max_new_tokens=max_output_tokens,
363
- repetition_penalty=repetition_penalty,
364
- use_cache=False,
365
- pad_token_id=our_chatbot.tokenizer.eos_token_id,
366
- eos_token_id=our_chatbot.tokenizer.eos_token_id,
367
- length_penalty=1.0, # Don't penalize longer sequences
368
  )
369
-
370
- # Decode response
371
- try:
372
- print(f"[DEBUG] Outputs shape: {outputs.shape if hasattr(outputs, 'shape') else 'No shape attr'}")
373
- print(f"[DEBUG] Outputs length: {len(outputs) if hasattr(outputs, '__len__') else 'No length'}")
374
- print(f"[DEBUG] Input IDs shape: {input_ids.shape}")
375
-
376
- if len(outputs) == 0:
377
- return {"error": "Model generated empty output"}
378
-
379
- response = our_chatbot.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
380
-
381
- print(f"[DEBUG] Conversation messages length: {len(our_chatbot.conversation.messages)}")
382
- if len(our_chatbot.conversation.messages) > 0:
383
- last_message = our_chatbot.conversation.messages[-1]
384
- print(f"[DEBUG] Last message: {last_message}")
385
- if isinstance(last_message, list) and len(last_message) > 1:
386
- our_chatbot.conversation.messages[-1][-1] = response
387
- print(f"[DEBUG] Response added to conversation")
388
- else:
389
- print(f"[DEBUG] Last message format unexpected: {last_message}")
390
- # Add response as new message if format is wrong
391
- our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], response)
392
- else:
393
- print("[DEBUG] No conversation messages found")
394
- # Add response as new message
395
- our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], response)
396
-
397
- print(f"[DEBUG] Generated response length: {len(response)}")
398
- except Exception as e:
399
- print(f"[DEBUG] Response decoding error: {str(e)}")
400
- return {"error": f"Response decoding failed: {str(e)}"}
401
-
402
- # Log conversation
403
- history = [(message_text, response)]
404
  with open(get_conv_log_filename(), "a") as fout:
405
- data = {
406
  "type": "chat",
407
  "model": "PULSE-7b",
408
- "state": history,
409
- "images": all_image_hash,
410
- "images_path": all_image_path
411
- }
412
- print("#### conv log", data)
413
- fout.write(json.dumps(data) + "\n")
414
-
415
- # Upload files to Hugging Face if configured
416
- if api and repo_name:
417
- try:
418
- for upload_img in all_image_path:
419
- api.upload_file(
420
- path_or_fileobj=upload_img,
421
- path_in_repo=upload_img.replace("./logs/", ""),
422
- repo_id=repo_name,
423
- repo_type="dataset",
424
- )
425
-
426
- # Upload conversation log
427
- api.upload_file(
428
- path_or_fileobj=get_conv_log_filename(),
429
- path_in_repo=get_conv_log_filename().replace("./logs/", ""),
430
- repo_id=repo_name,
431
- repo_type="dataset")
432
- except Exception as e:
433
- print(f"Failed to upload files: {e}")
434
-
435
- return {
436
- "status": "success",
437
- "response": response,
438
- "conversation_id": id(our_chatbot.conversation)
439
- }
440
-
441
  except Exception as e:
442
  return {"error": f"Generation failed: {str(e)}"}
443
 
 
444
  def upvote_last_response(conversation_id):
445
- """Upvote the last response"""
446
  try:
447
  vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B")
448
- return {"status": "success", "message": "Thank you for your voting!"}
449
  except Exception as e:
450
- return {"error": f"Failed to upvote: {str(e)}"}
451
 
452
  def downvote_last_response(conversation_id):
453
- """Downvote the last response"""
454
  try:
455
  vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B")
456
- return {"status": "success", "message": "Thank you for your voting!"}
457
  except Exception as e:
458
- return {"error": f"Failed to downvote: {str(e)}"}
459
 
460
  def flag_response(conversation_id):
461
- """Flag the last response"""
462
  try:
463
  vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B")
464
- return {"status": "success", "message": "Response flagged successfully"}
465
  except Exception as e:
466
- return {"error": f"Failed to flag response: {str(e)}"}
467
 
468
- # Initialize model when module is imported
469
  def initialize_model():
470
- """Initialize the model and tokenizer"""
471
  global tokenizer, model, image_processor, context_len, args
472
-
473
  if not LLAVA_AVAILABLE:
474
  print("LLaVA modules not available, skipping model initialization")
475
  return False
476
-
477
  try:
478
- # Set default arguments
479
  class Args:
480
  def __init__(self):
481
  self.model_path = "PULSE-ECG/PULSE-7B"
@@ -488,95 +393,93 @@ def initialize_model():
488
  self.load_8bit = False
489
  self.load_4bit = False
490
  self.debug = False
491
-
492
  args = Args()
493
-
494
- # Load model
495
- model_path = args.model_path
496
  model_name = get_model_name_from_path(args.model_path)
497
- tokenizer, model, image_processor, context_len = load_pretrained_model(
498
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
499
  )
500
-
501
- print("### image_processor", image_processor)
502
- print("### tokenizer", tokenizer)
503
-
504
- # Move model to GPU if available
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  if torch.cuda.is_available():
506
  model = model.to(torch.device('cuda'))
507
  print("Model moved to CUDA")
508
  else:
509
  print("CUDA not available, using CPU")
510
-
511
  return True
512
-
513
  except Exception as e:
514
  print(f"Failed to initialize model: {e}")
515
  return False
516
 
517
- # Don't initialize model on import - do it lazily
518
- model_initialized = False
519
-
520
- # Main endpoint function for Hugging Face
521
  def query(payload):
522
- """Main endpoint function for Hugging Face inference API"""
523
  global model_initialized
524
-
525
- # Lazy initialization - initialize model on first call
526
  if not model_initialized:
527
  print("Initializing model on first query...")
528
  model_initialized = initialize_model()
529
  if not model_initialized:
530
  return {"error": "Model initialization failed"}
531
-
532
  try:
 
533
  print(f"[DEBUG] query payload keys={list(payload.keys()) if hasattr(payload,'keys') else 'N/A'}")
534
-
535
- # Extract prompt with multiple possible keys
536
- message_text = (payload.get("message") or
537
- payload.get("query") or
538
- payload.get("prompt") or
539
- payload.get("istem") or "")
540
-
541
- # Extract image with multiple possible keys
542
- image_input = (payload.get("image") or
543
- payload.get("image_url") or
544
- payload.get("img") or None)
545
-
546
- # Extract generation parameters with fallbacks
547
  temperature = float(payload.get("temperature", 0.05))
548
  top_p = float(payload.get("top_p", 1.0))
549
- max_output_tokens = int(payload.get("max_output_tokens",
550
- payload.get("max_new_tokens",
551
- payload.get("max_tokens", 8192))))
552
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
553
  conv_mode_override = payload.get("conv_mode", None)
554
-
555
- if not message_text or not message_text.strip():
556
- return {"error": "Missing prompt text. Use 'message', 'query', 'prompt', or 'istem' key"}
557
-
 
 
 
 
558
  if not image_input:
559
- return {"error": "Missing image. Use 'image', 'image_url', or 'img' key"}
560
-
561
- # Generate response with all parameters
562
- result = generate_response(
563
  message_text=message_text,
564
  image_input=image_input,
565
  temperature=temperature,
566
  top_p=top_p,
567
  max_output_tokens=max_output_tokens,
568
  repetition_penalty=repetition_penalty,
569
- conv_mode_override=conv_mode_override
 
 
 
570
  )
571
-
572
- return result
573
-
574
  except Exception as e:
575
  return {"error": f"Query failed: {str(e)}"}
576
 
577
- # Additional utility endpoints
578
  def health_check():
579
- """Health check endpoint"""
580
  return {
581
  "status": "healthy",
582
  "model_initialized": model_initialized,
@@ -584,18 +487,12 @@ def health_check():
584
  "llava_available": LLAVA_AVAILABLE,
585
  "transformers_available": TRANSFORMERS_AVAILABLE,
586
  "cv2_available": CV2_AVAILABLE,
587
- "lazy_loading": True # Model will be loaded on first query
588
  }
589
 
590
  def get_model_info():
591
- """Get model information"""
592
  if not model_initialized:
593
- return {
594
- "error": "Model not initialized yet",
595
- "lazy_loading": True,
596
- "note": "Model will be loaded on first query"
597
- }
598
-
599
  return {
600
  "model_path": args.model_path if args else "Unknown",
601
  "model_type": "PULSE-7B",
@@ -603,37 +500,19 @@ def get_model_info():
603
  "device": str(model.device) if model else "Unknown"
604
  }
605
 
606
- # Hugging Face EndpointHandler class
607
  class EndpointHandler:
608
- """Hugging Face endpoint handler class"""
609
-
610
  def __init__(self, model_dir):
611
- """Initialize the endpoint handler"""
612
  self.model_dir = model_dir
613
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
614
-
615
  def __call__(self, payload):
616
- """Main endpoint function - handles Hugging Face payload format"""
617
- # Hugging Face sends payload in "inputs" wrapper
618
  if "inputs" in payload:
619
- # Extract the actual payload from inputs wrapper
620
- actual_payload = payload["inputs"]
621
- return query(actual_payload)
622
- else:
623
- # Direct payload (for backward compatibility)
624
- return query(payload)
625
-
626
  def health_check(self):
627
- """Health check endpoint"""
628
  return health_check()
629
-
630
  def get_model_info(self):
631
- """Get model information"""
632
  return get_model_info()
633
 
634
- # For backward compatibility and testing
635
  if __name__ == "__main__":
636
- print("Handler module loaded successfully!")
637
- print("This handler is now ready for Hugging Face endpoints.")
638
- print("Use the 'query' function as the main endpoint.")
639
- print("Or use EndpointHandler class for Hugging Face compatibility.")
 
1
+ # -*- coding: utf-8 -*-
2
+ # handler.py — PULSE-7B / LLaVA endpoint (robust + deterministic-ready)
3
+
4
  import os
5
  import datetime
6
  import torch
 
12
  from PIL import Image
13
  from io import BytesIO
14
 
15
+ # Optional cv2
16
  try:
17
  import cv2
18
  CV2_AVAILABLE = True
 
20
  CV2_AVAILABLE = False
21
  print("Warning: cv2 (OpenCV) not available. Video processing will be disabled.")
22
 
23
+ # LLaVA stack
24
  try:
25
  from llava import conversation as conversation_lib
26
  from llava.constants import DEFAULT_IMAGE_TOKEN
 
44
  LLAVA_AVAILABLE = False
45
  print(f"Warning: LLaVA modules not available: {e}")
46
 
47
+ # Transformers
48
  try:
49
+ from transformers import GenerationConfig
50
  TRANSFORMERS_AVAILABLE = True
51
  except ImportError:
52
  TRANSFORMERS_AVAILABLE = False
53
  print("Warning: Transformers not available")
54
 
55
+ # HF Hub (optional)
56
  try:
57
  from huggingface_hub import HfApi, login
58
  HF_HUB_AVAILABLE = True
 
60
  HF_HUB_AVAILABLE = False
61
  print("Warning: Hugging Face Hub not available")
62
 
63
+ # HF Hub init (optional)
64
  if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
65
  try:
66
  login(token=os.environ["HF_TOKEN"], write_permission=True)
 
74
  api = None
75
  repo_name = ""
76
 
77
+ # Logs
78
  external_log_dir = "./logs"
79
  LOGDIR = external_log_dir
80
  VOTEDIR = "./votes"
81
 
82
+ # Globals
83
  tokenizer = None
84
  model = None
85
  image_processor = None
86
  context_len = None
87
  args = None
88
+ model_initialized = False
89
 
90
+ # ----- Utils -----
91
  def get_conv_log_filename():
92
  t = datetime.datetime.now()
93
+ return os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
 
94
 
95
  def get_conv_vote_filename():
96
  t = datetime.datetime.now()
 
103
  if api and repo_name:
104
  try:
105
  with open(get_conv_vote_filename(), "a") as fout:
106
+ fout.write(json.dumps({"type": vote_type, "model": model_selector, "state": state}) + "\n")
 
 
 
 
 
 
107
  api.upload_file(
108
  path_or_fileobj=get_conv_vote_filename(),
109
  path_in_repo=get_conv_vote_filename().replace("./votes/", ""),
 
114
 
115
  def is_valid_video_filename(name):
116
  if not CV2_AVAILABLE:
117
+ return False
118
+ return name.split(".")[-1].lower() in ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
 
 
119
 
120
  def is_valid_image_filename(name):
121
+ return name.split(".")[-1].lower() in ["jpg","jpeg","png","bmp","gif","tiff","webp","heic","heif","jfif","svg","eps","raw"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  def load_image(image_file):
124
+ if image_file.startswith("http"):
125
+ r = requests.get(image_file)
126
+ if r.status_code == 200:
127
+ return Image.open(BytesIO(r.content)).convert("RGB")
128
+ raise ValueError("Failed to load image from URL")
129
+ return Image.open(image_file).convert("RGB")
 
 
 
 
 
130
 
131
  def process_base64_image(base64_string):
132
+ if base64_string.startswith('data:image'):
133
+ base64_string = base64_string.split(',')[1]
134
+ image_data = base64.b64decode(base64_string)
135
+ return Image.open(BytesIO(image_data)).convert("RGB")
 
 
 
 
 
 
 
 
 
 
136
 
137
  def process_image_input(image_input):
 
138
  if isinstance(image_input, str):
139
  if image_input.startswith("http"):
140
  return load_image(image_input)
141
  elif os.path.exists(image_input):
142
  return load_image(image_input)
143
  else:
 
144
  return process_base64_image(image_input)
145
  elif isinstance(image_input, dict) and "image" in image_input:
 
146
  return process_base64_image(image_input["image"])
147
  else:
148
  raise ValueError("Unsupported image input format")
149
 
150
+ # ----- Chat session -----
151
  class InferenceDemo(object):
152
  def __init__(self, args, model_path, tokenizer, model, image_processor, context_len) -> None:
153
  if not LLAVA_AVAILABLE:
154
  raise ImportError("LLaVA modules not available")
 
155
  disable_torch_init()
 
156
  self.tokenizer, self.model, self.image_processor, self.context_len = (
157
+ tokenizer, model, image_processor, context_len
 
 
 
158
  )
 
159
  model_name = get_model_name_from_path(model_path)
160
  if "llama-2" in model_name.lower():
161
  conv_mode = "llava_llama_2"
 
167
  conv_mode = "qwen_1_5"
168
  else:
169
  conv_mode = "llava_v0"
 
170
  if args.conv_mode is not None and conv_mode != args.conv_mode:
171
+ print(f"[WARNING] auto inferred conv_mode={conv_mode}, using {args.conv_mode}")
 
 
 
 
172
  else:
173
  args.conv_mode = conv_mode
174
+ self.conv_mode = args.conv_mode
175
+ self.conversation = conv_templates[self.conv_mode].copy()
176
  self.num_frames = args.num_frames
177
 
178
  class ChatSessionManager:
179
  def __init__(self):
180
  self.chatbot_instance = None
 
181
  def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
182
  self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
183
  print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
 
184
  def reset_chatbot(self):
185
  self.chatbot_instance = None
 
186
  def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
187
  if self.chatbot_instance is None:
188
  self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
 
191
  chat_manager = ChatSessionManager()
192
 
193
  def clear_history():
 
194
  if not LLAVA_AVAILABLE:
195
  return {"error": "LLaVA modules not available"}
 
196
  try:
197
+ inst = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B",
198
+ tokenizer, model, image_processor, context_len)
199
+ mode = getattr(inst, 'conv_mode', None)
200
+ if mode and mode in conv_templates:
201
+ inst.conversation = conv_templates[mode].copy()
202
+ else:
203
+ inst.conversation = inst.conversation.__class__()
 
 
204
  return {"status": "success", "message": "Conversation history cleared"}
205
  except Exception as e:
206
  return {"error": f"Failed to clear history: {str(e)}"}
207
 
208
+ # ----- Robust prefix stripper -----
209
+ def _strip_prefix_relaxed(text: str, prefix: str) -> str:
210
+ try:
211
+ if text.startswith(prefix):
212
+ return text[len(prefix):]
213
+ t_norm = " ".join(text.split())
214
+ p_norm = " ".join(prefix.split())
215
+ if t_norm.startswith(p_norm):
216
+ idx = text.find(prefix.splitlines()[0]) if prefix.splitlines() else -1
217
+ if idx >= 0:
218
+ return text[idx + len(prefix.splitlines()[0]):]
219
+ except Exception:
220
+ pass
221
+ return text
222
+
223
+ # ----- Core generate -----
224
+ def generate_response(message_text,
225
+ image_input,
226
+ temperature=0.05,
227
+ top_p=1.0,
228
+ max_output_tokens=1024,
229
+ repetition_penalty=1.0,
230
+ conv_mode_override=None,
231
+ do_sample=False, # default greedy -> deterministik
232
+ seed=None,
233
+ use_stop=True):
234
  if not LLAVA_AVAILABLE:
235
  return {"error": "LLaVA modules not available"}
236
+
237
  try:
238
  if not message_text or not image_input:
239
  return {"error": "Both message text and image are required"}
240
+
241
+ # Determinism knobs
242
+ if seed is not None:
243
+ try:
244
+ seed = int(seed)
245
+ torch.manual_seed(seed)
246
+ np.random.seed(seed)
247
+ except Exception:
248
+ pass
249
+
250
+ inst = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B",
251
+ tokenizer, model, image_processor, context_len)
252
+
253
+ # Image
254
+ image = process_image_input(image_input)
255
  img_byte_arr = BytesIO()
256
  image.save(img_byte_arr, format='JPEG')
257
+ image_hash = hashlib.md5(img_byte_arr.getvalue()).hexdigest()
258
+
 
 
259
  # Save image to logs
260
  t = datetime.datetime.now()
261
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{image_hash}.jpg")
262
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
263
+ image.save(filename)
264
+
265
+ # Preprocess
266
+ processed_images = process_images([image], inst.image_processor, inst.model.config)
267
+ if len(processed_images) == 0:
268
+ return {"error": "Image processing returned empty list"}
269
+ image_tensor = processed_images[0].half().to(inst.model.device).unsqueeze(0)
270
+
271
+ # Conversation
272
+ if conv_mode_override:
273
+ inst.conversation = conv_templates[conv_mode_override].copy()
274
+ else:
275
+ inst.conversation = conv_templates[inst.conv_mode].copy()
276
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text
278
+ inst.conversation.append_message(inst.conversation.roles[0], inp)
279
+ inst.conversation.append_message(inst.conversation.roles[1], None)
280
+ prompt = inst.conversation.get_prompt()
281
+
282
+ # Tokenize
283
+ input_ids = tokenizer_image_token(prompt, inst.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(inst.model.device)
284
+
285
+ # Stop criteria
 
 
 
286
  stopping_criteria = None
287
+ stop_str = inst.conversation.sep if inst.conversation.sep_style != SeparatorStyle.TWO else inst.conversation.sep2
288
+ if use_stop:
289
+ stopping_criteria = KeywordsStoppingCriteria([stop_str], inst.tokenizer, input_ids)
290
+
291
+ # PAD/EOS safety
292
+ pad_id = inst.tokenizer.pad_token_id
293
+ eos_id = inst.tokenizer.eos_token_id if inst.tokenizer.eos_token_id is not None else pad_id
294
+ if pad_id is None:
295
+ # safety net (rare)
296
+ inst.tokenizer.add_special_tokens({"pad_token": inst.tokenizer.eos_token or "</s>"})
297
+ pad_id = inst.tokenizer.pad_token_id
298
+ eos_id = inst.tokenizer.eos_token_id or pad_id
299
+
300
+ gen_cfg = GenerationConfig(
301
+ do_sample=bool(do_sample),
302
+ temperature=float(temperature),
303
+ top_p=float(top_p),
304
+ max_new_tokens=int(max_output_tokens),
305
+ repetition_penalty=float(repetition_penalty),
306
+ pad_token_id=pad_id,
307
+ eos_token_id=eos_id
308
+ )
309
+
310
  with torch.no_grad():
311
+ outputs = inst.model.generate(
312
  inputs=input_ids,
313
  images=image_tensor,
314
+ generation_config=gen_cfg,
315
+ use_cache=True,
316
+ stopping_criteria=[stopping_criteria] if stopping_criteria is not None else None,
317
+ return_dict_in_generate=True
 
 
 
 
 
318
  )
319
+
320
+ # Robust decode
321
+ sequences = outputs.sequences
322
+ gen_ids = sequences[0]
323
+ full_text = inst.tokenizer.decode(gen_ids, skip_special_tokens=True)
324
+ prompt_text = inst.tokenizer.decode(input_ids[0], skip_special_tokens=True)
325
+
326
+ if gen_ids.shape[0] > input_ids.shape[1]:
327
+ response = inst.tokenizer.decode(gen_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
328
+ else:
329
+ response = _strip_prefix_relaxed(full_text, prompt_text).strip()
330
+
331
+ if not response:
332
+ response = full_text.replace(stop_str, "").strip()
333
+
334
+ # Add to conversation
335
+ if len(inst.conversation.messages) > 0 and isinstance(inst.conversation.messages[-1], list) and len(inst.conversation.messages[-1]) > 1:
336
+ inst.conversation.messages[-1][-1] = response
337
+ else:
338
+ inst.conversation.append_message(inst.conversation.roles[1], response)
339
+
340
+ # Log
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  with open(get_conv_log_filename(), "a") as fout:
342
+ fout.write(json.dumps({
343
  "type": "chat",
344
  "model": "PULSE-7b",
345
+ "state": [(message_text, response)],
346
+ "images": [image_hash],
347
+ "images_path": [filename]
348
+ }) + "\n")
349
+
350
+ return {"status": "success", "response": response, "conversation_id": id(inst.conversation)}
351
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  except Exception as e:
353
  return {"error": f"Generation failed: {str(e)}"}
354
 
355
+ # ----- Votes -----
356
  def upvote_last_response(conversation_id):
 
357
  try:
358
  vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B")
359
+ return {"status": "success", "message": "Upvoted"}
360
  except Exception as e:
361
+ return {"error": str(e)}
362
 
363
  def downvote_last_response(conversation_id):
 
364
  try:
365
  vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B")
366
+ return {"status": "success", "message": "Downvoted"}
367
  except Exception as e:
368
+ return {"error": str(e)}
369
 
370
  def flag_response(conversation_id):
 
371
  try:
372
  vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B")
373
+ return {"status": "success", "message": "Flagged"}
374
  except Exception as e:
375
+ return {"error": str(e)}
376
 
377
+ # ----- Init model (with PAD/EOS safety) -----
378
  def initialize_model():
 
379
  global tokenizer, model, image_processor, context_len, args
 
380
  if not LLAVA_AVAILABLE:
381
  print("LLaVA modules not available, skipping model initialization")
382
  return False
 
383
  try:
 
384
  class Args:
385
  def __init__(self):
386
  self.model_path = "PULSE-ECG/PULSE-7B"
 
393
  self.load_8bit = False
394
  self.load_4bit = False
395
  self.debug = False
 
396
  args = Args()
397
+
 
 
398
  model_name = get_model_name_from_path(args.model_path)
399
+ tok, mdl, img_proc, ctx_len = load_pretrained_model(
400
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
401
  )
402
+
403
+ # PAD/EOS safety
404
+ if tok.eos_token_id is None and tok.eos_token is None:
405
+ try:
406
+ tok.add_special_tokens({"eos_token": "</s>"})
407
+ except Exception:
408
+ pass
409
+ if tok.pad_token_id is None:
410
+ if tok.eos_token is not None:
411
+ tok.pad_token = tok.eos_token
412
+ else:
413
+ if tok.unk_token is None:
414
+ try:
415
+ tok.add_special_tokens({"unk_token": "<unk>"})
416
+ except Exception:
417
+ pass
418
+ tok.pad_token = tok.unk_token or "</s>"
419
+
420
+ tokenizer, model, image_processor, context_len = tok, mdl, img_proc, ctx_len
421
  if torch.cuda.is_available():
422
  model = model.to(torch.device('cuda'))
423
  print("Model moved to CUDA")
424
  else:
425
  print("CUDA not available, using CPU")
 
426
  return True
 
427
  except Exception as e:
428
  print(f"Failed to initialize model: {e}")
429
  return False
430
 
431
+ # ----- Query entrypoint -----
 
 
 
432
  def query(payload):
 
433
  global model_initialized
 
 
434
  if not model_initialized:
435
  print("Initializing model on first query...")
436
  model_initialized = initialize_model()
437
  if not model_initialized:
438
  return {"error": "Model initialization failed"}
439
+
440
  try:
441
+ # Log incoming keys
442
  print(f"[DEBUG] query payload keys={list(payload.keys()) if hasattr(payload,'keys') else 'N/A'}")
443
+
444
+ # Inputs
445
+ message_text = (payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or "").strip()
446
+ image_input = (payload.get("image") or payload.get("image_url") or payload.get("img") or None)
447
+
448
+ # Gen params
 
 
 
 
 
 
 
449
  temperature = float(payload.get("temperature", 0.05))
450
  top_p = float(payload.get("top_p", 1.0))
451
+ max_output_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 1024))))
 
 
452
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
453
  conv_mode_override = payload.get("conv_mode", None)
454
+
455
+ # Determinism toggles
456
+ do_sample = bool(payload.get("do_sample", False)) # default greedy
457
+ seed = payload.get("seed", None)
458
+ use_stop = bool(payload.get("use_stop", True)) # default stop criteria açık
459
+
460
+ if not message_text:
461
+ return {"error": "Missing prompt text. Provide 'message' (or 'query'/'prompt'/'istem')."}
462
  if not image_input:
463
+ return {"error": "Missing image. Provide 'image' (url/base64/path) or 'image_url'/'img'."}
464
+
465
+ return generate_response(
 
466
  message_text=message_text,
467
  image_input=image_input,
468
  temperature=temperature,
469
  top_p=top_p,
470
  max_output_tokens=max_output_tokens,
471
  repetition_penalty=repetition_penalty,
472
+ conv_mode_override=conv_mode_override,
473
+ do_sample=do_sample,
474
+ seed=seed,
475
+ use_stop=use_stop
476
  )
477
+
 
 
478
  except Exception as e:
479
  return {"error": f"Query failed: {str(e)}"}
480
 
481
+ # ----- Health / Info -----
482
  def health_check():
 
483
  return {
484
  "status": "healthy",
485
  "model_initialized": model_initialized,
 
487
  "llava_available": LLAVA_AVAILABLE,
488
  "transformers_available": TRANSFORMERS_AVAILABLE,
489
  "cv2_available": CV2_AVAILABLE,
490
+ "lazy_loading": True
491
  }
492
 
493
  def get_model_info():
 
494
  if not model_initialized:
495
+ return {"error": "Model not initialized yet", "lazy_loading": True}
 
 
 
 
 
496
  return {
497
  "model_path": args.model_path if args else "Unknown",
498
  "model_type": "PULSE-7B",
 
500
  "device": str(model.device) if model else "Unknown"
501
  }
502
 
503
+ # ----- HF Endpoint handler -----
504
  class EndpointHandler:
 
 
505
  def __init__(self, model_dir):
 
506
  self.model_dir = model_dir
507
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
 
508
  def __call__(self, payload):
 
 
509
  if "inputs" in payload:
510
+ return query(payload["inputs"])
511
+ return query(payload)
 
 
 
 
 
512
  def health_check(self):
 
513
  return health_check()
 
514
  def get_model_info(self):
 
515
  return get_model_info()
516
 
 
517
  if __name__ == "__main__":
518
+ print("Handler loaded and ready.")