Update weights
Browse files- __pycache__/configuration_earthmind_chat.cpython-310.pyc +0 -0
- __pycache__/configuration_intern_vit.cpython-310.pyc +0 -0
- __pycache__/configuration_internlm2.cpython-310.pyc +0 -0
- __pycache__/configuration_phi3.cpython-310.pyc +0 -0
- __pycache__/flash_attention.cpython-310.pyc +0 -0
- __pycache__/modeling_earthmind_chat.cpython-310.pyc +0 -0
- __pycache__/modeling_intern_vit.cpython-310.pyc +0 -0
- __pycache__/modeling_internlm2.cpython-310.pyc +0 -0
- __pycache__/modeling_phi3.cpython-310.pyc +0 -0
- __pycache__/sam2.cpython-310.pyc +0 -0
- __pycache__/templates.cpython-310.pyc +0 -0
- config.json +1 -1
- model-00001-of-00004.safetensors +2 -2
- model-00002-of-00004.safetensors +1 -1
- model-00003-of-00004.safetensors +1 -1
- model-00004-of-00004.safetensors +1 -1
- model.safetensors.index.json +1 -2
- modeling_earthmind_chat.py +93 -25
__pycache__/configuration_earthmind_chat.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/configuration_earthmind_chat.cpython-310.pyc and b/__pycache__/configuration_earthmind_chat.cpython-310.pyc differ
|
|
|
__pycache__/configuration_intern_vit.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/configuration_intern_vit.cpython-310.pyc and b/__pycache__/configuration_intern_vit.cpython-310.pyc differ
|
|
|
__pycache__/configuration_internlm2.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/configuration_internlm2.cpython-310.pyc and b/__pycache__/configuration_internlm2.cpython-310.pyc differ
|
|
|
__pycache__/configuration_phi3.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/configuration_phi3.cpython-310.pyc and b/__pycache__/configuration_phi3.cpython-310.pyc differ
|
|
|
__pycache__/flash_attention.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/flash_attention.cpython-310.pyc and b/__pycache__/flash_attention.cpython-310.pyc differ
|
|
|
__pycache__/modeling_earthmind_chat.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/modeling_earthmind_chat.cpython-310.pyc and b/__pycache__/modeling_earthmind_chat.cpython-310.pyc differ
|
|
|
__pycache__/modeling_intern_vit.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/modeling_intern_vit.cpython-310.pyc and b/__pycache__/modeling_intern_vit.cpython-310.pyc differ
|
|
|
__pycache__/modeling_internlm2.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/modeling_internlm2.cpython-310.pyc and b/__pycache__/modeling_internlm2.cpython-310.pyc differ
|
|
|
__pycache__/modeling_phi3.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/modeling_phi3.cpython-310.pyc and b/__pycache__/modeling_phi3.cpython-310.pyc differ
|
|
|
__pycache__/sam2.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/sam2.cpython-310.pyc and b/__pycache__/sam2.cpython-310.pyc differ
|
|
|
__pycache__/templates.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/templates.cpython-310.pyc and b/__pycache__/templates.cpython-310.pyc differ
|
|
|
config.json
CHANGED
|
@@ -102,7 +102,7 @@
|
|
| 102 |
"select_layer": -1,
|
| 103 |
"template": "phi3_chat",
|
| 104 |
"tie_word_embeddings": false,
|
| 105 |
-
"torch_dtype": "
|
| 106 |
"transformers_version": null,
|
| 107 |
"use_backbone_lora": 0,
|
| 108 |
"use_llm_lora": 0,
|
|
|
|
| 102 |
"select_layer": -1,
|
| 103 |
"template": "phi3_chat",
|
| 104 |
"tie_word_embeddings": false,
|
| 105 |
+
"torch_dtype": "bfloat16",
|
| 106 |
"transformers_version": null,
|
| 107 |
"use_backbone_lora": 0,
|
| 108 |
"use_llm_lora": 0,
|
model-00001-of-00004.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ffb2ee69f62e4ee01ef85baa7899ef2a2058a1ddfa8f7d56ca015b7a57ae57cc
|
| 3 |
+
size 4971473960
|
model-00002-of-00004.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 4932952216
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:da0766717029b20f96662a83365ebbca20b8f71b3a4886f21851d58620c023a6
|
| 3 |
size 4932952216
|
model-00003-of-00004.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 4995688160
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ffc1300216ae100b755f554627400b6e039bdcf06d889b9c5f7fe198445dc2ca
|
| 3 |
size 4995688160
|
model-00004-of-00004.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 259328744
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c94104c4b476c29e652268cbe580828af39841c553d443c3984c0d4619b572ce
|
| 3 |
size 259328744
|
model.safetensors.index.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"metadata": {
|
| 3 |
-
"total_size":
|
| 4 |
},
|
| 5 |
"weight_map": {
|
| 6 |
"grounding_encoder.sam2_model.image_encoder.neck.convs.0.conv.bias": "model-00004-of-00004.safetensors",
|
|
@@ -1338,7 +1338,6 @@
|
|
| 1338 |
"language_model.model.layers.9.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 1339 |
"language_model.model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 1340 |
"language_model.model.norm.weight": "model-00003-of-00004.safetensors",
|
| 1341 |
-
"local_query": "model-00001-of-00004.safetensors",
|
| 1342 |
"mlp1.0.bias": "model-00003-of-00004.safetensors",
|
| 1343 |
"mlp1.0.weight": "model-00003-of-00004.safetensors",
|
| 1344 |
"mlp1.1.bias": "model-00003-of-00004.safetensors",
|
|
|
|
| 1 |
{
|
| 2 |
"metadata": {
|
| 3 |
+
"total_size": 15159214280
|
| 4 |
},
|
| 5 |
"weight_map": {
|
| 6 |
"grounding_encoder.sam2_model.image_encoder.neck.convs.0.conv.bias": "model-00004-of-00004.safetensors",
|
|
|
|
| 1338 |
"language_model.model.layers.9.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 1339 |
"language_model.model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 1340 |
"language_model.model.norm.weight": "model-00003-of-00004.safetensors",
|
|
|
|
| 1341 |
"mlp1.0.bias": "model-00003-of-00004.safetensors",
|
| 1342 |
"mlp1.0.weight": "model-00003-of-00004.safetensors",
|
| 1343 |
"mlp1.1.bias": "model-00003-of-00004.safetensors",
|
modeling_earthmind_chat.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
# Copyright (c) 2024 OpenGVLab
|
| 4 |
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
# --------------------------------------------------------
|
| 6 |
-
|
| 7 |
import warnings
|
| 8 |
from typing import Any, List, Optional, Tuple, Union
|
| 9 |
|
|
@@ -113,7 +113,9 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 113 |
self.ps_version = config.ps_version
|
| 114 |
self.llm_arch_name = config.llm_config.architectures[0]
|
| 115 |
|
| 116 |
-
|
|
|
|
|
|
|
| 117 |
|
| 118 |
use_flash_attn = use_flash_attn if has_flash_attn else False
|
| 119 |
config.vision_config.use_flash_attn = True if use_flash_attn else False
|
|
@@ -334,7 +336,6 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 334 |
|
| 335 |
B, N, C = input_embeds.shape
|
| 336 |
input_embeds = input_embeds.reshape(B * N, C)
|
| 337 |
-
|
| 338 |
|
| 339 |
self._count += 1
|
| 340 |
|
|
@@ -550,6 +551,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 550 |
) -> torch.LongTensor:
|
| 551 |
device = self.device
|
| 552 |
assert self.img_context_token_id is not None
|
|
|
|
| 553 |
|
| 554 |
if pixel_values is not None:
|
| 555 |
if visual_features is not None:
|
|
@@ -572,14 +574,60 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 572 |
|
| 573 |
vit_embeds = self.extract_feature(pixel_values.to(device))
|
| 574 |
rgb_vit_embeds = self.extract_feature(rgb_pixel_values.to(device))
|
| 575 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 576 |
|
| 577 |
image_flags = torch.sum(pixel_values, dim=(1, 2, 3)) != 0
|
| 578 |
image_flags = image_flags.long()
|
| 579 |
vit_embeds = vit_embeds[image_flags == 1]
|
| 580 |
|
| 581 |
|
| 582 |
-
|
| 583 |
B, N, C = input_embeds.shape
|
| 584 |
input_embeds = input_embeds.reshape(B * N, C)
|
| 585 |
|
|
@@ -648,7 +696,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 648 |
|
| 649 |
# print("generate",encode_outputs.hidden_states[-1][0].shape)
|
| 650 |
encode_feature=encode_outputs.hidden_states[-1][0]
|
| 651 |
-
return outputs,encode_feature,encode_outputs.attentions
|
| 652 |
|
| 653 |
def preparing_for_generation(self, tokenizer, max_new_tokens=2048, torch_dtype=torch.bfloat16):
|
| 654 |
# set stop criteria and generation configs for model
|
|
@@ -784,7 +832,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 784 |
self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
|
| 785 |
]).to(self.torch_dtype)
|
| 786 |
|
| 787 |
-
images = dynamic_preprocess(image, self.min_dynamic_patch,
|
| 788 |
self.max_dynamic_patch,
|
| 789 |
self.image_size, self.use_thumbnail)
|
| 790 |
|
|
@@ -899,11 +947,6 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 899 |
|
| 900 |
|
| 901 |
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
def predict_forward_multi(
|
| 908 |
self,
|
| 909 |
image=None,
|
|
@@ -966,7 +1009,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 966 |
input_dict['vp_overall_mask'] = None
|
| 967 |
else:
|
| 968 |
ori_image_size = image.size
|
| 969 |
-
|
| 970 |
# prepare grounding images
|
| 971 |
g_image = np.array(image) # for grounding
|
| 972 |
g_image = self.extra_image_processor.apply_image(g_image)
|
|
@@ -976,14 +1019,11 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 976 |
self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
|
| 977 |
]).to(self.torch_dtype)
|
| 978 |
|
| 979 |
-
images = dynamic_preprocess(image, self.min_dynamic_patch,
|
| 980 |
self.max_dynamic_patch,
|
| 981 |
self.image_size, self.use_thumbnail)
|
| 982 |
|
| 983 |
-
|
| 984 |
-
rgb_images = dynamic_preprocess(rgb_image, self.min_dynamic_patch,
|
| 985 |
-
self.max_dynamic_patch,
|
| 986 |
-
self.image_size, self.use_thumbnail)
|
| 987 |
|
| 988 |
|
| 989 |
|
|
@@ -996,11 +1036,34 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 996 |
pixel_values = [self.transformer(image) for image in images]
|
| 997 |
pixel_values = torch.stack(pixel_values).to(self.torch_dtype)
|
| 998 |
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1002 |
|
| 1003 |
-
num_image_tokens = pixel_values.shape[0] * self.patch_token
|
| 1004 |
num_frames = 1
|
| 1005 |
input_dict['g_pixel_values'] = g_pixel_values
|
| 1006 |
input_dict['pixel_values'] = pixel_values
|
|
@@ -1067,7 +1130,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 1067 |
'vp_overall_mask': input_dict['vp_overall_mask'],
|
| 1068 |
}
|
| 1069 |
|
| 1070 |
-
generate_output,encode_feature,encode_attention = self.generate_multi(
|
| 1071 |
**mm_inputs,
|
| 1072 |
generation_config=self.gen_config,
|
| 1073 |
streamer=None,
|
|
@@ -1109,7 +1172,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 1109 |
masks = masks.cpu().numpy()
|
| 1110 |
ret_masks.append(masks)
|
| 1111 |
|
| 1112 |
-
return {'prediction': predict, 'prediction_masks': ret_masks,}
|
| 1113 |
|
| 1114 |
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
|
| 1115 |
seg_mask = output_ids == seg_id
|
|
@@ -1159,6 +1222,11 @@ def dynamic_preprocess(image,
|
|
| 1159 |
target_height = image_size * target_aspect_ratio[1]
|
| 1160 |
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
| 1161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1162 |
# resize the image
|
| 1163 |
resized_img = image.resize((target_width, target_height))
|
| 1164 |
processed_images = []
|
|
@@ -1174,7 +1242,7 @@ def dynamic_preprocess(image,
|
|
| 1174 |
if use_thumbnail and len(processed_images) != 1:
|
| 1175 |
thumbnail_img = image.resize((image_size, image_size))
|
| 1176 |
processed_images.append(thumbnail_img)
|
| 1177 |
-
return processed_images
|
| 1178 |
|
| 1179 |
|
| 1180 |
from transformers.cache_utils import Cache, DynamicCache
|
|
|
|
| 3 |
# Copyright (c) 2024 OpenGVLab
|
| 4 |
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
# --------------------------------------------------------
|
| 6 |
+
from math import sqrt
|
| 7 |
import warnings
|
| 8 |
from typing import Any, List, Optional, Tuple, Union
|
| 9 |
|
|
|
|
| 113 |
self.ps_version = config.ps_version
|
| 114 |
self.llm_arch_name = config.llm_config.architectures[0]
|
| 115 |
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
self.hca_tau = 1.0
|
| 119 |
|
| 120 |
use_flash_attn = use_flash_attn if has_flash_attn else False
|
| 121 |
config.vision_config.use_flash_attn = True if use_flash_attn else False
|
|
|
|
| 336 |
|
| 337 |
B, N, C = input_embeds.shape
|
| 338 |
input_embeds = input_embeds.reshape(B * N, C)
|
|
|
|
| 339 |
|
| 340 |
self._count += 1
|
| 341 |
|
|
|
|
| 551 |
) -> torch.LongTensor:
|
| 552 |
device = self.device
|
| 553 |
assert self.img_context_token_id is not None
|
| 554 |
+
input_embeds = self.language_model.get_input_embeddings()(input_ids.to(device))
|
| 555 |
|
| 556 |
if pixel_values is not None:
|
| 557 |
if visual_features is not None:
|
|
|
|
| 574 |
|
| 575 |
vit_embeds = self.extract_feature(pixel_values.to(device))
|
| 576 |
rgb_vit_embeds = self.extract_feature(rgb_pixel_values.to(device))
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
print("extract_featrues",rgb_vit_embeds.shape,vit_embeds.shape)
|
| 580 |
+
if rgb_vit_embeds.shape[0] != vit_embeds.shape[0]:
|
| 581 |
+
# 对batch维度进行平均池化,保持后两个维度不变
|
| 582 |
+
rgb_vit_embeds = rgb_vit_embeds.mean(dim=0, keepdim=True)
|
| 583 |
+
print("after avgpooling:", rgb_vit_embeds.shape)
|
| 584 |
+
X_sar=vit_embeds
|
| 585 |
+
X_rgb=rgb_vit_embeds
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
tau = 0.2 # 温度,拉开分布
|
| 589 |
+
tau_txt=0.2
|
| 590 |
+
D = X_rgb.size(-1)
|
| 591 |
+
|
| 592 |
+
# 互注意力
|
| 593 |
+
A_rs = torch.matmul(X_rgb, X_sar.transpose(-2, -1)) / (sqrt(D) * tau) # (B,N,N)
|
| 594 |
+
A_rs = F.softmax(A_rs, dim=-1) # RGB->SAR
|
| 595 |
+
|
| 596 |
+
A_sr = torch.matmul(X_sar, X_rgb.transpose(-2, -1)) / (sqrt(D) * tau) # (B,N,N)
|
| 597 |
+
A_sr = F.softmax(A_sr, dim=-1) # SAR->RGB
|
| 598 |
+
|
| 599 |
+
# 对角线 = 对应位置的互吸引力
|
| 600 |
+
r_sar = torch.diagonal(A_rs, dim1=-2, dim2=-1) # (B, N) RGB认为对应SAR位置的可靠度
|
| 601 |
+
r_rgb = torch.diagonal(A_sr, dim1=-2, dim2=-1) # (B, N) SAR认为对应RGB位置的可靠度
|
| 602 |
+
|
| 603 |
+
###################### add visual-text cross attention
|
| 604 |
+
t = input_embeds.mean(dim=1)
|
| 605 |
+
beta_rgb = torch.matmul(X_rgb, t.unsqueeze(-1)).squeeze(-1) / tau_txt # (B, N)
|
| 606 |
+
beta_sar = torch.matmul(X_sar, t.unsqueeze(-1)).squeeze(-1) / tau_txt # (B, N)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
# 先把两路都变成概率,再在 logit 空间做线性插值更稳
|
| 610 |
+
eps = 1e-6
|
| 611 |
+
vis_pair = torch.stack([r_rgb, r_sar], dim=-1) # (B, N, 2)
|
| 612 |
+
txt_pair = torch.stack([beta_rgb, beta_sar], dim=-1) # (B, N, 2)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
logits = 0.5 * vis_pair + 0.5 * txt_pair # (B, N, 2)
|
| 616 |
+
|
| 617 |
+
alpha = F.softmax(logits, dim=-1) # (B, N, 2)
|
| 618 |
+
alpha_rgb = alpha[..., 0].unsqueeze(-1) # (B, N, 1)
|
| 619 |
+
alpha_sar = alpha[..., 1].unsqueeze(-1) # (B, N, 1)
|
| 620 |
+
|
| 621 |
+
# 按位置加权融合
|
| 622 |
+
Z = alpha_rgb * X_rgb + alpha_sar * X_sar # (B,N,D)
|
| 623 |
+
vit_embeds = Z
|
| 624 |
|
| 625 |
image_flags = torch.sum(pixel_values, dim=(1, 2, 3)) != 0
|
| 626 |
image_flags = image_flags.long()
|
| 627 |
vit_embeds = vit_embeds[image_flags == 1]
|
| 628 |
|
| 629 |
|
| 630 |
+
|
| 631 |
B, N, C = input_embeds.shape
|
| 632 |
input_embeds = input_embeds.reshape(B * N, C)
|
| 633 |
|
|
|
|
| 696 |
|
| 697 |
# print("generate",encode_outputs.hidden_states[-1][0].shape)
|
| 698 |
encode_feature=encode_outputs.hidden_states[-1][0]
|
| 699 |
+
return outputs,encode_feature,encode_outputs.attentions,(alpha_rgb,alpha_sar)
|
| 700 |
|
| 701 |
def preparing_for_generation(self, tokenizer, max_new_tokens=2048, torch_dtype=torch.bfloat16):
|
| 702 |
# set stop criteria and generation configs for model
|
|
|
|
| 832 |
self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
|
| 833 |
]).to(self.torch_dtype)
|
| 834 |
|
| 835 |
+
images,weight = dynamic_preprocess(image, self.min_dynamic_patch,
|
| 836 |
self.max_dynamic_patch,
|
| 837 |
self.image_size, self.use_thumbnail)
|
| 838 |
|
|
|
|
| 947 |
|
| 948 |
|
| 949 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 950 |
def predict_forward_multi(
|
| 951 |
self,
|
| 952 |
image=None,
|
|
|
|
| 1009 |
input_dict['vp_overall_mask'] = None
|
| 1010 |
else:
|
| 1011 |
ori_image_size = image.size
|
| 1012 |
+
|
| 1013 |
# prepare grounding images
|
| 1014 |
g_image = np.array(image) # for grounding
|
| 1015 |
g_image = self.extra_image_processor.apply_image(g_image)
|
|
|
|
| 1019 |
self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
|
| 1020 |
]).to(self.torch_dtype)
|
| 1021 |
|
| 1022 |
+
images,sta = dynamic_preprocess(image, self.min_dynamic_patch,
|
| 1023 |
self.max_dynamic_patch,
|
| 1024 |
self.image_size, self.use_thumbnail)
|
| 1025 |
|
| 1026 |
+
|
|
|
|
|
|
|
|
|
|
| 1027 |
|
| 1028 |
|
| 1029 |
|
|
|
|
| 1036 |
pixel_values = [self.transformer(image) for image in images]
|
| 1037 |
pixel_values = torch.stack(pixel_values).to(self.torch_dtype)
|
| 1038 |
|
| 1039 |
+
if type(rgb_image) is list:
|
| 1040 |
+
pixel_values_list = []
|
| 1041 |
+
for img_rgb_like in rgb_image:
|
| 1042 |
+
sub_images,_ = dynamic_preprocess(
|
| 1043 |
+
img_rgb_like,
|
| 1044 |
+
self.min_dynamic_patch,
|
| 1045 |
+
self.max_dynamic_patch,
|
| 1046 |
+
self.image_size,
|
| 1047 |
+
self.use_thumbnail
|
| 1048 |
+
)
|
| 1049 |
+
pixel_values_list.extend([self.transformer(si) for si in sub_images])
|
| 1050 |
+
|
| 1051 |
+
rgb_pixel_values = torch.stack(pixel_values_list).to(self.torch_dtype) # 形状:[M_total, 3, 44
|
| 1052 |
+
|
| 1053 |
+
else:
|
| 1054 |
+
rgb_images,sta = dynamic_preprocess(rgb_image, self.min_dynamic_patch,
|
| 1055 |
+
self.max_dynamic_patch,
|
| 1056 |
+
self.image_size, self.use_thumbnail)
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
|
| 1060 |
+
|
| 1061 |
+
rgb_pixel_values = [self.transformer(image) for image in rgb_images]
|
| 1062 |
+
rgb_pixel_values = torch.stack(rgb_pixel_values).to(self.torch_dtype)
|
| 1063 |
+
|
| 1064 |
+
print("input",rgb_pixel_values.shape,pixel_values.shape)
|
| 1065 |
|
| 1066 |
+
num_image_tokens = pixel_values.shape[0] * self.patch_token
|
| 1067 |
num_frames = 1
|
| 1068 |
input_dict['g_pixel_values'] = g_pixel_values
|
| 1069 |
input_dict['pixel_values'] = pixel_values
|
|
|
|
| 1130 |
'vp_overall_mask': input_dict['vp_overall_mask'],
|
| 1131 |
}
|
| 1132 |
|
| 1133 |
+
generate_output,encode_feature,encode_attention,vis_weight = self.generate_multi(
|
| 1134 |
**mm_inputs,
|
| 1135 |
generation_config=self.gen_config,
|
| 1136 |
streamer=None,
|
|
|
|
| 1172 |
masks = masks.cpu().numpy()
|
| 1173 |
ret_masks.append(masks)
|
| 1174 |
|
| 1175 |
+
return {'prediction': predict, 'prediction_masks': ret_masks,'vis_weight':vis_weight,"sta":sta}
|
| 1176 |
|
| 1177 |
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
|
| 1178 |
seg_mask = output_ids == seg_id
|
|
|
|
| 1222 |
target_height = image_size * target_aspect_ratio[1]
|
| 1223 |
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
| 1224 |
|
| 1225 |
+
cols=target_aspect_ratio[0]
|
| 1226 |
+
rows=target_aspect_ratio[1]
|
| 1227 |
+
|
| 1228 |
+
|
| 1229 |
+
|
| 1230 |
# resize the image
|
| 1231 |
resized_img = image.resize((target_width, target_height))
|
| 1232 |
processed_images = []
|
|
|
|
| 1242 |
if use_thumbnail and len(processed_images) != 1:
|
| 1243 |
thumbnail_img = image.resize((image_size, image_size))
|
| 1244 |
processed_images.append(thumbnail_img)
|
| 1245 |
+
return processed_images,(cols,rows)
|
| 1246 |
|
| 1247 |
|
| 1248 |
from transformers.cache_utils import Cache, DynamicCache
|