""" SPLADE Model for HuggingFace Hub Adapted from: https://github.com/naver/splade """ import torch from transformers import AutoModelForMaskedLM, PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import BaseModelOutput class SpladeConfig(PretrainedConfig): """Configuration class for SPLADE model""" model_type = "splade" def __init__( self, base_model="neuralmind/bert-base-portuguese-cased", aggregation="max", fp16=True, **kwargs ): super().__init__(**kwargs) self.base_model = base_model self.aggregation = aggregation self.fp16 = fp16 class Splade(PreTrainedModel): """ SPLADE model for sparse retrieval. This model produces sparse representations by: 1. Using a MLM head to get vocabulary-sized logits 2. Applying log(1 + ReLU(logits)) 3. Max-pooling over sequence length Usage: from transformers import AutoTokenizer from modeling_splade import Splade model = Splade.from_pretrained("AxelPCG/splade-pt-br") tokenizer = AutoTokenizer.from_pretrained("AxelPCG/splade-pt-br") # Encode query query_tokens = tokenizer("Qual é a capital do Brasil?", return_tensors="pt") with torch.no_grad(): query_vec = model(q_kwargs=query_tokens)["q_rep"] """ config_class = SpladeConfig def __init__(self, config): super().__init__(config) self.config = config # Load base BERT model with MLM head base_model = getattr(config, 'base_model', 'neuralmind/bert-base-portuguese-cased') self.transformer = AutoModelForMaskedLM.from_pretrained(base_model) self.aggregation = getattr(config, 'aggregation', 'max') self.fp16 = getattr(config, 'fp16', True) def encode(self, tokens): """Encode tokens to sparse representation""" # Get MLM logits out = self.transformer(**tokens) logits = out.logits # shape (bs, seq_len, vocab_size) # Apply log(1 + ReLU(x)) relu_log = torch.log1p(torch.relu(logits)) # Apply attention mask attention_mask = tokens["attention_mask"].unsqueeze(-1) masked = relu_log * attention_mask # Aggregate (max or sum) if self.aggregation == "max": values, _ = torch.max(masked, dim=1) return values else: # sum return torch.sum(masked, dim=1) def forward(self, q_kwargs=None, d_kwargs=None, **kwargs): """ Forward pass supporting both query and document encoding. Args: q_kwargs: Query tokens (dict with input_ids, attention_mask) d_kwargs: Document tokens (dict with input_ids, attention_mask) **kwargs: Additional arguments (for compatibility) Returns: dict with 'q_rep' and/or 'd_rep' keys containing sparse vectors """ output = {} if q_kwargs is not None: output["q_rep"] = self.encode(q_kwargs) if d_kwargs is not None: output["d_rep"] = self.encode(d_kwargs) # If neither q_kwargs nor d_kwargs, use kwargs directly if not output and kwargs: output["rep"] = self.encode(kwargs) return output