TurkishCodeMan commited on
Commit
4db9aa3
·
verified ·
1 Parent(s): bffd49d

Upload folder using huggingface_hub

Browse files
models/local_nemotron/__init__.py ADDED
File without changes
models/local_nemotron/configuration_llama_nemotron_vl.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0.
3
+
4
+ from typing import Optional
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.models.llama.configuration_llama import LlamaConfig
8
+ from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
9
+ from transformers.utils import logging
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ # ============================================================================
15
+ # Bidirectional LLaMA Configuration
16
+ # ============================================================================
17
+
18
+
19
+ class LlamaBidirectionalConfig(LlamaConfig):
20
+ """Configuration for bidirectional (non-causal) LLaMA model."""
21
+
22
+ model_type = "llama_bidirec"
23
+
24
+ def __init__(
25
+ self,
26
+ pooling="avg",
27
+ temperature=1.0,
28
+ **kwargs,
29
+ ):
30
+ self.pooling = pooling
31
+ self.temperature = temperature
32
+ super().__init__(
33
+ **kwargs,
34
+ )
35
+
36
+
37
+ # ============================================================================
38
+ # LlamaNemotronVL Configuration Classes
39
+ # ============================================================================
40
+
41
+
42
+ class LlamaNemotronVLConfig(PretrainedConfig):
43
+ """
44
+ Base configuration for vision-language models combining vision and language components.
45
+
46
+ This serves as the foundation for LlamaNemotronVL configurations.
47
+ """
48
+
49
+ model_type = "llama_nemotron_vl"
50
+ is_composition = True
51
+ # is_composition was renamed to has_no_defaults_at_init in transformers 4.52.1
52
+ # In PR https://github.com/huggingface/transformers/pull/36263
53
+ has_no_defaults_at_init = True
54
+
55
+ def __init__(
56
+ self,
57
+ vision_config=None,
58
+ llm_config=None,
59
+ use_backbone_lora=0,
60
+ use_llm_lora=0,
61
+ select_layer=-1,
62
+ force_image_size=None,
63
+ downsample_ratio=0.5,
64
+ template=None,
65
+ dynamic_image_size=False,
66
+ use_thumbnail=False,
67
+ min_dynamic_patch=1,
68
+ max_dynamic_patch=6,
69
+ mlp_checkpoint=True,
70
+ pre_feature_reduction=False,
71
+ keep_aspect_ratio=False,
72
+ vocab_size=-1,
73
+ q_max_length: Optional[int] = 512,
74
+ p_max_length: Optional[int] = 10240,
75
+ query_prefix: str = "query:",
76
+ passage_prefix: str = "passage:",
77
+ pooling: str = "last",
78
+ bidirectional_attention: bool = False,
79
+ max_input_tiles: int = 2,
80
+ img_context_token_id: int = 128258, # tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
81
+ **kwargs,
82
+ ):
83
+ if vision_config is None:
84
+ vision_config = {}
85
+ logger.info(
86
+ "vision_config is None. Initializing Vision Encoders with default values."
87
+ )
88
+ else:
89
+ if vision_config["model_type"] == "siglip_vision_model":
90
+ self.vision_config = SiglipVisionConfig(**vision_config)
91
+ else:
92
+ raise ValueError(
93
+ "Unsupported model_type: {}".format(vision_config["model_type"])
94
+ )
95
+
96
+ if llm_config is None:
97
+ llm_config = {}
98
+ logger.info(
99
+ "llm_config is None. Initializing the LLM config with default values"
100
+ )
101
+ else:
102
+ if llm_config["architectures"][0] in {
103
+ "LlamaBidirectionalModel",
104
+ "LlamaBidirectionalForSequenceClassification",
105
+ }:
106
+ self.llm_config = LlamaBidirectionalConfig(**llm_config)
107
+ else:
108
+ raise ValueError(
109
+ "Unsupported architecture: {}".format(
110
+ llm_config["architectures"][0]
111
+ )
112
+ )
113
+ self.vocab_size = self.llm_config.vocab_size
114
+ self.use_backbone_lora = use_backbone_lora
115
+ self.use_llm_lora = use_llm_lora
116
+ self.select_layer = select_layer
117
+ self.force_image_size = force_image_size
118
+ self.downsample_ratio = downsample_ratio
119
+ self.template = template
120
+ self.dynamic_image_size = dynamic_image_size
121
+ self.use_thumbnail = use_thumbnail
122
+ self.min_dynamic_patch = min_dynamic_patch
123
+ self.max_dynamic_patch = max_dynamic_patch
124
+ self.mlp_checkpoint = mlp_checkpoint
125
+ self.pre_feature_reduction = pre_feature_reduction
126
+ self.keep_aspect_ratio = keep_aspect_ratio
127
+
128
+ self.q_max_length = q_max_length
129
+ self.p_max_length = p_max_length
130
+ self.query_prefix = query_prefix
131
+ self.passage_prefix = passage_prefix
132
+ self.pooling = pooling
133
+ self.bidirectional_attention = bidirectional_attention
134
+ self.img_context_token_id = img_context_token_id
135
+ self.max_input_tiles = max_input_tiles
136
+ super().__init__(**kwargs)
models/local_nemotron/modeling_llama_nemotron_vl.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0.
3
+
4
+ import math
5
+ from typing import List, Optional, Tuple, Union, Any, Dict
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
+ from transformers import AutoProcessor, PreTrainedModel, AutoConfig
11
+ from transformers.cache_utils import Cache
12
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
13
+ from transformers.modeling_outputs import (
14
+ CausalLMOutputWithPast,
15
+ SequenceClassifierOutputWithPast,
16
+ )
17
+ from transformers.models.llama.modeling_llama import (
18
+ LlamaForSequenceClassification,
19
+ LlamaModel,
20
+ )
21
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
22
+ from transformers.utils import logging
23
+
24
+ from .configuration_llama_nemotron_vl import (
25
+ LlamaBidirectionalConfig,
26
+ LlamaNemotronVLConfig,
27
+ )
28
+
29
+ from .processing_llama_nemotron_vl import LlamaNemotronVLProcessor
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ def split_model(model_path, device):
35
+ device_map = {}
36
+ world_size = torch.cuda.device_count()
37
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
38
+ num_layers = config.llm_config.num_hidden_layers
39
+
40
+ print("world_size", world_size)
41
+ num_layers_per_gpu_ = math.floor(num_layers / (world_size - 1))
42
+ num_layers_per_gpu = [num_layers_per_gpu_] * world_size
43
+ num_layers_per_gpu[device] = num_layers - num_layers_per_gpu_ * (world_size - 1)
44
+ print(num_layers_per_gpu)
45
+ layer_cnt = 0
46
+ for i, num_layer in enumerate(num_layers_per_gpu):
47
+ for j in range(num_layer):
48
+ device_map[f"language_model.model.layers.{layer_cnt}"] = i
49
+ layer_cnt += 1
50
+ device_map["vision_model"] = device
51
+ device_map["mlp1"] = device
52
+ device_map["language_model.model.tok_embeddings"] = device
53
+ device_map["language_model.model.embed_tokens"] = device
54
+ device_map["language_model.output"] = device
55
+ device_map["language_model.model.norm"] = device
56
+ device_map["language_model.lm_head"] = device
57
+ device_map["language_model.model.rotary_emb"] = device
58
+ device_map[f"language_model.model.layers.{num_layers - 1}"] = device
59
+ return device_map
60
+
61
+
62
+ def pool(
63
+ last_hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool_type: str
64
+ ) -> torch.Tensor:
65
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
66
+
67
+ if pool_type == "avg":
68
+ emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
69
+ elif pool_type == "weighted_avg":
70
+ emb = last_hidden.sum(dim=1)
71
+ elif pool_type == "cls":
72
+ emb = last_hidden[:, 0]
73
+ elif pool_type == "last":
74
+ left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
75
+ if left_padding:
76
+ emb = last_hidden[:, -1]
77
+ else:
78
+ sequence_lengths = attention_mask.sum(dim=1) - 1
79
+ batch_size = last_hidden.shape[0]
80
+ emb = last_hidden[
81
+ torch.arange(batch_size, device=last_hidden.device), sequence_lengths
82
+ ]
83
+ elif pool_type == "cls_last":
84
+ emb = last_hidden[:, 0]
85
+ elif pool_type == "colbert":
86
+ emb = last_hidden
87
+ else:
88
+ raise ValueError(f"pool_type {pool_type} not supported")
89
+
90
+ return emb
91
+
92
+
93
+ # ============================================================================
94
+ # Bidirectional LLaMA Model
95
+ # ============================================================================
96
+
97
+
98
+ class LlamaBidirectionalModel(LlamaModel):
99
+ """LLaMA model with bidirectional (non-causal) attention."""
100
+
101
+ config_class = LlamaBidirectionalConfig
102
+
103
+ def __init__(self, config: LlamaBidirectionalConfig):
104
+ # ✅ FIX: Force eager attention before super().__init__ triggers FA2 checks
105
+ config._attn_implementation = "eager"
106
+ if hasattr(config, 'llm_config'):
107
+ config.llm_config._attn_implementation = "eager"
108
+
109
+ super().__init__(config)
110
+ for layer in self.layers:
111
+ layer.self_attn.is_causal = False
112
+
113
+ def _update_causal_mask(
114
+ self,
115
+ attention_mask: torch.Tensor,
116
+ input_tensor: torch.Tensor,
117
+ cache_position: torch.Tensor,
118
+ past_key_values: Cache,
119
+ output_attentions: bool,
120
+ ):
121
+ assert self.config._attn_implementation in ["flash_attention_2", "eager", "sdpa"], (
122
+ f"Unsupported attention implementation: {self.config._attn_implementation}, "
123
+ "only support flash_attention_2, eager or sdpa"
124
+ )
125
+
126
+ if self.config._attn_implementation == "flash_attention_2":
127
+ if attention_mask is not None and (attention_mask == 0.0).any():
128
+ return attention_mask
129
+ return None
130
+ elif self.config._attn_implementation in {"eager", "sdpa"}:
131
+ causal_mask = _prepare_4d_attention_mask(
132
+ attention_mask,
133
+ dtype=input_tensor.dtype,
134
+ )
135
+ return causal_mask
136
+
137
+
138
+ class LlamaBidirectionalForSequenceClassification(LlamaForSequenceClassification):
139
+ """LLaMA sequence classification model with bidirectional attention."""
140
+
141
+ config_class = LlamaBidirectionalConfig
142
+
143
+ def __init__(self, config):
144
+ super().__init__(config)
145
+ # Releasing the parameters of LlamaModel created by parent
146
+ del self.model
147
+ self.model = LlamaBidirectionalModel(config)
148
+ # Initialize weights and apply final processing
149
+ self.post_init()
150
+
151
+ def forward(
152
+ self,
153
+ input_ids: Optional[torch.LongTensor] = None,
154
+ attention_mask: Optional[torch.Tensor] = None,
155
+ position_ids: Optional[torch.LongTensor] = None,
156
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
157
+ inputs_embeds: Optional[torch.FloatTensor] = None,
158
+ labels: Optional[torch.LongTensor] = None,
159
+ use_cache: Optional[bool] = None,
160
+ output_attentions: Optional[bool] = None,
161
+ output_hidden_states: Optional[bool] = None,
162
+ return_dict: Optional[bool] = None,
163
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
164
+ r"""
165
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
166
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
167
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
168
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
169
+ """
170
+ return_dict = (
171
+ return_dict if return_dict is not None else self.config.use_return_dict
172
+ )
173
+
174
+ transformer_outputs = self.model(
175
+ input_ids,
176
+ attention_mask=attention_mask,
177
+ position_ids=position_ids,
178
+ past_key_values=past_key_values,
179
+ inputs_embeds=inputs_embeds,
180
+ use_cache=use_cache,
181
+ output_attentions=output_attentions,
182
+ output_hidden_states=output_hidden_states,
183
+ return_dict=return_dict,
184
+ )
185
+ hidden_states = transformer_outputs[0]
186
+
187
+ pooled_hidden_states = pool(
188
+ last_hidden_states=hidden_states,
189
+ attention_mask=attention_mask,
190
+ pool_type=self.config.pooling,
191
+ )
192
+
193
+ pooled_logits = self.score(pooled_hidden_states)
194
+ pooled_logits = pooled_logits / self.config.temperature
195
+
196
+ loss = None
197
+ if labels is not None:
198
+ labels = labels.to(pooled_logits.device)
199
+ if self.config.problem_type is None:
200
+ if self.num_labels == 1:
201
+ self.config.problem_type = "regression"
202
+ elif self.num_labels > 1 and (
203
+ labels.dtype == torch.long or labels.dtype == torch.int
204
+ ):
205
+ self.config.problem_type = "single_label_classification"
206
+ else:
207
+ self.config.problem_type = "multi_label_classification"
208
+
209
+ if self.config.problem_type == "regression":
210
+ loss_fct = MSELoss()
211
+ if self.num_labels == 1:
212
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
213
+ else:
214
+ loss = loss_fct(pooled_logits, labels)
215
+ elif self.config.problem_type == "single_label_classification":
216
+ loss_fct = CrossEntropyLoss()
217
+ loss = loss_fct(
218
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
219
+ )
220
+ elif self.config.problem_type == "multi_label_classification":
221
+ loss_fct = BCEWithLogitsLoss()
222
+ loss = loss_fct(pooled_logits, labels)
223
+
224
+ if not return_dict:
225
+ output = (pooled_logits,) + transformer_outputs[1:]
226
+ return ((loss,) + output) if loss is not None else output
227
+
228
+ return SequenceClassifierOutputWithPast(
229
+ loss=loss,
230
+ logits=pooled_logits,
231
+ past_key_values=transformer_outputs.past_key_values,
232
+ hidden_states=transformer_outputs.hidden_states,
233
+ attentions=transformer_outputs.attentions,
234
+ )
235
+
236
+
237
+ # ============================================================================
238
+ # LlamaNemotronVL Model Classes
239
+ # ============================================================================
240
+
241
+
242
+ class LlamaNemotronVLModel(PreTrainedModel):
243
+ """
244
+ LlamaNemotron VL model for vision-language reranking.
245
+ Combines a vision encoder (SigLIP) with a bidirectional language model (LLaMA)
246
+ for cross-modal reranking tasks.
247
+ """
248
+
249
+ config_class = LlamaNemotronVLConfig
250
+ main_input_name = "pixel_values"
251
+ _no_split_modules = ["LlamaDecoderLayer"]
252
+ _supports_flash_attn_2 = True
253
+ _supports_sdpa = True
254
+
255
+ def __init__(
256
+ self,
257
+ config: LlamaNemotronVLConfig,
258
+ vision_model: Optional[PreTrainedModel] = None,
259
+ language_model: Optional[PreTrainedModel] = None,
260
+ ):
261
+ # ✅ FIX: Force eager attention here as well
262
+ config._attn_implementation = "eager"
263
+ super().__init__(config)
264
+
265
+ # Calculate image token count
266
+ image_size = config.force_image_size or config.vision_config.image_size
267
+ if hasattr(config.vision_config, "grid_size"):
268
+ grid_size = config.vision_config.grid_size
269
+ self.patch_size = 14
270
+ self.num_image_token = int((grid_size * config.downsample_ratio) ** 2)
271
+ else:
272
+ patch_size = config.vision_config.patch_size
273
+ self.patch_size = patch_size
274
+ self.num_image_token = int(
275
+ (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
276
+ )
277
+
278
+ self.select_layer = config.select_layer
279
+ self.template = config.template
280
+ self.downsample_ratio = config.downsample_ratio
281
+
282
+ logger.info(f"num_image_token: {self.num_image_token}")
283
+ if vision_model is not None:
284
+ self.vision_model = vision_model
285
+ else:
286
+ if config.vision_config.model_type == "siglip_vision_model":
287
+ config.vision_config._attn_implementation = config._attn_implementation
288
+ self.vision_model = SiglipVisionModel(config.vision_config)
289
+ else:
290
+ raise NotImplementedError(
291
+ f"Unsupported vision model type: {config.vision_config.model_type}"
292
+ )
293
+
294
+ if language_model is not None:
295
+ self.language_model = language_model
296
+ else:
297
+ if config.llm_config.architectures[0] == "LlamaBidirectionalModel":
298
+ config.llm_config._attn_implementation = config._attn_implementation
299
+ self.language_model = LlamaBidirectionalModel(config.llm_config)
300
+ else:
301
+ raise NotImplementedError(
302
+ f"{config.llm_config.architectures[0]} is not implemented."
303
+ )
304
+
305
+ # Vision-to-language projection
306
+ vit_hidden_size = config.vision_config.hidden_size
307
+ llm_hidden_size = config.llm_config.hidden_size
308
+ self.mlp1 = nn.Sequential(
309
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
310
+ nn.Linear(
311
+ vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
312
+ llm_hidden_size,
313
+ ),
314
+ nn.GELU(),
315
+ nn.Linear(llm_hidden_size, llm_hidden_size),
316
+ )
317
+ self.img_context_token_id = None
318
+
319
+ # Initialize processor
320
+ self.processor = AutoProcessor.from_pretrained(
321
+ config.name_or_path, trust_remote_code=True
322
+ )
323
+
324
+ def _embed_batch(self, inputs: Dict[str, Any], pool_type: Optional[str] = None):
325
+ """
326
+ Encodes the inputs into a tensor of embeddings.
327
+ Args:
328
+ inputs: A dictionary of inputs to the model. You can prepare the inputs using the processor.process_queries and processor.process_documents methods.
329
+ pool_type: The type of pooling to use. If None, the pooling type is set to the pooling type configured in the model.
330
+ Returns:
331
+ A tensor of embeddings.
332
+ """
333
+ inputs = {
334
+ k: v.to(self.device) if isinstance(v, torch.Tensor) else v
335
+ for k, v in inputs.items()
336
+ }
337
+
338
+ outputs = self.forward(**inputs, output_hidden_states=True, return_dict=True)
339
+ if not pool_type:
340
+ pool_type = self.config.pooling
341
+ embeddings = pool(last_hidden_states=outputs.hidden_states[-1], attention_mask=inputs["attention_mask"], pool_type=pool_type)
342
+ return embeddings
343
+
344
+ def encode_queries(self, queries: List[str], **kwargs):
345
+ """
346
+ Encodes the input queries into a tensor of embeddings.
347
+ Args:
348
+ queries: A list of queries.
349
+ Returns:
350
+ A tensor of embeddings.
351
+ """
352
+ queries_dict = self.processor.process_queries(queries)
353
+ queries_embeddings = self._embed_batch(inputs=queries_dict, **kwargs)
354
+ return queries_embeddings
355
+
356
+ def encode_documents(self, images: Optional[List[Any]] = None, texts: Optional[List[str]] = None, **kwargs):
357
+ """
358
+ Encodes the input document images and texts into a tensor of embeddings.
359
+ Args:
360
+ images: A list of PIL.Image of document pages images.
361
+ texts: A list of document page texts.
362
+ Returns:
363
+ A tensor of embeddings.
364
+ """
365
+ if images and texts:
366
+ examples = [{
367
+ "image": image,
368
+ "text": doc_text
369
+ } for image, doc_text in zip(images, texts)]
370
+
371
+ elif images:
372
+ examples = [{
373
+ "image": image,
374
+ "text": ""
375
+ } for image in images]
376
+
377
+ elif texts:
378
+ examples = [{
379
+ "image": "",
380
+ "text": doc_text
381
+ } for doc_text in texts]
382
+ else:
383
+ raise ValueError("At least docs_images or docs_texts need to be provided")
384
+
385
+ docs_dict = self.processor.process_documents(examples)
386
+ docs_embeddings = self._embed_batch(inputs=docs_dict, **kwargs)
387
+ return docs_embeddings
388
+
389
+ def forward(
390
+ self,
391
+ pixel_values: torch.FloatTensor = None,
392
+ input_ids: torch.LongTensor = None,
393
+ attention_mask: Optional[torch.Tensor] = None,
394
+ position_ids: Optional[torch.LongTensor] = None,
395
+ image_flags: Optional[torch.LongTensor] = None,
396
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
397
+ labels: Optional[torch.LongTensor] = None,
398
+ use_cache: Optional[bool] = None,
399
+ output_attentions: Optional[bool] = None,
400
+ output_hidden_states: Optional[bool] = None,
401
+ return_dict: Optional[bool] = None,
402
+ num_patches_list: Optional[List[torch.Tensor]] = None,
403
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
404
+ return_dict = (
405
+ return_dict if return_dict is not None else self.config.use_return_dict
406
+ )
407
+
408
+ # Get text embeddings
409
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
410
+
411
+ # Process and inject vision embeddings if present
412
+ if pixel_values is not None:
413
+ if image_flags is None:
414
+ image_flags = torch.ones(pixel_values.shape[0])
415
+ image_flags = image_flags.squeeze(-1)
416
+ vit_embeds = self.extract_feature(pixel_values).to(
417
+ device=input_embeds.device
418
+ )
419
+
420
+ if not isinstance(image_flags, list):
421
+ image_flags = image_flags.squeeze(-1)
422
+ vit_embeds = vit_embeds[image_flags == 1]
423
+
424
+ # Inject vision tokens into text embeddings
425
+ B, N, C = input_embeds.shape
426
+ input_embeds = input_embeds.reshape(B * N, C)
427
+ input_ids = input_ids.reshape(B * N)
428
+ selected = (input_ids == self.config.img_context_token_id).to(input_embeds.device)
429
+ try:
430
+ input_embeds[selected] = input_embeds[
431
+ selected
432
+ ] * 0.0 + vit_embeds.reshape(-1, C)
433
+ except Exception as e:
434
+ vit_embeds = vit_embeds.reshape(-1, C)
435
+ print(
436
+ f"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, "
437
+ f"vit_embeds.shape={vit_embeds.shape}"
438
+ )
439
+ n_token = selected.sum()
440
+ input_embeds[selected] = (
441
+ input_embeds[selected] * 0.0 + vit_embeds[:n_token]
442
+ )
443
+
444
+ input_embeds = input_embeds.reshape(B, N, C)
445
+
446
+ # Forward through language model
447
+ outputs = self.language_model(
448
+ inputs_embeds=input_embeds,
449
+ attention_mask=attention_mask,
450
+ position_ids=position_ids,
451
+ past_key_values=past_key_values,
452
+ use_cache=use_cache,
453
+ output_attentions=output_attentions,
454
+ output_hidden_states=output_hidden_states,
455
+ )
456
+ logits = None
457
+ loss = None
458
+
459
+ if hasattr(outputs, "logits"):
460
+ logits = outputs.logits
461
+ if labels is not None:
462
+ # Shift so that tokens < n predict n
463
+ shift_logits = logits[..., :-1, :].contiguous()
464
+ shift_labels = labels[..., 1:].contiguous()
465
+ # Flatten the tokens
466
+ loss_fct = CrossEntropyLoss()
467
+ shift_logits = shift_logits.view(
468
+ -1, self.language_model.config.vocab_size
469
+ )
470
+ shift_labels = shift_labels.view(-1)
471
+ # Enable model parallelism
472
+ shift_labels = shift_labels.to(shift_logits.device)
473
+ loss = loss_fct(shift_logits, shift_labels)
474
+
475
+ if not return_dict:
476
+ output = (logits,) + outputs[1:]
477
+ return (loss,) + output if loss is not None else output
478
+
479
+ return CausalLMOutputWithPast(
480
+ loss=loss,
481
+ logits=logits,
482
+ past_key_values=outputs.past_key_values,
483
+ hidden_states=outputs.hidden_states,
484
+ attentions=outputs.attentions,
485
+ )
486
+
487
+ def pixel_shuffle(self, x, scale_factor=0.5):
488
+ n, w, h, c = x.shape
489
+ # N, W, H, C --> N, W, H * scale, C // scale
490
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
491
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
492
+ x = x.permute(0, 2, 1, 3).contiguous()
493
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
494
+ x = x.view(
495
+ n,
496
+ int(h * scale_factor),
497
+ int(w * scale_factor),
498
+ int(c / (scale_factor * scale_factor)),
499
+ )
500
+ x = x.permute(0, 2, 1, 3).contiguous()
501
+ return x
502
+
503
+ def extract_feature(self, pixel_values):
504
+ """Extract and project vision features to language model space."""
505
+ # Extract features from vision encoder
506
+ if self.select_layer == -1:
507
+ vit_embeds = self.vision_model(
508
+ pixel_values=pixel_values, output_hidden_states=False, return_dict=True
509
+ )
510
+ if hasattr(vit_embeds, "last_hidden_state"):
511
+ vit_embeds = vit_embeds.last_hidden_state
512
+ else:
513
+ vit_embeds = self.vision_model(
514
+ pixel_values=pixel_values, output_hidden_states=True, return_dict=True
515
+ ).hidden_states[self.select_layer]
516
+
517
+ # Remove CLS token if not using SigLIP
518
+ if not isinstance(self.vision_model, SiglipVisionModel):
519
+ vit_embeds = vit_embeds[:, 1:, :]
520
+
521
+ # Apply pixel shuffle and MLP projection
522
+ _, n, c = vit_embeds.shape
523
+ h = w = int(n**0.5)
524
+ vit_embeds = vit_embeds.reshape(-1, h, w, c) # (B, H, W, C)
525
+ vit_embeds = self.pixel_shuffle(
526
+ vit_embeds, scale_factor=self.downsample_ratio
527
+ ) # (B, H/s, W/s, C*s*s)
528
+ _, h_s, w_s, c_s = vit_embeds.shape
529
+ vit_embeds = vit_embeds.reshape(-1, h_s * w_s, c_s) # (B, (H/s)*(W/s), C*s*s)
530
+ vit_embeds = self.mlp1(vit_embeds)
531
+
532
+ return vit_embeds
533
+
534
+ def get_input_embeddings(self):
535
+ return self.language_model.get_input_embeddings()
536
+
537
+ def get_output_embeddings(self):
538
+ return self.language_model.get_output_embeddings()
539
+
540
+ def build_collator(self, processor=None,**kwargs):
541
+ return processor or self.processor
542
+
543
+ def post_loss(self, loss, inputs):
544
+ # Add Dummy Gradients for Vision Encoder to ensure multi-GPU synchronization when there are batches with only text samples
545
+ # and other batches with images.
546
+ if "pixel_values" in inputs and inputs["pixel_values"] is None:
547
+ dummy_pixels = torch.zeros(
548
+ 1, 3, 512, 512, device=loss.device, dtype=self.vision_model.dtype
549
+ )
550
+ dummy_output = self.extract_feature(dummy_pixels)
551
+ loss = loss + dummy_output.sum() * 0.0
552
+ return loss
models/local_nemotron/processing_llama_nemotron_vl.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0.
3
+
4
+ import base64
5
+ import os
6
+ from io import BytesIO
7
+ from typing import Any, Dict, List, Optional, Union, Tuple
8
+ import dataclasses
9
+ from dataclasses import field
10
+
11
+ import requests
12
+ import torch
13
+ import torchvision.transforms as T
14
+ from PIL import Image
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ from transformers import ProcessorMixin
17
+
18
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
19
+ IMAGENET_STD = (0.229, 0.224, 0.225)
20
+
21
+ SIGLIP_MEAN = (0.5, 0.5, 0.5)
22
+ SIGLIP_STD = (0.5, 0.5, 0.5)
23
+
24
+
25
+ @dataclasses.dataclass
26
+ class Conversation:
27
+ """Manages prompt construction with system messages and multi-turn dialogues."""
28
+
29
+ # System instruction prepended to prompts
30
+ system_message: str = ""
31
+ # Role identifiers for dialogue turns
32
+ roles: Tuple[str, str] = ("", "")
33
+ # Message history as (role, content) pairs
34
+ messages: List[List[str]] = field(default_factory=list)
35
+ # Separator token between messages
36
+ sep: str = ""
37
+ # Token IDs that trigger generation stopping
38
+ stop_token_ids: List[int] = None
39
+
40
+ def get_prompt(self) -> str:
41
+ """Construct the formatted prompt string from system message and dialogue history."""
42
+ ret = self.system_message + self.sep
43
+ for role, message in self.messages:
44
+ if message:
45
+ ret += role + message + self.sep
46
+ else:
47
+ ret += role
48
+ return ret
49
+
50
+ def append_message(self, role: str, message: str):
51
+ """Add a message turn to the dialogue history."""
52
+ self.messages.append([role, message])
53
+
54
+
55
+ def get_conv_template(name: str) -> Conversation:
56
+ """Initialize a conversation instance with default configuration."""
57
+ return Conversation(
58
+ stop_token_ids=[128259, 128001],
59
+ )
60
+
61
+
62
+ def load_image(image):
63
+ if isinstance(image, Image.Image):
64
+ return image
65
+ elif isinstance(image, str) and os.path.exists(image):
66
+ return Image.open(image)
67
+ elif isinstance(image, dict):
68
+ if "disk_path" in image:
69
+ return Image.open(image["disk_path"])
70
+ elif "base64" in image:
71
+ return Image.open(BytesIO(base64.b64decode(image["base64"])))
72
+ elif "url" in image:
73
+ response = requests.get(image["url"])
74
+ return Image.open(BytesIO(response.content))
75
+ elif "bytes" in image:
76
+ return Image.open(BytesIO(image["bytes"]))
77
+ else:
78
+ raise ValueError(f"Invalid image: {image}")
79
+ else:
80
+ raise ValueError(f"Invalid image: {image}")
81
+
82
+
83
+ def build_transform(input_size, norm_type="imagenet"):
84
+ if norm_type == "imagenet":
85
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
86
+ elif norm_type == "siglip":
87
+ MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
88
+
89
+ transform = T.Compose(
90
+ [
91
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
92
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
93
+ T.ToTensor(),
94
+ T.Normalize(mean=MEAN, std=STD),
95
+ ]
96
+ )
97
+ return transform
98
+
99
+
100
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
101
+ """
102
+ previous version mainly foucs on ratio.
103
+ We also consider area ratio here.
104
+ """
105
+ best_factor = float("-inf")
106
+ best_ratio = (1, 1)
107
+ area = width * height
108
+ for ratio in target_ratios:
109
+ target_aspect_ratio = ratio[0] / ratio[1]
110
+ area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area
111
+ # new area > 60% of original image area is enough.
112
+ factor_based_on_area_n_ratio = min(area_ratio, 0.6) * min(
113
+ target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio
114
+ )
115
+
116
+ if factor_based_on_area_n_ratio > best_factor:
117
+ best_factor = factor_based_on_area_n_ratio
118
+ best_ratio = ratio
119
+
120
+ return best_ratio
121
+
122
+
123
+ def dynamic_preprocess(
124
+ image, min_num=1, max_num=6, image_size=448, use_thumbnail=False
125
+ ):
126
+ orig_width, orig_height = image.size
127
+ aspect_ratio = orig_width / orig_height
128
+
129
+ # calculate the existing image aspect ratio
130
+ target_ratios = set(
131
+ (i, j)
132
+ for n in range(min_num, max_num + 1)
133
+ for i in range(1, n + 1)
134
+ for j in range(1, n + 1)
135
+ if i * j <= max_num and i * j >= min_num
136
+ )
137
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
138
+
139
+ # find the closest aspect ratio to the target
140
+ target_aspect_ratio = find_closest_aspect_ratio(
141
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
142
+ )
143
+
144
+ # calculate the target width and height
145
+ target_width = image_size * target_aspect_ratio[0]
146
+ target_height = image_size * target_aspect_ratio[1]
147
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
148
+
149
+ # resize the image
150
+ resized_img = image.resize((target_width, target_height))
151
+ processed_images = []
152
+ for i in range(blocks):
153
+ box = (
154
+ (i % (target_width // image_size)) * image_size,
155
+ (i // (target_width // image_size)) * image_size,
156
+ ((i % (target_width // image_size)) + 1) * image_size,
157
+ ((i // (target_width // image_size)) + 1) * image_size,
158
+ )
159
+ # split the image
160
+ split_img = resized_img.crop(box)
161
+ processed_images.append(split_img)
162
+ assert len(processed_images) == blocks
163
+ if use_thumbnail and len(processed_images) != 1:
164
+ thumbnail_img = image.resize((image_size, image_size))
165
+ processed_images.append(thumbnail_img)
166
+ return processed_images
167
+
168
+
169
+ class LlamaNemotronVLProcessor(ProcessorMixin):
170
+ attributes = ["tokenizer"]
171
+ tokenizer_class = "AutoTokenizer"
172
+
173
+ def __init__(
174
+ self,
175
+ tokenizer: Any,
176
+ q_max_length: Optional[int] = None,
177
+ p_max_length: Optional[int] = None,
178
+ pad_to_multiple_of: Optional[int] = None,
179
+ query_prefix: str = "query:",
180
+ passage_prefix: str = "passage:",
181
+ max_input_tiles: int = 6,
182
+ num_image_token: int = 128258,
183
+ dynamic_image_size: bool = True,
184
+ image_size: int = 512,
185
+ use_thumbnail: bool = True,
186
+ template: str = "bidirectional-llama-retriever",
187
+ num_channels: int = 3,
188
+ norm_type: str = "siglip",
189
+ system_message: str = "",
190
+ padding: Union[bool, str] = True,
191
+ **kwargs,
192
+ ):
193
+ tokens_to_keep = ["<box>", "</box>", "<ref>", "</ref>"]
194
+ tokenizer.additional_special_tokens = [
195
+ item
196
+ for item in tokenizer.additional_special_tokens
197
+ if item not in tokens_to_keep
198
+ ]
199
+ tokenizer.padding_side = "left"
200
+ tokenizer.model_input_names = tokenizer.model_input_names + ["pixel_values"]
201
+ self.tokenizer = tokenizer
202
+
203
+ self.q_max_length = q_max_length
204
+ self.p_max_length = p_max_length
205
+ self.pad_to_multiple_of = pad_to_multiple_of
206
+ self.query_prefix = query_prefix
207
+ self.passage_prefix = passage_prefix
208
+ self.max_input_tiles = max_input_tiles
209
+ self.num_image_token = num_image_token
210
+ self.dynamic_image_size = dynamic_image_size
211
+ self.image_size = image_size
212
+ self.use_thumbnail = use_thumbnail
213
+ self.template = template
214
+ self.num_channels = num_channels
215
+ self.norm_type = norm_type
216
+ self.system_message = system_message
217
+ self.padding = padding
218
+
219
+ super().__init__(self.tokenizer)
220
+
221
+ def process_documents(self, documents: Union[Dict, List[Dict]], **kwargs):
222
+ if isinstance(documents, dict):
223
+ images = documents["images"]
224
+ texts = documents["texts"]
225
+ assert len(texts) == len(images)
226
+ elif isinstance(documents, list):
227
+ images = [pair["image"] for pair in documents]
228
+ texts = [pair["text"] for pair in documents]
229
+ else:
230
+ raise ValueError("The documents need to be a dict or list of dicts")
231
+
232
+ contents, pil_images, max_input_tile_list, llm_onlys = [], [], [], []
233
+ for image, text in zip(images, texts):
234
+ prefix = ""
235
+ llm_only = True
236
+ if image is not None and image != "":
237
+ pil_images.append(load_image(image))
238
+ prefix = "<image>"
239
+ max_input_tile_list.append(self.max_input_tiles)
240
+ llm_only = False
241
+ else:
242
+ pil_images.append(None)
243
+ max_input_tile_list.append(self.max_input_tiles)
244
+
245
+ llm_onlys.append(llm_only)
246
+
247
+ # ToDo: Order is hardcoded and different than before. No \n after <image>
248
+ content = text
249
+ if prefix != "":
250
+ content = prefix + " " + content
251
+ if self.passage_prefix:
252
+ content = self.passage_prefix + " " + content
253
+ contents.append(content)
254
+
255
+ try:
256
+ assert len(max_input_tile_list) == len(pil_images), (
257
+ "The number of max_input_tile_list and pil_images should be the same."
258
+ )
259
+ assert len(max_input_tile_list) == len(contents), (
260
+ "The number of max_input_tile_list and pil_images should be the same."
261
+ )
262
+ except Exception as e:
263
+ print(f"Error: {e}")
264
+ print(
265
+ f"max_input_tile_list: {max_input_tile_list}, pil_images: {pil_images}"
266
+ )
267
+ raise e
268
+
269
+ transform = build_transform(
270
+ input_size=self.image_size, norm_type=self.norm_type
271
+ )
272
+
273
+ template = get_conv_template(self.template)
274
+ template.system_message = self.system_message
275
+
276
+ content_prompts = []
277
+ pixel_values_list = []
278
+ for content, pil_image, max_input_tiles, llm_only in zip(
279
+ contents, pil_images, max_input_tile_list, llm_onlys
280
+ ):
281
+ if pil_image is not None:
282
+ if self.dynamic_image_size:
283
+ image_tiles = dynamic_preprocess(
284
+ pil_image,
285
+ image_size=self.image_size,
286
+ max_num=max_input_tiles,
287
+ use_thumbnail=self.use_thumbnail,
288
+ )
289
+ else:
290
+ image_tiles = [pil_image]
291
+
292
+ pixel_values = [transform(item) for item in image_tiles]
293
+ pixel_values = torch.stack(pixel_values).to(dtype=torch.bfloat16)
294
+ pixel_values_list.append(pixel_values)
295
+ else:
296
+ pixel_values = None
297
+
298
+ IMG_START_TOKEN = "<img>"
299
+ IMG_END_TOKEN = "</img>"
300
+ IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
301
+
302
+ if pixel_values is not None and "<image>" not in content and not llm_only:
303
+ content = "<image> " + content
304
+
305
+ # Reseting conversation messages
306
+ template.messages.clear()
307
+
308
+ # TODO: do we need this template?
309
+ template.append_message(template.roles[0], content) # user
310
+ template.append_message(template.roles[1], None) # assistant
311
+ content_prompt = template.get_prompt()
312
+
313
+ if pixel_values is not None:
314
+ num_patches = pixel_values.shape[0]
315
+ image_tokens = (
316
+ IMG_START_TOKEN
317
+ + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
318
+ + IMG_END_TOKEN
319
+ )
320
+ content_prompt = content_prompt.replace("<image>", image_tokens, 1)
321
+
322
+ content_prompts.append(content_prompt)
323
+
324
+ model_inputs = self.tokenizer(
325
+ content_prompts,
326
+ truncation=True,
327
+ max_length=self.p_max_length,
328
+ padding=self.padding,
329
+ pad_to_multiple_of=self.pad_to_multiple_of,
330
+ return_tensors="pt",
331
+ )
332
+
333
+ if len(pixel_values_list) > 1:
334
+ pixel_values_squeezed = torch.concat(pixel_values_list, axis=0)
335
+ elif len(pixel_values_list) == 1:
336
+ pixel_values_squeezed = pixel_values_list[0]
337
+ else:
338
+ pixel_values_squeezed = None
339
+
340
+ batch_docs = {
341
+ "input_ids": model_inputs["input_ids"],
342
+ "attention_mask": model_inputs["attention_mask"],
343
+ "pixel_values": None,
344
+ }
345
+ if pixel_values_squeezed is not None:
346
+ batch_docs["pixel_values"] = pixel_values_squeezed
347
+
348
+ return batch_docs
349
+
350
+ def process_queries(self, queries: List[str], **kwargs):
351
+ template = get_conv_template(self.template)
352
+ template.system_message = self.system_message
353
+
354
+ query_prompts = []
355
+ for query in queries:
356
+ if self.query_prefix:
357
+ query = f"{self.query_prefix} {query}"
358
+
359
+ # Reseting conversation messages
360
+ template.messages.clear()
361
+
362
+ template.append_message(template.roles[0], query) # user
363
+ template.append_message(template.roles[1], None) # assistant
364
+ query_prompt = template.get_prompt()
365
+
366
+ query_prompts.append(query_prompt)
367
+
368
+ batch_query = self.tokenizer(
369
+ query_prompts,
370
+ truncation=True,
371
+ max_length=self.q_max_length,
372
+ padding=self.padding,
373
+ pad_to_multiple_of=self.pad_to_multiple_of,
374
+ return_tensors="pt",
375
+ )
376
+
377
+ return batch_query
378
+
379
+ def process_queries_documents_biencoder(self, features: Dict, **kwargs):
380
+ """
381
+ (Pdb) features
382
+ [{'image': [<PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C3A0>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C580>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C940>], 'text': ['passage: ', 'passage: ', 'passage: '], 'question': "query: What change did Carl Rey suggest for the Strategic Plan's website objective deadline?"}, {'image': [<PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C0D0>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5DC00>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5EBF0>], 'text': ['passage: ', 'passage: ', 'passage: '], 'question': 'query: What are the name and TIN requirements for individuals with real estate transactions?'}, {'image': [<PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5D390>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C850>, <PIL.Image.Image image mode=RGB size=1275x1650 at 0x155059A5C070>], 'text': ['passage: ', 'passage: ', 'passage: '], 'question': 'query: How does Richard Hooker view human inclinations?'}]
383
+ """
384
+ queries = []
385
+ pos_neg_text_batch = []
386
+ pos_neg_image_batch = []
387
+ for feature in features:
388
+ queries.append(feature["question"])
389
+ pos_neg_text_batch.extend(feature["doc_text"])
390
+ pos_neg_image_batch.extend(feature["doc_image"])
391
+
392
+ query_batch_dict = self.process_queries(queries, **kwargs)
393
+ doc_batch_dict = self.process_documents(
394
+ {"images": pos_neg_image_batch, "texts": pos_neg_text_batch}, **kwargs
395
+ )
396
+
397
+ merged_batch_dict = self.merge_batch_dict(query_batch_dict, doc_batch_dict)
398
+ merged_batch_dict = self.add_dummy_labels(queries, merged_batch_dict)
399
+ return merged_batch_dict
400
+
401
+ def merge_batch_dict(self, query_batch_dict, doc_batch_dict):
402
+ q_prefix, d_prefix = "q_", "d_"
403
+ # merge into a single BatchEncoding by adding prefix
404
+ merged_batch_dict = {}
405
+ for k in list(query_batch_dict.keys()):
406
+ merged_batch_dict[q_prefix + k] = query_batch_dict[k]
407
+ del query_batch_dict[k]
408
+ for k in list(doc_batch_dict.keys()):
409
+ merged_batch_dict[d_prefix + k] = doc_batch_dict[k]
410
+ del doc_batch_dict[k]
411
+ return merged_batch_dict
412
+
413
+ def add_dummy_labels(self, questions, merged_batch_dict):
414
+ # dummy placeholder for field "labels", won't use it to compute loss
415
+ labels = torch.zeros(len(questions), dtype=torch.long)
416
+ merged_batch_dict["labels"] = labels
417
+ return merged_batch_dict
models/local_nemotron_rerank/__init__.py ADDED
File without changes
models/local_nemotron_rerank/configuration_llama_nemotron_vl.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0.
3
+
4
+ from typing import Optional
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.models.llama.configuration_llama import LlamaConfig
8
+ from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
9
+ from transformers.utils import logging
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ # ============================================================================
15
+ # Bidirectional LLaMA Configuration
16
+ # ============================================================================
17
+
18
+
19
+ class LlamaBidirectionalConfig(LlamaConfig):
20
+ """Configuration for bidirectional (non-causal) LLaMA model."""
21
+
22
+ model_type = "llama_bidirec"
23
+
24
+ def __init__(
25
+ self,
26
+ pooling="avg",
27
+ temperature=1.0,
28
+ **kwargs,
29
+ ):
30
+ self.pooling = pooling
31
+ self.temperature = temperature
32
+ super().__init__(**kwargs)
33
+
34
+
35
+ # ============================================================================
36
+ # LlamaNemotronVL Configuration Classes
37
+ # ============================================================================
38
+
39
+
40
+ class LlamaNemotronVLConfig(PretrainedConfig):
41
+ """
42
+ Base configuration for vision-language models combining vision and language components.
43
+
44
+ This serves as the foundation for LlamaNemotronVL configurations.
45
+ """
46
+
47
+ model_type = "llama_nemotron_vl"
48
+ is_composition = True
49
+ # is_composition was renamed to has_no_defaults_at_init in transformers 4.52.1
50
+ # In PR https://github.com/huggingface/transformers/pull/36263
51
+ has_no_defaults_at_init = True
52
+
53
+ def __init__(
54
+ self,
55
+ # Vision-language parameters
56
+ vision_config=None,
57
+ llm_config=None,
58
+ use_backbone_lora=0,
59
+ use_llm_lora=0,
60
+ select_layer=-1,
61
+ force_image_size=None,
62
+ downsample_ratio=0.5,
63
+ template=None,
64
+ dynamic_image_size=False,
65
+ use_thumbnail=False,
66
+ min_dynamic_patch=1,
67
+ max_dynamic_patch=6,
68
+ mlp_checkpoint=True,
69
+ pre_feature_reduction=False,
70
+ keep_aspect_ratio=False,
71
+ vocab_size=-1,
72
+ q_max_length: Optional[int] = 512,
73
+ p_max_length: Optional[int] = 10240,
74
+ query_prefix: str = "query:",
75
+ passage_prefix: str = "passage:",
76
+ pooling: str = "last",
77
+ bidirectional_attention: bool = False,
78
+ max_input_tiles: int = 2,
79
+ img_context_token_id: int = 128258, # tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
80
+ **kwargs,
81
+ ):
82
+ # Initialize vision config
83
+ if vision_config is None:
84
+ vision_config = {}
85
+ logger.info(
86
+ "vision_config is None. Initializing Vision Encoders with default values."
87
+ )
88
+ else:
89
+ if vision_config["model_type"] == "siglip_vision_model":
90
+ self.vision_config = SiglipVisionConfig(**vision_config)
91
+ else:
92
+ raise ValueError(
93
+ "Unsupported model_type: {}".format(vision_config["model_type"])
94
+ )
95
+
96
+ # Initialize LLM config
97
+ if llm_config is None:
98
+ llm_config = {}
99
+ logger.info(
100
+ "llm_config is None. Initializing the LLM config with default values"
101
+ )
102
+ else:
103
+ if llm_config["architectures"][0] in {
104
+ "LlamaBidirectionalModel",
105
+ "LlamaBidirectionalForSequenceClassification",
106
+ }:
107
+ self.llm_config = LlamaBidirectionalConfig(**llm_config)
108
+ else:
109
+ raise ValueError(
110
+ "Unsupported architecture: {}".format(
111
+ llm_config["architectures"][0]
112
+ )
113
+ )
114
+ self.vocab_size = self.llm_config.vocab_size
115
+
116
+ # Vision-language parameters
117
+ self.use_backbone_lora = use_backbone_lora
118
+ self.use_llm_lora = use_llm_lora
119
+ self.select_layer = select_layer
120
+ self.force_image_size = force_image_size
121
+ self.downsample_ratio = downsample_ratio
122
+ self.template = template
123
+ self.dynamic_image_size = dynamic_image_size
124
+ self.use_thumbnail = use_thumbnail
125
+ self.min_dynamic_patch = min_dynamic_patch
126
+ self.max_dynamic_patch = max_dynamic_patch
127
+ self.mlp_checkpoint = mlp_checkpoint
128
+ self.pre_feature_reduction = pre_feature_reduction
129
+ self.keep_aspect_ratio = keep_aspect_ratio
130
+
131
+ # Reranking-specific parameters
132
+ self.q_max_length = q_max_length
133
+ self.p_max_length = p_max_length
134
+ self.query_prefix = query_prefix
135
+ self.passage_prefix = passage_prefix
136
+ self.pooling = pooling
137
+ self.bidirectional_attention = bidirectional_attention
138
+ self.img_context_token_id = img_context_token_id
139
+ self.max_input_tiles = max_input_tiles
140
+
141
+ super().__init__(**kwargs)
142
+
143
+
144
+ class LlamaNemotronVLForSequenceClassificationConfig(LlamaNemotronVLConfig):
145
+ """
146
+ Configuration class for LlamaNemotron VL sequence classification model.
147
+
148
+ This configuration extends LlamaNemotronVLConfig with parameters specific to
149
+ sequence classification tasks (reranking).
150
+ """
151
+
152
+ model_type = "llama_nemotron_vl_rerank"
153
+
154
+ def __init__(
155
+ self,
156
+ rerank_max_length: Optional[int] = 512,
157
+ temperature: float = 1.0,
158
+ prompt_template: str = None,
159
+ **kwargs,
160
+ ):
161
+ self.rerank_max_length = rerank_max_length
162
+ self.temperature = temperature
163
+ self.prompt_template = prompt_template
164
+ super().__init__(**kwargs)
models/local_nemotron_rerank/modeling_llama_nemotron_vl.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0.
3
+
4
+ import math
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
+ from transformers import AutoProcessor, PreTrainedModel
11
+ from transformers.cache_utils import Cache
12
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
13
+ from transformers.modeling_outputs import (
14
+ CausalLMOutputWithPast,
15
+ SequenceClassifierOutputWithPast,
16
+ )
17
+ from transformers.models.llama.modeling_llama import (
18
+ LlamaForSequenceClassification,
19
+ LlamaModel,
20
+ )
21
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
22
+ from transformers.utils import logging
23
+
24
+ from .configuration_llama_nemotron_vl import (
25
+ LlamaBidirectionalConfig,
26
+ LlamaNemotronVLConfig,
27
+ LlamaNemotronVLForSequenceClassificationConfig,
28
+ )
29
+ from .processing_llama_nemotron_vl import LlamaNemotronVLRerankProcessor
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ def pool(
35
+ last_hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool_type: str
36
+ ) -> torch.Tensor:
37
+ """
38
+ Pool hidden states according to the specified pooling strategy.
39
+
40
+ Args:
41
+ last_hidden_states: Tensor of shape (batch_size, seq_len, hidden_size)
42
+ attention_mask: Tensor of shape (batch_size, seq_len)
43
+ pool_type: Pooling strategy ('avg', 'weighted_avg', 'cls', 'last', 'cls_last', 'colbert')
44
+
45
+ Returns:
46
+ Pooled embeddings
47
+ """
48
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
49
+
50
+ if pool_type == "avg":
51
+ emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
52
+ elif pool_type == "weighted_avg":
53
+ emb = last_hidden.sum(dim=1)
54
+ elif pool_type == "cls":
55
+ emb = last_hidden[:, 0]
56
+ elif pool_type == "last":
57
+ left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
58
+ if left_padding:
59
+ emb = last_hidden[:, -1]
60
+ else:
61
+ sequence_lengths = attention_mask.sum(dim=1) - 1
62
+ batch_size = last_hidden.shape[0]
63
+ emb = last_hidden[
64
+ torch.arange(batch_size, device=last_hidden.device), sequence_lengths
65
+ ]
66
+ elif pool_type == "cls_last":
67
+ emb = last_hidden[:, 0]
68
+ elif pool_type == "colbert":
69
+ emb = last_hidden
70
+ else:
71
+ raise ValueError(f"pool_type {pool_type} not supported")
72
+
73
+ return emb
74
+
75
+
76
+ # ============================================================================
77
+ # Bidirectional LLaMA Model
78
+ # ============================================================================
79
+
80
+
81
+ class LlamaBidirectionalModel(LlamaModel):
82
+ """LLaMA model with bidirectional (non-causal) attention."""
83
+
84
+ config_class = LlamaBidirectionalConfig
85
+
86
+ def __init__(self, config: LlamaBidirectionalConfig):
87
+ # ✅ FIX: Force eager attention before super().__init__ triggers FA2 checks
88
+ config._attn_implementation = "eager"
89
+ if hasattr(config, 'llm_config'):
90
+ config.llm_config._attn_implementation = "eager"
91
+
92
+ super().__init__(config)
93
+
94
+ # Set non-causal attention for all layers
95
+ for layer in self.layers:
96
+ layer.self_attn.is_causal = False
97
+
98
+ def _update_causal_mask(
99
+ self,
100
+ attention_mask: torch.Tensor,
101
+ input_tensor: torch.Tensor,
102
+ cache_position: torch.Tensor,
103
+ past_key_values: Cache,
104
+ output_attentions: bool,
105
+ ):
106
+ """
107
+ Update causal mask for bidirectional attention.
108
+ Supports flash_attention_2, sdpa, and eager implementations.
109
+ """
110
+ if self.config._attn_implementation == "flash_attention_2":
111
+ # Flash Attention 2: only pass mask if there are actual masks
112
+ if attention_mask is not None and (attention_mask == 0.0).any():
113
+ return attention_mask
114
+ return None
115
+
116
+ elif self.config._attn_implementation == "sdpa":
117
+ # SDPA: prepare 4D attention mask for bidirectional attention
118
+ if attention_mask is not None:
119
+ # Convert 2D mask to 4D: (batch_size, 1, seq_len, seq_len)
120
+ causal_mask = _prepare_4d_attention_mask(
121
+ attention_mask,
122
+ dtype=input_tensor.dtype,
123
+ tgt_len=input_tensor.shape[1],
124
+ )
125
+ return causal_mask
126
+ return None
127
+
128
+ elif self.config._attn_implementation == "eager":
129
+ # Eager: standard 4D attention mask
130
+ causal_mask = _prepare_4d_attention_mask(
131
+ attention_mask,
132
+ dtype=input_tensor.dtype,
133
+ )
134
+ return causal_mask
135
+
136
+ else:
137
+ raise ValueError(
138
+ f"Unsupported attention implementation: {self.config._attn_implementation}. "
139
+ "Supported values: ['flash_attention_2', 'sdpa', 'eager']"
140
+ )
141
+
142
+
143
+ class LlamaBidirectionalForSequenceClassification(LlamaForSequenceClassification):
144
+ """LLaMA sequence classification model with bidirectional attention."""
145
+
146
+ config_class = LlamaBidirectionalConfig
147
+
148
+ def __init__(self, config):
149
+ super().__init__(config)
150
+ # Release the parameters of LlamaModel created by parent
151
+ del self.model
152
+ self.model = LlamaBidirectionalModel(config)
153
+ # Initialize weights and apply final processing
154
+ self.post_init()
155
+
156
+ def forward(
157
+ self,
158
+ input_ids: Optional[torch.LongTensor] = None,
159
+ attention_mask: Optional[torch.Tensor] = None,
160
+ position_ids: Optional[torch.LongTensor] = None,
161
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
162
+ inputs_embeds: Optional[torch.FloatTensor] = None,
163
+ labels: Optional[torch.LongTensor] = None,
164
+ use_cache: Optional[bool] = None,
165
+ output_attentions: Optional[bool] = None,
166
+ output_hidden_states: Optional[bool] = None,
167
+ return_dict: Optional[bool] = None,
168
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
169
+ r"""
170
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
171
+ Labels for computing the sequence classification/regression loss.
172
+ """
173
+ return_dict = (
174
+ return_dict if return_dict is not None else self.config.use_return_dict
175
+ )
176
+
177
+ transformer_outputs = self.model(
178
+ input_ids,
179
+ attention_mask=attention_mask,
180
+ position_ids=position_ids,
181
+ past_key_values=past_key_values,
182
+ inputs_embeds=inputs_embeds,
183
+ use_cache=use_cache,
184
+ output_attentions=output_attentions,
185
+ output_hidden_states=output_hidden_states,
186
+ return_dict=return_dict,
187
+ )
188
+ hidden_states = transformer_outputs[0]
189
+
190
+ pooled_hidden_states = pool(
191
+ last_hidden_states=hidden_states,
192
+ attention_mask=attention_mask,
193
+ pool_type=self.config.pooling,
194
+ )
195
+
196
+ pooled_logits = self.score(pooled_hidden_states)
197
+ pooled_logits = pooled_logits / self.config.temperature
198
+
199
+ loss = None
200
+ if labels is not None:
201
+ labels = labels.to(pooled_logits.device)
202
+ if self.config.problem_type is None:
203
+ if self.num_labels == 1:
204
+ self.config.problem_type = "regression"
205
+ elif self.num_labels > 1 and (
206
+ labels.dtype == torch.long or labels.dtype == torch.int
207
+ ):
208
+ self.config.problem_type = "single_label_classification"
209
+ else:
210
+ self.config.problem_type = "multi_label_classification"
211
+
212
+ if self.config.problem_type == "regression":
213
+ loss_fct = MSELoss()
214
+ if self.num_labels == 1:
215
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
216
+ else:
217
+ loss = loss_fct(pooled_logits, labels)
218
+ elif self.config.problem_type == "single_label_classification":
219
+ loss_fct = CrossEntropyLoss()
220
+ loss = loss_fct(
221
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
222
+ )
223
+ elif self.config.problem_type == "multi_label_classification":
224
+ loss_fct = BCEWithLogitsLoss()
225
+ loss = loss_fct(pooled_logits, labels)
226
+
227
+ if not return_dict:
228
+ output = (pooled_logits,) + transformer_outputs[1:]
229
+ return ((loss,) + output) if loss is not None else output
230
+
231
+ return SequenceClassifierOutputWithPast(
232
+ loss=loss,
233
+ logits=pooled_logits,
234
+ past_key_values=transformer_outputs.past_key_values,
235
+ hidden_states=transformer_outputs.hidden_states,
236
+ attentions=transformer_outputs.attentions,
237
+ )
238
+
239
+
240
+ # ============================================================================
241
+ # LlamaNemotronVL Model Classes
242
+ # ============================================================================
243
+
244
+
245
+ class LlamaNemotronVLModel(PreTrainedModel):
246
+ """
247
+ LlamaNemotron VL model for vision-language reranking.
248
+
249
+ Combines a vision encoder (SigLIP) with a bidirectional language model (LLaMA)
250
+ for cross-modal reranking tasks.
251
+
252
+ Supports flash_attention_2, sdpa, and eager attention implementations.
253
+ """
254
+
255
+ config_class = LlamaNemotronVLConfig
256
+ main_input_name = "pixel_values"
257
+ _no_split_modules = ["LlamaDecoderLayer"]
258
+ _supports_flash_attn_2 = True
259
+ _supports_sdpa = True
260
+
261
+ def __init__(self, config: LlamaNemotronVLConfig, *model_args, **model_kwargs):
262
+ # ✅ FIX: Force eager attention here as well
263
+ config._attn_implementation = "eager"
264
+ super().__init__(config, *model_args, **model_kwargs)
265
+
266
+ # Calculate image token count
267
+ image_size = config.force_image_size or config.vision_config.image_size
268
+ if hasattr(config.vision_config, "grid_size"):
269
+ grid_size = config.vision_config.grid_size
270
+ self.patch_size = 14
271
+ self.num_image_token = int((grid_size * config.downsample_ratio) ** 2)
272
+ else:
273
+ patch_size = config.vision_config.patch_size
274
+ self.patch_size = patch_size
275
+ self.num_image_token = int(
276
+ (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
277
+ )
278
+
279
+ self.select_layer = config.select_layer
280
+ self.template = config.template
281
+ self.downsample_ratio = config.downsample_ratio
282
+
283
+ logger.info(f"num_image_token: {self.num_image_token}")
284
+
285
+ # Initialize vision encoder
286
+ if config.vision_config.model_type == "siglip_vision_model":
287
+ self.vision_model = SiglipVisionModel(config.vision_config)
288
+ else:
289
+ raise NotImplementedError(
290
+ f"Unsupported vision model type: {config.vision_config.model_type}"
291
+ )
292
+
293
+ # Set attention implementation (default to flash_attention_2 if available)
294
+ if not hasattr(config.llm_config, '_attn_implementation'):
295
+ if torch.cuda.is_available() and hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
296
+ config.llm_config._attn_implementation = "sdpa"
297
+ logger.info("Using SDPA attention implementation")
298
+ else:
299
+ config.llm_config._attn_implementation = "eager"
300
+ logger.info("Using eager attention implementation")
301
+ else:
302
+ logger.info(f"Using {config.llm_config._attn_implementation} attention implementation")
303
+
304
+ # Initialize language model (bidirectional for reranking)
305
+ if config.llm_config.architectures[0] in [
306
+ "LlamaBidirectionalModel",
307
+ "LlamaBidirectionalForSequenceClassification",
308
+ ]:
309
+ self.language_model = LlamaBidirectionalModel(config.llm_config)
310
+ else:
311
+ raise NotImplementedError(
312
+ f"{config.llm_config.architectures[0]} is not implemented for reranking."
313
+ )
314
+
315
+ # Vision-to-language projection
316
+ vit_hidden_size = config.vision_config.hidden_size
317
+ llm_hidden_size = config.llm_config.hidden_size
318
+ self.mlp1 = nn.Sequential(
319
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
320
+ nn.Linear(
321
+ vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
322
+ llm_hidden_size,
323
+ ),
324
+ nn.GELU(),
325
+ nn.Linear(llm_hidden_size, llm_hidden_size),
326
+ )
327
+ self.img_context_token_id = None
328
+
329
+ # Initialize processor
330
+ self.processor = AutoProcessor.from_pretrained(
331
+ config.name_or_path, trust_remote_code=True
332
+ )
333
+
334
+ def forward(
335
+ self,
336
+ pixel_values: torch.FloatTensor = None,
337
+ input_ids: torch.LongTensor = None,
338
+ attention_mask: Optional[torch.Tensor] = None,
339
+ position_ids: Optional[torch.LongTensor] = None,
340
+ image_flags: Optional[torch.LongTensor] = None,
341
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
342
+ labels: Optional[torch.LongTensor] = None,
343
+ use_cache: Optional[bool] = None,
344
+ output_attentions: Optional[bool] = None,
345
+ output_hidden_states: Optional[bool] = None,
346
+ return_dict: Optional[bool] = None,
347
+ num_patches_list: Optional[List[torch.Tensor]] = None,
348
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
349
+ return_dict = (
350
+ return_dict if return_dict is not None else self.config.use_return_dict
351
+ )
352
+
353
+ # Get text embeddings
354
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
355
+
356
+ # Process and inject vision embeddings if present
357
+ if pixel_values is not None:
358
+ if image_flags is None:
359
+ image_flags = torch.ones(pixel_values.shape[0])
360
+ image_flags = image_flags.squeeze(-1)
361
+ vit_embeds = self.extract_feature(pixel_values).to(
362
+ device=input_embeds.device
363
+ )
364
+
365
+ if not isinstance(image_flags, list):
366
+ image_flags = image_flags.squeeze(-1)
367
+ vit_embeds = vit_embeds[image_flags == 1]
368
+
369
+ # Inject vision tokens into text embeddings
370
+ B, N, C = input_embeds.shape
371
+ input_embeds = input_embeds.reshape(B * N, C)
372
+ input_ids = input_ids.reshape(B * N)
373
+ selected = input_ids == self.config.img_context_token_id
374
+ try:
375
+ input_embeds[selected] = input_embeds[
376
+ selected
377
+ ] * 0.0 + vit_embeds.reshape(-1, C)
378
+ except Exception as e:
379
+ vit_embeds = vit_embeds.reshape(-1, C)
380
+ logger.warning(
381
+ f"Shape mismatch in vision embedding injection: {e}, "
382
+ f"input_embeds[selected].shape={input_embeds[selected].shape}, "
383
+ f"vit_embeds.shape={vit_embeds.shape}"
384
+ )
385
+ n_token = selected.sum()
386
+ input_embeds[selected] = (
387
+ input_embeds[selected] * 0.0 + vit_embeds[:n_token]
388
+ )
389
+
390
+ input_embeds = input_embeds.reshape(B, N, C)
391
+
392
+ # Forward through language model
393
+ outputs = self.language_model(
394
+ inputs_embeds=input_embeds,
395
+ attention_mask=attention_mask,
396
+ position_ids=position_ids,
397
+ past_key_values=past_key_values,
398
+ use_cache=use_cache,
399
+ output_attentions=output_attentions,
400
+ output_hidden_states=output_hidden_states,
401
+ )
402
+ logits = None
403
+ loss = None
404
+
405
+ if hasattr(outputs, "logits"):
406
+ logits = outputs.logits
407
+ if labels is not None:
408
+ # Shift so that tokens < n predict n
409
+ shift_logits = logits[..., :-1, :].contiguous()
410
+ shift_labels = labels[..., 1:].contiguous()
411
+ # Flatten the tokens
412
+ loss_fct = CrossEntropyLoss()
413
+ shift_logits = shift_logits.view(
414
+ -1, self.language_model.config.vocab_size
415
+ )
416
+ shift_labels = shift_labels.view(-1)
417
+ # Enable model parallelism
418
+ shift_labels = shift_labels.to(shift_logits.device)
419
+ loss = loss_fct(shift_logits, shift_labels)
420
+
421
+ if not return_dict:
422
+ output = (logits,) + outputs[1:]
423
+ return (loss,) + output if loss is not None else output
424
+
425
+ return CausalLMOutputWithPast(
426
+ loss=loss,
427
+ logits=logits,
428
+ past_key_values=outputs.past_key_values,
429
+ hidden_states=outputs.hidden_states,
430
+ attentions=outputs.attentions,
431
+ )
432
+
433
+ def pixel_shuffle(self, x, scale_factor=0.5):
434
+ """
435
+ Rearrange pixels for downsampling/upsampling.
436
+
437
+ Args:
438
+ x: Input tensor of shape (N, W, H, C)
439
+ scale_factor: Scaling factor for shuffle operation
440
+
441
+ Returns:
442
+ Shuffled tensor
443
+ """
444
+ n, w, h, c = x.shape
445
+ # N, W, H, C --> N, W, H * scale, C // scale
446
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
447
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
448
+ x = x.permute(0, 2, 1, 3).contiguous()
449
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
450
+ x = x.view(
451
+ n,
452
+ int(h * scale_factor),
453
+ int(w * scale_factor),
454
+ int(c / (scale_factor * scale_factor)),
455
+ )
456
+ x = x.permute(0, 2, 1, 3).contiguous()
457
+ return x
458
+
459
+ def extract_feature(self, pixel_values):
460
+ """
461
+ Extract and project vision features to language model space.
462
+
463
+ Args:
464
+ pixel_values: Image tensor
465
+
466
+ Returns:
467
+ Projected vision embeddings
468
+ """
469
+ # Extract features from vision encoder
470
+ if self.select_layer == -1:
471
+ vit_embeds = self.vision_model(
472
+ pixel_values=pixel_values, output_hidden_states=False, return_dict=True
473
+ )
474
+ if hasattr(vit_embeds, "last_hidden_state"):
475
+ vit_embeds = vit_embeds.last_hidden_state
476
+ else:
477
+ vit_embeds = self.vision_model(
478
+ pixel_values=pixel_values, output_hidden_states=True, return_dict=True
479
+ ).hidden_states[self.select_layer]
480
+
481
+ # Remove CLS token if not using SigLIP
482
+ if not isinstance(self.vision_model, SiglipVisionModel):
483
+ vit_embeds = vit_embeds[:, 1:, :]
484
+
485
+ # Apply pixel shuffle and MLP projection
486
+ _, n, c = vit_embeds.shape
487
+ h = w = int(n**0.5)
488
+ vit_embeds = vit_embeds.reshape(-1, h, w, c)
489
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
490
+ _, h_s, w_s, c_s = vit_embeds.shape
491
+ vit_embeds = vit_embeds.reshape(-1, h_s * w_s, c_s)
492
+ vit_embeds = self.mlp1(vit_embeds)
493
+
494
+ return vit_embeds
495
+
496
+ def build_collator(self, tokenizer, **kwargs):
497
+ return self.processor
498
+
499
+ def post_loss(self, loss, inputs):
500
+ """
501
+ Add dummy gradients for vision encoder to ensure multi-GPU synchronization.
502
+
503
+ Args:
504
+ loss: Computed loss
505
+ inputs: Input dictionary
506
+
507
+ Returns:
508
+ Modified loss with dummy gradients
509
+ """
510
+ if "pixel_values" in inputs and inputs["pixel_values"] is None:
511
+ dummy_pixels = torch.zeros(
512
+ 1, 3, 512, 512, device=loss.device, dtype=self.vision_model.dtype
513
+ )
514
+ dummy_output = self.extract_feature(dummy_pixels)
515
+ loss = loss + dummy_output.sum() * 0.0
516
+ return loss
517
+
518
+
519
+ class CrossEncoderHead(nn.Linear):
520
+ """Classification head for cross-encoder reranking."""
521
+
522
+ pass
523
+
524
+
525
+ class LlamaNemotronVLForSequenceClassification(PreTrainedModel):
526
+ """
527
+ LlamaNemotron VL model for sequence classification (reranking).
528
+
529
+ Supports flash_attention_2, sdpa, and eager attention implementations.
530
+ """
531
+
532
+ config_class = LlamaNemotronVLForSequenceClassificationConfig
533
+ base_model_prefix = "model"
534
+ _supports_flash_attn_2 = True
535
+ _supports_sdpa = True
536
+ _no_split_modules = ["LlamaNemotronVLModel"]
537
+
538
+ def __init__(self, config, **kwargs):
539
+ super().__init__(config, **kwargs)
540
+ self.num_labels = config.num_labels
541
+
542
+ self.add_module("model", LlamaNemotronVLModel(config))
543
+
544
+ score = CrossEncoderHead(
545
+ config.llm_config.hidden_size,
546
+ self.num_labels,
547
+ bias=False,
548
+ dtype=torch.float32,
549
+ )
550
+ self.add_module("score", score)
551
+
552
+ # Initialize weights and apply final processing
553
+ self.post_init()
554
+
555
+ def _init_weights(self, module):
556
+ """Initialize weights for the model."""
557
+ super()._init_weights(module)
558
+ if isinstance(module, CrossEncoderHead):
559
+ # Initialize cross-encoder head to avoid NaN/Inf loss
560
+ torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
561
+
562
+ def forward(
563
+ self,
564
+ pixel_values: torch.FloatTensor = None,
565
+ input_ids: torch.LongTensor = None,
566
+ attention_mask: Optional[torch.Tensor] = None,
567
+ position_ids: Optional[torch.LongTensor] = None,
568
+ image_flags: Optional[torch.LongTensor] = None,
569
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
570
+ inputs_embeds: Optional[torch.FloatTensor] = None,
571
+ labels: Optional[torch.LongTensor] = None,
572
+ use_cache: Optional[bool] = None,
573
+ output_attentions: Optional[bool] = None,
574
+ output_hidden_states: Optional[bool] = None,
575
+ return_dict: Optional[bool] = None,
576
+ num_patches_list: Optional[List[torch.Tensor]] = None,
577
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
578
+ r"""
579
+ Forward pass for sequence classification.
580
+
581
+ Args:
582
+ pixel_values: Image pixel values
583
+ input_ids: Input token IDs
584
+ attention_mask: Attention mask
585
+ position_ids: Position IDs
586
+ image_flags: Flags indicating image presence
587
+ past_key_values: Cached key-value pairs
588
+ inputs_embeds: Input embeddings (alternative to input_ids)
589
+ labels: Labels for classification
590
+ use_cache: Whether to use KV cache
591
+ output_attentions: Whether to output attention weights
592
+ output_hidden_states: Whether to output hidden states
593
+ return_dict: Whether to return ModelOutput
594
+ num_patches_list: List of number of patches per image
595
+
596
+ Returns:
597
+ SequenceClassifierOutputWithPast or tuple
598
+ """
599
+ return_dict = (
600
+ return_dict if return_dict is not None else self.config.use_return_dict
601
+ )
602
+
603
+ transformer_outputs = self.model(
604
+ pixel_values=pixel_values,
605
+ input_ids=input_ids,
606
+ attention_mask=attention_mask,
607
+ position_ids=position_ids,
608
+ image_flags=image_flags,
609
+ past_key_values=past_key_values,
610
+ use_cache=use_cache,
611
+ output_attentions=output_attentions,
612
+ output_hidden_states=True,
613
+ return_dict=return_dict,
614
+ num_patches_list=num_patches_list,
615
+ )
616
+
617
+ hidden_states = transformer_outputs.hidden_states[-1]
618
+
619
+ pooled_hidden_states = pool(
620
+ last_hidden_states=hidden_states,
621
+ attention_mask=attention_mask,
622
+ pool_type=self.config.pooling,
623
+ )
624
+
625
+ pooled_logits = self.score(pooled_hidden_states.to(self.score.weight.dtype))
626
+ pooled_logits = pooled_logits / self.config.temperature
627
+
628
+ if torch.isnan(pooled_logits).any():
629
+ raise ValueError("NaN detected in pooled_logits!")
630
+
631
+ loss = None
632
+
633
+ if not return_dict:
634
+ output = (pooled_logits,) + transformer_outputs[1:]
635
+ return ((loss,) + output) if loss is not None else output
636
+
637
+ return SequenceClassifierOutputWithPast(
638
+ loss=loss,
639
+ logits=pooled_logits,
640
+ past_key_values=transformer_outputs.past_key_values,
641
+ hidden_states=transformer_outputs.hidden_states,
642
+ attentions=transformer_outputs.attentions,
643
+ )
644
+
645
+ def build_collator(self, tokenizer, **kwargs):
646
+ """Build data collator for reranking."""
647
+ rerank_max_length = kwargs.pop(
648
+ "rerank_max_length", self.config.rerank_max_length
649
+ )
650
+ max_input_tiles = kwargs.pop("max_input_tiles", self.config.max_input_tiles)
651
+ prompt_template = kwargs.pop("prompt_template", self.config.prompt_template)
652
+ return LlamaNemotronVLRerankProcessor(
653
+ tokenizer=tokenizer,
654
+ rerank_max_length=rerank_max_length,
655
+ max_input_tiles=max_input_tiles,
656
+ num_image_token=self.model.num_image_token,
657
+ prompt_template=prompt_template,
658
+ **kwargs,
659
+ )
660
+
661
+ def post_loss(self, loss, inputs):
662
+ """
663
+ Add dummy gradients for vision encoder to ensure multi-GPU synchronization.
664
+
665
+ Args:
666
+ loss: Computed loss
667
+ inputs: Input dictionary
668
+
669
+ Returns:
670
+ Modified loss with dummy gradients
671
+ """
672
+ if "pixel_values" in inputs and inputs["pixel_values"] is None:
673
+ dummy_pixels = torch.zeros(
674
+ 1, 3, 512, 512, device=loss.device, dtype=self.model.vision_model.dtype
675
+ )
676
+ dummy_output = self.model.extract_feature(dummy_pixels)
677
+ loss = loss + dummy_output.sum() * 0.0
678
+ return loss
models/local_nemotron_rerank/processing_llama_nemotron_vl.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0.
3
+
4
+ import base64
5
+ import os
6
+ from io import BytesIO
7
+ from typing import Any, Dict, List, Optional, Union, Tuple
8
+ import dataclasses
9
+ from dataclasses import field
10
+
11
+ import requests
12
+ import torch
13
+ import torchvision.transforms as T
14
+ from PIL import Image
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ from transformers import ProcessorMixin
17
+
18
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
19
+ IMAGENET_STD = (0.229, 0.224, 0.225)
20
+
21
+ SIGLIP_MEAN = (0.5, 0.5, 0.5)
22
+ SIGLIP_STD = (0.5, 0.5, 0.5)
23
+
24
+
25
+ @dataclasses.dataclass
26
+ class Conversation:
27
+ """Manages prompt construction with system messages and multi-turn dialogues."""
28
+
29
+ # System instruction prepended to prompts
30
+ system_message: str = ""
31
+ # Role identifiers for dialogue turns
32
+ roles: Tuple[str, str] = ("", "")
33
+ # Message history as (role, content) pairs
34
+ messages: List[List[str]] = field(default_factory=list)
35
+ # Separator token between messages
36
+ sep: str = ""
37
+ # Token IDs that trigger generation stopping
38
+ stop_token_ids: List[int] = None
39
+
40
+ def get_prompt(self) -> str:
41
+ """Construct the formatted prompt string from system message and dialogue history."""
42
+ ret = self.system_message + self.sep
43
+ for role, message in self.messages:
44
+ if message:
45
+ ret += role + message + self.sep
46
+ else:
47
+ ret += role
48
+ return ret
49
+
50
+ def append_message(self, role: str, message: str):
51
+ """Add a message turn to the dialogue history."""
52
+ self.messages.append([role, message])
53
+
54
+
55
+ def get_conv_template(name: str) -> Conversation:
56
+ """Initialize a conversation instance with default configuration."""
57
+ return Conversation(
58
+ stop_token_ids=[128259, 128001],
59
+ )
60
+
61
+
62
+ def load_image(image):
63
+ if isinstance(image, Image.Image):
64
+ return image
65
+ elif isinstance(image, str) and os.path.exists(image):
66
+ return Image.open(image)
67
+ elif isinstance(image, dict):
68
+ if "disk_path" in image:
69
+ return Image.open(image["disk_path"])
70
+ elif "base64" in image:
71
+ return Image.open(BytesIO(base64.b64decode(image["base64"])))
72
+ elif "url" in image:
73
+ response = requests.get(image["url"])
74
+ return Image.open(BytesIO(response.content))
75
+ elif "bytes" in image:
76
+ return Image.open(BytesIO(image["bytes"]))
77
+ else:
78
+ raise ValueError(f"Invalid image: {image}")
79
+ else:
80
+ raise ValueError(f"Invalid image: {image}")
81
+
82
+
83
+ def build_transform(input_size, norm_type="imagenet"):
84
+ if norm_type == "imagenet":
85
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
86
+ elif norm_type == "siglip":
87
+ MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
88
+
89
+ transform = T.Compose(
90
+ [
91
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
92
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
93
+ T.ToTensor(),
94
+ T.Normalize(mean=MEAN, std=STD),
95
+ ]
96
+ )
97
+ return transform
98
+
99
+
100
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
101
+ """
102
+ previous version mainly foucs on ratio.
103
+ We also consider area ratio here.
104
+ """
105
+ best_factor = float("-inf")
106
+ best_ratio = (1, 1)
107
+ area = width * height
108
+ for ratio in target_ratios:
109
+ target_aspect_ratio = ratio[0] / ratio[1]
110
+ area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area
111
+ # new area > 60% of original image area is enough.
112
+ factor_based_on_area_n_ratio = min(area_ratio, 0.6) * min(
113
+ target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio
114
+ )
115
+
116
+ if factor_based_on_area_n_ratio > best_factor:
117
+ best_factor = factor_based_on_area_n_ratio
118
+ best_ratio = ratio
119
+
120
+ return best_ratio
121
+
122
+
123
+ def dynamic_preprocess(
124
+ image, min_num=1, max_num=6, image_size=448, use_thumbnail=False
125
+ ):
126
+ orig_width, orig_height = image.size
127
+ aspect_ratio = orig_width / orig_height
128
+
129
+ # calculate the existing image aspect ratio
130
+ target_ratios = set(
131
+ (i, j)
132
+ for n in range(min_num, max_num + 1)
133
+ for i in range(1, n + 1)
134
+ for j in range(1, n + 1)
135
+ if i * j <= max_num and i * j >= min_num
136
+ )
137
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
138
+
139
+ # find the closest aspect ratio to the target
140
+ target_aspect_ratio = find_closest_aspect_ratio(
141
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
142
+ )
143
+
144
+ # calculate the target width and height
145
+ target_width = image_size * target_aspect_ratio[0]
146
+ target_height = image_size * target_aspect_ratio[1]
147
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
148
+
149
+ # resize the image
150
+ resized_img = image.resize((target_width, target_height))
151
+ processed_images = []
152
+ for i in range(blocks):
153
+ box = (
154
+ (i % (target_width // image_size)) * image_size,
155
+ (i // (target_width // image_size)) * image_size,
156
+ ((i % (target_width // image_size)) + 1) * image_size,
157
+ ((i // (target_width // image_size)) + 1) * image_size,
158
+ )
159
+ # split the image
160
+ split_img = resized_img.crop(box)
161
+ processed_images.append(split_img)
162
+ assert len(processed_images) == blocks
163
+ if use_thumbnail and len(processed_images) != 1:
164
+ thumbnail_img = image.resize((image_size, image_size))
165
+ processed_images.append(thumbnail_img)
166
+ return processed_images
167
+
168
+
169
+ class LlamaNemotronVLRerankProcessor(ProcessorMixin):
170
+ attributes = ["tokenizer"]
171
+ tokenizer_class = "AutoTokenizer"
172
+
173
+ def __init__(
174
+ self,
175
+ tokenizer: Any,
176
+ padding: Union[bool, str] = True,
177
+ rerank_max_length: Optional[int] = 512,
178
+ pad_to_multiple_of: Optional[int] = None,
179
+ max_input_tiles: int = 2,
180
+ num_image_token: int = None,
181
+ prompt_template: str = None,
182
+ force_image_size: int = 512,
183
+ template: str = "bidirectional-llama-retriever",
184
+ dynamic_image_size: bool = True,
185
+ use_thumbnail: bool = True,
186
+ **kwargs,
187
+ ):
188
+ self.padding = padding
189
+ self.rerank_max_length = rerank_max_length
190
+ self.pad_to_multiple_of = pad_to_multiple_of
191
+
192
+ tokens_to_keep = ["<box>", "</box>", "<ref>", "</ref>"]
193
+ tokenizer.additional_special_tokens = [
194
+ item
195
+ for item in tokenizer.additional_special_tokens
196
+ if item not in tokens_to_keep
197
+ ]
198
+ tokenizer.padding_side = "left"
199
+ self.tokenizer = tokenizer
200
+
201
+ self.norm_type = "siglip"
202
+ self.image_size = force_image_size
203
+ self.max_input_tiles = max_input_tiles
204
+ self.num_image_token = num_image_token
205
+ self.system_message = ""
206
+ self.prompt_template = prompt_template
207
+ self.template = template
208
+ self.dynamic_image_size = dynamic_image_size
209
+ self.use_thumbnail = use_thumbnail
210
+
211
+ super().__init__(self.tokenizer)
212
+
213
+ def process_query_documents(self, documents: Union[Dict, List[Dict]], **kwargs):
214
+ if isinstance(documents, dict):
215
+ images = documents["images"]
216
+ texts = documents["texts"]
217
+ assert len(texts) == len(images)
218
+ elif isinstance(documents, list):
219
+ images = [pair["image"] for pair in documents]
220
+ texts = [pair["text"] for pair in documents]
221
+ else:
222
+ raise ValueError("The documents need to be a dict or list of dicts")
223
+
224
+ contents, pil_images, max_input_tile_list, llm_onlys = [], [], [], []
225
+ for image, text in zip(images, texts):
226
+ prefix = ""
227
+ llm_only = True
228
+ if image is not None and image != "":
229
+ pil_images.append(load_image(image))
230
+ prefix = "<image>"
231
+ max_input_tile_list.append(self.max_input_tiles)
232
+ llm_only = False
233
+ else:
234
+ pil_images.append(None)
235
+ max_input_tile_list.append(self.max_input_tiles)
236
+ llm_onlys.append(llm_only)
237
+
238
+ # ToDo: Order is hardcoded and different than before. No \n after <image>
239
+ content = text
240
+ if prefix != "":
241
+ content = prefix + " " + content
242
+ contents.append(content)
243
+
244
+ assert len(max_input_tile_list) == len(pil_images), (
245
+ "The number of max_input_tile_list and pil_images should be the same."
246
+ )
247
+ assert len(max_input_tile_list) == len(contents), (
248
+ "The number of max_input_tile_list and contents should be the same."
249
+ )
250
+
251
+ transform = build_transform(
252
+ input_size=self.image_size, norm_type=self.norm_type
253
+ )
254
+
255
+ template = get_conv_template(self.template)
256
+ template.system_message = self.system_message
257
+
258
+ content_prompts = []
259
+ pixel_values_list = []
260
+ for content, pil_image, max_input_tiles, llm_only in zip(
261
+ contents, pil_images, max_input_tile_list, llm_onlys
262
+ ):
263
+ if pil_image is not None:
264
+ if self.dynamic_image_size:
265
+ image_tiles = dynamic_preprocess(
266
+ pil_image,
267
+ image_size=self.image_size,
268
+ max_num=max_input_tiles,
269
+ use_thumbnail=self.use_thumbnail,
270
+ )
271
+ else:
272
+ image_tiles = [pil_image]
273
+
274
+ pixel_values = [transform(item) for item in image_tiles]
275
+ pixel_values = torch.stack(pixel_values).to(dtype=torch.bfloat16)
276
+ # print(f'Split images to {pixel_values[0].shape}')
277
+ pixel_values_list.append(pixel_values)
278
+ else:
279
+ pixel_values = None
280
+
281
+ IMG_START_TOKEN = "<img>"
282
+ IMG_END_TOKEN = "</img>"
283
+ IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
284
+
285
+ if pixel_values is not None and "<image>" not in content and not llm_only:
286
+ content = "<image> " + content
287
+
288
+ # Reseting conversation messages
289
+ template.messages.clear()
290
+
291
+ # TODO: do we need this template?
292
+ template.append_message(template.roles[0], content) # user
293
+ template.append_message(template.roles[1], None) # assistant
294
+ content_prompt = template.get_prompt()
295
+
296
+ if "<image>" in content:
297
+ num_patches = pixel_values.shape[0]
298
+ image_tokens = (
299
+ IMG_START_TOKEN
300
+ + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
301
+ + IMG_END_TOKEN
302
+ )
303
+ content_prompt = content_prompt.replace("<image>", image_tokens, 1)
304
+
305
+ content_prompts.append(content_prompt)
306
+
307
+ model_inputs = self.tokenizer(
308
+ content_prompts,
309
+ truncation=True,
310
+ max_length=self.rerank_max_length,
311
+ padding=self.padding,
312
+ pad_to_multiple_of=self.pad_to_multiple_of,
313
+ return_tensors="pt",
314
+ )
315
+
316
+ if len(pixel_values_list) > 1:
317
+ pixel_values_squeezed = torch.concat(pixel_values_list, axis=0)
318
+ elif len(pixel_values_list) == 1:
319
+ pixel_values_squeezed = pixel_values_list[0]
320
+ else:
321
+ pixel_values_squeezed = None
322
+
323
+ batch_docs = {
324
+ "input_ids": model_inputs["input_ids"],
325
+ "attention_mask": model_inputs["attention_mask"],
326
+ "pixel_values": None,
327
+ }
328
+
329
+ if pixel_values_squeezed is not None:
330
+ batch_docs["pixel_values"] = pixel_values_squeezed
331
+
332
+ return batch_docs
333
+
334
+ def prompt_template_question_passage(self, question, text):
335
+ return f"question:{question} \n \n passage:{text}"
336
+
337
+ def process_queries_documents_crossencoder(self, features: List[Dict], **kwargs):
338
+ images = [feature["doc_image"] for feature in features]
339
+ if self.prompt_template == "v1":
340
+ questions_texts = [
341
+ self.prompt_template_question_passage(
342
+ feature["question"], feature["doc_text"]
343
+ )
344
+ for feature in features
345
+ ]
346
+ else:
347
+ questions_texts = [
348
+ f"{feature['question']} \n {feature['doc_text']}"
349
+ for feature in features
350
+ ]
351
+ batch_dict = self.process_query_documents(
352
+ {"images": images, "texts": questions_texts}, **kwargs
353
+ )
354
+
355
+ if "num_labels" in features[0]:
356
+ batch_dict["labels"] = torch.zeros(
357
+ features[0]["num_labels"], dtype=torch.long
358
+ )
359
+
360
+ return batch_dict
models/model_loader.py CHANGED
@@ -9,17 +9,23 @@ def load_embed_model(model_path: str = "nvidia/llama-nemotron-embed-vl-1b-v2"):
9
 
10
  print(f"🔄 Loading embedding model on {device}...")
11
 
 
12
  config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
13
- # ✅ FIX: Removed SDPA config override which causes issues in HF Spaces
14
 
15
- # FIX: Use manual device instead of device_map="auto"
16
- model = AutoModel.from_pretrained(
 
 
 
 
 
 
17
  model_path,
18
  config=config,
19
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
20
- trust_remote_code=True,
21
- low_cpu_mem_usage=True, # ✅ CPU optimization
22
- attn_implementation="eager", # FIX: Force eager execution
23
  ).to(device).eval()
24
 
25
  print(f"✅ Embedding model loaded on {device}")
@@ -34,10 +40,22 @@ def load_rerank_model(model_path: str = "nvidia/llama-nemotron-rerank-vl-1b-v2")
34
  print(f"🔄 Loading reranking model on {device}...")
35
 
36
  # ✅ FIX: Use manual device instead of device_map="auto"
37
- model = AutoModelForSequenceClassification.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
38
  model_path,
 
39
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
40
- trust_remote_code=True,
41
  attn_implementation="eager",
42
  ).to(device).eval()
43
 
 
9
 
10
  print(f"🔄 Loading embedding model on {device}...")
11
 
12
+ # ✅ FIX: Load CONFIG from hub but CODE from local patched file
13
  config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
 
14
 
15
+ # Import local patched model class
16
+ import sys
17
+ import os
18
+ sys.path.append(os.path.join(os.path.dirname(__file__), "local_nemotron"))
19
+ from local_nemotron.modeling_llama_nemotron_vl import LlamaNemotronVLModel
20
+
21
+ # Initialize model using local class
22
+ model = LlamaNemotronVLModel.from_pretrained(
23
  model_path,
24
  config=config,
25
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
26
+ trust_remote_code=False, # We are using local code now
27
+ low_cpu_mem_usage=True,
28
+ # attn_implementation="eager", # Explicitly set in __init__ patch now
29
  ).to(device).eval()
30
 
31
  print(f"✅ Embedding model loaded on {device}")
 
40
  print(f"🔄 Loading reranking model on {device}...")
41
 
42
  # ✅ FIX: Use manual device instead of device_map="auto"
43
+ # FIX: Load CONFIG from hub but CODE from local patched file
44
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
45
+
46
+ # Import local patched model class
47
+ import sys
48
+ import os
49
+ sys.path.append(os.path.join(os.path.dirname(__file__), "local_nemotron_rerank"))
50
+ # Rerank model usually uses ForSequenceClassification variant, checking imports
51
+ from local_nemotron_rerank.modeling_llama_nemotron_vl import LlamaNemotronVLForSequenceClassification
52
+
53
+ # Initialize model using local class
54
+ model = LlamaNemotronVLForSequenceClassification.from_pretrained(
55
  model_path,
56
+ config=config,
57
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
58
+ trust_remote_code=False,
59
  attn_implementation="eager",
60
  ).to(device).eval()
61