| """ |
| model.py — MATCHA contrastive model architecture. |
| |
| ContrastiveModel wraps a pretrained language model backbone and adds a |
| SenseNetwork that decomposes word embeddings into multiple "sense" vectors, |
| followed by a learned transformation and mean-pooling to produce a single |
| sentence embedding for contrastive learning. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from transformers.pytorch_utils import Conv1D |
| from transformers.activations import ACT2FN |
| from typing import Optional, Tuple |
|
|
|
|
| class ContrastiveModel(nn.Module): |
| """Top-level model: backbone word embeddings -> SenseNetwork -> projection. |
| |
| Args: |
| contxtl_model: Pretrained HuggingFace model used only for its embedding layer. |
| config: SimpleNamespace with model_type, n_embd, num_senses, etc. |
| """ |
|
|
| def __init__(self, contxtl_model, config): |
| super().__init__() |
| self.sense_network = SenseNetwork(config) |
| self.contxtl_model = contxtl_model |
|
|
| |
| if config.model_type in ['gpt2', 'gpt_neo', 'roberta', 'xlm-roberta']: |
| self.word_embeddings = self.contxtl_model.get_input_embeddings() |
| elif config.model_type in ['mistral']: |
| self.word_embeddings = self.contxtl_model.model.embed_tokens |
|
|
| |
| self.transformation_matrix = nn.Parameter(torch.randn(config.n_embd, config.n_embd)) |
|
|
| def get_model_output(self, input_ids): |
| """Compute multi-sense embeddings from token IDs.""" |
| sense_input_embeds = self.word_embeddings(input_ids) |
| senses = self.sense_network(sense_input_embeds) |
| return senses |
|
|
| def forward(self, input_ids): |
| """Produce a single sentence embedding by mean-pooling transformed senses. |
| |
| Returns: |
| embedding: Tensor of shape (bs, d) |
| """ |
| assert not torch.isnan(input_ids).any(), "Input IDs contain NaN values" |
|
|
| senses = self.get_model_output(input_ids) |
| transformed_senses = senses @ self.transformation_matrix |
| embedding = transformed_senses.mean(dim=(1, 2)) |
| return embedding |
|
|
|
|
| class MLP(nn.Module): |
| """Feed-forward block: linear -> activation -> linear -> dropout. |
| |
| Uses HuggingFace's Conv1D (equivalent to a linear layer applied |
| along the last dimension) for compatibility with GPT-2 style configs. |
| """ |
|
|
| def __init__(self, embed_dim, intermediate_dim, out_dim, config): |
| super().__init__() |
| self.c_fc = Conv1D(intermediate_dim, embed_dim) |
| self.c_proj = Conv1D(out_dim, intermediate_dim) |
| self.act = ACT2FN[config.activation_function] |
| self.dropout = nn.Dropout(config.resid_pdrop) |
|
|
| def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: |
| hidden_states = self.c_fc(hidden_states) |
| hidden_states = self.act(hidden_states) |
| hidden_states = self.c_proj(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
| return hidden_states |
|
|
|
|
| class NoMixBlock(nn.Module): |
| """Transformer-style block *without* attention (no token mixing). |
| |
| Applies two residual sub-layers with layer normalization and dropout, |
| where the only transformation is an MLP — tokens are processed independently. |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
| self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
| self.mlp = MLP(config.n_embd, config.n_embd * 4, config.n_embd, config) |
| self.resid_dropout1 = nn.Dropout(config.resid_pdrop) |
| self.resid_dropout2 = nn.Dropout(config.resid_pdrop) |
|
|
| def forward(self, hidden_states, residual): |
| residual = self.resid_dropout1(hidden_states) + residual |
| hidden_states = self.ln_1(residual) |
| mlp_out = self.mlp(hidden_states) |
| residual = self.resid_dropout2(mlp_out) + residual |
| hidden_states = self.ln_2(residual) |
| return hidden_states |
|
|
|
|
| class SenseNetwork(nn.Module): |
| """Decomposes token embeddings into multiple sense vectors. |
| |
| Each token is mapped from (d,) to (num_senses, d) via a NoMixBlock |
| followed by an MLP that expands the embedding dimension and reshapes. |
| |
| Input: (bs, s, d) |
| Output: (bs, num_senses, s, d) |
| """ |
|
|
| def __init__(self, config, device=None, dtype=None): |
| super().__init__() |
| self.num_senses = config.num_senses |
| self.n_embd = config.n_embd |
|
|
| self.dropout = nn.Dropout(config.embd_pdrop) |
| self.block = NoMixBlock(config) |
| self.ln = nn.LayerNorm(self.n_embd, eps=config.layer_norm_epsilon) |
| self.final_mlp = MLP( |
| embed_dim=config.n_embd, |
| intermediate_dim=config.sense_intermediate_scale * config.n_embd, |
| out_dim=config.n_embd * config.num_senses, |
| config=config, |
| ) |
|
|
| def forward(self, input_embeds): |
| residual = self.dropout(input_embeds) |
| hidden_states = self.ln(residual) |
| hidden_states = self.block(hidden_states, residual) |
| senses = self.final_mlp(hidden_states) |
| bs, s, nvd = senses.shape |
| |
| return senses.reshape(bs, s, self.num_senses, self.n_embd).transpose(1, 2) |
|
|