QuantaSparkLabs commited on
Commit
9047576
·
verified ·
1 Parent(s): 59f1228

Fix tie_weights and add torch_fallback support

Browse files
Files changed (1) hide show
  1. modeling_tiny_gpt.py +55 -7
modeling_tiny_gpt.py CHANGED
@@ -6,23 +6,28 @@ import torch.nn.functional as F
6
  from transformers import PreTrainedModel, GenerationMixin
7
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
8
  from .configuration_tiny_gpt import TinyGPTConfig
 
9
  _FLASH2_KERNEL = None
10
  _FLASH3_KERNEL = None
 
11
  def _get_flash2_kernel():
12
  global _FLASH2_KERNEL
13
  if _FLASH2_KERNEL is None:
14
  kernels = importlib.import_module("kernels")
15
  _FLASH2_KERNEL = kernels.get_kernel("kernels-community/flash-attn2", version=1)
16
  return _FLASH2_KERNEL
 
17
  def _get_flash3_kernel():
18
  global _FLASH3_KERNEL
19
  if _FLASH3_KERNEL is None:
20
  kernels = importlib.import_module("kernels")
21
  _FLASH3_KERNEL = kernels.get_kernel("kernels-community/flash-attn3", version=1)
22
  return _FLASH3_KERNEL
 
23
  def _get_sageattn():
24
  module = importlib.import_module("sageattention")
25
  return module.sageattn
 
26
  class CausalSelfAttention(nn.Module):
27
  def __init__(self, config: TinyGPTConfig):
28
  super().__init__()
@@ -74,18 +79,21 @@ class CausalSelfAttention(nn.Module):
74
  self.attention_backend = "torch"
75
  else:
76
  raise
 
77
  def _torch_attention(self, q, k, v, t):
78
  scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
79
  scores = scores.masked_fill(self.mask[:, :, :t, :t] == 0, float("-inf"))
80
  att = F.softmax(scores.float(), dim=-1).to(q.dtype)
81
  att = self.dropout(att)
82
  return att @ v
 
83
  def _sage_attention(self, q, k, v):
84
  if self.sageattn is None or not q.is_cuda:
85
  if self.torch_fallback:
86
  return None
87
  raise RuntimeError("SageAttention requires CUDA + sageattention")
88
  return self.sageattn(q.contiguous(), k.contiguous(), v.contiguous(), tensor_layout="HND", is_causal=True)
 
89
  def _flash2_attention(self, q, k, v):
90
  if self.flash_kernel is None or not q.is_cuda:
91
  if self.torch_fallback:
@@ -97,6 +105,7 @@ class CausalSelfAttention(nn.Module):
97
  dropout_p = self.dropout_p if self.training else 0.0
98
  y = self.flash_kernel.flash_attn_func(q, k, v, dropout_p=dropout_p, causal=True)
99
  return y.transpose(1, 2).contiguous()
 
100
  def _flash3_attention(self, q, k, v):
101
  if self.flash_kernel is None or not q.is_cuda:
102
  if self.torch_fallback:
@@ -107,6 +116,7 @@ class CausalSelfAttention(nn.Module):
107
  v = v.transpose(1, 2).contiguous()
108
  y = self.flash_kernel.flash_attn_func(q, k, v, causal=True)
109
  return y.transpose(1, 2).contiguous()
 
110
  def forward(self, x):
111
  b, t, c = x.shape
112
  qkv = self.qkv(x)
@@ -130,18 +140,21 @@ class CausalSelfAttention(nn.Module):
130
  y = self._torch_attention(q, k, v, t)
131
  y = y.transpose(1, 2).contiguous().view(b, t, c)
132
  return self.proj(y)
 
133
  class MLP(nn.Module):
134
  def __init__(self, config: TinyGPTConfig):
135
  super().__init__()
136
  self.fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
137
  self.proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
138
  self.dropout = nn.Dropout(config.dropout)
 
139
  def forward(self, x):
140
  x = self.fc(x)
141
  x = F.gelu(x)
142
  x = self.proj(x)
143
  x = self.dropout(x)
144
  return x
 
145
  class Block(nn.Module):
146
  def __init__(self, config: TinyGPTConfig):
147
  super().__init__()
@@ -149,14 +162,17 @@ class Block(nn.Module):
149
  self.attn = CausalSelfAttention(config)
150
  self.ln2 = nn.LayerNorm(config.n_embd)
151
  self.mlp = MLP(config)
 
152
  def forward(self, x):
153
  x = x + self.attn(self.ln1(x))
154
  x = x + self.mlp(self.ln2(x))
155
  return x
 
156
  class TinyGPTPreTrainedModel(PreTrainedModel):
157
  config_class = TinyGPTConfig
158
  base_model_prefix = "tiny_gpt"
159
  supports_gradient_checkpointing = False
 
160
  def _init_weights(self, module):
161
  if isinstance(module, nn.Linear):
162
  nn.init.normal_(module.weight, mean=0.0, std=0.02)
@@ -164,8 +180,10 @@ class TinyGPTPreTrainedModel(PreTrainedModel):
164
  nn.init.zeros_(module.bias)
165
  elif isinstance(module, nn.Embedding):
166
  nn.init.normal_(module.weight, mean=0.0, std=0.02)
 
167
  class TinyGPTModel(TinyGPTPreTrainedModel):
168
  _tied_weights_keys = ["head.weight"]
 
169
  def __init__(self, config: TinyGPTConfig):
170
  super().__init__(config)
171
  self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
@@ -175,18 +193,24 @@ class TinyGPTModel(TinyGPTPreTrainedModel):
175
  self.ln_f = nn.LayerNorm(config.n_embd)
176
  self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
177
  self.post_init()
178
- self.tie_weights()
 
179
  def get_input_embeddings(self):
180
  return self.tok_emb
 
181
  def set_input_embeddings(self, value):
182
  self.tok_emb = value
183
  self.tie_weights()
 
184
  def get_output_embeddings(self):
185
  return self.head
 
186
  def set_output_embeddings(self, new_embeddings):
187
  self.head = new_embeddings
 
188
  def tie_weights(self, *args, **kwargs):
189
- self._tie_or_clone_weights(self.head, self.tok_emb)
 
190
  def forward(self, input_ids, attention_mask=None, return_dict=True, return_logits=False, **kwargs):
191
  b, t = input_ids.shape
192
  if t > self.config.ctx_len:
@@ -202,29 +226,47 @@ class TinyGPTModel(TinyGPTPreTrainedModel):
202
  return (hidden, logits) if return_logits else (hidden,)
203
  if return_logits:
204
  return hidden, logits
205
- return BaseModelOutputWithPast(last_hidden_state=hidden, past_key_values=None, hidden_states=None, attentions=None)
 
 
 
 
 
 
206
  class TinyGPTForCausalLM(TinyGPTPreTrainedModel, GenerationMixin):
207
  _tied_weights_keys = ["tiny_gpt.head.weight"]
 
208
  def __init__(self, config: TinyGPTConfig):
209
  super().__init__(config)
210
  self.tiny_gpt = TinyGPTModel(config)
211
  self.post_init()
212
- self.tie_weights()
213
  def get_input_embeddings(self):
214
  return self.tiny_gpt.tok_emb
 
215
  def set_input_embeddings(self, value):
216
  self.tiny_gpt.tok_emb = value
217
  self.tie_weights()
 
218
  def get_output_embeddings(self):
219
  return self.tiny_gpt.head
 
220
  def set_output_embeddings(self, new_embeddings):
221
  self.tiny_gpt.head = new_embeddings
 
222
  def tie_weights(self, *args, **kwargs):
223
- self._tie_or_clone_weights(self.tiny_gpt.head, self.tiny_gpt.tok_emb)
 
224
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
225
  return {"input_ids": input_ids}
 
226
  def forward(self, input_ids, attention_mask=None, labels=None, return_dict=True, **kwargs):
227
- hidden, logits = self.tiny_gpt(input_ids=input_ids, attention_mask=attention_mask, return_dict=True, return_logits=True)
 
 
 
 
 
228
  loss = None
229
  if labels is not None:
230
  shift_logits = logits[:, :-1, :].contiguous()
@@ -232,4 +274,10 @@ class TinyGPTForCausalLM(TinyGPTPreTrainedModel, GenerationMixin):
232
  loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)).float(), shift_labels.view(-1))
233
  if not return_dict:
234
  return ((loss, logits) if loss is not None else (logits,))
235
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None)
 
 
 
 
 
 
 
6
  from transformers import PreTrainedModel, GenerationMixin
7
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
8
  from .configuration_tiny_gpt import TinyGPTConfig
9
+
10
  _FLASH2_KERNEL = None
11
  _FLASH3_KERNEL = None
12
+
13
  def _get_flash2_kernel():
14
  global _FLASH2_KERNEL
15
  if _FLASH2_KERNEL is None:
16
  kernels = importlib.import_module("kernels")
17
  _FLASH2_KERNEL = kernels.get_kernel("kernels-community/flash-attn2", version=1)
18
  return _FLASH2_KERNEL
19
+
20
  def _get_flash3_kernel():
21
  global _FLASH3_KERNEL
22
  if _FLASH3_KERNEL is None:
23
  kernels = importlib.import_module("kernels")
24
  _FLASH3_KERNEL = kernels.get_kernel("kernels-community/flash-attn3", version=1)
25
  return _FLASH3_KERNEL
26
+
27
  def _get_sageattn():
28
  module = importlib.import_module("sageattention")
29
  return module.sageattn
30
+
31
  class CausalSelfAttention(nn.Module):
32
  def __init__(self, config: TinyGPTConfig):
33
  super().__init__()
 
79
  self.attention_backend = "torch"
80
  else:
81
  raise
82
+
83
  def _torch_attention(self, q, k, v, t):
84
  scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
85
  scores = scores.masked_fill(self.mask[:, :, :t, :t] == 0, float("-inf"))
86
  att = F.softmax(scores.float(), dim=-1).to(q.dtype)
87
  att = self.dropout(att)
88
  return att @ v
89
+
90
  def _sage_attention(self, q, k, v):
91
  if self.sageattn is None or not q.is_cuda:
92
  if self.torch_fallback:
93
  return None
94
  raise RuntimeError("SageAttention requires CUDA + sageattention")
95
  return self.sageattn(q.contiguous(), k.contiguous(), v.contiguous(), tensor_layout="HND", is_causal=True)
96
+
97
  def _flash2_attention(self, q, k, v):
98
  if self.flash_kernel is None or not q.is_cuda:
99
  if self.torch_fallback:
 
105
  dropout_p = self.dropout_p if self.training else 0.0
106
  y = self.flash_kernel.flash_attn_func(q, k, v, dropout_p=dropout_p, causal=True)
107
  return y.transpose(1, 2).contiguous()
108
+
109
  def _flash3_attention(self, q, k, v):
110
  if self.flash_kernel is None or not q.is_cuda:
111
  if self.torch_fallback:
 
116
  v = v.transpose(1, 2).contiguous()
117
  y = self.flash_kernel.flash_attn_func(q, k, v, causal=True)
118
  return y.transpose(1, 2).contiguous()
119
+
120
  def forward(self, x):
121
  b, t, c = x.shape
122
  qkv = self.qkv(x)
 
140
  y = self._torch_attention(q, k, v, t)
141
  y = y.transpose(1, 2).contiguous().view(b, t, c)
142
  return self.proj(y)
143
+
144
  class MLP(nn.Module):
145
  def __init__(self, config: TinyGPTConfig):
146
  super().__init__()
147
  self.fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
148
  self.proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
149
  self.dropout = nn.Dropout(config.dropout)
150
+
151
  def forward(self, x):
152
  x = self.fc(x)
153
  x = F.gelu(x)
154
  x = self.proj(x)
155
  x = self.dropout(x)
156
  return x
157
+
158
  class Block(nn.Module):
159
  def __init__(self, config: TinyGPTConfig):
160
  super().__init__()
 
162
  self.attn = CausalSelfAttention(config)
163
  self.ln2 = nn.LayerNorm(config.n_embd)
164
  self.mlp = MLP(config)
165
+
166
  def forward(self, x):
167
  x = x + self.attn(self.ln1(x))
168
  x = x + self.mlp(self.ln2(x))
169
  return x
170
+
171
  class TinyGPTPreTrainedModel(PreTrainedModel):
172
  config_class = TinyGPTConfig
173
  base_model_prefix = "tiny_gpt"
174
  supports_gradient_checkpointing = False
175
+
176
  def _init_weights(self, module):
177
  if isinstance(module, nn.Linear):
178
  nn.init.normal_(module.weight, mean=0.0, std=0.02)
 
180
  nn.init.zeros_(module.bias)
181
  elif isinstance(module, nn.Embedding):
182
  nn.init.normal_(module.weight, mean=0.0, std=0.02)
183
+
184
  class TinyGPTModel(TinyGPTPreTrainedModel):
185
  _tied_weights_keys = ["head.weight"]
186
+
187
  def __init__(self, config: TinyGPTConfig):
188
  super().__init__(config)
189
  self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
 
193
  self.ln_f = nn.LayerNorm(config.n_embd)
194
  self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
195
  self.post_init()
196
+ # tie_weights will be called by post_init, but we provide the override below.
197
+
198
  def get_input_embeddings(self):
199
  return self.tok_emb
200
+
201
  def set_input_embeddings(self, value):
202
  self.tok_emb = value
203
  self.tie_weights()
204
+
205
  def get_output_embeddings(self):
206
  return self.head
207
+
208
  def set_output_embeddings(self, new_embeddings):
209
  self.head = new_embeddings
210
+
211
  def tie_weights(self, *args, **kwargs):
212
+ self.head.weight = self.tok_emb.weight
213
+
214
  def forward(self, input_ids, attention_mask=None, return_dict=True, return_logits=False, **kwargs):
215
  b, t = input_ids.shape
216
  if t > self.config.ctx_len:
 
226
  return (hidden, logits) if return_logits else (hidden,)
227
  if return_logits:
228
  return hidden, logits
229
+ return BaseModelOutputWithPast(
230
+ last_hidden_state=hidden,
231
+ past_key_values=None,
232
+ hidden_states=None,
233
+ attentions=None,
234
+ )
235
+
236
  class TinyGPTForCausalLM(TinyGPTPreTrainedModel, GenerationMixin):
237
  _tied_weights_keys = ["tiny_gpt.head.weight"]
238
+
239
  def __init__(self, config: TinyGPTConfig):
240
  super().__init__(config)
241
  self.tiny_gpt = TinyGPTModel(config)
242
  self.post_init()
243
+
244
  def get_input_embeddings(self):
245
  return self.tiny_gpt.tok_emb
246
+
247
  def set_input_embeddings(self, value):
248
  self.tiny_gpt.tok_emb = value
249
  self.tie_weights()
250
+
251
  def get_output_embeddings(self):
252
  return self.tiny_gpt.head
253
+
254
  def set_output_embeddings(self, new_embeddings):
255
  self.tiny_gpt.head = new_embeddings
256
+
257
  def tie_weights(self, *args, **kwargs):
258
+ self.tiny_gpt.head.weight = self.tiny_gpt.tok_emb.weight
259
+
260
  def prepare_inputs_for_generation(self, input_ids, **kwargs):
261
  return {"input_ids": input_ids}
262
+
263
  def forward(self, input_ids, attention_mask=None, labels=None, return_dict=True, **kwargs):
264
+ hidden, logits = self.tiny_gpt(
265
+ input_ids=input_ids,
266
+ attention_mask=attention_mask,
267
+ return_dict=True,
268
+ return_logits=True,
269
+ )
270
  loss = None
271
  if labels is not None:
272
  shift_logits = logits[:, :-1, :].contiguous()
 
274
  loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)).float(), shift_labels.view(-1))
275
  if not return_dict:
276
  return ((loss, logits) if loss is not None else (logits,))
277
+ return CausalLMOutputWithPast(
278
+ loss=loss,
279
+ logits=logits,
280
+ past_key_values=None,
281
+ hidden_states=None,
282
+ attentions=None,
283
+ )