from .zero_neuron import ZeroNeuron from sentence_transformers.models import Module import torch class SparseEmbedding(Module): """ This module should be applied last (after Pooling, Normalize, etc.) """ config_keys = ["n_in", "init_mean", "init_std", "temperature", "stretch", "eps"] def __init__(self, n_in: int, init_mean: float = 0.5, init_std: float = 0.01, temperature: float = 1.0, stretch: float = 0.1, eps: float = 1e-6): super(SparseEmbedding, self).__init__() self.n_in = n_in self.init_mean = init_mean self.init_std = init_std self.temperature = temperature self.stretch = stretch self.eps = eps self.sparsifyer = ZeroNeuron( in_features=n_in, out_features=n_in, init_mean=init_mean, init_std=init_std, temperature=temperature, stretch=stretch, eps=eps ) def forward(self, features, *args, **kwargs): mask = self.sparsifyer(features["sentence_embedding"], dim=kwargs.get("dim", None)) features["mask"] = mask features["sparsity_loss"] = self.sparsifyer.l0_norm(features["sentence_embedding"]) return features def save(self, output_path: str): self.save_config(output_path) torch.save(self.sparsifyer.state_dict(), output_path + "/pytorch_model.bin")