Upload modeling_opt.py
Browse files- modeling_opt.py +17 -8
modeling_opt.py
CHANGED
|
@@ -13,7 +13,7 @@
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
"""PyTorch OPT model."""
|
| 16 |
-
|
| 17 |
from typing import List, Optional, Tuple, Union
|
| 18 |
from functools import partial
|
| 19 |
import torch
|
|
@@ -46,10 +46,12 @@ from transformers.utils import (
|
|
| 46 |
)
|
| 47 |
from .configuration_opt import OPTConfig
|
| 48 |
|
|
|
|
| 49 |
def logit(p, eps=1e-16):
|
| 50 |
p = np.clip(p, eps, 1 - eps)
|
| 51 |
return -np.log(1 / p - 1)
|
| 52 |
|
|
|
|
| 53 |
class BaseEnumOptions(Flag):
|
| 54 |
def __str__(self):
|
| 55 |
return self.name
|
|
@@ -198,7 +200,8 @@ class OPTAttention(nn.Module):
|
|
| 198 |
|
| 199 |
if (self.head_dim * self.num_heads) != self.embed_dim:
|
| 200 |
raise ValueError(
|
| 201 |
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: {
|
|
|
|
| 202 |
f" and `num_heads`: {self.num_heads})."
|
| 203 |
)
|
| 204 |
self.scaling = self.head_dim**-0.5
|
|
@@ -368,14 +371,16 @@ class OPTAttention(nn.Module):
|
|
| 368 |
|
| 369 |
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
| 370 |
raise ValueError(
|
| 371 |
-
f"Attention weights should be of size {
|
|
|
|
| 372 |
f" {attn_weights.size()}"
|
| 373 |
)
|
| 374 |
|
| 375 |
if attention_mask is not None:
|
| 376 |
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
| 377 |
raise ValueError(
|
| 378 |
-
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {
|
|
|
|
| 379 |
)
|
| 380 |
attn_weights = attn_weights.view(
|
| 381 |
bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
|
@@ -396,7 +401,8 @@ class OPTAttention(nn.Module):
|
|
| 396 |
if layer_head_mask is not None:
|
| 397 |
if layer_head_mask.size() != (self.num_heads,):
|
| 398 |
raise ValueError(
|
| 399 |
-
f"Head mask for a single layer should be of size {
|
|
|
|
| 400 |
f" {layer_head_mask.size()}"
|
| 401 |
)
|
| 402 |
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
|
|
@@ -430,7 +436,8 @@ class OPTAttention(nn.Module):
|
|
| 430 |
|
| 431 |
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
| 432 |
raise ValueError(
|
| 433 |
-
f"`attn_output` should be of size {
|
|
|
|
| 434 |
f" {attn_output.size()}"
|
| 435 |
)
|
| 436 |
|
|
@@ -1088,7 +1095,8 @@ class OPTDecoder(OPTPreTrainedModel):
|
|
| 1088 |
batch_size, mask_seq_length, device=inputs_embeds.device)
|
| 1089 |
elif attention_mask.shape[1] != mask_seq_length:
|
| 1090 |
raise ValueError(
|
| 1091 |
-
f"The provided attention mask has length {
|
|
|
|
| 1092 |
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
|
| 1093 |
)
|
| 1094 |
causal_attention_mask = _prepare_4d_causal_attention_mask(
|
|
@@ -1120,7 +1128,8 @@ class OPTDecoder(OPTPreTrainedModel):
|
|
| 1120 |
if attn_mask is not None:
|
| 1121 |
if attn_mask.size()[0] != (len(self.layers)):
|
| 1122 |
raise ValueError(
|
| 1123 |
-
f"The `{mask_name}` should be specified for {
|
|
|
|
| 1124 |
f" {head_mask.size()[0]}."
|
| 1125 |
)
|
| 1126 |
|
|
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
"""PyTorch OPT model."""
|
| 16 |
+
import numpy as np
|
| 17 |
from typing import List, Optional, Tuple, Union
|
| 18 |
from functools import partial
|
| 19 |
import torch
|
|
|
|
| 46 |
)
|
| 47 |
from .configuration_opt import OPTConfig
|
| 48 |
|
| 49 |
+
|
| 50 |
def logit(p, eps=1e-16):
|
| 51 |
p = np.clip(p, eps, 1 - eps)
|
| 52 |
return -np.log(1 / p - 1)
|
| 53 |
|
| 54 |
+
|
| 55 |
class BaseEnumOptions(Flag):
|
| 56 |
def __str__(self):
|
| 57 |
return self.name
|
|
|
|
| 200 |
|
| 201 |
if (self.head_dim * self.num_heads) != self.embed_dim:
|
| 202 |
raise ValueError(
|
| 203 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {
|
| 204 |
+
self.embed_dim}"
|
| 205 |
f" and `num_heads`: {self.num_heads})."
|
| 206 |
)
|
| 207 |
self.scaling = self.head_dim**-0.5
|
|
|
|
| 371 |
|
| 372 |
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
| 373 |
raise ValueError(
|
| 374 |
+
f"Attention weights should be of size {
|
| 375 |
+
(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
| 376 |
f" {attn_weights.size()}"
|
| 377 |
)
|
| 378 |
|
| 379 |
if attention_mask is not None:
|
| 380 |
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
| 381 |
raise ValueError(
|
| 382 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {
|
| 383 |
+
attention_mask.size()}"
|
| 384 |
)
|
| 385 |
attn_weights = attn_weights.view(
|
| 386 |
bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
|
|
|
| 401 |
if layer_head_mask is not None:
|
| 402 |
if layer_head_mask.size() != (self.num_heads,):
|
| 403 |
raise ValueError(
|
| 404 |
+
f"Head mask for a single layer should be of size {
|
| 405 |
+
(self.num_heads,)}, but is"
|
| 406 |
f" {layer_head_mask.size()}"
|
| 407 |
)
|
| 408 |
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
|
|
|
|
| 436 |
|
| 437 |
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
| 438 |
raise ValueError(
|
| 439 |
+
f"`attn_output` should be of size {
|
| 440 |
+
(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
| 441 |
f" {attn_output.size()}"
|
| 442 |
)
|
| 443 |
|
|
|
|
| 1095 |
batch_size, mask_seq_length, device=inputs_embeds.device)
|
| 1096 |
elif attention_mask.shape[1] != mask_seq_length:
|
| 1097 |
raise ValueError(
|
| 1098 |
+
f"The provided attention mask has length {
|
| 1099 |
+
attention_mask.shape[1]}, but its length should be "
|
| 1100 |
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
|
| 1101 |
)
|
| 1102 |
causal_attention_mask = _prepare_4d_causal_attention_mask(
|
|
|
|
| 1128 |
if attn_mask is not None:
|
| 1129 |
if attn_mask.size()[0] != (len(self.layers)):
|
| 1130 |
raise ValueError(
|
| 1131 |
+
f"The `{mask_name}` should be specified for {
|
| 1132 |
+
len(self.layers)} layers, but it is for"
|
| 1133 |
f" {head_mask.size()[0]}."
|
| 1134 |
)
|
| 1135 |
|