|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import datetime |
|
|
import torch |
|
|
import numpy as np |
|
|
import hashlib |
|
|
import json |
|
|
import base64 |
|
|
import requests |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
|
|
|
|
|
|
try: |
|
|
import cv2 |
|
|
CV2_AVAILABLE = True |
|
|
except ImportError: |
|
|
CV2_AVAILABLE = False |
|
|
print("Warning: cv2 (OpenCV) not available. Video processing will be disabled.") |
|
|
|
|
|
|
|
|
try: |
|
|
from llava import conversation as conversation_lib |
|
|
from llava.constants import DEFAULT_IMAGE_TOKEN |
|
|
from llava.constants import ( |
|
|
IMAGE_TOKEN_INDEX, |
|
|
DEFAULT_IMAGE_TOKEN, |
|
|
DEFAULT_IM_START_TOKEN, |
|
|
DEFAULT_IM_END_TOKEN, |
|
|
) |
|
|
from llava.conversation import conv_templates, SeparatorStyle |
|
|
from llava.model.builder import load_pretrained_model |
|
|
from llava.utils import disable_torch_init |
|
|
from llava.mm_utils import ( |
|
|
tokenizer_image_token, |
|
|
process_images, |
|
|
get_model_name_from_path, |
|
|
KeywordsStoppingCriteria, |
|
|
) |
|
|
LLAVA_AVAILABLE = True |
|
|
except ImportError as e: |
|
|
LLAVA_AVAILABLE = False |
|
|
print(f"Warning: LLaVA modules not available: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
from transformers import GenerationConfig |
|
|
TRANSFORMERS_AVAILABLE = True |
|
|
except ImportError: |
|
|
TRANSFORMERS_AVAILABLE = False |
|
|
print("Warning: Transformers not available") |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import HfApi, login |
|
|
HF_HUB_AVAILABLE = True |
|
|
except ImportError: |
|
|
HF_HUB_AVAILABLE = False |
|
|
print("Warning: Hugging Face Hub not available") |
|
|
|
|
|
|
|
|
if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ: |
|
|
try: |
|
|
login(token=os.environ["HF_TOKEN"], write_permission=True) |
|
|
api = HfApi() |
|
|
repo_name = os.environ.get("LOG_REPO", "") |
|
|
except Exception as e: |
|
|
print(f"Failed to initialize HF API: {e}") |
|
|
api = None |
|
|
repo_name = "" |
|
|
else: |
|
|
api = None |
|
|
repo_name = "" |
|
|
|
|
|
|
|
|
external_log_dir = "./logs" |
|
|
LOGDIR = external_log_dir |
|
|
VOTEDIR = "./votes" |
|
|
|
|
|
|
|
|
tokenizer = None |
|
|
model = None |
|
|
image_processor = None |
|
|
context_len = None |
|
|
args = None |
|
|
model_initialized = False |
|
|
|
|
|
|
|
|
def get_conv_log_filename(): |
|
|
t = datetime.datetime.now() |
|
|
return os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json") |
|
|
|
|
|
def get_conv_vote_filename(): |
|
|
t = datetime.datetime.now() |
|
|
name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json") |
|
|
if not os.path.isfile(name): |
|
|
os.makedirs(os.path.dirname(name), exist_ok=True) |
|
|
return name |
|
|
|
|
|
def vote_last_response(state, vote_type, model_selector): |
|
|
if api and repo_name: |
|
|
try: |
|
|
with open(get_conv_vote_filename(), "a") as fout: |
|
|
fout.write(json.dumps({"type": vote_type, "model": model_selector, "state": state}) + "\n") |
|
|
api.upload_file( |
|
|
path_or_fileobj=get_conv_vote_filename(), |
|
|
path_in_repo=get_conv_vote_filename().replace("./votes/", ""), |
|
|
repo_id=repo_name, |
|
|
repo_type="dataset") |
|
|
except Exception as e: |
|
|
print(f"Failed to upload vote file: {e}") |
|
|
|
|
|
def is_valid_video_filename(name): |
|
|
if not CV2_AVAILABLE: |
|
|
return False |
|
|
return name.split(".")[-1].lower() in ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"] |
|
|
|
|
|
def is_valid_image_filename(name): |
|
|
return name.split(".")[-1].lower() in ["jpg","jpeg","png","bmp","gif","tiff","webp","heic","heif","jfif","svg","eps","raw"] |
|
|
|
|
|
def load_image(image_file): |
|
|
if image_file.startswith("http"): |
|
|
r = requests.get(image_file) |
|
|
if r.status_code == 200: |
|
|
return Image.open(BytesIO(r.content)).convert("RGB") |
|
|
raise ValueError("Failed to load image from URL") |
|
|
return Image.open(image_file).convert("RGB") |
|
|
|
|
|
def process_base64_image(base64_string): |
|
|
if base64_string.startswith('data:image'): |
|
|
base64_string = base64_string.split(',')[1] |
|
|
image_data = base64.b64decode(base64_string) |
|
|
return Image.open(BytesIO(image_data)).convert("RGB") |
|
|
|
|
|
def process_image_input(image_input): |
|
|
if isinstance(image_input, str): |
|
|
if image_input.startswith("http"): |
|
|
return load_image(image_input) |
|
|
elif os.path.exists(image_input): |
|
|
return load_image(image_input) |
|
|
else: |
|
|
return process_base64_image(image_input) |
|
|
elif isinstance(image_input, dict) and "image" in image_input: |
|
|
return process_base64_image(image_input["image"]) |
|
|
else: |
|
|
raise ValueError("Unsupported image input format") |
|
|
|
|
|
|
|
|
class InferenceDemo(object): |
|
|
def __init__(self, args, model_path, tokenizer, model, image_processor, context_len) -> None: |
|
|
if not LLAVA_AVAILABLE: |
|
|
raise ImportError("LLaVA modules not available") |
|
|
disable_torch_init() |
|
|
self.tokenizer, self.model, self.image_processor, self.context_len = ( |
|
|
tokenizer, model, image_processor, context_len |
|
|
) |
|
|
model_name = get_model_name_from_path(model_path) |
|
|
if "llama-2" in model_name.lower(): |
|
|
conv_mode = "llava_llama_2" |
|
|
elif "v1" in model_name.lower() or "pulse" in model_name.lower(): |
|
|
conv_mode = "llava_v1" |
|
|
elif "mpt" in model_name.lower(): |
|
|
conv_mode = "mpt" |
|
|
elif "qwen" in model_name.lower(): |
|
|
conv_mode = "qwen_1_5" |
|
|
else: |
|
|
conv_mode = "llava_v0" |
|
|
if args.conv_mode is not None and conv_mode != args.conv_mode: |
|
|
print(f"[WARNING] auto inferred conv_mode={conv_mode}, using {args.conv_mode}") |
|
|
else: |
|
|
args.conv_mode = conv_mode |
|
|
self.conv_mode = args.conv_mode |
|
|
self.conversation = conv_templates[self.conv_mode].copy() |
|
|
self.num_frames = args.num_frames |
|
|
|
|
|
class ChatSessionManager: |
|
|
def __init__(self): |
|
|
self.chatbot_instance = None |
|
|
def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): |
|
|
self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len) |
|
|
print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}") |
|
|
def reset_chatbot(self): |
|
|
self.chatbot_instance = None |
|
|
def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): |
|
|
if self.chatbot_instance is None: |
|
|
self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len) |
|
|
return self.chatbot_instance |
|
|
|
|
|
chat_manager = ChatSessionManager() |
|
|
|
|
|
def clear_history(): |
|
|
if not LLAVA_AVAILABLE: |
|
|
return {"error": "LLaVA modules not available"} |
|
|
try: |
|
|
inst = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", |
|
|
tokenizer, model, image_processor, context_len) |
|
|
mode = getattr(inst, 'conv_mode', None) |
|
|
if mode and mode in conv_templates: |
|
|
inst.conversation = conv_templates[mode].copy() |
|
|
else: |
|
|
inst.conversation = inst.conversation.__class__() |
|
|
return {"status": "success", "message": "Conversation history cleared"} |
|
|
except Exception as e: |
|
|
return {"error": f"Failed to clear history: {str(e)}"} |
|
|
|
|
|
|
|
|
def _strip_prefix_relaxed(text: str, prefix: str) -> str: |
|
|
try: |
|
|
if text.startswith(prefix): |
|
|
return text[len(prefix):] |
|
|
t_norm = " ".join(text.split()) |
|
|
p_norm = " ".join(prefix.split()) |
|
|
if t_norm.startswith(p_norm): |
|
|
idx = text.find(prefix.splitlines()[0]) if prefix.splitlines() else -1 |
|
|
if idx >= 0: |
|
|
return text[idx + len(prefix.splitlines()[0]):] |
|
|
except Exception: |
|
|
pass |
|
|
return text |
|
|
|
|
|
|
|
|
def generate_response(message_text, |
|
|
image_input, |
|
|
temperature=0.05, |
|
|
top_p=1.0, |
|
|
max_output_tokens=1024, |
|
|
repetition_penalty=1.0, |
|
|
conv_mode_override=None, |
|
|
do_sample=False, |
|
|
seed=None, |
|
|
use_stop=True): |
|
|
if not LLAVA_AVAILABLE: |
|
|
return {"error": "LLaVA modules not available"} |
|
|
|
|
|
try: |
|
|
if not message_text or not image_input: |
|
|
return {"error": "Both message text and image are required"} |
|
|
|
|
|
|
|
|
if seed is not None: |
|
|
try: |
|
|
seed = int(seed) |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
inst = chat_manager.get_chatbot(args, args.model_path if args else "PULSE-ECG/PULSE-7B", |
|
|
tokenizer, model, image_processor, context_len) |
|
|
|
|
|
|
|
|
image = process_image_input(image_input) |
|
|
img_byte_arr = BytesIO() |
|
|
image.save(img_byte_arr, format='JPEG') |
|
|
image_hash = hashlib.md5(img_byte_arr.getvalue()).hexdigest() |
|
|
|
|
|
|
|
|
t = datetime.datetime.now() |
|
|
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{image_hash}.jpg") |
|
|
os.makedirs(os.path.dirname(filename), exist_ok=True) |
|
|
image.save(filename) |
|
|
|
|
|
|
|
|
processed_images = process_images([image], inst.image_processor, inst.model.config) |
|
|
if len(processed_images) == 0: |
|
|
return {"error": "Image processing returned empty list"} |
|
|
image_tensor = processed_images[0].half().to(inst.model.device).unsqueeze(0) |
|
|
|
|
|
|
|
|
if conv_mode_override: |
|
|
inst.conversation = conv_templates[conv_mode_override].copy() |
|
|
else: |
|
|
inst.conversation = conv_templates[inst.conv_mode].copy() |
|
|
|
|
|
inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text |
|
|
inst.conversation.append_message(inst.conversation.roles[0], inp) |
|
|
inst.conversation.append_message(inst.conversation.roles[1], None) |
|
|
prompt = inst.conversation.get_prompt() |
|
|
|
|
|
|
|
|
input_ids = tokenizer_image_token(prompt, inst.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(inst.model.device) |
|
|
|
|
|
|
|
|
stopping_criteria = None |
|
|
stop_str = inst.conversation.sep if inst.conversation.sep_style != SeparatorStyle.TWO else inst.conversation.sep2 |
|
|
if use_stop: |
|
|
stopping_criteria = KeywordsStoppingCriteria([stop_str], inst.tokenizer, input_ids) |
|
|
|
|
|
|
|
|
pad_id = inst.tokenizer.pad_token_id |
|
|
eos_id = inst.tokenizer.eos_token_id if inst.tokenizer.eos_token_id is not None else pad_id |
|
|
if pad_id is None: |
|
|
|
|
|
inst.tokenizer.add_special_tokens({"pad_token": inst.tokenizer.eos_token or "</s>"}) |
|
|
pad_id = inst.tokenizer.pad_token_id |
|
|
eos_id = inst.tokenizer.eos_token_id or pad_id |
|
|
|
|
|
gen_cfg = GenerationConfig( |
|
|
do_sample=bool(do_sample), |
|
|
temperature=float(temperature), |
|
|
top_p=float(top_p), |
|
|
max_new_tokens=int(max_output_tokens), |
|
|
repetition_penalty=float(repetition_penalty), |
|
|
pad_token_id=pad_id, |
|
|
eos_token_id=eos_id |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = inst.model.generate( |
|
|
inputs=input_ids, |
|
|
images=image_tensor, |
|
|
generation_config=gen_cfg, |
|
|
use_cache=True, |
|
|
stopping_criteria=[stopping_criteria] if stopping_criteria is not None else None, |
|
|
return_dict_in_generate=True |
|
|
) |
|
|
|
|
|
|
|
|
sequences = outputs.sequences |
|
|
gen_ids = sequences[0] |
|
|
full_text = inst.tokenizer.decode(gen_ids, skip_special_tokens=True) |
|
|
prompt_text = inst.tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
|
|
|
|
if gen_ids.shape[0] > input_ids.shape[1]: |
|
|
response = inst.tokenizer.decode(gen_ids[input_ids.shape[1]:], skip_special_tokens=True).strip() |
|
|
else: |
|
|
response = _strip_prefix_relaxed(full_text, prompt_text).strip() |
|
|
|
|
|
if not response: |
|
|
response = full_text.replace(stop_str, "").strip() |
|
|
|
|
|
|
|
|
if len(inst.conversation.messages) > 0 and isinstance(inst.conversation.messages[-1], list) and len(inst.conversation.messages[-1]) > 1: |
|
|
inst.conversation.messages[-1][-1] = response |
|
|
else: |
|
|
inst.conversation.append_message(inst.conversation.roles[1], response) |
|
|
|
|
|
|
|
|
with open(get_conv_log_filename(), "a") as fout: |
|
|
fout.write(json.dumps({ |
|
|
"type": "chat", |
|
|
"model": "PULSE-7b", |
|
|
"state": [(message_text, response)], |
|
|
"images": [image_hash], |
|
|
"images_path": [filename] |
|
|
}) + "\n") |
|
|
|
|
|
return {"status": "success", "response": response, "conversation_id": id(inst.conversation)} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"Generation failed: {str(e)}"} |
|
|
|
|
|
|
|
|
def upvote_last_response(conversation_id): |
|
|
try: |
|
|
vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B") |
|
|
return {"status": "success", "message": "Upvoted"} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
def downvote_last_response(conversation_id): |
|
|
try: |
|
|
vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B") |
|
|
return {"status": "success", "message": "Downvoted"} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
def flag_response(conversation_id): |
|
|
try: |
|
|
vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B") |
|
|
return {"status": "success", "message": "Flagged"} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
def initialize_model(): |
|
|
global tokenizer, model, image_processor, context_len, args |
|
|
if not LLAVA_AVAILABLE: |
|
|
print("LLaVA modules not available, skipping model initialization") |
|
|
return False |
|
|
try: |
|
|
class Args: |
|
|
def __init__(self): |
|
|
self.model_path = "PULSE-ECG/PULSE-7B" |
|
|
self.model_base = None |
|
|
self.num_gpus = 1 |
|
|
self.conv_mode = None |
|
|
self.temperature = 0.05 |
|
|
self.max_new_tokens = 1024 |
|
|
self.num_frames = 16 |
|
|
self.load_8bit = False |
|
|
self.load_4bit = False |
|
|
self.debug = False |
|
|
args = Args() |
|
|
|
|
|
model_name = get_model_name_from_path(args.model_path) |
|
|
tok, mdl, img_proc, ctx_len = load_pretrained_model( |
|
|
args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit |
|
|
) |
|
|
|
|
|
|
|
|
if tok.eos_token_id is None and tok.eos_token is None: |
|
|
try: |
|
|
tok.add_special_tokens({"eos_token": "</s>"}) |
|
|
except Exception: |
|
|
pass |
|
|
if tok.pad_token_id is None: |
|
|
if tok.eos_token is not None: |
|
|
tok.pad_token = tok.eos_token |
|
|
else: |
|
|
if tok.unk_token is None: |
|
|
try: |
|
|
tok.add_special_tokens({"unk_token": "<unk>"}) |
|
|
except Exception: |
|
|
pass |
|
|
tok.pad_token = tok.unk_token or "</s>" |
|
|
|
|
|
tokenizer, model, image_processor, context_len = tok, mdl, img_proc, ctx_len |
|
|
if torch.cuda.is_available(): |
|
|
model = model.to(torch.device('cuda')) |
|
|
print("Model moved to CUDA") |
|
|
else: |
|
|
print("CUDA not available, using CPU") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Failed to initialize model: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def query(payload): |
|
|
global model_initialized |
|
|
if not model_initialized: |
|
|
print("Initializing model on first query...") |
|
|
model_initialized = initialize_model() |
|
|
if not model_initialized: |
|
|
return {"error": "Model initialization failed"} |
|
|
|
|
|
try: |
|
|
|
|
|
print(f"[DEBUG] query payload keys={list(payload.keys()) if hasattr(payload,'keys') else 'N/A'}") |
|
|
|
|
|
|
|
|
message_text = (payload.get("message") or payload.get("query") or payload.get("prompt") or payload.get("istem") or "").strip() |
|
|
image_input = (payload.get("image") or payload.get("image_url") or payload.get("img") or None) |
|
|
|
|
|
|
|
|
temperature = float(payload.get("temperature", 0.05)) |
|
|
top_p = float(payload.get("top_p", 1.0)) |
|
|
max_output_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 1024)))) |
|
|
repetition_penalty = float(payload.get("repetition_penalty", 1.0)) |
|
|
conv_mode_override = payload.get("conv_mode", None) |
|
|
|
|
|
|
|
|
do_sample = bool(payload.get("do_sample", False)) |
|
|
seed = payload.get("seed", None) |
|
|
use_stop = bool(payload.get("use_stop", True)) |
|
|
|
|
|
if not message_text: |
|
|
return {"error": "Missing prompt text. Provide 'message' (or 'query'/'prompt'/'istem')."} |
|
|
if not image_input: |
|
|
return {"error": "Missing image. Provide 'image' (url/base64/path) or 'image_url'/'img'."} |
|
|
|
|
|
return generate_response( |
|
|
message_text=message_text, |
|
|
image_input=image_input, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
max_output_tokens=max_output_tokens, |
|
|
repetition_penalty=repetition_penalty, |
|
|
conv_mode_override=conv_mode_override, |
|
|
do_sample=do_sample, |
|
|
seed=seed, |
|
|
use_stop=use_stop |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"Query failed: {str(e)}"} |
|
|
|
|
|
|
|
|
def health_check(): |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model_initialized": model_initialized, |
|
|
"cuda_available": torch.cuda.is_available(), |
|
|
"llava_available": LLAVA_AVAILABLE, |
|
|
"transformers_available": TRANSFORMERS_AVAILABLE, |
|
|
"cv2_available": CV2_AVAILABLE, |
|
|
"lazy_loading": True |
|
|
} |
|
|
|
|
|
def get_model_info(): |
|
|
if not model_initialized: |
|
|
return {"error": "Model not initialized yet", "lazy_loading": True} |
|
|
return { |
|
|
"model_path": args.model_path if args else "Unknown", |
|
|
"model_type": "PULSE-7B", |
|
|
"cuda_available": torch.cuda.is_available(), |
|
|
"device": str(model.device) if model else "Unknown" |
|
|
} |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir): |
|
|
self.model_dir = model_dir |
|
|
print(f"EndpointHandler initialized with model_dir: {model_dir}") |
|
|
def __call__(self, payload): |
|
|
if "inputs" in payload: |
|
|
return query(payload["inputs"]) |
|
|
return query(payload) |
|
|
def health_check(self): |
|
|
return health_check() |
|
|
def get_model_info(self): |
|
|
return get_model_info() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Handler loaded and ready.") |
|
|
|