Rapid_ECG / handler.py
ismailhakki37's picture
Update handler.py
d790281 verified
raw
history blame
19.4 kB
# -*- coding: utf-8 -*-
# handler.py — PULSE-7B / LLaVA endpoint (robust + deterministic-ready)
import os
import datetime
import torch
import numpy as np
import hashlib
import json
import base64
import requests
from PIL import Image
from io import BytesIO
# Optional cv2
try:
import cv2
CV2_AVAILABLE = True
except ImportError:
CV2_AVAILABLE = False
print("Warning: cv2 (OpenCV) not available. Video processing will be disabled.")
# LLaVA stack
try:
from llava import conversation as conversation_lib
from llava.constants import DEFAULT_IMAGE_TOKEN
from llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
tokenizer_image_token,
process_images,
get_model_name_from_path,
KeywordsStoppingCriteria,
)
LLAVA_AVAILABLE = True
except ImportError as e:
LLAVA_AVAILABLE = False
print(f"Warning: LLaVA modules not available: {e}")
# Transformers
try:
from transformers import GenerationConfig
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
print("Warning: Transformers not available")
# HF Hub (optional)
try:
from huggingface_hub import HfApi, login
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
print("Warning: Hugging Face Hub not available")
# HF Hub init (optional)
if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
try:
login(token=os.environ["HF_TOKEN"], write_permission=True)
api = HfApi()
repo_name = os.environ.get("LOG_REPO", "")
except Exception as e:
print(f"Failed to initialize HF API: {e}")
api = None
repo_name = ""
else:
api = None
repo_name = ""
# Logs
external_log_dir = "./logs"
LOGDIR = external_log_dir
VOTEDIR = "./votes"
# Globals
tokenizer = None
model = None
image_processor = None
context_len = None
args = None
model_initialized = False
# ----- Utils -----
def get_conv_log_filename():
t = datetime.datetime.now()
return os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
def get_conv_vote_filename():
t = datetime.datetime.now()
name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json")
if not os.path.isfile(name):
os.makedirs(os.path.dirname(name), exist_ok=True)
return name
def vote_last_response(state, vote_type, model_selector):
if api and repo_name:
try:
with open(get_conv_vote_filename(), "a") as fout:
fout.write(json.dumps({"type": vote_type, "model": model_selector, "state": state}) + "\n")
api.upload_file(
path_or_fileobj=get_conv_vote_filename(),
path_in_repo=get_conv_vote_filename().replace("./votes/", ""),
repo_id=repo_name,
repo_type="dataset")
except Exception as e:
print(f"Failed to upload vote file: {e}")
def is_valid_video_filename(name):
if not CV2_AVAILABLE:
return False
return name.split(".")[-1].lower() in ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
def is_valid_image_filename(name):
return name.split(".")[-1].lower() in ["jpg","jpeg","png","bmp","gif","tiff","webp","heic","heif","jfif","svg","eps","raw"]
def load_image(image_file):
if image_file.startswith("http"):
r = requests.get(image_file)
if r.status_code == 200:
return Image.open(BytesIO(r.content)).convert("RGB")
raise ValueError("Failed to load image from URL")
return Image.open(image_file).convert("RGB")
def process_base64_image(base64_string):
if base64_string.startswith('data:image'):
base64_string = base64_string.split(',')[1]
image_data = base64.b64decode(base64_string)
return Image.open(BytesIO(image_data)).convert("RGB")
def process_image_input(image_input):
if isinstance(image_input, str):
if image_input.startswith("http"):
return load_image(image_input)
elif os.path.exists(image_input):
return load_image(image_input)
else:
return process_base64_image(image_input)
elif isinstance(image_input, dict) and "image" in image_input:
return process_base64_image(image_input["image"])
else:
raise ValueError("Unsupported image input format")
# ----- Chat session -----
class InferenceDemo(object):
def __init__(self, args, model_path, tokenizer, model, image_processor, context_len) -> None:
if not LLAVA_AVAILABLE:
raise ImportError("LLaVA modules not available")
disable_torch_init()
self.tokenizer, self.model, self.image_processor, self.context_len = (
tokenizer, model, image_processor, context_len
)
model_name = get_model_name_from_path(model_path)
if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "v1" in model_name.lower() or "pulse" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
elif "qwen" in model_name.lower():
conv_mode = "qwen_1_5"
else:
conv_mode = "llava_v0"
if args.conv_mode is not None and conv_mode != args.conv_mode:
print(f"[WARNING] auto inferred conv_mode={conv_mode}, using {args.conv_mode}")
else:
args.conv_mode = conv_mode
self.conv_mode = args.conv_mode
self.conversation = conv_templates[self.conv_mode].copy()
self.num_frames = args.num_frames
class ChatSessionManager:
def __init__(self):
self.chatbot_instance = None
def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}")
def reset_chatbot(self):
self.chatbot_instance = None
def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
if self.chatbot_instance is None:
self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len)
return self.chatbot_instance
chat_manager = ChatSessionManager()
def clear_history():
if not LLAVA_AVAILABLE:
return {"error": "LLaVA modules not available"}
try:
inst = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B",
tokenizer, model, image_processor, context_len)
mode = getattr(inst, 'conv_mode', None)
if mode and mode in conv_templates:
inst.conversation = conv_templates[mode].copy()
else:
inst.conversation = inst.conversation.__class__()
return {"status": "success", "message": "Conversation history cleared"}
except Exception as e:
return {"error": f"Failed to clear history: {str(e)}"}
# ----- Robust prefix stripper -----
def _strip_prefix_relaxed(text: str, prefix: str) -> str:
try:
if text.startswith(prefix):
return text[len(prefix):]
t_norm = " ".join(text.split())
p_norm = " ".join(prefix.split())
if t_norm.startswith(p_norm):
idx = text.find(prefix.splitlines()[0]) if prefix.splitlines() else -1
if idx >= 0:
return text[idx + len(prefix.splitlines()[0]):]
except Exception:
pass
return text
# ----- Core generate -----
def generate_response(message_text,
image_input,
temperature=0.05,
top_p=1.0,
max_output_tokens=1024,
repetition_penalty=1.0,
conv_mode_override=None,
do_sample=False, # default greedy -> deterministik
seed=None,
use_stop=True):
if not LLAVA_AVAILABLE:
return {"error": "LLaVA modules not available"}
try:
if not message_text or not image_input:
return {"error": "Both message text and image are required"}
# Determinism knobs
if seed is not None:
try:
seed = int(seed)
torch.manual_seed(seed)
np.random.seed(seed)
except Exception:
pass
inst = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B",
tokenizer, model, image_processor, context_len)
# Image
image = process_image_input(image_input)
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='JPEG')
image_hash = hashlib.md5(img_byte_arr.getvalue()).hexdigest()
# Save image to logs
t = datetime.datetime.now()
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{image_hash}.jpg")
os.makedirs(os.path.dirname(filename), exist_ok=True)
image.save(filename)
# Preprocess
processed_images = process_images([image], inst.image_processor, inst.model.config)
if len(processed_images) == 0:
return {"error": "Image processing returned empty list"}
image_tensor = processed_images[0].half().to(inst.model.device).unsqueeze(0)
# Conversation
if conv_mode_override:
inst.conversation = conv_templates[conv_mode_override].copy()
else:
inst.conversation = conv_templates[inst.conv_mode].copy()
inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text
inst.conversation.append_message(inst.conversation.roles[0], inp)
inst.conversation.append_message(inst.conversation.roles[1], None)
prompt = inst.conversation.get_prompt()
# Tokenize
input_ids = tokenizer_image_token(prompt, inst.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(inst.model.device)
# Stop criteria
stopping_criteria = None
stop_str = inst.conversation.sep if inst.conversation.sep_style != SeparatorStyle.TWO else inst.conversation.sep2
if use_stop:
stopping_criteria = KeywordsStoppingCriteria([stop_str], inst.tokenizer, input_ids)
# PAD/EOS safety
pad_id = inst.tokenizer.pad_token_id
eos_id = inst.tokenizer.eos_token_id if inst.tokenizer.eos_token_id is not None else pad_id
if pad_id is None:
# safety net (rare)
inst.tokenizer.add_special_tokens({"pad_token": inst.tokenizer.eos_token or "</s>"})
pad_id = inst.tokenizer.pad_token_id
eos_id = inst.tokenizer.eos_token_id or pad_id
gen_cfg = GenerationConfig(
do_sample=bool(do_sample),
temperature=float(temperature),
top_p=float(top_p),
max_new_tokens=int(max_output_tokens),
repetition_penalty=float(repetition_penalty),
pad_token_id=pad_id,
eos_token_id=eos_id
)
with torch.no_grad():
outputs = inst.model.generate(
inputs=input_ids,
images=image_tensor,
generation_config=gen_cfg,
use_cache=True,
stopping_criteria=[stopping_criteria] if stopping_criteria is not None else None,
return_dict_in_generate=True
)
# Robust decode
sequences = outputs.sequences
gen_ids = sequences[0]
full_text = inst.tokenizer.decode(gen_ids, skip_special_tokens=True)
prompt_text = inst.tokenizer.decode(input_ids[0], skip_special_tokens=True)
if gen_ids.shape[0] > input_ids.shape[1]:
response = inst.tokenizer.decode(gen_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
else:
response = _strip_prefix_relaxed(full_text, prompt_text).strip()
if not response:
response = full_text.replace(stop_str, "").strip()
# Add to conversation
if len(inst.conversation.messages) > 0 and isinstance(inst.conversation.messages[-1], list) and len(inst.conversation.messages[-1]) > 1:
inst.conversation.messages[-1][-1] = response
else:
inst.conversation.append_message(inst.conversation.roles[1], response)
# Log
with open(get_conv_log_filename(), "a") as fout:
fout.write(json.dumps({
"type": "chat",
"model": "PULSE-7b",
"state": [(message_text, response)],
"images": [image_hash],
"images_path": [filename]
}) + "\n")
return {"status": "success", "response": response, "conversation_id": id(inst.conversation)}
except Exception as e:
return {"error": f"Generation failed: {str(e)}"}
# ----- Votes -----
def upvote_last_response(conversation_id):
try:
vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B")
return {"status": "success", "message": "Upvoted"}
except Exception as e:
return {"error": str(e)}
def downvote_last_response(conversation_id):
try:
vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B")
return {"status": "success", "message": "Downvoted"}
except Exception as e:
return {"error": str(e)}
def flag_response(conversation_id):
try:
vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B")
return {"status": "success", "message": "Flagged"}
except Exception as e:
return {"error": str(e)}
# ----- Init model (with PAD/EOS safety) -----
def initialize_model():
global tokenizer, model, image_processor, context_len, args
if not LLAVA_AVAILABLE:
print("LLaVA modules not available, skipping model initialization")
return False
try:
class Args:
def __init__(self):
self.model_path = "PULSE-ECG/PULSE-7B"
self.model_base = None
self.num_gpus = 1
self.conv_mode = None
self.temperature = 0.05
self.max_new_tokens = 1024
self.num_frames = 16
self.load_8bit = False
self.load_4bit = False
self.debug = False
args = Args()
model_name = get_model_name_from_path(args.model_path)
tok, mdl, img_proc, ctx_len = load_pretrained_model(
args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
)
# PAD/EOS safety
if tok.eos_token_id is None and tok.eos_token is None:
try:
tok.add_special_tokens({"eos_token": "</s>"})
except Exception:
pass
if tok.pad_token_id is None:
if tok.eos_token is not None:
tok.pad_token = tok.eos_token
else:
if tok.unk_token is None:
try:
tok.add_special_tokens({"unk_token": "<unk>"})
except Exception:
pass
tok.pad_token = tok.unk_token or "</s>"
tokenizer, model, image_processor, context_len = tok, mdl, img_proc, ctx_len
if torch.cuda.is_available():
model = model.to(torch.device('cuda'))
print("Model moved to CUDA")
else:
print("CUDA not available, using CPU")
return True
except Exception as e:
print(f"Failed to initialize model: {e}")
return False
# ----- Query entrypoint -----
def query(payload):
global model_initialized
if not model_initialized:
print("Initializing model on first query...")
model_initialized = initialize_model()
if not model_initialized:
return {"error": "Model initialization failed"}
try:
# Log incoming keys
print(f"[DEBUG] query payload keys={list(payload.keys()) if hasattr(payload,'keys') else 'N/A'}")
# Inputs
message_text = (payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or "").strip()
image_input = (payload.get("image") or payload.get("image_url") or payload.get("img") or None)
# Gen params
temperature = float(payload.get("temperature", 0.05))
top_p = float(payload.get("top_p", 1.0))
max_output_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 1024))))
repetition_penalty = float(payload.get("repetition_penalty", 1.0))
conv_mode_override = payload.get("conv_mode", None)
# Determinism toggles
do_sample = bool(payload.get("do_sample", False)) # default greedy
seed = payload.get("seed", None)
use_stop = bool(payload.get("use_stop", True)) # default stop criteria açık
if not message_text:
return {"error": "Missing prompt text. Provide 'message' (or 'query'/'prompt'/'istem')."}
if not image_input:
return {"error": "Missing image. Provide 'image' (url/base64/path) or 'image_url'/'img'."}
return generate_response(
message_text=message_text,
image_input=image_input,
temperature=temperature,
top_p=top_p,
max_output_tokens=max_output_tokens,
repetition_penalty=repetition_penalty,
conv_mode_override=conv_mode_override,
do_sample=do_sample,
seed=seed,
use_stop=use_stop
)
except Exception as e:
return {"error": f"Query failed: {str(e)}"}
# ----- Health / Info -----
def health_check():
return {
"status": "healthy",
"model_initialized": model_initialized,
"cuda_available": torch.cuda.is_available(),
"llava_available": LLAVA_AVAILABLE,
"transformers_available": TRANSFORMERS_AVAILABLE,
"cv2_available": CV2_AVAILABLE,
"lazy_loading": True
}
def get_model_info():
if not model_initialized:
return {"error": "Model not initialized yet", "lazy_loading": True}
return {
"model_path": args.model_path if args else "Unknown",
"model_type": "PULSE-7B",
"cuda_available": torch.cuda.is_available(),
"device": str(model.device) if model else "Unknown"
}
# ----- HF Endpoint handler -----
class EndpointHandler:
def __init__(self, model_dir):
self.model_dir = model_dir
print(f"EndpointHandler initialized with model_dir: {model_dir}")
def __call__(self, payload):
if "inputs" in payload:
return query(payload["inputs"])
return query(payload)
def health_check(self):
return health_check()
def get_model_info(self):
return get_model_info()
if __name__ == "__main__":
print("Handler loaded and ready.")