|
|
""" |
|
|
SPLADE Model for HuggingFace Hub |
|
|
Adapted from: https://github.com/naver/splade |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForMaskedLM, PreTrainedModel, PretrainedConfig |
|
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
|
|
|
|
|
class SpladeConfig(PretrainedConfig): |
|
|
"""Configuration class for SPLADE model""" |
|
|
model_type = "splade" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
base_model="neuralmind/bert-base-portuguese-cased", |
|
|
aggregation="max", |
|
|
fp16=True, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.base_model = base_model |
|
|
self.aggregation = aggregation |
|
|
self.fp16 = fp16 |
|
|
|
|
|
|
|
|
class Splade(PreTrainedModel): |
|
|
""" |
|
|
SPLADE model for sparse retrieval. |
|
|
|
|
|
This model produces sparse representations by: |
|
|
1. Using a MLM head to get vocabulary-sized logits |
|
|
2. Applying log(1 + ReLU(logits)) |
|
|
3. Max-pooling over sequence length |
|
|
|
|
|
Usage: |
|
|
from transformers import AutoTokenizer |
|
|
from modeling_splade import Splade |
|
|
|
|
|
model = Splade.from_pretrained("AxelPCG/splade-pt-br") |
|
|
tokenizer = AutoTokenizer.from_pretrained("AxelPCG/splade-pt-br") |
|
|
|
|
|
# Encode query |
|
|
query_tokens = tokenizer("Qual é a capital do Brasil?", return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
query_vec = model(q_kwargs=query_tokens)["q_rep"] |
|
|
""" |
|
|
config_class = SpladeConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
base_model = getattr(config, 'base_model', 'neuralmind/bert-base-portuguese-cased') |
|
|
self.transformer = AutoModelForMaskedLM.from_pretrained(base_model) |
|
|
self.aggregation = getattr(config, 'aggregation', 'max') |
|
|
self.fp16 = getattr(config, 'fp16', True) |
|
|
|
|
|
def encode(self, tokens): |
|
|
"""Encode tokens to sparse representation""" |
|
|
|
|
|
out = self.transformer(**tokens) |
|
|
logits = out.logits |
|
|
|
|
|
|
|
|
relu_log = torch.log1p(torch.relu(logits)) |
|
|
|
|
|
|
|
|
attention_mask = tokens["attention_mask"].unsqueeze(-1) |
|
|
masked = relu_log * attention_mask |
|
|
|
|
|
|
|
|
if self.aggregation == "max": |
|
|
values, _ = torch.max(masked, dim=1) |
|
|
return values |
|
|
else: |
|
|
return torch.sum(masked, dim=1) |
|
|
|
|
|
def forward(self, q_kwargs=None, d_kwargs=None, **kwargs): |
|
|
""" |
|
|
Forward pass supporting both query and document encoding. |
|
|
|
|
|
Args: |
|
|
q_kwargs: Query tokens (dict with input_ids, attention_mask) |
|
|
d_kwargs: Document tokens (dict with input_ids, attention_mask) |
|
|
**kwargs: Additional arguments (for compatibility) |
|
|
|
|
|
Returns: |
|
|
dict with 'q_rep' and/or 'd_rep' keys containing sparse vectors |
|
|
""" |
|
|
output = {} |
|
|
|
|
|
if q_kwargs is not None: |
|
|
output["q_rep"] = self.encode(q_kwargs) |
|
|
|
|
|
if d_kwargs is not None: |
|
|
output["d_rep"] = self.encode(d_kwargs) |
|
|
|
|
|
|
|
|
if not output and kwargs: |
|
|
output["rep"] = self.encode(kwargs) |
|
|
|
|
|
return output |
|
|
|