f2p_decoder / modeling_f2p_decoder.py
toilaluan's picture
Upload folder using huggingface_hub
09b2c2d verified
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
from transformers.modeling_outputs import ModelOutput
from transformers.modeling_utils import PreTrainedModel
try:
from .configuration_f2p_decoder import F2PDecoderConfig
from .decoder import GeneralDecoder
except ImportError:
from configuration_f2p_decoder import F2PDecoderConfig
from decoder import GeneralDecoder
@dataclass
class F2PDecoderOutput(ModelOutput):
reconstruction: torch.FloatTensor = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class F2PDecoderModel(PreTrainedModel):
"""Feature-to-pixel decoder for SigLIP2 patch features."""
config_class = F2PDecoderConfig
base_model_prefix = "f2p_decoder"
main_input_name = "hidden_states"
supports_gradient_checkpointing = True
def __init__(self, config: F2PDecoderConfig):
super().__init__(config)
image_mean = torch.tensor(config.image_mean, dtype=torch.float32).view(
1, config.num_channels, 1, 1
)
image_std = torch.tensor(config.image_std, dtype=torch.float32).view(
1, config.num_channels, 1, 1
)
self.register_buffer("image_mean", image_mean)
self.register_buffer("image_std", image_std)
self.decoder = GeneralDecoder(config, num_patches=config.num_patches)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GeneralDecoder):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: Optional[torch.Tensor] = None,
zs: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if hidden_states is None:
hidden_states = zs
if hidden_states is None:
raise ValueError("Pass SigLIP2 features as hidden_states or zs.")
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
decoder_output = self.decoder(
hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
drop_cls_token=self.config.drop_cls_token,
)
reconstruction = self.decoder.unpatchify(decoder_output.logits)
reconstruction = reconstruction * self.image_std + self.image_mean
if return_dict:
return F2PDecoderOutput(
reconstruction=reconstruction,
logits=decoder_output.logits,
hidden_states=decoder_output.hidden_states,
attentions=decoder_output.attentions,
)
return reconstruction
@torch.no_grad()
def infer(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.forward(hidden_states, return_dict=False)