nalazhar commited on
Commit
405458c
·
verified ·
1 Parent(s): 3b31ddf

Upload whaleye.patch with huggingface_hub

Browse files
Files changed (1) hide show
  1. whaleye.patch +631 -0
whaleye.patch ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/tests/models/registry.py b/tests/models/registry.py
2
+ index 020cb7493..7a9e16c00 100644
3
+ --- a/tests/models/registry.py
4
+ +++ b/tests/models/registry.py
5
+ @@ -845,6 +845,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
6
+ # disable this temporarily until we support HF format
7
+ is_available_online=False,
8
+ ),
9
+ + "WhaleyeForConditionalGeneration": _HfExamplesInfo(
10
+ + "umans-ai/Whaleye-V0",
11
+ + is_available_online=False,
12
+ + ),
13
+ # [Encoder-decoder]
14
+ "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"),
15
+ # [Cross-encoder]
16
+ diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
17
+ index a4a964bc7..fd40ff25c 100644
18
+ --- a/vllm/model_executor/models/registry.py
19
+ +++ b/vllm/model_executor/models/registry.py
20
+ @@ -411,6 +411,7 @@ _MULTIMODAL_MODELS = {
21
+ ),
22
+ "UltravoxModel": ("ultravox", "UltravoxModel"),
23
+ "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
24
+ + "WhaleyeForConditionalGeneration": ("whaleye", "WhaleyeForConditionalGeneration"), # noqa: E501
25
+ # [Encoder-decoder]
26
+ "WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
27
+ }
28
+ diff --git a/vllm/model_executor/models/whaleye.py b/vllm/model_executor/models/whaleye.py
29
+ new file mode 100644
30
+ index 000000000..60d8f8b22
31
+ --- /dev/null
32
+ +++ b/vllm/model_executor/models/whaleye.py
33
+ @@ -0,0 +1,598 @@
34
+ +# SPDX-License-Identifier: Apache-2.0
35
+ +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
36
+ +"""
37
+ +Whaleye: Pixtral Vision Encoder + DeepSeek V3.2 Language Model
38
+ +"""
39
+ +
40
+ +from collections.abc import Iterable, Mapping, Sequence
41
+ +from dataclasses import fields
42
+ +from functools import cached_property
43
+ +
44
+ +import torch
45
+ +from torch import nn
46
+ +from mistral_common.protocol.instruct.chunk import ImageChunk
47
+ +from mistral_common.tokens.tokenizers.image import (
48
+ + ImageConfig,
49
+ + ImageEncoder,
50
+ + SpecialImageIDs,
51
+ +)
52
+ +from PIL import Image
53
+ +from transformers import TensorType
54
+ +from transformers.feature_extraction_utils import BatchFeature
55
+ +from transformers.image_utils import ImageInput
56
+ +from transformers.tokenization_utils_base import TextInput
57
+ +
58
+ +from vllm.config import VllmConfig
59
+ +from vllm.config.multimodal import BaseDummyOptions
60
+ +from vllm.model_executor.model_loader.weight_utils import default_weight_loader
61
+ +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
62
+ +from vllm.multimodal.inputs import (
63
+ + MultiModalDataDict,
64
+ + MultiModalFieldConfig,
65
+ + MultiModalUUIDDict,
66
+ + NestedTensors,
67
+ +)
68
+ +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
69
+ +from vllm.multimodal.processing import (
70
+ + BaseMultiModalProcessor,
71
+ + BaseProcessingInfo,
72
+ + MultiModalProcessingInfo,
73
+ + PromptReplacement,
74
+ + PromptUpdate,
75
+ + PromptUpdateDetails,
76
+ +)
77
+ +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
78
+ +from vllm.sequence import IntermediateTensors
79
+ +from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
80
+ +
81
+ +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
82
+ +from .pixtral import (
83
+ + PATCH_MERGE,
84
+ + PatchMerger,
85
+ + PixtralImagePixelInputs,
86
+ + VisionEncoderArgs,
87
+ + VisionLanguageAdapter,
88
+ + VisionTransformer,
89
+ +)
90
+ +from .utils import init_vllm_registered_model, maybe_prefix
91
+ +
92
+ +# Re-use RMSNorm from layernorm module
93
+ +from vllm.model_executor.layers.layernorm import RMSNorm
94
+ +
95
+ +
96
+ +class WhaleyeProcessorAdapter:
97
+ +
98
+ + def __init__(self, tokenizer: TokenizerLike, image_encoder: ImageEncoder) -> None:
99
+ + super().__init__()
100
+ + self._tokenizer = tokenizer
101
+ + self._image_encoder = image_encoder
102
+ +
103
+ + @property
104
+ + def tokenizer(self) -> TokenizerLike:
105
+ + return self._tokenizer
106
+ +
107
+ + @property
108
+ + def image_processor(self) -> ImageEncoder:
109
+ + return self._image_encoder
110
+ +
111
+ + @cached_property
112
+ + def image_token_id(self) -> int:
113
+ + return self.image_processor.special_ids.img
114
+ +
115
+ + @cached_property
116
+ + def image_break_id(self) -> int:
117
+ + return self.image_processor.special_ids.img_break
118
+ +
119
+ + @cached_property
120
+ + def image_end_id(self) -> int:
121
+ + return self.image_processor.special_ids.img_end
122
+ +
123
+ + @cached_property
124
+ + def image_size(self) -> int:
125
+ + return self.image_processor.mm_config.max_image_size
126
+ +
127
+ + @cached_property
128
+ + def patch_size(self) -> int:
129
+ + return self.image_processor.mm_config.image_patch_size
130
+ +
131
+ + def __call__(
132
+ + self,
133
+ + text: TextInput | list[TextInput] | None = None,
134
+ + images: ImageInput | list[ImageInput] | None = None,
135
+ + return_tensors: str | TensorType | None = None,
136
+ + **kwargs,
137
+ + ) -> Mapping[str, NestedTensors]:
138
+ + if text is None:
139
+ + text_list: list[str] = []
140
+ + elif isinstance(text, list):
141
+ + text_list = list(text)
142
+ + else:
143
+ + text_list = [text]
144
+ +
145
+ + if images is None:
146
+ + images = []
147
+ + if not isinstance(images, list):
148
+ + images = [images]
149
+ +
150
+ + if not images:
151
+ + if not text_list:
152
+ + return BatchFeature(dict(input_ids=torch.empty((0, 0), dtype=torch.long)))
153
+ +
154
+ + encoded = [
155
+ + self.tokenizer.encode(t, add_special_tokens=False)
156
+ + for t in text_list
157
+ + ]
158
+ + max_len = max(len(ids) for ids in encoded) if encoded else 0
159
+ + pad_id = getattr(self.tokenizer, "pad_token_id", 0) or 0
160
+ + input_ids = torch.full((len(encoded), max_len), pad_id, dtype=torch.long)
161
+ + for i, ids in enumerate(encoded):
162
+ + if ids:
163
+ + input_ids[i, :len(ids)] = torch.tensor(ids, dtype=torch.long)
164
+ +
165
+ + return BatchFeature(dict(input_ids=input_ids))
166
+ +
167
+ + pixel_values: list[torch.Tensor] = []
168
+ + image_sizes: list[tuple[int, int]] = []
169
+ +
170
+ + for image in images:
171
+ + if hasattr(image, "media"):
172
+ + image = image.media
173
+ +
174
+ + image_inputs = self.image_processor(ImageChunk(image=image))
175
+ + processed_image = torch.tensor(image_inputs.image)
176
+ + pixel_values.append(processed_image)
177
+ + image_sizes.append((processed_image.shape[1], processed_image.shape[2]))
178
+ +
179
+ + input_ids = torch.empty((len(text_list) or 1, 0), dtype=torch.long)
180
+ +
181
+ + return BatchFeature(
182
+ + dict(
183
+ + input_ids=input_ids,
184
+ + pixel_values=pixel_values,
185
+ + image_sizes=image_sizes,
186
+ + )
187
+ + )
188
+ +
189
+ +
190
+ +class WhaleyeProcessingInfo(BaseProcessingInfo):
191
+ +
192
+ + def get_tokenizer(self) -> TokenizerLike:
193
+ + return cached_tokenizer_from_config(self.ctx.model_config)
194
+ +
195
+ + @cached_property
196
+ + def _vision_config(self):
197
+ + vision_cfg = self.ctx.model_config.hf_config.vision_config
198
+ + # vision_config may be a dict or a config object depending on how it was loaded
199
+ + if isinstance(vision_cfg, dict):
200
+ + return vision_cfg
201
+ + return vision_cfg.to_dict() if hasattr(vision_cfg, "to_dict") else vision_cfg
202
+ +
203
+ + def _get_vision_value(self, key: str, default=None):
204
+ + """Get a value from vision_config, handling both dict and object."""
205
+ + vision_cfg = self._vision_config
206
+ + if isinstance(vision_cfg, dict):
207
+ + return vision_cfg.get(key, default)
208
+ + return getattr(vision_cfg, key, default)
209
+ +
210
+ + @cached_property
211
+ + def _image_encoder(self) -> ImageEncoder:
212
+ + hf_config = self.ctx.model_config.hf_config
213
+ +
214
+ + # Get image_size from vision_config, with fallback to max_image_size
215
+ + image_size = self._get_vision_value("max_image_size")
216
+ + if image_size is None:
217
+ + image_size = getattr(hf_config, "max_image_size", None)
218
+ + if image_size is None:
219
+ + image_size = self._get_vision_value("image_size")
220
+ + image_size = int(image_size)
221
+ +
222
+ + patch_size = int(self._get_vision_value("patch_size"))
223
+ +
224
+ + spatial_merge_size = getattr(hf_config, "spatial_merge_size", None)
225
+ + if spatial_merge_size is None:
226
+ + spatial_merge_size = self._get_vision_value("spatial_merge_size", 1)
227
+ + spatial_merge_size = int(spatial_merge_size)
228
+ +
229
+ + image_config = ImageConfig(
230
+ + image_patch_size=patch_size,
231
+ + max_image_size=image_size,
232
+ + spatial_merge_size=spatial_merge_size,
233
+ + )
234
+ +
235
+ + special_ids = SpecialImageIDs(
236
+ + img=int(self._get_vision_value("image_token_id")),
237
+ + img_break=int(self._get_vision_value("image_break_token_id")),
238
+ + img_end=int(self._get_vision_value("image_end_token_id")),
239
+ + )
240
+ +
241
+ + return ImageEncoder(image_config=image_config, special_ids=special_ids)
242
+ +
243
+ + def get_hf_processor(self, **kwargs: object) -> WhaleyeProcessorAdapter:
244
+ + return WhaleyeProcessorAdapter(self.get_tokenizer(), self._image_encoder)
245
+ +
246
+ + def get_supported_mm_limits(self) -> Mapping[str, int | None]:
247
+ + return {"image": None}
248
+ +
249
+ + def get_num_image_tokens(
250
+ + self,
251
+ + *,
252
+ + image_width: int,
253
+ + image_height: int,
254
+ + processor: WhaleyeProcessorAdapter | None = None,
255
+ + ) -> int:
256
+ + if processor is None:
257
+ + processor = self.get_hf_processor()
258
+ +
259
+ + ncols, nrows = processor.image_processor._image_to_num_tokens(
260
+ + Image.new("RGB", (image_width, image_height))
261
+ + )
262
+ + return ncols * nrows
263
+ +
264
+ + def get_image_size_with_most_features(self) -> ImageSize:
265
+ + cfg = self._image_encoder.image_config
266
+ + return ImageSize(width=cfg.max_image_size, height=cfg.max_image_size)
267
+ +
268
+ +
269
+ +class WhaleyeDummyInputsBuilder(BaseDummyInputsBuilder[WhaleyeProcessingInfo]):
270
+ +
271
+ + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
272
+ + return ""
273
+ +
274
+ + def get_dummy_mm_data(
275
+ + self,
276
+ + seq_len: int,
277
+ + mm_counts: Mapping[str, int],
278
+ + mm_options: Mapping[str, BaseDummyOptions] | None = None,
279
+ + ) -> MultiModalDataDict:
280
+ + num_images = mm_counts.get("image", 0)
281
+ + target_width, target_height = self.info.get_image_size_with_most_features()
282
+ + image_overrides = mm_options.get("image") if mm_options else None
283
+ + return {
284
+ + "image": self._get_dummy_images(
285
+ + width=target_width,
286
+ + height=target_height,
287
+ + num_images=num_images,
288
+ + overrides=image_overrides,
289
+ + )
290
+ + }
291
+ +
292
+ + def get_dummy_processor_inputs(
293
+ + self,
294
+ + seq_len: int,
295
+ + mm_counts: Mapping[str, int],
296
+ + mm_options: Mapping[str, BaseDummyOptions] | None = None,
297
+ + ) -> ProcessorInputs:
298
+ + num_images = mm_counts.get("image", 0)
299
+ + dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
300
+ +
301
+ + processor = self.info.get_hf_processor()
302
+ + image_token_id = processor.image_token_id
303
+ +
304
+ + dummy_tokens = [image_token_id] * num_images
305
+ +
306
+ + return ProcessorInputs(
307
+ + prompt=dummy_tokens,
308
+ + mm_data=dummy_mm_data,
309
+ + tokenization_kwargs={"truncation": False},
310
+ + )
311
+ +
312
+ +
313
+ +class WhaleyeMultiModalProcessor(BaseMultiModalProcessor[WhaleyeProcessingInfo]):
314
+ +
315
+ + def _call_hf_processor(
316
+ + self,
317
+ + prompt: str,
318
+ + mm_data: Mapping[str, object],
319
+ + mm_kwargs: Mapping[str, object],
320
+ + tok_kwargs: Mapping[str, object],
321
+ + ) -> BatchFeature:
322
+ + processed_outputs = super()._call_hf_processor(
323
+ + prompt=prompt,
324
+ + mm_data=mm_data,
325
+ + mm_kwargs=mm_kwargs,
326
+ + tok_kwargs=tok_kwargs,
327
+ + )
328
+ +
329
+ + pixel_values = processed_outputs.get("pixel_values")
330
+ + if pixel_values is not None:
331
+ + image_sizes = processed_outputs.get("image_sizes")
332
+ + if isinstance(pixel_values, list) and image_sizes is not None:
333
+ + assert len(pixel_values) == len(image_sizes)
334
+ + processed_outputs["images"] = [
335
+ + p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
336
+ + ]
337
+ + else:
338
+ + processed_outputs["images"] = pixel_values
339
+ + processed_outputs.pop("pixel_values", None)
340
+ +
341
+ + return processed_outputs
342
+ +
343
+ + def _get_mm_fields_config(
344
+ + self,
345
+ + hf_inputs: Mapping[str, NestedTensors],
346
+ + hf_processor_mm_kwargs: Mapping[str, object],
347
+ + ) -> Mapping[str, MultiModalFieldConfig]:
348
+ + return dict(images=MultiModalFieldConfig.batched("image"))
349
+ +
350
+ + def _get_prompt_updates(
351
+ + self,
352
+ + mm_items: MultiModalDataItems,
353
+ + hf_processor_mm_kwargs: Mapping[str, object],
354
+ + out_mm_kwargs: MultiModalKwargsItems,
355
+ + ) -> Sequence[PromptUpdate]:
356
+ + processor = self.info.get_hf_processor()
357
+ + image_token_id = processor.image_token_id
358
+ + image_break_id = processor.image_break_id
359
+ + image_end_id = processor.image_end_id
360
+ +
361
+ + def get_replacement(item_idx: int):
362
+ + images = mm_items.get_items("image", ImageProcessorItems)
363
+ + image_size = images.get_image_size(item_idx)
364
+ +
365
+ + ncols, nrows = processor.image_processor._image_to_num_tokens(
366
+ + Image.new("RGB", (image_size.width, image_size.height))
367
+ + )
368
+ +
369
+ + tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
370
+ + tokens[-1] = image_end_id
371
+ +
372
+ + return PromptUpdateDetails.select_token_id(tokens, image_token_id)
373
+ +
374
+ + return [
375
+ + PromptReplacement(
376
+ + modality="image",
377
+ + target=[image_token_id],
378
+ + replacement=get_replacement,
379
+ + ),
380
+ + ]
381
+ +
382
+ + def _cached_apply_hf_processor(
383
+ + self,
384
+ + prompt: str | list[int],
385
+ + mm_data_items: MultiModalDataItems,
386
+ + hf_processor_mm_kwargs: Mapping[str, object],
387
+ + tokenization_kwargs: Mapping[str, object],
388
+ + mm_uuids: MultiModalUUIDDict | None = None,
389
+ + ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
390
+ + prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
391
+ + prompt=prompt,
392
+ + mm_data_items=mm_data_items,
393
+ + hf_processor_mm_kwargs=hf_processor_mm_kwargs,
394
+ + tokenization_kwargs=tokenization_kwargs,
395
+ + mm_uuids=mm_uuids,
396
+ + )
397
+ + return prompt_ids, mm_info, False
398
+ +
399
+ +
400
+ +@MULTIMODAL_REGISTRY.register_processor(
401
+ + WhaleyeMultiModalProcessor,
402
+ + info=WhaleyeProcessingInfo,
403
+ + dummy_inputs=WhaleyeDummyInputsBuilder,
404
+ +)
405
+ +class WhaleyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
406
+ +
407
+ + @classmethod
408
+ + def get_placeholder_str(cls, modality: str, i: int) -> str | None:
409
+ + if modality.startswith("image"):
410
+ + return "<|img|>"
411
+ + raise ValueError("Only image modality is supported")
412
+ +
413
+ + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
414
+ + super().__init__()
415
+ + config = vllm_config.model_config.hf_config
416
+ + multimodal_config = vllm_config.model_config.multimodal_config
417
+ + self.config = config
418
+ + self.multimodal_config = multimodal_config
419
+ +
420
+ + # Build vision encoder args from vision_config
421
+ + vision_config = config.vision_config
422
+ + # vision_config may be a dict or a config object
423
+ + if isinstance(vision_config, dict):
424
+ + vision_config_dict = vision_config
425
+ + else:
426
+ + vision_config_dict = vision_config.to_dict()
427
+ + dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
428
+ + vision_args_dict = {
429
+ + key: value
430
+ + for key, value in vision_config_dict.items()
431
+ + if key in dataclass_fields
432
+ + }
433
+ + self.vision_args = VisionEncoderArgs(**vision_args_dict)
434
+ +
435
+ + # Initialize DeepSeek V3.2 language model
436
+ + # Uses flat config (hf_config itself has all DeepSeek fields at top level)
437
+ + self.language_model = init_vllm_registered_model(
438
+ + vllm_config=vllm_config,
439
+ + hf_config=config, # flat config with DeepSeek fields
440
+ + architectures=["DeepseekV3ForCausalLM"],
441
+ + prefix=maybe_prefix(prefix, "language_model"),
442
+ + )
443
+ +
444
+ + # Initialize vision components (from Pixtral)
445
+ + if multimodal_config.get_limit_per_prompt("image"):
446
+ + self.vision_encoder = VisionTransformer(self.vision_args)
447
+ + self.pre_mm_projector_norm = (
448
+ + RMSNorm(self.vision_args.hidden_size, eps=1e-5)
449
+ + if self.vision_args.add_pre_mm_projector_layer_norm
450
+ + else None
451
+ + )
452
+ + self.patch_merger = (
453
+ + PatchMerger(
454
+ + vision_encoder_dim=self.vision_args.hidden_size,
455
+ + spatial_merge_size=self.vision_args.spatial_merge_size,
456
+ + use_mlp_bias=False,
457
+ + )
458
+ + if self.vision_args.mm_projector_id == PATCH_MERGE
459
+ + else None
460
+ + )
461
+ + # Use hidden_size from top-level config (DeepSeek LM hidden size)
462
+ + self.vision_language_adapter = VisionLanguageAdapter(
463
+ + self.vision_args, dim=config.hidden_size
464
+ + )
465
+ + else:
466
+ + self.vision_encoder = None
467
+ + self.pre_mm_projector_norm = None
468
+ + self.patch_merger = None
469
+ + self.vision_language_adapter = None
470
+ +
471
+ + self.make_empty_intermediate_tensors = (
472
+ + self.language_model.make_empty_intermediate_tensors
473
+ + )
474
+ +
475
+ + def _parse_and_validate_image_input(
476
+ + self, **kwargs: object
477
+ + ) -> PixtralImagePixelInputs | None:
478
+ + images = kwargs.pop("images", None)
479
+ + if images is None:
480
+ + return None
481
+ +
482
+ + return PixtralImagePixelInputs(
483
+ + type="pixel_values",
484
+ + images=images,
485
+ + )
486
+ +
487
+ + def _process_image_input(
488
+ + self,
489
+ + image_input: PixtralImagePixelInputs,
490
+ + ) -> tuple[torch.Tensor, ...]:
491
+ + assert (
492
+ + self.vision_encoder is not None and self.vision_language_adapter is not None
493
+ + )
494
+ +
495
+ + images = image_input["images"]
496
+ + image_features = self.vision_encoder(images)
497
+ + feature_sizes = [image_feature.shape[0] for image_feature in image_features]
498
+ + image_features = torch.cat(image_features)
499
+ + if self.pre_mm_projector_norm is not None:
500
+ + image_features = self.pre_mm_projector_norm(image_features)
501
+ + if self.patch_merger is not None:
502
+ + patch_size = self.vision_args.patch_size
503
+ + spatial_merge_size_square = self.vision_args.spatial_merge_size**2
504
+ + img_patch_dims = [
505
+ + (img.shape[1] // patch_size, img.shape[2] // patch_size)
506
+ + for img in images
507
+ + ]
508
+ + feature_sizes = [
509
+ + feature_size // spatial_merge_size_square
510
+ + for feature_size in feature_sizes
511
+ + ]
512
+ + image_features = self.patch_merger(
513
+ + image_features, image_sizes=img_patch_dims
514
+ + )
515
+ + image_embeds = self.vision_language_adapter(image_features)
516
+ + image_embeds = torch.split(image_embeds, feature_sizes)
517
+ + return image_embeds
518
+ +
519
+ + def get_language_model(self) -> nn.Module:
520
+ + return self.language_model
521
+ +
522
+ + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
523
+ + image_input = self._parse_and_validate_image_input(**kwargs)
524
+ + if image_input is None:
525
+ + return []
526
+ +
527
+ + return self._process_image_input(image_input)
528
+ +
529
+ + def forward(
530
+ + self,
531
+ + input_ids: torch.Tensor,
532
+ + positions: torch.Tensor,
533
+ + intermediate_tensors: IntermediateTensors | None = None,
534
+ + inputs_embeds: torch.Tensor | None = None,
535
+ + **kwargs: object,
536
+ + ) -> torch.Tensor | IntermediateTensors:
537
+ + """Run forward pass for Whaleye."""
538
+ + if intermediate_tensors is not None:
539
+ + inputs_embeds = None
540
+ +
541
+ + hidden_states = self.language_model.model(
542
+ + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
543
+ + )
544
+ +
545
+ + return hidden_states
546
+ +
547
+ + def compute_logits(
548
+ + self,
549
+ + hidden_states: torch.Tensor,
550
+ + ) -> torch.Tensor | None:
551
+ + return self.language_model.compute_logits(hidden_states)
552
+ +
553
+ + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
554
+ + """Load weights for vision components and language model."""
555
+ +
556
+ + def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
557
+ + return weight[0].startswith("vision_encoder")
558
+ +
559
+ + def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]):
560
+ + return weight[0].startswith("vision_language_adapter")
561
+ +
562
+ + def is_patch_merger(weight: tuple[str, torch.Tensor]):
563
+ + return weight[0].startswith("patch_merger")
564
+ +
565
+ + def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]):
566
+ + return weight[0].startswith("pre_mm_projector_norm")
567
+ +
568
+ + # Get references to parameters for direct loading
569
+ + vision_encoder_dict = (
570
+ + dict(self.vision_encoder.named_parameters())
571
+ + if self.vision_encoder is not None
572
+ + else {}
573
+ + )
574
+ + patch_merger_dict = (
575
+ + dict(self.patch_merger.named_parameters())
576
+ + if self.patch_merger is not None
577
+ + else {}
578
+ + )
579
+ + pre_mm_projector_norm_dict = (
580
+ + dict(self.pre_mm_projector_norm.named_parameters())
581
+ + if self.pre_mm_projector_norm is not None
582
+ + else {}
583
+ + )
584
+ + vision_lang_adapter_dict = (
585
+ + dict(self.vision_language_adapter.named_parameters())
586
+ + if self.vision_language_adapter is not None
587
+ + else {}
588
+ + )
589
+ +
590
+ + def llm_weights_generator():
591
+ + # Single pass over weights
592
+ + for name, w in weights:
593
+ + if is_vision_encoder_weights((name, w)):
594
+ + if self.vision_encoder is None:
595
+ + continue
596
+ + # Load vision encoder weights directly
597
+ + trimmed_name = ".".join(name.split(".")[1:])
598
+ + param = vision_encoder_dict[trimmed_name]
599
+ + with torch.no_grad():
600
+ + default_weight_loader(param, w)
601
+ + elif is_patch_merger((name, w)):
602
+ + if self.patch_merger is None:
603
+ + continue
604
+ + # Load vision patch merger weights directly
605
+ + trimmed_name = ".".join(name.split(".")[1:])
606
+ + param = patch_merger_dict[trimmed_name]
607
+ + with torch.no_grad():
608
+ + default_weight_loader(param, w)
609
+ + elif is_pre_mm_projector_norm((name, w)):
610
+ + if self.pre_mm_projector_norm is None:
611
+ + continue
612
+ + # Load vision pre_mm_projector_norm weights directly
613
+ + trimmed_name = ".".join(name.split(".")[1:])
614
+ + param = pre_mm_projector_norm_dict[trimmed_name]
615
+ + with torch.no_grad():
616
+ + default_weight_loader(param, w)
617
+ + elif is_vision_lang_adapter_weights((name, w)):
618
+ + if self.vision_language_adapter is None:
619
+ + continue
620
+ + # Load vision-language adapter weights directly
621
+ + trimmed_name = ".".join(name.split(".")[1:])
622
+ + param = vision_lang_adapter_dict[trimmed_name]
623
+ + with torch.no_grad():
624
+ + default_weight_loader(param, w)
625
+ + else:
626
+ + # LLM weights: yield them to be loaded
627
+ + # by language_model.load_weights
628
+ + yield (name, w)
629
+ +
630
+ + # Now we call the language model load with the generator
631
+ + self.language_model.load_weights(llm_weights_generator())