from typing import Optional import torch from torch import nn from torch.nn import functional as F import numpy as np from dataclasses import dataclass from .transformer import ( LayerNormFp32, LayerNorm, QuickGELU, MultimodalTransformer, MixClsHead, ) from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower @dataclass class ClassHeadCfg(CLIPTextCfg): cls_mlp_ratio: int = 4 cls_layers: int = 1 def _build_cls_head( width, embed_dim, clshead_cfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): clshead_cfg = ClassHeadCfg(**clshead_cfg) if isinstance(clshead_cfg, dict) else clshead_cfg act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = ( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm ) head = MixClsHead( width=width, embed_dim=embed_dim, layers=clshead_cfg.cls_layers, mlp_ratio=clshead_cfg.cls_mlp_ratio, act_layer=act_layer, norm_layer=norm_layer, output_dim=clshead_cfg.vocab_size, ) return head class Classifier(nn.Module): def __init__( self, embed_dim, text_cfg: CLIPTextCfg, vision_cfg: CLIPVisionCfg, init_logit_scale: float = np.log(1 / 0.07), quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): super().__init__() text_cfg = ClassHeadCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg self.visual = _build_vision_tower(0, vision_cfg, quick_gelu, cast_dtype) self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.context_length = self.text.context_length self.vocab_size = self.text.vocab_size self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) self.head = _build_cls_head( vision_cfg.width, embed_dim, clshead_cfg=text_cfg, quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) self.register_buffer("cap_fq", torch.zeros([1, self.vocab_size], dtype=torch.float64)) self.register_buffer("num_samples", torch.zeros([1, 1], dtype=torch.float64)) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) # self.text_decoder.set_grad_checkpointing(enable) def encode_image(self, images, normalize=False, return_all=False): image_features = self.visual(images) image_features, logits = self.head(image_features) image_features = F.normalize(image_features, dim=-1) if normalize else image_features if return_all: return image_features, logits return image_features def encode_text(self, text, normalize=False): features = self.text(text) return F.normalize(features, dim=-1) if normalize else features def forward(self, image=None, text=None): image_features = self.encode_image(image, normalize=True, return_all=True) if image is not None else None text_features = self.encode_text(text, normalize=True) if text is not None else None labels = text.clone() return { "cap_fq": self.cap_fq, "num_samples": self.num_samples, "image_features": image_features, "text_features": text_features, "labels": labels, "logit_scale": self.logit_scale.exp(), }