CanerDedeoglu commited on
Commit
6cee168
·
verified ·
1 Parent(s): d4d405a

deterministik output added

Browse files
Files changed (1) hide show
  1. handler.py +363 -463
handler.py CHANGED
@@ -1,21 +1,15 @@
 
1
  """
2
- PULSE ECG Handler - Deterministic ECG Analysis Model
3
-
4
- This handler provides consistent, deterministic responses for ECG analysis.
5
- All generation parameters are fixed to ensure reproducible results across
6
- different API calls and clients.
7
-
8
- Key Features:
9
- - Deterministic generation (do_sample=False)
10
- - Fixed random seed for consistency
11
- - No temperature/top_p sampling parameters
12
- - Consistent response lengths and content
13
  """
14
 
15
  import os
16
  import datetime
17
  import torch
18
- import numpy as np
19
  import hashlib
20
  import json
21
  import base64
@@ -23,18 +17,22 @@ import requests
23
  from PIL import Image
24
  from io import BytesIO
25
 
26
- # Try to import cv2, but make it optional
 
 
 
 
 
27
  try:
28
  import cv2
29
  CV2_AVAILABLE = True
30
- except ImportError:
31
  CV2_AVAILABLE = False
32
  print("Warning: cv2 (OpenCV) not available. Video processing will be disabled.")
33
 
34
- # Try to import llava modules, but make them optional
35
  try:
36
  from llava import conversation as conversation_lib
37
- from llava.constants import DEFAULT_IMAGE_TOKEN
38
  from llava.constants import (
39
  IMAGE_TOKEN_INDEX,
40
  DEFAULT_IMAGE_TOKEN,
@@ -51,27 +49,27 @@ try:
51
  KeywordsStoppingCriteria,
52
  )
53
  LLAVA_AVAILABLE = True
54
- except ImportError as e:
55
  LLAVA_AVAILABLE = False
56
  print(f"Warning: LLaVA modules not available: {e}")
57
 
58
- # Try to import transformers
59
  try:
60
- from transformers import TextStreamer, TextIteratorStreamer
61
  TRANSFORMERS_AVAILABLE = True
62
- except ImportError:
63
  TRANSFORMERS_AVAILABLE = False
64
  print("Warning: Transformers not available")
65
 
66
- # Try to import huggingface_hub
67
  try:
68
  from huggingface_hub import HfApi, login
69
  HF_HUB_AVAILABLE = True
70
- except ImportError:
71
  HF_HUB_AVAILABLE = False
72
  print("Warning: Hugging Face Hub not available")
73
 
74
- # Initialize Hugging Face API
75
  if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
76
  try:
77
  login(token=os.environ["HF_TOKEN"], write_permission=True)
@@ -85,173 +83,183 @@ else:
85
  api = None
86
  repo_name = ""
87
 
88
- external_log_dir = "./logs"
89
- LOGDIR = external_log_dir
90
  VOTEDIR = "./votes"
 
 
91
 
92
- # Global variables for model and tokenizer
93
  tokenizer = None
94
  model = None
95
  image_processor = None
96
  context_len = None
97
  args = None
 
 
 
 
 
 
98
 
99
- # Configuration for consistent responses
100
- PROMPT_NORMALIZATION = True # Set to False to disable prompt normalization
101
- DEFAULT_ECG_PROMPT = "What are the main features and diagnosis in this ECG image? Provide a comprehensive clinical analysis"
102
 
103
- # Note: When PROMPT_NORMALIZATION is True, all ECG diagnosis requests will be
104
- # standardized to ensure consistent response lengths and content across different clients.
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  def get_conv_log_filename():
107
  t = datetime.datetime.now()
108
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
 
109
  return name
110
 
111
  def get_conv_vote_filename():
112
  t = datetime.datetime.now()
113
  name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json")
114
- if not os.path.isfile(name):
115
- os.makedirs(os.path.dirname(name), exist_ok=True)
116
  return name
117
 
118
  def vote_last_response(state, vote_type, model_selector):
119
- if api and repo_name:
120
- try:
121
- with open(get_conv_vote_filename(), "a") as fout:
122
- data = {
123
- "type": vote_type,
124
- "model": model_selector,
125
- "state": state,
126
- }
127
- fout.write(json.dumps(data) + "\n")
128
-
129
- api.upload_file(
130
- path_or_fileobj=get_conv_vote_filename(),
131
- path_in_repo=get_conv_vote_filename().replace("./votes/", ""),
132
- repo_id=repo_name,
133
- repo_type="dataset")
134
- except Exception as e:
135
- print(f"Failed to upload vote file: {e}")
136
 
137
- def is_valid_video_filename(name):
138
- if not CV2_AVAILABLE:
139
- return False # Video processing disabled
140
- video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
 
 
 
 
 
 
 
 
 
 
141
  ext = name.split(".")[-1].lower()
142
- return ext in video_extensions
143
 
144
- def is_valid_image_filename(name):
145
- image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"]
 
146
  ext = name.split(".")[-1].lower()
147
- return ext in image_extensions
148
 
149
  def sample_frames(video_file, num_frames):
150
  if not CV2_AVAILABLE:
151
  raise ImportError("cv2 (OpenCV) not available. Video processing is disabled.")
152
-
153
- video = cv2.VideoCapture(video_file)
154
- total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
155
- interval = total_frames // num_frames
 
 
 
156
  frames = []
157
- for i in range(total_frames):
158
- ret, frame = video.read()
159
- if not ret:
 
160
  continue
161
- if i % interval == 0:
162
- pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
163
- frames.append(pil_img)
164
- video.release()
165
  return frames
166
 
167
  def load_image(image_file):
168
- if image_file.startswith("http") or image_file.startswith("https"):
169
- response = requests.get(image_file)
170
- if response.status_code == 200:
171
- image = Image.open(BytesIO(response.content)).convert("RGB")
172
- else:
173
- raise ValueError("Failed to load image from URL")
 
174
  else:
175
- print("Load image from local file")
176
- print(image_file)
177
- image = Image.open(image_file).convert("RGB")
178
- return image
179
 
180
- def process_base64_image(base64_string):
181
- """Process base64 encoded image string"""
182
  try:
183
- # Remove data URL prefix if present
184
- if base64_string.startswith('data:image'):
185
- base64_string = base64_string.split(',')[1]
186
-
187
- # Decode base64 to bytes
188
  image_data = base64.b64decode(base64_string)
189
-
190
- # Convert to PIL Image
191
  image = Image.open(BytesIO(image_data)).convert("RGB")
192
  return image
193
  except Exception as e:
194
  raise ValueError(f"Failed to process base64 image: {e}")
195
 
196
  def process_image_input(image_input):
197
- """Process different types of image input (file path, URL, or base64)"""
198
  if isinstance(image_input, str):
199
- if image_input.startswith("http"):
200
  return load_image(image_input)
201
- elif os.path.exists(image_input):
202
  return load_image(image_input)
203
- else:
204
- # Try to process as base64
205
- return process_base64_image(image_input)
206
- elif isinstance(image_input, dict) and "image" in image_input:
207
- # Handle base64 image from dict
208
  return process_base64_image(image_input["image"])
209
- else:
210
- raise ValueError("Unsupported image input format")
 
211
 
212
  class InferenceDemo(object):
213
  def __init__(self, args, model_path, tokenizer, model, image_processor, context_len) -> None:
214
  if not LLAVA_AVAILABLE:
215
  raise ImportError("LLaVA modules not available")
216
-
217
  disable_torch_init()
218
-
219
  self.tokenizer, self.model, self.image_processor, self.context_len = (
220
- tokenizer,
221
- model,
222
- image_processor,
223
- context_len,
224
  )
225
-
226
  model_name = get_model_name_from_path(model_path)
227
- if "llama-2" in model_name.lower():
 
228
  conv_mode = "llava_llama_2"
229
- elif "v1" in model_name.lower() or "pulse" in model_name.lower():
230
  conv_mode = "llava_v1"
231
- elif "mpt" in model_name.lower():
232
  conv_mode = "mpt"
233
- elif "qwen" in model_name.lower():
234
  conv_mode = "qwen_1_5"
235
  else:
236
  conv_mode = "llava_v0"
237
 
238
  if args.conv_mode is not None and conv_mode != args.conv_mode:
239
- print(
240
- "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
241
- conv_mode, args.conv_mode, args.conv_mode
242
- )
243
- )
244
  else:
245
  args.conv_mode = conv_mode
246
- self.conv_mode = conv_mode
247
- self.conversation = conv_templates[args.conv_mode].copy()
248
  self.num_frames = args.num_frames
249
 
250
  class ChatSessionManager:
251
  def __init__(self):
252
  self.chatbot_instance = None
 
 
253
 
254
  def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
 
 
255
  self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
256
  print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
257
 
@@ -266,225 +274,242 @@ class ChatSessionManager:
266
  chat_manager = ChatSessionManager()
267
 
268
  def clear_history():
269
- """Clear conversation history"""
270
  if not LLAVA_AVAILABLE:
271
  return {"error": "LLaVA modules not available"}
272
-
273
  try:
274
- chatbot_instance = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len)
 
275
  try:
276
- if hasattr(chatbot_instance, 'conv_mode') and chatbot_instance.conv_mode and LLAVA_AVAILABLE:
277
- chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy()
278
- else:
279
- # Use default conversation template
280
- chatbot_instance.conversation = chatbot_instance.conversation.__class__()
281
  except Exception as e:
282
- print(f"[DEBUG] Failed to reset conversation in clear_history: {e}")
283
  return {"status": "success", "message": "Conversation history cleared"}
284
  except Exception as e:
285
  return {"error": f"Failed to clear history: {str(e)}"}
286
 
287
- def add_message(message_text, image_input=None):
288
- """Add a message to the conversation"""
289
- return {"status": "success", "message": "Message added"}
290
-
291
- def generate_response(message_text, image_input, max_output_tokens=4096, repetition_penalty=1.0, conv_mode_override=None):
292
- """Generate response for the given message and image using deterministic generation for consistency"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  if not LLAVA_AVAILABLE:
294
  return {"error": "LLaVA modules not available"}
295
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  try:
297
- if not message_text or not image_input:
298
- return {"error": "Both message text and image are required"}
299
-
300
- our_chatbot = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len)
301
-
302
- # Process image input
303
- try:
304
- image = process_image_input(image_input)
305
- except Exception as e:
306
- return {"error": f"Failed to process image: {str(e)}"}
307
-
308
- # Save image for logging
309
- all_image_hash = []
310
- all_image_path = []
311
-
312
- # Generate hash for the image
313
  img_byte_arr = BytesIO()
314
- image.save(img_byte_arr, format='JPEG')
315
- img_byte_arr = img_byte_arr.getvalue()
316
- image_hash = hashlib.md5(img_byte_arr).hexdigest()
317
- all_image_hash.append(image_hash)
318
-
319
- # Save image to logs
320
  t = datetime.datetime.now()
321
- filename = os.path.join(
322
- LOGDIR,
323
- "serve_images",
324
- f"{t.year}-{t.month:02d}-{t.day:02d}",
325
- f"{image_hash}.jpg",
326
- )
327
- all_image_path.append(filename)
328
- if not os.path.isfile(filename):
329
- os.makedirs(os.path.dirname(filename), exist_ok=True)
330
- print("image save to", filename)
331
- image.save(filename)
332
-
333
- # Process image for model
334
- try:
335
- print(f"[DEBUG] Processing image for model...")
336
- processed_images = process_images([image], our_chatbot.image_processor, our_chatbot.model.config)
337
- print(f"[DEBUG] Processed images length: {len(processed_images)}")
338
-
339
- if len(processed_images) == 0:
340
- return {"error": "Image processing returned empty list"}
341
-
342
- image_tensor = processed_images[0]
343
- image_tensor = image_tensor.half().to(our_chatbot.model.device)
344
- image_tensor = image_tensor.unsqueeze(0)
345
- print(f"[DEBUG] Image tensor shape: {image_tensor.shape}")
346
- except Exception as e:
347
- print(f"[DEBUG] Image processing error: {str(e)}")
348
- return {"error": f"Image processing failed: {str(e)}"}
349
-
350
- # Prepare conversation - reset for each request to avoid history issues
351
- try:
352
- if hasattr(our_chatbot, 'conv_mode') and our_chatbot.conv_mode and LLAVA_AVAILABLE:
353
- our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
354
- print(f"[DEBUG] Reset conversation using conv_mode: {our_chatbot.conv_mode}")
355
- else:
356
- # Use default conversation template
357
- our_chatbot.conversation = our_chatbot.conversation.__class__()
358
- print(f"[DEBUG] Reset conversation using default template")
359
- except Exception as e:
360
- print(f"[DEBUG] Failed to reset conversation: {e}")
361
- # Continue with existing conversation
362
-
363
- inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text
364
- our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
365
- our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
366
- prompt = our_chatbot.conversation.get_prompt()
367
-
368
- print(f"[DEBUG] Conversation template: {type(our_chatbot.conversation).__name__}")
369
- print(f"[DEBUG] Conversation roles: {our_chatbot.conversation.roles if hasattr(our_chatbot.conversation, 'roles') else 'No roles'}")
370
- print(f"[DEBUG] Final prompt length: {len(prompt)} characters")
371
-
372
- # Tokenize input
373
- input_ids = tokenizer_image_token(
374
- prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
375
- ).unsqueeze(0).to(our_chatbot.model.device)
376
-
377
- # No stopping criteria - let model generate freely up to max_new_tokens
378
- print(f"[DEBUG] No stopping criteria - free generation up to {max_output_tokens} tokens")
379
- print(f"[DEBUG] Input prompt length: {len(prompt)} characters")
380
- print(f"[DEBUG] Input tokens: {input_ids.shape[1]} tokens")
381
- stopping_criteria = None
382
-
383
- # Set seed for deterministic generation
384
- # This ensures the same input always produces the same output
385
- torch.manual_seed(42)
386
- if torch.cuda.is_available():
387
- torch.cuda.manual_seed(42)
388
- torch.cuda.manual_seed_all(42)
389
-
390
- # Generate response using deterministic greedy decoding
391
- # This eliminates randomness and ensures consistent responses
392
- print(f"[DEBUG] About to generate with max_new_tokens: {max_output_tokens}")
393
- print(f"[DEBUG] Model device: {our_chatbot.model.device}")
394
- print(f"[DEBUG] Image tensor device: {image_tensor.device}")
395
-
396
  with torch.no_grad():
397
- outputs = our_chatbot.model.generate(
398
  inputs=input_ids,
399
  images=image_tensor,
400
- do_sample=False, # Deterministic generation for consistency
401
- max_new_tokens=max_output_tokens,
402
- repetition_penalty=repetition_penalty,
403
  use_cache=False,
404
- pad_token_id=our_chatbot.tokenizer.eos_token_id,
405
- eos_token_id=our_chatbot.tokenizer.eos_token_id,
406
- length_penalty=1.0, # Don't penalize longer sequences
407
- early_stopping=False, # Ensure no early stopping
 
408
  )
409
-
410
- # Decode response
411
- try:
412
- print(f"[DEBUG] Outputs shape: {outputs.shape if hasattr(outputs, 'shape') else 'No shape attr'}")
413
- print(f"[DEBUG] Outputs length: {len(outputs) if hasattr(outputs, '__len__') else 'No length'}")
414
- print(f"[DEBUG] Input IDs shape: {input_ids.shape}")
415
- print(f"[DEBUG] Generated tokens: {outputs.shape[1] - input_ids.shape[1] if hasattr(outputs, 'shape') else 'Unknown'}")
416
- print(f"[DEBUG] Expected max tokens: {max_output_tokens}")
417
-
418
- if len(outputs) == 0:
419
- return {"error": "Model generated empty output"}
420
-
421
- response = our_chatbot.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
422
-
423
- print(f"[DEBUG] Conversation messages length: {len(our_chatbot.conversation.messages)}")
424
- if len(our_chatbot.conversation.messages) > 0:
425
- last_message = our_chatbot.conversation.messages[-1]
426
- print(f"[DEBUG] Last message: {last_message}")
427
- if isinstance(last_message, list) and len(last_message) > 1:
428
- our_chatbot.conversation.messages[-1][-1] = response
429
- print(f"[DEBUG] Response added to conversation")
430
- else:
431
- print(f"[DEBUG] Last message format unexpected: {last_message}")
432
- # Add response as new message if format is wrong
433
- our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], response)
434
- else:
435
- print("[DEBUG] No conversation messages found")
436
- # Add response as new message
437
- our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], response)
438
-
439
- print(f"[DEBUG] Generated response length: {len(response)}")
440
- except Exception as e:
441
- print(f"[DEBUG] Response decoding error: {str(e)}")
442
- return {"error": f"Response decoding failed: {str(e)}"}
443
-
444
- # Log conversation
445
  history = [(message_text, response)]
446
  with open(get_conv_log_filename(), "a") as fout:
447
  data = {
448
  "type": "chat",
449
- "model": "PULSE-7b",
450
  "state": history,
451
- "images": all_image_hash,
452
- "images_path": all_image_path
453
  }
454
- print("#### conv log", data)
455
  fout.write(json.dumps(data) + "\n")
456
-
457
- # Upload files to Hugging Face if configured
458
- if api and repo_name:
459
- try:
460
- for upload_img in all_image_path:
461
- api.upload_file(
462
- path_or_fileobj=upload_img,
463
- path_in_repo=upload_img.replace("./logs/", ""),
464
- repo_id=repo_name,
465
- repo_type="dataset",
466
- )
467
-
468
- # Upload conversation log
469
- api.upload_file(
470
- path_or_fileobj=get_conv_log_filename(),
471
- path_in_repo=get_conv_log_filename().replace("./logs/", ""),
472
- repo_id=repo_name,
473
- repo_type="dataset")
474
- except Exception as e:
475
- print(f"Failed to upload files: {e}")
476
-
477
- return {
478
- "status": "success",
479
- "response": response,
480
- "conversation_id": id(our_chatbot.conversation)
481
- }
482
-
483
  except Exception as e:
484
- return {"error": f"Generation failed: {str(e)}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
  def upvote_last_response(conversation_id):
487
- """Upvote the last response"""
488
  try:
489
  vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B")
490
  return {"status": "success", "message": "Thank you for your voting!"}
@@ -492,7 +517,6 @@ def upvote_last_response(conversation_id):
492
  return {"error": f"Failed to upvote: {str(e)}"}
493
 
494
  def downvote_last_response(conversation_id):
495
- """Downvote the last response"""
496
  try:
497
  vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B")
498
  return {"status": "success", "message": "Thank you for your voting!"}
@@ -500,200 +524,76 @@ def downvote_last_response(conversation_id):
500
  return {"error": f"Failed to downvote: {str(e)}"}
501
 
502
  def flag_response(conversation_id):
503
- """Flag the last response"""
504
  try:
505
  vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B")
506
  return {"status": "success", "message": "Response flagged successfully"}
507
  except Exception as e:
508
  return {"error": f"Failed to flag response: {str(e)}"}
509
 
510
- # Initialize model when module is imported
 
511
  def initialize_model():
512
- """Initialize the model and tokenizer"""
513
  global tokenizer, model, image_processor, context_len, args
514
-
515
  if not LLAVA_AVAILABLE:
516
  print("LLaVA modules not available, skipping model initialization")
517
  return False
518
-
519
  try:
520
- # Set default arguments
521
  class Args:
522
  def __init__(self):
523
- self.model_path = "PULSE-ECG/PULSE-7B"
524
  self.model_base = None
525
- self.num_gpus = 1
526
  self.conv_mode = None
527
- self.max_new_tokens = 4096
528
  self.num_frames = 16
529
- self.load_8bit = False
530
- self.load_4bit = False
531
- self.debug = False
532
-
533
  args = Args()
534
-
535
- # Load model
536
- model_path = args.model_path
537
  model_name = get_model_name_from_path(args.model_path)
538
  tokenizer, model, image_processor, context_len = load_pretrained_model(
539
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
540
  )
541
-
542
- print("### image_processor", image_processor)
543
- print("### tokenizer", tokenizer)
544
-
545
- # Move model to GPU if available
546
- if torch.cuda.is_available():
547
- model = model.to(torch.device('cuda'))
548
- print("Model moved to CUDA")
549
- else:
550
- print("CUDA not available, using CPU")
551
-
552
  return True
553
-
554
  except Exception as e:
555
  print(f"Failed to initialize model: {e}")
556
  return False
557
 
558
- # Don't initialize model on import - do it lazily
559
- model_initialized = False
560
-
561
- # Main endpoint function for Hugging Face
562
- def query(payload):
563
- """Main endpoint function for Hugging Face inference API"""
564
- global model_initialized
565
-
566
- # Lazy initialization - initialize model on first call
567
- if not model_initialized:
568
- print("Initializing model on first query...")
569
- model_initialized = initialize_model()
570
- if not model_initialized:
571
- return {"error": "Model initialization failed"}
572
-
573
- try:
574
- print(f"[DEBUG] query payload keys={list(payload.keys()) if hasattr(payload,'keys') else 'N/A'}")
575
-
576
- # Extract prompt with multiple possible keys
577
- message_text = (payload.get("message") or
578
- payload.get("query") or
579
- payload.get("prompt") or
580
- payload.get("istem") or "")
581
-
582
- # Normalize prompt to ensure consistent responses
583
- # This helps maintain consistency across different clients
584
- if PROMPT_NORMALIZATION and "ecg" in message_text.lower() and "diagnosis" in message_text.lower():
585
- # Standardize ECG analysis prompts for consistency
586
- if "comprehensive" in message_text.lower():
587
- message_text = DEFAULT_ECG_PROMPT
588
- elif "concise" in message_text.lower():
589
- message_text = "What are the main features and diagnosis in this ECG image? Provide a concise, clinical answer."
590
- else:
591
- # Default to comprehensive analysis for consistency
592
- message_text = DEFAULT_ECG_PROMPT
593
- print(f"[DEBUG] Normalized prompt to: {message_text}")
594
-
595
- # Extract image with multiple possible keys
596
- image_input = (payload.get("image") or
597
- payload.get("image_url") or
598
- payload.get("img") or None)
599
-
600
- # Extract generation parameters with fallbacks
601
- max_output_tokens = int(payload.get("max_output_tokens",
602
- payload.get("max_new_tokens",
603
- payload.get("max_tokens", 8192))))
604
- repetition_penalty = float(payload.get("repetition_penalty", 1.0))
605
- conv_mode_override = payload.get("conv_mode", None)
606
-
607
- # Debug: Log all generation parameters
608
- print(f"[DEBUG] Generation parameters:")
609
- print(f"[DEBUG] max_output_tokens: {max_output_tokens}")
610
- print(f"[DEBUG] repetition_penalty: {repetition_penalty}")
611
- print(f"[DEBUG] Original payload max_output_tokens: {payload.get('max_output_tokens')}")
612
- print(f"[DEBUG] Original payload max_new_tokens: {payload.get('max_new_tokens')}")
613
- print(f"[DEBUG] Original payload max_tokens: {payload.get('max_tokens')}")
614
- print(f"[DEBUG] Full payload keys: {list(payload.keys())}")
615
- print(f"[DEBUG] Payload values: {dict(payload)}")
616
-
617
- if not message_text or not message_text.strip():
618
- return {"error": "Missing prompt text. Use 'message', 'query', 'prompt', or 'istem' key"}
619
-
620
- if not image_input:
621
- return {"error": "Missing image. Use 'image', 'image_url', or 'img' key"}
622
-
623
- # Generate response with deterministic parameters
624
- result = generate_response(
625
- message_text=message_text,
626
- image_input=image_input,
627
- max_output_tokens=max_output_tokens,
628
- repetition_penalty=repetition_penalty,
629
- conv_mode_override=conv_mode_override
630
- )
631
-
632
- return result
633
-
634
- except Exception as e:
635
- return {"error": f"Query failed: {str(e)}"}
636
-
637
- # Additional utility endpoints
638
- def health_check():
639
- """Health check endpoint"""
640
- return {
641
- "status": "healthy",
642
- "model_initialized": model_initialized,
643
- "cuda_available": torch.cuda.is_available(),
644
- "llava_available": LLAVA_AVAILABLE,
645
- "transformers_available": TRANSFORMERS_AVAILABLE,
646
- "cv2_available": CV2_AVAILABLE,
647
- "lazy_loading": True # Model will be loaded on first query
648
- }
649
 
650
- def get_model_info():
651
- """Get model information"""
652
- if not model_initialized:
653
- return {
654
- "error": "Model not initialized yet",
655
- "lazy_loading": True,
656
- "note": "Model will be loaded on first query"
657
- }
658
-
659
- return {
660
- "model_path": args.model_path if args else "Unknown",
661
- "model_type": "PULSE-7B",
662
- "cuda_available": torch.cuda.is_available(),
663
- "device": str(model.device) if model else "Unknown"
664
- }
665
-
666
- # Hugging Face EndpointHandler class
667
  class EndpointHandler:
668
  """Hugging Face endpoint handler class"""
669
-
670
  def __init__(self, model_dir):
671
- """Initialize the endpoint handler"""
672
  self.model_dir = model_dir
673
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
674
-
675
  def __call__(self, payload):
676
- """Main endpoint function - handles Hugging Face payload format"""
677
- # Hugging Face sends payload in "inputs" wrapper
678
  if "inputs" in payload:
679
- # Extract the actual payload from inputs wrapper
680
- actual_payload = payload["inputs"]
681
- return query(actual_payload)
682
- else:
683
- # Direct payload (for backward compatibility)
684
- return query(payload)
685
-
686
  def health_check(self):
687
- """Health check endpoint"""
688
  return health_check()
689
-
690
  def get_model_info(self):
691
- """Get model information"""
692
  return get_model_info()
693
 
694
- # For backward compatibility and testing
695
  if __name__ == "__main__":
696
- print("Handler module loaded successfully!")
697
- print("This handler is now ready for Hugging Face endpoints.")
698
- print("Use the 'query' function as the main endpoint.")
699
- print("Or use EndpointHandler class for Hugging Face compatibility.")
 
1
+ # -*- coding: utf-8 -*-
2
  """
3
+ PULSE ECG Handler - Deterministic ECG Analysis Model (app.py uyumlu)
4
+ - Deterministic (do_sample=False, sabit seed)
5
+ - Tek görüntü, LLaVA conv_template + <image> token akışı
6
+ - Model dtype/device ile uyumlu görüntü tensörü
7
+ - Sağlam URL/base64 işleme, güvenli logging, opsiyonel HF upload
 
 
 
 
 
 
8
  """
9
 
10
  import os
11
  import datetime
12
  import torch
 
13
  import hashlib
14
  import json
15
  import base64
 
17
  from PIL import Image
18
  from io import BytesIO
19
 
20
+ # --- Opsiyonel bağımlılıklar ---
21
+ try:
22
+ import numpy as np # isteğe bağlı, kullanılabilir
23
+ except Exception:
24
+ np = None
25
+
26
  try:
27
  import cv2
28
  CV2_AVAILABLE = True
29
+ except Exception:
30
  CV2_AVAILABLE = False
31
  print("Warning: cv2 (OpenCV) not available. Video processing will be disabled.")
32
 
33
+ # LLaVA
34
  try:
35
  from llava import conversation as conversation_lib
 
36
  from llava.constants import (
37
  IMAGE_TOKEN_INDEX,
38
  DEFAULT_IMAGE_TOKEN,
 
49
  KeywordsStoppingCriteria,
50
  )
51
  LLAVA_AVAILABLE = True
52
+ except Exception as e:
53
  LLAVA_AVAILABLE = False
54
  print(f"Warning: LLaVA modules not available: {e}")
55
 
56
+ # Transformers
57
  try:
58
+ from transformers import TextIteratorStreamer # kullanılmıyor ama mevcutsa sorun değil
59
  TRANSFORMERS_AVAILABLE = True
60
+ except Exception:
61
  TRANSFORMERS_AVAILABLE = False
62
  print("Warning: Transformers not available")
63
 
64
+ # HF Hub (opsiyonel)
65
  try:
66
  from huggingface_hub import HfApi, login
67
  HF_HUB_AVAILABLE = True
68
+ except Exception:
69
  HF_HUB_AVAILABLE = False
70
  print("Warning: Hugging Face Hub not available")
71
 
72
+ # --- HF Hub init (opsiyonel) ---
73
  if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
74
  try:
75
  login(token=os.environ["HF_TOKEN"], write_permission=True)
 
83
  api = None
84
  repo_name = ""
85
 
86
+ # --- Sabitler / Dizinyapısı ---
87
+ LOGDIR = "./logs"
88
  VOTEDIR = "./votes"
89
+ os.makedirs(LOGDIR, exist_ok=True)
90
+ os.makedirs(VOTEDIR, exist_ok=True)
91
 
92
+ # --- Global model durumları ---
93
  tokenizer = None
94
  model = None
95
  image_processor = None
96
  context_len = None
97
  args = None
98
+ model_initialized = False
99
+
100
+ # --- Tutarlılık ayarları ---
101
+ # Tutarlılık ayarları
102
+ PROMPT_NORMALIZATION = True
103
+ DEFAULT_ECG_PROMPT = "Perform a detailed ECG interpretation of the provided image. Analyze step by step the rhythm, heart rate, cardiac axis, P waves, PR interval, QRS complex morphology and duration, ST segments, T waves, and QT/QTc interval. Highlight any abnormalities, conduction disturbances, or ischemic changes you detect. Conclude with a structured clinical impression of the overall ECG."
104
 
 
 
 
105
 
106
+ # ---------- Yardımcılar ----------
107
+
108
+ def _safe_upload(path):
109
+ if api and repo_name and os.path.isfile(path):
110
+ try:
111
+ api.upload_file(
112
+ path_or_fileobj=path,
113
+ path_in_repo=path.replace("./logs/", ""),
114
+ repo_id=repo_name,
115
+ repo_type="dataset",
116
+ )
117
+ except Exception as e:
118
+ print(f"[upload] failed for {path}: {e}")
119
 
120
  def get_conv_log_filename():
121
  t = datetime.datetime.now()
122
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
123
+ os.makedirs(os.path.dirname(name), exist_ok=True)
124
  return name
125
 
126
  def get_conv_vote_filename():
127
  t = datetime.datetime.now()
128
  name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json")
129
+ os.makedirs(os.path.dirname(name), exist_ok=True)
 
130
  return name
131
 
132
  def vote_last_response(state, vote_type, model_selector):
133
+ try:
134
+ with open(get_conv_vote_filename(), "a") as fout:
135
+ data = {
136
+ "type": vote_type,
137
+ "model": model_selector,
138
+ "state": state,
139
+ }
140
+ fout.write(json.dumps(data) + "\n")
141
+ _safe_upload(get_conv_vote_filename())
142
+ except Exception as e:
143
+ print(f"Failed to record vote: {e}")
 
 
 
 
 
 
144
 
145
+ # Yalın uzantı listeleri (sorunlu formatlar çıkarıldı)
146
+ IMAGE_EXTS = {"jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "jfif"}
147
+ # HEIC/HEIF: pillow-heif yoksa desteklemeyelim
148
+ try:
149
+ import pillow_heif # noqa: F401
150
+ IMAGE_EXTS.update({"heic", "heif"})
151
+ except Exception:
152
+ pass
153
+
154
+ VIDEO_EXTS = {"avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"}
155
+
156
+ def is_valid_video_filename(name: str) -> bool:
157
+ if not CV2_AVAILABLE or not name:
158
+ return False
159
  ext = name.split(".")[-1].lower()
160
+ return ext in VIDEO_EXTS
161
 
162
+ def is_valid_image_filename(name: str) -> bool:
163
+ if not name:
164
+ return False
165
  ext = name.split(".")[-1].lower()
166
+ return ext in IMAGE_EXTS
167
 
168
  def sample_frames(video_file, num_frames):
169
  if not CV2_AVAILABLE:
170
  raise ImportError("cv2 (OpenCV) not available. Video processing is disabled.")
171
+ cap = cv2.VideoCapture(video_file)
172
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
173
+ if total <= 0 or num_frames <= 0:
174
+ cap.release()
175
+ return []
176
+ step = max(1, total // num_frames)
177
+ idxs = list(range(0, total, step))[:num_frames]
178
  frames = []
179
+ for i in idxs:
180
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i)
181
+ ret, frame = cap.read()
182
+ if not ret or frame is None:
183
  continue
184
+ pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
185
+ frames.append(pil_img)
186
+ cap.release()
 
187
  return frames
188
 
189
  def load_image(image_file):
190
+ if image_file.startswith(("http://", "https://")):
191
+ try:
192
+ r = requests.get(image_file, timeout=(5, 15))
193
+ r.raise_for_status()
194
+ return Image.open(BytesIO(r.content)).convert("RGB")
195
+ except Exception as e:
196
+ raise ValueError(f"Failed to load image URL: {e}")
197
  else:
198
+ return Image.open(image_file).convert("RGB")
 
 
 
199
 
200
+ def process_base64_image(base64_string: str) -> Image.Image:
 
201
  try:
202
+ if base64_string.startswith("data:image"):
203
+ base64_string = base64_string.split(",", 1)[1]
 
 
 
204
  image_data = base64.b64decode(base64_string)
 
 
205
  image = Image.open(BytesIO(image_data)).convert("RGB")
206
  return image
207
  except Exception as e:
208
  raise ValueError(f"Failed to process base64 image: {e}")
209
 
210
  def process_image_input(image_input):
211
+ """Desteklenen formatlar: yerel yol, URL, base64 string veya {'image': base64} sözlüğü."""
212
  if isinstance(image_input, str):
213
+ if image_input.startswith(("http://", "https://")):
214
  return load_image(image_input)
215
+ if os.path.exists(image_input):
216
  return load_image(image_input)
217
+ # muhtemelen base64
218
+ return process_base64_image(image_input)
219
+ if isinstance(image_input, dict) and "image" in image_input:
 
 
220
  return process_base64_image(image_input["image"])
221
+ raise ValueError("Unsupported image input format")
222
+
223
+ # ---------- Oturum / Konuşma ----------
224
 
225
  class InferenceDemo(object):
226
  def __init__(self, args, model_path, tokenizer, model, image_processor, context_len) -> None:
227
  if not LLAVA_AVAILABLE:
228
  raise ImportError("LLaVA modules not available")
 
229
  disable_torch_init()
 
230
  self.tokenizer, self.model, self.image_processor, self.context_len = (
231
+ tokenizer, model, image_processor, context_len
 
 
 
232
  )
 
233
  model_name = get_model_name_from_path(model_path)
234
+ low = model_name.lower()
235
+ if "llama-2" in low:
236
  conv_mode = "llava_llama_2"
237
+ elif "v1" in low or "pulse" in low:
238
  conv_mode = "llava_v1"
239
+ elif "mpt" in low:
240
  conv_mode = "mpt"
241
+ elif "qwen" in low:
242
  conv_mode = "qwen_1_5"
243
  else:
244
  conv_mode = "llava_v0"
245
 
246
  if args.conv_mode is not None and conv_mode != args.conv_mode:
247
+ print(f"[WARNING] auto conv={conv_mode}, using --conv-mode={args.conv_mode}")
 
 
 
 
248
  else:
249
  args.conv_mode = conv_mode
250
+ self.conv_mode = args.conv_mode
251
+ self.conversation = conv_templates[self.conv_mode].copy()
252
  self.num_frames = args.num_frames
253
 
254
  class ChatSessionManager:
255
  def __init__(self):
256
  self.chatbot_instance = None
257
+ self.args = None
258
+ self.model_path = None
259
 
260
  def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
261
+ self.args = args
262
+ self.model_path = model_path
263
  self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
264
  print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
265
 
 
274
  chat_manager = ChatSessionManager()
275
 
276
  def clear_history():
 
277
  if not LLAVA_AVAILABLE:
278
  return {"error": "LLaVA modules not available"}
 
279
  try:
280
+ chatbot = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B",
281
+ tokenizer, model, image_processor, context_len)
282
  try:
283
+ chatbot.conversation = conv_templates[chatbot.conv_mode].copy()
 
 
 
 
284
  except Exception as e:
285
+ print(f"[DEBUG] Failed to reset conversation: {e}")
286
  return {"status": "success", "message": "Conversation history cleared"}
287
  except Exception as e:
288
  return {"error": f"Failed to clear history: {str(e)}"}
289
 
290
+ # ---------- Cevap üretimi ----------
291
+
292
+ def _build_prompt(chatbot, user_text: str) -> str:
293
+ # App.py ile aynı: <image> token + kullanıcı metni
294
+ image_token = DEFAULT_IMAGE_TOKEN
295
+ inp = image_token + "\n" + user_text
296
+ chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
297
+ chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
298
+ prompt = chatbot.conversation.get_prompt()
299
+ return prompt
300
+
301
+ def _stop_criteria_from_conv(chatbot, input_ids):
302
+ conv = chatbot.conversation
303
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
304
+ return KeywordsStoppingCriteria([stop_str], chatbot.tokenizer, input_ids)
305
+
306
+ def generate_response(message_text,
307
+ image_input,
308
+ max_output_tokens=2048,
309
+ repetition_penalty=1.0,
310
+ conv_mode_override=None):
311
  if not LLAVA_AVAILABLE:
312
  return {"error": "LLaVA modules not available"}
313
+
314
+ # Zorunlu girişler
315
+ if not message_text or not image_input:
316
+ return {"error": "Both message text and image are required"}
317
+
318
+ # Chatbot al
319
+ chatbot = chat_manager.get_chatbot(
320
+ args, args.model_path if args else "PULSE-ECG/PULSE-7B",
321
+ tokenizer, model, image_processor, context_len
322
+ )
323
+
324
+ # İsteğe bağlı conv override
325
+ if conv_mode_override and conv_mode_override in conv_templates:
326
+ chatbot.conversation = conv_templates[conv_mode_override].copy()
327
+ else:
328
+ chatbot.conversation = conv_templates[chatbot.conv_mode].copy()
329
+
330
+ # Görüntüyü al/işle
331
+ try:
332
+ image = process_image_input(image_input)
333
+ except Exception as e:
334
+ return {"error": f"Failed to process image: {str(e)}"}
335
+
336
+ # Log için kaydet
337
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  img_byte_arr = BytesIO()
339
+ image.save(img_byte_arr, format="JPEG")
340
+ image_hash = hashlib.md5(img_byte_arr.getvalue()).hexdigest()
 
 
 
 
341
  t = datetime.datetime.now()
342
+ out_path = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{image_hash}.jpg")
343
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
344
+ if not os.path.isfile(out_path):
345
+ image.save(out_path)
346
+ except Exception as e:
347
+ print(f"[WARN] Failed to save image: {e}")
348
+ out_path = None
349
+ image_hash = "NA"
350
+
351
+ # Model dtype/device
352
+ model_device = next(chatbot.model.parameters()).device
353
+ model_dtype = next(chatbot.model.parameters()).dtype
354
+
355
+ # Görüntü tensörü (modelin dtype/device’ında)
356
+ try:
357
+ processed = process_images([image], chatbot.image_processor, chatbot.model.config)
358
+ if not processed:
359
+ return {"error": "Image processing returned empty list"}
360
+ image_tensor = processed[0].to(device=model_device, dtype=model_dtype).unsqueeze(0)
361
+ except Exception as e:
362
+ return {"error": f"Image processing failed: {str(e)}"}
363
+
364
+ # Prompt & tokenizasyon
365
+ prompt = _build_prompt(chatbot, message_text)
366
+ input_ids = tokenizer_image_token(
367
+ prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
368
+ ).unsqueeze(0).to(model_device)
369
+
370
+ # Stop kriteri (app.py uyumlu)
371
+ stopping_criteria = _stop_criteria_from_conv(chatbot, input_ids)
372
+
373
+ # Deterministik üretim
374
+ torch.manual_seed(42)
375
+ if torch.cuda.is_available():
376
+ torch.cuda.manual_seed(42)
377
+ torch.cuda.manual_seed_all(42)
378
+
379
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  with torch.no_grad():
381
+ outputs = chatbot.model.generate(
382
  inputs=input_ids,
383
  images=image_tensor,
384
+ do_sample=False, # deterministik
385
+ max_new_tokens=int(max_output_tokens),
386
+ repetition_penalty=float(repetition_penalty),
387
  use_cache=False,
388
+ pad_token_id=chatbot.tokenizer.eos_token_id,
389
+ eos_token_id=chatbot.tokenizer.eos_token_id,
390
+ length_penalty=1.0,
391
+ early_stopping=False,
392
+ stopping_criteria=[stopping_criteria],
393
  )
394
+ # Sadece yeni üretilen kısmı çöz
395
+ gen = outputs[0][input_ids.shape[1]:]
396
+ response = chatbot.tokenizer.decode(gen, skip_special_tokens=True)
397
+
398
+ # Konuşmaya yerleştir
399
+ if chatbot.conversation.messages and isinstance(chatbot.conversation.messages[-1], list):
400
+ chatbot.conversation.messages[-1][-1] = response
401
+ else:
402
+ chatbot.conversation.append_message(chatbot.conversation.roles[1], response)
403
+
404
+ except Exception as e:
405
+ return {"error": f"Generation failed: {str(e)}"}
406
+
407
+ # Log
408
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  history = [(message_text, response)]
410
  with open(get_conv_log_filename(), "a") as fout:
411
  data = {
412
  "type": "chat",
413
+ "model": "PULSE-7B",
414
  "state": history,
415
+ "images": [image_hash],
416
+ "images_path": [out_path] if out_path else []
417
  }
 
418
  fout.write(json.dumps(data) + "\n")
419
+ _safe_upload(get_conv_log_filename())
420
+ if out_path:
421
+ _safe_upload(out_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  except Exception as e:
423
+ print(f"[WARN] Failed to log/upload: {e}")
424
+
425
+ return {
426
+ "status": "success",
427
+ "response": response,
428
+ "conversation_id": id(chatbot.conversation)
429
+ }
430
+
431
+ # ---------- API yüzeyi ----------
432
+
433
+ def query(payload):
434
+ """HF Endpoint ana giriş noktası"""
435
+ global model_initialized, tokenizer, model, image_processor, context_len, args
436
+
437
+ # Lazy init
438
+ if not model_initialized:
439
+ ok = initialize_model()
440
+ if not ok:
441
+ return {"error": "Model initialization failed"}
442
+ model_initialized = True
443
+
444
+ try:
445
+ # Metin
446
+ message_text = (
447
+ payload.get("message")
448
+ or payload.get("query")
449
+ or payload.get("prompt")
450
+ or payload.get("istem")
451
+ or ""
452
+ )
453
+
454
+ # Prompt normalization (ECG + diagnosis içeriyorsa)
455
+ if PROMPT_NORMALIZATION and "ecg" in message_text.lower():
456
+ if "concise" in message_text.lower():
457
+ message_text = "Provide a short, concise clinical summary of the ECG."
458
+ else:
459
+ message_text = DEFAULT_ECG_PROMPT
460
+
461
+ # Görüntü
462
+ image_input = (
463
+ payload.get("image")
464
+ or payload.get("image_url")
465
+ or payload.get("img")
466
+ or None
467
+ )
468
+
469
+ # Parametreler
470
+ max_output_tokens = int(payload.get("max_output_tokens",
471
+ payload.get("max_new_tokens",
472
+ payload.get("max_tokens", 2048))))
473
+ repetition_penalty = float(payload.get("repetition_penalty", 1.0))
474
+ conv_mode_override = payload.get("conv_mode", None)
475
+
476
+ if not message_text.strip():
477
+ return {"error": "Missing prompt text. Use 'message', 'query', 'prompt', or 'istem' key"}
478
+ if image_input is None:
479
+ return {"error": "Missing image. Use 'image', 'image_url', or 'img' key"}
480
+
481
+ return generate_response(
482
+ message_text=message_text,
483
+ image_input=image_input,
484
+ max_output_tokens=max_output_tokens,
485
+ repetition_penalty=repetition_penalty,
486
+ conv_mode_override=conv_mode_override
487
+ )
488
+ except Exception as e:
489
+ return {"error": f"Query failed: {str(e)}"}
490
+
491
+ def health_check():
492
+ return {
493
+ "status": "healthy",
494
+ "model_initialized": model_initialized,
495
+ "cuda_available": torch.cuda.is_available(),
496
+ "llava_available": LLAVA_AVAILABLE,
497
+ "transformers_available": TRANSFORMERS_AVAILABLE,
498
+ "cv2_available": CV2_AVAILABLE,
499
+ "lazy_loading": True
500
+ }
501
+
502
+ def get_model_info():
503
+ if not model_initialized:
504
+ return {"error": "Model not initialized yet", "lazy_loading": True}
505
+ return {
506
+ "model_path": args.model_path if args else "Unknown",
507
+ "model_type": "PULSE-7B",
508
+ "cuda_available": torch.cuda.is_available(),
509
+ "device": str(model.device) if model else "Unknown"
510
+ }
511
 
512
  def upvote_last_response(conversation_id):
 
513
  try:
514
  vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B")
515
  return {"status": "success", "message": "Thank you for your voting!"}
 
517
  return {"error": f"Failed to upvote: {str(e)}"}
518
 
519
  def downvote_last_response(conversation_id):
 
520
  try:
521
  vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B")
522
  return {"status": "success", "message": "Thank you for your voting!"}
 
524
  return {"error": f"Failed to downvote: {str(e)}"}
525
 
526
  def flag_response(conversation_id):
 
527
  try:
528
  vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B")
529
  return {"status": "success", "message": "Response flagged successfully"}
530
  except Exception as e:
531
  return {"error": f"Failed to flag response: {str(e)}"}
532
 
533
+ # ---------- Model init ----------
534
+
535
  def initialize_model():
536
+ """Modeli yükle (lazy)"""
537
  global tokenizer, model, image_processor, context_len, args
538
+
539
  if not LLAVA_AVAILABLE:
540
  print("LLaVA modules not available, skipping model initialization")
541
  return False
542
+
543
  try:
 
544
  class Args:
545
  def __init__(self):
546
+ self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
547
  self.model_base = None
548
+ self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
549
  self.conv_mode = None
550
+ self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "2048"))
551
  self.num_frames = 16
552
+ self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
553
+ self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0")))
554
+ self.debug = bool(int(os.getenv("DEBUG", "0")))
555
+
556
  args = Args()
557
+
 
 
558
  model_name = get_model_name_from_path(args.model_path)
559
  tokenizer, model, image_processor, context_len = load_pretrained_model(
560
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
561
  )
562
+
563
+ # Device: accelerate devicemap kullanıyorsa ek .to('cuda') gerekmez
564
+ try:
565
+ _ = next(model.parameters()).device
566
+ except Exception:
567
+ # güvenli taşıma
568
+ if torch.cuda.is_available():
569
+ model = model.to(torch.device("cuda"))
570
+
571
+ print("[init] tokenizer/image_processor/context_len ready")
 
572
  return True
573
+
574
  except Exception as e:
575
  print(f"Failed to initialize model: {e}")
576
  return False
577
 
578
+ # ---------- HF EndpointHandler ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
  class EndpointHandler:
581
  """Hugging Face endpoint handler class"""
582
+
583
  def __init__(self, model_dir):
 
584
  self.model_dir = model_dir
585
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
586
+
587
  def __call__(self, payload):
 
 
588
  if "inputs" in payload:
589
+ return query(payload["inputs"])
590
+ return query(payload)
591
+
 
 
 
 
592
  def health_check(self):
 
593
  return health_check()
594
+
595
  def get_model_info(self):
 
596
  return get_model_info()
597
 
 
598
  if __name__ == "__main__":
599
+ print("Handler loaded. Use `query` or `EndpointHandler` in HF Inference Endpoints.")