Update modeling_minicpmv.py
#39
by
qianyuchen - opened
- modeling_minicpmv.py +96 -7
modeling_minicpmv.py
CHANGED
|
@@ -42,13 +42,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 42 |
|
| 43 |
return model
|
| 44 |
|
| 45 |
-
def init_resampler(self, embed_dim, vision_dim):
|
| 46 |
return Resampler(
|
| 47 |
num_queries=self.config.query_num,
|
| 48 |
embed_dim=embed_dim,
|
| 49 |
num_heads=embed_dim // 128,
|
| 50 |
kv_dim=vision_dim,
|
| 51 |
-
adaptive=True
|
| 52 |
)
|
| 53 |
|
| 54 |
def init_transform(self):
|
|
@@ -60,13 +60,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 60 |
),
|
| 61 |
]
|
| 62 |
)
|
| 63 |
-
|
| 64 |
def get_input_embeddings(self):
|
| 65 |
return self.llm.get_input_embeddings()
|
| 66 |
|
| 67 |
def set_input_embeddings(self, value):
|
| 68 |
self.llm.embed_tokens = value
|
| 69 |
-
|
| 70 |
def get_vllm_embedding(self, data):
|
| 71 |
if 'vision_hidden_states' not in data:
|
| 72 |
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
|
@@ -152,16 +152,105 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 152 |
image_indices = torch.stack(
|
| 153 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
| 154 |
).to(vllm_embedding.device)
|
| 155 |
-
|
| 156 |
cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
| 157 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
|
| 158 |
elif self.training:
|
| 159 |
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
| 160 |
|
| 161 |
return vllm_embedding, vision_hidden_states
|
| 162 |
-
|
| 163 |
def forward(self, data, **kwargs):
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
position_ids = data["position_ids"]
|
| 166 |
if position_ids.dtype != torch.int64:
|
| 167 |
position_ids = position_ids.long()
|
|
|
|
| 42 |
|
| 43 |
return model
|
| 44 |
|
| 45 |
+
def init_resampler(self, embed_dim, vision_dim,):
|
| 46 |
return Resampler(
|
| 47 |
num_queries=self.config.query_num,
|
| 48 |
embed_dim=embed_dim,
|
| 49 |
num_heads=embed_dim // 128,
|
| 50 |
kv_dim=vision_dim,
|
| 51 |
+
adaptive=True,
|
| 52 |
)
|
| 53 |
|
| 54 |
def init_transform(self):
|
|
|
|
| 60 |
),
|
| 61 |
]
|
| 62 |
)
|
| 63 |
+
|
| 64 |
def get_input_embeddings(self):
|
| 65 |
return self.llm.get_input_embeddings()
|
| 66 |
|
| 67 |
def set_input_embeddings(self, value):
|
| 68 |
self.llm.embed_tokens = value
|
| 69 |
+
|
| 70 |
def get_vllm_embedding(self, data):
|
| 71 |
if 'vision_hidden_states' not in data:
|
| 72 |
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
|
|
|
| 152 |
image_indices = torch.stack(
|
| 153 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
| 154 |
).to(vllm_embedding.device)
|
|
|
|
| 155 |
cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
| 156 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
|
| 157 |
elif self.training:
|
| 158 |
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
| 159 |
|
| 160 |
return vllm_embedding, vision_hidden_states
|
| 161 |
+
|
| 162 |
def forward(self, data, **kwargs):
|
| 163 |
+
|
| 164 |
+
if 'vision_hidden_states' not in data:
|
| 165 |
+
dtype = self.llm.lm_head.weight.dtype
|
| 166 |
+
device = self.llm.lm_head.weight.device
|
| 167 |
+
tgt_sizes = data['tgt_sizes']
|
| 168 |
+
pixel_values_list = data['pixel_values']
|
| 169 |
+
vision_hidden_states = []
|
| 170 |
+
all_pixel_values = []
|
| 171 |
+
img_cnt = []
|
| 172 |
+
for pixel_values in pixel_values_list:
|
| 173 |
+
img_cnt.append(len(pixel_values))
|
| 174 |
+
all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
|
| 175 |
+
|
| 176 |
+
# exist image
|
| 177 |
+
if all_pixel_values:
|
| 178 |
+
tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
|
| 179 |
+
|
| 180 |
+
if self.config.batch_vision_input:
|
| 181 |
+
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
|
| 182 |
+
|
| 183 |
+
all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True,
|
| 184 |
+
padding_value=0.0)
|
| 185 |
+
B, L, _ = all_pixel_values.shape
|
| 186 |
+
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
| 187 |
+
|
| 188 |
+
patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
|
| 189 |
+
for i in range(B):
|
| 190 |
+
patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
|
| 191 |
+
|
| 192 |
+
vision_embedding = self.vpm(all_pixel_values.type(dtype), patch_attention_mask=patch_attn_mask).last_hidden_state
|
| 193 |
+
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
|
| 194 |
+
else:
|
| 195 |
+
# get vision_embedding foreach
|
| 196 |
+
vision_embedding = []
|
| 197 |
+
for single_tgt_size, single_pixel_values in zip(tgt_sizes, all_pixel_values):
|
| 198 |
+
single_pixel_values = single_pixel_values.unsqueeze(0)
|
| 199 |
+
B, L, _ = single_pixel_values.shape
|
| 200 |
+
single_pixel_values = single_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
| 201 |
+
single_vision_embedding = self.vpm(single_pixel_values.type(dtype)).last_hidden_state
|
| 202 |
+
single_vision_embedding = self.resampler(single_vision_embedding, single_tgt_size.unsqueeze(0))
|
| 203 |
+
vision_embedding.append(single_vision_embedding)
|
| 204 |
+
vision_embedding = torch.vstack(vision_embedding)
|
| 205 |
+
|
| 206 |
+
start = 0
|
| 207 |
+
for pixel_values in pixel_values_list:
|
| 208 |
+
img_cnt = len(pixel_values)
|
| 209 |
+
if img_cnt > 0:
|
| 210 |
+
vision_hidden_states.append(vision_embedding[start: start + img_cnt])
|
| 211 |
+
start += img_cnt
|
| 212 |
+
else:
|
| 213 |
+
vision_hidden_states.append([])
|
| 214 |
+
else: # no image
|
| 215 |
+
if self.training:
|
| 216 |
+
dummy_image = torch.zeros(
|
| 217 |
+
(1, 3, 224, 224),
|
| 218 |
+
device=device, dtype=dtype
|
| 219 |
+
)
|
| 220 |
+
tgt_sizes = torch.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).type(torch.int32)
|
| 221 |
+
dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
|
| 222 |
+
else:
|
| 223 |
+
dummy_feature = []
|
| 224 |
+
for _ in range(len(pixel_values_list)):
|
| 225 |
+
vision_hidden_states.append(dummy_feature)
|
| 226 |
+
|
| 227 |
+
else:
|
| 228 |
+
vision_hidden_states = data['vision_hidden_states']
|
| 229 |
+
|
| 230 |
+
if hasattr(self.llm.config, 'scale_emb'):
|
| 231 |
+
vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
|
| 232 |
+
else:
|
| 233 |
+
vllm_embedding = self.llm.model.embed_tokens(data['input_ids'])
|
| 234 |
+
|
| 235 |
+
vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(
|
| 236 |
+
i, torch.Tensor) else i for i in vision_hidden_states]
|
| 237 |
+
|
| 238 |
+
bs = len(data['input_ids'])
|
| 239 |
+
for i in range(bs):
|
| 240 |
+
cur_vs_hs = vision_hidden_states[i]
|
| 241 |
+
if len(cur_vs_hs) > 0:
|
| 242 |
+
cur_vllm_emb = vllm_embedding[i]
|
| 243 |
+
cur_image_bound = data['image_bound'][i]
|
| 244 |
+
if len(cur_image_bound) > 0:
|
| 245 |
+
image_indices = torch.stack(
|
| 246 |
+
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
| 247 |
+
).to(vllm_embedding.device)
|
| 248 |
+
cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
| 249 |
+
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
|
| 250 |
+
elif self.training:
|
| 251 |
+
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
| 252 |
+
|
| 253 |
+
# vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
|
| 254 |
position_ids = data["position_ids"]
|
| 255 |
if position_ids.dtype != torch.int64:
|
| 256 |
position_ids = position_ids.long()
|