| | import textwrap |
| | from io import BytesIO |
| |
|
| | import requests |
| | |
| | from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX |
| | from llava.conversation import SeparatorStyle, conv_templates |
| | from llava.mm_utils import ( |
| | KeywordsStoppingCriteria, |
| | get_model_name_from_path, |
| | process_images, |
| | tokenizer_image_token, |
| | ) |
| | from llava.model.builder import load_pretrained_model |
| | from llava.utils import disable_torch_init |
| | from PIL import Image |
| |
|
| |
|
| | disable_torch_init() |
| |
|
| | MODEL = "4bit/llava-v1.5-13b-3GB" |
| | model_name = get_model_name_from_path(MODEL) |
| | model_name |
| |
|
| | tokenizer, model, image_processor, context_len = load_pretrained_model( |
| | model_path=MODEL, model_base=None, model_name=model_name, load_4bit=True |
| | ) |
| |
|
| | def process_image(image): |
| | args = {"image_aspect_ratio": "pad"} |
| | image_tensor = process_images([image], image_processor, args) |
| | return image_tensor.to(model.device, dtype=torch.float16) |
| |
|
| | processed_image = process_image(image) |
| | type(processed_image), processed_image.shape |
| |
|
| | CONV_MODE = "llava_v0" |
| |
|
| | def create_prompt(prompt: str): |
| | conv = conv_templates[CONV_MODE].copy() |
| | roles = conv.roles |
| | prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt |
| | conv.append_message(roles[0], prompt) |
| | conv.append_message(roles[1], None) |
| | return conv.get_prompt(), conv |
| |
|
| |
|
| | prompt, _ = create_prompt("Describe the image") |
| | print(prompt) |
| |
|
| | def ask_image(image: Image, prompt: str): |
| | image_tensor = process_image(image) |
| | prompt, conv = create_prompt(prompt) |
| | input_ids = ( |
| | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") |
| | .unsqueeze(0) |
| | .to(model.device) |
| | ) |
| |
|
| | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
| | stopping_criteria = KeywordsStoppingCriteria( |
| | keywords=[stop_str], tokenizer=tokenizer, input_ids=input_ids |
| | ) |
| |
|
| | with torch.inference_mode(): |
| | output_ids = model.generate( |
| | input_ids, |
| | images=image_tensor, |
| | do_sample=True, |
| | temperature=0.01, |
| | max_new_tokens=512, |
| | use_cache=True, |
| | stopping_criteria=[stopping_criteria], |
| | ) |
| | return tokenizer.decode( |
| | output_ids[0, input_ids.shape[1] :], skip_special_tokens=True |
| | ).strip() |
| |
|
| |
|
| | import gradio as gr |
| | |
| | import textwrap |
| |
|
| | |
| | def describe_image(image, text): |
| | |
| | result = ask_image(image, text) |
| |
|
| | |
| | formatted_result = textwrap.fill(result, width=110) |
| |
|
| | return formatted_result |
| |
|
| | |
| | demo = gr.Interface(fn=describe_image, inputs=["image", "text"], outputs="text") |
| |
|
| | |
| | demo.launch(inline=False) |