Simplified the model by always computing batch-first
Browse files- modeling_norbert.py +40 -44
modeling_norbert.py
CHANGED
|
@@ -139,10 +139,10 @@ class Attention(nn.Module):
|
|
| 139 |
return bucket_pos
|
| 140 |
|
| 141 |
def forward(self, hidden_states, attention_mask, relative_embedding):
|
| 142 |
-
|
| 143 |
query_len = key_len
|
| 144 |
|
| 145 |
-
#
|
| 146 |
if self.position_indices.size(0) < query_len:
|
| 147 |
position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
|
| 148 |
- torch.arange(query_len, dtype=torch.long).unsqueeze(0)
|
|
@@ -150,55 +150,52 @@ class Attention(nn.Module):
|
|
| 150 |
position_indices = self.config.position_bucket_size - 1 + position_indices
|
| 151 |
self.position_indices = position_indices.to(hidden_states.device)
|
| 152 |
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
-
#
|
| 156 |
-
query,
|
| 157 |
-
|
|
|
|
| 158 |
|
| 159 |
-
#
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
| 163 |
|
| 164 |
-
#
|
| 165 |
-
|
|
|
|
| 166 |
|
| 167 |
-
#
|
| 168 |
-
|
| 169 |
-
query_pos, key_pos = pos.view(-1, self.num_heads, 2*self.head_size).chunk(2, dim=2) # [2T-1, H, D]
|
| 170 |
|
| 171 |
-
#
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
-
#
|
| 176 |
-
|
| 177 |
|
| 178 |
-
#
|
| 179 |
-
|
| 180 |
-
|
| 181 |
|
| 182 |
-
#
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
attention_scores.add_(attention_p_c)
|
| 190 |
-
|
| 191 |
-
attention_scores = attention_scores.masked_fill(attention_mask, float('-inf'))
|
| 192 |
-
attention_probs = F.softmax(attention_scores, dim=-1)
|
| 193 |
-
|
| 194 |
-
attention_probs = self.dropout(attention_probs)
|
| 195 |
-
context = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D]
|
| 196 |
-
context = context.transpose(0, 1).reshape(context.size(1), -1, self.hidden_size) # shape: [Q, B, H*D]
|
| 197 |
-
context = self.out_proj(context)
|
| 198 |
-
context = self.post_layer_norm(context)
|
| 199 |
-
context = self.dropout(context)
|
| 200 |
|
| 201 |
-
return
|
| 202 |
|
| 203 |
|
| 204 |
class Embedding(nn.Module):
|
|
@@ -281,9 +278,8 @@ class NorbertModel(NorbertPreTrainedModel):
|
|
| 281 |
attention_mask = ~attention_mask.bool()
|
| 282 |
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 283 |
|
| 284 |
-
static_embeddings, relative_embedding = self.embedding(input_ids
|
| 285 |
contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
|
| 286 |
-
contextualized_embeddings = [e.transpose(0, 1) for e in contextualized_embeddings]
|
| 287 |
last_layer = contextualized_embeddings[-1]
|
| 288 |
contextualized_embeddings = [contextualized_embeddings[0]] + [
|
| 289 |
contextualized_embeddings[i] - contextualized_embeddings[i - 1]
|
|
|
|
| 139 |
return bucket_pos
|
| 140 |
|
| 141 |
def forward(self, hidden_states, attention_mask, relative_embedding):
|
| 142 |
+
batch_size, key_len, _ = hidden_states.size()
|
| 143 |
query_len = key_len
|
| 144 |
|
| 145 |
+
# Recompute position_indices if sequence length exceeds the precomputed size
|
| 146 |
if self.position_indices.size(0) < query_len:
|
| 147 |
position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
|
| 148 |
- torch.arange(query_len, dtype=torch.long).unsqueeze(0)
|
|
|
|
| 150 |
position_indices = self.config.position_bucket_size - 1 + position_indices
|
| 151 |
self.position_indices = position_indices.to(hidden_states.device)
|
| 152 |
|
| 153 |
+
# Pre-LN and project query/key/value.
|
| 154 |
+
hidden_states = self.pre_layer_norm(hidden_states) # shape: [B, T, D]
|
| 155 |
+
query, key = self.in_proj_qk(hidden_states).chunk(2, dim=-1) # shape: [B, T, D]
|
| 156 |
+
value = self.in_proj_v(hidden_states) # shape: [B, T, D]
|
| 157 |
|
| 158 |
+
# Reshape to [B, num_heads, T, head_size]
|
| 159 |
+
query = query.reshape(B, T, self.num_heads, self.head_size).transpose(1, 2) # shape: [B, num_heads, T_q, head_size]
|
| 160 |
+
key = key.reshape(B, T, self.num_heads, self.head_size).permute(0, 2, 3, 1) # shape: [B, num_heads, head_size, T_k]
|
| 161 |
+
value = value.view(B, T, self.num_heads, self.head_size).transpose(1, 2) # shape: [B, num_heads, T_k, head_size]
|
| 162 |
|
| 163 |
+
# Compute relative positional contributions
|
| 164 |
+
pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2*position_bucket_size - 1, 2D]
|
| 165 |
+
query_pos, key_pos = pos.view(-1, self.num_heads, 2*self.head_size).chunk(2, dim=2) # shape: [2*position_bucket_size - 1, num_heads, head_size]
|
| 166 |
+
query_pos = query_pos.permute(1, 0, 2) # shape: [num_heads, 2*position_bucket_size - 1, head_size]
|
| 167 |
+
key_pos = key_pos.permute(1, 0, 2) # shape: [num_heads, 2*position_bucket_size - 1, head_size]
|
| 168 |
|
| 169 |
+
# Scale the keys
|
| 170 |
+
key = key * self.scale
|
| 171 |
+
key_pos = key_pos * self.scale
|
| 172 |
|
| 173 |
+
# Compute standard content-to-content attention scores
|
| 174 |
+
attention_c_to_c = torch.matmul(query, key) # shape: [B, num_heads, T_q, T_k]
|
|
|
|
| 175 |
|
| 176 |
+
# Compute content-to-position and position-to-content attention scores
|
| 177 |
+
position_indices = self.position_indices[:query_len, :key_len].unsqueeze(0).unsqueeze(0).expand(batch_size, self.num_heads, query_len, key_len)
|
| 178 |
+
attention_c_to_p = torch.matmul(query, key_pos.unsqueeze(0)) # [B, num_heads, T-q, 2*position_bucket_size - 1]
|
| 179 |
+
attention_p_to_c = torch.matmul(query_pos.unsqueeze(0), key) # [B, num_heads, 2*position_bucket_size - 1, T_k]
|
| 180 |
+
attention_c_to_p = attention_c_to_p.gather(3, position_indices) # shape: [B, num_heads, T_q, T_k]
|
| 181 |
+
attention_p_to_c = attention_p_to_c.gather(2, position_indices) # shape: [B, num_heads, T_q, T_k]
|
| 182 |
|
| 183 |
+
# Full attention score
|
| 184 |
+
attention_scores = attention_c_to_c + attention_c_to_p + attention_p_to_c # shape: [B, num_heads, T_q, T_k]
|
| 185 |
|
| 186 |
+
# Masked softmax
|
| 187 |
+
attention_scores = attention_scores.masked_fill(attention_mask, float('-inf')) # shape: [B, num_heads, T_q, T_k]
|
| 188 |
+
attention_probs = F.softmax(attention_scores, dim=-1) # shape: [B, num_heads, T_q, T_k]
|
| 189 |
|
| 190 |
+
# Collect the weighted-averaged values
|
| 191 |
+
attention_probs = self.dropout(attention_probs) # shape: [B, num_heads, T_q, T_k]
|
| 192 |
+
output = torch.matmul(attention_probs, value) # shape: [B, num_heads, T_q, head_size]
|
| 193 |
+
output = output.transpose(1, 2).flatten(2, 3) # shape: [B, T_q, D]
|
| 194 |
+
output = self.out_proj(output)
|
| 195 |
+
output = self.post_layer_norm(output)
|
| 196 |
+
output = self.dropout(output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
+
return output, attention_probs.detach()
|
| 199 |
|
| 200 |
|
| 201 |
class Embedding(nn.Module):
|
|
|
|
| 278 |
attention_mask = ~attention_mask.bool()
|
| 279 |
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 280 |
|
| 281 |
+
static_embeddings, relative_embedding = self.embedding(input_ids)
|
| 282 |
contextualized_embeddings, attention_probs = self.transformer(static_embeddings, attention_mask, relative_embedding)
|
|
|
|
| 283 |
last_layer = contextualized_embeddings[-1]
|
| 284 |
contextualized_embeddings = [contextualized_embeddings[0]] + [
|
| 285 |
contextualized_embeddings[i] - contextualized_embeddings[i - 1]
|