|
|
import os |
|
|
import datetime |
|
|
import torch |
|
|
import numpy as np |
|
|
import hashlib |
|
|
import json |
|
|
import requests |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
|
|
|
|
|
|
try: |
|
|
import cv2 |
|
|
CV2_AVAILABLE = True |
|
|
except ImportError: |
|
|
CV2_AVAILABLE = False |
|
|
print("Warning: cv2 (OpenCV) not available. Video processing will be disabled.") |
|
|
|
|
|
|
|
|
try: |
|
|
from llava import conversation as conversation_lib |
|
|
from llava.constants import DEFAULT_IMAGE_TOKEN |
|
|
from llava.constants import ( |
|
|
IMAGE_TOKEN_INDEX, |
|
|
DEFAULT_IMAGE_TOKEN, |
|
|
DEFAULT_IM_START_TOKEN, |
|
|
DEFAULT_IM_END_TOKEN, |
|
|
) |
|
|
from llava.conversation import conv_templates, SeparatorStyle |
|
|
from llava.model.builder import load_pretrained_model |
|
|
from llava.utils import disable_torch_init |
|
|
from llava.mm_utils import ( |
|
|
tokenizer_image_token, |
|
|
process_images, |
|
|
get_model_name_from_path, |
|
|
KeywordsStoppingCriteria, |
|
|
) |
|
|
LLAVA_AVAILABLE = True |
|
|
except ImportError as e: |
|
|
LLAVA_AVAILABLE = False |
|
|
print(f"Warning: LLaVA modules not available: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
from transformers import TextStreamer, TextIteratorStreamer |
|
|
TRANSFORMERS_AVAILABLE = True |
|
|
except ImportError: |
|
|
TRANSFORMERS_AVAILABLE = False |
|
|
print("Warning: Transformers not available") |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import HfApi, login |
|
|
HF_HUB_AVAILABLE = True |
|
|
except ImportError: |
|
|
HF_HUB_AVAILABLE = False |
|
|
print("Warning: Hugging Face Hub not available") |
|
|
|
|
|
|
|
|
if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ: |
|
|
try: |
|
|
login(token=os.environ["HF_TOKEN"], write_permission=True) |
|
|
api = HfApi() |
|
|
repo_name = os.environ.get("LOG_REPO", "") |
|
|
except Exception as e: |
|
|
print(f"Failed to initialize HF API: {e}") |
|
|
api = None |
|
|
repo_name = "" |
|
|
else: |
|
|
api = None |
|
|
repo_name = "" |
|
|
|
|
|
external_log_dir = "./logs" |
|
|
LOGDIR = external_log_dir |
|
|
VOTEDIR = "./votes" |
|
|
|
|
|
|
|
|
tokenizer = None |
|
|
model = None |
|
|
image_processor = None |
|
|
context_len = None |
|
|
args = None |
|
|
|
|
|
def get_conv_log_filename(): |
|
|
t = datetime.datetime.now() |
|
|
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json") |
|
|
return name |
|
|
|
|
|
def get_conv_vote_filename(): |
|
|
t = datetime.datetime.now() |
|
|
name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json") |
|
|
if not os.path.isfile(name): |
|
|
os.makedirs(os.path.dirname(name), exist_ok=True) |
|
|
return name |
|
|
|
|
|
def vote_last_response(state, vote_type, model_selector): |
|
|
if api and repo_name: |
|
|
try: |
|
|
with open(get_conv_vote_filename(), "a") as fout: |
|
|
data = { |
|
|
"type": vote_type, |
|
|
"model": model_selector, |
|
|
"state": state, |
|
|
} |
|
|
fout.write(json.dumps(data) + "\n") |
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=get_conv_vote_filename(), |
|
|
path_in_repo=get_conv_vote_filename().replace("./votes/", ""), |
|
|
repo_id=repo_name, |
|
|
repo_type="dataset") |
|
|
except Exception as e: |
|
|
print(f"Failed to upload vote file: {e}") |
|
|
|
|
|
def is_valid_video_filename(name): |
|
|
if not CV2_AVAILABLE: |
|
|
return False |
|
|
video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"] |
|
|
ext = name.split(".")[-1].lower() |
|
|
return ext in video_extensions |
|
|
|
|
|
def is_valid_image_filename(name): |
|
|
image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"] |
|
|
ext = name.split(".")[-1].lower() |
|
|
return ext in image_extensions |
|
|
|
|
|
def sample_frames(video_file, num_frames): |
|
|
if not CV2_AVAILABLE: |
|
|
raise ImportError("cv2 (OpenCV) not available. Video processing is disabled.") |
|
|
|
|
|
video = cv2.VideoCapture(video_file) |
|
|
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
interval = total_frames // num_frames |
|
|
frames = [] |
|
|
for i in range(total_frames): |
|
|
ret, frame = video.read() |
|
|
if not ret: |
|
|
continue |
|
|
if i % interval == 0: |
|
|
pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
|
frames.append(pil_img) |
|
|
video.release() |
|
|
return frames |
|
|
|
|
|
def load_image(image_file): |
|
|
if image_file.startswith("http") or image_file.startswith("https"): |
|
|
response = requests.get(image_file) |
|
|
if response.status_code == 200: |
|
|
image = Image.open(BytesIO(response.content)).convert("RGB") |
|
|
else: |
|
|
raise ValueError("Failed to load image from URL") |
|
|
else: |
|
|
print("Load image from local file") |
|
|
print(image_file) |
|
|
image = Image.open(image_file).convert("RGB") |
|
|
return image |
|
|
|
|
|
def process_base64_image(base64_string): |
|
|
"""Process base64 encoded image string""" |
|
|
try: |
|
|
|
|
|
if base64_string.startswith('data:image'): |
|
|
base64_string = base64_string.split(',')[1] |
|
|
|
|
|
|
|
|
image_data = base64.b64decode(base64_string) |
|
|
|
|
|
|
|
|
image = Image.open(BytesIO(image_data)).convert("RGB") |
|
|
return image |
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to process base64 image: {e}") |
|
|
|
|
|
def process_image_input(image_input): |
|
|
"""Process different types of image input (file path, URL, or base64)""" |
|
|
if isinstance(image_input, str): |
|
|
if image_input.startswith("http"): |
|
|
return load_image(image_input) |
|
|
elif os.path.exists(image_input): |
|
|
return load_image(image_input) |
|
|
else: |
|
|
|
|
|
return process_base64_image(image_input) |
|
|
elif isinstance(image_input, dict) and "image" in image_input: |
|
|
|
|
|
return process_base64_image(image_input["image"]) |
|
|
else: |
|
|
raise ValueError("Unsupported image input format") |
|
|
|
|
|
class InferenceDemo(object): |
|
|
def __init__(self, args, model_path, tokenizer, model, image_processor, context_len) -> None: |
|
|
if not LLAVA_AVAILABLE: |
|
|
raise ImportError("LLaVA modules not available") |
|
|
|
|
|
disable_torch_init() |
|
|
|
|
|
self.tokenizer, self.model, self.image_processor, self.context_len = ( |
|
|
tokenizer, |
|
|
model, |
|
|
image_processor, |
|
|
context_len, |
|
|
) |
|
|
|
|
|
model_name = get_model_name_from_path(model_path) |
|
|
if "llama-2" in model_name.lower(): |
|
|
conv_mode = "llava_llama_2" |
|
|
elif "v1" in model_name.lower() or "pulse" in model_name.lower(): |
|
|
conv_mode = "llava_v1" |
|
|
elif "mpt" in model_name.lower(): |
|
|
conv_mode = "mpt" |
|
|
elif "qwen" in model_name.lower(): |
|
|
conv_mode = "qwen_1_5" |
|
|
else: |
|
|
conv_mode = "llava_v0" |
|
|
|
|
|
if args.conv_mode is not None and conv_mode != args.conv_mode: |
|
|
print( |
|
|
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( |
|
|
conv_mode, args.conv_mode, args.conv_mode |
|
|
) |
|
|
) |
|
|
else: |
|
|
args.conv_mode = conv_mode |
|
|
self.conv_mode = conv_mode |
|
|
self.conversation = conv_templates[args.conv_mode].copy() |
|
|
self.num_frames = args.num_frames |
|
|
|
|
|
class ChatSessionManager: |
|
|
def __init__(self): |
|
|
self.chatbot_instance = None |
|
|
|
|
|
def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): |
|
|
self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len) |
|
|
print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}") |
|
|
|
|
|
def reset_chatbot(self): |
|
|
self.chatbot_instance = None |
|
|
|
|
|
def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): |
|
|
if self.chatbot_instance is None: |
|
|
self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len) |
|
|
return self.chatbot_instance |
|
|
|
|
|
chat_manager = ChatSessionManager() |
|
|
|
|
|
def clear_history(): |
|
|
"""Clear conversation history""" |
|
|
if not LLAVA_AVAILABLE: |
|
|
return {"error": "LLaVA modules not available"} |
|
|
|
|
|
try: |
|
|
chatbot_instance = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len) |
|
|
chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy() |
|
|
return {"status": "success", "message": "Conversation history cleared"} |
|
|
except Exception as e: |
|
|
return {"error": f"Failed to clear history: {str(e)}"} |
|
|
|
|
|
def add_message(message_text, image_input=None): |
|
|
"""Add a message to the conversation""" |
|
|
return {"status": "success", "message": "Message added"} |
|
|
|
|
|
def generate_response(message_text, image_input, temperature=0.05, top_p=1.0, max_output_tokens=4096): |
|
|
"""Generate response for the given message and image""" |
|
|
if not LLAVA_AVAILABLE: |
|
|
return {"error": "LLaVA modules not available"} |
|
|
|
|
|
try: |
|
|
if not message_text or not image_input: |
|
|
return {"error": "Both message text and image are required"} |
|
|
|
|
|
our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len) |
|
|
|
|
|
|
|
|
try: |
|
|
image = process_image_input(image_input) |
|
|
except Exception as e: |
|
|
return {"error": f"Failed to process image: {str(e)}"} |
|
|
|
|
|
|
|
|
all_image_hash = [] |
|
|
all_image_path = [] |
|
|
|
|
|
|
|
|
img_byte_arr = BytesIO() |
|
|
image.save(img_byte_arr, format='JPEG') |
|
|
img_byte_arr = img_byte_arr.getvalue() |
|
|
image_hash = hashlib.md5(img_byte_arr).hexdigest() |
|
|
all_image_hash.append(image_hash) |
|
|
|
|
|
|
|
|
t = datetime.datetime.now() |
|
|
filename = os.path.join( |
|
|
LOGDIR, |
|
|
"serve_images", |
|
|
f"{t.year}-{t.month:02d}-{t.day:02d}", |
|
|
f"{image_hash}.jpg", |
|
|
) |
|
|
all_image_path.append(filename) |
|
|
if not os.path.isfile(filename): |
|
|
os.makedirs(os.path.dirname(filename), exist_ok=True) |
|
|
print("image save to", filename) |
|
|
image.save(filename) |
|
|
|
|
|
|
|
|
try: |
|
|
image_tensor = process_images([image], our_chatbot.image_processor, our_chatbot.model.config)[0] |
|
|
image_tensor = image_tensor.half().to(our_chatbot.model.device) |
|
|
image_tensor = image_tensor.unsqueeze(0) |
|
|
except Exception as e: |
|
|
return {"error": f"Image processing failed: {str(e)}"} |
|
|
|
|
|
|
|
|
inp = DEFAULT_IMAGE_TOKEN + "\n" + message_text |
|
|
our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp) |
|
|
our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None) |
|
|
prompt = our_chatbot.conversation.get_prompt() |
|
|
|
|
|
|
|
|
input_ids = tokenizer_image_token( |
|
|
prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" |
|
|
).unsqueeze(0).to(our_chatbot.model.device) |
|
|
|
|
|
|
|
|
stop_str = ( |
|
|
our_chatbot.conversation.sep |
|
|
if our_chatbot.conversation.sep_style != SeparatorStyle.TWO |
|
|
else our_chatbot.conversation.sep2 |
|
|
) |
|
|
keywords = [stop_str] |
|
|
stopping_criteria = KeywordsStoppingCriteria( |
|
|
keywords, our_chatbot.tokenizer, input_ids |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = our_chatbot.model.generate( |
|
|
inputs=input_ids, |
|
|
images=image_tensor, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
max_new_tokens=max_output_tokens, |
|
|
use_cache=False, |
|
|
stopping_criteria=[stopping_criteria], |
|
|
) |
|
|
|
|
|
|
|
|
response = our_chatbot.tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True) |
|
|
our_chatbot.conversation.messages[-1][-1] = response |
|
|
|
|
|
|
|
|
history = [(message_text, response)] |
|
|
with open(get_conv_log_filename(), "a") as fout: |
|
|
data = { |
|
|
"type": "chat", |
|
|
"model": "PULSE-7b", |
|
|
"state": history, |
|
|
"images": all_image_hash, |
|
|
"images_path": all_image_path |
|
|
} |
|
|
print("#### conv log", data) |
|
|
fout.write(json.dumps(data) + "\n") |
|
|
|
|
|
|
|
|
if api and repo_name: |
|
|
try: |
|
|
for upload_img in all_image_path: |
|
|
api.upload_file( |
|
|
path_or_fileobj=upload_img, |
|
|
path_in_repo=upload_img.replace("./logs/", ""), |
|
|
repo_id=repo_name, |
|
|
repo_type="dataset", |
|
|
) |
|
|
|
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=get_conv_log_filename(), |
|
|
path_in_repo=get_conv_log_filename().replace("./logs/", ""), |
|
|
repo_id=repo_name, |
|
|
repo_type="dataset") |
|
|
except Exception as e: |
|
|
print(f"Failed to upload files: {e}") |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"response": response, |
|
|
"conversation_id": id(our_chatbot.conversation) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"Generation failed: {str(e)}"} |
|
|
|
|
|
def upvote_last_response(conversation_id): |
|
|
"""Upvote the last response""" |
|
|
try: |
|
|
vote_last_response({"conversation_id": conversation_id}, "upvote", "PULSE-7B") |
|
|
return {"status": "success", "message": "Thank you for your voting!"} |
|
|
except Exception as e: |
|
|
return {"error": f"Failed to upvote: {str(e)}"} |
|
|
|
|
|
def downvote_last_response(conversation_id): |
|
|
"""Downvote the last response""" |
|
|
try: |
|
|
vote_last_response({"conversation_id": conversation_id}, "downvote", "PULSE-7B") |
|
|
return {"status": "success", "message": "Thank you for your voting!"} |
|
|
except Exception as e: |
|
|
return {"error": f"Failed to downvote: {str(e)}"} |
|
|
|
|
|
def flag_response(conversation_id): |
|
|
"""Flag the last response""" |
|
|
try: |
|
|
vote_last_response({"conversation_id": conversation_id}, "flag", "PULSE-7B") |
|
|
return {"status": "success", "message": "Response flagged successfully"} |
|
|
except Exception as e: |
|
|
return {"error": f"Failed to flag response: {str(e)}"} |
|
|
|
|
|
|
|
|
def initialize_model(): |
|
|
"""Initialize the model and tokenizer""" |
|
|
global tokenizer, model, image_processor, context_len, args |
|
|
|
|
|
if not LLAVA_AVAILABLE: |
|
|
print("LLaVA modules not available, skipping model initialization") |
|
|
return False |
|
|
|
|
|
try: |
|
|
|
|
|
class Args: |
|
|
def __init__(self): |
|
|
self.model_path = "PULSE-ECG/PULSE-7B" |
|
|
self.model_base = None |
|
|
self.num_gpus = 1 |
|
|
self.conv_mode = None |
|
|
self.temperature = 0.05 |
|
|
self.max_new_tokens = 1024 |
|
|
self.num_frames = 16 |
|
|
self.load_8bit = False |
|
|
self.load_4bit = False |
|
|
self.debug = False |
|
|
|
|
|
args = Args() |
|
|
|
|
|
|
|
|
model_path = args.model_path |
|
|
model_name = get_model_name_from_path(args.model_path) |
|
|
tokenizer, model, image_processor, context_len = load_pretrained_model( |
|
|
args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit |
|
|
) |
|
|
|
|
|
print("### image_processor", image_processor) |
|
|
print("### tokenizer", tokenizer) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
model = model.to(torch.device('cuda')) |
|
|
print("Model moved to CUDA") |
|
|
else: |
|
|
print("CUDA not available, using CPU") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Failed to initialize model: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
model_initialized = False |
|
|
|
|
|
|
|
|
def query(payload): |
|
|
"""Main endpoint function for Hugging Face inference API""" |
|
|
global model_initialized |
|
|
|
|
|
|
|
|
if not model_initialized: |
|
|
print("Initializing model on first query...") |
|
|
model_initialized = initialize_model() |
|
|
if not model_initialized: |
|
|
return {"error": "Model initialization failed"} |
|
|
|
|
|
try: |
|
|
|
|
|
message_text = payload.get("message", "") |
|
|
image_input = payload.get("image", None) |
|
|
temperature = payload.get("temperature", 0.05) |
|
|
top_p = payload.get("top_p", 1.0) |
|
|
max_output_tokens = payload.get("max_output_tokens", 4096) |
|
|
|
|
|
if not message_text or not image_input: |
|
|
return {"error": "Both 'message' and 'image' are required in the payload"} |
|
|
|
|
|
|
|
|
result = generate_response( |
|
|
message_text=message_text, |
|
|
image_input=image_input, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
max_output_tokens=max_output_tokens |
|
|
) |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"Query failed: {str(e)}"} |
|
|
|
|
|
|
|
|
def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model_initialized": model_initialized, |
|
|
"cuda_available": torch.cuda.is_available(), |
|
|
"llava_available": LLAVA_AVAILABLE, |
|
|
"transformers_available": TRANSFORMERS_AVAILABLE, |
|
|
"cv2_available": CV2_AVAILABLE, |
|
|
"lazy_loading": True |
|
|
} |
|
|
|
|
|
def get_model_info(): |
|
|
"""Get model information""" |
|
|
if not model_initialized: |
|
|
return { |
|
|
"error": "Model not initialized yet", |
|
|
"lazy_loading": True, |
|
|
"note": "Model will be loaded on first query" |
|
|
} |
|
|
|
|
|
return { |
|
|
"model_path": args.model_path if args else "Unknown", |
|
|
"model_type": "PULSE-7B", |
|
|
"cuda_available": torch.cuda.is_available(), |
|
|
"device": str(model.device) if model else "Unknown" |
|
|
} |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
"""Hugging Face endpoint handler class""" |
|
|
|
|
|
def __init__(self, model_dir): |
|
|
"""Initialize the endpoint handler""" |
|
|
self.model_dir = model_dir |
|
|
print(f"EndpointHandler initialized with model_dir: {model_dir}") |
|
|
|
|
|
def __call__(self, payload): |
|
|
"""Main endpoint function - delegates to query function""" |
|
|
return query(payload) |
|
|
|
|
|
def health_check(self): |
|
|
"""Health check endpoint""" |
|
|
return health_check() |
|
|
|
|
|
def get_model_info(self): |
|
|
"""Get model information""" |
|
|
return get_model_info() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Handler module loaded successfully!") |
|
|
print("This handler is now ready for Hugging Face endpoints.") |
|
|
print("Use the 'query' function as the main endpoint.") |
|
|
print("Or use EndpointHandler class for Hugging Face compatibility.") |