DiffICM / 1_feature_extractor /models_clip.py
Qiyp's picture
code of stage1 & 3, remove large files
1633fcc
from collections import OrderedDict
from typing import Tuple, Union, Callable
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import trunc_normal_
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
fn(module=module, name=name)
return module
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x.squeeze(0)
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor):
x = x.permute(1, 0, 2) # NLD -> LND
x = self.resblocks(x)
x = x.permute(1, 0, 2) # LND -> NLD
return x
class VisionTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int):
super().__init__()
self.input_resolution = input_resolution
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.mask_token = nn.Parameter(torch.zeros(1, width))
self.ln_post = LayerNorm(width)
self.embed_dim = width
self.patch_size = patch_size
self.init_weights()
def init_weights(self):
trunc_normal_(self.positional_embedding, std=0.02)
nn.init.normal_(self.class_embedding, std=1e-6)
named_apply(init_weights_vit_timm, self)
def prepare_tokens_with_masks(self, x, masks=None):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
return x
def forward_features_list(self, x_list, masks_list):
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
all_x = [self.transformer(t) for t in x]
output = []
for x, masks in zip(all_x, masks_list):
output.append(
{
"x_norm_clstoken": self.ln_post(x[:, 0]),
"x_norm_patchtokens": x[:, 1 :],
"x_prenorm": x,
"masks": masks,
}
)
return output
def forward(self, x: torch.Tensor, masks=None):
if isinstance(x, list):
return self.forward_features_list(x, masks)
x = self.prepare_tokens_with_masks(x, masks)
x = self.transformer(x)
return {
"x_norm_clstoken": self.ln_post(x[:, 0]),
"x_norm_patchtokens": x[:, 1 :],
"x_prenorm": x,
"masks": masks,
}
def init_weights_vit_timm(module: nn.Module, name: str = ""):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def vit_small(patch_size=14, teacher_path=None):
model = VisionTransformer(
input_resolution=224,
patch_size=patch_size,
width=384,
layers=12,
heads=6
)
if teacher_path is not None:
checkpoint = torch.load(teacher_path, map_location='cpu')
if 'state_dict' in checkpoint:
pretrained_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
pretrained_dict = checkpoint['model']
else:
pretrained_dict = checkpoint
missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
print('missing_keys: ', missing_keys)
print('unexpected_keys: ', unexpected_keys)
return model
def vit_base(patch_size=14, teacher_path=None):
model = VisionTransformer(
input_resolution=224,
patch_size=patch_size,
width=768,
layers=12,
heads=12
)
if teacher_path is not None:
checkpoint = torch.load(teacher_path, map_location='cpu')
if 'state_dict' in checkpoint:
pretrained_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
pretrained_dict = checkpoint['model']
else:
pretrained_dict = checkpoint
missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
print('missing_keys: ', missing_keys)
print('unexpected_keys: ', unexpected_keys)
return model
def vit_large(patch_size=14, teacher_path=None):
model = VisionTransformer(
input_resolution=224,
patch_size=patch_size,
width=1024,
layers=24,
heads=16
)
if teacher_path is not None:
checkpoint = torch.load(teacher_path, map_location='cpu')
if 'state_dict' in checkpoint:
pretrained_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
pretrained_dict = checkpoint['model']
else:
pretrained_dict = checkpoint
missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False)
print('missing_keys: ', missing_keys)
print('unexpected_keys: ', unexpected_keys)
return model
if __name__ == "__main__":
import argparse
import clip
import open_clip
from fvcore.nn import FlopCountAnalysis, parameter_count_table
parser = argparse.ArgumentParser(description='PyTorch resnet Training')
args = parser.parse_args()
# with torch.no_grad():
# print(clip.available_models())
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model, preprocess = clip.load('ViT-L/14', device)
# print(model.visual)
# model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion400m_e32')
# model = model.to('cuda')
# for k,v in model.visual.named_parameters():
# print(k, v.shape)
# self_model = VisionTransformer(
# input_resolution=224,
# patch_size=32,
# width=768,
# layers=12,
# heads=12
# )
# print(self_model)
# for k,v in self_model.named_parameters():
# print(k, v.shape)
# new_ckpt = OrderedDict()
# for k,v in model.visual.named_parameters():
# if 'proj' != k:
# print(k)
# new_ckpt[k] = v
# new_ckpt[k] = v
# torch.save(new_ckpt, '/home/qw/yitian/TA-KD/clip_model/clip_l_14_400m.pth')
# model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion400m_e32')
# print(model.visual)
# model = clip_base_32()
# model = clip_base_14()
# print(parameter_count_table(model))
# tensor = torch.rand(1, 3, 224, 224)
# flops = FlopCountAnalysis(model, tensor)
# print("FLOPs: ", flops.total()/1e9)