Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import os | |
| import re | |
| import sys | |
| import bleach | |
| import cv2 | |
| import gradio as gr | |
| from matplotlib import pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor | |
| from model.LISA import LISAForCausalLM | |
| from model.llava import conversation as conversation_lib | |
| from model.llava.mm_utils import tokenizer_image_token | |
| from model.segment_anything.utils.transforms import ResizeLongestSide | |
| from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, | |
| DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX) | |
| import spaces | |
| def parse_args(args): | |
| parser = argparse.ArgumentParser(description="LISA chat") | |
| parser.add_argument("--version", default="derektan95/LISA-AVS") | |
| parser.add_argument("--vis_save_path", default="./vis_output", type=str) | |
| parser.add_argument( | |
| "--precision", | |
| default="bf16", | |
| type=str, | |
| choices=["fp32", "bf16", "fp16"], | |
| help="precision for inference", | |
| ) | |
| parser.add_argument("--image_size", default=1024, type=int, help="image size") | |
| parser.add_argument("--model_max_length", default=512, type=int) | |
| parser.add_argument("--lora_r", default=8, type=int) | |
| parser.add_argument( | |
| "--vision-tower", default="openai/clip-vit-large-patch14", type=str | |
| ) | |
| parser.add_argument("--local-rank", default=0, type=int, help="node rank") | |
| parser.add_argument("--load_in_8bit", action="store_true", default=False) | |
| parser.add_argument("--load_in_4bit", action="store_true", default=False) | |
| parser.add_argument("--use_mm_start_end", action="store_true", default=True) | |
| parser.add_argument( | |
| "--conv_type", | |
| default="llava_v1", | |
| type=str, | |
| choices=["llava_v1", "llava_llama_2"], | |
| ) | |
| return parser.parse_args(args) | |
| def preprocess( | |
| x, | |
| pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), | |
| pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), | |
| img_size=1024, | |
| ) -> torch.Tensor: | |
| """Normalize pixel values and pad to a square input.""" | |
| # Normalize colors | |
| x = (x - pixel_mean) / pixel_std | |
| # Pad | |
| h, w = x.shape[-2:] | |
| padh = img_size - h | |
| padw = img_size - w | |
| x = F.pad(x, (0, padw, 0, padh)) | |
| return x | |
| args = parse_args(sys.argv[1:]) | |
| os.makedirs(args.vis_save_path, exist_ok=True) | |
| # Create model | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| args.version, | |
| cache_dir=None, | |
| model_max_length=args.model_max_length, | |
| padding_side="right", | |
| use_fast=False, | |
| ) | |
| tokenizer.pad_token = tokenizer.unk_token | |
| args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] | |
| torch_dtype = torch.float32 | |
| if args.precision == "bf16": | |
| torch_dtype = torch.bfloat16 | |
| elif args.precision == "fp16": | |
| torch_dtype = torch.half | |
| kwargs = {"torch_dtype": torch_dtype} | |
| if args.load_in_4bit: | |
| kwargs.update( | |
| { | |
| "torch_dtype": torch.half, | |
| "load_in_4bit": True, | |
| "quantization_config": BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| llm_int8_skip_modules=["visual_model"], | |
| ), | |
| } | |
| ) | |
| elif args.load_in_8bit: | |
| kwargs.update( | |
| { | |
| "torch_dtype": torch.half, | |
| "quantization_config": BitsAndBytesConfig( | |
| llm_int8_skip_modules=["visual_model"], | |
| load_in_8bit=True, | |
| ), | |
| } | |
| ) | |
| model = LISAForCausalLM.from_pretrained( | |
| args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, **kwargs | |
| ) | |
| model.config.eos_token_id = tokenizer.eos_token_id | |
| model.config.bos_token_id = tokenizer.bos_token_id | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| model.get_model().initialize_vision_modules(model.get_model().config) | |
| vision_tower = model.get_model().get_vision_tower() | |
| vision_tower.to(dtype=torch_dtype) | |
| if args.precision == "bf16": | |
| model = model.bfloat16().cuda() | |
| elif ( | |
| args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit) | |
| ): | |
| vision_tower = model.get_model().get_vision_tower() | |
| model.model.vision_tower = None | |
| import deepspeed | |
| model_engine = deepspeed.init_inference( | |
| model=model, | |
| dtype=torch.half, | |
| replace_with_kernel_inject=True, | |
| replace_method="auto", | |
| ) | |
| model = model_engine.module | |
| model.model.vision_tower = vision_tower.half().cuda() | |
| elif args.precision == "fp32": | |
| model = model.float().cuda() | |
| vision_tower = model.get_model().get_vision_tower() | |
| vision_tower.to(device=args.local_rank) | |
| clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower) | |
| transform = ResizeLongestSide(args.image_size) | |
| model.eval() | |
| # Gradio | |
| examples_in_domain = [ | |
| [ | |
| "./imgs/examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg", | |
| "Where can I find the shore birds (Larus marinus) in this image? Please output segmentation mask.", | |
| ], | |
| [ | |
| "./imgs/examples/Animalia_Chordata_Mammalia_Rodentia_Caviidae_Hydrochoerus_hydrochaeris/28871_-12.80255_-69.29999.jpg", | |
| "Where can I find the capybaras (Hydrochoerus hydrochaeris) in this image? Please output segmentation mask and explain why.", | |
| ], | |
| [ | |
| "./imgs/examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg", | |
| "Where can I find the crabs (Ocypode quadrata) in this image? Please output segmentation mask.", | |
| ], | |
| [ | |
| "./imgs/examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg", | |
| "Where can I find the marmots (Marmota marmota) in this image? Please output segmentation mask and explain why.", | |
| ], | |
| [ | |
| "./imgs/examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg", | |
| "Where can I find monitor lizard (Varanus salvator) in this image? Please output segmentation mask.", | |
| ], | |
| ] | |
| examples_out_domain = [ | |
| [ | |
| "./imgs/examples/Animalia_Chordata_Mammalia_Carnivora_Phocidae_Mirounga_angustirostris/27423_35.64005_-121.17595.jpg", | |
| "Where can I find the seals (Animalia Chordata Mammalia Carnivora Phocidae Mirounga angustirostris) in this image? Please output segmentation mask and explain why.", | |
| ], | |
| [ | |
| "./imgs/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Urocyon_littoralis/304160_34.0144_-119.54417.jpg", | |
| "Where can I find the raccoons (Animalia Chordata Mammalia Carnivora Canidae Urocyon littoralis) in this image? Please output segmentation mask.", | |
| ], | |
| [ | |
| "./imgs/examples/Animalia_Chordata_Mammalia_Carnivora_Canidae_Canis_aureus/1528408_13.00422_80.23033.jpg", | |
| "Where can I find the wolves (Animalia Chordata Mammalia Carnivora Canidae Canis aureus) in this image? Please output segmentation mask and explain why.", | |
| ], | |
| [ | |
| "./imgs/examples/Animalia_Chordata_Elasmobranchii_Carcharhiniformes_Carcharhinidae_Triaenodon_obesus/5041_-0.28573_-90.54837.jpg", | |
| "Where can I find the sharks (Animalia Chordata Elasmobranchii Carcharhiniformes Carcharhinidae Triaenodon obesus) in this image? Please output segmentation mask.", | |
| ], | |
| [ | |
| "./imgs/examples/Animalia_Chordata_Reptilia_Crocodylia_Alligatoridae_Caiman_crocodilus/340271_10.52832_-83.49678.jpg", | |
| "Where can I find the crocodiles (Animalia Chordata Reptilia Crocodylia Alligatoridae Caiman crocodilus) in this image? Please output segmentation mask and explain why.", | |
| ], | |
| ] | |
| output_labels = ["Segmentation Output"] | |
| title = "LISA-AVS: LISA 7B Model Finetuned on AVS-Bench Dataset" | |
| description = """ | |
| <font size=4> | |
| This is an adapted version of the online demo for <a href='https://github.com/dvlab-research/LISA' target='_blank'>LISA</a>, where we finetune from scratch the LISA model (7B) with data from <a href='https://search-tta.github.io/' target='_blank'>AVS-Bench (Search-TTA)</a>. \n | |
| **Note**: Different prompts can lead to significantly varied results. Please **standardize** your input text prompts to **avoid ambiguity**, and pay attention to whether the **punctuations** of the input are correct. \n | |
| **Usage**: <br> | |
|  (1) To let LISA-AVS **segment something**, input prompt like: "Where can I find the <em>Common Name</em> (<em>Taxonomy Name</em>) in this image? Please output segmentation mask."; <br> | |
|  (2) To let LISA-AVS **output an explanation**, input prompt like: "Where can I find the <em>Common Name</em> (<em>Taxonomy Name</em>) in this image? Please output segmentation mask and explain why."; <br> | |
|  (3) To obtain **solely language output**, you can input like what you should do in current multi-modal LLM (e.g., LLaVA), like: "Where can I find the <em>Common Name</em> (<em>Taxonomy Name</em>) in this image?" <br> | |
| </font> | |
| """ | |
| article = """ | |
| <p style='text-align: center'> | |
| <a href='https://search-tta.github.io/' target='_blank'> | |
| Search-TTA | |
| </a> | |
| \n | |
| <p style='text-align: center'> | |
| <a href='https://huggingface.co/datasets/derektan95/avs-bench' target='_blank'> | |
| AVS-Bench | |
| </a> | |
| \n | |
| <p style='text-align: center'> | |
| <a href='https://github.com/dvlab-research/LISA' target='_blank'> LISA Project </a></p> | |
| """ | |
| ## to be implemented | |
| def inference(input_image, input_str): | |
| ## filter out special chars | |
| input_str = bleach.clean(input_str) | |
| print("input_str: ", input_str, "input_image: ", input_image) | |
| ## basic validity check: non-empty and reasonable length only | |
| if len(input_str.strip()) == 0 or len(input_str) > 1024: | |
| output_str = f"[Error] Invalid input length: {len(input_str)}" | |
| # Create a red placeholder image to indicate an error | |
| output_image = np.zeros((128, 128, 3), dtype=np.uint8) | |
| output_image[:] = (0, 0, 0) # Red color in RGB | |
| return output_image, output_str | |
| # Model Inference | |
| conv = conversation_lib.conv_templates[args.conv_type].copy() | |
| conv.messages = [] | |
| prompt = input_str | |
| prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt | |
| if args.use_mm_start_end: | |
| replace_token = ( | |
| DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN | |
| ) | |
| prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) | |
| conv.append_message(conv.roles[0], prompt) | |
| conv.append_message(conv.roles[1], "") | |
| prompt = conv.get_prompt() | |
| image_np = cv2.imread(input_image) | |
| image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) | |
| original_size_list = [image_np.shape[:2]] | |
| image_clip = ( | |
| clip_image_processor.preprocess(image_np, return_tensors="pt")[ | |
| "pixel_values" | |
| ][0] | |
| .unsqueeze(0) | |
| .cuda() | |
| ) | |
| if args.precision == "bf16": | |
| image_clip = image_clip.bfloat16() | |
| elif args.precision == "fp16": | |
| image_clip = image_clip.half() | |
| else: | |
| image_clip = image_clip.float() | |
| image = transform.apply_image(image_np) | |
| resize_list = [image.shape[:2]] | |
| image = ( | |
| preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) | |
| .unsqueeze(0) | |
| .cuda() | |
| ) | |
| if args.precision == "bf16": | |
| image = image.bfloat16() | |
| elif args.precision == "fp16": | |
| image = image.half() | |
| else: | |
| image = image.float() | |
| input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt") | |
| input_ids = input_ids.unsqueeze(0).cuda() | |
| output_ids, pred_masks = model.evaluate( | |
| image_clip, | |
| image, | |
| input_ids, | |
| resize_list, | |
| original_size_list, | |
| max_new_tokens=512, | |
| tokenizer=tokenizer, | |
| ) | |
| output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX] | |
| text_output = tokenizer.decode(output_ids, skip_special_tokens=False) | |
| text_output = text_output.replace("\n", "").replace(" ", " ") | |
| text_output = text_output.split("ASSISTANT: ")[-1] | |
| print("text_output: ", text_output) | |
| save_img = None | |
| for i, pred_mask in enumerate(pred_masks): | |
| if pred_mask.shape[0] == 0: | |
| continue | |
| pred_mask_np = pred_mask.detach().cpu().numpy()[0] | |
| # Normalize the continuous score mask to 0-255 range for visualization | |
| min_val = float(pred_mask_np.min()) | |
| max_val = float(pred_mask_np.max()) | |
| # Avoid division by zero if min_val == max_val | |
| denom = (max_val - min_val) if (max_val - min_val) != 0 else 1e-8 | |
| # Normalize to [0, 255] for image display | |
| normalized_mask = ((pred_mask_np - min_val) / denom * 255).astype(np.uint8) | |
| # Apply colormap (jet) to create a colored visualization | |
| save_img = cv2.applyColorMap(normalized_mask, cv2.COLORMAP_VIRIDIS) | |
| save_img = cv2.cvtColor(save_img, cv2.COLOR_BGR2RGB) | |
| # ------------------------------------------------------------- | |
| # Add a vertical legend (color bar) to the right of save_img | |
| # ------------------------------------------------------------- | |
| legend_width = 30 | |
| legend_height = save_img.shape[0] | |
| # Create vertical gradient from 255 (top) to 0 (bottom) | |
| gradient = np.linspace(255, 0, legend_height, dtype=np.uint8).reshape(-1, 1) | |
| gradient = np.repeat(gradient, legend_width, axis=1) | |
| legend_color = cv2.applyColorMap(gradient, cv2.COLORMAP_VIRIDIS) | |
| legend_color = cv2.cvtColor(legend_color, cv2.COLOR_BGR2RGB) | |
| # Put min / max text on legend | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = 0.4 | |
| thickness = 1 | |
| cv2.putText(legend_color, f"{max_val:.2f}", (2, 12), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA) | |
| cv2.putText(legend_color, f"{min_val:.2f}", (2, legend_height - 4), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA) | |
| # Concatenate original visualization with legend | |
| save_img = np.concatenate([save_img, legend_color], axis=1) | |
| output_str = "ASSISTANT: " + text_output # input_str | |
| if save_img is not None: | |
| output_image = save_img # input_image | |
| else: | |
| ## no segmentation output, return a black placeholder image | |
| output_image = np.zeros((128, 128, 3), dtype=np.uint8) | |
| return output_image, output_str | |
| with gr.Blocks() as demo: | |
| # Title, description and article | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown(description) | |
| # gr.Markdown(articSle) | |
| # Input and output columns side-by-side | |
| with gr.Row(): | |
| with gr.Column(): # Left column – inputs | |
| inp_image = gr.Image(type="filepath", label="Input Image",sources=["upload"]) | |
| inp_text = gr.Textbox(lines=1, placeholder=None, label="Text Instruction") | |
| run_btn = gr.Button("Run", variant="primary") | |
| with gr.Column(): # Right column – outputs | |
| out_seg = gr.Image(type="pil", label="Segmentation Output") | |
| out_text = gr.Textbox(lines=1, label="Text Output") | |
| # Bind the button to inference | |
| run_btn.click(fn=inference, inputs=[inp_image, inp_text], outputs=[out_seg, out_text]) | |
| # ---------------- Example Galleries ---------------- | |
| gr.Markdown("### In-Domain Taxonomy") | |
| gr.Examples( | |
| examples=examples_in_domain, | |
| inputs=[inp_image, inp_text], | |
| outputs=[out_seg, out_text], | |
| fn=inference, | |
| cache_examples=False, | |
| ) | |
| gr.Markdown("### Out-Domain Taxonomy") | |
| gr.Examples( | |
| examples=examples_out_domain, | |
| inputs=[inp_image, inp_text], | |
| outputs=[out_seg, out_text], | |
| fn=inference, | |
| cache_examples=False, | |
| ) | |
| demo.queue() | |
| demo.launch() | |