from dataclasses import dataclass from typing import Optional, Tuple import torch from torch import nn from transformers.modeling_outputs import ModelOutput from transformers.modeling_utils import PreTrainedModel try: from .configuration_f2p_decoder import F2PDecoderConfig from .decoder import GeneralDecoder except ImportError: from configuration_f2p_decoder import F2PDecoderConfig from decoder import GeneralDecoder @dataclass class F2PDecoderOutput(ModelOutput): reconstruction: torch.FloatTensor = None logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None class F2PDecoderModel(PreTrainedModel): """Feature-to-pixel decoder for SigLIP2 patch features.""" config_class = F2PDecoderConfig base_model_prefix = "f2p_decoder" main_input_name = "hidden_states" supports_gradient_checkpointing = True def __init__(self, config: F2PDecoderConfig): super().__init__(config) image_mean = torch.tensor(config.image_mean, dtype=torch.float32).view( 1, config.num_channels, 1, 1 ) image_std = torch.tensor(config.image_std, dtype=torch.float32).view( 1, config.num_channels, 1, 1 ) self.register_buffer("image_mean", image_mean) self.register_buffer("image_std", image_std) self.decoder = GeneralDecoder(config, num_patches=config.num_patches) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, GeneralDecoder): module.gradient_checkpointing = value def forward( self, hidden_states: Optional[torch.Tensor] = None, zs: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): if hidden_states is None: hidden_states = zs if hidden_states is None: raise ValueError("Pass SigLIP2 features as hidden_states or zs.") output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) decoder_output = self.decoder( hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, drop_cls_token=self.config.drop_cls_token, ) reconstruction = self.decoder.unpatchify(decoder_output.logits) reconstruction = reconstruction * self.image_std + self.image_mean if return_dict: return F2PDecoderOutput( reconstruction=reconstruction, logits=decoder_output.logits, hidden_states=decoder_output.hidden_states, attentions=decoder_output.attentions, ) return reconstruction @torch.no_grad() def infer(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.forward(hidden_states, return_dict=False)