Update modeling_deberta.py
Browse files- modeling_deberta.py +4 -10
modeling_deberta.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
""" PyTorch DeBERTa-v2 model."""
|
| 16 |
|
|
|
|
| 17 |
from collections.abc import Sequence
|
| 18 |
from typing import Optional, Tuple, Union
|
| 19 |
|
|
@@ -553,16 +554,9 @@ class DebertaV2Encoder(nn.Module):
|
|
| 553 |
def make_log_bucket_position(relative_pos, bucket_size, max_position):
|
| 554 |
sign = torch.sign(relative_pos)
|
| 555 |
mid = bucket_size // 2
|
| 556 |
-
abs_pos = torch.where(
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
torch.abs(relative_pos),
|
| 560 |
-
)
|
| 561 |
-
log_pos = (
|
| 562 |
-
torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
|
| 563 |
-
)
|
| 564 |
-
bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
|
| 565 |
-
bucket_pos = bucket_pos.clamp(min=-bucket_size+1, max=bucket_size-1)
|
| 566 |
return bucket_pos
|
| 567 |
|
| 568 |
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
""" PyTorch DeBERTa-v2 model."""
|
| 16 |
|
| 17 |
+
import math
|
| 18 |
from collections.abc import Sequence
|
| 19 |
from typing import Optional, Tuple, Union
|
| 20 |
|
|
|
|
| 554 |
def make_log_bucket_position(relative_pos, bucket_size, max_position):
|
| 555 |
sign = torch.sign(relative_pos)
|
| 556 |
mid = bucket_size // 2
|
| 557 |
+
abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos).clamp(max=max_position - 1))
|
| 558 |
+
log_pos = torch.ceil(torch.log(abs_pos / mid) / math.log((max_position-1) / mid) * (mid - 1)).int() + mid
|
| 559 |
+
bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
return bucket_pos
|
| 561 |
|
| 562 |
|