jiaxie commited on
Commit
0d269c5
·
verified ·
1 Parent(s): 0fb4414

Delete modeling_hyena.py

Browse files
Files changed (1) hide show
  1. modeling_hyena.py +0 -574
modeling_hyena.py DELETED
@@ -1,574 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """HyenaDNA custom code port to Hugging Face Hub"""
3
-
4
- import math
5
- import torch
6
- import torch.nn as nn
7
- from torch.nn import functional as F
8
- from .configuration_hyena import HyenaConfig
9
- from transformers import PreTrainedModel
10
- from typing import Optional, Tuple, Union
11
- from transformers.modeling_outputs import CausalLMOutput, SequenceClassifierOutput, BaseModelOutputWithNoAttention
12
-
13
-
14
- def fftconv(u, k, D):
15
- """
16
- We apply a convolution through the fourier domain (from the Convolution Theorem)
17
-
18
- """
19
- seqlen = u.shape[-1]
20
- fft_size = 2 * seqlen
21
-
22
- k_f = torch.fft.rfft(k.to(torch.float32), n=fft_size) / fft_size
23
- u_f = torch.fft.rfft(u.to(dtype=torch.float32), n=fft_size)
24
-
25
- if len(u.shape) > 3: k_f = k_f.unsqueeze(1)
26
- y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen]
27
-
28
- out = y + u * D.unsqueeze(-1)
29
- return out.to(dtype=u.dtype)
30
-
31
-
32
- @torch.jit.script
33
- def mul_sum(q, y):
34
- return (q * y).sum(dim=1)
35
-
36
-
37
- class HyenaSin(nn.Module):
38
- """The Sin activation function for the Hyena Filter function."""
39
- def __init__(self, config):
40
- super().__init__()
41
- self.freq = nn.Parameter(config.activation_freq * torch.ones(1, config.filter_order)) if config.train_freq else config.activation_freq * torch.ones(1, config.filter_order)
42
-
43
- def forward(self, x):
44
- return torch.sin(self.freq * x)
45
-
46
-
47
- class HyenaPositionalEmbedding(nn.Module):
48
- def __init__(self, config):
49
- """Complex exponential positional embeddings for Hyena filters."""
50
- super().__init__()
51
-
52
- self.seq_len = config.max_seq_len
53
- # The time embedding fed to the filteres is normalized so that t_f = 1
54
- t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
55
-
56
- if config.emb_dim > 1:
57
- bands = (config.emb_dim - 1) // 2
58
- # To compute the right embeddings we use the "proper" linspace
59
- t_rescaled = torch.linspace(0, self.seq_len - 1, self.seq_len)[None, :, None]
60
- w = 2 * math.pi * t_rescaled / self.seq_len # 1, L, 1
61
-
62
- f = torch.linspace(1e-4, bands - 1, bands)[None, None]
63
-
64
- z = torch.cat([t, torch.cos(-f * w), torch.sin(-f * w)], dim=-1)
65
-
66
- self.register_buffer("z", z)
67
- self.register_buffer("t", t)
68
-
69
- def forward(self, L):
70
- return self.z[:, :L], self.t[:, :L]
71
-
72
-
73
- class HyenaExponentialModulation(nn.Module):
74
- """The window function applied to the output of the (MLP) filter function."""
75
- def __init__(
76
- self,
77
- d_model,
78
- fast_decay_pct=0.3,
79
- slow_decay_pct=1.5,
80
- target=1e-2,
81
- modulate: bool=True,
82
- shift: float = 0.05,
83
- **kwargs
84
- ):
85
- super().__init__()
86
- self.modulate = modulate
87
- self.shift = shift
88
- max_decay = math.log(target) / fast_decay_pct
89
- min_decay = math.log(target) / slow_decay_pct
90
- deltas = torch.linspace(min_decay, max_decay, d_model)[None, None]
91
- self.register_buffer("deltas", deltas)
92
-
93
- def forward(self, t, x):
94
- if self.modulate:
95
- decay = torch.exp(-t * self.deltas.abs())
96
- x = x * (decay + self.shift)
97
- return x
98
-
99
-
100
- class HyenaFilter(nn.Module):
101
- def __init__(
102
- self,
103
- config,
104
- **kwargs
105
- ):
106
- """
107
- Implicit long filter with modulation.
108
-
109
- Args:
110
- d_model: number of channels in the input
111
- emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands
112
- order: width of the FFN
113
- num_inner_mlps: number of inner linear layers inside filter MLP
114
-
115
- Note:
116
- filter_dropout is not implemented
117
- """
118
- super().__init__()
119
-
120
- self.d_model = config.d_model * (config.hyena_order - 1)
121
- self.use_bias = config.use_bias
122
- self.bias = nn.Parameter(torch.randn(self.d_model))
123
- self.dropout = nn.Dropout(config.hyena_filter_dropout)
124
-
125
- act = HyenaSin(config)
126
- self.emb_dim = config.emb_dim
127
- assert self.emb_dim % 2 != 0 and self.emb_dim >= 3, "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)"
128
- self.seq_len = config.max_seq_len
129
-
130
- self.pos_emb = HyenaPositionalEmbedding(config)
131
-
132
- self.implicit_filter = nn.Sequential(
133
- nn.Linear(self.emb_dim, config.filter_order),
134
- act,
135
- )
136
- for i in range(config.num_inner_mlps):
137
- self.implicit_filter.append(nn.Linear(config.filter_order, config.filter_order))
138
- self.implicit_filter.append(act)
139
-
140
- self.implicit_filter.append(nn.Linear(config.filter_order, config.d_model, bias=False))
141
-
142
- self.modulation = HyenaExponentialModulation(config.d_model)
143
-
144
- self.normalized = False
145
-
146
- def filter(self, L, *args, **kwargs):
147
- z, t = self.pos_emb(L)
148
- h = self.implicit_filter(z.to(dtype=self.implicit_filter[0].weight.dtype))
149
- h = self.modulation(t, h)
150
- return h
151
-
152
- def forward(self, x, L, k=None, bias=None, *args, **kwargs):
153
- if k is None: k = self.filter(L)
154
-
155
- # Ensure compatibility with filters that return a tuple
156
- k = k[0] if type(k) is tuple else k
157
-
158
- y = fftconv(x, k, bias)
159
- return y
160
-
161
-
162
- class HyenaOperator(nn.Module):
163
- def __init__(
164
- self,
165
- config,
166
- **filter_args,
167
- ):
168
- r"""
169
- Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf
170
-
171
- Args:
172
- d_model (int): Dimension of the input and output embeddings (width of the layer)
173
- l_max: (int): Maximum input sequence length. Defaults to None
174
- order: (int): Depth of the Hyena recurrence. Defaults to 2
175
- dropout: (float): Dropout probability. Defaults to 0.0
176
- filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0
177
- """
178
- super().__init__()
179
-
180
- self.d_model = config.d_model
181
- self.l_max = config.max_seq_len
182
- self.order = config.hyena_order
183
- inner_width = config.d_model * (self.order + 1)
184
- self.dropout = nn.Dropout(config.hyena_dropout)
185
- self.in_proj = nn.Linear(self.d_model, inner_width)
186
- self.out_proj = nn.Linear(self.d_model, self.d_model)
187
-
188
- self.short_filter = nn.Conv1d(
189
- inner_width,
190
- inner_width,
191
- config.short_filter_order,
192
- padding=2,
193
- groups=inner_width
194
- )
195
- self.filter_fn = HyenaFilter(config)
196
-
197
- def forward(self, u):
198
- l = u.size(-2)
199
- l_filter = min(l, self.l_max)
200
- u = self.in_proj(u).transpose(1, 2)
201
-
202
- uc = self.short_filter(u)[...,:l_filter]
203
- *x, v = uc.split(self.d_model, dim=1)
204
-
205
- k = self.filter_fn.filter(l_filter)[0]
206
- k = k.transpose(0, 1).reshape(self.order - 1, self.d_model, l_filter)
207
- bias = self.filter_fn.bias.reshape(self.order - 1, self.d_model)
208
-
209
- for o, x_i in enumerate(reversed(x[1:])):
210
- v = self.dropout(v * x_i)
211
- v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])
212
-
213
- y = (v * x[0]).transpose(1, 2)
214
-
215
- y = self.out_proj(y)
216
- return y
217
-
218
- class HyenaMlp(nn.Module):
219
-
220
- def __init__(self, config):
221
- """
222
- From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/mlp.py
223
- """
224
- super().__init__()
225
- in_features = config.d_model
226
- hidden_features = config.d_inner
227
- self.fc1 = nn.Linear(in_features, hidden_features)
228
- self.fc2 = nn.Linear(hidden_features, config.d_model)
229
-
230
- def forward(self, x):
231
- y = self.fc1(x)
232
- y = F.gelu(y, approximate="tanh")
233
- y = self.fc2(y)
234
- return y
235
-
236
- class HyenaBlock(nn.Module):
237
-
238
- def __init__(self, config):
239
- """
240
- From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/block.py
241
- For prenorm=True, this Block has a slightly different structure compared to a regular
242
- prenorm Transformer block.
243
- The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
244
- [Ref: https://arxiv.org/abs/2002.04745]
245
- Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
246
- the hidden_states (output of the MLP) and the residual.
247
- This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
248
- The residual needs to be provided (except for the very first block).
249
- For prenorm=False, this Block has the same structure as a regular postnorm Transformer
250
- block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
251
- return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
252
- This is for performance reason: for post-norm architecture, returning the input allows us
253
- to fuse the backward of nn.Linear with the residual connection.
254
- """
255
- super().__init__()
256
- self.mixer = HyenaOperator(config)
257
- self.norm1 = nn.LayerNorm(config.d_model)
258
- self.mlp = HyenaMlp(config)
259
- self.norm2 = nn.LayerNorm(config.d_model)
260
-
261
- def forward(self, hidden_states):
262
- r"""Pass the input through the encoder layer.
263
- Args:
264
- hidden_states: the sequence to the encoder layer (required).
265
- residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
266
- mixer_subset: for cross-attention only. If not None, will take a subset of x
267
- before applying the query projection. Useful for e.g., ViT where we only care
268
- about the CLS token in the last layer.
269
- """
270
- residual = hidden_states
271
- residual = residual.to(torch.float32)
272
- hyena_normed = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
273
- hidden_states = self.mixer(hyena_normed)
274
- # Tested above here and all is equivalent. That means the mixer is fine!!!
275
- residual = hidden_states + residual
276
- hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
277
- residual = residual.to(torch.float32)
278
-
279
- hidden_states = self.mlp(hidden_states)
280
- return hidden_states + residual
281
-
282
-
283
- # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
284
-
285
-
286
- class HyenaEmbeddings(nn.Module):
287
-
288
- def __init__(self, config, padding_idx=None):
289
- """
290
- If max_position_embeddings <= 0, there's no position embeddings
291
- If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
292
- the project up to embed_dim
293
- """
294
- super().__init__()
295
- vocab_size = config.vocab_size
296
- if vocab_size % config.pad_vocab_size_multiple != 0:
297
- vocab_size += config.pad_vocab_size_multiple - (vocab_size % config.pad_vocab_size_multiple)
298
- self.word_embeddings = nn.Embedding(vocab_size, config.d_model, padding_idx=padding_idx)
299
-
300
- def forward(self, input_ids):
301
- """
302
- input_ids: (batch, seqlen)
303
- """
304
- embeddings = self.word_embeddings(input_ids)
305
- return embeddings
306
-
307
- class HyenaLMBackbone(nn.Module):
308
-
309
- def __init__(self, config) -> None:
310
- super().__init__()
311
- # note max_position_embeddings is 0 for Hyena, and therefore isn't used
312
- self.embeddings = HyenaEmbeddings(config)
313
- self.dropout = nn.Dropout(config.embed_dropout)
314
-
315
- self.layers = nn.ModuleList([HyenaBlock(config) for i in range(config.n_layer)])
316
-
317
- self.ln_f = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
318
- self.gradient_checkpointing = False
319
-
320
- def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
321
- all_hidden_states = []
322
- if inputs_embeds is not None:
323
- hidden_states = inputs_embeds
324
- else:
325
- hidden_states = self.embeddings(input_ids)
326
- if output_hidden_states:
327
- all_hidden_states.append(hidden_states)
328
-
329
- for layer in self.layers:
330
- if self.gradient_checkpointing and self.training:
331
- hidden_states = self._gradient_checkpointing_func(layer.__call__, hidden_states)
332
- else:
333
- hidden_states = layer(hidden_states)
334
- if output_hidden_states:
335
- all_hidden_states.append(hidden_states)
336
-
337
- hidden_states = self.ln_f(hidden_states.to(dtype=self.ln_f.weight.dtype))
338
- if output_hidden_states:
339
- all_hidden_states.append(hidden_states)
340
-
341
- return hidden_states, all_hidden_states
342
-
343
-
344
- class HyenaDNAPreTrainedModel(PreTrainedModel):
345
- config_class = HyenaConfig
346
- base_model_prefix = "hyena"
347
- supports_gradient_checkpointing = True
348
- _no_split_modules = ["HyenaBlock"]
349
- _skip_keys_device_placement = "past_key_values"
350
- _keys_to_ignore_on_load_missing = [r"freq"] # Shared tensors that safetensors merges
351
-
352
- def _init_weights(self, module, initializer_range=0.02):
353
- if isinstance(module, nn.Linear):
354
- nn.init.normal_(module.weight, std=initializer_range)
355
- if module.bias is not None:
356
- nn.init.zeros_(module.bias)
357
- elif isinstance(module, nn.Embedding):
358
- nn.init.normal_(module.weight, std=initializer_range)
359
- # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
360
- # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
361
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
362
- # > -- GPT-2 :: https://openai.com/blog/better-language-models/
363
- #
364
- # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
365
- for name, p in self.named_parameters():
366
- if name in ["out_proj.weight", "fc2.weight"]:
367
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
368
- nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers))
369
- # If using GLU activation for now, we scale the std by 2
370
- elif name in ["output_linear.0.weight"]:
371
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
372
- nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * self.config.num_layers))
373
-
374
-
375
- class HyenaDNAModel(HyenaDNAPreTrainedModel):
376
- def __init__(self, config, **kwargs) -> None:
377
- super().__init__(config, **kwargs)
378
-
379
- self.backbone = HyenaLMBackbone(config)
380
- self.config = config
381
-
382
- # Initialize weights and apply final processing
383
- self.post_init()
384
-
385
- def forward(self, input_ids, inputs_embeds=None, output_hidden_states=None, return_dict=None):
386
- output_hidden_states = (
387
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
388
- )
389
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
390
-
391
- hidden_states, all_hidden_states = self.backbone(input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states)
392
- if return_dict:
393
- return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states,
394
- hidden_states=all_hidden_states if output_hidden_states else None)
395
- elif output_hidden_states:
396
- return hidden_states, all_hidden_states
397
- else:
398
- return hidden_states
399
-
400
-
401
- class HyenaDNAForCausalLM(HyenaDNAPreTrainedModel):
402
-
403
- def __init__(self, config, **kwargs):
404
- super().__init__(config, **kwargs)
405
- self.hyena = HyenaDNAModel(config)
406
- vocab_size = config.vocab_size
407
- if vocab_size % config.pad_vocab_size_multiple != 0:
408
- vocab_size += config.pad_vocab_size_multiple - (vocab_size % config.pad_vocab_size_multiple)
409
- self.vocab_size = vocab_size
410
- self.lm_head = nn.Linear(config.d_model, vocab_size, bias=False)
411
-
412
- # Initialize weights and apply final processing
413
- self.post_init()
414
-
415
- def get_input_embeddings(self):
416
- return self.hyena.backbone.embeddings.word_embeddings
417
-
418
- def set_input_embeddings(self, value):
419
- self.hyena.backbone.embeddings.word_embeddings = value
420
-
421
- def get_output_embeddings(self):
422
- return self.lm_head
423
-
424
- def set_output_embeddings(self, new_embeddings):
425
- self.lm_head = new_embeddings
426
-
427
- def set_decoder(self, decoder):
428
- self.hyena = decoder
429
-
430
- def get_decoder(self):
431
- return self.hyena
432
-
433
- def forward(
434
- self,
435
- input_ids: torch.LongTensor = None,
436
- inputs_embeds: Optional[torch.FloatTensor] = None,
437
- labels: Optional[torch.LongTensor] = None,
438
- output_hidden_states: Optional[bool] = None,
439
- return_dict: Optional[bool] = None,
440
- ) -> Union[Tuple, CausalLMOutput]:
441
-
442
- output_hidden_states = (
443
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
444
- )
445
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
446
-
447
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
448
- outputs = self.hyena(
449
- input_ids=input_ids,
450
- inputs_embeds=inputs_embeds,
451
- output_hidden_states=output_hidden_states,
452
- return_dict=return_dict,
453
- )
454
-
455
- hidden_states = outputs[0]
456
- logits = self.lm_head(hidden_states)
457
- logits = logits.float()
458
-
459
- loss = None
460
- if labels is not None:
461
- # Shift so that tokens < n predict n
462
- shift_logits = logits[..., :-1, :].contiguous()
463
- shift_labels = labels[..., 1:].contiguous()
464
- # Flatten the tokens
465
- loss_fct = nn.CrossEntropyLoss()
466
- shift_logits = shift_logits.view(-1, self.vocab_size)
467
- shift_labels = shift_labels.view(-1)
468
- # Enable model parallelism
469
- shift_labels = shift_labels.to(shift_logits.device)
470
- loss = loss_fct(shift_logits, shift_labels)
471
-
472
- if not return_dict:
473
- output = (logits,) + outputs[1:]
474
- return (loss,) + output if loss is not None else output
475
-
476
- return CausalLMOutput(
477
- loss=loss,
478
- logits=logits,
479
- hidden_states=outputs.hidden_states,
480
- )
481
-
482
-
483
- class HyenaDNAForSequenceClassification(HyenaDNAPreTrainedModel):
484
- def __init__(self, config, **kwargs):
485
- super().__init__(config, **kwargs)
486
- self.num_labels = kwargs.get("num_labels", config.num_labels)
487
- self.hyena = HyenaDNAModel(config)
488
- self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
489
-
490
- # Initialize weights and apply final processing
491
- self.post_init()
492
-
493
- def get_input_embeddings(self):
494
- return self.hyena.backbone.embeddings.word_embeddings
495
-
496
- def set_input_embeddings(self, value):
497
- self.hyena.backbone.embeddings.word_embeddings = value
498
-
499
- def forward(
500
- self,
501
- input_ids: torch.LongTensor = None,
502
- inputs_embeds: Optional[torch.FloatTensor] = None,
503
- labels: Optional[torch.LongTensor] = None,
504
- output_hidden_states: Optional[bool] = None,
505
- return_dict: Optional[bool] = None,
506
- ) -> Union[Tuple, SequenceClassifierOutput]:
507
- r"""
508
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
509
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
510
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
511
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
512
- """
513
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
514
-
515
- transformer_outputs = self.hyena(
516
- input_ids,
517
- inputs_embeds=inputs_embeds,
518
- output_hidden_states=output_hidden_states,
519
- return_dict=return_dict,
520
- )
521
- hidden_states = transformer_outputs[0]
522
- logits = self.score(hidden_states)
523
-
524
- if input_ids is not None:
525
- batch_size = input_ids.shape[0]
526
- else:
527
- batch_size = inputs_embeds.shape[0]
528
-
529
- if self.config.pad_token_id is None and batch_size != 1:
530
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
531
- if self.config.pad_token_id is None:
532
- sequence_lengths = -1
533
- else:
534
- if input_ids is not None:
535
- sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
536
- logits.device
537
- )
538
- else:
539
- sequence_lengths = -1
540
-
541
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
542
-
543
- loss = None
544
- if labels is not None:
545
- labels = labels.to(logits.device)
546
- if self.config.problem_type is None:
547
- if self.num_labels == 1:
548
- self.config.problem_type = "regression"
549
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
550
- self.config.problem_type = "single_label_classification"
551
- else:
552
- self.config.problem_type = "multi_label_classification"
553
-
554
- if self.config.problem_type == "regression":
555
- loss_fct = nn.MSELoss()
556
- if self.num_labels == 1:
557
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
558
- else:
559
- loss = loss_fct(pooled_logits, labels)
560
- elif self.config.problem_type == "single_label_classification":
561
- loss_fct = nn.CrossEntropyLoss()
562
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
563
- elif self.config.problem_type == "multi_label_classification":
564
- loss_fct = nn.BCEWithLogitsLoss()
565
- loss = loss_fct(pooled_logits, labels)
566
- if not return_dict:
567
- output = (pooled_logits,) + transformer_outputs[1:]
568
- return ((loss,) + output) if loss is not None else output
569
-
570
- return SequenceClassifierOutput(
571
- loss=loss,
572
- logits=pooled_logits,
573
- hidden_states=transformer_outputs.hidden_states,
574
- )