| """ |
| Paraformer model implementation for Hugging Face Transformers. |
| |
| This module implements the Paraformer model for legal document retrieval, |
| based on the paper "Attentive Deep Neural Networks for Legal Document Retrieval". |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import List, Optional, Union, Tuple |
| from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import logging |
|
|
| try: |
| from .configuration_paraformer import ParaformerConfig |
| except ImportError: |
| from configuration_paraformer import ParaformerConfig |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def sparsemax(input_tensor, dim=-1): |
| """ |
| Sparsemax activation function. |
| |
| Args: |
| input_tensor: Input tensor |
| dim: Dimension along which to apply sparsemax |
| |
| Returns: |
| Sparsemax output tensor |
| """ |
| |
| sorted_input, _ = torch.sort(input_tensor, dim=dim, descending=True) |
| |
| |
| input_cumsum = torch.cumsum(sorted_input, dim=dim) - 1 |
| |
| |
| k = torch.arange(1, input_tensor.size(dim) + 1, dtype=input_tensor.dtype, device=input_tensor.device) |
| if dim != -1: |
| shape = [1] * input_tensor.dim() |
| shape[dim] = -1 |
| k = k.view(shape) |
| |
| |
| support = k * sorted_input > input_cumsum |
| |
| |
| support_cumsum = torch.cumsum(support.float(), dim=dim) |
| support_size = torch.sum(support.float(), dim=dim, keepdim=True) |
| |
| |
| tau_cumsum = torch.cumsum(sorted_input * support.float(), dim=dim) |
| tau = (tau_cumsum - 1) / support_size |
| |
| |
| if dim != -1: |
| tau = tau.unsqueeze(dim) |
| |
| |
| output = torch.clamp(input_tensor - tau, min=0) |
| |
| return output |
|
|
|
|
| class ParaformerAttention(nn.Module): |
| """ |
| Attention mechanism for Paraformer model. |
| |
| This implements a general attention mechanism with optional sparsemax activation. |
| """ |
| |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.use_sparsemax = config.use_sparsemax |
| |
| |
| if config.attention_type == "general": |
| self.attention_weights = nn.Linear(config.hidden_size, 1, bias=False) |
| else: |
| raise ValueError(f"Unsupported attention type: {config.attention_type}") |
| |
| def forward(self, query_embedding, sentence_embeddings, attention_mask=None): |
| """ |
| Apply attention mechanism. |
| |
| Args: |
| query_embedding: Query embedding tensor [batch_size, hidden_size] |
| sentence_embeddings: Sentence embeddings [batch_size, num_sentences, hidden_size] |
| attention_mask: Mask for padding sentences [batch_size, num_sentences] |
| |
| Returns: |
| attended_output: Weighted combination of sentence embeddings |
| attention_weights: Attention weights for interpretability |
| """ |
| batch_size, num_sentences, hidden_size = sentence_embeddings.shape |
| |
| |
| query_expanded = query_embedding.unsqueeze(1).expand(-1, num_sentences, -1) |
| |
| |
| |
| combined = query_expanded * sentence_embeddings |
| attention_scores = self.attention_weights(combined).squeeze(-1) |
| |
| |
| if attention_mask is not None: |
| attention_scores = attention_scores.masked_fill(~attention_mask, float('-inf')) |
| |
| |
| if self.use_sparsemax: |
| attention_weights = sparsemax(attention_scores, dim=-1) |
| else: |
| attention_weights = F.softmax(attention_scores, dim=-1) |
| |
| |
| attended_output = torch.sum(attention_weights.unsqueeze(-1) * sentence_embeddings.clone(), dim=1) |
| |
| return attended_output, attention_weights |
|
|
|
|
| class ParaformerModel(PreTrainedModel): |
| """ |
| Paraformer model for legal document retrieval. |
| |
| This model uses a hierarchical approach with attention mechanism to encode legal documents |
| and queries for relevance classification. |
| """ |
|
|
| config_class = ParaformerConfig |
| base_model_prefix = "paraformer" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["ParaformerAttention"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| self._sentence_encoder = None |
| |
| |
| self.attention = ParaformerAttention(config) |
| |
| |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| self.dropout = nn.Dropout(config.dropout_prob) |
|
|
| |
| self.post_init() |
|
|
| @property |
| def sentence_encoder(self): |
| """Lazy loading of SentenceTransformer to avoid meta tensor issues""" |
| if self._sentence_encoder is None: |
| from sentence_transformers import SentenceTransformer |
| self._sentence_encoder = SentenceTransformer(self.config.base_model_name) |
| return self._sentence_encoder |
|
|
| def forward( |
| self, |
| query_texts: Optional[List[str]] = None, |
| article_texts: Optional[List[List[str]]] = None, |
| labels: Optional[torch.Tensor] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs |
| ): |
| """ |
| Forward pass of the Paraformer model. |
| |
| Args: |
| query_texts: List of query strings |
| article_texts: List of article sentence lists |
| labels: Optional labels for training |
| return_dict: Whether to return a dictionary |
| |
| Returns: |
| Model outputs including logits and optional loss |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
| if query_texts is None or article_texts is None: |
| raise ValueError("Both query_texts and article_texts must be provided") |
| |
| batch_size = len(query_texts) |
| device = next(self.parameters()).device |
| |
| |
| query_embeddings = self.sentence_encoder.encode( |
| query_texts, |
| convert_to_tensor=True, |
| device=device |
| ).clone() |
| |
| |
| all_attended_outputs = [] |
| all_attention_weights = [] |
| |
| for i, article in enumerate(article_texts): |
| if not article: |
| attended_output = torch.zeros(self.config.hidden_size, device=device) |
| attention_weights = torch.zeros(1, device=device) |
| else: |
| |
| sentence_embeddings = self.sentence_encoder.encode( |
| article, |
| convert_to_tensor=True, |
| device=device |
| ).clone() |
| |
| |
| if sentence_embeddings.dim() == 2: |
| sentence_embeddings = sentence_embeddings.unsqueeze(0) |
| |
| |
| attended_output, attention_weights = self.attention( |
| query_embeddings[i:i+1], |
| sentence_embeddings |
| ) |
| attended_output = attended_output.squeeze(0) |
| attention_weights = attention_weights.squeeze(0) |
| |
| all_attended_outputs.append(attended_output) |
| all_attention_weights.append(attention_weights) |
| |
| |
| attended_outputs = torch.stack(all_attended_outputs) |
| |
| |
| attended_outputs = self.dropout(attended_outputs) |
| logits = self.classifier(attended_outputs) |
| |
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
| |
| if not return_dict: |
| output = (logits,) + (all_attention_weights,) |
| return ((loss,) + output) if loss is not None else output |
| |
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=None, |
| attentions=torch.stack([w.unsqueeze(0) for w in all_attention_weights]) if all_attention_weights else None, |
| ) |
|
|
| def get_relevance_score(self, query: str, article: List[str]) -> float: |
| """ |
| Get relevance score for a single query-article pair. |
| |
| Args: |
| query: Query string |
| article: List of article sentences |
| |
| Returns: |
| Relevance score between 0 and 1 |
| """ |
| self.eval() |
| with torch.no_grad(): |
| outputs = self.forward( |
| query_texts=[query], |
| article_texts=[article], |
| return_dict=True |
| ) |
| |
| probabilities = torch.softmax(outputs.logits, dim=-1) |
| relevance_score = probabilities[0, 1].item() |
| |
| return relevance_score |
|
|
| def predict_relevance(self, query: str, article: List[str]) -> int: |
| """ |
| Predict binary relevance for a single query-article pair. |
| |
| Args: |
| query: Query string |
| article: List of article sentences |
| |
| Returns: |
| Binary prediction (0 = not relevant, 1 = relevant) |
| """ |
| self.eval() |
| with torch.no_grad(): |
| outputs = self.forward( |
| query_texts=[query], |
| article_texts=[article], |
| return_dict=True |
| ) |
| |
| prediction = torch.argmax(outputs.logits, dim=-1).item() |
| |
| return prediction |
|
|
| def _init_weights(self, module): |
| """Initialize the weights""" |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
|
|