sy1998 commited on
Commit
f17c66b
·
verified ·
1 Parent(s): 246b86d

Update weights

Browse files
__pycache__/configuration_earthmind_chat.cpython-310.pyc CHANGED
Binary files a/__pycache__/configuration_earthmind_chat.cpython-310.pyc and b/__pycache__/configuration_earthmind_chat.cpython-310.pyc differ
 
__pycache__/configuration_intern_vit.cpython-310.pyc CHANGED
Binary files a/__pycache__/configuration_intern_vit.cpython-310.pyc and b/__pycache__/configuration_intern_vit.cpython-310.pyc differ
 
__pycache__/configuration_internlm2.cpython-310.pyc CHANGED
Binary files a/__pycache__/configuration_internlm2.cpython-310.pyc and b/__pycache__/configuration_internlm2.cpython-310.pyc differ
 
__pycache__/configuration_phi3.cpython-310.pyc CHANGED
Binary files a/__pycache__/configuration_phi3.cpython-310.pyc and b/__pycache__/configuration_phi3.cpython-310.pyc differ
 
__pycache__/flash_attention.cpython-310.pyc CHANGED
Binary files a/__pycache__/flash_attention.cpython-310.pyc and b/__pycache__/flash_attention.cpython-310.pyc differ
 
__pycache__/modeling_earthmind_chat.cpython-310.pyc CHANGED
Binary files a/__pycache__/modeling_earthmind_chat.cpython-310.pyc and b/__pycache__/modeling_earthmind_chat.cpython-310.pyc differ
 
__pycache__/modeling_intern_vit.cpython-310.pyc CHANGED
Binary files a/__pycache__/modeling_intern_vit.cpython-310.pyc and b/__pycache__/modeling_intern_vit.cpython-310.pyc differ
 
__pycache__/modeling_internlm2.cpython-310.pyc CHANGED
Binary files a/__pycache__/modeling_internlm2.cpython-310.pyc and b/__pycache__/modeling_internlm2.cpython-310.pyc differ
 
__pycache__/modeling_phi3.cpython-310.pyc CHANGED
Binary files a/__pycache__/modeling_phi3.cpython-310.pyc and b/__pycache__/modeling_phi3.cpython-310.pyc differ
 
__pycache__/sam2.cpython-310.pyc CHANGED
Binary files a/__pycache__/sam2.cpython-310.pyc and b/__pycache__/sam2.cpython-310.pyc differ
 
__pycache__/templates.cpython-310.pyc CHANGED
Binary files a/__pycache__/templates.cpython-310.pyc and b/__pycache__/templates.cpython-310.pyc differ
 
config.json CHANGED
@@ -102,7 +102,7 @@
102
  "select_layer": -1,
103
  "template": "phi3_chat",
104
  "tie_word_embeddings": false,
105
- "torch_dtype": "float32",
106
  "transformers_version": null,
107
  "use_backbone_lora": 0,
108
  "use_llm_lora": 0,
 
102
  "select_layer": -1,
103
  "template": "phi3_chat",
104
  "tie_word_embeddings": false,
105
+ "torch_dtype": "bfloat16",
106
  "transformers_version": null,
107
  "use_backbone_lora": 0,
108
  "use_llm_lora": 0,
model-00001-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:39e63f528727404f1bcdbf73e14d96ed8d27e98ed19c3fe9bbced85c0604130a
3
- size 4971490432
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffb2ee69f62e4ee01ef85baa7899ef2a2058a1ddfa8f7d56ca015b7a57ae57cc
3
+ size 4971473960
model-00002-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2bfb324f82fc39c5726321b3415ed4be3e942967f0cfbb50729cb5948ac78a6c
3
  size 4932952216
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da0766717029b20f96662a83365ebbca20b8f71b3a4886f21851d58620c023a6
3
  size 4932952216
model-00003-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:42f354c28221ee31e11f68d6981a3ad3a0d8d0264b2978957d01746233a30b29
3
  size 4995688160
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffc1300216ae100b755f554627400b6e039bdcf06d889b9c5f7fe198445dc2ca
3
  size 4995688160
model-00004-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9ab80a04199bb5a17e5696a716037554e7d2a8ffc230e3863cd86dd29b587acc
3
  size 259328744
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c94104c4b476c29e652268cbe580828af39841c553d443c3984c0d4619b572ce
3
  size 259328744
model.safetensors.index.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "metadata": {
3
- "total_size": 15159230664
4
  },
5
  "weight_map": {
6
  "grounding_encoder.sam2_model.image_encoder.neck.convs.0.conv.bias": "model-00004-of-00004.safetensors",
@@ -1338,7 +1338,6 @@
1338
  "language_model.model.layers.9.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
1339
  "language_model.model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
1340
  "language_model.model.norm.weight": "model-00003-of-00004.safetensors",
1341
- "local_query": "model-00001-of-00004.safetensors",
1342
  "mlp1.0.bias": "model-00003-of-00004.safetensors",
1343
  "mlp1.0.weight": "model-00003-of-00004.safetensors",
1344
  "mlp1.1.bias": "model-00003-of-00004.safetensors",
 
1
  {
2
  "metadata": {
3
+ "total_size": 15159214280
4
  },
5
  "weight_map": {
6
  "grounding_encoder.sam2_model.image_encoder.neck.convs.0.conv.bias": "model-00004-of-00004.safetensors",
 
1338
  "language_model.model.layers.9.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
1339
  "language_model.model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
1340
  "language_model.model.norm.weight": "model-00003-of-00004.safetensors",
 
1341
  "mlp1.0.bias": "model-00003-of-00004.safetensors",
1342
  "mlp1.0.weight": "model-00003-of-00004.safetensors",
1343
  "mlp1.1.bias": "model-00003-of-00004.safetensors",
modeling_earthmind_chat.py CHANGED
@@ -3,7 +3,7 @@
3
  # Copyright (c) 2024 OpenGVLab
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
6
-
7
  import warnings
8
  from typing import Any, List, Optional, Tuple, Union
9
 
@@ -113,7 +113,9 @@ class Sa2VAChatModel(PreTrainedModel):
113
  self.ps_version = config.ps_version
114
  self.llm_arch_name = config.llm_config.architectures[0]
115
 
116
- self.local_query = nn.Parameter(torch.randn(2, 2048))
 
 
117
 
118
  use_flash_attn = use_flash_attn if has_flash_attn else False
119
  config.vision_config.use_flash_attn = True if use_flash_attn else False
@@ -334,7 +336,6 @@ class Sa2VAChatModel(PreTrainedModel):
334
 
335
  B, N, C = input_embeds.shape
336
  input_embeds = input_embeds.reshape(B * N, C)
337
-
338
 
339
  self._count += 1
340
 
@@ -550,6 +551,7 @@ class Sa2VAChatModel(PreTrainedModel):
550
  ) -> torch.LongTensor:
551
  device = self.device
552
  assert self.img_context_token_id is not None
 
553
 
554
  if pixel_values is not None:
555
  if visual_features is not None:
@@ -572,14 +574,60 @@ class Sa2VAChatModel(PreTrainedModel):
572
 
573
  vit_embeds = self.extract_feature(pixel_values.to(device))
574
  rgb_vit_embeds = self.extract_feature(rgb_pixel_values.to(device))
575
- vit_embeds = torch.cat([vit_embeds, rgb_vit_embeds], dim=1) # 10, 512, 2048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
 
577
  image_flags = torch.sum(pixel_values, dim=(1, 2, 3)) != 0
578
  image_flags = image_flags.long()
579
  vit_embeds = vit_embeds[image_flags == 1]
580
 
581
 
582
- input_embeds = self.language_model.get_input_embeddings()(input_ids.to(device))
583
  B, N, C = input_embeds.shape
584
  input_embeds = input_embeds.reshape(B * N, C)
585
 
@@ -648,7 +696,7 @@ class Sa2VAChatModel(PreTrainedModel):
648
 
649
  # print("generate",encode_outputs.hidden_states[-1][0].shape)
650
  encode_feature=encode_outputs.hidden_states[-1][0]
651
- return outputs,encode_feature,encode_outputs.attentions
652
 
653
  def preparing_for_generation(self, tokenizer, max_new_tokens=2048, torch_dtype=torch.bfloat16):
654
  # set stop criteria and generation configs for model
@@ -784,7 +832,7 @@ class Sa2VAChatModel(PreTrainedModel):
784
  self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
785
  ]).to(self.torch_dtype)
786
 
787
- images = dynamic_preprocess(image, self.min_dynamic_patch,
788
  self.max_dynamic_patch,
789
  self.image_size, self.use_thumbnail)
790
 
@@ -899,11 +947,6 @@ class Sa2VAChatModel(PreTrainedModel):
899
 
900
 
901
 
902
-
903
-
904
-
905
-
906
-
907
  def predict_forward_multi(
908
  self,
909
  image=None,
@@ -966,7 +1009,7 @@ class Sa2VAChatModel(PreTrainedModel):
966
  input_dict['vp_overall_mask'] = None
967
  else:
968
  ori_image_size = image.size
969
-
970
  # prepare grounding images
971
  g_image = np.array(image) # for grounding
972
  g_image = self.extra_image_processor.apply_image(g_image)
@@ -976,14 +1019,11 @@ class Sa2VAChatModel(PreTrainedModel):
976
  self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
977
  ]).to(self.torch_dtype)
978
 
979
- images = dynamic_preprocess(image, self.min_dynamic_patch,
980
  self.max_dynamic_patch,
981
  self.image_size, self.use_thumbnail)
982
 
983
-
984
- rgb_images = dynamic_preprocess(rgb_image, self.min_dynamic_patch,
985
- self.max_dynamic_patch,
986
- self.image_size, self.use_thumbnail)
987
 
988
 
989
 
@@ -996,11 +1036,34 @@ class Sa2VAChatModel(PreTrainedModel):
996
  pixel_values = [self.transformer(image) for image in images]
997
  pixel_values = torch.stack(pixel_values).to(self.torch_dtype)
998
 
999
-
1000
- rgb_pixel_values = [self.transformer(image) for image in rgb_images]
1001
- rgb_pixel_values = torch.stack(rgb_pixel_values).to(self.torch_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1002
 
1003
- num_image_tokens = pixel_values.shape[0] * self.patch_token *2
1004
  num_frames = 1
1005
  input_dict['g_pixel_values'] = g_pixel_values
1006
  input_dict['pixel_values'] = pixel_values
@@ -1067,7 +1130,7 @@ class Sa2VAChatModel(PreTrainedModel):
1067
  'vp_overall_mask': input_dict['vp_overall_mask'],
1068
  }
1069
 
1070
- generate_output,encode_feature,encode_attention = self.generate_multi(
1071
  **mm_inputs,
1072
  generation_config=self.gen_config,
1073
  streamer=None,
@@ -1109,7 +1172,7 @@ class Sa2VAChatModel(PreTrainedModel):
1109
  masks = masks.cpu().numpy()
1110
  ret_masks.append(masks)
1111
 
1112
- return {'prediction': predict, 'prediction_masks': ret_masks,}
1113
 
1114
  def get_seg_hidden_states(hidden_states, output_ids, seg_id):
1115
  seg_mask = output_ids == seg_id
@@ -1159,6 +1222,11 @@ def dynamic_preprocess(image,
1159
  target_height = image_size * target_aspect_ratio[1]
1160
  blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
1161
 
 
 
 
 
 
1162
  # resize the image
1163
  resized_img = image.resize((target_width, target_height))
1164
  processed_images = []
@@ -1174,7 +1242,7 @@ def dynamic_preprocess(image,
1174
  if use_thumbnail and len(processed_images) != 1:
1175
  thumbnail_img = image.resize((image_size, image_size))
1176
  processed_images.append(thumbnail_img)
1177
- return processed_images
1178
 
1179
 
1180
  from transformers.cache_utils import Cache, DynamicCache
 
3
  # Copyright (c) 2024 OpenGVLab
4
  # Licensed under The MIT License [see LICENSE for details]
5
  # --------------------------------------------------------
6
+ from math import sqrt
7
  import warnings
8
  from typing import Any, List, Optional, Tuple, Union
9
 
 
113
  self.ps_version = config.ps_version
114
  self.llm_arch_name = config.llm_config.architectures[0]
115
 
116
+
117
+
118
+ self.hca_tau = 1.0
119
 
120
  use_flash_attn = use_flash_attn if has_flash_attn else False
121
  config.vision_config.use_flash_attn = True if use_flash_attn else False
 
336
 
337
  B, N, C = input_embeds.shape
338
  input_embeds = input_embeds.reshape(B * N, C)
 
339
 
340
  self._count += 1
341
 
 
551
  ) -> torch.LongTensor:
552
  device = self.device
553
  assert self.img_context_token_id is not None
554
+ input_embeds = self.language_model.get_input_embeddings()(input_ids.to(device))
555
 
556
  if pixel_values is not None:
557
  if visual_features is not None:
 
574
 
575
  vit_embeds = self.extract_feature(pixel_values.to(device))
576
  rgb_vit_embeds = self.extract_feature(rgb_pixel_values.to(device))
577
+
578
+
579
+ print("extract_featrues",rgb_vit_embeds.shape,vit_embeds.shape)
580
+ if rgb_vit_embeds.shape[0] != vit_embeds.shape[0]:
581
+ # 对batch维度进行平均池化,保持后两个维度不变
582
+ rgb_vit_embeds = rgb_vit_embeds.mean(dim=0, keepdim=True)
583
+ print("after avgpooling:", rgb_vit_embeds.shape)
584
+ X_sar=vit_embeds
585
+ X_rgb=rgb_vit_embeds
586
+
587
+
588
+ tau = 0.2 # 温度,拉开分布
589
+ tau_txt=0.2
590
+ D = X_rgb.size(-1)
591
+
592
+ # 互注意力
593
+ A_rs = torch.matmul(X_rgb, X_sar.transpose(-2, -1)) / (sqrt(D) * tau) # (B,N,N)
594
+ A_rs = F.softmax(A_rs, dim=-1) # RGB->SAR
595
+
596
+ A_sr = torch.matmul(X_sar, X_rgb.transpose(-2, -1)) / (sqrt(D) * tau) # (B,N,N)
597
+ A_sr = F.softmax(A_sr, dim=-1) # SAR->RGB
598
+
599
+ # 对角线 = 对应位置的互吸引力
600
+ r_sar = torch.diagonal(A_rs, dim1=-2, dim2=-1) # (B, N) RGB认为对应SAR位置的可靠度
601
+ r_rgb = torch.diagonal(A_sr, dim1=-2, dim2=-1) # (B, N) SAR认为对应RGB位置的可靠度
602
+
603
+ ###################### add visual-text cross attention
604
+ t = input_embeds.mean(dim=1)
605
+ beta_rgb = torch.matmul(X_rgb, t.unsqueeze(-1)).squeeze(-1) / tau_txt # (B, N)
606
+ beta_sar = torch.matmul(X_sar, t.unsqueeze(-1)).squeeze(-1) / tau_txt # (B, N)
607
+
608
+
609
+ # 先把两路都变成概率,再在 logit 空间做线性插值更稳
610
+ eps = 1e-6
611
+ vis_pair = torch.stack([r_rgb, r_sar], dim=-1) # (B, N, 2)
612
+ txt_pair = torch.stack([beta_rgb, beta_sar], dim=-1) # (B, N, 2)
613
+
614
+
615
+ logits = 0.5 * vis_pair + 0.5 * txt_pair # (B, N, 2)
616
+
617
+ alpha = F.softmax(logits, dim=-1) # (B, N, 2)
618
+ alpha_rgb = alpha[..., 0].unsqueeze(-1) # (B, N, 1)
619
+ alpha_sar = alpha[..., 1].unsqueeze(-1) # (B, N, 1)
620
+
621
+ # 按位置加权融合
622
+ Z = alpha_rgb * X_rgb + alpha_sar * X_sar # (B,N,D)
623
+ vit_embeds = Z
624
 
625
  image_flags = torch.sum(pixel_values, dim=(1, 2, 3)) != 0
626
  image_flags = image_flags.long()
627
  vit_embeds = vit_embeds[image_flags == 1]
628
 
629
 
630
+
631
  B, N, C = input_embeds.shape
632
  input_embeds = input_embeds.reshape(B * N, C)
633
 
 
696
 
697
  # print("generate",encode_outputs.hidden_states[-1][0].shape)
698
  encode_feature=encode_outputs.hidden_states[-1][0]
699
+ return outputs,encode_feature,encode_outputs.attentions,(alpha_rgb,alpha_sar)
700
 
701
  def preparing_for_generation(self, tokenizer, max_new_tokens=2048, torch_dtype=torch.bfloat16):
702
  # set stop criteria and generation configs for model
 
832
  self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
833
  ]).to(self.torch_dtype)
834
 
835
+ images,weight = dynamic_preprocess(image, self.min_dynamic_patch,
836
  self.max_dynamic_patch,
837
  self.image_size, self.use_thumbnail)
838
 
 
947
 
948
 
949
 
 
 
 
 
 
950
  def predict_forward_multi(
951
  self,
952
  image=None,
 
1009
  input_dict['vp_overall_mask'] = None
1010
  else:
1011
  ori_image_size = image.size
1012
+
1013
  # prepare grounding images
1014
  g_image = np.array(image) # for grounding
1015
  g_image = self.extra_image_processor.apply_image(g_image)
 
1019
  self.grounding_encoder.preprocess_image(pixel) for pixel in extra_pixel_values
1020
  ]).to(self.torch_dtype)
1021
 
1022
+ images,sta = dynamic_preprocess(image, self.min_dynamic_patch,
1023
  self.max_dynamic_patch,
1024
  self.image_size, self.use_thumbnail)
1025
 
1026
+
 
 
 
1027
 
1028
 
1029
 
 
1036
  pixel_values = [self.transformer(image) for image in images]
1037
  pixel_values = torch.stack(pixel_values).to(self.torch_dtype)
1038
 
1039
+ if type(rgb_image) is list:
1040
+ pixel_values_list = []
1041
+ for img_rgb_like in rgb_image:
1042
+ sub_images,_ = dynamic_preprocess(
1043
+ img_rgb_like,
1044
+ self.min_dynamic_patch,
1045
+ self.max_dynamic_patch,
1046
+ self.image_size,
1047
+ self.use_thumbnail
1048
+ )
1049
+ pixel_values_list.extend([self.transformer(si) for si in sub_images])
1050
+
1051
+ rgb_pixel_values = torch.stack(pixel_values_list).to(self.torch_dtype) # 形状:[M_total, 3, 44
1052
+
1053
+ else:
1054
+ rgb_images,sta = dynamic_preprocess(rgb_image, self.min_dynamic_patch,
1055
+ self.max_dynamic_patch,
1056
+ self.image_size, self.use_thumbnail)
1057
+
1058
+
1059
+
1060
+
1061
+ rgb_pixel_values = [self.transformer(image) for image in rgb_images]
1062
+ rgb_pixel_values = torch.stack(rgb_pixel_values).to(self.torch_dtype)
1063
+
1064
+ print("input",rgb_pixel_values.shape,pixel_values.shape)
1065
 
1066
+ num_image_tokens = pixel_values.shape[0] * self.patch_token
1067
  num_frames = 1
1068
  input_dict['g_pixel_values'] = g_pixel_values
1069
  input_dict['pixel_values'] = pixel_values
 
1130
  'vp_overall_mask': input_dict['vp_overall_mask'],
1131
  }
1132
 
1133
+ generate_output,encode_feature,encode_attention,vis_weight = self.generate_multi(
1134
  **mm_inputs,
1135
  generation_config=self.gen_config,
1136
  streamer=None,
 
1172
  masks = masks.cpu().numpy()
1173
  ret_masks.append(masks)
1174
 
1175
+ return {'prediction': predict, 'prediction_masks': ret_masks,'vis_weight':vis_weight,"sta":sta}
1176
 
1177
  def get_seg_hidden_states(hidden_states, output_ids, seg_id):
1178
  seg_mask = output_ids == seg_id
 
1222
  target_height = image_size * target_aspect_ratio[1]
1223
  blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
1224
 
1225
+ cols=target_aspect_ratio[0]
1226
+ rows=target_aspect_ratio[1]
1227
+
1228
+
1229
+
1230
  # resize the image
1231
  resized_img = image.resize((target_width, target_height))
1232
  processed_images = []
 
1242
  if use_thumbnail and len(processed_images) != 1:
1243
  thumbnail_img = image.resize((image_size, image_size))
1244
  processed_images.append(thumbnail_img)
1245
+ return processed_images,(cols,rows)
1246
 
1247
 
1248
  from transformers.cache_utils import Cache, DynamicCache