Prior2DSM / src /dinov3 /eval /text /text_tower.py
osherr's picture
Upload 222 files
bc90483 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This software may be used and distributed in accordance with
# the terms of the DINOv3 License Agreement.
import logging
from typing import Optional
import torch
from dinov3.eval.text.text_transformer import TextTransformer
from dinov3.layers import CausalSelfAttentionBlock
from torch import nn
logger = logging.getLogger("dinov3")
class TextHead(nn.Module):
def __init__(
self,
input_dim: int,
embed_dim: int,
num_heads: int,
num_blocks: int,
block_drop_prob: float,
is_causal: bool,
use_linear_projection: bool,
):
super().__init__()
block_list = [nn.Identity()]
self.ln_final = nn.Identity()
if num_blocks > 0:
logger.info(f"Adding {num_blocks} text tower transformer head blocks")
block_list = [
CausalSelfAttentionBlock(
dim=input_dim,
num_heads=num_heads,
is_causal=is_causal,
dropout_prob=block_drop_prob,
)
for _ in range(num_blocks)
]
self.ln_final = nn.LayerNorm(input_dim)
self.blocks = nn.ModuleList(block_list)
self.num_blocks = num_blocks
self.linear_projection = nn.Identity()
if input_dim != embed_dim or use_linear_projection:
logger.info(
f"Text tower : Using a linear projection from {input_dim} to {embed_dim}"
)
self.linear_projection = nn.Linear(input_dim, embed_dim, bias=False)
def init_weights(self):
if self.num_blocks > 0:
for i in range(self.num_blocks):
self.blocks[i].init_weights()
self.ln_final.reset_parameters()
if isinstance(self.linear_projection, nn.Linear):
nn.init.normal_(
self.linear_projection.weight,
std=self.linear_projection.in_features**-0.5,
)
def forward(self, text_tokens: torch.Tensor) -> torch.Tensor:
for block in self.blocks:
text_tokens = block(text_tokens)
text_tokens = self.ln_final(text_tokens)
return self.linear_projection(text_tokens)
class TextTower(nn.Module):
def __init__(
self,
backbone: nn.Module,
freeze_backbone: bool,
embed_dim: int,
num_head_blocks: int,
head_blocks_is_causal: bool,
head_blocks_block_drop_prob: float,
tokens_pooler_type: str,
use_linear_projection: bool,
):
super().__init__()
self.backbone = backbone
self.freeze_backbone = freeze_backbone
backbone_out_dim = backbone.embed_dim
logger.info(f"Text backbone embedding dimension: {backbone_out_dim}")
self.backbone = backbone
self.head = TextHead(
backbone_out_dim,
embed_dim,
self.backbone.num_heads,
num_head_blocks,
head_blocks_block_drop_prob,
head_blocks_is_causal,
use_linear_projection,
)
self.tokens_pooler_type = tokens_pooler_type
def init_weights(self):
self.backbone.init_weights()
self.head.init_weights()
def forward(self, token_indices: torch.Tensor) -> torch.Tensor:
text_tokens = self.backbone(token_indices)
text_tokens = self.head(text_tokens)
if self.tokens_pooler_type == "first":
features = text_tokens[:, 0]
elif self.tokens_pooler_type == "last":
features = text_tokens[:, -1]
elif self.tokens_pooler_type == "argmax":
assert token_indices is not None
features = text_tokens[
torch.arange(text_tokens.shape[0]), token_indices.argmax(dim=-1)
]
else:
raise ValueError(f"Unknown text tokens pooler type: {self.pooler_type}")
return features
def build_text_backbone(
cfg,
) -> torch.nn.Module:
logger.info("Setting up a text transformer")
model = TextTransformer(
context_length=cfg.context_length,
vocab_size=cfg.vocab_size,
dim=cfg.dim,
num_heads=cfg.num_heads,
num_layers=cfg.num_layers,
ffn_ratio=cfg.ffn_ratio,
is_causal=cfg.is_causal,
ls_init_value=cfg.ls_init_value,
dropout_prob=cfg.dropout_prob,
)
logger.info(f"Setting upa custom text transformer {cfg.model_name}")
return model
def build_text_model(
embed_dim: int,
backbone_model_config: str,
freeze_backbone: bool,
num_head_blocks: int,
head_blocks_is_causal: bool,
head_blocks_drop_prob: float,
tokens_pooler_type: str,
use_linear_projection: bool,
backbone: Optional[nn.Module] = None,
):
if backbone is None:
if backbone_model_config is not None:
from omegaconf import OmegaConf
cfg = OmegaConf.load(backbone_model_config)
backbone = build_text_backbone(cfg)
else:
raise RuntimeError(
"Failed to create, text backbone, either backbone or backbone_model_config should be not None"
)
return TextTower(
backbone,
freeze_backbone,
embed_dim,
num_head_blocks,
head_blocks_is_causal,
head_blocks_drop_prob,
tokens_pooler_type,
use_linear_projection,
)