Spaces:
Build error
Build error
| import gradio as gr | |
| import textwrap | |
| from io import BytesIO | |
| import requests | |
| import torch | |
| from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX | |
| from llava.conversation import SeparatorStyle, conv_templates | |
| from llava.mm_utils import ( | |
| KeywordsStoppingCriteria, | |
| get_model_name_from_path, | |
| process_images, | |
| tokenizer_image_token, | |
| ) | |
| from llava.model.builder import load_pretrained_model | |
| from llava.utils import disable_torch_init | |
| from PIL import Image | |
| from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX | |
| from llava.conversation import SeparatorStyle, conv_templates | |
| from llava.mm_utils import ( | |
| KeywordsStoppingCriteria, | |
| get_model_name_from_path, | |
| process_images, | |
| tokenizer_image_token, | |
| ) | |
| from llava.model.builder import load_pretrained_model | |
| from llava.utils import disable_torch_init | |
| from PIL import Image | |
| import torch | |
| # Disable PyTorch initialization | |
| disable_torch_init() | |
| # Load the pretrained model | |
| MODEL = "4bit/llava-v1.5-13b-3GB" | |
| model_name = get_model_name_from_path(MODEL) | |
| tokenizer, model, image_processor, context_len = load_pretrained_model( | |
| model_path=MODEL, model_base=None, model_name=model_name, load_4bit=True | |
| ) | |
| # Define the prompt creation function | |
| def create_prompt(prompt: str): | |
| conv = conv_templates["llava_v0"].copy() | |
| roles = conv.roles | |
| prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt | |
| conv.append_message(roles[0], prompt) | |
| conv.append_message(roles[1], None) | |
| return conv.get_prompt(), conv | |
| # Define the image processing function | |
| def process_image(image): | |
| args = {"image_aspect_ratio": "pad"} | |
| image_tensor = process_images([image], image_processor, args) | |
| return image_tensor.to(model.device, dtype=torch.float16) | |
| # Define the image description function | |
| def describe_image(image_file): | |
| image = Image.open(image_file) | |
| image.resize((500, 500)) | |
| processed_image = process_image(image) | |
| prompt, _ = create_prompt("Describe the image") | |
| input_ids = ( | |
| tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") | |
| .unsqueeze(0) | |
| .to(model.device) | |
| ) | |
| stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
| stopping_criteria = KeywordsStoppingCriteria( | |
| keywords=[stop_str], tokenizer=tokenizer, input_ids=input_ids | |
| ) | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| input_ids, | |
| images=processed_image, | |
| do_sample=True, | |
| temperature=0.01, | |
| max_new_tokens=512, | |
| use_cache=True, | |
| stopping_criteria=[stopping_criteria], | |
| ) | |
| description = tokenizer.decode( | |
| output_ids[0, input_ids.shape[1] :], skip_special_tokens=True | |
| ).strip() | |
| return description | |
| iface = gr.Interface( | |
| fn=describe_image, | |
| inputs=gr.Image(type="pil", label="Image"), # Specify the label for the input | |
| outputs=gr.Textbox(), | |
| live=True, | |
| capture_session=True | |
| ) | |
| # Launch the Gradio interface | |
| iface.launch() | |