| import torch |
| from torch import nn |
| import numpy as np |
|
|
|
|
| class SupportSets(nn.Module): |
| def __init__(self, prompt_features=None, num_support_sets=None, num_support_dipoles=None, support_vectors_dim=None, |
| lss_beta=0.5, css_beta=0.5, jung_radius=None): |
| """SupportSets class constructor. |
| |
| Args: |
| prompt_features (torch.Tensor) : CLIP text feature statistics of prompts from the given corpus |
| num_support_sets (int) : number of support sets (each one defining a warping function) |
| num_support_dipoles (int) : number of support dipoles per support set (per warping function) |
| support_vectors_dim (int) : dimensionality of support vectors (latent space dimensionality, z_dim) |
| lss_beta (float) : set beta parameter for initializing latent space RBFs' gamma parameters |
| (0.25 < lss_beta < 1.0) |
| css_beta (float) : set beta parameter for fixing CLIP space RBFs' gamma parameters |
| (0.25 <= css_beta < 1.0) |
| jung_radius (float) : radius of the minimum enclosing ball of a set of a set of 10K latent codes |
| """ |
| super(SupportSets, self).__init__() |
| self.prompt_features = prompt_features |
|
|
| |
| |
| |
| |
| |
| if self.prompt_features is not None: |
| |
| self.num_support_sets = self.prompt_features.shape[0] |
| self.num_support_dipoles = 1 |
| self.support_vectors_dim = self.prompt_features.shape[2] |
| self.css_beta = css_beta |
|
|
| |
| |
| |
| self.SUPPORT_SETS = nn.Parameter(data=torch.ones(self.num_support_sets, |
| 2 * self.num_support_dipoles * self.support_vectors_dim), |
| requires_grad=False) |
| self.SUPPORT_SETS.data = self.prompt_features.reshape(self.prompt_features.shape[0], |
| self.prompt_features.shape[1] * |
| self.prompt_features.shape[2]).clone() |
|
|
| |
| |
| |
| |
| self.ALPHAS = torch.zeros(self.num_support_sets, 2 * self.num_support_dipoles) |
| for k in range(self.num_support_sets): |
| a = [] |
| for _ in range(self.num_support_dipoles): |
| a.extend([1, -1]) |
| self.ALPHAS[k] = torch.Tensor(a) |
|
|
| |
| |
| |
| |
| self.LOGGAMMA = nn.Parameter(data=torch.ones(self.num_support_sets, 1), requires_grad=False) |
| for k in range(self.num_support_sets): |
| g = -np.log(self.css_beta) / (self.prompt_features[k, 1] - self.prompt_features[k, 0]).norm() ** 2 |
| self.LOGGAMMA.data[k] = torch.log(torch.Tensor([g])) |
|
|
| |
| |
| |
| |
| |
| else: |
| |
| if num_support_sets is None: |
| raise ValueError("Number of latent support sets not defined.") |
| else: |
| self.num_support_sets = num_support_sets |
| if num_support_dipoles is None: |
| raise ValueError("Number of latent support dipoles not defined.") |
| else: |
| self.num_support_dipoles = num_support_dipoles |
| if support_vectors_dim is None: |
| raise ValueError("Latent support vector dimensionality not defined.") |
| else: |
| self.support_vectors_dim = support_vectors_dim |
| if jung_radius is None: |
| raise ValueError("Jung radius not given.") |
| else: |
| self.jung_radius = jung_radius |
| self.lss_beta = lss_beta |
|
|
| |
| |
| |
| |
| self.r_min = 0.90 * self.jung_radius |
| self.r_max = 1.25 * self.jung_radius |
| self.radii = torch.arange(self.r_min, self.r_max, (self.r_max - self.r_min) / self.num_support_sets) |
| self.SUPPORT_SETS = nn.Parameter(data=torch.ones(self.num_support_sets, |
| 2 * self.num_support_dipoles * self.support_vectors_dim)) |
| SUPPORT_SETS = torch.zeros(self.num_support_sets, 2 * self.num_support_dipoles, self.support_vectors_dim) |
| for k in range(self.num_support_sets): |
| SV_set = [] |
| for i in range(self.num_support_dipoles): |
| SV = torch.randn(1, self.support_vectors_dim) |
| SV_set.extend([SV, -SV]) |
| SV_set = torch.cat(SV_set) |
| SV_set = self.radii[k] * SV_set / torch.norm(SV_set, dim=1, keepdim=True) |
| SUPPORT_SETS[k, :] = SV_set |
|
|
| |
| self.SUPPORT_SETS.data = SUPPORT_SETS.reshape( |
| self.num_support_sets, 2 * self.num_support_dipoles * self.support_vectors_dim).clone() |
|
|
| |
| |
| |
| |
| self.ALPHAS = torch.zeros(self.num_support_sets, 2 * self.num_support_dipoles) |
| for k in range(self.num_support_sets): |
| a = [] |
| for _ in range(self.num_support_dipoles): |
| a.extend([1, -1]) |
| self.ALPHAS.data[k] = torch.Tensor(a) |
|
|
| |
| |
| |
| |
| self.LOGGAMMA = nn.Parameter(data=torch.ones(self.num_support_sets, 1)) |
| for k in range(self.num_support_sets): |
| g = -np.log(self.lss_beta) / ((2 * self.radii[k]) ** 2) |
| self.LOGGAMMA.data[k] = torch.log(torch.Tensor([g])) |
|
|
| def forward(self, support_sets_mask, z): |
| |
| support_sets_batch = torch.matmul(support_sets_mask, self.SUPPORT_SETS) |
| support_sets_batch = support_sets_batch.reshape(-1, 2 * self.num_support_dipoles, self.support_vectors_dim) |
|
|
| |
| alphas_batch = torch.matmul(support_sets_mask, self.ALPHAS).unsqueeze(dim=2) |
|
|
| |
| gammas_batch = torch.exp(torch.matmul(support_sets_mask, self.LOGGAMMA).unsqueeze(dim=2)) |
|
|
| |
| D = z.unsqueeze(dim=1).repeat(1, 2 * self.num_support_dipoles, 1) - support_sets_batch |
|
|
| grad_f = -2 * (alphas_batch * gammas_batch * |
| torch.exp(-gammas_batch * (torch.norm(D, dim=2) ** 2).unsqueeze(dim=2)) * D).sum(dim=1) |
|
|
| return grad_f / torch.norm(grad_f, dim=1, keepdim=True) |
|
|