lsnu's picture
Add files using upload-large-folder tool
5ce8761 verified
# Adapted from https://github.com/openai/CLIP/blob/main/clip/model.py
import torch
from torch import nn
import clip
from clip.model import ModifiedResNet
class CLIPTransform(nn.Module):
def __init__(self):
super().__init__()
_mean = [0.48145466, 0.4578275, 0.40821073]
_std = [0.26862954, 0.26130258, 0.27577711]
self.register_buffer("mean", torch.tensor(_mean).reshape(1, -1, 1, 1))
self.register_buffer("std", torch.tensor(_std).reshape(1, -1, 1, 1))
def forward(self, img):
return (img - self.mean) / self.std
def load_clip():
clip_model, clip_transforms = clip.load("RN50")
state_dict = clip_model.state_dict()
layers = tuple([len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}")))
for b in [1, 2, 3, 4]])
output_dim = state_dict["text_projection"].shape[1]
heads = state_dict["visual.layer1.0.conv1.weight"].shape[0] * 32 // 64
backbone = ModifiedResNetFeatures(layers, output_dim, heads)
backbone.load_state_dict(clip_model.visual.state_dict())
# normalize = clip_transforms.transforms[-1]
normalize = CLIPTransform()
return backbone, normalize
class ModifiedResNetFeatures(ModifiedResNet):
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__(layers, output_dim, heads, input_resolution, width)
def forward(self, x: torch.Tensor):
x = x.type(self.conv1.weight.dtype)
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x0 = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x0)
x1 = self.layer1(x)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
return {
"res1": x0,
"res2": x1,
"res3": x2,
"res4": x3,
"res5": x4,
}