| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from copy import deepcopy | |
| from clip.model import VisionTransformer | |
| from typing import Tuple | |
| class CSD(nn.Module): | |
| def __init__( | |
| self, | |
| vit_input_resolution: int = 224, | |
| vit_patch_size: int = 14, | |
| vit_width: int = 1024, | |
| vit_layers: int = 768, | |
| vit_heads: int = 16, | |
| vit_output_dim: int = 768, | |
| ) -> None: | |
| super(CSD, self).__init__() | |
| self.backbone = VisionTransformer( | |
| input_resolution=vit_input_resolution, | |
| patch_size=vit_patch_size, | |
| width=vit_width, | |
| layers=vit_layers, | |
| heads=vit_heads, | |
| output_dim=vit_output_dim, | |
| ) | |
| self.last_layer_style = deepcopy(self.backbone.proj) | |
| self.last_layer_content = deepcopy(self.backbone.proj) | |
| self.backbone.proj = None | |
| def forward(self, pixel_values: torch.Tensor) -> Tuple[torch.Tensor]: | |
| features = self.backbone(pixel_values) | |
| style_output = features @ self.last_layer_style | |
| style_output = F.normalize(style_output, dim=1, p=2) | |
| content_output = features @ self.last_layer_content | |
| content_output = F.normalize(content_output, dim=1, p=2) | |
| return features, style_output, content_output | |