Spaces:
Running
Running
| """ CLIP Model | |
| Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. | |
| """ | |
| import copy | |
| import logging | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torch.utils.checkpoint import checkpoint | |
| from functools import partial | |
| from .hf_model import HFTextEncoder | |
| from .modified_resnet import ModifiedResNet | |
| from .timm_model import TimmModel | |
| from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\ | |
| text_global_pool | |
| from .utils import to_2tuple | |
| from torchvision.ops import roi_align | |
| class CLIPVisionCfg: | |
| layers: Union[Tuple[int, int, int, int], int] = 12 | |
| width: int = 768 | |
| head_width: int = 64 | |
| mlp_ratio: float = 4.0 | |
| patch_size: int = 16 | |
| image_size: Union[Tuple[int, int], int] = 224 | |
| ls_init_value: Optional[float] = None # layer scale initial value | |
| patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results | |
| attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type) | |
| attn_pooler_queries: int = 256 # n_queries for attentional pooler | |
| attn_pooler_heads: int = 8 # n heads for attentional_pooling | |
| no_ln_pre: bool = False # disable pre transformer LayerNorm | |
| pos_embed_type: str = 'learnable' | |
| final_ln_after_pool: bool = False # apply final LayerNorm after pooling | |
| pool_type: str = 'tok' | |
| output_tokens: bool = False | |
| act_kwargs: Optional[dict] = None | |
| norm_kwargs: Optional[dict] = None | |
| timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size | |
| timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model | |
| timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') | |
| timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') | |
| timm_proj_bias: bool = False # enable bias final projection | |
| timm_drop: float = 0. # head dropout | |
| timm_drop_path: Optional[float] = None # backbone stochastic depth | |
| class CLIPTextCfg: | |
| context_length: int = 77 # 77 | |
| vocab_size: int = 49408 | |
| hf_tokenizer_name: Optional[str] = None | |
| tokenizer_kwargs: Optional[dict] = None | |
| width: int = 512 | |
| heads: int = 8 | |
| layers: int = 12 | |
| mlp_ratio: float = 4.0 | |
| ls_init_value: Optional[float] = None # layer scale initial value | |
| embed_cls: bool = False | |
| pad_id: int = 0 | |
| no_causal_mask: bool = False # disable causal masking | |
| final_ln_after_pool: bool = False # apply final LayerNorm after pooling | |
| pool_type: str = 'argmax' | |
| proj_bias: bool = False | |
| proj_type: str = 'linear' # control final text projection, 'none' forces no projection | |
| output_tokens: bool = False | |
| act_kwargs: dict = None | |
| norm_kwargs: dict = None | |
| # HuggingFace specific text tower config | |
| hf_model_name: Optional[str] = None | |
| hf_model_pretrained: bool = True | |
| hf_proj_type: str = 'mlp' | |
| hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models | |
| def get_cast_dtype(precision: str): | |
| cast_dtype = None | |
| if precision == 'bf16': | |
| cast_dtype = torch.bfloat16 | |
| elif precision == 'fp16': | |
| cast_dtype = torch.float16 | |
| return cast_dtype | |
| def get_input_dtype(precision: str): | |
| input_dtype = None | |
| if precision in ('bf16', 'pure_bf16'): | |
| input_dtype = torch.bfloat16 | |
| elif precision in ('fp16', 'pure_fp16'): | |
| input_dtype = torch.float16 | |
| return input_dtype | |
| def _build_vision_tower( | |
| embed_dim: int, | |
| vision_cfg: CLIPVisionCfg, | |
| quick_gelu: bool = False, | |
| cast_dtype: Optional[torch.dtype] = None | |
| ): | |
| if isinstance(vision_cfg, dict): | |
| vision_cfg = CLIPVisionCfg(**vision_cfg) | |
| # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more | |
| # memory efficient in recent PyTorch releases (>= 1.10). | |
| # NOTE: timm models always use native GELU regardless of quick_gelu flag. | |
| act_layer = QuickGELU if quick_gelu else nn.GELU | |
| if vision_cfg.timm_model_name: | |
| visual = TimmModel( | |
| vision_cfg.timm_model_name, | |
| pretrained=vision_cfg.timm_model_pretrained, | |
| pool=vision_cfg.timm_pool, | |
| proj=vision_cfg.timm_proj, | |
| proj_bias=vision_cfg.timm_proj_bias, | |
| drop=vision_cfg.timm_drop, | |
| drop_path=vision_cfg.timm_drop_path, | |
| patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, | |
| embed_dim=embed_dim, | |
| image_size=vision_cfg.image_size, | |
| ) | |
| elif isinstance(vision_cfg.layers, (tuple, list)): | |
| vision_heads = vision_cfg.width * 32 // vision_cfg.head_width | |
| visual = ModifiedResNet( | |
| layers=vision_cfg.layers, | |
| output_dim=embed_dim, | |
| heads=vision_heads, | |
| image_size=vision_cfg.image_size, | |
| width=vision_cfg.width, | |
| ) | |
| else: | |
| vision_heads = vision_cfg.width // vision_cfg.head_width | |
| norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm | |
| if vision_cfg.norm_kwargs: | |
| norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) | |
| if vision_cfg.act_kwargs is not None: | |
| act_layer = partial(act_layer, **vision_cfg.act_kwargs) | |
| visual = VisionTransformer( | |
| image_size=vision_cfg.image_size, | |
| patch_size=vision_cfg.patch_size, | |
| width=vision_cfg.width, | |
| layers=vision_cfg.layers, | |
| heads=vision_heads, | |
| mlp_ratio=vision_cfg.mlp_ratio, | |
| ls_init_value=vision_cfg.ls_init_value, | |
| patch_dropout=vision_cfg.patch_dropout, | |
| attentional_pool=vision_cfg.attentional_pool, | |
| attn_pooler_queries=vision_cfg.attn_pooler_queries, | |
| attn_pooler_heads=vision_cfg.attn_pooler_heads, | |
| pos_embed_type=vision_cfg.pos_embed_type, | |
| no_ln_pre=vision_cfg.no_ln_pre, | |
| final_ln_after_pool=vision_cfg.final_ln_after_pool, | |
| pool_type=vision_cfg.pool_type, | |
| output_tokens=vision_cfg.output_tokens, | |
| output_dim=embed_dim, | |
| act_layer=act_layer, | |
| norm_layer=norm_layer, | |
| ) | |
| return visual | |
| def _build_text_tower( | |
| embed_dim: int, | |
| text_cfg: CLIPTextCfg, | |
| quick_gelu: bool = False, | |
| cast_dtype: Optional[torch.dtype] = None, | |
| ): | |
| if isinstance(text_cfg, dict): | |
| text_cfg = CLIPTextCfg(**text_cfg) | |
| if text_cfg.hf_model_name: | |
| text = HFTextEncoder( | |
| text_cfg.hf_model_name, | |
| output_dim=embed_dim, | |
| proj_type=text_cfg.hf_proj_type, | |
| pooler_type=text_cfg.hf_pooler_type, | |
| pretrained=text_cfg.hf_model_pretrained, | |
| output_tokens=text_cfg.output_tokens, | |
| ) | |
| else: | |
| act_layer = QuickGELU if quick_gelu else nn.GELU | |
| norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm | |
| if text_cfg.norm_kwargs: | |
| norm_layer = partial(norm_layer, **text_cfg.norm_kwargs) | |
| if text_cfg.act_kwargs is not None: | |
| act_layer = partial(act_layer, **text_cfg.act_kwargs) | |
| text = TextTransformer( | |
| context_length=text_cfg.context_length, | |
| vocab_size=text_cfg.vocab_size, | |
| width=text_cfg.width, | |
| heads=text_cfg.heads, | |
| layers=text_cfg.layers, | |
| mlp_ratio=text_cfg.mlp_ratio, | |
| ls_init_value=text_cfg.ls_init_value, | |
| output_dim=embed_dim, | |
| embed_cls=text_cfg.embed_cls, | |
| no_causal_mask=text_cfg.no_causal_mask, | |
| pad_id=text_cfg.pad_id, | |
| pool_type=text_cfg.pool_type, | |
| proj_type=text_cfg.proj_type, | |
| proj_bias=text_cfg.proj_bias, | |
| output_tokens=text_cfg.output_tokens, | |
| act_layer=act_layer, | |
| norm_layer=norm_layer, | |
| ) | |
| return text | |
| class CLIP(nn.Module): | |
| output_dict: torch.jit.Final[bool] | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| vision_cfg: CLIPVisionCfg, | |
| text_cfg: CLIPTextCfg, | |
| quick_gelu: bool = False, | |
| init_logit_scale: float = np.log(1 / 0.07), | |
| init_logit_bias: Optional[float] = None, | |
| nonscalar_logit_scale: bool = False, | |
| cast_dtype: Optional[torch.dtype] = None, | |
| output_dict: bool = False, | |
| long_clip: str = "disable" | |
| ): | |
| super().__init__() | |
| self.output_dict = output_dict | |
| self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) | |
| self.long_clip = long_clip | |
| if not long_clip == "disable": | |
| text_cfg['context_length'] = 248 | |
| text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) | |
| self.transformer = text.transformer | |
| self.context_length = text.context_length | |
| self.vocab_size = text.vocab_size | |
| self.token_embedding = text.token_embedding | |
| self.positional_embedding = text.positional_embedding | |
| self.ln_final = text.ln_final | |
| self.text_projection = text.text_projection | |
| self.text_pool_type = text.pool_type | |
| self.register_buffer('attn_mask', text.attn_mask, persistent=False) | |
| if not long_clip == "disable": | |
| if long_clip == "load_from_scratch": | |
| self.positional_embedding = nn.Parameter(torch.empty(248, text.width)) | |
| self.positional_embedding_res = nn.Parameter(torch.empty(248, text.width)) | |
| elif long_clip == "load_from_clip": | |
| self.positional_embedding = nn.Parameter(torch.empty(77, text.width)) | |
| else: raise 'Incorrect parameter for long_clip.' | |
| self.register_buffer("mask1", torch.zeros(248, 1)) | |
| self.mask1[:20, :] = 1 | |
| self.register_buffer("mask2", torch.zeros(248, 1)) | |
| self.mask2[20:, :] = 1 | |
| lshape = [1] if nonscalar_logit_scale else [] | |
| self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) | |
| if init_logit_bias is not None: | |
| self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) | |
| else: | |
| self.logit_bias = None | |
| def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): | |
| # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 | |
| self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) | |
| def set_grad_checkpointing(self, enable=True): | |
| self.visual.set_grad_checkpointing(enable) | |
| self.transformer.grad_checkpointing = enable | |
| def no_weight_decay(self): | |
| # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default | |
| no_wd = {'positional_embedding'} | |
| if hasattr(self.visual, 'no_weight_decay'): | |
| for n in self.visual.no_weight_decay(): | |
| no_wd.add('visual.' + n) | |
| return no_wd | |
| def encode_image(self, image, normalize: bool = False): | |
| features = self.visual(image) | |
| return F.normalize(features, dim=-1) if normalize else features | |
| def encode_text(self, text, normalize: bool = False): | |
| cast_dtype = self.transformer.get_cast_dtype() | |
| x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] | |
| if self.long_clip == "disable": | |
| x = x + self.positional_embedding.to(cast_dtype) | |
| else: | |
| x = x + (self.positional_embedding * self.mask1).to(cast_dtype) + ( | |
| self.positional_embedding_res * self.mask2).to(cast_dtype) | |
| x = self.transformer(x, attn_mask=self.attn_mask) | |
| x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] | |
| x = text_global_pool(x, text, self.text_pool_type) | |
| if self.text_projection is not None: | |
| if isinstance(self.text_projection, nn.Linear): | |
| x = self.text_projection(x) | |
| else: | |
| x = x @ self.text_projection | |
| return F.normalize(x, dim=-1) if normalize else x | |
| def get_logits(self, image, text): | |
| image_features = self.encode_image(image, normalize=True) | |
| text_features = self.encode_text(text, normalize=True) | |
| image_logits = self.logit_scale.exp() * image_features @ text_features.T | |
| if self.logit_bias is not None: | |
| image_logits += self.logit_bias | |
| text_logits = image_logits.T | |
| return image_logits, text_logits | |
| def forward_intermediates( | |
| self, | |
| image: Optional[torch.Tensor] = None, | |
| text: Optional[torch.Tensor] = None, | |
| image_indices: Optional[Union[int, List[int]]] = None, | |
| text_indices: Optional[Union[int, List[int]]] = None, | |
| stop_early: bool = False, | |
| normalize: bool = True, | |
| normalize_intermediates: bool = False, | |
| intermediates_only: bool = False, | |
| image_output_fmt: str = 'NCHW', | |
| image_output_extra_tokens: bool = False, | |
| text_output_fmt: str = 'NLC', | |
| text_output_extra_tokens: bool = False, | |
| output_logits: bool = False, | |
| output_logit_scale_bias: bool = False, | |
| ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: | |
| """ Forward features that returns intermediates. | |
| Args: | |
| image: Input image tensor | |
| text: Input text tensor | |
| image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence | |
| text_indices: Take last n blocks if int, all if None, select matching indices if sequence | |
| stop_early: Stop iterating over blocks when last desired intermediate hit | |
| normalize_intermediates: Apply final norm layer to all intermediates | |
| normalize: L2 Normalize final features | |
| intermediates_only: Only return intermediate features, do not return final features | |
| image_output_fmt: Shape of intermediate image feature outputs | |
| image_output_extra_tokens: Return both prefix and spatial intermediate tokens | |
| text_output_fmt: Shape of intermediate text feature outputs (ignored for this model) | |
| text_output_extra_tokens: Return both prefix and spatial intermediate tokens (ignored for this model) | |
| output_logits: Include logits in output | |
| output_logit_scale_bias: Include the logit scale bias in the output | |
| Returns: | |
| """ | |
| output = {} | |
| if intermediates_only: | |
| # intermediates only disables final feature normalization, and include logits | |
| normalize = False | |
| output_logits = False | |
| if output_logits: | |
| assert image is not None and text is not None, 'Both image and text inputs are required to compute logits' | |
| if image is not None: | |
| image_output = self.visual.forward_intermediates( | |
| image, | |
| indices=image_indices, | |
| stop_early=stop_early, | |
| normalize_intermediates=normalize_intermediates, | |
| intermediates_only=intermediates_only, | |
| output_fmt=image_output_fmt, | |
| output_extra_tokens=image_output_extra_tokens, | |
| ) | |
| if normalize and "image_features" in image_output: | |
| image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) | |
| output.update(image_output) | |
| if text is not None: | |
| cast_dtype = self.transformer.get_cast_dtype() | |
| x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] | |
| x = x + self.positional_embedding.to(cast_dtype) | |
| x, intermediates = self.transformer.forward_intermediates( | |
| x, | |
| attn_mask=self.attn_mask, | |
| indices=text_indices | |
| ) | |
| if normalize_intermediates: | |
| intermediates = [self.ln_final(xi) for xi in intermediates] | |
| # NOTE this model doesn't support cls embed in text transformer, no need for extra intermediate tokens | |
| output["text_intermediates"] = intermediates | |
| if not intermediates_only: | |
| x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] | |
| x = text_global_pool(x, text, self.text_pool_type) | |
| if self.text_projection is not None: | |
| if isinstance(self.text_projection, nn.Linear): | |
| x = self.text_projection(x) | |
| else: | |
| x = x @ self.text_projection | |
| if normalize: | |
| x = F.normalize(x, dim=-1) | |
| output["text_features"] = x | |
| logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None | |
| if output_logits: | |
| image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T | |
| if self.logit_bias is not None: | |
| image_logits += self.logit_bias | |
| text_logits = image_logits.T | |
| output["image_logits"] = image_logits | |
| output["text_logits"] = text_logits | |
| if output_logit_scale_bias: | |
| output["logit_scale"] = logit_scale_exp | |
| if self.logit_bias is not None: | |
| output['logit_bias'] = self.logit_bias | |
| return output | |
| def load_from_pretrained_short_pe(self, keep_len=20): | |
| cast_dtype = self.transformer.get_cast_dtype() | |
| device = self.positional_embedding.device | |
| positional_embedding_pre = self.positional_embedding.type(cast_dtype) | |
| length, dim = positional_embedding_pre.shape | |
| posisitonal_embedding_new = torch.zeros([4 * length - 3 * keep_len, dim], dtype=cast_dtype).to(device) | |
| for i in range(keep_len): | |
| posisitonal_embedding_new[i] = positional_embedding_pre[i] | |
| for i in range(length - 1 - keep_len): | |
| posisitonal_embedding_new[4 * i + keep_len] = positional_embedding_pre[i + keep_len] | |
| posisitonal_embedding_new[4 * i + 1 + keep_len] = 3 * positional_embedding_pre[i + keep_len] / 4 + 1 * \ | |
| positional_embedding_pre[i + 1 + keep_len] / 4 | |
| posisitonal_embedding_new[4 * i + 2 + keep_len] = 2 * positional_embedding_pre[i + keep_len] / 4 + 2 * \ | |
| positional_embedding_pre[i + 1 + keep_len] / 4 | |
| posisitonal_embedding_new[4 * i + 3 + keep_len] = 1 * positional_embedding_pre[i + keep_len] / 4 + 3 * \ | |
| positional_embedding_pre[i + 1 + keep_len] / 4 | |
| posisitonal_embedding_new[4 * length - 3 * keep_len - 4] = positional_embedding_pre[length - 1] + 0 * ( | |
| positional_embedding_pre[length - 1] - positional_embedding_pre[length - 2]) / 4 | |
| posisitonal_embedding_new[4 * length - 3 * keep_len - 3] = positional_embedding_pre[length - 1] + 1 * ( | |
| positional_embedding_pre[length - 1] - positional_embedding_pre[length - 2]) / 4 | |
| posisitonal_embedding_new[4 * length - 3 * keep_len - 2] = positional_embedding_pre[length - 1] + 2 * ( | |
| positional_embedding_pre[length - 1] - positional_embedding_pre[length - 2]) / 4 | |
| posisitonal_embedding_new[4 * length - 3 * keep_len - 1] = positional_embedding_pre[length - 1] + 3 * ( | |
| positional_embedding_pre[length - 1] - positional_embedding_pre[length - 2]) / 4 | |
| positional_embedding_res = posisitonal_embedding_new.clone() | |
| self.positional_embedding = nn.Parameter(posisitonal_embedding_new, requires_grad=False) | |
| self.positional_embedding_res = nn.Parameter(positional_embedding_res, requires_grad=True) | |
| def forward( | |
| self, | |
| image: Optional[torch.Tensor] = None, | |
| text: Optional[torch.Tensor] = None, | |
| **kwargs | |
| ): | |
| if 'batch' in kwargs: | |
| batch = kwargs['batch'] | |
| used_losses = kwargs.get('used_losses', ["global_itc"]) | |
| last_attn_type = kwargs.get('last_attn_type') | |
| features = {} | |
| # ======= Visual Encoder: Global ======= | |
| global_image = batch["global_image"] | |
| global_pooled, global_patches = self.visual.forward_v2(global_image, last_attn_type) | |
| features["global_image_pooled"] = global_pooled | |
| features["global_patches"] = global_patches | |
| # ======= Visual Encoder: Local ======= | |
| if "local_itc" in used_losses: | |
| local_imgs = batch['local_images'] | |
| local_pooled, local_patches = self.visual.forward_v2(local_imgs, last_attn_type) | |
| features["local_image_pooled"] = local_pooled | |
| features["local_patches"] = local_patches | |
| if "distill" in used_losses: | |
| if "subset_images" in batch: | |
| subset_imgs = batch['subset_images'] | |
| subset_pooled, subset_patches = self.visual.forward_v2(subset_imgs, last_attn_type) | |
| features["subset_image_pooled"] = subset_pooled | |
| features["subset_patches"] = subset_patches | |
| # ======= Text Encoder ======= | |
| if "global_itc" in used_losses: | |
| global_text = batch["global_text"] | |
| features["global_text_pooled"] = self.encode_text(global_text) | |
| if "local_itc" in used_losses: | |
| local_texts = batch["local_texts"] | |
| features["local_text_pooled"] = self.encode_text(local_texts) | |
| # ======= Logit Scale ======= | |
| features["logit_scale"] = self.logit_scale.exp() | |
| return features | |
| # ======= Vanilla CLIP ======= | |
| image_features = self.encode_image(image, normalize=True) if image is not None else None | |
| text_features = self.encode_text(text, normalize=True) if text is not None else None | |
| if self.output_dict: | |
| out_dict = { | |
| "image_features": image_features, | |
| "text_features": text_features, | |
| "logit_scale": self.logit_scale.exp() | |
| } | |
| if self.logit_bias is not None: | |
| out_dict['logit_bias'] = self.logit_bias | |
| return out_dict | |
| if self.logit_bias is not None: | |
| return image_features, text_features, self.logit_scale.exp(), self.logit_bias | |
| return image_features, text_features, self.logit_scale.exp() | |
| class CustomTextCLIP(nn.Module): | |
| output_dict: torch.jit.Final[bool] | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| vision_cfg: CLIPVisionCfg, | |
| text_cfg: CLIPTextCfg, | |
| quick_gelu: bool = False, | |
| init_logit_scale: float = np.log(1 / 0.07), | |
| init_logit_bias: Optional[float] = None, | |
| nonscalar_logit_scale: bool = False, | |
| cast_dtype: Optional[torch.dtype] = None, | |
| output_dict: bool = False, | |
| ): | |
| super().__init__() | |
| self.output_dict = output_dict | |
| self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) | |
| self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) | |
| self.context_length = self.text.context_length | |
| self.vocab_size = self.text.vocab_size | |
| lshape = [1] if nonscalar_logit_scale else [] | |
| self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) | |
| if init_logit_bias is not None: | |
| self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) | |
| else: | |
| self.logit_bias = None | |
| def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): | |
| # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 | |
| self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) | |
| def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): | |
| self.text.lock(unlocked_layers, freeze_layer_norm) | |
| def set_grad_checkpointing(self, enable=True): | |
| self.visual.set_grad_checkpointing(enable) | |
| self.text.set_grad_checkpointing(enable) | |
| def no_weight_decay(self): | |
| # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default | |
| no_wd = set() | |
| if hasattr(self.visual, 'no_weight_decay'): | |
| for n in self.visual.no_weight_decay(): | |
| no_wd.add('visual.' + n) | |
| if hasattr(self.text, 'no_weight_decay'): | |
| for n in self.visual.no_weight_decay(): | |
| no_wd.add('text.' + n) | |
| return no_wd | |
| def encode_image(self, image, normalize: bool = False): | |
| features = self.visual(image) | |
| return F.normalize(features, dim=-1) if normalize else features | |
| def encode_text(self, text, normalize: bool = False): | |
| features = self.text(text) | |
| return F.normalize(features, dim=-1) if normalize else features | |
| def get_logits(self, image, text): | |
| image_features = self.encode_image(image, normalize=True) | |
| text_features = self.encode_text(text, normalize=True) | |
| image_logits = self.logit_scale.exp() * image_features @ text_features.T | |
| if self.logit_bias is not None: | |
| image_logits += self.logit_bias | |
| text_logits = image_logits.T | |
| return image_logits, text_logits | |
| def forward_intermediates( | |
| self, | |
| image: Optional[torch.Tensor] = None, | |
| text: Optional[torch.Tensor] = None, | |
| image_indices: Optional[Union[int, List[int]]] = None, | |
| text_indices: Optional[Union[int, List[int]]] = None, | |
| stop_early: bool = False, | |
| normalize: bool = True, | |
| normalize_intermediates: bool = False, | |
| intermediates_only: bool = False, | |
| image_output_fmt: str = 'NCHW', | |
| image_output_extra_tokens: bool = False, | |
| text_output_fmt: str = 'NLC', | |
| text_output_extra_tokens: bool = False, | |
| output_logits: bool = False, | |
| output_logit_scale_bias: bool = False, | |
| ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: | |
| """ Forward features that returns intermediates. | |
| Args: | |
| image: Input image tensor | |
| text: Input text tensor | |
| image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence | |
| text_indices: Take last n blocks if int, all if None, select matching indices if sequence | |
| stop_early: Stop iterating over blocks when last desired intermediate hit | |
| normalize: L2 Normalize final image and text features (if present) | |
| normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible) | |
| intermediates_only: Only return intermediate features, do not return final features | |
| image_output_fmt: Shape of intermediate image feature outputs | |
| image_output_extra_tokens: Return both prefix and spatial intermediate tokens | |
| text_output_fmt: Shape of intermediate text feature outputs | |
| text_output_extra_tokens: Return both prefix and spatial intermediate tokens | |
| output_logits: Include logits in output | |
| output_logit_scale_bias: Include the logit scale bias in the output | |
| Returns: | |
| """ | |
| output = {} | |
| if intermediates_only: | |
| # intermediates only disables final feature normalization, and include logits | |
| normalize = False | |
| output_logits = False | |
| if output_logits: | |
| assert image is not None and text is not None, 'Both image and text inputs are required to compute logits' | |
| if image is not None: | |
| image_output = self.visual.forward_intermediates( | |
| image, | |
| indices=image_indices, | |
| stop_early=stop_early, | |
| normalize_intermediates=normalize_intermediates, | |
| intermediates_only=intermediates_only, | |
| output_fmt=image_output_fmt, | |
| output_extra_tokens=image_output_extra_tokens, | |
| ) | |
| if normalize and "image_features" in image_output: | |
| image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) | |
| output.update(image_output) | |
| if text is not None: | |
| text_output = self.text.forward_intermediates( | |
| text, | |
| indices=text_indices, | |
| stop_early=stop_early, | |
| normalize_intermediates=normalize_intermediates, | |
| intermediates_only=intermediates_only, | |
| output_fmt=text_output_fmt, | |
| output_extra_tokens=text_output_extra_tokens, | |
| ) | |
| if normalize and "text_features" in text_output: | |
| text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1) | |
| output.update(text_output) | |
| logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None | |
| if output_logits: | |
| image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T | |
| if self.logit_bias is not None: | |
| image_logits += self.logit_bias | |
| text_logits = image_logits.T | |
| output["image_logits"] = image_logits | |
| output["text_logits"] = text_logits | |
| if output_logit_scale_bias: | |
| output["logit_scale"] = logit_scale_exp | |
| if self.logit_bias is not None: | |
| output['logit_bias'] = self.logit_bias | |
| return output | |
| def forward( | |
| self, | |
| image: Optional[torch.Tensor] = None, | |
| text: Optional[torch.Tensor] = None, | |
| ): | |
| image_features = self.encode_image(image, normalize=True) if image is not None else None | |
| text_features = self.encode_text(text, normalize=True) if text is not None else None | |
| if self.output_dict: | |
| out_dict = { | |
| "image_features": image_features, | |
| "text_features": text_features, | |
| "logit_scale": self.logit_scale.exp() | |
| } | |
| if self.logit_bias is not None: | |
| out_dict['logit_bias'] = self.logit_bias | |
| return out_dict | |
| if self.logit_bias is not None: | |
| return image_features, text_features, self.logit_scale.exp(), self.logit_bias | |
| return image_features, text_features, self.logit_scale.exp() | |
| def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): | |
| """Convert applicable model parameters to low-precision (bf16 or fp16)""" | |
| def _convert_weights(l): | |
| if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): | |
| l.weight.data = l.weight.data.to(dtype) | |
| if l.bias is not None: | |
| l.bias.data = l.bias.data.to(dtype) | |
| if isinstance(l, (nn.MultiheadAttention, Attention)): | |
| for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: | |
| tensor = getattr(l, attr) | |
| if tensor is not None: | |
| tensor.data = tensor.data.to(dtype) | |
| if isinstance(l, (CLIP, TextTransformer)): | |
| # convert text nn.Parameter projections | |
| attr = getattr(l, "text_projection", None) | |
| if attr is not None: | |
| attr.data = attr.data.to(dtype) | |
| if isinstance(l, VisionTransformer): | |
| # convert vision nn.Parameter projections | |
| attr = getattr(l, "proj", None) | |
| if attr is not None: | |
| attr.data = attr.data.to(dtype) | |
| model.apply(_convert_weights) | |
| convert_weights_to_fp16 = convert_weights_to_lp # backwards compat | |
| # used to maintain checkpoint compatibility | |
| def convert_to_custom_text_state_dict(state_dict: dict): | |
| if 'text_projection' in state_dict: | |
| # old format state_dict, move text tower -> .text | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if any(k.startswith(p) for p in ( | |
| 'text_projection', | |
| 'positional_embedding', | |
| 'token_embedding', | |
| 'transformer', | |
| 'ln_final', | |
| )): | |
| k = 'text.' + k | |
| new_state_dict[k] = v | |
| return new_state_dict | |
| return state_dict | |
| def build_model_from_openai_state_dict( | |
| state_dict: dict, | |
| quick_gelu=True, | |
| cast_dtype=torch.float16, | |
| ): | |
| vit = "visual.proj" in state_dict | |
| if vit: | |
| vision_width = state_dict["visual.conv1.weight"].shape[0] | |
| vision_layers = len( | |
| [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) | |
| vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] | |
| grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) | |
| image_size = vision_patch_size * grid_size | |
| else: | |
| counts: list = [ | |
| len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] | |
| vision_layers = tuple(counts) | |
| vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] | |
| output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) | |
| vision_patch_size = None | |
| assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] | |
| image_size = output_width * 32 | |
| embed_dim = state_dict["text_projection"].shape[1] | |
| context_length = state_dict["positional_embedding"].shape[0] | |
| vocab_size = state_dict["token_embedding.weight"].shape[0] | |
| transformer_width = state_dict["ln_final.weight"].shape[0] | |
| transformer_heads = transformer_width // 64 | |
| transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) | |
| vision_cfg = CLIPVisionCfg( | |
| layers=vision_layers, | |
| width=vision_width, | |
| patch_size=vision_patch_size, | |
| image_size=image_size, | |
| ) | |
| text_cfg = CLIPTextCfg( | |
| context_length=context_length, | |
| vocab_size=vocab_size, | |
| width=transformer_width, | |
| heads=transformer_heads, | |
| layers=transformer_layers, | |
| ) | |
| model = CLIP( | |
| embed_dim, | |
| vision_cfg=vision_cfg, | |
| text_cfg=text_cfg, | |
| quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU | |
| cast_dtype=cast_dtype, | |
| ) | |
| for key in ["input_resolution", "context_length", "vocab_size"]: | |
| state_dict.pop(key, None) | |
| convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 | |
| model.load_state_dict(state_dict) | |
| return model.eval() | |
| def trace_model(model, batch_size=256, device=torch.device('cpu')): | |
| model.eval() | |
| image_size = model.visual.image_size | |
| example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) | |
| example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) | |
| model = torch.jit.trace_module( | |
| model, | |
| inputs=dict( | |
| forward=(example_images, example_text), | |
| encode_text=(example_text,), | |
| encode_image=(example_images,) | |
| )) | |
| model.visual.image_size = image_size | |
| return model | |
| def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): | |
| # Rescale the grid of position embeddings when loading from state_dict | |
| old_pos_embed = state_dict.get('visual.positional_embedding', None) | |
| if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): | |
| return | |
| grid_size = to_2tuple(model.visual.grid_size) | |
| extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) | |
| new_seq_len = grid_size[0] * grid_size[1] + extra_tokens | |
| if new_seq_len == old_pos_embed.shape[0]: | |
| return | |
| if extra_tokens: | |
| pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] | |
| else: | |
| pos_emb_tok, pos_emb_img = None, old_pos_embed | |
| old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) | |
| logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) | |
| pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) | |
| pos_emb_img = F.interpolate( | |
| pos_emb_img, | |
| size=grid_size, | |
| mode=interpolation, | |
| antialias=antialias, | |
| align_corners=False, | |
| ) | |
| pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] | |
| if pos_emb_tok is not None: | |
| new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) | |
| else: | |
| new_pos_embed = pos_emb_img | |
| state_dict['visual.positional_embedding'] = new_pos_embed | |
| def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False): | |
| old_pos_embed = state_dict.get('positional_embedding', None) | |
| if old_pos_embed is None: | |
| return | |
| # FIXME add support for text cls_token | |
| model_pos_embed = getattr(model, 'positional_embedding', None) | |
| if model_pos_embed is None: | |
| model_pos_embed = getattr(model.text, 'positional_embedding', None) | |
| old_num_pos = old_pos_embed.shape[0] | |
| old_width = old_pos_embed.shape[1] | |
| num_pos = model_pos_embed.shape[0] | |
| width = model_pos_embed.shape[1] | |
| assert old_width == width, 'text pos_embed width changed!' | |
| if old_num_pos == num_pos: | |
| return | |
| logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos) | |
| old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1) | |
| old_pos_embed = F.interpolate( | |
| old_pos_embed, | |
| size=num_pos, | |
| mode=interpolation, | |
| antialias=antialias, | |
| align_corners=False, | |
| ) | |
| old_pos_embed = old_pos_embed.permute(0, 2, 1)[0] | |
| new_pos_embed = old_pos_embed | |
| state_dict['positional_embedding'] = new_pos_embed | |
| def get_model_preprocess_cfg(model): | |
| module = getattr(model, 'visual', model) | |
| preprocess_cfg = getattr(module, 'preprocess_cfg', {}) | |
| if not preprocess_cfg: | |
| # use separate legacy attributes if preprocess_cfg dict not found | |
| size = getattr(module, 'image_size') | |
| if size is not None: | |
| preprocess_cfg['size'] = size | |
| mean = getattr(module, 'image_mean', None) | |
| if mean is not None: | |
| preprocess_cfg['mean'] = mean | |
| std = getattr(module, 'image_std', None) | |
| if std is not None: | |
| preprocess_cfg['std'] = std | |
| return preprocess_cfg | |
| def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): | |
| module = getattr(model, 'visual', model) | |
| module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat | |
| module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat | |
| module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict | |
| def get_model_tokenize_cfg(model): | |
| module = getattr(model, 'text', model) | |
| cfg = {} | |
| context_length = getattr(module, 'context_length', None) | |
| if context_length is not None: | |
| cfg['context_length'] = context_length | |
| vocab_size = getattr(module, 'vocab_size', None) | |
| if vocab_size is not None: | |
| cfg['vocab_size'] = vocab_size | |
| return cfg |