noblebarkrr commited on
Commit
9e87102
·
verified ·
1 Parent(s): 44bc4cc

Delete models/bs_roformer/attend_sw.py

Browse files
Files changed (1) hide show
  1. models/bs_roformer/attend_sw.py +0 -88
models/bs_roformer/attend_sw.py DELETED
@@ -1,88 +0,0 @@
1
- import logging
2
- import os
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- from packaging import version
7
- from torch import Tensor, einsum, nn
8
- from torch.nn.attention import SDPBackend, sdpa_kernel
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- class Attend(nn.Module):
14
- def __init__(self, dropout: float = 0.0, flash: bool = False, scale=None):
15
- super().__init__()
16
- self.scale = scale
17
- self.dropout = dropout
18
- self.attn_dropout = nn.Dropout(dropout)
19
-
20
- self.flash = flash
21
- assert not (
22
- flash and version.parse(torch.__version__) < version.parse("2.0.0")
23
- ), "expected pytorch >= 2.0.0 to use flash attention"
24
-
25
- self.cpu_backends = [
26
- SDPBackend.FLASH_ATTENTION,
27
- SDPBackend.EFFICIENT_ATTENTION,
28
- SDPBackend.MATH,
29
- ]
30
- self.cuda_backends: list | None = None
31
-
32
- if not torch.cuda.is_available() or not flash:
33
- return
34
-
35
- device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
36
- device_version = version.parse(
37
- f"{device_properties.major}.{device_properties.minor}"
38
- )
39
-
40
- if device_version >= version.parse("8.0"):
41
- if os.name == "nt":
42
- cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
43
- logger.info(f"windows detected, {cuda_backends=}")
44
- else:
45
- cuda_backends = [SDPBackend.FLASH_ATTENTION]
46
- logger.info(f"gpu compute capability >= 8.0, {cuda_backends=}")
47
- else:
48
- cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
49
- logger.info(f"gpu compute capability < 8.0, {cuda_backends=}")
50
-
51
- self.cuda_backends = cuda_backends
52
-
53
- def flash_attn(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
54
- _, _heads, _q_len, _, _k_len, is_cuda, _device = (
55
- *q.shape,
56
- k.shape[-2],
57
- q.is_cuda,
58
- q.device,
59
- ) # type: ignore
60
-
61
- if self.scale is not None:
62
- default_scale = q.shape[-1] ** -0.5
63
- q = q * (self.scale / default_scale)
64
-
65
- backends = self.cuda_backends if is_cuda else self.cpu_backends
66
- with sdpa_kernel(backends=backends): # type: ignore
67
- out = F.scaled_dot_product_attention(
68
- q, k, v, dropout_p=self.dropout if self.training else 0.0
69
- )
70
-
71
- return out
72
-
73
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
74
- _q_len, _k_len, _device = q.shape[-2], k.shape[-2], q.device
75
-
76
- scale = self.scale or q.shape[-1] ** -0.5
77
-
78
- if self.flash:
79
- return self.flash_attn(q, k, v)
80
-
81
- sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
82
-
83
- attn = sim.softmax(dim=-1)
84
- attn = self.attn_dropout(attn)
85
-
86
- out = einsum("b h i j, b h j d -> b h i d", attn, v)
87
-
88
- return out