davda54 commited on
Commit
ae26905
·
verified ·
1 Parent(s): 80119d2

Simplified the model by always computing batch-first

Browse files
Files changed (1) hide show
  1. modeling_norbert.py +40 -44
modeling_norbert.py CHANGED
@@ -139,10 +139,10 @@ class Attention(nn.Module):
139
  return bucket_pos
140
 
141
  def forward(self, hidden_states, attention_mask, relative_embedding):
142
- key_len, batch_size, _ = hidden_states.size()
143
  query_len = key_len
144
 
145
- # Ensure position indices are large enough
146
  if self.position_indices.size(0) < query_len:
147
  position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
148
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
@@ -150,55 +150,52 @@ class Attention(nn.Module):
150
  position_indices = self.config.position_bucket_size - 1 + position_indices
151
  self.position_indices = position_indices.to(hidden_states.device)
152
 
153
- hidden_states = self.pre_layer_norm(hidden_states)
 
 
 
154
 
155
- # QKV linear projections
156
- query, key = self.in_proj_qk(hidden_states).chunk(2, dim=2) # shape: [T, B, D]
157
- value = self.in_proj_v(hidden_states) # shape: [T, B, D]
 
158
 
159
- # Reshape and transpose for attention computation
160
- query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
161
- key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
162
- value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
 
163
 
164
- # Content-based attention
165
- attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
 
166
 
167
- # Positional embeddings
168
- pos = self.in_proj_qk(self.dropout(relative_embedding)) # [2T-1, 2D]
169
- query_pos, key_pos = pos.view(-1, self.num_heads, 2*self.head_size).chunk(2, dim=2) # [2T-1, H, D]
170
 
171
- # Reshape query and key for positional attention
172
- query = query.view(batch_size, self.num_heads, query_len, self.head_size) # [B, H, Q, D]
173
- key = key.view(batch_size, self.num_heads, key_len, self.head_size) # [B, H, K, D]
 
 
 
174
 
175
- # Get relative position indices
176
- rel_pos = self.position_indices[:query_len, :key_len] # [Q, K]
177
 
178
- # Select positional embeddings based on relative positions
179
- key_pos_rel = key_pos[rel_pos] # [Q, K, H, D]
180
- query_pos_rel = query_pos[rel_pos] # [Q, K, H, D]
181
 
182
- # Compute disentangled attention scores
183
- attention_c_p = torch.einsum("bhqd,qkhd->bhqk", query, key_pos_rel * self.scale) # [B, H, Q, K]
184
- attention_p_c = torch.einsum("bhkd,qkhd->bhqk", key * self.scale, query_pos_rel) # [B, H, Q, K]
185
-
186
- # Combine attention scores
187
- attention_scores = attention_scores.view(batch_size, self.num_heads, query_len, key_len) # [B, H, Q, K]
188
- attention_scores.add_(attention_c_p)
189
- attention_scores.add_(attention_p_c)
190
-
191
- attention_scores = attention_scores.masked_fill(attention_mask, float('-inf'))
192
- attention_probs = F.softmax(attention_scores, dim=-1)
193
-
194
- attention_probs = self.dropout(attention_probs)
195
- context = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D]
196
- context = context.transpose(0, 1).reshape(context.size(1), -1, self.hidden_size) # shape: [Q, B, H*D]
197
- context = self.out_proj(context)
198
- context = self.post_layer_norm(context)
199
- context = self.dropout(context)
200
 
201
- return context, attention_probs.detach()
202
 
203
 
204
  class Embedding(nn.Module):
@@ -281,9 +278,8 @@ class NorbertModel(NorbertPreTrainedModel):
281
  attention_mask = ~attention_mask.bool()
282
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
283
 
284
- static_embeddings, relative_embedding = self.embedding(input_ids.t())
285
  contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
286
- contextualized_embeddings = [e.transpose(0, 1) for e in contextualized_embeddings]
287
  last_layer = contextualized_embeddings[-1]
288
  contextualized_embeddings = [contextualized_embeddings[0]] + [
289
  contextualized_embeddings[i] - contextualized_embeddings[i - 1]
 
139
  return bucket_pos
140
 
141
  def forward(self, hidden_states, attention_mask, relative_embedding):
142
+ batch_size, key_len, _ = hidden_states.size()
143
  query_len = key_len
144
 
145
+ # Recompute position_indices if sequence length exceeds the precomputed size
146
  if self.position_indices.size(0) < query_len:
147
  position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
148
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
 
150
  position_indices = self.config.position_bucket_size - 1 + position_indices
151
  self.position_indices = position_indices.to(hidden_states.device)
152
 
153
+ # Pre-LN and project query/key/value.
154
+ hidden_states = self.pre_layer_norm(hidden_states) # shape: [B, T, D]
155
+ query, key = self.in_proj_qk(hidden_states).chunk(2, dim=-1) # shape: [B, T, D]
156
+ value = self.in_proj_v(hidden_states) # shape: [B, T, D]
157
 
158
+ # Reshape to [B, num_heads, T, head_size]
159
+ query = query.reshape(B, T, self.num_heads, self.head_size).transpose(1, 2) # shape: [B, num_heads, T_q, head_size]
160
+ key = key.reshape(B, T, self.num_heads, self.head_size).permute(0, 2, 3, 1) # shape: [B, num_heads, head_size, T_k]
161
+ value = value.view(B, T, self.num_heads, self.head_size).transpose(1, 2) # shape: [B, num_heads, T_k, head_size]
162
 
163
+ # Compute relative positional contributions
164
+ pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2*position_bucket_size - 1, 2D]
165
+ query_pos, key_pos = pos.view(-1, self.num_heads, 2*self.head_size).chunk(2, dim=2) # shape: [2*position_bucket_size - 1, num_heads, head_size]
166
+ query_pos = query_pos.permute(1, 0, 2) # shape: [num_heads, 2*position_bucket_size - 1, head_size]
167
+ key_pos = key_pos.permute(1, 0, 2) # shape: [num_heads, 2*position_bucket_size - 1, head_size]
168
 
169
+ # Scale the keys
170
+ key = key * self.scale
171
+ key_pos = key_pos * self.scale
172
 
173
+ # Compute standard content-to-content attention scores
174
+ attention_c_to_c = torch.matmul(query, key) # shape: [B, num_heads, T_q, T_k]
 
175
 
176
+ # Compute content-to-position and position-to-content attention scores
177
+ position_indices = self.position_indices[:query_len, :key_len].unsqueeze(0).unsqueeze(0).expand(batch_size, self.num_heads, query_len, key_len)
178
+ attention_c_to_p = torch.matmul(query, key_pos.unsqueeze(0)) # [B, num_heads, T-q, 2*position_bucket_size - 1]
179
+ attention_p_to_c = torch.matmul(query_pos.unsqueeze(0), key) # [B, num_heads, 2*position_bucket_size - 1, T_k]
180
+ attention_c_to_p = attention_c_to_p.gather(3, position_indices) # shape: [B, num_heads, T_q, T_k]
181
+ attention_p_to_c = attention_p_to_c.gather(2, position_indices) # shape: [B, num_heads, T_q, T_k]
182
 
183
+ # Full attention score
184
+ attention_scores = attention_c_to_c + attention_c_to_p + attention_p_to_c # shape: [B, num_heads, T_q, T_k]
185
 
186
+ # Masked softmax
187
+ attention_scores = attention_scores.masked_fill(attention_mask, float('-inf')) # shape: [B, num_heads, T_q, T_k]
188
+ attention_probs = F.softmax(attention_scores, dim=-1) # shape: [B, num_heads, T_q, T_k]
189
 
190
+ # Collect the weighted-averaged values
191
+ attention_probs = self.dropout(attention_probs) # shape: [B, num_heads, T_q, T_k]
192
+ output = torch.matmul(attention_probs, value) # shape: [B, num_heads, T_q, head_size]
193
+ output = output.transpose(1, 2).flatten(2, 3) # shape: [B, T_q, D]
194
+ output = self.out_proj(output)
195
+ output = self.post_layer_norm(output)
196
+ output = self.dropout(output)
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ return output, attention_probs.detach()
199
 
200
 
201
  class Embedding(nn.Module):
 
278
  attention_mask = ~attention_mask.bool()
279
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
280
 
281
+ static_embeddings, relative_embedding = self.embedding(input_ids)
282
  contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
 
283
  last_layer = contextualized_embeddings[-1]
284
  contextualized_embeddings = [contextualized_embeddings[0]] + [
285
  contextualized_embeddings[i] - contextualized_embeddings[i - 1]