import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Optional from torch import Tensor from transformers import PretrainedConfig, PreTrainedModel # ---------------- CONFIG ---------------- # class MMNLIConfig(PretrainedConfig): model_type = "mmnli" def __init__( self, embedding_dim: int = 1024, hidden_dims: Optional[List[int]] = None, dropout: float = 0.1, activation: str = "TANH", norm_emb: bool = True, **kwargs, ): super().__init__(**kwargs) self.embedding_dim = embedding_dim self.hidden_dims = hidden_dims if hidden_dims is not None else [3072, 1536] self.dropout = dropout self.activation = activation self.norm_emb = norm_emb self.output_dim = 3 # entailment, contradiction, neutral # ---------------- CORE MODEL ---------------- # ACTIVATIONS = {"TANH": nn.Tanh, "RELU": nn.ReLU} class MMNLICore(nn.Module): def __init__( self, embedding_dim: int, hidden_dims: List[int], dropout: float, activation: str, norm_emb: bool, ): super().__init__() self.norm_emb = norm_emb if activation not in ACTIVATIONS: raise ValueError(f"Unrecognized activation: {activation}") # Input: concatenation of [p, h, p*h, |p-h|] => 4 * embedding_dim input_dim = embedding_dim * 4 modules: List[nn.Module] = [] if dropout > 0: modules.append(nn.Dropout(p=dropout)) nprev = input_dim for h in hidden_dims: modules.append(nn.Linear(nprev, h)) modules.append(ACTIVATIONS[activation]()) if dropout > 0: modules.append(nn.Dropout(p=dropout)) nprev = h # Final classifier layer: 3-way softmax modules.append(nn.Linear(nprev, 3)) modules.append(nn.Softmax(dim=-1)) self.mlp = nn.Sequential(*modules) def _norm(self, emb: Optional[Tensor]) -> Optional[Tensor]: return F.normalize(emb) if (emb is not None and self.norm_emb) else emb def featurize(self, premise: Tensor, hypothesis: Tensor) -> Tensor: return torch.cat( [premise, hypothesis, premise * hypothesis, torch.abs(premise - hypothesis)], dim=-1, ) # ---------------- HF MODEL WRAPPER ---------------- # class MMNLIModel(PreTrainedModel): config_class = MMNLIConfig def __init__(self, config: MMNLIConfig): super().__init__(config) self.core = MMNLICore( embedding_dim=config.embedding_dim, hidden_dims=config.hidden_dims, dropout=config.dropout, activation=config.activation, norm_emb=config.norm_emb, ) def forward(self, premise: Tensor, hypothesis: Tensor): premise = self.core._norm(premise) hypothesis = self.core._norm(hypothesis) proc = self.core.featurize(premise, hypothesis) return self.core.mlp(proc)