""" PULSE ECG Handler - Deterministic ECG Analysis Model This handler provides consistent, deterministic responses for ECG analysis. All generation parameters are fixed to ensure reproducible results across different API calls and clients. Key Features: - Deterministic generation (do_sample=False) - Fixed random seed for consistency - No temperature/top_p sampling parameters - Consistent response lengths and content """ import os import datetime import torch import numpy as np import hashlib import json import base64 import requests from PIL import Image from io import BytesIO # Try to import cv2, but make it optional try: import cv2 CV2_AVAILABLE = True except ImportError: CV2_AVAILABLE = False print("Warning: cv2 (OpenCV) not available. Video processing will be disabled.") # Try to import llava modules, but make them optional try: from llava import conversation as conversation_lib from llava.constants import DEFAULT_IMAGE_TOKEN from llava.constants import ( IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, ) from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import ( tokenizer_image_token, process_images, get_model_name_from_path, KeywordsStoppingCriteria, ) LLAVA_AVAILABLE = True except ImportError as e: LLAVA_AVAILABLE = False print(f"Warning: LLaVA modules not available: {e}") # Try to import transformers try: from transformers import TextStreamer, TextIteratorStreamer TRANSFORMERS_AVAILABLE = True except ImportError: TRANSFORMERS_AVAILABLE = False print("Warning: Transformers not available") # Try to import huggingface_hub try: from huggingface_hub import HfApi, login HF_HUB_AVAILABLE = True except ImportError: HF_HUB_AVAILABLE = False print("Warning: Hugging Face Hub not available") # Initialize Hugging Face API if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ: try: login(token=os.environ["HF_TOKEN"], write_permission=True) api = HfApi() repo_name = os.environ.get("LOG_REPO", "") except Exception as e: print(f"Failed to initialize HF API: {e}") api = None repo_name = "" else: api = None repo_name = "" external_log_dir = "./logs" LOGDIR = external_log_dir VOTEDIR = "./votes" # Global variables for model and tokenizer tokenizer = None model = None image_processor = None context_len = None args = None # Configuration for consistent responses PROMPT_NORMALIZATION = True # Set to False to disable prompt normalization DEFAULT_ECG_PROMPT = "What are the main features and diagnosis in this ECG image? Provide a comprehensive clinical analysis" # Note: When PROMPT_NORMALIZATION is True, all ECG diagnosis requests will be # standardized to ensure consistent response lengths and content across different clients. def get_conv_log_filename(): t = datetime.datetime.now() name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json") return name def get_conv_vote_filename(): t = datetime.datetime.now() name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json") if not os.path.isfile(name): os.makedirs(os.path.dirname(name), exist_ok=True) return name def vote_last_response(state, vote_type, model_selector): if api and repo_name: try: with open(get_conv_vote_filename(), "a") as fout: data = { "type": vote_type, "model": model_selector, "state": state, } fout.write(json.dumps(data) + "\n") api.upload_file( path_or_fileobj=get_conv_vote_filename(), path_in_repo=get_conv_vote_filename().replace("./votes/", ""), repo_id=repo_name, repo_type="dataset") except Exception as e: print(f"Failed to upload vote file: {e}") def is_valid_video_filename(name): if not CV2_AVAILABLE: return False # Video processing disabled video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"] ext = name.split(".")[-1].lower() return ext in video_extensions def is_valid_image_filename(name): image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"] ext = name.split(".")[-1].lower() return ext in image_extensions def sample_frames(video_file, num_frames): if not CV2_AVAILABLE: raise ImportError("cv2 (OpenCV) not available. Video processing is disabled.") video = cv2.VideoCapture(video_file) total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) interval = total_frames // num_frames frames = [] for i in range(total_frames): ret, frame = video.read() if not ret: continue if i % interval == 0: pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) frames.append(pil_img) video.release() return frames def load_image(image_file): if image_file.startswith("http") or image_file.startswith("https"): response = requests.get(image_file) if response.status_code == 200: image = Image.open(BytesIO(response.content)).convert("RGB") else: raise ValueError("Failed to load image from URL") else: print("Load image from local file") print(image_file) image = Image.open(image_file).convert("RGB") return image def process_base64_image(base64_string): """Process base64 encoded image string""" try: # Remove data URL prefix if present if base64_string.startswith('data:image'): base64_string = base64_string.split(',')[1] # Decode base64 to bytes image_data = base64.b64decode(base64_string) # Convert to PIL Image image = Image.open(BytesIO(image_data)).convert("RGB") return image except Exception as e: raise ValueError(f"Failed to process base64 image: {e}") def process_image_input(image_input): """Process different types of image input (file path, URL, or base64)""" if isinstance(image_input, str): if image_input.startswith("http"): return load_image(image_input) elif os.path.exists(image_input): return load_image(image_input) else: # Try to process as base64 return process_base64_image(image_input) elif isinstance(image_input, dict) and "image" in image_input: # Handle base64 image from dict return process_base64_image(image_input["image"]) else: raise ValueError("Unsupported image input format") class InferenceDemo(object): def __init__(self, args, model_path, tokenizer, model, image_processor, context_len) -> None: if not LLAVA_AVAILABLE: raise ImportError("LLaVA modules not available") disable_torch_init() self.tokenizer, self.model, self.image_processor, self.context_len = ( tokenizer, model, image_processor, context_len, ) model_name = get_model_name_from_path(model_path) if "llama-2" in model_name.lower(): conv_mode = "llava_llama_2" elif "v1" in model_name.lower() or "pulse" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt" elif "qwen" in model_name.lower(): conv_mode = "qwen_1_5" else: conv_mode = "llava_v0" if args.conv_mode is not None and conv_mode != args.conv_mode: print( "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( conv_mode, args.conv_mode, args.conv_mode ) ) else: args.conv_mode = conv_mode self.conv_mode = conv_mode self.conversation = conv_templates[args.conv_mode].copy() self.num_frames = args.num_frames class ChatSessionManager: def __init__(self): self.chatbot_instance = None def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len) print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}") def reset_chatbot(self): self.chatbot_instance = None def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): if self.chatbot_instance is None: self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len) return self.chatbot_instance chat_manager = ChatSessionManager() def clear_history(): """Clear conversation history""" if not LLAVA_AVAILABLE: return {"error": "LLaVA modules not available"} try: chatbot_instance = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len) try: if hasattr(chatbot_instance, 'conv_mode') and chatbot_instance.conv_mode and LLAVA_AVAILABLE: chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy() else: # Use default conversation template chatbot_instance.conversation = chatbot_instance.conversation.__class__() except Exception as e: print(f"[DEBUG] Failed to reset conversation in clear_history: {e}") return {"status": "success", "message": "Conversation history cleared"} except Exception as e: return {"error": f"Failed to clear history: {str(e)}"} def add_message(message_text, image_input=None): """Add a message to the conversation""" return {"status": "success", "message": "Message added"} def generate_response(message_text, image_input, max_output_tokens=4096, repetition_penalty=1.0, conv_mode_override=None): """Generate response for the given message and image using deterministic generation for consistency""" if not LLAVA_AVAILABLE: return {"error": "LLaVA modules not available"} try: if not message_text or not image_input: return {"error": "Both message text and image are required"} our_chatbot = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len) # Process image input try: image = process_image_input(image_input) except Exception as e: return {"error": f"Failed to process image: {str(e)}"} # Save image for logging all_image_hash = [] all_image_path = [] # Generate hash for the image img_byte_arr = BytesIO() image.save(img_byte_arr, format='JPEG') img_byte_arr = img_byte_arr.getvalue() image_hash = hashlib.md5(img_byte_arr).hexdigest() all_image_hash.append(image_hash) # Save image to logs t = datetime.datetime.now() filename = os.path.join( LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{image_hash}.jpg", ) all_image_path.append(filename) if not os.path.isfile(filename): os.makedirs(os.path.dirname(filename), exist_ok=True) print("image save to", filename) image.save(filename) # Process image for model try: print(f"[DEBUG] Processing image for model...") processed_images = process_images([image], our_chatbot.image_processor, our_chatbot.model.config) print(f"[DEBUG] Processed images length: {len(processed_images)}") if len(processed_images) == 0: return {"error": "Image processing returned empty list"} image_tensor = processed_images[0] image_tensor = image_tensor.half().to(our_chatbot.model.device) image_tensor = image_tensor.unsqueeze(0) print(f"[DEBUG] Image tensor shape: {image_tensor.shape}") except Exception as e: print(f"[DEBUG] Image processing error: {str(e)}") return {"error": f"Image processing failed: {str(e)}"} # Prepare conversation - reset for each request to avoid history issues try: if hasattr(our_chatbot, 'conv_mode') and our_chatbot.conv_mode and LLAVA_AVAILABLE: our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy() else: # Use default conversation template our_chatbot.conversation = our_chatbot.conversation.__class__() except Exception as e: print(f"[DEBUG] Failed to reset conversation: {e}") # Continue with existing conversation inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp) our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None) prompt = our_chatbot.conversation.get_prompt() # Tokenize input input_ids = tokenizer_image_token( prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" ).unsqueeze(0).to(our_chatbot.model.device) # No stopping criteria - let model generate freely up to max_new_tokens print(f"[DEBUG] No stopping criteria - free generation up to {max_output_tokens} tokens") stopping_criteria = None # Set seed for deterministic generation # This ensures the same input always produces the same output torch.manual_seed(42) if torch.cuda.is_available(): torch.cuda.manual_seed(42) torch.cuda.manual_seed_all(42) # Generate response using deterministic greedy decoding # This eliminates randomness and ensures consistent responses with torch.no_grad(): outputs = our_chatbot.model.generate( inputs=input_ids, images=image_tensor, do_sample=False, # Deterministic generation for consistency max_new_tokens=max_output_tokens, repetition_penalty=repetition_penalty, use_cache=False, pad_token_id=our_chatbot.tokenizer.eos_token_id, eos_token_id=our_chatbot.tokenizer.eos_token_id, length_penalty=1.0, # Don't penalize longer sequences ) # Decode response try: print(f"[DEBUG] Outputs shape: {outputs.shape if hasattr(outputs, 'shape') else 'No shape attr'}") print(f"[DEBUG] Outputs length: {len(outputs) if hasattr(outputs, '__len__') else 'No length'}") print(f"[DEBUG] Input IDs shape: {input_ids.shape}") if len(outputs) == 0: return {"error": "Model generated empty output"} response = our_chatbot.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True) print(f"[DEBUG] Conversation messages length: {len(our_chatbot.conversation.messages)}") if len(our_chatbot.conversation.messages) > 0: last_message = our_chatbot.conversation.messages[-1] print(f"[DEBUG] Last message: {last_message}") if isinstance(last_message, list) and len(last_message) > 1: our_chatbot.conversation.messages[-1][-1] = response print(f"[DEBUG] Response added to conversation") else: print(f"[DEBUG] Last message format unexpected: {last_message}") # Add response as new message if format is wrong our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], response) else: print("[DEBUG] No conversation messages found") # Add response as new message our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], response) print(f"[DEBUG] Generated response length: {len(response)}") except Exception as e: print(f"[DEBUG] Response decoding error: {str(e)}") return {"error": f"Response decoding failed: {str(e)}"} # Log conversation history = [(message_text, response)] with open(get_conv_log_filename(), "a") as fout: data = { "type": "chat", "model": "PULSE-7b", "state": history, "images": all_image_hash, "images_path": all_image_path } print("#### conv log", data) fout.write(json.dumps(data) + "\n") # Upload files to Hugging Face if configured if api and repo_name: try: for upload_img in all_image_path: api.upload_file( path_or_fileobj=upload_img, path_in_repo=upload_img.replace("./logs/", ""), repo_id=repo_name, repo_type="dataset", ) # Upload conversation log api.upload_file( path_or_fileobj=get_conv_log_filename(), path_in_repo=get_conv_log_filename().replace("./logs/", ""), repo_id=repo_name, repo_type="dataset") except Exception as e: print(f"Failed to upload files: {e}") return { "status": "success", "response": response, "conversation_id": id(our_chatbot.conversation) } except Exception as e: return {"error": f"Generation failed: {str(e)}"} def upvote_last_response(conversation_id): """Upvote the last response""" try: vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B") return {"status": "success", "message": "Thank you for your voting!"} except Exception as e: return {"error": f"Failed to upvote: {str(e)}"} def downvote_last_response(conversation_id): """Downvote the last response""" try: vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B") return {"status": "success", "message": "Thank you for your voting!"} except Exception as e: return {"error": f"Failed to downvote: {str(e)}"} def flag_response(conversation_id): """Flag the last response""" try: vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B") return {"status": "success", "message": "Response flagged successfully"} except Exception as e: return {"error": f"Failed to flag response: {str(e)}"} # Initialize model when module is imported def initialize_model(): """Initialize the model and tokenizer""" global tokenizer, model, image_processor, context_len, args if not LLAVA_AVAILABLE: print("LLaVA modules not available, skipping model initialization") return False try: # Set default arguments class Args: def __init__(self): self.model_path = "PULSE-ECG/PULSE-7B" self.model_base = None self.num_gpus = 1 self.conv_mode = None self.max_new_tokens = 1024 self.num_frames = 16 self.load_8bit = False self.load_4bit = False self.debug = False args = Args() # Load model model_path = args.model_path model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model( args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit ) print("### image_processor", image_processor) print("### tokenizer", tokenizer) # Move model to GPU if available if torch.cuda.is_available(): model = model.to(torch.device('cuda')) print("Model moved to CUDA") else: print("CUDA not available, using CPU") return True except Exception as e: print(f"Failed to initialize model: {e}") return False # Don't initialize model on import - do it lazily model_initialized = False # Main endpoint function for Hugging Face def query(payload): """Main endpoint function for Hugging Face inference API""" global model_initialized # Lazy initialization - initialize model on first call if not model_initialized: print("Initializing model on first query...") model_initialized = initialize_model() if not model_initialized: return {"error": "Model initialization failed"} try: print(f"[DEBUG] query payload keys={list(payload.keys()) if hasattr(payload,'keys') else 'N/A'}") # Extract prompt with multiple possible keys message_text = (payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or "") # Normalize prompt to ensure consistent responses # This helps maintain consistency across different clients if PROMPT_NORMALIZATION and "ecg" in message_text.lower() and "diagnosis" in message_text.lower(): # Standardize ECG analysis prompts for consistency if "comprehensive" in message_text.lower(): message_text = DEFAULT_ECG_PROMPT elif "concise" in message_text.lower(): message_text = "What are the main features and diagnosis in this ECG image? Provide a concise, clinical answer." else: # Default to comprehensive analysis for consistency message_text = DEFAULT_ECG_PROMPT print(f"[DEBUG] Normalized prompt to: {message_text}") # Extract image with multiple possible keys image_input = (payload.get("image") or payload.get("image_url") or payload.get("img") or None) # Extract generation parameters with fallbacks max_output_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 8192)))) repetition_penalty = float(payload.get("repetition_penalty", 1.0)) conv_mode_override = payload.get("conv_mode", None) if not message_text or not message_text.strip(): return {"error": "Missing prompt text. Use 'message', 'query', 'prompt', or 'istem' key"} if not image_input: return {"error": "Missing image. Use 'image', 'image_url', or 'img' key"} # Generate response with deterministic parameters result = generate_response( message_text=message_text, image_input=image_input, max_output_tokens=max_output_tokens, repetition_penalty=repetition_penalty, conv_mode_override=conv_mode_override ) return result except Exception as e: return {"error": f"Query failed: {str(e)}"} # Additional utility endpoints def health_check(): """Health check endpoint""" return { "status": "healthy", "model_initialized": model_initialized, "cuda_available": torch.cuda.is_available(), "llava_available": LLAVA_AVAILABLE, "transformers_available": TRANSFORMERS_AVAILABLE, "cv2_available": CV2_AVAILABLE, "lazy_loading": True # Model will be loaded on first query } def get_model_info(): """Get model information""" if not model_initialized: return { "error": "Model not initialized yet", "lazy_loading": True, "note": "Model will be loaded on first query" } return { "model_path": args.model_path if args else "Unknown", "model_type": "PULSE-7B", "cuda_available": torch.cuda.is_available(), "device": str(model.device) if model else "Unknown" } # Hugging Face EndpointHandler class class EndpointHandler: """Hugging Face endpoint handler class""" def __init__(self, model_dir): """Initialize the endpoint handler""" self.model_dir = model_dir print(f"EndpointHandler initialized with model_dir: {model_dir}") def __call__(self, payload): """Main endpoint function - handles Hugging Face payload format""" # Hugging Face sends payload in "inputs" wrapper if "inputs" in payload: # Extract the actual payload from inputs wrapper actual_payload = payload["inputs"] return query(actual_payload) else: # Direct payload (for backward compatibility) return query(payload) def health_check(self): """Health check endpoint""" return health_check() def get_model_info(self): """Get model information""" return get_model_info() # For backward compatibility and testing if __name__ == "__main__": print("Handler module loaded successfully!") print("This handler is now ready for Hugging Face endpoints.") print("Use the 'query' function as the main endpoint.") print("Or use EndpointHandler class for Hugging Face compatibility.")