QuantaSparkLabs commited on
Commit
e55f915
·
verified ·
1 Parent(s): 3d213d5

Update modeling_tiny_gpt.py

Browse files
Files changed (1) hide show
  1. modeling_tiny_gpt.py +68 -63
modeling_tiny_gpt.py CHANGED
@@ -13,20 +13,29 @@ _FLASH3_KERNEL = None
13
  def _get_flash2_kernel():
14
  global _FLASH2_KERNEL
15
  if _FLASH2_KERNEL is None:
16
- kernels = importlib.import_module("kernels")
17
- _FLASH2_KERNEL = kernels.get_kernel("kernels-community/flash-attn2", version=1)
 
 
 
18
  return _FLASH2_KERNEL
19
 
20
  def _get_flash3_kernel():
21
  global _FLASH3_KERNEL
22
  if _FLASH3_KERNEL is None:
23
- kernels = importlib.import_module("kernels")
24
- _FLASH3_KERNEL = kernels.get_kernel("kernels-community/flash-attn3", version=1)
 
 
 
25
  return _FLASH3_KERNEL
26
 
27
  def _get_sageattn():
28
- module = importlib.import_module("sageattention")
29
- return module.sageattn
 
 
 
30
 
31
  class CausalSelfAttention(nn.Module):
32
  def __init__(self, config: TinyGPTConfig):
@@ -36,49 +45,37 @@ class CausalSelfAttention(nn.Module):
36
  self.n_head = int(config.n_head)
37
  self.head_dim = int(config.n_embd // config.n_head)
38
  self.attention_backend = str(getattr(config, "attention_backend", "torch"))
39
- self.torch_fallback = bool(getattr(config, "torch_fallback", False))
40
- self.dropout_p = float(config.dropout)
41
  if self.attention_backend not in ("sage", "torch", "flash2", "flash3"):
42
- raise ValueError("attention_backend must be sage, torch, flash2 or flash3")
43
  if self.attention_backend == "sage" and self.head_dim not in (64, 96, 128):
44
- raise ValueError(f"SageAttention requires head_dim in [64, 96, 128], got {self.head_dim}")
45
  if self.attention_backend == "sage" and self.dropout_p != 0.0:
46
- raise ValueError("SageAttention requires dropout=0.0")
47
  if self.attention_backend == "flash3" and self.dropout_p != 0.0:
48
- raise ValueError("FlashAttention3 requires dropout=0.0")
49
  if self.attention_backend in ("flash2", "flash3") and self.head_dim % 8 != 0:
50
- raise ValueError(f"FlashAttention requires head_dim multiple of 8, got {self.head_dim}")
51
  self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
52
  self.proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
53
- self.dropout = nn.Dropout(config.dropout)
54
  mask = torch.tril(torch.ones(config.ctx_len, config.ctx_len, dtype=torch.bool))
55
  self.register_buffer("mask", mask.view(1, 1, config.ctx_len, config.ctx_len), persistent=False)
56
  self.sageattn = None
57
  self.flash_kernel = None
58
  if self.attention_backend == "sage":
59
- try:
60
- self.sageattn = _get_sageattn()
61
- except Exception:
62
- if self.torch_fallback:
63
- self.attention_backend = "torch"
64
- else:
65
- raise
66
  if self.attention_backend == "flash2":
67
- try:
68
- self.flash_kernel = _get_flash2_kernel()
69
- except Exception:
70
- if self.torch_fallback:
71
- self.attention_backend = "torch"
72
- else:
73
- raise
74
  if self.attention_backend == "flash3":
75
- try:
76
- self.flash_kernel = _get_flash3_kernel()
77
- except Exception:
78
- if self.torch_fallback:
79
- self.attention_backend = "torch"
80
- else:
81
- raise
82
 
83
  def _torch_attention(self, q, k, v, t):
84
  scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
@@ -88,34 +85,43 @@ class CausalSelfAttention(nn.Module):
88
  return att @ v
89
 
90
  def _sage_attention(self, q, k, v):
91
- if self.sageattn is None or not q.is_cuda:
92
- if self.torch_fallback:
93
- return None
94
- raise RuntimeError("SageAttention requires CUDA + sageattention")
95
- return self.sageattn(q.contiguous(), k.contiguous(), v.contiguous(), tensor_layout="HND", is_causal=True)
 
 
 
96
 
97
  def _flash2_attention(self, q, k, v):
98
- if self.flash_kernel is None or not q.is_cuda:
99
- if self.torch_fallback:
100
- return None
101
- raise RuntimeError("FlashAttention2 requires CUDA + kernels")
102
- q = q.transpose(1, 2).contiguous()
103
- k = k.transpose(1, 2).contiguous()
104
- v = v.transpose(1, 2).contiguous()
105
- dropout_p = self.dropout_p if self.training else 0.0
106
- y = self.flash_kernel.flash_attn_func(q, k, v, dropout_p=dropout_p, causal=True)
107
- return y.transpose(1, 2).contiguous()
 
 
 
108
 
109
  def _flash3_attention(self, q, k, v):
110
- if self.flash_kernel is None or not q.is_cuda:
111
- if self.torch_fallback:
112
- return None
113
- raise RuntimeError("FlashAttention3 requires CUDA + kernels")
114
- q = q.transpose(1, 2).contiguous()
115
- k = k.transpose(1, 2).contiguous()
116
- v = v.transpose(1, 2).contiguous()
117
- y = self.flash_kernel.flash_attn_func(q, k, v, causal=True)
118
- return y.transpose(1, 2).contiguous()
 
 
 
119
 
120
  def forward(self, x):
121
  b, t, c = x.shape
@@ -193,14 +199,13 @@ class TinyGPTModel(TinyGPTPreTrainedModel):
193
  self.ln_f = nn.LayerNorm(config.n_embd)
194
  self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
195
  self.post_init()
196
-
197
 
198
  def get_input_embeddings(self):
199
  return self.tok_emb
200
 
201
  def set_input_embeddings(self, value):
202
  self.tok_emb = value
203
- self.tie_weights()
204
 
205
  def get_output_embeddings(self):
206
  return self.head
@@ -246,7 +251,7 @@ class TinyGPTForCausalLM(TinyGPTPreTrainedModel, GenerationMixin):
246
 
247
  def set_input_embeddings(self, value):
248
  self.tiny_gpt.tok_emb = value
249
- self.tie_weights()
250
 
251
  def get_output_embeddings(self):
252
  return self.tiny_gpt.head
@@ -280,4 +285,4 @@ class TinyGPTForCausalLM(TinyGPTPreTrainedModel, GenerationMixin):
280
  past_key_values=None,
281
  hidden_states=None,
282
  attentions=None,
283
- )
 
13
  def _get_flash2_kernel():
14
  global _FLASH2_KERNEL
15
  if _FLASH2_KERNEL is None:
16
+ try:
17
+ kernels = importlib.import_module("kernels")
18
+ _FLASH2_KERNEL = kernels.get_kernel("kernels-community/flash-attn2", version=1)
19
+ except ImportError:
20
+ pass
21
  return _FLASH2_KERNEL
22
 
23
  def _get_flash3_kernel():
24
  global _FLASH3_KERNEL
25
  if _FLASH3_KERNEL is None:
26
+ try:
27
+ kernels = importlib.import_module("kernels")
28
+ _FLASH3_KERNEL = kernels.get_kernel("kernels-community/flash-attn3", version=1)
29
+ except ImportError:
30
+ pass
31
  return _FLASH3_KERNEL
32
 
33
  def _get_sageattn():
34
+ try:
35
+ module = importlib.import_module("sageattention")
36
+ return module.sageattn
37
+ except ImportError:
38
+ return None
39
 
40
  class CausalSelfAttention(nn.Module):
41
  def __init__(self, config: TinyGPTConfig):
 
45
  self.n_head = int(config.n_head)
46
  self.head_dim = int(config.n_embd // config.n_head)
47
  self.attention_backend = str(getattr(config, "attention_backend", "torch"))
48
+ self.torch_fallback = bool(getattr(config, "torch_fallback", True))
49
+ self.dropout_p = float(config.dropout) if hasattr(config, "dropout") else 0.0
50
  if self.attention_backend not in ("sage", "torch", "flash2", "flash3"):
51
+ self.attention_backend = "torch"
52
  if self.attention_backend == "sage" and self.head_dim not in (64, 96, 128):
53
+ self.attention_backend = "torch"
54
  if self.attention_backend == "sage" and self.dropout_p != 0.0:
55
+ self.attention_backend = "torch"
56
  if self.attention_backend == "flash3" and self.dropout_p != 0.0:
57
+ self.attention_backend = "torch"
58
  if self.attention_backend in ("flash2", "flash3") and self.head_dim % 8 != 0:
59
+ self.attention_backend = "torch"
60
  self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
61
  self.proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
62
+ self.dropout = nn.Dropout(self.dropout_p)
63
  mask = torch.tril(torch.ones(config.ctx_len, config.ctx_len, dtype=torch.bool))
64
  self.register_buffer("mask", mask.view(1, 1, config.ctx_len, config.ctx_len), persistent=False)
65
  self.sageattn = None
66
  self.flash_kernel = None
67
  if self.attention_backend == "sage":
68
+ self.sageattn = _get_sageattn()
69
+ if self.sageattn is None and not self.torch_fallback:
70
+ raise RuntimeError("SageAttention requested but not available")
 
 
 
 
71
  if self.attention_backend == "flash2":
72
+ self.flash_kernel = _get_flash2_kernel()
73
+ if self.flash_kernel is None and not self.torch_fallback:
74
+ raise RuntimeError("FlashAttention2 requested but not available")
 
 
 
 
75
  if self.attention_backend == "flash3":
76
+ self.flash_kernel = _get_flash3_kernel()
77
+ if self.flash_kernel is None and not self.torch_fallback:
78
+ raise RuntimeError("FlashAttention3 requested but not available")
 
 
 
 
79
 
80
  def _torch_attention(self, q, k, v, t):
81
  scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
 
85
  return att @ v
86
 
87
  def _sage_attention(self, q, k, v):
88
+ if self.sageattn is None:
89
+ return None
90
+ if not q.is_cuda:
91
+ return None
92
+ try:
93
+ return self.sageattn(q.contiguous(), k.contiguous(), v.contiguous(), tensor_layout="HND", is_causal=True)
94
+ except Exception:
95
+ return None
96
 
97
  def _flash2_attention(self, q, k, v):
98
+ if self.flash_kernel is None:
99
+ return None
100
+ if not q.is_cuda:
101
+ return None
102
+ try:
103
+ q = q.transpose(1, 2).contiguous()
104
+ k = k.transpose(1, 2).contiguous()
105
+ v = v.transpose(1, 2).contiguous()
106
+ dropout_p = self.dropout_p if self.training else 0.0
107
+ y = self.flash_kernel.flash_attn_func(q, k, v, dropout_p=dropout_p, causal=True)
108
+ return y.transpose(1, 2).contiguous()
109
+ except Exception:
110
+ return None
111
 
112
  def _flash3_attention(self, q, k, v):
113
+ if self.flash_kernel is None:
114
+ return None
115
+ if not q.is_cuda:
116
+ return None
117
+ try:
118
+ q = q.transpose(1, 2).contiguous()
119
+ k = k.transpose(1, 2).contiguous()
120
+ v = v.transpose(1, 2).contiguous()
121
+ y = self.flash_kernel.flash_attn_func(q, k, v, causal=True)
122
+ return y.transpose(1, 2).contiguous()
123
+ except Exception:
124
+ return None
125
 
126
  def forward(self, x):
127
  b, t, c = x.shape
 
199
  self.ln_f = nn.LayerNorm(config.n_embd)
200
  self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
201
  self.post_init()
 
202
 
203
  def get_input_embeddings(self):
204
  return self.tok_emb
205
 
206
  def set_input_embeddings(self, value):
207
  self.tok_emb = value
208
+ self.head.weight = self.tok_emb.weight
209
 
210
  def get_output_embeddings(self):
211
  return self.head
 
251
 
252
  def set_input_embeddings(self, value):
253
  self.tiny_gpt.tok_emb = value
254
+ self.tiny_gpt.head.weight = self.tiny_gpt.tok_emb.weight
255
 
256
  def get_output_embeddings(self):
257
  return self.tiny_gpt.head
 
285
  past_key_values=None,
286
  hidden_states=None,
287
  attentions=None,
288
+ )