Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import clip | |
| import copy | |
| from torch.autograd import Function | |
| from collections import OrderedDict | |
| from torchvision import transforms | |
| def convert_state_dict(state_dict): | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| if k.startswith("module."): | |
| k = k.replace("module.", "") | |
| new_state_dict[k] = v | |
| return new_state_dict | |
| def convert_weights_float(model: nn.Module): | |
| """Convert applicable model parameters to fp32""" | |
| def _convert_weights_to_fp32(l): | |
| if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): | |
| l.weight.data = l.weight.data.float() | |
| if l.bias is not None: | |
| l.bias.data = l.bias.data.float() | |
| if isinstance(l, nn.MultiheadAttention): | |
| for attr in [ | |
| *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], | |
| "in_proj_bias", | |
| "bias_k", | |
| "bias_v", | |
| ]: | |
| tensor = getattr(l, attr) | |
| if tensor is not None: | |
| tensor.data = tensor.data.float() | |
| for name in ["text_projection", "proj"]: | |
| if hasattr(l, name): | |
| attr = getattr(l, name) | |
| if attr is not None: | |
| attr.data = attr.data.float() | |
| model.apply(_convert_weights_to_fp32) | |
| class ReverseLayerF(Function): | |
| def forward(ctx, x, alpha): | |
| ctx.alpha = alpha | |
| return x.view_as(x) | |
| def backward(ctx, grad_output): | |
| output = grad_output.neg() * ctx.alpha | |
| return output, None | |
| ## taken from https://github.com/moein-shariatnia/OpenAI-CLIP/blob/master/modules.py | |
| class ProjectionHead(nn.Module): | |
| def __init__(self, embedding_dim, projection_dim, dropout=0): | |
| super().__init__() | |
| self.projection = nn.Linear(embedding_dim, projection_dim) | |
| self.gelu = nn.GELU() | |
| self.fc = nn.Linear(projection_dim, projection_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.layer_norm = nn.LayerNorm(projection_dim) | |
| def forward(self, x): | |
| projected = self.projection(x) | |
| x = self.gelu(projected) | |
| x = self.fc(x) | |
| x = self.dropout(x) | |
| x = x + projected | |
| x = self.layer_norm(x) | |
| return x | |
| def init_weights(m): # TODO: do we need init for layernorm? | |
| if isinstance(m, nn.Linear): | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.normal_(m.bias, std=1e-6) | |
| class CSD_CLIP(nn.Module): | |
| """backbone + projection head""" | |
| def __init__(self, name="vit_large", content_proj_head="default", model_path=None): | |
| super(CSD_CLIP, self).__init__() | |
| self.content_proj_head = content_proj_head | |
| if name == "vit_large": | |
| if model_path is None: | |
| clipmodel, _ = clip.load("ViT-L/14") | |
| else: | |
| clipmodel, _ = clip.load(model_path) | |
| self.backbone = clipmodel.visual | |
| self.embedding_dim = 1024 | |
| elif name == "vit_base": | |
| if model_path is None: | |
| clipmodel, _ = clip.load("ViT-B/16") | |
| else: | |
| clipmodel, _ = clip.load(model_path) | |
| self.backbone = clipmodel.visual | |
| self.embedding_dim = 768 | |
| self.feat_dim = 512 | |
| else: | |
| raise Exception("This model is not implemented") | |
| convert_weights_float(self.backbone) | |
| self.last_layer_style = copy.deepcopy(self.backbone.proj) | |
| if content_proj_head == "custom": | |
| self.last_layer_content = ProjectionHead(self.embedding_dim, self.feat_dim) | |
| self.last_layer_content.apply(init_weights) | |
| else: | |
| self.last_layer_content = copy.deepcopy(self.backbone.proj) | |
| self.backbone.proj = None | |
| def dtype(self): | |
| return self.backbone.conv1.weight.dtype | |
| def forward(self, input_data, alpha=None): | |
| feature = self.backbone(input_data) | |
| if alpha is not None: | |
| reverse_feature = ReverseLayerF.apply(feature, alpha) | |
| else: | |
| reverse_feature = feature | |
| style_output = feature @ self.last_layer_style | |
| style_output = nn.functional.normalize(style_output, dim=1, p=2) | |
| # if alpha is not None: | |
| if self.content_proj_head == "custom": | |
| content_output = self.last_layer_content(reverse_feature) | |
| else: | |
| content_output = reverse_feature @ self.last_layer_content | |
| content_output = nn.functional.normalize(content_output, dim=1, p=2) | |
| return feature, content_output, style_output | |
| def create_model_and_transforms(model_path="models/csd_clip.pth"): | |
| # init model | |
| model = CSD_CLIP("vit_large", "default") | |
| # load model | |
| checkpoint = torch.load(model_path, map_location="cpu") | |
| state_dict = convert_state_dict(checkpoint["model_state_dict"]) | |
| model.load_state_dict(state_dict, strict=False) | |
| # normalization | |
| normalize = transforms.Normalize( | |
| (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) | |
| ) | |
| preprocess = transforms.Compose( | |
| [ | |
| transforms.Resize( | |
| size=224, interpolation=transforms.functional.InterpolationMode.BICUBIC | |
| ), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| return model, preprocess, preprocess | |