File size: 35,570 Bytes
7b526cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
import math
import typing

import einops
from functools import partial
import huggingface_hub
import omegaconf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
import transformers
from functools import lru_cache
from .config import EsoLMConfig

torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True
import torch._inductor.config as inductor_cfg
inductor_cfg.triton.cudagraphs = True
inductor_cfg.coordinate_descent_tuning = True

# Flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)


@lru_cache
def _causal_mask(b, h, q_idx, kv_idx):
  causal = q_idx >= kv_idx
  return causal


@lru_cache
def _get_causal_mask(seq_len):
  return create_block_mask(
    _causal_mask,
    B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len)


@lru_cache
def _bidirectional_mask(b, h, q_idx, kv_idx):
  bidirectional = q_idx == q_idx
  return bidirectional


@lru_cache
def _get_bidirectional_mask(seq_len):
  return create_block_mask(
    _bidirectional_mask,
    B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len)


@lru_cache
def _mixed_mask(b, h, q_idx, kv_idx, cutoffs):
  causal = q_idx >= kv_idx
  block_identity = q_idx >= cutoffs[b]
  return causal | block_identity


@lru_cache
def _get_mixed_mask(seq_len, cutoffs):
  return create_block_mask(
    partial(_mixed_mask, cutoffs=cutoffs),
    B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len)


@lru_cache
def _mixed2_mask(b, h, q_idx, kv_idx, cutoffs):
  causal = q_idx >= kv_idx
  block_identity = (q_idx < cutoffs[b]) & (kv_idx < cutoffs[b])
  return causal | block_identity


@lru_cache
def _get_mixed2_mask(seq_len, cutoffs):
  return create_block_mask(
    partial(_mixed2_mask, cutoffs=cutoffs),
    B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len)


def _block_diff_mask(b, h, q_idx, kv_idx, block_size=1, n=None):
  """
  Copied directly from BD3LM's codebase: https://github.com/kuleshov-group/bd3lms

  Constructs the specialized block diffusion attention mask for training
  composed of three masks:
  - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks
  - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context
  - **Block Causal Mask (M_BC)**: Attention to update x0

  Args:
      b, h: Batch and head indices (ignored for mask logic).
      q_idx, kv_idx: Query and Key indices.
      seq_len: Total sequence length.
      block_size: Defines the block structure.

  Returns:
      A boolean attention mask.
  """

  # Indicate whether token belongs to xt or x0
  x0_flag_q = (q_idx >= n)
  x0_flag_kv = (kv_idx >= n)

  # Compute block indices
  block_q = torch.where(x0_flag_q == 1,
                        (q_idx - n) // block_size,
                        q_idx // block_size)
  block_kv = torch.where(x0_flag_kv == 1,
                         (kv_idx - n) // block_size,
                         kv_idx // block_size)

  # **1. Block Diagonal Mask (M_BD) **
  block_diagonal = (
    block_q == block_kv) & (x0_flag_q == x0_flag_kv)

  # **2. Offset Block-Causal Mask (M_OBC) **
  offset_block_causal = ((block_q > block_kv)
                          & (x0_flag_kv == 1)
                          & (x0_flag_q == 0))

  # **3. Block-Causal Mask (M_BC) **
  block_causal = (block_q >= block_kv) & (
    x0_flag_kv == 1) & (x0_flag_q == 1)

  # **4. Combine Masks **
  return block_diagonal | offset_block_causal | block_causal


@lru_cache
def _get_seq_mask(seq_len):
  # here, seq_len means the length of zt only
  return create_block_mask(
    partial(_block_diff_mask, block_size=1, n=seq_len),
    B=None, H=None, Q_LEN=seq_len*2, KV_LEN=seq_len*2)


def _block_diff_mask_prefix_lm(b, h, q_idx, kv_idx, n, cutoffs):
  block_diff_mask_output = _block_diff_mask(
    b, h, q_idx, kv_idx, block_size=1, n=n)
  block_prefix_lm = (
    (n <= q_idx) & (q_idx < n + cutoffs[b])
    & (n <= kv_idx) & (kv_idx < n + cutoffs[b]))
  return block_diff_mask_output | block_prefix_lm


@lru_cache
def _get_seq_mask_prefix_lm(seq_len, cutoffs):
  # here, seq_len means the length of zt only
  return create_block_mask(
    partial(_block_diff_mask_prefix_lm, n=seq_len, cutoffs=cutoffs),
    B=None, H=None, Q_LEN=seq_len*2, KV_LEN=seq_len*2)


flex_attention_compiled = torch.compile(flex_attention, dynamic=False, fullgraph=True, mode='reduce-overhead')
# flex_attention_compiled = torch.compile(flex_attention, dynamic=False, fullgraph=True, mode='max-autotune-no-cudagraphs')
# flex_attention_compiled = flex_attention
# flex_attention_compiled = torch.compile(flex_attention, dynamic=True)


def fused_flex_attention(q, k, v, mask=None):
  return flex_attention_compiled(q, k, v, block_mask=mask)


def bias_dropout_add_scale(
    x: torch.Tensor,
    bias: typing.Optional[torch.Tensor],
    scale: torch.Tensor,
    residual: typing.Optional[torch.Tensor],
    prob: float,
    training: bool) -> torch.Tensor:
  if bias is not None:
    out = scale * F.dropout(x + bias, p=prob, training=training)
  else:
    out = scale * F.dropout(x, p=prob, training=training)

  if residual is not None:
    out = residual + out
  return out


def get_bias_dropout_add_scale(training):
  def _bias_dropout_add(x, bias, scale, residual, prob):
    return bias_dropout_add_scale(
      x, bias, scale, residual, prob, training)

  return _bias_dropout_add


# function overload
def modulate(x: torch.Tensor,
             shift: torch.Tensor,
             scale: torch.Tensor) -> torch.Tensor:
  return x * (1 + scale) + shift


@torch.jit.script
def bias_dropout_add_scale_fused_train(
    x: torch.Tensor,
    bias: typing.Optional[torch.Tensor],
    scale: torch.Tensor,
    residual: typing.Optional[torch.Tensor],
    prob: float) -> torch.Tensor:
  return bias_dropout_add_scale(
    x, bias, scale, residual, prob, True)


@torch.jit.script
def bias_dropout_add_scale_fused_inference(
    x: torch.Tensor,
    bias: typing.Optional[torch.Tensor],
    scale: torch.Tensor,
    residual: typing.Optional[torch.Tensor],
    prob: float) -> torch.Tensor:
  return bias_dropout_add_scale(
    x, bias, scale, residual, prob, False)


@torch.jit.script
def modulate_fused(x: torch.Tensor,
                   shift: torch.Tensor,
                   scale: torch.Tensor) -> torch.Tensor:
  return modulate(x, shift, scale)


class Rotary(torch.nn.Module):
  def __init__(self, dim, base=10_000):
    super().__init__()
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    self.register_buffer('inv_freq', inv_freq)
    self.seq_len_cached = None
    self.cos_cached = None
    self.sin_cached = None

  def forward(self, x, seq_dim=1):
    seq_len = x.shape[seq_dim]
    if seq_len != self.seq_len_cached:
      self.seq_len_cached = seq_len
      t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
      freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
      emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
      # dims are: batch, seq_len, qkv, head, dim
      self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1)
      self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1)
      # This makes the transformation on v an identity.
      self.cos_cached[:,:,2,:,:].fill_(1.)
      self.sin_cached[:,:,2,:,:].fill_(0.)

    return self.cos_cached, self.sin_cached


def rotate_half(x, interleaved=False):
  """Copied and refactored from FlashAttention"""
  if interleaved:
    x1, x2 = x[..., ::2], x[..., 1::2]
    return einops.rearrange(
      torch.stack((-x2, x1), dim=-1),
      "... d two -> ... (d two)",
      two=2)
  x1, x2 = x.chunk(2, dim=-1)
  return torch.cat((-x2, x1), dim=-1)


def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
  """
  Copied and refactored from FlashAttention
  x: (batch_size, seq_len, nheads, headdim)
  cos, sin: (seq_len, rotary_dim / 2) or (batch_size, seq_len, rotary_dim / 2)
  """
  ro_dim = cos.shape[-1] * 2
  assert ro_dim <= x.shape[-1]
  pattern = "... d -> ... 1 (2 d)"
  if interleaved:
    pattern =  "... d -> ... 1 (d 2)"
  cos = einops.repeat(cos, pattern)
  sin = einops.repeat(sin, pattern)
  return torch.cat(
      [x[..., :ro_dim] * cos
       + rotate_half(x[..., :ro_dim],
                     interleaved) * sin, x[..., ro_dim:]],
      dim=-1)


def _split_rotary(rotary_cos_sin, dtype):
  cos, sin = rotary_cos_sin
  cos = cos.to(dtype)
  sin = sin.to(dtype)
  cos = cos[0,:,0,0,:cos.shape[-1]//2]
  sin = sin[0,:,0,0,:sin.shape[-1]//2]
  return cos, sin


def split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin):
  with torch.amp.autocast('cuda', enabled=False):
    cos, sin = _split_rotary(rotary_cos_sin, dtype=qkv.dtype)
    q, k, v = qkv.chunk(3, dim=2)
    q = apply_rotary_emb_torch(
      q.squeeze(dim=2), cos, sin)
    k = apply_rotary_emb_torch(
      k.squeeze(dim=2), cos, sin)
    v = v.squeeze(dim=2)
  return q, k, v


def split_and_apply_rotary_pos_emb_batch(qkv, rotary_cos_sin):
  with torch.amp.autocast('cuda', enabled=False):
    cos, sin = rotary_cos_sin
    cos = cos.to(qkv.dtype)
    sin = sin.to(qkv.dtype)
    cos = cos[:,:,0,0,:cos.shape[-1]//2]  # difference is here
    sin = sin[:,:,0,0,:sin.shape[-1]//2]  # difference is here
    q, k, v = qkv.chunk(3, dim=2)
    q = apply_rotary_emb_torch(
      q.squeeze(dim=2), cos, sin)
    k = apply_rotary_emb_torch(
      k.squeeze(dim=2), cos, sin)
    v = v.squeeze(dim=2)
  return q, k, v


def flex_attention_multi_headed(q, k, v, mask):
  q = q.transpose(1, 2).contiguous()
  k = k.transpose(1, 2).contiguous()
  v = v.transpose(1, 2).contiguous()
  attention_output = fused_flex_attention(q, k, v, mask=mask)
  attention_output = attention_output.transpose(1, 2).contiguous()
  return einops.rearrange(attention_output, 'b s h d -> b s (h d)')

#################################################################################
#                                  Layers                                       #
#################################################################################
class LayerNorm(nn.Module):
  def __init__(self, dim):
    super().__init__()
    self.weight = nn.Parameter(torch.ones([dim]))
    self.dim = dim
  def forward(self, x):
    with torch.amp.autocast('cuda', enabled=False):
      x = F.layer_norm(x.float(), [self.dim])
    return x * self.weight[None, None, :]


def residual_linear(x, W, x_skip, residual_scale):
  """x_skip + residual_scale * W @ x"""
  dim_out, dim_in = W.shape[0], W.shape[1]
  return torch.addmm(
    x_skip.view(-1, dim_out),
    x.view(-1, dim_in),
    W.T,
    alpha=residual_scale).view(*x.shape[:-1], dim_out)


#################################################################################
#               Embedding Layers for Timesteps and Class Labels                 #
#################################################################################
class TimestepEmbedder(nn.Module):
  """
  Embeds scalar timesteps into vector representations.
  """
  def __init__(self, hidden_size, frequency_embedding_size=256):
    super().__init__()
    self.mlp = nn.Sequential(
      nn.Linear(frequency_embedding_size, hidden_size, bias=True),
      nn.SiLU(),
      nn.Linear(hidden_size, hidden_size, bias=True))
    self.frequency_embedding_size = frequency_embedding_size

  @staticmethod
  def timestep_embedding(t, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.
    :param t: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an (N, D) Tensor of positional embeddings.
    """
    # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
    half = dim // 2
    freqs = torch.exp(
      - math.log(max_period)
      * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
      / half)
    args = t[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
      embedding = torch.cat(
        [embedding,
         torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

  def forward(self, t):
    t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
    t_emb = self.mlp(t_freq)
    return t_emb


class LabelEmbedder(nn.Module):
  """Embeds class labels into vector representations.
  
  Also handles label dropout for classifier-free guidance.
  """
  def __init__(self, num_classes, cond_size):
    super().__init__()
    self.embedding_table = nn.Embedding(num_classes + 1, cond_size)
    self.num_classes = num_classes

    # TODO think of initializing with 0.02 std deviation like in original DiT paper

  def forward(self, labels):
    embeddings = self.embedding_table(labels)
    return embeddings
    

#################################################################################
#                                 Core Model                                    #
#################################################################################

class DDiTBlockCausal(nn.Module):
  def __init__(self, dim, n_heads, mlp_ratio=4, dropout=0.1):
    super().__init__()
    self.n_heads = n_heads

    self.dim = dim
    self.norm1 = LayerNorm(dim)
    self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
    self.attn_out = nn.Linear(dim, dim, bias=False)
    self.dropout1 = nn.Dropout(dropout)

    self.norm2 = LayerNorm(dim)
    self.mlp = nn.Sequential(
      nn.Linear(dim, mlp_ratio * dim, bias=True),
      nn.GELU(approximate='tanh'),
      nn.Linear(mlp_ratio * dim, dim, bias=True))
    self.dropout2 = nn.Dropout(dropout)
    self.dropout = dropout

    self.past_k = None
    self.past_v = None

  def _get_bias_dropout_scale(self):
    if self.training:
      return bias_dropout_add_scale_fused_train
    else:
      return bias_dropout_add_scale_fused_inference

  def reset_kv_cache(self):
    self.past_k = None
    self.past_v = None

  def _process_and_update_kv(self, k, v):
    if (self.past_k is not None
        and self.past_v is not None):
      k = torch.cat([self.past_k, k], dim=1)
      v = torch.cat([self.past_v, v], dim=1)
    self.past_k = k
    self.past_v = v
    return k, v

  @torch.no_grad()
  def _attention_with_kv_cache(self, qkv, rotary_cos_sin):
    assert qkv.shape[1] == 1
    q, k, v = qkv.chunk(3, dim=2)
    k, v = self._process_and_update_kv(k=k, v=v)
    with torch.amp.autocast('cuda', enabled=False):
      cos, sin = _split_rotary(rotary_cos_sin, q.dtype)
      q = apply_rotary_emb_torch(
        q.squeeze(dim=2), cos[-1:, :], sin[-1:, :])
      k = apply_rotary_emb_torch(k.squeeze(dim=2), cos, sin)
      v = v.squeeze(dim=2)
    scale = q.shape[-1] ** 0.5
    # swap seq_len and num_heads
    q = q.transpose(1, 2)
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
    attn_weights = F.softmax(attn_scores, dim=-1)
    x =  torch.matmul(attn_weights, v).transpose(1, 2)
    return x.view(x.shape[0], 1, self.dim)

  def forward(self, x, rotary_cos_sin, kv_cache=False, **kwargs):
    del kwargs
    bias_dropout_scale_fn = self._get_bias_dropout_scale()
    x_skip = x
    x = self.norm1(x)
    qkv = einops.rearrange(
      self.attn_qkv(x),
      'b s (three h d) -> b s three h d',
      three=3,
      h=self.n_heads)
    
    if kv_cache:
      x = self._attention_with_kv_cache(qkv.detach())
    else:
      q, k, v = split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin)
      # recreate the mask every time (cheap) to fit different input length
      # different input length can happen during generation
      attn_mask = _get_causal_mask(x.shape[1])
      x = flex_attention_multi_headed(q, k, v, attn_mask)

    scale = torch.ones(1, device=x.device, dtype=x.dtype)
    x = bias_dropout_scale_fn(
      self.attn_out(x), None, scale, x_skip, self.dropout)

    # mlp operation
    x = bias_dropout_scale_fn(
      self.mlp(self.norm2(x)), None, scale, x, self.dropout)
    return x


class DDiTBlock(nn.Module):
  def __init__(self, dim, n_heads, adaLN,
               cond_dim=None, mlp_ratio=4,
               dropout=0.1):
    super().__init__()
    self.n_heads = n_heads
    self.dim = dim
    self.adaLN = adaLN

    self.norm1 = LayerNorm(dim)
    self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
    self.attn_out = nn.Linear(dim, dim, bias=False)
    self.dropout1 = nn.Dropout(dropout)

    self.norm2 = LayerNorm(dim)
    self.mlp = nn.Sequential(
      nn.Linear(dim, mlp_ratio * dim, bias=True),
      nn.GELU(approximate='tanh'),
      nn.Linear(mlp_ratio * dim, dim, bias=True))
    self.dropout2 = nn.Dropout(dropout)
    self.dropout = dropout

    if self.adaLN:
      self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim)
      self.adaLN_modulation.weight.data.zero_()
      self.adaLN_modulation.bias.data.zero_()

    self.past_k = None
    self.past_v = None
    self.neg_infinity = -1000000.0

  def _get_bias_dropout_scale(self):
    if self.training:
      return bias_dropout_add_scale_fused_train
    else:
      return bias_dropout_add_scale_fused_inference

  def reset_kv_cache(self):
    self.past_k = None
    self.past_v = None

  def _process_and_update_kv(self, k, v, num_clean):
    if num_clean == 0:
      # no caching if all we see if mask tokens
      return k, v
    else:
      if (self.past_k is None 
          and self.past_v is None):
        self.past_k = k[:, :num_clean, :, :]
        self.past_v = v[:, :num_clean, :, :]
        return k, v
      else:
        k_so_far = torch.cat([self.past_k, k], dim=1)
        v_so_far = torch.cat([self.past_v, v], dim=1)
        # only update the kv cache with kv values from
        # clean tokens generated during the previous 
        # iteration
        self.past_k = torch.cat(
          [self.past_k, k[:, :num_clean, :, :]], dim=1)
        self.past_v = torch.cat(
          [self.past_v, v[:, :num_clean, :, :]], dim=1)
        return k_so_far, v_so_far

  @torch.no_grad()
  def _attention_with_kv_cache(self, qkv, rotary_cos_sin, 
                               num_clean, num_clean_and_mask):
    # num_clean: num gen last
    # num_clean_and_mask: num gen last + num to gen
    assert qkv.shape[1] == num_clean_and_mask
    # qkv shape: 
    # [bs, num gen last + num to gen, 3, h, d]
    q, k, v = qkv.chunk(3, dim=2)
    q = q.squeeze(dim=2)
    k = k.squeeze(dim=2)
    v = v.squeeze(dim=2)
    k, v = self._process_and_update_kv(
      k=k, v=v, num_clean=num_clean)
    # new kv shape: 
    # [bs, 
    #  num gen before last + num gen last + num to gen, 
    #  h, d]
    with torch.amp.autocast('cuda', enabled=False):
      cos, sin = rotary_cos_sin
      cos = cos.to(qkv.dtype)
      sin = sin.to(qkv.dtype)
      cos = cos[:,:,0,0,:cos.shape[-1]//2]
      sin = sin[:,:,0,0,:sin.shape[-1]//2]
      cos_part = cos[:, -num_clean_and_mask:]
      sin_part = sin[:, -num_clean_and_mask:]
      q = apply_rotary_emb_torch(q, cos_part, sin_part)
      k = apply_rotary_emb_torch(k, cos, sin)
    scale = q.shape[-1] ** 0.5
    # shapes after transpose:
    # q: [bs, h, num gen last + num to gen, d]
    # k: [bs, h, num gen before last + num gen last + num to gen, d]
    # v: [bs, h, num gen before last + num gen last + num to gen, d]
    q = q.transpose(1, 2)
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)
    # attn_scores shape: 
    # [bs, h, 
    #  num gen last + num to gen, 
    #  num gen before last + num gen last + num to gen]
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
    ones = torch.ones(
      num_clean_and_mask, num_clean_and_mask).to(qkv.device)
    # A contains very large negative values above the diagonal
    # - q attends to all v values over "num gen before last"
    # - q attends causally to v values within "num gen last
    #   + num to gen"
    A = self.neg_infinity * torch.triu(ones, diagonal=1)
    A = A.view(1, 1, num_clean_and_mask, num_clean_and_mask)
    attn_scores[:, :, :, -num_clean_and_mask:] += A
    attn_weights = F.softmax(attn_scores, dim=-1)
    # matmul shape: [bs, h, num gen last + num to gen, d] 
    # shape after tranpose: [bs, num gen last + num to gen, h, d]
    attn_output = torch.matmul(attn_weights, v).transpose(1, 2)
    return einops.rearrange(attn_output, 'b s h d -> b s (h d)')

  def forward(self, x, rotary_cos_sin, c=None, attn_mask=None,
              kv_cache=False, num_clean=None, num_clean_and_mask=None):
    bias_dropout_scale_fn = self._get_bias_dropout_scale()

    x_skip = x
    x = self.norm1(x)
    if self.adaLN:
      # self.adaLN_modulation(c): (128, 1536)
      # self.adaLN_modulation(c)[:, None]: (128, 1, 1536)
      # "" .chunk(6, dim=2) returns 6 tuples of shapes (128, 1, 256)
      (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
       gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
      x = modulate_fused(x, shift_msa, scale_msa)

    qkv = einops.rearrange(
      self.attn_qkv(x),
      'b s (three h d) -> b s three h d',
      three=3,
      h=self.n_heads).contiguous()
    if kv_cache:
      x = self._attention_with_kv_cache(
        qkv.detach(), rotary_cos_sin,
        num_clean=num_clean, num_clean_and_mask=num_clean_and_mask)
    else:
      if rotary_cos_sin[0].shape[0] > 1:
        q, k, v = split_and_apply_rotary_pos_emb_batch(qkv, rotary_cos_sin)
      else:
        q, k, v = split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin)
      x = flex_attention_multi_headed(q, k, v, attn_mask)

    if self.adaLN:
      x = bias_dropout_scale_fn(self.attn_out(x),
                                None,
                                gate_msa,
                                x_skip,
                                self.dropout)
      x = bias_dropout_scale_fn(
        self.mlp(modulate_fused(
          self.norm2(x), shift_mlp, scale_mlp)),
        None, gate_mlp, x, self.dropout)
    else:
      scale = torch.ones(1, device=x.device, dtype=x.dtype)
      x = bias_dropout_scale_fn(
        self.attn_out(x), None, scale, x_skip, self.dropout)
      x = bias_dropout_scale_fn(
        self.mlp(self.norm2(x)), None, scale, x, self.dropout)
    return x


class EmbeddingLayer(nn.Module):
  def __init__(self, dim, vocab_dim):
    super().__init__()
    self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
    torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))

  def forward(self, x):
    if x.ndim == 2:
      return self.embedding[x]
    assert x.ndim == 3
    return torch.einsum(
      "blv,ve->ble",
      torch.nn.functional.softmax(x, dim=-1).float(),
      self.embedding.float()).to(x.dtype)


class DDiTFinalLayer(nn.Module):
  def __init__(self, hidden_size, out_channels, cond_dim,
               adaLN):
    super().__init__()
    self.norm_final = LayerNorm(hidden_size)
    self.linear = nn.Linear(hidden_size, out_channels)
    self.linear.weight.data.zero_()
    self.linear.bias.data.zero_()
    self.adaLN = adaLN
    if self.adaLN:
      self.adaLN_modulation = nn.Linear(cond_dim,
                                        2 * hidden_size,
                                        bias=True)
      self.adaLN_modulation.weight.data.zero_()
      self.adaLN_modulation.bias.data.zero_()


  def forward(self, x, c):
    x = self.norm_final(x)
    if self.adaLN:
      shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
      x = modulate_fused(x, shift, scale)
    x = self.linear(x)
    return x


class DiT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
  def __init__(self, config, vocab_size: int):
    super().__init__()
    if type(config) == dict:
      config = omegaconf.OmegaConf.create(config)
    self.causal = config.algo.causal_attention
    self.adaLN = not self.causal
    self.config = config
    self.vocab_size = vocab_size
    dim = config.model.hidden_size
    cond_dim = config.model.cond_dim
    self.vocab_embed = EmbeddingLayer(dim, vocab_size)
    if not self.causal:
      self.sigma_map = TimestepEmbedder(cond_dim)
    self.rotary_dim = dim // config.model.n_heads
    self.rotary_emb = Rotary(self.rotary_dim)

    blocks = []
    for _ in range(config.model.n_blocks):
      if self.causal:
        block = DDiTBlockCausal(
          dim=dim,
          n_heads=config.model.n_heads,
          dropout=config.model.dropout)
      else:
        block = DDiTBlock(
          dim=dim,
          n_heads=config.model.n_heads,
          cond_dim=cond_dim,
          adaLN=self.adaLN,
          dropout=config.model.dropout)
      blocks.append(block)
    self.blocks = nn.ModuleList(blocks)

    self.output_layer = DDiTFinalLayer(
      hidden_size=dim,
      out_channels=vocab_size,
      cond_dim=cond_dim,
      adaLN=self.adaLN)
    self.scale_by_sigma = config.model.scale_by_sigma

  def _get_bias_dropout_scale(self):
    if self.training:
      return bias_dropout_add_scale_fused_train
    else:
      return  bias_dropout_add_scale_fused_inference

  def reset_kv_cache(self):
    for block in self.blocks:
      block.reset_kv_cache()

  def forward(self, x, sigma, x0=None, kv_cache=False):
    assert x0 is None
    x = self.vocab_embed(x)
    if self.causal:
      t_cond = None
    else:
      t_cond = F.silu(self.sigma_map(sigma))

    rotary_cos_sin = self.rotary_emb(x)
    if kv_cache:
      x = x[:, -1:, :]
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
      for i in range(len(self.blocks)):
        x = self.blocks[i](
          x, rotary_cos_sin, c=t_cond, kv_cache=kv_cache)
      x = self.output_layer(x, c=t_cond)
    return x


def _get_reverse_indices(indices):
  """
  indices: LongTensor of shape [B, N] representing permutations
  returns: LongTensor of shape [B, N] representing the inverse permutations
  """
  B, N = indices.shape
  reverse_indices = torch.empty_like(indices)
  arange = torch.arange(N, device=indices.device).unsqueeze(0).expand(B, -1)
  reverse_indices.scatter_(1, indices, arange)
  return reverse_indices


class EsoLMDiT(DiT):
  def __init__(self, config, vocab_size: int, mask_index: int):
    super().__init__(config, vocab_size)
    # sequential not causal
    # this also makes sure that
    # - sigma_map was created
    # - DDiTBlock was used instead of DDiTBlockCausal
    assert not self.causal and self.adaLN
    self.mask_index = mask_index

    self.diffusion_shuffle = config.algo.diffusion_shuffle
    self.diffusion_attn_mode = config.algo.diffusion_attn_mode
    self.sequential_shuffle = config.algo.sequential_shuffle
    self.sequential_attn_mode = config.algo.sequential_attn_mode

    self.mdlm_mask = None
    self.seq_mask = None

  def _sort_indices(
    self, indices, shuffle, keep_masks_unshuffled=False):
    masked = (indices == self.mask_index)
    if shuffle:
      offsets = torch.rand(
        indices.shape).to(indices.device) * 0.9
      if keep_masks_unshuffled:
        # induce left-to-right order within masked tokens
        # only for sequential part
        offsets[masked] = torch.linspace(
          0, 1, torch.sum(masked)).to(indices.device)
    else:
      offsets = torch.linspace(
        0, 0.9, indices.shape[1]).to(indices.device)
    sort_idx = (masked + offsets).argsort(descending=False)
    indices = torch.gather(indices, dim=1, index=sort_idx)
    return indices, sort_idx
  
  def _sort_rotary_cos_sin(self, rotary_cos_sin, sort_idx):
    # example cos shape: (1, 128, 3, 1, 32)
    # 128 for seq_len, 3 for qkv, 32 for head dim
    cos, sin = rotary_cos_sin
    bs = sort_idx.shape[0]
    cos = cos.expand(bs, -1, -1, -1, -1)
    sin = sin.expand(bs, -1, -1, -1, -1)
    cos = torch.gather(
      cos, dim=1, 
      index=sort_idx[:, :, None, None, None].expand(
        -1, -1, 3, -1, self.rotary_dim)).contiguous()
    sin = torch.gather(
      sin, dim=1, 
      index=sort_idx[:, :, None, None, None].expand(
        -1, -1, 3, -1, self.rotary_dim)).contiguous()
    return cos, sin

  def _get_attention_mask(self, seq_len, attn_mode=None,
                          cutoffs=None):
    if attn_mode == 'causal':
      if self.mdlm_mask is None:
        self.mdlm_mask = _get_causal_mask(seq_len)
      return self.mdlm_mask
    elif attn_mode == 'bidirectional':
      if self.mdlm_mask is None:
        self.mdlm_mask = _get_bidirectional_mask(seq_len)
      return self.mdlm_mask
    elif attn_mode == 'mixed':
      # causal over clean tokens
      # bidirectional over masked tokens
      return _get_mixed_mask(seq_len=seq_len,
                             cutoffs=cutoffs)
    elif attn_mode == 'mixed2':
      # bidirectional over clean tokens
      # causal over masked tokens
      return _get_mixed2_mask(seq_len=seq_len,
                              cutoffs=cutoffs)

  def _diffusion_features(self, zt, sort_idx=None,
                          attn_mode=None, cutoffs=None):
    # masked diffusion:
    #  - move masked tokens to the left
    #  - move unmasked tokens to the right
    if cutoffs is None:
      cutoffs = torch.sum(zt != self.mask_index, dim=1)
    if attn_mode is None:
      attn_mode = self.diffusion_attn_mode
    if sort_idx is None:
      zt, sort_idx = self._sort_indices(
        zt, self.diffusion_shuffle)
    x = self.vocab_embed(zt)
    rotary_cos_sin = self.rotary_emb(x)
    rotary_cos_sin = self._sort_rotary_cos_sin(
      rotary_cos_sin, sort_idx)
    attention_mask = self._get_attention_mask(
      seq_len=zt.shape[1],
      attn_mode=attn_mode,
      cutoffs=cutoffs)
    return {'x': x,
            'rotary': rotary_cos_sin,
            'attention': attention_mask,
            'sorted_indices': sort_idx}

  def _sequential_features(self, zt, x0):
    # gap-filling AR with trick from BD3LM 
    #  - also move masked tokens to the left
    #  - also move unmasked tokens to the right
    seq_len = zt.shape[1]
    zt, sort_idx = self._sort_indices(
      zt, self.sequential_shuffle, 
      keep_masks_unshuffled=True)
    x0 = torch.gather(x0, dim=1, index=sort_idx)
    zt_and_x0 = torch.cat([zt, x0], dim=1)
    cutoffs = torch.sum(zt != self.mask_index, dim=1)
    x = self.vocab_embed(zt_and_x0)
    rotary_cos_sin = self.rotary_emb(x[:, :seq_len])
    rotary_cos_sin = self._sort_rotary_cos_sin(
      rotary_cos_sin, sort_idx)
    cos, sin = rotary_cos_sin
    cos = torch.cat([cos, cos], dim=1)
    sin = torch.cat([sin, sin], dim=1)
    rotary_cos_sin = (cos, sin)
    
    if self.sequential_attn_mode == 'causal':
      if self.seq_mask is None:
        self.seq_mask = _get_seq_mask(seq_len)
      return {'x': x,
              'rotary': rotary_cos_sin,
              'attention': self.seq_mask,
              'sorted_indices': sort_idx}
    elif self.sequential_attn_mode == 'mixed':
      return {'x': x,
              'rotary': rotary_cos_sin,
              'attention': _get_seq_mask_prefix_lm(
                seq_len, cutoffs=cutoffs),
              'sorted_indices': sort_idx}

  def forward(self, zt, sigma, x0=None):
    diffusion_mode = x0 is None
    seq_len = zt.shape[1]

    if diffusion_mode:
      features = self._diffusion_features(zt)
    else:
      features = self._sequential_features(zt, x0)
    x = features['x']
    t_cond = F.silu(self.sigma_map(sigma))
    with torch.amp.autocast('cuda', enabled=False):
      for i in range(len(self.blocks)):
        x = self.blocks[i](x, features['rotary'], c=t_cond, 
                           attn_mask=features['attention'])
      x = self.output_layer(x, c=t_cond)

    if not diffusion_mode:
      x = x[:, :seq_len]
    sort_idx_reversed = _get_reverse_indices(features['sorted_indices'])
    x = torch.gather(
      x, dim=1, 
      index=sort_idx_reversed[:, :, None].expand(
        -1, -1, self.vocab_size))
    return x

  @torch.no_grad()
  def forward_sample(self, zt, sort_idx, attn_mode=None,
                     cutoffs=None, kv_cache=False,
                     last_k_start=None,
                     curr_k_start=None,
                     curr_k_end=None):
    """
    zt is expected to be sorted as per sort_idx.
    
    When kv_cache is true:
    - zt will have shape (num_samples, model.length); we need its shape to generate
      all the rotary embeddings because any of them can be selected by
      the random ordering
    - sort_idx will have shape 
      (num_samples, model.length) for the same reason
    - last_k_start_idx (starting index)
    - curr_k_start_idx
    - curr_k_end_idx (ending index)
    - use these two to select features['x'] to pass into the blocks

    Within self._diffusion_features, zt will be used
    to generate the full rotary embeddings, and sort_idx
    will be index the embedded zt into shape
    (num_samples, num_tokens_generated_last_time (non-mask) + num_tokens_to_gen (mask), hidden)

    We want to append the kv values for num_tokens_generated_last_time to the old kv cache
    and not build up kv values for num_tokens_to_gen (because they are masks)
    """
    assert attn_mode is not None
    ones = torch.ones(zt.shape[0], device=zt.device)
    if cutoffs is not None:
      cutoffs = cutoffs * ones
      assert cutoffs.ndim == 1
    features = self._diffusion_features(
      zt=zt,
      sort_idx=sort_idx,
      attn_mode=attn_mode,
      cutoffs=cutoffs)
    zeros = torch.zeros(zt.shape[0], device=zt.device)
    t_cond = F.silu(self.sigma_map(zeros))

    x = features['x']
    rotary = features['rotary']
    if kv_cache:
      # expect x to be sorted
      x = x[:, last_k_start:curr_k_end, :]
      # rotary is already sorted here
      # looking ahead
      cos, sin = rotary
      rotary = (cos[:, :curr_k_end], sin[:, :curr_k_end])
      num_clean = curr_k_start - last_k_start
      num_clean_and_mask = curr_k_end - last_k_start
    else:
      num_clean = None
      num_clean_and_mask = None

    with torch.amp.autocast('cuda', enabled=False):
      for i in range(len(self.blocks)):
        x = self.blocks[i](
          x, rotary, c=t_cond, 
          attn_mask=features['attention'],
          kv_cache=kv_cache, 
          num_clean=num_clean,
          num_clean_and_mask=num_clean_and_mask)
      x = self.output_layer(x, c=t_cond)
    
    if kv_cache:
      x = x[:, num_clean:, :]
    return x


class EsoLMHFDiT(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.vocab_embed = EmbeddingLayer(
      config.hidden_size, config.vocab_size)
    self.sigma_map = TimestepEmbedder(config.cond_dim)
    self.rotary_dim = config.hidden_size // config.n_heads
    self.rotary_emb = Rotary(self.rotary_dim)

    blocks = []
    for _ in range(config.n_blocks):
      block = DDiTBlock(
        dim=config.hidden_size,
        n_heads=config.n_heads,
        cond_dim=config.cond_dim,
        adaLN=True,
        dropout=config.dropout)
      blocks.append(block)
    self.blocks = nn.ModuleList(blocks)

    self.output_layer = DDiTFinalLayer(
      hidden_size=config.hidden_size,
      out_channels=config.vocab_size,
      cond_dim=config.cond_dim,
      adaLN=True)

  def reset_kv_cache(self):
    for block in self.blocks:
      block.reset_kv_cache()


class EsoLM(transformers.PreTrainedModel):
  """HF-compatible model."""
  config_class = EsoLMConfig
  base_model_prefix = 'esolm'

  def __init__(self, config: EsoLMConfig):
    super().__init__(config)
    self.config = config
    self.backbone = EsoLMHFDiT(config)