Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +7 -0
- .ipynb_checkpoints/batch_generate-checkpoint.py +401 -0
- .ipynb_checkpoints/batch_generate-checkpoint.sh +14 -0
- .ipynb_checkpoints/batch_generate_prefill_accelerate-checkpoint.py +418 -0
- .ipynb_checkpoints/chat-checkpoint.py +255 -0
- .ipynb_checkpoints/chat_prefill-checkpoint.py +282 -0
- .ipynb_checkpoints/train_aff-checkpoint.py +620 -0
- README.md +79 -3
- app.py +329 -0
- batch_generate.sh +14 -0
- batch_generate_prefill_accelerate.py +418 -0
- chat.py +255 -0
- chat_prefill.py +282 -0
- ckpts/AffordanceVLM-7B/.gitattributes +35 -0
- ckpts/AffordanceVLM-7B/README.md +3 -0
- ckpts/AffordanceVLM-7B/added_tokens.json +7 -0
- ckpts/AffordanceVLM-7B/config.json +42 -0
- ckpts/AffordanceVLM-7B/eval_result.txt +1 -0
- ckpts/AffordanceVLM-7B/generation_config.json +7 -0
- ckpts/AffordanceVLM-7B/pytorch_model-00001-of-00002.bin +3 -0
- ckpts/AffordanceVLM-7B/pytorch_model-00002-of-00002.bin +3 -0
- ckpts/AffordanceVLM-7B/pytorch_model.bin.index.json +930 -0
- ckpts/AffordanceVLM-7B/special_tokens_map.json +24 -0
- ckpts/AffordanceVLM-7B/tokenizer.model +3 -0
- ckpts/AffordanceVLM-7B/tokenizer_config.json +35 -0
- ckpts/sam_vit_h_4b8939.pth +3 -0
- client.py +67 -0
- data_curation/.ipynb_checkpoints/check_dataset-checkpoint.py +100 -0
- data_curation/build_vlpart.py +105 -0
- data_curation/check_dataset.py +100 -0
- data_curation/prompt_generation_handal_easy_reasoning.py +126 -0
- data_curation/prompt_generation_handal_hard_reasoning.py +136 -0
- data_curation/vlpart_sam2_tracking.py +187 -0
- docs/dataset.md +93 -0
- docs/installation.md +10 -0
- docs/training_and_evaluation.md +56 -0
- imgs/.ipynb_checkpoints/AffordanceNet-checkpoint.jpg +3 -0
- imgs/AffordanceNet.jpg +3 -0
- imgs/AffordanceNet.png +3 -0
- merge_lora_weights_and_save_hf_model.py +162 -0
- model/AffordanceVLM.py +428 -0
- model/__pycache__/AffordanceVLM.cpython-39.pyc +0 -0
- model/llava/__init__.py +1 -0
- model/llava/__pycache__/__init__.cpython-39.pyc +0 -0
- model/llava/__pycache__/constants.cpython-39.pyc +0 -0
- model/llava/__pycache__/conversation.cpython-39.pyc +0 -0
- model/llava/__pycache__/mm_utils.cpython-39.pyc +0 -0
- model/llava/constants.py +12 -0
- model/llava/conversation.py +399 -0
- model/llava/mm_utils.py +88 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
imgs/.ipynb_checkpoints/AffordanceNet-checkpoint.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
imgs/AffordanceNet.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
imgs/AffordanceNet.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
vis_output/.ipynb_checkpoints/my_workspace-checkpoint.JPG filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
vis_output/.ipynb_checkpoints/my_workspace_masked_img_0-checkpoint.jpg filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
vis_output/my_workspace.JPG filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
vis_output/my_workspace_masked_img_0.jpg filter=lfs diff=lfs merge=lfs -text
|
.ipynb_checkpoints/batch_generate-checkpoint.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Batch affordance mask generation for per-step datasets.
|
| 3 |
+
|
| 4 |
+
Reads a per-step dataset (converted by convert_lerobot_to_perstep.py) and
|
| 5 |
+
generates affordance masks for every image_primary.jpg and image_wrist.jpg
|
| 6 |
+
using AffordanceVLM.
|
| 7 |
+
|
| 8 |
+
Input structure:
|
| 9 |
+
{data_dir}/
|
| 10 |
+
├── meta_info.h5
|
| 11 |
+
└── episodes/
|
| 12 |
+
└── {episode_id:06d}/
|
| 13 |
+
└── steps/
|
| 14 |
+
└── {step_id:04d}/
|
| 15 |
+
├── other.h5 # language_instruction
|
| 16 |
+
├── image_primary.jpg
|
| 17 |
+
└── image_wrist.jpg
|
| 18 |
+
|
| 19 |
+
Output structure:
|
| 20 |
+
{save_dir}/
|
| 21 |
+
└── episode_{episode_id}/
|
| 22 |
+
└── steps/
|
| 23 |
+
└── step_{step_id}/
|
| 24 |
+
├── image_primary_mask.png # binary 0/255
|
| 25 |
+
└── image_wrist_mask.png
|
| 26 |
+
|
| 27 |
+
Usage:
|
| 28 |
+
python batch_generate.py \
|
| 29 |
+
--data_dir /path/to/perstep_dataset \
|
| 30 |
+
--save_dir /path/to/mask_output \
|
| 31 |
+
--start_episode 0 --end_episode 10
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
import argparse
|
| 35 |
+
import os
|
| 36 |
+
import sys
|
| 37 |
+
from pathlib import Path
|
| 38 |
+
|
| 39 |
+
import cv2
|
| 40 |
+
import h5py
|
| 41 |
+
import numpy as np
|
| 42 |
+
import torch
|
| 43 |
+
import torch.nn.functional as F
|
| 44 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
|
| 45 |
+
|
| 46 |
+
from model.AffordanceVLM import AffordanceVLMForCausalLM
|
| 47 |
+
from model.llava import conversation as conversation_lib
|
| 48 |
+
from model.llava.mm_utils import tokenizer_image_token
|
| 49 |
+
from model.segment_anything.utils.transforms import ResizeLongestSide
|
| 50 |
+
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
| 51 |
+
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def parse_args(args):
|
| 55 |
+
parser = argparse.ArgumentParser(
|
| 56 |
+
description="Batch affordance mask generation for per-step datasets"
|
| 57 |
+
)
|
| 58 |
+
# Model arguments (same as chat.py)
|
| 59 |
+
parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B")
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--precision", default="bf16", type=str,
|
| 62 |
+
choices=["fp32", "bf16", "fp16"],
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument("--image_size", default=1024, type=int)
|
| 65 |
+
parser.add_argument("--model_max_length", default=512, type=int)
|
| 66 |
+
parser.add_argument("--lora_r", default=8, type=int)
|
| 67 |
+
parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14", type=str)
|
| 68 |
+
parser.add_argument("--local-rank", default=0, type=int)
|
| 69 |
+
parser.add_argument("--load_in_8bit", action="store_true", default=False)
|
| 70 |
+
parser.add_argument("--load_in_4bit", action="store_true", default=False)
|
| 71 |
+
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--conv_type", default="llava_v1", type=str,
|
| 74 |
+
choices=["llava_v1", "llava_llama_2"],
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Batch processing arguments
|
| 78 |
+
parser.add_argument("--data_dir", type=str, required=True,
|
| 79 |
+
help="Root of per-step dataset (contains episodes/)")
|
| 80 |
+
parser.add_argument("--save_dir", type=str, required=True,
|
| 81 |
+
help="Output directory for masks")
|
| 82 |
+
parser.add_argument("--prompt_template", type=str,
|
| 83 |
+
default="{}",
|
| 84 |
+
help="Template wrapping language_instruction. Use {} as placeholder.")
|
| 85 |
+
parser.add_argument("--start_episode", type=int, default=None,
|
| 86 |
+
help="First episode index to process (inclusive)")
|
| 87 |
+
parser.add_argument("--end_episode", type=int, default=None,
|
| 88 |
+
help="Last episode index to process (exclusive)")
|
| 89 |
+
return parser.parse_args(args)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def preprocess(
|
| 93 |
+
x,
|
| 94 |
+
pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
|
| 95 |
+
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
|
| 96 |
+
img_size=1024,
|
| 97 |
+
) -> torch.Tensor:
|
| 98 |
+
"""Normalize pixel values and pad to a square input."""
|
| 99 |
+
x = (x - pixel_mean) / pixel_std
|
| 100 |
+
h, w = x.shape[-2:]
|
| 101 |
+
padh = img_size - h
|
| 102 |
+
padw = img_size - w
|
| 103 |
+
x = F.pad(x, (0, padw, 0, padh))
|
| 104 |
+
return x
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def load_model(args):
|
| 108 |
+
"""Load tokenizer and model, identical to chat.py."""
|
| 109 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 110 |
+
args.version,
|
| 111 |
+
cache_dir=None,
|
| 112 |
+
model_max_length=args.model_max_length,
|
| 113 |
+
padding_side="right",
|
| 114 |
+
use_fast=False,
|
| 115 |
+
)
|
| 116 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 117 |
+
tokenizer.add_tokens("[SEG]")
|
| 118 |
+
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 119 |
+
tokenizer.add_tokens("[AFF]")
|
| 120 |
+
args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
|
| 121 |
+
|
| 122 |
+
torch_dtype = torch.float32
|
| 123 |
+
if args.precision == "bf16":
|
| 124 |
+
torch_dtype = torch.bfloat16
|
| 125 |
+
elif args.precision == "fp16":
|
| 126 |
+
torch_dtype = torch.half
|
| 127 |
+
|
| 128 |
+
kwargs = {"torch_dtype": torch_dtype}
|
| 129 |
+
if args.load_in_4bit:
|
| 130 |
+
kwargs.update({
|
| 131 |
+
"torch_dtype": torch.half,
|
| 132 |
+
"load_in_4bit": True,
|
| 133 |
+
"quantization_config": BitsAndBytesConfig(
|
| 134 |
+
load_in_4bit=True,
|
| 135 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 136 |
+
bnb_4bit_use_double_quant=True,
|
| 137 |
+
bnb_4bit_quant_type="nf4",
|
| 138 |
+
llm_int8_skip_modules=["visual_model"],
|
| 139 |
+
),
|
| 140 |
+
})
|
| 141 |
+
elif args.load_in_8bit:
|
| 142 |
+
kwargs.update({
|
| 143 |
+
"torch_dtype": torch.half,
|
| 144 |
+
"quantization_config": BitsAndBytesConfig(
|
| 145 |
+
llm_int8_skip_modules=["visual_model"],
|
| 146 |
+
load_in_8bit=True,
|
| 147 |
+
),
|
| 148 |
+
})
|
| 149 |
+
|
| 150 |
+
model = AffordanceVLMForCausalLM.from_pretrained(
|
| 151 |
+
args.version,
|
| 152 |
+
low_cpu_mem_usage=True,
|
| 153 |
+
vision_tower=args.vision_tower,
|
| 154 |
+
seg_token_idx=args.seg_token_idx,
|
| 155 |
+
aff_token_idx=args.aff_token_idx,
|
| 156 |
+
**kwargs,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 160 |
+
model.config.bos_token_id = tokenizer.bos_token_id
|
| 161 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 162 |
+
|
| 163 |
+
model.get_model().initialize_vision_modules(model.get_model().config)
|
| 164 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 165 |
+
vision_tower.to(dtype=torch_dtype)
|
| 166 |
+
|
| 167 |
+
if args.precision == "bf16":
|
| 168 |
+
model = model.bfloat16().cuda()
|
| 169 |
+
elif args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit):
|
| 170 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 171 |
+
model.model.vision_tower = None
|
| 172 |
+
import deepspeed
|
| 173 |
+
model_engine = deepspeed.init_inference(
|
| 174 |
+
model=model,
|
| 175 |
+
dtype=torch.half,
|
| 176 |
+
replace_with_kernel_inject=True,
|
| 177 |
+
replace_method="auto",
|
| 178 |
+
)
|
| 179 |
+
model = model_engine.module
|
| 180 |
+
model.model.vision_tower = vision_tower.half().cuda()
|
| 181 |
+
elif args.precision == "fp32":
|
| 182 |
+
model = model.float().cuda()
|
| 183 |
+
|
| 184 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 185 |
+
vision_tower.to(device=args.local_rank)
|
| 186 |
+
|
| 187 |
+
clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
|
| 188 |
+
transform = ResizeLongestSide(args.image_size)
|
| 189 |
+
|
| 190 |
+
model.eval()
|
| 191 |
+
return model, tokenizer, clip_image_processor, transform
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def build_prompt(text: str, args) -> str:
|
| 195 |
+
"""Build the full conversation prompt from a text query."""
|
| 196 |
+
conv = conversation_lib.conv_templates[args.conv_type].copy()
|
| 197 |
+
conv.messages = []
|
| 198 |
+
|
| 199 |
+
prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + text
|
| 200 |
+
if args.use_mm_start_end:
|
| 201 |
+
replace_token = (
|
| 202 |
+
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
| 203 |
+
)
|
| 204 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 205 |
+
|
| 206 |
+
conv.append_message(conv.roles[0], prompt)
|
| 207 |
+
conv.append_message(conv.roles[1], "")
|
| 208 |
+
return conv.get_prompt()
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def infer_single_image(
|
| 212 |
+
image_path: str,
|
| 213 |
+
prompt_str: str,
|
| 214 |
+
model,
|
| 215 |
+
tokenizer,
|
| 216 |
+
clip_image_processor,
|
| 217 |
+
transform,
|
| 218 |
+
args,
|
| 219 |
+
) -> "np.ndarray | None":
|
| 220 |
+
"""Run inference on a single image. Returns binary mask (H, W) uint8 0/255 or None."""
|
| 221 |
+
image_np = cv2.imread(image_path)
|
| 222 |
+
if image_np is None:
|
| 223 |
+
print(f" [WARNING] Cannot read image: {image_path}")
|
| 224 |
+
return None
|
| 225 |
+
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
|
| 226 |
+
original_size_list = [image_np.shape[:2]]
|
| 227 |
+
|
| 228 |
+
# CLIP preprocessing
|
| 229 |
+
image_clip = (
|
| 230 |
+
clip_image_processor.preprocess(image_np, return_tensors="pt")["pixel_values"][0]
|
| 231 |
+
.unsqueeze(0)
|
| 232 |
+
.cuda()
|
| 233 |
+
)
|
| 234 |
+
if args.precision == "bf16":
|
| 235 |
+
image_clip = image_clip.bfloat16()
|
| 236 |
+
elif args.precision == "fp16":
|
| 237 |
+
image_clip = image_clip.half()
|
| 238 |
+
else:
|
| 239 |
+
image_clip = image_clip.float()
|
| 240 |
+
|
| 241 |
+
# SAM preprocessing
|
| 242 |
+
image = transform.apply_image(image_np)
|
| 243 |
+
resize_list = [image.shape[:2]]
|
| 244 |
+
image = (
|
| 245 |
+
preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
|
| 246 |
+
.unsqueeze(0)
|
| 247 |
+
.cuda()
|
| 248 |
+
)
|
| 249 |
+
if args.precision == "bf16":
|
| 250 |
+
image = image.bfloat16()
|
| 251 |
+
elif args.precision == "fp16":
|
| 252 |
+
image = image.half()
|
| 253 |
+
else:
|
| 254 |
+
image = image.float()
|
| 255 |
+
|
| 256 |
+
# Tokenize
|
| 257 |
+
input_ids = tokenizer_image_token(prompt_str, tokenizer, return_tensors="pt")
|
| 258 |
+
input_ids = input_ids.unsqueeze(0).cuda()
|
| 259 |
+
|
| 260 |
+
# Inference
|
| 261 |
+
with torch.no_grad():
|
| 262 |
+
output_ids, pred_masks = model.evaluate(
|
| 263 |
+
image_clip,
|
| 264 |
+
image,
|
| 265 |
+
input_ids,
|
| 266 |
+
resize_list,
|
| 267 |
+
original_size_list,
|
| 268 |
+
max_new_tokens=512,
|
| 269 |
+
tokenizer=tokenizer,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Merge all predicted masks via union (logical OR)
|
| 273 |
+
h, w = original_size_list[0]
|
| 274 |
+
merged = np.zeros((h, w), dtype=bool)
|
| 275 |
+
has_mask = False
|
| 276 |
+
for pred_mask in pred_masks:
|
| 277 |
+
if pred_mask.shape[0] == 0:
|
| 278 |
+
continue
|
| 279 |
+
mask_np = pred_mask.detach().cpu().numpy()[0] # (H, W)
|
| 280 |
+
merged |= (mask_np > 0)
|
| 281 |
+
has_mask = True
|
| 282 |
+
|
| 283 |
+
if not has_mask:
|
| 284 |
+
return None
|
| 285 |
+
|
| 286 |
+
return (merged.astype(np.uint8) * 255)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def read_language_instruction(h5_path: str) -> str:
|
| 290 |
+
"""Read language_instruction from other.h5."""
|
| 291 |
+
with h5py.File(h5_path, "r") as f:
|
| 292 |
+
instr = f["language_instruction"][()]
|
| 293 |
+
if isinstance(instr, bytes):
|
| 294 |
+
instr = instr.decode("utf-8")
|
| 295 |
+
return str(instr)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def main(args):
|
| 299 |
+
args = parse_args(args)
|
| 300 |
+
data_dir = Path(args.data_dir)
|
| 301 |
+
save_dir = Path(args.save_dir)
|
| 302 |
+
|
| 303 |
+
episodes_dir = data_dir / "episodes"
|
| 304 |
+
if not episodes_dir.is_dir():
|
| 305 |
+
print(f"Error: episodes directory not found at {episodes_dir}")
|
| 306 |
+
sys.exit(1)
|
| 307 |
+
|
| 308 |
+
# Collect and sort episode directories
|
| 309 |
+
episode_dirs = sorted(
|
| 310 |
+
[d for d in episodes_dir.iterdir() if d.is_dir()],
|
| 311 |
+
key=lambda p: p.name,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Filter by episode range
|
| 315 |
+
if args.start_episode is not None or args.end_episode is not None:
|
| 316 |
+
start = args.start_episode if args.start_episode is not None else 0
|
| 317 |
+
end = args.end_episode if args.end_episode is not None else len(episode_dirs)
|
| 318 |
+
episode_dirs = [
|
| 319 |
+
d for d in episode_dirs
|
| 320 |
+
if start <= int(d.name) < end
|
| 321 |
+
]
|
| 322 |
+
|
| 323 |
+
print(f"Data dir : {data_dir}")
|
| 324 |
+
print(f"Save dir : {save_dir}")
|
| 325 |
+
print(f"Episodes : {len(episode_dirs)}")
|
| 326 |
+
print(f"Prompt : {args.prompt_template}")
|
| 327 |
+
print()
|
| 328 |
+
|
| 329 |
+
# Load model
|
| 330 |
+
print("Loading model...")
|
| 331 |
+
model, tokenizer, clip_image_processor, transform = load_model(args)
|
| 332 |
+
print("Model loaded.\n")
|
| 333 |
+
|
| 334 |
+
total_steps = 0
|
| 335 |
+
empty_mask_count = 0
|
| 336 |
+
|
| 337 |
+
for ep_dir in episode_dirs:
|
| 338 |
+
episode_id = ep_dir.name # e.g. "000000"
|
| 339 |
+
steps_dir = ep_dir / "steps"
|
| 340 |
+
if not steps_dir.is_dir():
|
| 341 |
+
print(f" [WARNING] No steps/ in {ep_dir}, skipping.")
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
step_dirs = sorted(
|
| 345 |
+
[d for d in steps_dir.iterdir() if d.is_dir()],
|
| 346 |
+
key=lambda p: p.name,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
for step_dir in step_dirs:
|
| 350 |
+
step_id = step_dir.name # e.g. "0000"
|
| 351 |
+
|
| 352 |
+
# Read language instruction
|
| 353 |
+
other_h5 = step_dir / "other.h5"
|
| 354 |
+
if not other_h5.exists():
|
| 355 |
+
print(f" [WARNING] Missing other.h5 in {step_dir}, skipping.")
|
| 356 |
+
continue
|
| 357 |
+
language_instruction = read_language_instruction(str(other_h5))
|
| 358 |
+
# debug
|
| 359 |
+
# print(language_instruction)
|
| 360 |
+
|
| 361 |
+
# Build prompt
|
| 362 |
+
query_text = args.prompt_template.format(language_instruction)
|
| 363 |
+
prompt_str = build_prompt(query_text, args)
|
| 364 |
+
|
| 365 |
+
# Output directory (same structure as input: episodes/{episode_id}/steps/{step_id}/)
|
| 366 |
+
out_dir = save_dir / "episodes" / episode_id / "steps" / step_id
|
| 367 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 368 |
+
|
| 369 |
+
# Process both cameras
|
| 370 |
+
for cam_name in ("image_primary", "image_wrist"):
|
| 371 |
+
img_path = step_dir / f"{cam_name}.jpg"
|
| 372 |
+
mask_path = out_dir / f"{cam_name}_mask.png"
|
| 373 |
+
|
| 374 |
+
if not img_path.exists():
|
| 375 |
+
print(f" [WARNING] Missing {img_path}, skipping.")
|
| 376 |
+
continue
|
| 377 |
+
|
| 378 |
+
mask = infer_single_image(
|
| 379 |
+
str(img_path), prompt_str,
|
| 380 |
+
model, tokenizer, clip_image_processor, transform, args,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
if mask is None:
|
| 384 |
+
# Save blank mask and warn
|
| 385 |
+
h, w = cv2.imread(str(img_path)).shape[:2]
|
| 386 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 387 |
+
empty_mask_count += 1
|
| 388 |
+
|
| 389 |
+
cv2.imwrite(str(mask_path), mask)
|
| 390 |
+
|
| 391 |
+
total_steps += 1
|
| 392 |
+
if total_steps % 50 == 0:
|
| 393 |
+
print(f" Processed {total_steps} steps (episode {episode_id}, step {step_id})")
|
| 394 |
+
|
| 395 |
+
print(f"Episode {episode_id} done ({len(step_dirs)} steps)")
|
| 396 |
+
|
| 397 |
+
print(f"\nFinished. {total_steps} steps processed, {empty_mask_count} empty masks.")
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
if __name__ == "__main__":
|
| 401 |
+
main(sys.argv[1:])
|
.ipynb_checkpoints/batch_generate-checkpoint.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Batch generate affordance masks for all four LIBERO subsets sequentially.
|
| 3 |
+
|
| 4 |
+
SRC_ROOT="/gemini/space/wrz/libero_per_frame"
|
| 5 |
+
TGT_ROOT="/gemini/space/wrz/ragnet_results"
|
| 6 |
+
|
| 7 |
+
for ds in libero_object libero_goal libero_spatial libero_10; do
|
| 8 |
+
echo "========== Processing ${ds} =========="
|
| 9 |
+
CUDA_VISIBLE_DEVICES=0 python batch_generate.py \
|
| 10 |
+
--data_dir "${SRC_ROOT}/${ds}_converted" \
|
| 11 |
+
--save_dir "${TGT_ROOT}/${ds}"
|
| 12 |
+
echo "========== ${ds} done =========="
|
| 13 |
+
echo
|
| 14 |
+
done
|
.ipynb_checkpoints/batch_generate_prefill_accelerate-checkpoint.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Batch affordance mask generation for per-step datasets.
|
| 3 |
+
|
| 4 |
+
Reads a per-step dataset (converted by convert_lerobot_to_perstep.py) and
|
| 5 |
+
generates affordance masks for every image_primary.jpg and image_wrist.jpg
|
| 6 |
+
using AffordanceVLM.
|
| 7 |
+
|
| 8 |
+
Input structure:
|
| 9 |
+
{data_dir}/
|
| 10 |
+
├── meta_info.h5
|
| 11 |
+
└── episodes/
|
| 12 |
+
└── {episode_id:06d}/
|
| 13 |
+
└── steps/
|
| 14 |
+
└── {step_id:04d}/
|
| 15 |
+
├── other.h5 # language_instruction
|
| 16 |
+
├── image_primary.jpg
|
| 17 |
+
└── image_wrist.jpg
|
| 18 |
+
|
| 19 |
+
Output structure:
|
| 20 |
+
{save_dir}/
|
| 21 |
+
└── episodes/
|
| 22 |
+
└── {episode_id:06d}/
|
| 23 |
+
└── steps/
|
| 24 |
+
└── {step_id:04d}/
|
| 25 |
+
├── image_primary_mask.png # binary 0/255
|
| 26 |
+
└── image_wrist_mask.png
|
| 27 |
+
|
| 28 |
+
Usage:
|
| 29 |
+
CUDA_VISIBLE_DEVICES=1 python batch_generate_prefill_accelerate.py \
|
| 30 |
+
--data_dir /gemini/space/wrz/libero_per_frame/libero_spatial_converted \
|
| 31 |
+
--save_dir /gemini/space/wrz/ragnet_results/libero_spatial
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
import argparse
|
| 35 |
+
import os
|
| 36 |
+
import sys
|
| 37 |
+
from pathlib import Path
|
| 38 |
+
|
| 39 |
+
import cv2
|
| 40 |
+
import h5py
|
| 41 |
+
import numpy as np
|
| 42 |
+
import torch
|
| 43 |
+
import torch.nn.functional as F
|
| 44 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
|
| 45 |
+
|
| 46 |
+
from model.AffordanceVLM import AffordanceVLMForCausalLM
|
| 47 |
+
from model.llava import conversation as conversation_lib
|
| 48 |
+
from model.llava.mm_utils import tokenizer_image_token
|
| 49 |
+
from model.segment_anything.utils.transforms import ResizeLongestSide
|
| 50 |
+
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
| 51 |
+
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def parse_args(args):
|
| 55 |
+
parser = argparse.ArgumentParser(
|
| 56 |
+
description="Batch affordance mask generation for per-step datasets"
|
| 57 |
+
)
|
| 58 |
+
# Model arguments (same as chat.py)
|
| 59 |
+
parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B")
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--precision", default="bf16", type=str,
|
| 62 |
+
choices=["fp32", "bf16", "fp16"],
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument("--image_size", default=1024, type=int)
|
| 65 |
+
parser.add_argument("--model_max_length", default=512, type=int)
|
| 66 |
+
parser.add_argument("--lora_r", default=8, type=int)
|
| 67 |
+
parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14", type=str)
|
| 68 |
+
parser.add_argument("--local-rank", default=0, type=int)
|
| 69 |
+
parser.add_argument("--load_in_8bit", action="store_true", default=False)
|
| 70 |
+
parser.add_argument("--load_in_4bit", action="store_true", default=False)
|
| 71 |
+
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--conv_type", default="llava_v1", type=str,
|
| 74 |
+
choices=["llava_v1", "llava_llama_2"],
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Batch processing arguments
|
| 78 |
+
parser.add_argument("--data_dir", type=str, required=True,
|
| 79 |
+
help="Root of per-step dataset (contains episodes/)")
|
| 80 |
+
parser.add_argument("--save_dir", type=str, required=True,
|
| 81 |
+
help="Output directory for masks")
|
| 82 |
+
parser.add_argument("--prompt_template", type=str,
|
| 83 |
+
default="{}",
|
| 84 |
+
help="Template wrapping language_instruction. Use {} as placeholder.")
|
| 85 |
+
# "{}"
|
| 86 |
+
# Segment the most suitable manipulation region on the single target object for the task '{}'.
|
| 87 |
+
# Segment the affordance map for the task '{}' in this image.
|
| 88 |
+
# Segment the affordance map of the single target object for the task '{}' in this image.
|
| 89 |
+
# Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask.
|
| 90 |
+
# Given the task instruction '{}', what is the affordance map of the single target object in this image? There is only one target object. Please output segmentation mask.
|
| 91 |
+
parser.add_argument("--start_episode", type=int, default=None,
|
| 92 |
+
help="First episode index to process (inclusive)")
|
| 93 |
+
parser.add_argument("--end_episode", type=int, default=None,
|
| 94 |
+
help="Last episode index to process (exclusive)")
|
| 95 |
+
return parser.parse_args(args)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def preprocess(
|
| 99 |
+
x,
|
| 100 |
+
pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
|
| 101 |
+
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
|
| 102 |
+
img_size=1024,
|
| 103 |
+
) -> torch.Tensor:
|
| 104 |
+
"""Normalize pixel values and pad to a square input."""
|
| 105 |
+
x = (x - pixel_mean) / pixel_std
|
| 106 |
+
h, w = x.shape[-2:]
|
| 107 |
+
padh = img_size - h
|
| 108 |
+
padw = img_size - w
|
| 109 |
+
x = F.pad(x, (0, padw, 0, padh))
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def load_model(args):
|
| 114 |
+
"""Load tokenizer and model, identical to chat.py."""
|
| 115 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 116 |
+
args.version,
|
| 117 |
+
cache_dir=None,
|
| 118 |
+
model_max_length=args.model_max_length,
|
| 119 |
+
padding_side="right",
|
| 120 |
+
use_fast=False,
|
| 121 |
+
)
|
| 122 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 123 |
+
tokenizer.add_tokens("[SEG]")
|
| 124 |
+
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 125 |
+
tokenizer.add_tokens("[AFF]")
|
| 126 |
+
args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
|
| 127 |
+
|
| 128 |
+
torch_dtype = torch.float32
|
| 129 |
+
if args.precision == "bf16":
|
| 130 |
+
torch_dtype = torch.bfloat16
|
| 131 |
+
elif args.precision == "fp16":
|
| 132 |
+
torch_dtype = torch.half
|
| 133 |
+
|
| 134 |
+
kwargs = {"torch_dtype": torch_dtype}
|
| 135 |
+
if args.load_in_4bit:
|
| 136 |
+
kwargs.update({
|
| 137 |
+
"torch_dtype": torch.half,
|
| 138 |
+
"load_in_4bit": True,
|
| 139 |
+
"quantization_config": BitsAndBytesConfig(
|
| 140 |
+
load_in_4bit=True,
|
| 141 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 142 |
+
bnb_4bit_use_double_quant=True,
|
| 143 |
+
bnb_4bit_quant_type="nf4",
|
| 144 |
+
llm_int8_skip_modules=["visual_model"],
|
| 145 |
+
),
|
| 146 |
+
})
|
| 147 |
+
elif args.load_in_8bit:
|
| 148 |
+
kwargs.update({
|
| 149 |
+
"torch_dtype": torch.half,
|
| 150 |
+
"quantization_config": BitsAndBytesConfig(
|
| 151 |
+
llm_int8_skip_modules=["visual_model"],
|
| 152 |
+
load_in_8bit=True,
|
| 153 |
+
),
|
| 154 |
+
})
|
| 155 |
+
|
| 156 |
+
model = AffordanceVLMForCausalLM.from_pretrained(
|
| 157 |
+
args.version,
|
| 158 |
+
low_cpu_mem_usage=True,
|
| 159 |
+
vision_tower=args.vision_tower,
|
| 160 |
+
seg_token_idx=args.seg_token_idx,
|
| 161 |
+
aff_token_idx=args.aff_token_idx,
|
| 162 |
+
**kwargs,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 166 |
+
model.config.bos_token_id = tokenizer.bos_token_id
|
| 167 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 168 |
+
|
| 169 |
+
model.get_model().initialize_vision_modules(model.get_model().config)
|
| 170 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 171 |
+
vision_tower.to(dtype=torch_dtype)
|
| 172 |
+
|
| 173 |
+
if args.precision == "bf16":
|
| 174 |
+
model = model.bfloat16().cuda()
|
| 175 |
+
elif args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit):
|
| 176 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 177 |
+
model.model.vision_tower = None
|
| 178 |
+
import deepspeed
|
| 179 |
+
model_engine = deepspeed.init_inference(
|
| 180 |
+
model=model,
|
| 181 |
+
dtype=torch.half,
|
| 182 |
+
replace_with_kernel_inject=True,
|
| 183 |
+
replace_method="auto",
|
| 184 |
+
)
|
| 185 |
+
model = model_engine.module
|
| 186 |
+
model.model.vision_tower = vision_tower.half().cuda()
|
| 187 |
+
elif args.precision == "fp32":
|
| 188 |
+
model = model.float().cuda()
|
| 189 |
+
|
| 190 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 191 |
+
vision_tower.to(device=args.local_rank)
|
| 192 |
+
|
| 193 |
+
clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
|
| 194 |
+
transform = ResizeLongestSide(args.image_size)
|
| 195 |
+
|
| 196 |
+
model.eval()
|
| 197 |
+
return model, tokenizer, clip_image_processor, transform
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def build_prompt(text: str, args) -> str:
|
| 201 |
+
"""Build the full conversation prompt from a text query."""
|
| 202 |
+
conv = conversation_lib.conv_templates[args.conv_type].copy()
|
| 203 |
+
conv.messages = []
|
| 204 |
+
|
| 205 |
+
prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + text
|
| 206 |
+
if args.use_mm_start_end:
|
| 207 |
+
replace_token = (
|
| 208 |
+
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
| 209 |
+
)
|
| 210 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 211 |
+
|
| 212 |
+
conv.append_message(conv.roles[0], prompt)
|
| 213 |
+
conv.append_message(conv.roles[1], "[AFF].")
|
| 214 |
+
return conv.get_prompt()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def infer_single_image(
|
| 218 |
+
image_path: str,
|
| 219 |
+
prompt_str: str,
|
| 220 |
+
model,
|
| 221 |
+
tokenizer,
|
| 222 |
+
clip_image_processor,
|
| 223 |
+
transform,
|
| 224 |
+
args,
|
| 225 |
+
) -> "np.ndarray | None":
|
| 226 |
+
"""Run inference on a single image. Returns binary mask (H, W) uint8 0/255 or None."""
|
| 227 |
+
image_np = cv2.imread(image_path)
|
| 228 |
+
if image_np is None:
|
| 229 |
+
print(f" [WARNING] Cannot read image: {image_path}")
|
| 230 |
+
return None
|
| 231 |
+
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
|
| 232 |
+
original_size_list = [image_np.shape[:2]]
|
| 233 |
+
|
| 234 |
+
# CLIP preprocessing
|
| 235 |
+
image_clip = (
|
| 236 |
+
clip_image_processor.preprocess(image_np, return_tensors="pt")["pixel_values"][0]
|
| 237 |
+
.unsqueeze(0)
|
| 238 |
+
.cuda()
|
| 239 |
+
)
|
| 240 |
+
if args.precision == "bf16":
|
| 241 |
+
image_clip = image_clip.bfloat16()
|
| 242 |
+
elif args.precision == "fp16":
|
| 243 |
+
image_clip = image_clip.half()
|
| 244 |
+
else:
|
| 245 |
+
image_clip = image_clip.float()
|
| 246 |
+
|
| 247 |
+
# SAM preprocessing
|
| 248 |
+
image = transform.apply_image(image_np)
|
| 249 |
+
resize_list = [image.shape[:2]]
|
| 250 |
+
image = (
|
| 251 |
+
preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
|
| 252 |
+
.unsqueeze(0)
|
| 253 |
+
.cuda()
|
| 254 |
+
)
|
| 255 |
+
if args.precision == "bf16":
|
| 256 |
+
image = image.bfloat16()
|
| 257 |
+
elif args.precision == "fp16":
|
| 258 |
+
image = image.half()
|
| 259 |
+
else:
|
| 260 |
+
image = image.float()
|
| 261 |
+
|
| 262 |
+
# Tokenize
|
| 263 |
+
input_ids = tokenizer_image_token(prompt_str, tokenizer, return_tensors="pt")
|
| 264 |
+
input_ids = input_ids.unsqueeze(0).cuda()
|
| 265 |
+
attention_masks = input_ids.ne(tokenizer.pad_token_id)
|
| 266 |
+
|
| 267 |
+
# Prefill inference (single forward pass instead of autoregressive generation)
|
| 268 |
+
h, w = original_size_list[0]
|
| 269 |
+
labels = input_ids.clone()
|
| 270 |
+
offset = torch.LongTensor([0, 1]).cuda()
|
| 271 |
+
masks_list = [torch.zeros(1, h, w).float().cuda()]
|
| 272 |
+
label_list = [torch.zeros(h, w).long().cuda()]
|
| 273 |
+
|
| 274 |
+
with torch.no_grad():
|
| 275 |
+
output_dict = model(
|
| 276 |
+
images=image,
|
| 277 |
+
images_clip=image_clip,
|
| 278 |
+
input_ids=input_ids,
|
| 279 |
+
labels=labels,
|
| 280 |
+
attention_masks=attention_masks,
|
| 281 |
+
offset=offset,
|
| 282 |
+
masks_list=masks_list,
|
| 283 |
+
label_list=label_list,
|
| 284 |
+
resize_list=resize_list,
|
| 285 |
+
inference=True,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
pred_masks = output_dict["pred_masks"]
|
| 289 |
+
|
| 290 |
+
# Merge all predicted masks via union (logical OR)
|
| 291 |
+
merged = np.zeros((h, w), dtype=bool)
|
| 292 |
+
has_mask = False
|
| 293 |
+
for pred_mask in pred_masks:
|
| 294 |
+
if pred_mask.shape[0] == 0:
|
| 295 |
+
continue
|
| 296 |
+
mask_np = pred_mask.detach().cpu().numpy()[0] # (H, W)
|
| 297 |
+
merged |= (mask_np > 0)
|
| 298 |
+
has_mask = True
|
| 299 |
+
|
| 300 |
+
if not has_mask:
|
| 301 |
+
return None
|
| 302 |
+
|
| 303 |
+
return (merged.astype(np.uint8) * 255)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def read_language_instruction(h5_path: str) -> str:
|
| 307 |
+
"""Read language_instruction from other.h5."""
|
| 308 |
+
with h5py.File(h5_path, "r") as f:
|
| 309 |
+
instr = f["language_instruction"][()]
|
| 310 |
+
if isinstance(instr, bytes):
|
| 311 |
+
instr = instr.decode("utf-8")
|
| 312 |
+
return str(instr)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def main(args):
|
| 316 |
+
args = parse_args(args)
|
| 317 |
+
data_dir = Path(args.data_dir)
|
| 318 |
+
save_dir = Path(args.save_dir)
|
| 319 |
+
|
| 320 |
+
episodes_dir = data_dir / "episodes"
|
| 321 |
+
if not episodes_dir.is_dir():
|
| 322 |
+
print(f"Error: episodes directory not found at {episodes_dir}")
|
| 323 |
+
sys.exit(1)
|
| 324 |
+
|
| 325 |
+
# Collect and sort episode directories
|
| 326 |
+
episode_dirs = sorted(
|
| 327 |
+
[d for d in episodes_dir.iterdir() if d.is_dir()],
|
| 328 |
+
key=lambda p: p.name,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# Filter by episode range
|
| 332 |
+
if args.start_episode is not None or args.end_episode is not None:
|
| 333 |
+
start = args.start_episode if args.start_episode is not None else 0
|
| 334 |
+
end = args.end_episode if args.end_episode is not None else len(episode_dirs)
|
| 335 |
+
episode_dirs = [
|
| 336 |
+
d for d in episode_dirs
|
| 337 |
+
if start <= int(d.name) < end
|
| 338 |
+
]
|
| 339 |
+
|
| 340 |
+
print(f"Data dir : {data_dir}")
|
| 341 |
+
print(f"Save dir : {save_dir}")
|
| 342 |
+
print(f"Episodes : {len(episode_dirs)}")
|
| 343 |
+
print(f"Prompt : {args.prompt_template}")
|
| 344 |
+
print()
|
| 345 |
+
|
| 346 |
+
# Load model
|
| 347 |
+
print("Loading model...")
|
| 348 |
+
model, tokenizer, clip_image_processor, transform = load_model(args)
|
| 349 |
+
print("Model loaded.\n")
|
| 350 |
+
|
| 351 |
+
total_steps = 0
|
| 352 |
+
empty_mask_count = 0
|
| 353 |
+
|
| 354 |
+
for ep_dir in episode_dirs:
|
| 355 |
+
episode_id = ep_dir.name # e.g. "000000"
|
| 356 |
+
steps_dir = ep_dir / "steps"
|
| 357 |
+
if not steps_dir.is_dir():
|
| 358 |
+
print(f" [WARNING] No steps/ in {ep_dir}, skipping.")
|
| 359 |
+
continue
|
| 360 |
+
|
| 361 |
+
step_dirs = sorted(
|
| 362 |
+
[d for d in steps_dir.iterdir() if d.is_dir()],
|
| 363 |
+
key=lambda p: p.name,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
for step_dir in step_dirs:
|
| 367 |
+
step_id = step_dir.name # e.g. "0000"
|
| 368 |
+
|
| 369 |
+
# Read language instruction
|
| 370 |
+
other_h5 = step_dir / "other.h5"
|
| 371 |
+
if not other_h5.exists():
|
| 372 |
+
print(f" [WARNING] Missing other.h5 in {step_dir}, skipping.")
|
| 373 |
+
continue
|
| 374 |
+
language_instruction = read_language_instruction(str(other_h5))
|
| 375 |
+
# debug
|
| 376 |
+
# print(language_instruction)
|
| 377 |
+
|
| 378 |
+
# Build prompt
|
| 379 |
+
query_text = args.prompt_template.format(language_instruction)
|
| 380 |
+
prompt_str = build_prompt(query_text, args)
|
| 381 |
+
|
| 382 |
+
# Output directory (same structure as input: episodes/{episode_id}/steps/{step_id}/)
|
| 383 |
+
out_dir = save_dir / "episodes" / episode_id / "steps" / step_id
|
| 384 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 385 |
+
|
| 386 |
+
# Process both cameras
|
| 387 |
+
for cam_name in ("image_primary", "image_wrist"):
|
| 388 |
+
img_path = step_dir / f"{cam_name}.jpg"
|
| 389 |
+
mask_path = out_dir / f"{cam_name}_mask.png"
|
| 390 |
+
|
| 391 |
+
if not img_path.exists():
|
| 392 |
+
print(f" [WARNING] Missing {img_path}, skipping.")
|
| 393 |
+
continue
|
| 394 |
+
|
| 395 |
+
mask = infer_single_image(
|
| 396 |
+
str(img_path), prompt_str,
|
| 397 |
+
model, tokenizer, clip_image_processor, transform, args,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
if mask is None:
|
| 401 |
+
# Save blank mask and warn
|
| 402 |
+
h, w = cv2.imread(str(img_path)).shape[:2]
|
| 403 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 404 |
+
empty_mask_count += 1
|
| 405 |
+
|
| 406 |
+
cv2.imwrite(str(mask_path), mask)
|
| 407 |
+
|
| 408 |
+
total_steps += 1
|
| 409 |
+
if total_steps % 50 == 0:
|
| 410 |
+
print(f" Processed {total_steps} steps (episode {episode_id}, step {step_id})")
|
| 411 |
+
|
| 412 |
+
print(f"Episode {episode_id} done ({len(step_dirs)} steps)")
|
| 413 |
+
|
| 414 |
+
print(f"\nFinished. {total_steps} steps processed, {empty_mask_count} empty masks.")
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
if __name__ == "__main__":
|
| 418 |
+
main(sys.argv[1:])
|
.ipynb_checkpoints/chat-checkpoint.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
|
| 10 |
+
|
| 11 |
+
from model.AffordanceVLM import AffordanceVLMForCausalLM
|
| 12 |
+
from model.llava import conversation as conversation_lib
|
| 13 |
+
from model.llava.mm_utils import tokenizer_image_token
|
| 14 |
+
from model.segment_anything.utils.transforms import ResizeLongestSide
|
| 15 |
+
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
| 16 |
+
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def parse_args(args):
|
| 20 |
+
parser = argparse.ArgumentParser(description="LISA chat")
|
| 21 |
+
parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B")
|
| 22 |
+
parser.add_argument("--vis_save_path", default="./vis_output", type=str)
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
"--precision",
|
| 25 |
+
default="bf16",
|
| 26 |
+
type=str,
|
| 27 |
+
choices=["fp32", "bf16", "fp16"],
|
| 28 |
+
help="precision for inference",
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument("--image_size", default=1024, type=int, help="image size")
|
| 31 |
+
parser.add_argument("--model_max_length", default=512, type=int)
|
| 32 |
+
parser.add_argument("--lora_r", default=8, type=int)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument("--local-rank", default=0, type=int, help="node rank")
|
| 37 |
+
parser.add_argument("--load_in_8bit", action="store_true", default=False)
|
| 38 |
+
parser.add_argument("--load_in_4bit", action="store_true", default=False)
|
| 39 |
+
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--conv_type",
|
| 42 |
+
default="llava_v1",
|
| 43 |
+
type=str,
|
| 44 |
+
choices=["llava_v1", "llava_llama_2"],
|
| 45 |
+
)
|
| 46 |
+
return parser.parse_args(args)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def preprocess(
|
| 50 |
+
x,
|
| 51 |
+
pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
|
| 52 |
+
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
|
| 53 |
+
img_size=1024,
|
| 54 |
+
) -> torch.Tensor:
|
| 55 |
+
"""Normalize pixel values and pad to a square input."""
|
| 56 |
+
# Normalize colors
|
| 57 |
+
x = (x - pixel_mean) / pixel_std
|
| 58 |
+
# Pad
|
| 59 |
+
h, w = x.shape[-2:]
|
| 60 |
+
padh = img_size - h
|
| 61 |
+
padw = img_size - w
|
| 62 |
+
x = F.pad(x, (0, padw, 0, padh))
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def main(args):
|
| 67 |
+
args = parse_args(args)
|
| 68 |
+
os.makedirs(args.vis_save_path, exist_ok=True)
|
| 69 |
+
|
| 70 |
+
# Create model
|
| 71 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 72 |
+
args.version,
|
| 73 |
+
cache_dir=None,
|
| 74 |
+
model_max_length=args.model_max_length,
|
| 75 |
+
padding_side="right",
|
| 76 |
+
use_fast=False,
|
| 77 |
+
)
|
| 78 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 79 |
+
num_added_tokens = tokenizer.add_tokens("[SEG]")
|
| 80 |
+
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 81 |
+
num_added_tokens = tokenizer.add_tokens("[AFF]")
|
| 82 |
+
args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
|
| 83 |
+
|
| 84 |
+
torch_dtype = torch.float32
|
| 85 |
+
if args.precision == "bf16":
|
| 86 |
+
torch_dtype = torch.bfloat16
|
| 87 |
+
elif args.precision == "fp16":
|
| 88 |
+
torch_dtype = torch.half
|
| 89 |
+
|
| 90 |
+
kwargs = {"torch_dtype": torch_dtype}
|
| 91 |
+
if args.load_in_4bit:
|
| 92 |
+
kwargs.update(
|
| 93 |
+
{
|
| 94 |
+
"torch_dtype": torch.half,
|
| 95 |
+
"load_in_4bit": True,
|
| 96 |
+
"quantization_config": BitsAndBytesConfig(
|
| 97 |
+
load_in_4bit=True,
|
| 98 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 99 |
+
bnb_4bit_use_double_quant=True,
|
| 100 |
+
bnb_4bit_quant_type="nf4",
|
| 101 |
+
llm_int8_skip_modules=["visual_model"],
|
| 102 |
+
),
|
| 103 |
+
}
|
| 104 |
+
)
|
| 105 |
+
elif args.load_in_8bit:
|
| 106 |
+
kwargs.update(
|
| 107 |
+
{
|
| 108 |
+
"torch_dtype": torch.half,
|
| 109 |
+
"quantization_config": BitsAndBytesConfig(
|
| 110 |
+
llm_int8_skip_modules=["visual_model"],
|
| 111 |
+
load_in_8bit=True,
|
| 112 |
+
),
|
| 113 |
+
}
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
model = AffordanceVLMForCausalLM.from_pretrained(
|
| 117 |
+
args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, aff_token_idx=args.aff_token_idx, **kwargs
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 121 |
+
model.config.bos_token_id = tokenizer.bos_token_id
|
| 122 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 123 |
+
|
| 124 |
+
model.get_model().initialize_vision_modules(model.get_model().config)
|
| 125 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 126 |
+
vision_tower.to(dtype=torch_dtype)
|
| 127 |
+
|
| 128 |
+
if args.precision == "bf16":
|
| 129 |
+
model = model.bfloat16().cuda()
|
| 130 |
+
elif (
|
| 131 |
+
args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit)
|
| 132 |
+
):
|
| 133 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 134 |
+
model.model.vision_tower = None
|
| 135 |
+
import deepspeed
|
| 136 |
+
|
| 137 |
+
model_engine = deepspeed.init_inference(
|
| 138 |
+
model=model,
|
| 139 |
+
dtype=torch.half,
|
| 140 |
+
replace_with_kernel_inject=True,
|
| 141 |
+
replace_method="auto",
|
| 142 |
+
)
|
| 143 |
+
model = model_engine.module
|
| 144 |
+
model.model.vision_tower = vision_tower.half().cuda()
|
| 145 |
+
elif args.precision == "fp32":
|
| 146 |
+
model = model.float().cuda()
|
| 147 |
+
|
| 148 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 149 |
+
vision_tower.to(device=args.local_rank)
|
| 150 |
+
|
| 151 |
+
clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
|
| 152 |
+
transform = ResizeLongestSide(args.image_size)
|
| 153 |
+
|
| 154 |
+
model.eval()
|
| 155 |
+
|
| 156 |
+
while True:
|
| 157 |
+
conv = conversation_lib.conv_templates[args.conv_type].copy()
|
| 158 |
+
conv.messages = []
|
| 159 |
+
|
| 160 |
+
prompt = input("Please input your prompt: ")
|
| 161 |
+
prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + prompt
|
| 162 |
+
if args.use_mm_start_end:
|
| 163 |
+
replace_token = (
|
| 164 |
+
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
| 165 |
+
)
|
| 166 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 167 |
+
|
| 168 |
+
conv.append_message(conv.roles[0], prompt)
|
| 169 |
+
conv.append_message(conv.roles[1], "")
|
| 170 |
+
prompt = conv.get_prompt()
|
| 171 |
+
|
| 172 |
+
image_path = input("Please input the image path: ")
|
| 173 |
+
if not os.path.exists(image_path):
|
| 174 |
+
print("File not found in {}".format(image_path))
|
| 175 |
+
continue
|
| 176 |
+
|
| 177 |
+
image_np = cv2.imread(image_path)
|
| 178 |
+
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
|
| 179 |
+
original_size_list = [image_np.shape[:2]]
|
| 180 |
+
|
| 181 |
+
image_clip = (
|
| 182 |
+
clip_image_processor.preprocess(image_np, return_tensors="pt")[
|
| 183 |
+
"pixel_values"
|
| 184 |
+
][0]
|
| 185 |
+
.unsqueeze(0)
|
| 186 |
+
.cuda()
|
| 187 |
+
)
|
| 188 |
+
if args.precision == "bf16":
|
| 189 |
+
image_clip = image_clip.bfloat16()
|
| 190 |
+
elif args.precision == "fp16":
|
| 191 |
+
image_clip = image_clip.half()
|
| 192 |
+
else:
|
| 193 |
+
image_clip = image_clip.float()
|
| 194 |
+
|
| 195 |
+
image = transform.apply_image(image_np)
|
| 196 |
+
resize_list = [image.shape[:2]]
|
| 197 |
+
|
| 198 |
+
image = (
|
| 199 |
+
preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
|
| 200 |
+
.unsqueeze(0)
|
| 201 |
+
.cuda()
|
| 202 |
+
)
|
| 203 |
+
if args.precision == "bf16":
|
| 204 |
+
image = image.bfloat16()
|
| 205 |
+
elif args.precision == "fp16":
|
| 206 |
+
image = image.half()
|
| 207 |
+
else:
|
| 208 |
+
image = image.float()
|
| 209 |
+
|
| 210 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 211 |
+
input_ids = input_ids.unsqueeze(0).cuda()
|
| 212 |
+
|
| 213 |
+
output_ids, pred_masks = model.evaluate(
|
| 214 |
+
image_clip,
|
| 215 |
+
image,
|
| 216 |
+
input_ids,
|
| 217 |
+
resize_list,
|
| 218 |
+
original_size_list,
|
| 219 |
+
max_new_tokens=512,
|
| 220 |
+
tokenizer=tokenizer,
|
| 221 |
+
)
|
| 222 |
+
output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
|
| 223 |
+
|
| 224 |
+
text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
|
| 225 |
+
text_output = text_output.replace("\n", "").replace(" ", " ")
|
| 226 |
+
print("text_output: ", text_output)
|
| 227 |
+
|
| 228 |
+
for i, pred_mask in enumerate(pred_masks):
|
| 229 |
+
if pred_mask.shape[0] == 0:
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
pred_mask = pred_mask.detach().cpu().numpy()[0]
|
| 233 |
+
pred_mask = pred_mask > 0
|
| 234 |
+
|
| 235 |
+
save_path = "{}/{}_mask_{}.jpg".format(
|
| 236 |
+
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
|
| 237 |
+
)
|
| 238 |
+
cv2.imwrite(save_path, pred_mask * 100)
|
| 239 |
+
print("{} has been saved.".format(save_path))
|
| 240 |
+
|
| 241 |
+
save_path = "{}/{}_masked_img_{}.jpg".format(
|
| 242 |
+
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
|
| 243 |
+
)
|
| 244 |
+
save_img = image_np.copy()
|
| 245 |
+
save_img[pred_mask] = (
|
| 246 |
+
image_np * 0.5
|
| 247 |
+
+ pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
|
| 248 |
+
)[pred_mask]
|
| 249 |
+
save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
|
| 250 |
+
cv2.imwrite(save_path, save_img)
|
| 251 |
+
print("{} has been saved.".format(save_path))
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
if __name__ == "__main__":
|
| 255 |
+
main(sys.argv[1:])
|
.ipynb_checkpoints/chat_prefill-checkpoint.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Interactive affordance mask generation using prefill mode (single forward pass).
|
| 3 |
+
|
| 4 |
+
Same interactive workflow as chat.py, but uses prefill inference instead of
|
| 5 |
+
autoregressive generation. The assistant response "[AFF]." is pre-filled in the
|
| 6 |
+
prompt, so the model only does one forward pass to extract mask embeddings.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
import cv2
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
|
| 18 |
+
|
| 19 |
+
from model.AffordanceVLM import AffordanceVLMForCausalLM
|
| 20 |
+
from model.llava import conversation as conversation_lib
|
| 21 |
+
from model.llava.mm_utils import tokenizer_image_token
|
| 22 |
+
from model.segment_anything.utils.transforms import ResizeLongestSide
|
| 23 |
+
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
| 24 |
+
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def parse_args(args):
|
| 28 |
+
parser = argparse.ArgumentParser(description="AffordanceVLM chat (prefill mode)")
|
| 29 |
+
parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B")
|
| 30 |
+
parser.add_argument("--vis_save_path", default="./vis_output_prefill", type=str)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--precision", default="bf16", type=str,
|
| 33 |
+
choices=["fp32", "bf16", "fp16"],
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument("--image_size", default=1024, type=int)
|
| 36 |
+
parser.add_argument("--model_max_length", default=512, type=int)
|
| 37 |
+
parser.add_argument("--lora_r", default=8, type=int)
|
| 38 |
+
parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14", type=str)
|
| 39 |
+
parser.add_argument("--local-rank", default=0, type=int)
|
| 40 |
+
parser.add_argument("--load_in_8bit", action="store_true", default=False)
|
| 41 |
+
parser.add_argument("--load_in_4bit", action="store_true", default=False)
|
| 42 |
+
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--conv_type", default="llava_v1", type=str,
|
| 45 |
+
choices=["llava_v1", "llava_llama_2"],
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument("--prompt_template", type=str,
|
| 48 |
+
default="Segment the most suitable manipulation region on the single target object for the task '{}'.",
|
| 49 |
+
help="Template wrapping language_instruction. Use {} as placeholder.")
|
| 50 |
+
# Segment the most suitable manipulation region on the single target object for the task '{}'.
|
| 51 |
+
# Segment the affordance map for the task '{}' in this image.
|
| 52 |
+
# Segment the affordance map of the single target object for the task '{}' in this image.
|
| 53 |
+
# Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask.
|
| 54 |
+
# Given the task instruction '{}', what is the affordance map of the single target object in this image? There is only one target object. Please output segmentation mask.
|
| 55 |
+
return parser.parse_args(args)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def preprocess(
|
| 59 |
+
x,
|
| 60 |
+
pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
|
| 61 |
+
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
|
| 62 |
+
img_size=1024,
|
| 63 |
+
) -> torch.Tensor:
|
| 64 |
+
"""Normalize pixel values and pad to a square input."""
|
| 65 |
+
x = (x - pixel_mean) / pixel_std
|
| 66 |
+
h, w = x.shape[-2:]
|
| 67 |
+
padh = img_size - h
|
| 68 |
+
padw = img_size - w
|
| 69 |
+
x = F.pad(x, (0, padw, 0, padh))
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def main(args):
|
| 74 |
+
args = parse_args(args)
|
| 75 |
+
os.makedirs(args.vis_save_path, exist_ok=True)
|
| 76 |
+
|
| 77 |
+
# Create model
|
| 78 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 79 |
+
args.version,
|
| 80 |
+
cache_dir=None,
|
| 81 |
+
model_max_length=args.model_max_length,
|
| 82 |
+
padding_side="right",
|
| 83 |
+
use_fast=False,
|
| 84 |
+
)
|
| 85 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 86 |
+
tokenizer.add_tokens("[SEG]")
|
| 87 |
+
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 88 |
+
tokenizer.add_tokens("[AFF]")
|
| 89 |
+
args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
|
| 90 |
+
|
| 91 |
+
torch_dtype = torch.float32
|
| 92 |
+
if args.precision == "bf16":
|
| 93 |
+
torch_dtype = torch.bfloat16
|
| 94 |
+
elif args.precision == "fp16":
|
| 95 |
+
torch_dtype = torch.half
|
| 96 |
+
|
| 97 |
+
kwargs = {"torch_dtype": torch_dtype}
|
| 98 |
+
if args.load_in_4bit:
|
| 99 |
+
kwargs.update({
|
| 100 |
+
"torch_dtype": torch.half,
|
| 101 |
+
"load_in_4bit": True,
|
| 102 |
+
"quantization_config": BitsAndBytesConfig(
|
| 103 |
+
load_in_4bit=True,
|
| 104 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 105 |
+
bnb_4bit_use_double_quant=True,
|
| 106 |
+
bnb_4bit_quant_type="nf4",
|
| 107 |
+
llm_int8_skip_modules=["visual_model"],
|
| 108 |
+
),
|
| 109 |
+
})
|
| 110 |
+
elif args.load_in_8bit:
|
| 111 |
+
kwargs.update({
|
| 112 |
+
"torch_dtype": torch.half,
|
| 113 |
+
"quantization_config": BitsAndBytesConfig(
|
| 114 |
+
llm_int8_skip_modules=["visual_model"],
|
| 115 |
+
load_in_8bit=True,
|
| 116 |
+
),
|
| 117 |
+
})
|
| 118 |
+
|
| 119 |
+
model = AffordanceVLMForCausalLM.from_pretrained(
|
| 120 |
+
args.version,
|
| 121 |
+
low_cpu_mem_usage=True,
|
| 122 |
+
vision_tower=args.vision_tower,
|
| 123 |
+
seg_token_idx=args.seg_token_idx,
|
| 124 |
+
aff_token_idx=args.aff_token_idx,
|
| 125 |
+
**kwargs,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 129 |
+
model.config.bos_token_id = tokenizer.bos_token_id
|
| 130 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 131 |
+
|
| 132 |
+
model.get_model().initialize_vision_modules(model.get_model().config)
|
| 133 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 134 |
+
vision_tower.to(dtype=torch_dtype)
|
| 135 |
+
|
| 136 |
+
if args.precision == "bf16":
|
| 137 |
+
model = model.bfloat16().cuda()
|
| 138 |
+
elif args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit):
|
| 139 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 140 |
+
model.model.vision_tower = None
|
| 141 |
+
import deepspeed
|
| 142 |
+
model_engine = deepspeed.init_inference(
|
| 143 |
+
model=model,
|
| 144 |
+
dtype=torch.half,
|
| 145 |
+
replace_with_kernel_inject=True,
|
| 146 |
+
replace_method="auto",
|
| 147 |
+
)
|
| 148 |
+
model = model_engine.module
|
| 149 |
+
model.model.vision_tower = vision_tower.half().cuda()
|
| 150 |
+
elif args.precision == "fp32":
|
| 151 |
+
model = model.float().cuda()
|
| 152 |
+
|
| 153 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 154 |
+
vision_tower.to(device=args.local_rank)
|
| 155 |
+
|
| 156 |
+
clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
|
| 157 |
+
transform = ResizeLongestSide(args.image_size)
|
| 158 |
+
|
| 159 |
+
model.eval()
|
| 160 |
+
|
| 161 |
+
# debug
|
| 162 |
+
template = "Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask."
|
| 163 |
+
|
| 164 |
+
while True:
|
| 165 |
+
conv = conversation_lib.conv_templates[args.conv_type].copy()
|
| 166 |
+
conv.messages = []
|
| 167 |
+
|
| 168 |
+
prompt = input("Please input your prompt: ")
|
| 169 |
+
# 加入模版
|
| 170 |
+
prompt = args.prompt_template.format(prompt)
|
| 171 |
+
|
| 172 |
+
prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + prompt
|
| 173 |
+
if args.use_mm_start_end:
|
| 174 |
+
replace_token = (
|
| 175 |
+
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
| 176 |
+
)
|
| 177 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 178 |
+
|
| 179 |
+
conv.append_message(conv.roles[0], prompt)
|
| 180 |
+
conv.append_message(conv.roles[1], "[AFF].")
|
| 181 |
+
prompt = conv.get_prompt()
|
| 182 |
+
|
| 183 |
+
image_path = input("Please input the image path: ")
|
| 184 |
+
if not os.path.exists(image_path):
|
| 185 |
+
print("File not found in {}".format(image_path))
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
image_np = cv2.imread(image_path)
|
| 189 |
+
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
|
| 190 |
+
original_size_list = [image_np.shape[:2]]
|
| 191 |
+
h, w = original_size_list[0]
|
| 192 |
+
|
| 193 |
+
image_clip = (
|
| 194 |
+
clip_image_processor.preprocess(image_np, return_tensors="pt")[
|
| 195 |
+
"pixel_values"
|
| 196 |
+
][0]
|
| 197 |
+
.unsqueeze(0)
|
| 198 |
+
.cuda()
|
| 199 |
+
)
|
| 200 |
+
if args.precision == "bf16":
|
| 201 |
+
image_clip = image_clip.bfloat16()
|
| 202 |
+
elif args.precision == "fp16":
|
| 203 |
+
image_clip = image_clip.half()
|
| 204 |
+
else:
|
| 205 |
+
image_clip = image_clip.float()
|
| 206 |
+
|
| 207 |
+
image = transform.apply_image(image_np)
|
| 208 |
+
resize_list = [image.shape[:2]]
|
| 209 |
+
|
| 210 |
+
image = (
|
| 211 |
+
preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
|
| 212 |
+
.unsqueeze(0)
|
| 213 |
+
.cuda()
|
| 214 |
+
)
|
| 215 |
+
if args.precision == "bf16":
|
| 216 |
+
image = image.bfloat16()
|
| 217 |
+
elif args.precision == "fp16":
|
| 218 |
+
image = image.half()
|
| 219 |
+
else:
|
| 220 |
+
image = image.float()
|
| 221 |
+
|
| 222 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 223 |
+
input_ids = input_ids.unsqueeze(0).cuda()
|
| 224 |
+
attention_masks = input_ids.ne(tokenizer.pad_token_id)
|
| 225 |
+
|
| 226 |
+
# Print the full prompt text (prefill mode has no generated text)
|
| 227 |
+
# debug
|
| 228 |
+
text_ids = input_ids[0][input_ids[0] != IMAGE_TOKEN_INDEX]
|
| 229 |
+
text_output = tokenizer.decode(text_ids, skip_special_tokens=False)
|
| 230 |
+
text_output = text_output.replace("\n", "").replace(" ", " ")
|
| 231 |
+
print("text_output: ", text_output)
|
| 232 |
+
|
| 233 |
+
# Prefill inference
|
| 234 |
+
labels = input_ids.clone()
|
| 235 |
+
offset = torch.LongTensor([0, 1]).cuda()
|
| 236 |
+
masks_list = [torch.zeros(1, h, w).float().cuda()]
|
| 237 |
+
label_list = [torch.zeros(h, w).long().cuda()]
|
| 238 |
+
|
| 239 |
+
with torch.no_grad():
|
| 240 |
+
output_dict = model(
|
| 241 |
+
images=image,
|
| 242 |
+
images_clip=image_clip,
|
| 243 |
+
input_ids=input_ids,
|
| 244 |
+
labels=labels,
|
| 245 |
+
attention_masks=attention_masks,
|
| 246 |
+
offset=offset,
|
| 247 |
+
masks_list=masks_list,
|
| 248 |
+
label_list=label_list,
|
| 249 |
+
resize_list=resize_list,
|
| 250 |
+
inference=True,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
pred_masks = output_dict["pred_masks"]
|
| 254 |
+
|
| 255 |
+
for i, pred_mask in enumerate(pred_masks):
|
| 256 |
+
if pred_mask.shape[0] == 0:
|
| 257 |
+
continue
|
| 258 |
+
|
| 259 |
+
pred_mask = pred_mask.detach().cpu().numpy()[0]
|
| 260 |
+
pred_mask = pred_mask > 0
|
| 261 |
+
|
| 262 |
+
save_path = "{}/{}_mask_{}.jpg".format(
|
| 263 |
+
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
|
| 264 |
+
)
|
| 265 |
+
cv2.imwrite(save_path, pred_mask * 100)
|
| 266 |
+
print("{} has been saved.".format(save_path))
|
| 267 |
+
|
| 268 |
+
save_path = "{}/{}_masked_img_{}.jpg".format(
|
| 269 |
+
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
|
| 270 |
+
)
|
| 271 |
+
save_img = image_np.copy()
|
| 272 |
+
save_img[pred_mask] = (
|
| 273 |
+
image_np * 0.5
|
| 274 |
+
+ pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
|
| 275 |
+
)[pred_mask]
|
| 276 |
+
save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
|
| 277 |
+
cv2.imwrite(save_path, save_img)
|
| 278 |
+
print("{} has been saved.".format(save_path))
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
if __name__ == "__main__":
|
| 282 |
+
main(sys.argv[1:])
|
.ipynb_checkpoints/train_aff-checkpoint.py
ADDED
|
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
import deepspeed
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import tqdm
|
| 12 |
+
import transformers
|
| 13 |
+
from peft import LoraConfig, get_peft_model
|
| 14 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 15 |
+
|
| 16 |
+
from model.AffordanceVLM import AffordanceVLMForCausalLM
|
| 17 |
+
from model.llava import conversation as conversation_lib
|
| 18 |
+
from utils.dataset import HybridDataset, ValDataset, collate_fn
|
| 19 |
+
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
| 20 |
+
AverageMeter, ProgressMeter, Summary, dict_to_cuda,
|
| 21 |
+
intersectionAndUnionGPU)
|
| 22 |
+
|
| 23 |
+
from utils.aff_seg_dataset import AffValDataset
|
| 24 |
+
from utils.reason_aff_dataset import ReasonAffValDataset
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def parse_args(args):
|
| 28 |
+
parser = argparse.ArgumentParser(description="LISA Model Training")
|
| 29 |
+
parser.add_argument("--local_rank", default=0, type=int, help="node rank")
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview"
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument("--vis_save_path", default="./vis_output", type=str)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--precision",
|
| 36 |
+
default="bf16",
|
| 37 |
+
type=str,
|
| 38 |
+
choices=["fp32", "bf16", "fp16"],
|
| 39 |
+
help="precision for inference",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument("--image_size", default=1024, type=int, help="image size")
|
| 42 |
+
parser.add_argument("--model_max_length", default=512, type=int)
|
| 43 |
+
parser.add_argument("--lora_r", default=8, type=int)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument("--load_in_8bit", action="store_true", default=False)
|
| 48 |
+
parser.add_argument("--load_in_4bit", action="store_true", default=False)
|
| 49 |
+
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--dataset", default="sem_seg||refer_seg||vqa||reason_seg", type=str
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument("--sample_rates", default="9,3,3,1", type=str)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--sem_seg_data",
|
| 56 |
+
default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary",
|
| 57 |
+
type=str,
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str
|
| 61 |
+
)
|
| 62 |
+
parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
|
| 63 |
+
parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
|
| 64 |
+
parser.add_argument("--aff_seg_data", default="handal", type=str)
|
| 65 |
+
parser.add_argument("--aff_sample_rates", default="1", type=str)
|
| 66 |
+
parser.add_argument("--reason_aff_data", default="handal_hard_reasoning", type=str)
|
| 67 |
+
parser.add_argument("--reason_aff_sample_rates", default="1", type=str)
|
| 68 |
+
parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str)
|
| 69 |
+
parser.add_argument("--dataset_dir", default="./dataset", type=str)
|
| 70 |
+
parser.add_argument("--log_base_dir", default="./runs", type=str)
|
| 71 |
+
parser.add_argument("--exp_name", default="lisa", type=str)
|
| 72 |
+
parser.add_argument("--epochs", default=10, type=int)
|
| 73 |
+
parser.add_argument("--steps_per_epoch", default=500, type=int)
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--batch_size", default=2, type=int, help="batch size per device per step"
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--grad_accumulation_steps",
|
| 79 |
+
default=10,
|
| 80 |
+
type=int,
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument("--val_batch_size", default=1, type=int)
|
| 83 |
+
parser.add_argument("--workers", default=4, type=int)
|
| 84 |
+
parser.add_argument("--lr", default=0.0003, type=float)
|
| 85 |
+
parser.add_argument("--ce_loss_weight", default=1.0, type=float)
|
| 86 |
+
parser.add_argument("--dice_loss_weight", default=0.5, type=float)
|
| 87 |
+
parser.add_argument("--bce_loss_weight", default=2.0, type=float)
|
| 88 |
+
parser.add_argument("--lora_alpha", default=16, type=int)
|
| 89 |
+
parser.add_argument("--lora_dropout", default=0.05, type=float)
|
| 90 |
+
parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
|
| 91 |
+
parser.add_argument("--explanatory", default=0.1, type=float)
|
| 92 |
+
parser.add_argument("--beta1", default=0.9, type=float)
|
| 93 |
+
parser.add_argument("--beta2", default=0.95, type=float)
|
| 94 |
+
parser.add_argument("--num_classes_per_sample", default=3, type=int)
|
| 95 |
+
parser.add_argument("--exclude_val", action="store_true", default=False)
|
| 96 |
+
parser.add_argument("--no_eval", action="store_true", default=False)
|
| 97 |
+
parser.add_argument("--eval_only", action="store_true", default=False)
|
| 98 |
+
parser.add_argument("--eval_affordance", action="store_true", default=False)
|
| 99 |
+
parser.add_argument("--eval_reason_aff", action="store_true", default=False)
|
| 100 |
+
parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
|
| 101 |
+
parser.add_argument("--out_dim", default=256, type=int)
|
| 102 |
+
parser.add_argument("--resume", default="", type=str)
|
| 103 |
+
parser.add_argument("--print_freq", default=1, type=int)
|
| 104 |
+
parser.add_argument("--start_epoch", default=0, type=int)
|
| 105 |
+
parser.add_argument("--gradient_checkpointing", action="store_true", default=True)
|
| 106 |
+
parser.add_argument("--train_mask_decoder", action="store_true", default=True)
|
| 107 |
+
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
|
| 108 |
+
parser.add_argument("--auto_resume", action="store_true", default=True)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--conv_type",
|
| 111 |
+
default="llava_v1",
|
| 112 |
+
type=str,
|
| 113 |
+
choices=["llava_v1", "llava_llama_2"],
|
| 114 |
+
)
|
| 115 |
+
return parser.parse_args(args)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def main(args):
|
| 119 |
+
args = parse_args(args)
|
| 120 |
+
args.log_dir = os.path.join(args.log_base_dir, args.exp_name)
|
| 121 |
+
if args.local_rank == 0:
|
| 122 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
| 123 |
+
writer = SummaryWriter(args.log_dir)
|
| 124 |
+
else:
|
| 125 |
+
writer = None
|
| 126 |
+
|
| 127 |
+
# Create model
|
| 128 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 129 |
+
args.version,
|
| 130 |
+
cache_dir=None,
|
| 131 |
+
model_max_length=args.model_max_length,
|
| 132 |
+
padding_side="right",
|
| 133 |
+
use_fast=False,
|
| 134 |
+
)
|
| 135 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 136 |
+
num_added_tokens = tokenizer.add_tokens("[SEG]")
|
| 137 |
+
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 138 |
+
num_added_tokens = tokenizer.add_tokens("[AFF]")
|
| 139 |
+
args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
|
| 140 |
+
|
| 141 |
+
if args.use_mm_start_end:
|
| 142 |
+
tokenizer.add_tokens(
|
| 143 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
model_args = {
|
| 147 |
+
"train_mask_decoder": args.train_mask_decoder,
|
| 148 |
+
"out_dim": args.out_dim,
|
| 149 |
+
"ce_loss_weight": args.ce_loss_weight,
|
| 150 |
+
"dice_loss_weight": args.dice_loss_weight,
|
| 151 |
+
"bce_loss_weight": args.bce_loss_weight,
|
| 152 |
+
"seg_token_idx": args.seg_token_idx,
|
| 153 |
+
"aff_token_idx": args.aff_token_idx,
|
| 154 |
+
"vision_pretrained": args.vision_pretrained,
|
| 155 |
+
"vision_tower": args.vision_tower,
|
| 156 |
+
"use_mm_start_end": args.use_mm_start_end,
|
| 157 |
+
}
|
| 158 |
+
torch_dtype = torch.float32
|
| 159 |
+
if args.precision == "bf16":
|
| 160 |
+
torch_dtype = torch.bfloat16
|
| 161 |
+
elif args.precision == "fp16":
|
| 162 |
+
torch_dtype = torch.half
|
| 163 |
+
model = AffordanceVLMForCausalLM.from_pretrained(
|
| 164 |
+
args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args
|
| 165 |
+
)
|
| 166 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 167 |
+
model.config.bos_token_id = tokenizer.bos_token_id
|
| 168 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 169 |
+
|
| 170 |
+
model.enable_input_require_grads()
|
| 171 |
+
model.gradient_checkpointing_enable()
|
| 172 |
+
|
| 173 |
+
model.get_model().initialize_vision_modules(model.get_model().config)
|
| 174 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 175 |
+
vision_tower.to(dtype=torch_dtype, device=args.local_rank)
|
| 176 |
+
if not args.eval_only:
|
| 177 |
+
model.get_model().initialize_lisa_modules(model.get_model().config)
|
| 178 |
+
|
| 179 |
+
for p in vision_tower.parameters():
|
| 180 |
+
p.requires_grad = False
|
| 181 |
+
for p in model.get_model().mm_projector.parameters():
|
| 182 |
+
p.requires_grad = False
|
| 183 |
+
|
| 184 |
+
conversation_lib.default_conversation = conversation_lib.conv_templates[
|
| 185 |
+
args.conv_type
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
lora_r = args.lora_r
|
| 189 |
+
if lora_r > 0:
|
| 190 |
+
|
| 191 |
+
def find_linear_layers(model, lora_target_modules):
|
| 192 |
+
cls = torch.nn.Linear
|
| 193 |
+
lora_module_names = set()
|
| 194 |
+
for name, module in model.named_modules():
|
| 195 |
+
if (
|
| 196 |
+
isinstance(module, cls)
|
| 197 |
+
and all(
|
| 198 |
+
[
|
| 199 |
+
x not in name
|
| 200 |
+
for x in [
|
| 201 |
+
"visual_model",
|
| 202 |
+
"vision_tower",
|
| 203 |
+
"mm_projector",
|
| 204 |
+
"text_hidden_fcs",
|
| 205 |
+
]
|
| 206 |
+
]
|
| 207 |
+
)
|
| 208 |
+
and any([x in name for x in lora_target_modules])
|
| 209 |
+
):
|
| 210 |
+
lora_module_names.add(name)
|
| 211 |
+
return sorted(list(lora_module_names))
|
| 212 |
+
|
| 213 |
+
lora_alpha = args.lora_alpha
|
| 214 |
+
lora_dropout = args.lora_dropout
|
| 215 |
+
lora_target_modules = find_linear_layers(
|
| 216 |
+
model, args.lora_target_modules.split(",")
|
| 217 |
+
)
|
| 218 |
+
lora_config = LoraConfig(
|
| 219 |
+
r=lora_r,
|
| 220 |
+
lora_alpha=lora_alpha,
|
| 221 |
+
target_modules=lora_target_modules,
|
| 222 |
+
lora_dropout=lora_dropout,
|
| 223 |
+
bias="none",
|
| 224 |
+
task_type="CAUSAL_LM",
|
| 225 |
+
)
|
| 226 |
+
model = get_peft_model(model, lora_config)
|
| 227 |
+
model.print_trainable_parameters()
|
| 228 |
+
|
| 229 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 230 |
+
|
| 231 |
+
# make text_hidden_fcs, mask_decoder, lm_head, embed_tokens trainable
|
| 232 |
+
for n, p in model.named_parameters():
|
| 233 |
+
if any(
|
| 234 |
+
[
|
| 235 |
+
x in n
|
| 236 |
+
for x in ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs"]
|
| 237 |
+
]
|
| 238 |
+
):
|
| 239 |
+
print("n: ", n, "p.shape: ", p.shape)
|
| 240 |
+
p.requires_grad = True
|
| 241 |
+
|
| 242 |
+
world_size = torch.cuda.device_count()
|
| 243 |
+
args.distributed = world_size > 1
|
| 244 |
+
train_dataset = HybridDataset(
|
| 245 |
+
args.dataset_dir,
|
| 246 |
+
tokenizer,
|
| 247 |
+
args.vision_tower,
|
| 248 |
+
samples_per_epoch=args.batch_size
|
| 249 |
+
* args.grad_accumulation_steps
|
| 250 |
+
* args.steps_per_epoch
|
| 251 |
+
* world_size,
|
| 252 |
+
precision=args.precision,
|
| 253 |
+
image_size=args.image_size,
|
| 254 |
+
num_classes_per_sample=args.num_classes_per_sample,
|
| 255 |
+
exclude_val=args.exclude_val,
|
| 256 |
+
dataset=args.dataset,
|
| 257 |
+
sample_rate=[float(x) for x in args.sample_rates.split(",")],
|
| 258 |
+
sem_seg_data=args.sem_seg_data,
|
| 259 |
+
refer_seg_data=args.refer_seg_data,
|
| 260 |
+
vqa_data=args.vqa_data,
|
| 261 |
+
reason_seg_data=args.reason_seg_data,
|
| 262 |
+
aff_seg_data=args.aff_seg_data,
|
| 263 |
+
aff_sample_rate=[float(x) for x in args.aff_sample_rates.split(",")],
|
| 264 |
+
reason_aff_data=args.reason_aff_data,
|
| 265 |
+
reason_aff_sample_rate=[float(x) for x in args.reason_aff_sample_rates.split(",")],
|
| 266 |
+
explanatory=args.explanatory,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
if args.no_eval == False:
|
| 270 |
+
if args.eval_affordance:
|
| 271 |
+
val_dataset = AffValDataset(
|
| 272 |
+
args.dataset_dir,
|
| 273 |
+
tokenizer,
|
| 274 |
+
args.vision_tower,
|
| 275 |
+
args.val_dataset,
|
| 276 |
+
args.image_size,
|
| 277 |
+
)
|
| 278 |
+
elif args.eval_reason_aff:
|
| 279 |
+
val_dataset = ReasonAffValDataset(
|
| 280 |
+
args.dataset_dir,
|
| 281 |
+
tokenizer,
|
| 282 |
+
args.vision_tower,
|
| 283 |
+
args.val_dataset,
|
| 284 |
+
args.image_size,
|
| 285 |
+
)
|
| 286 |
+
else:
|
| 287 |
+
val_dataset = ValDataset(
|
| 288 |
+
args.dataset_dir,
|
| 289 |
+
tokenizer,
|
| 290 |
+
args.vision_tower,
|
| 291 |
+
args.val_dataset,
|
| 292 |
+
args.image_size,
|
| 293 |
+
)
|
| 294 |
+
print(
|
| 295 |
+
f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples."
|
| 296 |
+
)
|
| 297 |
+
else:
|
| 298 |
+
val_dataset = None
|
| 299 |
+
print(f"Training with {len(train_dataset)} examples.")
|
| 300 |
+
|
| 301 |
+
ds_config = {
|
| 302 |
+
"train_micro_batch_size_per_gpu": args.batch_size,
|
| 303 |
+
"gradient_accumulation_steps": args.grad_accumulation_steps,
|
| 304 |
+
"optimizer": {
|
| 305 |
+
"type": "AdamW",
|
| 306 |
+
"params": {
|
| 307 |
+
"lr": args.lr,
|
| 308 |
+
"weight_decay": 0.0,
|
| 309 |
+
"betas": (args.beta1, args.beta2),
|
| 310 |
+
},
|
| 311 |
+
},
|
| 312 |
+
"scheduler": {
|
| 313 |
+
"type": "WarmupDecayLR",
|
| 314 |
+
"params": {
|
| 315 |
+
"total_num_steps": args.epochs * args.steps_per_epoch,
|
| 316 |
+
"warmup_min_lr": 0,
|
| 317 |
+
"warmup_max_lr": args.lr,
|
| 318 |
+
"warmup_num_steps": 100,
|
| 319 |
+
"warmup_type": "linear",
|
| 320 |
+
},
|
| 321 |
+
},
|
| 322 |
+
"fp16": {
|
| 323 |
+
"enabled": args.precision == "fp16",
|
| 324 |
+
},
|
| 325 |
+
"bf16": {
|
| 326 |
+
"enabled": args.precision == "bf16",
|
| 327 |
+
},
|
| 328 |
+
"gradient_clipping": 1.0,
|
| 329 |
+
"zero_optimization": {
|
| 330 |
+
"stage": 2,
|
| 331 |
+
"contiguous_gradients": True,
|
| 332 |
+
"overlap_comm": True,
|
| 333 |
+
"reduce_scatter": True,
|
| 334 |
+
"reduce_bucket_size": 5e8,
|
| 335 |
+
"allgather_bucket_size": 5e8,
|
| 336 |
+
},
|
| 337 |
+
}
|
| 338 |
+
model_engine, optimizer, train_loader, scheduler = deepspeed.initialize(
|
| 339 |
+
model=model,
|
| 340 |
+
model_parameters=model.parameters(),
|
| 341 |
+
training_data=train_dataset,
|
| 342 |
+
collate_fn=partial(
|
| 343 |
+
collate_fn,
|
| 344 |
+
tokenizer=tokenizer,
|
| 345 |
+
conv_type=args.conv_type,
|
| 346 |
+
use_mm_start_end=args.use_mm_start_end,
|
| 347 |
+
local_rank=args.local_rank,
|
| 348 |
+
),
|
| 349 |
+
config=ds_config,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# resume deepspeed checkpoint
|
| 353 |
+
if args.auto_resume and len(args.resume) == 0:
|
| 354 |
+
resume = os.path.join(args.log_dir, "ckpt_model")
|
| 355 |
+
if os.path.exists(resume):
|
| 356 |
+
args.resume = resume
|
| 357 |
+
|
| 358 |
+
if args.resume:
|
| 359 |
+
load_path, client_state = model_engine.load_checkpoint(args.resume)
|
| 360 |
+
with open(os.path.join(args.resume, "latest"), "r") as f:
|
| 361 |
+
ckpt_dir = f.readlines()[0].strip()
|
| 362 |
+
args.start_epoch = (
|
| 363 |
+
int(ckpt_dir.replace("global_step", "")) // args.steps_per_epoch
|
| 364 |
+
)
|
| 365 |
+
print(
|
| 366 |
+
"resume training from {}, start from epoch {}".format(
|
| 367 |
+
args.resume, args.start_epoch
|
| 368 |
+
)
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# validation dataset
|
| 372 |
+
if val_dataset is not None:
|
| 373 |
+
assert args.val_batch_size == 1
|
| 374 |
+
val_sampler = torch.utils.data.distributed.DistributedSampler(
|
| 375 |
+
val_dataset, shuffle=False, drop_last=False
|
| 376 |
+
)
|
| 377 |
+
val_loader = torch.utils.data.DataLoader(
|
| 378 |
+
val_dataset,
|
| 379 |
+
batch_size=args.val_batch_size,
|
| 380 |
+
shuffle=False,
|
| 381 |
+
num_workers=args.workers,
|
| 382 |
+
pin_memory=False,
|
| 383 |
+
sampler=val_sampler,
|
| 384 |
+
collate_fn=partial(
|
| 385 |
+
collate_fn,
|
| 386 |
+
tokenizer=tokenizer,
|
| 387 |
+
conv_type=args.conv_type,
|
| 388 |
+
use_mm_start_end=args.use_mm_start_end,
|
| 389 |
+
local_rank=args.local_rank,
|
| 390 |
+
),
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
train_iter = iter(train_loader)
|
| 394 |
+
best_score, cur_ciou = 0.0, 0.0
|
| 395 |
+
|
| 396 |
+
if args.eval_only:
|
| 397 |
+
giou, ciou = validate(val_loader, model_engine, 0, writer, args)
|
| 398 |
+
if args.local_rank == 0:
|
| 399 |
+
with open(os.path.join(args.version, "eval_result.txt"), "a") as f:
|
| 400 |
+
f.write(f"dataset: {args.val_dataset}, giou: {giou}, ciou: {ciou} \n")
|
| 401 |
+
exit()
|
| 402 |
+
|
| 403 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 404 |
+
# train for one epoch
|
| 405 |
+
train_iter = train(
|
| 406 |
+
train_loader,
|
| 407 |
+
model_engine,
|
| 408 |
+
epoch,
|
| 409 |
+
scheduler,
|
| 410 |
+
writer,
|
| 411 |
+
train_iter,
|
| 412 |
+
args,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
if args.no_eval == False:
|
| 416 |
+
giou, ciou = validate(val_loader, model_engine, epoch, writer, args)
|
| 417 |
+
is_best = giou > best_score
|
| 418 |
+
best_score = max(giou, best_score)
|
| 419 |
+
cur_ciou = ciou if is_best else cur_ciou
|
| 420 |
+
|
| 421 |
+
if args.no_eval or is_best:
|
| 422 |
+
save_dir = os.path.join(args.log_dir, "ckpt_model")
|
| 423 |
+
if args.local_rank == 0:
|
| 424 |
+
torch.save(
|
| 425 |
+
{"epoch": epoch},
|
| 426 |
+
os.path.join(
|
| 427 |
+
args.log_dir,
|
| 428 |
+
"meta_log_giou{:.3f}_ciou{:.3f}.pth".format(
|
| 429 |
+
best_score, cur_ciou
|
| 430 |
+
),
|
| 431 |
+
),
|
| 432 |
+
)
|
| 433 |
+
if os.path.exists(save_dir):
|
| 434 |
+
shutil.rmtree(save_dir)
|
| 435 |
+
torch.distributed.barrier()
|
| 436 |
+
model_engine.save_checkpoint(save_dir)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def train(
|
| 440 |
+
train_loader,
|
| 441 |
+
model,
|
| 442 |
+
epoch,
|
| 443 |
+
scheduler,
|
| 444 |
+
writer,
|
| 445 |
+
train_iter,
|
| 446 |
+
args,
|
| 447 |
+
):
|
| 448 |
+
"""Main training loop."""
|
| 449 |
+
batch_time = AverageMeter("Time", ":6.3f")
|
| 450 |
+
data_time = AverageMeter("Data", ":6.3f")
|
| 451 |
+
losses = AverageMeter("Loss", ":.4f")
|
| 452 |
+
ce_losses = AverageMeter("CeLoss", ":.4f")
|
| 453 |
+
mask_bce_losses = AverageMeter("MaskBCELoss", ":.4f")
|
| 454 |
+
mask_dice_losses = AverageMeter("MaskDICELoss", ":.4f")
|
| 455 |
+
mask_losses = AverageMeter("MaskLoss", ":.4f")
|
| 456 |
+
|
| 457 |
+
progress = ProgressMeter(
|
| 458 |
+
args.steps_per_epoch,
|
| 459 |
+
[
|
| 460 |
+
batch_time,
|
| 461 |
+
losses,
|
| 462 |
+
ce_losses,
|
| 463 |
+
mask_losses,
|
| 464 |
+
mask_bce_losses,
|
| 465 |
+
mask_dice_losses,
|
| 466 |
+
],
|
| 467 |
+
prefix="Epoch: [{}]".format(epoch),
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
# switch to train mode
|
| 471 |
+
model.train()
|
| 472 |
+
end = time.time()
|
| 473 |
+
for global_step in range(args.steps_per_epoch):
|
| 474 |
+
for i in range(args.grad_accumulation_steps):
|
| 475 |
+
try:
|
| 476 |
+
input_dict = next(train_iter)
|
| 477 |
+
except:
|
| 478 |
+
train_iter = iter(train_loader)
|
| 479 |
+
input_dict = next(train_iter)
|
| 480 |
+
|
| 481 |
+
data_time.update(time.time() - end)
|
| 482 |
+
input_dict = dict_to_cuda(input_dict)
|
| 483 |
+
|
| 484 |
+
if args.precision == "fp16":
|
| 485 |
+
input_dict["images"] = input_dict["images"].half()
|
| 486 |
+
input_dict["images_clip"] = input_dict["images_clip"].half()
|
| 487 |
+
elif args.precision == "bf16":
|
| 488 |
+
input_dict["images"] = input_dict["images"].bfloat16()
|
| 489 |
+
input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
|
| 490 |
+
else:
|
| 491 |
+
input_dict["images"] = input_dict["images"].float()
|
| 492 |
+
input_dict["images_clip"] = input_dict["images_clip"].float()
|
| 493 |
+
|
| 494 |
+
output_dict = model(**input_dict)
|
| 495 |
+
|
| 496 |
+
loss = output_dict["loss"]
|
| 497 |
+
ce_loss = output_dict["ce_loss"]
|
| 498 |
+
mask_bce_loss = output_dict["mask_bce_loss"]
|
| 499 |
+
mask_dice_loss = output_dict["mask_dice_loss"]
|
| 500 |
+
mask_loss = output_dict["mask_loss"]
|
| 501 |
+
|
| 502 |
+
losses.update(loss.item(), input_dict["images"].size(0))
|
| 503 |
+
ce_losses.update(ce_loss.item(), input_dict["images"].size(0))
|
| 504 |
+
mask_bce_losses.update(mask_bce_loss.item(), input_dict["images"].size(0))
|
| 505 |
+
mask_dice_losses.update(mask_dice_loss.item(), input_dict["images"].size(0))
|
| 506 |
+
mask_losses.update(mask_loss.item(), input_dict["images"].size(0))
|
| 507 |
+
model.backward(loss)
|
| 508 |
+
model.step()
|
| 509 |
+
|
| 510 |
+
# measure elapsed time
|
| 511 |
+
batch_time.update(time.time() - end)
|
| 512 |
+
end = time.time()
|
| 513 |
+
|
| 514 |
+
if global_step % args.print_freq == 0:
|
| 515 |
+
if args.distributed:
|
| 516 |
+
batch_time.all_reduce()
|
| 517 |
+
data_time.all_reduce()
|
| 518 |
+
|
| 519 |
+
losses.all_reduce()
|
| 520 |
+
ce_losses.all_reduce()
|
| 521 |
+
mask_bce_losses.all_reduce()
|
| 522 |
+
mask_dice_losses.all_reduce()
|
| 523 |
+
mask_losses.all_reduce()
|
| 524 |
+
|
| 525 |
+
if args.local_rank == 0:
|
| 526 |
+
progress.display(global_step + 1)
|
| 527 |
+
writer.add_scalar("train/loss", losses.avg, global_step)
|
| 528 |
+
writer.add_scalar("train/ce_loss", ce_losses.avg, global_step)
|
| 529 |
+
writer.add_scalar(
|
| 530 |
+
"train/mask_bce_loss", mask_bce_losses.avg, global_step
|
| 531 |
+
)
|
| 532 |
+
writer.add_scalar(
|
| 533 |
+
"train/mask_dice_loss", mask_dice_losses.avg, global_step
|
| 534 |
+
)
|
| 535 |
+
writer.add_scalar("train/mask_loss", mask_losses.avg, global_step)
|
| 536 |
+
writer.add_scalar(
|
| 537 |
+
"metrics/total_secs_per_batch", batch_time.avg, global_step
|
| 538 |
+
)
|
| 539 |
+
writer.add_scalar(
|
| 540 |
+
"metrics/data_secs_per_batch", data_time.avg, global_step
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
batch_time.reset()
|
| 544 |
+
data_time.reset()
|
| 545 |
+
losses.reset()
|
| 546 |
+
ce_losses.reset()
|
| 547 |
+
mask_bce_losses.reset()
|
| 548 |
+
mask_dice_losses.reset()
|
| 549 |
+
mask_losses.reset()
|
| 550 |
+
|
| 551 |
+
if global_step != 0:
|
| 552 |
+
curr_lr = scheduler.get_last_lr()
|
| 553 |
+
if args.local_rank == 0:
|
| 554 |
+
writer.add_scalar("train/lr", curr_lr[0], global_step)
|
| 555 |
+
|
| 556 |
+
return train_iter
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def validate(val_loader, model_engine, epoch, writer, args):
|
| 560 |
+
intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
|
| 561 |
+
union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
|
| 562 |
+
acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
|
| 563 |
+
|
| 564 |
+
model_engine.eval()
|
| 565 |
+
|
| 566 |
+
for input_dict in tqdm.tqdm(val_loader):
|
| 567 |
+
torch.cuda.empty_cache()
|
| 568 |
+
|
| 569 |
+
input_dict = dict_to_cuda(input_dict)
|
| 570 |
+
if args.precision == "fp16":
|
| 571 |
+
input_dict["images"] = input_dict["images"].half()
|
| 572 |
+
input_dict["images_clip"] = input_dict["images_clip"].half()
|
| 573 |
+
elif args.precision == "bf16":
|
| 574 |
+
input_dict["images"] = input_dict["images"].bfloat16()
|
| 575 |
+
input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
|
| 576 |
+
else:
|
| 577 |
+
input_dict["images"] = input_dict["images"].float()
|
| 578 |
+
input_dict["images_clip"] = input_dict["images_clip"].float()
|
| 579 |
+
|
| 580 |
+
with torch.no_grad():
|
| 581 |
+
output_dict = model_engine(**input_dict)
|
| 582 |
+
|
| 583 |
+
pred_masks = output_dict["pred_masks"]
|
| 584 |
+
masks_list = output_dict["gt_masks"][0].int()
|
| 585 |
+
output_list = (pred_masks[0] > 0).int()
|
| 586 |
+
assert len(pred_masks) == 1
|
| 587 |
+
|
| 588 |
+
intersection, union, acc_iou = 0.0, 0.0, 0.0
|
| 589 |
+
for mask_i, output_i in zip(masks_list, output_list):
|
| 590 |
+
intersection_i, union_i, _ = intersectionAndUnionGPU(
|
| 591 |
+
output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255
|
| 592 |
+
)
|
| 593 |
+
intersection += intersection_i
|
| 594 |
+
union += union_i
|
| 595 |
+
acc_iou += intersection_i / (union_i + 1e-5)
|
| 596 |
+
acc_iou[union_i == 0] += 1.0 # no-object target
|
| 597 |
+
intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
|
| 598 |
+
acc_iou = acc_iou.cpu().numpy() / masks_list.shape[0]
|
| 599 |
+
intersection_meter.update(intersection), union_meter.update(
|
| 600 |
+
union
|
| 601 |
+
), acc_iou_meter.update(acc_iou, n=masks_list.shape[0])
|
| 602 |
+
|
| 603 |
+
intersection_meter.all_reduce()
|
| 604 |
+
union_meter.all_reduce()
|
| 605 |
+
acc_iou_meter.all_reduce()
|
| 606 |
+
|
| 607 |
+
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
|
| 608 |
+
ciou = iou_class[1]
|
| 609 |
+
giou = acc_iou_meter.avg[1]
|
| 610 |
+
|
| 611 |
+
if args.local_rank == 0:
|
| 612 |
+
writer.add_scalar("val/giou", giou, epoch)
|
| 613 |
+
writer.add_scalar("val/ciou", ciou, epoch)
|
| 614 |
+
print("giou: {:.4f}, ciou: {:.4f}".format(giou, ciou))
|
| 615 |
+
|
| 616 |
+
return giou, ciou
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
if __name__ == "__main__":
|
| 620 |
+
main(sys.argv[1:])
|
README.md
CHANGED
|
@@ -1,3 +1,79 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<h1>
|
| 3 |
+
<b>
|
| 4 |
+
RAGNet: Large-scale Reasoning-based Affordance Segmentation Benchmark towards General Grasping
|
| 5 |
+
</b>
|
| 6 |
+
</h1>
|
| 7 |
+
</div>
|
| 8 |
+
|
| 9 |
+
<div align="center">
|
| 10 |
+
|
| 11 |
+
| [**📑 Paper**](https://arxiv.org/abs/2507.23734) | [**🤗 Model**](https://huggingface.co/Dongming97/AffordanceVLM) | [**🤗 Dataset**](https://huggingface.co/datasets/Dongming97/RAGNet) | [**🖥️ Website**](https://wudongming97.github.io/RAGNet/) |
|
| 12 |
+
|
| 13 |
+
</div>
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
<p align="center"><img src="./imgs/AffordanceNet.jpg" width="800"/></p>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
> **[RAGNet: Large-scale Reasoning-based Affordance Segmentation Benchmark towards General Grasping](https://arxiv.org/abs/2507.23734)**
|
| 20 |
+
>
|
| 21 |
+
> Dongming Wu, Yanping Fu, Saike Huang, Yingfei Liu, Fan Jia, Nian Liu, Feng Dai, Tiancai Wang, Rao Muhammad Anwer, Fahad Shahbaz Khan, Jianbing Shen
|
| 22 |
+
|
| 23 |
+
## 📝 TL;DR
|
| 24 |
+
- To push forward general robotic grasping, we introduce a large-scale reasoning-based affordance segmentation benchmark, **RAGNet**. It contains 273k images, 180 categories, and 26k reasoning instructions.
|
| 25 |
+
- Furthermore, we propose a comprehensive affordance-based grasping framework, named AffordanceNet, which consists of a VLM (named AffordanceVLM) pre-trained on our massive affordance data and a grasping network that conditions an affordance map to grasp the target.
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## 📰 News
|
| 30 |
+
- [2025.08] Paper is released at [arXiv](https://arxiv.org/abs/2507.23734).
|
| 31 |
+
- [2025.07] Inference code and the [AffordanceVLM](https://huggingface.co/Dongming97/AffordanceVLM) model are released. Welcome to try it!
|
| 32 |
+
- [2025.06] Paper is accepted by ICCV 2025!
|
| 33 |
+
|
| 34 |
+
---
|
| 35 |
+
|
| 36 |
+
## 🚀 Getting Started
|
| 37 |
+
|
| 38 |
+
* [Installation](docs/installation.md)
|
| 39 |
+
* [Download dataset](docs/dataset.md)
|
| 40 |
+
* [Training and evaluation](docs/training_and_evaluation.md)
|
| 41 |
+
* To deploy using Gradio, run the following command:
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
python app.py --version='./exps/AffordanceVLM-7B'
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
## 📊 Main Results
|
| 50 |
+
### 🔹 Affordance Segmentation
|
| 51 |
+
| Method | HANDAL gIoU | HANDAL cIoU | HANDAL† gIoU | HANDAL† cIoU | GraspNet seen gIoU | GraspNet seen cIoU | GraspNet novel gIoU | GraspNet novel cIoU | 3DOI gIoU | 3DOI cIoU |
|
| 52 |
+
|--------------------------------------|-------------|-------------|---------------|---------------|----------------------|----------------------|------------------------|------------------------|------------|------------|
|
| 53 |
+
| AffordanceNet | 60.3| 60.8 |60.5|60.3|63.3 |64.0| 45.6 |33.2 | 37.4| 37.4 |
|
| 54 |
+
|
| 55 |
+
### 🔸 Reasoning-Based Affordance Segmentation
|
| 56 |
+
|
| 57 |
+
| Method | HANDAL (easy) gIoU | HANDAL (easy) cIoU | HANDAL (hard) gIoU | HANDAL (hard) cIoU | 3DOI gIoU | 3DOI cIoU |
|
| 58 |
+
|---------|---------------------|---------------------|---------------------|---------------------|-----------|-----------|
|
| 59 |
+
| AffordanceNet| 58.3| 58.1 | 58.2| 57.8 | 38.1 | 39.4|
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
## 📚 Citation
|
| 63 |
+
If you find our work useful, please consider citing:
|
| 64 |
+
|
| 65 |
+
```bibtex
|
| 66 |
+
@inproceedings{wu2025ragnet,
|
| 67 |
+
title={RAGNet: Large-scale Reasoning-based Affordance Segmentation Benchmark towards General Grasping},
|
| 68 |
+
author={Wu, Dongming and Fu, Yanping and Huang, Saike and Liu, Yingfei and Jia, Fan and Liu, Nian and Dai, Feng and Wang, Tiancai and Anwer, Rao Muhammad and Khan, Fahad Shahbaz and others},
|
| 69 |
+
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
|
| 70 |
+
pages={11980--11990},
|
| 71 |
+
year={2025}
|
| 72 |
+
}
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
## 🙏 Acknowledgements
|
| 76 |
+
We thank the authors that open the following projects.
|
| 77 |
+
- [LISA](https://github.com/dvlab-research/LISA)
|
| 78 |
+
- [LLaVA](https://github.com/haotian-liu/LLaVA)
|
| 79 |
+
- [SAM](https://github.com/facebookresearch/segment-anything)
|
app.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
import bleach
|
| 7 |
+
import cv2
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
|
| 14 |
+
|
| 15 |
+
from model.AffordanceVLM import AffordanceVLMForCausalLM
|
| 16 |
+
from model.llava import conversation as conversation_lib
|
| 17 |
+
from model.llava.mm_utils import tokenizer_image_token
|
| 18 |
+
from model.segment_anything.utils.transforms import ResizeLongestSide
|
| 19 |
+
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
| 20 |
+
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
|
| 21 |
+
|
| 22 |
+
from datetime import datetime
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def parse_args(args):
|
| 26 |
+
parser = argparse.ArgumentParser(description="AffordanceVLM chat")
|
| 27 |
+
parser.add_argument("--version", default="./exps/AffordanceVLM-7B")
|
| 28 |
+
parser.add_argument("--vis_save_path", default="./vis_output", type=str)
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--precision",
|
| 31 |
+
default="bf16",
|
| 32 |
+
type=str,
|
| 33 |
+
choices=["fp32", "bf16", "fp16"],
|
| 34 |
+
help="precision for inference",
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument("--image_size", default=1024, type=int, help="image size")
|
| 37 |
+
parser.add_argument("--model_max_length", default=512, type=int)
|
| 38 |
+
parser.add_argument("--lora_r", default=8, type=int)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument("--local-rank", default=0, type=int, help="node rank")
|
| 43 |
+
parser.add_argument("--load_in_8bit", action="store_true", default=False)
|
| 44 |
+
parser.add_argument("--load_in_4bit", action="store_true", default=False)
|
| 45 |
+
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--conv_type",
|
| 48 |
+
default="llava_v1",
|
| 49 |
+
type=str,
|
| 50 |
+
choices=["llava_v1", "llava_llama_2"],
|
| 51 |
+
)
|
| 52 |
+
return parser.parse_args(args)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def preprocess(
|
| 56 |
+
x,
|
| 57 |
+
pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
|
| 58 |
+
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
|
| 59 |
+
img_size=1024,
|
| 60 |
+
) -> torch.Tensor:
|
| 61 |
+
"""Normalize pixel values and pad to a square input."""
|
| 62 |
+
# Normalize colors
|
| 63 |
+
x = (x - pixel_mean) / pixel_std
|
| 64 |
+
# Pad
|
| 65 |
+
h, w = x.shape[-2:]
|
| 66 |
+
padh = img_size - h
|
| 67 |
+
padw = img_size - w
|
| 68 |
+
x = F.pad(x, (0, padw, 0, padh))
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
args = parse_args(sys.argv[1:])
|
| 72 |
+
os.makedirs(args.vis_save_path, exist_ok=True)
|
| 73 |
+
|
| 74 |
+
# Create model
|
| 75 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 76 |
+
args.version,
|
| 77 |
+
cache_dir=None,
|
| 78 |
+
model_max_length=args.model_max_length,
|
| 79 |
+
padding_side="right",
|
| 80 |
+
use_fast=False,
|
| 81 |
+
)
|
| 82 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 83 |
+
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 84 |
+
args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
|
| 85 |
+
|
| 86 |
+
torch_dtype = torch.float32
|
| 87 |
+
if args.precision == "bf16":
|
| 88 |
+
torch_dtype = torch.bfloat16
|
| 89 |
+
elif args.precision == "fp16":
|
| 90 |
+
torch_dtype = torch.half
|
| 91 |
+
|
| 92 |
+
kwargs = {"torch_dtype": torch_dtype}
|
| 93 |
+
if args.load_in_4bit:
|
| 94 |
+
kwargs.update(
|
| 95 |
+
{
|
| 96 |
+
"torch_dtype": torch.half,
|
| 97 |
+
"load_in_4bit": True,
|
| 98 |
+
"quantization_config": BitsAndBytesConfig(
|
| 99 |
+
load_in_4bit=True,
|
| 100 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 101 |
+
bnb_4bit_use_double_quant=True,
|
| 102 |
+
bnb_4bit_quant_type="nf4",
|
| 103 |
+
llm_int8_skip_modules=["visual_model"],
|
| 104 |
+
),
|
| 105 |
+
}
|
| 106 |
+
)
|
| 107 |
+
elif args.load_in_8bit:
|
| 108 |
+
kwargs.update(
|
| 109 |
+
{
|
| 110 |
+
"torch_dtype": torch.half,
|
| 111 |
+
"quantization_config": BitsAndBytesConfig(
|
| 112 |
+
llm_int8_skip_modules=["visual_model"],
|
| 113 |
+
load_in_8bit=True,
|
| 114 |
+
),
|
| 115 |
+
}
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
model = AffordanceVLMForCausalLM.from_pretrained(
|
| 119 |
+
args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, aff_token_idx=args.aff_token_idx, **kwargs
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 123 |
+
model.config.bos_token_id = tokenizer.bos_token_id
|
| 124 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 125 |
+
|
| 126 |
+
model.get_model().initialize_vision_modules(model.get_model().config)
|
| 127 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 128 |
+
vision_tower.to(dtype=torch_dtype)
|
| 129 |
+
|
| 130 |
+
if args.precision == "bf16":
|
| 131 |
+
model = model.bfloat16().cuda()
|
| 132 |
+
elif (
|
| 133 |
+
args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit)
|
| 134 |
+
):
|
| 135 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 136 |
+
model.model.vision_tower = None
|
| 137 |
+
import deepspeed
|
| 138 |
+
|
| 139 |
+
model_engine = deepspeed.init_inference(
|
| 140 |
+
model=model,
|
| 141 |
+
dtype=torch.half,
|
| 142 |
+
replace_with_kernel_inject=True,
|
| 143 |
+
replace_method="auto",
|
| 144 |
+
)
|
| 145 |
+
model = model_engine.module
|
| 146 |
+
model.model.vision_tower = vision_tower.half().cuda()
|
| 147 |
+
elif args.precision == "fp32":
|
| 148 |
+
model = model.float().cuda()
|
| 149 |
+
|
| 150 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 151 |
+
vision_tower.to(device=args.local_rank)
|
| 152 |
+
|
| 153 |
+
clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
|
| 154 |
+
transform = ResizeLongestSide(args.image_size)
|
| 155 |
+
|
| 156 |
+
model.eval()
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# Gradio
|
| 160 |
+
examples = [
|
| 161 |
+
[
|
| 162 |
+
"Please segment the affordance map of mug in this image.",
|
| 163 |
+
"/data/AffordanceNet/vis_output/my_workspace.JPG",
|
| 164 |
+
],
|
| 165 |
+
]
|
| 166 |
+
output_labels = ["Segmentation Output"]
|
| 167 |
+
|
| 168 |
+
title = "RAGNet: Large-scale Reasoning-based Affordance Segmentation Benchmark towards General Grasping"
|
| 169 |
+
|
| 170 |
+
description = """
|
| 171 |
+
<font size=4>
|
| 172 |
+
This is the online demo of AffordanceVLM. \n
|
| 173 |
+
**Note**: **Different prompts can lead to significantly varied results**. \n
|
| 174 |
+
**Note**: Please try to **standardize** your input text prompts to **avoid ambiguity**, and also pay attention to whether the **punctuations** of the input are correct. \n
|
| 175 |
+
**Note**: Current model is **AffordanceVLM-7B**. \n
|
| 176 |
+
**Usage**: <br>
|
| 177 |
+
To let AffordanceVLM **segment something**, input prompt like: "Can you segment the affordance map of xxx in this image?", "What is the affordance map of xxx in this image?"; <br>
|
| 178 |
+
</font>
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
article = """
|
| 182 |
+
<p style='text-align: center'>
|
| 183 |
+
<a href='https://arxiv.org/abs/2507.23734' target='_blank'>
|
| 184 |
+
Preprint Paper
|
| 185 |
+
</a>
|
| 186 |
+
\n
|
| 187 |
+
<p style='text-align: center'>
|
| 188 |
+
<a href='https://github.com/wudongming97/AffordanceNet' target='_blank'> Github Repo </a></p>
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
## to be implemented
|
| 193 |
+
def inference(input_str, input_image):
|
| 194 |
+
## filter out special chars
|
| 195 |
+
input_str = bleach.clean(input_str)
|
| 196 |
+
|
| 197 |
+
print("input_str: ", input_str, "input_image: ", input_image)
|
| 198 |
+
|
| 199 |
+
## input valid check
|
| 200 |
+
if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
|
| 201 |
+
output_str = "[Error] Invalid input: ", input_str
|
| 202 |
+
# output_image = np.zeros((128, 128, 3))
|
| 203 |
+
## error happened
|
| 204 |
+
output_image = cv2.imread("./resources/error_happened.png")[:, :, ::-1]
|
| 205 |
+
return output_image, output_str
|
| 206 |
+
|
| 207 |
+
# Model Inference
|
| 208 |
+
conv = conversation_lib.conv_templates[args.conv_type].copy()
|
| 209 |
+
conv.messages = []
|
| 210 |
+
|
| 211 |
+
prompt = input_str
|
| 212 |
+
prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + prompt
|
| 213 |
+
if args.use_mm_start_end:
|
| 214 |
+
replace_token = (
|
| 215 |
+
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
| 216 |
+
)
|
| 217 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 218 |
+
|
| 219 |
+
conv.append_message(conv.roles[0], prompt)
|
| 220 |
+
conv.append_message(conv.roles[1], "")
|
| 221 |
+
prompt = conv.get_prompt()
|
| 222 |
+
|
| 223 |
+
image_np = cv2.imread(input_image)
|
| 224 |
+
|
| 225 |
+
# save the input image
|
| 226 |
+
SAVE_DIR = "./gradio_images/"
|
| 227 |
+
os.makedirs(SAVE_DIR, exist_ok=True)
|
| 228 |
+
|
| 229 |
+
# generate a timestamped filename
|
| 230 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 231 |
+
filename = f"{timestamp}.png"
|
| 232 |
+
save_path = os.path.join(SAVE_DIR, filename)
|
| 233 |
+
|
| 234 |
+
# save the image
|
| 235 |
+
cv2.imwrite(save_path, image_np)
|
| 236 |
+
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
|
| 237 |
+
original_size_list = [image_np.shape[:2]]
|
| 238 |
+
|
| 239 |
+
image_clip = (
|
| 240 |
+
clip_image_processor.preprocess(image_np, return_tensors="pt")[
|
| 241 |
+
"pixel_values"
|
| 242 |
+
][0]
|
| 243 |
+
.unsqueeze(0)
|
| 244 |
+
.cuda()
|
| 245 |
+
)
|
| 246 |
+
if args.precision == "bf16":
|
| 247 |
+
image_clip = image_clip.bfloat16()
|
| 248 |
+
elif args.precision == "fp16":
|
| 249 |
+
image_clip = image_clip.half()
|
| 250 |
+
else:
|
| 251 |
+
image_clip = image_clip.float()
|
| 252 |
+
|
| 253 |
+
image = transform.apply_image(image_np)
|
| 254 |
+
resize_list = [image.shape[:2]]
|
| 255 |
+
|
| 256 |
+
image = (
|
| 257 |
+
preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
|
| 258 |
+
.unsqueeze(0)
|
| 259 |
+
.cuda()
|
| 260 |
+
)
|
| 261 |
+
if args.precision == "bf16":
|
| 262 |
+
image = image.bfloat16()
|
| 263 |
+
elif args.precision == "fp16":
|
| 264 |
+
image = image.half()
|
| 265 |
+
else:
|
| 266 |
+
image = image.float()
|
| 267 |
+
|
| 268 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 269 |
+
input_ids = input_ids.unsqueeze(0).cuda()
|
| 270 |
+
|
| 271 |
+
output_ids, pred_masks = model.evaluate(
|
| 272 |
+
image_clip,
|
| 273 |
+
image,
|
| 274 |
+
input_ids,
|
| 275 |
+
resize_list,
|
| 276 |
+
original_size_list,
|
| 277 |
+
max_new_tokens=512,
|
| 278 |
+
tokenizer=tokenizer,
|
| 279 |
+
)
|
| 280 |
+
output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
|
| 281 |
+
|
| 282 |
+
text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
|
| 283 |
+
text_output = text_output.replace("\n", "").replace(" ", " ")
|
| 284 |
+
text_output = text_output.split("ASSISTANT: ")[-1].replace('</s>', '')
|
| 285 |
+
|
| 286 |
+
print("text_output: ", text_output)
|
| 287 |
+
save_img = None
|
| 288 |
+
for i, pred_mask in enumerate(pred_masks):
|
| 289 |
+
if pred_mask.shape[0] == 0:
|
| 290 |
+
continue
|
| 291 |
+
|
| 292 |
+
pred_mask = pred_mask.detach().cpu().numpy()[0]
|
| 293 |
+
pred_mask = pred_mask > 0
|
| 294 |
+
|
| 295 |
+
save_img = image_np.copy()
|
| 296 |
+
save_img[pred_mask] = (
|
| 297 |
+
image_np * 0.5
|
| 298 |
+
+ pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
|
| 299 |
+
)[pred_mask]
|
| 300 |
+
|
| 301 |
+
output_str = "ASSITANT: " + text_output # input_str
|
| 302 |
+
if save_img is not None:
|
| 303 |
+
output_image = save_img # input_image
|
| 304 |
+
else:
|
| 305 |
+
## no seg output
|
| 306 |
+
output_image = cv2.imread("./resources/no_seg_out.png")[:, :, ::-1]
|
| 307 |
+
return output_image, output_str
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
demo = gr.Interface(
|
| 311 |
+
inference,
|
| 312 |
+
inputs=[
|
| 313 |
+
gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
|
| 314 |
+
gr.Image(type="filepath", label="Input Image"),
|
| 315 |
+
],
|
| 316 |
+
outputs=[
|
| 317 |
+
gr.Image(type="pil", label="Affordance Output"),
|
| 318 |
+
gr.Textbox(lines=1, placeholder=None, label="Text Output"),
|
| 319 |
+
],
|
| 320 |
+
title=title,
|
| 321 |
+
description=description,
|
| 322 |
+
article=article,
|
| 323 |
+
examples=examples,
|
| 324 |
+
allow_flagging="auto",
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
demo.queue()
|
| 328 |
+
# demo.launch()
|
| 329 |
+
demo.launch(server_name="0.0.0.0", server_port=3200)
|
batch_generate.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Batch generate affordance masks for all four LIBERO subsets sequentially.
|
| 3 |
+
|
| 4 |
+
SRC_ROOT="/gemini/space/wrz/libero_per_frame"
|
| 5 |
+
TGT_ROOT="/gemini/space/wrz/ragnet_results"
|
| 6 |
+
|
| 7 |
+
for ds in libero_object libero_goal libero_spatial libero_10; do
|
| 8 |
+
echo "========== Processing ${ds} =========="
|
| 9 |
+
CUDA_VISIBLE_DEVICES=0 python batch_generate.py \
|
| 10 |
+
--data_dir "${SRC_ROOT}/${ds}_converted" \
|
| 11 |
+
--save_dir "${TGT_ROOT}/${ds}"
|
| 12 |
+
echo "========== ${ds} done =========="
|
| 13 |
+
echo
|
| 14 |
+
done
|
batch_generate_prefill_accelerate.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Batch affordance mask generation for per-step datasets.
|
| 3 |
+
|
| 4 |
+
Reads a per-step dataset (converted by convert_lerobot_to_perstep.py) and
|
| 5 |
+
generates affordance masks for every image_primary.jpg and image_wrist.jpg
|
| 6 |
+
using AffordanceVLM.
|
| 7 |
+
|
| 8 |
+
Input structure:
|
| 9 |
+
{data_dir}/
|
| 10 |
+
├── meta_info.h5
|
| 11 |
+
└── episodes/
|
| 12 |
+
└── {episode_id:06d}/
|
| 13 |
+
└── steps/
|
| 14 |
+
└── {step_id:04d}/
|
| 15 |
+
├── other.h5 # language_instruction
|
| 16 |
+
├── image_primary.jpg
|
| 17 |
+
└── image_wrist.jpg
|
| 18 |
+
|
| 19 |
+
Output structure:
|
| 20 |
+
{save_dir}/
|
| 21 |
+
└── episodes/
|
| 22 |
+
└── {episode_id:06d}/
|
| 23 |
+
└── steps/
|
| 24 |
+
└── {step_id:04d}/
|
| 25 |
+
├── image_primary_mask.png # binary 0/255
|
| 26 |
+
└── image_wrist_mask.png
|
| 27 |
+
|
| 28 |
+
Usage:
|
| 29 |
+
CUDA_VISIBLE_DEVICES=1 python batch_generate_prefill_accelerate.py \
|
| 30 |
+
--data_dir /gemini/space/wrz/libero_per_frame/libero_spatial_converted \
|
| 31 |
+
--save_dir /gemini/space/wrz/ragnet_results/libero_spatial
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
import argparse
|
| 35 |
+
import os
|
| 36 |
+
import sys
|
| 37 |
+
from pathlib import Path
|
| 38 |
+
|
| 39 |
+
import cv2
|
| 40 |
+
import h5py
|
| 41 |
+
import numpy as np
|
| 42 |
+
import torch
|
| 43 |
+
import torch.nn.functional as F
|
| 44 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
|
| 45 |
+
|
| 46 |
+
from model.AffordanceVLM import AffordanceVLMForCausalLM
|
| 47 |
+
from model.llava import conversation as conversation_lib
|
| 48 |
+
from model.llava.mm_utils import tokenizer_image_token
|
| 49 |
+
from model.segment_anything.utils.transforms import ResizeLongestSide
|
| 50 |
+
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
| 51 |
+
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def parse_args(args):
|
| 55 |
+
parser = argparse.ArgumentParser(
|
| 56 |
+
description="Batch affordance mask generation for per-step datasets"
|
| 57 |
+
)
|
| 58 |
+
# Model arguments (same as chat.py)
|
| 59 |
+
parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B")
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--precision", default="bf16", type=str,
|
| 62 |
+
choices=["fp32", "bf16", "fp16"],
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument("--image_size", default=1024, type=int)
|
| 65 |
+
parser.add_argument("--model_max_length", default=512, type=int)
|
| 66 |
+
parser.add_argument("--lora_r", default=8, type=int)
|
| 67 |
+
parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14", type=str)
|
| 68 |
+
parser.add_argument("--local-rank", default=0, type=int)
|
| 69 |
+
parser.add_argument("--load_in_8bit", action="store_true", default=False)
|
| 70 |
+
parser.add_argument("--load_in_4bit", action="store_true", default=False)
|
| 71 |
+
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--conv_type", default="llava_v1", type=str,
|
| 74 |
+
choices=["llava_v1", "llava_llama_2"],
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Batch processing arguments
|
| 78 |
+
parser.add_argument("--data_dir", type=str, required=True,
|
| 79 |
+
help="Root of per-step dataset (contains episodes/)")
|
| 80 |
+
parser.add_argument("--save_dir", type=str, required=True,
|
| 81 |
+
help="Output directory for masks")
|
| 82 |
+
parser.add_argument("--prompt_template", type=str,
|
| 83 |
+
default="{}",
|
| 84 |
+
help="Template wrapping language_instruction. Use {} as placeholder.")
|
| 85 |
+
# "{}"
|
| 86 |
+
# Segment the most suitable manipulation region on the single target object for the task '{}'.
|
| 87 |
+
# Segment the affordance map for the task '{}' in this image.
|
| 88 |
+
# Segment the affordance map of the single target object for the task '{}' in this image.
|
| 89 |
+
# Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask.
|
| 90 |
+
# Given the task instruction '{}', what is the affordance map of the single target object in this image? There is only one target object. Please output segmentation mask.
|
| 91 |
+
parser.add_argument("--start_episode", type=int, default=None,
|
| 92 |
+
help="First episode index to process (inclusive)")
|
| 93 |
+
parser.add_argument("--end_episode", type=int, default=None,
|
| 94 |
+
help="Last episode index to process (exclusive)")
|
| 95 |
+
return parser.parse_args(args)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def preprocess(
|
| 99 |
+
x,
|
| 100 |
+
pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
|
| 101 |
+
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
|
| 102 |
+
img_size=1024,
|
| 103 |
+
) -> torch.Tensor:
|
| 104 |
+
"""Normalize pixel values and pad to a square input."""
|
| 105 |
+
x = (x - pixel_mean) / pixel_std
|
| 106 |
+
h, w = x.shape[-2:]
|
| 107 |
+
padh = img_size - h
|
| 108 |
+
padw = img_size - w
|
| 109 |
+
x = F.pad(x, (0, padw, 0, padh))
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def load_model(args):
|
| 114 |
+
"""Load tokenizer and model, identical to chat.py."""
|
| 115 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 116 |
+
args.version,
|
| 117 |
+
cache_dir=None,
|
| 118 |
+
model_max_length=args.model_max_length,
|
| 119 |
+
padding_side="right",
|
| 120 |
+
use_fast=False,
|
| 121 |
+
)
|
| 122 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 123 |
+
tokenizer.add_tokens("[SEG]")
|
| 124 |
+
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 125 |
+
tokenizer.add_tokens("[AFF]")
|
| 126 |
+
args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
|
| 127 |
+
|
| 128 |
+
torch_dtype = torch.float32
|
| 129 |
+
if args.precision == "bf16":
|
| 130 |
+
torch_dtype = torch.bfloat16
|
| 131 |
+
elif args.precision == "fp16":
|
| 132 |
+
torch_dtype = torch.half
|
| 133 |
+
|
| 134 |
+
kwargs = {"torch_dtype": torch_dtype}
|
| 135 |
+
if args.load_in_4bit:
|
| 136 |
+
kwargs.update({
|
| 137 |
+
"torch_dtype": torch.half,
|
| 138 |
+
"load_in_4bit": True,
|
| 139 |
+
"quantization_config": BitsAndBytesConfig(
|
| 140 |
+
load_in_4bit=True,
|
| 141 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 142 |
+
bnb_4bit_use_double_quant=True,
|
| 143 |
+
bnb_4bit_quant_type="nf4",
|
| 144 |
+
llm_int8_skip_modules=["visual_model"],
|
| 145 |
+
),
|
| 146 |
+
})
|
| 147 |
+
elif args.load_in_8bit:
|
| 148 |
+
kwargs.update({
|
| 149 |
+
"torch_dtype": torch.half,
|
| 150 |
+
"quantization_config": BitsAndBytesConfig(
|
| 151 |
+
llm_int8_skip_modules=["visual_model"],
|
| 152 |
+
load_in_8bit=True,
|
| 153 |
+
),
|
| 154 |
+
})
|
| 155 |
+
|
| 156 |
+
model = AffordanceVLMForCausalLM.from_pretrained(
|
| 157 |
+
args.version,
|
| 158 |
+
low_cpu_mem_usage=True,
|
| 159 |
+
vision_tower=args.vision_tower,
|
| 160 |
+
seg_token_idx=args.seg_token_idx,
|
| 161 |
+
aff_token_idx=args.aff_token_idx,
|
| 162 |
+
**kwargs,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 166 |
+
model.config.bos_token_id = tokenizer.bos_token_id
|
| 167 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 168 |
+
|
| 169 |
+
model.get_model().initialize_vision_modules(model.get_model().config)
|
| 170 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 171 |
+
vision_tower.to(dtype=torch_dtype)
|
| 172 |
+
|
| 173 |
+
if args.precision == "bf16":
|
| 174 |
+
model = model.bfloat16().cuda()
|
| 175 |
+
elif args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit):
|
| 176 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 177 |
+
model.model.vision_tower = None
|
| 178 |
+
import deepspeed
|
| 179 |
+
model_engine = deepspeed.init_inference(
|
| 180 |
+
model=model,
|
| 181 |
+
dtype=torch.half,
|
| 182 |
+
replace_with_kernel_inject=True,
|
| 183 |
+
replace_method="auto",
|
| 184 |
+
)
|
| 185 |
+
model = model_engine.module
|
| 186 |
+
model.model.vision_tower = vision_tower.half().cuda()
|
| 187 |
+
elif args.precision == "fp32":
|
| 188 |
+
model = model.float().cuda()
|
| 189 |
+
|
| 190 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 191 |
+
vision_tower.to(device=args.local_rank)
|
| 192 |
+
|
| 193 |
+
clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
|
| 194 |
+
transform = ResizeLongestSide(args.image_size)
|
| 195 |
+
|
| 196 |
+
model.eval()
|
| 197 |
+
return model, tokenizer, clip_image_processor, transform
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def build_prompt(text: str, args) -> str:
|
| 201 |
+
"""Build the full conversation prompt from a text query."""
|
| 202 |
+
conv = conversation_lib.conv_templates[args.conv_type].copy()
|
| 203 |
+
conv.messages = []
|
| 204 |
+
|
| 205 |
+
prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + text
|
| 206 |
+
if args.use_mm_start_end:
|
| 207 |
+
replace_token = (
|
| 208 |
+
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
| 209 |
+
)
|
| 210 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 211 |
+
|
| 212 |
+
conv.append_message(conv.roles[0], prompt)
|
| 213 |
+
conv.append_message(conv.roles[1], "[AFF].")
|
| 214 |
+
return conv.get_prompt()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def infer_single_image(
|
| 218 |
+
image_path: str,
|
| 219 |
+
prompt_str: str,
|
| 220 |
+
model,
|
| 221 |
+
tokenizer,
|
| 222 |
+
clip_image_processor,
|
| 223 |
+
transform,
|
| 224 |
+
args,
|
| 225 |
+
) -> "np.ndarray | None":
|
| 226 |
+
"""Run inference on a single image. Returns binary mask (H, W) uint8 0/255 or None."""
|
| 227 |
+
image_np = cv2.imread(image_path)
|
| 228 |
+
if image_np is None:
|
| 229 |
+
print(f" [WARNING] Cannot read image: {image_path}")
|
| 230 |
+
return None
|
| 231 |
+
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
|
| 232 |
+
original_size_list = [image_np.shape[:2]]
|
| 233 |
+
|
| 234 |
+
# CLIP preprocessing
|
| 235 |
+
image_clip = (
|
| 236 |
+
clip_image_processor.preprocess(image_np, return_tensors="pt")["pixel_values"][0]
|
| 237 |
+
.unsqueeze(0)
|
| 238 |
+
.cuda()
|
| 239 |
+
)
|
| 240 |
+
if args.precision == "bf16":
|
| 241 |
+
image_clip = image_clip.bfloat16()
|
| 242 |
+
elif args.precision == "fp16":
|
| 243 |
+
image_clip = image_clip.half()
|
| 244 |
+
else:
|
| 245 |
+
image_clip = image_clip.float()
|
| 246 |
+
|
| 247 |
+
# SAM preprocessing
|
| 248 |
+
image = transform.apply_image(image_np)
|
| 249 |
+
resize_list = [image.shape[:2]]
|
| 250 |
+
image = (
|
| 251 |
+
preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
|
| 252 |
+
.unsqueeze(0)
|
| 253 |
+
.cuda()
|
| 254 |
+
)
|
| 255 |
+
if args.precision == "bf16":
|
| 256 |
+
image = image.bfloat16()
|
| 257 |
+
elif args.precision == "fp16":
|
| 258 |
+
image = image.half()
|
| 259 |
+
else:
|
| 260 |
+
image = image.float()
|
| 261 |
+
|
| 262 |
+
# Tokenize
|
| 263 |
+
input_ids = tokenizer_image_token(prompt_str, tokenizer, return_tensors="pt")
|
| 264 |
+
input_ids = input_ids.unsqueeze(0).cuda()
|
| 265 |
+
attention_masks = input_ids.ne(tokenizer.pad_token_id)
|
| 266 |
+
|
| 267 |
+
# Prefill inference (single forward pass instead of autoregressive generation)
|
| 268 |
+
h, w = original_size_list[0]
|
| 269 |
+
labels = input_ids.clone()
|
| 270 |
+
offset = torch.LongTensor([0, 1]).cuda()
|
| 271 |
+
masks_list = [torch.zeros(1, h, w).float().cuda()]
|
| 272 |
+
label_list = [torch.zeros(h, w).long().cuda()]
|
| 273 |
+
|
| 274 |
+
with torch.no_grad():
|
| 275 |
+
output_dict = model(
|
| 276 |
+
images=image,
|
| 277 |
+
images_clip=image_clip,
|
| 278 |
+
input_ids=input_ids,
|
| 279 |
+
labels=labels,
|
| 280 |
+
attention_masks=attention_masks,
|
| 281 |
+
offset=offset,
|
| 282 |
+
masks_list=masks_list,
|
| 283 |
+
label_list=label_list,
|
| 284 |
+
resize_list=resize_list,
|
| 285 |
+
inference=True,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
pred_masks = output_dict["pred_masks"]
|
| 289 |
+
|
| 290 |
+
# Merge all predicted masks via union (logical OR)
|
| 291 |
+
merged = np.zeros((h, w), dtype=bool)
|
| 292 |
+
has_mask = False
|
| 293 |
+
for pred_mask in pred_masks:
|
| 294 |
+
if pred_mask.shape[0] == 0:
|
| 295 |
+
continue
|
| 296 |
+
mask_np = pred_mask.detach().cpu().numpy()[0] # (H, W)
|
| 297 |
+
merged |= (mask_np > 0)
|
| 298 |
+
has_mask = True
|
| 299 |
+
|
| 300 |
+
if not has_mask:
|
| 301 |
+
return None
|
| 302 |
+
|
| 303 |
+
return (merged.astype(np.uint8) * 255)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def read_language_instruction(h5_path: str) -> str:
|
| 307 |
+
"""Read language_instruction from other.h5."""
|
| 308 |
+
with h5py.File(h5_path, "r") as f:
|
| 309 |
+
instr = f["language_instruction"][()]
|
| 310 |
+
if isinstance(instr, bytes):
|
| 311 |
+
instr = instr.decode("utf-8")
|
| 312 |
+
return str(instr)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def main(args):
|
| 316 |
+
args = parse_args(args)
|
| 317 |
+
data_dir = Path(args.data_dir)
|
| 318 |
+
save_dir = Path(args.save_dir)
|
| 319 |
+
|
| 320 |
+
episodes_dir = data_dir / "episodes"
|
| 321 |
+
if not episodes_dir.is_dir():
|
| 322 |
+
print(f"Error: episodes directory not found at {episodes_dir}")
|
| 323 |
+
sys.exit(1)
|
| 324 |
+
|
| 325 |
+
# Collect and sort episode directories
|
| 326 |
+
episode_dirs = sorted(
|
| 327 |
+
[d for d in episodes_dir.iterdir() if d.is_dir()],
|
| 328 |
+
key=lambda p: p.name,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# Filter by episode range
|
| 332 |
+
if args.start_episode is not None or args.end_episode is not None:
|
| 333 |
+
start = args.start_episode if args.start_episode is not None else 0
|
| 334 |
+
end = args.end_episode if args.end_episode is not None else len(episode_dirs)
|
| 335 |
+
episode_dirs = [
|
| 336 |
+
d for d in episode_dirs
|
| 337 |
+
if start <= int(d.name) < end
|
| 338 |
+
]
|
| 339 |
+
|
| 340 |
+
print(f"Data dir : {data_dir}")
|
| 341 |
+
print(f"Save dir : {save_dir}")
|
| 342 |
+
print(f"Episodes : {len(episode_dirs)}")
|
| 343 |
+
print(f"Prompt : {args.prompt_template}")
|
| 344 |
+
print()
|
| 345 |
+
|
| 346 |
+
# Load model
|
| 347 |
+
print("Loading model...")
|
| 348 |
+
model, tokenizer, clip_image_processor, transform = load_model(args)
|
| 349 |
+
print("Model loaded.\n")
|
| 350 |
+
|
| 351 |
+
total_steps = 0
|
| 352 |
+
empty_mask_count = 0
|
| 353 |
+
|
| 354 |
+
for ep_dir in episode_dirs:
|
| 355 |
+
episode_id = ep_dir.name # e.g. "000000"
|
| 356 |
+
steps_dir = ep_dir / "steps"
|
| 357 |
+
if not steps_dir.is_dir():
|
| 358 |
+
print(f" [WARNING] No steps/ in {ep_dir}, skipping.")
|
| 359 |
+
continue
|
| 360 |
+
|
| 361 |
+
step_dirs = sorted(
|
| 362 |
+
[d for d in steps_dir.iterdir() if d.is_dir()],
|
| 363 |
+
key=lambda p: p.name,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
for step_dir in step_dirs:
|
| 367 |
+
step_id = step_dir.name # e.g. "0000"
|
| 368 |
+
|
| 369 |
+
# Read language instruction
|
| 370 |
+
other_h5 = step_dir / "other.h5"
|
| 371 |
+
if not other_h5.exists():
|
| 372 |
+
print(f" [WARNING] Missing other.h5 in {step_dir}, skipping.")
|
| 373 |
+
continue
|
| 374 |
+
language_instruction = read_language_instruction(str(other_h5))
|
| 375 |
+
# debug
|
| 376 |
+
# print(language_instruction)
|
| 377 |
+
|
| 378 |
+
# Build prompt
|
| 379 |
+
query_text = args.prompt_template.format(language_instruction)
|
| 380 |
+
prompt_str = build_prompt(query_text, args)
|
| 381 |
+
|
| 382 |
+
# Output directory (same structure as input: episodes/{episode_id}/steps/{step_id}/)
|
| 383 |
+
out_dir = save_dir / "episodes" / episode_id / "steps" / step_id
|
| 384 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 385 |
+
|
| 386 |
+
# Process both cameras
|
| 387 |
+
for cam_name in ("image_primary", "image_wrist"):
|
| 388 |
+
img_path = step_dir / f"{cam_name}.jpg"
|
| 389 |
+
mask_path = out_dir / f"{cam_name}_mask.png"
|
| 390 |
+
|
| 391 |
+
if not img_path.exists():
|
| 392 |
+
print(f" [WARNING] Missing {img_path}, skipping.")
|
| 393 |
+
continue
|
| 394 |
+
|
| 395 |
+
mask = infer_single_image(
|
| 396 |
+
str(img_path), prompt_str,
|
| 397 |
+
model, tokenizer, clip_image_processor, transform, args,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
if mask is None:
|
| 401 |
+
# Save blank mask and warn
|
| 402 |
+
h, w = cv2.imread(str(img_path)).shape[:2]
|
| 403 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
| 404 |
+
empty_mask_count += 1
|
| 405 |
+
|
| 406 |
+
cv2.imwrite(str(mask_path), mask)
|
| 407 |
+
|
| 408 |
+
total_steps += 1
|
| 409 |
+
if total_steps % 50 == 0:
|
| 410 |
+
print(f" Processed {total_steps} steps (episode {episode_id}, step {step_id})")
|
| 411 |
+
|
| 412 |
+
print(f"Episode {episode_id} done ({len(step_dirs)} steps)")
|
| 413 |
+
|
| 414 |
+
print(f"\nFinished. {total_steps} steps processed, {empty_mask_count} empty masks.")
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
if __name__ == "__main__":
|
| 418 |
+
main(sys.argv[1:])
|
chat.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
|
| 10 |
+
|
| 11 |
+
from model.AffordanceVLM import AffordanceVLMForCausalLM
|
| 12 |
+
from model.llava import conversation as conversation_lib
|
| 13 |
+
from model.llava.mm_utils import tokenizer_image_token
|
| 14 |
+
from model.segment_anything.utils.transforms import ResizeLongestSide
|
| 15 |
+
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
| 16 |
+
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def parse_args(args):
|
| 20 |
+
parser = argparse.ArgumentParser(description="LISA chat")
|
| 21 |
+
parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B")
|
| 22 |
+
parser.add_argument("--vis_save_path", default="./vis_output", type=str)
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
"--precision",
|
| 25 |
+
default="bf16",
|
| 26 |
+
type=str,
|
| 27 |
+
choices=["fp32", "bf16", "fp16"],
|
| 28 |
+
help="precision for inference",
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument("--image_size", default=1024, type=int, help="image size")
|
| 31 |
+
parser.add_argument("--model_max_length", default=512, type=int)
|
| 32 |
+
parser.add_argument("--lora_r", default=8, type=int)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument("--local-rank", default=0, type=int, help="node rank")
|
| 37 |
+
parser.add_argument("--load_in_8bit", action="store_true", default=False)
|
| 38 |
+
parser.add_argument("--load_in_4bit", action="store_true", default=False)
|
| 39 |
+
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--conv_type",
|
| 42 |
+
default="llava_v1",
|
| 43 |
+
type=str,
|
| 44 |
+
choices=["llava_v1", "llava_llama_2"],
|
| 45 |
+
)
|
| 46 |
+
return parser.parse_args(args)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def preprocess(
|
| 50 |
+
x,
|
| 51 |
+
pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
|
| 52 |
+
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
|
| 53 |
+
img_size=1024,
|
| 54 |
+
) -> torch.Tensor:
|
| 55 |
+
"""Normalize pixel values and pad to a square input."""
|
| 56 |
+
# Normalize colors
|
| 57 |
+
x = (x - pixel_mean) / pixel_std
|
| 58 |
+
# Pad
|
| 59 |
+
h, w = x.shape[-2:]
|
| 60 |
+
padh = img_size - h
|
| 61 |
+
padw = img_size - w
|
| 62 |
+
x = F.pad(x, (0, padw, 0, padh))
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def main(args):
|
| 67 |
+
args = parse_args(args)
|
| 68 |
+
os.makedirs(args.vis_save_path, exist_ok=True)
|
| 69 |
+
|
| 70 |
+
# Create model
|
| 71 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 72 |
+
args.version,
|
| 73 |
+
cache_dir=None,
|
| 74 |
+
model_max_length=args.model_max_length,
|
| 75 |
+
padding_side="right",
|
| 76 |
+
use_fast=False,
|
| 77 |
+
)
|
| 78 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 79 |
+
num_added_tokens = tokenizer.add_tokens("[SEG]")
|
| 80 |
+
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 81 |
+
num_added_tokens = tokenizer.add_tokens("[AFF]")
|
| 82 |
+
args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
|
| 83 |
+
|
| 84 |
+
torch_dtype = torch.float32
|
| 85 |
+
if args.precision == "bf16":
|
| 86 |
+
torch_dtype = torch.bfloat16
|
| 87 |
+
elif args.precision == "fp16":
|
| 88 |
+
torch_dtype = torch.half
|
| 89 |
+
|
| 90 |
+
kwargs = {"torch_dtype": torch_dtype}
|
| 91 |
+
if args.load_in_4bit:
|
| 92 |
+
kwargs.update(
|
| 93 |
+
{
|
| 94 |
+
"torch_dtype": torch.half,
|
| 95 |
+
"load_in_4bit": True,
|
| 96 |
+
"quantization_config": BitsAndBytesConfig(
|
| 97 |
+
load_in_4bit=True,
|
| 98 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 99 |
+
bnb_4bit_use_double_quant=True,
|
| 100 |
+
bnb_4bit_quant_type="nf4",
|
| 101 |
+
llm_int8_skip_modules=["visual_model"],
|
| 102 |
+
),
|
| 103 |
+
}
|
| 104 |
+
)
|
| 105 |
+
elif args.load_in_8bit:
|
| 106 |
+
kwargs.update(
|
| 107 |
+
{
|
| 108 |
+
"torch_dtype": torch.half,
|
| 109 |
+
"quantization_config": BitsAndBytesConfig(
|
| 110 |
+
llm_int8_skip_modules=["visual_model"],
|
| 111 |
+
load_in_8bit=True,
|
| 112 |
+
),
|
| 113 |
+
}
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
model = AffordanceVLMForCausalLM.from_pretrained(
|
| 117 |
+
args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, aff_token_idx=args.aff_token_idx, **kwargs
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 121 |
+
model.config.bos_token_id = tokenizer.bos_token_id
|
| 122 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 123 |
+
|
| 124 |
+
model.get_model().initialize_vision_modules(model.get_model().config)
|
| 125 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 126 |
+
vision_tower.to(dtype=torch_dtype)
|
| 127 |
+
|
| 128 |
+
if args.precision == "bf16":
|
| 129 |
+
model = model.bfloat16().cuda()
|
| 130 |
+
elif (
|
| 131 |
+
args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit)
|
| 132 |
+
):
|
| 133 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 134 |
+
model.model.vision_tower = None
|
| 135 |
+
import deepspeed
|
| 136 |
+
|
| 137 |
+
model_engine = deepspeed.init_inference(
|
| 138 |
+
model=model,
|
| 139 |
+
dtype=torch.half,
|
| 140 |
+
replace_with_kernel_inject=True,
|
| 141 |
+
replace_method="auto",
|
| 142 |
+
)
|
| 143 |
+
model = model_engine.module
|
| 144 |
+
model.model.vision_tower = vision_tower.half().cuda()
|
| 145 |
+
elif args.precision == "fp32":
|
| 146 |
+
model = model.float().cuda()
|
| 147 |
+
|
| 148 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 149 |
+
vision_tower.to(device=args.local_rank)
|
| 150 |
+
|
| 151 |
+
clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
|
| 152 |
+
transform = ResizeLongestSide(args.image_size)
|
| 153 |
+
|
| 154 |
+
model.eval()
|
| 155 |
+
|
| 156 |
+
while True:
|
| 157 |
+
conv = conversation_lib.conv_templates[args.conv_type].copy()
|
| 158 |
+
conv.messages = []
|
| 159 |
+
|
| 160 |
+
prompt = input("Please input your prompt: ")
|
| 161 |
+
prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + prompt
|
| 162 |
+
if args.use_mm_start_end:
|
| 163 |
+
replace_token = (
|
| 164 |
+
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
| 165 |
+
)
|
| 166 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 167 |
+
|
| 168 |
+
conv.append_message(conv.roles[0], prompt)
|
| 169 |
+
conv.append_message(conv.roles[1], "")
|
| 170 |
+
prompt = conv.get_prompt()
|
| 171 |
+
|
| 172 |
+
image_path = input("Please input the image path: ")
|
| 173 |
+
if not os.path.exists(image_path):
|
| 174 |
+
print("File not found in {}".format(image_path))
|
| 175 |
+
continue
|
| 176 |
+
|
| 177 |
+
image_np = cv2.imread(image_path)
|
| 178 |
+
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
|
| 179 |
+
original_size_list = [image_np.shape[:2]]
|
| 180 |
+
|
| 181 |
+
image_clip = (
|
| 182 |
+
clip_image_processor.preprocess(image_np, return_tensors="pt")[
|
| 183 |
+
"pixel_values"
|
| 184 |
+
][0]
|
| 185 |
+
.unsqueeze(0)
|
| 186 |
+
.cuda()
|
| 187 |
+
)
|
| 188 |
+
if args.precision == "bf16":
|
| 189 |
+
image_clip = image_clip.bfloat16()
|
| 190 |
+
elif args.precision == "fp16":
|
| 191 |
+
image_clip = image_clip.half()
|
| 192 |
+
else:
|
| 193 |
+
image_clip = image_clip.float()
|
| 194 |
+
|
| 195 |
+
image = transform.apply_image(image_np)
|
| 196 |
+
resize_list = [image.shape[:2]]
|
| 197 |
+
|
| 198 |
+
image = (
|
| 199 |
+
preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
|
| 200 |
+
.unsqueeze(0)
|
| 201 |
+
.cuda()
|
| 202 |
+
)
|
| 203 |
+
if args.precision == "bf16":
|
| 204 |
+
image = image.bfloat16()
|
| 205 |
+
elif args.precision == "fp16":
|
| 206 |
+
image = image.half()
|
| 207 |
+
else:
|
| 208 |
+
image = image.float()
|
| 209 |
+
|
| 210 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 211 |
+
input_ids = input_ids.unsqueeze(0).cuda()
|
| 212 |
+
|
| 213 |
+
output_ids, pred_masks = model.evaluate(
|
| 214 |
+
image_clip,
|
| 215 |
+
image,
|
| 216 |
+
input_ids,
|
| 217 |
+
resize_list,
|
| 218 |
+
original_size_list,
|
| 219 |
+
max_new_tokens=512,
|
| 220 |
+
tokenizer=tokenizer,
|
| 221 |
+
)
|
| 222 |
+
output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
|
| 223 |
+
|
| 224 |
+
text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
|
| 225 |
+
text_output = text_output.replace("\n", "").replace(" ", " ")
|
| 226 |
+
print("text_output: ", text_output)
|
| 227 |
+
|
| 228 |
+
for i, pred_mask in enumerate(pred_masks):
|
| 229 |
+
if pred_mask.shape[0] == 0:
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
pred_mask = pred_mask.detach().cpu().numpy()[0]
|
| 233 |
+
pred_mask = pred_mask > 0
|
| 234 |
+
|
| 235 |
+
save_path = "{}/{}_mask_{}.jpg".format(
|
| 236 |
+
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
|
| 237 |
+
)
|
| 238 |
+
cv2.imwrite(save_path, pred_mask * 100)
|
| 239 |
+
print("{} has been saved.".format(save_path))
|
| 240 |
+
|
| 241 |
+
save_path = "{}/{}_masked_img_{}.jpg".format(
|
| 242 |
+
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
|
| 243 |
+
)
|
| 244 |
+
save_img = image_np.copy()
|
| 245 |
+
save_img[pred_mask] = (
|
| 246 |
+
image_np * 0.5
|
| 247 |
+
+ pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
|
| 248 |
+
)[pred_mask]
|
| 249 |
+
save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
|
| 250 |
+
cv2.imwrite(save_path, save_img)
|
| 251 |
+
print("{} has been saved.".format(save_path))
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
if __name__ == "__main__":
|
| 255 |
+
main(sys.argv[1:])
|
chat_prefill.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Interactive affordance mask generation using prefill mode (single forward pass).
|
| 3 |
+
|
| 4 |
+
Same interactive workflow as chat.py, but uses prefill inference instead of
|
| 5 |
+
autoregressive generation. The assistant response "[AFF]." is pre-filled in the
|
| 6 |
+
prompt, so the model only does one forward pass to extract mask embeddings.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
import cv2
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
|
| 18 |
+
|
| 19 |
+
from model.AffordanceVLM import AffordanceVLMForCausalLM
|
| 20 |
+
from model.llava import conversation as conversation_lib
|
| 21 |
+
from model.llava.mm_utils import tokenizer_image_token
|
| 22 |
+
from model.segment_anything.utils.transforms import ResizeLongestSide
|
| 23 |
+
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
| 24 |
+
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def parse_args(args):
|
| 28 |
+
parser = argparse.ArgumentParser(description="AffordanceVLM chat (prefill mode)")
|
| 29 |
+
parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B")
|
| 30 |
+
parser.add_argument("--vis_save_path", default="./vis_output_prefill", type=str)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--precision", default="bf16", type=str,
|
| 33 |
+
choices=["fp32", "bf16", "fp16"],
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument("--image_size", default=1024, type=int)
|
| 36 |
+
parser.add_argument("--model_max_length", default=512, type=int)
|
| 37 |
+
parser.add_argument("--lora_r", default=8, type=int)
|
| 38 |
+
parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14", type=str)
|
| 39 |
+
parser.add_argument("--local-rank", default=0, type=int)
|
| 40 |
+
parser.add_argument("--load_in_8bit", action="store_true", default=False)
|
| 41 |
+
parser.add_argument("--load_in_4bit", action="store_true", default=False)
|
| 42 |
+
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--conv_type", default="llava_v1", type=str,
|
| 45 |
+
choices=["llava_v1", "llava_llama_2"],
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument("--prompt_template", type=str,
|
| 48 |
+
default="Segment the most suitable manipulation region on the single target object for the task '{}'.",
|
| 49 |
+
help="Template wrapping language_instruction. Use {} as placeholder.")
|
| 50 |
+
# Segment the most suitable manipulation region on the single target object for the task '{}'.
|
| 51 |
+
# Segment the affordance map for the task '{}' in this image.
|
| 52 |
+
# Segment the affordance map of the single target object for the task '{}' in this image.
|
| 53 |
+
# Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask.
|
| 54 |
+
# Given the task instruction '{}', what is the affordance map of the single target object in this image? There is only one target object. Please output segmentation mask.
|
| 55 |
+
return parser.parse_args(args)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def preprocess(
|
| 59 |
+
x,
|
| 60 |
+
pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
|
| 61 |
+
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
|
| 62 |
+
img_size=1024,
|
| 63 |
+
) -> torch.Tensor:
|
| 64 |
+
"""Normalize pixel values and pad to a square input."""
|
| 65 |
+
x = (x - pixel_mean) / pixel_std
|
| 66 |
+
h, w = x.shape[-2:]
|
| 67 |
+
padh = img_size - h
|
| 68 |
+
padw = img_size - w
|
| 69 |
+
x = F.pad(x, (0, padw, 0, padh))
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def main(args):
|
| 74 |
+
args = parse_args(args)
|
| 75 |
+
os.makedirs(args.vis_save_path, exist_ok=True)
|
| 76 |
+
|
| 77 |
+
# Create model
|
| 78 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 79 |
+
args.version,
|
| 80 |
+
cache_dir=None,
|
| 81 |
+
model_max_length=args.model_max_length,
|
| 82 |
+
padding_side="right",
|
| 83 |
+
use_fast=False,
|
| 84 |
+
)
|
| 85 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 86 |
+
tokenizer.add_tokens("[SEG]")
|
| 87 |
+
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 88 |
+
tokenizer.add_tokens("[AFF]")
|
| 89 |
+
args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
|
| 90 |
+
|
| 91 |
+
torch_dtype = torch.float32
|
| 92 |
+
if args.precision == "bf16":
|
| 93 |
+
torch_dtype = torch.bfloat16
|
| 94 |
+
elif args.precision == "fp16":
|
| 95 |
+
torch_dtype = torch.half
|
| 96 |
+
|
| 97 |
+
kwargs = {"torch_dtype": torch_dtype}
|
| 98 |
+
if args.load_in_4bit:
|
| 99 |
+
kwargs.update({
|
| 100 |
+
"torch_dtype": torch.half,
|
| 101 |
+
"load_in_4bit": True,
|
| 102 |
+
"quantization_config": BitsAndBytesConfig(
|
| 103 |
+
load_in_4bit=True,
|
| 104 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 105 |
+
bnb_4bit_use_double_quant=True,
|
| 106 |
+
bnb_4bit_quant_type="nf4",
|
| 107 |
+
llm_int8_skip_modules=["visual_model"],
|
| 108 |
+
),
|
| 109 |
+
})
|
| 110 |
+
elif args.load_in_8bit:
|
| 111 |
+
kwargs.update({
|
| 112 |
+
"torch_dtype": torch.half,
|
| 113 |
+
"quantization_config": BitsAndBytesConfig(
|
| 114 |
+
llm_int8_skip_modules=["visual_model"],
|
| 115 |
+
load_in_8bit=True,
|
| 116 |
+
),
|
| 117 |
+
})
|
| 118 |
+
|
| 119 |
+
model = AffordanceVLMForCausalLM.from_pretrained(
|
| 120 |
+
args.version,
|
| 121 |
+
low_cpu_mem_usage=True,
|
| 122 |
+
vision_tower=args.vision_tower,
|
| 123 |
+
seg_token_idx=args.seg_token_idx,
|
| 124 |
+
aff_token_idx=args.aff_token_idx,
|
| 125 |
+
**kwargs,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 129 |
+
model.config.bos_token_id = tokenizer.bos_token_id
|
| 130 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 131 |
+
|
| 132 |
+
model.get_model().initialize_vision_modules(model.get_model().config)
|
| 133 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 134 |
+
vision_tower.to(dtype=torch_dtype)
|
| 135 |
+
|
| 136 |
+
if args.precision == "bf16":
|
| 137 |
+
model = model.bfloat16().cuda()
|
| 138 |
+
elif args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit):
|
| 139 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 140 |
+
model.model.vision_tower = None
|
| 141 |
+
import deepspeed
|
| 142 |
+
model_engine = deepspeed.init_inference(
|
| 143 |
+
model=model,
|
| 144 |
+
dtype=torch.half,
|
| 145 |
+
replace_with_kernel_inject=True,
|
| 146 |
+
replace_method="auto",
|
| 147 |
+
)
|
| 148 |
+
model = model_engine.module
|
| 149 |
+
model.model.vision_tower = vision_tower.half().cuda()
|
| 150 |
+
elif args.precision == "fp32":
|
| 151 |
+
model = model.float().cuda()
|
| 152 |
+
|
| 153 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 154 |
+
vision_tower.to(device=args.local_rank)
|
| 155 |
+
|
| 156 |
+
clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
|
| 157 |
+
transform = ResizeLongestSide(args.image_size)
|
| 158 |
+
|
| 159 |
+
model.eval()
|
| 160 |
+
|
| 161 |
+
# debug
|
| 162 |
+
template = "Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask."
|
| 163 |
+
|
| 164 |
+
while True:
|
| 165 |
+
conv = conversation_lib.conv_templates[args.conv_type].copy()
|
| 166 |
+
conv.messages = []
|
| 167 |
+
|
| 168 |
+
prompt = input("Please input your prompt: ")
|
| 169 |
+
# 加入模版
|
| 170 |
+
prompt = args.prompt_template.format(prompt)
|
| 171 |
+
|
| 172 |
+
prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + prompt
|
| 173 |
+
if args.use_mm_start_end:
|
| 174 |
+
replace_token = (
|
| 175 |
+
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
| 176 |
+
)
|
| 177 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
| 178 |
+
|
| 179 |
+
conv.append_message(conv.roles[0], prompt)
|
| 180 |
+
conv.append_message(conv.roles[1], "[AFF].")
|
| 181 |
+
prompt = conv.get_prompt()
|
| 182 |
+
|
| 183 |
+
image_path = input("Please input the image path: ")
|
| 184 |
+
if not os.path.exists(image_path):
|
| 185 |
+
print("File not found in {}".format(image_path))
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
image_np = cv2.imread(image_path)
|
| 189 |
+
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
|
| 190 |
+
original_size_list = [image_np.shape[:2]]
|
| 191 |
+
h, w = original_size_list[0]
|
| 192 |
+
|
| 193 |
+
image_clip = (
|
| 194 |
+
clip_image_processor.preprocess(image_np, return_tensors="pt")[
|
| 195 |
+
"pixel_values"
|
| 196 |
+
][0]
|
| 197 |
+
.unsqueeze(0)
|
| 198 |
+
.cuda()
|
| 199 |
+
)
|
| 200 |
+
if args.precision == "bf16":
|
| 201 |
+
image_clip = image_clip.bfloat16()
|
| 202 |
+
elif args.precision == "fp16":
|
| 203 |
+
image_clip = image_clip.half()
|
| 204 |
+
else:
|
| 205 |
+
image_clip = image_clip.float()
|
| 206 |
+
|
| 207 |
+
image = transform.apply_image(image_np)
|
| 208 |
+
resize_list = [image.shape[:2]]
|
| 209 |
+
|
| 210 |
+
image = (
|
| 211 |
+
preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
|
| 212 |
+
.unsqueeze(0)
|
| 213 |
+
.cuda()
|
| 214 |
+
)
|
| 215 |
+
if args.precision == "bf16":
|
| 216 |
+
image = image.bfloat16()
|
| 217 |
+
elif args.precision == "fp16":
|
| 218 |
+
image = image.half()
|
| 219 |
+
else:
|
| 220 |
+
image = image.float()
|
| 221 |
+
|
| 222 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 223 |
+
input_ids = input_ids.unsqueeze(0).cuda()
|
| 224 |
+
attention_masks = input_ids.ne(tokenizer.pad_token_id)
|
| 225 |
+
|
| 226 |
+
# Print the full prompt text (prefill mode has no generated text)
|
| 227 |
+
# debug
|
| 228 |
+
text_ids = input_ids[0][input_ids[0] != IMAGE_TOKEN_INDEX]
|
| 229 |
+
text_output = tokenizer.decode(text_ids, skip_special_tokens=False)
|
| 230 |
+
text_output = text_output.replace("\n", "").replace(" ", " ")
|
| 231 |
+
print("text_output: ", text_output)
|
| 232 |
+
|
| 233 |
+
# Prefill inference
|
| 234 |
+
labels = input_ids.clone()
|
| 235 |
+
offset = torch.LongTensor([0, 1]).cuda()
|
| 236 |
+
masks_list = [torch.zeros(1, h, w).float().cuda()]
|
| 237 |
+
label_list = [torch.zeros(h, w).long().cuda()]
|
| 238 |
+
|
| 239 |
+
with torch.no_grad():
|
| 240 |
+
output_dict = model(
|
| 241 |
+
images=image,
|
| 242 |
+
images_clip=image_clip,
|
| 243 |
+
input_ids=input_ids,
|
| 244 |
+
labels=labels,
|
| 245 |
+
attention_masks=attention_masks,
|
| 246 |
+
offset=offset,
|
| 247 |
+
masks_list=masks_list,
|
| 248 |
+
label_list=label_list,
|
| 249 |
+
resize_list=resize_list,
|
| 250 |
+
inference=True,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
pred_masks = output_dict["pred_masks"]
|
| 254 |
+
|
| 255 |
+
for i, pred_mask in enumerate(pred_masks):
|
| 256 |
+
if pred_mask.shape[0] == 0:
|
| 257 |
+
continue
|
| 258 |
+
|
| 259 |
+
pred_mask = pred_mask.detach().cpu().numpy()[0]
|
| 260 |
+
pred_mask = pred_mask > 0
|
| 261 |
+
|
| 262 |
+
save_path = "{}/{}_mask_{}.jpg".format(
|
| 263 |
+
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
|
| 264 |
+
)
|
| 265 |
+
cv2.imwrite(save_path, pred_mask * 100)
|
| 266 |
+
print("{} has been saved.".format(save_path))
|
| 267 |
+
|
| 268 |
+
save_path = "{}/{}_masked_img_{}.jpg".format(
|
| 269 |
+
args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
|
| 270 |
+
)
|
| 271 |
+
save_img = image_np.copy()
|
| 272 |
+
save_img[pred_mask] = (
|
| 273 |
+
image_np * 0.5
|
| 274 |
+
+ pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
|
| 275 |
+
)[pred_mask]
|
| 276 |
+
save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
|
| 277 |
+
cv2.imwrite(save_path, save_img)
|
| 278 |
+
print("{} has been saved.".format(save_path))
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
if __name__ == "__main__":
|
| 282 |
+
main(sys.argv[1:])
|
ckpts/AffordanceVLM-7B/.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
ckpts/AffordanceVLM-7B/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
---
|
ckpts/AffordanceVLM-7B/added_tokens.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<im_end>": 32002,
|
| 3 |
+
"<im_patch>": 32000,
|
| 4 |
+
"<im_start>": 32001,
|
| 5 |
+
"[AFF]": 32004,
|
| 6 |
+
"[SEG]": 32003
|
| 7 |
+
}
|
ckpts/AffordanceVLM-7B/config.json
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "./LLaVA/LLaVA-Lightning-7B-v1-1",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"AffordanceVLMForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"bos_token_id": 1,
|
| 7 |
+
"eos_token_id": 2,
|
| 8 |
+
"freeze_mm_mlp_adapter": true,
|
| 9 |
+
"hidden_act": "silu",
|
| 10 |
+
"hidden_size": 4096,
|
| 11 |
+
"image_aspect_ratio": "square",
|
| 12 |
+
"image_grid_pinpoints": null,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"intermediate_size": 11008,
|
| 15 |
+
"max_position_embeddings": 2048,
|
| 16 |
+
"max_sequence_length": 2048,
|
| 17 |
+
"mm_hidden_size": 1024,
|
| 18 |
+
"mm_use_im_patch_token": false,
|
| 19 |
+
"mm_use_im_start_end": true,
|
| 20 |
+
"mm_vision_select_feature": "patch",
|
| 21 |
+
"mm_vision_select_layer": -2,
|
| 22 |
+
"mm_vision_tower": "openai/clip-vit-large-patch14",
|
| 23 |
+
"model_type": "llava",
|
| 24 |
+
"num_attention_heads": 32,
|
| 25 |
+
"num_hidden_layers": 32,
|
| 26 |
+
"num_key_value_heads": 32,
|
| 27 |
+
"out_dim": 256,
|
| 28 |
+
"pad_token_id": 0,
|
| 29 |
+
"pretrain_mm_mlp_adapter": null,
|
| 30 |
+
"pretraining_tp": 1,
|
| 31 |
+
"rms_norm_eps": 1e-06,
|
| 32 |
+
"rope_scaling": null,
|
| 33 |
+
"tie_word_embeddings": false,
|
| 34 |
+
"torch_dtype": "bfloat16",
|
| 35 |
+
"train_mask_decoder": true,
|
| 36 |
+
"transformers_version": "4.31.0",
|
| 37 |
+
"tune_mm_mlp_adapter": false,
|
| 38 |
+
"use_cache": false,
|
| 39 |
+
"use_mm_proj": true,
|
| 40 |
+
"vision_tower": "openai/clip-vit-large-patch14",
|
| 41 |
+
"vocab_size": 32005
|
| 42 |
+
}
|
ckpts/AffordanceVLM-7B/eval_result.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
dataset: handal_all, giou: 0.60872483253479, ciou: 0.6054294109344482
|
ckpts/AffordanceVLM-7B/generation_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 0,
|
| 4 |
+
"eos_token_id": 1,
|
| 5 |
+
"pad_token_id": 0,
|
| 6 |
+
"transformers_version": "4.31.0"
|
| 7 |
+
}
|
ckpts/AffordanceVLM-7B/pytorch_model-00001-of-00002.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:efdb3ff9accdd733412d083c770ba34ae1c6745b28e2bae07d3546dc9356bfec
|
| 3 |
+
size 9976675518
|
ckpts/AffordanceVLM-7B/pytorch_model-00002-of-00002.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7259eabdd3c03be21d45a328177ac3e46e1385cbc5ff2d757cd8bb70dec81ae9
|
| 3 |
+
size 6144654233
|
ckpts/AffordanceVLM-7B/pytorch_model.bin.index.json
ADDED
|
@@ -0,0 +1,930 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 16121002176
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"lm_head.weight": "pytorch_model-00002-of-00002.bin",
|
| 7 |
+
"model.embed_tokens.weight": "pytorch_model-00001-of-00002.bin",
|
| 8 |
+
"model.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 9 |
+
"model.layers.0.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 10 |
+
"model.layers.0.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 11 |
+
"model.layers.0.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 12 |
+
"model.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 13 |
+
"model.layers.0.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 14 |
+
"model.layers.0.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 15 |
+
"model.layers.0.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 16 |
+
"model.layers.0.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 17 |
+
"model.layers.0.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 18 |
+
"model.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 19 |
+
"model.layers.1.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 20 |
+
"model.layers.1.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 21 |
+
"model.layers.1.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 22 |
+
"model.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 23 |
+
"model.layers.1.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 24 |
+
"model.layers.1.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 25 |
+
"model.layers.1.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 26 |
+
"model.layers.1.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 27 |
+
"model.layers.1.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 28 |
+
"model.layers.10.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 29 |
+
"model.layers.10.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 30 |
+
"model.layers.10.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 31 |
+
"model.layers.10.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 32 |
+
"model.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 33 |
+
"model.layers.10.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 34 |
+
"model.layers.10.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 35 |
+
"model.layers.10.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 36 |
+
"model.layers.10.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 37 |
+
"model.layers.10.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 38 |
+
"model.layers.11.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 39 |
+
"model.layers.11.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 40 |
+
"model.layers.11.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 41 |
+
"model.layers.11.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 42 |
+
"model.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 43 |
+
"model.layers.11.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 44 |
+
"model.layers.11.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 45 |
+
"model.layers.11.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 46 |
+
"model.layers.11.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 47 |
+
"model.layers.11.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 48 |
+
"model.layers.12.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 49 |
+
"model.layers.12.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 50 |
+
"model.layers.12.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 51 |
+
"model.layers.12.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 52 |
+
"model.layers.12.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 53 |
+
"model.layers.12.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 54 |
+
"model.layers.12.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 55 |
+
"model.layers.12.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 56 |
+
"model.layers.12.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 57 |
+
"model.layers.12.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 58 |
+
"model.layers.13.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 59 |
+
"model.layers.13.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 60 |
+
"model.layers.13.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 61 |
+
"model.layers.13.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 62 |
+
"model.layers.13.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 63 |
+
"model.layers.13.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 64 |
+
"model.layers.13.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 65 |
+
"model.layers.13.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 66 |
+
"model.layers.13.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 67 |
+
"model.layers.13.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 68 |
+
"model.layers.14.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 69 |
+
"model.layers.14.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 70 |
+
"model.layers.14.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 71 |
+
"model.layers.14.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 72 |
+
"model.layers.14.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 73 |
+
"model.layers.14.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 74 |
+
"model.layers.14.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 75 |
+
"model.layers.14.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 76 |
+
"model.layers.14.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 77 |
+
"model.layers.14.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 78 |
+
"model.layers.15.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 79 |
+
"model.layers.15.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 80 |
+
"model.layers.15.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 81 |
+
"model.layers.15.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 82 |
+
"model.layers.15.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 83 |
+
"model.layers.15.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 84 |
+
"model.layers.15.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 85 |
+
"model.layers.15.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 86 |
+
"model.layers.15.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 87 |
+
"model.layers.15.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 88 |
+
"model.layers.16.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 89 |
+
"model.layers.16.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 90 |
+
"model.layers.16.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 91 |
+
"model.layers.16.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 92 |
+
"model.layers.16.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 93 |
+
"model.layers.16.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 94 |
+
"model.layers.16.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 95 |
+
"model.layers.16.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 96 |
+
"model.layers.16.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 97 |
+
"model.layers.16.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 98 |
+
"model.layers.17.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 99 |
+
"model.layers.17.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 100 |
+
"model.layers.17.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 101 |
+
"model.layers.17.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 102 |
+
"model.layers.17.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 103 |
+
"model.layers.17.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 104 |
+
"model.layers.17.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 105 |
+
"model.layers.17.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 106 |
+
"model.layers.17.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 107 |
+
"model.layers.17.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 108 |
+
"model.layers.18.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 109 |
+
"model.layers.18.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 110 |
+
"model.layers.18.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 111 |
+
"model.layers.18.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 112 |
+
"model.layers.18.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 113 |
+
"model.layers.18.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 114 |
+
"model.layers.18.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 115 |
+
"model.layers.18.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 116 |
+
"model.layers.18.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 117 |
+
"model.layers.18.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 118 |
+
"model.layers.19.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 119 |
+
"model.layers.19.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 120 |
+
"model.layers.19.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 121 |
+
"model.layers.19.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 122 |
+
"model.layers.19.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 123 |
+
"model.layers.19.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 124 |
+
"model.layers.19.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 125 |
+
"model.layers.19.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 126 |
+
"model.layers.19.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 127 |
+
"model.layers.19.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 128 |
+
"model.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 129 |
+
"model.layers.2.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 130 |
+
"model.layers.2.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 131 |
+
"model.layers.2.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 132 |
+
"model.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 133 |
+
"model.layers.2.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 134 |
+
"model.layers.2.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 135 |
+
"model.layers.2.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 136 |
+
"model.layers.2.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 137 |
+
"model.layers.2.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 138 |
+
"model.layers.20.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 139 |
+
"model.layers.20.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 140 |
+
"model.layers.20.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 141 |
+
"model.layers.20.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 142 |
+
"model.layers.20.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 143 |
+
"model.layers.20.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 144 |
+
"model.layers.20.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 145 |
+
"model.layers.20.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 146 |
+
"model.layers.20.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 147 |
+
"model.layers.20.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 148 |
+
"model.layers.21.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 149 |
+
"model.layers.21.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 150 |
+
"model.layers.21.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 151 |
+
"model.layers.21.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 152 |
+
"model.layers.21.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 153 |
+
"model.layers.21.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 154 |
+
"model.layers.21.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 155 |
+
"model.layers.21.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 156 |
+
"model.layers.21.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 157 |
+
"model.layers.21.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 158 |
+
"model.layers.22.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 159 |
+
"model.layers.22.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 160 |
+
"model.layers.22.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 161 |
+
"model.layers.22.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 162 |
+
"model.layers.22.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 163 |
+
"model.layers.22.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 164 |
+
"model.layers.22.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 165 |
+
"model.layers.22.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 166 |
+
"model.layers.22.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 167 |
+
"model.layers.22.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 168 |
+
"model.layers.23.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 169 |
+
"model.layers.23.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 170 |
+
"model.layers.23.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 171 |
+
"model.layers.23.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 172 |
+
"model.layers.23.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 173 |
+
"model.layers.23.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 174 |
+
"model.layers.23.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 175 |
+
"model.layers.23.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 176 |
+
"model.layers.23.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 177 |
+
"model.layers.23.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 178 |
+
"model.layers.24.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 179 |
+
"model.layers.24.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 180 |
+
"model.layers.24.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 181 |
+
"model.layers.24.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 182 |
+
"model.layers.24.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 183 |
+
"model.layers.24.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 184 |
+
"model.layers.24.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 185 |
+
"model.layers.24.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 186 |
+
"model.layers.24.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
| 187 |
+
"model.layers.24.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 188 |
+
"model.layers.25.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 189 |
+
"model.layers.25.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 190 |
+
"model.layers.25.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 191 |
+
"model.layers.25.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 192 |
+
"model.layers.25.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 193 |
+
"model.layers.25.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 194 |
+
"model.layers.25.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 195 |
+
"model.layers.25.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 196 |
+
"model.layers.25.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
| 197 |
+
"model.layers.25.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 198 |
+
"model.layers.26.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 199 |
+
"model.layers.26.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 200 |
+
"model.layers.26.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 201 |
+
"model.layers.26.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 202 |
+
"model.layers.26.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 203 |
+
"model.layers.26.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 204 |
+
"model.layers.26.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 205 |
+
"model.layers.26.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 206 |
+
"model.layers.26.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
| 207 |
+
"model.layers.26.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 208 |
+
"model.layers.27.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 209 |
+
"model.layers.27.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 210 |
+
"model.layers.27.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 211 |
+
"model.layers.27.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 212 |
+
"model.layers.27.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 213 |
+
"model.layers.27.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 214 |
+
"model.layers.27.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 215 |
+
"model.layers.27.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 216 |
+
"model.layers.27.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
| 217 |
+
"model.layers.27.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 218 |
+
"model.layers.28.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 219 |
+
"model.layers.28.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 220 |
+
"model.layers.28.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 221 |
+
"model.layers.28.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 222 |
+
"model.layers.28.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 223 |
+
"model.layers.28.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 224 |
+
"model.layers.28.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 225 |
+
"model.layers.28.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 226 |
+
"model.layers.28.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
| 227 |
+
"model.layers.28.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 228 |
+
"model.layers.29.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 229 |
+
"model.layers.29.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 230 |
+
"model.layers.29.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 231 |
+
"model.layers.29.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 232 |
+
"model.layers.29.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 233 |
+
"model.layers.29.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 234 |
+
"model.layers.29.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 235 |
+
"model.layers.29.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 236 |
+
"model.layers.29.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
| 237 |
+
"model.layers.29.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 238 |
+
"model.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 239 |
+
"model.layers.3.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 240 |
+
"model.layers.3.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 241 |
+
"model.layers.3.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 242 |
+
"model.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 243 |
+
"model.layers.3.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 244 |
+
"model.layers.3.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 245 |
+
"model.layers.3.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 246 |
+
"model.layers.3.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 247 |
+
"model.layers.3.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 248 |
+
"model.layers.30.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 249 |
+
"model.layers.30.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 250 |
+
"model.layers.30.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 251 |
+
"model.layers.30.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 252 |
+
"model.layers.30.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 253 |
+
"model.layers.30.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 254 |
+
"model.layers.30.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 255 |
+
"model.layers.30.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 256 |
+
"model.layers.30.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
| 257 |
+
"model.layers.30.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 258 |
+
"model.layers.31.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 259 |
+
"model.layers.31.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 260 |
+
"model.layers.31.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 261 |
+
"model.layers.31.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 262 |
+
"model.layers.31.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
|
| 263 |
+
"model.layers.31.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 264 |
+
"model.layers.31.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 265 |
+
"model.layers.31.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 266 |
+
"model.layers.31.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
|
| 267 |
+
"model.layers.31.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 268 |
+
"model.layers.4.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 269 |
+
"model.layers.4.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 270 |
+
"model.layers.4.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 271 |
+
"model.layers.4.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 272 |
+
"model.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 273 |
+
"model.layers.4.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 274 |
+
"model.layers.4.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 275 |
+
"model.layers.4.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 276 |
+
"model.layers.4.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 277 |
+
"model.layers.4.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 278 |
+
"model.layers.5.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 279 |
+
"model.layers.5.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 280 |
+
"model.layers.5.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 281 |
+
"model.layers.5.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 282 |
+
"model.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 283 |
+
"model.layers.5.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 284 |
+
"model.layers.5.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 285 |
+
"model.layers.5.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 286 |
+
"model.layers.5.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 287 |
+
"model.layers.5.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 288 |
+
"model.layers.6.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 289 |
+
"model.layers.6.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 290 |
+
"model.layers.6.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 291 |
+
"model.layers.6.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 292 |
+
"model.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 293 |
+
"model.layers.6.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 294 |
+
"model.layers.6.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 295 |
+
"model.layers.6.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 296 |
+
"model.layers.6.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 297 |
+
"model.layers.6.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 298 |
+
"model.layers.7.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 299 |
+
"model.layers.7.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 300 |
+
"model.layers.7.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 301 |
+
"model.layers.7.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 302 |
+
"model.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 303 |
+
"model.layers.7.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 304 |
+
"model.layers.7.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 305 |
+
"model.layers.7.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 306 |
+
"model.layers.7.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 307 |
+
"model.layers.7.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 308 |
+
"model.layers.8.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 309 |
+
"model.layers.8.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 310 |
+
"model.layers.8.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 311 |
+
"model.layers.8.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 312 |
+
"model.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 313 |
+
"model.layers.8.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 314 |
+
"model.layers.8.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 315 |
+
"model.layers.8.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 316 |
+
"model.layers.8.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 317 |
+
"model.layers.8.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 318 |
+
"model.layers.9.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 319 |
+
"model.layers.9.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 320 |
+
"model.layers.9.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 321 |
+
"model.layers.9.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 322 |
+
"model.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
|
| 323 |
+
"model.layers.9.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 324 |
+
"model.layers.9.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 325 |
+
"model.layers.9.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 326 |
+
"model.layers.9.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
|
| 327 |
+
"model.layers.9.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
|
| 328 |
+
"model.mm_projector.bias": "pytorch_model-00002-of-00002.bin",
|
| 329 |
+
"model.mm_projector.weight": "pytorch_model-00002-of-00002.bin",
|
| 330 |
+
"model.norm.weight": "pytorch_model-00002-of-00002.bin",
|
| 331 |
+
"model.text_hidden_fcs.0.0.bias": "pytorch_model-00002-of-00002.bin",
|
| 332 |
+
"model.text_hidden_fcs.0.0.weight": "pytorch_model-00002-of-00002.bin",
|
| 333 |
+
"model.text_hidden_fcs.0.2.bias": "pytorch_model-00002-of-00002.bin",
|
| 334 |
+
"model.text_hidden_fcs.0.2.weight": "pytorch_model-00002-of-00002.bin",
|
| 335 |
+
"model.visual_model.image_encoder.blocks.0.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 336 |
+
"model.visual_model.image_encoder.blocks.0.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 337 |
+
"model.visual_model.image_encoder.blocks.0.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 338 |
+
"model.visual_model.image_encoder.blocks.0.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 339 |
+
"model.visual_model.image_encoder.blocks.0.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 340 |
+
"model.visual_model.image_encoder.blocks.0.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 341 |
+
"model.visual_model.image_encoder.blocks.0.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 342 |
+
"model.visual_model.image_encoder.blocks.0.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 343 |
+
"model.visual_model.image_encoder.blocks.0.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 344 |
+
"model.visual_model.image_encoder.blocks.0.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 345 |
+
"model.visual_model.image_encoder.blocks.0.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 346 |
+
"model.visual_model.image_encoder.blocks.0.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 347 |
+
"model.visual_model.image_encoder.blocks.0.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 348 |
+
"model.visual_model.image_encoder.blocks.0.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 349 |
+
"model.visual_model.image_encoder.blocks.1.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 350 |
+
"model.visual_model.image_encoder.blocks.1.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 351 |
+
"model.visual_model.image_encoder.blocks.1.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 352 |
+
"model.visual_model.image_encoder.blocks.1.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 353 |
+
"model.visual_model.image_encoder.blocks.1.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 354 |
+
"model.visual_model.image_encoder.blocks.1.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 355 |
+
"model.visual_model.image_encoder.blocks.1.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 356 |
+
"model.visual_model.image_encoder.blocks.1.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 357 |
+
"model.visual_model.image_encoder.blocks.1.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 358 |
+
"model.visual_model.image_encoder.blocks.1.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 359 |
+
"model.visual_model.image_encoder.blocks.1.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 360 |
+
"model.visual_model.image_encoder.blocks.1.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 361 |
+
"model.visual_model.image_encoder.blocks.1.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 362 |
+
"model.visual_model.image_encoder.blocks.1.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 363 |
+
"model.visual_model.image_encoder.blocks.10.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 364 |
+
"model.visual_model.image_encoder.blocks.10.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 365 |
+
"model.visual_model.image_encoder.blocks.10.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 366 |
+
"model.visual_model.image_encoder.blocks.10.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 367 |
+
"model.visual_model.image_encoder.blocks.10.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 368 |
+
"model.visual_model.image_encoder.blocks.10.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 369 |
+
"model.visual_model.image_encoder.blocks.10.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 370 |
+
"model.visual_model.image_encoder.blocks.10.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 371 |
+
"model.visual_model.image_encoder.blocks.10.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 372 |
+
"model.visual_model.image_encoder.blocks.10.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 373 |
+
"model.visual_model.image_encoder.blocks.10.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 374 |
+
"model.visual_model.image_encoder.blocks.10.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 375 |
+
"model.visual_model.image_encoder.blocks.10.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 376 |
+
"model.visual_model.image_encoder.blocks.10.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 377 |
+
"model.visual_model.image_encoder.blocks.11.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 378 |
+
"model.visual_model.image_encoder.blocks.11.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 379 |
+
"model.visual_model.image_encoder.blocks.11.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 380 |
+
"model.visual_model.image_encoder.blocks.11.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 381 |
+
"model.visual_model.image_encoder.blocks.11.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 382 |
+
"model.visual_model.image_encoder.blocks.11.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 383 |
+
"model.visual_model.image_encoder.blocks.11.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 384 |
+
"model.visual_model.image_encoder.blocks.11.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 385 |
+
"model.visual_model.image_encoder.blocks.11.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 386 |
+
"model.visual_model.image_encoder.blocks.11.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 387 |
+
"model.visual_model.image_encoder.blocks.11.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 388 |
+
"model.visual_model.image_encoder.blocks.11.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 389 |
+
"model.visual_model.image_encoder.blocks.11.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 390 |
+
"model.visual_model.image_encoder.blocks.11.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 391 |
+
"model.visual_model.image_encoder.blocks.12.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 392 |
+
"model.visual_model.image_encoder.blocks.12.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 393 |
+
"model.visual_model.image_encoder.blocks.12.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 394 |
+
"model.visual_model.image_encoder.blocks.12.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 395 |
+
"model.visual_model.image_encoder.blocks.12.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 396 |
+
"model.visual_model.image_encoder.blocks.12.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 397 |
+
"model.visual_model.image_encoder.blocks.12.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 398 |
+
"model.visual_model.image_encoder.blocks.12.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 399 |
+
"model.visual_model.image_encoder.blocks.12.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 400 |
+
"model.visual_model.image_encoder.blocks.12.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 401 |
+
"model.visual_model.image_encoder.blocks.12.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 402 |
+
"model.visual_model.image_encoder.blocks.12.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 403 |
+
"model.visual_model.image_encoder.blocks.12.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 404 |
+
"model.visual_model.image_encoder.blocks.12.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 405 |
+
"model.visual_model.image_encoder.blocks.13.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 406 |
+
"model.visual_model.image_encoder.blocks.13.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 407 |
+
"model.visual_model.image_encoder.blocks.13.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 408 |
+
"model.visual_model.image_encoder.blocks.13.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 409 |
+
"model.visual_model.image_encoder.blocks.13.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 410 |
+
"model.visual_model.image_encoder.blocks.13.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 411 |
+
"model.visual_model.image_encoder.blocks.13.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 412 |
+
"model.visual_model.image_encoder.blocks.13.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 413 |
+
"model.visual_model.image_encoder.blocks.13.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 414 |
+
"model.visual_model.image_encoder.blocks.13.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 415 |
+
"model.visual_model.image_encoder.blocks.13.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 416 |
+
"model.visual_model.image_encoder.blocks.13.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 417 |
+
"model.visual_model.image_encoder.blocks.13.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 418 |
+
"model.visual_model.image_encoder.blocks.13.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 419 |
+
"model.visual_model.image_encoder.blocks.14.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 420 |
+
"model.visual_model.image_encoder.blocks.14.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 421 |
+
"model.visual_model.image_encoder.blocks.14.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 422 |
+
"model.visual_model.image_encoder.blocks.14.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 423 |
+
"model.visual_model.image_encoder.blocks.14.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 424 |
+
"model.visual_model.image_encoder.blocks.14.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 425 |
+
"model.visual_model.image_encoder.blocks.14.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 426 |
+
"model.visual_model.image_encoder.blocks.14.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 427 |
+
"model.visual_model.image_encoder.blocks.14.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 428 |
+
"model.visual_model.image_encoder.blocks.14.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 429 |
+
"model.visual_model.image_encoder.blocks.14.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 430 |
+
"model.visual_model.image_encoder.blocks.14.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 431 |
+
"model.visual_model.image_encoder.blocks.14.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 432 |
+
"model.visual_model.image_encoder.blocks.14.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 433 |
+
"model.visual_model.image_encoder.blocks.15.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 434 |
+
"model.visual_model.image_encoder.blocks.15.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 435 |
+
"model.visual_model.image_encoder.blocks.15.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 436 |
+
"model.visual_model.image_encoder.blocks.15.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 437 |
+
"model.visual_model.image_encoder.blocks.15.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 438 |
+
"model.visual_model.image_encoder.blocks.15.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 439 |
+
"model.visual_model.image_encoder.blocks.15.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 440 |
+
"model.visual_model.image_encoder.blocks.15.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 441 |
+
"model.visual_model.image_encoder.blocks.15.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 442 |
+
"model.visual_model.image_encoder.blocks.15.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 443 |
+
"model.visual_model.image_encoder.blocks.15.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 444 |
+
"model.visual_model.image_encoder.blocks.15.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 445 |
+
"model.visual_model.image_encoder.blocks.15.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 446 |
+
"model.visual_model.image_encoder.blocks.15.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 447 |
+
"model.visual_model.image_encoder.blocks.16.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 448 |
+
"model.visual_model.image_encoder.blocks.16.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 449 |
+
"model.visual_model.image_encoder.blocks.16.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 450 |
+
"model.visual_model.image_encoder.blocks.16.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 451 |
+
"model.visual_model.image_encoder.blocks.16.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 452 |
+
"model.visual_model.image_encoder.blocks.16.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 453 |
+
"model.visual_model.image_encoder.blocks.16.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 454 |
+
"model.visual_model.image_encoder.blocks.16.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 455 |
+
"model.visual_model.image_encoder.blocks.16.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 456 |
+
"model.visual_model.image_encoder.blocks.16.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 457 |
+
"model.visual_model.image_encoder.blocks.16.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 458 |
+
"model.visual_model.image_encoder.blocks.16.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 459 |
+
"model.visual_model.image_encoder.blocks.16.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 460 |
+
"model.visual_model.image_encoder.blocks.16.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 461 |
+
"model.visual_model.image_encoder.blocks.17.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 462 |
+
"model.visual_model.image_encoder.blocks.17.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 463 |
+
"model.visual_model.image_encoder.blocks.17.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 464 |
+
"model.visual_model.image_encoder.blocks.17.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 465 |
+
"model.visual_model.image_encoder.blocks.17.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 466 |
+
"model.visual_model.image_encoder.blocks.17.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 467 |
+
"model.visual_model.image_encoder.blocks.17.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 468 |
+
"model.visual_model.image_encoder.blocks.17.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 469 |
+
"model.visual_model.image_encoder.blocks.17.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 470 |
+
"model.visual_model.image_encoder.blocks.17.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 471 |
+
"model.visual_model.image_encoder.blocks.17.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 472 |
+
"model.visual_model.image_encoder.blocks.17.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 473 |
+
"model.visual_model.image_encoder.blocks.17.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 474 |
+
"model.visual_model.image_encoder.blocks.17.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 475 |
+
"model.visual_model.image_encoder.blocks.18.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 476 |
+
"model.visual_model.image_encoder.blocks.18.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 477 |
+
"model.visual_model.image_encoder.blocks.18.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 478 |
+
"model.visual_model.image_encoder.blocks.18.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 479 |
+
"model.visual_model.image_encoder.blocks.18.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 480 |
+
"model.visual_model.image_encoder.blocks.18.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 481 |
+
"model.visual_model.image_encoder.blocks.18.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 482 |
+
"model.visual_model.image_encoder.blocks.18.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 483 |
+
"model.visual_model.image_encoder.blocks.18.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 484 |
+
"model.visual_model.image_encoder.blocks.18.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 485 |
+
"model.visual_model.image_encoder.blocks.18.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 486 |
+
"model.visual_model.image_encoder.blocks.18.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 487 |
+
"model.visual_model.image_encoder.blocks.18.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 488 |
+
"model.visual_model.image_encoder.blocks.18.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 489 |
+
"model.visual_model.image_encoder.blocks.19.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 490 |
+
"model.visual_model.image_encoder.blocks.19.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 491 |
+
"model.visual_model.image_encoder.blocks.19.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 492 |
+
"model.visual_model.image_encoder.blocks.19.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 493 |
+
"model.visual_model.image_encoder.blocks.19.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 494 |
+
"model.visual_model.image_encoder.blocks.19.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 495 |
+
"model.visual_model.image_encoder.blocks.19.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 496 |
+
"model.visual_model.image_encoder.blocks.19.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 497 |
+
"model.visual_model.image_encoder.blocks.19.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 498 |
+
"model.visual_model.image_encoder.blocks.19.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 499 |
+
"model.visual_model.image_encoder.blocks.19.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 500 |
+
"model.visual_model.image_encoder.blocks.19.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 501 |
+
"model.visual_model.image_encoder.blocks.19.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 502 |
+
"model.visual_model.image_encoder.blocks.19.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 503 |
+
"model.visual_model.image_encoder.blocks.2.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 504 |
+
"model.visual_model.image_encoder.blocks.2.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 505 |
+
"model.visual_model.image_encoder.blocks.2.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 506 |
+
"model.visual_model.image_encoder.blocks.2.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 507 |
+
"model.visual_model.image_encoder.blocks.2.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 508 |
+
"model.visual_model.image_encoder.blocks.2.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 509 |
+
"model.visual_model.image_encoder.blocks.2.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 510 |
+
"model.visual_model.image_encoder.blocks.2.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 511 |
+
"model.visual_model.image_encoder.blocks.2.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 512 |
+
"model.visual_model.image_encoder.blocks.2.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 513 |
+
"model.visual_model.image_encoder.blocks.2.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 514 |
+
"model.visual_model.image_encoder.blocks.2.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 515 |
+
"model.visual_model.image_encoder.blocks.2.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 516 |
+
"model.visual_model.image_encoder.blocks.2.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 517 |
+
"model.visual_model.image_encoder.blocks.20.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 518 |
+
"model.visual_model.image_encoder.blocks.20.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 519 |
+
"model.visual_model.image_encoder.blocks.20.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 520 |
+
"model.visual_model.image_encoder.blocks.20.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 521 |
+
"model.visual_model.image_encoder.blocks.20.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 522 |
+
"model.visual_model.image_encoder.blocks.20.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 523 |
+
"model.visual_model.image_encoder.blocks.20.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 524 |
+
"model.visual_model.image_encoder.blocks.20.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 525 |
+
"model.visual_model.image_encoder.blocks.20.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 526 |
+
"model.visual_model.image_encoder.blocks.20.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 527 |
+
"model.visual_model.image_encoder.blocks.20.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 528 |
+
"model.visual_model.image_encoder.blocks.20.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 529 |
+
"model.visual_model.image_encoder.blocks.20.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 530 |
+
"model.visual_model.image_encoder.blocks.20.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 531 |
+
"model.visual_model.image_encoder.blocks.21.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 532 |
+
"model.visual_model.image_encoder.blocks.21.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 533 |
+
"model.visual_model.image_encoder.blocks.21.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 534 |
+
"model.visual_model.image_encoder.blocks.21.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 535 |
+
"model.visual_model.image_encoder.blocks.21.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 536 |
+
"model.visual_model.image_encoder.blocks.21.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 537 |
+
"model.visual_model.image_encoder.blocks.21.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 538 |
+
"model.visual_model.image_encoder.blocks.21.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 539 |
+
"model.visual_model.image_encoder.blocks.21.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 540 |
+
"model.visual_model.image_encoder.blocks.21.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 541 |
+
"model.visual_model.image_encoder.blocks.21.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 542 |
+
"model.visual_model.image_encoder.blocks.21.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 543 |
+
"model.visual_model.image_encoder.blocks.21.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 544 |
+
"model.visual_model.image_encoder.blocks.21.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 545 |
+
"model.visual_model.image_encoder.blocks.22.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 546 |
+
"model.visual_model.image_encoder.blocks.22.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 547 |
+
"model.visual_model.image_encoder.blocks.22.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 548 |
+
"model.visual_model.image_encoder.blocks.22.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 549 |
+
"model.visual_model.image_encoder.blocks.22.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 550 |
+
"model.visual_model.image_encoder.blocks.22.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 551 |
+
"model.visual_model.image_encoder.blocks.22.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 552 |
+
"model.visual_model.image_encoder.blocks.22.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 553 |
+
"model.visual_model.image_encoder.blocks.22.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 554 |
+
"model.visual_model.image_encoder.blocks.22.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 555 |
+
"model.visual_model.image_encoder.blocks.22.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 556 |
+
"model.visual_model.image_encoder.blocks.22.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 557 |
+
"model.visual_model.image_encoder.blocks.22.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 558 |
+
"model.visual_model.image_encoder.blocks.22.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 559 |
+
"model.visual_model.image_encoder.blocks.23.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 560 |
+
"model.visual_model.image_encoder.blocks.23.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 561 |
+
"model.visual_model.image_encoder.blocks.23.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 562 |
+
"model.visual_model.image_encoder.blocks.23.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 563 |
+
"model.visual_model.image_encoder.blocks.23.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 564 |
+
"model.visual_model.image_encoder.blocks.23.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 565 |
+
"model.visual_model.image_encoder.blocks.23.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 566 |
+
"model.visual_model.image_encoder.blocks.23.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 567 |
+
"model.visual_model.image_encoder.blocks.23.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 568 |
+
"model.visual_model.image_encoder.blocks.23.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 569 |
+
"model.visual_model.image_encoder.blocks.23.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 570 |
+
"model.visual_model.image_encoder.blocks.23.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 571 |
+
"model.visual_model.image_encoder.blocks.23.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 572 |
+
"model.visual_model.image_encoder.blocks.23.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 573 |
+
"model.visual_model.image_encoder.blocks.24.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 574 |
+
"model.visual_model.image_encoder.blocks.24.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 575 |
+
"model.visual_model.image_encoder.blocks.24.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 576 |
+
"model.visual_model.image_encoder.blocks.24.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 577 |
+
"model.visual_model.image_encoder.blocks.24.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 578 |
+
"model.visual_model.image_encoder.blocks.24.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 579 |
+
"model.visual_model.image_encoder.blocks.24.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 580 |
+
"model.visual_model.image_encoder.blocks.24.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 581 |
+
"model.visual_model.image_encoder.blocks.24.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 582 |
+
"model.visual_model.image_encoder.blocks.24.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 583 |
+
"model.visual_model.image_encoder.blocks.24.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 584 |
+
"model.visual_model.image_encoder.blocks.24.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 585 |
+
"model.visual_model.image_encoder.blocks.24.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 586 |
+
"model.visual_model.image_encoder.blocks.24.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 587 |
+
"model.visual_model.image_encoder.blocks.25.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 588 |
+
"model.visual_model.image_encoder.blocks.25.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 589 |
+
"model.visual_model.image_encoder.blocks.25.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 590 |
+
"model.visual_model.image_encoder.blocks.25.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 591 |
+
"model.visual_model.image_encoder.blocks.25.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 592 |
+
"model.visual_model.image_encoder.blocks.25.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 593 |
+
"model.visual_model.image_encoder.blocks.25.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 594 |
+
"model.visual_model.image_encoder.blocks.25.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 595 |
+
"model.visual_model.image_encoder.blocks.25.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 596 |
+
"model.visual_model.image_encoder.blocks.25.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 597 |
+
"model.visual_model.image_encoder.blocks.25.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 598 |
+
"model.visual_model.image_encoder.blocks.25.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 599 |
+
"model.visual_model.image_encoder.blocks.25.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 600 |
+
"model.visual_model.image_encoder.blocks.25.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 601 |
+
"model.visual_model.image_encoder.blocks.26.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 602 |
+
"model.visual_model.image_encoder.blocks.26.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 603 |
+
"model.visual_model.image_encoder.blocks.26.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 604 |
+
"model.visual_model.image_encoder.blocks.26.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 605 |
+
"model.visual_model.image_encoder.blocks.26.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 606 |
+
"model.visual_model.image_encoder.blocks.26.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 607 |
+
"model.visual_model.image_encoder.blocks.26.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 608 |
+
"model.visual_model.image_encoder.blocks.26.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 609 |
+
"model.visual_model.image_encoder.blocks.26.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 610 |
+
"model.visual_model.image_encoder.blocks.26.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 611 |
+
"model.visual_model.image_encoder.blocks.26.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 612 |
+
"model.visual_model.image_encoder.blocks.26.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 613 |
+
"model.visual_model.image_encoder.blocks.26.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 614 |
+
"model.visual_model.image_encoder.blocks.26.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 615 |
+
"model.visual_model.image_encoder.blocks.27.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 616 |
+
"model.visual_model.image_encoder.blocks.27.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 617 |
+
"model.visual_model.image_encoder.blocks.27.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 618 |
+
"model.visual_model.image_encoder.blocks.27.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 619 |
+
"model.visual_model.image_encoder.blocks.27.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 620 |
+
"model.visual_model.image_encoder.blocks.27.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 621 |
+
"model.visual_model.image_encoder.blocks.27.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 622 |
+
"model.visual_model.image_encoder.blocks.27.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 623 |
+
"model.visual_model.image_encoder.blocks.27.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 624 |
+
"model.visual_model.image_encoder.blocks.27.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 625 |
+
"model.visual_model.image_encoder.blocks.27.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 626 |
+
"model.visual_model.image_encoder.blocks.27.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 627 |
+
"model.visual_model.image_encoder.blocks.27.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 628 |
+
"model.visual_model.image_encoder.blocks.27.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 629 |
+
"model.visual_model.image_encoder.blocks.28.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 630 |
+
"model.visual_model.image_encoder.blocks.28.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 631 |
+
"model.visual_model.image_encoder.blocks.28.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 632 |
+
"model.visual_model.image_encoder.blocks.28.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 633 |
+
"model.visual_model.image_encoder.blocks.28.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 634 |
+
"model.visual_model.image_encoder.blocks.28.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 635 |
+
"model.visual_model.image_encoder.blocks.28.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 636 |
+
"model.visual_model.image_encoder.blocks.28.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 637 |
+
"model.visual_model.image_encoder.blocks.28.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 638 |
+
"model.visual_model.image_encoder.blocks.28.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 639 |
+
"model.visual_model.image_encoder.blocks.28.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 640 |
+
"model.visual_model.image_encoder.blocks.28.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 641 |
+
"model.visual_model.image_encoder.blocks.28.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 642 |
+
"model.visual_model.image_encoder.blocks.28.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 643 |
+
"model.visual_model.image_encoder.blocks.29.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 644 |
+
"model.visual_model.image_encoder.blocks.29.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 645 |
+
"model.visual_model.image_encoder.blocks.29.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 646 |
+
"model.visual_model.image_encoder.blocks.29.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 647 |
+
"model.visual_model.image_encoder.blocks.29.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 648 |
+
"model.visual_model.image_encoder.blocks.29.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 649 |
+
"model.visual_model.image_encoder.blocks.29.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 650 |
+
"model.visual_model.image_encoder.blocks.29.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 651 |
+
"model.visual_model.image_encoder.blocks.29.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 652 |
+
"model.visual_model.image_encoder.blocks.29.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 653 |
+
"model.visual_model.image_encoder.blocks.29.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 654 |
+
"model.visual_model.image_encoder.blocks.29.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 655 |
+
"model.visual_model.image_encoder.blocks.29.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 656 |
+
"model.visual_model.image_encoder.blocks.29.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 657 |
+
"model.visual_model.image_encoder.blocks.3.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 658 |
+
"model.visual_model.image_encoder.blocks.3.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 659 |
+
"model.visual_model.image_encoder.blocks.3.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 660 |
+
"model.visual_model.image_encoder.blocks.3.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 661 |
+
"model.visual_model.image_encoder.blocks.3.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 662 |
+
"model.visual_model.image_encoder.blocks.3.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 663 |
+
"model.visual_model.image_encoder.blocks.3.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 664 |
+
"model.visual_model.image_encoder.blocks.3.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 665 |
+
"model.visual_model.image_encoder.blocks.3.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 666 |
+
"model.visual_model.image_encoder.blocks.3.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 667 |
+
"model.visual_model.image_encoder.blocks.3.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 668 |
+
"model.visual_model.image_encoder.blocks.3.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 669 |
+
"model.visual_model.image_encoder.blocks.3.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 670 |
+
"model.visual_model.image_encoder.blocks.3.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 671 |
+
"model.visual_model.image_encoder.blocks.30.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 672 |
+
"model.visual_model.image_encoder.blocks.30.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 673 |
+
"model.visual_model.image_encoder.blocks.30.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 674 |
+
"model.visual_model.image_encoder.blocks.30.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 675 |
+
"model.visual_model.image_encoder.blocks.30.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 676 |
+
"model.visual_model.image_encoder.blocks.30.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 677 |
+
"model.visual_model.image_encoder.blocks.30.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 678 |
+
"model.visual_model.image_encoder.blocks.30.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 679 |
+
"model.visual_model.image_encoder.blocks.30.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 680 |
+
"model.visual_model.image_encoder.blocks.30.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 681 |
+
"model.visual_model.image_encoder.blocks.30.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 682 |
+
"model.visual_model.image_encoder.blocks.30.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 683 |
+
"model.visual_model.image_encoder.blocks.30.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 684 |
+
"model.visual_model.image_encoder.blocks.30.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 685 |
+
"model.visual_model.image_encoder.blocks.31.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 686 |
+
"model.visual_model.image_encoder.blocks.31.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 687 |
+
"model.visual_model.image_encoder.blocks.31.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 688 |
+
"model.visual_model.image_encoder.blocks.31.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 689 |
+
"model.visual_model.image_encoder.blocks.31.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 690 |
+
"model.visual_model.image_encoder.blocks.31.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 691 |
+
"model.visual_model.image_encoder.blocks.31.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 692 |
+
"model.visual_model.image_encoder.blocks.31.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 693 |
+
"model.visual_model.image_encoder.blocks.31.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 694 |
+
"model.visual_model.image_encoder.blocks.31.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 695 |
+
"model.visual_model.image_encoder.blocks.31.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 696 |
+
"model.visual_model.image_encoder.blocks.31.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 697 |
+
"model.visual_model.image_encoder.blocks.31.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 698 |
+
"model.visual_model.image_encoder.blocks.31.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 699 |
+
"model.visual_model.image_encoder.blocks.4.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 700 |
+
"model.visual_model.image_encoder.blocks.4.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 701 |
+
"model.visual_model.image_encoder.blocks.4.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 702 |
+
"model.visual_model.image_encoder.blocks.4.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 703 |
+
"model.visual_model.image_encoder.blocks.4.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 704 |
+
"model.visual_model.image_encoder.blocks.4.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 705 |
+
"model.visual_model.image_encoder.blocks.4.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 706 |
+
"model.visual_model.image_encoder.blocks.4.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 707 |
+
"model.visual_model.image_encoder.blocks.4.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 708 |
+
"model.visual_model.image_encoder.blocks.4.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 709 |
+
"model.visual_model.image_encoder.blocks.4.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 710 |
+
"model.visual_model.image_encoder.blocks.4.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 711 |
+
"model.visual_model.image_encoder.blocks.4.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 712 |
+
"model.visual_model.image_encoder.blocks.4.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 713 |
+
"model.visual_model.image_encoder.blocks.5.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 714 |
+
"model.visual_model.image_encoder.blocks.5.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 715 |
+
"model.visual_model.image_encoder.blocks.5.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 716 |
+
"model.visual_model.image_encoder.blocks.5.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 717 |
+
"model.visual_model.image_encoder.blocks.5.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 718 |
+
"model.visual_model.image_encoder.blocks.5.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 719 |
+
"model.visual_model.image_encoder.blocks.5.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 720 |
+
"model.visual_model.image_encoder.blocks.5.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 721 |
+
"model.visual_model.image_encoder.blocks.5.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 722 |
+
"model.visual_model.image_encoder.blocks.5.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 723 |
+
"model.visual_model.image_encoder.blocks.5.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 724 |
+
"model.visual_model.image_encoder.blocks.5.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 725 |
+
"model.visual_model.image_encoder.blocks.5.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 726 |
+
"model.visual_model.image_encoder.blocks.5.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 727 |
+
"model.visual_model.image_encoder.blocks.6.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 728 |
+
"model.visual_model.image_encoder.blocks.6.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 729 |
+
"model.visual_model.image_encoder.blocks.6.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 730 |
+
"model.visual_model.image_encoder.blocks.6.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 731 |
+
"model.visual_model.image_encoder.blocks.6.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 732 |
+
"model.visual_model.image_encoder.blocks.6.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 733 |
+
"model.visual_model.image_encoder.blocks.6.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 734 |
+
"model.visual_model.image_encoder.blocks.6.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 735 |
+
"model.visual_model.image_encoder.blocks.6.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 736 |
+
"model.visual_model.image_encoder.blocks.6.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 737 |
+
"model.visual_model.image_encoder.blocks.6.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 738 |
+
"model.visual_model.image_encoder.blocks.6.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 739 |
+
"model.visual_model.image_encoder.blocks.6.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 740 |
+
"model.visual_model.image_encoder.blocks.6.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 741 |
+
"model.visual_model.image_encoder.blocks.7.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 742 |
+
"model.visual_model.image_encoder.blocks.7.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 743 |
+
"model.visual_model.image_encoder.blocks.7.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 744 |
+
"model.visual_model.image_encoder.blocks.7.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 745 |
+
"model.visual_model.image_encoder.blocks.7.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 746 |
+
"model.visual_model.image_encoder.blocks.7.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 747 |
+
"model.visual_model.image_encoder.blocks.7.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 748 |
+
"model.visual_model.image_encoder.blocks.7.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 749 |
+
"model.visual_model.image_encoder.blocks.7.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 750 |
+
"model.visual_model.image_encoder.blocks.7.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 751 |
+
"model.visual_model.image_encoder.blocks.7.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 752 |
+
"model.visual_model.image_encoder.blocks.7.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 753 |
+
"model.visual_model.image_encoder.blocks.7.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 754 |
+
"model.visual_model.image_encoder.blocks.7.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 755 |
+
"model.visual_model.image_encoder.blocks.8.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 756 |
+
"model.visual_model.image_encoder.blocks.8.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 757 |
+
"model.visual_model.image_encoder.blocks.8.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 758 |
+
"model.visual_model.image_encoder.blocks.8.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 759 |
+
"model.visual_model.image_encoder.blocks.8.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 760 |
+
"model.visual_model.image_encoder.blocks.8.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 761 |
+
"model.visual_model.image_encoder.blocks.8.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 762 |
+
"model.visual_model.image_encoder.blocks.8.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 763 |
+
"model.visual_model.image_encoder.blocks.8.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 764 |
+
"model.visual_model.image_encoder.blocks.8.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 765 |
+
"model.visual_model.image_encoder.blocks.8.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 766 |
+
"model.visual_model.image_encoder.blocks.8.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 767 |
+
"model.visual_model.image_encoder.blocks.8.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 768 |
+
"model.visual_model.image_encoder.blocks.8.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 769 |
+
"model.visual_model.image_encoder.blocks.9.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 770 |
+
"model.visual_model.image_encoder.blocks.9.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 771 |
+
"model.visual_model.image_encoder.blocks.9.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
|
| 772 |
+
"model.visual_model.image_encoder.blocks.9.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
|
| 773 |
+
"model.visual_model.image_encoder.blocks.9.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
|
| 774 |
+
"model.visual_model.image_encoder.blocks.9.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
|
| 775 |
+
"model.visual_model.image_encoder.blocks.9.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 776 |
+
"model.visual_model.image_encoder.blocks.9.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 777 |
+
"model.visual_model.image_encoder.blocks.9.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 778 |
+
"model.visual_model.image_encoder.blocks.9.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 779 |
+
"model.visual_model.image_encoder.blocks.9.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 780 |
+
"model.visual_model.image_encoder.blocks.9.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 781 |
+
"model.visual_model.image_encoder.blocks.9.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 782 |
+
"model.visual_model.image_encoder.blocks.9.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 783 |
+
"model.visual_model.image_encoder.neck.0.weight": "pytorch_model-00002-of-00002.bin",
|
| 784 |
+
"model.visual_model.image_encoder.neck.1.bias": "pytorch_model-00002-of-00002.bin",
|
| 785 |
+
"model.visual_model.image_encoder.neck.1.weight": "pytorch_model-00002-of-00002.bin",
|
| 786 |
+
"model.visual_model.image_encoder.neck.2.weight": "pytorch_model-00002-of-00002.bin",
|
| 787 |
+
"model.visual_model.image_encoder.neck.3.bias": "pytorch_model-00002-of-00002.bin",
|
| 788 |
+
"model.visual_model.image_encoder.neck.3.weight": "pytorch_model-00002-of-00002.bin",
|
| 789 |
+
"model.visual_model.image_encoder.patch_embed.proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 790 |
+
"model.visual_model.image_encoder.patch_embed.proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 791 |
+
"model.visual_model.image_encoder.pos_embed": "pytorch_model-00002-of-00002.bin",
|
| 792 |
+
"model.visual_model.mask_decoder.iou_prediction_head.layers.0.bias": "pytorch_model-00002-of-00002.bin",
|
| 793 |
+
"model.visual_model.mask_decoder.iou_prediction_head.layers.0.weight": "pytorch_model-00002-of-00002.bin",
|
| 794 |
+
"model.visual_model.mask_decoder.iou_prediction_head.layers.1.bias": "pytorch_model-00002-of-00002.bin",
|
| 795 |
+
"model.visual_model.mask_decoder.iou_prediction_head.layers.1.weight": "pytorch_model-00002-of-00002.bin",
|
| 796 |
+
"model.visual_model.mask_decoder.iou_prediction_head.layers.2.bias": "pytorch_model-00002-of-00002.bin",
|
| 797 |
+
"model.visual_model.mask_decoder.iou_prediction_head.layers.2.weight": "pytorch_model-00002-of-00002.bin",
|
| 798 |
+
"model.visual_model.mask_decoder.iou_token.weight": "pytorch_model-00002-of-00002.bin",
|
| 799 |
+
"model.visual_model.mask_decoder.mask_tokens.weight": "pytorch_model-00002-of-00002.bin",
|
| 800 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.0.layers.0.bias": "pytorch_model-00002-of-00002.bin",
|
| 801 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.0.layers.0.weight": "pytorch_model-00002-of-00002.bin",
|
| 802 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.0.layers.1.bias": "pytorch_model-00002-of-00002.bin",
|
| 803 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.0.layers.1.weight": "pytorch_model-00002-of-00002.bin",
|
| 804 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.0.layers.2.bias": "pytorch_model-00002-of-00002.bin",
|
| 805 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.0.layers.2.weight": "pytorch_model-00002-of-00002.bin",
|
| 806 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.1.layers.0.bias": "pytorch_model-00002-of-00002.bin",
|
| 807 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.1.layers.0.weight": "pytorch_model-00002-of-00002.bin",
|
| 808 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.1.layers.1.bias": "pytorch_model-00002-of-00002.bin",
|
| 809 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.1.layers.1.weight": "pytorch_model-00002-of-00002.bin",
|
| 810 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.1.layers.2.bias": "pytorch_model-00002-of-00002.bin",
|
| 811 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.1.layers.2.weight": "pytorch_model-00002-of-00002.bin",
|
| 812 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.2.layers.0.bias": "pytorch_model-00002-of-00002.bin",
|
| 813 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.2.layers.0.weight": "pytorch_model-00002-of-00002.bin",
|
| 814 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.2.layers.1.bias": "pytorch_model-00002-of-00002.bin",
|
| 815 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.2.layers.1.weight": "pytorch_model-00002-of-00002.bin",
|
| 816 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.2.layers.2.bias": "pytorch_model-00002-of-00002.bin",
|
| 817 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.2.layers.2.weight": "pytorch_model-00002-of-00002.bin",
|
| 818 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.3.layers.0.bias": "pytorch_model-00002-of-00002.bin",
|
| 819 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.3.layers.0.weight": "pytorch_model-00002-of-00002.bin",
|
| 820 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.3.layers.1.bias": "pytorch_model-00002-of-00002.bin",
|
| 821 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.3.layers.1.weight": "pytorch_model-00002-of-00002.bin",
|
| 822 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.3.layers.2.bias": "pytorch_model-00002-of-00002.bin",
|
| 823 |
+
"model.visual_model.mask_decoder.output_hypernetworks_mlps.3.layers.2.weight": "pytorch_model-00002-of-00002.bin",
|
| 824 |
+
"model.visual_model.mask_decoder.output_upscaling.0.bias": "pytorch_model-00002-of-00002.bin",
|
| 825 |
+
"model.visual_model.mask_decoder.output_upscaling.0.weight": "pytorch_model-00002-of-00002.bin",
|
| 826 |
+
"model.visual_model.mask_decoder.output_upscaling.1.bias": "pytorch_model-00002-of-00002.bin",
|
| 827 |
+
"model.visual_model.mask_decoder.output_upscaling.1.weight": "pytorch_model-00002-of-00002.bin",
|
| 828 |
+
"model.visual_model.mask_decoder.output_upscaling.3.bias": "pytorch_model-00002-of-00002.bin",
|
| 829 |
+
"model.visual_model.mask_decoder.output_upscaling.3.weight": "pytorch_model-00002-of-00002.bin",
|
| 830 |
+
"model.visual_model.mask_decoder.transformer.final_attn_token_to_image.k_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 831 |
+
"model.visual_model.mask_decoder.transformer.final_attn_token_to_image.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 832 |
+
"model.visual_model.mask_decoder.transformer.final_attn_token_to_image.out_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 833 |
+
"model.visual_model.mask_decoder.transformer.final_attn_token_to_image.out_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 834 |
+
"model.visual_model.mask_decoder.transformer.final_attn_token_to_image.q_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 835 |
+
"model.visual_model.mask_decoder.transformer.final_attn_token_to_image.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 836 |
+
"model.visual_model.mask_decoder.transformer.final_attn_token_to_image.v_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 837 |
+
"model.visual_model.mask_decoder.transformer.final_attn_token_to_image.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 838 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 839 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 840 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 841 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 842 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 843 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 844 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 845 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 846 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 847 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 848 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 849 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 850 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 851 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 852 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 853 |
+
"model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 854 |
+
"model.visual_model.mask_decoder.transformer.layers.0.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 855 |
+
"model.visual_model.mask_decoder.transformer.layers.0.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 856 |
+
"model.visual_model.mask_decoder.transformer.layers.0.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 857 |
+
"model.visual_model.mask_decoder.transformer.layers.0.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 858 |
+
"model.visual_model.mask_decoder.transformer.layers.0.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 859 |
+
"model.visual_model.mask_decoder.transformer.layers.0.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 860 |
+
"model.visual_model.mask_decoder.transformer.layers.0.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 861 |
+
"model.visual_model.mask_decoder.transformer.layers.0.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 862 |
+
"model.visual_model.mask_decoder.transformer.layers.0.norm3.bias": "pytorch_model-00002-of-00002.bin",
|
| 863 |
+
"model.visual_model.mask_decoder.transformer.layers.0.norm3.weight": "pytorch_model-00002-of-00002.bin",
|
| 864 |
+
"model.visual_model.mask_decoder.transformer.layers.0.norm4.bias": "pytorch_model-00002-of-00002.bin",
|
| 865 |
+
"model.visual_model.mask_decoder.transformer.layers.0.norm4.weight": "pytorch_model-00002-of-00002.bin",
|
| 866 |
+
"model.visual_model.mask_decoder.transformer.layers.0.self_attn.k_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 867 |
+
"model.visual_model.mask_decoder.transformer.layers.0.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 868 |
+
"model.visual_model.mask_decoder.transformer.layers.0.self_attn.out_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 869 |
+
"model.visual_model.mask_decoder.transformer.layers.0.self_attn.out_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 870 |
+
"model.visual_model.mask_decoder.transformer.layers.0.self_attn.q_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 871 |
+
"model.visual_model.mask_decoder.transformer.layers.0.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 872 |
+
"model.visual_model.mask_decoder.transformer.layers.0.self_attn.v_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 873 |
+
"model.visual_model.mask_decoder.transformer.layers.0.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 874 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 875 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 876 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 877 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 878 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 879 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 880 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 881 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 882 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 883 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 884 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 885 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 886 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 887 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 888 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 889 |
+
"model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 890 |
+
"model.visual_model.mask_decoder.transformer.layers.1.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
|
| 891 |
+
"model.visual_model.mask_decoder.transformer.layers.1.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
|
| 892 |
+
"model.visual_model.mask_decoder.transformer.layers.1.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
|
| 893 |
+
"model.visual_model.mask_decoder.transformer.layers.1.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
|
| 894 |
+
"model.visual_model.mask_decoder.transformer.layers.1.norm1.bias": "pytorch_model-00002-of-00002.bin",
|
| 895 |
+
"model.visual_model.mask_decoder.transformer.layers.1.norm1.weight": "pytorch_model-00002-of-00002.bin",
|
| 896 |
+
"model.visual_model.mask_decoder.transformer.layers.1.norm2.bias": "pytorch_model-00002-of-00002.bin",
|
| 897 |
+
"model.visual_model.mask_decoder.transformer.layers.1.norm2.weight": "pytorch_model-00002-of-00002.bin",
|
| 898 |
+
"model.visual_model.mask_decoder.transformer.layers.1.norm3.bias": "pytorch_model-00002-of-00002.bin",
|
| 899 |
+
"model.visual_model.mask_decoder.transformer.layers.1.norm3.weight": "pytorch_model-00002-of-00002.bin",
|
| 900 |
+
"model.visual_model.mask_decoder.transformer.layers.1.norm4.bias": "pytorch_model-00002-of-00002.bin",
|
| 901 |
+
"model.visual_model.mask_decoder.transformer.layers.1.norm4.weight": "pytorch_model-00002-of-00002.bin",
|
| 902 |
+
"model.visual_model.mask_decoder.transformer.layers.1.self_attn.k_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 903 |
+
"model.visual_model.mask_decoder.transformer.layers.1.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 904 |
+
"model.visual_model.mask_decoder.transformer.layers.1.self_attn.out_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 905 |
+
"model.visual_model.mask_decoder.transformer.layers.1.self_attn.out_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 906 |
+
"model.visual_model.mask_decoder.transformer.layers.1.self_attn.q_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 907 |
+
"model.visual_model.mask_decoder.transformer.layers.1.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 908 |
+
"model.visual_model.mask_decoder.transformer.layers.1.self_attn.v_proj.bias": "pytorch_model-00002-of-00002.bin",
|
| 909 |
+
"model.visual_model.mask_decoder.transformer.layers.1.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
|
| 910 |
+
"model.visual_model.mask_decoder.transformer.norm_final_attn.bias": "pytorch_model-00002-of-00002.bin",
|
| 911 |
+
"model.visual_model.mask_decoder.transformer.norm_final_attn.weight": "pytorch_model-00002-of-00002.bin",
|
| 912 |
+
"model.visual_model.prompt_encoder.mask_downscaling.0.bias": "pytorch_model-00002-of-00002.bin",
|
| 913 |
+
"model.visual_model.prompt_encoder.mask_downscaling.0.weight": "pytorch_model-00002-of-00002.bin",
|
| 914 |
+
"model.visual_model.prompt_encoder.mask_downscaling.1.bias": "pytorch_model-00002-of-00002.bin",
|
| 915 |
+
"model.visual_model.prompt_encoder.mask_downscaling.1.weight": "pytorch_model-00002-of-00002.bin",
|
| 916 |
+
"model.visual_model.prompt_encoder.mask_downscaling.3.bias": "pytorch_model-00002-of-00002.bin",
|
| 917 |
+
"model.visual_model.prompt_encoder.mask_downscaling.3.weight": "pytorch_model-00002-of-00002.bin",
|
| 918 |
+
"model.visual_model.prompt_encoder.mask_downscaling.4.bias": "pytorch_model-00002-of-00002.bin",
|
| 919 |
+
"model.visual_model.prompt_encoder.mask_downscaling.4.weight": "pytorch_model-00002-of-00002.bin",
|
| 920 |
+
"model.visual_model.prompt_encoder.mask_downscaling.6.bias": "pytorch_model-00002-of-00002.bin",
|
| 921 |
+
"model.visual_model.prompt_encoder.mask_downscaling.6.weight": "pytorch_model-00002-of-00002.bin",
|
| 922 |
+
"model.visual_model.prompt_encoder.no_mask_embed.weight": "pytorch_model-00002-of-00002.bin",
|
| 923 |
+
"model.visual_model.prompt_encoder.not_a_point_embed.weight": "pytorch_model-00002-of-00002.bin",
|
| 924 |
+
"model.visual_model.prompt_encoder.pe_layer.positional_encoding_gaussian_matrix": "pytorch_model-00002-of-00002.bin",
|
| 925 |
+
"model.visual_model.prompt_encoder.point_embeddings.0.weight": "pytorch_model-00002-of-00002.bin",
|
| 926 |
+
"model.visual_model.prompt_encoder.point_embeddings.1.weight": "pytorch_model-00002-of-00002.bin",
|
| 927 |
+
"model.visual_model.prompt_encoder.point_embeddings.2.weight": "pytorch_model-00002-of-00002.bin",
|
| 928 |
+
"model.visual_model.prompt_encoder.point_embeddings.3.weight": "pytorch_model-00002-of-00002.bin"
|
| 929 |
+
}
|
| 930 |
+
}
|
ckpts/AffordanceVLM-7B/special_tokens_map.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": true,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": true,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": "<unk>",
|
| 17 |
+
"unk_token": {
|
| 18 |
+
"content": "<unk>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": true,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
}
|
| 24 |
+
}
|
ckpts/AffordanceVLM-7B/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
| 3 |
+
size 499723
|
ckpts/AffordanceVLM-7B/tokenizer_config.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"bos_token": {
|
| 5 |
+
"__type": "AddedToken",
|
| 6 |
+
"content": "<s>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": true,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false
|
| 11 |
+
},
|
| 12 |
+
"clean_up_tokenization_spaces": false,
|
| 13 |
+
"eos_token": {
|
| 14 |
+
"__type": "AddedToken",
|
| 15 |
+
"content": "</s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": true,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false
|
| 20 |
+
},
|
| 21 |
+
"legacy": true,
|
| 22 |
+
"model_max_length": 512,
|
| 23 |
+
"pad_token": null,
|
| 24 |
+
"padding_side": "right",
|
| 25 |
+
"sp_model_kwargs": {},
|
| 26 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 27 |
+
"unk_token": {
|
| 28 |
+
"__type": "AddedToken",
|
| 29 |
+
"content": "<unk>",
|
| 30 |
+
"lstrip": false,
|
| 31 |
+
"normalized": true,
|
| 32 |
+
"rstrip": false,
|
| 33 |
+
"single_word": false
|
| 34 |
+
}
|
| 35 |
+
}
|
ckpts/sam_vit_h_4b8939.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
|
| 3 |
+
size 2564550879
|
client.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Client script to send an image and prompt to a Flask-based vision-language segmentation server.
|
| 4 |
+
|
| 5 |
+
from __future__ import absolute_import, print_function, division
|
| 6 |
+
import requests
|
| 7 |
+
import cv2
|
| 8 |
+
import base64
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
# ---------------------------
|
| 12 |
+
# Encode image to base64 string
|
| 13 |
+
# ---------------------------
|
| 14 |
+
def img2b64(img):
|
| 15 |
+
retval, buffer = cv2.imencode('.bmp', img) # Encode as BMP
|
| 16 |
+
pic_str = base64.b64encode(buffer).decode() # Convert to base64 string
|
| 17 |
+
return pic_str
|
| 18 |
+
|
| 19 |
+
# ---------------------------
|
| 20 |
+
# Decode base64 string back to image
|
| 21 |
+
# ---------------------------
|
| 22 |
+
def b642img(pic_str):
|
| 23 |
+
img_data = base64.b64decode(pic_str)
|
| 24 |
+
nparr = np.frombuffer(img_data, np.uint8)
|
| 25 |
+
img_np = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 26 |
+
return img_np
|
| 27 |
+
|
| 28 |
+
# ---------------------------
|
| 29 |
+
# Send image and prompt to server, receive result and save
|
| 30 |
+
# ---------------------------
|
| 31 |
+
def post_files():
|
| 32 |
+
path = 'vis_output/my_workspace.JPG' # Input image path
|
| 33 |
+
img = cv2.imread(path)
|
| 34 |
+
if img is None:
|
| 35 |
+
print(f"Failed to read image at {path}")
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
pic_str = img2b64(img)
|
| 39 |
+
data = {
|
| 40 |
+
'img': pic_str,
|
| 41 |
+
'prompt': 'Please segment the affordance map of mug in this image.'
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
# Send POST request to Flask server
|
| 45 |
+
r = requests.post('http://localhost:3200/img_mask', json=data)
|
| 46 |
+
|
| 47 |
+
if r.status_code == 200:
|
| 48 |
+
print('Success. Received response from server.')
|
| 49 |
+
result = r.json()
|
| 50 |
+
result_b64 = result.get('img', None)
|
| 51 |
+
|
| 52 |
+
if result_b64:
|
| 53 |
+
result_img = b642img(result_b64)
|
| 54 |
+
save_path = 'affordance_mask_result.jpg'
|
| 55 |
+
cv2.imwrite(save_path, result_img)
|
| 56 |
+
print(f"Result saved to {save_path}")
|
| 57 |
+
else:
|
| 58 |
+
print("No image returned in the response.")
|
| 59 |
+
else:
|
| 60 |
+
print(f"Request failed with status code {r.status_code}")
|
| 61 |
+
|
| 62 |
+
# ---------------------------
|
| 63 |
+
# Main entry
|
| 64 |
+
# ---------------------------
|
| 65 |
+
if __name__ == '__main__':
|
| 66 |
+
post_files()
|
| 67 |
+
|
data_curation/.ipynb_checkpoints/check_dataset-checkpoint.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle as pkl
|
| 3 |
+
|
| 4 |
+
DATA_DIR = '/gemini/space/wrz/AffordanceNet/data'
|
| 5 |
+
|
| 6 |
+
# 新增一个路径修复函数
|
| 7 |
+
def resolve_path(path):
|
| 8 |
+
"""
|
| 9 |
+
如果路径是相对路径 (比如 ./data/...),将其转换为绝对路径
|
| 10 |
+
"""
|
| 11 |
+
if path.startswith('./data/'):
|
| 12 |
+
# 截掉前缀的 './data/' (长度为 7),拼接到真实的 DATA_DIR 后面
|
| 13 |
+
return os.path.join(DATA_DIR, path[7:])
|
| 14 |
+
elif path.startswith('./'):
|
| 15 |
+
# 兼容其他情况
|
| 16 |
+
return os.path.join(os.path.dirname(DATA_DIR), path[2:])
|
| 17 |
+
return path
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_data_paths():
|
| 21 |
+
"""Retrieve train/val/reasoning/non-reasoning pkl file paths."""
|
| 22 |
+
all_files = os.listdir(DATA_DIR)
|
| 23 |
+
train_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('train.pkl')]
|
| 24 |
+
val_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('val.pkl')]
|
| 25 |
+
reasoning_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('reasoning_val.pkl')]
|
| 26 |
+
non_reasoning_paths = [vp for vp in val_paths if vp not in reasoning_paths]
|
| 27 |
+
|
| 28 |
+
return train_paths, reasoning_paths, non_reasoning_paths
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def check_file_exists(file_path, description=""):
|
| 32 |
+
"""Assert that the file exists, otherwise raise an error."""
|
| 33 |
+
assert os.path.exists(file_path), f"{description} does not exist: {file_path}"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def check_train_data(train_path):
|
| 37 |
+
"""Check frame and mask paths for each sample in training data."""
|
| 38 |
+
print(f"[Train] Checking: {train_path}")
|
| 39 |
+
with open(train_path, "rb") as f:
|
| 40 |
+
data = pkl.load(f)
|
| 41 |
+
|
| 42 |
+
for item in data:
|
| 43 |
+
# 修改这里:在检查之前先转换路径
|
| 44 |
+
real_frame_path = resolve_path(item["frame_path"])
|
| 45 |
+
real_mask_path = resolve_path(item["mask_path"])
|
| 46 |
+
|
| 47 |
+
check_file_exists(real_frame_path, "Frame path")
|
| 48 |
+
check_file_exists(real_mask_path, "Mask path")
|
| 49 |
+
|
| 50 |
+
print(f"[Train] ✅ Checked {train_path}. Samples: {len(data)}")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def check_val_data(val_path, reasoning=False):
|
| 54 |
+
"""Check validation data paths depending on reasoning mode."""
|
| 55 |
+
tag = "Reasoning Val" if reasoning else "Non-Reasoning Val"
|
| 56 |
+
print(f"[{tag}] Checking: {val_path}")
|
| 57 |
+
|
| 58 |
+
with open(val_path, "rb") as f:
|
| 59 |
+
data = pkl.load(f)
|
| 60 |
+
|
| 61 |
+
if reasoning:
|
| 62 |
+
for item in data:
|
| 63 |
+
# 修改这里
|
| 64 |
+
real_frame_path = resolve_path(item["frame_path"])
|
| 65 |
+
real_mask_path = resolve_path(item["mask_path"])
|
| 66 |
+
|
| 67 |
+
check_file_exists(real_frame_path, "Frame path")
|
| 68 |
+
check_file_exists(real_mask_path, "Mask path")
|
| 69 |
+
print(f"[{tag}] ✅ Checked {val_path}. Samples: {len(data)}")
|
| 70 |
+
else:
|
| 71 |
+
total_images = 0
|
| 72 |
+
for class_name, image_list in data.get('images', {}).items():
|
| 73 |
+
for image_path in image_list:
|
| 74 |
+
# 修改这里
|
| 75 |
+
check_file_exists(resolve_path(image_path), "Image path")
|
| 76 |
+
total_images += len(image_list)
|
| 77 |
+
|
| 78 |
+
for class_name, label_list in data.get('labels', {}).items():
|
| 79 |
+
for label_path in label_list:
|
| 80 |
+
# 修改这里
|
| 81 |
+
check_file_exists(resolve_path(label_path), "Label path")
|
| 82 |
+
|
| 83 |
+
print(f"[{tag}] ✅ Checked {val_path}. Samples: {total_images}")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def main():
|
| 87 |
+
train_paths, reasoning_paths, non_reasoning_paths = get_data_paths()
|
| 88 |
+
|
| 89 |
+
for train_path in train_paths:
|
| 90 |
+
check_train_data(train_path)
|
| 91 |
+
|
| 92 |
+
for val_path in non_reasoning_paths:
|
| 93 |
+
check_val_data(val_path, reasoning=False)
|
| 94 |
+
|
| 95 |
+
for val_path in reasoning_paths:
|
| 96 |
+
check_val_data(val_path, reasoning=True)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
main()
|
data_curation/build_vlpart.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
import argparse
|
| 3 |
+
import glob
|
| 4 |
+
import multiprocessing as mp
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
import time
|
| 9 |
+
import warnings
|
| 10 |
+
import cv2
|
| 11 |
+
import tqdm
|
| 12 |
+
|
| 13 |
+
from detectron2.config import get_cfg
|
| 14 |
+
from detectron2.data.detection_utils import read_image
|
| 15 |
+
from detectron2.utils.logger import setup_logger
|
| 16 |
+
|
| 17 |
+
import sys
|
| 18 |
+
sys.path.append('.')
|
| 19 |
+
from VLPart.vlpart.config import add_vlpart_config
|
| 20 |
+
|
| 21 |
+
from VLPart.demo.predictor import VisualizationDemo
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# constants
|
| 25 |
+
WINDOW_NAME = "image demo"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def setup_cfg(args):
|
| 29 |
+
# load config from file and command-line arguments
|
| 30 |
+
cfg = get_cfg()
|
| 31 |
+
add_vlpart_config(cfg)
|
| 32 |
+
cfg.merge_from_file(args.config_file)
|
| 33 |
+
cfg.merge_from_list(args.opts)
|
| 34 |
+
# Set score_threshold for builtin models
|
| 35 |
+
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
|
| 36 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
|
| 37 |
+
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
|
| 38 |
+
cfg.freeze()
|
| 39 |
+
return cfg
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_parser():
|
| 43 |
+
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--config-file",
|
| 46 |
+
default="VLPart/configs/joint/swinbase_cascade_lvis_paco_pascalpart_partimagenet.yaml",
|
| 47 |
+
metavar="FILE",
|
| 48 |
+
help="path to config file",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
|
| 51 |
+
parser.add_argument("--video-input", help="Path to video file.")
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--input",
|
| 54 |
+
nargs="+",
|
| 55 |
+
default='',
|
| 56 |
+
help="A list of space separated input images; "
|
| 57 |
+
"or a single glob pattern such as 'directory/*.jpg'",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--output",
|
| 61 |
+
default='',
|
| 62 |
+
help="A file or directory to save output visualizations. "
|
| 63 |
+
"If not given, will show output in an OpenCV window.",
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--vocabulary",
|
| 67 |
+
default="custom",
|
| 68 |
+
choices=['pascal_part', 'partimagenet', 'paco',
|
| 69 |
+
'voc', 'coco', 'lvis',
|
| 70 |
+
'pascal_part_voc', 'lvis_paco', 'custom'],
|
| 71 |
+
help="",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--custom_vocabulary",
|
| 75 |
+
default="",
|
| 76 |
+
help="",
|
| 77 |
+
)
|
| 78 |
+
parser.add_argument(
|
| 79 |
+
"--confidence-threshold",
|
| 80 |
+
type=float,
|
| 81 |
+
default=0.7,
|
| 82 |
+
help="Minimum score for instance predictions to be shown",
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--opts",
|
| 87 |
+
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
| 88 |
+
default=['MODEL.WEIGHTS', "/data/VLPart/ckpts/swinbase_cascade_lvis_paco_pascalpart_partimagenet.pth", "VIS.BOX", False],
|
| 89 |
+
nargs=argparse.REMAINDER,
|
| 90 |
+
)
|
| 91 |
+
return parser
|
| 92 |
+
|
| 93 |
+
def build_vlpart_model(custom_vocabulary):
|
| 94 |
+
|
| 95 |
+
mp.set_start_method("spawn", force=True)
|
| 96 |
+
args = get_parser().parse_args()
|
| 97 |
+
args.custom_vocabulary = custom_vocabulary
|
| 98 |
+
setup_logger(name="fvcore")
|
| 99 |
+
logger = setup_logger()
|
| 100 |
+
logger.info("Arguments: " + str(args))
|
| 101 |
+
|
| 102 |
+
cfg = setup_cfg(args)
|
| 103 |
+
model = VisualizationDemo(cfg, args)
|
| 104 |
+
|
| 105 |
+
return model
|
data_curation/check_dataset.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle as pkl
|
| 3 |
+
|
| 4 |
+
DATA_DIR = '/gemini/space/wrz/AffordanceNet/data'
|
| 5 |
+
|
| 6 |
+
# 新增一个路径修复函数
|
| 7 |
+
def resolve_path(path):
|
| 8 |
+
"""
|
| 9 |
+
如果路径是相对路径 (比如 ./data/...),将其转换为绝对路径
|
| 10 |
+
"""
|
| 11 |
+
if path.startswith('./data/'):
|
| 12 |
+
# 截掉前缀的 './data/' (长度为 7),拼接到真实的 DATA_DIR 后面
|
| 13 |
+
return os.path.join(DATA_DIR, path[7:])
|
| 14 |
+
elif path.startswith('./'):
|
| 15 |
+
# 兼容其他情况
|
| 16 |
+
return os.path.join(os.path.dirname(DATA_DIR), path[2:])
|
| 17 |
+
return path
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_data_paths():
|
| 21 |
+
"""Retrieve train/val/reasoning/non-reasoning pkl file paths."""
|
| 22 |
+
all_files = os.listdir(DATA_DIR)
|
| 23 |
+
train_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('train.pkl')]
|
| 24 |
+
val_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('val.pkl')]
|
| 25 |
+
reasoning_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('reasoning_val.pkl')]
|
| 26 |
+
non_reasoning_paths = [vp for vp in val_paths if vp not in reasoning_paths]
|
| 27 |
+
|
| 28 |
+
return train_paths, reasoning_paths, non_reasoning_paths
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def check_file_exists(file_path, description=""):
|
| 32 |
+
"""Assert that the file exists, otherwise raise an error."""
|
| 33 |
+
assert os.path.exists(file_path), f"{description} does not exist: {file_path}"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def check_train_data(train_path):
|
| 37 |
+
"""Check frame and mask paths for each sample in training data."""
|
| 38 |
+
print(f"[Train] Checking: {train_path}")
|
| 39 |
+
with open(train_path, "rb") as f:
|
| 40 |
+
data = pkl.load(f)
|
| 41 |
+
|
| 42 |
+
for item in data:
|
| 43 |
+
# 修改这里:在检查之前先转换路径
|
| 44 |
+
real_frame_path = resolve_path(item["frame_path"])
|
| 45 |
+
real_mask_path = resolve_path(item["mask_path"])
|
| 46 |
+
|
| 47 |
+
check_file_exists(real_frame_path, "Frame path")
|
| 48 |
+
check_file_exists(real_mask_path, "Mask path")
|
| 49 |
+
|
| 50 |
+
print(f"[Train] ✅ Checked {train_path}. Samples: {len(data)}")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def check_val_data(val_path, reasoning=False):
|
| 54 |
+
"""Check validation data paths depending on reasoning mode."""
|
| 55 |
+
tag = "Reasoning Val" if reasoning else "Non-Reasoning Val"
|
| 56 |
+
print(f"[{tag}] Checking: {val_path}")
|
| 57 |
+
|
| 58 |
+
with open(val_path, "rb") as f:
|
| 59 |
+
data = pkl.load(f)
|
| 60 |
+
|
| 61 |
+
if reasoning:
|
| 62 |
+
for item in data:
|
| 63 |
+
# 修改这里
|
| 64 |
+
real_frame_path = resolve_path(item["frame_path"])
|
| 65 |
+
real_mask_path = resolve_path(item["mask_path"])
|
| 66 |
+
|
| 67 |
+
check_file_exists(real_frame_path, "Frame path")
|
| 68 |
+
check_file_exists(real_mask_path, "Mask path")
|
| 69 |
+
print(f"[{tag}] ✅ Checked {val_path}. Samples: {len(data)}")
|
| 70 |
+
else:
|
| 71 |
+
total_images = 0
|
| 72 |
+
for class_name, image_list in data.get('images', {}).items():
|
| 73 |
+
for image_path in image_list:
|
| 74 |
+
# 修改这里
|
| 75 |
+
check_file_exists(resolve_path(image_path), "Image path")
|
| 76 |
+
total_images += len(image_list)
|
| 77 |
+
|
| 78 |
+
for class_name, label_list in data.get('labels', {}).items():
|
| 79 |
+
for label_path in label_list:
|
| 80 |
+
# 修改这里
|
| 81 |
+
check_file_exists(resolve_path(label_path), "Label path")
|
| 82 |
+
|
| 83 |
+
print(f"[{tag}] ✅ Checked {val_path}. Samples: {total_images}")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def main():
|
| 87 |
+
train_paths, reasoning_paths, non_reasoning_paths = get_data_paths()
|
| 88 |
+
|
| 89 |
+
for train_path in train_paths:
|
| 90 |
+
check_train_data(train_path)
|
| 91 |
+
|
| 92 |
+
for val_path in non_reasoning_paths:
|
| 93 |
+
check_val_data(val_path, reasoning=False)
|
| 94 |
+
|
| 95 |
+
for val_path in reasoning_paths:
|
| 96 |
+
check_val_data(val_path, reasoning=True)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
main()
|
data_curation/prompt_generation_handal_easy_reasoning.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import pickle
|
| 4 |
+
import requests
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 6 |
+
|
| 7 |
+
# Dataset name
|
| 8 |
+
DATASET = 'handal'
|
| 9 |
+
|
| 10 |
+
# Handle-equipped objects to filter
|
| 11 |
+
OBJECTS_WITH_HANDLE = [
|
| 12 |
+
'strainers', 'fixed joint pliers', 'hammers', 'ladles', 'whisks', 'measuring cups',
|
| 13 |
+
'locking pliers', 'power drills', 'adjustable wrenches', 'mugs', 'ratchets', 'utensils',
|
| 14 |
+
'combinational wrenches', 'pots pans', 'spatulas', 'screwdrivers', 'slip joint pliers'
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
# OpenAI API settings (update key!)
|
| 18 |
+
API_URL = 'https://api.openai.com/v1/chat/completions'
|
| 19 |
+
HEADERS = {
|
| 20 |
+
'Content-Type': 'application/json',
|
| 21 |
+
'Authorization': 'Bearer YOUR-API-KEY' # Replace with your real key
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def read_pkl_file(pkl_path):
|
| 26 |
+
"""Reads pkl file and filters entries for objects with handles."""
|
| 27 |
+
with open(pkl_path, 'rb') as f:
|
| 28 |
+
val_data = pickle.load(f)
|
| 29 |
+
|
| 30 |
+
filtered_data = []
|
| 31 |
+
for class_name, image_list in val_data['images'].items():
|
| 32 |
+
if class_name in OBJECTS_WITH_HANDLE:
|
| 33 |
+
for idx, img in enumerate(image_list):
|
| 34 |
+
class_label = val_data['class_names'][class_name][idx]
|
| 35 |
+
save_path = os.path.join(
|
| 36 |
+
f'./reason_affordance/{DATASET}_easy_reasoning',
|
| 37 |
+
class_label,
|
| 38 |
+
os.path.splitext(os.path.basename(img))[0] + ".json"
|
| 39 |
+
)
|
| 40 |
+
if not os.path.exists(save_path):
|
| 41 |
+
filtered_data.append({'img_name': img, 'class_name': class_label})
|
| 42 |
+
return filtered_data
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def process_sentence(class_name):
|
| 46 |
+
"""Send prompt to OpenAI and return generated sentence."""
|
| 47 |
+
prompt = [
|
| 48 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
| 49 |
+
{'role': 'system',
|
| 50 |
+
'content': (
|
| 51 |
+
'Based on several words where the first is category name, '
|
| 52 |
+
'please design an instruction <1> and instruction <2> in embodied scenes. '
|
| 53 |
+
'The instruction <1> must include object category name itself. '
|
| 54 |
+
'The instruction <2> must include the object category name itself. '
|
| 55 |
+
'The instruction <2> must belong to embodied manipulation and give action if instruction <1> provides. '
|
| 56 |
+
'The instruction <2> does not exceed 50 words.'
|
| 57 |
+
)},
|
| 58 |
+
{'role': 'user', 'content': 'mug'},
|
| 59 |
+
{'role': 'assistant',
|
| 60 |
+
'content': '<1> I need a drink. Please find a mug to fill water. <2> The mug has a handle as affordance map. So the robot can hold its handle.'},
|
| 61 |
+
{'role': 'user', 'content': 'knife'},
|
| 62 |
+
{'role': 'assistant',
|
| 63 |
+
'content': '<1> Please give me a knife to cut apple. <2> The knife has a handle, and you can use its handle to cut apple.'},
|
| 64 |
+
{'role': 'user', 'content': 'hammers'},
|
| 65 |
+
{'role': 'assistant',
|
| 66 |
+
'content': '<1> What is the proper way to hold the hammers? <2> The correct method is to hold the hammer by its handle.'},
|
| 67 |
+
{'role': 'user', 'content': 'fork'},
|
| 68 |
+
{'role': 'assistant',
|
| 69 |
+
'content': '<1> Kindly pick up the fork. <2> You will be holding the fork handle.'},
|
| 70 |
+
{'role': 'user', 'content': 'screwdrivers'},
|
| 71 |
+
{'role': 'assistant',
|
| 72 |
+
'content': '<1> I need a tool to tighten or loosen screws. <2> The screwdriver is here, hold its handle to turn and control screws.'},
|
| 73 |
+
{'role': 'user', 'content': class_name}
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
response = requests.post(API_URL, headers=HEADERS, json={'model': 'gpt-4', 'messages': prompt})
|
| 77 |
+
if response.status_code == 200:
|
| 78 |
+
return response.json()['choices'][0]['message']['content']
|
| 79 |
+
else:
|
| 80 |
+
print(f"API Error for {class_name}:", response.text)
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def process_json(data):
|
| 85 |
+
"""Process a single data entry and save result to JSON file."""
|
| 86 |
+
class_name = data["class_name"]
|
| 87 |
+
|
| 88 |
+
# Retry up to 5 times
|
| 89 |
+
for _ in range(5):
|
| 90 |
+
result = process_sentence(class_name)
|
| 91 |
+
if not result or '<1>' not in result or '<2>' not in result:
|
| 92 |
+
continue
|
| 93 |
+
break
|
| 94 |
+
else:
|
| 95 |
+
print(f"Failed to process: {class_name}")
|
| 96 |
+
return
|
| 97 |
+
|
| 98 |
+
print("Processed:", result)
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
question = result.split('<2>')[0].split('<1>')[-1].strip()
|
| 102 |
+
answer = result.split('<2>')[-1].strip()
|
| 103 |
+
|
| 104 |
+
save_dir = os.path.join(f'./reason_affordance/{DATASET}_easy_reasoning', class_name)
|
| 105 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 106 |
+
|
| 107 |
+
save_path = os.path.join(save_dir, os.path.splitext(os.path.basename(data["img_name"]))[0] + ".json")
|
| 108 |
+
output = {'img_name': data["img_name"], 'class_name': class_name, 'question': question, 'answer': answer}
|
| 109 |
+
|
| 110 |
+
with open(save_path, 'w') as f:
|
| 111 |
+
json.dump(output, f, indent=4)
|
| 112 |
+
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"Error saving file for {class_name}:", e)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def main():
|
| 118 |
+
pkl_file = f'./data/{DATASET}_val.pkl'
|
| 119 |
+
data_list = read_pkl_file(pkl_file)
|
| 120 |
+
|
| 121 |
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
| 122 |
+
executor.map(process_json, data_list)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
main()
|
data_curation/prompt_generation_handal_hard_reasoning.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import pickle
|
| 4 |
+
import requests
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 6 |
+
|
| 7 |
+
# Dataset configuration
|
| 8 |
+
DATASET = 'handal'
|
| 9 |
+
|
| 10 |
+
# Object categories with handle
|
| 11 |
+
OBJECTS_WITH_HANDLE = [
|
| 12 |
+
'strainers', 'fixed joint pliers', 'hammers', 'ladles', 'whisks', 'measuring cups',
|
| 13 |
+
'locking pliers', 'power drills', 'adjustable wrenches', 'mugs', 'ratchets', 'utensils',
|
| 14 |
+
'combinational wrenches', 'pots pans', 'spatulas', 'screwdrivers', 'slip joint pliers'
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
# OpenAI API settings (update key!)
|
| 18 |
+
API_URL = 'https://api.openai.com/v1/chat/completions'
|
| 19 |
+
HEADERS = {
|
| 20 |
+
'Content-Type': 'application/json',
|
| 21 |
+
'Authorization': 'Bearer YOUR-API-KEY' # Replace with your real key
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def read_pkl_file(pkl_path):
|
| 26 |
+
"""
|
| 27 |
+
Load a pickle file and extract data entries containing objects with handles,
|
| 28 |
+
skipping already processed samples.
|
| 29 |
+
"""
|
| 30 |
+
with open(pkl_path, 'rb') as f:
|
| 31 |
+
val_data = pickle.load(f)
|
| 32 |
+
|
| 33 |
+
filtered_data = []
|
| 34 |
+
for class_name, img_list in val_data['images'].items():
|
| 35 |
+
if class_name not in OBJECTS_WITH_HANDLE:
|
| 36 |
+
continue
|
| 37 |
+
for i, img_path in enumerate(img_list):
|
| 38 |
+
class_label = val_data['class_names'][class_name][i]
|
| 39 |
+
save_path = os.path.join(
|
| 40 |
+
f'./reason_affordance/{DATASET}_hard_reasoning',
|
| 41 |
+
class_label,
|
| 42 |
+
os.path.splitext(os.path.basename(img_path))[0] + ".json"
|
| 43 |
+
)
|
| 44 |
+
if not os.path.exists(save_path):
|
| 45 |
+
filtered_data.append({'img_name': img_path, 'class_name': class_label})
|
| 46 |
+
|
| 47 |
+
return filtered_data
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def process_sentence(category):
|
| 51 |
+
"""
|
| 52 |
+
Generate reasoning instructions (<1>, <2>) from category name using GPT.
|
| 53 |
+
"""
|
| 54 |
+
payload = {
|
| 55 |
+
'model': 'gpt-4',
|
| 56 |
+
'messages': [
|
| 57 |
+
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
| 58 |
+
{'role': 'system',
|
| 59 |
+
'content': (
|
| 60 |
+
'Based on several words where the first is category name, please design an instruction <1> and instruction <2> in embodied scenes. '
|
| 61 |
+
'The instruction <1> must not include object category name itself. '
|
| 62 |
+
'The instruction <2> must include the object category name itself. '
|
| 63 |
+
'The instruction <2> must belong to embodied manipulation and give action if instruction <1> provides. '
|
| 64 |
+
'The instruction <2> does not exceed 50 words.'
|
| 65 |
+
)},
|
| 66 |
+
{'role': 'user', 'content': 'microwave, open'},
|
| 67 |
+
{'role': 'assistant', 'content': '<1> Heat up food quickly. <2> The microwave is closed, so it can be open to access the food inside.'},
|
| 68 |
+
{'role': 'user', 'content': 'knife'},
|
| 69 |
+
{'role': 'assistant', 'content': '<1> I want to cut a bread. <2> The knife has a handle, you can use its handle to cut bread.'},
|
| 70 |
+
{'role': 'user', 'content': 'computer mouse'},
|
| 71 |
+
{'role': 'assistant', 'content': '<1> Give me a tool to control the cursor on the screen. <2> The computer mouse is here. It has no handle, so you can grasp its whole body.'},
|
| 72 |
+
{'role': 'user', 'content': 'fork'},
|
| 73 |
+
{'role': 'assistant', 'content': '<1> Use to pierce and lift food. <2> The fork is here, and its handle can be grasped.'},
|
| 74 |
+
{'role': 'user', 'content': 'screwdrivers'},
|
| 75 |
+
{'role': 'assistant', 'content': '<1> I need a tool to tighten or loosen screws. <2> The screwdriver is here, hold its handle to turn and control screws.'},
|
| 76 |
+
{'role': 'user', 'content': category}
|
| 77 |
+
]
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
response = requests.post(API_URL, headers=HEADERS, json=payload)
|
| 81 |
+
if response.status_code == 200:
|
| 82 |
+
return response.json()['choices'][0]['message']['content']
|
| 83 |
+
else:
|
| 84 |
+
print(f"[API Error] {category}: {response.status_code} - {response.text}")
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def process_json(entry):
|
| 89 |
+
"""
|
| 90 |
+
Process a single image/class entry by generating reasoning and saving result to file.
|
| 91 |
+
"""
|
| 92 |
+
class_name = entry['class_name']
|
| 93 |
+
|
| 94 |
+
for _ in range(5):
|
| 95 |
+
result = process_sentence(class_name)
|
| 96 |
+
if result and '<1>' in result and '<2>' in result:
|
| 97 |
+
break
|
| 98 |
+
else:
|
| 99 |
+
print(f"[Retry Failed] {class_name}")
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
question = result.split('<2>')[0].split('<1>')[-1].strip()
|
| 104 |
+
answer = result.split('<2>')[-1].strip()
|
| 105 |
+
|
| 106 |
+
save_dir = os.path.join(f'./reason_affordance/{DATASET}_hard_reasoning', class_name)
|
| 107 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 108 |
+
|
| 109 |
+
save_path = os.path.join(save_dir, os.path.splitext(os.path.basename(entry['img_name']))[0] + ".json")
|
| 110 |
+
output = {
|
| 111 |
+
'img_name': entry['img_name'],
|
| 112 |
+
'class_name': class_name,
|
| 113 |
+
'question': question,
|
| 114 |
+
'answer': answer
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
with open(save_path, 'w') as f:
|
| 118 |
+
json.dump(output, f, indent=4)
|
| 119 |
+
print(f"[Saved] {save_path}")
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f"[Error] Failed to save {class_name}: {e}")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def main():
|
| 125 |
+
"""
|
| 126 |
+
Main execution: loads data, then processes in parallel.
|
| 127 |
+
"""
|
| 128 |
+
pkl_path = f'./data/{DATASET}_val.pkl'
|
| 129 |
+
entries = read_pkl_file(pkl_path)
|
| 130 |
+
|
| 131 |
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
| 132 |
+
executor.map(process_json, entries)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
main()
|
data_curation/vlpart_sam2_tracking.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import pickle
|
| 5 |
+
import argparse
|
| 6 |
+
import numpy as np
|
| 7 |
+
import warnings
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from detectron2.data.detection_utils import read_image
|
| 13 |
+
from supervision import Detections, BoxAnnotator, MaskAnnotator, LabelAnnotator, mask_to_xyxy
|
| 14 |
+
|
| 15 |
+
from sam2.build_sam import build_sam2_video_predictor
|
| 16 |
+
from VLPart.build_vlpart import build_vlpart_model
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
warnings.filterwarnings('ignore')
|
| 20 |
+
|
| 21 |
+
# Constants
|
| 22 |
+
SAM2_CONFIG = "sam2_hiera_l.yaml"
|
| 23 |
+
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
|
| 24 |
+
OUTPUT_ROOT = "/data/robot-merlin/mask_vlpart+sam2_tracking"
|
| 25 |
+
OUTPUT_ROOT_IMG = "/data/robot-merlin/mask_vlpart+sam2_tracking_with_image"
|
| 26 |
+
|
| 27 |
+
# Set up torch environment
|
| 28 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
| 29 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
| 30 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 31 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 32 |
+
|
| 33 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_affordance_data(pkl_path):
|
| 37 |
+
"""
|
| 38 |
+
Load affordance data from a pickle file and organize it by video directory.
|
| 39 |
+
Args:
|
| 40 |
+
pkl_path (str): Path to the pickle file containing affordance data.
|
| 41 |
+
Returns:
|
| 42 |
+
dict: A dictionary where keys are video directory paths and values are lists of data entries.
|
| 43 |
+
"""
|
| 44 |
+
with open(pkl_path, 'rb') as f:
|
| 45 |
+
datas = pickle.load(f)
|
| 46 |
+
|
| 47 |
+
data_dict = {}
|
| 48 |
+
for data in datas:
|
| 49 |
+
vid_path = os.path.dirname(data['frame_path'])
|
| 50 |
+
data_dict.setdefault(vid_path, []).append(data)
|
| 51 |
+
return data_dict
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def init_vlpart_once(text, prev_text, vlpart_model):
|
| 55 |
+
"""
|
| 56 |
+
Initialize VLPart model if the text has changed.
|
| 57 |
+
"""
|
| 58 |
+
if text != prev_text:
|
| 59 |
+
if vlpart_model is not None:
|
| 60 |
+
del vlpart_model
|
| 61 |
+
vlpart_model = build_vlpart_model(text)
|
| 62 |
+
return vlpart_model, text
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def run_vlpart_on_first_frame(vlpart_model, image_path):
|
| 66 |
+
"""
|
| 67 |
+
Run VLPart model on the first frame to get bounding boxes.
|
| 68 |
+
"""
|
| 69 |
+
img = read_image(image_path, format="BGR")
|
| 70 |
+
predictions, _ = vlpart_model.run_on_image(img)
|
| 71 |
+
if len(predictions["instances"]) != 1:
|
| 72 |
+
return None
|
| 73 |
+
return predictions["instances"].pred_boxes.tensor.cpu().numpy()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def run_sam2_tracking(video_dir, frame_names, sam2_predictor, boxes):
|
| 77 |
+
"""
|
| 78 |
+
Run SAM2 tracking on the video frames using the provided bounding boxes.
|
| 79 |
+
"""
|
| 80 |
+
inference_state = sam2_predictor.init_state(video_path=video_dir)
|
| 81 |
+
sam2_predictor.reset_state(inference_state)
|
| 82 |
+
|
| 83 |
+
_, obj_ids, mask_logits = sam2_predictor.add_new_points_or_box(
|
| 84 |
+
inference_state=inference_state,
|
| 85 |
+
frame_idx=0,
|
| 86 |
+
obj_id=1,
|
| 87 |
+
box=boxes,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
results = {}
|
| 91 |
+
for frame_idx, out_ids, out_logits in sam2_predictor.propagate_in_video(inference_state):
|
| 92 |
+
results[frame_idx] = {
|
| 93 |
+
oid: (out_logits[i] > 0).cpu().numpy()
|
| 94 |
+
for i, oid in enumerate(out_ids)
|
| 95 |
+
}
|
| 96 |
+
return results
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def save_tracking_results(video_dir, frame_names, video_segments, object_name, output_base, vid):
|
| 100 |
+
"""
|
| 101 |
+
Save the tracking results to the specified output directory.
|
| 102 |
+
"""
|
| 103 |
+
objects = [object_name]
|
| 104 |
+
id_to_objects = {i: obj for i, obj in enumerate(objects, start=1)}
|
| 105 |
+
|
| 106 |
+
output_dir = Path(f"{output_base}/{vid:06d}")
|
| 107 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 108 |
+
|
| 109 |
+
output_dir_img = Path(f"{OUTPUT_ROOT_IMG}/{vid:06d}")
|
| 110 |
+
output_dir_img.mkdir(parents=True, exist_ok=True)
|
| 111 |
+
|
| 112 |
+
box_annotator = BoxAnnotator()
|
| 113 |
+
label_annotator = LabelAnnotator()
|
| 114 |
+
mask_annotator = MaskAnnotator()
|
| 115 |
+
|
| 116 |
+
for idx, masks in video_segments.items():
|
| 117 |
+
frame_path = os.path.join(video_dir, frame_names[idx])
|
| 118 |
+
frame = cv2.imread(frame_path)
|
| 119 |
+
|
| 120 |
+
obj_ids = list(masks.keys())
|
| 121 |
+
mask_arr = np.concatenate(list(masks.values()), axis=0)
|
| 122 |
+
|
| 123 |
+
detections = Detections(
|
| 124 |
+
xyxy=mask_to_xyxy(mask_arr),
|
| 125 |
+
mask=mask_arr,
|
| 126 |
+
class_id=np.array(obj_ids, dtype=np.int32),
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
annotated = box_annotator.annotate(frame.copy(), detections)
|
| 130 |
+
annotated = label_annotator.annotate(annotated, detections, [id_to_objects[i] for i in obj_ids])
|
| 131 |
+
annotated = mask_annotator.annotate(annotated, detections)
|
| 132 |
+
|
| 133 |
+
cv2.imwrite(str(output_dir_img / frame_names[idx]), annotated)
|
| 134 |
+
cv2.imwrite(str(output_dir / frame_names[idx]), mask_arr[0] * 255)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_sorted_frame_names(video_dir):
|
| 138 |
+
return sorted([
|
| 139 |
+
f for f in os.listdir(video_dir)
|
| 140 |
+
if f.lower().endswith(('.jpg', '.jpeg'))
|
| 141 |
+
], key=lambda name: int(os.path.splitext(name)[0]))
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def main(openx_data, text_override=None):
|
| 145 |
+
# You can reorganize the data loading logic as needed
|
| 146 |
+
data_dict = load_affordance_data(f'./data/{openx_data}_for_affordance.pkl')
|
| 147 |
+
|
| 148 |
+
# Initialize SAM2 predictor
|
| 149 |
+
sam2_predictor = build_sam2_video_predictor(SAM2_CONFIG, SAM2_CHECKPOINT, device=device)
|
| 150 |
+
|
| 151 |
+
prev_text = ''
|
| 152 |
+
vlpart_model = None
|
| 153 |
+
|
| 154 |
+
for video_dir, data_list in tqdm(data_dict.items()):
|
| 155 |
+
first_sample = data_list[0]
|
| 156 |
+
frame_path = first_sample['frame_path']
|
| 157 |
+
task_class = first_sample['task_object_class']
|
| 158 |
+
|
| 159 |
+
# Only process specific classes
|
| 160 |
+
if not any(k in task_class for k in ['door', 'drawer', 'knife']):
|
| 161 |
+
continue
|
| 162 |
+
|
| 163 |
+
# Initialize VLPart model with the task class
|
| 164 |
+
input_text = f"{task_class} handle" if not text_override else text_override
|
| 165 |
+
vlpart_model, prev_text = init_vlpart_once(input_text, prev_text, vlpart_model)
|
| 166 |
+
|
| 167 |
+
# Process the first frame to get bounding boxes
|
| 168 |
+
boxes = run_vlpart_on_first_frame(vlpart_model, frame_path)
|
| 169 |
+
if boxes is None:
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
# Run SAM2 tracking on the video frames
|
| 173 |
+
frame_names = get_sorted_frame_names(video_dir)
|
| 174 |
+
segments = run_sam2_tracking(video_dir, frame_names, sam2_predictor, boxes)
|
| 175 |
+
save_tracking_results(video_dir, frame_names, segments, input_text,
|
| 176 |
+
f"{OUTPUT_ROOT}/", first_sample['vid'])
|
| 177 |
+
print(f"[Done] {frame_path} | {task_class}")
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
if __name__ == "__main__":
|
| 181 |
+
parser = argparse.ArgumentParser("VLPart + SAM2 Tracking Demo")
|
| 182 |
+
parser.add_argument("--pipeline", type=str, default="referring_expression_segmentation", help="Pipeline task")
|
| 183 |
+
parser.add_argument("--text_input", type=str, default=None, help="Optional override for input text")
|
| 184 |
+
parser.add_argument("--dataset", type=str, default="bridge", help="Dataset name (e.g., bridge)")
|
| 185 |
+
args = parser.parse_args()
|
| 186 |
+
|
| 187 |
+
main(args.dataset, args.pipeline, args.text_input)
|
docs/dataset.md
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Dataset
|
| 2 |
+
|
| 3 |
+
To train our affordance segmentation model, we use two types of data:
|
| 4 |
+
* **General Segmentation Data**: This follows [LISA](https://github.com/dvlab-research/LISA).
|
| 5 |
+
* **Affordance Segmentation Data**: This is a large-scale dataset that we collect.
|
| 6 |
+
|
| 7 |
+
### General Segmentation Data
|
| 8 |
+
These data is organized as follows:
|
| 9 |
+
```
|
| 10 |
+
./data/
|
| 11 |
+
├── lisa_data
|
| 12 |
+
│ ├── ade20k
|
| 13 |
+
│ ├── coco
|
| 14 |
+
│ ├── cocostuff
|
| 15 |
+
│ ├── llava_dataset
|
| 16 |
+
│ ├── mapillary
|
| 17 |
+
│ ├── reason_seg
|
| 18 |
+
│ ├── refer_seg
|
| 19 |
+
│ ├── vlpart
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
### Affordance Segmentation Data
|
| 23 |
+
|
| 24 |
+
We employ images from HANDAL, Open-X, GraspNet, EgoObjects, and RLBench in our affordance segmentation task.
|
| 25 |
+
|
| 26 |
+
The HANDAL data is downloaded and organized according to its official [repo](https://github.com/NVlabs/HANDAL).
|
| 27 |
+
Other data can be downloaded from the [Hugging Face](https://huggingface.co/datasets/Dongming97/RAGNet).
|
| 28 |
+
|
| 29 |
+
The training data is organized as follows:
|
| 30 |
+
```
|
| 31 |
+
./data/
|
| 32 |
+
├── openx_train.pkl
|
| 33 |
+
├── graspnet_train.pkl
|
| 34 |
+
├── egoobjects_train.pkl
|
| 35 |
+
├── rlbench_train.pkl
|
| 36 |
+
├── handal_hard_reasoning_train.pkl
|
| 37 |
+
├── egoobjects_easy_reasoning_train.pkl
|
| 38 |
+
├── egoobjects_hard_reasoning_train.pkl
|
| 39 |
+
├── HANDAL
|
| 40 |
+
│ ├── without_depth
|
| 41 |
+
│ ├── handal_dataset_adjustable_wrenches
|
| 42 |
+
│ ├── handal_dataset_combinational_wrenches
|
| 43 |
+
│ ├── handal_dataset_fixed_joint_pliers
|
| 44 |
+
│ ├── ...
|
| 45 |
+
├── openx
|
| 46 |
+
│ ├── images
|
| 47 |
+
│ ├── fractal20220817_data
|
| 48 |
+
│ ├── bridge
|
| 49 |
+
│ ├── masks
|
| 50 |
+
│ ├── fractal20220817_data
|
| 51 |
+
│ ├── bridge
|
| 52 |
+
├── graspnet
|
| 53 |
+
│ ├── images
|
| 54 |
+
│ ├── masks
|
| 55 |
+
│ ├── test_seen
|
| 56 |
+
│ ├── test_novel
|
| 57 |
+
├── egoobjects
|
| 58 |
+
│ ├── images
|
| 59 |
+
│ ├── masks
|
| 60 |
+
├── rlbench
|
| 61 |
+
│ ├── images
|
| 62 |
+
│ ├── masks
|
| 63 |
+
├── 3doi
|
| 64 |
+
│ ├── images
|
| 65 |
+
│ ├── masks
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
The evaluation data is also in the same dictory, but with the `*_eval.pkl` files instead of `*_train.pkl`.
|
| 69 |
+
|
| 70 |
+
```
|
| 71 |
+
./data/
|
| 72 |
+
├── handal_mini_val.pkl
|
| 73 |
+
├── graspnet_test_seen_val.pkl
|
| 74 |
+
├── graspnet_test_novel_val.pkl
|
| 75 |
+
├── 3doi_val.pkl
|
| 76 |
+
├── handal_easy_reasoning_val.pkl
|
| 77 |
+
├── handal_hard_reasoning_val.pkl
|
| 78 |
+
├── 3doi_easy_reasoning_val.pkl
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
You can use the following script to confirm if data is organized correctly:
|
| 82 |
+
```bash
|
| 83 |
+
python data_curation/check_dataset.py
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### About data curation
|
| 87 |
+
1. **SAM2**: We use SAM2 to generate affordance mask if the dataset provides box annotation.
|
| 88 |
+
2. **Florence-2 + SAM2**: We use Florence-2 to generate the initial segmentation masks of some complete objects, and then refine them with SAM2. Please see [Florence-2+SAM2](https://github.com/IDEA-Research/Grounded-SAM-2).
|
| 89 |
+
3. **VLPart + SAM2**: We use VLPart to generate box of object part, and then refine them with SAM2. We refer to [VLPart](https://github.com/facebookresearch/VLPart).
|
| 90 |
+
We provide our inference demo scripts in `data_curation/build_vlpart.py` and `data_curation/vlpart_sam2_tracking.py`.
|
| 91 |
+
4. **Reasoning Instruction**: We provide two example scripts to generate reasoning instructions for the affordance segmentation task:
|
| 92 |
+
- `data_curation/prompt_generation_handal_easy_reasoning.py`: This script generates easy reasoning instructions for the HANDAL dataset.
|
| 93 |
+
- `data_curation/prompt_generation_handal_hard_reasoning.py`: This script generates hard reasoning instructions for the HANDAL dataset.
|
docs/installation.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Installation
|
| 2 |
+
The environment installation mainly follows [LISA](https://github.com/dvlab-research/LISA).
|
| 3 |
+
```
|
| 4 |
+
https://github.com/wudongming97/AffordanceNet.git
|
| 5 |
+
cd AffordanceNet
|
| 6 |
+
conda create -n affordancenet python=3.9
|
| 7 |
+
conda activate affordancenet
|
| 8 |
+
pip install -r requirements.txt
|
| 9 |
+
pip install flash-attn --no-build-isolation
|
| 10 |
+
```
|
docs/training_and_evaluation.md
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Training and Evaluation
|
| 2 |
+
|
| 3 |
+
### Pre-trained Weights
|
| 4 |
+
#### LLaVA
|
| 5 |
+
For convenience of using pre-trained LLaVA weights, we provide a link from [Hugging Face](https://huggingface.co/Dongming97/LLaVA-Lightning-7B-v1-1).
|
| 6 |
+
|
| 7 |
+
#### SAM
|
| 8 |
+
Download SAM ViT-H pre-trained weights from the [link](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth).
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
### Training
|
| 12 |
+
To train AffordanceVLM, you can use the following command.
|
| 13 |
+
```
|
| 14 |
+
bash ./scripts/train.sh
|
| 15 |
+
```
|
| 16 |
+
When training is finished, to get the full model weight:
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
cd ./runs/AffordanceVLM-7B/ckpt_model && python zero_to_fp32.py . ../pytorch_model.bin
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
### Merge LoRA Weight
|
| 23 |
+
Merge the LoRA weights of `pytorch_model.bin`, save the resulting model into your desired path in the Hugging Face format:
|
| 24 |
+
```
|
| 25 |
+
CUDA_VISIBLE_DEVICES="" python merge_lora_weights_and_save_hf_model.py \
|
| 26 |
+
--version="PATH_TO_LLaVA" \
|
| 27 |
+
--weight="PATH_TO_pytorch_model.bin" \
|
| 28 |
+
--save_path="PATH_TO_SAVED_MODEL"
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
For example:
|
| 32 |
+
```
|
| 33 |
+
CUDA_VISIBLE_DEVICES="" python3 merge_lora_weights_and_save_hf_model.py \
|
| 34 |
+
--version="./LLaVA/LLaVA-Lightning-7B-v1-1" \
|
| 35 |
+
--weight="./runs/AffordanceVLM-7B/pytorch_model.bin" \
|
| 36 |
+
--save_path="./exps/AffordanceVLM-7B"
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
### Evaluation
|
| 40 |
+
To evaluate AffordanceVLM on the entire [HANDAL](https://github.com/NVlabs/HANDAL) dataset, please adjust the `--dataset_dir` parameter in `evaluate.sh`.
|
| 41 |
+
```
|
| 42 |
+
bash ./scripts/evaluate.sh
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
To chat with [AffordanceVLM-7B](https://huggingface.co/Dongming97/AffordanceVLM):
|
| 46 |
+
```
|
| 47 |
+
CUDA_VISIBLE_DEVICES=0 python chat.py --version=./exps/AffordanceVLM-7B
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### Main Results
|
| 51 |
+
|
| 52 |
+
HANDAL:
|
| 53 |
+
|
| 54 |
+
| Method | gIoU | cIoU |
|
| 55 |
+
|:----------------:|:----:|-----:|
|
| 56 |
+
| AffordanceVLM-7B | 60.3 | 60.8 |
|
imgs/.ipynb_checkpoints/AffordanceNet-checkpoint.jpg
ADDED
|
Git LFS Details
|
imgs/AffordanceNet.jpg
ADDED
|
Git LFS Details
|
imgs/AffordanceNet.png
ADDED
|
Git LFS Details
|
merge_lora_weights_and_save_hf_model.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import glob
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import transformers
|
| 11 |
+
from peft import LoraConfig, get_peft_model
|
| 12 |
+
from transformers import AutoTokenizer
|
| 13 |
+
|
| 14 |
+
from model.AffordanceVLM import AffordanceVLMForCausalLM
|
| 15 |
+
from utils.utils import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def parse_args(args):
|
| 19 |
+
parser = argparse.ArgumentParser(
|
| 20 |
+
description="merge lora weights and save model with hf format"
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview"
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument("--vis_save_path", default="./vis_output", type=str)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--precision",
|
| 28 |
+
default="bf16",
|
| 29 |
+
type=str,
|
| 30 |
+
choices=["fp32", "bf16", "fp16"],
|
| 31 |
+
help="precision for inference",
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
|
| 34 |
+
parser.add_argument("--out_dim", default=256, type=int)
|
| 35 |
+
parser.add_argument("--image_size", default=1024, type=int, help="image size")
|
| 36 |
+
parser.add_argument("--model_max_length", default=512, type=int)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--vision-tower", default="openai/clip-vit-large-patch14", type=str
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument("--lora_r", default=8, type=int)
|
| 41 |
+
parser.add_argument("--lora_alpha", default=16, type=int)
|
| 42 |
+
parser.add_argument("--lora_dropout", default=0.05, type=float)
|
| 43 |
+
parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
|
| 44 |
+
parser.add_argument("--local-rank", default=0, type=int, help="node rank")
|
| 45 |
+
parser.add_argument("--train_mask_decoder", action="store_true", default=True)
|
| 46 |
+
parser.add_argument("--use_mm_start_end", action="store_true", default=True)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--conv_type",
|
| 49 |
+
default="llava_v1",
|
| 50 |
+
type=str,
|
| 51 |
+
choices=["llava_v1", "llava_llama_2"],
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument("--weight", default="", type=str, required=True)
|
| 54 |
+
parser.add_argument("--save_path", default="./lisa_model", type=str, required=True)
|
| 55 |
+
return parser.parse_args(args)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def main(args):
|
| 59 |
+
args = parse_args(args)
|
| 60 |
+
os.makedirs(args.vis_save_path, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
# Create model
|
| 63 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 64 |
+
args.version,
|
| 65 |
+
cache_dir=None,
|
| 66 |
+
model_max_length=args.model_max_length,
|
| 67 |
+
padding_side="right",
|
| 68 |
+
use_fast=False,
|
| 69 |
+
)
|
| 70 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 71 |
+
num_added_tokens = tokenizer.add_tokens("[SEG]")
|
| 72 |
+
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 73 |
+
num_added_tokens = tokenizer.add_tokens("[AFF]")
|
| 74 |
+
args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
|
| 75 |
+
|
| 76 |
+
if args.use_mm_start_end:
|
| 77 |
+
tokenizer.add_tokens(
|
| 78 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
model_args = {
|
| 82 |
+
"train_mask_decoder": args.train_mask_decoder,
|
| 83 |
+
"out_dim": args.out_dim,
|
| 84 |
+
"seg_token_idx": args.seg_token_idx,
|
| 85 |
+
"aff_token_idx": args.aff_token_idx,
|
| 86 |
+
"vision_tower": args.vision_tower,
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
torch_dtype = torch.float32
|
| 90 |
+
if args.precision == "bf16":
|
| 91 |
+
torch_dtype = torch.bfloat16
|
| 92 |
+
elif args.precision == "fp16":
|
| 93 |
+
torch_dtype = torch.half
|
| 94 |
+
model = AffordanceVLMForCausalLM.from_pretrained(
|
| 95 |
+
args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args
|
| 96 |
+
)
|
| 97 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 98 |
+
model.config.bos_token_id = tokenizer.bos_token_id
|
| 99 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 100 |
+
|
| 101 |
+
model.get_model().initialize_vision_modules(model.get_model().config)
|
| 102 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 103 |
+
vision_tower.to(dtype=torch_dtype)
|
| 104 |
+
model.get_model().initialize_lisa_modules(model.get_model().config)
|
| 105 |
+
|
| 106 |
+
lora_r = args.lora_r
|
| 107 |
+
if lora_r > 0:
|
| 108 |
+
|
| 109 |
+
def find_linear_layers(model, lora_target_modules):
|
| 110 |
+
cls = torch.nn.Linear
|
| 111 |
+
lora_module_names = set()
|
| 112 |
+
for name, module in model.named_modules():
|
| 113 |
+
if (
|
| 114 |
+
isinstance(module, cls)
|
| 115 |
+
and all(
|
| 116 |
+
[
|
| 117 |
+
x not in name
|
| 118 |
+
for x in [
|
| 119 |
+
"visual_model",
|
| 120 |
+
"vision_tower",
|
| 121 |
+
"mm_projector",
|
| 122 |
+
"text_hidden_fcs",
|
| 123 |
+
]
|
| 124 |
+
]
|
| 125 |
+
)
|
| 126 |
+
and any([x in name for x in lora_target_modules])
|
| 127 |
+
):
|
| 128 |
+
lora_module_names.add(name)
|
| 129 |
+
return sorted(list(lora_module_names))
|
| 130 |
+
|
| 131 |
+
lora_alpha = args.lora_alpha
|
| 132 |
+
lora_dropout = args.lora_dropout
|
| 133 |
+
lora_target_modules = find_linear_layers(
|
| 134 |
+
model, args.lora_target_modules.split(",")
|
| 135 |
+
)
|
| 136 |
+
lora_config = LoraConfig(
|
| 137 |
+
r=lora_r,
|
| 138 |
+
lora_alpha=lora_alpha,
|
| 139 |
+
target_modules=lora_target_modules,
|
| 140 |
+
lora_dropout=lora_dropout,
|
| 141 |
+
bias="none",
|
| 142 |
+
task_type="CAUSAL_LM",
|
| 143 |
+
)
|
| 144 |
+
model = get_peft_model(model, lora_config)
|
| 145 |
+
model.print_trainable_parameters()
|
| 146 |
+
|
| 147 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 148 |
+
|
| 149 |
+
state_dict = torch.load(args.weight, map_location="cpu")
|
| 150 |
+
model.load_state_dict(state_dict, strict=True)
|
| 151 |
+
|
| 152 |
+
model = model.merge_and_unload()
|
| 153 |
+
state_dict = {}
|
| 154 |
+
for k, v in model.state_dict().items():
|
| 155 |
+
if "vision_tower" not in k:
|
| 156 |
+
state_dict[k] = v
|
| 157 |
+
model.save_pretrained(args.save_path, state_dict=state_dict)
|
| 158 |
+
tokenizer.save_pretrained(args.save_path)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
main(sys.argv[1:])
|
model/AffordanceVLM.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from transformers import BitsAndBytesConfig, CLIPVisionModel
|
| 7 |
+
|
| 8 |
+
from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
|
| 9 |
+
DEFAULT_IMAGE_PATCH_TOKEN)
|
| 10 |
+
|
| 11 |
+
from .llava.model.language_model.llava_llama import (LlavaLlamaForCausalLM,
|
| 12 |
+
LlavaLlamaModel)
|
| 13 |
+
from .segment_anything import build_sam_vit_h
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def dice_loss(
|
| 17 |
+
inputs: torch.Tensor,
|
| 18 |
+
targets: torch.Tensor,
|
| 19 |
+
num_masks: float,
|
| 20 |
+
scale=1000, # 100000.0,
|
| 21 |
+
eps=1e-6,
|
| 22 |
+
):
|
| 23 |
+
"""
|
| 24 |
+
Compute the DICE loss, similar to generalized IOU for masks
|
| 25 |
+
Args:
|
| 26 |
+
inputs: A float tensor of arbitrary shape.
|
| 27 |
+
The predictions for each example.
|
| 28 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
| 29 |
+
classification label for each element in inputs
|
| 30 |
+
(0 for the negative class and 1 for the positive class).
|
| 31 |
+
"""
|
| 32 |
+
inputs = inputs.sigmoid()
|
| 33 |
+
inputs = inputs.flatten(1, 2)
|
| 34 |
+
targets = targets.flatten(1, 2)
|
| 35 |
+
numerator = 2 * (inputs / scale * targets).sum(-1)
|
| 36 |
+
denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1)
|
| 37 |
+
loss = 1 - (numerator + eps) / (denominator + eps)
|
| 38 |
+
loss = loss.sum() / (num_masks + 1e-8)
|
| 39 |
+
return loss
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def sigmoid_ce_loss(
|
| 43 |
+
inputs: torch.Tensor,
|
| 44 |
+
targets: torch.Tensor,
|
| 45 |
+
num_masks: float,
|
| 46 |
+
):
|
| 47 |
+
"""
|
| 48 |
+
Args:
|
| 49 |
+
inputs: A float tensor of arbitrary shape.
|
| 50 |
+
The predictions for each example.
|
| 51 |
+
targets: A float tensor with the same shape as inputs. Stores the binary
|
| 52 |
+
classification label for each element in inputs
|
| 53 |
+
(0 for the negative class and 1 for the positive class).
|
| 54 |
+
Returns:
|
| 55 |
+
Loss tensor
|
| 56 |
+
"""
|
| 57 |
+
loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
| 58 |
+
loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8)
|
| 59 |
+
return loss
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class LisaMetaModel:
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
config,
|
| 66 |
+
**kwargs,
|
| 67 |
+
):
|
| 68 |
+
super(LisaMetaModel, self).__init__(config)
|
| 69 |
+
|
| 70 |
+
self.config = config
|
| 71 |
+
if not hasattr(self.config, "train_mask_decoder"):
|
| 72 |
+
self.config.train_mask_decoder = kwargs["train_mask_decoder"]
|
| 73 |
+
self.config.out_dim = kwargs["out_dim"]
|
| 74 |
+
self.vision_pretrained = kwargs.get("vision_pretrained", None)
|
| 75 |
+
else:
|
| 76 |
+
self.vision_pretrained = kwargs.get("vision_pretrained", None)
|
| 77 |
+
self.initialize_lisa_modules(self.config)
|
| 78 |
+
|
| 79 |
+
def initialize_lisa_modules(self, config):
|
| 80 |
+
# SAM
|
| 81 |
+
self.visual_model = build_sam_vit_h(self.vision_pretrained)
|
| 82 |
+
for param in self.visual_model.parameters():
|
| 83 |
+
param.requires_grad = False
|
| 84 |
+
if config.train_mask_decoder:
|
| 85 |
+
self.visual_model.mask_decoder.train()
|
| 86 |
+
for param in self.visual_model.mask_decoder.parameters():
|
| 87 |
+
param.requires_grad = True
|
| 88 |
+
|
| 89 |
+
# Projection layer
|
| 90 |
+
in_dim = config.hidden_size
|
| 91 |
+
out_dim = config.out_dim
|
| 92 |
+
text_fc = [
|
| 93 |
+
nn.Linear(in_dim, in_dim),
|
| 94 |
+
nn.ReLU(inplace=True),
|
| 95 |
+
nn.Linear(in_dim, out_dim),
|
| 96 |
+
nn.Dropout(0.0),
|
| 97 |
+
]
|
| 98 |
+
self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
|
| 99 |
+
self.text_hidden_fcs.train()
|
| 100 |
+
for param in self.text_hidden_fcs.parameters():
|
| 101 |
+
param.requires_grad = True
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class LisaModel(LisaMetaModel, LlavaLlamaModel):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
config,
|
| 108 |
+
**kwargs,
|
| 109 |
+
):
|
| 110 |
+
super(LisaModel, self).__init__(config, **kwargs)
|
| 111 |
+
|
| 112 |
+
self.config.use_cache = False
|
| 113 |
+
self.config.vision_tower = self.config.mm_vision_tower
|
| 114 |
+
self.config.mm_vision_select_feature = "patch"
|
| 115 |
+
self.config.image_aspect_ratio = "square"
|
| 116 |
+
self.config.image_grid_pinpoints = None
|
| 117 |
+
self.config.tune_mm_mlp_adapter = False
|
| 118 |
+
self.config.freeze_mm_mlp_adapter = True
|
| 119 |
+
self.config.pretrain_mm_mlp_adapter = None
|
| 120 |
+
self.config.mm_use_im_patch_token = False
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class AffordanceVLMForCausalLM(LlavaLlamaForCausalLM):
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
config,
|
| 127 |
+
**kwargs,
|
| 128 |
+
):
|
| 129 |
+
if not hasattr(config, "train_mask_decoder"):
|
| 130 |
+
config.mm_use_im_start_end = kwargs.pop("use_mm_start_end", True)
|
| 131 |
+
config.mm_vision_tower = kwargs.get(
|
| 132 |
+
"vision_tower", "openai/clip-vit-large-patch14"
|
| 133 |
+
)
|
| 134 |
+
self.ce_loss_weight = kwargs.pop("ce_loss_weight", None)
|
| 135 |
+
self.dice_loss_weight = kwargs.pop("dice_loss_weight", None)
|
| 136 |
+
self.bce_loss_weight = kwargs.pop("bce_loss_weight", None)
|
| 137 |
+
else:
|
| 138 |
+
config.mm_vision_tower = config.vision_tower
|
| 139 |
+
|
| 140 |
+
self.seg_token_idx = kwargs.pop("seg_token_idx")
|
| 141 |
+
self.aff_token_idx = kwargs.pop("aff_token_idx")
|
| 142 |
+
|
| 143 |
+
super().__init__(config)
|
| 144 |
+
|
| 145 |
+
self.model = LisaModel(config, **kwargs)
|
| 146 |
+
|
| 147 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 148 |
+
|
| 149 |
+
# Initialize weights and apply final processing
|
| 150 |
+
self.post_init()
|
| 151 |
+
|
| 152 |
+
def get_visual_embs(self, pixel_values: torch.FloatTensor):
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
image_embeddings_list = []
|
| 155 |
+
for i in range(pixel_values.shape[0]):
|
| 156 |
+
torch.cuda.empty_cache()
|
| 157 |
+
image_embeddings = self.model.visual_model.image_encoder(
|
| 158 |
+
pixel_values[i].unsqueeze(0)
|
| 159 |
+
)
|
| 160 |
+
image_embeddings_list.append(image_embeddings)
|
| 161 |
+
torch.cuda.empty_cache()
|
| 162 |
+
image_embeddings = torch.cat(image_embeddings_list, 0)
|
| 163 |
+
return image_embeddings
|
| 164 |
+
|
| 165 |
+
def forward(self, **kwargs):
|
| 166 |
+
if "past_key_values" in kwargs:
|
| 167 |
+
return super().forward(**kwargs)
|
| 168 |
+
return self.model_forward(**kwargs)
|
| 169 |
+
|
| 170 |
+
def model_forward(
|
| 171 |
+
self,
|
| 172 |
+
images: torch.FloatTensor,
|
| 173 |
+
images_clip: torch.FloatTensor,
|
| 174 |
+
input_ids: torch.LongTensor,
|
| 175 |
+
labels: torch.LongTensor,
|
| 176 |
+
attention_masks: torch.LongTensor,
|
| 177 |
+
offset: torch.LongTensor,
|
| 178 |
+
masks_list: List[torch.FloatTensor],
|
| 179 |
+
label_list: List[torch.Tensor],
|
| 180 |
+
resize_list: List[tuple],
|
| 181 |
+
inference: bool = False,
|
| 182 |
+
**kwargs,
|
| 183 |
+
):
|
| 184 |
+
image_embeddings = self.get_visual_embs(images)
|
| 185 |
+
batch_size = image_embeddings.shape[0]
|
| 186 |
+
assert batch_size == len(offset) - 1
|
| 187 |
+
|
| 188 |
+
seg_token_mask = (input_ids[:, 1:] == self.seg_token_idx) + (input_ids[:, 1:] == self.aff_token_idx)
|
| 189 |
+
seg_token_mask = torch.cat(
|
| 190 |
+
[
|
| 191 |
+
seg_token_mask,
|
| 192 |
+
torch.zeros((seg_token_mask.shape[0], 1)).bool().cuda(),
|
| 193 |
+
],
|
| 194 |
+
dim=1,
|
| 195 |
+
)
|
| 196 |
+
# hack for IMAGE_TOKEN_INDEX (we suppose that there is only one image, and it is in the front)
|
| 197 |
+
seg_token_mask = torch.cat(
|
| 198 |
+
[torch.zeros((seg_token_mask.shape[0], 255)).bool().cuda(), seg_token_mask],
|
| 199 |
+
dim=1,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
if inference:
|
| 203 |
+
n_batch = 1
|
| 204 |
+
length = input_ids.shape[0]
|
| 205 |
+
assert images_clip.shape[0] == 1
|
| 206 |
+
images_clip_extend = images_clip.expand(length, -1, -1, -1).contiguous()
|
| 207 |
+
|
| 208 |
+
output_hidden_states = []
|
| 209 |
+
for i in range(n_batch):
|
| 210 |
+
start_i, end_i = i * length, min((i + 1) * length, input_ids.shape[0])
|
| 211 |
+
output_i = super().forward(
|
| 212 |
+
images=images_clip_extend[: end_i - start_i],
|
| 213 |
+
attention_mask=attention_masks[start_i:end_i],
|
| 214 |
+
input_ids=input_ids[start_i:end_i],
|
| 215 |
+
output_hidden_states=True,
|
| 216 |
+
)
|
| 217 |
+
output_hidden_states.append(output_i.hidden_states)
|
| 218 |
+
torch.cuda.empty_cache()
|
| 219 |
+
|
| 220 |
+
output_hidden_states_list = []
|
| 221 |
+
output_hidden_states_level = torch.cat(output_hidden_states, dim=0)
|
| 222 |
+
output_hidden_states_list.append(output_hidden_states_level)
|
| 223 |
+
output_hidden_states = output_hidden_states_list
|
| 224 |
+
output = None
|
| 225 |
+
|
| 226 |
+
else:
|
| 227 |
+
images_clip_list = []
|
| 228 |
+
for i in range(len(offset) - 1):
|
| 229 |
+
start_i, end_i = offset[i], offset[i + 1]
|
| 230 |
+
images_clip_i = (
|
| 231 |
+
images_clip[i]
|
| 232 |
+
.unsqueeze(0)
|
| 233 |
+
.expand(end_i - start_i, -1, -1, -1)
|
| 234 |
+
.contiguous()
|
| 235 |
+
)
|
| 236 |
+
images_clip_list.append(images_clip_i)
|
| 237 |
+
images_clip = torch.cat(images_clip_list, dim=0)
|
| 238 |
+
|
| 239 |
+
output = super().forward(
|
| 240 |
+
images=images_clip,
|
| 241 |
+
attention_mask=attention_masks,
|
| 242 |
+
input_ids=input_ids,
|
| 243 |
+
labels=labels,
|
| 244 |
+
output_hidden_states=True,
|
| 245 |
+
)
|
| 246 |
+
output_hidden_states = output.hidden_states
|
| 247 |
+
|
| 248 |
+
hidden_states = []
|
| 249 |
+
|
| 250 |
+
assert len(self.model.text_hidden_fcs) == 1
|
| 251 |
+
hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states[-1]))
|
| 252 |
+
|
| 253 |
+
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
|
| 254 |
+
pred_embeddings = last_hidden_state[seg_token_mask]
|
| 255 |
+
seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ]
|
| 256 |
+
|
| 257 |
+
seg_token_offset = seg_token_counts.cumsum(-1)
|
| 258 |
+
seg_token_offset = torch.cat(
|
| 259 |
+
[torch.zeros(1).long().cuda(), seg_token_offset], dim=0
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
seg_token_offset = seg_token_offset[offset]
|
| 263 |
+
|
| 264 |
+
pred_embeddings_ = []
|
| 265 |
+
for i in range(len(seg_token_offset) - 1):
|
| 266 |
+
start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
|
| 267 |
+
pred_embeddings_.append(pred_embeddings[start_i:end_i])
|
| 268 |
+
pred_embeddings = pred_embeddings_
|
| 269 |
+
|
| 270 |
+
multimask_output = False
|
| 271 |
+
pred_masks = []
|
| 272 |
+
for i in range(len(pred_embeddings)):
|
| 273 |
+
(
|
| 274 |
+
sparse_embeddings,
|
| 275 |
+
dense_embeddings,
|
| 276 |
+
) = self.model.visual_model.prompt_encoder(
|
| 277 |
+
points=None,
|
| 278 |
+
boxes=None,
|
| 279 |
+
masks=None,
|
| 280 |
+
text_embeds=pred_embeddings[i].unsqueeze(1),
|
| 281 |
+
)
|
| 282 |
+
sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
|
| 283 |
+
low_res_masks, iou_predictions = self.model.visual_model.mask_decoder(
|
| 284 |
+
image_embeddings=image_embeddings[i].unsqueeze(0),
|
| 285 |
+
image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(),
|
| 286 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 287 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 288 |
+
multimask_output=multimask_output,
|
| 289 |
+
)
|
| 290 |
+
pred_mask = self.model.visual_model.postprocess_masks(
|
| 291 |
+
low_res_masks,
|
| 292 |
+
input_size=resize_list[i],
|
| 293 |
+
original_size=label_list[i].shape,
|
| 294 |
+
)
|
| 295 |
+
pred_masks.append(pred_mask[:, 0])
|
| 296 |
+
|
| 297 |
+
model_output = output
|
| 298 |
+
gt_masks = masks_list
|
| 299 |
+
|
| 300 |
+
if inference:
|
| 301 |
+
return {
|
| 302 |
+
"pred_masks": pred_masks,
|
| 303 |
+
"gt_masks": gt_masks,
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
output = model_output.logits
|
| 307 |
+
|
| 308 |
+
ce_loss = model_output.loss
|
| 309 |
+
ce_loss = ce_loss * self.ce_loss_weight
|
| 310 |
+
mask_bce_loss = 0
|
| 311 |
+
mask_dice_loss = 0
|
| 312 |
+
num_masks = 0
|
| 313 |
+
for batch_idx in range(len(pred_masks)):
|
| 314 |
+
gt_mask = gt_masks[batch_idx]
|
| 315 |
+
pred_mask = pred_masks[batch_idx]
|
| 316 |
+
|
| 317 |
+
assert (
|
| 318 |
+
gt_mask.shape[0] == pred_mask.shape[0]
|
| 319 |
+
), "gt_mask.shape: {}, pred_mask.shape: {}".format(
|
| 320 |
+
gt_mask.shape, pred_mask.shape
|
| 321 |
+
)
|
| 322 |
+
mask_bce_loss += (
|
| 323 |
+
sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
|
| 324 |
+
* gt_mask.shape[0]
|
| 325 |
+
)
|
| 326 |
+
mask_dice_loss += (
|
| 327 |
+
dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
|
| 328 |
+
* gt_mask.shape[0]
|
| 329 |
+
)
|
| 330 |
+
num_masks += gt_mask.shape[0]
|
| 331 |
+
|
| 332 |
+
mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8)
|
| 333 |
+
mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
|
| 334 |
+
mask_loss = mask_bce_loss + mask_dice_loss
|
| 335 |
+
|
| 336 |
+
loss = ce_loss + mask_loss
|
| 337 |
+
|
| 338 |
+
return {
|
| 339 |
+
"loss": loss,
|
| 340 |
+
"ce_loss": ce_loss,
|
| 341 |
+
"mask_bce_loss": mask_bce_loss,
|
| 342 |
+
"mask_dice_loss": mask_dice_loss,
|
| 343 |
+
"mask_loss": mask_loss,
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
def evaluate(
|
| 347 |
+
self,
|
| 348 |
+
images_clip,
|
| 349 |
+
images,
|
| 350 |
+
input_ids,
|
| 351 |
+
resize_list,
|
| 352 |
+
original_size_list,
|
| 353 |
+
max_new_tokens=32,
|
| 354 |
+
tokenizer=None,
|
| 355 |
+
):
|
| 356 |
+
with torch.no_grad():
|
| 357 |
+
outputs = self.generate(
|
| 358 |
+
images=images_clip,
|
| 359 |
+
input_ids=input_ids,
|
| 360 |
+
max_new_tokens=max_new_tokens,
|
| 361 |
+
num_beams=1,
|
| 362 |
+
output_hidden_states=True,
|
| 363 |
+
return_dict_in_generate=True,
|
| 364 |
+
)
|
| 365 |
+
output_hidden_states = outputs.hidden_states[-1]
|
| 366 |
+
output_ids = outputs.sequences
|
| 367 |
+
|
| 368 |
+
seg_token_mask = (output_ids[:, 1:] == self.seg_token_idx) + (output_ids[:, 1:] == self.aff_token_idx)
|
| 369 |
+
# hack for IMAGE_TOKEN_INDEX (we suppose that there is only one image, and it is in the front)
|
| 370 |
+
seg_token_mask = torch.cat(
|
| 371 |
+
[
|
| 372 |
+
torch.zeros((seg_token_mask.shape[0], 255)).bool().cuda(),
|
| 373 |
+
seg_token_mask,
|
| 374 |
+
],
|
| 375 |
+
dim=1,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
hidden_states = []
|
| 379 |
+
|
| 380 |
+
assert len(self.model.text_hidden_fcs) == 1
|
| 381 |
+
hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states))
|
| 382 |
+
|
| 383 |
+
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
|
| 384 |
+
pred_embeddings = last_hidden_state[seg_token_mask]
|
| 385 |
+
|
| 386 |
+
seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ]
|
| 387 |
+
seg_token_offset = seg_token_counts.cumsum(-1)
|
| 388 |
+
seg_token_offset = torch.cat(
|
| 389 |
+
[torch.zeros(1).long().cuda(), seg_token_offset], dim=0
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
pred_embeddings_ = []
|
| 393 |
+
for i in range(len(seg_token_offset) - 1):
|
| 394 |
+
start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
|
| 395 |
+
pred_embeddings_.append(pred_embeddings[start_i:end_i])
|
| 396 |
+
pred_embeddings = pred_embeddings_
|
| 397 |
+
|
| 398 |
+
image_embeddings = self.get_visual_embs(images)
|
| 399 |
+
|
| 400 |
+
multimask_output = False
|
| 401 |
+
pred_masks = []
|
| 402 |
+
for i in range(len(pred_embeddings)):
|
| 403 |
+
(
|
| 404 |
+
sparse_embeddings,
|
| 405 |
+
dense_embeddings,
|
| 406 |
+
) = self.model.visual_model.prompt_encoder(
|
| 407 |
+
points=None,
|
| 408 |
+
boxes=None,
|
| 409 |
+
masks=None,
|
| 410 |
+
text_embeds=pred_embeddings[i].unsqueeze(1),
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
|
| 414 |
+
low_res_masks, iou_predictions = self.model.visual_model.mask_decoder(
|
| 415 |
+
image_embeddings=image_embeddings[i].unsqueeze(0),
|
| 416 |
+
image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(),
|
| 417 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 418 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 419 |
+
multimask_output=multimask_output,
|
| 420 |
+
)
|
| 421 |
+
pred_mask = self.model.visual_model.postprocess_masks(
|
| 422 |
+
low_res_masks,
|
| 423 |
+
input_size=resize_list[i],
|
| 424 |
+
original_size=original_size_list[i],
|
| 425 |
+
)
|
| 426 |
+
pred_masks.append(pred_mask[:, 0])
|
| 427 |
+
|
| 428 |
+
return output_ids, pred_masks
|
model/__pycache__/AffordanceVLM.cpython-39.pyc
ADDED
|
Binary file (9.71 kB). View file
|
|
|
model/llava/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .model import LlavaLlamaForCausalLM
|
model/llava/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (192 Bytes). View file
|
|
|
model/llava/__pycache__/constants.cpython-39.pyc
ADDED
|
Binary file (454 Bytes). View file
|
|
|
model/llava/__pycache__/conversation.cpython-39.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
model/llava/__pycache__/mm_utils.cpython-39.pyc
ADDED
|
Binary file (3.4 kB). View file
|
|
|
model/llava/constants.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
| 2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
| 3 |
+
|
| 4 |
+
LOGDIR = "."
|
| 5 |
+
|
| 6 |
+
# Model Constants
|
| 7 |
+
IGNORE_INDEX = -100
|
| 8 |
+
IMAGE_TOKEN_INDEX = -200
|
| 9 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 10 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
| 11 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
| 12 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
model/llava/conversation.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
from enum import Enum, auto
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SeparatorStyle(Enum):
|
| 7 |
+
"""Different separator style."""
|
| 8 |
+
|
| 9 |
+
SINGLE = auto()
|
| 10 |
+
TWO = auto()
|
| 11 |
+
MPT = auto()
|
| 12 |
+
PLAIN = auto()
|
| 13 |
+
LLAMA_2 = auto()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclasses.dataclass
|
| 17 |
+
class Conversation:
|
| 18 |
+
"""A class that keeps all conversation history."""
|
| 19 |
+
|
| 20 |
+
system: str
|
| 21 |
+
roles: List[str]
|
| 22 |
+
messages: List[List[str]]
|
| 23 |
+
offset: int
|
| 24 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
| 25 |
+
sep: str = "###"
|
| 26 |
+
sep2: str = None
|
| 27 |
+
version: str = "Unknown"
|
| 28 |
+
|
| 29 |
+
skip_next: bool = False
|
| 30 |
+
|
| 31 |
+
def get_prompt(self):
|
| 32 |
+
messages = self.messages
|
| 33 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
| 34 |
+
messages = self.messages.copy()
|
| 35 |
+
init_role, init_msg = messages[0].copy()
|
| 36 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
| 37 |
+
if "mmtag" in self.version:
|
| 38 |
+
messages[0] = (init_role, init_msg)
|
| 39 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
| 40 |
+
messages.insert(1, (self.roles[1], "Received."))
|
| 41 |
+
else:
|
| 42 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
| 43 |
+
|
| 44 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
| 45 |
+
ret = self.system + self.sep
|
| 46 |
+
for role, message in messages:
|
| 47 |
+
if message:
|
| 48 |
+
if type(message) is tuple:
|
| 49 |
+
message, _, _ = message
|
| 50 |
+
ret += role + ": " + message + self.sep
|
| 51 |
+
else:
|
| 52 |
+
ret += role + ":"
|
| 53 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
| 54 |
+
seps = [self.sep, self.sep2]
|
| 55 |
+
ret = self.system + seps[0]
|
| 56 |
+
for i, (role, message) in enumerate(messages):
|
| 57 |
+
if message:
|
| 58 |
+
if type(message) is tuple:
|
| 59 |
+
message, _, _ = message
|
| 60 |
+
ret += role + ": " + message + seps[i % 2]
|
| 61 |
+
else:
|
| 62 |
+
ret += role + ":"
|
| 63 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
| 64 |
+
ret = self.system + self.sep
|
| 65 |
+
for role, message in messages:
|
| 66 |
+
if message:
|
| 67 |
+
if type(message) is tuple:
|
| 68 |
+
message, _, _ = message
|
| 69 |
+
ret += role + message + self.sep
|
| 70 |
+
else:
|
| 71 |
+
ret += role
|
| 72 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
| 73 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
| 74 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
| 75 |
+
ret = ""
|
| 76 |
+
|
| 77 |
+
for i, (role, message) in enumerate(messages):
|
| 78 |
+
if i == 0:
|
| 79 |
+
assert message, "first message should not be none"
|
| 80 |
+
assert role == self.roles[0], "first message should come from user"
|
| 81 |
+
if message:
|
| 82 |
+
if type(message) is tuple:
|
| 83 |
+
message, _, _ = message
|
| 84 |
+
if i == 0:
|
| 85 |
+
message = wrap_sys(self.system) + message
|
| 86 |
+
if i % 2 == 0:
|
| 87 |
+
message = wrap_inst(message)
|
| 88 |
+
ret += self.sep + message
|
| 89 |
+
else:
|
| 90 |
+
ret += " " + message + " " + self.sep2
|
| 91 |
+
else:
|
| 92 |
+
ret += ""
|
| 93 |
+
ret = ret.lstrip(self.sep)
|
| 94 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
| 95 |
+
seps = [self.sep, self.sep2]
|
| 96 |
+
ret = self.system
|
| 97 |
+
for i, (role, message) in enumerate(messages):
|
| 98 |
+
if message:
|
| 99 |
+
if type(message) is tuple:
|
| 100 |
+
message, _, _ = message
|
| 101 |
+
ret += message + seps[i % 2]
|
| 102 |
+
else:
|
| 103 |
+
ret += ""
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
| 106 |
+
|
| 107 |
+
return ret
|
| 108 |
+
|
| 109 |
+
def append_message(self, role, message):
|
| 110 |
+
self.messages.append([role, message])
|
| 111 |
+
|
| 112 |
+
def get_images(self, return_pil=False):
|
| 113 |
+
images = []
|
| 114 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
| 115 |
+
if i % 2 == 0:
|
| 116 |
+
if type(msg) is tuple:
|
| 117 |
+
import base64
|
| 118 |
+
from io import BytesIO
|
| 119 |
+
|
| 120 |
+
from PIL import Image
|
| 121 |
+
|
| 122 |
+
msg, image, image_process_mode = msg
|
| 123 |
+
if image_process_mode == "Pad":
|
| 124 |
+
|
| 125 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
| 126 |
+
width, height = pil_img.size
|
| 127 |
+
if width == height:
|
| 128 |
+
return pil_img
|
| 129 |
+
elif width > height:
|
| 130 |
+
result = Image.new(
|
| 131 |
+
pil_img.mode, (width, width), background_color
|
| 132 |
+
)
|
| 133 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 134 |
+
return result
|
| 135 |
+
else:
|
| 136 |
+
result = Image.new(
|
| 137 |
+
pil_img.mode, (height, height), background_color
|
| 138 |
+
)
|
| 139 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 140 |
+
return result
|
| 141 |
+
|
| 142 |
+
image = expand2square(image)
|
| 143 |
+
elif image_process_mode == "Crop":
|
| 144 |
+
pass
|
| 145 |
+
elif image_process_mode == "Resize":
|
| 146 |
+
image = image.resize((336, 336))
|
| 147 |
+
else:
|
| 148 |
+
raise ValueError(
|
| 149 |
+
f"Invalid image_process_mode: {image_process_mode}"
|
| 150 |
+
)
|
| 151 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
| 152 |
+
aspect_ratio = max_hw / min_hw
|
| 153 |
+
max_len, min_len = 800, 400
|
| 154 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
| 155 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
| 156 |
+
W, H = image.size
|
| 157 |
+
if H > W:
|
| 158 |
+
H, W = longest_edge, shortest_edge
|
| 159 |
+
else:
|
| 160 |
+
H, W = shortest_edge, longest_edge
|
| 161 |
+
image = image.resize((W, H))
|
| 162 |
+
if return_pil:
|
| 163 |
+
images.append(image)
|
| 164 |
+
else:
|
| 165 |
+
buffered = BytesIO()
|
| 166 |
+
image.save(buffered, format="PNG")
|
| 167 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
| 168 |
+
images.append(img_b64_str)
|
| 169 |
+
return images
|
| 170 |
+
|
| 171 |
+
def to_gradio_chatbot(self):
|
| 172 |
+
ret = []
|
| 173 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
| 174 |
+
if i % 2 == 0:
|
| 175 |
+
if type(msg) is tuple:
|
| 176 |
+
import base64
|
| 177 |
+
from io import BytesIO
|
| 178 |
+
|
| 179 |
+
msg, image, image_process_mode = msg
|
| 180 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
| 181 |
+
aspect_ratio = max_hw / min_hw
|
| 182 |
+
max_len, min_len = 800, 400
|
| 183 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
| 184 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
| 185 |
+
W, H = image.size
|
| 186 |
+
if H > W:
|
| 187 |
+
H, W = longest_edge, shortest_edge
|
| 188 |
+
else:
|
| 189 |
+
H, W = shortest_edge, longest_edge
|
| 190 |
+
image = image.resize((W, H))
|
| 191 |
+
buffered = BytesIO()
|
| 192 |
+
image.save(buffered, format="JPEG")
|
| 193 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
| 194 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
| 195 |
+
ret.append([img_str, None])
|
| 196 |
+
msg = msg.replace("<image>", "").strip()
|
| 197 |
+
if len(msg) > 0:
|
| 198 |
+
ret.append([msg, None])
|
| 199 |
+
else:
|
| 200 |
+
ret.append([msg, None])
|
| 201 |
+
else:
|
| 202 |
+
ret[-1][-1] = msg
|
| 203 |
+
return ret
|
| 204 |
+
|
| 205 |
+
def copy(self):
|
| 206 |
+
return Conversation(
|
| 207 |
+
system=self.system,
|
| 208 |
+
roles=self.roles,
|
| 209 |
+
messages=[[x, y] for x, y in self.messages],
|
| 210 |
+
offset=self.offset,
|
| 211 |
+
sep_style=self.sep_style,
|
| 212 |
+
sep=self.sep,
|
| 213 |
+
sep2=self.sep2,
|
| 214 |
+
version=self.version,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def dict(self):
|
| 218 |
+
if len(self.get_images()) > 0:
|
| 219 |
+
return {
|
| 220 |
+
"system": self.system,
|
| 221 |
+
"roles": self.roles,
|
| 222 |
+
"messages": [
|
| 223 |
+
[x, y[0] if type(y) is tuple else y] for x, y in self.messages
|
| 224 |
+
],
|
| 225 |
+
"offset": self.offset,
|
| 226 |
+
"sep": self.sep,
|
| 227 |
+
"sep2": self.sep2,
|
| 228 |
+
}
|
| 229 |
+
return {
|
| 230 |
+
"system": self.system,
|
| 231 |
+
"roles": self.roles,
|
| 232 |
+
"messages": self.messages,
|
| 233 |
+
"offset": self.offset,
|
| 234 |
+
"sep": self.sep,
|
| 235 |
+
"sep2": self.sep2,
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
conv_vicuna_v0 = Conversation(
|
| 240 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
| 241 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
| 242 |
+
roles=("Human", "Assistant"),
|
| 243 |
+
messages=(
|
| 244 |
+
(
|
| 245 |
+
"Human",
|
| 246 |
+
"What are the key differences between renewable and non-renewable energy sources?",
|
| 247 |
+
),
|
| 248 |
+
(
|
| 249 |
+
"Assistant",
|
| 250 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
| 251 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
| 252 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
| 253 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
| 254 |
+
"renewable and non-renewable energy sources:\n"
|
| 255 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
| 256 |
+
"energy sources are finite and will eventually run out.\n"
|
| 257 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
| 258 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
| 259 |
+
"and other negative effects.\n"
|
| 260 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
| 261 |
+
"have lower operational costs than non-renewable sources.\n"
|
| 262 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
| 263 |
+
"locations than non-renewable sources.\n"
|
| 264 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
| 265 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
| 266 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
| 267 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
|
| 268 |
+
),
|
| 269 |
+
),
|
| 270 |
+
offset=2,
|
| 271 |
+
sep_style=SeparatorStyle.SINGLE,
|
| 272 |
+
sep="###",
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
conv_vicuna_v1 = Conversation(
|
| 276 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
| 277 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
| 278 |
+
roles=("USER", "ASSISTANT"),
|
| 279 |
+
version="v1",
|
| 280 |
+
messages=(),
|
| 281 |
+
offset=0,
|
| 282 |
+
sep_style=SeparatorStyle.TWO,
|
| 283 |
+
sep=" ",
|
| 284 |
+
sep2="</s>",
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
conv_llama_2 = Conversation(
|
| 288 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
| 289 |
+
|
| 290 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
| 291 |
+
roles=("USER", "ASSISTANT"),
|
| 292 |
+
version="llama_v2",
|
| 293 |
+
messages=(),
|
| 294 |
+
offset=0,
|
| 295 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
| 296 |
+
sep="<s>",
|
| 297 |
+
sep2="</s>",
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
conv_llava_llama_2 = Conversation(
|
| 301 |
+
system="You are a helpful language and vision assistant. "
|
| 302 |
+
"You are able to understand the visual content that the user provides, "
|
| 303 |
+
"and assist the user with a variety of tasks using natural language.",
|
| 304 |
+
roles=("USER", "ASSISTANT"),
|
| 305 |
+
version="llama_v2",
|
| 306 |
+
messages=(),
|
| 307 |
+
offset=0,
|
| 308 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
| 309 |
+
sep="<s>",
|
| 310 |
+
sep2="</s>",
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
conv_mpt = Conversation(
|
| 314 |
+
system="""<|im_start|>system
|
| 315 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
| 316 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
| 317 |
+
version="mpt",
|
| 318 |
+
messages=(),
|
| 319 |
+
offset=0,
|
| 320 |
+
sep_style=SeparatorStyle.MPT,
|
| 321 |
+
sep="<|im_end|>",
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
conv_llava_plain = Conversation(
|
| 325 |
+
system="",
|
| 326 |
+
roles=("", ""),
|
| 327 |
+
messages=(),
|
| 328 |
+
offset=0,
|
| 329 |
+
sep_style=SeparatorStyle.PLAIN,
|
| 330 |
+
sep="\n",
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
conv_llava_v0 = Conversation(
|
| 334 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
| 335 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
| 336 |
+
roles=("Human", "Assistant"),
|
| 337 |
+
messages=(("Human", "Hi!"), ("Assistant", "Hi there! How can I help you today?")),
|
| 338 |
+
offset=2,
|
| 339 |
+
sep_style=SeparatorStyle.SINGLE,
|
| 340 |
+
sep="###",
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
conv_llava_v0_mmtag = Conversation(
|
| 344 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
| 345 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
| 346 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
| 347 |
+
roles=("Human", "Assistant"),
|
| 348 |
+
messages=(),
|
| 349 |
+
offset=0,
|
| 350 |
+
sep_style=SeparatorStyle.SINGLE,
|
| 351 |
+
sep="###",
|
| 352 |
+
version="v0_mmtag",
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
conv_llava_v1 = Conversation(
|
| 356 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
| 357 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
| 358 |
+
roles=("USER", "ASSISTANT"),
|
| 359 |
+
version="v1",
|
| 360 |
+
messages=(),
|
| 361 |
+
offset=0,
|
| 362 |
+
sep_style=SeparatorStyle.TWO,
|
| 363 |
+
sep=" ",
|
| 364 |
+
sep2="</s>",
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
conv_llava_v1_mmtag = Conversation(
|
| 368 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
| 369 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
| 370 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
| 371 |
+
roles=("USER", "ASSISTANT"),
|
| 372 |
+
messages=(),
|
| 373 |
+
offset=0,
|
| 374 |
+
sep_style=SeparatorStyle.TWO,
|
| 375 |
+
sep=" ",
|
| 376 |
+
sep2="</s>",
|
| 377 |
+
version="v1_mmtag",
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
default_conversation = conv_vicuna_v0
|
| 381 |
+
conv_templates = {
|
| 382 |
+
"default": conv_vicuna_v0,
|
| 383 |
+
"v0": conv_vicuna_v0,
|
| 384 |
+
"v1": conv_vicuna_v1,
|
| 385 |
+
"vicuna_v1": conv_vicuna_v1,
|
| 386 |
+
"llama_2": conv_llama_2,
|
| 387 |
+
"plain": conv_llava_plain,
|
| 388 |
+
"v0_plain": conv_llava_plain,
|
| 389 |
+
"llava_v0": conv_llava_v0,
|
| 390 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
| 391 |
+
"llava_v1": conv_llava_v1,
|
| 392 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
| 393 |
+
"llava_llama_2": conv_llava_llama_2,
|
| 394 |
+
"mpt": conv_mpt,
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
if __name__ == "__main__":
|
| 399 |
+
print(default_conversation.get_prompt())
|
model/llava/mm_utils.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from transformers import StoppingCriteria
|
| 7 |
+
|
| 8 |
+
from .constants import IMAGE_TOKEN_INDEX
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_image_from_base64(image):
|
| 12 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def process_images(images, image_processor, model_cfg):
|
| 16 |
+
return image_processor(images, return_tensors="pt")["pixel_values"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def tokenizer_image_token(
|
| 20 |
+
prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
|
| 21 |
+
):
|
| 22 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
| 23 |
+
|
| 24 |
+
def insert_separator(X, sep):
|
| 25 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
| 26 |
+
|
| 27 |
+
input_ids = []
|
| 28 |
+
offset = 0
|
| 29 |
+
if (
|
| 30 |
+
len(prompt_chunks) > 0
|
| 31 |
+
and len(prompt_chunks[0]) > 0
|
| 32 |
+
and prompt_chunks[0][0] == tokenizer.bos_token_id
|
| 33 |
+
):
|
| 34 |
+
offset = 1
|
| 35 |
+
input_ids.append(prompt_chunks[0][0])
|
| 36 |
+
|
| 37 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
| 38 |
+
input_ids.extend(x[offset:])
|
| 39 |
+
|
| 40 |
+
if return_tensors is not None:
|
| 41 |
+
if return_tensors == "pt":
|
| 42 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 43 |
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
| 44 |
+
return input_ids
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_model_name_from_path(model_path):
|
| 48 |
+
model_path = model_path.strip("/")
|
| 49 |
+
model_paths = model_path.split("/")
|
| 50 |
+
if model_paths[-1].startswith("checkpoint-"):
|
| 51 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
| 52 |
+
else:
|
| 53 |
+
return model_paths[-1]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
| 57 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
| 58 |
+
self.keywords = keywords
|
| 59 |
+
self.keyword_ids = []
|
| 60 |
+
for keyword in keywords:
|
| 61 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
| 62 |
+
if (
|
| 63 |
+
len(cur_keyword_ids) > 1
|
| 64 |
+
and cur_keyword_ids[0] == tokenizer.bos_token_id
|
| 65 |
+
):
|
| 66 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
| 67 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
| 68 |
+
self.tokenizer = tokenizer
|
| 69 |
+
self.start_len = input_ids.shape[1]
|
| 70 |
+
|
| 71 |
+
def __call__(
|
| 72 |
+
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
| 73 |
+
) -> bool:
|
| 74 |
+
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
| 75 |
+
offset = min(output_ids.shape[1] - self.start_len, 3)
|
| 76 |
+
self.keyword_ids = [
|
| 77 |
+
keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
|
| 78 |
+
]
|
| 79 |
+
for keyword_id in self.keyword_ids:
|
| 80 |
+
if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
|
| 81 |
+
return True
|
| 82 |
+
outputs = self.tokenizer.batch_decode(
|
| 83 |
+
output_ids[:, -offset:], skip_special_tokens=True
|
| 84 |
+
)[0]
|
| 85 |
+
for keyword in self.keywords:
|
| 86 |
+
if keyword in outputs:
|
| 87 |
+
return True
|
| 88 |
+
return False
|