import textwrap from io import BytesIO import requests #import torch from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX from llava.conversation import SeparatorStyle, conv_templates from llava.mm_utils import ( KeywordsStoppingCriteria, get_model_name_from_path, process_images, tokenizer_image_token, ) from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from PIL import Image disable_torch_init() MODEL = "4bit/llava-v1.5-13b-3GB" model_name = get_model_name_from_path(MODEL) model_name tokenizer, model, image_processor, context_len = load_pretrained_model( model_path=MODEL, model_base=None, model_name=model_name, load_4bit=True ) def process_image(image): args = {"image_aspect_ratio": "pad"} image_tensor = process_images([image], image_processor, args) return image_tensor.to(model.device, dtype=torch.float16) processed_image = process_image(image) type(processed_image), processed_image.shape CONV_MODE = "llava_v0" def create_prompt(prompt: str): conv = conv_templates[CONV_MODE].copy() roles = conv.roles prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt conv.append_message(roles[0], prompt) conv.append_message(roles[1], None) return conv.get_prompt(), conv prompt, _ = create_prompt("Describe the image") print(prompt) def ask_image(image: Image, prompt: str): image_tensor = process_image(image) prompt, conv = create_prompt(prompt) input_ids = ( tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") .unsqueeze(0) .to(model.device) ) stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 stopping_criteria = KeywordsStoppingCriteria( keywords=[stop_str], tokenizer=tokenizer, input_ids=input_ids ) with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=True, temperature=0.01, max_new_tokens=512, use_cache=True, stopping_criteria=[stopping_criteria], ) return tokenizer.decode( output_ids[0, input_ids.shape[1] :], skip_special_tokens=True ).strip() import gradio as gr #import ask_image import textwrap # Define the function that takes the image and the text as input and returns the formatted result def describe_image(image, text): # Generate a description of the image result = ask_image(image, text) # Format the result so that it is wrapped to 110 characters per line formatted_result = textwrap.fill(result, width=110) return formatted_result # Create a Gradio interface with the following inputs and outputs demo = gr.Interface(fn=describe_image, inputs=["image", "text"], outputs="text") # Launch the Gradio interface demo.launch(inline=False)