|
|
import torch |
|
|
from torch import nn, Tensor |
|
|
|
|
|
from .blocks import LayerNorm, Transformer |
|
|
|
|
|
|
|
|
class CLIPTextEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_dim: int, |
|
|
context_length: int, |
|
|
vocab_size: int, |
|
|
transformer_width: int, |
|
|
transformer_heads: int, |
|
|
transformer_layers: int, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.context_length = context_length |
|
|
self.token_embedding = nn.Embedding(vocab_size, transformer_width) |
|
|
self.transformer = Transformer( |
|
|
width=transformer_width, |
|
|
layers=transformer_layers, |
|
|
heads=transformer_heads, |
|
|
attn_mask=self.build_attention_mask(), |
|
|
) |
|
|
self.vocab_size = vocab_size |
|
|
self.token_embedding = nn.Embedding(vocab_size, transformer_width) |
|
|
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) |
|
|
self.ln_final = LayerNorm(transformer_width) |
|
|
|
|
|
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) |
|
|
|
|
|
def build_attention_mask(self): |
|
|
|
|
|
|
|
|
mask = torch.empty(self.context_length, self.context_length) |
|
|
mask.fill_(float("-inf")) |
|
|
mask.triu_(1) |
|
|
return mask |
|
|
|
|
|
@property |
|
|
def dtype(self): |
|
|
return self.transformer.resblocks[0].attn.in_proj_weight.dtype |
|
|
|
|
|
def forward(self, text: Tensor): |
|
|
x = self.token_embedding(text).type(self.dtype) |
|
|
x = x + self.positional_embedding.type(self.dtype) |
|
|
x = x.permute(1, 0, 2) |
|
|
x = self.transformer(x) |
|
|
x = x.permute(1, 0, 2) |
|
|
x = self.ln_final(x).type(self.dtype) |
|
|
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection |
|
|
return x |
|
|
|