Spaces:
Runtime error
Runtime error
| import torch | |
| from constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | |
| from conversation import conv_templates, SeparatorStyle | |
| from builder import load_pretrained_model | |
| from utils import disable_torch_init | |
| from mm_utils import process_images, tokenizer_image_token, get_model_name_from_path | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| from transformers import TextStreamer | |
| import spaces | |
| from functools import partial | |
| import traceback | |
| import sys | |
| # def load_image(image_file): | |
| # if image_file.startswith('http://') or image_file.startswith('https://'): | |
| # response = requests.get(image_file) | |
| # image = Image.open(BytesIO(response.content)).convert('RGB') | |
| # else: | |
| # image = Image.open(image_file).convert('RGB') | |
| # return image | |
| def load_image(image_file): | |
| print("the image file : ", image_file) | |
| image = Image.open(image_file).convert('RGB') | |
| if image is None: | |
| print("image is None") | |
| sys.exit("Aborting program: Image is None.") | |
| return image | |
| def run_inference( | |
| model_path, | |
| image_file, | |
| prompt_text, | |
| model_base=None, | |
| device="cuda", | |
| conv_mode=None, | |
| temperature=0.2, | |
| max_new_tokens=512, | |
| load_8bit=False, | |
| load_4bit=False, | |
| debug=False | |
| ): | |
| # Model initialization | |
| disable_torch_init() | |
| model_name = get_model_name_from_path(model_path) | |
| tokenizer, model, image_processor, context_len = load_pretrained_model( | |
| model_path, model_base, model_name, load_8bit, load_4bit | |
| ) | |
| # Determine conversation mode | |
| if "llama-2" in model_name.lower(): | |
| conv_mode_inferred = "llava_llama_2" | |
| elif "mistral" in model_name.lower(): | |
| conv_mode_inferred = "mistral_instruct" | |
| elif "v1.6-34b" in model_name.lower(): | |
| conv_mode_inferred = "chatml_direct" | |
| elif "v1" in model_name.lower(): | |
| conv_mode_inferred = "llava_v1" | |
| elif "mpt" in model_name.lower(): | |
| conv_mode_inferred = "mpt" | |
| elif "gemma" in model_name.lower(): | |
| conv_mode_inferred = "ferret_gemma_instruct" | |
| elif "llama" in model_name.lower(): | |
| conv_mode_inferred = "ferret_llama_3" | |
| else: | |
| conv_mode_inferred = "llava_v0" | |
| # Use user-specified conversation mode if provided | |
| conv_mode = conv_mode or conv_mode_inferred | |
| if conv_mode != conv_mode_inferred: | |
| print(f'[WARNING] the auto inferred conversation mode is {conv_mode_inferred}, while `conv_mode` is {conv_mode}, using {conv_mode}') | |
| conv = conv_templates[conv_mode].copy() | |
| if "mpt" in model_name.lower(): | |
| roles = ('user', 'assistant') | |
| else: | |
| roles = conv.roles | |
| # Load and process image | |
| print("loading image", image_file) | |
| image = load_image(image_file) | |
| if image is None: | |
| print("image is None") | |
| image_size = image.size | |
| image_h = 336 # Height of the image | |
| image_w = 336 | |
| #ERROR | |
| # image_tensor = process_images([image], image_processor, model.config) | |
| # if type(image_tensor) is list: | |
| # image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] | |
| # else: | |
| # image_tensor = image_tensor.to(model.device, dtype=torch.float16) | |
| if model.config.image_aspect_ratio == "square_nocrop": | |
| image_tensor = image_processor.preprocess(image, return_tensors='pt', do_resize=True, | |
| do_center_crop=False, size=[image_h, image_w])['pixel_values'][0] | |
| elif model.config.image_aspect_ratio == "anyres": | |
| image_process_func = partial(image_processor.preprocess, return_tensors='pt', do_resize=True, do_center_crop=False, size=[image_h, image_w]) | |
| image_tensor = process_images([image], image_processor, model.config, image_process_func=image_process_func)[0] | |
| else: | |
| image_tensor = process_images([image], image_processor, model.config)[0] | |
| if model.dtype == torch.float16: | |
| image_tensor = image_tensor.half() # Convert image tensor to float16 | |
| data_type = torch.float16 | |
| else: | |
| image_tensor = image_tensor.float() # Keep it in float32 | |
| data_type = torch.float32 | |
| # Now, add the batch dimension and move to GPU | |
| images = image_tensor.unsqueeze(0).to(data_type).cuda() | |
| # Process the first message with the image | |
| if model.config.mm_use_im_start_end: | |
| prompt_text = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt_text | |
| else: | |
| prompt_text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text | |
| # Prepare conversation | |
| conv.append_message(conv.roles[0], prompt_text) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt() | |
| input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) | |
| stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
| keywords = [stop_str] | |
| streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| print("image size: ", image_size) | |
| # Generate the model's response | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| input_ids, | |
| images=images, | |
| image_sizes=[image_size], | |
| do_sample=True if temperature > 0 else False, | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| streamer=streamer, | |
| num_beams=1, | |
| use_cache=True | |
| ) | |
| # Decode and return the output | |
| outputs = tokenizer.decode(output_ids[0]).strip() | |
| conv.messages[-1][-1] = outputs | |
| if debug: | |
| print("\n", {"prompt": prompt, "outputs": outputs}, "\n") | |
| return outputs | |
| # Example usage: | |
| # response = run_inference("path_to_model", "path_to_image", "your_prompt") | |
| # print(response) | |