Upload processor
Browse files- modeling_yangjian.py +102 -44
- tokenizer_config.json +0 -4
modeling_yangjian.py
CHANGED
|
@@ -279,7 +279,7 @@ class YangJianCompareVisualEncoder(nn.Module):
|
|
| 279 |
# Decoder 部分:Query 与编码特征交互
|
| 280 |
# 可学习的 Query Embeddings
|
| 281 |
self.query_embeddings = nn.Parameter(
|
| 282 |
-
torch.
|
| 283 |
)
|
| 284 |
|
| 285 |
# 只保留 Cross Attention for queries to attend to encoded features
|
|
@@ -314,47 +314,94 @@ class YangJianCompareVisualEncoder(nn.Module):
|
|
| 314 |
self.encoder_mlp2 = self.encoder_mlp2.to(device=device, dtype=dtype)
|
| 315 |
self.decoder_mlp = self.decoder_mlp.to(device=device, dtype=dtype)
|
| 316 |
|
| 317 |
-
def
|
|
|
|
|
|
|
|
|
|
| 318 |
"""
|
| 319 |
Args:
|
| 320 |
images_hidden_states: List of tensor, each tensor has shape [seq_len, hidden_size]
|
| 321 |
|
| 322 |
Returns:
|
| 323 |
-
|
| 324 |
"""
|
| 325 |
if not images_hidden_states:
|
| 326 |
-
return
|
| 327 |
|
| 328 |
# 确保所有组件的设备和数据类型一致
|
| 329 |
-
self._ensure_device_dtype_consistency(images_hidden_states[0])
|
| 330 |
-
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
compare_visual_embeds.append(compare_visual_embed)
|
| 344 |
|
| 345 |
-
return compare_visual_embeds
|
| 346 |
|
| 347 |
-
def _encoder_forward(self, current_features, previous_features):
|
| 348 |
"""
|
| 349 |
Encoder: 双向图像特征交互
|
| 350 |
-
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
| 352 |
"""
|
| 353 |
-
# 确保数据类型和设备一致
|
| 354 |
-
device = current_features.device
|
| 355 |
-
dtype = current_features.dtype
|
| 356 |
-
previous_features = previous_features.to(device=device, dtype=dtype)
|
| 357 |
-
|
| 358 |
# 第一步:previous attend to current
|
| 359 |
residual = previous_features
|
| 360 |
|
|
@@ -365,7 +412,8 @@ class YangJianCompareVisualEncoder(nn.Module):
|
|
| 365 |
# Cross attention: previous attend to current
|
| 366 |
cross_attn_output1 = self.encoder_cross_attn1(
|
| 367 |
query_states=previous_normed,
|
| 368 |
-
key_value_states=current_normed1
|
|
|
|
| 369 |
)
|
| 370 |
|
| 371 |
# Residual connection
|
|
@@ -382,12 +430,13 @@ class YangJianCompareVisualEncoder(nn.Module):
|
|
| 382 |
|
| 383 |
# Layer norm
|
| 384 |
current_normed2 = self.encoder_norm3(current_features)
|
| 385 |
-
previous_normed2 = self.encoder_norm3(previous_features)
|
| 386 |
|
| 387 |
# Cross attention: current attend to previous
|
| 388 |
cross_attn_output2 = self.encoder_cross_attn2(
|
| 389 |
query_states=current_normed2,
|
| 390 |
-
key_value_states=previous_normed2
|
|
|
|
| 391 |
)
|
| 392 |
|
| 393 |
# Residual connection
|
|
@@ -401,17 +450,15 @@ class YangJianCompareVisualEncoder(nn.Module):
|
|
| 401 |
|
| 402 |
return current_features
|
| 403 |
|
| 404 |
-
def _decoder_forward(self, encoded_features):
|
| 405 |
"""
|
| 406 |
-
Decoder: Query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
"""
|
| 408 |
-
# 获取设备和数据类型
|
| 409 |
-
device = encoded_features.device
|
| 410 |
-
dtype = encoded_features.dtype
|
| 411 |
-
|
| 412 |
-
# 初始化 queries 并确保设备和数据类型一致
|
| 413 |
-
queries = self.query_embeddings.to(device=device, dtype=dtype)
|
| 414 |
-
|
| 415 |
# Cross attention: queries attend to encoded features
|
| 416 |
residual = queries
|
| 417 |
queries_normed = self.decoder_norm1(queries)
|
|
@@ -419,7 +466,8 @@ class YangJianCompareVisualEncoder(nn.Module):
|
|
| 419 |
|
| 420 |
cross_attn_output = self.decoder_cross_attn(
|
| 421 |
query_states=queries_normed,
|
| 422 |
-
key_value_states=encoded_normed
|
|
|
|
| 423 |
)
|
| 424 |
|
| 425 |
queries = residual + cross_attn_output
|
|
@@ -430,7 +478,7 @@ class YangJianCompareVisualEncoder(nn.Module):
|
|
| 430 |
mlp_output = self.decoder_mlp(mlp_input)
|
| 431 |
queries = residual + mlp_output
|
| 432 |
|
| 433 |
-
return queries # [token_size, hidden_size]
|
| 434 |
|
| 435 |
|
| 436 |
# 先把组件继承出来方便修改
|
|
@@ -497,10 +545,20 @@ class YangJianVisionTransformerPretrainedModel(Qwen2_5_VisionTransformerPretrain
|
|
| 497 |
|
| 498 |
split_sizes = grid_thw.prod(-1).tolist()
|
| 499 |
splited_hidden_states_before_merger = torch.split(hidden_states, split_sizes)
|
|
|
|
| 500 |
compare_visual_embeds = self.compare_visual_encoder(splited_hidden_states_before_merger)
|
| 501 |
-
#
|
| 502 |
-
|
| 503 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
|
| 505 |
hidden_states = self.merger(hidden_states)
|
| 506 |
reverse_indices = torch.argsort(window_index)
|
|
|
|
| 279 |
# Decoder 部分:Query 与编码特征交互
|
| 280 |
# 可学习的 Query Embeddings
|
| 281 |
self.query_embeddings = nn.Parameter(
|
| 282 |
+
torch.empty(self.token_size, self.hidden_size)
|
| 283 |
)
|
| 284 |
|
| 285 |
# 只保留 Cross Attention for queries to attend to encoded features
|
|
|
|
| 314 |
self.encoder_mlp2 = self.encoder_mlp2.to(device=device, dtype=dtype)
|
| 315 |
self.decoder_mlp = self.decoder_mlp.to(device=device, dtype=dtype)
|
| 316 |
|
| 317 |
+
def _initialize_weights(self):
|
| 318 |
+
nn.init.normal_(self.query_embeddings.weight, mean=0.0, std=0.02)
|
| 319 |
+
|
| 320 |
+
def forward(self, images_hidden_states: list) -> torch.Tensor:
|
| 321 |
"""
|
| 322 |
Args:
|
| 323 |
images_hidden_states: List of tensor, each tensor has shape [seq_len, hidden_size]
|
| 324 |
|
| 325 |
Returns:
|
| 326 |
+
Tensor of shape [total_images, token_size, hidden_size]
|
| 327 |
"""
|
| 328 |
if not images_hidden_states:
|
| 329 |
+
return torch.empty(0, self.token_size, self.hidden_size)
|
| 330 |
|
| 331 |
# 确保所有组件的设备和数据类型一致
|
| 332 |
+
# self._ensure_device_dtype_consistency(images_hidden_states[0])
|
| 333 |
+
|
| 334 |
+
# 检查 query_embeddings 是否包含 NaN
|
| 335 |
+
if torch.isnan(self.query_embeddings).any():
|
| 336 |
+
print("警告:query_embeddings 包含 NaN 值,重新初始化")
|
| 337 |
+
nn.init.normal_(self.query_embeddings, mean=0.0, std=0.02)
|
| 338 |
+
|
| 339 |
+
# 获取每个图像的序列长度
|
| 340 |
+
seq_lengths = [state.size(0) for state in images_hidden_states]
|
| 341 |
+
max_seq_len = max(seq_lengths)
|
| 342 |
+
batch_size = len(images_hidden_states)
|
| 343 |
+
device = images_hidden_states[0].device
|
| 344 |
+
dtype = images_hidden_states[0].dtype
|
| 345 |
+
|
| 346 |
+
# 将所有图像填充到相同长度并堆叠
|
| 347 |
+
padded_states = []
|
| 348 |
+
attention_masks = []
|
| 349 |
+
for state in images_hidden_states:
|
| 350 |
+
pad_len = max_seq_len - state.size(0)
|
| 351 |
+
if pad_len > 0:
|
| 352 |
+
# 填充序列
|
| 353 |
+
padded_state = F.pad(state, (0, 0, 0, pad_len), mode='constant', value=0)
|
| 354 |
+
# 创建注意力掩码
|
| 355 |
+
attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device)
|
| 356 |
+
attention_mask[state.size(0):] = False
|
| 357 |
+
else:
|
| 358 |
+
padded_state = state
|
| 359 |
+
attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device)
|
| 360 |
+
padded_states.append(padded_state)
|
| 361 |
+
attention_masks.append(attention_mask)
|
| 362 |
+
|
| 363 |
+
# [batch_size, max_seq_len, hidden_size]
|
| 364 |
+
batched_states = torch.stack(padded_states)
|
| 365 |
+
# [batch_size, max_seq_len]
|
| 366 |
+
attention_masks = torch.stack(attention_masks)
|
| 367 |
+
|
| 368 |
+
# 创建循环移位的状态用于对比
|
| 369 |
+
# 对于第一个图像,使用自身作为previous
|
| 370 |
+
previous_states = torch.roll(batched_states, shifts=1, dims=0)
|
| 371 |
+
previous_states[0] = batched_states[0]
|
| 372 |
+
previous_masks = torch.roll(attention_masks, shifts=1, dims=0)
|
| 373 |
+
previous_masks[0] = attention_masks[0]
|
| 374 |
+
|
| 375 |
+
# Encoder: 批量处理所有图像
|
| 376 |
+
encoded_features = self._encoder_forward(
|
| 377 |
+
batched_states, # [batch_size, max_seq_len, hidden_size]
|
| 378 |
+
previous_states, # [batch_size, max_seq_len, hidden_size]
|
| 379 |
+
attention_masks, # [batch_size, max_seq_len]
|
| 380 |
+
previous_masks # [batch_size, max_seq_len]
|
| 381 |
+
)
|
| 382 |
|
| 383 |
+
# Decoder: 批量处理所有图像
|
| 384 |
+
# 扩展query_embeddings到batch维度
|
| 385 |
+
batch_queries = self.query_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
|
| 386 |
+
# [batch_size, token_size, hidden_size]
|
| 387 |
+
compare_visual_embeds = self._decoder_forward(
|
| 388 |
+
batch_queries,
|
| 389 |
+
encoded_features,
|
| 390 |
+
torch.ones(batch_size, self.token_size, dtype=torch.bool, device=device), # query掩码
|
| 391 |
+
attention_masks # encoded特征的掩码
|
| 392 |
+
)
|
|
|
|
| 393 |
|
| 394 |
+
return compare_visual_embeds # [batch_size, token_size, hidden_size]
|
| 395 |
|
| 396 |
+
def _encoder_forward(self, current_features, previous_features, current_mask=None, previous_mask=None):
|
| 397 |
"""
|
| 398 |
Encoder: 双向图像特征交互
|
| 399 |
+
Args:
|
| 400 |
+
current_features: [batch_size, seq_len, hidden_size]
|
| 401 |
+
previous_features: [batch_size, seq_len, hidden_size]
|
| 402 |
+
current_mask: [batch_size, seq_len]
|
| 403 |
+
previous_mask: [batch_size, seq_len]
|
| 404 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
# 第一步:previous attend to current
|
| 406 |
residual = previous_features
|
| 407 |
|
|
|
|
| 412 |
# Cross attention: previous attend to current
|
| 413 |
cross_attn_output1 = self.encoder_cross_attn1(
|
| 414 |
query_states=previous_normed,
|
| 415 |
+
key_value_states=current_normed1,
|
| 416 |
+
attention_mask=current_mask.unsqueeze(1).unsqueeze(2) if current_mask is not None else None
|
| 417 |
)
|
| 418 |
|
| 419 |
# Residual connection
|
|
|
|
| 430 |
|
| 431 |
# Layer norm
|
| 432 |
current_normed2 = self.encoder_norm3(current_features)
|
| 433 |
+
previous_normed2 = self.encoder_norm3(previous_features)
|
| 434 |
|
| 435 |
# Cross attention: current attend to previous
|
| 436 |
cross_attn_output2 = self.encoder_cross_attn2(
|
| 437 |
query_states=current_normed2,
|
| 438 |
+
key_value_states=previous_normed2,
|
| 439 |
+
attention_mask=previous_mask.unsqueeze(1).unsqueeze(2) if previous_mask is not None else None
|
| 440 |
)
|
| 441 |
|
| 442 |
# Residual connection
|
|
|
|
| 450 |
|
| 451 |
return current_features
|
| 452 |
|
| 453 |
+
def _decoder_forward(self, queries, encoded_features, query_mask=None, encoded_mask=None):
|
| 454 |
"""
|
| 455 |
+
Decoder: Query 与编码特征交互
|
| 456 |
+
Args:
|
| 457 |
+
queries: [batch_size, token_size, hidden_size]
|
| 458 |
+
encoded_features: [batch_size, seq_len, hidden_size]
|
| 459 |
+
query_mask: [batch_size, token_size]
|
| 460 |
+
encoded_mask: [batch_size, seq_len]
|
| 461 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
# Cross attention: queries attend to encoded features
|
| 463 |
residual = queries
|
| 464 |
queries_normed = self.decoder_norm1(queries)
|
|
|
|
| 466 |
|
| 467 |
cross_attn_output = self.decoder_cross_attn(
|
| 468 |
query_states=queries_normed,
|
| 469 |
+
key_value_states=encoded_normed,
|
| 470 |
+
attention_mask=encoded_mask.unsqueeze(1).unsqueeze(2) if encoded_mask is not None else None
|
| 471 |
)
|
| 472 |
|
| 473 |
queries = residual + cross_attn_output
|
|
|
|
| 478 |
mlp_output = self.decoder_mlp(mlp_input)
|
| 479 |
queries = residual + mlp_output
|
| 480 |
|
| 481 |
+
return queries # [batch_size, token_size, hidden_size]
|
| 482 |
|
| 483 |
|
| 484 |
# 先把组件继承出来方便修改
|
|
|
|
| 545 |
|
| 546 |
split_sizes = grid_thw.prod(-1).tolist()
|
| 547 |
splited_hidden_states_before_merger = torch.split(hidden_states, split_sizes)
|
| 548 |
+
# [total_images, token_size, hidden_size]
|
| 549 |
compare_visual_embeds = self.compare_visual_encoder(splited_hidden_states_before_merger)
|
| 550 |
+
# 记录每个batch的token数量
|
| 551 |
+
batch_size = compare_visual_embeds.size(0)
|
| 552 |
+
token_size = compare_visual_embeds.size(1)
|
| 553 |
+
# 将所有batch的数据拼接在一起
|
| 554 |
+
# [batch_size * token_size, hidden_size]
|
| 555 |
+
flattened_embeds = compare_visual_embeds.view(-1, compare_visual_embeds.size(-1))
|
| 556 |
+
# 一次性进行merger操作
|
| 557 |
+
# 假设merger会将token数量���为原来的1/4
|
| 558 |
+
merged = self.merger(flattened_embeds) # [(batch_size * token_size)/4, merged_hidden_size]
|
| 559 |
+
merged_token_size = token_size // self.spatial_merge_size**2
|
| 560 |
+
# [batch_size, merged_token_size, merged_hidden_size]
|
| 561 |
+
compare_visual_embeds = merged.view(batch_size, merged_token_size, -1)
|
| 562 |
|
| 563 |
hidden_states = self.merger(hidden_states)
|
| 564 |
reverse_indices = torch.argsort(window_index)
|
tokenizer_config.json
CHANGED
|
@@ -202,12 +202,8 @@
|
|
| 202 |
"eos_token": "<|im_end|>",
|
| 203 |
"errors": "replace",
|
| 204 |
"extra_special_tokens": {},
|
| 205 |
-
"max_length": null,
|
| 206 |
"model_max_length": 131072,
|
| 207 |
-
"pad_to_multiple_of": null,
|
| 208 |
"pad_token": "<|endoftext|>",
|
| 209 |
-
"pad_token_type_id": 0,
|
| 210 |
-
"padding_side": "right",
|
| 211 |
"processor_class": "YangJianProcessor",
|
| 212 |
"split_special_tokens": false,
|
| 213 |
"tokenizer_class": "Qwen2Tokenizer",
|
|
|
|
| 202 |
"eos_token": "<|im_end|>",
|
| 203 |
"errors": "replace",
|
| 204 |
"extra_special_tokens": {},
|
|
|
|
| 205 |
"model_max_length": 131072,
|
|
|
|
| 206 |
"pad_token": "<|endoftext|>",
|
|
|
|
|
|
|
| 207 |
"processor_class": "YangJianProcessor",
|
| 208 |
"split_special_tokens": false,
|
| 209 |
"tokenizer_class": "Qwen2Tokenizer",
|