linyq commited on
Commit
da71758
·
verified ·
1 Parent(s): 9eccfc5

Update mllm_encoder/mllm_encoder.py

Browse files
Files changed (1) hide show
  1. mllm_encoder/mllm_encoder.py +36 -36
mllm_encoder/mllm_encoder.py CHANGED
@@ -2328,9 +2328,9 @@ class MLLMEncoder(ModelMixin, ConfigMixin):
2328
  self.video_queries = nn.Parameter(
2329
  torch.randn(num_video_queries, hidden_size) * 0.02
2330
  )
2331
- self.ref_queries = nn.Parameter(
2332
- torch.randn(num_ref_queries, hidden_size) * 0.02
2333
- )
2334
 
2335
  # Connector MLP: MLLM hidden → DiT dim
2336
  self.connector = nn.Sequential(
@@ -2342,13 +2342,13 @@ class MLLMEncoder(ModelMixin, ConfigMixin):
2342
  nn.init.zeros_(self.connector[2].bias)
2343
 
2344
  # Ref connector MLP (separate from main connector)
2345
- self.ref_connector = nn.Sequential(
2346
- nn.Linear(hidden_size, dit_dim),
2347
- nn.GELU(approximate="tanh"),
2348
- nn.Linear(dit_dim, dit_dim),
2349
- )
2350
- nn.init.zeros_(self.ref_connector[2].weight)
2351
- nn.init.zeros_(self.ref_connector[2].bias)
2352
 
2353
  # Qwen VL model and processor (loaded lazily)
2354
  self.qwen_model = None
@@ -2705,31 +2705,31 @@ class MLLMEncoder(ModelMixin, ConfigMixin):
2705
  learnable_query_features = self.connector(learnable_query_features)
2706
 
2707
  # Extract ref image features if in ref mode
2708
- if ref_image:
2709
- vision_start_id = self.processor.tokenizer.convert_tokens_to_ids(
2710
- "<|vision_start|>"
2711
- )
2712
- vision_end_id = self.processor.tokenizer.convert_tokens_to_ids(
2713
- "<|vision_end|>"
2714
- )
2715
- input_ids = inputs.input_ids[0]
2716
- vision_start_indices = (input_ids == vision_start_id).nonzero(
2717
- as_tuple=True
2718
- )[-1]
2719
- if len(vision_start_indices) > 0:
2720
- last_vision_start = vision_start_indices[-1]
2721
- remaining_ids = input_ids[last_vision_start:]
2722
- end_relative_idx = (remaining_ids == vision_end_id).nonzero(
2723
- as_tuple=True
2724
- )[-1]
2725
- if len(end_relative_idx) > 0:
2726
- last_vision_end = last_vision_start + end_relative_idx[0]
2727
- ref_image_features = hidden_states[
2728
- :, last_vision_start + 1 : last_vision_end, :
2729
- ]
2730
- ref_image_features = self.ref_connector(ref_image_features)
2731
- learnable_query_features = torch.cat(
2732
- [ref_image_features, learnable_query_features], dim=1
2733
- )
2734
 
2735
  return learnable_query_features
 
2328
  self.video_queries = nn.Parameter(
2329
  torch.randn(num_video_queries, hidden_size) * 0.02
2330
  )
2331
+ # self.ref_queries = nn.Parameter(
2332
+ # torch.randn(num_ref_queries, hidden_size) * 0.02
2333
+ # )
2334
 
2335
  # Connector MLP: MLLM hidden → DiT dim
2336
  self.connector = nn.Sequential(
 
2342
  nn.init.zeros_(self.connector[2].bias)
2343
 
2344
  # Ref connector MLP (separate from main connector)
2345
+ # self.ref_connector = nn.Sequential(
2346
+ # nn.Linear(hidden_size, dit_dim),
2347
+ # nn.GELU(approximate="tanh"),
2348
+ # nn.Linear(dit_dim, dit_dim),
2349
+ # )
2350
+ # nn.init.zeros_(self.ref_connector[2].weight)
2351
+ # nn.init.zeros_(self.ref_connector[2].bias)
2352
 
2353
  # Qwen VL model and processor (loaded lazily)
2354
  self.qwen_model = None
 
2705
  learnable_query_features = self.connector(learnable_query_features)
2706
 
2707
  # Extract ref image features if in ref mode
2708
+ # if ref_image:
2709
+ # vision_start_id = self.processor.tokenizer.convert_tokens_to_ids(
2710
+ # "<|vision_start|>"
2711
+ # )
2712
+ # vision_end_id = self.processor.tokenizer.convert_tokens_to_ids(
2713
+ # "<|vision_end|>"
2714
+ # )
2715
+ # input_ids = inputs.input_ids[0]
2716
+ # vision_start_indices = (input_ids == vision_start_id).nonzero(
2717
+ # as_tuple=True
2718
+ # )[-1]
2719
+ # if len(vision_start_indices) > 0:
2720
+ # last_vision_start = vision_start_indices[-1]
2721
+ # remaining_ids = input_ids[last_vision_start:]
2722
+ # end_relative_idx = (remaining_ids == vision_end_id).nonzero(
2723
+ # as_tuple=True
2724
+ # )[-1]
2725
+ # if len(end_relative_idx) > 0:
2726
+ # last_vision_end = last_vision_start + end_relative_idx[0]
2727
+ # ref_image_features = hidden_states[
2728
+ # :, last_vision_start + 1 : last_vision_end, :
2729
+ # ]
2730
+ # ref_image_features = self.ref_connector(ref_image_features)
2731
+ # learnable_query_features = torch.cat(
2732
+ # [ref_image_features, learnable_query_features], dim=1
2733
+ # )
2734
 
2735
  return learnable_query_features