Siran-Li commited on
Commit
3a2194a
·
verified ·
1 Parent(s): 799ee64

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +137 -0
model.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model.py — MATCHA contrastive model architecture.
3
+
4
+ ContrastiveModel wraps a pretrained language model backbone and adds a
5
+ SenseNetwork that decomposes word embeddings into multiple "sense" vectors,
6
+ followed by a learned transformation and mean-pooling to produce a single
7
+ sentence embedding for contrastive learning.
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from transformers.pytorch_utils import Conv1D
13
+ from transformers.activations import ACT2FN
14
+ from typing import Optional, Tuple
15
+
16
+
17
+ class ContrastiveModel(nn.Module):
18
+ """Top-level model: backbone word embeddings -> SenseNetwork -> projection.
19
+
20
+ Args:
21
+ contxtl_model: Pretrained HuggingFace model used only for its embedding layer.
22
+ config: SimpleNamespace with model_type, n_embd, num_senses, etc.
23
+ """
24
+
25
+ def __init__(self, contxtl_model, config):
26
+ super().__init__()
27
+ self.sense_network = SenseNetwork(config)
28
+ self.contxtl_model = contxtl_model
29
+
30
+ # Extract the word embedding layer from the backbone
31
+ if config.model_type in ['gpt2', 'gpt_neo', 'roberta', 'xlm-roberta']:
32
+ self.word_embeddings = self.contxtl_model.get_input_embeddings()
33
+ elif config.model_type in ['mistral']:
34
+ self.word_embeddings = self.contxtl_model.model.embed_tokens
35
+
36
+ # Learnable transformation applied to sense vectors before pooling
37
+ self.transformation_matrix = nn.Parameter(torch.randn(config.n_embd, config.n_embd))
38
+
39
+ def get_model_output(self, input_ids):
40
+ """Compute multi-sense embeddings from token IDs."""
41
+ sense_input_embeds = self.word_embeddings(input_ids) # (bs, s, d)
42
+ senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)
43
+ return senses
44
+
45
+ def forward(self, input_ids):
46
+ """Produce a single sentence embedding by mean-pooling transformed senses.
47
+
48
+ Returns:
49
+ embedding: Tensor of shape (bs, d)
50
+ """
51
+ assert not torch.isnan(input_ids).any(), "Input IDs contain NaN values"
52
+
53
+ senses = self.get_model_output(input_ids) # (bs, nv, s, d)
54
+ transformed_senses = senses @ self.transformation_matrix # (bs, nv, s, d)
55
+ embedding = transformed_senses.mean(dim=(1, 2)) # (bs, d)
56
+ return embedding
57
+
58
+
59
+ class MLP(nn.Module):
60
+ """Feed-forward block: linear -> activation -> linear -> dropout.
61
+
62
+ Uses HuggingFace's Conv1D (equivalent to a linear layer applied
63
+ along the last dimension) for compatibility with GPT-2 style configs.
64
+ """
65
+
66
+ def __init__(self, embed_dim, intermediate_dim, out_dim, config):
67
+ super().__init__()
68
+ self.c_fc = Conv1D(intermediate_dim, embed_dim)
69
+ self.c_proj = Conv1D(out_dim, intermediate_dim)
70
+ self.act = ACT2FN[config.activation_function]
71
+ self.dropout = nn.Dropout(config.resid_pdrop)
72
+
73
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
74
+ hidden_states = self.c_fc(hidden_states)
75
+ hidden_states = self.act(hidden_states)
76
+ hidden_states = self.c_proj(hidden_states)
77
+ hidden_states = self.dropout(hidden_states)
78
+ return hidden_states
79
+
80
+
81
+ class NoMixBlock(nn.Module):
82
+ """Transformer-style block *without* attention (no token mixing).
83
+
84
+ Applies two residual sub-layers with layer normalization and dropout,
85
+ where the only transformation is an MLP — tokens are processed independently.
86
+ """
87
+
88
+ def __init__(self, config):
89
+ super().__init__()
90
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
91
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
92
+ self.mlp = MLP(config.n_embd, config.n_embd * 4, config.n_embd, config)
93
+ self.resid_dropout1 = nn.Dropout(config.resid_pdrop)
94
+ self.resid_dropout2 = nn.Dropout(config.resid_pdrop)
95
+
96
+ def forward(self, hidden_states, residual):
97
+ residual = self.resid_dropout1(hidden_states) + residual
98
+ hidden_states = self.ln_1(residual)
99
+ mlp_out = self.mlp(hidden_states)
100
+ residual = self.resid_dropout2(mlp_out) + residual
101
+ hidden_states = self.ln_2(residual)
102
+ return hidden_states
103
+
104
+
105
+ class SenseNetwork(nn.Module):
106
+ """Decomposes token embeddings into multiple sense vectors.
107
+
108
+ Each token is mapped from (d,) to (num_senses, d) via a NoMixBlock
109
+ followed by an MLP that expands the embedding dimension and reshapes.
110
+
111
+ Input: (bs, s, d)
112
+ Output: (bs, num_senses, s, d)
113
+ """
114
+
115
+ def __init__(self, config, device=None, dtype=None):
116
+ super().__init__()
117
+ self.num_senses = config.num_senses
118
+ self.n_embd = config.n_embd
119
+
120
+ self.dropout = nn.Dropout(config.embd_pdrop)
121
+ self.block = NoMixBlock(config)
122
+ self.ln = nn.LayerNorm(self.n_embd, eps=config.layer_norm_epsilon)
123
+ self.final_mlp = MLP(
124
+ embed_dim=config.n_embd,
125
+ intermediate_dim=config.sense_intermediate_scale * config.n_embd,
126
+ out_dim=config.n_embd * config.num_senses,
127
+ config=config,
128
+ )
129
+
130
+ def forward(self, input_embeds):
131
+ residual = self.dropout(input_embeds)
132
+ hidden_states = self.ln(residual)
133
+ hidden_states = self.block(hidden_states, residual)
134
+ senses = self.final_mlp(hidden_states)
135
+ bs, s, nvd = senses.shape
136
+ # Reshape from (bs, s, num_senses*d) -> (bs, num_senses, s, d)
137
+ return senses.reshape(bs, s, self.num_senses, self.n_embd).transpose(1, 2)