ll / sparse_embedding.py
kaamd's picture
Upload folder using huggingface_hub
6a48e45 verified
from .zero_neuron import ZeroNeuron
from sentence_transformers.models import Module
import torch
class SparseEmbedding(Module):
""" This module should be applied last (after Pooling, Normalize, etc.) """
config_keys = ["n_in", "init_mean", "init_std", "temperature", "stretch", "eps"]
def __init__(self,
n_in: int,
init_mean: float = 0.5,
init_std: float = 0.01,
temperature: float = 1.0,
stretch: float = 0.1,
eps: float = 1e-6):
super(SparseEmbedding, self).__init__()
self.n_in = n_in
self.init_mean = init_mean
self.init_std = init_std
self.temperature = temperature
self.stretch = stretch
self.eps = eps
self.sparsifyer = ZeroNeuron(
in_features=n_in,
out_features=n_in,
init_mean=init_mean,
init_std=init_std,
temperature=temperature,
stretch=stretch,
eps=eps
)
def forward(self, features, *args, **kwargs):
mask = self.sparsifyer(features["sentence_embedding"], dim=kwargs.get("dim", None))
features["mask"] = mask
features["sparsity_loss"] = self.sparsifyer.l0_norm(features["sentence_embedding"])
return features
def save(self, output_path: str):
self.save_config(output_path)
torch.save(self.sparsifyer.state_dict(), output_path + "/pytorch_model.bin")