| """ Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats. |
| """ |
| from typing import Union |
|
|
| import torch |
| import numpy as np |
|
|
| from .model import CLIP, CustomTextCLIP |
| from .transformer import TextTransformer, Transformer |
|
|
|
|
| @torch.no_grad() |
| def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): |
| """ Load weights from .npz checkpoints for official Google big_vision image-text models |
| |
| Currently the SigLIP source models are supported and a CustomTextCLIP destination model |
| w/ timm image encoder. |
| """ |
| from timm.layers import resample_patch_embed, resample_abs_pos_embed |
|
|
| def _n2p(w, t=True): |
| if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: |
| w = w.flatten() |
| if t: |
| if w.ndim == 4: |
| w = w.transpose([3, 2, 0, 1]) |
| elif w.ndim == 3: |
| w = w.transpose([2, 0, 1]) |
| elif w.ndim == 2: |
| w = w.transpose([1, 0]) |
| return torch.from_numpy(w) |
|
|
| w = np.load(checkpoint_path) |
| interpolation = 'bilinear' |
| antialias = False |
|
|
| def _convert_timm_img(module, prefix): |
| embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) |
| if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: |
| embed_conv_w = resample_patch_embed( |
| embed_conv_w, |
| module.patch_embed.proj.weight.shape[-2:], |
| interpolation=interpolation, |
| antialias=antialias, |
| verbose=True, |
| ) |
| module.patch_embed.proj.weight.copy_(embed_conv_w) |
| module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) |
|
|
| if module.cls_token is not None: |
| module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) |
|
|
| pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) |
| if pos_embed_w.shape != module.pos_embed.shape: |
| assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' |
| num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) |
| pos_embed_w = resample_abs_pos_embed( |
| pos_embed_w, |
| new_size=module.patch_embed.grid_size, |
| num_prefix_tokens=num_prefix_tokens, |
| interpolation=interpolation, |
| antialias=antialias, |
| verbose=True, |
| ) |
| module.pos_embed.copy_(pos_embed_w) |
|
|
| mha_sub, b_sub, ln1_sub = (0, 0, 1) |
| for i, block in enumerate(module.blocks.children()): |
| block_prefix = f'{prefix}Transformer/encoderblock_{i}/' |
| mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' |
| block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) |
| block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) |
| block.attn.qkv.weight.copy_(torch.cat([ |
| _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) |
| block.attn.qkv.bias.copy_(torch.cat([ |
| _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) |
| block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) |
| block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) |
| for r in range(2): |
| getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) |
| getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) |
| block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) |
| block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) |
|
|
| module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) |
| module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) |
|
|
| if module.attn_pool is not None: |
| block_prefix = f'{prefix}MAPHead_0/' |
| mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' |
| module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) |
| module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) |
| module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) |
| module.attn_pool.kv.weight.copy_(torch.cat([ |
| _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) |
| module.attn_pool.kv.bias.copy_(torch.cat([ |
| _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) |
| module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) |
| module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) |
| module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) |
| module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) |
| for r in range(2): |
| getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) |
| getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) |
|
|
| def _convert_openclip_transformer(module: Transformer, prefix): |
| for i, block in enumerate(module.resblocks.children()): |
| block_prefix = f'{prefix}encoderblock_{i}/' |
| mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' |
| block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) |
| block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) |
| block.attn.in_proj_weight.copy_(torch.cat([ |
| _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) |
| block.attn.in_proj_bias.copy_(torch.cat([ |
| _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) |
| block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) |
| block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) |
| block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) |
| block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) |
| block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) |
| block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) |
| block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) |
| block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) |
|
|
| def _convert_openclip_txt(module: TextTransformer, prefix): |
| module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) |
| pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) |
| module.positional_embedding.copy_(pos_embed_w) |
| _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') |
| module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) |
| module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) |
| module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) |
| module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) |
|
|
| _convert_timm_img(model.visual.trunk, 'params/img/') |
| _convert_openclip_txt(model.text, 'params/txt/') |
| model.logit_bias.copy_(_n2p(w['params/b'])[0]) |
| model.logit_scale.copy_(_n2p(w['params/t'])[0]) |
|
|
|
|
| @torch.no_grad() |
| def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True): |
|
|
| def _convert_timm_img(state_dict): |
| if fastvit: |
| from timm.models.fastvit import checkpoint_filter_fn |
| else: |
| from timm.models.vision_transformer_hybrid import checkpoint_filter_fn |
| timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk) |
| timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()} |
| return timm_state_dict |
|
|
| def _convert_openclip_txt(state_dict, prefix='text_encoder.'): |
| text_dict = {} |
| for k, v in state_dict.items(): |
| if not k.startswith(prefix): |
| continue |
| k = k.replace(prefix, '') |
| k = k.replace('projection_layer', 'text_projection') |
| k = k.replace('embedding_layer', 'token_embedding') |
| if k.startswith('positional_embedding.pos_embed.pos_embed'): |
| k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding') |
| v = v.squeeze() |
| k = k.replace('final_layer_norm', 'ln_final') |
| k = k.replace('pre_norm_mha.0', 'ln_1') |
| k = k.replace('pre_norm_mha.1', 'attn') |
| k = k.replace('pre_norm_ffn.0', 'ln_2') |
| k = k.replace('pre_norm_ffn.1', 'mlp.c_fc') |
| k = k.replace('pre_norm_ffn.4', 'mlp.c_proj') |
| k = k.replace('qkv_proj.weight', 'in_proj_weight') |
| k = k.replace('qkv_proj.bias', 'in_proj_bias') |
| k = k.replace('transformer.', 'transformer.resblocks.') |
| text_dict['text.' + k] = v |
| return text_dict |
|
|
| image_dict = _convert_timm_img(state_dict) |
| text_dict = _convert_openclip_txt(state_dict) |
| out_dict = {**image_dict, **text_dict} |
| out_dict['logit_scale'] = state_dict['logit_scale'] |
| return out_dict |
|
|
|
|
| def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict): |
| if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict: |
| |
| state_dict = convert_mobile_clip_state_dict(model, state_dict) |
| if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict: |
| |
| state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False) |
| return state_dict |
|
|