"""GR-Lite: Fashion image retrieval model based on DINOv2 ViT-L/16.""" import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPooling from .configuration_gr_lite import GRLiteConfig class GRLiteEmbeddings(nn.Module): def __init__(self, config: GRLiteConfig): super().__init__() self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.register_tokens = nn.Parameter( torch.zeros(1, config.num_register_tokens, config.hidden_size) ) self.patch_embeddings = nn.Conv2d( config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size, ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: B = pixel_values.shape[0] x = self.patch_embeddings(pixel_values) # [B, D, H/P, W/P] x = x.flatten(2).transpose(1, 2) # [B, N, D] cls = self.cls_token.expand(B, -1, -1) reg = self.register_tokens.expand(B, -1, -1) x = torch.cat([cls, reg, x], dim=1) # [B, 1+R+N, D] return x class GRLiteAttention(nn.Module): def __init__(self, config: GRLiteConfig): super().__init__() self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads self.scale = self.head_dim**-0.5 self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias) self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.k_bias) self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias) self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, _ = x.shape q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) return self.o_proj(x) class GRLiteLayerScale(nn.Module): def __init__(self, dim: int): super().__init__() self.lambda1 = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x * self.lambda1 class GRLiteMLP(nn.Module): def __init__(self, config: GRLiteConfig): super().__init__() self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size) self.act = nn.GELU() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act(self.up_proj(x))) class GRLiteLayer(nn.Module): def __init__(self, config: GRLiteConfig): super().__init__() self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attention = GRLiteAttention(config) self.layer_scale1 = GRLiteLayerScale(config.hidden_size) self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = GRLiteMLP(config) self.layer_scale2 = GRLiteLayerScale(config.hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.layer_scale1(self.attention(self.norm1(x))) x = x + self.layer_scale2(self.mlp(self.norm2(x))) return x class GRLitePreTrainedModel(PreTrainedModel): config_class = GRLiteConfig base_model_prefix = "model" supports_gradient_checkpointing = False class GRLiteModel(GRLitePreTrainedModel): """GR-Lite: Fashion image retrieval model (DINOv2 ViT-L/16 backbone). Produces L2-normalized 1024-dim image embeddings. """ def __init__(self, config: GRLiteConfig): super().__init__(config) self.embeddings = GRLiteEmbeddings(config) self.layer = nn.ModuleList( [GRLiteLayer(config) for _ in range(config.num_hidden_layers)] ) self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, pixel_values: torch.Tensor, output_hidden_states: bool = False, return_dict: bool = True, **kwargs, ): x = self.embeddings(pixel_values) hidden_states = () if output_hidden_states else None for layer_module in self.layer: if output_hidden_states: hidden_states += (x,) x = layer_module(x) x = self.norm(x) if output_hidden_states: hidden_states += (x,) # CLS token embedding, L2-normalized pooled = F.normalize(x[:, 0], p=2, dim=-1) if not return_dict: return (x, pooled, hidden_states) if output_hidden_states else (x, pooled) return BaseModelOutputWithPooling( last_hidden_state=x, pooler_output=pooled, hidden_states=hidden_states, )