noblebarkrr commited on
Commit
51c079b
·
verified ·
1 Parent(s): 9e87102

Delete models/bs_roformer/attend_sage.py

Browse files
Files changed (1) hide show
  1. models/bs_roformer/attend_sage.py +0 -147
models/bs_roformer/attend_sage.py DELETED
@@ -1,147 +0,0 @@
1
- from functools import wraps
2
- from packaging import version
3
- from collections import namedtuple
4
-
5
- import os
6
- import torch
7
- from torch import nn, einsum
8
- import torch.nn.functional as F
9
-
10
- from einops import rearrange, reduce
11
-
12
-
13
- def _print_once(msg):
14
- printed = False
15
-
16
- def inner():
17
- nonlocal printed
18
- if not printed:
19
- print(msg)
20
- printed = True
21
-
22
- return inner
23
-
24
-
25
- # Проверяем доступность SageAttention
26
- try:
27
- from sageattention import sageattn
28
- _has_sage_attention = True
29
- except ImportError:
30
- _has_sage_attention = False
31
- _print_sage_not_found = _print_once(
32
- "SageAttention not found. Will fall back to PyTorch SDPA (if available) or manual einsum."
33
- )
34
- _print_sage_not_found()
35
-
36
-
37
- def exists(val):
38
- return val is not None
39
-
40
-
41
- def default(v, d):
42
- return v if exists(v) else d
43
-
44
-
45
- class Attend(nn.Module):
46
- def __init__(self, dropout=0.0, flash=False, scale=None):
47
- super().__init__()
48
- self.scale = scale
49
- self.dropout = dropout
50
-
51
- self.use_sage = flash and _has_sage_attention
52
- self.use_pytorch_sdpa = False
53
- self._sdpa_checked = False
54
- self.flash = flash
55
-
56
- # Инициализируем сообщения
57
- self._init_messages = False
58
-
59
- if flash and not self.use_sage:
60
- if not self._sdpa_checked:
61
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
62
- self.use_pytorch_sdpa = True
63
- self._sdpa_checked = True
64
-
65
- self.attn_dropout = nn.Dropout(dropout)
66
-
67
- def _print_init_messages(self):
68
- """Печатаем сообщения инициализации один раз"""
69
- if self._init_messages:
70
- return
71
-
72
- if self.flash:
73
- if self.use_sage:
74
- print_once = _print_once("Using SageAttention backend.")
75
- print_once()
76
- elif self.use_pytorch_sdpa:
77
- print_once = _print_once(
78
- "Using PyTorch SDPA backend (FlashAttention-2, Memory-Efficient, or Math)."
79
- )
80
- print_once()
81
- else:
82
- print_once = _print_once(
83
- "Flash attention requested but Pytorch < 2.0 and SageAttention not found. Falling back to einsum."
84
- )
85
- print_once()
86
-
87
- self._init_messages = True
88
-
89
- def forward(self, q, k, v):
90
- q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
91
-
92
- # Печатаем сообщения инициализации при первом вызове
93
- self._print_init_messages()
94
-
95
- # Пробуем SageAttention если доступен
96
- if self.use_sage and self.flash:
97
- try:
98
- # Исправленный вызов: убрали повторный try-except
99
- out = sageattn(q, k, v, tensor_layout="HND", is_causal=False)
100
- return out
101
- except Exception as e:
102
- print(f"SageAttention failed with error: {e}. Falling back.")
103
- self.use_sage = False
104
- if not self._sdpa_checked:
105
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
106
- self.use_pytorch_sdpa = True
107
- print_once = _print_once(
108
- "Falling back to PyTorch SDPA."
109
- )
110
- print_once()
111
- else:
112
- print_once = _print_once("Falling back to einsum.")
113
- print_once()
114
- self._sdpa_checked = True
115
-
116
- # Пробуем PyTorch SDPA если доступен
117
- if self.use_pytorch_sdpa and self.flash:
118
- try:
119
- # Для PyTorch >= 2.0
120
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
121
- with torch.backends.cuda.sdp_kernel(
122
- enable_flash=True, enable_math=True, enable_mem_efficient=True
123
- ):
124
- out = F.scaled_dot_product_attention(
125
- q,
126
- k,
127
- v,
128
- attn_mask=None,
129
- dropout_p=self.dropout if self.training else 0.0,
130
- is_causal=False,
131
- )
132
- return out
133
- except Exception as e:
134
- print(f"PyTorch SDPA failed with error: {e}. Falling back to einsum.")
135
- self.use_pytorch_sdpa = False
136
-
137
- # Fallback на einsum (работает в PyTorch 1.13+)
138
- scale = default(self.scale, q.shape[-1] ** -0.5)
139
-
140
- sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
141
-
142
- attn = sim.softmax(dim=-1)
143
- attn = self.attn_dropout(attn)
144
-
145
- out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
146
-
147
- return out