RAGNet / chat_prefill.py
wangzeze's picture
Upload folder using huggingface_hub
0453c63 verified
"""
Interactive affordance mask generation using prefill mode (single forward pass).
Same interactive workflow as chat.py, but uses prefill inference instead of
autoregressive generation. The assistant response "[AFF]." is pre-filled in the
prompt, so the model only does one forward pass to extract mask embeddings.
"""
import argparse
import os
import sys
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
from model.AffordanceVLM import AffordanceVLMForCausalLM
from model.llava import conversation as conversation_lib
from model.llava.mm_utils import tokenizer_image_token
from model.segment_anything.utils.transforms import ResizeLongestSide
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
def parse_args(args):
parser = argparse.ArgumentParser(description="AffordanceVLM chat (prefill mode)")
parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B")
parser.add_argument("--vis_save_path", default="./vis_output_prefill", type=str)
parser.add_argument(
"--precision", default="bf16", type=str,
choices=["fp32", "bf16", "fp16"],
)
parser.add_argument("--image_size", default=1024, type=int)
parser.add_argument("--model_max_length", default=512, type=int)
parser.add_argument("--lora_r", default=8, type=int)
parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14", type=str)
parser.add_argument("--local-rank", default=0, type=int)
parser.add_argument("--load_in_8bit", action="store_true", default=False)
parser.add_argument("--load_in_4bit", action="store_true", default=False)
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
parser.add_argument(
"--conv_type", default="llava_v1", type=str,
choices=["llava_v1", "llava_llama_2"],
)
parser.add_argument("--prompt_template", type=str,
default="Segment the most suitable manipulation region on the single target object for the task '{}'.",
help="Template wrapping language_instruction. Use {} as placeholder.")
# Segment the most suitable manipulation region on the single target object for the task '{}'.
# Segment the affordance map for the task '{}' in this image.
# Segment the affordance map of the single target object for the task '{}' in this image.
# Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask.
# Given the task instruction '{}', what is the affordance map of the single target object in this image? There is only one target object. Please output segmentation mask.
return parser.parse_args(args)
def preprocess(
x,
pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
img_size=1024,
) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
x = (x - pixel_mean) / pixel_std
h, w = x.shape[-2:]
padh = img_size - h
padw = img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
def main(args):
args = parse_args(args)
os.makedirs(args.vis_save_path, exist_ok=True)
# Create model
tokenizer = AutoTokenizer.from_pretrained(
args.version,
cache_dir=None,
model_max_length=args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.add_tokens("[SEG]")
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
tokenizer.add_tokens("[AFF]")
args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
torch_dtype = torch.float32
if args.precision == "bf16":
torch_dtype = torch.bfloat16
elif args.precision == "fp16":
torch_dtype = torch.half
kwargs = {"torch_dtype": torch_dtype}
if args.load_in_4bit:
kwargs.update({
"torch_dtype": torch.half,
"load_in_4bit": True,
"quantization_config": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
llm_int8_skip_modules=["visual_model"],
),
})
elif args.load_in_8bit:
kwargs.update({
"torch_dtype": torch.half,
"quantization_config": BitsAndBytesConfig(
llm_int8_skip_modules=["visual_model"],
load_in_8bit=True,
),
})
model = AffordanceVLMForCausalLM.from_pretrained(
args.version,
low_cpu_mem_usage=True,
vision_tower=args.vision_tower,
seg_token_idx=args.seg_token_idx,
aff_token_idx=args.aff_token_idx,
**kwargs,
)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.get_model().initialize_vision_modules(model.get_model().config)
vision_tower = model.get_model().get_vision_tower()
vision_tower.to(dtype=torch_dtype)
if args.precision == "bf16":
model = model.bfloat16().cuda()
elif args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit):
vision_tower = model.get_model().get_vision_tower()
model.model.vision_tower = None
import deepspeed
model_engine = deepspeed.init_inference(
model=model,
dtype=torch.half,
replace_with_kernel_inject=True,
replace_method="auto",
)
model = model_engine.module
model.model.vision_tower = vision_tower.half().cuda()
elif args.precision == "fp32":
model = model.float().cuda()
vision_tower = model.get_model().get_vision_tower()
vision_tower.to(device=args.local_rank)
clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
transform = ResizeLongestSide(args.image_size)
model.eval()
# debug
template = "Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask."
while True:
conv = conversation_lib.conv_templates[args.conv_type].copy()
conv.messages = []
prompt = input("Please input your prompt: ")
# 加入模版
prompt = args.prompt_template.format(prompt)
prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + prompt
if args.use_mm_start_end:
replace_token = (
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
)
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], "[AFF].")
prompt = conv.get_prompt()
image_path = input("Please input the image path: ")
if not os.path.exists(image_path):
print("File not found in {}".format(image_path))
continue
image_np = cv2.imread(image_path)
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
original_size_list = [image_np.shape[:2]]
h, w = original_size_list[0]
image_clip = (
clip_image_processor.preprocess(image_np, return_tensors="pt")[
"pixel_values"
][0]
.unsqueeze(0)
.cuda()
)
if args.precision == "bf16":
image_clip = image_clip.bfloat16()
elif args.precision == "fp16":
image_clip = image_clip.half()
else:
image_clip = image_clip.float()
image = transform.apply_image(image_np)
resize_list = [image.shape[:2]]
image = (
preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
.unsqueeze(0)
.cuda()
)
if args.precision == "bf16":
image = image.bfloat16()
elif args.precision == "fp16":
image = image.half()
else:
image = image.float()
input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
input_ids = input_ids.unsqueeze(0).cuda()
attention_masks = input_ids.ne(tokenizer.pad_token_id)
# Print the full prompt text (prefill mode has no generated text)
# debug
text_ids = input_ids[0][input_ids[0] != IMAGE_TOKEN_INDEX]
text_output = tokenizer.decode(text_ids, skip_special_tokens=False)
text_output = text_output.replace("\n", "").replace(" ", " ")
print("text_output: ", text_output)
# Prefill inference
labels = input_ids.clone()
offset = torch.LongTensor([0, 1]).cuda()
masks_list = [torch.zeros(1, h, w).float().cuda()]
label_list = [torch.zeros(h, w).long().cuda()]
with torch.no_grad():
output_dict = model(
images=image,
images_clip=image_clip,
input_ids=input_ids,
labels=labels,
attention_masks=attention_masks,
offset=offset,
masks_list=masks_list,
label_list=label_list,
resize_list=resize_list,
inference=True,
)
pred_masks = output_dict["pred_masks"]
for i, pred_mask in enumerate(pred_masks):
if pred_mask.shape[0] == 0:
continue
pred_mask = pred_mask.detach().cpu().numpy()[0]
pred_mask = pred_mask > 0
save_path = "{}/{}_mask_{}.jpg".format(
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
)
cv2.imwrite(save_path, pred_mask * 100)
print("{} has been saved.".format(save_path))
save_path = "{}/{}_masked_img_{}.jpg".format(
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
)
save_img = image_np.copy()
save_img[pred_mask] = (
image_np * 0.5
+ pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
)[pred_mask]
save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
cv2.imwrite(save_path, save_img)
print("{} has been saved.".format(save_path))
if __name__ == "__main__":
main(sys.argv[1:])