Spaces:
Running
on
Zero
Running
on
Zero
| ############################### | |
| # chat.py | |
| # Inference for LISA (terminal-based) | |
| ############################### | |
| import argparse | |
| import os | |
| import sys | |
| import cv2 | |
| from matplotlib import pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor | |
| from model.LISA import LISAForCausalLM | |
| from model.llava import conversation as conversation_lib | |
| from model.llava.mm_utils import tokenizer_image_token | |
| from model.segment_anything.utils.transforms import ResizeLongestSide | |
| from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, | |
| DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX) | |
| def parse_args(args): | |
| parser = argparse.ArgumentParser(description="LISA chat") | |
| parser.add_argument("--version", default="xinlai/LISA-13B-llama2-v1") | |
| parser.add_argument("--vis_save_path", default="./vis_output", type=str) | |
| parser.add_argument( | |
| "--precision", | |
| default="bf16", | |
| type=str, | |
| choices=["fp32", "bf16", "fp16"], | |
| help="precision for inference", | |
| ) | |
| parser.add_argument("--image_size", default=1024, type=int, help="image size") | |
| parser.add_argument("--model_max_length", default=512, type=int) | |
| parser.add_argument("--lora_r", default=8, type=int) | |
| parser.add_argument( | |
| "--vision-tower", default="openai/clip-vit-large-patch14", type=str | |
| ) | |
| parser.add_argument("--local-rank", default=0, type=int, help="node rank") | |
| parser.add_argument("--load_in_8bit", action="store_true", default=False) | |
| parser.add_argument("--load_in_4bit", action="store_true", default=False) | |
| parser.add_argument("--use_mm_start_end", action="store_true", default=True) | |
| parser.add_argument( | |
| "--conv_type", | |
| default="llava_v1", | |
| type=str, | |
| choices=["llava_v1", "llava_llama_2"], | |
| ) | |
| return parser.parse_args(args) | |
| def preprocess( | |
| x, | |
| pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), | |
| pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), | |
| img_size=1024, | |
| ) -> torch.Tensor: | |
| """Normalize pixel values and pad to a square input.""" | |
| # Normalize colors | |
| x = (x - pixel_mean) / pixel_std | |
| # Pad | |
| h, w = x.shape[-2:] | |
| padh = img_size - h | |
| padw = img_size - w | |
| x = F.pad(x, (0, padw, 0, padh)) | |
| return x | |
| def main(args): | |
| args = parse_args(args) | |
| os.makedirs(args.vis_save_path, exist_ok=True) | |
| # NOTE: NO NEED? | |
| # if args.version == "BigData-KSU/RS-llava-v1.5-7b-LoRA": | |
| # tokenizer_base = 'Intel/neural-chat-7b-v3-3' | |
| # else: | |
| # tokenizer_base = args.version | |
| # Create model | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| args.version, # tokenizer_base? | |
| cache_dir=None, | |
| model_max_length=args.model_max_length, | |
| padding_side="right", | |
| use_fast=False, | |
| ) | |
| tokenizer.pad_token = tokenizer.unk_token | |
| # num_added_tokens = tokenizer.add_tokens("[SEG]") # NOTE: NO NEED? | |
| args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] | |
| # NOTE: NO NEED? | |
| # if args.use_mm_start_end: | |
| # tokenizer.add_tokens( | |
| # [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True | |
| # ) | |
| torch_dtype = torch.float32 | |
| if args.precision == "bf16": | |
| torch_dtype = torch.bfloat16 | |
| elif args.precision == "fp16": | |
| torch_dtype = torch.half | |
| kwargs = {"torch_dtype": torch_dtype} | |
| if args.load_in_4bit: | |
| kwargs.update( | |
| { | |
| "torch_dtype": torch.half, | |
| "load_in_4bit": True, | |
| "quantization_config": BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| llm_int8_skip_modules=["visual_model"], | |
| ), | |
| } | |
| ) | |
| elif args.load_in_8bit: | |
| kwargs.update( | |
| { | |
| "torch_dtype": torch.half, | |
| "quantization_config": BitsAndBytesConfig( | |
| llm_int8_skip_modules=["visual_model"], | |
| load_in_8bit=True, | |
| ), | |
| } | |
| ) | |
| model = LISAForCausalLM.from_pretrained( | |
| args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, **kwargs | |
| ) | |
| model.config.eos_token_id = tokenizer.eos_token_id | |
| model.config.bos_token_id = tokenizer.bos_token_id | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| model.get_model().initialize_vision_modules(model.get_model().config) | |
| vision_tower = model.get_model().get_vision_tower() | |
| vision_tower.to(dtype=torch_dtype) | |
| if args.precision == "bf16": | |
| model = model.bfloat16().cuda() | |
| elif ( | |
| args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit) | |
| ): | |
| vision_tower = model.get_model().get_vision_tower() | |
| model.model.vision_tower = None | |
| import deepspeed | |
| model_engine = deepspeed.init_inference( | |
| model=model, | |
| dtype=torch.half, | |
| replace_with_kernel_inject=True, | |
| replace_method="auto", | |
| ) | |
| model = model_engine.module | |
| model.model.vision_tower = vision_tower.half().cuda() | |
| elif args.precision == "fp32": | |
| model = model.float().cuda() | |
| vision_tower = model.get_model().get_vision_tower() | |
| vision_tower.to(device=args.local_rank) | |
| clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower) | |
| transform = ResizeLongestSide(args.image_size) | |
| model.eval() | |
| while True: | |
| conv = conversation_lib.conv_templates[args.conv_type].copy() | |
| conv.messages = [] | |
| question = input("Please input your prompt: ") | |
| prompt = DEFAULT_IMAGE_TOKEN + "\n" + question | |
| if args.use_mm_start_end: | |
| replace_token = ( | |
| DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN | |
| ) | |
| prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) | |
| conv.append_message(conv.roles[0], prompt) | |
| conv.append_message(conv.roles[1], "") | |
| prompt = conv.get_prompt() | |
| image_path = input("Please input the image path: ") | |
| if not os.path.exists(image_path): | |
| print("File not found in {}".format(image_path)) | |
| continue | |
| image_np = cv2.imread(image_path) | |
| image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) | |
| original_size_list = [image_np.shape[:2]] | |
| image_clip = ( | |
| clip_image_processor.preprocess(image_np, return_tensors="pt")[ | |
| "pixel_values" | |
| ][0] | |
| .unsqueeze(0) | |
| .cuda() | |
| ) | |
| if args.precision == "bf16": | |
| image_clip = image_clip.bfloat16() | |
| elif args.precision == "fp16": | |
| image_clip = image_clip.half() | |
| else: | |
| image_clip = image_clip.float() | |
| image = transform.apply_image(image_np) | |
| resize_list = [image.shape[:2]] | |
| image = ( | |
| preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) | |
| .unsqueeze(0) | |
| .cuda() | |
| ) | |
| if args.precision == "bf16": | |
| image = image.bfloat16() | |
| elif args.precision == "fp16": | |
| image = image.half() | |
| else: | |
| image = image.float() | |
| input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt") | |
| input_ids = input_ids.unsqueeze(0).cuda() | |
| output_ids, pred_masks = model.evaluate( | |
| image_clip, | |
| image, | |
| input_ids, | |
| resize_list, | |
| original_size_list, | |
| max_new_tokens=512, | |
| tokenizer=tokenizer, | |
| ) | |
| output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX] | |
| text_output = tokenizer.decode(output_ids, skip_special_tokens=False) | |
| text_output = text_output.replace("\n", "").replace(" ", " ") | |
| print("text_output: ", text_output) | |
| # for i, pred_mask in enumerate(pred_masks): | |
| # if pred_mask.shape[0] == 0: | |
| # continue | |
| # print("min pre_mask: ", pred_mask.min()) | |
| # print("max pre_mask: ", pred_mask.max()) | |
| # pred_mask = pred_mask.detach().cpu().numpy()[0] | |
| # pred_mask = pred_mask > 0 | |
| # save_path = "{}/{}_mask_{}.jpg".format( | |
| # args.vis_save_path, image_path.split("/")[-1].split(".")[0], i | |
| # ) | |
| # cv2.imwrite(save_path, pred_mask * 100) | |
| # print("{} has been saved.".format(save_path)) | |
| # save_path = "{}/{}_masked_img_{}.jpg".format( | |
| # args.vis_save_path, image_path.split("/")[-1].split(".")[0], i | |
| # ) | |
| # save_img = image_np.copy() | |
| # save_img[pred_mask] = ( | |
| # image_np * 0.5 | |
| # + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5 | |
| # )[pred_mask] | |
| # save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR) | |
| # cv2.imwrite(save_path, save_img) | |
| # print("{} has been saved.".format(save_path)) | |
| for i, pred_mask in enumerate(pred_masks): | |
| if pred_mask.shape[0] == 0: | |
| continue | |
| # ------------------------------------------------------------------ | |
| # 1) Prepare / detach / copy stuff | |
| # ------------------------------------------------------------------ | |
| # Convert torch tensor -> NumPy | |
| pred_mask_np = pred_mask.detach().cpu().numpy()[0] | |
| # Convert your image from RGB to a float NumPy array if needed | |
| # (Adjust as necessary depending on your original image data type) | |
| image_rgb = image_np.astype(np.float32) # shape (H, W, 3) | |
| # ------------------------------------------------------------------ | |
| # 2) Create the Binary Mask & Overlaid Image (subplot #2) | |
| # ------------------------------------------------------------------ | |
| # Binary threshold (> 0) | |
| binary_mask = pred_mask_np > 0 | |
| # Make a copy of the original image for overlaying | |
| masked_image = image_rgb.copy() | |
| # Option A: Simple half-blend with red for the masked area | |
| # We only modify pixels where binary_mask is True | |
| red_color = np.array([255, 0, 0], dtype=np.float32) | |
| blended_red = image_rgb[binary_mask] * 0.5 + red_color * 0.5 | |
| masked_image[binary_mask] = blended_red | |
| # ------------------------------------------------------------------ | |
| # 3) Create the Raw Mask (subplot #3) + Colorbar | |
| # ------------------------------------------------------------------ | |
| min_val = float(pred_mask_np.min()) | |
| max_val = float(pred_mask_np.max()) | |
| # Avoid division by zero if min_val == max_val | |
| denom = (max_val - min_val) if (max_val - min_val) != 0 else 1e-8 | |
| # Normalize to [0, 1] | |
| normalized_mask = (pred_mask_np - min_val) / denom | |
| # ------------------------------------------------------------------ | |
| # 4) Plot everything with Matplotlib | |
| # ------------------------------------------------------------------ | |
| fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) | |
| # (Left) Original Image | |
| ax1.imshow(image_rgb.astype(np.uint8)) | |
| ax1.set_title("Original Image") | |
| ax1.axis("off") | |
| # (Middle) Binary Mask Overlaid | |
| ax2.imshow(masked_image.astype(np.uint8)) | |
| ax2.set_title("Binary Mask (>0) in Red") | |
| ax2.axis("off") | |
| # (Right) Raw Mask with Colorbar | |
| # Show the normalized mask in [0..1] range, but apply a color map | |
| im3 = ax3.imshow(normalized_mask, cmap='jet', vmin=0, vmax=1) | |
| ax3.set_title("Raw Mask (Continuous)") | |
| ax3.axis("off") | |
| # Add a colorbar to the third subplot | |
| cbar = fig.colorbar(im3, ax=ax3, fraction=0.046, pad=0.04) | |
| cbar.set_label("Normalized Mask Value") | |
| # Add a main title (optional) | |
| fig.suptitle(f"Question: {question}") | |
| answer = text_output[text_output.find("ASSISTANT"):] | |
| fig.text(0.5, 0.05, f"{answer}", ha='center', va='center') | |
| # ------------------------------------------------------------------ | |
| # 5) Show the figure, then save after it’s closed | |
| # ------------------------------------------------------------------ | |
| # When plt.show() returns, the figure is closed if interactive mode is off. | |
| plt.show(block=True) # This pauses execution until the window is closed. | |
| # Now save the figure | |
| save_path = "{}/{}_matplotlib_{}.png".format( | |
| args.vis_save_path, image_path.split("/")[-1].split(".")[0], i | |
| ) | |
| fig.savefig(save_path) | |
| print(f"Figure saved to: {save_path}") | |
| # Finally close the figure to free memory | |
| plt.close(fig) | |
| if __name__ == "__main__": | |
| main(sys.argv[1:]) | |