| from .zero_neuron import ZeroNeuron | |
| from sentence_transformers.models import Module | |
| import torch | |
| class SparseEmbedding(Module): | |
| """ This module should be applied last (after Pooling, Normalize, etc.) """ | |
| config_keys = ["n_in", "init_mean", "init_std", "temperature", "stretch", "eps"] | |
| def __init__(self, | |
| n_in: int, | |
| init_mean: float = 0.5, | |
| init_std: float = 0.01, | |
| temperature: float = 1.0, | |
| stretch: float = 0.1, | |
| eps: float = 1e-6): | |
| super(SparseEmbedding, self).__init__() | |
| self.n_in = n_in | |
| self.init_mean = init_mean | |
| self.init_std = init_std | |
| self.temperature = temperature | |
| self.stretch = stretch | |
| self.eps = eps | |
| self.sparsifyer = ZeroNeuron( | |
| in_features=n_in, | |
| out_features=n_in, | |
| init_mean=init_mean, | |
| init_std=init_std, | |
| temperature=temperature, | |
| stretch=stretch, | |
| eps=eps | |
| ) | |
| def forward(self, features, *args, **kwargs): | |
| mask = self.sparsifyer(features["sentence_embedding"], dim=kwargs.get("dim", None)) | |
| features["mask"] = mask | |
| features["sparsity_loss"] = self.sparsifyer.l0_norm(features["sentence_embedding"]) | |
| return features | |
| def save(self, output_path: str): | |
| self.save_config(output_path) | |
| torch.save(self.sparsifyer.state_dict(), output_path + "/pytorch_model.bin") | |