zongzhex's picture
Add source code
06acd95 verified
""" CLIP Model
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import copy
import logging
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from functools import partial
from .transformer import (
LayerNormFp32,
LayerNorm,
QuickGELU,
Attention,
VisionTransformer,
TextTransformer,
text_global_pool,
lock_text_tower,
to_2tuple,
)
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)
attn_pooler_queries: int = 256 # n_queries for attentional pooler
attn_pooler_heads: int = 8 # n heads for attentional_pooling
no_ln_pre: bool = False # disable pre transformer LayerNorm
pos_embed_type: str = 'learnable'
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
pool_type: str = 'tok'
output_tokens: bool = False
act_kwargs: Optional[dict] = None
norm_kwargs: Optional[dict] = None
# Custom attention block settings
block_type: Optional[str] = None # attention block type ('default', 'custom'), auto-selects 'custom' if any below features enabled
qk_norm: bool = False # apply layer norm to q and k in attention
scaled_cosine_attn: bool = False # use scaled cosine attention
scale_heads: bool = False # learnable head-specific scale applied to attention logits
scale_attn_inner: bool = False # apply layer norm on attention context, before output projection
scale_attn: bool = False # apply layer norm after full attention block
scale_fc: bool = False # apply layer norm in MLP block
timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
timm_drop: float = 0. # head dropout
timm_drop_path: Optional[float] = None # backbone stochastic depth
@dataclass
class CLIPTextCfg:
context_length: int = 77
vocab_size: int = 49408
hf_tokenizer_name: Optional[str] = None
tokenizer_mode: Optional[str] = None
tokenizer_kwargs: Optional[dict] = None
width: int = 512
heads: int = 8
layers: int = 12
mlp_ratio: float = 4.0
ls_init_value: Optional[float] = None # layer scale initial value
embed_cls: bool = False
pad_id: int = 0
eos_id: int = 2 # only used for when pool_type == 'eos', must match tokenizer eos
no_causal_mask: bool = False # disable causal masking
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
pool_type: str = 'argmax'
proj_bias: bool = False
proj_type: str = 'linear' # control final text projection, 'none' forces no projection
output_tokens: bool = False
act_kwargs: dict = None
norm_kwargs: dict = None
# Custom attention block settings
block_type: Optional[str] = None # attention block type ('default', 'custom'), auto-selects 'custom' if any custom features enabled
qk_norm: bool = False # apply layer norm to q and k in attention
scaled_cosine_attn: bool = False # use scaled cosine attention
scale_heads: bool = False # learnable head-specific scale applied to attention logits
scale_attn_inner: bool = False # apply layer norm on attention context, before output projection
scale_attn: bool = False # apply layer norm after full attention block
scale_fc: bool = False # apply layer norm in MLP block
# HuggingFace specific text tower config
hf_model_name: Optional[str] = None
hf_model_pretrained: bool = True
hf_proj_type: str = 'mlp'
hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models
special_tokens_to_add: Optional[dict] = None # special tokens to add to tokenizer (e.g., for Pythia)
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == 'bf16':
cast_dtype = torch.bfloat16
elif precision == 'fp16':
cast_dtype = torch.float16
return cast_dtype
def get_input_dtype(precision: str):
input_dtype = None
if precision in ('bf16', 'pure_bf16'):
input_dtype = torch.bfloat16
elif precision in ('fp16', 'pure_fp16'):
input_dtype = torch.float16
return input_dtype
def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
# memory efficient in recent PyTorch releases (>= 1.10).
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
act_layer = QuickGELU if quick_gelu else nn.GELU
if vision_cfg.timm_model_name:
from .timm_model import TimmModel
visual = TimmModel(
vision_cfg.timm_model_name,
pretrained=vision_cfg.timm_model_pretrained,
pool=vision_cfg.timm_pool,
proj=vision_cfg.timm_proj,
proj_bias=vision_cfg.timm_proj_bias,
drop=vision_cfg.timm_drop,
drop_path=vision_cfg.timm_drop_path,
patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
embed_dim=embed_dim,
image_size=vision_cfg.image_size,
)
elif isinstance(vision_cfg.layers, (tuple, list)):
from .modified_resnet import ModifiedResNet
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
visual = ModifiedResNet(
layers=vision_cfg.layers,
output_dim=embed_dim,
heads=vision_heads,
image_size=vision_cfg.image_size,
width=vision_cfg.width,
)
else:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
if vision_cfg.norm_kwargs:
norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
if vision_cfg.act_kwargs is not None:
act_layer = partial(act_layer, **vision_cfg.act_kwargs)
visual = VisionTransformer(
image_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
attentional_pool=vision_cfg.attentional_pool,
attn_pooler_queries=vision_cfg.attn_pooler_queries,
attn_pooler_heads=vision_cfg.attn_pooler_heads,
pos_embed_type=vision_cfg.pos_embed_type,
no_ln_pre=vision_cfg.no_ln_pre,
final_ln_after_pool=vision_cfg.final_ln_after_pool,
pool_type=vision_cfg.pool_type,
output_tokens=vision_cfg.output_tokens,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
block_type=vision_cfg.block_type,
qk_norm=vision_cfg.qk_norm,
scaled_cosine_attn=vision_cfg.scaled_cosine_attn,
scale_heads=vision_cfg.scale_heads,
scale_attn_inner=vision_cfg.scale_attn_inner,
scale_attn=vision_cfg.scale_attn,
scale_fc=vision_cfg.scale_fc,
)
return visual
def _build_text_tower(
embed_dim: int,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
if isinstance(text_cfg, dict):
text_cfg = CLIPTextCfg(**text_cfg)
if text_cfg.hf_model_name:
from .hf_model import HFTextEncoder
text = HFTextEncoder(
text_cfg.hf_model_name,
output_dim=embed_dim,
proj_type=text_cfg.hf_proj_type,
pooler_type=text_cfg.hf_pooler_type,
pretrained=text_cfg.hf_model_pretrained,
output_tokens=text_cfg.output_tokens,
)
# Handle special tokens if configured (e.g., for Pythia)
special_tokens_cfg = getattr(text_cfg, 'special_tokens_to_add', None)
if special_tokens_cfg:
from transformers import AutoTokenizer
import logging
# Load tokenizer from local cache only (ensures consistency with get_tokenizer())
# get_tokenizer() is called first and downloads/caches, we just reuse that exact version
tokenizer = AutoTokenizer.from_pretrained(
text_cfg.hf_model_name,
local_files_only=True
)
# Store original vocab size before adding new tokens
# This is needed to unfreeze new token embeddings after locking
original_vocab_size = len(tokenizer)
text.original_vocab_size = original_vocab_size
tokenizer.add_special_tokens(special_tokens_cfg)
# Resize model embeddings to accommodate new tokens
# pad_to_multiple_of=64 ensures optimal Tensor Core performance for embedding lookups
new_vocab_size = len(tokenizer)
text.transformer.resize_token_embeddings(new_vocab_size, pad_to_multiple_of=64)
# Store token IDs for use in forward pass
if 'additional_special_tokens' in special_tokens_cfg:
for token in special_tokens_cfg['additional_special_tokens']:
if token == '<coca_cls>':
text.coca_cls_token_id = tokenizer.convert_tokens_to_ids(token)
if 'pad_token' in special_tokens_cfg:
text.config.pad_token_id = tokenizer.pad_token_id
text.pad_token_id = tokenizer.pad_token_id
text.config.vocab_size = new_vocab_size
text.vocab_size = new_vocab_size
logging.info(f"Added special tokens to {text_cfg.hf_model_name}:")
logging.info(f" Original vocab size: {original_vocab_size}")
logging.info(f" New vocab size: {new_vocab_size}")
logging.info(f" Added {new_vocab_size - original_vocab_size} new tokens")
if text.coca_cls_token_id is not None:
logging.info(f" CoCa CLS token ID: {text.coca_cls_token_id}")
if text.pad_token_id is not None:
logging.info(f" Pad token ID: {text.pad_token_id}")
else:
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
if text_cfg.norm_kwargs:
norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)
if text_cfg.act_kwargs is not None:
act_layer = partial(act_layer, **text_cfg.act_kwargs)
text = TextTransformer(
context_length=text_cfg.context_length,
vocab_size=text_cfg.vocab_size,
width=text_cfg.width,
heads=text_cfg.heads,
layers=text_cfg.layers,
mlp_ratio=text_cfg.mlp_ratio,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
embed_cls=text_cfg.embed_cls,
no_causal_mask=text_cfg.no_causal_mask,
pad_id=text_cfg.pad_id,
eos_id=text_cfg.eos_id,
pool_type=text_cfg.pool_type,
proj_type=text_cfg.proj_type,
proj_bias=text_cfg.proj_bias,
output_tokens=text_cfg.output_tokens,
act_layer=act_layer,
norm_layer=norm_layer,
block_type=text_cfg.block_type,
qk_norm=text_cfg.qk_norm,
scaled_cosine_attn=text_cfg.scaled_cosine_attn,
scale_heads=text_cfg.scale_heads,
scale_attn_inner=text_cfg.scale_attn_inner,
scale_attn=text_cfg.scale_attn,
scale_fc=text_cfg.scale_fc,
)
return text
class CLIP(nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
nonscalar_logit_scale: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.transformer = text.transformer
self.context_length = text.context_length
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.text_pool_type = text.pool_type
self.text_eos_id = text.eos_id
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
lshape = [1] if nonscalar_logit_scale else []
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
else:
self.logit_bias = None
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
assert freeze_layer_norm, 'Unfreezing LayerNorm is not supported. LayerNorm treated like other weights.'
lock_text_tower(self, unlocked_layers)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
@torch.jit.ignore
def no_weight_decay(self):
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
no_wd = {'positional_embedding'}
if hasattr(self.visual, 'no_weight_decay'):
for n in self.visual.no_weight_decay():
no_wd.add('visual.' + n)
return no_wd
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = self.transformer(x, attn_mask=self.attn_mask)
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
x = text_global_pool(x, text, self.text_pool_type, eos_token_id=getattr(self, "text_eos_id", None))
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
x = self.text_projection(x)
else:
x = x @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x
def get_logits(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
image_logits = self.logit_scale.exp() * image_features @ text_features.T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
return image_logits, text_logits
def forward_intermediates(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
image_indices: Optional[Union[int, List[int]]] = None,
text_indices: Optional[Union[int, List[int]]] = None,
stop_early: bool = False,
normalize: bool = True,
normalize_intermediates: bool = False,
intermediates_only: bool = False,
image_output_fmt: str = 'NCHW',
image_output_extra_tokens: bool = False,
text_output_fmt: str = 'NLC',
text_output_extra_tokens: bool = False,
output_logits: bool = False,
output_logit_scale_bias: bool = False,
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
image: Input image tensor
text: Input text tensor
image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
text_indices: Take last n blocks if int, all if None, select matching indices if sequence
stop_early: Stop iterating over blocks when last desired intermediate hit
normalize_intermediates: Apply final norm layer to all intermediates
normalize: L2 Normalize final features
intermediates_only: Only return intermediate features, do not return final features
image_output_fmt: Shape of intermediate image feature outputs
image_output_extra_tokens: Return both prefix and spatial intermediate tokens
text_output_fmt: Shape of intermediate text feature outputs (ignored for this model)
text_output_extra_tokens: Return both prefix and spatial intermediate tokens (ignored for this model)
output_logits: Include logits in output
output_logit_scale_bias: Include the logit scale bias in the output
Returns:
"""
output = {}
if intermediates_only:
# intermediates only disables final feature normalization, and include logits
normalize = False
output_logits = False
if output_logits:
assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'
if image is not None:
image_output = self.visual.forward_intermediates(
image,
indices=image_indices,
stop_early=stop_early,
normalize_intermediates=normalize_intermediates,
intermediates_only=intermediates_only,
output_fmt=image_output_fmt,
output_extra_tokens=image_output_extra_tokens,
)
if normalize and "image_features" in image_output:
image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
output.update(image_output)
if text is not None:
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x, intermediates = self.transformer.forward_intermediates(
x,
attn_mask=self.attn_mask,
indices=text_indices
)
if normalize_intermediates:
intermediates = [self.ln_final(xi) for xi in intermediates]
# NOTE this model doesn't support cls embed in text transformer, no need for extra intermediate tokens
output["text_intermediates"] = intermediates
if not intermediates_only:
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
x = text_global_pool(x, text, self.text_pool_type, eos_token_id=getattr(self, "text_eos_id", None))
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
x = self.text_projection(x)
else:
x = x @ self.text_projection
if normalize:
x = F.normalize(x, dim=-1)
output["text_features"] = x
logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
if output_logits:
image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
output["image_logits"] = image_logits
output["text_logits"] = text_logits
if output_logit_scale_bias:
output["logit_scale"] = logit_scale_exp
if self.logit_bias is not None:
output['logit_bias'] = self.logit_bias
return output
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
):
image_features = self.encode_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None
if self.output_dict:
out_dict = {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
if self.logit_bias is not None:
out_dict['logit_bias'] = self.logit_bias
return out_dict
if self.logit_bias is not None:
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
return image_features, text_features, self.logit_scale.exp()
class CustomTextCLIP(nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
nonscalar_logit_scale: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.context_length = self.text.context_length
self.vocab_size = self.text.vocab_size
lshape = [1] if nonscalar_logit_scale else []
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
else:
self.logit_bias = None
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
self.text.lock(unlocked_layers, freeze_layer_norm)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
@torch.jit.ignore
def no_weight_decay(self):
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
no_wd = set()
if hasattr(self.visual, 'no_weight_decay'):
for n in self.visual.no_weight_decay():
no_wd.add('visual.' + n)
if hasattr(self.text, 'no_weight_decay'):
for n in self.text.no_weight_decay():
no_wd.add('text.' + n)
return no_wd
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def get_logits(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
image_logits = self.logit_scale.exp() * image_features @ text_features.T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
return image_logits, text_logits
def forward_intermediates(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
image_indices: Optional[Union[int, List[int]]] = None,
text_indices: Optional[Union[int, List[int]]] = None,
stop_early: bool = False,
normalize: bool = True,
normalize_intermediates: bool = False,
intermediates_only: bool = False,
image_output_fmt: str = 'NCHW',
image_output_extra_tokens: bool = False,
text_output_fmt: str = 'NLC',
text_output_extra_tokens: bool = False,
output_logits: bool = False,
output_logit_scale_bias: bool = False,
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
Args:
image: Input image tensor
text: Input text tensor
image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
text_indices: Take last n blocks if int, all if None, select matching indices if sequence
stop_early: Stop iterating over blocks when last desired intermediate hit
normalize: L2 Normalize final image and text features (if present)
normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)
intermediates_only: Only return intermediate features, do not return final features
image_output_fmt: Shape of intermediate image feature outputs
image_output_extra_tokens: Return both prefix and spatial intermediate tokens
text_output_fmt: Shape of intermediate text feature outputs
text_output_extra_tokens: Return both prefix and spatial intermediate tokens
output_logits: Include logits in output
output_logit_scale_bias: Include the logit scale bias in the output
Returns:
"""
output = {}
if intermediates_only:
# intermediates only disables final feature normalization, and include logits
normalize = False
output_logits = False
if output_logits:
assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'
if image is not None:
image_output = self.visual.forward_intermediates(
image,
indices=image_indices,
stop_early=stop_early,
normalize_intermediates=normalize_intermediates,
intermediates_only=intermediates_only,
output_fmt=image_output_fmt,
output_extra_tokens=image_output_extra_tokens,
)
if normalize and "image_features" in image_output:
image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
output.update(image_output)
if text is not None:
text_output = self.text.forward_intermediates(
text,
indices=text_indices,
stop_early=stop_early,
normalize_intermediates=normalize_intermediates,
intermediates_only=intermediates_only,
output_fmt=text_output_fmt,
output_extra_tokens=text_output_extra_tokens,
)
if normalize and "text_features" in text_output:
text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1)
output.update(text_output)
logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
if output_logits:
image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
output["image_logits"] = image_logits
output["text_logits"] = text_logits
if output_logit_scale_bias:
output["logit_scale"] = logit_scale_exp
if self.logit_bias is not None:
output['logit_bias'] = self.logit_bias
return output
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
):
image_features = self.encode_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None
if self.output_dict:
out_dict = {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
if self.logit_bias is not None:
out_dict['logit_bias'] = self.logit_bias
return out_dict
if self.logit_bias is not None:
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
return image_features, text_features, self.logit_scale.exp()
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
def _convert_weights(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.to(dtype)
if l.bias is not None:
l.bias.data = l.bias.data.to(dtype)
if isinstance(l, (nn.MultiheadAttention, Attention)):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr, None)
if tensor is not None:
tensor.data = tensor.data.to(dtype)
if isinstance(l, (CLIP, TextTransformer)):
# convert text nn.Parameter projections
attr = getattr(l, "text_projection", None)
if attr is not None:
attr.data = attr.data.to(dtype)
if isinstance(l, VisionTransformer):
# convert vision nn.Parameter projections
attr = getattr(l, "proj", None)
if attr is not None:
attr.data = attr.data.to(dtype)
model.apply(_convert_weights)
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
# used to maintain checkpoint compatibility
def convert_to_custom_text_state_dict(state_dict: dict):
if 'text_projection' in state_dict:
# old format state_dict, move text tower -> .text
new_state_dict = {}
for k, v in state_dict.items():
if any(k.startswith(p) for p in (
'text_projection',
'positional_embedding',
'token_embedding',
'transformer',
'ln_final',
)):
k = 'text.' + k
new_state_dict[k] = v
return new_state_dict
return state_dict
def build_model_from_openai_state_dict(
state_dict: dict,
quick_gelu=True,
cast_dtype=torch.float16,
):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_size = vision_patch_size * grid_size
else:
counts: list = [
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_size = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
vision_cfg = CLIPVisionCfg(
layers=vision_layers,
width=vision_width,
patch_size=vision_patch_size,
image_size=image_size,
)
text_cfg = CLIPTextCfg(
context_length=context_length,
vocab_size=vocab_size,
width=transformer_width,
heads=transformer_heads,
layers=transformer_layers,
)
model = CLIP(
embed_dim,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
cast_dtype=cast_dtype,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
model.load_state_dict(state_dict)
return model.eval()
def trace_model(model, batch_size=256, device=torch.device('cpu')):
model.eval()
image_size = model.visual.image_size
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
model = torch.jit.trace_module(
model,
inputs=dict(
forward=(example_images, example_text),
encode_text=(example_text,),
encode_image=(example_images,)
))
model.visual.image_size = image_size
return model
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get('visual.positional_embedding', None)
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
antialias=antialias,
align_corners=False,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict['visual.positional_embedding'] = new_pos_embed
def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):
pos_embed_key = 'positional_embedding' if 'positional_embedding' in state_dict else 'text.positional_embedding'
old_pos_embed = state_dict.get(pos_embed_key, None)
if old_pos_embed is None:
return
# FIXME add support for text cls_token
model_pos_embed = getattr(model, 'positional_embedding', None)
if model_pos_embed is None:
model_pos_embed = getattr(model.text, 'positional_embedding', None)
old_num_pos = old_pos_embed.shape[0]
old_width = old_pos_embed.shape[1]
num_pos = model_pos_embed.shape[0]
width = model_pos_embed.shape[1]
assert old_width == width, 'text pos_embed width changed!'
if old_num_pos == num_pos:
return
logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)
old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)
old_pos_embed = F.interpolate(
old_pos_embed,
size=num_pos,
mode=interpolation,
antialias=antialias,
align_corners=False,
)
old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]
new_pos_embed = old_pos_embed
state_dict[pos_embed_key] = new_pos_embed
def get_model_preprocess_cfg(model):
module = getattr(model, 'visual', model)
preprocess_cfg = getattr(module, 'preprocess_cfg', {})
if not preprocess_cfg:
# use separate legacy attributes if preprocess_cfg dict not found
size = getattr(module, 'image_size')
if size is not None:
preprocess_cfg['size'] = size
mean = getattr(module, 'image_mean', None)
if mean is not None:
preprocess_cfg['mean'] = mean
std = getattr(module, 'image_std', None)
if std is not None:
preprocess_cfg['std'] = std
return preprocess_cfg
def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):
module = getattr(model, 'visual', model)
module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat
module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat
module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict
def get_model_tokenize_cfg(model):
module = getattr(model, 'text', model)
cfg = {}
context_length = getattr(module, 'context_length', None)
if context_length is not None:
cfg['context_length'] = context_length
vocab_size = getattr(module, 'vocab_size', None)
if vocab_size is not None:
cfg['vocab_size'] = vocab_size
return cfg