Upload modeling_bert.py
Browse files- modeling_bert.py +14 -0
modeling_bert.py
CHANGED
|
@@ -21,6 +21,7 @@ import warnings
|
|
| 21 |
from dataclasses import dataclass
|
| 22 |
from typing import List, Optional, Tuple, Union
|
| 23 |
from functools import partial
|
|
|
|
| 24 |
import torch
|
| 25 |
import torch.utils.checkpoint
|
| 26 |
from packaging import version
|
|
@@ -56,6 +57,18 @@ from transformers.utils import (
|
|
| 56 |
)
|
| 57 |
from .configuration_bert import BertConfig
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
def softmax_n_shifted_zeros(input: torch.Tensor, n: int, dim=-1) -> torch.Tensor:
|
| 61 |
"""
|
|
@@ -91,6 +104,7 @@ def clipped_softmax(data, dim=1, eta=1.1, gamma=-0.1, **kw):
|
|
| 91 |
return torch.clip(stretched_out, 0, 1)
|
| 92 |
|
| 93 |
|
|
|
|
| 94 |
def clipped_softmax1(data, dim=1, eta=1.1, gamma=-0.1, **kw):
|
| 95 |
sm_out = softmax_1(data, dim=dim, **kw)
|
| 96 |
stretched_out = sm_out * (eta - gamma) + gamma
|
|
|
|
| 21 |
from dataclasses import dataclass
|
| 22 |
from typing import List, Optional, Tuple, Union
|
| 23 |
from functools import partial
|
| 24 |
+
from enum import Flag, auto
|
| 25 |
import torch
|
| 26 |
import torch.utils.checkpoint
|
| 27 |
from packaging import version
|
|
|
|
| 57 |
)
|
| 58 |
from .configuration_bert import BertConfig
|
| 59 |
|
| 60 |
+
class BaseEnumOptions(Flag):
|
| 61 |
+
def __str__(self):
|
| 62 |
+
return self.name
|
| 63 |
+
|
| 64 |
+
@classmethod
|
| 65 |
+
def list_names(cls):
|
| 66 |
+
return [m.name for m in cls]
|
| 67 |
+
class AttentionGateType(BaseEnumOptions):
|
| 68 |
+
none = 0
|
| 69 |
+
unconditional_per_head = 1
|
| 70 |
+
conditional_per_head = 2
|
| 71 |
+
conditional_per_token = 3
|
| 72 |
|
| 73 |
def softmax_n_shifted_zeros(input: torch.Tensor, n: int, dim=-1) -> torch.Tensor:
|
| 74 |
"""
|
|
|
|
| 104 |
return torch.clip(stretched_out, 0, 1)
|
| 105 |
|
| 106 |
|
| 107 |
+
|
| 108 |
def clipped_softmax1(data, dim=1, eta=1.1, gamma=-0.1, **kw):
|
| 109 |
sm_out = softmax_1(data, dim=dim, **kw)
|
| 110 |
stretched_out = sm_out * (eta - gamma) + gamma
|