""" Interactive affordance mask generation using prefill mode (single forward pass). Same interactive workflow as chat.py, but uses prefill inference instead of autoregressive generation. The assistant response "[AFF]." is pre-filled in the prompt, so the model only does one forward pass to extract mask embeddings. """ import argparse import os import sys import cv2 import numpy as np import torch import torch.nn.functional as F from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor from model.AffordanceVLM import AffordanceVLMForCausalLM from model.llava import conversation as conversation_lib from model.llava.mm_utils import tokenizer_image_token from model.segment_anything.utils.transforms import ResizeLongestSide from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX) def parse_args(args): parser = argparse.ArgumentParser(description="AffordanceVLM chat (prefill mode)") parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B") parser.add_argument("--vis_save_path", default="./vis_output_prefill", type=str) parser.add_argument( "--precision", default="bf16", type=str, choices=["fp32", "bf16", "fp16"], ) parser.add_argument("--image_size", default=1024, type=int) parser.add_argument("--model_max_length", default=512, type=int) parser.add_argument("--lora_r", default=8, type=int) parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14", type=str) parser.add_argument("--local-rank", default=0, type=int) parser.add_argument("--load_in_8bit", action="store_true", default=False) parser.add_argument("--load_in_4bit", action="store_true", default=False) parser.add_argument("--use_mm_start_end", action="store_true", default=True) parser.add_argument( "--conv_type", default="llava_v1", type=str, choices=["llava_v1", "llava_llama_2"], ) parser.add_argument("--prompt_template", type=str, default="Segment the most suitable manipulation region on the single target object for the task '{}'.", help="Template wrapping language_instruction. Use {} as placeholder.") # Segment the most suitable manipulation region on the single target object for the task '{}'. # Segment the affordance map for the task '{}' in this image. # Segment the affordance map of the single target object for the task '{}' in this image. # Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask. # Given the task instruction '{}', what is the affordance map of the single target object in this image? There is only one target object. Please output segmentation mask. return parser.parse_args(args) def preprocess( x, pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), img_size=1024, ) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" x = (x - pixel_mean) / pixel_std h, w = x.shape[-2:] padh = img_size - h padw = img_size - w x = F.pad(x, (0, padw, 0, padh)) return x def main(args): args = parse_args(args) os.makedirs(args.vis_save_path, exist_ok=True) # Create model tokenizer = AutoTokenizer.from_pretrained( args.version, cache_dir=None, model_max_length=args.model_max_length, padding_side="right", use_fast=False, ) tokenizer.pad_token = tokenizer.unk_token tokenizer.add_tokens("[SEG]") args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] tokenizer.add_tokens("[AFF]") args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0] torch_dtype = torch.float32 if args.precision == "bf16": torch_dtype = torch.bfloat16 elif args.precision == "fp16": torch_dtype = torch.half kwargs = {"torch_dtype": torch_dtype} if args.load_in_4bit: kwargs.update({ "torch_dtype": torch.half, "load_in_4bit": True, "quantization_config": BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", llm_int8_skip_modules=["visual_model"], ), }) elif args.load_in_8bit: kwargs.update({ "torch_dtype": torch.half, "quantization_config": BitsAndBytesConfig( llm_int8_skip_modules=["visual_model"], load_in_8bit=True, ), }) model = AffordanceVLMForCausalLM.from_pretrained( args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, aff_token_idx=args.aff_token_idx, **kwargs, ) model.config.eos_token_id = tokenizer.eos_token_id model.config.bos_token_id = tokenizer.bos_token_id model.config.pad_token_id = tokenizer.pad_token_id model.get_model().initialize_vision_modules(model.get_model().config) vision_tower = model.get_model().get_vision_tower() vision_tower.to(dtype=torch_dtype) if args.precision == "bf16": model = model.bfloat16().cuda() elif args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit): vision_tower = model.get_model().get_vision_tower() model.model.vision_tower = None import deepspeed model_engine = deepspeed.init_inference( model=model, dtype=torch.half, replace_with_kernel_inject=True, replace_method="auto", ) model = model_engine.module model.model.vision_tower = vision_tower.half().cuda() elif args.precision == "fp32": model = model.float().cuda() vision_tower = model.get_model().get_vision_tower() vision_tower.to(device=args.local_rank) clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower) transform = ResizeLongestSide(args.image_size) model.eval() # debug template = "Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask." while True: conv = conversation_lib.conv_templates[args.conv_type].copy() conv.messages = [] prompt = input("Please input your prompt: ") # 加入模版 prompt = args.prompt_template.format(prompt) prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + prompt if args.use_mm_start_end: replace_token = ( DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN ) prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) conv.append_message(conv.roles[0], prompt) conv.append_message(conv.roles[1], "[AFF].") prompt = conv.get_prompt() image_path = input("Please input the image path: ") if not os.path.exists(image_path): print("File not found in {}".format(image_path)) continue image_np = cv2.imread(image_path) image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) original_size_list = [image_np.shape[:2]] h, w = original_size_list[0] image_clip = ( clip_image_processor.preprocess(image_np, return_tensors="pt")[ "pixel_values" ][0] .unsqueeze(0) .cuda() ) if args.precision == "bf16": image_clip = image_clip.bfloat16() elif args.precision == "fp16": image_clip = image_clip.half() else: image_clip = image_clip.float() image = transform.apply_image(image_np) resize_list = [image.shape[:2]] image = ( preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) .unsqueeze(0) .cuda() ) if args.precision == "bf16": image = image.bfloat16() elif args.precision == "fp16": image = image.half() else: image = image.float() input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt") input_ids = input_ids.unsqueeze(0).cuda() attention_masks = input_ids.ne(tokenizer.pad_token_id) # Print the full prompt text (prefill mode has no generated text) # debug text_ids = input_ids[0][input_ids[0] != IMAGE_TOKEN_INDEX] text_output = tokenizer.decode(text_ids, skip_special_tokens=False) text_output = text_output.replace("\n", "").replace(" ", " ") print("text_output: ", text_output) # Prefill inference labels = input_ids.clone() offset = torch.LongTensor([0, 1]).cuda() masks_list = [torch.zeros(1, h, w).float().cuda()] label_list = [torch.zeros(h, w).long().cuda()] with torch.no_grad(): output_dict = model( images=image, images_clip=image_clip, input_ids=input_ids, labels=labels, attention_masks=attention_masks, offset=offset, masks_list=masks_list, label_list=label_list, resize_list=resize_list, inference=True, ) pred_masks = output_dict["pred_masks"] for i, pred_mask in enumerate(pred_masks): if pred_mask.shape[0] == 0: continue pred_mask = pred_mask.detach().cpu().numpy()[0] pred_mask = pred_mask > 0 save_path = "{}/{}_mask_{}.jpg".format( args.vis_save_path, image_path.split("/")[-1].split(".")[0], i ) cv2.imwrite(save_path, pred_mask * 100) print("{} has been saved.".format(save_path)) save_path = "{}/{}_masked_img_{}.jpg".format( args.vis_save_path, image_path.split("/")[-1].split(".")[0], i ) save_img = image_np.copy() save_img[pred_mask] = ( image_np * 0.5 + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5 )[pred_mask] save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR) cv2.imwrite(save_path, save_img) print("{} has been saved.".format(save_path)) if __name__ == "__main__": main(sys.argv[1:])