File size: 3,433 Bytes
688ac07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
SPLADE Model for HuggingFace Hub
Adapted from: https://github.com/naver/splade
"""

import torch
from transformers import AutoModelForMaskedLM, PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput


class SpladeConfig(PretrainedConfig):
    """Configuration class for SPLADE model"""
    model_type = "splade"
    
    def __init__(
        self,
        base_model="neuralmind/bert-base-portuguese-cased",
        aggregation="max",
        fp16=True,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.base_model = base_model
        self.aggregation = aggregation
        self.fp16 = fp16


class Splade(PreTrainedModel):
    """
    SPLADE model for sparse retrieval.
    
    This model produces sparse representations by:
    1. Using a MLM head to get vocabulary-sized logits
    2. Applying log(1 + ReLU(logits))
    3. Max-pooling over sequence length
    
    Usage:
        from transformers import AutoTokenizer
        from modeling_splade import Splade
        
        model = Splade.from_pretrained("AxelPCG/splade-pt-br")
        tokenizer = AutoTokenizer.from_pretrained("AxelPCG/splade-pt-br")
        
        # Encode query
        query_tokens = tokenizer("Qual é a capital do Brasil?", return_tensors="pt")
        with torch.no_grad():
            query_vec = model(q_kwargs=query_tokens)["q_rep"]
    """
    config_class = SpladeConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        # Load base BERT model with MLM head
        base_model = getattr(config, 'base_model', 'neuralmind/bert-base-portuguese-cased')
        self.transformer = AutoModelForMaskedLM.from_pretrained(base_model)
        self.aggregation = getattr(config, 'aggregation', 'max')
        self.fp16 = getattr(config, 'fp16', True)
        
    def encode(self, tokens):
        """Encode tokens to sparse representation"""
        # Get MLM logits
        out = self.transformer(**tokens)
        logits = out.logits  # shape (bs, seq_len, vocab_size)
        
        # Apply log(1 + ReLU(x))
        relu_log = torch.log1p(torch.relu(logits))
        
        # Apply attention mask
        attention_mask = tokens["attention_mask"].unsqueeze(-1)
        masked = relu_log * attention_mask
        
        # Aggregate (max or sum)
        if self.aggregation == "max":
            values, _ = torch.max(masked, dim=1)
            return values
        else:  # sum
            return torch.sum(masked, dim=1)
    
    def forward(self, q_kwargs=None, d_kwargs=None, **kwargs):
        """
        Forward pass supporting both query and document encoding.
        
        Args:
            q_kwargs: Query tokens (dict with input_ids, attention_mask)
            d_kwargs: Document tokens (dict with input_ids, attention_mask)
            **kwargs: Additional arguments (for compatibility)
        
        Returns:
            dict with 'q_rep' and/or 'd_rep' keys containing sparse vectors
        """
        output = {}
        
        if q_kwargs is not None:
            output["q_rep"] = self.encode(q_kwargs)
        
        if d_kwargs is not None:
            output["d_rep"] = self.encode(d_kwargs)
        
        # If neither q_kwargs nor d_kwargs, use kwargs directly
        if not output and kwargs:
            output["rep"] = self.encode(kwargs)
        
        return output