Update mllm_encoder/mllm_encoder.py
Browse files- mllm_encoder/mllm_encoder.py +36 -36
mllm_encoder/mllm_encoder.py
CHANGED
|
@@ -2328,9 +2328,9 @@ class MLLMEncoder(ModelMixin, ConfigMixin):
|
|
| 2328 |
self.video_queries = nn.Parameter(
|
| 2329 |
torch.randn(num_video_queries, hidden_size) * 0.02
|
| 2330 |
)
|
| 2331 |
-
self.ref_queries = nn.Parameter(
|
| 2332 |
-
|
| 2333 |
-
)
|
| 2334 |
|
| 2335 |
# Connector MLP: MLLM hidden → DiT dim
|
| 2336 |
self.connector = nn.Sequential(
|
|
@@ -2342,13 +2342,13 @@ class MLLMEncoder(ModelMixin, ConfigMixin):
|
|
| 2342 |
nn.init.zeros_(self.connector[2].bias)
|
| 2343 |
|
| 2344 |
# Ref connector MLP (separate from main connector)
|
| 2345 |
-
self.ref_connector = nn.Sequential(
|
| 2346 |
-
|
| 2347 |
-
|
| 2348 |
-
|
| 2349 |
-
)
|
| 2350 |
-
nn.init.zeros_(self.ref_connector[2].weight)
|
| 2351 |
-
nn.init.zeros_(self.ref_connector[2].bias)
|
| 2352 |
|
| 2353 |
# Qwen VL model and processor (loaded lazily)
|
| 2354 |
self.qwen_model = None
|
|
@@ -2705,31 +2705,31 @@ class MLLMEncoder(ModelMixin, ConfigMixin):
|
|
| 2705 |
learnable_query_features = self.connector(learnable_query_features)
|
| 2706 |
|
| 2707 |
# Extract ref image features if in ref mode
|
| 2708 |
-
if ref_image:
|
| 2709 |
-
|
| 2710 |
-
|
| 2711 |
-
|
| 2712 |
-
|
| 2713 |
-
|
| 2714 |
-
|
| 2715 |
-
|
| 2716 |
-
|
| 2717 |
-
|
| 2718 |
-
|
| 2719 |
-
|
| 2720 |
-
|
| 2721 |
-
|
| 2722 |
-
|
| 2723 |
-
|
| 2724 |
-
|
| 2725 |
-
|
| 2726 |
-
|
| 2727 |
-
|
| 2728 |
-
|
| 2729 |
-
|
| 2730 |
-
|
| 2731 |
-
|
| 2732 |
-
|
| 2733 |
-
|
| 2734 |
|
| 2735 |
return learnable_query_features
|
|
|
|
| 2328 |
self.video_queries = nn.Parameter(
|
| 2329 |
torch.randn(num_video_queries, hidden_size) * 0.02
|
| 2330 |
)
|
| 2331 |
+
# self.ref_queries = nn.Parameter(
|
| 2332 |
+
# torch.randn(num_ref_queries, hidden_size) * 0.02
|
| 2333 |
+
# )
|
| 2334 |
|
| 2335 |
# Connector MLP: MLLM hidden → DiT dim
|
| 2336 |
self.connector = nn.Sequential(
|
|
|
|
| 2342 |
nn.init.zeros_(self.connector[2].bias)
|
| 2343 |
|
| 2344 |
# Ref connector MLP (separate from main connector)
|
| 2345 |
+
# self.ref_connector = nn.Sequential(
|
| 2346 |
+
# nn.Linear(hidden_size, dit_dim),
|
| 2347 |
+
# nn.GELU(approximate="tanh"),
|
| 2348 |
+
# nn.Linear(dit_dim, dit_dim),
|
| 2349 |
+
# )
|
| 2350 |
+
# nn.init.zeros_(self.ref_connector[2].weight)
|
| 2351 |
+
# nn.init.zeros_(self.ref_connector[2].bias)
|
| 2352 |
|
| 2353 |
# Qwen VL model and processor (loaded lazily)
|
| 2354 |
self.qwen_model = None
|
|
|
|
| 2705 |
learnable_query_features = self.connector(learnable_query_features)
|
| 2706 |
|
| 2707 |
# Extract ref image features if in ref mode
|
| 2708 |
+
# if ref_image:
|
| 2709 |
+
# vision_start_id = self.processor.tokenizer.convert_tokens_to_ids(
|
| 2710 |
+
# "<|vision_start|>"
|
| 2711 |
+
# )
|
| 2712 |
+
# vision_end_id = self.processor.tokenizer.convert_tokens_to_ids(
|
| 2713 |
+
# "<|vision_end|>"
|
| 2714 |
+
# )
|
| 2715 |
+
# input_ids = inputs.input_ids[0]
|
| 2716 |
+
# vision_start_indices = (input_ids == vision_start_id).nonzero(
|
| 2717 |
+
# as_tuple=True
|
| 2718 |
+
# )[-1]
|
| 2719 |
+
# if len(vision_start_indices) > 0:
|
| 2720 |
+
# last_vision_start = vision_start_indices[-1]
|
| 2721 |
+
# remaining_ids = input_ids[last_vision_start:]
|
| 2722 |
+
# end_relative_idx = (remaining_ids == vision_end_id).nonzero(
|
| 2723 |
+
# as_tuple=True
|
| 2724 |
+
# )[-1]
|
| 2725 |
+
# if len(end_relative_idx) > 0:
|
| 2726 |
+
# last_vision_end = last_vision_start + end_relative_idx[0]
|
| 2727 |
+
# ref_image_features = hidden_states[
|
| 2728 |
+
# :, last_vision_start + 1 : last_vision_end, :
|
| 2729 |
+
# ]
|
| 2730 |
+
# ref_image_features = self.ref_connector(ref_image_features)
|
| 2731 |
+
# learnable_query_features = torch.cat(
|
| 2732 |
+
# [ref_image_features, learnable_query_features], dim=1
|
| 2733 |
+
# )
|
| 2734 |
|
| 2735 |
return learnable_query_features
|