| import argparse |
| import os |
| import re |
| import sys |
|
|
| import bleach |
| import cv2 |
| import gradio as gr |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from PIL import Image |
| from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor |
|
|
| from model.AffordanceVLM import AffordanceVLMForCausalLM |
| from model.llava import conversation as conversation_lib |
| from model.llava.mm_utils import tokenizer_image_token |
| from model.segment_anything.utils.transforms import ResizeLongestSide |
| from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, |
| DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX) |
|
|
| from datetime import datetime |
|
|
|
|
| def parse_args(args): |
| parser = argparse.ArgumentParser(description="AffordanceVLM chat") |
| parser.add_argument("--version", default="./exps/AffordanceVLM-7B") |
| parser.add_argument("--vis_save_path", default="./vis_output", type=str) |
| parser.add_argument( |
| "--precision", |
| default="bf16", |
| type=str, |
| choices=["fp32", "bf16", "fp16"], |
| help="precision for inference", |
| ) |
| parser.add_argument("--image_size", default=1024, type=int, help="image size") |
| parser.add_argument("--model_max_length", default=512, type=int) |
| parser.add_argument("--lora_r", default=8, type=int) |
| parser.add_argument( |
| "--vision-tower", default="openai/clip-vit-large-patch14", type=str |
| ) |
| parser.add_argument("--local-rank", default=0, type=int, help="node rank") |
| parser.add_argument("--load_in_8bit", action="store_true", default=False) |
| parser.add_argument("--load_in_4bit", action="store_true", default=False) |
| parser.add_argument("--use_mm_start_end", action="store_true", default=True) |
| parser.add_argument( |
| "--conv_type", |
| default="llava_v1", |
| type=str, |
| choices=["llava_v1", "llava_llama_2"], |
| ) |
| return parser.parse_args(args) |
|
|
|
|
| def preprocess( |
| x, |
| pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), |
| pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), |
| img_size=1024, |
| ) -> torch.Tensor: |
| """Normalize pixel values and pad to a square input.""" |
| |
| x = (x - pixel_mean) / pixel_std |
| |
| h, w = x.shape[-2:] |
| padh = img_size - h |
| padw = img_size - w |
| x = F.pad(x, (0, padw, 0, padh)) |
| return x |
|
|
| args = parse_args(sys.argv[1:]) |
| os.makedirs(args.vis_save_path, exist_ok=True) |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained( |
| args.version, |
| cache_dir=None, |
| model_max_length=args.model_max_length, |
| padding_side="right", |
| use_fast=False, |
| ) |
| tokenizer.pad_token = tokenizer.unk_token |
| args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
| args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0] |
|
|
| torch_dtype = torch.float32 |
| if args.precision == "bf16": |
| torch_dtype = torch.bfloat16 |
| elif args.precision == "fp16": |
| torch_dtype = torch.half |
|
|
| kwargs = {"torch_dtype": torch_dtype} |
| if args.load_in_4bit: |
| kwargs.update( |
| { |
| "torch_dtype": torch.half, |
| "load_in_4bit": True, |
| "quantization_config": BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.float16, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| llm_int8_skip_modules=["visual_model"], |
| ), |
| } |
| ) |
| elif args.load_in_8bit: |
| kwargs.update( |
| { |
| "torch_dtype": torch.half, |
| "quantization_config": BitsAndBytesConfig( |
| llm_int8_skip_modules=["visual_model"], |
| load_in_8bit=True, |
| ), |
| } |
| ) |
|
|
| model = AffordanceVLMForCausalLM.from_pretrained( |
| args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, aff_token_idx=args.aff_token_idx, **kwargs |
| ) |
|
|
| model.config.eos_token_id = tokenizer.eos_token_id |
| model.config.bos_token_id = tokenizer.bos_token_id |
| model.config.pad_token_id = tokenizer.pad_token_id |
|
|
| model.get_model().initialize_vision_modules(model.get_model().config) |
| vision_tower = model.get_model().get_vision_tower() |
| vision_tower.to(dtype=torch_dtype) |
|
|
| if args.precision == "bf16": |
| model = model.bfloat16().cuda() |
| elif ( |
| args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit) |
| ): |
| vision_tower = model.get_model().get_vision_tower() |
| model.model.vision_tower = None |
| import deepspeed |
|
|
| model_engine = deepspeed.init_inference( |
| model=model, |
| dtype=torch.half, |
| replace_with_kernel_inject=True, |
| replace_method="auto", |
| ) |
| model = model_engine.module |
| model.model.vision_tower = vision_tower.half().cuda() |
| elif args.precision == "fp32": |
| model = model.float().cuda() |
|
|
| vision_tower = model.get_model().get_vision_tower() |
| vision_tower.to(device=args.local_rank) |
|
|
| clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower) |
| transform = ResizeLongestSide(args.image_size) |
|
|
| model.eval() |
|
|
|
|
| |
| examples = [ |
| [ |
| "Please segment the affordance map of mug in this image.", |
| "/data/AffordanceNet/vis_output/my_workspace.JPG", |
| ], |
| ] |
| output_labels = ["Segmentation Output"] |
|
|
| title = "RAGNet: Large-scale Reasoning-based Affordance Segmentation Benchmark towards General Grasping" |
|
|
| description = """ |
| <font size=4> |
| This is the online demo of AffordanceVLM. \n |
| **Note**: **Different prompts can lead to significantly varied results**. \n |
| **Note**: Please try to **standardize** your input text prompts to **avoid ambiguity**, and also pay attention to whether the **punctuations** of the input are correct. \n |
| **Note**: Current model is **AffordanceVLM-7B**. \n |
| **Usage**: <br> |
| To let AffordanceVLM **segment something**, input prompt like: "Can you segment the affordance map of xxx in this image?", "What is the affordance map of xxx in this image?"; <br> |
| </font> |
| """ |
|
|
| article = """ |
| <p style='text-align: center'> |
| <a href='https://arxiv.org/abs/2507.23734' target='_blank'> |
| Preprint Paper |
| </a> |
| \n |
| <p style='text-align: center'> |
| <a href='https://github.com/wudongming97/AffordanceNet' target='_blank'> Github Repo </a></p> |
| """ |
|
|
|
|
| |
| def inference(input_str, input_image): |
| |
| input_str = bleach.clean(input_str) |
|
|
| print("input_str: ", input_str, "input_image: ", input_image) |
|
|
| |
| if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1: |
| output_str = "[Error] Invalid input: ", input_str |
| |
| |
| output_image = cv2.imread("./resources/error_happened.png")[:, :, ::-1] |
| return output_image, output_str |
|
|
| |
| conv = conversation_lib.conv_templates[args.conv_type].copy() |
| conv.messages = [] |
|
|
| prompt = input_str |
| prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + prompt |
| if args.use_mm_start_end: |
| replace_token = ( |
| DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN |
| ) |
| prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) |
|
|
| conv.append_message(conv.roles[0], prompt) |
| conv.append_message(conv.roles[1], "") |
| prompt = conv.get_prompt() |
|
|
| image_np = cv2.imread(input_image) |
|
|
| |
| SAVE_DIR = "./gradio_images/" |
| os.makedirs(SAVE_DIR, exist_ok=True) |
|
|
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| filename = f"{timestamp}.png" |
| save_path = os.path.join(SAVE_DIR, filename) |
|
|
| |
| cv2.imwrite(save_path, image_np) |
| image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) |
| original_size_list = [image_np.shape[:2]] |
|
|
| image_clip = ( |
| clip_image_processor.preprocess(image_np, return_tensors="pt")[ |
| "pixel_values" |
| ][0] |
| .unsqueeze(0) |
| .cuda() |
| ) |
| if args.precision == "bf16": |
| image_clip = image_clip.bfloat16() |
| elif args.precision == "fp16": |
| image_clip = image_clip.half() |
| else: |
| image_clip = image_clip.float() |
|
|
| image = transform.apply_image(image_np) |
| resize_list = [image.shape[:2]] |
|
|
| image = ( |
| preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) |
| .unsqueeze(0) |
| .cuda() |
| ) |
| if args.precision == "bf16": |
| image = image.bfloat16() |
| elif args.precision == "fp16": |
| image = image.half() |
| else: |
| image = image.float() |
|
|
| input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt") |
| input_ids = input_ids.unsqueeze(0).cuda() |
|
|
| output_ids, pred_masks = model.evaluate( |
| image_clip, |
| image, |
| input_ids, |
| resize_list, |
| original_size_list, |
| max_new_tokens=512, |
| tokenizer=tokenizer, |
| ) |
| output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX] |
|
|
| text_output = tokenizer.decode(output_ids, skip_special_tokens=False) |
| text_output = text_output.replace("\n", "").replace(" ", " ") |
| text_output = text_output.split("ASSISTANT: ")[-1].replace('</s>', '') |
|
|
| print("text_output: ", text_output) |
| save_img = None |
| for i, pred_mask in enumerate(pred_masks): |
| if pred_mask.shape[0] == 0: |
| continue |
|
|
| pred_mask = pred_mask.detach().cpu().numpy()[0] |
| pred_mask = pred_mask > 0 |
|
|
| save_img = image_np.copy() |
| save_img[pred_mask] = ( |
| image_np * 0.5 |
| + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5 |
| )[pred_mask] |
|
|
| output_str = "ASSITANT: " + text_output |
| if save_img is not None: |
| output_image = save_img |
| else: |
| |
| output_image = cv2.imread("./resources/no_seg_out.png")[:, :, ::-1] |
| return output_image, output_str |
|
|
|
|
| demo = gr.Interface( |
| inference, |
| inputs=[ |
| gr.Textbox(lines=1, placeholder=None, label="Text Instruction"), |
| gr.Image(type="filepath", label="Input Image"), |
| ], |
| outputs=[ |
| gr.Image(type="pil", label="Affordance Output"), |
| gr.Textbox(lines=1, placeholder=None, label="Text Output"), |
| ], |
| title=title, |
| description=description, |
| article=article, |
| examples=examples, |
| allow_flagging="auto", |
| ) |
|
|
| demo.queue() |
| |
| demo.launch(server_name="0.0.0.0", server_port=3200) |