Spaces:
Runtime error
Runtime error
| import torch | |
| import clip | |
| class CLIPLoss(torch.nn.Module): | |
| def __init__(self, opts): | |
| super(CLIPLoss, self).__init__() | |
| self.model, self.preprocess = clip.load("ViT-B/32", device="cuda") | |
| self.upsample = torch.nn.Upsample(scale_factor=7) | |
| self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32) | |
| def forward(self, image, text): | |
| image = self.avg_pool(self.upsample(image)) | |
| similarity = 1 - self.model(image, text)[0] / 100 | |
| return similarity |