imbue2025 commited on
Commit
d3d5ea7
·
verified ·
1 Parent(s): a12b533

Create model_arch.py

Browse files
Files changed (1) hide show
  1. model_arch.py +303 -0
model_arch.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple, List, Dict
6
+ import math
7
+
8
+
9
+ @dataclass
10
+ class EmbeddingConfig:
11
+ """Configuration for embedding models"""
12
+ vocab_size: int
13
+ hidden_size: int = 384
14
+ n_layer: int = 6
15
+ n_head: int = 6
16
+ n_kv_head: int = 2
17
+ intermediate_size: int = 1024
18
+ max_seq_len: int = 512
19
+ dropout: float = 0.1
20
+ rms_norm_eps: float = 1e-6
21
+ use_cache: bool = False
22
+ # Embedding-specific
23
+ embedding_dim: int = 384 # Output embedding dimension
24
+ pooling_method: str = "mean" # "mean", "cls", "attention"
25
+ normalize_embeddings: bool = True
26
+ use_temperature_scaling: bool = True
27
+ temperature: float = 0.05
28
+
29
+
30
+ class RoPE(nn.Module):
31
+ """Rotary Position Embedding"""
32
+ def __init__(self, dim: int, base: int = 10000):
33
+ super().__init__()
34
+ self.dim = dim
35
+ self.base = base
36
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
37
+ self.register_buffer("inv_freq", inv_freq)
38
+
39
+ def _rotate_half(self, vector):
40
+ vector1 = vector[..., :vector.shape[-1] // 2]
41
+ vector2 = vector[..., vector.shape[-1] // 2:]
42
+ return torch.cat((-vector2, vector1), dim=-1)
43
+
44
+ def forward(self, q, k, position_ids):
45
+ inv_freq = self.inv_freq.to(dtype=q.dtype)
46
+ freqs = (position_ids.unsqueeze(-1) * inv_freq.unsqueeze(0)).to(q.device)
47
+ emb = torch.cat([freqs, freqs], dim=-1)
48
+ cos = emb.cos().unsqueeze(1).to(q.dtype)
49
+ sin = emb.sin().unsqueeze(1).to(q.dtype)
50
+ q_rot = (q * cos) + (self._rotate_half(q) * sin)
51
+ k_rot = (k * cos) + (self._rotate_half(k) * sin)
52
+ return q_rot, k_rot
53
+
54
+
55
+ class RMSNorm(nn.Module):
56
+ """Root Mean Square Layer Normalization"""
57
+ def __init__(self, dim: int, eps: float = 1e-6):
58
+ super().__init__()
59
+ self.eps = eps
60
+ self.weight = nn.Parameter(torch.ones(dim))
61
+
62
+ def forward(self, x):
63
+ var = torch.mean(x ** 2, dim=-1, keepdim=True)
64
+ x_normed = x * torch.rsqrt(var + self.eps)
65
+ return self.weight * x_normed
66
+
67
+
68
+ class SwiGLU(nn.Module):
69
+ """Gated Linear Unit with SiLU activation"""
70
+ def __init__(self, dim, hidden_dim):
71
+ super().__init__()
72
+ self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False)
73
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
74
+
75
+ def forward(self, x):
76
+ fused_output = self.gate_up_proj(x)
77
+ gate_output, up_output = fused_output.chunk(2, dim=-1)
78
+ return self.down_proj(F.silu(gate_output) * up_output)
79
+
80
+
81
+ class GroupedQueryAttention(nn.Module):
82
+ """Grouped Query Attention mechanism"""
83
+ def __init__(self, config: EmbeddingConfig):
84
+ super().__init__()
85
+ self.num_heads = config.n_head
86
+ self.num_kv_heads = config.n_kv_head
87
+ # Validate head configuration
88
+ if self.num_heads <= 0:
89
+ raise ValueError(f"n_head must be > 0, got {self.num_heads}")
90
+ if self.num_kv_heads <= 0:
91
+ raise ValueError(f"n_kv_head must be > 0, got {self.num_kv_heads}")
92
+ if self.num_heads % self.num_kv_heads != 0:
93
+ raise ValueError(
94
+ f"n_head ({self.num_heads}) must be divisible by n_kv_head ({self.num_kv_heads})"
95
+ )
96
+ if config.hidden_size % self.num_heads != 0:
97
+ raise ValueError(
98
+ f"hidden_size ({config.hidden_size}) must be divisible by n_head ({self.num_heads}). "
99
+ "Choose hidden_size that is multiple of n_head or set n_head accordingly."
100
+ )
101
+
102
+ self.num_head_groups = self.num_heads // self.num_kv_heads
103
+ self.head_dim = config.hidden_size // self.num_heads
104
+ self.dropout = config.dropout
105
+
106
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
107
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
108
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
109
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
110
+
111
+ self.rotary_emb = RoPE(self.head_dim)
112
+ self.q_norm = RMSNorm(self.head_dim)
113
+ self.k_norm = RMSNorm(self.head_dim)
114
+
115
+ def forward(self, x, position_ids, attention_mask=None):
116
+ B, L, D = x.shape
117
+ q = self.q_proj(x).reshape(B, L, self.num_heads, self.head_dim)
118
+ k = self.k_proj(x).reshape(B, L, self.num_kv_heads, self.head_dim)
119
+ v = self.v_proj(x).reshape(B, L, self.num_kv_heads, self.head_dim)
120
+
121
+ q, k = self.q_norm(q), self.k_norm(k)
122
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
123
+
124
+ q, k = self.rotary_emb(q, k, position_ids)
125
+
126
+ # Expand KV heads for grouped attention
127
+ if self.num_head_groups > 1:
128
+ k = k.unsqueeze(2).expand(-1, -1, self.num_head_groups, -1, -1).reshape(B, self.num_heads, -1, self.head_dim)
129
+ v = v.unsqueeze(2).expand(-1, -1, self.num_head_groups, -1, -1).reshape(B, self.num_heads, -1, self.head_dim)
130
+
131
+ # Compute attention scores
132
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
133
+
134
+ if attention_mask is not None:
135
+ scores = scores.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))
136
+
137
+ attn_weights = F.softmax(scores, dim=-1)
138
+ attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
139
+
140
+ out = torch.matmul(attn_weights, v)
141
+ out = out.transpose(1, 2).contiguous().reshape(B, L, D)
142
+ out = self.o_proj(out)
143
+
144
+ return out
145
+
146
+
147
+ class EmbeddingTransformerLayer(nn.Module):
148
+ """Single Transformer layer for embedding model"""
149
+ def __init__(self, config: EmbeddingConfig):
150
+ super().__init__()
151
+ self.attention = GroupedQueryAttention(config)
152
+ self.mlp = SwiGLU(config.hidden_size, config.intermediate_size)
153
+ self.norm1 = RMSNorm(config.hidden_size, config.rms_norm_eps)
154
+ self.norm2 = RMSNorm(config.hidden_size, config.rms_norm_eps)
155
+ self.dropout = nn.Dropout(config.dropout)
156
+
157
+ def forward(self, x, position_ids, attention_mask=None):
158
+ # Pre-norm architecture
159
+ normed_x = self.norm1(x)
160
+ attn_out = self.attention(normed_x, position_ids, attention_mask)
161
+ x = x + self.dropout(attn_out)
162
+
163
+ normed_x = self.norm2(x)
164
+ mlp_out = self.mlp(normed_x)
165
+ x = x + self.dropout(mlp_out)
166
+
167
+ return x
168
+
169
+
170
+ class EmbeddingEncoder(nn.Module):
171
+ """Transformer-based encoder for generating embeddings"""
172
+ def __init__(self, config: EmbeddingConfig):
173
+ super().__init__()
174
+ self.config = config
175
+ self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
176
+ self.position_ids = torch.arange(config.max_seq_len, dtype=torch.long).unsqueeze(0)
177
+
178
+ self.layers = nn.ModuleList([
179
+ EmbeddingTransformerLayer(config) for _ in range(config.n_layer)
180
+ ])
181
+
182
+ self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
183
+ self.embedding_proj = None
184
+ if config.embedding_dim != config.hidden_size:
185
+ self.embedding_proj = nn.Linear(config.hidden_size, config.embedding_dim, bias=False)
186
+
187
+ def forward(self, input_ids, attention_mask=None):
188
+ """
189
+ Args:
190
+ input_ids: (batch_size, seq_len)
191
+ attention_mask: (batch_size, seq_len) - 1 for valid tokens, 0 for padding
192
+
193
+ Returns:
194
+ embeddings: (batch_size, embedding_dim)
195
+ hidden_states: (batch_size, seq_len, hidden_size)
196
+ """
197
+ B, L = input_ids.shape
198
+ device = input_ids.device
199
+
200
+ # Token embedding
201
+ # Sanity check: ensure token ids are within embedding range to avoid CUDA OOB
202
+ if input_ids.numel() > 0:
203
+ max_id = int(input_ids.max().item())
204
+ min_id = int(input_ids.min().item())
205
+ vocab_size = self.token_embedding.num_embeddings
206
+ if min_id < 0 or max_id >= vocab_size:
207
+ raise ValueError(
208
+ f"Input token id out of range: found ids in [{min_id}, {max_id}] but "
209
+ f"embedding vocab_size={vocab_size}. Ensure tokenizer and model vocab sizes match."
210
+ )
211
+ x = self.token_embedding(input_ids)
212
+
213
+ # Position IDs - ensure buffer is long enough
214
+ if L > self.position_ids.size(1):
215
+ # Extend position IDs if needed
216
+ new_position_ids = torch.arange(L, dtype=torch.long, device=device).unsqueeze(0)
217
+ position_ids = new_position_ids
218
+ else:
219
+ position_ids = self.position_ids[:, :L].to(device)
220
+
221
+ # Create attention mask if not provided
222
+ if attention_mask is None:
223
+ attention_mask = torch.ones_like(input_ids)
224
+
225
+ # Transformer layers
226
+ for layer in self.layers:
227
+ x = layer(x, position_ids, attention_mask)
228
+
229
+ # Final normalization
230
+ hidden_states = self.norm(x)
231
+
232
+ # Pooling
233
+ if self.config.pooling_method == "mean":
234
+ # Mean pooling with masking
235
+ mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float()
236
+ sum_embeddings = (hidden_states * mask_expanded).sum(1)
237
+ sum_mask = mask_expanded.sum(1).clamp(min=1e-9)
238
+ embeddings = sum_embeddings / sum_mask
239
+ elif self.config.pooling_method == "cls":
240
+ # Use CLS token (first token)
241
+ embeddings = hidden_states[:, 0, :]
242
+ elif self.config.pooling_method == "attention":
243
+ # Attention-weighted pooling
244
+ attn_weights = F.softmax(
245
+ torch.ones(1, L, device=device) * attention_mask.float().unsqueeze(0),
246
+ dim=-1
247
+ )
248
+ embeddings = torch.matmul(attn_weights, hidden_states).squeeze(1)
249
+ else:
250
+ raise ValueError(f"Unknown pooling method: {self.config.pooling_method}")
251
+
252
+ # Projection to embedding dimension
253
+ if self.embedding_proj is not None:
254
+ embeddings = self.embedding_proj(embeddings)
255
+
256
+ # Normalize embeddings
257
+ if self.config.normalize_embeddings:
258
+ embeddings = F.normalize(embeddings, p=2, dim=1)
259
+
260
+ return embeddings, hidden_states
261
+
262
+
263
+ class DualEmbeddingModel(nn.Module):
264
+ """Dual-encoder architecture for symmetric similarity learning"""
265
+ def __init__(self, config: EmbeddingConfig):
266
+ super().__init__()
267
+ self.config = config
268
+ self.encoder = EmbeddingEncoder(config)
269
+ if config.use_temperature_scaling:
270
+ self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / config.temperature))
271
+ else:
272
+ self.logit_scale = None
273
+
274
+ def forward(self, input_ids_1, input_ids_2=None, attention_mask_1=None, attention_mask_2=None):
275
+ """
276
+ Args:
277
+ input_ids_1: (batch_size, seq_len)
278
+ input_ids_2: (batch_size, seq_len) - if None, returns only embeddings for input_ids_1
279
+ attention_mask_1: (batch_size, seq_len)
280
+ attention_mask_2: (batch_size, seq_len)
281
+
282
+ Returns:
283
+ embeddings_1: (batch_size, embedding_dim)
284
+ embeddings_2: (batch_size, embedding_dim) or None
285
+ """
286
+ embeddings_1, _ = self.encoder(input_ids_1, attention_mask_1)
287
+
288
+ if input_ids_2 is not None:
289
+ embeddings_2, _ = self.encoder(input_ids_2, attention_mask_2)
290
+ return embeddings_1, embeddings_2
291
+
292
+ return embeddings_1, None
293
+
294
+ def compute_similarity(self, embeddings_1, embeddings_2):
295
+ """Compute cosine similarity between embeddings"""
296
+ # embeddings should already be normalized if normalize_embeddings=True
297
+ similarity = torch.matmul(embeddings_1, embeddings_2.t())
298
+
299
+ if self.logit_scale is not None:
300
+ logit_scale = self.logit_scale.exp()
301
+ similarity = similarity * logit_scale
302
+
303
+ return similarity