Upload modelling_RW.py
Browse files- modelling_RW.py +18 -4
modelling_RW.py
CHANGED
|
@@ -276,9 +276,23 @@ class Attention(nn.Module):
|
|
| 276 |
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
| 277 |
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
| 284 |
x = x.permute(0, 2, 1, 3)
|
|
@@ -945,7 +959,7 @@ class RWForTokenClassification(RWPreTrainedModel):
|
|
| 945 |
else:
|
| 946 |
classifier_dropout = 0.1
|
| 947 |
self.dropout = nn.Dropout(classifier_dropout)
|
| 948 |
-
self.classifier = nn.Linear(config.hidden_size, config.
|
| 949 |
|
| 950 |
# Initialize weights and apply final processing
|
| 951 |
self.post_init()
|
|
|
|
| 276 |
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
| 277 |
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
| 278 |
|
| 279 |
+
if torch.__version__ < "2.0.0":
|
| 280 |
+
mask = torch.ones(q_length, q_length, device=query_layer_.device)
|
| 281 |
+
mask = torch.tril(mask)
|
| 282 |
+
mask = (1.0 - mask) * -10000
|
| 283 |
+
mask = mask.repeat(batch_size, 1, 1, 1)
|
| 284 |
+
|
| 285 |
+
scores = torch.matmul(query_layer_, key_layer_.transpose(-2, -1))
|
| 286 |
+
scores = scores / math.sqrt(float(self.head_dim))
|
| 287 |
+
scores = scores + mask.type_as(scores)
|
| 288 |
+
|
| 289 |
+
probs = nn.Softmax(dim=-1)(scores)
|
| 290 |
+
|
| 291 |
+
attn_output = probs @ value_layer_
|
| 292 |
+
else:
|
| 293 |
+
attn_output = F.scaled_dot_product_attention(
|
| 294 |
+
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
| 295 |
+
)
|
| 296 |
|
| 297 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
| 298 |
x = x.permute(0, 2, 1, 3)
|
|
|
|
| 959 |
else:
|
| 960 |
classifier_dropout = 0.1
|
| 961 |
self.dropout = nn.Dropout(classifier_dropout)
|
| 962 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_lab els)
|
| 963 |
|
| 964 |
# Initialize weights and apply final processing
|
| 965 |
self.post_init()
|