| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass |
| from typing import Optional, Tuple |
| import torch |
| import torch.nn as nn |
| from transformers import SiglipImageProcessor, SiglipVisionModel |
| from transformers.utils import ModelOutput |
| import numpy as np |
|
|
| PRECISION_TO_TYPE = { |
| 'fp32': torch.float32, |
| 'fp16': torch.float16, |
| 'bf16': torch.bfloat16, |
| } |
|
|
|
|
| VISION_ENCODER_PATH = {} |
|
|
|
|
| def use_default(value, default): |
| return value if value is not None else default |
|
|
|
|
| def load_vision_encoder( |
| vision_encoder_type, |
| vision_encoder_precision=None, |
| vision_encoder_path=None, |
| logger=None, |
| device=None, |
| ): |
| if vision_encoder_path is None: |
| vision_encoder_path = VISION_ENCODER_PATH[vision_encoder_type] |
|
|
| if vision_encoder_type == "siglip": |
| vision_encoder = SiglipVisionModel.from_pretrained( |
| vision_encoder_path, |
| |
| ) |
| else: |
| raise ValueError(f"Unsupported vision encoder type: {vision_encoder_type}") |
|
|
| |
| if vision_encoder_precision is not None: |
| vision_encoder = vision_encoder.to(dtype=PRECISION_TO_TYPE[vision_encoder_precision]) |
|
|
| vision_encoder.requires_grad_(False) |
|
|
| if device is not None: |
| vision_encoder = vision_encoder.to(device) |
|
|
| return vision_encoder, vision_encoder_path |
|
|
|
|
| def load_image_processor( |
| processor_type, |
| processor_path=None, |
| logger=None |
| ): |
| if processor_path is None: |
| processor_path = VISION_ENCODER_PATH[processor_type] |
|
|
| if processor_type == "siglip": |
| processor = SiglipImageProcessor.from_pretrained( |
| processor_path, |
| |
| ) |
| else: |
| raise ValueError(f"Unsupported processor type: {processor_type}") |
|
|
| return processor, processor_path |
|
|
|
|
| @dataclass |
| class VisionEncoderModelOutput(ModelOutput): |
| """ |
| Base class for vision encoder model's outputs. |
| |
| Args: |
| last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| Sequence of hidden-states at the output of the last layer of the model. |
| pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*): |
| Last layer hidden-state of the first token of the sequence (classification token) |
| after further processing through the layers used for the auxiliary pretraining task. |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed): |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
| """ |
|
|
| last_hidden_state: torch.FloatTensor = None |
| pooler_output: Optional[torch.FloatTensor] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
|
|
| class VisionEncoder(nn.Module): |
| def __init__( |
| self, |
| vision_encoder_type: str, |
| vision_encoder_precision: Optional[str] = None, |
| vision_encoder_path: Optional[str] = None, |
| processor_type: Optional[str] = None, |
| processor_path: Optional[str] = None, |
| output_key: Optional[str] = None, |
| logger=None, |
| device=None, |
| ): |
| super().__init__() |
| self.vision_encoder_type = vision_encoder_type |
| self.precision = vision_encoder_precision |
| self.model_path = vision_encoder_path |
| self.processor_type = ( |
| processor_type if processor_type is not None else vision_encoder_type |
| ) |
| self.processor_path = ( |
| processor_path if processor_path is not None else vision_encoder_path |
| ) |
| self.logger = logger |
|
|
| if "siglip" in vision_encoder_type: |
| self.output_key = output_key or "last_hidden_state" |
| else: |
| raise ValueError(f"Unsupported vision encoder type: {vision_encoder_type}") |
|
|
| self.model, self.model_path = load_vision_encoder( |
| vision_encoder_type=self.vision_encoder_type, |
| vision_encoder_precision=self.precision, |
| vision_encoder_path=self.model_path, |
| logger=self.logger, |
| device=device, |
| ) |
| self.dtype = self.model.dtype |
| self.device = self.model.device |
|
|
| self.processor, self.processor_path = load_image_processor( |
| processor_type=self.processor_type, |
| processor_path=self.processor_path, |
| logger=self.logger, |
| ) |
|
|
| def __repr__(self): |
| return f"{self.vision_encoder_type} ({self.precision} - {self.model_path})" |
|
|
| def encode_latents_to_images(self, latents, vae, reorg_token=False): |
| """ |
| Convert latents to images using VAE decoder. |
| |
| Args: |
| latents: Input latents tensor |
| vae: VAE model for decoding |
| reorg_token: Whether to reorg the token |
| Returns: |
| images: Decoded images as numpy array |
| """ |
| |
| first_image_latents = latents[:, :, 0, ...] if len(latents.shape) == 5 else latents |
| first_image_latents = 1 / vae.config.scaling_factor * first_image_latents |
| first_image = vae.decode(first_image_latents.unsqueeze(2).to(vae.dtype), return_dict=False)[0].cpu() |
| first_image = first_image[:, :, 0, :, :] |
| first_image = (first_image / 2 + 0.5).clamp(0, 1) |
| first_image = (first_image * 255.0).clamp(0, 255.0) |
| first_image = first_image.to(torch.uint8).numpy() |
| first_image = first_image.transpose(0, 2, 3, 1) |
|
|
| assert isinstance(first_image, np.ndarray) |
| assert first_image.ndim == 4 and first_image.shape[3] == 3 |
| assert first_image.dtype == np.uint8 |
|
|
| return first_image |
|
|
| def encode_images(self, images): |
| """ |
| Encode images using the vision encoder. |
| |
| Args: |
| images: Input images (numpy array or preprocessed tensor) |
| |
| Returns: |
| VisionEncoderModelOutput with encoded features |
| """ |
| if isinstance(images, np.ndarray): |
| |
| preprocessed = self.processor.preprocess(images=images, return_tensors="pt").to( |
| device=self.model.device, dtype=self.model.dtype) |
| else: |
| |
| preprocessed = images |
|
|
| outputs = self.model(**preprocessed) |
| |
| return VisionEncoderModelOutput( |
| last_hidden_state=outputs.last_hidden_state, |
| pooler_output=outputs.pooler_output if hasattr(outputs, 'pooler_output') else None, |
| hidden_states=outputs.hidden_states if hasattr(outputs, 'hidden_states') else None |
| ) |
|
|
| def encode_latents(self, latents, vae, reorg_token=False): |
| """ |
| Encode latents by first converting to images, then encoding. |
| This is the main function that replaces sigclip_vision_encode. |
| |
| Args: |
| latents: Input latent tensors |
| vae: VAE model for decoding latents to images |
| |
| Returns: |
| Encoded image features |
| """ |
| |
| images = self.encode_latents_to_images(latents, vae, reorg_token) |
| |
| |
| outputs = self.encode_images(images) |
| |
| return outputs.last_hidden_state |
|
|
| def forward(self, images): |
| """ |
| Forward pass for direct image encoding. |
| |
| Args: |
| images: Input images |
| |
| Returns: |
| VisionEncoderModelOutput with encoded features |
| """ |
| return self.encode_images(images) |