gr-lite / modeling_gr_lite.py
pierrexsq's picture
Add HuggingFace-compatible model (config + safetensors + model code)
39d5d4e verified
Raw
History Blame Contribute Delete
5.36 kB
"""GR-Lite: Fashion image retrieval model based on DINOv2 ViT-L/16."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPooling
from .configuration_gr_lite import GRLiteConfig
class GRLiteEmbeddings(nn.Module):
def __init__(self, config: GRLiteConfig):
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.register_tokens = nn.Parameter(
torch.zeros(1, config.num_register_tokens, config.hidden_size)
)
self.patch_embeddings = nn.Conv2d(
config.num_channels,
config.hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size,
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
B = pixel_values.shape[0]
x = self.patch_embeddings(pixel_values) # [B, D, H/P, W/P]
x = x.flatten(2).transpose(1, 2) # [B, N, D]
cls = self.cls_token.expand(B, -1, -1)
reg = self.register_tokens.expand(B, -1, -1)
x = torch.cat([cls, reg, x], dim=1) # [B, 1+R+N, D]
return x
class GRLiteAttention(nn.Module):
def __init__(self, config: GRLiteConfig):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.scale = self.head_dim**-0.5
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.k_bias)
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, _ = x.shape
q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
return self.o_proj(x)
class GRLiteLayerScale(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.lambda1 = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.lambda1
class GRLiteMLP(nn.Module):
def __init__(self, config: GRLiteConfig):
super().__init__()
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size)
self.act = nn.GELU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act(self.up_proj(x)))
class GRLiteLayer(nn.Module):
def __init__(self, config: GRLiteConfig):
super().__init__()
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = GRLiteAttention(config)
self.layer_scale1 = GRLiteLayerScale(config.hidden_size)
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = GRLiteMLP(config)
self.layer_scale2 = GRLiteLayerScale(config.hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.layer_scale1(self.attention(self.norm1(x)))
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
return x
class GRLitePreTrainedModel(PreTrainedModel):
config_class = GRLiteConfig
base_model_prefix = "model"
supports_gradient_checkpointing = False
class GRLiteModel(GRLitePreTrainedModel):
"""GR-Lite: Fashion image retrieval model (DINOv2 ViT-L/16 backbone).
Produces L2-normalized 1024-dim image embeddings.
"""
def __init__(self, config: GRLiteConfig):
super().__init__(config)
self.embeddings = GRLiteEmbeddings(config)
self.layer = nn.ModuleList(
[GRLiteLayer(config) for _ in range(config.num_hidden_layers)]
)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
pixel_values: torch.Tensor,
output_hidden_states: bool = False,
return_dict: bool = True,
**kwargs,
):
x = self.embeddings(pixel_values)
hidden_states = () if output_hidden_states else None
for layer_module in self.layer:
if output_hidden_states:
hidden_states += (x,)
x = layer_module(x)
x = self.norm(x)
if output_hidden_states:
hidden_states += (x,)
# CLS token embedding, L2-normalized
pooled = F.normalize(x[:, 0], p=2, dim=-1)
if not return_dict:
return (x, pooled, hidden_states) if output_hidden_states else (x, pooled)
return BaseModelOutputWithPooling(
last_hidden_state=x,
pooler_output=pooled,
hidden_states=hidden_states,
)