""" CLIP Model Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ import copy import logging import math from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.utils.checkpoint import checkpoint from functools import partial from .transformer import ( LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer, text_global_pool, lock_text_tower, to_2tuple, ) @dataclass class CLIPVisionCfg: layers: Union[Tuple[int, int, int, int], int] = 12 width: int = 768 head_width: int = 64 mlp_ratio: float = 4.0 patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 ls_init_value: Optional[float] = None # layer scale initial value patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type) attn_pooler_queries: int = 256 # n_queries for attentional pooler attn_pooler_heads: int = 8 # n heads for attentional_pooling no_ln_pre: bool = False # disable pre transformer LayerNorm pos_embed_type: str = 'learnable' final_ln_after_pool: bool = False # apply final LayerNorm after pooling pool_type: str = 'tok' output_tokens: bool = False act_kwargs: Optional[dict] = None norm_kwargs: Optional[dict] = None # Custom attention block settings block_type: Optional[str] = None # attention block type ('default', 'custom'), auto-selects 'custom' if any below features enabled qk_norm: bool = False # apply layer norm to q and k in attention scaled_cosine_attn: bool = False # use scaled cosine attention scale_heads: bool = False # learnable head-specific scale applied to attention logits scale_attn_inner: bool = False # apply layer norm on attention context, before output projection scale_attn: bool = False # apply layer norm after full attention block scale_fc: bool = False # apply layer norm in MLP block timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') timm_proj_bias: bool = False # enable bias final projection timm_drop: float = 0. # head dropout timm_drop_path: Optional[float] = None # backbone stochastic depth @dataclass class CLIPTextCfg: context_length: int = 77 vocab_size: int = 49408 hf_tokenizer_name: Optional[str] = None tokenizer_mode: Optional[str] = None tokenizer_kwargs: Optional[dict] = None width: int = 512 heads: int = 8 layers: int = 12 mlp_ratio: float = 4.0 ls_init_value: Optional[float] = None # layer scale initial value embed_cls: bool = False pad_id: int = 0 eos_id: int = 2 # only used for when pool_type == 'eos', must match tokenizer eos no_causal_mask: bool = False # disable causal masking final_ln_after_pool: bool = False # apply final LayerNorm after pooling pool_type: str = 'argmax' proj_bias: bool = False proj_type: str = 'linear' # control final text projection, 'none' forces no projection output_tokens: bool = False act_kwargs: dict = None norm_kwargs: dict = None # Custom attention block settings block_type: Optional[str] = None # attention block type ('default', 'custom'), auto-selects 'custom' if any custom features enabled qk_norm: bool = False # apply layer norm to q and k in attention scaled_cosine_attn: bool = False # use scaled cosine attention scale_heads: bool = False # learnable head-specific scale applied to attention logits scale_attn_inner: bool = False # apply layer norm on attention context, before output projection scale_attn: bool = False # apply layer norm after full attention block scale_fc: bool = False # apply layer norm in MLP block # HuggingFace specific text tower config hf_model_name: Optional[str] = None hf_model_pretrained: bool = True hf_proj_type: str = 'mlp' hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models special_tokens_to_add: Optional[dict] = None # special tokens to add to tokenizer (e.g., for Pythia) def get_cast_dtype(precision: str): cast_dtype = None if precision == 'bf16': cast_dtype = torch.bfloat16 elif precision == 'fp16': cast_dtype = torch.float16 return cast_dtype def get_input_dtype(precision: str): input_dtype = None if precision in ('bf16', 'pure_bf16'): input_dtype = torch.bfloat16 elif precision in ('fp16', 'pure_fp16'): input_dtype = torch.float16 return input_dtype def _build_vision_tower( embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None ): if isinstance(vision_cfg, dict): vision_cfg = CLIPVisionCfg(**vision_cfg) # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more # memory efficient in recent PyTorch releases (>= 1.10). # NOTE: timm models always use native GELU regardless of quick_gelu flag. act_layer = QuickGELU if quick_gelu else nn.GELU if vision_cfg.timm_model_name: from .timm_model import TimmModel visual = TimmModel( vision_cfg.timm_model_name, pretrained=vision_cfg.timm_model_pretrained, pool=vision_cfg.timm_pool, proj=vision_cfg.timm_proj, proj_bias=vision_cfg.timm_proj_bias, drop=vision_cfg.timm_drop, drop_path=vision_cfg.timm_drop_path, patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, embed_dim=embed_dim, image_size=vision_cfg.image_size, ) elif isinstance(vision_cfg.layers, (tuple, list)): from .modified_resnet import ModifiedResNet vision_heads = vision_cfg.width * 32 // vision_cfg.head_width visual = ModifiedResNet( layers=vision_cfg.layers, output_dim=embed_dim, heads=vision_heads, image_size=vision_cfg.image_size, width=vision_cfg.width, ) else: vision_heads = vision_cfg.width // vision_cfg.head_width norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm if vision_cfg.norm_kwargs: norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) if vision_cfg.act_kwargs is not None: act_layer = partial(act_layer, **vision_cfg.act_kwargs) visual = VisionTransformer( image_size=vision_cfg.image_size, patch_size=vision_cfg.patch_size, width=vision_cfg.width, layers=vision_cfg.layers, heads=vision_heads, mlp_ratio=vision_cfg.mlp_ratio, ls_init_value=vision_cfg.ls_init_value, patch_dropout=vision_cfg.patch_dropout, attentional_pool=vision_cfg.attentional_pool, attn_pooler_queries=vision_cfg.attn_pooler_queries, attn_pooler_heads=vision_cfg.attn_pooler_heads, pos_embed_type=vision_cfg.pos_embed_type, no_ln_pre=vision_cfg.no_ln_pre, final_ln_after_pool=vision_cfg.final_ln_after_pool, pool_type=vision_cfg.pool_type, output_tokens=vision_cfg.output_tokens, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, block_type=vision_cfg.block_type, qk_norm=vision_cfg.qk_norm, scaled_cosine_attn=vision_cfg.scaled_cosine_attn, scale_heads=vision_cfg.scale_heads, scale_attn_inner=vision_cfg.scale_attn_inner, scale_attn=vision_cfg.scale_attn, scale_fc=vision_cfg.scale_fc, ) return visual def _build_text_tower( embed_dim: int, text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): if isinstance(text_cfg, dict): text_cfg = CLIPTextCfg(**text_cfg) if text_cfg.hf_model_name: from .hf_model import HFTextEncoder text = HFTextEncoder( text_cfg.hf_model_name, output_dim=embed_dim, proj_type=text_cfg.hf_proj_type, pooler_type=text_cfg.hf_pooler_type, pretrained=text_cfg.hf_model_pretrained, output_tokens=text_cfg.output_tokens, ) # Handle special tokens if configured (e.g., for Pythia) special_tokens_cfg = getattr(text_cfg, 'special_tokens_to_add', None) if special_tokens_cfg: from transformers import AutoTokenizer import logging # Load tokenizer from local cache only (ensures consistency with get_tokenizer()) # get_tokenizer() is called first and downloads/caches, we just reuse that exact version tokenizer = AutoTokenizer.from_pretrained( text_cfg.hf_model_name, local_files_only=True ) # Store original vocab size before adding new tokens # This is needed to unfreeze new token embeddings after locking original_vocab_size = len(tokenizer) text.original_vocab_size = original_vocab_size tokenizer.add_special_tokens(special_tokens_cfg) # Resize model embeddings to accommodate new tokens # pad_to_multiple_of=64 ensures optimal Tensor Core performance for embedding lookups new_vocab_size = len(tokenizer) text.transformer.resize_token_embeddings(new_vocab_size, pad_to_multiple_of=64) # Store token IDs for use in forward pass if 'additional_special_tokens' in special_tokens_cfg: for token in special_tokens_cfg['additional_special_tokens']: if token == '': text.coca_cls_token_id = tokenizer.convert_tokens_to_ids(token) if 'pad_token' in special_tokens_cfg: text.config.pad_token_id = tokenizer.pad_token_id text.pad_token_id = tokenizer.pad_token_id text.config.vocab_size = new_vocab_size text.vocab_size = new_vocab_size logging.info(f"Added special tokens to {text_cfg.hf_model_name}:") logging.info(f" Original vocab size: {original_vocab_size}") logging.info(f" New vocab size: {new_vocab_size}") logging.info(f" Added {new_vocab_size - original_vocab_size} new tokens") if text.coca_cls_token_id is not None: logging.info(f" CoCa CLS token ID: {text.coca_cls_token_id}") if text.pad_token_id is not None: logging.info(f" Pad token ID: {text.pad_token_id}") else: act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm if text_cfg.norm_kwargs: norm_layer = partial(norm_layer, **text_cfg.norm_kwargs) if text_cfg.act_kwargs is not None: act_layer = partial(act_layer, **text_cfg.act_kwargs) text = TextTransformer( context_length=text_cfg.context_length, vocab_size=text_cfg.vocab_size, width=text_cfg.width, heads=text_cfg.heads, layers=text_cfg.layers, mlp_ratio=text_cfg.mlp_ratio, ls_init_value=text_cfg.ls_init_value, output_dim=embed_dim, embed_cls=text_cfg.embed_cls, no_causal_mask=text_cfg.no_causal_mask, pad_id=text_cfg.pad_id, eos_id=text_cfg.eos_id, pool_type=text_cfg.pool_type, proj_type=text_cfg.proj_type, proj_bias=text_cfg.proj_bias, output_tokens=text_cfg.output_tokens, act_layer=act_layer, norm_layer=norm_layer, block_type=text_cfg.block_type, qk_norm=text_cfg.qk_norm, scaled_cosine_attn=text_cfg.scaled_cosine_attn, scale_heads=text_cfg.scale_heads, scale_attn_inner=text_cfg.scale_attn_inner, scale_attn=text_cfg.scale_attn, scale_fc=text_cfg.scale_fc, ) return text class CLIP(nn.Module): output_dict: torch.jit.Final[bool] def __init__( self, embed_dim: int, vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, init_logit_scale: float = np.log(1 / 0.07), init_logit_bias: Optional[float] = None, nonscalar_logit_scale: bool = False, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): super().__init__() self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.transformer = text.transformer self.context_length = text.context_length self.vocab_size = text.vocab_size self.token_embedding = text.token_embedding self.positional_embedding = text.positional_embedding self.ln_final = text.ln_final self.text_projection = text.text_projection self.text_pool_type = text.pool_type self.text_eos_id = text.eos_id self.register_buffer('attn_mask', text.attn_mask, persistent=False) lshape = [1] if nonscalar_logit_scale else [] self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) if init_logit_bias is not None: self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) else: self.logit_bias = None def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): assert freeze_layer_norm, 'Unfreezing LayerNorm is not supported. LayerNorm treated like other weights.' lock_text_tower(self, unlocked_layers) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.transformer.grad_checkpointing = enable @torch.jit.ignore def no_weight_decay(self): # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default no_wd = {'positional_embedding'} if hasattr(self.visual, 'no_weight_decay'): for n in self.visual.no_weight_decay(): no_wd.add('visual.' + n) return no_wd def encode_image(self, image, normalize: bool = False): features = self.visual(image) return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): cast_dtype = self.transformer.get_cast_dtype() x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.to(cast_dtype) x = self.transformer(x, attn_mask=self.attn_mask) x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] x = text_global_pool(x, text, self.text_pool_type, eos_token_id=getattr(self, "text_eos_id", None)) if self.text_projection is not None: if isinstance(self.text_projection, nn.Linear): x = self.text_projection(x) else: x = x @ self.text_projection return F.normalize(x, dim=-1) if normalize else x def get_logits(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) image_logits = self.logit_scale.exp() * image_features @ text_features.T if self.logit_bias is not None: image_logits += self.logit_bias text_logits = image_logits.T return image_logits, text_logits def forward_intermediates( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, image_indices: Optional[Union[int, List[int]]] = None, text_indices: Optional[Union[int, List[int]]] = None, stop_early: bool = False, normalize: bool = True, normalize_intermediates: bool = False, intermediates_only: bool = False, image_output_fmt: str = 'NCHW', image_output_extra_tokens: bool = False, text_output_fmt: str = 'NLC', text_output_extra_tokens: bool = False, output_logits: bool = False, output_logit_scale_bias: bool = False, ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. Args: image: Input image tensor text: Input text tensor image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence text_indices: Take last n blocks if int, all if None, select matching indices if sequence stop_early: Stop iterating over blocks when last desired intermediate hit normalize_intermediates: Apply final norm layer to all intermediates normalize: L2 Normalize final features intermediates_only: Only return intermediate features, do not return final features image_output_fmt: Shape of intermediate image feature outputs image_output_extra_tokens: Return both prefix and spatial intermediate tokens text_output_fmt: Shape of intermediate text feature outputs (ignored for this model) text_output_extra_tokens: Return both prefix and spatial intermediate tokens (ignored for this model) output_logits: Include logits in output output_logit_scale_bias: Include the logit scale bias in the output Returns: """ output = {} if intermediates_only: # intermediates only disables final feature normalization, and include logits normalize = False output_logits = False if output_logits: assert image is not None and text is not None, 'Both image and text inputs are required to compute logits' if image is not None: image_output = self.visual.forward_intermediates( image, indices=image_indices, stop_early=stop_early, normalize_intermediates=normalize_intermediates, intermediates_only=intermediates_only, output_fmt=image_output_fmt, output_extra_tokens=image_output_extra_tokens, ) if normalize and "image_features" in image_output: image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) output.update(image_output) if text is not None: cast_dtype = self.transformer.get_cast_dtype() x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.to(cast_dtype) x, intermediates = self.transformer.forward_intermediates( x, attn_mask=self.attn_mask, indices=text_indices ) if normalize_intermediates: intermediates = [self.ln_final(xi) for xi in intermediates] # NOTE this model doesn't support cls embed in text transformer, no need for extra intermediate tokens output["text_intermediates"] = intermediates if not intermediates_only: x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] x = text_global_pool(x, text, self.text_pool_type, eos_token_id=getattr(self, "text_eos_id", None)) if self.text_projection is not None: if isinstance(self.text_projection, nn.Linear): x = self.text_projection(x) else: x = x @ self.text_projection if normalize: x = F.normalize(x, dim=-1) output["text_features"] = x logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None if output_logits: image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T if self.logit_bias is not None: image_logits += self.logit_bias text_logits = image_logits.T output["image_logits"] = image_logits output["text_logits"] = text_logits if output_logit_scale_bias: output["logit_scale"] = logit_scale_exp if self.logit_bias is not None: output['logit_bias'] = self.logit_bias return output def forward( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, ): image_features = self.encode_image(image, normalize=True) if image is not None else None text_features = self.encode_text(text, normalize=True) if text is not None else None if self.output_dict: out_dict = { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale.exp() } if self.logit_bias is not None: out_dict['logit_bias'] = self.logit_bias return out_dict if self.logit_bias is not None: return image_features, text_features, self.logit_scale.exp(), self.logit_bias return image_features, text_features, self.logit_scale.exp() class CustomTextCLIP(nn.Module): output_dict: torch.jit.Final[bool] def __init__( self, embed_dim: int, vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, init_logit_scale: float = np.log(1 / 0.07), init_logit_bias: Optional[float] = None, nonscalar_logit_scale: bool = False, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): super().__init__() self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.context_length = self.text.context_length self.vocab_size = self.text.vocab_size lshape = [1] if nonscalar_logit_scale else [] self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) if init_logit_bias is not None: self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) else: self.logit_bias = None def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): self.text.lock(unlocked_layers, freeze_layer_norm) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) @torch.jit.ignore def no_weight_decay(self): # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default no_wd = set() if hasattr(self.visual, 'no_weight_decay'): for n in self.visual.no_weight_decay(): no_wd.add('visual.' + n) if hasattr(self.text, 'no_weight_decay'): for n in self.text.no_weight_decay(): no_wd.add('text.' + n) return no_wd def encode_image(self, image, normalize: bool = False): features = self.visual(image) return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): features = self.text(text) return F.normalize(features, dim=-1) if normalize else features def get_logits(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) image_logits = self.logit_scale.exp() * image_features @ text_features.T if self.logit_bias is not None: image_logits += self.logit_bias text_logits = image_logits.T return image_logits, text_logits def forward_intermediates( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, image_indices: Optional[Union[int, List[int]]] = None, text_indices: Optional[Union[int, List[int]]] = None, stop_early: bool = False, normalize: bool = True, normalize_intermediates: bool = False, intermediates_only: bool = False, image_output_fmt: str = 'NCHW', image_output_extra_tokens: bool = False, text_output_fmt: str = 'NLC', text_output_extra_tokens: bool = False, output_logits: bool = False, output_logit_scale_bias: bool = False, ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. Args: image: Input image tensor text: Input text tensor image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence text_indices: Take last n blocks if int, all if None, select matching indices if sequence stop_early: Stop iterating over blocks when last desired intermediate hit normalize: L2 Normalize final image and text features (if present) normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible) intermediates_only: Only return intermediate features, do not return final features image_output_fmt: Shape of intermediate image feature outputs image_output_extra_tokens: Return both prefix and spatial intermediate tokens text_output_fmt: Shape of intermediate text feature outputs text_output_extra_tokens: Return both prefix and spatial intermediate tokens output_logits: Include logits in output output_logit_scale_bias: Include the logit scale bias in the output Returns: """ output = {} if intermediates_only: # intermediates only disables final feature normalization, and include logits normalize = False output_logits = False if output_logits: assert image is not None and text is not None, 'Both image and text inputs are required to compute logits' if image is not None: image_output = self.visual.forward_intermediates( image, indices=image_indices, stop_early=stop_early, normalize_intermediates=normalize_intermediates, intermediates_only=intermediates_only, output_fmt=image_output_fmt, output_extra_tokens=image_output_extra_tokens, ) if normalize and "image_features" in image_output: image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) output.update(image_output) if text is not None: text_output = self.text.forward_intermediates( text, indices=text_indices, stop_early=stop_early, normalize_intermediates=normalize_intermediates, intermediates_only=intermediates_only, output_fmt=text_output_fmt, output_extra_tokens=text_output_extra_tokens, ) if normalize and "text_features" in text_output: text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1) output.update(text_output) logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None if output_logits: image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T if self.logit_bias is not None: image_logits += self.logit_bias text_logits = image_logits.T output["image_logits"] = image_logits output["text_logits"] = text_logits if output_logit_scale_bias: output["logit_scale"] = logit_scale_exp if self.logit_bias is not None: output['logit_bias'] = self.logit_bias return output def forward( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, ): image_features = self.encode_image(image, normalize=True) if image is not None else None text_features = self.encode_text(text, normalize=True) if text is not None else None if self.output_dict: out_dict = { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale.exp() } if self.logit_bias is not None: out_dict['logit_bias'] = self.logit_bias return out_dict if self.logit_bias is not None: return image_features, text_features, self.logit_scale.exp(), self.logit_bias return image_features, text_features, self.logit_scale.exp() def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): """Convert applicable model parameters to low-precision (bf16 or fp16)""" def _convert_weights(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.to(dtype) if l.bias is not None: l.bias.data = l.bias.data.to(dtype) if isinstance(l, (nn.MultiheadAttention, Attention)): for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: tensor = getattr(l, attr, None) if tensor is not None: tensor.data = tensor.data.to(dtype) if isinstance(l, (CLIP, TextTransformer)): # convert text nn.Parameter projections attr = getattr(l, "text_projection", None) if attr is not None: attr.data = attr.data.to(dtype) if isinstance(l, VisionTransformer): # convert vision nn.Parameter projections attr = getattr(l, "proj", None) if attr is not None: attr.data = attr.data.to(dtype) model.apply(_convert_weights) convert_weights_to_fp16 = convert_weights_to_lp # backwards compat # used to maintain checkpoint compatibility def convert_to_custom_text_state_dict(state_dict: dict): if 'text_projection' in state_dict: # old format state_dict, move text tower -> .text new_state_dict = {} for k, v in state_dict.items(): if any(k.startswith(p) for p in ( 'text_projection', 'positional_embedding', 'token_embedding', 'transformer', 'ln_final', )): k = 'text.' + k new_state_dict[k] = v return new_state_dict return state_dict def build_model_from_openai_state_dict( state_dict: dict, quick_gelu=True, cast_dtype=torch.float16, ): vit = "visual.proj" in state_dict if vit: vision_width = state_dict["visual.conv1.weight"].shape[0] vision_layers = len( [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) image_size = vision_patch_size * grid_size else: counts: list = [ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] vision_layers = tuple(counts) vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) vision_patch_size = None assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] image_size = output_width * 32 embed_dim = state_dict["text_projection"].shape[1] context_length = state_dict["positional_embedding"].shape[0] vocab_size = state_dict["token_embedding.weight"].shape[0] transformer_width = state_dict["ln_final.weight"].shape[0] transformer_heads = transformer_width // 64 transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) vision_cfg = CLIPVisionCfg( layers=vision_layers, width=vision_width, patch_size=vision_patch_size, image_size=image_size, ) text_cfg = CLIPTextCfg( context_length=context_length, vocab_size=vocab_size, width=transformer_width, heads=transformer_heads, layers=transformer_layers, ) model = CLIP( embed_dim, vision_cfg=vision_cfg, text_cfg=text_cfg, quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU cast_dtype=cast_dtype, ) for key in ["input_resolution", "context_length", "vocab_size"]: state_dict.pop(key, None) convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 model.load_state_dict(state_dict) return model.eval() def trace_model(model, batch_size=256, device=torch.device('cpu')): model.eval() image_size = model.visual.image_size example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) model = torch.jit.trace_module( model, inputs=dict( forward=(example_images, example_text), encode_text=(example_text,), encode_image=(example_images,) )) model.visual.image_size = image_size return model def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): # Rescale the grid of position embeddings when loading from state_dict old_pos_embed = state_dict.get('visual.positional_embedding', None) if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): return grid_size = to_2tuple(model.visual.grid_size) extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) new_seq_len = grid_size[0] * grid_size[1] + extra_tokens if new_seq_len == old_pos_embed.shape[0]: return if extra_tokens: pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] else: pos_emb_tok, pos_emb_img = None, old_pos_embed old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) pos_emb_img = F.interpolate( pos_emb_img, size=grid_size, mode=interpolation, antialias=antialias, align_corners=False, ) pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] if pos_emb_tok is not None: new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) else: new_pos_embed = pos_emb_img state_dict['visual.positional_embedding'] = new_pos_embed def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False): pos_embed_key = 'positional_embedding' if 'positional_embedding' in state_dict else 'text.positional_embedding' old_pos_embed = state_dict.get(pos_embed_key, None) if old_pos_embed is None: return # FIXME add support for text cls_token model_pos_embed = getattr(model, 'positional_embedding', None) if model_pos_embed is None: model_pos_embed = getattr(model.text, 'positional_embedding', None) old_num_pos = old_pos_embed.shape[0] old_width = old_pos_embed.shape[1] num_pos = model_pos_embed.shape[0] width = model_pos_embed.shape[1] assert old_width == width, 'text pos_embed width changed!' if old_num_pos == num_pos: return logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos) old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1) old_pos_embed = F.interpolate( old_pos_embed, size=num_pos, mode=interpolation, antialias=antialias, align_corners=False, ) old_pos_embed = old_pos_embed.permute(0, 2, 1)[0] new_pos_embed = old_pos_embed state_dict[pos_embed_key] = new_pos_embed def get_model_preprocess_cfg(model): module = getattr(model, 'visual', model) preprocess_cfg = getattr(module, 'preprocess_cfg', {}) if not preprocess_cfg: # use separate legacy attributes if preprocess_cfg dict not found size = getattr(module, 'image_size') if size is not None: preprocess_cfg['size'] = size mean = getattr(module, 'image_mean', None) if mean is not None: preprocess_cfg['mean'] = mean std = getattr(module, 'image_std', None) if std is not None: preprocess_cfg['std'] = std return preprocess_cfg def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): module = getattr(model, 'visual', model) module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict def get_model_tokenize_cfg(model): module = getattr(model, 'text', model) cfg = {} context_length = getattr(module, 'context_length', None) if context_length is not None: cfg['context_length'] = context_length vocab_size = getattr(module, 'vocab_size', None) if vocab_size is not None: cfg['vocab_size'] = vocab_size return cfg