# -*- coding: utf-8 -*- # handler.py — PULSE-7B / LLaVA endpoint (robust + deterministic-ready) import os import datetime import torch import numpy as np import hashlib import json import base64 import requests from PIL import Image from io import BytesIO # Optional cv2 try: import cv2 CV2_AVAILABLE = True except ImportError: CV2_AVAILABLE = False print("Warning: cv2 (OpenCV) not available. Video processing will be disabled.") # LLaVA stack try: from llava import conversation as conversation_lib from llava.constants import DEFAULT_IMAGE_TOKEN from llava.constants import ( IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, ) from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import ( tokenizer_image_token, process_images, get_model_name_from_path, KeywordsStoppingCriteria, ) LLAVA_AVAILABLE = True except ImportError as e: LLAVA_AVAILABLE = False print(f"Warning: LLaVA modules not available: {e}") # Transformers try: from transformers import GenerationConfig TRANSFORMERS_AVAILABLE = True except ImportError: TRANSFORMERS_AVAILABLE = False print("Warning: Transformers not available") # HF Hub (optional) try: from huggingface_hub import HfApi, login HF_HUB_AVAILABLE = True except ImportError: HF_HUB_AVAILABLE = False print("Warning: Hugging Face Hub not available") # HF Hub init (optional) if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ: try: login(token=os.environ["HF_TOKEN"], write_permission=True) api = HfApi() repo_name = os.environ.get("LOG_REPO", "") except Exception as e: print(f"Failed to initialize HF API: {e}") api = None repo_name = "" else: api = None repo_name = "" # Logs external_log_dir = "./logs" LOGDIR = external_log_dir VOTEDIR = "./votes" # Globals tokenizer = None model = None image_processor = None context_len = None args = None model_initialized = False # ----- Utils ----- def get_conv_log_filename(): t = datetime.datetime.now() return os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json") def get_conv_vote_filename(): t = datetime.datetime.now() name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json") if not os.path.isfile(name): os.makedirs(os.path.dirname(name), exist_ok=True) return name def vote_last_response(state, vote_type, model_selector): if api and repo_name: try: with open(get_conv_vote_filename(), "a") as fout: fout.write(json.dumps({"type": vote_type, "model": model_selector, "state": state}) + "\n") api.upload_file( path_or_fileobj=get_conv_vote_filename(), path_in_repo=get_conv_vote_filename().replace("./votes/", ""), repo_id=repo_name, repo_type="dataset") except Exception as e: print(f"Failed to upload vote file: {e}") def is_valid_video_filename(name): if not CV2_AVAILABLE: return False return name.split(".")[-1].lower() in ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"] def is_valid_image_filename(name): return name.split(".")[-1].lower() in ["jpg","jpeg","png","bmp","gif","tiff","webp","heic","heif","jfif","svg","eps","raw"] def load_image(image_file): if image_file.startswith("http"): r = requests.get(image_file) if r.status_code == 200: return Image.open(BytesIO(r.content)).convert("RGB") raise ValueError("Failed to load image from URL") return Image.open(image_file).convert("RGB") def process_base64_image(base64_string): if base64_string.startswith('data:image'): base64_string = base64_string.split(',')[1] image_data = base64.b64decode(base64_string) return Image.open(BytesIO(image_data)).convert("RGB") def process_image_input(image_input): if isinstance(image_input, str): if image_input.startswith("http"): return load_image(image_input) elif os.path.exists(image_input): return load_image(image_input) else: return process_base64_image(image_input) elif isinstance(image_input, dict) and "image" in image_input: return process_base64_image(image_input["image"]) else: raise ValueError("Unsupported image input format") # ----- Chat session ----- class InferenceDemo(object): def __init__(self, args, model_path, tokenizer, model, image_processor, context_len) -> None: if not LLAVA_AVAILABLE: raise ImportError("LLaVA modules not available") disable_torch_init() self.tokenizer, self.model, self.image_processor, self.context_len = ( tokenizer, model, image_processor, context_len ) model_name = get_model_name_from_path(model_path) if "llama-2" in model_name.lower(): conv_mode = "llava_llama_2" elif "v1" in model_name.lower() or "pulse" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt" elif "qwen" in model_name.lower(): conv_mode = "qwen_1_5" else: conv_mode = "llava_v0" if args.conv_mode is not None and conv_mode != args.conv_mode: print(f"[WARNING] auto inferred conv_mode={conv_mode}, using {args.conv_mode}") else: args.conv_mode = conv_mode self.conv_mode = args.conv_mode self.conversation = conv_templates[self.conv_mode].copy() self.num_frames = args.num_frames class ChatSessionManager: def __init__(self): self.chatbot_instance = None def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len) print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}") def reset_chatbot(self): self.chatbot_instance = None def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): if self.chatbot_instance is None: self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len) return self.chatbot_instance chat_manager = ChatSessionManager() def clear_history(): if not LLAVA_AVAILABLE: return {"error": "LLaVA modules not available"} try: inst = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len) mode = getattr(inst, 'conv_mode', None) if mode and mode in conv_templates: inst.conversation = conv_templates[mode].copy() else: inst.conversation = inst.conversation.__class__() return {"status": "success", "message": "Conversation history cleared"} except Exception as e: return {"error": f"Failed to clear history: {str(e)}"} # ----- Robust prefix stripper ----- def _strip_prefix_relaxed(text: str, prefix: str) -> str: try: if text.startswith(prefix): return text[len(prefix):] t_norm = " ".join(text.split()) p_norm = " ".join(prefix.split()) if t_norm.startswith(p_norm): idx = text.find(prefix.splitlines()[0]) if prefix.splitlines() else -1 if idx >= 0: return text[idx + len(prefix.splitlines()[0]):] except Exception: pass return text # ----- Core generate ----- def generate_response(message_text, image_input, temperature=0.05, top_p=1.0, max_output_tokens=1024, repetition_penalty=1.0, conv_mode_override=None, do_sample=False, # default greedy -> deterministik seed=None, use_stop=True): if not LLAVA_AVAILABLE: return {"error": "LLaVA modules not available"} try: if not message_text or not image_input: return {"error": "Both message text and image are required"} # Determinism knobs if seed is not None: try: seed = int(seed) torch.manual_seed(seed) np.random.seed(seed) except Exception: pass inst = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", tokenizer, model, image_processor, context_len) # Image image = process_image_input(image_input) img_byte_arr = BytesIO() image.save(img_byte_arr, format='JPEG') image_hash = hashlib.md5(img_byte_arr.getvalue()).hexdigest() # Save image to logs t = datetime.datetime.now() filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{image_hash}.jpg") os.makedirs(os.path.dirname(filename), exist_ok=True) image.save(filename) # Preprocess processed_images = process_images([image], inst.image_processor, inst.model.config) if len(processed_images) == 0: return {"error": "Image processing returned empty list"} image_tensor = processed_images[0].half().to(inst.model.device).unsqueeze(0) # Conversation if conv_mode_override: inst.conversation = conv_templates[conv_mode_override].copy() else: inst.conversation = conv_templates[inst.conv_mode].copy() inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text inst.conversation.append_message(inst.conversation.roles[0], inp) inst.conversation.append_message(inst.conversation.roles[1], None) prompt = inst.conversation.get_prompt() # Tokenize input_ids = tokenizer_image_token(prompt, inst.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(inst.model.device) # Stop criteria stopping_criteria = None stop_str = inst.conversation.sep if inst.conversation.sep_style != SeparatorStyle.TWO else inst.conversation.sep2 if use_stop: stopping_criteria = KeywordsStoppingCriteria([stop_str], inst.tokenizer, input_ids) # PAD/EOS safety pad_id = inst.tokenizer.pad_token_id eos_id = inst.tokenizer.eos_token_id if inst.tokenizer.eos_token_id is not None else pad_id if pad_id is None: # safety net (rare) inst.tokenizer.add_special_tokens({"pad_token": inst.tokenizer.eos_token or ""}) pad_id = inst.tokenizer.pad_token_id eos_id = inst.tokenizer.eos_token_id or pad_id gen_cfg = GenerationConfig( do_sample=bool(do_sample), temperature=float(temperature), top_p=float(top_p), max_new_tokens=int(max_output_tokens), repetition_penalty=float(repetition_penalty), pad_token_id=pad_id, eos_token_id=eos_id ) with torch.no_grad(): outputs = inst.model.generate( inputs=input_ids, images=image_tensor, generation_config=gen_cfg, use_cache=True, stopping_criteria=[stopping_criteria] if stopping_criteria is not None else None, return_dict_in_generate=True ) # Robust decode sequences = outputs.sequences gen_ids = sequences[0] full_text = inst.tokenizer.decode(gen_ids, skip_special_tokens=True) prompt_text = inst.tokenizer.decode(input_ids[0], skip_special_tokens=True) if gen_ids.shape[0] > input_ids.shape[1]: response = inst.tokenizer.decode(gen_ids[input_ids.shape[1]:], skip_special_tokens=True).strip() else: response = _strip_prefix_relaxed(full_text, prompt_text).strip() if not response: response = full_text.replace(stop_str, "").strip() # Add to conversation if len(inst.conversation.messages) > 0 and isinstance(inst.conversation.messages[-1], list) and len(inst.conversation.messages[-1]) > 1: inst.conversation.messages[-1][-1] = response else: inst.conversation.append_message(inst.conversation.roles[1], response) # Log with open(get_conv_log_filename(), "a") as fout: fout.write(json.dumps({ "type": "chat", "model": "PULSE-7b", "state": [(message_text, response)], "images": [image_hash], "images_path": [filename] }) + "\n") return {"status": "success", "response": response, "conversation_id": id(inst.conversation)} except Exception as e: return {"error": f"Generation failed: {str(e)}"} # ----- Votes ----- def upvote_last_response(conversation_id): try: vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B") return {"status": "success", "message": "Upvoted"} except Exception as e: return {"error": str(e)} def downvote_last_response(conversation_id): try: vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B") return {"status": "success", "message": "Downvoted"} except Exception as e: return {"error": str(e)} def flag_response(conversation_id): try: vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B") return {"status": "success", "message": "Flagged"} except Exception as e: return {"error": str(e)} # ----- Init model (with PAD/EOS safety) ----- def initialize_model(): global tokenizer, model, image_processor, context_len, args if not LLAVA_AVAILABLE: print("LLaVA modules not available, skipping model initialization") return False try: class Args: def __init__(self): self.model_path = "PULSE-ECG/PULSE-7B" self.model_base = None self.num_gpus = 1 self.conv_mode = None self.temperature = 0.05 self.max_new_tokens = 1024 self.num_frames = 16 self.load_8bit = False self.load_4bit = False self.debug = False args = Args() model_name = get_model_name_from_path(args.model_path) tok, mdl, img_proc, ctx_len = load_pretrained_model( args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit ) # PAD/EOS safety if tok.eos_token_id is None and tok.eos_token is None: try: tok.add_special_tokens({"eos_token": ""}) except Exception: pass if tok.pad_token_id is None: if tok.eos_token is not None: tok.pad_token = tok.eos_token else: if tok.unk_token is None: try: tok.add_special_tokens({"unk_token": ""}) except Exception: pass tok.pad_token = tok.unk_token or "" tokenizer, model, image_processor, context_len = tok, mdl, img_proc, ctx_len if torch.cuda.is_available(): model = model.to(torch.device('cuda')) print("Model moved to CUDA") else: print("CUDA not available, using CPU") return True except Exception as e: print(f"Failed to initialize model: {e}") return False # ----- Query entrypoint ----- def query(payload): global model_initialized if not model_initialized: print("Initializing model on first query...") model_initialized = initialize_model() if not model_initialized: return {"error": "Model initialization failed"} try: # Log incoming keys print(f"[DEBUG] query payload keys={list(payload.keys()) if hasattr(payload,'keys') else 'N/A'}") # Inputs message_text = (payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or "").strip() image_input = (payload.get("image") or payload.get("image_url") or payload.get("img") or None) # Gen params temperature = float(payload.get("temperature", 0.05)) top_p = float(payload.get("top_p", 1.0)) max_output_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 1024)))) repetition_penalty = float(payload.get("repetition_penalty", 1.0)) conv_mode_override = payload.get("conv_mode", None) # Determinism toggles do_sample = bool(payload.get("do_sample", False)) # default greedy seed = payload.get("seed", None) use_stop = bool(payload.get("use_stop", True)) # default stop criteria açık if not message_text: return {"error": "Missing prompt text. Provide 'message' (or 'query'/'prompt'/'istem')."} if not image_input: return {"error": "Missing image. Provide 'image' (url/base64/path) or 'image_url'/'img'."} return generate_response( message_text=message_text, image_input=image_input, temperature=temperature, top_p=top_p, max_output_tokens=max_output_tokens, repetition_penalty=repetition_penalty, conv_mode_override=conv_mode_override, do_sample=do_sample, seed=seed, use_stop=use_stop ) except Exception as e: return {"error": f"Query failed: {str(e)}"} # ----- Health / Info ----- def health_check(): return { "status": "healthy", "model_initialized": model_initialized, "cuda_available": torch.cuda.is_available(), "llava_available": LLAVA_AVAILABLE, "transformers_available": TRANSFORMERS_AVAILABLE, "cv2_available": CV2_AVAILABLE, "lazy_loading": True } def get_model_info(): if not model_initialized: return {"error": "Model not initialized yet", "lazy_loading": True} return { "model_path": args.model_path if args else "Unknown", "model_type": "PULSE-7B", "cuda_available": torch.cuda.is_available(), "device": str(model.device) if model else "Unknown" } # ----- HF Endpoint handler ----- class EndpointHandler: def __init__(self, model_dir): self.model_dir = model_dir print(f"EndpointHandler initialized with model_dir: {model_dir}") def __call__(self, payload): if "inputs" in payload: return query(payload["inputs"]) return query(payload) def health_check(self): return health_check() def get_model_info(self): return get_model_info() if __name__ == "__main__": print("Handler loaded and ready.")