File size: 1,509 Bytes
6a48e45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from .zero_neuron import ZeroNeuron
from sentence_transformers.models import Module
import torch

class SparseEmbedding(Module):
    """ This module should be applied last (after Pooling, Normalize, etc.) """
    config_keys = ["n_in", "init_mean", "init_std", "temperature", "stretch", "eps"]
    def __init__(self,
                 n_in: int,
                 init_mean: float = 0.5,
                 init_std: float = 0.01,
                 temperature: float = 1.0,
                 stretch: float = 0.1,
                 eps: float = 1e-6):
        super(SparseEmbedding, self).__init__()
        self.n_in = n_in
        self.init_mean = init_mean
        self.init_std = init_std
        self.temperature = temperature
        self.stretch = stretch
        self.eps = eps
        self.sparsifyer = ZeroNeuron(
            in_features=n_in,
            out_features=n_in,
            init_mean=init_mean,
            init_std=init_std,
            temperature=temperature,
            stretch=stretch,
            eps=eps
            )
        
    def forward(self, features, *args, **kwargs):
        mask = self.sparsifyer(features["sentence_embedding"], dim=kwargs.get("dim", None))
        features["mask"] = mask
        features["sparsity_loss"] = self.sparsifyer.l0_norm(features["sentence_embedding"])

        return features
    
    def save(self, output_path: str):
        self.save_config(output_path)
        torch.save(self.sparsifyer.state_dict(), output_path + "/pytorch_model.bin")