Files changed (1) hide show
  1. handler.py +368 -232
handler.py CHANGED
@@ -1,5 +1,16 @@
1
- # -*- coding: utf-8 -*-
2
- # handler.py PULSE-7B / LLaVA endpoint (robust + deterministic-ready)
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  import os
5
  import datetime
@@ -12,7 +23,7 @@ import requests
12
  from PIL import Image
13
  from io import BytesIO
14
 
15
- # Optional cv2
16
  try:
17
  import cv2
18
  CV2_AVAILABLE = True
@@ -20,7 +31,7 @@ except ImportError:
20
  CV2_AVAILABLE = False
21
  print("Warning: cv2 (OpenCV) not available. Video processing will be disabled.")
22
 
23
- # LLaVA stack
24
  try:
25
  from llava import conversation as conversation_lib
26
  from llava.constants import DEFAULT_IMAGE_TOKEN
@@ -44,15 +55,15 @@ except ImportError as e:
44
  LLAVA_AVAILABLE = False
45
  print(f"Warning: LLaVA modules not available: {e}")
46
 
47
- # Transformers
48
  try:
49
- from transformers import GenerationConfig
50
  TRANSFORMERS_AVAILABLE = True
51
  except ImportError:
52
  TRANSFORMERS_AVAILABLE = False
53
  print("Warning: Transformers not available")
54
 
55
- # HF Hub (optional)
56
  try:
57
  from huggingface_hub import HfApi, login
58
  HF_HUB_AVAILABLE = True
@@ -60,7 +71,7 @@ except ImportError:
60
  HF_HUB_AVAILABLE = False
61
  print("Warning: Hugging Face Hub not available")
62
 
63
- # HF Hub init (optional)
64
  if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
65
  try:
66
  login(token=os.environ["HF_TOKEN"], write_permission=True)
@@ -74,23 +85,21 @@ else:
74
  api = None
75
  repo_name = ""
76
 
77
- # Logs
78
  external_log_dir = "./logs"
79
  LOGDIR = external_log_dir
80
  VOTEDIR = "./votes"
81
 
82
- # Globals
83
  tokenizer = None
84
  model = None
85
  image_processor = None
86
  context_len = None
87
  args = None
88
- model_initialized = False
89
 
90
- # ----- Utils -----
91
  def get_conv_log_filename():
92
  t = datetime.datetime.now()
93
- return os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
 
94
 
95
  def get_conv_vote_filename():
96
  t = datetime.datetime.now()
@@ -103,7 +112,13 @@ def vote_last_response(state, vote_type, model_selector):
103
  if api and repo_name:
104
  try:
105
  with open(get_conv_vote_filename(), "a") as fout:
106
- fout.write(json.dumps({"type": vote_type, "model": model_selector, "state": state}) + "\n")
 
 
 
 
 
 
107
  api.upload_file(
108
  path_or_fileobj=get_conv_vote_filename(),
109
  path_in_repo=get_conv_vote_filename().replace("./votes/", ""),
@@ -114,48 +129,93 @@ def vote_last_response(state, vote_type, model_selector):
114
 
115
  def is_valid_video_filename(name):
116
  if not CV2_AVAILABLE:
117
- return False
118
- return name.split(".")[-1].lower() in ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
 
 
119
 
120
  def is_valid_image_filename(name):
121
- return name.split(".")[-1].lower() in ["jpg","jpeg","png","bmp","gif","tiff","webp","heic","heif","jfif","svg","eps","raw"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  def load_image(image_file):
124
- if image_file.startswith("http"):
125
- r = requests.get(image_file)
126
- if r.status_code == 200:
127
- return Image.open(BytesIO(r.content)).convert("RGB")
128
- raise ValueError("Failed to load image from URL")
129
- return Image.open(image_file).convert("RGB")
 
 
 
 
 
130
 
131
  def process_base64_image(base64_string):
132
- if base64_string.startswith('data:image'):
133
- base64_string = base64_string.split(',')[1]
134
- image_data = base64.b64decode(base64_string)
135
- return Image.open(BytesIO(image_data)).convert("RGB")
 
 
 
 
 
 
 
 
 
 
136
 
137
  def process_image_input(image_input):
 
138
  if isinstance(image_input, str):
139
  if image_input.startswith("http"):
140
  return load_image(image_input)
141
  elif os.path.exists(image_input):
142
  return load_image(image_input)
143
  else:
 
144
  return process_base64_image(image_input)
145
  elif isinstance(image_input, dict) and "image" in image_input:
 
146
  return process_base64_image(image_input["image"])
147
  else:
148
  raise ValueError("Unsupported image input format")
149
 
150
- # ----- Chat session -----
151
  class InferenceDemo(object):
152
  def __init__(self, args, model_path, tokenizer, model, image_processor, context_len) -> None:
153
  if not LLAVA_AVAILABLE:
154
  raise ImportError("LLaVA modules not available")
 
155
  disable_torch_init()
 
156
  self.tokenizer, self.model, self.image_processor, self.context_len = (
157
- tokenizer, model, image_processor, context_len
 
 
 
158
  )
 
159
  model_name = get_model_name_from_path(model_path)
160
  if "llama-2" in model_name.lower():
161
  conv_mode = "llava_llama_2"
@@ -167,22 +227,30 @@ class InferenceDemo(object):
167
  conv_mode = "qwen_1_5"
168
  else:
169
  conv_mode = "llava_v0"
 
170
  if args.conv_mode is not None and conv_mode != args.conv_mode:
171
- print(f"[WARNING] auto inferred conv_mode={conv_mode}, using {args.conv_mode}")
 
 
 
 
172
  else:
173
  args.conv_mode = conv_mode
174
- self.conv_mode = args.conv_mode
175
- self.conversation = conv_templates[self.conv_mode].copy()
176
  self.num_frames = args.num_frames
177
 
178
  class ChatSessionManager:
179
  def __init__(self):
180
  self.chatbot_instance = None
 
181
  def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
182
  self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
183
  print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
 
184
  def reset_chatbot(self):
185
  self.chatbot_instance = None
 
186
  def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
187
  if self.chatbot_instance is None:
188
  self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
@@ -191,295 +259,339 @@ class ChatSessionManager:
191
  chat_manager = ChatSessionManager()
192
 
193
  def clear_history():
 
194
  if not LLAVA_AVAILABLE:
195
  return {"error": "LLaVA modules not available"}
 
196
  try:
197
- inst = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B",
198
- tokenizer, model, image_processor, context_len)
199
- mode = getattr(inst, 'conv_mode', None)
200
- if mode and mode in conv_templates:
201
- inst.conversation = conv_templates[mode].copy()
202
- else:
203
- inst.conversation = inst.conversation.__class__()
 
 
204
  return {"status": "success", "message": "Conversation history cleared"}
205
  except Exception as e:
206
  return {"error": f"Failed to clear history: {str(e)}"}
207
 
208
- # ----- Robust prefix stripper -----
209
- def _strip_prefix_relaxed(text: str, prefix: str) -> str:
210
- try:
211
- if text.startswith(prefix):
212
- return text[len(prefix):]
213
- t_norm = " ".join(text.split())
214
- p_norm = " ".join(prefix.split())
215
- if t_norm.startswith(p_norm):
216
- idx = text.find(prefix.splitlines()[0]) if prefix.splitlines() else -1
217
- if idx >= 0:
218
- return text[idx + len(prefix.splitlines()[0]):]
219
- except Exception:
220
- pass
221
- return text
222
-
223
- # ----- Core generate -----
224
- def generate_response(message_text,
225
- image_input,
226
- temperature=0.05,
227
- top_p=1.0,
228
- max_output_tokens=1024,
229
- repetition_penalty=1.0,
230
- conv_mode_override=None,
231
- do_sample=False, # default greedy -> deterministik
232
- seed=None,
233
- use_stop=True):
234
  if not LLAVA_AVAILABLE:
235
  return {"error": "LLaVA modules not available"}
236
-
237
  try:
238
  if not message_text or not image_input:
239
  return {"error": "Both message text and image are required"}
240
-
241
- # Determinism knobs
242
- if seed is not None:
243
- try:
244
- seed = int(seed)
245
- torch.manual_seed(seed)
246
- np.random.seed(seed)
247
- except Exception:
248
- pass
249
-
250
- inst = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B",
251
- tokenizer, model, image_processor, context_len)
252
-
253
- # Image
254
- image = process_image_input(image_input)
255
  img_byte_arr = BytesIO()
256
  image.save(img_byte_arr, format='JPEG')
257
- image_hash = hashlib.md5(img_byte_arr.getvalue()).hexdigest()
258
-
 
 
259
  # Save image to logs
260
  t = datetime.datetime.now()
261
- filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{image_hash}.jpg")
262
- os.makedirs(os.path.dirname(filename), exist_ok=True)
263
- image.save(filename)
264
-
265
- # Preprocess
266
- processed_images = process_images([image], inst.image_processor, inst.model.config)
267
- if len(processed_images) == 0:
268
- return {"error": "Image processing returned empty list"}
269
- image_tensor = processed_images[0].half().to(inst.model.device).unsqueeze(0)
270
-
271
- # Conversation
272
- if conv_mode_override:
273
- inst.conversation = conv_templates[conv_mode_override].copy()
274
- else:
275
- inst.conversation = conv_templates[inst.conv_mode].copy()
276
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text
278
- inst.conversation.append_message(inst.conversation.roles[0], inp)
279
- inst.conversation.append_message(inst.conversation.roles[1], None)
280
- prompt = inst.conversation.get_prompt()
281
-
282
- # Tokenize
283
- input_ids = tokenizer_image_token(prompt, inst.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(inst.model.device)
284
-
285
- # Stop criteria
 
 
 
286
  stopping_criteria = None
287
- stop_str = inst.conversation.sep if inst.conversation.sep_style != SeparatorStyle.TWO else inst.conversation.sep2
288
- if use_stop:
289
- stopping_criteria = KeywordsStoppingCriteria([stop_str], inst.tokenizer, input_ids)
290
-
291
- # PAD/EOS safety
292
- pad_id = inst.tokenizer.pad_token_id
293
- eos_id = inst.tokenizer.eos_token_id if inst.tokenizer.eos_token_id is not None else pad_id
294
- if pad_id is None:
295
- # safety net (rare)
296
- inst.tokenizer.add_special_tokens({"pad_token": inst.tokenizer.eos_token or "</s>"})
297
- pad_id = inst.tokenizer.pad_token_id
298
- eos_id = inst.tokenizer.eos_token_id or pad_id
299
-
300
- gen_cfg = GenerationConfig(
301
- do_sample=bool(do_sample),
302
- temperature=float(temperature),
303
- top_p=float(top_p),
304
- max_new_tokens=int(max_output_tokens),
305
- repetition_penalty=float(repetition_penalty),
306
- pad_token_id=pad_id,
307
- eos_token_id=eos_id
308
- )
309
-
310
  with torch.no_grad():
311
- outputs = inst.model.generate(
312
  inputs=input_ids,
313
  images=image_tensor,
314
- generation_config=gen_cfg,
315
- use_cache=True,
316
- stopping_criteria=[stopping_criteria] if stopping_criteria is not None else None,
317
- return_dict_in_generate=True
 
 
 
318
  )
319
-
320
- # Robust decode
321
- sequences = outputs.sequences
322
- gen_ids = sequences[0]
323
- full_text = inst.tokenizer.decode(gen_ids, skip_special_tokens=True)
324
- prompt_text = inst.tokenizer.decode(input_ids[0], skip_special_tokens=True)
325
-
326
- if gen_ids.shape[0] > input_ids.shape[1]:
327
- response = inst.tokenizer.decode(gen_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
328
- else:
329
- response = _strip_prefix_relaxed(full_text, prompt_text).strip()
330
-
331
- if not response:
332
- response = full_text.replace(stop_str, "").strip()
333
-
334
- # Add to conversation
335
- if len(inst.conversation.messages) > 0 and isinstance(inst.conversation.messages[-1], list) and len(inst.conversation.messages[-1]) > 1:
336
- inst.conversation.messages[-1][-1] = response
337
- else:
338
- inst.conversation.append_message(inst.conversation.roles[1], response)
339
-
340
- # Log
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  with open(get_conv_log_filename(), "a") as fout:
342
- fout.write(json.dumps({
343
  "type": "chat",
344
  "model": "PULSE-7b",
345
- "state": [(message_text, response)],
346
- "images": [image_hash],
347
- "images_path": [filename]
348
- }) + "\n")
349
-
350
- return {"status": "success", "response": response, "conversation_id": id(inst.conversation)}
351
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  except Exception as e:
353
  return {"error": f"Generation failed: {str(e)}"}
354
 
355
- # ----- Votes -----
356
  def upvote_last_response(conversation_id):
 
357
  try:
358
  vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B")
359
- return {"status": "success", "message": "Upvoted"}
360
  except Exception as e:
361
- return {"error": str(e)}
362
 
363
  def downvote_last_response(conversation_id):
 
364
  try:
365
  vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B")
366
- return {"status": "success", "message": "Downvoted"}
367
  except Exception as e:
368
- return {"error": str(e)}
369
 
370
  def flag_response(conversation_id):
 
371
  try:
372
  vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B")
373
- return {"status": "success", "message": "Flagged"}
374
  except Exception as e:
375
- return {"error": str(e)}
376
 
377
- # ----- Init model (with PAD/EOS safety) -----
378
  def initialize_model():
 
379
  global tokenizer, model, image_processor, context_len, args
 
380
  if not LLAVA_AVAILABLE:
381
  print("LLaVA modules not available, skipping model initialization")
382
  return False
 
383
  try:
 
384
  class Args:
385
  def __init__(self):
386
  self.model_path = "PULSE-ECG/PULSE-7B"
387
  self.model_base = None
388
  self.num_gpus = 1
389
  self.conv_mode = None
390
- self.temperature = 0.05
391
  self.max_new_tokens = 1024
392
  self.num_frames = 16
393
  self.load_8bit = False
394
  self.load_4bit = False
395
  self.debug = False
 
396
  args = Args()
397
-
 
 
398
  model_name = get_model_name_from_path(args.model_path)
399
- tok, mdl, img_proc, ctx_len = load_pretrained_model(
400
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
401
  )
402
-
403
- # PAD/EOS safety
404
- if tok.eos_token_id is None and tok.eos_token is None:
405
- try:
406
- tok.add_special_tokens({"eos_token": "</s>"})
407
- except Exception:
408
- pass
409
- if tok.pad_token_id is None:
410
- if tok.eos_token is not None:
411
- tok.pad_token = tok.eos_token
412
- else:
413
- if tok.unk_token is None:
414
- try:
415
- tok.add_special_tokens({"unk_token": "<unk>"})
416
- except Exception:
417
- pass
418
- tok.pad_token = tok.unk_token or "</s>"
419
-
420
- tokenizer, model, image_processor, context_len = tok, mdl, img_proc, ctx_len
421
  if torch.cuda.is_available():
422
  model = model.to(torch.device('cuda'))
423
  print("Model moved to CUDA")
424
  else:
425
  print("CUDA not available, using CPU")
 
426
  return True
 
427
  except Exception as e:
428
  print(f"Failed to initialize model: {e}")
429
  return False
430
 
431
- # ----- Query entrypoint -----
 
 
 
432
  def query(payload):
 
433
  global model_initialized
 
 
434
  if not model_initialized:
435
  print("Initializing model on first query...")
436
  model_initialized = initialize_model()
437
  if not model_initialized:
438
  return {"error": "Model initialization failed"}
439
-
440
  try:
441
- # Log incoming keys
442
  print(f"[DEBUG] query payload keys={list(payload.keys()) if hasattr(payload,'keys') else 'N/A'}")
443
-
444
- # Inputs
445
- message_text = (payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or "").strip()
446
- image_input = (payload.get("image") or payload.get("image_url") or payload.get("img") or None)
447
-
448
- # Gen params
449
- temperature = float(payload.get("temperature", 0.05))
450
- top_p = float(payload.get("top_p", 1.0))
451
- max_output_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 1024))))
 
 
 
 
 
 
 
452
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
453
  conv_mode_override = payload.get("conv_mode", None)
454
-
455
- # Determinism toggles
456
- do_sample = bool(payload.get("do_sample", False)) # default greedy
457
- seed = payload.get("seed", None)
458
- use_stop = bool(payload.get("use_stop", True)) # default stop criteria açık
459
-
460
- if not message_text:
461
- return {"error": "Missing prompt text. Provide 'message' (or 'query'/'prompt'/'istem')."}
462
  if not image_input:
463
- return {"error": "Missing image. Provide 'image' (url/base64/path) or 'image_url'/'img'."}
464
-
465
- return generate_response(
 
466
  message_text=message_text,
467
  image_input=image_input,
468
- temperature=temperature,
469
- top_p=top_p,
470
  max_output_tokens=max_output_tokens,
471
  repetition_penalty=repetition_penalty,
472
- conv_mode_override=conv_mode_override,
473
- do_sample=do_sample,
474
- seed=seed,
475
- use_stop=use_stop
476
  )
477
-
 
 
478
  except Exception as e:
479
  return {"error": f"Query failed: {str(e)}"}
480
 
481
- # ----- Health / Info -----
482
  def health_check():
 
483
  return {
484
  "status": "healthy",
485
  "model_initialized": model_initialized,
@@ -487,12 +599,18 @@ def health_check():
487
  "llava_available": LLAVA_AVAILABLE,
488
  "transformers_available": TRANSFORMERS_AVAILABLE,
489
  "cv2_available": CV2_AVAILABLE,
490
- "lazy_loading": True
491
  }
492
 
493
  def get_model_info():
 
494
  if not model_initialized:
495
- return {"error": "Model not initialized yet", "lazy_loading": True}
 
 
 
 
 
496
  return {
497
  "model_path": args.model_path if args else "Unknown",
498
  "model_type": "PULSE-7B",
@@ -500,19 +618,37 @@ def get_model_info():
500
  "device": str(model.device) if model else "Unknown"
501
  }
502
 
503
- # ----- HF Endpoint handler -----
504
  class EndpointHandler:
 
 
505
  def __init__(self, model_dir):
 
506
  self.model_dir = model_dir
507
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
 
508
  def __call__(self, payload):
 
 
509
  if "inputs" in payload:
510
- return query(payload["inputs"])
511
- return query(payload)
 
 
 
 
 
512
  def health_check(self):
 
513
  return health_check()
 
514
  def get_model_info(self):
 
515
  return get_model_info()
516
 
 
517
  if __name__ == "__main__":
518
- print("Handler loaded and ready.")
 
 
 
 
1
+ """
2
+ PULSE ECG Handler - Deterministic ECG Analysis Model
3
+
4
+ This handler provides consistent, deterministic responses for ECG analysis.
5
+ All generation parameters are fixed to ensure reproducible results across
6
+ different API calls and clients.
7
+
8
+ Key Features:
9
+ - Deterministic generation (do_sample=False)
10
+ - Fixed random seed for consistency
11
+ - No temperature/top_p sampling parameters
12
+ - Consistent response lengths and content
13
+ """
14
 
15
  import os
16
  import datetime
 
23
  from PIL import Image
24
  from io import BytesIO
25
 
26
+ # Try to import cv2, but make it optional
27
  try:
28
  import cv2
29
  CV2_AVAILABLE = True
 
31
  CV2_AVAILABLE = False
32
  print("Warning: cv2 (OpenCV) not available. Video processing will be disabled.")
33
 
34
+ # Try to import llava modules, but make them optional
35
  try:
36
  from llava import conversation as conversation_lib
37
  from llava.constants import DEFAULT_IMAGE_TOKEN
 
55
  LLAVA_AVAILABLE = False
56
  print(f"Warning: LLaVA modules not available: {e}")
57
 
58
+ # Try to import transformers
59
  try:
60
+ from transformers import TextStreamer, TextIteratorStreamer
61
  TRANSFORMERS_AVAILABLE = True
62
  except ImportError:
63
  TRANSFORMERS_AVAILABLE = False
64
  print("Warning: Transformers not available")
65
 
66
+ # Try to import huggingface_hub
67
  try:
68
  from huggingface_hub import HfApi, login
69
  HF_HUB_AVAILABLE = True
 
71
  HF_HUB_AVAILABLE = False
72
  print("Warning: Hugging Face Hub not available")
73
 
74
+ # Initialize Hugging Face API
75
  if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
76
  try:
77
  login(token=os.environ["HF_TOKEN"], write_permission=True)
 
85
  api = None
86
  repo_name = ""
87
 
 
88
  external_log_dir = "./logs"
89
  LOGDIR = external_log_dir
90
  VOTEDIR = "./votes"
91
 
92
+ # Global variables for model and tokenizer
93
  tokenizer = None
94
  model = None
95
  image_processor = None
96
  context_len = None
97
  args = None
 
98
 
 
99
  def get_conv_log_filename():
100
  t = datetime.datetime.now()
101
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
102
+ return name
103
 
104
  def get_conv_vote_filename():
105
  t = datetime.datetime.now()
 
112
  if api and repo_name:
113
  try:
114
  with open(get_conv_vote_filename(), "a") as fout:
115
+ data = {
116
+ "type": vote_type,
117
+ "model": model_selector,
118
+ "state": state,
119
+ }
120
+ fout.write(json.dumps(data) + "\n")
121
+
122
  api.upload_file(
123
  path_or_fileobj=get_conv_vote_filename(),
124
  path_in_repo=get_conv_vote_filename().replace("./votes/", ""),
 
129
 
130
  def is_valid_video_filename(name):
131
  if not CV2_AVAILABLE:
132
+ return False # Video processing disabled
133
+ video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
134
+ ext = name.split(".")[-1].lower()
135
+ return ext in video_extensions
136
 
137
  def is_valid_image_filename(name):
138
+ image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
139
+ ext = name.split(".")[-1].lower()
140
+ return ext in image_extensions
141
+
142
+ def sample_frames(video_file, num_frames):
143
+ if not CV2_AVAILABLE:
144
+ raise ImportError("cv2 (OpenCV) not available. Video processing is disabled.")
145
+
146
+ video = cv2.VideoCapture(video_file)
147
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
148
+ interval = total_frames // num_frames
149
+ frames = []
150
+ for i in range(total_frames):
151
+ ret, frame = video.read()
152
+ if not ret:
153
+ continue
154
+ if i % interval == 0:
155
+ pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
156
+ frames.append(pil_img)
157
+ video.release()
158
+ return frames
159
 
160
  def load_image(image_file):
161
+ if image_file.startswith("http") or image_file.startswith("https"):
162
+ response = requests.get(image_file)
163
+ if response.status_code == 200:
164
+ image = Image.open(BytesIO(response.content)).convert("RGB")
165
+ else:
166
+ raise ValueError("Failed to load image from URL")
167
+ else:
168
+ print("Load image from local file")
169
+ print(image_file)
170
+ image = Image.open(image_file).convert("RGB")
171
+ return image
172
 
173
  def process_base64_image(base64_string):
174
+ """Process base64 encoded image string"""
175
+ try:
176
+ # Remove data URL prefix if present
177
+ if base64_string.startswith('data:image'):
178
+ base64_string = base64_string.split(',')[1]
179
+
180
+ # Decode base64 to bytes
181
+ image_data = base64.b64decode(base64_string)
182
+
183
+ # Convert to PIL Image
184
+ image = Image.open(BytesIO(image_data)).convert("RGB")
185
+ return image
186
+ except Exception as e:
187
+ raise ValueError(f"Failed to process base64 image: {e}")
188
 
189
  def process_image_input(image_input):
190
+ """Process different types of image input (file path, URL, or base64)"""
191
  if isinstance(image_input, str):
192
  if image_input.startswith("http"):
193
  return load_image(image_input)
194
  elif os.path.exists(image_input):
195
  return load_image(image_input)
196
  else:
197
+ # Try to process as base64
198
  return process_base64_image(image_input)
199
  elif isinstance(image_input, dict) and "image" in image_input:
200
+ # Handle base64 image from dict
201
  return process_base64_image(image_input["image"])
202
  else:
203
  raise ValueError("Unsupported image input format")
204
 
 
205
  class InferenceDemo(object):
206
  def __init__(self, args, model_path, tokenizer, model, image_processor, context_len) -> None:
207
  if not LLAVA_AVAILABLE:
208
  raise ImportError("LLaVA modules not available")
209
+
210
  disable_torch_init()
211
+
212
  self.tokenizer, self.model, self.image_processor, self.context_len = (
213
+ tokenizer,
214
+ model,
215
+ image_processor,
216
+ context_len,
217
  )
218
+
219
  model_name = get_model_name_from_path(model_path)
220
  if "llama-2" in model_name.lower():
221
  conv_mode = "llava_llama_2"
 
227
  conv_mode = "qwen_1_5"
228
  else:
229
  conv_mode = "llava_v0"
230
+
231
  if args.conv_mode is not None and conv_mode != args.conv_mode:
232
+ print(
233
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
234
+ conv_mode, args.conv_mode, args.conv_mode
235
+ )
236
+ )
237
  else:
238
  args.conv_mode = conv_mode
239
+ self.conv_mode = conv_mode
240
+ self.conversation = conv_templates[args.conv_mode].copy()
241
  self.num_frames = args.num_frames
242
 
243
  class ChatSessionManager:
244
  def __init__(self):
245
  self.chatbot_instance = None
246
+
247
  def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
248
  self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
249
  print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
250
+
251
  def reset_chatbot(self):
252
  self.chatbot_instance = None
253
+
254
  def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
255
  if self.chatbot_instance is None:
256
  self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
 
259
  chat_manager = ChatSessionManager()
260
 
261
  def clear_history():
262
+ """Clear conversation history"""
263
  if not LLAVA_AVAILABLE:
264
  return {"error": "LLaVA modules not available"}
265
+
266
  try:
267
+ chatbot_instance = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len)
268
+ try:
269
+ if hasattr(chatbot_instance, 'conv_mode') and chatbot_instance.conv_mode and LLAVA_AVAILABLE:
270
+ chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy()
271
+ else:
272
+ # Use default conversation template
273
+ chatbot_instance.conversation = chatbot_instance.conversation.__class__()
274
+ except Exception as e:
275
+ print(f"[DEBUG] Failed to reset conversation in clear_history: {e}")
276
  return {"status": "success", "message": "Conversation history cleared"}
277
  except Exception as e:
278
  return {"error": f"Failed to clear history: {str(e)}"}
279
 
280
+ def add_message(message_text, image_input=None):
281
+ """Add a message to the conversation"""
282
+ return {"status": "success", "message": "Message added"}
283
+
284
+ def generate_response(message_text, image_input, max_output_tokens=4096, repetition_penalty=1.0, conv_mode_override=None):
285
+ """Generate response for the given message and image using deterministic generation for consistency"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  if not LLAVA_AVAILABLE:
287
  return {"error": "LLaVA modules not available"}
288
+
289
  try:
290
  if not message_text or not image_input:
291
  return {"error": "Both message text and image are required"}
292
+
293
+ our_chatbot = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len)
294
+
295
+ # Process image input
296
+ try:
297
+ image = process_image_input(image_input)
298
+ except Exception as e:
299
+ return {"error": f"Failed to process image: {str(e)}"}
300
+
301
+ # Save image for logging
302
+ all_image_hash = []
303
+ all_image_path = []
304
+
305
+ # Generate hash for the image
 
306
  img_byte_arr = BytesIO()
307
  image.save(img_byte_arr, format='JPEG')
308
+ img_byte_arr = img_byte_arr.getvalue()
309
+ image_hash = hashlib.md5(img_byte_arr).hexdigest()
310
+ all_image_hash.append(image_hash)
311
+
312
  # Save image to logs
313
  t = datetime.datetime.now()
314
+ filename = os.path.join(
315
+ LOGDIR,
316
+ "serve_images",
317
+ f"{t.year}-{t.month:02d}-{t.day:02d}",
318
+ f"{image_hash}.jpg",
319
+ )
320
+ all_image_path.append(filename)
321
+ if not os.path.isfile(filename):
322
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
323
+ print("image save to", filename)
324
+ image.save(filename)
325
+
326
+ # Process image for model
327
+ try:
328
+ print(f"[DEBUG] Processing image for model...")
329
+ processed_images = process_images([image], our_chatbot.image_processor, our_chatbot.model.config)
330
+ print(f"[DEBUG] Processed images length: {len(processed_images)}")
331
+
332
+ if len(processed_images) == 0:
333
+ return {"error": "Image processing returned empty list"}
334
+
335
+ image_tensor = processed_images[0]
336
+ image_tensor = image_tensor.half().to(our_chatbot.model.device)
337
+ image_tensor = image_tensor.unsqueeze(0)
338
+ print(f"[DEBUG] Image tensor shape: {image_tensor.shape}")
339
+ except Exception as e:
340
+ print(f"[DEBUG] Image processing error: {str(e)}")
341
+ return {"error": f"Image processing failed: {str(e)}"}
342
+
343
+ # Prepare conversation - reset for each request to avoid history issues
344
+ try:
345
+ if hasattr(our_chatbot, 'conv_mode') and our_chatbot.conv_mode and LLAVA_AVAILABLE:
346
+ our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
347
+ else:
348
+ # Use default conversation template
349
+ our_chatbot.conversation = our_chatbot.conversation.__class__()
350
+ except Exception as e:
351
+ print(f"[DEBUG] Failed to reset conversation: {e}")
352
+ # Continue with existing conversation
353
+
354
  inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text
355
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
356
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
357
+ prompt = our_chatbot.conversation.get_prompt()
358
+
359
+ # Tokenize input
360
+ input_ids = tokenizer_image_token(
361
+ prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
362
+ ).unsqueeze(0).to(our_chatbot.model.device)
363
+
364
+ # No stopping criteria - let model generate freely up to max_new_tokens
365
+ print(f"[DEBUG] No stopping criteria - free generation up to {max_output_tokens} tokens")
366
  stopping_criteria = None
367
+
368
+ # Set seed for deterministic generation
369
+ # This ensures the same input always produces the same output
370
+ torch.manual_seed(42)
371
+ if torch.cuda.is_available():
372
+ torch.cuda.manual_seed(42)
373
+ torch.cuda.manual_seed_all(42)
374
+
375
+ # Generate response using deterministic greedy decoding
376
+ # This eliminates randomness and ensures consistent responses
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  with torch.no_grad():
378
+ outputs = our_chatbot.model.generate(
379
  inputs=input_ids,
380
  images=image_tensor,
381
+ do_sample=False, # Deterministic generation for consistency
382
+ max_new_tokens=max_output_tokens,
383
+ repetition_penalty=repetition_penalty,
384
+ use_cache=False,
385
+ pad_token_id=our_chatbot.tokenizer.eos_token_id,
386
+ eos_token_id=our_chatbot.tokenizer.eos_token_id,
387
+ length_penalty=1.0, # Don't penalize longer sequences
388
  )
389
+
390
+ # Decode response
391
+ try:
392
+ print(f"[DEBUG] Outputs shape: {outputs.shape if hasattr(outputs, 'shape') else 'No shape attr'}")
393
+ print(f"[DEBUG] Outputs length: {len(outputs) if hasattr(outputs, '__len__') else 'No length'}")
394
+ print(f"[DEBUG] Input IDs shape: {input_ids.shape}")
395
+
396
+ if len(outputs) == 0:
397
+ return {"error": "Model generated empty output"}
398
+
399
+ response = our_chatbot.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
400
+
401
+ print(f"[DEBUG] Conversation messages length: {len(our_chatbot.conversation.messages)}")
402
+ if len(our_chatbot.conversation.messages) > 0:
403
+ last_message = our_chatbot.conversation.messages[-1]
404
+ print(f"[DEBUG] Last message: {last_message}")
405
+ if isinstance(last_message, list) and len(last_message) > 1:
406
+ our_chatbot.conversation.messages[-1][-1] = response
407
+ print(f"[DEBUG] Response added to conversation")
408
+ else:
409
+ print(f"[DEBUG] Last message format unexpected: {last_message}")
410
+ # Add response as new message if format is wrong
411
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], response)
412
+ else:
413
+ print("[DEBUG] No conversation messages found")
414
+ # Add response as new message
415
+ our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], response)
416
+
417
+ print(f"[DEBUG] Generated response length: {len(response)}")
418
+ except Exception as e:
419
+ print(f"[DEBUG] Response decoding error: {str(e)}")
420
+ return {"error": f"Response decoding failed: {str(e)}"}
421
+
422
+ # Log conversation
423
+ history = [(message_text, response)]
424
  with open(get_conv_log_filename(), "a") as fout:
425
+ data = {
426
  "type": "chat",
427
  "model": "PULSE-7b",
428
+ "state": history,
429
+ "images": all_image_hash,
430
+ "images_path": all_image_path
431
+ }
432
+ print("#### conv log", data)
433
+ fout.write(json.dumps(data) + "\n")
434
+
435
+ # Upload files to Hugging Face if configured
436
+ if api and repo_name:
437
+ try:
438
+ for upload_img in all_image_path:
439
+ api.upload_file(
440
+ path_or_fileobj=upload_img,
441
+ path_in_repo=upload_img.replace("./logs/", ""),
442
+ repo_id=repo_name,
443
+ repo_type="dataset",
444
+ )
445
+
446
+ # Upload conversation log
447
+ api.upload_file(
448
+ path_or_fileobj=get_conv_log_filename(),
449
+ path_in_repo=get_conv_log_filename().replace("./logs/", ""),
450
+ repo_id=repo_name,
451
+ repo_type="dataset")
452
+ except Exception as e:
453
+ print(f"Failed to upload files: {e}")
454
+
455
+ return {
456
+ "status": "success",
457
+ "response": response,
458
+ "conversation_id": id(our_chatbot.conversation)
459
+ }
460
+
461
  except Exception as e:
462
  return {"error": f"Generation failed: {str(e)}"}
463
 
 
464
  def upvote_last_response(conversation_id):
465
+ """Upvote the last response"""
466
  try:
467
  vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B")
468
+ return {"status": "success", "message": "Thank you for your voting!"}
469
  except Exception as e:
470
+ return {"error": f"Failed to upvote: {str(e)}"}
471
 
472
  def downvote_last_response(conversation_id):
473
+ """Downvote the last response"""
474
  try:
475
  vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B")
476
+ return {"status": "success", "message": "Thank you for your voting!"}
477
  except Exception as e:
478
+ return {"error": f"Failed to downvote: {str(e)}"}
479
 
480
  def flag_response(conversation_id):
481
+ """Flag the last response"""
482
  try:
483
  vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B")
484
+ return {"status": "success", "message": "Response flagged successfully"}
485
  except Exception as e:
486
+ return {"error": f"Failed to flag response: {str(e)}"}
487
 
488
+ # Initialize model when module is imported
489
  def initialize_model():
490
+ """Initialize the model and tokenizer"""
491
  global tokenizer, model, image_processor, context_len, args
492
+
493
  if not LLAVA_AVAILABLE:
494
  print("LLaVA modules not available, skipping model initialization")
495
  return False
496
+
497
  try:
498
+ # Set default arguments
499
  class Args:
500
  def __init__(self):
501
  self.model_path = "PULSE-ECG/PULSE-7B"
502
  self.model_base = None
503
  self.num_gpus = 1
504
  self.conv_mode = None
 
505
  self.max_new_tokens = 1024
506
  self.num_frames = 16
507
  self.load_8bit = False
508
  self.load_4bit = False
509
  self.debug = False
510
+
511
  args = Args()
512
+
513
+ # Load model
514
+ model_path = args.model_path
515
  model_name = get_model_name_from_path(args.model_path)
516
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
517
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
518
  )
519
+
520
+ print("### image_processor", image_processor)
521
+ print("### tokenizer", tokenizer)
522
+
523
+ # Move model to GPU if available
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  if torch.cuda.is_available():
525
  model = model.to(torch.device('cuda'))
526
  print("Model moved to CUDA")
527
  else:
528
  print("CUDA not available, using CPU")
529
+
530
  return True
531
+
532
  except Exception as e:
533
  print(f"Failed to initialize model: {e}")
534
  return False
535
 
536
+ # Don't initialize model on import - do it lazily
537
+ model_initialized = False
538
+
539
+ # Main endpoint function for Hugging Face
540
  def query(payload):
541
+ """Main endpoint function for Hugging Face inference API"""
542
  global model_initialized
543
+
544
+ # Lazy initialization - initialize model on first call
545
  if not model_initialized:
546
  print("Initializing model on first query...")
547
  model_initialized = initialize_model()
548
  if not model_initialized:
549
  return {"error": "Model initialization failed"}
550
+
551
  try:
 
552
  print(f"[DEBUG] query payload keys={list(payload.keys()) if hasattr(payload,'keys') else 'N/A'}")
553
+
554
+ # Extract prompt with multiple possible keys
555
+ message_text = (payload.get("message") or
556
+ payload.get("query") or
557
+ payload.get("prompt") or
558
+ payload.get("istem") or "")
559
+
560
+ # Extract image with multiple possible keys
561
+ image_input = (payload.get("image") or
562
+ payload.get("image_url") or
563
+ payload.get("img") or None)
564
+
565
+ # Extract generation parameters with fallbacks
566
+ max_output_tokens = int(payload.get("max_output_tokens",
567
+ payload.get("max_new_tokens",
568
+ payload.get("max_tokens", 8192))))
569
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
570
  conv_mode_override = payload.get("conv_mode", None)
571
+
572
+ if not message_text or not message_text.strip():
573
+ return {"error": "Missing prompt text. Use 'message', 'query', 'prompt', or 'istem' key"}
574
+
 
 
 
 
575
  if not image_input:
576
+ return {"error": "Missing image. Use 'image', 'image_url', or 'img' key"}
577
+
578
+ # Generate response with deterministic parameters
579
+ result = generate_response(
580
  message_text=message_text,
581
  image_input=image_input,
 
 
582
  max_output_tokens=max_output_tokens,
583
  repetition_penalty=repetition_penalty,
584
+ conv_mode_override=conv_mode_override
 
 
 
585
  )
586
+
587
+ return result
588
+
589
  except Exception as e:
590
  return {"error": f"Query failed: {str(e)}"}
591
 
592
+ # Additional utility endpoints
593
  def health_check():
594
+ """Health check endpoint"""
595
  return {
596
  "status": "healthy",
597
  "model_initialized": model_initialized,
 
599
  "llava_available": LLAVA_AVAILABLE,
600
  "transformers_available": TRANSFORMERS_AVAILABLE,
601
  "cv2_available": CV2_AVAILABLE,
602
+ "lazy_loading": True # Model will be loaded on first query
603
  }
604
 
605
  def get_model_info():
606
+ """Get model information"""
607
  if not model_initialized:
608
+ return {
609
+ "error": "Model not initialized yet",
610
+ "lazy_loading": True,
611
+ "note": "Model will be loaded on first query"
612
+ }
613
+
614
  return {
615
  "model_path": args.model_path if args else "Unknown",
616
  "model_type": "PULSE-7B",
 
618
  "device": str(model.device) if model else "Unknown"
619
  }
620
 
621
+ # Hugging Face EndpointHandler class
622
  class EndpointHandler:
623
+ """Hugging Face endpoint handler class"""
624
+
625
  def __init__(self, model_dir):
626
+ """Initialize the endpoint handler"""
627
  self.model_dir = model_dir
628
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
629
+
630
  def __call__(self, payload):
631
+ """Main endpoint function - handles Hugging Face payload format"""
632
+ # Hugging Face sends payload in "inputs" wrapper
633
  if "inputs" in payload:
634
+ # Extract the actual payload from inputs wrapper
635
+ actual_payload = payload["inputs"]
636
+ return query(actual_payload)
637
+ else:
638
+ # Direct payload (for backward compatibility)
639
+ return query(payload)
640
+
641
  def health_check(self):
642
+ """Health check endpoint"""
643
  return health_check()
644
+
645
  def get_model_info(self):
646
+ """Get model information"""
647
  return get_model_info()
648
 
649
+ # For backward compatibility and testing
650
  if __name__ == "__main__":
651
+ print("Handler module loaded successfully!")
652
+ print("This handler is now ready for Hugging Face endpoints.")
653
+ print("Use the 'query' function as the main endpoint.")
654
+ print("Or use EndpointHandler class for Hugging Face compatibility.")