| | import torch |
| | import torch.nn as nn |
| | from torch.nn.init import xavier_normal_ |
| | from sklearn.cluster import KMeans |
| |
|
| |
|
| | class MLPLayers(nn.Module): |
| |
|
| | def __init__( |
| | self, layers, dropout=0.0, activation="relu", bn=False |
| | ): |
| | super(MLPLayers, self).__init__() |
| | self.layers = layers |
| | self.dropout = dropout |
| | self.activation = activation |
| | self.use_bn = bn |
| |
|
| | mlp_modules = [] |
| | for idx, (input_size, output_size) in enumerate( |
| | zip(self.layers[:-1], self.layers[1:]) |
| | ): |
| | mlp_modules.append(nn.Dropout(p=self.dropout)) |
| | mlp_modules.append(nn.Linear(input_size, output_size)) |
| | if self.use_bn: |
| | mlp_modules.append(nn.BatchNorm1d(num_features=output_size)) |
| | activation_func = activation_layer(self.activation, output_size) |
| | if activation_func is not None and idx != (len(self.layers)-2): |
| | mlp_modules.append(activation_func) |
| |
|
| | self.mlp_layers = nn.Sequential(*mlp_modules) |
| | self.apply(self.init_weights) |
| |
|
| | def init_weights(self, module): |
| | |
| | if isinstance(module, nn.Linear): |
| | xavier_normal_(module.weight.data) |
| | if module.bias is not None: |
| | module.bias.data.fill_(0.0) |
| |
|
| | def forward(self, input_feature): |
| | return self.mlp_layers(input_feature) |
| |
|
| | def activation_layer(activation_name="relu", emb_dim=None): |
| |
|
| | if activation_name is None: |
| | activation = None |
| | elif isinstance(activation_name, str): |
| | if activation_name.lower() == "sigmoid": |
| | activation = nn.Sigmoid() |
| | elif activation_name.lower() == "tanh": |
| | activation = nn.Tanh() |
| | elif activation_name.lower() == "relu": |
| | activation = nn.ReLU() |
| | elif activation_name.lower() == "leakyrelu": |
| | activation = nn.LeakyReLU() |
| | elif activation_name.lower() == "none": |
| | activation = None |
| | elif issubclass(activation_name, nn.Module): |
| | activation = activation_name() |
| | else: |
| | raise NotImplementedError( |
| | "activation function {} is not implemented".format(activation_name) |
| | ) |
| |
|
| | return activation |
| |
|
| | def kmeans( |
| | samples, |
| | num_clusters, |
| | num_iters = 10, |
| | ): |
| | B, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device |
| | x = samples.cpu().detach().numpy() |
| |
|
| | cluster = KMeans(n_clusters = num_clusters, max_iter = num_iters).fit(x) |
| |
|
| | centers = cluster.cluster_centers_ |
| | tensor_centers = torch.from_numpy(centers).to(device) |
| |
|
| | return tensor_centers |
| |
|
| |
|
| | @torch.no_grad() |
| | def sinkhorn_algorithm(distances, epsilon, sinkhorn_iterations): |
| | Q = torch.exp(- distances / epsilon) |
| |
|
| | B = Q.shape[0] |
| | K = Q.shape[1] |
| |
|
| | |
| | sum_Q = Q.sum(-1, keepdim=True).sum(-2, keepdim=True) |
| | Q /= sum_Q |
| | |
| | for it in range(sinkhorn_iterations): |
| |
|
| | |
| | Q /= torch.sum(Q, dim=1, keepdim=True) |
| | Q /= B |
| |
|
| | |
| | Q /= torch.sum(Q, dim=0, keepdim=True) |
| | Q /= K |
| |
|
| |
|
| | Q *= B |
| | return Q |