import gradio as gr import random import numpy as np import os import requests import torch import torchvision.transforms as T from PIL import Image from transformers import AutoProcessor, AutoModelForVision2Seq import cv2 import spaces import ast colors = [ (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), (0, 255, 255), (114, 128, 250), (0, 165, 255), (0, 128, 0), (144, 238, 144), (238, 238, 175), (255, 191, 0), (0, 128, 0), (226, 43, 138), (255, 0, 255), (0, 215, 255), (255, 0, 0), ] color_map = { f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for color_id, color in enumerate(colors) } def is_overlapping(rect1, rect2): x1, y1, x2, y2 = rect1 x3, y3, x4, y4 = rect2 return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4) @spaces.GPU def draw_entity_boxes_on_image(image, entities, show=False, save_path=None, entity_index=-1): """_summary_ Args: image (_type_): image or image path collect_entity_location (_type_): _description_ """ if isinstance(image, Image.Image): image_h = image.height image_w = image.width image = np.array(image)[:, :, [2, 1, 0]] elif isinstance(image, str): if os.path.exists(image): pil_img = Image.open(image).convert("RGB") image = np.array(pil_img)[:, :, [2, 1, 0]] image_h = pil_img.height image_w = pil_img.width else: raise ValueError(f"invaild image path, {image}") elif isinstance(image, torch.Tensor): # pdb.set_trace() image_tensor = image.cpu() reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None] reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None] image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean pil_img = T.ToPILImage()(image_tensor) image_h = pil_img.height image_w = pil_img.width image = np.array(pil_img)[:, :, [2, 1, 0]] else: raise ValueError(f"invaild image format, {type(image)} for {image}") if len(entities) == 0: return image indices = list(range(len(entities))) if entity_index >= 0: indices = [entity_index] # Not to show too many bboxes entities = entities[:len(color_map)] new_image = image.copy() previous_bboxes = [] # size of text text_size = 1 # thickness of text text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1)) box_line = 3 (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line) base_height = int(text_height * 0.675) text_offset_original = text_height - base_height text_spaces = 3 # num_bboxes = sum(len(x[-1]) for x in entities) used_colors = colors # random.sample(colors, k=num_bboxes) color_id = -1 for entity_idx, (entity_name, (start, end), bboxes) in enumerate(entities): color_id += 1 if entity_idx not in indices: continue for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes): orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm * image_w), int(y1_norm * image_h), int(x2_norm * image_w), int(y2_norm * image_h) # draw bbox color = used_colors[color_id] new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line) l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1 x1 = orig_x1 - l_o y1 = orig_y1 - l_o if y1 < text_height + text_offset_original + 2 * text_spaces: y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces x1 = orig_x1 + r_o (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line) text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1 for prev_bbox in previous_bboxes: while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox): text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces) text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces) y1 += (text_height + text_offset_original + 2 * text_spaces) if text_bg_y2 >= image_h: text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces)) text_bg_y2 = image_h y1 = image_h break alpha = 0.5 for i in range(text_bg_y1, text_bg_y2): for j in range(text_bg_x1, text_bg_x2): if i < image_h and j < image_w: if j < text_bg_x1 + 1.35 * c_width: bg_color = color else: bg_color = [255, 255, 255] new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(np.uint8) cv2.putText( new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces), cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA ) previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2)) pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]]) if save_path: pil_image.save(save_path) if show: pil_image.show() return pil_image ckpt = "microsoft/kosmos-2-patch14-224" model = AutoModelForVision2Seq.from_pretrained(ckpt) processor = AutoProcessor.from_pretrained(ckpt) @spaces.GPU def generate_predictions(image_input, text_input, question=None): user_image_path = "/tmp/user_input_test_image.jpg" image_input.save(user_image_path) image_input = Image.open(user_image_path) if text_input == "Brief": text_input = "An image of" elif text_input == "Detailed": text_input = "Describe this image in detail:" if question: text_input = f"{question}" inputs = processor(text=text_input, images=image_input, return_tensors="pt") generated_ids = model.generate( pixel_values=inputs["pixel_values"], input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], image_embeds=None, image_embeds_position_mask=inputs["image_embeds_position_mask"], use_cache=True, max_new_tokens=128, ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] processed_text, entities = processor.post_process_generation(generated_text) annotated_image = draw_entity_boxes_on_image(image_input, entities, show=False) color_id = -1 entity_info = [] filtered_entities = [] for entity in entities: entity_name, (start, end), bboxes = entity if start == end: continue color_id += 1 entity_info.append(((start, end), color_id)) filtered_entities.append(entity) colored_text = [] prev_start = 0 end = 0 for idx, ((start, end), color_id) in enumerate(entity_info): if start > prev_start: colored_text.append((processed_text[prev_start:start], None)) colored_text.append((processed_text[start:end], f"{color_id}")) prev_start = end if end < len(processed_text): colored_text.append((processed_text[end:len(processed_text)], None)) return annotated_image, colored_text, str(filtered_entities) term_of_use = """ ### Terms of use By using this model, users are required to agree to the following terms: The model is intended for academic and research purposes. The utilization of the model to create unsuitable material is strictly forbidden and not endorsed by this work. The accountability for any improper or unacceptable application of the model rests exclusively with the individuals who generated such content. """ # Custom CSS styles for Gradio interface custom_css = """ /* Add your custom CSS styles here */ .gradio-root { font-family: Arial, sans-serif; } .gradio-dropdown select { padding: 8px 10px; border-radius: 5px; border: 1px solid #ccc; background-color: #f9f9f9; } .gradio-radio input[type="radio"]:checked+label { background-color: #007bff; color: #fff; } .gradio-radio input[type="radio"]:not(:checked)+label { background-color: #fff; color: #555; } .gradio-radio input[type="radio"]:focus+label { outline: none; border-color: #007bff; } .gradio-radio label { border-radius: 5px; padding: 8px 12px; margin: 0; cursor: pointer; } .gradio-radio label:hover { background-color: #f0f0f0; } .gradio-slider-container { padding: 10px 0; } .gradio-slider { -webkit-appearance: none; width: 100%; height: 8px; border-radius: 5px; background-color: #f9f9f9; outline: none; opacity: 0.7; -webkit-transition: .2s; transition: opacity .2s; } .gradio-slider::-webkit-slider-thumb { -webkit-appearance: none; appearance: none; width: 16px; height: 16px; border-radius: 50%; background-color: #007bff; cursor: pointer; } .gradio-slider::-moz-range-thumb { width: 16px; height: 16px; border-radius: 50%; background-color: #007bff; cursor: pointer; } """ # Create Gradio interface with gr.Blocks(title="Kosmos-2", theme=gr.themes.Base(), css=custom_css).queue() as demo: # Add Gradio interface components # Add Gradio interface components gr.Markdown((""" # Kosmos-2: Grounding Multimodal Large Language Models to the World ### This model can answer visual questions, does localize objects in a given image, and even caption the image without hallucination! ### To get started, simply pick one of the images. Pick "Brief" or "Detailed" input for captioning. For visual question answering, pick "None" and enter your question. """)) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Test Image") text_input = gr.Radio(["Brief", "Detailed", "None"], label="Captioning Detail", value="Brief") question = gr.Textbox(label="Visual Question Answering") run_button = gr.Button(value="Run", visible=True) with gr.Column(): image_output = gr.Image(type="pil") text_output1 = gr.HighlightedText( label="Generated Description", combine_adjacent=False, show_legend=True, ) with gr.Row(): with gr.Column(): gr.Examples(examples=[ ["/content/krina2.png", "Detailed", None], ["/content/krina.png", "Brief", None], ["/content/krina3.png", "None", "What is in this image?"], ], inputs=[image_input, text_input, question]) gr.Markdown(term_of_use) selected = gr.Number(-1, show_label=False, visible=False) entity_output = gr.Textbox(visible=False) def get_text_span_label(evt: gr.SelectData): if evt.value[-1] is None: return -1 return int(evt.value[-1]) text_output1.select(get_text_span_label, None, selected) def update_output_image(img_input, image_output, entities, idx): entities = ast.literal_eval(entities) updated_image = draw_entity_boxes_on_image(img_input, entities, entity_index=idx) return updated_image selected.change(update_output_image, [image_input, image_output, entity_output, selected], [image_output]) run_button.click(fn=generate_predictions, inputs=[image_input, text_input, question], outputs=[image_output, text_output1, entity_output], show_progress=True, queue=True) demo.launch(debug=True)