Upload modeling_opt.py
Browse files- modeling_opt.py +23 -8
modeling_opt.py
CHANGED
|
@@ -17,32 +17,37 @@
|
|
| 17 |
from typing import List, Optional, Tuple, Union
|
| 18 |
|
| 19 |
import torch
|
|
|
|
| 20 |
import torch.nn.functional as F
|
| 21 |
import torch.utils.checkpoint
|
| 22 |
from torch import nn
|
| 23 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 24 |
|
| 25 |
-
from .
|
| 26 |
-
from .
|
| 27 |
-
from .
|
| 28 |
BaseModelOutputWithPast,
|
| 29 |
CausalLMOutputWithPast,
|
| 30 |
QuestionAnsweringModelOutput,
|
| 31 |
SequenceClassifierOutputWithPast,
|
| 32 |
)
|
| 33 |
-
|
| 34 |
-
from .
|
|
|
|
| 35 |
add_code_sample_docstrings,
|
| 36 |
add_start_docstrings,
|
| 37 |
add_start_docstrings_to_model_forward,
|
|
|
|
| 38 |
is_flash_attn_2_available,
|
| 39 |
is_flash_attn_greater_or_equal_2_10,
|
| 40 |
logging,
|
| 41 |
replace_return_docstrings,
|
|
|
|
| 42 |
)
|
| 43 |
from .configuration_opt import OPTConfig
|
| 44 |
|
| 45 |
|
|
|
|
| 46 |
if is_flash_attn_2_available():
|
| 47 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 48 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
@@ -128,6 +133,16 @@ def softmax_1(input: torch.Tensor, dim=-1, dtype=torch.float32) -> torch.Tensor:
|
|
| 128 |
output = softmax_n_shifted_zeros(input, 1, dim=dim)
|
| 129 |
return output if dtype is None else output.type(dtype=dtype)
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
class OPTAttention(nn.Module):
|
| 133 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
@@ -147,7 +162,7 @@ class OPTAttention(nn.Module):
|
|
| 147 |
|
| 148 |
self.head_dim = self.embed_dim // self.num_heads
|
| 149 |
self.is_causal = True
|
| 150 |
-
|
| 151 |
if (self.head_dim * self.num_heads) != self.embed_dim:
|
| 152 |
raise ValueError(
|
| 153 |
f"embed_dim must be divisible by num_heads (got `embed_dim`: {
|
|
@@ -251,10 +266,10 @@ class OPTAttention(nn.Module):
|
|
| 251 |
|
| 252 |
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
| 253 |
if attn_weights.dtype == torch.float16:
|
| 254 |
-
attn_weights =
|
| 255 |
attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
|
| 256 |
else:
|
| 257 |
-
attn_weights =
|
| 258 |
|
| 259 |
if layer_head_mask is not None:
|
| 260 |
if layer_head_mask.size() != (self.num_heads,):
|
|
|
|
| 17 |
from typing import List, Optional, Tuple, Union
|
| 18 |
|
| 19 |
import torch
|
| 20 |
+
|
| 21 |
import torch.nn.functional as F
|
| 22 |
import torch.utils.checkpoint
|
| 23 |
from torch import nn
|
| 24 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 25 |
|
| 26 |
+
from transformers.activations import ACT2FN
|
| 27 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 28 |
+
from transformers.modeling_outputs import (
|
| 29 |
BaseModelOutputWithPast,
|
| 30 |
CausalLMOutputWithPast,
|
| 31 |
QuestionAnsweringModelOutput,
|
| 32 |
SequenceClassifierOutputWithPast,
|
| 33 |
)
|
| 34 |
+
|
| 35 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 36 |
+
from transformers.utils import (
|
| 37 |
add_code_sample_docstrings,
|
| 38 |
add_start_docstrings,
|
| 39 |
add_start_docstrings_to_model_forward,
|
| 40 |
+
|
| 41 |
is_flash_attn_2_available,
|
| 42 |
is_flash_attn_greater_or_equal_2_10,
|
| 43 |
logging,
|
| 44 |
replace_return_docstrings,
|
| 45 |
+
|
| 46 |
)
|
| 47 |
from .configuration_opt import OPTConfig
|
| 48 |
|
| 49 |
|
| 50 |
+
|
| 51 |
if is_flash_attn_2_available():
|
| 52 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 53 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
|
| 133 |
output = softmax_n_shifted_zeros(input, 1, dim=dim)
|
| 134 |
return output if dtype is None else output.type(dtype=dtype)
|
| 135 |
|
| 136 |
+
def clipped_softmax(data, dim=1, eta=1.1, gamma=-0.1, **kw):
|
| 137 |
+
sm_out = torch.nn.functional.softmax(data, dim=dim, **kw)
|
| 138 |
+
stretched_out = sm_out * (eta - gamma) + gamma
|
| 139 |
+
return torch.clip(stretched_out, 0, 1)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def clipped_softmax1(data, dim=1, eta=1.1, gamma=-0.1, **kw):
|
| 143 |
+
sm_out = softmax_1(data, dim=dim, **kw)
|
| 144 |
+
stretched_out = sm_out * (eta - gamma) + gamma
|
| 145 |
+
return torch.clip(stretched_out, 0, 1)
|
| 146 |
|
| 147 |
class OPTAttention(nn.Module):
|
| 148 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
| 162 |
|
| 163 |
self.head_dim = self.embed_dim // self.num_heads
|
| 164 |
self.is_causal = True
|
| 165 |
+
self.softmax_fn = clipped_softmax1
|
| 166 |
if (self.head_dim * self.num_heads) != self.embed_dim:
|
| 167 |
raise ValueError(
|
| 168 |
f"embed_dim must be divisible by num_heads (got `embed_dim`: {
|
|
|
|
| 266 |
|
| 267 |
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
| 268 |
if attn_weights.dtype == torch.float16:
|
| 269 |
+
attn_weights = self.softmax_fn(
|
| 270 |
attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
|
| 271 |
else:
|
| 272 |
+
attn_weights = self.softmax_fn(attn_weights, dim=-1)
|
| 273 |
|
| 274 |
if layer_head_mask is not None:
|
| 275 |
if layer_head_mask.size() != (self.num_heads,):
|