import os import datetime import torch import numpy as np import hashlib import json import requests from PIL import Image from io import BytesIO # Try to import cv2, but make it optional try: import cv2 CV2_AVAILABLE = True except ImportError: CV2_AVAILABLE = False print("Warning: cv2 (OpenCV) not available. Video processing will be disabled.") # Try to import llava modules, but make them optional try: from llava import conversation as conversation_lib from llava.constants import DEFAULT_IMAGE_TOKEN from llava.constants import ( IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, ) from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import ( tokenizer_image_token, process_images, get_model_name_from_path, KeywordsStoppingCriteria, ) LLAVA_AVAILABLE = True except ImportError as e: LLAVA_AVAILABLE = False print(f"Warning: LLaVA modules not available: {e}") # Try to import transformers try: from transformers import TextStreamer, TextIteratorStreamer TRANSFORMERS_AVAILABLE = True except ImportError: TRANSFORMERS_AVAILABLE = False print("Warning: Transformers not available") # Try to import huggingface_hub try: from huggingface_hub import HfApi, login HF_HUB_AVAILABLE = True except ImportError: HF_HUB_AVAILABLE = False print("Warning: Hugging Face Hub not available") # Initialize Hugging Face API if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ: try: login(token=os.environ["HF_TOKEN"], write_permission=True) api = HfApi() repo_name = os.environ.get("LOG_REPO", "") except Exception as e: print(f"Failed to initialize HF API: {e}") api = None repo_name = "" else: api = None repo_name = "" external_log_dir = "./logs" LOGDIR = external_log_dir VOTEDIR = "./votes" # Global variables for model and tokenizer tokenizer = None model = None image_processor = None context_len = None args = None def get_conv_log_filename(): t = datetime.datetime.now() name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json") return name def get_conv_vote_filename(): t = datetime.datetime.now() name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json") if not os.path.isfile(name): os.makedirs(os.path.dirname(name), exist_ok=True) return name def vote_last_response(state, vote_type, model_selector): if api and repo_name: try: with open(get_conv_vote_filename(), "a") as fout: data = { "type": vote_type, "model": model_selector, "state": state, } fout.write(json.dumps(data) + "\n") api.upload_file( path_or_fileobj=get_conv_vote_filename(), path_in_repo=get_conv_vote_filename().replace("./votes/", ""), repo_id=repo_name, repo_type="dataset") except Exception as e: print(f"Failed to upload vote file: {e}") def is_valid_video_filename(name): if not CV2_AVAILABLE: return False # Video processing disabled video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"] ext = name.split(".")[-1].lower() return ext in video_extensions def is_valid_image_filename(name): image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"] ext = name.split(".")[-1].lower() return ext in image_extensions def sample_frames(video_file, num_frames): if not CV2_AVAILABLE: raise ImportError("cv2 (OpenCV) not available. Video processing is disabled.") video = cv2.VideoCapture(video_file) total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) interval = total_frames // num_frames frames = [] for i in range(total_frames): ret, frame = video.read() if not ret: continue if i % interval == 0: pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) frames.append(pil_img) video.release() return frames def load_image(image_file): if image_file.startswith("http") or image_file.startswith("https"): response = requests.get(image_file) if response.status_code == 200: image = Image.open(BytesIO(response.content)).convert("RGB") else: raise ValueError("Failed to load image from URL") else: print("Load image from local file") print(image_file) image = Image.open(image_file).convert("RGB") return image def process_base64_image(base64_string): """Process base64 encoded image string""" try: # Remove data URL prefix if present if base64_string.startswith('data:image'): base64_string = base64_string.split(',')[1] # Decode base64 to bytes image_data = base64.b64decode(base64_string) # Convert to PIL Image image = Image.open(BytesIO(image_data)).convert("RGB") return image except Exception as e: raise ValueError(f"Failed to process base64 image: {e}") def process_image_input(image_input): """Process different types of image input (file path, URL, or base64)""" if isinstance(image_input, str): if image_input.startswith("http"): return load_image(image_input) elif os.path.exists(image_input): return load_image(image_input) else: # Try to process as base64 return process_base64_image(image_input) elif isinstance(image_input, dict) and "image" in image_input: # Handle base64 image from dict return process_base64_image(image_input["image"]) else: raise ValueError("Unsupported image input format") class InferenceDemo(object): def __init__(self, args, model_path, tokenizer, model, image_processor, context_len) -> None: if not LLAVA_AVAILABLE: raise ImportError("LLaVA modules not available") disable_torch_init() self.tokenizer, self.model, self.image_processor, self.context_len = ( tokenizer, model, image_processor, context_len, ) model_name = get_model_name_from_path(model_path) if "llama-2" in model_name.lower(): conv_mode = "llava_llama_2" elif "v1" in model_name.lower() or "pulse" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt" elif "qwen" in model_name.lower(): conv_mode = "qwen_1_5" else: conv_mode = "llava_v0" if args.conv_mode is not None and conv_mode != args.conv_mode: print( "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( conv_mode, args.conv_mode, args.conv_mode ) ) else: args.conv_mode = conv_mode self.conv_mode = conv_mode self.conversation = conv_templates[args.conv_mode].copy() self.num_frames = args.num_frames class ChatSessionManager: def __init__(self): self.chatbot_instance = None def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len) print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}") def reset_chatbot(self): self.chatbot_instance = None def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): if self.chatbot_instance is None: self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len) return self.chatbot_instance chat_manager = ChatSessionManager() def clear_history(): """Clear conversation history""" if not LLAVA_AVAILABLE: return {"error": "LLaVA modules not available"} try: chatbot_instance = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len) chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy() return {"status": "success", "message": "Conversation history cleared"} except Exception as e: return {"error": f"Failed to clear history: {str(e)}"} def add_message(message_text, image_input=None): """Add a message to the conversation""" return {"status": "success", "message": "Message added"} def generate_response(message_text, image_input, temperature=0.05, top_p=1.0, max_output_tokens=4096, repetition_penalty=1.0, conv_mode_override=None): """Generate response for the given message and image""" if not LLAVA_AVAILABLE: return {"error": "LLaVA modules not available"} try: if not message_text or not image_input: return {"error": "Both message text and image are required"} our_chatbot = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len) # Process image input try: image = process_image_input(image_input) except Exception as e: return {"error": f"Failed to process image: {str(e)}"} # Save image for logging all_image_hash = [] all_image_path = [] # Generate hash for the image img_byte_arr = BytesIO() image.save(img_byte_arr, format='JPEG') img_byte_arr = img_byte_arr.getvalue() image_hash = hashlib.md5(img_byte_arr).hexdigest() all_image_hash.append(image_hash) # Save image to logs t = datetime.datetime.now() filename = os.path.join( LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{image_hash}.jpg", ) all_image_path.append(filename) if not os.path.isfile(filename): os.makedirs(os.path.dirname(filename), exist_ok=True) print("image save to", filename) image.save(filename) # Process image for model try: image_tensor = process_images([image], our_chatbot.image_processor, our_chatbot.model.config)[0] image_tensor = image_tensor.half().to(our_chatbot.model.device) image_tensor = image_tensor.unsqueeze(0) except Exception as e: return {"error": f"Image processing failed: {str(e)}"} # Prepare conversation inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp) our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None) prompt = our_chatbot.conversation.get_prompt() # Tokenize input input_ids = tokenizer_image_token( prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" ).unsqueeze(0).to(our_chatbot.model.device) # Set up stopping criteria stop_str = ( our_chatbot.conversation.sep if our_chatbot.conversation.sep_style != SeparatorStyle.TWO else our_chatbot.conversation.sep2 ) keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria( keywords, our_chatbot.tokenizer, input_ids ) # Generate response with torch.no_grad(): outputs = our_chatbot.model.generate( inputs=input_ids, images=image_tensor, do_sample=True, temperature=temperature, top_p=top_p, max_new_tokens=max_output_tokens, repetition_penalty=repetition_penalty, use_cache=False, stopping_criteria=[stopping_criteria], ) # Decode response response = our_chatbot.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True) our_chatbot.conversation.messages[-1][-1] = response # Log conversation history = [(message_text, response)] with open(get_conv_log_filename(), "a") as fout: data = { "type": "chat", "model": "PULSE-7b", "state": history, "images": all_image_hash, "images_path": all_image_path } print("#### conv log", data) fout.write(json.dumps(data) + "\n") # Upload files to Hugging Face if configured if api and repo_name: try: for upload_img in all_image_path: api.upload_file( path_or_fileobj=upload_img, path_in_repo=upload_img.replace("./logs/", ""), repo_id=repo_name, repo_type="dataset", ) # Upload conversation log api.upload_file( path_or_fileobj=get_conv_log_filename(), path_in_repo=get_conv_log_filename().replace("./logs/", ""), repo_id=repo_name, repo_type="dataset") except Exception as e: print(f"Failed to upload files: {e}") return { "status": "success", "response": response, "conversation_id": id(our_chatbot.conversation) } except Exception as e: return {"error": f"Generation failed: {str(e)}"} def upvote_last_response(conversation_id): """Upvote the last response""" try: vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B") return {"status": "success", "message": "Thank you for your voting!"} except Exception as e: return {"error": f"Failed to upvote: {str(e)}"} def downvote_last_response(conversation_id): """Downvote the last response""" try: vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B") return {"status": "success", "message": "Thank you for your voting!"} except Exception as e: return {"error": f"Failed to downvote: {str(e)}"} def flag_response(conversation_id): """Flag the last response""" try: vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B") return {"status": "success", "message": "Response flagged successfully"} except Exception as e: return {"error": f"Failed to flag response: {str(e)}"} # Initialize model when module is imported def initialize_model(): """Initialize the model and tokenizer""" global tokenizer, model, image_processor, context_len, args if not LLAVA_AVAILABLE: print("LLaVA modules not available, skipping model initialization") return False try: # Set default arguments class Args: def __init__(self): self.model_path = "PULSE-ECG/PULSE-7B" self.model_base = None self.num_gpus = 1 self.conv_mode = None self.temperature = 0.05 self.max_new_tokens = 1024 self.num_frames = 16 self.load_8bit = False self.load_4bit = False self.debug = False args = Args() # Load model model_path = args.model_path model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model( args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit ) print("### image_processor", image_processor) print("### tokenizer", tokenizer) # Move model to GPU if available if torch.cuda.is_available(): model = model.to(torch.device('cuda')) print("Model moved to CUDA") else: print("CUDA not available, using CPU") return True except Exception as e: print(f"Failed to initialize model: {e}") return False # Don't initialize model on import - do it lazily model_initialized = False # Main endpoint function for Hugging Face def query(payload): """Main endpoint function for Hugging Face inference API""" global model_initialized # Lazy initialization - initialize model on first call if not model_initialized: print("Initializing model on first query...") model_initialized = initialize_model() if not model_initialized: return {"error": "Model initialization failed"} try: print(f"[DEBUG] query payload keys={list(payload.keys()) if hasattr(payload,'keys') else 'N/A'}") # Extract prompt with multiple possible keys message_text = (payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or "") # Extract image with multiple possible keys image_input = (payload.get("image") or payload.get("image_url") or payload.get("img") or None) # Extract generation parameters with fallbacks temperature = float(payload.get("temperature", 0.05)) top_p = float(payload.get("top_p", 1.0)) max_output_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096)))) repetition_penalty = float(payload.get("repetition_penalty", 1.0)) conv_mode_override = payload.get("conv_mode", None) if not message_text or not message_text.strip(): return {"error": "Missing prompt text. Use 'message', 'query', 'prompt', or 'istem' key"} if not image_input: return {"error": "Missing image. Use 'image', 'image_url', or 'img' key"} # Generate response with all parameters result = generate_response( message_text=message_text, image_input=image_input, temperature=temperature, top_p=top_p, max_output_tokens=max_output_tokens, repetition_penalty=repetition_penalty, conv_mode_override=conv_mode_override ) return result except Exception as e: return {"error": f"Query failed: {str(e)}"} # Additional utility endpoints def health_check(): """Health check endpoint""" return { "status": "healthy", "model_initialized": model_initialized, "cuda_available": torch.cuda.is_available(), "llava_available": LLAVA_AVAILABLE, "transformers_available": TRANSFORMERS_AVAILABLE, "cv2_available": CV2_AVAILABLE, "lazy_loading": True # Model will be loaded on first query } def get_model_info(): """Get model information""" if not model_initialized: return { "error": "Model not initialized yet", "lazy_loading": True, "note": "Model will be loaded on first query" } return { "model_path": args.model_path if args else "Unknown", "model_type": "PULSE-7B", "cuda_available": torch.cuda.is_available(), "device": str(model.device) if model else "Unknown" } # Hugging Face EndpointHandler class class EndpointHandler: """Hugging Face endpoint handler class""" def __init__(self, model_dir): """Initialize the endpoint handler""" self.model_dir = model_dir print(f"EndpointHandler initialized with model_dir: {model_dir}") def __call__(self, payload): """Main endpoint function - handles Hugging Face payload format""" # Hugging Face sends payload in "inputs" wrapper if "inputs" in payload: # Extract the actual payload from inputs wrapper actual_payload = payload["inputs"] return query(actual_payload) else: # Direct payload (for backward compatibility) return query(payload) def health_check(self): """Health check endpoint""" return health_check() def get_model_info(self): """Get model information""" return get_model_info() # For backward compatibility and testing if __name__ == "__main__": print("Handler module loaded successfully!") print("This handler is now ready for Hugging Face endpoints.") print("Use the 'query' function as the main endpoint.") print("Or use EndpointHandler class for Hugging Face compatibility.")