splade-pt-br / modeling_splade.py
AxelPCG's picture
Upload SPLADE-PT-BR model v1.0.0
688ac07 verified
"""
SPLADE Model for HuggingFace Hub
Adapted from: https://github.com/naver/splade
"""
import torch
from transformers import AutoModelForMaskedLM, PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
class SpladeConfig(PretrainedConfig):
"""Configuration class for SPLADE model"""
model_type = "splade"
def __init__(
self,
base_model="neuralmind/bert-base-portuguese-cased",
aggregation="max",
fp16=True,
**kwargs
):
super().__init__(**kwargs)
self.base_model = base_model
self.aggregation = aggregation
self.fp16 = fp16
class Splade(PreTrainedModel):
"""
SPLADE model for sparse retrieval.
This model produces sparse representations by:
1. Using a MLM head to get vocabulary-sized logits
2. Applying log(1 + ReLU(logits))
3. Max-pooling over sequence length
Usage:
from transformers import AutoTokenizer
from modeling_splade import Splade
model = Splade.from_pretrained("AxelPCG/splade-pt-br")
tokenizer = AutoTokenizer.from_pretrained("AxelPCG/splade-pt-br")
# Encode query
query_tokens = tokenizer("Qual é a capital do Brasil?", return_tensors="pt")
with torch.no_grad():
query_vec = model(q_kwargs=query_tokens)["q_rep"]
"""
config_class = SpladeConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Load base BERT model with MLM head
base_model = getattr(config, 'base_model', 'neuralmind/bert-base-portuguese-cased')
self.transformer = AutoModelForMaskedLM.from_pretrained(base_model)
self.aggregation = getattr(config, 'aggregation', 'max')
self.fp16 = getattr(config, 'fp16', True)
def encode(self, tokens):
"""Encode tokens to sparse representation"""
# Get MLM logits
out = self.transformer(**tokens)
logits = out.logits # shape (bs, seq_len, vocab_size)
# Apply log(1 + ReLU(x))
relu_log = torch.log1p(torch.relu(logits))
# Apply attention mask
attention_mask = tokens["attention_mask"].unsqueeze(-1)
masked = relu_log * attention_mask
# Aggregate (max or sum)
if self.aggregation == "max":
values, _ = torch.max(masked, dim=1)
return values
else: # sum
return torch.sum(masked, dim=1)
def forward(self, q_kwargs=None, d_kwargs=None, **kwargs):
"""
Forward pass supporting both query and document encoding.
Args:
q_kwargs: Query tokens (dict with input_ids, attention_mask)
d_kwargs: Document tokens (dict with input_ids, attention_mask)
**kwargs: Additional arguments (for compatibility)
Returns:
dict with 'q_rep' and/or 'd_rep' keys containing sparse vectors
"""
output = {}
if q_kwargs is not None:
output["q_rep"] = self.encode(q_kwargs)
if d_kwargs is not None:
output["d_rep"] = self.encode(d_kwargs)
# If neither q_kwargs nor d_kwargs, use kwargs directly
if not output and kwargs:
output["rep"] = self.encode(kwargs)
return output