wangzeze commited on
Commit
0453c63
·
verified ·
1 Parent(s): 296ec95

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. .ipynb_checkpoints/batch_generate-checkpoint.py +401 -0
  3. .ipynb_checkpoints/batch_generate-checkpoint.sh +14 -0
  4. .ipynb_checkpoints/batch_generate_prefill_accelerate-checkpoint.py +418 -0
  5. .ipynb_checkpoints/chat-checkpoint.py +255 -0
  6. .ipynb_checkpoints/chat_prefill-checkpoint.py +282 -0
  7. .ipynb_checkpoints/train_aff-checkpoint.py +620 -0
  8. README.md +79 -3
  9. app.py +329 -0
  10. batch_generate.sh +14 -0
  11. batch_generate_prefill_accelerate.py +418 -0
  12. chat.py +255 -0
  13. chat_prefill.py +282 -0
  14. ckpts/AffordanceVLM-7B/.gitattributes +35 -0
  15. ckpts/AffordanceVLM-7B/README.md +3 -0
  16. ckpts/AffordanceVLM-7B/added_tokens.json +7 -0
  17. ckpts/AffordanceVLM-7B/config.json +42 -0
  18. ckpts/AffordanceVLM-7B/eval_result.txt +1 -0
  19. ckpts/AffordanceVLM-7B/generation_config.json +7 -0
  20. ckpts/AffordanceVLM-7B/pytorch_model-00001-of-00002.bin +3 -0
  21. ckpts/AffordanceVLM-7B/pytorch_model-00002-of-00002.bin +3 -0
  22. ckpts/AffordanceVLM-7B/pytorch_model.bin.index.json +930 -0
  23. ckpts/AffordanceVLM-7B/special_tokens_map.json +24 -0
  24. ckpts/AffordanceVLM-7B/tokenizer.model +3 -0
  25. ckpts/AffordanceVLM-7B/tokenizer_config.json +35 -0
  26. ckpts/sam_vit_h_4b8939.pth +3 -0
  27. client.py +67 -0
  28. data_curation/.ipynb_checkpoints/check_dataset-checkpoint.py +100 -0
  29. data_curation/build_vlpart.py +105 -0
  30. data_curation/check_dataset.py +100 -0
  31. data_curation/prompt_generation_handal_easy_reasoning.py +126 -0
  32. data_curation/prompt_generation_handal_hard_reasoning.py +136 -0
  33. data_curation/vlpart_sam2_tracking.py +187 -0
  34. docs/dataset.md +93 -0
  35. docs/installation.md +10 -0
  36. docs/training_and_evaluation.md +56 -0
  37. imgs/.ipynb_checkpoints/AffordanceNet-checkpoint.jpg +3 -0
  38. imgs/AffordanceNet.jpg +3 -0
  39. imgs/AffordanceNet.png +3 -0
  40. merge_lora_weights_and_save_hf_model.py +162 -0
  41. model/AffordanceVLM.py +428 -0
  42. model/__pycache__/AffordanceVLM.cpython-39.pyc +0 -0
  43. model/llava/__init__.py +1 -0
  44. model/llava/__pycache__/__init__.cpython-39.pyc +0 -0
  45. model/llava/__pycache__/constants.cpython-39.pyc +0 -0
  46. model/llava/__pycache__/conversation.cpython-39.pyc +0 -0
  47. model/llava/__pycache__/mm_utils.cpython-39.pyc +0 -0
  48. model/llava/constants.py +12 -0
  49. model/llava/conversation.py +399 -0
  50. model/llava/mm_utils.py +88 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ imgs/.ipynb_checkpoints/AffordanceNet-checkpoint.jpg filter=lfs diff=lfs merge=lfs -text
37
+ imgs/AffordanceNet.jpg filter=lfs diff=lfs merge=lfs -text
38
+ imgs/AffordanceNet.png filter=lfs diff=lfs merge=lfs -text
39
+ vis_output/.ipynb_checkpoints/my_workspace-checkpoint.JPG filter=lfs diff=lfs merge=lfs -text
40
+ vis_output/.ipynb_checkpoints/my_workspace_masked_img_0-checkpoint.jpg filter=lfs diff=lfs merge=lfs -text
41
+ vis_output/my_workspace.JPG filter=lfs diff=lfs merge=lfs -text
42
+ vis_output/my_workspace_masked_img_0.jpg filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/batch_generate-checkpoint.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batch affordance mask generation for per-step datasets.
3
+
4
+ Reads a per-step dataset (converted by convert_lerobot_to_perstep.py) and
5
+ generates affordance masks for every image_primary.jpg and image_wrist.jpg
6
+ using AffordanceVLM.
7
+
8
+ Input structure:
9
+ {data_dir}/
10
+ ├── meta_info.h5
11
+ └── episodes/
12
+ └── {episode_id:06d}/
13
+ └── steps/
14
+ └── {step_id:04d}/
15
+ ├── other.h5 # language_instruction
16
+ ├── image_primary.jpg
17
+ └── image_wrist.jpg
18
+
19
+ Output structure:
20
+ {save_dir}/
21
+ └── episode_{episode_id}/
22
+ └── steps/
23
+ └── step_{step_id}/
24
+ ├── image_primary_mask.png # binary 0/255
25
+ └── image_wrist_mask.png
26
+
27
+ Usage:
28
+ python batch_generate.py \
29
+ --data_dir /path/to/perstep_dataset \
30
+ --save_dir /path/to/mask_output \
31
+ --start_episode 0 --end_episode 10
32
+ """
33
+
34
+ import argparse
35
+ import os
36
+ import sys
37
+ from pathlib import Path
38
+
39
+ import cv2
40
+ import h5py
41
+ import numpy as np
42
+ import torch
43
+ import torch.nn.functional as F
44
+ from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
45
+
46
+ from model.AffordanceVLM import AffordanceVLMForCausalLM
47
+ from model.llava import conversation as conversation_lib
48
+ from model.llava.mm_utils import tokenizer_image_token
49
+ from model.segment_anything.utils.transforms import ResizeLongestSide
50
+ from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
51
+ DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
52
+
53
+
54
+ def parse_args(args):
55
+ parser = argparse.ArgumentParser(
56
+ description="Batch affordance mask generation for per-step datasets"
57
+ )
58
+ # Model arguments (same as chat.py)
59
+ parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B")
60
+ parser.add_argument(
61
+ "--precision", default="bf16", type=str,
62
+ choices=["fp32", "bf16", "fp16"],
63
+ )
64
+ parser.add_argument("--image_size", default=1024, type=int)
65
+ parser.add_argument("--model_max_length", default=512, type=int)
66
+ parser.add_argument("--lora_r", default=8, type=int)
67
+ parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14", type=str)
68
+ parser.add_argument("--local-rank", default=0, type=int)
69
+ parser.add_argument("--load_in_8bit", action="store_true", default=False)
70
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
71
+ parser.add_argument("--use_mm_start_end", action="store_true", default=True)
72
+ parser.add_argument(
73
+ "--conv_type", default="llava_v1", type=str,
74
+ choices=["llava_v1", "llava_llama_2"],
75
+ )
76
+
77
+ # Batch processing arguments
78
+ parser.add_argument("--data_dir", type=str, required=True,
79
+ help="Root of per-step dataset (contains episodes/)")
80
+ parser.add_argument("--save_dir", type=str, required=True,
81
+ help="Output directory for masks")
82
+ parser.add_argument("--prompt_template", type=str,
83
+ default="{}",
84
+ help="Template wrapping language_instruction. Use {} as placeholder.")
85
+ parser.add_argument("--start_episode", type=int, default=None,
86
+ help="First episode index to process (inclusive)")
87
+ parser.add_argument("--end_episode", type=int, default=None,
88
+ help="Last episode index to process (exclusive)")
89
+ return parser.parse_args(args)
90
+
91
+
92
+ def preprocess(
93
+ x,
94
+ pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
95
+ pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
96
+ img_size=1024,
97
+ ) -> torch.Tensor:
98
+ """Normalize pixel values and pad to a square input."""
99
+ x = (x - pixel_mean) / pixel_std
100
+ h, w = x.shape[-2:]
101
+ padh = img_size - h
102
+ padw = img_size - w
103
+ x = F.pad(x, (0, padw, 0, padh))
104
+ return x
105
+
106
+
107
+ def load_model(args):
108
+ """Load tokenizer and model, identical to chat.py."""
109
+ tokenizer = AutoTokenizer.from_pretrained(
110
+ args.version,
111
+ cache_dir=None,
112
+ model_max_length=args.model_max_length,
113
+ padding_side="right",
114
+ use_fast=False,
115
+ )
116
+ tokenizer.pad_token = tokenizer.unk_token
117
+ tokenizer.add_tokens("[SEG]")
118
+ args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
119
+ tokenizer.add_tokens("[AFF]")
120
+ args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
121
+
122
+ torch_dtype = torch.float32
123
+ if args.precision == "bf16":
124
+ torch_dtype = torch.bfloat16
125
+ elif args.precision == "fp16":
126
+ torch_dtype = torch.half
127
+
128
+ kwargs = {"torch_dtype": torch_dtype}
129
+ if args.load_in_4bit:
130
+ kwargs.update({
131
+ "torch_dtype": torch.half,
132
+ "load_in_4bit": True,
133
+ "quantization_config": BitsAndBytesConfig(
134
+ load_in_4bit=True,
135
+ bnb_4bit_compute_dtype=torch.float16,
136
+ bnb_4bit_use_double_quant=True,
137
+ bnb_4bit_quant_type="nf4",
138
+ llm_int8_skip_modules=["visual_model"],
139
+ ),
140
+ })
141
+ elif args.load_in_8bit:
142
+ kwargs.update({
143
+ "torch_dtype": torch.half,
144
+ "quantization_config": BitsAndBytesConfig(
145
+ llm_int8_skip_modules=["visual_model"],
146
+ load_in_8bit=True,
147
+ ),
148
+ })
149
+
150
+ model = AffordanceVLMForCausalLM.from_pretrained(
151
+ args.version,
152
+ low_cpu_mem_usage=True,
153
+ vision_tower=args.vision_tower,
154
+ seg_token_idx=args.seg_token_idx,
155
+ aff_token_idx=args.aff_token_idx,
156
+ **kwargs,
157
+ )
158
+
159
+ model.config.eos_token_id = tokenizer.eos_token_id
160
+ model.config.bos_token_id = tokenizer.bos_token_id
161
+ model.config.pad_token_id = tokenizer.pad_token_id
162
+
163
+ model.get_model().initialize_vision_modules(model.get_model().config)
164
+ vision_tower = model.get_model().get_vision_tower()
165
+ vision_tower.to(dtype=torch_dtype)
166
+
167
+ if args.precision == "bf16":
168
+ model = model.bfloat16().cuda()
169
+ elif args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit):
170
+ vision_tower = model.get_model().get_vision_tower()
171
+ model.model.vision_tower = None
172
+ import deepspeed
173
+ model_engine = deepspeed.init_inference(
174
+ model=model,
175
+ dtype=torch.half,
176
+ replace_with_kernel_inject=True,
177
+ replace_method="auto",
178
+ )
179
+ model = model_engine.module
180
+ model.model.vision_tower = vision_tower.half().cuda()
181
+ elif args.precision == "fp32":
182
+ model = model.float().cuda()
183
+
184
+ vision_tower = model.get_model().get_vision_tower()
185
+ vision_tower.to(device=args.local_rank)
186
+
187
+ clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
188
+ transform = ResizeLongestSide(args.image_size)
189
+
190
+ model.eval()
191
+ return model, tokenizer, clip_image_processor, transform
192
+
193
+
194
+ def build_prompt(text: str, args) -> str:
195
+ """Build the full conversation prompt from a text query."""
196
+ conv = conversation_lib.conv_templates[args.conv_type].copy()
197
+ conv.messages = []
198
+
199
+ prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + text
200
+ if args.use_mm_start_end:
201
+ replace_token = (
202
+ DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
203
+ )
204
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
205
+
206
+ conv.append_message(conv.roles[0], prompt)
207
+ conv.append_message(conv.roles[1], "")
208
+ return conv.get_prompt()
209
+
210
+
211
+ def infer_single_image(
212
+ image_path: str,
213
+ prompt_str: str,
214
+ model,
215
+ tokenizer,
216
+ clip_image_processor,
217
+ transform,
218
+ args,
219
+ ) -> "np.ndarray | None":
220
+ """Run inference on a single image. Returns binary mask (H, W) uint8 0/255 or None."""
221
+ image_np = cv2.imread(image_path)
222
+ if image_np is None:
223
+ print(f" [WARNING] Cannot read image: {image_path}")
224
+ return None
225
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
226
+ original_size_list = [image_np.shape[:2]]
227
+
228
+ # CLIP preprocessing
229
+ image_clip = (
230
+ clip_image_processor.preprocess(image_np, return_tensors="pt")["pixel_values"][0]
231
+ .unsqueeze(0)
232
+ .cuda()
233
+ )
234
+ if args.precision == "bf16":
235
+ image_clip = image_clip.bfloat16()
236
+ elif args.precision == "fp16":
237
+ image_clip = image_clip.half()
238
+ else:
239
+ image_clip = image_clip.float()
240
+
241
+ # SAM preprocessing
242
+ image = transform.apply_image(image_np)
243
+ resize_list = [image.shape[:2]]
244
+ image = (
245
+ preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
246
+ .unsqueeze(0)
247
+ .cuda()
248
+ )
249
+ if args.precision == "bf16":
250
+ image = image.bfloat16()
251
+ elif args.precision == "fp16":
252
+ image = image.half()
253
+ else:
254
+ image = image.float()
255
+
256
+ # Tokenize
257
+ input_ids = tokenizer_image_token(prompt_str, tokenizer, return_tensors="pt")
258
+ input_ids = input_ids.unsqueeze(0).cuda()
259
+
260
+ # Inference
261
+ with torch.no_grad():
262
+ output_ids, pred_masks = model.evaluate(
263
+ image_clip,
264
+ image,
265
+ input_ids,
266
+ resize_list,
267
+ original_size_list,
268
+ max_new_tokens=512,
269
+ tokenizer=tokenizer,
270
+ )
271
+
272
+ # Merge all predicted masks via union (logical OR)
273
+ h, w = original_size_list[0]
274
+ merged = np.zeros((h, w), dtype=bool)
275
+ has_mask = False
276
+ for pred_mask in pred_masks:
277
+ if pred_mask.shape[0] == 0:
278
+ continue
279
+ mask_np = pred_mask.detach().cpu().numpy()[0] # (H, W)
280
+ merged |= (mask_np > 0)
281
+ has_mask = True
282
+
283
+ if not has_mask:
284
+ return None
285
+
286
+ return (merged.astype(np.uint8) * 255)
287
+
288
+
289
+ def read_language_instruction(h5_path: str) -> str:
290
+ """Read language_instruction from other.h5."""
291
+ with h5py.File(h5_path, "r") as f:
292
+ instr = f["language_instruction"][()]
293
+ if isinstance(instr, bytes):
294
+ instr = instr.decode("utf-8")
295
+ return str(instr)
296
+
297
+
298
+ def main(args):
299
+ args = parse_args(args)
300
+ data_dir = Path(args.data_dir)
301
+ save_dir = Path(args.save_dir)
302
+
303
+ episodes_dir = data_dir / "episodes"
304
+ if not episodes_dir.is_dir():
305
+ print(f"Error: episodes directory not found at {episodes_dir}")
306
+ sys.exit(1)
307
+
308
+ # Collect and sort episode directories
309
+ episode_dirs = sorted(
310
+ [d for d in episodes_dir.iterdir() if d.is_dir()],
311
+ key=lambda p: p.name,
312
+ )
313
+
314
+ # Filter by episode range
315
+ if args.start_episode is not None or args.end_episode is not None:
316
+ start = args.start_episode if args.start_episode is not None else 0
317
+ end = args.end_episode if args.end_episode is not None else len(episode_dirs)
318
+ episode_dirs = [
319
+ d for d in episode_dirs
320
+ if start <= int(d.name) < end
321
+ ]
322
+
323
+ print(f"Data dir : {data_dir}")
324
+ print(f"Save dir : {save_dir}")
325
+ print(f"Episodes : {len(episode_dirs)}")
326
+ print(f"Prompt : {args.prompt_template}")
327
+ print()
328
+
329
+ # Load model
330
+ print("Loading model...")
331
+ model, tokenizer, clip_image_processor, transform = load_model(args)
332
+ print("Model loaded.\n")
333
+
334
+ total_steps = 0
335
+ empty_mask_count = 0
336
+
337
+ for ep_dir in episode_dirs:
338
+ episode_id = ep_dir.name # e.g. "000000"
339
+ steps_dir = ep_dir / "steps"
340
+ if not steps_dir.is_dir():
341
+ print(f" [WARNING] No steps/ in {ep_dir}, skipping.")
342
+ continue
343
+
344
+ step_dirs = sorted(
345
+ [d for d in steps_dir.iterdir() if d.is_dir()],
346
+ key=lambda p: p.name,
347
+ )
348
+
349
+ for step_dir in step_dirs:
350
+ step_id = step_dir.name # e.g. "0000"
351
+
352
+ # Read language instruction
353
+ other_h5 = step_dir / "other.h5"
354
+ if not other_h5.exists():
355
+ print(f" [WARNING] Missing other.h5 in {step_dir}, skipping.")
356
+ continue
357
+ language_instruction = read_language_instruction(str(other_h5))
358
+ # debug
359
+ # print(language_instruction)
360
+
361
+ # Build prompt
362
+ query_text = args.prompt_template.format(language_instruction)
363
+ prompt_str = build_prompt(query_text, args)
364
+
365
+ # Output directory (same structure as input: episodes/{episode_id}/steps/{step_id}/)
366
+ out_dir = save_dir / "episodes" / episode_id / "steps" / step_id
367
+ out_dir.mkdir(parents=True, exist_ok=True)
368
+
369
+ # Process both cameras
370
+ for cam_name in ("image_primary", "image_wrist"):
371
+ img_path = step_dir / f"{cam_name}.jpg"
372
+ mask_path = out_dir / f"{cam_name}_mask.png"
373
+
374
+ if not img_path.exists():
375
+ print(f" [WARNING] Missing {img_path}, skipping.")
376
+ continue
377
+
378
+ mask = infer_single_image(
379
+ str(img_path), prompt_str,
380
+ model, tokenizer, clip_image_processor, transform, args,
381
+ )
382
+
383
+ if mask is None:
384
+ # Save blank mask and warn
385
+ h, w = cv2.imread(str(img_path)).shape[:2]
386
+ mask = np.zeros((h, w), dtype=np.uint8)
387
+ empty_mask_count += 1
388
+
389
+ cv2.imwrite(str(mask_path), mask)
390
+
391
+ total_steps += 1
392
+ if total_steps % 50 == 0:
393
+ print(f" Processed {total_steps} steps (episode {episode_id}, step {step_id})")
394
+
395
+ print(f"Episode {episode_id} done ({len(step_dirs)} steps)")
396
+
397
+ print(f"\nFinished. {total_steps} steps processed, {empty_mask_count} empty masks.")
398
+
399
+
400
+ if __name__ == "__main__":
401
+ main(sys.argv[1:])
.ipynb_checkpoints/batch_generate-checkpoint.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Batch generate affordance masks for all four LIBERO subsets sequentially.
3
+
4
+ SRC_ROOT="/gemini/space/wrz/libero_per_frame"
5
+ TGT_ROOT="/gemini/space/wrz/ragnet_results"
6
+
7
+ for ds in libero_object libero_goal libero_spatial libero_10; do
8
+ echo "========== Processing ${ds} =========="
9
+ CUDA_VISIBLE_DEVICES=0 python batch_generate.py \
10
+ --data_dir "${SRC_ROOT}/${ds}_converted" \
11
+ --save_dir "${TGT_ROOT}/${ds}"
12
+ echo "========== ${ds} done =========="
13
+ echo
14
+ done
.ipynb_checkpoints/batch_generate_prefill_accelerate-checkpoint.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batch affordance mask generation for per-step datasets.
3
+
4
+ Reads a per-step dataset (converted by convert_lerobot_to_perstep.py) and
5
+ generates affordance masks for every image_primary.jpg and image_wrist.jpg
6
+ using AffordanceVLM.
7
+
8
+ Input structure:
9
+ {data_dir}/
10
+ ├── meta_info.h5
11
+ └── episodes/
12
+ └── {episode_id:06d}/
13
+ └── steps/
14
+ └── {step_id:04d}/
15
+ ├── other.h5 # language_instruction
16
+ ├── image_primary.jpg
17
+ └── image_wrist.jpg
18
+
19
+ Output structure:
20
+ {save_dir}/
21
+ └── episodes/
22
+ └── {episode_id:06d}/
23
+ └── steps/
24
+ └── {step_id:04d}/
25
+ ├── image_primary_mask.png # binary 0/255
26
+ └── image_wrist_mask.png
27
+
28
+ Usage:
29
+ CUDA_VISIBLE_DEVICES=1 python batch_generate_prefill_accelerate.py \
30
+ --data_dir /gemini/space/wrz/libero_per_frame/libero_spatial_converted \
31
+ --save_dir /gemini/space/wrz/ragnet_results/libero_spatial
32
+ """
33
+
34
+ import argparse
35
+ import os
36
+ import sys
37
+ from pathlib import Path
38
+
39
+ import cv2
40
+ import h5py
41
+ import numpy as np
42
+ import torch
43
+ import torch.nn.functional as F
44
+ from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
45
+
46
+ from model.AffordanceVLM import AffordanceVLMForCausalLM
47
+ from model.llava import conversation as conversation_lib
48
+ from model.llava.mm_utils import tokenizer_image_token
49
+ from model.segment_anything.utils.transforms import ResizeLongestSide
50
+ from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
51
+ DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
52
+
53
+
54
+ def parse_args(args):
55
+ parser = argparse.ArgumentParser(
56
+ description="Batch affordance mask generation for per-step datasets"
57
+ )
58
+ # Model arguments (same as chat.py)
59
+ parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B")
60
+ parser.add_argument(
61
+ "--precision", default="bf16", type=str,
62
+ choices=["fp32", "bf16", "fp16"],
63
+ )
64
+ parser.add_argument("--image_size", default=1024, type=int)
65
+ parser.add_argument("--model_max_length", default=512, type=int)
66
+ parser.add_argument("--lora_r", default=8, type=int)
67
+ parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14", type=str)
68
+ parser.add_argument("--local-rank", default=0, type=int)
69
+ parser.add_argument("--load_in_8bit", action="store_true", default=False)
70
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
71
+ parser.add_argument("--use_mm_start_end", action="store_true", default=True)
72
+ parser.add_argument(
73
+ "--conv_type", default="llava_v1", type=str,
74
+ choices=["llava_v1", "llava_llama_2"],
75
+ )
76
+
77
+ # Batch processing arguments
78
+ parser.add_argument("--data_dir", type=str, required=True,
79
+ help="Root of per-step dataset (contains episodes/)")
80
+ parser.add_argument("--save_dir", type=str, required=True,
81
+ help="Output directory for masks")
82
+ parser.add_argument("--prompt_template", type=str,
83
+ default="{}",
84
+ help="Template wrapping language_instruction. Use {} as placeholder.")
85
+ # "{}"
86
+ # Segment the most suitable manipulation region on the single target object for the task '{}'.
87
+ # Segment the affordance map for the task '{}' in this image.
88
+ # Segment the affordance map of the single target object for the task '{}' in this image.
89
+ # Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask.
90
+ # Given the task instruction '{}', what is the affordance map of the single target object in this image? There is only one target object. Please output segmentation mask.
91
+ parser.add_argument("--start_episode", type=int, default=None,
92
+ help="First episode index to process (inclusive)")
93
+ parser.add_argument("--end_episode", type=int, default=None,
94
+ help="Last episode index to process (exclusive)")
95
+ return parser.parse_args(args)
96
+
97
+
98
+ def preprocess(
99
+ x,
100
+ pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
101
+ pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
102
+ img_size=1024,
103
+ ) -> torch.Tensor:
104
+ """Normalize pixel values and pad to a square input."""
105
+ x = (x - pixel_mean) / pixel_std
106
+ h, w = x.shape[-2:]
107
+ padh = img_size - h
108
+ padw = img_size - w
109
+ x = F.pad(x, (0, padw, 0, padh))
110
+ return x
111
+
112
+
113
+ def load_model(args):
114
+ """Load tokenizer and model, identical to chat.py."""
115
+ tokenizer = AutoTokenizer.from_pretrained(
116
+ args.version,
117
+ cache_dir=None,
118
+ model_max_length=args.model_max_length,
119
+ padding_side="right",
120
+ use_fast=False,
121
+ )
122
+ tokenizer.pad_token = tokenizer.unk_token
123
+ tokenizer.add_tokens("[SEG]")
124
+ args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
125
+ tokenizer.add_tokens("[AFF]")
126
+ args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
127
+
128
+ torch_dtype = torch.float32
129
+ if args.precision == "bf16":
130
+ torch_dtype = torch.bfloat16
131
+ elif args.precision == "fp16":
132
+ torch_dtype = torch.half
133
+
134
+ kwargs = {"torch_dtype": torch_dtype}
135
+ if args.load_in_4bit:
136
+ kwargs.update({
137
+ "torch_dtype": torch.half,
138
+ "load_in_4bit": True,
139
+ "quantization_config": BitsAndBytesConfig(
140
+ load_in_4bit=True,
141
+ bnb_4bit_compute_dtype=torch.float16,
142
+ bnb_4bit_use_double_quant=True,
143
+ bnb_4bit_quant_type="nf4",
144
+ llm_int8_skip_modules=["visual_model"],
145
+ ),
146
+ })
147
+ elif args.load_in_8bit:
148
+ kwargs.update({
149
+ "torch_dtype": torch.half,
150
+ "quantization_config": BitsAndBytesConfig(
151
+ llm_int8_skip_modules=["visual_model"],
152
+ load_in_8bit=True,
153
+ ),
154
+ })
155
+
156
+ model = AffordanceVLMForCausalLM.from_pretrained(
157
+ args.version,
158
+ low_cpu_mem_usage=True,
159
+ vision_tower=args.vision_tower,
160
+ seg_token_idx=args.seg_token_idx,
161
+ aff_token_idx=args.aff_token_idx,
162
+ **kwargs,
163
+ )
164
+
165
+ model.config.eos_token_id = tokenizer.eos_token_id
166
+ model.config.bos_token_id = tokenizer.bos_token_id
167
+ model.config.pad_token_id = tokenizer.pad_token_id
168
+
169
+ model.get_model().initialize_vision_modules(model.get_model().config)
170
+ vision_tower = model.get_model().get_vision_tower()
171
+ vision_tower.to(dtype=torch_dtype)
172
+
173
+ if args.precision == "bf16":
174
+ model = model.bfloat16().cuda()
175
+ elif args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit):
176
+ vision_tower = model.get_model().get_vision_tower()
177
+ model.model.vision_tower = None
178
+ import deepspeed
179
+ model_engine = deepspeed.init_inference(
180
+ model=model,
181
+ dtype=torch.half,
182
+ replace_with_kernel_inject=True,
183
+ replace_method="auto",
184
+ )
185
+ model = model_engine.module
186
+ model.model.vision_tower = vision_tower.half().cuda()
187
+ elif args.precision == "fp32":
188
+ model = model.float().cuda()
189
+
190
+ vision_tower = model.get_model().get_vision_tower()
191
+ vision_tower.to(device=args.local_rank)
192
+
193
+ clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
194
+ transform = ResizeLongestSide(args.image_size)
195
+
196
+ model.eval()
197
+ return model, tokenizer, clip_image_processor, transform
198
+
199
+
200
+ def build_prompt(text: str, args) -> str:
201
+ """Build the full conversation prompt from a text query."""
202
+ conv = conversation_lib.conv_templates[args.conv_type].copy()
203
+ conv.messages = []
204
+
205
+ prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + text
206
+ if args.use_mm_start_end:
207
+ replace_token = (
208
+ DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
209
+ )
210
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
211
+
212
+ conv.append_message(conv.roles[0], prompt)
213
+ conv.append_message(conv.roles[1], "[AFF].")
214
+ return conv.get_prompt()
215
+
216
+
217
+ def infer_single_image(
218
+ image_path: str,
219
+ prompt_str: str,
220
+ model,
221
+ tokenizer,
222
+ clip_image_processor,
223
+ transform,
224
+ args,
225
+ ) -> "np.ndarray | None":
226
+ """Run inference on a single image. Returns binary mask (H, W) uint8 0/255 or None."""
227
+ image_np = cv2.imread(image_path)
228
+ if image_np is None:
229
+ print(f" [WARNING] Cannot read image: {image_path}")
230
+ return None
231
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
232
+ original_size_list = [image_np.shape[:2]]
233
+
234
+ # CLIP preprocessing
235
+ image_clip = (
236
+ clip_image_processor.preprocess(image_np, return_tensors="pt")["pixel_values"][0]
237
+ .unsqueeze(0)
238
+ .cuda()
239
+ )
240
+ if args.precision == "bf16":
241
+ image_clip = image_clip.bfloat16()
242
+ elif args.precision == "fp16":
243
+ image_clip = image_clip.half()
244
+ else:
245
+ image_clip = image_clip.float()
246
+
247
+ # SAM preprocessing
248
+ image = transform.apply_image(image_np)
249
+ resize_list = [image.shape[:2]]
250
+ image = (
251
+ preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
252
+ .unsqueeze(0)
253
+ .cuda()
254
+ )
255
+ if args.precision == "bf16":
256
+ image = image.bfloat16()
257
+ elif args.precision == "fp16":
258
+ image = image.half()
259
+ else:
260
+ image = image.float()
261
+
262
+ # Tokenize
263
+ input_ids = tokenizer_image_token(prompt_str, tokenizer, return_tensors="pt")
264
+ input_ids = input_ids.unsqueeze(0).cuda()
265
+ attention_masks = input_ids.ne(tokenizer.pad_token_id)
266
+
267
+ # Prefill inference (single forward pass instead of autoregressive generation)
268
+ h, w = original_size_list[0]
269
+ labels = input_ids.clone()
270
+ offset = torch.LongTensor([0, 1]).cuda()
271
+ masks_list = [torch.zeros(1, h, w).float().cuda()]
272
+ label_list = [torch.zeros(h, w).long().cuda()]
273
+
274
+ with torch.no_grad():
275
+ output_dict = model(
276
+ images=image,
277
+ images_clip=image_clip,
278
+ input_ids=input_ids,
279
+ labels=labels,
280
+ attention_masks=attention_masks,
281
+ offset=offset,
282
+ masks_list=masks_list,
283
+ label_list=label_list,
284
+ resize_list=resize_list,
285
+ inference=True,
286
+ )
287
+
288
+ pred_masks = output_dict["pred_masks"]
289
+
290
+ # Merge all predicted masks via union (logical OR)
291
+ merged = np.zeros((h, w), dtype=bool)
292
+ has_mask = False
293
+ for pred_mask in pred_masks:
294
+ if pred_mask.shape[0] == 0:
295
+ continue
296
+ mask_np = pred_mask.detach().cpu().numpy()[0] # (H, W)
297
+ merged |= (mask_np > 0)
298
+ has_mask = True
299
+
300
+ if not has_mask:
301
+ return None
302
+
303
+ return (merged.astype(np.uint8) * 255)
304
+
305
+
306
+ def read_language_instruction(h5_path: str) -> str:
307
+ """Read language_instruction from other.h5."""
308
+ with h5py.File(h5_path, "r") as f:
309
+ instr = f["language_instruction"][()]
310
+ if isinstance(instr, bytes):
311
+ instr = instr.decode("utf-8")
312
+ return str(instr)
313
+
314
+
315
+ def main(args):
316
+ args = parse_args(args)
317
+ data_dir = Path(args.data_dir)
318
+ save_dir = Path(args.save_dir)
319
+
320
+ episodes_dir = data_dir / "episodes"
321
+ if not episodes_dir.is_dir():
322
+ print(f"Error: episodes directory not found at {episodes_dir}")
323
+ sys.exit(1)
324
+
325
+ # Collect and sort episode directories
326
+ episode_dirs = sorted(
327
+ [d for d in episodes_dir.iterdir() if d.is_dir()],
328
+ key=lambda p: p.name,
329
+ )
330
+
331
+ # Filter by episode range
332
+ if args.start_episode is not None or args.end_episode is not None:
333
+ start = args.start_episode if args.start_episode is not None else 0
334
+ end = args.end_episode if args.end_episode is not None else len(episode_dirs)
335
+ episode_dirs = [
336
+ d for d in episode_dirs
337
+ if start <= int(d.name) < end
338
+ ]
339
+
340
+ print(f"Data dir : {data_dir}")
341
+ print(f"Save dir : {save_dir}")
342
+ print(f"Episodes : {len(episode_dirs)}")
343
+ print(f"Prompt : {args.prompt_template}")
344
+ print()
345
+
346
+ # Load model
347
+ print("Loading model...")
348
+ model, tokenizer, clip_image_processor, transform = load_model(args)
349
+ print("Model loaded.\n")
350
+
351
+ total_steps = 0
352
+ empty_mask_count = 0
353
+
354
+ for ep_dir in episode_dirs:
355
+ episode_id = ep_dir.name # e.g. "000000"
356
+ steps_dir = ep_dir / "steps"
357
+ if not steps_dir.is_dir():
358
+ print(f" [WARNING] No steps/ in {ep_dir}, skipping.")
359
+ continue
360
+
361
+ step_dirs = sorted(
362
+ [d for d in steps_dir.iterdir() if d.is_dir()],
363
+ key=lambda p: p.name,
364
+ )
365
+
366
+ for step_dir in step_dirs:
367
+ step_id = step_dir.name # e.g. "0000"
368
+
369
+ # Read language instruction
370
+ other_h5 = step_dir / "other.h5"
371
+ if not other_h5.exists():
372
+ print(f" [WARNING] Missing other.h5 in {step_dir}, skipping.")
373
+ continue
374
+ language_instruction = read_language_instruction(str(other_h5))
375
+ # debug
376
+ # print(language_instruction)
377
+
378
+ # Build prompt
379
+ query_text = args.prompt_template.format(language_instruction)
380
+ prompt_str = build_prompt(query_text, args)
381
+
382
+ # Output directory (same structure as input: episodes/{episode_id}/steps/{step_id}/)
383
+ out_dir = save_dir / "episodes" / episode_id / "steps" / step_id
384
+ out_dir.mkdir(parents=True, exist_ok=True)
385
+
386
+ # Process both cameras
387
+ for cam_name in ("image_primary", "image_wrist"):
388
+ img_path = step_dir / f"{cam_name}.jpg"
389
+ mask_path = out_dir / f"{cam_name}_mask.png"
390
+
391
+ if not img_path.exists():
392
+ print(f" [WARNING] Missing {img_path}, skipping.")
393
+ continue
394
+
395
+ mask = infer_single_image(
396
+ str(img_path), prompt_str,
397
+ model, tokenizer, clip_image_processor, transform, args,
398
+ )
399
+
400
+ if mask is None:
401
+ # Save blank mask and warn
402
+ h, w = cv2.imread(str(img_path)).shape[:2]
403
+ mask = np.zeros((h, w), dtype=np.uint8)
404
+ empty_mask_count += 1
405
+
406
+ cv2.imwrite(str(mask_path), mask)
407
+
408
+ total_steps += 1
409
+ if total_steps % 50 == 0:
410
+ print(f" Processed {total_steps} steps (episode {episode_id}, step {step_id})")
411
+
412
+ print(f"Episode {episode_id} done ({len(step_dirs)} steps)")
413
+
414
+ print(f"\nFinished. {total_steps} steps processed, {empty_mask_count} empty masks.")
415
+
416
+
417
+ if __name__ == "__main__":
418
+ main(sys.argv[1:])
.ipynb_checkpoints/chat-checkpoint.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
10
+
11
+ from model.AffordanceVLM import AffordanceVLMForCausalLM
12
+ from model.llava import conversation as conversation_lib
13
+ from model.llava.mm_utils import tokenizer_image_token
14
+ from model.segment_anything.utils.transforms import ResizeLongestSide
15
+ from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
16
+ DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
17
+
18
+
19
+ def parse_args(args):
20
+ parser = argparse.ArgumentParser(description="LISA chat")
21
+ parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B")
22
+ parser.add_argument("--vis_save_path", default="./vis_output", type=str)
23
+ parser.add_argument(
24
+ "--precision",
25
+ default="bf16",
26
+ type=str,
27
+ choices=["fp32", "bf16", "fp16"],
28
+ help="precision for inference",
29
+ )
30
+ parser.add_argument("--image_size", default=1024, type=int, help="image size")
31
+ parser.add_argument("--model_max_length", default=512, type=int)
32
+ parser.add_argument("--lora_r", default=8, type=int)
33
+ parser.add_argument(
34
+ "--vision-tower", default="openai/clip-vit-large-patch14", type=str
35
+ )
36
+ parser.add_argument("--local-rank", default=0, type=int, help="node rank")
37
+ parser.add_argument("--load_in_8bit", action="store_true", default=False)
38
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
39
+ parser.add_argument("--use_mm_start_end", action="store_true", default=True)
40
+ parser.add_argument(
41
+ "--conv_type",
42
+ default="llava_v1",
43
+ type=str,
44
+ choices=["llava_v1", "llava_llama_2"],
45
+ )
46
+ return parser.parse_args(args)
47
+
48
+
49
+ def preprocess(
50
+ x,
51
+ pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
52
+ pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
53
+ img_size=1024,
54
+ ) -> torch.Tensor:
55
+ """Normalize pixel values and pad to a square input."""
56
+ # Normalize colors
57
+ x = (x - pixel_mean) / pixel_std
58
+ # Pad
59
+ h, w = x.shape[-2:]
60
+ padh = img_size - h
61
+ padw = img_size - w
62
+ x = F.pad(x, (0, padw, 0, padh))
63
+ return x
64
+
65
+
66
+ def main(args):
67
+ args = parse_args(args)
68
+ os.makedirs(args.vis_save_path, exist_ok=True)
69
+
70
+ # Create model
71
+ tokenizer = AutoTokenizer.from_pretrained(
72
+ args.version,
73
+ cache_dir=None,
74
+ model_max_length=args.model_max_length,
75
+ padding_side="right",
76
+ use_fast=False,
77
+ )
78
+ tokenizer.pad_token = tokenizer.unk_token
79
+ num_added_tokens = tokenizer.add_tokens("[SEG]")
80
+ args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
81
+ num_added_tokens = tokenizer.add_tokens("[AFF]")
82
+ args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
83
+
84
+ torch_dtype = torch.float32
85
+ if args.precision == "bf16":
86
+ torch_dtype = torch.bfloat16
87
+ elif args.precision == "fp16":
88
+ torch_dtype = torch.half
89
+
90
+ kwargs = {"torch_dtype": torch_dtype}
91
+ if args.load_in_4bit:
92
+ kwargs.update(
93
+ {
94
+ "torch_dtype": torch.half,
95
+ "load_in_4bit": True,
96
+ "quantization_config": BitsAndBytesConfig(
97
+ load_in_4bit=True,
98
+ bnb_4bit_compute_dtype=torch.float16,
99
+ bnb_4bit_use_double_quant=True,
100
+ bnb_4bit_quant_type="nf4",
101
+ llm_int8_skip_modules=["visual_model"],
102
+ ),
103
+ }
104
+ )
105
+ elif args.load_in_8bit:
106
+ kwargs.update(
107
+ {
108
+ "torch_dtype": torch.half,
109
+ "quantization_config": BitsAndBytesConfig(
110
+ llm_int8_skip_modules=["visual_model"],
111
+ load_in_8bit=True,
112
+ ),
113
+ }
114
+ )
115
+
116
+ model = AffordanceVLMForCausalLM.from_pretrained(
117
+ args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, aff_token_idx=args.aff_token_idx, **kwargs
118
+ )
119
+
120
+ model.config.eos_token_id = tokenizer.eos_token_id
121
+ model.config.bos_token_id = tokenizer.bos_token_id
122
+ model.config.pad_token_id = tokenizer.pad_token_id
123
+
124
+ model.get_model().initialize_vision_modules(model.get_model().config)
125
+ vision_tower = model.get_model().get_vision_tower()
126
+ vision_tower.to(dtype=torch_dtype)
127
+
128
+ if args.precision == "bf16":
129
+ model = model.bfloat16().cuda()
130
+ elif (
131
+ args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit)
132
+ ):
133
+ vision_tower = model.get_model().get_vision_tower()
134
+ model.model.vision_tower = None
135
+ import deepspeed
136
+
137
+ model_engine = deepspeed.init_inference(
138
+ model=model,
139
+ dtype=torch.half,
140
+ replace_with_kernel_inject=True,
141
+ replace_method="auto",
142
+ )
143
+ model = model_engine.module
144
+ model.model.vision_tower = vision_tower.half().cuda()
145
+ elif args.precision == "fp32":
146
+ model = model.float().cuda()
147
+
148
+ vision_tower = model.get_model().get_vision_tower()
149
+ vision_tower.to(device=args.local_rank)
150
+
151
+ clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
152
+ transform = ResizeLongestSide(args.image_size)
153
+
154
+ model.eval()
155
+
156
+ while True:
157
+ conv = conversation_lib.conv_templates[args.conv_type].copy()
158
+ conv.messages = []
159
+
160
+ prompt = input("Please input your prompt: ")
161
+ prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + prompt
162
+ if args.use_mm_start_end:
163
+ replace_token = (
164
+ DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
165
+ )
166
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
167
+
168
+ conv.append_message(conv.roles[0], prompt)
169
+ conv.append_message(conv.roles[1], "")
170
+ prompt = conv.get_prompt()
171
+
172
+ image_path = input("Please input the image path: ")
173
+ if not os.path.exists(image_path):
174
+ print("File not found in {}".format(image_path))
175
+ continue
176
+
177
+ image_np = cv2.imread(image_path)
178
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
179
+ original_size_list = [image_np.shape[:2]]
180
+
181
+ image_clip = (
182
+ clip_image_processor.preprocess(image_np, return_tensors="pt")[
183
+ "pixel_values"
184
+ ][0]
185
+ .unsqueeze(0)
186
+ .cuda()
187
+ )
188
+ if args.precision == "bf16":
189
+ image_clip = image_clip.bfloat16()
190
+ elif args.precision == "fp16":
191
+ image_clip = image_clip.half()
192
+ else:
193
+ image_clip = image_clip.float()
194
+
195
+ image = transform.apply_image(image_np)
196
+ resize_list = [image.shape[:2]]
197
+
198
+ image = (
199
+ preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
200
+ .unsqueeze(0)
201
+ .cuda()
202
+ )
203
+ if args.precision == "bf16":
204
+ image = image.bfloat16()
205
+ elif args.precision == "fp16":
206
+ image = image.half()
207
+ else:
208
+ image = image.float()
209
+
210
+ input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
211
+ input_ids = input_ids.unsqueeze(0).cuda()
212
+
213
+ output_ids, pred_masks = model.evaluate(
214
+ image_clip,
215
+ image,
216
+ input_ids,
217
+ resize_list,
218
+ original_size_list,
219
+ max_new_tokens=512,
220
+ tokenizer=tokenizer,
221
+ )
222
+ output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
223
+
224
+ text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
225
+ text_output = text_output.replace("\n", "").replace(" ", " ")
226
+ print("text_output: ", text_output)
227
+
228
+ for i, pred_mask in enumerate(pred_masks):
229
+ if pred_mask.shape[0] == 0:
230
+ continue
231
+
232
+ pred_mask = pred_mask.detach().cpu().numpy()[0]
233
+ pred_mask = pred_mask > 0
234
+
235
+ save_path = "{}/{}_mask_{}.jpg".format(
236
+ args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
237
+ )
238
+ cv2.imwrite(save_path, pred_mask * 100)
239
+ print("{} has been saved.".format(save_path))
240
+
241
+ save_path = "{}/{}_masked_img_{}.jpg".format(
242
+ args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
243
+ )
244
+ save_img = image_np.copy()
245
+ save_img[pred_mask] = (
246
+ image_np * 0.5
247
+ + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
248
+ )[pred_mask]
249
+ save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
250
+ cv2.imwrite(save_path, save_img)
251
+ print("{} has been saved.".format(save_path))
252
+
253
+
254
+ if __name__ == "__main__":
255
+ main(sys.argv[1:])
.ipynb_checkpoints/chat_prefill-checkpoint.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Interactive affordance mask generation using prefill mode (single forward pass).
3
+
4
+ Same interactive workflow as chat.py, but uses prefill inference instead of
5
+ autoregressive generation. The assistant response "[AFF]." is pre-filled in the
6
+ prompt, so the model only does one forward pass to extract mask embeddings.
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+ import sys
12
+
13
+ import cv2
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
18
+
19
+ from model.AffordanceVLM import AffordanceVLMForCausalLM
20
+ from model.llava import conversation as conversation_lib
21
+ from model.llava.mm_utils import tokenizer_image_token
22
+ from model.segment_anything.utils.transforms import ResizeLongestSide
23
+ from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
24
+ DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
25
+
26
+
27
+ def parse_args(args):
28
+ parser = argparse.ArgumentParser(description="AffordanceVLM chat (prefill mode)")
29
+ parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B")
30
+ parser.add_argument("--vis_save_path", default="./vis_output_prefill", type=str)
31
+ parser.add_argument(
32
+ "--precision", default="bf16", type=str,
33
+ choices=["fp32", "bf16", "fp16"],
34
+ )
35
+ parser.add_argument("--image_size", default=1024, type=int)
36
+ parser.add_argument("--model_max_length", default=512, type=int)
37
+ parser.add_argument("--lora_r", default=8, type=int)
38
+ parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14", type=str)
39
+ parser.add_argument("--local-rank", default=0, type=int)
40
+ parser.add_argument("--load_in_8bit", action="store_true", default=False)
41
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
42
+ parser.add_argument("--use_mm_start_end", action="store_true", default=True)
43
+ parser.add_argument(
44
+ "--conv_type", default="llava_v1", type=str,
45
+ choices=["llava_v1", "llava_llama_2"],
46
+ )
47
+ parser.add_argument("--prompt_template", type=str,
48
+ default="Segment the most suitable manipulation region on the single target object for the task '{}'.",
49
+ help="Template wrapping language_instruction. Use {} as placeholder.")
50
+ # Segment the most suitable manipulation region on the single target object for the task '{}'.
51
+ # Segment the affordance map for the task '{}' in this image.
52
+ # Segment the affordance map of the single target object for the task '{}' in this image.
53
+ # Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask.
54
+ # Given the task instruction '{}', what is the affordance map of the single target object in this image? There is only one target object. Please output segmentation mask.
55
+ return parser.parse_args(args)
56
+
57
+
58
+ def preprocess(
59
+ x,
60
+ pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
61
+ pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
62
+ img_size=1024,
63
+ ) -> torch.Tensor:
64
+ """Normalize pixel values and pad to a square input."""
65
+ x = (x - pixel_mean) / pixel_std
66
+ h, w = x.shape[-2:]
67
+ padh = img_size - h
68
+ padw = img_size - w
69
+ x = F.pad(x, (0, padw, 0, padh))
70
+ return x
71
+
72
+
73
+ def main(args):
74
+ args = parse_args(args)
75
+ os.makedirs(args.vis_save_path, exist_ok=True)
76
+
77
+ # Create model
78
+ tokenizer = AutoTokenizer.from_pretrained(
79
+ args.version,
80
+ cache_dir=None,
81
+ model_max_length=args.model_max_length,
82
+ padding_side="right",
83
+ use_fast=False,
84
+ )
85
+ tokenizer.pad_token = tokenizer.unk_token
86
+ tokenizer.add_tokens("[SEG]")
87
+ args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
88
+ tokenizer.add_tokens("[AFF]")
89
+ args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
90
+
91
+ torch_dtype = torch.float32
92
+ if args.precision == "bf16":
93
+ torch_dtype = torch.bfloat16
94
+ elif args.precision == "fp16":
95
+ torch_dtype = torch.half
96
+
97
+ kwargs = {"torch_dtype": torch_dtype}
98
+ if args.load_in_4bit:
99
+ kwargs.update({
100
+ "torch_dtype": torch.half,
101
+ "load_in_4bit": True,
102
+ "quantization_config": BitsAndBytesConfig(
103
+ load_in_4bit=True,
104
+ bnb_4bit_compute_dtype=torch.float16,
105
+ bnb_4bit_use_double_quant=True,
106
+ bnb_4bit_quant_type="nf4",
107
+ llm_int8_skip_modules=["visual_model"],
108
+ ),
109
+ })
110
+ elif args.load_in_8bit:
111
+ kwargs.update({
112
+ "torch_dtype": torch.half,
113
+ "quantization_config": BitsAndBytesConfig(
114
+ llm_int8_skip_modules=["visual_model"],
115
+ load_in_8bit=True,
116
+ ),
117
+ })
118
+
119
+ model = AffordanceVLMForCausalLM.from_pretrained(
120
+ args.version,
121
+ low_cpu_mem_usage=True,
122
+ vision_tower=args.vision_tower,
123
+ seg_token_idx=args.seg_token_idx,
124
+ aff_token_idx=args.aff_token_idx,
125
+ **kwargs,
126
+ )
127
+
128
+ model.config.eos_token_id = tokenizer.eos_token_id
129
+ model.config.bos_token_id = tokenizer.bos_token_id
130
+ model.config.pad_token_id = tokenizer.pad_token_id
131
+
132
+ model.get_model().initialize_vision_modules(model.get_model().config)
133
+ vision_tower = model.get_model().get_vision_tower()
134
+ vision_tower.to(dtype=torch_dtype)
135
+
136
+ if args.precision == "bf16":
137
+ model = model.bfloat16().cuda()
138
+ elif args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit):
139
+ vision_tower = model.get_model().get_vision_tower()
140
+ model.model.vision_tower = None
141
+ import deepspeed
142
+ model_engine = deepspeed.init_inference(
143
+ model=model,
144
+ dtype=torch.half,
145
+ replace_with_kernel_inject=True,
146
+ replace_method="auto",
147
+ )
148
+ model = model_engine.module
149
+ model.model.vision_tower = vision_tower.half().cuda()
150
+ elif args.precision == "fp32":
151
+ model = model.float().cuda()
152
+
153
+ vision_tower = model.get_model().get_vision_tower()
154
+ vision_tower.to(device=args.local_rank)
155
+
156
+ clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
157
+ transform = ResizeLongestSide(args.image_size)
158
+
159
+ model.eval()
160
+
161
+ # debug
162
+ template = "Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask."
163
+
164
+ while True:
165
+ conv = conversation_lib.conv_templates[args.conv_type].copy()
166
+ conv.messages = []
167
+
168
+ prompt = input("Please input your prompt: ")
169
+ # 加入模版
170
+ prompt = args.prompt_template.format(prompt)
171
+
172
+ prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + prompt
173
+ if args.use_mm_start_end:
174
+ replace_token = (
175
+ DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
176
+ )
177
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
178
+
179
+ conv.append_message(conv.roles[0], prompt)
180
+ conv.append_message(conv.roles[1], "[AFF].")
181
+ prompt = conv.get_prompt()
182
+
183
+ image_path = input("Please input the image path: ")
184
+ if not os.path.exists(image_path):
185
+ print("File not found in {}".format(image_path))
186
+ continue
187
+
188
+ image_np = cv2.imread(image_path)
189
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
190
+ original_size_list = [image_np.shape[:2]]
191
+ h, w = original_size_list[0]
192
+
193
+ image_clip = (
194
+ clip_image_processor.preprocess(image_np, return_tensors="pt")[
195
+ "pixel_values"
196
+ ][0]
197
+ .unsqueeze(0)
198
+ .cuda()
199
+ )
200
+ if args.precision == "bf16":
201
+ image_clip = image_clip.bfloat16()
202
+ elif args.precision == "fp16":
203
+ image_clip = image_clip.half()
204
+ else:
205
+ image_clip = image_clip.float()
206
+
207
+ image = transform.apply_image(image_np)
208
+ resize_list = [image.shape[:2]]
209
+
210
+ image = (
211
+ preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
212
+ .unsqueeze(0)
213
+ .cuda()
214
+ )
215
+ if args.precision == "bf16":
216
+ image = image.bfloat16()
217
+ elif args.precision == "fp16":
218
+ image = image.half()
219
+ else:
220
+ image = image.float()
221
+
222
+ input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
223
+ input_ids = input_ids.unsqueeze(0).cuda()
224
+ attention_masks = input_ids.ne(tokenizer.pad_token_id)
225
+
226
+ # Print the full prompt text (prefill mode has no generated text)
227
+ # debug
228
+ text_ids = input_ids[0][input_ids[0] != IMAGE_TOKEN_INDEX]
229
+ text_output = tokenizer.decode(text_ids, skip_special_tokens=False)
230
+ text_output = text_output.replace("\n", "").replace(" ", " ")
231
+ print("text_output: ", text_output)
232
+
233
+ # Prefill inference
234
+ labels = input_ids.clone()
235
+ offset = torch.LongTensor([0, 1]).cuda()
236
+ masks_list = [torch.zeros(1, h, w).float().cuda()]
237
+ label_list = [torch.zeros(h, w).long().cuda()]
238
+
239
+ with torch.no_grad():
240
+ output_dict = model(
241
+ images=image,
242
+ images_clip=image_clip,
243
+ input_ids=input_ids,
244
+ labels=labels,
245
+ attention_masks=attention_masks,
246
+ offset=offset,
247
+ masks_list=masks_list,
248
+ label_list=label_list,
249
+ resize_list=resize_list,
250
+ inference=True,
251
+ )
252
+
253
+ pred_masks = output_dict["pred_masks"]
254
+
255
+ for i, pred_mask in enumerate(pred_masks):
256
+ if pred_mask.shape[0] == 0:
257
+ continue
258
+
259
+ pred_mask = pred_mask.detach().cpu().numpy()[0]
260
+ pred_mask = pred_mask > 0
261
+
262
+ save_path = "{}/{}_mask_{}.jpg".format(
263
+ args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
264
+ )
265
+ cv2.imwrite(save_path, pred_mask * 100)
266
+ print("{} has been saved.".format(save_path))
267
+
268
+ save_path = "{}/{}_masked_img_{}.jpg".format(
269
+ args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
270
+ )
271
+ save_img = image_np.copy()
272
+ save_img[pred_mask] = (
273
+ image_np * 0.5
274
+ + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
275
+ )[pred_mask]
276
+ save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
277
+ cv2.imwrite(save_path, save_img)
278
+ print("{} has been saved.".format(save_path))
279
+
280
+
281
+ if __name__ == "__main__":
282
+ main(sys.argv[1:])
.ipynb_checkpoints/train_aff-checkpoint.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import shutil
4
+ import sys
5
+ import time
6
+ from functools import partial
7
+
8
+ import deepspeed
9
+ import numpy as np
10
+ import torch
11
+ import tqdm
12
+ import transformers
13
+ from peft import LoraConfig, get_peft_model
14
+ from torch.utils.tensorboard import SummaryWriter
15
+
16
+ from model.AffordanceVLM import AffordanceVLMForCausalLM
17
+ from model.llava import conversation as conversation_lib
18
+ from utils.dataset import HybridDataset, ValDataset, collate_fn
19
+ from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
20
+ AverageMeter, ProgressMeter, Summary, dict_to_cuda,
21
+ intersectionAndUnionGPU)
22
+
23
+ from utils.aff_seg_dataset import AffValDataset
24
+ from utils.reason_aff_dataset import ReasonAffValDataset
25
+
26
+
27
+ def parse_args(args):
28
+ parser = argparse.ArgumentParser(description="LISA Model Training")
29
+ parser.add_argument("--local_rank", default=0, type=int, help="node rank")
30
+ parser.add_argument(
31
+ "--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview"
32
+ )
33
+ parser.add_argument("--vis_save_path", default="./vis_output", type=str)
34
+ parser.add_argument(
35
+ "--precision",
36
+ default="bf16",
37
+ type=str,
38
+ choices=["fp32", "bf16", "fp16"],
39
+ help="precision for inference",
40
+ )
41
+ parser.add_argument("--image_size", default=1024, type=int, help="image size")
42
+ parser.add_argument("--model_max_length", default=512, type=int)
43
+ parser.add_argument("--lora_r", default=8, type=int)
44
+ parser.add_argument(
45
+ "--vision-tower", default="openai/clip-vit-large-patch14", type=str
46
+ )
47
+ parser.add_argument("--load_in_8bit", action="store_true", default=False)
48
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
49
+
50
+ parser.add_argument(
51
+ "--dataset", default="sem_seg||refer_seg||vqa||reason_seg", type=str
52
+ )
53
+ parser.add_argument("--sample_rates", default="9,3,3,1", type=str)
54
+ parser.add_argument(
55
+ "--sem_seg_data",
56
+ default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary",
57
+ type=str,
58
+ )
59
+ parser.add_argument(
60
+ "--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str
61
+ )
62
+ parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
63
+ parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
64
+ parser.add_argument("--aff_seg_data", default="handal", type=str)
65
+ parser.add_argument("--aff_sample_rates", default="1", type=str)
66
+ parser.add_argument("--reason_aff_data", default="handal_hard_reasoning", type=str)
67
+ parser.add_argument("--reason_aff_sample_rates", default="1", type=str)
68
+ parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str)
69
+ parser.add_argument("--dataset_dir", default="./dataset", type=str)
70
+ parser.add_argument("--log_base_dir", default="./runs", type=str)
71
+ parser.add_argument("--exp_name", default="lisa", type=str)
72
+ parser.add_argument("--epochs", default=10, type=int)
73
+ parser.add_argument("--steps_per_epoch", default=500, type=int)
74
+ parser.add_argument(
75
+ "--batch_size", default=2, type=int, help="batch size per device per step"
76
+ )
77
+ parser.add_argument(
78
+ "--grad_accumulation_steps",
79
+ default=10,
80
+ type=int,
81
+ )
82
+ parser.add_argument("--val_batch_size", default=1, type=int)
83
+ parser.add_argument("--workers", default=4, type=int)
84
+ parser.add_argument("--lr", default=0.0003, type=float)
85
+ parser.add_argument("--ce_loss_weight", default=1.0, type=float)
86
+ parser.add_argument("--dice_loss_weight", default=0.5, type=float)
87
+ parser.add_argument("--bce_loss_weight", default=2.0, type=float)
88
+ parser.add_argument("--lora_alpha", default=16, type=int)
89
+ parser.add_argument("--lora_dropout", default=0.05, type=float)
90
+ parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
91
+ parser.add_argument("--explanatory", default=0.1, type=float)
92
+ parser.add_argument("--beta1", default=0.9, type=float)
93
+ parser.add_argument("--beta2", default=0.95, type=float)
94
+ parser.add_argument("--num_classes_per_sample", default=3, type=int)
95
+ parser.add_argument("--exclude_val", action="store_true", default=False)
96
+ parser.add_argument("--no_eval", action="store_true", default=False)
97
+ parser.add_argument("--eval_only", action="store_true", default=False)
98
+ parser.add_argument("--eval_affordance", action="store_true", default=False)
99
+ parser.add_argument("--eval_reason_aff", action="store_true", default=False)
100
+ parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
101
+ parser.add_argument("--out_dim", default=256, type=int)
102
+ parser.add_argument("--resume", default="", type=str)
103
+ parser.add_argument("--print_freq", default=1, type=int)
104
+ parser.add_argument("--start_epoch", default=0, type=int)
105
+ parser.add_argument("--gradient_checkpointing", action="store_true", default=True)
106
+ parser.add_argument("--train_mask_decoder", action="store_true", default=True)
107
+ parser.add_argument("--use_mm_start_end", action="store_true", default=True)
108
+ parser.add_argument("--auto_resume", action="store_true", default=True)
109
+ parser.add_argument(
110
+ "--conv_type",
111
+ default="llava_v1",
112
+ type=str,
113
+ choices=["llava_v1", "llava_llama_2"],
114
+ )
115
+ return parser.parse_args(args)
116
+
117
+
118
+ def main(args):
119
+ args = parse_args(args)
120
+ args.log_dir = os.path.join(args.log_base_dir, args.exp_name)
121
+ if args.local_rank == 0:
122
+ os.makedirs(args.log_dir, exist_ok=True)
123
+ writer = SummaryWriter(args.log_dir)
124
+ else:
125
+ writer = None
126
+
127
+ # Create model
128
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
129
+ args.version,
130
+ cache_dir=None,
131
+ model_max_length=args.model_max_length,
132
+ padding_side="right",
133
+ use_fast=False,
134
+ )
135
+ tokenizer.pad_token = tokenizer.unk_token
136
+ num_added_tokens = tokenizer.add_tokens("[SEG]")
137
+ args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
138
+ num_added_tokens = tokenizer.add_tokens("[AFF]")
139
+ args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
140
+
141
+ if args.use_mm_start_end:
142
+ tokenizer.add_tokens(
143
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
144
+ )
145
+
146
+ model_args = {
147
+ "train_mask_decoder": args.train_mask_decoder,
148
+ "out_dim": args.out_dim,
149
+ "ce_loss_weight": args.ce_loss_weight,
150
+ "dice_loss_weight": args.dice_loss_weight,
151
+ "bce_loss_weight": args.bce_loss_weight,
152
+ "seg_token_idx": args.seg_token_idx,
153
+ "aff_token_idx": args.aff_token_idx,
154
+ "vision_pretrained": args.vision_pretrained,
155
+ "vision_tower": args.vision_tower,
156
+ "use_mm_start_end": args.use_mm_start_end,
157
+ }
158
+ torch_dtype = torch.float32
159
+ if args.precision == "bf16":
160
+ torch_dtype = torch.bfloat16
161
+ elif args.precision == "fp16":
162
+ torch_dtype = torch.half
163
+ model = AffordanceVLMForCausalLM.from_pretrained(
164
+ args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args
165
+ )
166
+ model.config.eos_token_id = tokenizer.eos_token_id
167
+ model.config.bos_token_id = tokenizer.bos_token_id
168
+ model.config.pad_token_id = tokenizer.pad_token_id
169
+
170
+ model.enable_input_require_grads()
171
+ model.gradient_checkpointing_enable()
172
+
173
+ model.get_model().initialize_vision_modules(model.get_model().config)
174
+ vision_tower = model.get_model().get_vision_tower()
175
+ vision_tower.to(dtype=torch_dtype, device=args.local_rank)
176
+ if not args.eval_only:
177
+ model.get_model().initialize_lisa_modules(model.get_model().config)
178
+
179
+ for p in vision_tower.parameters():
180
+ p.requires_grad = False
181
+ for p in model.get_model().mm_projector.parameters():
182
+ p.requires_grad = False
183
+
184
+ conversation_lib.default_conversation = conversation_lib.conv_templates[
185
+ args.conv_type
186
+ ]
187
+
188
+ lora_r = args.lora_r
189
+ if lora_r > 0:
190
+
191
+ def find_linear_layers(model, lora_target_modules):
192
+ cls = torch.nn.Linear
193
+ lora_module_names = set()
194
+ for name, module in model.named_modules():
195
+ if (
196
+ isinstance(module, cls)
197
+ and all(
198
+ [
199
+ x not in name
200
+ for x in [
201
+ "visual_model",
202
+ "vision_tower",
203
+ "mm_projector",
204
+ "text_hidden_fcs",
205
+ ]
206
+ ]
207
+ )
208
+ and any([x in name for x in lora_target_modules])
209
+ ):
210
+ lora_module_names.add(name)
211
+ return sorted(list(lora_module_names))
212
+
213
+ lora_alpha = args.lora_alpha
214
+ lora_dropout = args.lora_dropout
215
+ lora_target_modules = find_linear_layers(
216
+ model, args.lora_target_modules.split(",")
217
+ )
218
+ lora_config = LoraConfig(
219
+ r=lora_r,
220
+ lora_alpha=lora_alpha,
221
+ target_modules=lora_target_modules,
222
+ lora_dropout=lora_dropout,
223
+ bias="none",
224
+ task_type="CAUSAL_LM",
225
+ )
226
+ model = get_peft_model(model, lora_config)
227
+ model.print_trainable_parameters()
228
+
229
+ model.resize_token_embeddings(len(tokenizer))
230
+
231
+ # make text_hidden_fcs, mask_decoder, lm_head, embed_tokens trainable
232
+ for n, p in model.named_parameters():
233
+ if any(
234
+ [
235
+ x in n
236
+ for x in ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs"]
237
+ ]
238
+ ):
239
+ print("n: ", n, "p.shape: ", p.shape)
240
+ p.requires_grad = True
241
+
242
+ world_size = torch.cuda.device_count()
243
+ args.distributed = world_size > 1
244
+ train_dataset = HybridDataset(
245
+ args.dataset_dir,
246
+ tokenizer,
247
+ args.vision_tower,
248
+ samples_per_epoch=args.batch_size
249
+ * args.grad_accumulation_steps
250
+ * args.steps_per_epoch
251
+ * world_size,
252
+ precision=args.precision,
253
+ image_size=args.image_size,
254
+ num_classes_per_sample=args.num_classes_per_sample,
255
+ exclude_val=args.exclude_val,
256
+ dataset=args.dataset,
257
+ sample_rate=[float(x) for x in args.sample_rates.split(",")],
258
+ sem_seg_data=args.sem_seg_data,
259
+ refer_seg_data=args.refer_seg_data,
260
+ vqa_data=args.vqa_data,
261
+ reason_seg_data=args.reason_seg_data,
262
+ aff_seg_data=args.aff_seg_data,
263
+ aff_sample_rate=[float(x) for x in args.aff_sample_rates.split(",")],
264
+ reason_aff_data=args.reason_aff_data,
265
+ reason_aff_sample_rate=[float(x) for x in args.reason_aff_sample_rates.split(",")],
266
+ explanatory=args.explanatory,
267
+ )
268
+
269
+ if args.no_eval == False:
270
+ if args.eval_affordance:
271
+ val_dataset = AffValDataset(
272
+ args.dataset_dir,
273
+ tokenizer,
274
+ args.vision_tower,
275
+ args.val_dataset,
276
+ args.image_size,
277
+ )
278
+ elif args.eval_reason_aff:
279
+ val_dataset = ReasonAffValDataset(
280
+ args.dataset_dir,
281
+ tokenizer,
282
+ args.vision_tower,
283
+ args.val_dataset,
284
+ args.image_size,
285
+ )
286
+ else:
287
+ val_dataset = ValDataset(
288
+ args.dataset_dir,
289
+ tokenizer,
290
+ args.vision_tower,
291
+ args.val_dataset,
292
+ args.image_size,
293
+ )
294
+ print(
295
+ f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples."
296
+ )
297
+ else:
298
+ val_dataset = None
299
+ print(f"Training with {len(train_dataset)} examples.")
300
+
301
+ ds_config = {
302
+ "train_micro_batch_size_per_gpu": args.batch_size,
303
+ "gradient_accumulation_steps": args.grad_accumulation_steps,
304
+ "optimizer": {
305
+ "type": "AdamW",
306
+ "params": {
307
+ "lr": args.lr,
308
+ "weight_decay": 0.0,
309
+ "betas": (args.beta1, args.beta2),
310
+ },
311
+ },
312
+ "scheduler": {
313
+ "type": "WarmupDecayLR",
314
+ "params": {
315
+ "total_num_steps": args.epochs * args.steps_per_epoch,
316
+ "warmup_min_lr": 0,
317
+ "warmup_max_lr": args.lr,
318
+ "warmup_num_steps": 100,
319
+ "warmup_type": "linear",
320
+ },
321
+ },
322
+ "fp16": {
323
+ "enabled": args.precision == "fp16",
324
+ },
325
+ "bf16": {
326
+ "enabled": args.precision == "bf16",
327
+ },
328
+ "gradient_clipping": 1.0,
329
+ "zero_optimization": {
330
+ "stage": 2,
331
+ "contiguous_gradients": True,
332
+ "overlap_comm": True,
333
+ "reduce_scatter": True,
334
+ "reduce_bucket_size": 5e8,
335
+ "allgather_bucket_size": 5e8,
336
+ },
337
+ }
338
+ model_engine, optimizer, train_loader, scheduler = deepspeed.initialize(
339
+ model=model,
340
+ model_parameters=model.parameters(),
341
+ training_data=train_dataset,
342
+ collate_fn=partial(
343
+ collate_fn,
344
+ tokenizer=tokenizer,
345
+ conv_type=args.conv_type,
346
+ use_mm_start_end=args.use_mm_start_end,
347
+ local_rank=args.local_rank,
348
+ ),
349
+ config=ds_config,
350
+ )
351
+
352
+ # resume deepspeed checkpoint
353
+ if args.auto_resume and len(args.resume) == 0:
354
+ resume = os.path.join(args.log_dir, "ckpt_model")
355
+ if os.path.exists(resume):
356
+ args.resume = resume
357
+
358
+ if args.resume:
359
+ load_path, client_state = model_engine.load_checkpoint(args.resume)
360
+ with open(os.path.join(args.resume, "latest"), "r") as f:
361
+ ckpt_dir = f.readlines()[0].strip()
362
+ args.start_epoch = (
363
+ int(ckpt_dir.replace("global_step", "")) // args.steps_per_epoch
364
+ )
365
+ print(
366
+ "resume training from {}, start from epoch {}".format(
367
+ args.resume, args.start_epoch
368
+ )
369
+ )
370
+
371
+ # validation dataset
372
+ if val_dataset is not None:
373
+ assert args.val_batch_size == 1
374
+ val_sampler = torch.utils.data.distributed.DistributedSampler(
375
+ val_dataset, shuffle=False, drop_last=False
376
+ )
377
+ val_loader = torch.utils.data.DataLoader(
378
+ val_dataset,
379
+ batch_size=args.val_batch_size,
380
+ shuffle=False,
381
+ num_workers=args.workers,
382
+ pin_memory=False,
383
+ sampler=val_sampler,
384
+ collate_fn=partial(
385
+ collate_fn,
386
+ tokenizer=tokenizer,
387
+ conv_type=args.conv_type,
388
+ use_mm_start_end=args.use_mm_start_end,
389
+ local_rank=args.local_rank,
390
+ ),
391
+ )
392
+
393
+ train_iter = iter(train_loader)
394
+ best_score, cur_ciou = 0.0, 0.0
395
+
396
+ if args.eval_only:
397
+ giou, ciou = validate(val_loader, model_engine, 0, writer, args)
398
+ if args.local_rank == 0:
399
+ with open(os.path.join(args.version, "eval_result.txt"), "a") as f:
400
+ f.write(f"dataset: {args.val_dataset}, giou: {giou}, ciou: {ciou} \n")
401
+ exit()
402
+
403
+ for epoch in range(args.start_epoch, args.epochs):
404
+ # train for one epoch
405
+ train_iter = train(
406
+ train_loader,
407
+ model_engine,
408
+ epoch,
409
+ scheduler,
410
+ writer,
411
+ train_iter,
412
+ args,
413
+ )
414
+
415
+ if args.no_eval == False:
416
+ giou, ciou = validate(val_loader, model_engine, epoch, writer, args)
417
+ is_best = giou > best_score
418
+ best_score = max(giou, best_score)
419
+ cur_ciou = ciou if is_best else cur_ciou
420
+
421
+ if args.no_eval or is_best:
422
+ save_dir = os.path.join(args.log_dir, "ckpt_model")
423
+ if args.local_rank == 0:
424
+ torch.save(
425
+ {"epoch": epoch},
426
+ os.path.join(
427
+ args.log_dir,
428
+ "meta_log_giou{:.3f}_ciou{:.3f}.pth".format(
429
+ best_score, cur_ciou
430
+ ),
431
+ ),
432
+ )
433
+ if os.path.exists(save_dir):
434
+ shutil.rmtree(save_dir)
435
+ torch.distributed.barrier()
436
+ model_engine.save_checkpoint(save_dir)
437
+
438
+
439
+ def train(
440
+ train_loader,
441
+ model,
442
+ epoch,
443
+ scheduler,
444
+ writer,
445
+ train_iter,
446
+ args,
447
+ ):
448
+ """Main training loop."""
449
+ batch_time = AverageMeter("Time", ":6.3f")
450
+ data_time = AverageMeter("Data", ":6.3f")
451
+ losses = AverageMeter("Loss", ":.4f")
452
+ ce_losses = AverageMeter("CeLoss", ":.4f")
453
+ mask_bce_losses = AverageMeter("MaskBCELoss", ":.4f")
454
+ mask_dice_losses = AverageMeter("MaskDICELoss", ":.4f")
455
+ mask_losses = AverageMeter("MaskLoss", ":.4f")
456
+
457
+ progress = ProgressMeter(
458
+ args.steps_per_epoch,
459
+ [
460
+ batch_time,
461
+ losses,
462
+ ce_losses,
463
+ mask_losses,
464
+ mask_bce_losses,
465
+ mask_dice_losses,
466
+ ],
467
+ prefix="Epoch: [{}]".format(epoch),
468
+ )
469
+
470
+ # switch to train mode
471
+ model.train()
472
+ end = time.time()
473
+ for global_step in range(args.steps_per_epoch):
474
+ for i in range(args.grad_accumulation_steps):
475
+ try:
476
+ input_dict = next(train_iter)
477
+ except:
478
+ train_iter = iter(train_loader)
479
+ input_dict = next(train_iter)
480
+
481
+ data_time.update(time.time() - end)
482
+ input_dict = dict_to_cuda(input_dict)
483
+
484
+ if args.precision == "fp16":
485
+ input_dict["images"] = input_dict["images"].half()
486
+ input_dict["images_clip"] = input_dict["images_clip"].half()
487
+ elif args.precision == "bf16":
488
+ input_dict["images"] = input_dict["images"].bfloat16()
489
+ input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
490
+ else:
491
+ input_dict["images"] = input_dict["images"].float()
492
+ input_dict["images_clip"] = input_dict["images_clip"].float()
493
+
494
+ output_dict = model(**input_dict)
495
+
496
+ loss = output_dict["loss"]
497
+ ce_loss = output_dict["ce_loss"]
498
+ mask_bce_loss = output_dict["mask_bce_loss"]
499
+ mask_dice_loss = output_dict["mask_dice_loss"]
500
+ mask_loss = output_dict["mask_loss"]
501
+
502
+ losses.update(loss.item(), input_dict["images"].size(0))
503
+ ce_losses.update(ce_loss.item(), input_dict["images"].size(0))
504
+ mask_bce_losses.update(mask_bce_loss.item(), input_dict["images"].size(0))
505
+ mask_dice_losses.update(mask_dice_loss.item(), input_dict["images"].size(0))
506
+ mask_losses.update(mask_loss.item(), input_dict["images"].size(0))
507
+ model.backward(loss)
508
+ model.step()
509
+
510
+ # measure elapsed time
511
+ batch_time.update(time.time() - end)
512
+ end = time.time()
513
+
514
+ if global_step % args.print_freq == 0:
515
+ if args.distributed:
516
+ batch_time.all_reduce()
517
+ data_time.all_reduce()
518
+
519
+ losses.all_reduce()
520
+ ce_losses.all_reduce()
521
+ mask_bce_losses.all_reduce()
522
+ mask_dice_losses.all_reduce()
523
+ mask_losses.all_reduce()
524
+
525
+ if args.local_rank == 0:
526
+ progress.display(global_step + 1)
527
+ writer.add_scalar("train/loss", losses.avg, global_step)
528
+ writer.add_scalar("train/ce_loss", ce_losses.avg, global_step)
529
+ writer.add_scalar(
530
+ "train/mask_bce_loss", mask_bce_losses.avg, global_step
531
+ )
532
+ writer.add_scalar(
533
+ "train/mask_dice_loss", mask_dice_losses.avg, global_step
534
+ )
535
+ writer.add_scalar("train/mask_loss", mask_losses.avg, global_step)
536
+ writer.add_scalar(
537
+ "metrics/total_secs_per_batch", batch_time.avg, global_step
538
+ )
539
+ writer.add_scalar(
540
+ "metrics/data_secs_per_batch", data_time.avg, global_step
541
+ )
542
+
543
+ batch_time.reset()
544
+ data_time.reset()
545
+ losses.reset()
546
+ ce_losses.reset()
547
+ mask_bce_losses.reset()
548
+ mask_dice_losses.reset()
549
+ mask_losses.reset()
550
+
551
+ if global_step != 0:
552
+ curr_lr = scheduler.get_last_lr()
553
+ if args.local_rank == 0:
554
+ writer.add_scalar("train/lr", curr_lr[0], global_step)
555
+
556
+ return train_iter
557
+
558
+
559
+ def validate(val_loader, model_engine, epoch, writer, args):
560
+ intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
561
+ union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
562
+ acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
563
+
564
+ model_engine.eval()
565
+
566
+ for input_dict in tqdm.tqdm(val_loader):
567
+ torch.cuda.empty_cache()
568
+
569
+ input_dict = dict_to_cuda(input_dict)
570
+ if args.precision == "fp16":
571
+ input_dict["images"] = input_dict["images"].half()
572
+ input_dict["images_clip"] = input_dict["images_clip"].half()
573
+ elif args.precision == "bf16":
574
+ input_dict["images"] = input_dict["images"].bfloat16()
575
+ input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
576
+ else:
577
+ input_dict["images"] = input_dict["images"].float()
578
+ input_dict["images_clip"] = input_dict["images_clip"].float()
579
+
580
+ with torch.no_grad():
581
+ output_dict = model_engine(**input_dict)
582
+
583
+ pred_masks = output_dict["pred_masks"]
584
+ masks_list = output_dict["gt_masks"][0].int()
585
+ output_list = (pred_masks[0] > 0).int()
586
+ assert len(pred_masks) == 1
587
+
588
+ intersection, union, acc_iou = 0.0, 0.0, 0.0
589
+ for mask_i, output_i in zip(masks_list, output_list):
590
+ intersection_i, union_i, _ = intersectionAndUnionGPU(
591
+ output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255
592
+ )
593
+ intersection += intersection_i
594
+ union += union_i
595
+ acc_iou += intersection_i / (union_i + 1e-5)
596
+ acc_iou[union_i == 0] += 1.0 # no-object target
597
+ intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
598
+ acc_iou = acc_iou.cpu().numpy() / masks_list.shape[0]
599
+ intersection_meter.update(intersection), union_meter.update(
600
+ union
601
+ ), acc_iou_meter.update(acc_iou, n=masks_list.shape[0])
602
+
603
+ intersection_meter.all_reduce()
604
+ union_meter.all_reduce()
605
+ acc_iou_meter.all_reduce()
606
+
607
+ iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
608
+ ciou = iou_class[1]
609
+ giou = acc_iou_meter.avg[1]
610
+
611
+ if args.local_rank == 0:
612
+ writer.add_scalar("val/giou", giou, epoch)
613
+ writer.add_scalar("val/ciou", ciou, epoch)
614
+ print("giou: {:.4f}, ciou: {:.4f}".format(giou, ciou))
615
+
616
+ return giou, ciou
617
+
618
+
619
+ if __name__ == "__main__":
620
+ main(sys.argv[1:])
README.md CHANGED
@@ -1,3 +1,79 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>
3
+ <b>
4
+ RAGNet: Large-scale Reasoning-based Affordance Segmentation Benchmark towards General Grasping
5
+ </b>
6
+ </h1>
7
+ </div>
8
+
9
+ <div align="center">
10
+
11
+ | [**📑 Paper**](https://arxiv.org/abs/2507.23734) | [**🤗 Model**](https://huggingface.co/Dongming97/AffordanceVLM) | [**🤗 Dataset**](https://huggingface.co/datasets/Dongming97/RAGNet) | [**🖥️ Website**](https://wudongming97.github.io/RAGNet/) |
12
+
13
+ </div>
14
+
15
+
16
+ <p align="center"><img src="./imgs/AffordanceNet.jpg" width="800"/></p>
17
+
18
+
19
+ > **[RAGNet: Large-scale Reasoning-based Affordance Segmentation Benchmark towards General Grasping](https://arxiv.org/abs/2507.23734)**
20
+ >
21
+ > Dongming Wu, Yanping Fu, Saike Huang, Yingfei Liu, Fan Jia, Nian Liu, Feng Dai, Tiancai Wang, Rao Muhammad Anwer, Fahad Shahbaz Khan, Jianbing Shen
22
+
23
+ ## 📝 TL;DR
24
+ - To push forward general robotic grasping, we introduce a large-scale reasoning-based affordance segmentation benchmark, **RAGNet**. It contains 273k images, 180 categories, and 26k reasoning instructions.
25
+ - Furthermore, we propose a comprehensive affordance-based grasping framework, named AffordanceNet, which consists of a VLM (named AffordanceVLM) pre-trained on our massive affordance data and a grasping network that conditions an affordance map to grasp the target.
26
+
27
+ ---
28
+
29
+ ## 📰 News
30
+ - [2025.08] Paper is released at [arXiv](https://arxiv.org/abs/2507.23734).
31
+ - [2025.07] Inference code and the [AffordanceVLM](https://huggingface.co/Dongming97/AffordanceVLM) model are released. Welcome to try it!
32
+ - [2025.06] Paper is accepted by ICCV 2025!
33
+
34
+ ---
35
+
36
+ ## 🚀 Getting Started
37
+
38
+ * [Installation](docs/installation.md)
39
+ * [Download dataset](docs/dataset.md)
40
+ * [Training and evaluation](docs/training_and_evaluation.md)
41
+ * To deploy using Gradio, run the following command:
42
+
43
+ ```bash
44
+ python app.py --version='./exps/AffordanceVLM-7B'
45
+ ```
46
+
47
+
48
+
49
+ ## 📊 Main Results
50
+ ### 🔹 Affordance Segmentation
51
+ | Method | HANDAL gIoU | HANDAL cIoU | HANDAL† gIoU | HANDAL† cIoU | GraspNet seen gIoU | GraspNet seen cIoU | GraspNet novel gIoU | GraspNet novel cIoU | 3DOI gIoU | 3DOI cIoU |
52
+ |--------------------------------------|-------------|-------------|---------------|---------------|----------------------|----------------------|------------------------|------------------------|------------|------------|
53
+ | AffordanceNet | 60.3| 60.8 |60.5|60.3|63.3 |64.0| 45.6 |33.2 | 37.4| 37.4 |
54
+
55
+ ### 🔸 Reasoning-Based Affordance Segmentation
56
+
57
+ | Method | HANDAL (easy) gIoU | HANDAL (easy) cIoU | HANDAL (hard) gIoU | HANDAL (hard) cIoU | 3DOI gIoU | 3DOI cIoU |
58
+ |---------|---------------------|---------------------|---------------------|---------------------|-----------|-----------|
59
+ | AffordanceNet| 58.3| 58.1 | 58.2| 57.8 | 38.1 | 39.4|
60
+
61
+
62
+ ## 📚 Citation
63
+ If you find our work useful, please consider citing:
64
+
65
+ ```bibtex
66
+ @inproceedings{wu2025ragnet,
67
+ title={RAGNet: Large-scale Reasoning-based Affordance Segmentation Benchmark towards General Grasping},
68
+ author={Wu, Dongming and Fu, Yanping and Huang, Saike and Liu, Yingfei and Jia, Fan and Liu, Nian and Dai, Feng and Wang, Tiancai and Anwer, Rao Muhammad and Khan, Fahad Shahbaz and others},
69
+ booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
70
+ pages={11980--11990},
71
+ year={2025}
72
+ }
73
+ ```
74
+
75
+ ## 🙏 Acknowledgements
76
+ We thank the authors that open the following projects.
77
+ - [LISA](https://github.com/dvlab-research/LISA)
78
+ - [LLaVA](https://github.com/haotian-liu/LLaVA)
79
+ - [SAM](https://github.com/facebookresearch/segment-anything)
app.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+ import sys
5
+
6
+ import bleach
7
+ import cv2
8
+ import gradio as gr
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from PIL import Image
13
+ from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
14
+
15
+ from model.AffordanceVLM import AffordanceVLMForCausalLM
16
+ from model.llava import conversation as conversation_lib
17
+ from model.llava.mm_utils import tokenizer_image_token
18
+ from model.segment_anything.utils.transforms import ResizeLongestSide
19
+ from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
20
+ DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
21
+
22
+ from datetime import datetime
23
+
24
+
25
+ def parse_args(args):
26
+ parser = argparse.ArgumentParser(description="AffordanceVLM chat")
27
+ parser.add_argument("--version", default="./exps/AffordanceVLM-7B")
28
+ parser.add_argument("--vis_save_path", default="./vis_output", type=str)
29
+ parser.add_argument(
30
+ "--precision",
31
+ default="bf16",
32
+ type=str,
33
+ choices=["fp32", "bf16", "fp16"],
34
+ help="precision for inference",
35
+ )
36
+ parser.add_argument("--image_size", default=1024, type=int, help="image size")
37
+ parser.add_argument("--model_max_length", default=512, type=int)
38
+ parser.add_argument("--lora_r", default=8, type=int)
39
+ parser.add_argument(
40
+ "--vision-tower", default="openai/clip-vit-large-patch14", type=str
41
+ )
42
+ parser.add_argument("--local-rank", default=0, type=int, help="node rank")
43
+ parser.add_argument("--load_in_8bit", action="store_true", default=False)
44
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
45
+ parser.add_argument("--use_mm_start_end", action="store_true", default=True)
46
+ parser.add_argument(
47
+ "--conv_type",
48
+ default="llava_v1",
49
+ type=str,
50
+ choices=["llava_v1", "llava_llama_2"],
51
+ )
52
+ return parser.parse_args(args)
53
+
54
+
55
+ def preprocess(
56
+ x,
57
+ pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
58
+ pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
59
+ img_size=1024,
60
+ ) -> torch.Tensor:
61
+ """Normalize pixel values and pad to a square input."""
62
+ # Normalize colors
63
+ x = (x - pixel_mean) / pixel_std
64
+ # Pad
65
+ h, w = x.shape[-2:]
66
+ padh = img_size - h
67
+ padw = img_size - w
68
+ x = F.pad(x, (0, padw, 0, padh))
69
+ return x
70
+
71
+ args = parse_args(sys.argv[1:])
72
+ os.makedirs(args.vis_save_path, exist_ok=True)
73
+
74
+ # Create model
75
+ tokenizer = AutoTokenizer.from_pretrained(
76
+ args.version,
77
+ cache_dir=None,
78
+ model_max_length=args.model_max_length,
79
+ padding_side="right",
80
+ use_fast=False,
81
+ )
82
+ tokenizer.pad_token = tokenizer.unk_token
83
+ args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
84
+ args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
85
+
86
+ torch_dtype = torch.float32
87
+ if args.precision == "bf16":
88
+ torch_dtype = torch.bfloat16
89
+ elif args.precision == "fp16":
90
+ torch_dtype = torch.half
91
+
92
+ kwargs = {"torch_dtype": torch_dtype}
93
+ if args.load_in_4bit:
94
+ kwargs.update(
95
+ {
96
+ "torch_dtype": torch.half,
97
+ "load_in_4bit": True,
98
+ "quantization_config": BitsAndBytesConfig(
99
+ load_in_4bit=True,
100
+ bnb_4bit_compute_dtype=torch.float16,
101
+ bnb_4bit_use_double_quant=True,
102
+ bnb_4bit_quant_type="nf4",
103
+ llm_int8_skip_modules=["visual_model"],
104
+ ),
105
+ }
106
+ )
107
+ elif args.load_in_8bit:
108
+ kwargs.update(
109
+ {
110
+ "torch_dtype": torch.half,
111
+ "quantization_config": BitsAndBytesConfig(
112
+ llm_int8_skip_modules=["visual_model"],
113
+ load_in_8bit=True,
114
+ ),
115
+ }
116
+ )
117
+
118
+ model = AffordanceVLMForCausalLM.from_pretrained(
119
+ args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, aff_token_idx=args.aff_token_idx, **kwargs
120
+ )
121
+
122
+ model.config.eos_token_id = tokenizer.eos_token_id
123
+ model.config.bos_token_id = tokenizer.bos_token_id
124
+ model.config.pad_token_id = tokenizer.pad_token_id
125
+
126
+ model.get_model().initialize_vision_modules(model.get_model().config)
127
+ vision_tower = model.get_model().get_vision_tower()
128
+ vision_tower.to(dtype=torch_dtype)
129
+
130
+ if args.precision == "bf16":
131
+ model = model.bfloat16().cuda()
132
+ elif (
133
+ args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit)
134
+ ):
135
+ vision_tower = model.get_model().get_vision_tower()
136
+ model.model.vision_tower = None
137
+ import deepspeed
138
+
139
+ model_engine = deepspeed.init_inference(
140
+ model=model,
141
+ dtype=torch.half,
142
+ replace_with_kernel_inject=True,
143
+ replace_method="auto",
144
+ )
145
+ model = model_engine.module
146
+ model.model.vision_tower = vision_tower.half().cuda()
147
+ elif args.precision == "fp32":
148
+ model = model.float().cuda()
149
+
150
+ vision_tower = model.get_model().get_vision_tower()
151
+ vision_tower.to(device=args.local_rank)
152
+
153
+ clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
154
+ transform = ResizeLongestSide(args.image_size)
155
+
156
+ model.eval()
157
+
158
+
159
+ # Gradio
160
+ examples = [
161
+ [
162
+ "Please segment the affordance map of mug in this image.",
163
+ "/data/AffordanceNet/vis_output/my_workspace.JPG",
164
+ ],
165
+ ]
166
+ output_labels = ["Segmentation Output"]
167
+
168
+ title = "RAGNet: Large-scale Reasoning-based Affordance Segmentation Benchmark towards General Grasping"
169
+
170
+ description = """
171
+ <font size=4>
172
+ This is the online demo of AffordanceVLM. \n
173
+ **Note**: **Different prompts can lead to significantly varied results**. \n
174
+ **Note**: Please try to **standardize** your input text prompts to **avoid ambiguity**, and also pay attention to whether the **punctuations** of the input are correct. \n
175
+ **Note**: Current model is **AffordanceVLM-7B**. \n
176
+ **Usage**: <br>
177
+ To let AffordanceVLM **segment something**, input prompt like: "Can you segment the affordance map of xxx in this image?", "What is the affordance map of xxx in this image?"; <br>
178
+ </font>
179
+ """
180
+
181
+ article = """
182
+ <p style='text-align: center'>
183
+ <a href='https://arxiv.org/abs/2507.23734' target='_blank'>
184
+ Preprint Paper
185
+ </a>
186
+ \n
187
+ <p style='text-align: center'>
188
+ <a href='https://github.com/wudongming97/AffordanceNet' target='_blank'> Github Repo </a></p>
189
+ """
190
+
191
+
192
+ ## to be implemented
193
+ def inference(input_str, input_image):
194
+ ## filter out special chars
195
+ input_str = bleach.clean(input_str)
196
+
197
+ print("input_str: ", input_str, "input_image: ", input_image)
198
+
199
+ ## input valid check
200
+ if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
201
+ output_str = "[Error] Invalid input: ", input_str
202
+ # output_image = np.zeros((128, 128, 3))
203
+ ## error happened
204
+ output_image = cv2.imread("./resources/error_happened.png")[:, :, ::-1]
205
+ return output_image, output_str
206
+
207
+ # Model Inference
208
+ conv = conversation_lib.conv_templates[args.conv_type].copy()
209
+ conv.messages = []
210
+
211
+ prompt = input_str
212
+ prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + prompt
213
+ if args.use_mm_start_end:
214
+ replace_token = (
215
+ DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
216
+ )
217
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
218
+
219
+ conv.append_message(conv.roles[0], prompt)
220
+ conv.append_message(conv.roles[1], "")
221
+ prompt = conv.get_prompt()
222
+
223
+ image_np = cv2.imread(input_image)
224
+
225
+ # save the input image
226
+ SAVE_DIR = "./gradio_images/"
227
+ os.makedirs(SAVE_DIR, exist_ok=True)
228
+
229
+ # generate a timestamped filename
230
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
231
+ filename = f"{timestamp}.png"
232
+ save_path = os.path.join(SAVE_DIR, filename)
233
+
234
+ # save the image
235
+ cv2.imwrite(save_path, image_np)
236
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
237
+ original_size_list = [image_np.shape[:2]]
238
+
239
+ image_clip = (
240
+ clip_image_processor.preprocess(image_np, return_tensors="pt")[
241
+ "pixel_values"
242
+ ][0]
243
+ .unsqueeze(0)
244
+ .cuda()
245
+ )
246
+ if args.precision == "bf16":
247
+ image_clip = image_clip.bfloat16()
248
+ elif args.precision == "fp16":
249
+ image_clip = image_clip.half()
250
+ else:
251
+ image_clip = image_clip.float()
252
+
253
+ image = transform.apply_image(image_np)
254
+ resize_list = [image.shape[:2]]
255
+
256
+ image = (
257
+ preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
258
+ .unsqueeze(0)
259
+ .cuda()
260
+ )
261
+ if args.precision == "bf16":
262
+ image = image.bfloat16()
263
+ elif args.precision == "fp16":
264
+ image = image.half()
265
+ else:
266
+ image = image.float()
267
+
268
+ input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
269
+ input_ids = input_ids.unsqueeze(0).cuda()
270
+
271
+ output_ids, pred_masks = model.evaluate(
272
+ image_clip,
273
+ image,
274
+ input_ids,
275
+ resize_list,
276
+ original_size_list,
277
+ max_new_tokens=512,
278
+ tokenizer=tokenizer,
279
+ )
280
+ output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
281
+
282
+ text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
283
+ text_output = text_output.replace("\n", "").replace(" ", " ")
284
+ text_output = text_output.split("ASSISTANT: ")[-1].replace('</s>', '')
285
+
286
+ print("text_output: ", text_output)
287
+ save_img = None
288
+ for i, pred_mask in enumerate(pred_masks):
289
+ if pred_mask.shape[0] == 0:
290
+ continue
291
+
292
+ pred_mask = pred_mask.detach().cpu().numpy()[0]
293
+ pred_mask = pred_mask > 0
294
+
295
+ save_img = image_np.copy()
296
+ save_img[pred_mask] = (
297
+ image_np * 0.5
298
+ + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
299
+ )[pred_mask]
300
+
301
+ output_str = "ASSITANT: " + text_output # input_str
302
+ if save_img is not None:
303
+ output_image = save_img # input_image
304
+ else:
305
+ ## no seg output
306
+ output_image = cv2.imread("./resources/no_seg_out.png")[:, :, ::-1]
307
+ return output_image, output_str
308
+
309
+
310
+ demo = gr.Interface(
311
+ inference,
312
+ inputs=[
313
+ gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
314
+ gr.Image(type="filepath", label="Input Image"),
315
+ ],
316
+ outputs=[
317
+ gr.Image(type="pil", label="Affordance Output"),
318
+ gr.Textbox(lines=1, placeholder=None, label="Text Output"),
319
+ ],
320
+ title=title,
321
+ description=description,
322
+ article=article,
323
+ examples=examples,
324
+ allow_flagging="auto",
325
+ )
326
+
327
+ demo.queue()
328
+ # demo.launch()
329
+ demo.launch(server_name="0.0.0.0", server_port=3200)
batch_generate.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Batch generate affordance masks for all four LIBERO subsets sequentially.
3
+
4
+ SRC_ROOT="/gemini/space/wrz/libero_per_frame"
5
+ TGT_ROOT="/gemini/space/wrz/ragnet_results"
6
+
7
+ for ds in libero_object libero_goal libero_spatial libero_10; do
8
+ echo "========== Processing ${ds} =========="
9
+ CUDA_VISIBLE_DEVICES=0 python batch_generate.py \
10
+ --data_dir "${SRC_ROOT}/${ds}_converted" \
11
+ --save_dir "${TGT_ROOT}/${ds}"
12
+ echo "========== ${ds} done =========="
13
+ echo
14
+ done
batch_generate_prefill_accelerate.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batch affordance mask generation for per-step datasets.
3
+
4
+ Reads a per-step dataset (converted by convert_lerobot_to_perstep.py) and
5
+ generates affordance masks for every image_primary.jpg and image_wrist.jpg
6
+ using AffordanceVLM.
7
+
8
+ Input structure:
9
+ {data_dir}/
10
+ ├── meta_info.h5
11
+ └── episodes/
12
+ └── {episode_id:06d}/
13
+ └── steps/
14
+ └── {step_id:04d}/
15
+ ├── other.h5 # language_instruction
16
+ ├── image_primary.jpg
17
+ └── image_wrist.jpg
18
+
19
+ Output structure:
20
+ {save_dir}/
21
+ └── episodes/
22
+ └── {episode_id:06d}/
23
+ └── steps/
24
+ └── {step_id:04d}/
25
+ ├── image_primary_mask.png # binary 0/255
26
+ └── image_wrist_mask.png
27
+
28
+ Usage:
29
+ CUDA_VISIBLE_DEVICES=1 python batch_generate_prefill_accelerate.py \
30
+ --data_dir /gemini/space/wrz/libero_per_frame/libero_spatial_converted \
31
+ --save_dir /gemini/space/wrz/ragnet_results/libero_spatial
32
+ """
33
+
34
+ import argparse
35
+ import os
36
+ import sys
37
+ from pathlib import Path
38
+
39
+ import cv2
40
+ import h5py
41
+ import numpy as np
42
+ import torch
43
+ import torch.nn.functional as F
44
+ from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
45
+
46
+ from model.AffordanceVLM import AffordanceVLMForCausalLM
47
+ from model.llava import conversation as conversation_lib
48
+ from model.llava.mm_utils import tokenizer_image_token
49
+ from model.segment_anything.utils.transforms import ResizeLongestSide
50
+ from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
51
+ DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
52
+
53
+
54
+ def parse_args(args):
55
+ parser = argparse.ArgumentParser(
56
+ description="Batch affordance mask generation for per-step datasets"
57
+ )
58
+ # Model arguments (same as chat.py)
59
+ parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B")
60
+ parser.add_argument(
61
+ "--precision", default="bf16", type=str,
62
+ choices=["fp32", "bf16", "fp16"],
63
+ )
64
+ parser.add_argument("--image_size", default=1024, type=int)
65
+ parser.add_argument("--model_max_length", default=512, type=int)
66
+ parser.add_argument("--lora_r", default=8, type=int)
67
+ parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14", type=str)
68
+ parser.add_argument("--local-rank", default=0, type=int)
69
+ parser.add_argument("--load_in_8bit", action="store_true", default=False)
70
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
71
+ parser.add_argument("--use_mm_start_end", action="store_true", default=True)
72
+ parser.add_argument(
73
+ "--conv_type", default="llava_v1", type=str,
74
+ choices=["llava_v1", "llava_llama_2"],
75
+ )
76
+
77
+ # Batch processing arguments
78
+ parser.add_argument("--data_dir", type=str, required=True,
79
+ help="Root of per-step dataset (contains episodes/)")
80
+ parser.add_argument("--save_dir", type=str, required=True,
81
+ help="Output directory for masks")
82
+ parser.add_argument("--prompt_template", type=str,
83
+ default="{}",
84
+ help="Template wrapping language_instruction. Use {} as placeholder.")
85
+ # "{}"
86
+ # Segment the most suitable manipulation region on the single target object for the task '{}'.
87
+ # Segment the affordance map for the task '{}' in this image.
88
+ # Segment the affordance map of the single target object for the task '{}' in this image.
89
+ # Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask.
90
+ # Given the task instruction '{}', what is the affordance map of the single target object in this image? There is only one target object. Please output segmentation mask.
91
+ parser.add_argument("--start_episode", type=int, default=None,
92
+ help="First episode index to process (inclusive)")
93
+ parser.add_argument("--end_episode", type=int, default=None,
94
+ help="Last episode index to process (exclusive)")
95
+ return parser.parse_args(args)
96
+
97
+
98
+ def preprocess(
99
+ x,
100
+ pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
101
+ pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
102
+ img_size=1024,
103
+ ) -> torch.Tensor:
104
+ """Normalize pixel values and pad to a square input."""
105
+ x = (x - pixel_mean) / pixel_std
106
+ h, w = x.shape[-2:]
107
+ padh = img_size - h
108
+ padw = img_size - w
109
+ x = F.pad(x, (0, padw, 0, padh))
110
+ return x
111
+
112
+
113
+ def load_model(args):
114
+ """Load tokenizer and model, identical to chat.py."""
115
+ tokenizer = AutoTokenizer.from_pretrained(
116
+ args.version,
117
+ cache_dir=None,
118
+ model_max_length=args.model_max_length,
119
+ padding_side="right",
120
+ use_fast=False,
121
+ )
122
+ tokenizer.pad_token = tokenizer.unk_token
123
+ tokenizer.add_tokens("[SEG]")
124
+ args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
125
+ tokenizer.add_tokens("[AFF]")
126
+ args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
127
+
128
+ torch_dtype = torch.float32
129
+ if args.precision == "bf16":
130
+ torch_dtype = torch.bfloat16
131
+ elif args.precision == "fp16":
132
+ torch_dtype = torch.half
133
+
134
+ kwargs = {"torch_dtype": torch_dtype}
135
+ if args.load_in_4bit:
136
+ kwargs.update({
137
+ "torch_dtype": torch.half,
138
+ "load_in_4bit": True,
139
+ "quantization_config": BitsAndBytesConfig(
140
+ load_in_4bit=True,
141
+ bnb_4bit_compute_dtype=torch.float16,
142
+ bnb_4bit_use_double_quant=True,
143
+ bnb_4bit_quant_type="nf4",
144
+ llm_int8_skip_modules=["visual_model"],
145
+ ),
146
+ })
147
+ elif args.load_in_8bit:
148
+ kwargs.update({
149
+ "torch_dtype": torch.half,
150
+ "quantization_config": BitsAndBytesConfig(
151
+ llm_int8_skip_modules=["visual_model"],
152
+ load_in_8bit=True,
153
+ ),
154
+ })
155
+
156
+ model = AffordanceVLMForCausalLM.from_pretrained(
157
+ args.version,
158
+ low_cpu_mem_usage=True,
159
+ vision_tower=args.vision_tower,
160
+ seg_token_idx=args.seg_token_idx,
161
+ aff_token_idx=args.aff_token_idx,
162
+ **kwargs,
163
+ )
164
+
165
+ model.config.eos_token_id = tokenizer.eos_token_id
166
+ model.config.bos_token_id = tokenizer.bos_token_id
167
+ model.config.pad_token_id = tokenizer.pad_token_id
168
+
169
+ model.get_model().initialize_vision_modules(model.get_model().config)
170
+ vision_tower = model.get_model().get_vision_tower()
171
+ vision_tower.to(dtype=torch_dtype)
172
+
173
+ if args.precision == "bf16":
174
+ model = model.bfloat16().cuda()
175
+ elif args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit):
176
+ vision_tower = model.get_model().get_vision_tower()
177
+ model.model.vision_tower = None
178
+ import deepspeed
179
+ model_engine = deepspeed.init_inference(
180
+ model=model,
181
+ dtype=torch.half,
182
+ replace_with_kernel_inject=True,
183
+ replace_method="auto",
184
+ )
185
+ model = model_engine.module
186
+ model.model.vision_tower = vision_tower.half().cuda()
187
+ elif args.precision == "fp32":
188
+ model = model.float().cuda()
189
+
190
+ vision_tower = model.get_model().get_vision_tower()
191
+ vision_tower.to(device=args.local_rank)
192
+
193
+ clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
194
+ transform = ResizeLongestSide(args.image_size)
195
+
196
+ model.eval()
197
+ return model, tokenizer, clip_image_processor, transform
198
+
199
+
200
+ def build_prompt(text: str, args) -> str:
201
+ """Build the full conversation prompt from a text query."""
202
+ conv = conversation_lib.conv_templates[args.conv_type].copy()
203
+ conv.messages = []
204
+
205
+ prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + text
206
+ if args.use_mm_start_end:
207
+ replace_token = (
208
+ DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
209
+ )
210
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
211
+
212
+ conv.append_message(conv.roles[0], prompt)
213
+ conv.append_message(conv.roles[1], "[AFF].")
214
+ return conv.get_prompt()
215
+
216
+
217
+ def infer_single_image(
218
+ image_path: str,
219
+ prompt_str: str,
220
+ model,
221
+ tokenizer,
222
+ clip_image_processor,
223
+ transform,
224
+ args,
225
+ ) -> "np.ndarray | None":
226
+ """Run inference on a single image. Returns binary mask (H, W) uint8 0/255 or None."""
227
+ image_np = cv2.imread(image_path)
228
+ if image_np is None:
229
+ print(f" [WARNING] Cannot read image: {image_path}")
230
+ return None
231
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
232
+ original_size_list = [image_np.shape[:2]]
233
+
234
+ # CLIP preprocessing
235
+ image_clip = (
236
+ clip_image_processor.preprocess(image_np, return_tensors="pt")["pixel_values"][0]
237
+ .unsqueeze(0)
238
+ .cuda()
239
+ )
240
+ if args.precision == "bf16":
241
+ image_clip = image_clip.bfloat16()
242
+ elif args.precision == "fp16":
243
+ image_clip = image_clip.half()
244
+ else:
245
+ image_clip = image_clip.float()
246
+
247
+ # SAM preprocessing
248
+ image = transform.apply_image(image_np)
249
+ resize_list = [image.shape[:2]]
250
+ image = (
251
+ preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
252
+ .unsqueeze(0)
253
+ .cuda()
254
+ )
255
+ if args.precision == "bf16":
256
+ image = image.bfloat16()
257
+ elif args.precision == "fp16":
258
+ image = image.half()
259
+ else:
260
+ image = image.float()
261
+
262
+ # Tokenize
263
+ input_ids = tokenizer_image_token(prompt_str, tokenizer, return_tensors="pt")
264
+ input_ids = input_ids.unsqueeze(0).cuda()
265
+ attention_masks = input_ids.ne(tokenizer.pad_token_id)
266
+
267
+ # Prefill inference (single forward pass instead of autoregressive generation)
268
+ h, w = original_size_list[0]
269
+ labels = input_ids.clone()
270
+ offset = torch.LongTensor([0, 1]).cuda()
271
+ masks_list = [torch.zeros(1, h, w).float().cuda()]
272
+ label_list = [torch.zeros(h, w).long().cuda()]
273
+
274
+ with torch.no_grad():
275
+ output_dict = model(
276
+ images=image,
277
+ images_clip=image_clip,
278
+ input_ids=input_ids,
279
+ labels=labels,
280
+ attention_masks=attention_masks,
281
+ offset=offset,
282
+ masks_list=masks_list,
283
+ label_list=label_list,
284
+ resize_list=resize_list,
285
+ inference=True,
286
+ )
287
+
288
+ pred_masks = output_dict["pred_masks"]
289
+
290
+ # Merge all predicted masks via union (logical OR)
291
+ merged = np.zeros((h, w), dtype=bool)
292
+ has_mask = False
293
+ for pred_mask in pred_masks:
294
+ if pred_mask.shape[0] == 0:
295
+ continue
296
+ mask_np = pred_mask.detach().cpu().numpy()[0] # (H, W)
297
+ merged |= (mask_np > 0)
298
+ has_mask = True
299
+
300
+ if not has_mask:
301
+ return None
302
+
303
+ return (merged.astype(np.uint8) * 255)
304
+
305
+
306
+ def read_language_instruction(h5_path: str) -> str:
307
+ """Read language_instruction from other.h5."""
308
+ with h5py.File(h5_path, "r") as f:
309
+ instr = f["language_instruction"][()]
310
+ if isinstance(instr, bytes):
311
+ instr = instr.decode("utf-8")
312
+ return str(instr)
313
+
314
+
315
+ def main(args):
316
+ args = parse_args(args)
317
+ data_dir = Path(args.data_dir)
318
+ save_dir = Path(args.save_dir)
319
+
320
+ episodes_dir = data_dir / "episodes"
321
+ if not episodes_dir.is_dir():
322
+ print(f"Error: episodes directory not found at {episodes_dir}")
323
+ sys.exit(1)
324
+
325
+ # Collect and sort episode directories
326
+ episode_dirs = sorted(
327
+ [d for d in episodes_dir.iterdir() if d.is_dir()],
328
+ key=lambda p: p.name,
329
+ )
330
+
331
+ # Filter by episode range
332
+ if args.start_episode is not None or args.end_episode is not None:
333
+ start = args.start_episode if args.start_episode is not None else 0
334
+ end = args.end_episode if args.end_episode is not None else len(episode_dirs)
335
+ episode_dirs = [
336
+ d for d in episode_dirs
337
+ if start <= int(d.name) < end
338
+ ]
339
+
340
+ print(f"Data dir : {data_dir}")
341
+ print(f"Save dir : {save_dir}")
342
+ print(f"Episodes : {len(episode_dirs)}")
343
+ print(f"Prompt : {args.prompt_template}")
344
+ print()
345
+
346
+ # Load model
347
+ print("Loading model...")
348
+ model, tokenizer, clip_image_processor, transform = load_model(args)
349
+ print("Model loaded.\n")
350
+
351
+ total_steps = 0
352
+ empty_mask_count = 0
353
+
354
+ for ep_dir in episode_dirs:
355
+ episode_id = ep_dir.name # e.g. "000000"
356
+ steps_dir = ep_dir / "steps"
357
+ if not steps_dir.is_dir():
358
+ print(f" [WARNING] No steps/ in {ep_dir}, skipping.")
359
+ continue
360
+
361
+ step_dirs = sorted(
362
+ [d for d in steps_dir.iterdir() if d.is_dir()],
363
+ key=lambda p: p.name,
364
+ )
365
+
366
+ for step_dir in step_dirs:
367
+ step_id = step_dir.name # e.g. "0000"
368
+
369
+ # Read language instruction
370
+ other_h5 = step_dir / "other.h5"
371
+ if not other_h5.exists():
372
+ print(f" [WARNING] Missing other.h5 in {step_dir}, skipping.")
373
+ continue
374
+ language_instruction = read_language_instruction(str(other_h5))
375
+ # debug
376
+ # print(language_instruction)
377
+
378
+ # Build prompt
379
+ query_text = args.prompt_template.format(language_instruction)
380
+ prompt_str = build_prompt(query_text, args)
381
+
382
+ # Output directory (same structure as input: episodes/{episode_id}/steps/{step_id}/)
383
+ out_dir = save_dir / "episodes" / episode_id / "steps" / step_id
384
+ out_dir.mkdir(parents=True, exist_ok=True)
385
+
386
+ # Process both cameras
387
+ for cam_name in ("image_primary", "image_wrist"):
388
+ img_path = step_dir / f"{cam_name}.jpg"
389
+ mask_path = out_dir / f"{cam_name}_mask.png"
390
+
391
+ if not img_path.exists():
392
+ print(f" [WARNING] Missing {img_path}, skipping.")
393
+ continue
394
+
395
+ mask = infer_single_image(
396
+ str(img_path), prompt_str,
397
+ model, tokenizer, clip_image_processor, transform, args,
398
+ )
399
+
400
+ if mask is None:
401
+ # Save blank mask and warn
402
+ h, w = cv2.imread(str(img_path)).shape[:2]
403
+ mask = np.zeros((h, w), dtype=np.uint8)
404
+ empty_mask_count += 1
405
+
406
+ cv2.imwrite(str(mask_path), mask)
407
+
408
+ total_steps += 1
409
+ if total_steps % 50 == 0:
410
+ print(f" Processed {total_steps} steps (episode {episode_id}, step {step_id})")
411
+
412
+ print(f"Episode {episode_id} done ({len(step_dirs)} steps)")
413
+
414
+ print(f"\nFinished. {total_steps} steps processed, {empty_mask_count} empty masks.")
415
+
416
+
417
+ if __name__ == "__main__":
418
+ main(sys.argv[1:])
chat.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
10
+
11
+ from model.AffordanceVLM import AffordanceVLMForCausalLM
12
+ from model.llava import conversation as conversation_lib
13
+ from model.llava.mm_utils import tokenizer_image_token
14
+ from model.segment_anything.utils.transforms import ResizeLongestSide
15
+ from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
16
+ DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
17
+
18
+
19
+ def parse_args(args):
20
+ parser = argparse.ArgumentParser(description="LISA chat")
21
+ parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B")
22
+ parser.add_argument("--vis_save_path", default="./vis_output", type=str)
23
+ parser.add_argument(
24
+ "--precision",
25
+ default="bf16",
26
+ type=str,
27
+ choices=["fp32", "bf16", "fp16"],
28
+ help="precision for inference",
29
+ )
30
+ parser.add_argument("--image_size", default=1024, type=int, help="image size")
31
+ parser.add_argument("--model_max_length", default=512, type=int)
32
+ parser.add_argument("--lora_r", default=8, type=int)
33
+ parser.add_argument(
34
+ "--vision-tower", default="openai/clip-vit-large-patch14", type=str
35
+ )
36
+ parser.add_argument("--local-rank", default=0, type=int, help="node rank")
37
+ parser.add_argument("--load_in_8bit", action="store_true", default=False)
38
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
39
+ parser.add_argument("--use_mm_start_end", action="store_true", default=True)
40
+ parser.add_argument(
41
+ "--conv_type",
42
+ default="llava_v1",
43
+ type=str,
44
+ choices=["llava_v1", "llava_llama_2"],
45
+ )
46
+ return parser.parse_args(args)
47
+
48
+
49
+ def preprocess(
50
+ x,
51
+ pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
52
+ pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
53
+ img_size=1024,
54
+ ) -> torch.Tensor:
55
+ """Normalize pixel values and pad to a square input."""
56
+ # Normalize colors
57
+ x = (x - pixel_mean) / pixel_std
58
+ # Pad
59
+ h, w = x.shape[-2:]
60
+ padh = img_size - h
61
+ padw = img_size - w
62
+ x = F.pad(x, (0, padw, 0, padh))
63
+ return x
64
+
65
+
66
+ def main(args):
67
+ args = parse_args(args)
68
+ os.makedirs(args.vis_save_path, exist_ok=True)
69
+
70
+ # Create model
71
+ tokenizer = AutoTokenizer.from_pretrained(
72
+ args.version,
73
+ cache_dir=None,
74
+ model_max_length=args.model_max_length,
75
+ padding_side="right",
76
+ use_fast=False,
77
+ )
78
+ tokenizer.pad_token = tokenizer.unk_token
79
+ num_added_tokens = tokenizer.add_tokens("[SEG]")
80
+ args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
81
+ num_added_tokens = tokenizer.add_tokens("[AFF]")
82
+ args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
83
+
84
+ torch_dtype = torch.float32
85
+ if args.precision == "bf16":
86
+ torch_dtype = torch.bfloat16
87
+ elif args.precision == "fp16":
88
+ torch_dtype = torch.half
89
+
90
+ kwargs = {"torch_dtype": torch_dtype}
91
+ if args.load_in_4bit:
92
+ kwargs.update(
93
+ {
94
+ "torch_dtype": torch.half,
95
+ "load_in_4bit": True,
96
+ "quantization_config": BitsAndBytesConfig(
97
+ load_in_4bit=True,
98
+ bnb_4bit_compute_dtype=torch.float16,
99
+ bnb_4bit_use_double_quant=True,
100
+ bnb_4bit_quant_type="nf4",
101
+ llm_int8_skip_modules=["visual_model"],
102
+ ),
103
+ }
104
+ )
105
+ elif args.load_in_8bit:
106
+ kwargs.update(
107
+ {
108
+ "torch_dtype": torch.half,
109
+ "quantization_config": BitsAndBytesConfig(
110
+ llm_int8_skip_modules=["visual_model"],
111
+ load_in_8bit=True,
112
+ ),
113
+ }
114
+ )
115
+
116
+ model = AffordanceVLMForCausalLM.from_pretrained(
117
+ args.version, low_cpu_mem_usage=True, vision_tower=args.vision_tower, seg_token_idx=args.seg_token_idx, aff_token_idx=args.aff_token_idx, **kwargs
118
+ )
119
+
120
+ model.config.eos_token_id = tokenizer.eos_token_id
121
+ model.config.bos_token_id = tokenizer.bos_token_id
122
+ model.config.pad_token_id = tokenizer.pad_token_id
123
+
124
+ model.get_model().initialize_vision_modules(model.get_model().config)
125
+ vision_tower = model.get_model().get_vision_tower()
126
+ vision_tower.to(dtype=torch_dtype)
127
+
128
+ if args.precision == "bf16":
129
+ model = model.bfloat16().cuda()
130
+ elif (
131
+ args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit)
132
+ ):
133
+ vision_tower = model.get_model().get_vision_tower()
134
+ model.model.vision_tower = None
135
+ import deepspeed
136
+
137
+ model_engine = deepspeed.init_inference(
138
+ model=model,
139
+ dtype=torch.half,
140
+ replace_with_kernel_inject=True,
141
+ replace_method="auto",
142
+ )
143
+ model = model_engine.module
144
+ model.model.vision_tower = vision_tower.half().cuda()
145
+ elif args.precision == "fp32":
146
+ model = model.float().cuda()
147
+
148
+ vision_tower = model.get_model().get_vision_tower()
149
+ vision_tower.to(device=args.local_rank)
150
+
151
+ clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
152
+ transform = ResizeLongestSide(args.image_size)
153
+
154
+ model.eval()
155
+
156
+ while True:
157
+ conv = conversation_lib.conv_templates[args.conv_type].copy()
158
+ conv.messages = []
159
+
160
+ prompt = input("Please input your prompt: ")
161
+ prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + prompt
162
+ if args.use_mm_start_end:
163
+ replace_token = (
164
+ DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
165
+ )
166
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
167
+
168
+ conv.append_message(conv.roles[0], prompt)
169
+ conv.append_message(conv.roles[1], "")
170
+ prompt = conv.get_prompt()
171
+
172
+ image_path = input("Please input the image path: ")
173
+ if not os.path.exists(image_path):
174
+ print("File not found in {}".format(image_path))
175
+ continue
176
+
177
+ image_np = cv2.imread(image_path)
178
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
179
+ original_size_list = [image_np.shape[:2]]
180
+
181
+ image_clip = (
182
+ clip_image_processor.preprocess(image_np, return_tensors="pt")[
183
+ "pixel_values"
184
+ ][0]
185
+ .unsqueeze(0)
186
+ .cuda()
187
+ )
188
+ if args.precision == "bf16":
189
+ image_clip = image_clip.bfloat16()
190
+ elif args.precision == "fp16":
191
+ image_clip = image_clip.half()
192
+ else:
193
+ image_clip = image_clip.float()
194
+
195
+ image = transform.apply_image(image_np)
196
+ resize_list = [image.shape[:2]]
197
+
198
+ image = (
199
+ preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
200
+ .unsqueeze(0)
201
+ .cuda()
202
+ )
203
+ if args.precision == "bf16":
204
+ image = image.bfloat16()
205
+ elif args.precision == "fp16":
206
+ image = image.half()
207
+ else:
208
+ image = image.float()
209
+
210
+ input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
211
+ input_ids = input_ids.unsqueeze(0).cuda()
212
+
213
+ output_ids, pred_masks = model.evaluate(
214
+ image_clip,
215
+ image,
216
+ input_ids,
217
+ resize_list,
218
+ original_size_list,
219
+ max_new_tokens=512,
220
+ tokenizer=tokenizer,
221
+ )
222
+ output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
223
+
224
+ text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
225
+ text_output = text_output.replace("\n", "").replace(" ", " ")
226
+ print("text_output: ", text_output)
227
+
228
+ for i, pred_mask in enumerate(pred_masks):
229
+ if pred_mask.shape[0] == 0:
230
+ continue
231
+
232
+ pred_mask = pred_mask.detach().cpu().numpy()[0]
233
+ pred_mask = pred_mask > 0
234
+
235
+ save_path = "{}/{}_mask_{}.jpg".format(
236
+ args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
237
+ )
238
+ cv2.imwrite(save_path, pred_mask * 100)
239
+ print("{} has been saved.".format(save_path))
240
+
241
+ save_path = "{}/{}_masked_img_{}.jpg".format(
242
+ args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
243
+ )
244
+ save_img = image_np.copy()
245
+ save_img[pred_mask] = (
246
+ image_np * 0.5
247
+ + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
248
+ )[pred_mask]
249
+ save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
250
+ cv2.imwrite(save_path, save_img)
251
+ print("{} has been saved.".format(save_path))
252
+
253
+
254
+ if __name__ == "__main__":
255
+ main(sys.argv[1:])
chat_prefill.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Interactive affordance mask generation using prefill mode (single forward pass).
3
+
4
+ Same interactive workflow as chat.py, but uses prefill inference instead of
5
+ autoregressive generation. The assistant response "[AFF]." is pre-filled in the
6
+ prompt, so the model only does one forward pass to extract mask embeddings.
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+ import sys
12
+
13
+ import cv2
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from transformers import AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor
18
+
19
+ from model.AffordanceVLM import AffordanceVLMForCausalLM
20
+ from model.llava import conversation as conversation_lib
21
+ from model.llava.mm_utils import tokenizer_image_token
22
+ from model.segment_anything.utils.transforms import ResizeLongestSide
23
+ from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
24
+ DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
25
+
26
+
27
+ def parse_args(args):
28
+ parser = argparse.ArgumentParser(description="AffordanceVLM chat (prefill mode)")
29
+ parser.add_argument("--version", default="/gemini/code/AffordanceNet/ckpts/AffordanceVLM-7B")
30
+ parser.add_argument("--vis_save_path", default="./vis_output_prefill", type=str)
31
+ parser.add_argument(
32
+ "--precision", default="bf16", type=str,
33
+ choices=["fp32", "bf16", "fp16"],
34
+ )
35
+ parser.add_argument("--image_size", default=1024, type=int)
36
+ parser.add_argument("--model_max_length", default=512, type=int)
37
+ parser.add_argument("--lora_r", default=8, type=int)
38
+ parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14", type=str)
39
+ parser.add_argument("--local-rank", default=0, type=int)
40
+ parser.add_argument("--load_in_8bit", action="store_true", default=False)
41
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
42
+ parser.add_argument("--use_mm_start_end", action="store_true", default=True)
43
+ parser.add_argument(
44
+ "--conv_type", default="llava_v1", type=str,
45
+ choices=["llava_v1", "llava_llama_2"],
46
+ )
47
+ parser.add_argument("--prompt_template", type=str,
48
+ default="Segment the most suitable manipulation region on the single target object for the task '{}'.",
49
+ help="Template wrapping language_instruction. Use {} as placeholder.")
50
+ # Segment the most suitable manipulation region on the single target object for the task '{}'.
51
+ # Segment the affordance map for the task '{}' in this image.
52
+ # Segment the affordance map of the single target object for the task '{}' in this image.
53
+ # Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask.
54
+ # Given the task instruction '{}', what is the affordance map of the single target object in this image? There is only one target object. Please output segmentation mask.
55
+ return parser.parse_args(args)
56
+
57
+
58
+ def preprocess(
59
+ x,
60
+ pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
61
+ pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
62
+ img_size=1024,
63
+ ) -> torch.Tensor:
64
+ """Normalize pixel values and pad to a square input."""
65
+ x = (x - pixel_mean) / pixel_std
66
+ h, w = x.shape[-2:]
67
+ padh = img_size - h
68
+ padw = img_size - w
69
+ x = F.pad(x, (0, padw, 0, padh))
70
+ return x
71
+
72
+
73
+ def main(args):
74
+ args = parse_args(args)
75
+ os.makedirs(args.vis_save_path, exist_ok=True)
76
+
77
+ # Create model
78
+ tokenizer = AutoTokenizer.from_pretrained(
79
+ args.version,
80
+ cache_dir=None,
81
+ model_max_length=args.model_max_length,
82
+ padding_side="right",
83
+ use_fast=False,
84
+ )
85
+ tokenizer.pad_token = tokenizer.unk_token
86
+ tokenizer.add_tokens("[SEG]")
87
+ args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
88
+ tokenizer.add_tokens("[AFF]")
89
+ args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
90
+
91
+ torch_dtype = torch.float32
92
+ if args.precision == "bf16":
93
+ torch_dtype = torch.bfloat16
94
+ elif args.precision == "fp16":
95
+ torch_dtype = torch.half
96
+
97
+ kwargs = {"torch_dtype": torch_dtype}
98
+ if args.load_in_4bit:
99
+ kwargs.update({
100
+ "torch_dtype": torch.half,
101
+ "load_in_4bit": True,
102
+ "quantization_config": BitsAndBytesConfig(
103
+ load_in_4bit=True,
104
+ bnb_4bit_compute_dtype=torch.float16,
105
+ bnb_4bit_use_double_quant=True,
106
+ bnb_4bit_quant_type="nf4",
107
+ llm_int8_skip_modules=["visual_model"],
108
+ ),
109
+ })
110
+ elif args.load_in_8bit:
111
+ kwargs.update({
112
+ "torch_dtype": torch.half,
113
+ "quantization_config": BitsAndBytesConfig(
114
+ llm_int8_skip_modules=["visual_model"],
115
+ load_in_8bit=True,
116
+ ),
117
+ })
118
+
119
+ model = AffordanceVLMForCausalLM.from_pretrained(
120
+ args.version,
121
+ low_cpu_mem_usage=True,
122
+ vision_tower=args.vision_tower,
123
+ seg_token_idx=args.seg_token_idx,
124
+ aff_token_idx=args.aff_token_idx,
125
+ **kwargs,
126
+ )
127
+
128
+ model.config.eos_token_id = tokenizer.eos_token_id
129
+ model.config.bos_token_id = tokenizer.bos_token_id
130
+ model.config.pad_token_id = tokenizer.pad_token_id
131
+
132
+ model.get_model().initialize_vision_modules(model.get_model().config)
133
+ vision_tower = model.get_model().get_vision_tower()
134
+ vision_tower.to(dtype=torch_dtype)
135
+
136
+ if args.precision == "bf16":
137
+ model = model.bfloat16().cuda()
138
+ elif args.precision == "fp16" and (not args.load_in_4bit) and (not args.load_in_8bit):
139
+ vision_tower = model.get_model().get_vision_tower()
140
+ model.model.vision_tower = None
141
+ import deepspeed
142
+ model_engine = deepspeed.init_inference(
143
+ model=model,
144
+ dtype=torch.half,
145
+ replace_with_kernel_inject=True,
146
+ replace_method="auto",
147
+ )
148
+ model = model_engine.module
149
+ model.model.vision_tower = vision_tower.half().cuda()
150
+ elif args.precision == "fp32":
151
+ model = model.float().cuda()
152
+
153
+ vision_tower = model.get_model().get_vision_tower()
154
+ vision_tower.to(device=args.local_rank)
155
+
156
+ clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
157
+ transform = ResizeLongestSide(args.image_size)
158
+
159
+ model.eval()
160
+
161
+ # debug
162
+ template = "Given the task instruction '{}', what is the affordance map of the target object in this image? Please output segmentation mask."
163
+
164
+ while True:
165
+ conv = conversation_lib.conv_templates[args.conv_type].copy()
166
+ conv.messages = []
167
+
168
+ prompt = input("Please input your prompt: ")
169
+ # 加入模版
170
+ prompt = args.prompt_template.format(prompt)
171
+
172
+ prompt = DEFAULT_IMAGE_TOKEN + "\n" + "You are an embodied robot. " + prompt
173
+ if args.use_mm_start_end:
174
+ replace_token = (
175
+ DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
176
+ )
177
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
178
+
179
+ conv.append_message(conv.roles[0], prompt)
180
+ conv.append_message(conv.roles[1], "[AFF].")
181
+ prompt = conv.get_prompt()
182
+
183
+ image_path = input("Please input the image path: ")
184
+ if not os.path.exists(image_path):
185
+ print("File not found in {}".format(image_path))
186
+ continue
187
+
188
+ image_np = cv2.imread(image_path)
189
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
190
+ original_size_list = [image_np.shape[:2]]
191
+ h, w = original_size_list[0]
192
+
193
+ image_clip = (
194
+ clip_image_processor.preprocess(image_np, return_tensors="pt")[
195
+ "pixel_values"
196
+ ][0]
197
+ .unsqueeze(0)
198
+ .cuda()
199
+ )
200
+ if args.precision == "bf16":
201
+ image_clip = image_clip.bfloat16()
202
+ elif args.precision == "fp16":
203
+ image_clip = image_clip.half()
204
+ else:
205
+ image_clip = image_clip.float()
206
+
207
+ image = transform.apply_image(image_np)
208
+ resize_list = [image.shape[:2]]
209
+
210
+ image = (
211
+ preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())
212
+ .unsqueeze(0)
213
+ .cuda()
214
+ )
215
+ if args.precision == "bf16":
216
+ image = image.bfloat16()
217
+ elif args.precision == "fp16":
218
+ image = image.half()
219
+ else:
220
+ image = image.float()
221
+
222
+ input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
223
+ input_ids = input_ids.unsqueeze(0).cuda()
224
+ attention_masks = input_ids.ne(tokenizer.pad_token_id)
225
+
226
+ # Print the full prompt text (prefill mode has no generated text)
227
+ # debug
228
+ text_ids = input_ids[0][input_ids[0] != IMAGE_TOKEN_INDEX]
229
+ text_output = tokenizer.decode(text_ids, skip_special_tokens=False)
230
+ text_output = text_output.replace("\n", "").replace(" ", " ")
231
+ print("text_output: ", text_output)
232
+
233
+ # Prefill inference
234
+ labels = input_ids.clone()
235
+ offset = torch.LongTensor([0, 1]).cuda()
236
+ masks_list = [torch.zeros(1, h, w).float().cuda()]
237
+ label_list = [torch.zeros(h, w).long().cuda()]
238
+
239
+ with torch.no_grad():
240
+ output_dict = model(
241
+ images=image,
242
+ images_clip=image_clip,
243
+ input_ids=input_ids,
244
+ labels=labels,
245
+ attention_masks=attention_masks,
246
+ offset=offset,
247
+ masks_list=masks_list,
248
+ label_list=label_list,
249
+ resize_list=resize_list,
250
+ inference=True,
251
+ )
252
+
253
+ pred_masks = output_dict["pred_masks"]
254
+
255
+ for i, pred_mask in enumerate(pred_masks):
256
+ if pred_mask.shape[0] == 0:
257
+ continue
258
+
259
+ pred_mask = pred_mask.detach().cpu().numpy()[0]
260
+ pred_mask = pred_mask > 0
261
+
262
+ save_path = "{}/{}_mask_{}.jpg".format(
263
+ args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
264
+ )
265
+ cv2.imwrite(save_path, pred_mask * 100)
266
+ print("{} has been saved.".format(save_path))
267
+
268
+ save_path = "{}/{}_masked_img_{}.jpg".format(
269
+ args.vis_save_path, image_path.split("/")[-1].split(".")[0], i
270
+ )
271
+ save_img = image_np.copy()
272
+ save_img[pred_mask] = (
273
+ image_np * 0.5
274
+ + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
275
+ )[pred_mask]
276
+ save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
277
+ cv2.imwrite(save_path, save_img)
278
+ print("{} has been saved.".format(save_path))
279
+
280
+
281
+ if __name__ == "__main__":
282
+ main(sys.argv[1:])
ckpts/AffordanceVLM-7B/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
ckpts/AffordanceVLM-7B/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
ckpts/AffordanceVLM-7B/added_tokens.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "<im_end>": 32002,
3
+ "<im_patch>": 32000,
4
+ "<im_start>": 32001,
5
+ "[AFF]": 32004,
6
+ "[SEG]": 32003
7
+ }
ckpts/AffordanceVLM-7B/config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./LLaVA/LLaVA-Lightning-7B-v1-1",
3
+ "architectures": [
4
+ "AffordanceVLMForCausalLM"
5
+ ],
6
+ "bos_token_id": 1,
7
+ "eos_token_id": 2,
8
+ "freeze_mm_mlp_adapter": true,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 4096,
11
+ "image_aspect_ratio": "square",
12
+ "image_grid_pinpoints": null,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 11008,
15
+ "max_position_embeddings": 2048,
16
+ "max_sequence_length": 2048,
17
+ "mm_hidden_size": 1024,
18
+ "mm_use_im_patch_token": false,
19
+ "mm_use_im_start_end": true,
20
+ "mm_vision_select_feature": "patch",
21
+ "mm_vision_select_layer": -2,
22
+ "mm_vision_tower": "openai/clip-vit-large-patch14",
23
+ "model_type": "llava",
24
+ "num_attention_heads": 32,
25
+ "num_hidden_layers": 32,
26
+ "num_key_value_heads": 32,
27
+ "out_dim": 256,
28
+ "pad_token_id": 0,
29
+ "pretrain_mm_mlp_adapter": null,
30
+ "pretraining_tp": 1,
31
+ "rms_norm_eps": 1e-06,
32
+ "rope_scaling": null,
33
+ "tie_word_embeddings": false,
34
+ "torch_dtype": "bfloat16",
35
+ "train_mask_decoder": true,
36
+ "transformers_version": "4.31.0",
37
+ "tune_mm_mlp_adapter": false,
38
+ "use_cache": false,
39
+ "use_mm_proj": true,
40
+ "vision_tower": "openai/clip-vit-large-patch14",
41
+ "vocab_size": 32005
42
+ }
ckpts/AffordanceVLM-7B/eval_result.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ dataset: handal_all, giou: 0.60872483253479, ciou: 0.6054294109344482
ckpts/AffordanceVLM-7B/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.31.0"
7
+ }
ckpts/AffordanceVLM-7B/pytorch_model-00001-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efdb3ff9accdd733412d083c770ba34ae1c6745b28e2bae07d3546dc9356bfec
3
+ size 9976675518
ckpts/AffordanceVLM-7B/pytorch_model-00002-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7259eabdd3c03be21d45a328177ac3e46e1385cbc5ff2d757cd8bb70dec81ae9
3
+ size 6144654233
ckpts/AffordanceVLM-7B/pytorch_model.bin.index.json ADDED
@@ -0,0 +1,930 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 16121002176
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "pytorch_model-00002-of-00002.bin",
7
+ "model.embed_tokens.weight": "pytorch_model-00001-of-00002.bin",
8
+ "model.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
9
+ "model.layers.0.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
10
+ "model.layers.0.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
11
+ "model.layers.0.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
12
+ "model.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
13
+ "model.layers.0.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
14
+ "model.layers.0.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
15
+ "model.layers.0.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
16
+ "model.layers.0.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
17
+ "model.layers.0.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
18
+ "model.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
19
+ "model.layers.1.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
20
+ "model.layers.1.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
21
+ "model.layers.1.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
22
+ "model.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
23
+ "model.layers.1.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
24
+ "model.layers.1.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
25
+ "model.layers.1.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
26
+ "model.layers.1.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
27
+ "model.layers.1.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
28
+ "model.layers.10.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
29
+ "model.layers.10.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
30
+ "model.layers.10.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
31
+ "model.layers.10.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
32
+ "model.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
33
+ "model.layers.10.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
34
+ "model.layers.10.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
35
+ "model.layers.10.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
36
+ "model.layers.10.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
37
+ "model.layers.10.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
38
+ "model.layers.11.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
39
+ "model.layers.11.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
40
+ "model.layers.11.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
41
+ "model.layers.11.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
42
+ "model.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
43
+ "model.layers.11.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
44
+ "model.layers.11.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
45
+ "model.layers.11.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
46
+ "model.layers.11.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
47
+ "model.layers.11.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
48
+ "model.layers.12.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
49
+ "model.layers.12.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
50
+ "model.layers.12.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
51
+ "model.layers.12.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
52
+ "model.layers.12.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
53
+ "model.layers.12.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
54
+ "model.layers.12.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
55
+ "model.layers.12.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
56
+ "model.layers.12.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
57
+ "model.layers.12.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
58
+ "model.layers.13.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
59
+ "model.layers.13.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
60
+ "model.layers.13.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
61
+ "model.layers.13.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
62
+ "model.layers.13.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
63
+ "model.layers.13.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
64
+ "model.layers.13.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
65
+ "model.layers.13.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
66
+ "model.layers.13.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
67
+ "model.layers.13.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
68
+ "model.layers.14.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
69
+ "model.layers.14.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
70
+ "model.layers.14.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
71
+ "model.layers.14.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
72
+ "model.layers.14.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
73
+ "model.layers.14.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
74
+ "model.layers.14.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
75
+ "model.layers.14.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
76
+ "model.layers.14.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
77
+ "model.layers.14.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
78
+ "model.layers.15.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
79
+ "model.layers.15.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
80
+ "model.layers.15.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
81
+ "model.layers.15.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
82
+ "model.layers.15.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
83
+ "model.layers.15.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
84
+ "model.layers.15.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
85
+ "model.layers.15.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
86
+ "model.layers.15.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
87
+ "model.layers.15.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
88
+ "model.layers.16.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
89
+ "model.layers.16.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
90
+ "model.layers.16.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
91
+ "model.layers.16.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
92
+ "model.layers.16.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
93
+ "model.layers.16.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
94
+ "model.layers.16.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
95
+ "model.layers.16.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
96
+ "model.layers.16.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
97
+ "model.layers.16.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
98
+ "model.layers.17.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
99
+ "model.layers.17.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
100
+ "model.layers.17.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
101
+ "model.layers.17.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
102
+ "model.layers.17.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
103
+ "model.layers.17.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
104
+ "model.layers.17.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
105
+ "model.layers.17.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
106
+ "model.layers.17.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
107
+ "model.layers.17.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
108
+ "model.layers.18.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
109
+ "model.layers.18.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
110
+ "model.layers.18.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
111
+ "model.layers.18.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
112
+ "model.layers.18.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
113
+ "model.layers.18.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
114
+ "model.layers.18.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
115
+ "model.layers.18.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
116
+ "model.layers.18.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
117
+ "model.layers.18.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
118
+ "model.layers.19.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
119
+ "model.layers.19.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
120
+ "model.layers.19.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
121
+ "model.layers.19.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
122
+ "model.layers.19.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
123
+ "model.layers.19.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
124
+ "model.layers.19.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
125
+ "model.layers.19.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
126
+ "model.layers.19.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
127
+ "model.layers.19.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
128
+ "model.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
129
+ "model.layers.2.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
130
+ "model.layers.2.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
131
+ "model.layers.2.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
132
+ "model.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
133
+ "model.layers.2.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
134
+ "model.layers.2.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
135
+ "model.layers.2.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
136
+ "model.layers.2.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
137
+ "model.layers.2.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
138
+ "model.layers.20.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
139
+ "model.layers.20.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
140
+ "model.layers.20.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
141
+ "model.layers.20.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
142
+ "model.layers.20.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
143
+ "model.layers.20.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
144
+ "model.layers.20.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
145
+ "model.layers.20.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
146
+ "model.layers.20.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
147
+ "model.layers.20.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
148
+ "model.layers.21.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
149
+ "model.layers.21.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
150
+ "model.layers.21.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
151
+ "model.layers.21.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
152
+ "model.layers.21.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
153
+ "model.layers.21.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
154
+ "model.layers.21.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
155
+ "model.layers.21.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
156
+ "model.layers.21.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
157
+ "model.layers.21.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
158
+ "model.layers.22.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
159
+ "model.layers.22.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
160
+ "model.layers.22.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
161
+ "model.layers.22.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
162
+ "model.layers.22.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
163
+ "model.layers.22.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
164
+ "model.layers.22.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
165
+ "model.layers.22.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
166
+ "model.layers.22.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
167
+ "model.layers.22.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
168
+ "model.layers.23.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
169
+ "model.layers.23.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
170
+ "model.layers.23.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
171
+ "model.layers.23.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
172
+ "model.layers.23.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
173
+ "model.layers.23.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
174
+ "model.layers.23.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
175
+ "model.layers.23.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
176
+ "model.layers.23.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
177
+ "model.layers.23.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
178
+ "model.layers.24.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
179
+ "model.layers.24.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
180
+ "model.layers.24.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
181
+ "model.layers.24.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
182
+ "model.layers.24.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
183
+ "model.layers.24.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
184
+ "model.layers.24.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
185
+ "model.layers.24.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
186
+ "model.layers.24.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
187
+ "model.layers.24.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
188
+ "model.layers.25.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
189
+ "model.layers.25.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
190
+ "model.layers.25.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
191
+ "model.layers.25.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
192
+ "model.layers.25.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
193
+ "model.layers.25.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
194
+ "model.layers.25.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
195
+ "model.layers.25.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
196
+ "model.layers.25.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
197
+ "model.layers.25.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
198
+ "model.layers.26.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
199
+ "model.layers.26.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
200
+ "model.layers.26.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
201
+ "model.layers.26.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
202
+ "model.layers.26.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
203
+ "model.layers.26.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
204
+ "model.layers.26.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
205
+ "model.layers.26.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
206
+ "model.layers.26.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
207
+ "model.layers.26.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
208
+ "model.layers.27.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
209
+ "model.layers.27.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
210
+ "model.layers.27.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
211
+ "model.layers.27.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
212
+ "model.layers.27.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
213
+ "model.layers.27.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
214
+ "model.layers.27.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
215
+ "model.layers.27.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
216
+ "model.layers.27.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
217
+ "model.layers.27.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
218
+ "model.layers.28.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
219
+ "model.layers.28.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
220
+ "model.layers.28.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
221
+ "model.layers.28.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
222
+ "model.layers.28.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
223
+ "model.layers.28.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
224
+ "model.layers.28.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
225
+ "model.layers.28.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
226
+ "model.layers.28.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
227
+ "model.layers.28.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
228
+ "model.layers.29.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
229
+ "model.layers.29.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
230
+ "model.layers.29.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
231
+ "model.layers.29.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
232
+ "model.layers.29.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
233
+ "model.layers.29.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
234
+ "model.layers.29.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
235
+ "model.layers.29.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
236
+ "model.layers.29.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
237
+ "model.layers.29.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
238
+ "model.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
239
+ "model.layers.3.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
240
+ "model.layers.3.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
241
+ "model.layers.3.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
242
+ "model.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
243
+ "model.layers.3.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
244
+ "model.layers.3.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
245
+ "model.layers.3.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
246
+ "model.layers.3.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
247
+ "model.layers.3.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
248
+ "model.layers.30.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
249
+ "model.layers.30.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
250
+ "model.layers.30.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
251
+ "model.layers.30.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
252
+ "model.layers.30.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
253
+ "model.layers.30.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
254
+ "model.layers.30.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
255
+ "model.layers.30.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
256
+ "model.layers.30.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
257
+ "model.layers.30.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
258
+ "model.layers.31.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
259
+ "model.layers.31.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
260
+ "model.layers.31.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
261
+ "model.layers.31.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
262
+ "model.layers.31.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
263
+ "model.layers.31.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
264
+ "model.layers.31.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
265
+ "model.layers.31.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
266
+ "model.layers.31.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
267
+ "model.layers.31.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
268
+ "model.layers.4.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
269
+ "model.layers.4.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
270
+ "model.layers.4.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
271
+ "model.layers.4.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
272
+ "model.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
273
+ "model.layers.4.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
274
+ "model.layers.4.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
275
+ "model.layers.4.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
276
+ "model.layers.4.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
277
+ "model.layers.4.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
278
+ "model.layers.5.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
279
+ "model.layers.5.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
280
+ "model.layers.5.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
281
+ "model.layers.5.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
282
+ "model.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
283
+ "model.layers.5.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
284
+ "model.layers.5.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
285
+ "model.layers.5.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
286
+ "model.layers.5.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
287
+ "model.layers.5.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
288
+ "model.layers.6.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
289
+ "model.layers.6.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
290
+ "model.layers.6.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
291
+ "model.layers.6.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
292
+ "model.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
293
+ "model.layers.6.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
294
+ "model.layers.6.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
295
+ "model.layers.6.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
296
+ "model.layers.6.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
297
+ "model.layers.6.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
298
+ "model.layers.7.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
299
+ "model.layers.7.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
300
+ "model.layers.7.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
301
+ "model.layers.7.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
302
+ "model.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
303
+ "model.layers.7.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
304
+ "model.layers.7.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
305
+ "model.layers.7.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
306
+ "model.layers.7.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
307
+ "model.layers.7.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
308
+ "model.layers.8.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
309
+ "model.layers.8.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
310
+ "model.layers.8.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
311
+ "model.layers.8.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
312
+ "model.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
313
+ "model.layers.8.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
314
+ "model.layers.8.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
315
+ "model.layers.8.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
316
+ "model.layers.8.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
317
+ "model.layers.8.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
318
+ "model.layers.9.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
319
+ "model.layers.9.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
320
+ "model.layers.9.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
321
+ "model.layers.9.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
322
+ "model.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
323
+ "model.layers.9.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
324
+ "model.layers.9.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
325
+ "model.layers.9.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
326
+ "model.layers.9.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
327
+ "model.layers.9.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
328
+ "model.mm_projector.bias": "pytorch_model-00002-of-00002.bin",
329
+ "model.mm_projector.weight": "pytorch_model-00002-of-00002.bin",
330
+ "model.norm.weight": "pytorch_model-00002-of-00002.bin",
331
+ "model.text_hidden_fcs.0.0.bias": "pytorch_model-00002-of-00002.bin",
332
+ "model.text_hidden_fcs.0.0.weight": "pytorch_model-00002-of-00002.bin",
333
+ "model.text_hidden_fcs.0.2.bias": "pytorch_model-00002-of-00002.bin",
334
+ "model.text_hidden_fcs.0.2.weight": "pytorch_model-00002-of-00002.bin",
335
+ "model.visual_model.image_encoder.blocks.0.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
336
+ "model.visual_model.image_encoder.blocks.0.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
337
+ "model.visual_model.image_encoder.blocks.0.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
338
+ "model.visual_model.image_encoder.blocks.0.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
339
+ "model.visual_model.image_encoder.blocks.0.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
340
+ "model.visual_model.image_encoder.blocks.0.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
341
+ "model.visual_model.image_encoder.blocks.0.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
342
+ "model.visual_model.image_encoder.blocks.0.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
343
+ "model.visual_model.image_encoder.blocks.0.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
344
+ "model.visual_model.image_encoder.blocks.0.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
345
+ "model.visual_model.image_encoder.blocks.0.norm1.bias": "pytorch_model-00002-of-00002.bin",
346
+ "model.visual_model.image_encoder.blocks.0.norm1.weight": "pytorch_model-00002-of-00002.bin",
347
+ "model.visual_model.image_encoder.blocks.0.norm2.bias": "pytorch_model-00002-of-00002.bin",
348
+ "model.visual_model.image_encoder.blocks.0.norm2.weight": "pytorch_model-00002-of-00002.bin",
349
+ "model.visual_model.image_encoder.blocks.1.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
350
+ "model.visual_model.image_encoder.blocks.1.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
351
+ "model.visual_model.image_encoder.blocks.1.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
352
+ "model.visual_model.image_encoder.blocks.1.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
353
+ "model.visual_model.image_encoder.blocks.1.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
354
+ "model.visual_model.image_encoder.blocks.1.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
355
+ "model.visual_model.image_encoder.blocks.1.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
356
+ "model.visual_model.image_encoder.blocks.1.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
357
+ "model.visual_model.image_encoder.blocks.1.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
358
+ "model.visual_model.image_encoder.blocks.1.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
359
+ "model.visual_model.image_encoder.blocks.1.norm1.bias": "pytorch_model-00002-of-00002.bin",
360
+ "model.visual_model.image_encoder.blocks.1.norm1.weight": "pytorch_model-00002-of-00002.bin",
361
+ "model.visual_model.image_encoder.blocks.1.norm2.bias": "pytorch_model-00002-of-00002.bin",
362
+ "model.visual_model.image_encoder.blocks.1.norm2.weight": "pytorch_model-00002-of-00002.bin",
363
+ "model.visual_model.image_encoder.blocks.10.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
364
+ "model.visual_model.image_encoder.blocks.10.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
365
+ "model.visual_model.image_encoder.blocks.10.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
366
+ "model.visual_model.image_encoder.blocks.10.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
367
+ "model.visual_model.image_encoder.blocks.10.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
368
+ "model.visual_model.image_encoder.blocks.10.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
369
+ "model.visual_model.image_encoder.blocks.10.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
370
+ "model.visual_model.image_encoder.blocks.10.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
371
+ "model.visual_model.image_encoder.blocks.10.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
372
+ "model.visual_model.image_encoder.blocks.10.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
373
+ "model.visual_model.image_encoder.blocks.10.norm1.bias": "pytorch_model-00002-of-00002.bin",
374
+ "model.visual_model.image_encoder.blocks.10.norm1.weight": "pytorch_model-00002-of-00002.bin",
375
+ "model.visual_model.image_encoder.blocks.10.norm2.bias": "pytorch_model-00002-of-00002.bin",
376
+ "model.visual_model.image_encoder.blocks.10.norm2.weight": "pytorch_model-00002-of-00002.bin",
377
+ "model.visual_model.image_encoder.blocks.11.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
378
+ "model.visual_model.image_encoder.blocks.11.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
379
+ "model.visual_model.image_encoder.blocks.11.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
380
+ "model.visual_model.image_encoder.blocks.11.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
381
+ "model.visual_model.image_encoder.blocks.11.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
382
+ "model.visual_model.image_encoder.blocks.11.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
383
+ "model.visual_model.image_encoder.blocks.11.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
384
+ "model.visual_model.image_encoder.blocks.11.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
385
+ "model.visual_model.image_encoder.blocks.11.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
386
+ "model.visual_model.image_encoder.blocks.11.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
387
+ "model.visual_model.image_encoder.blocks.11.norm1.bias": "pytorch_model-00002-of-00002.bin",
388
+ "model.visual_model.image_encoder.blocks.11.norm1.weight": "pytorch_model-00002-of-00002.bin",
389
+ "model.visual_model.image_encoder.blocks.11.norm2.bias": "pytorch_model-00002-of-00002.bin",
390
+ "model.visual_model.image_encoder.blocks.11.norm2.weight": "pytorch_model-00002-of-00002.bin",
391
+ "model.visual_model.image_encoder.blocks.12.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
392
+ "model.visual_model.image_encoder.blocks.12.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
393
+ "model.visual_model.image_encoder.blocks.12.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
394
+ "model.visual_model.image_encoder.blocks.12.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
395
+ "model.visual_model.image_encoder.blocks.12.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
396
+ "model.visual_model.image_encoder.blocks.12.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
397
+ "model.visual_model.image_encoder.blocks.12.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
398
+ "model.visual_model.image_encoder.blocks.12.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
399
+ "model.visual_model.image_encoder.blocks.12.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
400
+ "model.visual_model.image_encoder.blocks.12.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
401
+ "model.visual_model.image_encoder.blocks.12.norm1.bias": "pytorch_model-00002-of-00002.bin",
402
+ "model.visual_model.image_encoder.blocks.12.norm1.weight": "pytorch_model-00002-of-00002.bin",
403
+ "model.visual_model.image_encoder.blocks.12.norm2.bias": "pytorch_model-00002-of-00002.bin",
404
+ "model.visual_model.image_encoder.blocks.12.norm2.weight": "pytorch_model-00002-of-00002.bin",
405
+ "model.visual_model.image_encoder.blocks.13.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
406
+ "model.visual_model.image_encoder.blocks.13.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
407
+ "model.visual_model.image_encoder.blocks.13.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
408
+ "model.visual_model.image_encoder.blocks.13.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
409
+ "model.visual_model.image_encoder.blocks.13.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
410
+ "model.visual_model.image_encoder.blocks.13.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
411
+ "model.visual_model.image_encoder.blocks.13.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
412
+ "model.visual_model.image_encoder.blocks.13.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
413
+ "model.visual_model.image_encoder.blocks.13.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
414
+ "model.visual_model.image_encoder.blocks.13.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
415
+ "model.visual_model.image_encoder.blocks.13.norm1.bias": "pytorch_model-00002-of-00002.bin",
416
+ "model.visual_model.image_encoder.blocks.13.norm1.weight": "pytorch_model-00002-of-00002.bin",
417
+ "model.visual_model.image_encoder.blocks.13.norm2.bias": "pytorch_model-00002-of-00002.bin",
418
+ "model.visual_model.image_encoder.blocks.13.norm2.weight": "pytorch_model-00002-of-00002.bin",
419
+ "model.visual_model.image_encoder.blocks.14.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
420
+ "model.visual_model.image_encoder.blocks.14.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
421
+ "model.visual_model.image_encoder.blocks.14.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
422
+ "model.visual_model.image_encoder.blocks.14.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
423
+ "model.visual_model.image_encoder.blocks.14.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
424
+ "model.visual_model.image_encoder.blocks.14.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
425
+ "model.visual_model.image_encoder.blocks.14.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
426
+ "model.visual_model.image_encoder.blocks.14.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
427
+ "model.visual_model.image_encoder.blocks.14.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
428
+ "model.visual_model.image_encoder.blocks.14.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
429
+ "model.visual_model.image_encoder.blocks.14.norm1.bias": "pytorch_model-00002-of-00002.bin",
430
+ "model.visual_model.image_encoder.blocks.14.norm1.weight": "pytorch_model-00002-of-00002.bin",
431
+ "model.visual_model.image_encoder.blocks.14.norm2.bias": "pytorch_model-00002-of-00002.bin",
432
+ "model.visual_model.image_encoder.blocks.14.norm2.weight": "pytorch_model-00002-of-00002.bin",
433
+ "model.visual_model.image_encoder.blocks.15.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
434
+ "model.visual_model.image_encoder.blocks.15.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
435
+ "model.visual_model.image_encoder.blocks.15.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
436
+ "model.visual_model.image_encoder.blocks.15.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
437
+ "model.visual_model.image_encoder.blocks.15.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
438
+ "model.visual_model.image_encoder.blocks.15.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
439
+ "model.visual_model.image_encoder.blocks.15.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
440
+ "model.visual_model.image_encoder.blocks.15.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
441
+ "model.visual_model.image_encoder.blocks.15.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
442
+ "model.visual_model.image_encoder.blocks.15.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
443
+ "model.visual_model.image_encoder.blocks.15.norm1.bias": "pytorch_model-00002-of-00002.bin",
444
+ "model.visual_model.image_encoder.blocks.15.norm1.weight": "pytorch_model-00002-of-00002.bin",
445
+ "model.visual_model.image_encoder.blocks.15.norm2.bias": "pytorch_model-00002-of-00002.bin",
446
+ "model.visual_model.image_encoder.blocks.15.norm2.weight": "pytorch_model-00002-of-00002.bin",
447
+ "model.visual_model.image_encoder.blocks.16.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
448
+ "model.visual_model.image_encoder.blocks.16.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
449
+ "model.visual_model.image_encoder.blocks.16.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
450
+ "model.visual_model.image_encoder.blocks.16.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
451
+ "model.visual_model.image_encoder.blocks.16.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
452
+ "model.visual_model.image_encoder.blocks.16.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
453
+ "model.visual_model.image_encoder.blocks.16.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
454
+ "model.visual_model.image_encoder.blocks.16.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
455
+ "model.visual_model.image_encoder.blocks.16.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
456
+ "model.visual_model.image_encoder.blocks.16.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
457
+ "model.visual_model.image_encoder.blocks.16.norm1.bias": "pytorch_model-00002-of-00002.bin",
458
+ "model.visual_model.image_encoder.blocks.16.norm1.weight": "pytorch_model-00002-of-00002.bin",
459
+ "model.visual_model.image_encoder.blocks.16.norm2.bias": "pytorch_model-00002-of-00002.bin",
460
+ "model.visual_model.image_encoder.blocks.16.norm2.weight": "pytorch_model-00002-of-00002.bin",
461
+ "model.visual_model.image_encoder.blocks.17.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
462
+ "model.visual_model.image_encoder.blocks.17.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
463
+ "model.visual_model.image_encoder.blocks.17.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
464
+ "model.visual_model.image_encoder.blocks.17.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
465
+ "model.visual_model.image_encoder.blocks.17.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
466
+ "model.visual_model.image_encoder.blocks.17.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
467
+ "model.visual_model.image_encoder.blocks.17.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
468
+ "model.visual_model.image_encoder.blocks.17.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
469
+ "model.visual_model.image_encoder.blocks.17.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
470
+ "model.visual_model.image_encoder.blocks.17.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
471
+ "model.visual_model.image_encoder.blocks.17.norm1.bias": "pytorch_model-00002-of-00002.bin",
472
+ "model.visual_model.image_encoder.blocks.17.norm1.weight": "pytorch_model-00002-of-00002.bin",
473
+ "model.visual_model.image_encoder.blocks.17.norm2.bias": "pytorch_model-00002-of-00002.bin",
474
+ "model.visual_model.image_encoder.blocks.17.norm2.weight": "pytorch_model-00002-of-00002.bin",
475
+ "model.visual_model.image_encoder.blocks.18.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
476
+ "model.visual_model.image_encoder.blocks.18.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
477
+ "model.visual_model.image_encoder.blocks.18.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
478
+ "model.visual_model.image_encoder.blocks.18.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
479
+ "model.visual_model.image_encoder.blocks.18.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
480
+ "model.visual_model.image_encoder.blocks.18.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
481
+ "model.visual_model.image_encoder.blocks.18.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
482
+ "model.visual_model.image_encoder.blocks.18.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
483
+ "model.visual_model.image_encoder.blocks.18.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
484
+ "model.visual_model.image_encoder.blocks.18.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
485
+ "model.visual_model.image_encoder.blocks.18.norm1.bias": "pytorch_model-00002-of-00002.bin",
486
+ "model.visual_model.image_encoder.blocks.18.norm1.weight": "pytorch_model-00002-of-00002.bin",
487
+ "model.visual_model.image_encoder.blocks.18.norm2.bias": "pytorch_model-00002-of-00002.bin",
488
+ "model.visual_model.image_encoder.blocks.18.norm2.weight": "pytorch_model-00002-of-00002.bin",
489
+ "model.visual_model.image_encoder.blocks.19.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
490
+ "model.visual_model.image_encoder.blocks.19.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
491
+ "model.visual_model.image_encoder.blocks.19.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
492
+ "model.visual_model.image_encoder.blocks.19.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
493
+ "model.visual_model.image_encoder.blocks.19.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
494
+ "model.visual_model.image_encoder.blocks.19.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
495
+ "model.visual_model.image_encoder.blocks.19.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
496
+ "model.visual_model.image_encoder.blocks.19.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
497
+ "model.visual_model.image_encoder.blocks.19.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
498
+ "model.visual_model.image_encoder.blocks.19.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
499
+ "model.visual_model.image_encoder.blocks.19.norm1.bias": "pytorch_model-00002-of-00002.bin",
500
+ "model.visual_model.image_encoder.blocks.19.norm1.weight": "pytorch_model-00002-of-00002.bin",
501
+ "model.visual_model.image_encoder.blocks.19.norm2.bias": "pytorch_model-00002-of-00002.bin",
502
+ "model.visual_model.image_encoder.blocks.19.norm2.weight": "pytorch_model-00002-of-00002.bin",
503
+ "model.visual_model.image_encoder.blocks.2.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
504
+ "model.visual_model.image_encoder.blocks.2.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
505
+ "model.visual_model.image_encoder.blocks.2.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
506
+ "model.visual_model.image_encoder.blocks.2.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
507
+ "model.visual_model.image_encoder.blocks.2.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
508
+ "model.visual_model.image_encoder.blocks.2.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
509
+ "model.visual_model.image_encoder.blocks.2.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
510
+ "model.visual_model.image_encoder.blocks.2.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
511
+ "model.visual_model.image_encoder.blocks.2.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
512
+ "model.visual_model.image_encoder.blocks.2.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
513
+ "model.visual_model.image_encoder.blocks.2.norm1.bias": "pytorch_model-00002-of-00002.bin",
514
+ "model.visual_model.image_encoder.blocks.2.norm1.weight": "pytorch_model-00002-of-00002.bin",
515
+ "model.visual_model.image_encoder.blocks.2.norm2.bias": "pytorch_model-00002-of-00002.bin",
516
+ "model.visual_model.image_encoder.blocks.2.norm2.weight": "pytorch_model-00002-of-00002.bin",
517
+ "model.visual_model.image_encoder.blocks.20.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
518
+ "model.visual_model.image_encoder.blocks.20.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
519
+ "model.visual_model.image_encoder.blocks.20.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
520
+ "model.visual_model.image_encoder.blocks.20.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
521
+ "model.visual_model.image_encoder.blocks.20.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
522
+ "model.visual_model.image_encoder.blocks.20.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
523
+ "model.visual_model.image_encoder.blocks.20.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
524
+ "model.visual_model.image_encoder.blocks.20.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
525
+ "model.visual_model.image_encoder.blocks.20.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
526
+ "model.visual_model.image_encoder.blocks.20.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
527
+ "model.visual_model.image_encoder.blocks.20.norm1.bias": "pytorch_model-00002-of-00002.bin",
528
+ "model.visual_model.image_encoder.blocks.20.norm1.weight": "pytorch_model-00002-of-00002.bin",
529
+ "model.visual_model.image_encoder.blocks.20.norm2.bias": "pytorch_model-00002-of-00002.bin",
530
+ "model.visual_model.image_encoder.blocks.20.norm2.weight": "pytorch_model-00002-of-00002.bin",
531
+ "model.visual_model.image_encoder.blocks.21.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
532
+ "model.visual_model.image_encoder.blocks.21.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
533
+ "model.visual_model.image_encoder.blocks.21.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
534
+ "model.visual_model.image_encoder.blocks.21.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
535
+ "model.visual_model.image_encoder.blocks.21.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
536
+ "model.visual_model.image_encoder.blocks.21.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
537
+ "model.visual_model.image_encoder.blocks.21.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
538
+ "model.visual_model.image_encoder.blocks.21.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
539
+ "model.visual_model.image_encoder.blocks.21.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
540
+ "model.visual_model.image_encoder.blocks.21.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
541
+ "model.visual_model.image_encoder.blocks.21.norm1.bias": "pytorch_model-00002-of-00002.bin",
542
+ "model.visual_model.image_encoder.blocks.21.norm1.weight": "pytorch_model-00002-of-00002.bin",
543
+ "model.visual_model.image_encoder.blocks.21.norm2.bias": "pytorch_model-00002-of-00002.bin",
544
+ "model.visual_model.image_encoder.blocks.21.norm2.weight": "pytorch_model-00002-of-00002.bin",
545
+ "model.visual_model.image_encoder.blocks.22.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
546
+ "model.visual_model.image_encoder.blocks.22.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
547
+ "model.visual_model.image_encoder.blocks.22.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
548
+ "model.visual_model.image_encoder.blocks.22.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
549
+ "model.visual_model.image_encoder.blocks.22.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
550
+ "model.visual_model.image_encoder.blocks.22.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
551
+ "model.visual_model.image_encoder.blocks.22.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
552
+ "model.visual_model.image_encoder.blocks.22.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
553
+ "model.visual_model.image_encoder.blocks.22.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
554
+ "model.visual_model.image_encoder.blocks.22.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
555
+ "model.visual_model.image_encoder.blocks.22.norm1.bias": "pytorch_model-00002-of-00002.bin",
556
+ "model.visual_model.image_encoder.blocks.22.norm1.weight": "pytorch_model-00002-of-00002.bin",
557
+ "model.visual_model.image_encoder.blocks.22.norm2.bias": "pytorch_model-00002-of-00002.bin",
558
+ "model.visual_model.image_encoder.blocks.22.norm2.weight": "pytorch_model-00002-of-00002.bin",
559
+ "model.visual_model.image_encoder.blocks.23.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
560
+ "model.visual_model.image_encoder.blocks.23.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
561
+ "model.visual_model.image_encoder.blocks.23.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
562
+ "model.visual_model.image_encoder.blocks.23.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
563
+ "model.visual_model.image_encoder.blocks.23.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
564
+ "model.visual_model.image_encoder.blocks.23.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
565
+ "model.visual_model.image_encoder.blocks.23.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
566
+ "model.visual_model.image_encoder.blocks.23.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
567
+ "model.visual_model.image_encoder.blocks.23.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
568
+ "model.visual_model.image_encoder.blocks.23.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
569
+ "model.visual_model.image_encoder.blocks.23.norm1.bias": "pytorch_model-00002-of-00002.bin",
570
+ "model.visual_model.image_encoder.blocks.23.norm1.weight": "pytorch_model-00002-of-00002.bin",
571
+ "model.visual_model.image_encoder.blocks.23.norm2.bias": "pytorch_model-00002-of-00002.bin",
572
+ "model.visual_model.image_encoder.blocks.23.norm2.weight": "pytorch_model-00002-of-00002.bin",
573
+ "model.visual_model.image_encoder.blocks.24.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
574
+ "model.visual_model.image_encoder.blocks.24.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
575
+ "model.visual_model.image_encoder.blocks.24.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
576
+ "model.visual_model.image_encoder.blocks.24.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
577
+ "model.visual_model.image_encoder.blocks.24.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
578
+ "model.visual_model.image_encoder.blocks.24.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
579
+ "model.visual_model.image_encoder.blocks.24.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
580
+ "model.visual_model.image_encoder.blocks.24.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
581
+ "model.visual_model.image_encoder.blocks.24.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
582
+ "model.visual_model.image_encoder.blocks.24.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
583
+ "model.visual_model.image_encoder.blocks.24.norm1.bias": "pytorch_model-00002-of-00002.bin",
584
+ "model.visual_model.image_encoder.blocks.24.norm1.weight": "pytorch_model-00002-of-00002.bin",
585
+ "model.visual_model.image_encoder.blocks.24.norm2.bias": "pytorch_model-00002-of-00002.bin",
586
+ "model.visual_model.image_encoder.blocks.24.norm2.weight": "pytorch_model-00002-of-00002.bin",
587
+ "model.visual_model.image_encoder.blocks.25.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
588
+ "model.visual_model.image_encoder.blocks.25.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
589
+ "model.visual_model.image_encoder.blocks.25.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
590
+ "model.visual_model.image_encoder.blocks.25.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
591
+ "model.visual_model.image_encoder.blocks.25.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
592
+ "model.visual_model.image_encoder.blocks.25.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
593
+ "model.visual_model.image_encoder.blocks.25.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
594
+ "model.visual_model.image_encoder.blocks.25.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
595
+ "model.visual_model.image_encoder.blocks.25.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
596
+ "model.visual_model.image_encoder.blocks.25.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
597
+ "model.visual_model.image_encoder.blocks.25.norm1.bias": "pytorch_model-00002-of-00002.bin",
598
+ "model.visual_model.image_encoder.blocks.25.norm1.weight": "pytorch_model-00002-of-00002.bin",
599
+ "model.visual_model.image_encoder.blocks.25.norm2.bias": "pytorch_model-00002-of-00002.bin",
600
+ "model.visual_model.image_encoder.blocks.25.norm2.weight": "pytorch_model-00002-of-00002.bin",
601
+ "model.visual_model.image_encoder.blocks.26.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
602
+ "model.visual_model.image_encoder.blocks.26.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
603
+ "model.visual_model.image_encoder.blocks.26.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
604
+ "model.visual_model.image_encoder.blocks.26.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
605
+ "model.visual_model.image_encoder.blocks.26.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
606
+ "model.visual_model.image_encoder.blocks.26.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
607
+ "model.visual_model.image_encoder.blocks.26.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
608
+ "model.visual_model.image_encoder.blocks.26.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
609
+ "model.visual_model.image_encoder.blocks.26.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
610
+ "model.visual_model.image_encoder.blocks.26.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
611
+ "model.visual_model.image_encoder.blocks.26.norm1.bias": "pytorch_model-00002-of-00002.bin",
612
+ "model.visual_model.image_encoder.blocks.26.norm1.weight": "pytorch_model-00002-of-00002.bin",
613
+ "model.visual_model.image_encoder.blocks.26.norm2.bias": "pytorch_model-00002-of-00002.bin",
614
+ "model.visual_model.image_encoder.blocks.26.norm2.weight": "pytorch_model-00002-of-00002.bin",
615
+ "model.visual_model.image_encoder.blocks.27.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
616
+ "model.visual_model.image_encoder.blocks.27.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
617
+ "model.visual_model.image_encoder.blocks.27.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
618
+ "model.visual_model.image_encoder.blocks.27.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
619
+ "model.visual_model.image_encoder.blocks.27.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
620
+ "model.visual_model.image_encoder.blocks.27.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
621
+ "model.visual_model.image_encoder.blocks.27.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
622
+ "model.visual_model.image_encoder.blocks.27.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
623
+ "model.visual_model.image_encoder.blocks.27.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
624
+ "model.visual_model.image_encoder.blocks.27.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
625
+ "model.visual_model.image_encoder.blocks.27.norm1.bias": "pytorch_model-00002-of-00002.bin",
626
+ "model.visual_model.image_encoder.blocks.27.norm1.weight": "pytorch_model-00002-of-00002.bin",
627
+ "model.visual_model.image_encoder.blocks.27.norm2.bias": "pytorch_model-00002-of-00002.bin",
628
+ "model.visual_model.image_encoder.blocks.27.norm2.weight": "pytorch_model-00002-of-00002.bin",
629
+ "model.visual_model.image_encoder.blocks.28.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
630
+ "model.visual_model.image_encoder.blocks.28.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
631
+ "model.visual_model.image_encoder.blocks.28.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
632
+ "model.visual_model.image_encoder.blocks.28.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
633
+ "model.visual_model.image_encoder.blocks.28.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
634
+ "model.visual_model.image_encoder.blocks.28.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
635
+ "model.visual_model.image_encoder.blocks.28.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
636
+ "model.visual_model.image_encoder.blocks.28.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
637
+ "model.visual_model.image_encoder.blocks.28.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
638
+ "model.visual_model.image_encoder.blocks.28.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
639
+ "model.visual_model.image_encoder.blocks.28.norm1.bias": "pytorch_model-00002-of-00002.bin",
640
+ "model.visual_model.image_encoder.blocks.28.norm1.weight": "pytorch_model-00002-of-00002.bin",
641
+ "model.visual_model.image_encoder.blocks.28.norm2.bias": "pytorch_model-00002-of-00002.bin",
642
+ "model.visual_model.image_encoder.blocks.28.norm2.weight": "pytorch_model-00002-of-00002.bin",
643
+ "model.visual_model.image_encoder.blocks.29.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
644
+ "model.visual_model.image_encoder.blocks.29.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
645
+ "model.visual_model.image_encoder.blocks.29.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
646
+ "model.visual_model.image_encoder.blocks.29.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
647
+ "model.visual_model.image_encoder.blocks.29.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
648
+ "model.visual_model.image_encoder.blocks.29.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
649
+ "model.visual_model.image_encoder.blocks.29.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
650
+ "model.visual_model.image_encoder.blocks.29.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
651
+ "model.visual_model.image_encoder.blocks.29.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
652
+ "model.visual_model.image_encoder.blocks.29.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
653
+ "model.visual_model.image_encoder.blocks.29.norm1.bias": "pytorch_model-00002-of-00002.bin",
654
+ "model.visual_model.image_encoder.blocks.29.norm1.weight": "pytorch_model-00002-of-00002.bin",
655
+ "model.visual_model.image_encoder.blocks.29.norm2.bias": "pytorch_model-00002-of-00002.bin",
656
+ "model.visual_model.image_encoder.blocks.29.norm2.weight": "pytorch_model-00002-of-00002.bin",
657
+ "model.visual_model.image_encoder.blocks.3.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
658
+ "model.visual_model.image_encoder.blocks.3.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
659
+ "model.visual_model.image_encoder.blocks.3.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
660
+ "model.visual_model.image_encoder.blocks.3.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
661
+ "model.visual_model.image_encoder.blocks.3.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
662
+ "model.visual_model.image_encoder.blocks.3.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
663
+ "model.visual_model.image_encoder.blocks.3.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
664
+ "model.visual_model.image_encoder.blocks.3.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
665
+ "model.visual_model.image_encoder.blocks.3.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
666
+ "model.visual_model.image_encoder.blocks.3.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
667
+ "model.visual_model.image_encoder.blocks.3.norm1.bias": "pytorch_model-00002-of-00002.bin",
668
+ "model.visual_model.image_encoder.blocks.3.norm1.weight": "pytorch_model-00002-of-00002.bin",
669
+ "model.visual_model.image_encoder.blocks.3.norm2.bias": "pytorch_model-00002-of-00002.bin",
670
+ "model.visual_model.image_encoder.blocks.3.norm2.weight": "pytorch_model-00002-of-00002.bin",
671
+ "model.visual_model.image_encoder.blocks.30.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
672
+ "model.visual_model.image_encoder.blocks.30.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
673
+ "model.visual_model.image_encoder.blocks.30.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
674
+ "model.visual_model.image_encoder.blocks.30.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
675
+ "model.visual_model.image_encoder.blocks.30.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
676
+ "model.visual_model.image_encoder.blocks.30.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
677
+ "model.visual_model.image_encoder.blocks.30.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
678
+ "model.visual_model.image_encoder.blocks.30.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
679
+ "model.visual_model.image_encoder.blocks.30.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
680
+ "model.visual_model.image_encoder.blocks.30.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
681
+ "model.visual_model.image_encoder.blocks.30.norm1.bias": "pytorch_model-00002-of-00002.bin",
682
+ "model.visual_model.image_encoder.blocks.30.norm1.weight": "pytorch_model-00002-of-00002.bin",
683
+ "model.visual_model.image_encoder.blocks.30.norm2.bias": "pytorch_model-00002-of-00002.bin",
684
+ "model.visual_model.image_encoder.blocks.30.norm2.weight": "pytorch_model-00002-of-00002.bin",
685
+ "model.visual_model.image_encoder.blocks.31.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
686
+ "model.visual_model.image_encoder.blocks.31.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
687
+ "model.visual_model.image_encoder.blocks.31.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
688
+ "model.visual_model.image_encoder.blocks.31.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
689
+ "model.visual_model.image_encoder.blocks.31.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
690
+ "model.visual_model.image_encoder.blocks.31.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
691
+ "model.visual_model.image_encoder.blocks.31.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
692
+ "model.visual_model.image_encoder.blocks.31.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
693
+ "model.visual_model.image_encoder.blocks.31.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
694
+ "model.visual_model.image_encoder.blocks.31.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
695
+ "model.visual_model.image_encoder.blocks.31.norm1.bias": "pytorch_model-00002-of-00002.bin",
696
+ "model.visual_model.image_encoder.blocks.31.norm1.weight": "pytorch_model-00002-of-00002.bin",
697
+ "model.visual_model.image_encoder.blocks.31.norm2.bias": "pytorch_model-00002-of-00002.bin",
698
+ "model.visual_model.image_encoder.blocks.31.norm2.weight": "pytorch_model-00002-of-00002.bin",
699
+ "model.visual_model.image_encoder.blocks.4.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
700
+ "model.visual_model.image_encoder.blocks.4.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
701
+ "model.visual_model.image_encoder.blocks.4.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
702
+ "model.visual_model.image_encoder.blocks.4.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
703
+ "model.visual_model.image_encoder.blocks.4.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
704
+ "model.visual_model.image_encoder.blocks.4.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
705
+ "model.visual_model.image_encoder.blocks.4.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
706
+ "model.visual_model.image_encoder.blocks.4.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
707
+ "model.visual_model.image_encoder.blocks.4.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
708
+ "model.visual_model.image_encoder.blocks.4.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
709
+ "model.visual_model.image_encoder.blocks.4.norm1.bias": "pytorch_model-00002-of-00002.bin",
710
+ "model.visual_model.image_encoder.blocks.4.norm1.weight": "pytorch_model-00002-of-00002.bin",
711
+ "model.visual_model.image_encoder.blocks.4.norm2.bias": "pytorch_model-00002-of-00002.bin",
712
+ "model.visual_model.image_encoder.blocks.4.norm2.weight": "pytorch_model-00002-of-00002.bin",
713
+ "model.visual_model.image_encoder.blocks.5.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
714
+ "model.visual_model.image_encoder.blocks.5.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
715
+ "model.visual_model.image_encoder.blocks.5.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
716
+ "model.visual_model.image_encoder.blocks.5.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
717
+ "model.visual_model.image_encoder.blocks.5.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
718
+ "model.visual_model.image_encoder.blocks.5.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
719
+ "model.visual_model.image_encoder.blocks.5.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
720
+ "model.visual_model.image_encoder.blocks.5.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
721
+ "model.visual_model.image_encoder.blocks.5.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
722
+ "model.visual_model.image_encoder.blocks.5.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
723
+ "model.visual_model.image_encoder.blocks.5.norm1.bias": "pytorch_model-00002-of-00002.bin",
724
+ "model.visual_model.image_encoder.blocks.5.norm1.weight": "pytorch_model-00002-of-00002.bin",
725
+ "model.visual_model.image_encoder.blocks.5.norm2.bias": "pytorch_model-00002-of-00002.bin",
726
+ "model.visual_model.image_encoder.blocks.5.norm2.weight": "pytorch_model-00002-of-00002.bin",
727
+ "model.visual_model.image_encoder.blocks.6.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
728
+ "model.visual_model.image_encoder.blocks.6.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
729
+ "model.visual_model.image_encoder.blocks.6.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
730
+ "model.visual_model.image_encoder.blocks.6.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
731
+ "model.visual_model.image_encoder.blocks.6.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
732
+ "model.visual_model.image_encoder.blocks.6.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
733
+ "model.visual_model.image_encoder.blocks.6.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
734
+ "model.visual_model.image_encoder.blocks.6.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
735
+ "model.visual_model.image_encoder.blocks.6.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
736
+ "model.visual_model.image_encoder.blocks.6.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
737
+ "model.visual_model.image_encoder.blocks.6.norm1.bias": "pytorch_model-00002-of-00002.bin",
738
+ "model.visual_model.image_encoder.blocks.6.norm1.weight": "pytorch_model-00002-of-00002.bin",
739
+ "model.visual_model.image_encoder.blocks.6.norm2.bias": "pytorch_model-00002-of-00002.bin",
740
+ "model.visual_model.image_encoder.blocks.6.norm2.weight": "pytorch_model-00002-of-00002.bin",
741
+ "model.visual_model.image_encoder.blocks.7.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
742
+ "model.visual_model.image_encoder.blocks.7.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
743
+ "model.visual_model.image_encoder.blocks.7.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
744
+ "model.visual_model.image_encoder.blocks.7.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
745
+ "model.visual_model.image_encoder.blocks.7.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
746
+ "model.visual_model.image_encoder.blocks.7.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
747
+ "model.visual_model.image_encoder.blocks.7.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
748
+ "model.visual_model.image_encoder.blocks.7.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
749
+ "model.visual_model.image_encoder.blocks.7.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
750
+ "model.visual_model.image_encoder.blocks.7.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
751
+ "model.visual_model.image_encoder.blocks.7.norm1.bias": "pytorch_model-00002-of-00002.bin",
752
+ "model.visual_model.image_encoder.blocks.7.norm1.weight": "pytorch_model-00002-of-00002.bin",
753
+ "model.visual_model.image_encoder.blocks.7.norm2.bias": "pytorch_model-00002-of-00002.bin",
754
+ "model.visual_model.image_encoder.blocks.7.norm2.weight": "pytorch_model-00002-of-00002.bin",
755
+ "model.visual_model.image_encoder.blocks.8.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
756
+ "model.visual_model.image_encoder.blocks.8.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
757
+ "model.visual_model.image_encoder.blocks.8.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
758
+ "model.visual_model.image_encoder.blocks.8.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
759
+ "model.visual_model.image_encoder.blocks.8.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
760
+ "model.visual_model.image_encoder.blocks.8.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
761
+ "model.visual_model.image_encoder.blocks.8.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
762
+ "model.visual_model.image_encoder.blocks.8.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
763
+ "model.visual_model.image_encoder.blocks.8.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
764
+ "model.visual_model.image_encoder.blocks.8.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
765
+ "model.visual_model.image_encoder.blocks.8.norm1.bias": "pytorch_model-00002-of-00002.bin",
766
+ "model.visual_model.image_encoder.blocks.8.norm1.weight": "pytorch_model-00002-of-00002.bin",
767
+ "model.visual_model.image_encoder.blocks.8.norm2.bias": "pytorch_model-00002-of-00002.bin",
768
+ "model.visual_model.image_encoder.blocks.8.norm2.weight": "pytorch_model-00002-of-00002.bin",
769
+ "model.visual_model.image_encoder.blocks.9.attn.proj.bias": "pytorch_model-00002-of-00002.bin",
770
+ "model.visual_model.image_encoder.blocks.9.attn.proj.weight": "pytorch_model-00002-of-00002.bin",
771
+ "model.visual_model.image_encoder.blocks.9.attn.qkv.bias": "pytorch_model-00002-of-00002.bin",
772
+ "model.visual_model.image_encoder.blocks.9.attn.qkv.weight": "pytorch_model-00002-of-00002.bin",
773
+ "model.visual_model.image_encoder.blocks.9.attn.rel_pos_h": "pytorch_model-00002-of-00002.bin",
774
+ "model.visual_model.image_encoder.blocks.9.attn.rel_pos_w": "pytorch_model-00002-of-00002.bin",
775
+ "model.visual_model.image_encoder.blocks.9.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
776
+ "model.visual_model.image_encoder.blocks.9.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
777
+ "model.visual_model.image_encoder.blocks.9.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
778
+ "model.visual_model.image_encoder.blocks.9.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
779
+ "model.visual_model.image_encoder.blocks.9.norm1.bias": "pytorch_model-00002-of-00002.bin",
780
+ "model.visual_model.image_encoder.blocks.9.norm1.weight": "pytorch_model-00002-of-00002.bin",
781
+ "model.visual_model.image_encoder.blocks.9.norm2.bias": "pytorch_model-00002-of-00002.bin",
782
+ "model.visual_model.image_encoder.blocks.9.norm2.weight": "pytorch_model-00002-of-00002.bin",
783
+ "model.visual_model.image_encoder.neck.0.weight": "pytorch_model-00002-of-00002.bin",
784
+ "model.visual_model.image_encoder.neck.1.bias": "pytorch_model-00002-of-00002.bin",
785
+ "model.visual_model.image_encoder.neck.1.weight": "pytorch_model-00002-of-00002.bin",
786
+ "model.visual_model.image_encoder.neck.2.weight": "pytorch_model-00002-of-00002.bin",
787
+ "model.visual_model.image_encoder.neck.3.bias": "pytorch_model-00002-of-00002.bin",
788
+ "model.visual_model.image_encoder.neck.3.weight": "pytorch_model-00002-of-00002.bin",
789
+ "model.visual_model.image_encoder.patch_embed.proj.bias": "pytorch_model-00002-of-00002.bin",
790
+ "model.visual_model.image_encoder.patch_embed.proj.weight": "pytorch_model-00002-of-00002.bin",
791
+ "model.visual_model.image_encoder.pos_embed": "pytorch_model-00002-of-00002.bin",
792
+ "model.visual_model.mask_decoder.iou_prediction_head.layers.0.bias": "pytorch_model-00002-of-00002.bin",
793
+ "model.visual_model.mask_decoder.iou_prediction_head.layers.0.weight": "pytorch_model-00002-of-00002.bin",
794
+ "model.visual_model.mask_decoder.iou_prediction_head.layers.1.bias": "pytorch_model-00002-of-00002.bin",
795
+ "model.visual_model.mask_decoder.iou_prediction_head.layers.1.weight": "pytorch_model-00002-of-00002.bin",
796
+ "model.visual_model.mask_decoder.iou_prediction_head.layers.2.bias": "pytorch_model-00002-of-00002.bin",
797
+ "model.visual_model.mask_decoder.iou_prediction_head.layers.2.weight": "pytorch_model-00002-of-00002.bin",
798
+ "model.visual_model.mask_decoder.iou_token.weight": "pytorch_model-00002-of-00002.bin",
799
+ "model.visual_model.mask_decoder.mask_tokens.weight": "pytorch_model-00002-of-00002.bin",
800
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.0.layers.0.bias": "pytorch_model-00002-of-00002.bin",
801
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.0.layers.0.weight": "pytorch_model-00002-of-00002.bin",
802
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.0.layers.1.bias": "pytorch_model-00002-of-00002.bin",
803
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.0.layers.1.weight": "pytorch_model-00002-of-00002.bin",
804
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.0.layers.2.bias": "pytorch_model-00002-of-00002.bin",
805
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.0.layers.2.weight": "pytorch_model-00002-of-00002.bin",
806
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.1.layers.0.bias": "pytorch_model-00002-of-00002.bin",
807
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.1.layers.0.weight": "pytorch_model-00002-of-00002.bin",
808
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.1.layers.1.bias": "pytorch_model-00002-of-00002.bin",
809
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.1.layers.1.weight": "pytorch_model-00002-of-00002.bin",
810
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.1.layers.2.bias": "pytorch_model-00002-of-00002.bin",
811
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.1.layers.2.weight": "pytorch_model-00002-of-00002.bin",
812
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.2.layers.0.bias": "pytorch_model-00002-of-00002.bin",
813
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.2.layers.0.weight": "pytorch_model-00002-of-00002.bin",
814
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.2.layers.1.bias": "pytorch_model-00002-of-00002.bin",
815
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.2.layers.1.weight": "pytorch_model-00002-of-00002.bin",
816
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.2.layers.2.bias": "pytorch_model-00002-of-00002.bin",
817
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.2.layers.2.weight": "pytorch_model-00002-of-00002.bin",
818
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.3.layers.0.bias": "pytorch_model-00002-of-00002.bin",
819
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.3.layers.0.weight": "pytorch_model-00002-of-00002.bin",
820
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.3.layers.1.bias": "pytorch_model-00002-of-00002.bin",
821
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.3.layers.1.weight": "pytorch_model-00002-of-00002.bin",
822
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.3.layers.2.bias": "pytorch_model-00002-of-00002.bin",
823
+ "model.visual_model.mask_decoder.output_hypernetworks_mlps.3.layers.2.weight": "pytorch_model-00002-of-00002.bin",
824
+ "model.visual_model.mask_decoder.output_upscaling.0.bias": "pytorch_model-00002-of-00002.bin",
825
+ "model.visual_model.mask_decoder.output_upscaling.0.weight": "pytorch_model-00002-of-00002.bin",
826
+ "model.visual_model.mask_decoder.output_upscaling.1.bias": "pytorch_model-00002-of-00002.bin",
827
+ "model.visual_model.mask_decoder.output_upscaling.1.weight": "pytorch_model-00002-of-00002.bin",
828
+ "model.visual_model.mask_decoder.output_upscaling.3.bias": "pytorch_model-00002-of-00002.bin",
829
+ "model.visual_model.mask_decoder.output_upscaling.3.weight": "pytorch_model-00002-of-00002.bin",
830
+ "model.visual_model.mask_decoder.transformer.final_attn_token_to_image.k_proj.bias": "pytorch_model-00002-of-00002.bin",
831
+ "model.visual_model.mask_decoder.transformer.final_attn_token_to_image.k_proj.weight": "pytorch_model-00002-of-00002.bin",
832
+ "model.visual_model.mask_decoder.transformer.final_attn_token_to_image.out_proj.bias": "pytorch_model-00002-of-00002.bin",
833
+ "model.visual_model.mask_decoder.transformer.final_attn_token_to_image.out_proj.weight": "pytorch_model-00002-of-00002.bin",
834
+ "model.visual_model.mask_decoder.transformer.final_attn_token_to_image.q_proj.bias": "pytorch_model-00002-of-00002.bin",
835
+ "model.visual_model.mask_decoder.transformer.final_attn_token_to_image.q_proj.weight": "pytorch_model-00002-of-00002.bin",
836
+ "model.visual_model.mask_decoder.transformer.final_attn_token_to_image.v_proj.bias": "pytorch_model-00002-of-00002.bin",
837
+ "model.visual_model.mask_decoder.transformer.final_attn_token_to_image.v_proj.weight": "pytorch_model-00002-of-00002.bin",
838
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.bias": "pytorch_model-00002-of-00002.bin",
839
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.weight": "pytorch_model-00002-of-00002.bin",
840
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.bias": "pytorch_model-00002-of-00002.bin",
841
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.weight": "pytorch_model-00002-of-00002.bin",
842
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.bias": "pytorch_model-00002-of-00002.bin",
843
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.weight": "pytorch_model-00002-of-00002.bin",
844
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.bias": "pytorch_model-00002-of-00002.bin",
845
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.weight": "pytorch_model-00002-of-00002.bin",
846
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.bias": "pytorch_model-00002-of-00002.bin",
847
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.weight": "pytorch_model-00002-of-00002.bin",
848
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.bias": "pytorch_model-00002-of-00002.bin",
849
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.weight": "pytorch_model-00002-of-00002.bin",
850
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.bias": "pytorch_model-00002-of-00002.bin",
851
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.weight": "pytorch_model-00002-of-00002.bin",
852
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.bias": "pytorch_model-00002-of-00002.bin",
853
+ "model.visual_model.mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.weight": "pytorch_model-00002-of-00002.bin",
854
+ "model.visual_model.mask_decoder.transformer.layers.0.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
855
+ "model.visual_model.mask_decoder.transformer.layers.0.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
856
+ "model.visual_model.mask_decoder.transformer.layers.0.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
857
+ "model.visual_model.mask_decoder.transformer.layers.0.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
858
+ "model.visual_model.mask_decoder.transformer.layers.0.norm1.bias": "pytorch_model-00002-of-00002.bin",
859
+ "model.visual_model.mask_decoder.transformer.layers.0.norm1.weight": "pytorch_model-00002-of-00002.bin",
860
+ "model.visual_model.mask_decoder.transformer.layers.0.norm2.bias": "pytorch_model-00002-of-00002.bin",
861
+ "model.visual_model.mask_decoder.transformer.layers.0.norm2.weight": "pytorch_model-00002-of-00002.bin",
862
+ "model.visual_model.mask_decoder.transformer.layers.0.norm3.bias": "pytorch_model-00002-of-00002.bin",
863
+ "model.visual_model.mask_decoder.transformer.layers.0.norm3.weight": "pytorch_model-00002-of-00002.bin",
864
+ "model.visual_model.mask_decoder.transformer.layers.0.norm4.bias": "pytorch_model-00002-of-00002.bin",
865
+ "model.visual_model.mask_decoder.transformer.layers.0.norm4.weight": "pytorch_model-00002-of-00002.bin",
866
+ "model.visual_model.mask_decoder.transformer.layers.0.self_attn.k_proj.bias": "pytorch_model-00002-of-00002.bin",
867
+ "model.visual_model.mask_decoder.transformer.layers.0.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
868
+ "model.visual_model.mask_decoder.transformer.layers.0.self_attn.out_proj.bias": "pytorch_model-00002-of-00002.bin",
869
+ "model.visual_model.mask_decoder.transformer.layers.0.self_attn.out_proj.weight": "pytorch_model-00002-of-00002.bin",
870
+ "model.visual_model.mask_decoder.transformer.layers.0.self_attn.q_proj.bias": "pytorch_model-00002-of-00002.bin",
871
+ "model.visual_model.mask_decoder.transformer.layers.0.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
872
+ "model.visual_model.mask_decoder.transformer.layers.0.self_attn.v_proj.bias": "pytorch_model-00002-of-00002.bin",
873
+ "model.visual_model.mask_decoder.transformer.layers.0.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
874
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.bias": "pytorch_model-00002-of-00002.bin",
875
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.weight": "pytorch_model-00002-of-00002.bin",
876
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.bias": "pytorch_model-00002-of-00002.bin",
877
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.weight": "pytorch_model-00002-of-00002.bin",
878
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.bias": "pytorch_model-00002-of-00002.bin",
879
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.weight": "pytorch_model-00002-of-00002.bin",
880
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.bias": "pytorch_model-00002-of-00002.bin",
881
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.weight": "pytorch_model-00002-of-00002.bin",
882
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.bias": "pytorch_model-00002-of-00002.bin",
883
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.weight": "pytorch_model-00002-of-00002.bin",
884
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.bias": "pytorch_model-00002-of-00002.bin",
885
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.weight": "pytorch_model-00002-of-00002.bin",
886
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.bias": "pytorch_model-00002-of-00002.bin",
887
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.weight": "pytorch_model-00002-of-00002.bin",
888
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.bias": "pytorch_model-00002-of-00002.bin",
889
+ "model.visual_model.mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.weight": "pytorch_model-00002-of-00002.bin",
890
+ "model.visual_model.mask_decoder.transformer.layers.1.mlp.lin1.bias": "pytorch_model-00002-of-00002.bin",
891
+ "model.visual_model.mask_decoder.transformer.layers.1.mlp.lin1.weight": "pytorch_model-00002-of-00002.bin",
892
+ "model.visual_model.mask_decoder.transformer.layers.1.mlp.lin2.bias": "pytorch_model-00002-of-00002.bin",
893
+ "model.visual_model.mask_decoder.transformer.layers.1.mlp.lin2.weight": "pytorch_model-00002-of-00002.bin",
894
+ "model.visual_model.mask_decoder.transformer.layers.1.norm1.bias": "pytorch_model-00002-of-00002.bin",
895
+ "model.visual_model.mask_decoder.transformer.layers.1.norm1.weight": "pytorch_model-00002-of-00002.bin",
896
+ "model.visual_model.mask_decoder.transformer.layers.1.norm2.bias": "pytorch_model-00002-of-00002.bin",
897
+ "model.visual_model.mask_decoder.transformer.layers.1.norm2.weight": "pytorch_model-00002-of-00002.bin",
898
+ "model.visual_model.mask_decoder.transformer.layers.1.norm3.bias": "pytorch_model-00002-of-00002.bin",
899
+ "model.visual_model.mask_decoder.transformer.layers.1.norm3.weight": "pytorch_model-00002-of-00002.bin",
900
+ "model.visual_model.mask_decoder.transformer.layers.1.norm4.bias": "pytorch_model-00002-of-00002.bin",
901
+ "model.visual_model.mask_decoder.transformer.layers.1.norm4.weight": "pytorch_model-00002-of-00002.bin",
902
+ "model.visual_model.mask_decoder.transformer.layers.1.self_attn.k_proj.bias": "pytorch_model-00002-of-00002.bin",
903
+ "model.visual_model.mask_decoder.transformer.layers.1.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
904
+ "model.visual_model.mask_decoder.transformer.layers.1.self_attn.out_proj.bias": "pytorch_model-00002-of-00002.bin",
905
+ "model.visual_model.mask_decoder.transformer.layers.1.self_attn.out_proj.weight": "pytorch_model-00002-of-00002.bin",
906
+ "model.visual_model.mask_decoder.transformer.layers.1.self_attn.q_proj.bias": "pytorch_model-00002-of-00002.bin",
907
+ "model.visual_model.mask_decoder.transformer.layers.1.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
908
+ "model.visual_model.mask_decoder.transformer.layers.1.self_attn.v_proj.bias": "pytorch_model-00002-of-00002.bin",
909
+ "model.visual_model.mask_decoder.transformer.layers.1.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
910
+ "model.visual_model.mask_decoder.transformer.norm_final_attn.bias": "pytorch_model-00002-of-00002.bin",
911
+ "model.visual_model.mask_decoder.transformer.norm_final_attn.weight": "pytorch_model-00002-of-00002.bin",
912
+ "model.visual_model.prompt_encoder.mask_downscaling.0.bias": "pytorch_model-00002-of-00002.bin",
913
+ "model.visual_model.prompt_encoder.mask_downscaling.0.weight": "pytorch_model-00002-of-00002.bin",
914
+ "model.visual_model.prompt_encoder.mask_downscaling.1.bias": "pytorch_model-00002-of-00002.bin",
915
+ "model.visual_model.prompt_encoder.mask_downscaling.1.weight": "pytorch_model-00002-of-00002.bin",
916
+ "model.visual_model.prompt_encoder.mask_downscaling.3.bias": "pytorch_model-00002-of-00002.bin",
917
+ "model.visual_model.prompt_encoder.mask_downscaling.3.weight": "pytorch_model-00002-of-00002.bin",
918
+ "model.visual_model.prompt_encoder.mask_downscaling.4.bias": "pytorch_model-00002-of-00002.bin",
919
+ "model.visual_model.prompt_encoder.mask_downscaling.4.weight": "pytorch_model-00002-of-00002.bin",
920
+ "model.visual_model.prompt_encoder.mask_downscaling.6.bias": "pytorch_model-00002-of-00002.bin",
921
+ "model.visual_model.prompt_encoder.mask_downscaling.6.weight": "pytorch_model-00002-of-00002.bin",
922
+ "model.visual_model.prompt_encoder.no_mask_embed.weight": "pytorch_model-00002-of-00002.bin",
923
+ "model.visual_model.prompt_encoder.not_a_point_embed.weight": "pytorch_model-00002-of-00002.bin",
924
+ "model.visual_model.prompt_encoder.pe_layer.positional_encoding_gaussian_matrix": "pytorch_model-00002-of-00002.bin",
925
+ "model.visual_model.prompt_encoder.point_embeddings.0.weight": "pytorch_model-00002-of-00002.bin",
926
+ "model.visual_model.prompt_encoder.point_embeddings.1.weight": "pytorch_model-00002-of-00002.bin",
927
+ "model.visual_model.prompt_encoder.point_embeddings.2.weight": "pytorch_model-00002-of-00002.bin",
928
+ "model.visual_model.prompt_encoder.point_embeddings.3.weight": "pytorch_model-00002-of-00002.bin"
929
+ }
930
+ }
ckpts/AffordanceVLM-7B/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<unk>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
ckpts/AffordanceVLM-7B/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
ckpts/AffordanceVLM-7B/tokenizer_config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "</s>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "legacy": true,
22
+ "model_max_length": 512,
23
+ "pad_token": null,
24
+ "padding_side": "right",
25
+ "sp_model_kwargs": {},
26
+ "tokenizer_class": "LlamaTokenizer",
27
+ "unk_token": {
28
+ "__type": "AddedToken",
29
+ "content": "<unk>",
30
+ "lstrip": false,
31
+ "normalized": true,
32
+ "rstrip": false,
33
+ "single_word": false
34
+ }
35
+ }
ckpts/sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879
client.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Client script to send an image and prompt to a Flask-based vision-language segmentation server.
4
+
5
+ from __future__ import absolute_import, print_function, division
6
+ import requests
7
+ import cv2
8
+ import base64
9
+ import numpy as np
10
+
11
+ # ---------------------------
12
+ # Encode image to base64 string
13
+ # ---------------------------
14
+ def img2b64(img):
15
+ retval, buffer = cv2.imencode('.bmp', img) # Encode as BMP
16
+ pic_str = base64.b64encode(buffer).decode() # Convert to base64 string
17
+ return pic_str
18
+
19
+ # ---------------------------
20
+ # Decode base64 string back to image
21
+ # ---------------------------
22
+ def b642img(pic_str):
23
+ img_data = base64.b64decode(pic_str)
24
+ nparr = np.frombuffer(img_data, np.uint8)
25
+ img_np = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
26
+ return img_np
27
+
28
+ # ---------------------------
29
+ # Send image and prompt to server, receive result and save
30
+ # ---------------------------
31
+ def post_files():
32
+ path = 'vis_output/my_workspace.JPG' # Input image path
33
+ img = cv2.imread(path)
34
+ if img is None:
35
+ print(f"Failed to read image at {path}")
36
+ return
37
+
38
+ pic_str = img2b64(img)
39
+ data = {
40
+ 'img': pic_str,
41
+ 'prompt': 'Please segment the affordance map of mug in this image.'
42
+ }
43
+
44
+ # Send POST request to Flask server
45
+ r = requests.post('http://localhost:3200/img_mask', json=data)
46
+
47
+ if r.status_code == 200:
48
+ print('Success. Received response from server.')
49
+ result = r.json()
50
+ result_b64 = result.get('img', None)
51
+
52
+ if result_b64:
53
+ result_img = b642img(result_b64)
54
+ save_path = 'affordance_mask_result.jpg'
55
+ cv2.imwrite(save_path, result_img)
56
+ print(f"Result saved to {save_path}")
57
+ else:
58
+ print("No image returned in the response.")
59
+ else:
60
+ print(f"Request failed with status code {r.status_code}")
61
+
62
+ # ---------------------------
63
+ # Main entry
64
+ # ---------------------------
65
+ if __name__ == '__main__':
66
+ post_files()
67
+
data_curation/.ipynb_checkpoints/check_dataset-checkpoint.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle as pkl
3
+
4
+ DATA_DIR = '/gemini/space/wrz/AffordanceNet/data'
5
+
6
+ # 新增一个路径修复函数
7
+ def resolve_path(path):
8
+ """
9
+ 如果路径是相对路径 (比如 ./data/...),将其转换为绝对路径
10
+ """
11
+ if path.startswith('./data/'):
12
+ # 截掉前缀的 './data/' (长度为 7),拼接到真实的 DATA_DIR 后面
13
+ return os.path.join(DATA_DIR, path[7:])
14
+ elif path.startswith('./'):
15
+ # 兼容其他情况
16
+ return os.path.join(os.path.dirname(DATA_DIR), path[2:])
17
+ return path
18
+
19
+
20
+ def get_data_paths():
21
+ """Retrieve train/val/reasoning/non-reasoning pkl file paths."""
22
+ all_files = os.listdir(DATA_DIR)
23
+ train_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('train.pkl')]
24
+ val_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('val.pkl')]
25
+ reasoning_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('reasoning_val.pkl')]
26
+ non_reasoning_paths = [vp for vp in val_paths if vp not in reasoning_paths]
27
+
28
+ return train_paths, reasoning_paths, non_reasoning_paths
29
+
30
+
31
+ def check_file_exists(file_path, description=""):
32
+ """Assert that the file exists, otherwise raise an error."""
33
+ assert os.path.exists(file_path), f"{description} does not exist: {file_path}"
34
+
35
+
36
+ def check_train_data(train_path):
37
+ """Check frame and mask paths for each sample in training data."""
38
+ print(f"[Train] Checking: {train_path}")
39
+ with open(train_path, "rb") as f:
40
+ data = pkl.load(f)
41
+
42
+ for item in data:
43
+ # 修改这里:在检查之前先转换路径
44
+ real_frame_path = resolve_path(item["frame_path"])
45
+ real_mask_path = resolve_path(item["mask_path"])
46
+
47
+ check_file_exists(real_frame_path, "Frame path")
48
+ check_file_exists(real_mask_path, "Mask path")
49
+
50
+ print(f"[Train] ✅ Checked {train_path}. Samples: {len(data)}")
51
+
52
+
53
+ def check_val_data(val_path, reasoning=False):
54
+ """Check validation data paths depending on reasoning mode."""
55
+ tag = "Reasoning Val" if reasoning else "Non-Reasoning Val"
56
+ print(f"[{tag}] Checking: {val_path}")
57
+
58
+ with open(val_path, "rb") as f:
59
+ data = pkl.load(f)
60
+
61
+ if reasoning:
62
+ for item in data:
63
+ # 修改这里
64
+ real_frame_path = resolve_path(item["frame_path"])
65
+ real_mask_path = resolve_path(item["mask_path"])
66
+
67
+ check_file_exists(real_frame_path, "Frame path")
68
+ check_file_exists(real_mask_path, "Mask path")
69
+ print(f"[{tag}] ✅ Checked {val_path}. Samples: {len(data)}")
70
+ else:
71
+ total_images = 0
72
+ for class_name, image_list in data.get('images', {}).items():
73
+ for image_path in image_list:
74
+ # 修改这里
75
+ check_file_exists(resolve_path(image_path), "Image path")
76
+ total_images += len(image_list)
77
+
78
+ for class_name, label_list in data.get('labels', {}).items():
79
+ for label_path in label_list:
80
+ # 修改这里
81
+ check_file_exists(resolve_path(label_path), "Label path")
82
+
83
+ print(f"[{tag}] ✅ Checked {val_path}. Samples: {total_images}")
84
+
85
+
86
+ def main():
87
+ train_paths, reasoning_paths, non_reasoning_paths = get_data_paths()
88
+
89
+ for train_path in train_paths:
90
+ check_train_data(train_path)
91
+
92
+ for val_path in non_reasoning_paths:
93
+ check_val_data(val_path, reasoning=False)
94
+
95
+ for val_path in reasoning_paths:
96
+ check_val_data(val_path, reasoning=True)
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()
data_curation/build_vlpart.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import argparse
3
+ import glob
4
+ import multiprocessing as mp
5
+ import numpy as np
6
+ import os
7
+ import tempfile
8
+ import time
9
+ import warnings
10
+ import cv2
11
+ import tqdm
12
+
13
+ from detectron2.config import get_cfg
14
+ from detectron2.data.detection_utils import read_image
15
+ from detectron2.utils.logger import setup_logger
16
+
17
+ import sys
18
+ sys.path.append('.')
19
+ from VLPart.vlpart.config import add_vlpart_config
20
+
21
+ from VLPart.demo.predictor import VisualizationDemo
22
+
23
+
24
+ # constants
25
+ WINDOW_NAME = "image demo"
26
+
27
+
28
+ def setup_cfg(args):
29
+ # load config from file and command-line arguments
30
+ cfg = get_cfg()
31
+ add_vlpart_config(cfg)
32
+ cfg.merge_from_file(args.config_file)
33
+ cfg.merge_from_list(args.opts)
34
+ # Set score_threshold for builtin models
35
+ cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
36
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
37
+ cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
38
+ cfg.freeze()
39
+ return cfg
40
+
41
+
42
+ def get_parser():
43
+ parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
44
+ parser.add_argument(
45
+ "--config-file",
46
+ default="VLPart/configs/joint/swinbase_cascade_lvis_paco_pascalpart_partimagenet.yaml",
47
+ metavar="FILE",
48
+ help="path to config file",
49
+ )
50
+ parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
51
+ parser.add_argument("--video-input", help="Path to video file.")
52
+ parser.add_argument(
53
+ "--input",
54
+ nargs="+",
55
+ default='',
56
+ help="A list of space separated input images; "
57
+ "or a single glob pattern such as 'directory/*.jpg'",
58
+ )
59
+ parser.add_argument(
60
+ "--output",
61
+ default='',
62
+ help="A file or directory to save output visualizations. "
63
+ "If not given, will show output in an OpenCV window.",
64
+ )
65
+ parser.add_argument(
66
+ "--vocabulary",
67
+ default="custom",
68
+ choices=['pascal_part', 'partimagenet', 'paco',
69
+ 'voc', 'coco', 'lvis',
70
+ 'pascal_part_voc', 'lvis_paco', 'custom'],
71
+ help="",
72
+ )
73
+ parser.add_argument(
74
+ "--custom_vocabulary",
75
+ default="",
76
+ help="",
77
+ )
78
+ parser.add_argument(
79
+ "--confidence-threshold",
80
+ type=float,
81
+ default=0.7,
82
+ help="Minimum score for instance predictions to be shown",
83
+ )
84
+
85
+ parser.add_argument(
86
+ "--opts",
87
+ help="Modify config options using the command-line 'KEY VALUE' pairs",
88
+ default=['MODEL.WEIGHTS', "/data/VLPart/ckpts/swinbase_cascade_lvis_paco_pascalpart_partimagenet.pth", "VIS.BOX", False],
89
+ nargs=argparse.REMAINDER,
90
+ )
91
+ return parser
92
+
93
+ def build_vlpart_model(custom_vocabulary):
94
+
95
+ mp.set_start_method("spawn", force=True)
96
+ args = get_parser().parse_args()
97
+ args.custom_vocabulary = custom_vocabulary
98
+ setup_logger(name="fvcore")
99
+ logger = setup_logger()
100
+ logger.info("Arguments: " + str(args))
101
+
102
+ cfg = setup_cfg(args)
103
+ model = VisualizationDemo(cfg, args)
104
+
105
+ return model
data_curation/check_dataset.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle as pkl
3
+
4
+ DATA_DIR = '/gemini/space/wrz/AffordanceNet/data'
5
+
6
+ # 新增一个路径修复函数
7
+ def resolve_path(path):
8
+ """
9
+ 如果路径是相对路径 (比如 ./data/...),将其转换为绝对路径
10
+ """
11
+ if path.startswith('./data/'):
12
+ # 截掉前缀的 './data/' (长度为 7),拼接到真实的 DATA_DIR 后面
13
+ return os.path.join(DATA_DIR, path[7:])
14
+ elif path.startswith('./'):
15
+ # 兼容其他情况
16
+ return os.path.join(os.path.dirname(DATA_DIR), path[2:])
17
+ return path
18
+
19
+
20
+ def get_data_paths():
21
+ """Retrieve train/val/reasoning/non-reasoning pkl file paths."""
22
+ all_files = os.listdir(DATA_DIR)
23
+ train_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('train.pkl')]
24
+ val_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('val.pkl')]
25
+ reasoning_paths = [os.path.join(DATA_DIR, f) for f in all_files if f.endswith('reasoning_val.pkl')]
26
+ non_reasoning_paths = [vp for vp in val_paths if vp not in reasoning_paths]
27
+
28
+ return train_paths, reasoning_paths, non_reasoning_paths
29
+
30
+
31
+ def check_file_exists(file_path, description=""):
32
+ """Assert that the file exists, otherwise raise an error."""
33
+ assert os.path.exists(file_path), f"{description} does not exist: {file_path}"
34
+
35
+
36
+ def check_train_data(train_path):
37
+ """Check frame and mask paths for each sample in training data."""
38
+ print(f"[Train] Checking: {train_path}")
39
+ with open(train_path, "rb") as f:
40
+ data = pkl.load(f)
41
+
42
+ for item in data:
43
+ # 修改这里:在检查之前先转换路径
44
+ real_frame_path = resolve_path(item["frame_path"])
45
+ real_mask_path = resolve_path(item["mask_path"])
46
+
47
+ check_file_exists(real_frame_path, "Frame path")
48
+ check_file_exists(real_mask_path, "Mask path")
49
+
50
+ print(f"[Train] ✅ Checked {train_path}. Samples: {len(data)}")
51
+
52
+
53
+ def check_val_data(val_path, reasoning=False):
54
+ """Check validation data paths depending on reasoning mode."""
55
+ tag = "Reasoning Val" if reasoning else "Non-Reasoning Val"
56
+ print(f"[{tag}] Checking: {val_path}")
57
+
58
+ with open(val_path, "rb") as f:
59
+ data = pkl.load(f)
60
+
61
+ if reasoning:
62
+ for item in data:
63
+ # 修改这里
64
+ real_frame_path = resolve_path(item["frame_path"])
65
+ real_mask_path = resolve_path(item["mask_path"])
66
+
67
+ check_file_exists(real_frame_path, "Frame path")
68
+ check_file_exists(real_mask_path, "Mask path")
69
+ print(f"[{tag}] ✅ Checked {val_path}. Samples: {len(data)}")
70
+ else:
71
+ total_images = 0
72
+ for class_name, image_list in data.get('images', {}).items():
73
+ for image_path in image_list:
74
+ # 修改这里
75
+ check_file_exists(resolve_path(image_path), "Image path")
76
+ total_images += len(image_list)
77
+
78
+ for class_name, label_list in data.get('labels', {}).items():
79
+ for label_path in label_list:
80
+ # 修改这里
81
+ check_file_exists(resolve_path(label_path), "Label path")
82
+
83
+ print(f"[{tag}] ✅ Checked {val_path}. Samples: {total_images}")
84
+
85
+
86
+ def main():
87
+ train_paths, reasoning_paths, non_reasoning_paths = get_data_paths()
88
+
89
+ for train_path in train_paths:
90
+ check_train_data(train_path)
91
+
92
+ for val_path in non_reasoning_paths:
93
+ check_val_data(val_path, reasoning=False)
94
+
95
+ for val_path in reasoning_paths:
96
+ check_val_data(val_path, reasoning=True)
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()
data_curation/prompt_generation_handal_easy_reasoning.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pickle
4
+ import requests
5
+ from concurrent.futures import ThreadPoolExecutor
6
+
7
+ # Dataset name
8
+ DATASET = 'handal'
9
+
10
+ # Handle-equipped objects to filter
11
+ OBJECTS_WITH_HANDLE = [
12
+ 'strainers', 'fixed joint pliers', 'hammers', 'ladles', 'whisks', 'measuring cups',
13
+ 'locking pliers', 'power drills', 'adjustable wrenches', 'mugs', 'ratchets', 'utensils',
14
+ 'combinational wrenches', 'pots pans', 'spatulas', 'screwdrivers', 'slip joint pliers'
15
+ ]
16
+
17
+ # OpenAI API settings (update key!)
18
+ API_URL = 'https://api.openai.com/v1/chat/completions'
19
+ HEADERS = {
20
+ 'Content-Type': 'application/json',
21
+ 'Authorization': 'Bearer YOUR-API-KEY' # Replace with your real key
22
+ }
23
+
24
+
25
+ def read_pkl_file(pkl_path):
26
+ """Reads pkl file and filters entries for objects with handles."""
27
+ with open(pkl_path, 'rb') as f:
28
+ val_data = pickle.load(f)
29
+
30
+ filtered_data = []
31
+ for class_name, image_list in val_data['images'].items():
32
+ if class_name in OBJECTS_WITH_HANDLE:
33
+ for idx, img in enumerate(image_list):
34
+ class_label = val_data['class_names'][class_name][idx]
35
+ save_path = os.path.join(
36
+ f'./reason_affordance/{DATASET}_easy_reasoning',
37
+ class_label,
38
+ os.path.splitext(os.path.basename(img))[0] + ".json"
39
+ )
40
+ if not os.path.exists(save_path):
41
+ filtered_data.append({'img_name': img, 'class_name': class_label})
42
+ return filtered_data
43
+
44
+
45
+ def process_sentence(class_name):
46
+ """Send prompt to OpenAI and return generated sentence."""
47
+ prompt = [
48
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
49
+ {'role': 'system',
50
+ 'content': (
51
+ 'Based on several words where the first is category name, '
52
+ 'please design an instruction <1> and instruction <2> in embodied scenes. '
53
+ 'The instruction <1> must include object category name itself. '
54
+ 'The instruction <2> must include the object category name itself. '
55
+ 'The instruction <2> must belong to embodied manipulation and give action if instruction <1> provides. '
56
+ 'The instruction <2> does not exceed 50 words.'
57
+ )},
58
+ {'role': 'user', 'content': 'mug'},
59
+ {'role': 'assistant',
60
+ 'content': '<1> I need a drink. Please find a mug to fill water. <2> The mug has a handle as affordance map. So the robot can hold its handle.'},
61
+ {'role': 'user', 'content': 'knife'},
62
+ {'role': 'assistant',
63
+ 'content': '<1> Please give me a knife to cut apple. <2> The knife has a handle, and you can use its handle to cut apple.'},
64
+ {'role': 'user', 'content': 'hammers'},
65
+ {'role': 'assistant',
66
+ 'content': '<1> What is the proper way to hold the hammers? <2> The correct method is to hold the hammer by its handle.'},
67
+ {'role': 'user', 'content': 'fork'},
68
+ {'role': 'assistant',
69
+ 'content': '<1> Kindly pick up the fork. <2> You will be holding the fork handle.'},
70
+ {'role': 'user', 'content': 'screwdrivers'},
71
+ {'role': 'assistant',
72
+ 'content': '<1> I need a tool to tighten or loosen screws. <2> The screwdriver is here, hold its handle to turn and control screws.'},
73
+ {'role': 'user', 'content': class_name}
74
+ ]
75
+
76
+ response = requests.post(API_URL, headers=HEADERS, json={'model': 'gpt-4', 'messages': prompt})
77
+ if response.status_code == 200:
78
+ return response.json()['choices'][0]['message']['content']
79
+ else:
80
+ print(f"API Error for {class_name}:", response.text)
81
+ return None
82
+
83
+
84
+ def process_json(data):
85
+ """Process a single data entry and save result to JSON file."""
86
+ class_name = data["class_name"]
87
+
88
+ # Retry up to 5 times
89
+ for _ in range(5):
90
+ result = process_sentence(class_name)
91
+ if not result or '<1>' not in result or '<2>' not in result:
92
+ continue
93
+ break
94
+ else:
95
+ print(f"Failed to process: {class_name}")
96
+ return
97
+
98
+ print("Processed:", result)
99
+
100
+ try:
101
+ question = result.split('<2>')[0].split('<1>')[-1].strip()
102
+ answer = result.split('<2>')[-1].strip()
103
+
104
+ save_dir = os.path.join(f'./reason_affordance/{DATASET}_easy_reasoning', class_name)
105
+ os.makedirs(save_dir, exist_ok=True)
106
+
107
+ save_path = os.path.join(save_dir, os.path.splitext(os.path.basename(data["img_name"]))[0] + ".json")
108
+ output = {'img_name': data["img_name"], 'class_name': class_name, 'question': question, 'answer': answer}
109
+
110
+ with open(save_path, 'w') as f:
111
+ json.dump(output, f, indent=4)
112
+
113
+ except Exception as e:
114
+ print(f"Error saving file for {class_name}:", e)
115
+
116
+
117
+ def main():
118
+ pkl_file = f'./data/{DATASET}_val.pkl'
119
+ data_list = read_pkl_file(pkl_file)
120
+
121
+ with ThreadPoolExecutor(max_workers=2) as executor:
122
+ executor.map(process_json, data_list)
123
+
124
+
125
+ if __name__ == "__main__":
126
+ main()
data_curation/prompt_generation_handal_hard_reasoning.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pickle
4
+ import requests
5
+ from concurrent.futures import ThreadPoolExecutor
6
+
7
+ # Dataset configuration
8
+ DATASET = 'handal'
9
+
10
+ # Object categories with handle
11
+ OBJECTS_WITH_HANDLE = [
12
+ 'strainers', 'fixed joint pliers', 'hammers', 'ladles', 'whisks', 'measuring cups',
13
+ 'locking pliers', 'power drills', 'adjustable wrenches', 'mugs', 'ratchets', 'utensils',
14
+ 'combinational wrenches', 'pots pans', 'spatulas', 'screwdrivers', 'slip joint pliers'
15
+ ]
16
+
17
+ # OpenAI API settings (update key!)
18
+ API_URL = 'https://api.openai.com/v1/chat/completions'
19
+ HEADERS = {
20
+ 'Content-Type': 'application/json',
21
+ 'Authorization': 'Bearer YOUR-API-KEY' # Replace with your real key
22
+ }
23
+
24
+
25
+ def read_pkl_file(pkl_path):
26
+ """
27
+ Load a pickle file and extract data entries containing objects with handles,
28
+ skipping already processed samples.
29
+ """
30
+ with open(pkl_path, 'rb') as f:
31
+ val_data = pickle.load(f)
32
+
33
+ filtered_data = []
34
+ for class_name, img_list in val_data['images'].items():
35
+ if class_name not in OBJECTS_WITH_HANDLE:
36
+ continue
37
+ for i, img_path in enumerate(img_list):
38
+ class_label = val_data['class_names'][class_name][i]
39
+ save_path = os.path.join(
40
+ f'./reason_affordance/{DATASET}_hard_reasoning',
41
+ class_label,
42
+ os.path.splitext(os.path.basename(img_path))[0] + ".json"
43
+ )
44
+ if not os.path.exists(save_path):
45
+ filtered_data.append({'img_name': img_path, 'class_name': class_label})
46
+
47
+ return filtered_data
48
+
49
+
50
+ def process_sentence(category):
51
+ """
52
+ Generate reasoning instructions (<1>, <2>) from category name using GPT.
53
+ """
54
+ payload = {
55
+ 'model': 'gpt-4',
56
+ 'messages': [
57
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
58
+ {'role': 'system',
59
+ 'content': (
60
+ 'Based on several words where the first is category name, please design an instruction <1> and instruction <2> in embodied scenes. '
61
+ 'The instruction <1> must not include object category name itself. '
62
+ 'The instruction <2> must include the object category name itself. '
63
+ 'The instruction <2> must belong to embodied manipulation and give action if instruction <1> provides. '
64
+ 'The instruction <2> does not exceed 50 words.'
65
+ )},
66
+ {'role': 'user', 'content': 'microwave, open'},
67
+ {'role': 'assistant', 'content': '<1> Heat up food quickly. <2> The microwave is closed, so it can be open to access the food inside.'},
68
+ {'role': 'user', 'content': 'knife'},
69
+ {'role': 'assistant', 'content': '<1> I want to cut a bread. <2> The knife has a handle, you can use its handle to cut bread.'},
70
+ {'role': 'user', 'content': 'computer mouse'},
71
+ {'role': 'assistant', 'content': '<1> Give me a tool to control the cursor on the screen. <2> The computer mouse is here. It has no handle, so you can grasp its whole body.'},
72
+ {'role': 'user', 'content': 'fork'},
73
+ {'role': 'assistant', 'content': '<1> Use to pierce and lift food. <2> The fork is here, and its handle can be grasped.'},
74
+ {'role': 'user', 'content': 'screwdrivers'},
75
+ {'role': 'assistant', 'content': '<1> I need a tool to tighten or loosen screws. <2> The screwdriver is here, hold its handle to turn and control screws.'},
76
+ {'role': 'user', 'content': category}
77
+ ]
78
+ }
79
+
80
+ response = requests.post(API_URL, headers=HEADERS, json=payload)
81
+ if response.status_code == 200:
82
+ return response.json()['choices'][0]['message']['content']
83
+ else:
84
+ print(f"[API Error] {category}: {response.status_code} - {response.text}")
85
+ return None
86
+
87
+
88
+ def process_json(entry):
89
+ """
90
+ Process a single image/class entry by generating reasoning and saving result to file.
91
+ """
92
+ class_name = entry['class_name']
93
+
94
+ for _ in range(5):
95
+ result = process_sentence(class_name)
96
+ if result and '<1>' in result and '<2>' in result:
97
+ break
98
+ else:
99
+ print(f"[Retry Failed] {class_name}")
100
+ return
101
+
102
+ try:
103
+ question = result.split('<2>')[0].split('<1>')[-1].strip()
104
+ answer = result.split('<2>')[-1].strip()
105
+
106
+ save_dir = os.path.join(f'./reason_affordance/{DATASET}_hard_reasoning', class_name)
107
+ os.makedirs(save_dir, exist_ok=True)
108
+
109
+ save_path = os.path.join(save_dir, os.path.splitext(os.path.basename(entry['img_name']))[0] + ".json")
110
+ output = {
111
+ 'img_name': entry['img_name'],
112
+ 'class_name': class_name,
113
+ 'question': question,
114
+ 'answer': answer
115
+ }
116
+
117
+ with open(save_path, 'w') as f:
118
+ json.dump(output, f, indent=4)
119
+ print(f"[Saved] {save_path}")
120
+ except Exception as e:
121
+ print(f"[Error] Failed to save {class_name}: {e}")
122
+
123
+
124
+ def main():
125
+ """
126
+ Main execution: loads data, then processes in parallel.
127
+ """
128
+ pkl_path = f'./data/{DATASET}_val.pkl'
129
+ entries = read_pkl_file(pkl_path)
130
+
131
+ with ThreadPoolExecutor(max_workers=2) as executor:
132
+ executor.map(process_json, entries)
133
+
134
+
135
+ if __name__ == "__main__":
136
+ main()
data_curation/vlpart_sam2_tracking.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import pickle
5
+ import argparse
6
+ import numpy as np
7
+ import warnings
8
+ from tqdm import tqdm
9
+ from pathlib import Path
10
+ from PIL import Image
11
+
12
+ from detectron2.data.detection_utils import read_image
13
+ from supervision import Detections, BoxAnnotator, MaskAnnotator, LabelAnnotator, mask_to_xyxy
14
+
15
+ from sam2.build_sam import build_sam2_video_predictor
16
+ from VLPart.build_vlpart import build_vlpart_model
17
+
18
+
19
+ warnings.filterwarnings('ignore')
20
+
21
+ # Constants
22
+ SAM2_CONFIG = "sam2_hiera_l.yaml"
23
+ SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
24
+ OUTPUT_ROOT = "/data/robot-merlin/mask_vlpart+sam2_tracking"
25
+ OUTPUT_ROOT_IMG = "/data/robot-merlin/mask_vlpart+sam2_tracking_with_image"
26
+
27
+ # Set up torch environment
28
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
29
+ if torch.cuda.get_device_properties(0).major >= 8:
30
+ torch.backends.cuda.matmul.allow_tf32 = True
31
+ torch.backends.cudnn.allow_tf32 = True
32
+
33
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
34
+
35
+
36
+ def load_affordance_data(pkl_path):
37
+ """
38
+ Load affordance data from a pickle file and organize it by video directory.
39
+ Args:
40
+ pkl_path (str): Path to the pickle file containing affordance data.
41
+ Returns:
42
+ dict: A dictionary where keys are video directory paths and values are lists of data entries.
43
+ """
44
+ with open(pkl_path, 'rb') as f:
45
+ datas = pickle.load(f)
46
+
47
+ data_dict = {}
48
+ for data in datas:
49
+ vid_path = os.path.dirname(data['frame_path'])
50
+ data_dict.setdefault(vid_path, []).append(data)
51
+ return data_dict
52
+
53
+
54
+ def init_vlpart_once(text, prev_text, vlpart_model):
55
+ """
56
+ Initialize VLPart model if the text has changed.
57
+ """
58
+ if text != prev_text:
59
+ if vlpart_model is not None:
60
+ del vlpart_model
61
+ vlpart_model = build_vlpart_model(text)
62
+ return vlpart_model, text
63
+
64
+
65
+ def run_vlpart_on_first_frame(vlpart_model, image_path):
66
+ """
67
+ Run VLPart model on the first frame to get bounding boxes.
68
+ """
69
+ img = read_image(image_path, format="BGR")
70
+ predictions, _ = vlpart_model.run_on_image(img)
71
+ if len(predictions["instances"]) != 1:
72
+ return None
73
+ return predictions["instances"].pred_boxes.tensor.cpu().numpy()
74
+
75
+
76
+ def run_sam2_tracking(video_dir, frame_names, sam2_predictor, boxes):
77
+ """
78
+ Run SAM2 tracking on the video frames using the provided bounding boxes.
79
+ """
80
+ inference_state = sam2_predictor.init_state(video_path=video_dir)
81
+ sam2_predictor.reset_state(inference_state)
82
+
83
+ _, obj_ids, mask_logits = sam2_predictor.add_new_points_or_box(
84
+ inference_state=inference_state,
85
+ frame_idx=0,
86
+ obj_id=1,
87
+ box=boxes,
88
+ )
89
+
90
+ results = {}
91
+ for frame_idx, out_ids, out_logits in sam2_predictor.propagate_in_video(inference_state):
92
+ results[frame_idx] = {
93
+ oid: (out_logits[i] > 0).cpu().numpy()
94
+ for i, oid in enumerate(out_ids)
95
+ }
96
+ return results
97
+
98
+
99
+ def save_tracking_results(video_dir, frame_names, video_segments, object_name, output_base, vid):
100
+ """
101
+ Save the tracking results to the specified output directory.
102
+ """
103
+ objects = [object_name]
104
+ id_to_objects = {i: obj for i, obj in enumerate(objects, start=1)}
105
+
106
+ output_dir = Path(f"{output_base}/{vid:06d}")
107
+ output_dir.mkdir(parents=True, exist_ok=True)
108
+
109
+ output_dir_img = Path(f"{OUTPUT_ROOT_IMG}/{vid:06d}")
110
+ output_dir_img.mkdir(parents=True, exist_ok=True)
111
+
112
+ box_annotator = BoxAnnotator()
113
+ label_annotator = LabelAnnotator()
114
+ mask_annotator = MaskAnnotator()
115
+
116
+ for idx, masks in video_segments.items():
117
+ frame_path = os.path.join(video_dir, frame_names[idx])
118
+ frame = cv2.imread(frame_path)
119
+
120
+ obj_ids = list(masks.keys())
121
+ mask_arr = np.concatenate(list(masks.values()), axis=0)
122
+
123
+ detections = Detections(
124
+ xyxy=mask_to_xyxy(mask_arr),
125
+ mask=mask_arr,
126
+ class_id=np.array(obj_ids, dtype=np.int32),
127
+ )
128
+
129
+ annotated = box_annotator.annotate(frame.copy(), detections)
130
+ annotated = label_annotator.annotate(annotated, detections, [id_to_objects[i] for i in obj_ids])
131
+ annotated = mask_annotator.annotate(annotated, detections)
132
+
133
+ cv2.imwrite(str(output_dir_img / frame_names[idx]), annotated)
134
+ cv2.imwrite(str(output_dir / frame_names[idx]), mask_arr[0] * 255)
135
+
136
+
137
+ def get_sorted_frame_names(video_dir):
138
+ return sorted([
139
+ f for f in os.listdir(video_dir)
140
+ if f.lower().endswith(('.jpg', '.jpeg'))
141
+ ], key=lambda name: int(os.path.splitext(name)[0]))
142
+
143
+
144
+ def main(openx_data, text_override=None):
145
+ # You can reorganize the data loading logic as needed
146
+ data_dict = load_affordance_data(f'./data/{openx_data}_for_affordance.pkl')
147
+
148
+ # Initialize SAM2 predictor
149
+ sam2_predictor = build_sam2_video_predictor(SAM2_CONFIG, SAM2_CHECKPOINT, device=device)
150
+
151
+ prev_text = ''
152
+ vlpart_model = None
153
+
154
+ for video_dir, data_list in tqdm(data_dict.items()):
155
+ first_sample = data_list[0]
156
+ frame_path = first_sample['frame_path']
157
+ task_class = first_sample['task_object_class']
158
+
159
+ # Only process specific classes
160
+ if not any(k in task_class for k in ['door', 'drawer', 'knife']):
161
+ continue
162
+
163
+ # Initialize VLPart model with the task class
164
+ input_text = f"{task_class} handle" if not text_override else text_override
165
+ vlpart_model, prev_text = init_vlpart_once(input_text, prev_text, vlpart_model)
166
+
167
+ # Process the first frame to get bounding boxes
168
+ boxes = run_vlpart_on_first_frame(vlpart_model, frame_path)
169
+ if boxes is None:
170
+ continue
171
+
172
+ # Run SAM2 tracking on the video frames
173
+ frame_names = get_sorted_frame_names(video_dir)
174
+ segments = run_sam2_tracking(video_dir, frame_names, sam2_predictor, boxes)
175
+ save_tracking_results(video_dir, frame_names, segments, input_text,
176
+ f"{OUTPUT_ROOT}/", first_sample['vid'])
177
+ print(f"[Done] {frame_path} | {task_class}")
178
+
179
+
180
+ if __name__ == "__main__":
181
+ parser = argparse.ArgumentParser("VLPart + SAM2 Tracking Demo")
182
+ parser.add_argument("--pipeline", type=str, default="referring_expression_segmentation", help="Pipeline task")
183
+ parser.add_argument("--text_input", type=str, default=None, help="Optional override for input text")
184
+ parser.add_argument("--dataset", type=str, default="bridge", help="Dataset name (e.g., bridge)")
185
+ args = parser.parse_args()
186
+
187
+ main(args.dataset, args.pipeline, args.text_input)
docs/dataset.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Dataset
2
+
3
+ To train our affordance segmentation model, we use two types of data:
4
+ * **General Segmentation Data**: This follows [LISA](https://github.com/dvlab-research/LISA).
5
+ * **Affordance Segmentation Data**: This is a large-scale dataset that we collect.
6
+
7
+ ### General Segmentation Data
8
+ These data is organized as follows:
9
+ ```
10
+ ./data/
11
+ ├── lisa_data
12
+ │ ├── ade20k
13
+ │ ├── coco
14
+ │ ├── cocostuff
15
+ │ ├── llava_dataset
16
+ │ ├── mapillary
17
+ │ ├── reason_seg
18
+ │ ├── refer_seg
19
+ │ ├── vlpart
20
+ ```
21
+
22
+ ### Affordance Segmentation Data
23
+
24
+ We employ images from HANDAL, Open-X, GraspNet, EgoObjects, and RLBench in our affordance segmentation task.
25
+
26
+ The HANDAL data is downloaded and organized according to its official [repo](https://github.com/NVlabs/HANDAL).
27
+ Other data can be downloaded from the [Hugging Face](https://huggingface.co/datasets/Dongming97/RAGNet).
28
+
29
+ The training data is organized as follows:
30
+ ```
31
+ ./data/
32
+ ├── openx_train.pkl
33
+ ├── graspnet_train.pkl
34
+ ├── egoobjects_train.pkl
35
+ ├── rlbench_train.pkl
36
+ ├── handal_hard_reasoning_train.pkl
37
+ ├── egoobjects_easy_reasoning_train.pkl
38
+ ├── egoobjects_hard_reasoning_train.pkl
39
+ ├── HANDAL
40
+ │ ├── without_depth
41
+ │ ├── handal_dataset_adjustable_wrenches
42
+ │ ├── handal_dataset_combinational_wrenches
43
+ │ ├── handal_dataset_fixed_joint_pliers
44
+ │ ├── ...
45
+ ├── openx
46
+ │ ├── images
47
+ │ ├── fractal20220817_data
48
+ │ ├── bridge
49
+ │ ├── masks
50
+ │ ├── fractal20220817_data
51
+ │ ├── bridge
52
+ ├── graspnet
53
+ │ ├── images
54
+ │ ├── masks
55
+ │ ├── test_seen
56
+ │ ├── test_novel
57
+ ├── egoobjects
58
+ │ ├── images
59
+ │ ├── masks
60
+ ├── rlbench
61
+ │ ├── images
62
+ │ ├── masks
63
+ ├── 3doi
64
+ │ ├── images
65
+ │ ├── masks
66
+ ```
67
+
68
+ The evaluation data is also in the same dictory, but with the `*_eval.pkl` files instead of `*_train.pkl`.
69
+
70
+ ```
71
+ ./data/
72
+ ├── handal_mini_val.pkl
73
+ ├── graspnet_test_seen_val.pkl
74
+ ├── graspnet_test_novel_val.pkl
75
+ ├── 3doi_val.pkl
76
+ ├── handal_easy_reasoning_val.pkl
77
+ ├── handal_hard_reasoning_val.pkl
78
+ ├── 3doi_easy_reasoning_val.pkl
79
+ ```
80
+
81
+ You can use the following script to confirm if data is organized correctly:
82
+ ```bash
83
+ python data_curation/check_dataset.py
84
+ ```
85
+
86
+ ### About data curation
87
+ 1. **SAM2**: We use SAM2 to generate affordance mask if the dataset provides box annotation.
88
+ 2. **Florence-2 + SAM2**: We use Florence-2 to generate the initial segmentation masks of some complete objects, and then refine them with SAM2. Please see [Florence-2+SAM2](https://github.com/IDEA-Research/Grounded-SAM-2).
89
+ 3. **VLPart + SAM2**: We use VLPart to generate box of object part, and then refine them with SAM2. We refer to [VLPart](https://github.com/facebookresearch/VLPart).
90
+ We provide our inference demo scripts in `data_curation/build_vlpart.py` and `data_curation/vlpart_sam2_tracking.py`.
91
+ 4. **Reasoning Instruction**: We provide two example scripts to generate reasoning instructions for the affordance segmentation task:
92
+ - `data_curation/prompt_generation_handal_easy_reasoning.py`: This script generates easy reasoning instructions for the HANDAL dataset.
93
+ - `data_curation/prompt_generation_handal_hard_reasoning.py`: This script generates hard reasoning instructions for the HANDAL dataset.
docs/installation.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Installation
2
+ The environment installation mainly follows [LISA](https://github.com/dvlab-research/LISA).
3
+ ```
4
+ https://github.com/wudongming97/AffordanceNet.git
5
+ cd AffordanceNet
6
+ conda create -n affordancenet python=3.9
7
+ conda activate affordancenet
8
+ pip install -r requirements.txt
9
+ pip install flash-attn --no-build-isolation
10
+ ```
docs/training_and_evaluation.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Training and Evaluation
2
+
3
+ ### Pre-trained Weights
4
+ #### LLaVA
5
+ For convenience of using pre-trained LLaVA weights, we provide a link from [Hugging Face](https://huggingface.co/Dongming97/LLaVA-Lightning-7B-v1-1).
6
+
7
+ #### SAM
8
+ Download SAM ViT-H pre-trained weights from the [link](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth).
9
+
10
+
11
+ ### Training
12
+ To train AffordanceVLM, you can use the following command.
13
+ ```
14
+ bash ./scripts/train.sh
15
+ ```
16
+ When training is finished, to get the full model weight:
17
+
18
+ ```
19
+ cd ./runs/AffordanceVLM-7B/ckpt_model && python zero_to_fp32.py . ../pytorch_model.bin
20
+ ```
21
+
22
+ ### Merge LoRA Weight
23
+ Merge the LoRA weights of `pytorch_model.bin`, save the resulting model into your desired path in the Hugging Face format:
24
+ ```
25
+ CUDA_VISIBLE_DEVICES="" python merge_lora_weights_and_save_hf_model.py \
26
+ --version="PATH_TO_LLaVA" \
27
+ --weight="PATH_TO_pytorch_model.bin" \
28
+ --save_path="PATH_TO_SAVED_MODEL"
29
+ ```
30
+
31
+ For example:
32
+ ```
33
+ CUDA_VISIBLE_DEVICES="" python3 merge_lora_weights_and_save_hf_model.py \
34
+ --version="./LLaVA/LLaVA-Lightning-7B-v1-1" \
35
+ --weight="./runs/AffordanceVLM-7B/pytorch_model.bin" \
36
+ --save_path="./exps/AffordanceVLM-7B"
37
+ ```
38
+
39
+ ### Evaluation
40
+ To evaluate AffordanceVLM on the entire [HANDAL](https://github.com/NVlabs/HANDAL) dataset, please adjust the `--dataset_dir` parameter in `evaluate.sh`.
41
+ ```
42
+ bash ./scripts/evaluate.sh
43
+ ```
44
+
45
+ To chat with [AffordanceVLM-7B](https://huggingface.co/Dongming97/AffordanceVLM):
46
+ ```
47
+ CUDA_VISIBLE_DEVICES=0 python chat.py --version=./exps/AffordanceVLM-7B
48
+ ```
49
+
50
+ ### Main Results
51
+
52
+ HANDAL:
53
+
54
+ | Method | gIoU | cIoU |
55
+ |:----------------:|:----:|-----:|
56
+ | AffordanceVLM-7B | 60.3 | 60.8 |
imgs/.ipynb_checkpoints/AffordanceNet-checkpoint.jpg ADDED

Git LFS Details

  • SHA256: 3abd71b7ead1d3353faf60d65da4ceeafed34314a4c123059b5d92f53685c797
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
imgs/AffordanceNet.jpg ADDED

Git LFS Details

  • SHA256: 3abd71b7ead1d3353faf60d65da4ceeafed34314a4c123059b5d92f53685c797
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
imgs/AffordanceNet.png ADDED

Git LFS Details

  • SHA256: 6c1537d2a0442b1685bdfdefbb8f028acf2cc9d90782a8f37c77037126aab550
  • Pointer size: 132 Bytes
  • Size of remote file: 1.88 MB
merge_lora_weights_and_save_hf_model.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ import sys
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import transformers
11
+ from peft import LoraConfig, get_peft_model
12
+ from transformers import AutoTokenizer
13
+
14
+ from model.AffordanceVLM import AffordanceVLMForCausalLM
15
+ from utils.utils import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN
16
+
17
+
18
+ def parse_args(args):
19
+ parser = argparse.ArgumentParser(
20
+ description="merge lora weights and save model with hf format"
21
+ )
22
+ parser.add_argument(
23
+ "--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview"
24
+ )
25
+ parser.add_argument("--vis_save_path", default="./vis_output", type=str)
26
+ parser.add_argument(
27
+ "--precision",
28
+ default="bf16",
29
+ type=str,
30
+ choices=["fp32", "bf16", "fp16"],
31
+ help="precision for inference",
32
+ )
33
+ parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str)
34
+ parser.add_argument("--out_dim", default=256, type=int)
35
+ parser.add_argument("--image_size", default=1024, type=int, help="image size")
36
+ parser.add_argument("--model_max_length", default=512, type=int)
37
+ parser.add_argument(
38
+ "--vision-tower", default="openai/clip-vit-large-patch14", type=str
39
+ )
40
+ parser.add_argument("--lora_r", default=8, type=int)
41
+ parser.add_argument("--lora_alpha", default=16, type=int)
42
+ parser.add_argument("--lora_dropout", default=0.05, type=float)
43
+ parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
44
+ parser.add_argument("--local-rank", default=0, type=int, help="node rank")
45
+ parser.add_argument("--train_mask_decoder", action="store_true", default=True)
46
+ parser.add_argument("--use_mm_start_end", action="store_true", default=True)
47
+ parser.add_argument(
48
+ "--conv_type",
49
+ default="llava_v1",
50
+ type=str,
51
+ choices=["llava_v1", "llava_llama_2"],
52
+ )
53
+ parser.add_argument("--weight", default="", type=str, required=True)
54
+ parser.add_argument("--save_path", default="./lisa_model", type=str, required=True)
55
+ return parser.parse_args(args)
56
+
57
+
58
+ def main(args):
59
+ args = parse_args(args)
60
+ os.makedirs(args.vis_save_path, exist_ok=True)
61
+
62
+ # Create model
63
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
64
+ args.version,
65
+ cache_dir=None,
66
+ model_max_length=args.model_max_length,
67
+ padding_side="right",
68
+ use_fast=False,
69
+ )
70
+ tokenizer.pad_token = tokenizer.unk_token
71
+ num_added_tokens = tokenizer.add_tokens("[SEG]")
72
+ args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
73
+ num_added_tokens = tokenizer.add_tokens("[AFF]")
74
+ args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0]
75
+
76
+ if args.use_mm_start_end:
77
+ tokenizer.add_tokens(
78
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
79
+ )
80
+
81
+ model_args = {
82
+ "train_mask_decoder": args.train_mask_decoder,
83
+ "out_dim": args.out_dim,
84
+ "seg_token_idx": args.seg_token_idx,
85
+ "aff_token_idx": args.aff_token_idx,
86
+ "vision_tower": args.vision_tower,
87
+ }
88
+
89
+ torch_dtype = torch.float32
90
+ if args.precision == "bf16":
91
+ torch_dtype = torch.bfloat16
92
+ elif args.precision == "fp16":
93
+ torch_dtype = torch.half
94
+ model = AffordanceVLMForCausalLM.from_pretrained(
95
+ args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args
96
+ )
97
+ model.config.eos_token_id = tokenizer.eos_token_id
98
+ model.config.bos_token_id = tokenizer.bos_token_id
99
+ model.config.pad_token_id = tokenizer.pad_token_id
100
+
101
+ model.get_model().initialize_vision_modules(model.get_model().config)
102
+ vision_tower = model.get_model().get_vision_tower()
103
+ vision_tower.to(dtype=torch_dtype)
104
+ model.get_model().initialize_lisa_modules(model.get_model().config)
105
+
106
+ lora_r = args.lora_r
107
+ if lora_r > 0:
108
+
109
+ def find_linear_layers(model, lora_target_modules):
110
+ cls = torch.nn.Linear
111
+ lora_module_names = set()
112
+ for name, module in model.named_modules():
113
+ if (
114
+ isinstance(module, cls)
115
+ and all(
116
+ [
117
+ x not in name
118
+ for x in [
119
+ "visual_model",
120
+ "vision_tower",
121
+ "mm_projector",
122
+ "text_hidden_fcs",
123
+ ]
124
+ ]
125
+ )
126
+ and any([x in name for x in lora_target_modules])
127
+ ):
128
+ lora_module_names.add(name)
129
+ return sorted(list(lora_module_names))
130
+
131
+ lora_alpha = args.lora_alpha
132
+ lora_dropout = args.lora_dropout
133
+ lora_target_modules = find_linear_layers(
134
+ model, args.lora_target_modules.split(",")
135
+ )
136
+ lora_config = LoraConfig(
137
+ r=lora_r,
138
+ lora_alpha=lora_alpha,
139
+ target_modules=lora_target_modules,
140
+ lora_dropout=lora_dropout,
141
+ bias="none",
142
+ task_type="CAUSAL_LM",
143
+ )
144
+ model = get_peft_model(model, lora_config)
145
+ model.print_trainable_parameters()
146
+
147
+ model.resize_token_embeddings(len(tokenizer))
148
+
149
+ state_dict = torch.load(args.weight, map_location="cpu")
150
+ model.load_state_dict(state_dict, strict=True)
151
+
152
+ model = model.merge_and_unload()
153
+ state_dict = {}
154
+ for k, v in model.state_dict().items():
155
+ if "vision_tower" not in k:
156
+ state_dict[k] = v
157
+ model.save_pretrained(args.save_path, state_dict=state_dict)
158
+ tokenizer.save_pretrained(args.save_path)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main(sys.argv[1:])
model/AffordanceVLM.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import BitsAndBytesConfig, CLIPVisionModel
7
+
8
+ from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
9
+ DEFAULT_IMAGE_PATCH_TOKEN)
10
+
11
+ from .llava.model.language_model.llava_llama import (LlavaLlamaForCausalLM,
12
+ LlavaLlamaModel)
13
+ from .segment_anything import build_sam_vit_h
14
+
15
+
16
+ def dice_loss(
17
+ inputs: torch.Tensor,
18
+ targets: torch.Tensor,
19
+ num_masks: float,
20
+ scale=1000, # 100000.0,
21
+ eps=1e-6,
22
+ ):
23
+ """
24
+ Compute the DICE loss, similar to generalized IOU for masks
25
+ Args:
26
+ inputs: A float tensor of arbitrary shape.
27
+ The predictions for each example.
28
+ targets: A float tensor with the same shape as inputs. Stores the binary
29
+ classification label for each element in inputs
30
+ (0 for the negative class and 1 for the positive class).
31
+ """
32
+ inputs = inputs.sigmoid()
33
+ inputs = inputs.flatten(1, 2)
34
+ targets = targets.flatten(1, 2)
35
+ numerator = 2 * (inputs / scale * targets).sum(-1)
36
+ denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1)
37
+ loss = 1 - (numerator + eps) / (denominator + eps)
38
+ loss = loss.sum() / (num_masks + 1e-8)
39
+ return loss
40
+
41
+
42
+ def sigmoid_ce_loss(
43
+ inputs: torch.Tensor,
44
+ targets: torch.Tensor,
45
+ num_masks: float,
46
+ ):
47
+ """
48
+ Args:
49
+ inputs: A float tensor of arbitrary shape.
50
+ The predictions for each example.
51
+ targets: A float tensor with the same shape as inputs. Stores the binary
52
+ classification label for each element in inputs
53
+ (0 for the negative class and 1 for the positive class).
54
+ Returns:
55
+ Loss tensor
56
+ """
57
+ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
58
+ loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8)
59
+ return loss
60
+
61
+
62
+ class LisaMetaModel:
63
+ def __init__(
64
+ self,
65
+ config,
66
+ **kwargs,
67
+ ):
68
+ super(LisaMetaModel, self).__init__(config)
69
+
70
+ self.config = config
71
+ if not hasattr(self.config, "train_mask_decoder"):
72
+ self.config.train_mask_decoder = kwargs["train_mask_decoder"]
73
+ self.config.out_dim = kwargs["out_dim"]
74
+ self.vision_pretrained = kwargs.get("vision_pretrained", None)
75
+ else:
76
+ self.vision_pretrained = kwargs.get("vision_pretrained", None)
77
+ self.initialize_lisa_modules(self.config)
78
+
79
+ def initialize_lisa_modules(self, config):
80
+ # SAM
81
+ self.visual_model = build_sam_vit_h(self.vision_pretrained)
82
+ for param in self.visual_model.parameters():
83
+ param.requires_grad = False
84
+ if config.train_mask_decoder:
85
+ self.visual_model.mask_decoder.train()
86
+ for param in self.visual_model.mask_decoder.parameters():
87
+ param.requires_grad = True
88
+
89
+ # Projection layer
90
+ in_dim = config.hidden_size
91
+ out_dim = config.out_dim
92
+ text_fc = [
93
+ nn.Linear(in_dim, in_dim),
94
+ nn.ReLU(inplace=True),
95
+ nn.Linear(in_dim, out_dim),
96
+ nn.Dropout(0.0),
97
+ ]
98
+ self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
99
+ self.text_hidden_fcs.train()
100
+ for param in self.text_hidden_fcs.parameters():
101
+ param.requires_grad = True
102
+
103
+
104
+ class LisaModel(LisaMetaModel, LlavaLlamaModel):
105
+ def __init__(
106
+ self,
107
+ config,
108
+ **kwargs,
109
+ ):
110
+ super(LisaModel, self).__init__(config, **kwargs)
111
+
112
+ self.config.use_cache = False
113
+ self.config.vision_tower = self.config.mm_vision_tower
114
+ self.config.mm_vision_select_feature = "patch"
115
+ self.config.image_aspect_ratio = "square"
116
+ self.config.image_grid_pinpoints = None
117
+ self.config.tune_mm_mlp_adapter = False
118
+ self.config.freeze_mm_mlp_adapter = True
119
+ self.config.pretrain_mm_mlp_adapter = None
120
+ self.config.mm_use_im_patch_token = False
121
+
122
+
123
+ class AffordanceVLMForCausalLM(LlavaLlamaForCausalLM):
124
+ def __init__(
125
+ self,
126
+ config,
127
+ **kwargs,
128
+ ):
129
+ if not hasattr(config, "train_mask_decoder"):
130
+ config.mm_use_im_start_end = kwargs.pop("use_mm_start_end", True)
131
+ config.mm_vision_tower = kwargs.get(
132
+ "vision_tower", "openai/clip-vit-large-patch14"
133
+ )
134
+ self.ce_loss_weight = kwargs.pop("ce_loss_weight", None)
135
+ self.dice_loss_weight = kwargs.pop("dice_loss_weight", None)
136
+ self.bce_loss_weight = kwargs.pop("bce_loss_weight", None)
137
+ else:
138
+ config.mm_vision_tower = config.vision_tower
139
+
140
+ self.seg_token_idx = kwargs.pop("seg_token_idx")
141
+ self.aff_token_idx = kwargs.pop("aff_token_idx")
142
+
143
+ super().__init__(config)
144
+
145
+ self.model = LisaModel(config, **kwargs)
146
+
147
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
148
+
149
+ # Initialize weights and apply final processing
150
+ self.post_init()
151
+
152
+ def get_visual_embs(self, pixel_values: torch.FloatTensor):
153
+ with torch.no_grad():
154
+ image_embeddings_list = []
155
+ for i in range(pixel_values.shape[0]):
156
+ torch.cuda.empty_cache()
157
+ image_embeddings = self.model.visual_model.image_encoder(
158
+ pixel_values[i].unsqueeze(0)
159
+ )
160
+ image_embeddings_list.append(image_embeddings)
161
+ torch.cuda.empty_cache()
162
+ image_embeddings = torch.cat(image_embeddings_list, 0)
163
+ return image_embeddings
164
+
165
+ def forward(self, **kwargs):
166
+ if "past_key_values" in kwargs:
167
+ return super().forward(**kwargs)
168
+ return self.model_forward(**kwargs)
169
+
170
+ def model_forward(
171
+ self,
172
+ images: torch.FloatTensor,
173
+ images_clip: torch.FloatTensor,
174
+ input_ids: torch.LongTensor,
175
+ labels: torch.LongTensor,
176
+ attention_masks: torch.LongTensor,
177
+ offset: torch.LongTensor,
178
+ masks_list: List[torch.FloatTensor],
179
+ label_list: List[torch.Tensor],
180
+ resize_list: List[tuple],
181
+ inference: bool = False,
182
+ **kwargs,
183
+ ):
184
+ image_embeddings = self.get_visual_embs(images)
185
+ batch_size = image_embeddings.shape[0]
186
+ assert batch_size == len(offset) - 1
187
+
188
+ seg_token_mask = (input_ids[:, 1:] == self.seg_token_idx) + (input_ids[:, 1:] == self.aff_token_idx)
189
+ seg_token_mask = torch.cat(
190
+ [
191
+ seg_token_mask,
192
+ torch.zeros((seg_token_mask.shape[0], 1)).bool().cuda(),
193
+ ],
194
+ dim=1,
195
+ )
196
+ # hack for IMAGE_TOKEN_INDEX (we suppose that there is only one image, and it is in the front)
197
+ seg_token_mask = torch.cat(
198
+ [torch.zeros((seg_token_mask.shape[0], 255)).bool().cuda(), seg_token_mask],
199
+ dim=1,
200
+ )
201
+
202
+ if inference:
203
+ n_batch = 1
204
+ length = input_ids.shape[0]
205
+ assert images_clip.shape[0] == 1
206
+ images_clip_extend = images_clip.expand(length, -1, -1, -1).contiguous()
207
+
208
+ output_hidden_states = []
209
+ for i in range(n_batch):
210
+ start_i, end_i = i * length, min((i + 1) * length, input_ids.shape[0])
211
+ output_i = super().forward(
212
+ images=images_clip_extend[: end_i - start_i],
213
+ attention_mask=attention_masks[start_i:end_i],
214
+ input_ids=input_ids[start_i:end_i],
215
+ output_hidden_states=True,
216
+ )
217
+ output_hidden_states.append(output_i.hidden_states)
218
+ torch.cuda.empty_cache()
219
+
220
+ output_hidden_states_list = []
221
+ output_hidden_states_level = torch.cat(output_hidden_states, dim=0)
222
+ output_hidden_states_list.append(output_hidden_states_level)
223
+ output_hidden_states = output_hidden_states_list
224
+ output = None
225
+
226
+ else:
227
+ images_clip_list = []
228
+ for i in range(len(offset) - 1):
229
+ start_i, end_i = offset[i], offset[i + 1]
230
+ images_clip_i = (
231
+ images_clip[i]
232
+ .unsqueeze(0)
233
+ .expand(end_i - start_i, -1, -1, -1)
234
+ .contiguous()
235
+ )
236
+ images_clip_list.append(images_clip_i)
237
+ images_clip = torch.cat(images_clip_list, dim=0)
238
+
239
+ output = super().forward(
240
+ images=images_clip,
241
+ attention_mask=attention_masks,
242
+ input_ids=input_ids,
243
+ labels=labels,
244
+ output_hidden_states=True,
245
+ )
246
+ output_hidden_states = output.hidden_states
247
+
248
+ hidden_states = []
249
+
250
+ assert len(self.model.text_hidden_fcs) == 1
251
+ hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states[-1]))
252
+
253
+ last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
254
+ pred_embeddings = last_hidden_state[seg_token_mask]
255
+ seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ]
256
+
257
+ seg_token_offset = seg_token_counts.cumsum(-1)
258
+ seg_token_offset = torch.cat(
259
+ [torch.zeros(1).long().cuda(), seg_token_offset], dim=0
260
+ )
261
+
262
+ seg_token_offset = seg_token_offset[offset]
263
+
264
+ pred_embeddings_ = []
265
+ for i in range(len(seg_token_offset) - 1):
266
+ start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
267
+ pred_embeddings_.append(pred_embeddings[start_i:end_i])
268
+ pred_embeddings = pred_embeddings_
269
+
270
+ multimask_output = False
271
+ pred_masks = []
272
+ for i in range(len(pred_embeddings)):
273
+ (
274
+ sparse_embeddings,
275
+ dense_embeddings,
276
+ ) = self.model.visual_model.prompt_encoder(
277
+ points=None,
278
+ boxes=None,
279
+ masks=None,
280
+ text_embeds=pred_embeddings[i].unsqueeze(1),
281
+ )
282
+ sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
283
+ low_res_masks, iou_predictions = self.model.visual_model.mask_decoder(
284
+ image_embeddings=image_embeddings[i].unsqueeze(0),
285
+ image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(),
286
+ sparse_prompt_embeddings=sparse_embeddings,
287
+ dense_prompt_embeddings=dense_embeddings,
288
+ multimask_output=multimask_output,
289
+ )
290
+ pred_mask = self.model.visual_model.postprocess_masks(
291
+ low_res_masks,
292
+ input_size=resize_list[i],
293
+ original_size=label_list[i].shape,
294
+ )
295
+ pred_masks.append(pred_mask[:, 0])
296
+
297
+ model_output = output
298
+ gt_masks = masks_list
299
+
300
+ if inference:
301
+ return {
302
+ "pred_masks": pred_masks,
303
+ "gt_masks": gt_masks,
304
+ }
305
+
306
+ output = model_output.logits
307
+
308
+ ce_loss = model_output.loss
309
+ ce_loss = ce_loss * self.ce_loss_weight
310
+ mask_bce_loss = 0
311
+ mask_dice_loss = 0
312
+ num_masks = 0
313
+ for batch_idx in range(len(pred_masks)):
314
+ gt_mask = gt_masks[batch_idx]
315
+ pred_mask = pred_masks[batch_idx]
316
+
317
+ assert (
318
+ gt_mask.shape[0] == pred_mask.shape[0]
319
+ ), "gt_mask.shape: {}, pred_mask.shape: {}".format(
320
+ gt_mask.shape, pred_mask.shape
321
+ )
322
+ mask_bce_loss += (
323
+ sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
324
+ * gt_mask.shape[0]
325
+ )
326
+ mask_dice_loss += (
327
+ dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
328
+ * gt_mask.shape[0]
329
+ )
330
+ num_masks += gt_mask.shape[0]
331
+
332
+ mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8)
333
+ mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
334
+ mask_loss = mask_bce_loss + mask_dice_loss
335
+
336
+ loss = ce_loss + mask_loss
337
+
338
+ return {
339
+ "loss": loss,
340
+ "ce_loss": ce_loss,
341
+ "mask_bce_loss": mask_bce_loss,
342
+ "mask_dice_loss": mask_dice_loss,
343
+ "mask_loss": mask_loss,
344
+ }
345
+
346
+ def evaluate(
347
+ self,
348
+ images_clip,
349
+ images,
350
+ input_ids,
351
+ resize_list,
352
+ original_size_list,
353
+ max_new_tokens=32,
354
+ tokenizer=None,
355
+ ):
356
+ with torch.no_grad():
357
+ outputs = self.generate(
358
+ images=images_clip,
359
+ input_ids=input_ids,
360
+ max_new_tokens=max_new_tokens,
361
+ num_beams=1,
362
+ output_hidden_states=True,
363
+ return_dict_in_generate=True,
364
+ )
365
+ output_hidden_states = outputs.hidden_states[-1]
366
+ output_ids = outputs.sequences
367
+
368
+ seg_token_mask = (output_ids[:, 1:] == self.seg_token_idx) + (output_ids[:, 1:] == self.aff_token_idx)
369
+ # hack for IMAGE_TOKEN_INDEX (we suppose that there is only one image, and it is in the front)
370
+ seg_token_mask = torch.cat(
371
+ [
372
+ torch.zeros((seg_token_mask.shape[0], 255)).bool().cuda(),
373
+ seg_token_mask,
374
+ ],
375
+ dim=1,
376
+ )
377
+
378
+ hidden_states = []
379
+
380
+ assert len(self.model.text_hidden_fcs) == 1
381
+ hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states))
382
+
383
+ last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
384
+ pred_embeddings = last_hidden_state[seg_token_mask]
385
+
386
+ seg_token_counts = seg_token_mask.int().sum(-1) # [bs, ]
387
+ seg_token_offset = seg_token_counts.cumsum(-1)
388
+ seg_token_offset = torch.cat(
389
+ [torch.zeros(1).long().cuda(), seg_token_offset], dim=0
390
+ )
391
+
392
+ pred_embeddings_ = []
393
+ for i in range(len(seg_token_offset) - 1):
394
+ start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
395
+ pred_embeddings_.append(pred_embeddings[start_i:end_i])
396
+ pred_embeddings = pred_embeddings_
397
+
398
+ image_embeddings = self.get_visual_embs(images)
399
+
400
+ multimask_output = False
401
+ pred_masks = []
402
+ for i in range(len(pred_embeddings)):
403
+ (
404
+ sparse_embeddings,
405
+ dense_embeddings,
406
+ ) = self.model.visual_model.prompt_encoder(
407
+ points=None,
408
+ boxes=None,
409
+ masks=None,
410
+ text_embeds=pred_embeddings[i].unsqueeze(1),
411
+ )
412
+
413
+ sparse_embeddings = sparse_embeddings.to(pred_embeddings[i].dtype)
414
+ low_res_masks, iou_predictions = self.model.visual_model.mask_decoder(
415
+ image_embeddings=image_embeddings[i].unsqueeze(0),
416
+ image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(),
417
+ sparse_prompt_embeddings=sparse_embeddings,
418
+ dense_prompt_embeddings=dense_embeddings,
419
+ multimask_output=multimask_output,
420
+ )
421
+ pred_mask = self.model.visual_model.postprocess_masks(
422
+ low_res_masks,
423
+ input_size=resize_list[i],
424
+ original_size=original_size_list[i],
425
+ )
426
+ pred_masks.append(pred_mask[:, 0])
427
+
428
+ return output_ids, pred_masks
model/__pycache__/AffordanceVLM.cpython-39.pyc ADDED
Binary file (9.71 kB). View file
 
model/llava/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import LlavaLlamaForCausalLM
model/llava/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (192 Bytes). View file
 
model/llava/__pycache__/constants.cpython-39.pyc ADDED
Binary file (454 Bytes). View file
 
model/llava/__pycache__/conversation.cpython-39.pyc ADDED
Binary file (10.4 kB). View file
 
model/llava/__pycache__/mm_utils.cpython-39.pyc ADDED
Binary file (3.4 kB). View file
 
model/llava/constants.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
model/llava/conversation.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import Enum, auto
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+
9
+ SINGLE = auto()
10
+ TWO = auto()
11
+ MPT = auto()
12
+ PLAIN = auto()
13
+ LLAMA_2 = auto()
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class Conversation:
18
+ """A class that keeps all conversation history."""
19
+
20
+ system: str
21
+ roles: List[str]
22
+ messages: List[List[str]]
23
+ offset: int
24
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
25
+ sep: str = "###"
26
+ sep2: str = None
27
+ version: str = "Unknown"
28
+
29
+ skip_next: bool = False
30
+
31
+ def get_prompt(self):
32
+ messages = self.messages
33
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
34
+ messages = self.messages.copy()
35
+ init_role, init_msg = messages[0].copy()
36
+ init_msg = init_msg[0].replace("<image>", "").strip()
37
+ if "mmtag" in self.version:
38
+ messages[0] = (init_role, init_msg)
39
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
40
+ messages.insert(1, (self.roles[1], "Received."))
41
+ else:
42
+ messages[0] = (init_role, "<image>\n" + init_msg)
43
+
44
+ if self.sep_style == SeparatorStyle.SINGLE:
45
+ ret = self.system + self.sep
46
+ for role, message in messages:
47
+ if message:
48
+ if type(message) is tuple:
49
+ message, _, _ = message
50
+ ret += role + ": " + message + self.sep
51
+ else:
52
+ ret += role + ":"
53
+ elif self.sep_style == SeparatorStyle.TWO:
54
+ seps = [self.sep, self.sep2]
55
+ ret = self.system + seps[0]
56
+ for i, (role, message) in enumerate(messages):
57
+ if message:
58
+ if type(message) is tuple:
59
+ message, _, _ = message
60
+ ret += role + ": " + message + seps[i % 2]
61
+ else:
62
+ ret += role + ":"
63
+ elif self.sep_style == SeparatorStyle.MPT:
64
+ ret = self.system + self.sep
65
+ for role, message in messages:
66
+ if message:
67
+ if type(message) is tuple:
68
+ message, _, _ = message
69
+ ret += role + message + self.sep
70
+ else:
71
+ ret += role
72
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
73
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
74
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
75
+ ret = ""
76
+
77
+ for i, (role, message) in enumerate(messages):
78
+ if i == 0:
79
+ assert message, "first message should not be none"
80
+ assert role == self.roles[0], "first message should come from user"
81
+ if message:
82
+ if type(message) is tuple:
83
+ message, _, _ = message
84
+ if i == 0:
85
+ message = wrap_sys(self.system) + message
86
+ if i % 2 == 0:
87
+ message = wrap_inst(message)
88
+ ret += self.sep + message
89
+ else:
90
+ ret += " " + message + " " + self.sep2
91
+ else:
92
+ ret += ""
93
+ ret = ret.lstrip(self.sep)
94
+ elif self.sep_style == SeparatorStyle.PLAIN:
95
+ seps = [self.sep, self.sep2]
96
+ ret = self.system
97
+ for i, (role, message) in enumerate(messages):
98
+ if message:
99
+ if type(message) is tuple:
100
+ message, _, _ = message
101
+ ret += message + seps[i % 2]
102
+ else:
103
+ ret += ""
104
+ else:
105
+ raise ValueError(f"Invalid style: {self.sep_style}")
106
+
107
+ return ret
108
+
109
+ def append_message(self, role, message):
110
+ self.messages.append([role, message])
111
+
112
+ def get_images(self, return_pil=False):
113
+ images = []
114
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
115
+ if i % 2 == 0:
116
+ if type(msg) is tuple:
117
+ import base64
118
+ from io import BytesIO
119
+
120
+ from PIL import Image
121
+
122
+ msg, image, image_process_mode = msg
123
+ if image_process_mode == "Pad":
124
+
125
+ def expand2square(pil_img, background_color=(122, 116, 104)):
126
+ width, height = pil_img.size
127
+ if width == height:
128
+ return pil_img
129
+ elif width > height:
130
+ result = Image.new(
131
+ pil_img.mode, (width, width), background_color
132
+ )
133
+ result.paste(pil_img, (0, (width - height) // 2))
134
+ return result
135
+ else:
136
+ result = Image.new(
137
+ pil_img.mode, (height, height), background_color
138
+ )
139
+ result.paste(pil_img, ((height - width) // 2, 0))
140
+ return result
141
+
142
+ image = expand2square(image)
143
+ elif image_process_mode == "Crop":
144
+ pass
145
+ elif image_process_mode == "Resize":
146
+ image = image.resize((336, 336))
147
+ else:
148
+ raise ValueError(
149
+ f"Invalid image_process_mode: {image_process_mode}"
150
+ )
151
+ max_hw, min_hw = max(image.size), min(image.size)
152
+ aspect_ratio = max_hw / min_hw
153
+ max_len, min_len = 800, 400
154
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
155
+ longest_edge = int(shortest_edge * aspect_ratio)
156
+ W, H = image.size
157
+ if H > W:
158
+ H, W = longest_edge, shortest_edge
159
+ else:
160
+ H, W = shortest_edge, longest_edge
161
+ image = image.resize((W, H))
162
+ if return_pil:
163
+ images.append(image)
164
+ else:
165
+ buffered = BytesIO()
166
+ image.save(buffered, format="PNG")
167
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
168
+ images.append(img_b64_str)
169
+ return images
170
+
171
+ def to_gradio_chatbot(self):
172
+ ret = []
173
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
174
+ if i % 2 == 0:
175
+ if type(msg) is tuple:
176
+ import base64
177
+ from io import BytesIO
178
+
179
+ msg, image, image_process_mode = msg
180
+ max_hw, min_hw = max(image.size), min(image.size)
181
+ aspect_ratio = max_hw / min_hw
182
+ max_len, min_len = 800, 400
183
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
184
+ longest_edge = int(shortest_edge * aspect_ratio)
185
+ W, H = image.size
186
+ if H > W:
187
+ H, W = longest_edge, shortest_edge
188
+ else:
189
+ H, W = shortest_edge, longest_edge
190
+ image = image.resize((W, H))
191
+ buffered = BytesIO()
192
+ image.save(buffered, format="JPEG")
193
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
194
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
195
+ ret.append([img_str, None])
196
+ msg = msg.replace("<image>", "").strip()
197
+ if len(msg) > 0:
198
+ ret.append([msg, None])
199
+ else:
200
+ ret.append([msg, None])
201
+ else:
202
+ ret[-1][-1] = msg
203
+ return ret
204
+
205
+ def copy(self):
206
+ return Conversation(
207
+ system=self.system,
208
+ roles=self.roles,
209
+ messages=[[x, y] for x, y in self.messages],
210
+ offset=self.offset,
211
+ sep_style=self.sep_style,
212
+ sep=self.sep,
213
+ sep2=self.sep2,
214
+ version=self.version,
215
+ )
216
+
217
+ def dict(self):
218
+ if len(self.get_images()) > 0:
219
+ return {
220
+ "system": self.system,
221
+ "roles": self.roles,
222
+ "messages": [
223
+ [x, y[0] if type(y) is tuple else y] for x, y in self.messages
224
+ ],
225
+ "offset": self.offset,
226
+ "sep": self.sep,
227
+ "sep2": self.sep2,
228
+ }
229
+ return {
230
+ "system": self.system,
231
+ "roles": self.roles,
232
+ "messages": self.messages,
233
+ "offset": self.offset,
234
+ "sep": self.sep,
235
+ "sep2": self.sep2,
236
+ }
237
+
238
+
239
+ conv_vicuna_v0 = Conversation(
240
+ system="A chat between a curious human and an artificial intelligence assistant. "
241
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
242
+ roles=("Human", "Assistant"),
243
+ messages=(
244
+ (
245
+ "Human",
246
+ "What are the key differences between renewable and non-renewable energy sources?",
247
+ ),
248
+ (
249
+ "Assistant",
250
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
251
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
252
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
253
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
254
+ "renewable and non-renewable energy sources:\n"
255
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
256
+ "energy sources are finite and will eventually run out.\n"
257
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
258
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
259
+ "and other negative effects.\n"
260
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
261
+ "have lower operational costs than non-renewable sources.\n"
262
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
263
+ "locations than non-renewable sources.\n"
264
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
265
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
266
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
267
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
268
+ ),
269
+ ),
270
+ offset=2,
271
+ sep_style=SeparatorStyle.SINGLE,
272
+ sep="###",
273
+ )
274
+
275
+ conv_vicuna_v1 = Conversation(
276
+ system="A chat between a curious user and an artificial intelligence assistant. "
277
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
278
+ roles=("USER", "ASSISTANT"),
279
+ version="v1",
280
+ messages=(),
281
+ offset=0,
282
+ sep_style=SeparatorStyle.TWO,
283
+ sep=" ",
284
+ sep2="</s>",
285
+ )
286
+
287
+ conv_llama_2 = Conversation(
288
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
289
+
290
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
291
+ roles=("USER", "ASSISTANT"),
292
+ version="llama_v2",
293
+ messages=(),
294
+ offset=0,
295
+ sep_style=SeparatorStyle.LLAMA_2,
296
+ sep="<s>",
297
+ sep2="</s>",
298
+ )
299
+
300
+ conv_llava_llama_2 = Conversation(
301
+ system="You are a helpful language and vision assistant. "
302
+ "You are able to understand the visual content that the user provides, "
303
+ "and assist the user with a variety of tasks using natural language.",
304
+ roles=("USER", "ASSISTANT"),
305
+ version="llama_v2",
306
+ messages=(),
307
+ offset=0,
308
+ sep_style=SeparatorStyle.LLAMA_2,
309
+ sep="<s>",
310
+ sep2="</s>",
311
+ )
312
+
313
+ conv_mpt = Conversation(
314
+ system="""<|im_start|>system
315
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
316
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
317
+ version="mpt",
318
+ messages=(),
319
+ offset=0,
320
+ sep_style=SeparatorStyle.MPT,
321
+ sep="<|im_end|>",
322
+ )
323
+
324
+ conv_llava_plain = Conversation(
325
+ system="",
326
+ roles=("", ""),
327
+ messages=(),
328
+ offset=0,
329
+ sep_style=SeparatorStyle.PLAIN,
330
+ sep="\n",
331
+ )
332
+
333
+ conv_llava_v0 = Conversation(
334
+ system="A chat between a curious human and an artificial intelligence assistant. "
335
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
336
+ roles=("Human", "Assistant"),
337
+ messages=(("Human", "Hi!"), ("Assistant", "Hi there! How can I help you today?")),
338
+ offset=2,
339
+ sep_style=SeparatorStyle.SINGLE,
340
+ sep="###",
341
+ )
342
+
343
+ conv_llava_v0_mmtag = Conversation(
344
+ system="A chat between a curious user and an artificial intelligence assistant. "
345
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
346
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
347
+ roles=("Human", "Assistant"),
348
+ messages=(),
349
+ offset=0,
350
+ sep_style=SeparatorStyle.SINGLE,
351
+ sep="###",
352
+ version="v0_mmtag",
353
+ )
354
+
355
+ conv_llava_v1 = Conversation(
356
+ system="A chat between a curious human and an artificial intelligence assistant. "
357
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
358
+ roles=("USER", "ASSISTANT"),
359
+ version="v1",
360
+ messages=(),
361
+ offset=0,
362
+ sep_style=SeparatorStyle.TWO,
363
+ sep=" ",
364
+ sep2="</s>",
365
+ )
366
+
367
+ conv_llava_v1_mmtag = Conversation(
368
+ system="A chat between a curious user and an artificial intelligence assistant. "
369
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
370
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
371
+ roles=("USER", "ASSISTANT"),
372
+ messages=(),
373
+ offset=0,
374
+ sep_style=SeparatorStyle.TWO,
375
+ sep=" ",
376
+ sep2="</s>",
377
+ version="v1_mmtag",
378
+ )
379
+
380
+ default_conversation = conv_vicuna_v0
381
+ conv_templates = {
382
+ "default": conv_vicuna_v0,
383
+ "v0": conv_vicuna_v0,
384
+ "v1": conv_vicuna_v1,
385
+ "vicuna_v1": conv_vicuna_v1,
386
+ "llama_2": conv_llama_2,
387
+ "plain": conv_llava_plain,
388
+ "v0_plain": conv_llava_plain,
389
+ "llava_v0": conv_llava_v0,
390
+ "v0_mmtag": conv_llava_v0_mmtag,
391
+ "llava_v1": conv_llava_v1,
392
+ "v1_mmtag": conv_llava_v1_mmtag,
393
+ "llava_llama_2": conv_llava_llama_2,
394
+ "mpt": conv_mpt,
395
+ }
396
+
397
+
398
+ if __name__ == "__main__":
399
+ print(default_conversation.get_prompt())
model/llava/mm_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import StoppingCriteria
7
+
8
+ from .constants import IMAGE_TOKEN_INDEX
9
+
10
+
11
+ def load_image_from_base64(image):
12
+ return Image.open(BytesIO(base64.b64decode(image)))
13
+
14
+
15
+ def process_images(images, image_processor, model_cfg):
16
+ return image_processor(images, return_tensors="pt")["pixel_values"]
17
+
18
+
19
+ def tokenizer_image_token(
20
+ prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
21
+ ):
22
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
23
+
24
+ def insert_separator(X, sep):
25
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
26
+
27
+ input_ids = []
28
+ offset = 0
29
+ if (
30
+ len(prompt_chunks) > 0
31
+ and len(prompt_chunks[0]) > 0
32
+ and prompt_chunks[0][0] == tokenizer.bos_token_id
33
+ ):
34
+ offset = 1
35
+ input_ids.append(prompt_chunks[0][0])
36
+
37
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
38
+ input_ids.extend(x[offset:])
39
+
40
+ if return_tensors is not None:
41
+ if return_tensors == "pt":
42
+ return torch.tensor(input_ids, dtype=torch.long)
43
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
44
+ return input_ids
45
+
46
+
47
+ def get_model_name_from_path(model_path):
48
+ model_path = model_path.strip("/")
49
+ model_paths = model_path.split("/")
50
+ if model_paths[-1].startswith("checkpoint-"):
51
+ return model_paths[-2] + "_" + model_paths[-1]
52
+ else:
53
+ return model_paths[-1]
54
+
55
+
56
+ class KeywordsStoppingCriteria(StoppingCriteria):
57
+ def __init__(self, keywords, tokenizer, input_ids):
58
+ self.keywords = keywords
59
+ self.keyword_ids = []
60
+ for keyword in keywords:
61
+ cur_keyword_ids = tokenizer(keyword).input_ids
62
+ if (
63
+ len(cur_keyword_ids) > 1
64
+ and cur_keyword_ids[0] == tokenizer.bos_token_id
65
+ ):
66
+ cur_keyword_ids = cur_keyword_ids[1:]
67
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
68
+ self.tokenizer = tokenizer
69
+ self.start_len = input_ids.shape[1]
70
+
71
+ def __call__(
72
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
73
+ ) -> bool:
74
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
75
+ offset = min(output_ids.shape[1] - self.start_len, 3)
76
+ self.keyword_ids = [
77
+ keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
78
+ ]
79
+ for keyword_id in self.keyword_ids:
80
+ if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
81
+ return True
82
+ outputs = self.tokenizer.batch_decode(
83
+ output_ids[:, -offset:], skip_special_tokens=True
84
+ )[0]
85
+ for keyword in self.keywords:
86
+ if keyword in outputs:
87
+ return True
88
+ return False