jiang-cc commited on
Commit
989ad6f
·
verified ·
1 Parent(s): 5fc4dd0

Upload processor

Browse files
Files changed (2) hide show
  1. modeling_yangjian.py +102 -44
  2. tokenizer_config.json +0 -4
modeling_yangjian.py CHANGED
@@ -279,7 +279,7 @@ class YangJianCompareVisualEncoder(nn.Module):
279
  # Decoder 部分:Query 与编码特征交互
280
  # 可学习的 Query Embeddings
281
  self.query_embeddings = nn.Parameter(
282
- torch.randn(self.token_size, self.hidden_size) * 0.02
283
  )
284
 
285
  # 只保留 Cross Attention for queries to attend to encoded features
@@ -314,47 +314,94 @@ class YangJianCompareVisualEncoder(nn.Module):
314
  self.encoder_mlp2 = self.encoder_mlp2.to(device=device, dtype=dtype)
315
  self.decoder_mlp = self.decoder_mlp.to(device=device, dtype=dtype)
316
 
317
- def forward(self, images_hidden_states: list) -> list:
 
 
 
318
  """
319
  Args:
320
  images_hidden_states: List of tensor, each tensor has shape [seq_len, hidden_size]
321
 
322
  Returns:
323
- List of compare visual embeddings, each has shape [token_size, hidden_size]
324
  """
325
  if not images_hidden_states:
326
- return []
327
 
328
  # 确保所有组件的设备和数据类型一致
329
- self._ensure_device_dtype_consistency(images_hidden_states[0])
330
-
331
- compare_visual_embeds = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
- for i in range(len(images_hidden_states)):
334
- current_hidden_state = images_hidden_states[i] # [seq_len_current, hidden_size]
335
- previous_hidden_state = images_hidden_states[i-1] if i > 0 else current_hidden_state # [seq_len_prev, hidden_size]
336
-
337
- # Encoder 部分:双向图像特征交互
338
- encoded_features = self._encoder_forward(current_hidden_state, previous_hidden_state)
339
-
340
- # Decoder 部分:Query 与编码特征交互
341
- compare_visual_embed = self._decoder_forward(encoded_features)
342
-
343
- compare_visual_embeds.append(compare_visual_embed)
344
 
345
- return compare_visual_embeds
346
 
347
- def _encoder_forward(self, current_features, previous_features):
348
  """
349
  Encoder: 双向图像特征交互
350
- 1. previous attend to current
351
- 2. current attend to previous
 
 
 
352
  """
353
- # 确保数据类型和设备一致
354
- device = current_features.device
355
- dtype = current_features.dtype
356
- previous_features = previous_features.to(device=device, dtype=dtype)
357
-
358
  # 第一步:previous attend to current
359
  residual = previous_features
360
 
@@ -365,7 +412,8 @@ class YangJianCompareVisualEncoder(nn.Module):
365
  # Cross attention: previous attend to current
366
  cross_attn_output1 = self.encoder_cross_attn1(
367
  query_states=previous_normed,
368
- key_value_states=current_normed1
 
369
  )
370
 
371
  # Residual connection
@@ -382,12 +430,13 @@ class YangJianCompareVisualEncoder(nn.Module):
382
 
383
  # Layer norm
384
  current_normed2 = self.encoder_norm3(current_features)
385
- previous_normed2 = self.encoder_norm3(previous_features) # 使用增强后的 previous features
386
 
387
  # Cross attention: current attend to previous
388
  cross_attn_output2 = self.encoder_cross_attn2(
389
  query_states=current_normed2,
390
- key_value_states=previous_normed2
 
391
  )
392
 
393
  # Residual connection
@@ -401,17 +450,15 @@ class YangJianCompareVisualEncoder(nn.Module):
401
 
402
  return current_features
403
 
404
- def _decoder_forward(self, encoded_features):
405
  """
406
- Decoder: Query 与编码特征交互(仅使用 cross attention)
 
 
 
 
 
407
  """
408
- # 获取设备和数据类型
409
- device = encoded_features.device
410
- dtype = encoded_features.dtype
411
-
412
- # 初始化 queries 并确保设备和数据类型一致
413
- queries = self.query_embeddings.to(device=device, dtype=dtype)
414
-
415
  # Cross attention: queries attend to encoded features
416
  residual = queries
417
  queries_normed = self.decoder_norm1(queries)
@@ -419,7 +466,8 @@ class YangJianCompareVisualEncoder(nn.Module):
419
 
420
  cross_attn_output = self.decoder_cross_attn(
421
  query_states=queries_normed,
422
- key_value_states=encoded_normed
 
423
  )
424
 
425
  queries = residual + cross_attn_output
@@ -430,7 +478,7 @@ class YangJianCompareVisualEncoder(nn.Module):
430
  mlp_output = self.decoder_mlp(mlp_input)
431
  queries = residual + mlp_output
432
 
433
- return queries # [token_size, hidden_size]
434
 
435
 
436
  # 先把组件继承出来方便修改
@@ -497,10 +545,20 @@ class YangJianVisionTransformerPretrainedModel(Qwen2_5_VisionTransformerPretrain
497
 
498
  split_sizes = grid_thw.prod(-1).tolist()
499
  splited_hidden_states_before_merger = torch.split(hidden_states, split_sizes)
 
500
  compare_visual_embeds = self.compare_visual_encoder(splited_hidden_states_before_merger)
501
- # compare_visual_embeds = self.merger(compare_visual_embeds)
502
- for i, embeds in enumerate(compare_visual_embeds):
503
- compare_visual_embeds[i] = self.merger(embeds)
 
 
 
 
 
 
 
 
 
504
 
505
  hidden_states = self.merger(hidden_states)
506
  reverse_indices = torch.argsort(window_index)
 
279
  # Decoder 部分:Query 与编码特征交互
280
  # 可学习的 Query Embeddings
281
  self.query_embeddings = nn.Parameter(
282
+ torch.empty(self.token_size, self.hidden_size)
283
  )
284
 
285
  # 只保留 Cross Attention for queries to attend to encoded features
 
314
  self.encoder_mlp2 = self.encoder_mlp2.to(device=device, dtype=dtype)
315
  self.decoder_mlp = self.decoder_mlp.to(device=device, dtype=dtype)
316
 
317
+ def _initialize_weights(self):
318
+ nn.init.normal_(self.query_embeddings.weight, mean=0.0, std=0.02)
319
+
320
+ def forward(self, images_hidden_states: list) -> torch.Tensor:
321
  """
322
  Args:
323
  images_hidden_states: List of tensor, each tensor has shape [seq_len, hidden_size]
324
 
325
  Returns:
326
+ Tensor of shape [total_images, token_size, hidden_size]
327
  """
328
  if not images_hidden_states:
329
+ return torch.empty(0, self.token_size, self.hidden_size)
330
 
331
  # 确保所有组件的设备和数据类型一致
332
+ # self._ensure_device_dtype_consistency(images_hidden_states[0])
333
+
334
+ # 检查 query_embeddings 是否包含 NaN
335
+ if torch.isnan(self.query_embeddings).any():
336
+ print("警告:query_embeddings 包含 NaN 值,重新初始化")
337
+ nn.init.normal_(self.query_embeddings, mean=0.0, std=0.02)
338
+
339
+ # 获取每个图像的序列长度
340
+ seq_lengths = [state.size(0) for state in images_hidden_states]
341
+ max_seq_len = max(seq_lengths)
342
+ batch_size = len(images_hidden_states)
343
+ device = images_hidden_states[0].device
344
+ dtype = images_hidden_states[0].dtype
345
+
346
+ # 将所有图像填充到相同长度并堆叠
347
+ padded_states = []
348
+ attention_masks = []
349
+ for state in images_hidden_states:
350
+ pad_len = max_seq_len - state.size(0)
351
+ if pad_len > 0:
352
+ # 填充序列
353
+ padded_state = F.pad(state, (0, 0, 0, pad_len), mode='constant', value=0)
354
+ # 创建注意力掩码
355
+ attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device)
356
+ attention_mask[state.size(0):] = False
357
+ else:
358
+ padded_state = state
359
+ attention_mask = torch.ones(max_seq_len, dtype=torch.bool, device=device)
360
+ padded_states.append(padded_state)
361
+ attention_masks.append(attention_mask)
362
+
363
+ # [batch_size, max_seq_len, hidden_size]
364
+ batched_states = torch.stack(padded_states)
365
+ # [batch_size, max_seq_len]
366
+ attention_masks = torch.stack(attention_masks)
367
+
368
+ # 创建循环移位的状态用于对比
369
+ # 对于第一个图像,使用自身作为previous
370
+ previous_states = torch.roll(batched_states, shifts=1, dims=0)
371
+ previous_states[0] = batched_states[0]
372
+ previous_masks = torch.roll(attention_masks, shifts=1, dims=0)
373
+ previous_masks[0] = attention_masks[0]
374
+
375
+ # Encoder: 批量处理所有图像
376
+ encoded_features = self._encoder_forward(
377
+ batched_states, # [batch_size, max_seq_len, hidden_size]
378
+ previous_states, # [batch_size, max_seq_len, hidden_size]
379
+ attention_masks, # [batch_size, max_seq_len]
380
+ previous_masks # [batch_size, max_seq_len]
381
+ )
382
 
383
+ # Decoder: 批量处理所有图像
384
+ # 扩展query_embeddings到batch维度
385
+ batch_queries = self.query_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
386
+ # [batch_size, token_size, hidden_size]
387
+ compare_visual_embeds = self._decoder_forward(
388
+ batch_queries,
389
+ encoded_features,
390
+ torch.ones(batch_size, self.token_size, dtype=torch.bool, device=device), # query掩码
391
+ attention_masks # encoded特征的掩码
392
+ )
 
393
 
394
+ return compare_visual_embeds # [batch_size, token_size, hidden_size]
395
 
396
+ def _encoder_forward(self, current_features, previous_features, current_mask=None, previous_mask=None):
397
  """
398
  Encoder: 双向图像特征交互
399
+ Args:
400
+ current_features: [batch_size, seq_len, hidden_size]
401
+ previous_features: [batch_size, seq_len, hidden_size]
402
+ current_mask: [batch_size, seq_len]
403
+ previous_mask: [batch_size, seq_len]
404
  """
 
 
 
 
 
405
  # 第一步:previous attend to current
406
  residual = previous_features
407
 
 
412
  # Cross attention: previous attend to current
413
  cross_attn_output1 = self.encoder_cross_attn1(
414
  query_states=previous_normed,
415
+ key_value_states=current_normed1,
416
+ attention_mask=current_mask.unsqueeze(1).unsqueeze(2) if current_mask is not None else None
417
  )
418
 
419
  # Residual connection
 
430
 
431
  # Layer norm
432
  current_normed2 = self.encoder_norm3(current_features)
433
+ previous_normed2 = self.encoder_norm3(previous_features)
434
 
435
  # Cross attention: current attend to previous
436
  cross_attn_output2 = self.encoder_cross_attn2(
437
  query_states=current_normed2,
438
+ key_value_states=previous_normed2,
439
+ attention_mask=previous_mask.unsqueeze(1).unsqueeze(2) if previous_mask is not None else None
440
  )
441
 
442
  # Residual connection
 
450
 
451
  return current_features
452
 
453
+ def _decoder_forward(self, queries, encoded_features, query_mask=None, encoded_mask=None):
454
  """
455
+ Decoder: Query 与编码特征交互
456
+ Args:
457
+ queries: [batch_size, token_size, hidden_size]
458
+ encoded_features: [batch_size, seq_len, hidden_size]
459
+ query_mask: [batch_size, token_size]
460
+ encoded_mask: [batch_size, seq_len]
461
  """
 
 
 
 
 
 
 
462
  # Cross attention: queries attend to encoded features
463
  residual = queries
464
  queries_normed = self.decoder_norm1(queries)
 
466
 
467
  cross_attn_output = self.decoder_cross_attn(
468
  query_states=queries_normed,
469
+ key_value_states=encoded_normed,
470
+ attention_mask=encoded_mask.unsqueeze(1).unsqueeze(2) if encoded_mask is not None else None
471
  )
472
 
473
  queries = residual + cross_attn_output
 
478
  mlp_output = self.decoder_mlp(mlp_input)
479
  queries = residual + mlp_output
480
 
481
+ return queries # [batch_size, token_size, hidden_size]
482
 
483
 
484
  # 先把组件继承出来方便修改
 
545
 
546
  split_sizes = grid_thw.prod(-1).tolist()
547
  splited_hidden_states_before_merger = torch.split(hidden_states, split_sizes)
548
+ # [total_images, token_size, hidden_size]
549
  compare_visual_embeds = self.compare_visual_encoder(splited_hidden_states_before_merger)
550
+ # 记录每个batch的token数量
551
+ batch_size = compare_visual_embeds.size(0)
552
+ token_size = compare_visual_embeds.size(1)
553
+ # 将所有batch的数据拼接在一起
554
+ # [batch_size * token_size, hidden_size]
555
+ flattened_embeds = compare_visual_embeds.view(-1, compare_visual_embeds.size(-1))
556
+ # 一次性进行merger操作
557
+ # 假设merger会将token数量���为原来的1/4
558
+ merged = self.merger(flattened_embeds) # [(batch_size * token_size)/4, merged_hidden_size]
559
+ merged_token_size = token_size // self.spatial_merge_size**2
560
+ # [batch_size, merged_token_size, merged_hidden_size]
561
+ compare_visual_embeds = merged.view(batch_size, merged_token_size, -1)
562
 
563
  hidden_states = self.merger(hidden_states)
564
  reverse_indices = torch.argsort(window_index)
tokenizer_config.json CHANGED
@@ -202,12 +202,8 @@
202
  "eos_token": "<|im_end|>",
203
  "errors": "replace",
204
  "extra_special_tokens": {},
205
- "max_length": null,
206
  "model_max_length": 131072,
207
- "pad_to_multiple_of": null,
208
  "pad_token": "<|endoftext|>",
209
- "pad_token_type_id": 0,
210
- "padding_side": "right",
211
  "processor_class": "YangJianProcessor",
212
  "split_special_tokens": false,
213
  "tokenizer_class": "Qwen2Tokenizer",
 
202
  "eos_token": "<|im_end|>",
203
  "errors": "replace",
204
  "extra_special_tokens": {},
 
205
  "model_max_length": 131072,
 
206
  "pad_token": "<|endoftext|>",
 
 
207
  "processor_class": "YangJianProcessor",
208
  "split_special_tokens": false,
209
  "tokenizer_class": "Qwen2Tokenizer",