zzy1123 commited on
Commit
71153bb
·
verified ·
1 Parent(s): c9236e9

Upload DiffusionLlamaLM

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DiffusionLlamaLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_diff_llama.DiffusionLlamaConfig",
7
+ "AutoModel": "modeling_diffusion_llama.DiffusionLlamaLM",
8
+ "AutoModelForCausalLM": "modeling_diff_llama.DiffusionLlamaLM"
9
+ },
10
+ "bias": false,
11
+ "block_size": 2048,
12
+ "condense_ratio": 1,
13
+ "dtype": "float32",
14
+ "eos_token_id": 2,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 4096,
17
+ "mask_token_id": 32000,
18
+ "mlp_class": "LLaMAMLP",
19
+ "model_type": "diff_llama",
20
+ "n_embd": 1024,
21
+ "n_head": 16,
22
+ "n_layer": 20,
23
+ "n_query_groups": 16,
24
+ "name": "Diff_LLaMA_336M",
25
+ "norm_class": "FusedRMSNorm",
26
+ "norm_eps": 1e-05,
27
+ "pad_token_id": 0,
28
+ "padded_vocab_size": 32000,
29
+ "padding_multiple": 64,
30
+ "parallel_residual": false,
31
+ "rotary_percentage": 1.0,
32
+ "shared_attention_norm": false,
33
+ "transformers_version": "4.57.3",
34
+ "vocab_size": 32000
35
+ }
configuration_diff_llama.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import Literal, Optional
3
+
4
+ class DiffusionLlamaConfig(PretrainedConfig):
5
+ model_type = "diff_llama"
6
+
7
+ def __init__(
8
+ self,
9
+ block_size: int = 4096,
10
+ vocab_size: int = 50254,
11
+ padding_multiple: int = 512,
12
+ padded_vocab_size: Optional[int] = None,
13
+ n_layer: int = 16,
14
+ n_head: int = 32,
15
+ n_embd: int = 4096,
16
+ rotary_percentage: float = 0.25,
17
+ parallel_residual: bool = True,
18
+ bias: bool = True,
19
+ n_query_groups: Optional[int] = None,
20
+ shared_attention_norm: bool = False,
21
+ norm_class: Literal["LayerNorm", "RMSNorm", "FusedRMSNorm"] = "LayerNorm",
22
+ norm_eps: float = 1e-5,
23
+ mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP",
24
+ intermediate_size: Optional[int] = None,
25
+ condense_ratio: int = 1,
26
+ initializer_range: float = 0.02,
27
+ **kwargs,
28
+ ):
29
+ self.block_size = block_size
30
+ self.vocab_size = vocab_size
31
+ self.padding_multiple = padding_multiple
32
+
33
+ # Logic from original Config.__post_init__
34
+ # 1. Calculate padded vocab size
35
+ if padded_vocab_size is None:
36
+ self.padded_vocab_size = self._find_multiple(vocab_size, padding_multiple)
37
+ else:
38
+ self.padded_vocab_size = padded_vocab_size
39
+
40
+ self.n_layer = n_layer
41
+ self.n_head = n_head
42
+ self.n_embd = n_embd
43
+ self.rotary_percentage = rotary_percentage
44
+ self.parallel_residual = parallel_residual
45
+ self.bias = bias
46
+
47
+ # 2. Calculate query groups
48
+ if n_query_groups is not None:
49
+ self.n_query_groups = n_query_groups
50
+ else:
51
+ self.n_query_groups = n_head
52
+
53
+ self.shared_attention_norm = shared_attention_norm
54
+ self.norm_class = norm_class
55
+ self.norm_eps = norm_eps
56
+ self.mlp_class = mlp_class
57
+
58
+ # 3. Calculate intermediate size
59
+ if intermediate_size is None:
60
+ # Default to 4x if not specified, though LLaMA usually specifies it explicitly
61
+ self.intermediate_size = 4 * n_embd
62
+ else:
63
+ self.intermediate_size = intermediate_size
64
+
65
+ self.condense_ratio = condense_ratio
66
+ self.initializer_range = initializer_range
67
+
68
+ super().__init__(**kwargs)
69
+
70
+ @property
71
+ def head_size(self) -> int:
72
+ return self.n_embd // self.n_head
73
+
74
+ def _find_multiple(self, n: int, k: int) -> int:
75
+ if k > 0 and n % k == 0:
76
+ return n
77
+ return n + k - (n % k)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:671794256bef4dff670845aca8d38e5fa382931f8f96d40028b887ee01a116f8
3
+ size 1604509704
modeling_diff_llama.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import init
7
+ from transformers import PreTrainedModel, AutoModelForCausalLM
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from einops import rearrange, repeat
10
+ from xformers.ops import SwiGLU
11
+
12
+ from .configuration_diff_llama import DiffusionLlamaConfig
13
+
14
+ # ===========================================================================
15
+ # IMPORTS & CHECKS
16
+ # ===========================================================================
17
+
18
+ try:
19
+ from lightning_utilities.core.imports import RequirementCache
20
+ FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1")
21
+ except ImportError:
22
+ # Fallback if lightning_utilities is missing
23
+ FlashAttention2Available = False
24
+
25
+ # Import compiled extensions if available
26
+ try:
27
+ import rotary_emb
28
+ except ImportError:
29
+ rotary_emb = None
30
+
31
+ try:
32
+ import dropout_layer_norm
33
+ except ImportError:
34
+ dropout_layer_norm = None
35
+
36
+
37
+ # ===========================================================================
38
+ # PART 1: ROTARY EMBEDDING (Autograd Function for Training)
39
+ # ===========================================================================
40
+
41
+ class ApplyRotaryEmb(torch.autograd.Function):
42
+ @staticmethod
43
+ @torch.compiler.disable
44
+ def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
45
+ """
46
+ Full forward pass from fused_rotary_embedding.py
47
+ """
48
+ batch, seqlen, nheads, headdim = x.shape
49
+ rotary_seqlen, rotary_dim = cos.shape
50
+ rotary_dim *= 2
51
+ assert rotary_dim <= headdim
52
+ assert seqlen <= rotary_seqlen
53
+
54
+ x_ro = x[..., :rotary_dim]
55
+ x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
56
+ out = torch.empty_like(x) if not inplace else x
57
+ out_ro = out[..., :rotary_dim]
58
+
59
+ if inplace:
60
+ o1, o2 = x1, x2
61
+ else:
62
+ o1, o2 = (
63
+ out_ro.chunk(2, dim=-1)
64
+ if not interleaved
65
+ else (out_ro[..., ::2], out_ro[..., 1::2])
66
+ )
67
+
68
+ if rotary_emb is None:
69
+ # Fallback or error if extension is missing but this code path is hit
70
+ raise ImportError("rotary_emb extension not found. Please install it to use fused rotary embeddings.")
71
+
72
+ rotary_emb.apply_rotary(
73
+ x1, x2,
74
+ rearrange(cos[:seqlen], "s d -> s 1 d"),
75
+ rearrange(sin[:seqlen], "s d -> s 1 d"),
76
+ o1, o2,
77
+ False,
78
+ )
79
+
80
+ if not inplace and rotary_dim < headdim:
81
+ out[..., rotary_dim:].copy_(x[..., rotary_dim:])
82
+
83
+ ctx.save_for_backward(cos, sin)
84
+ ctx.interleaved = interleaved
85
+ ctx.inplace = inplace
86
+ return out if not inplace else x
87
+
88
+ @staticmethod
89
+ def backward(ctx, do):
90
+ """
91
+ Full backward pass from fused_rotary_embedding.py to support training
92
+ """
93
+ cos, sin = ctx.saved_tensors
94
+ _, seqlen, _, headdim = do.shape
95
+ rotary_dim = cos.shape[-1] * 2
96
+ inplace = ctx.inplace
97
+ do_ro = do[..., :rotary_dim]
98
+
99
+ do1, do2 = (
100
+ do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])
101
+ )
102
+
103
+ dx = torch.empty_like(do) if not inplace else do
104
+ if inplace:
105
+ dx1, dx2 = do1, do2
106
+ else:
107
+ dx_ro = dx[..., :rotary_dim]
108
+ dx1, dx2 = (
109
+ dx_ro.chunk(2, dim=-1)
110
+ if not ctx.interleaved
111
+ else (dx_ro[..., ::2], dx_ro[..., 1::2])
112
+ )
113
+
114
+ rotary_emb.apply_rotary(
115
+ do1, do2,
116
+ rearrange(cos[:seqlen], "s d -> s 1 d"),
117
+ rearrange(sin[:seqlen], "s d -> s 1 d"),
118
+ dx1, dx2,
119
+ True,
120
+ )
121
+
122
+ if not inplace and rotary_dim < headdim:
123
+ dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
124
+
125
+ return dx, None, None, None, None
126
+
127
+ apply_rotary_emb_func = ApplyRotaryEmb.apply
128
+
129
+ def build_rope_cache(
130
+ seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1
131
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
132
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem))
133
+ seq_idx = torch.arange(seq_len, device=device) / condense_ratio
134
+ idx_theta = torch.outer(seq_idx, theta)
135
+ cos, sin = torch.cos(idx_theta), torch.sin(idx_theta)
136
+
137
+ if dtype == torch.bfloat16:
138
+ return cos.bfloat16(), sin.bfloat16()
139
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
140
+ return cos.half(), sin.half()
141
+ return cos, sin
142
+
143
+
144
+ # ===========================================================================
145
+ # PART 2: NORMALIZATION (Fused RMS Norm)
146
+ # ===========================================================================
147
+
148
+ def maybe_align(x, alignment_in_bytes=16):
149
+ return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
150
+
151
+ class DropoutAddLayerNormFn(torch.autograd.Function):
152
+ @staticmethod
153
+ @torch.compiler.disable
154
+ def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
155
+ if dropout_layer_norm is None:
156
+ raise ImportError("dropout_layer_norm extension not found. Cannot use FusedRMSNorm.")
157
+
158
+ x0 = maybe_align(x0.contiguous(), 16)
159
+ residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
160
+ gamma = maybe_align(gamma.contiguous(), 16)
161
+
162
+ zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
163
+ x0.view((-1, gamma.numel())),
164
+ residual.view((-1, gamma.numel())) if residual is not None else None,
165
+ gamma,
166
+ None, None, None, None, None, # unused args
167
+ dropout_p,
168
+ epsilon,
169
+ 1.0, 0, None,
170
+ residual_in_fp32,
171
+ is_rms_norm,
172
+ )
173
+
174
+ # --- FIX START ---
175
+ # When dropout_p is 0.0, the C++ kernel returns xmat as None optimization.
176
+ # We must fallback to the input x0.
177
+ if xmat is None:
178
+ xmat = x0
179
+ # --- FIX END ---
180
+
181
+ ctx.save_for_backward(xmat.view(x0.shape), x0, dmask, gamma, mu, rsigma)
182
+ ctx.dropout_p = dropout_p
183
+ ctx.is_rms_norm = is_rms_norm
184
+ ctx.has_residual = residual is not None
185
+
186
+ return zmat.view(x0.shape)
187
+
188
+ @staticmethod
189
+ def backward(ctx, dz, *args):
190
+ # Full backward implementation for training
191
+ dz = maybe_align(dz.contiguous(), 16)
192
+ x, x0, dmask, gamma, mu, rsigma = ctx.saved_tensors
193
+
194
+ dx0mat, dresidualmat, dgamma, dbeta, *rest = dropout_layer_norm.dropout_add_ln_bwd(
195
+ dz.view((-1, gamma.numel())), # <--- CHANGED: Force 2D view [batch*seq, hidden]
196
+ None, # dx
197
+ x.view((-1, gamma.numel())), # Note: x is already being flattened here
198
+ x0.view((-1, gamma.numel())) if x0 is not None else None,
199
+ dmask, mu, rsigma, gamma,
200
+ None, None, None, None, # scales
201
+ ctx.dropout_p,
202
+ 1.0, 0,
203
+ ctx.has_residual,
204
+ ctx.is_rms_norm,
205
+ )
206
+
207
+ # The outputs are reshaped back to original x.shape here, so the rest is fine
208
+ dx0 = dx0mat.view(x.shape)
209
+ dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None
210
+
211
+ return (dx0, dresidual, dgamma, None, None, None, None, None, None, None, None, None)
212
+ def rms_norm(x, weight, epsilon):
213
+ return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False, False, True)
214
+
215
+ class FusedRMSNorm(torch.nn.Module):
216
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5):
217
+ super().__init__()
218
+ self.eps = eps
219
+ self.weight = torch.nn.Parameter(torch.ones(size))
220
+ self.dim = dim
221
+ def reset_parameters(self):
222
+ init.ones_(self.weight)
223
+ def forward(self, x):
224
+ return rms_norm(x, self.weight, self.eps)
225
+
226
+ class RMSNorm(torch.nn.Module):
227
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
228
+ super().__init__()
229
+ self.weight = torch.nn.Parameter(torch.ones(size))
230
+ self.eps = eps
231
+ self.dim = dim
232
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
233
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
234
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
235
+ return self.weight * x_normed
236
+
237
+
238
+ # ===========================================================================
239
+ # PART 3: BLOCKS & LAYERS
240
+ # ===========================================================================
241
+
242
+ class GptNeoxMLP(nn.Module):
243
+ def __init__(self, config: DiffusionLlamaConfig) -> None:
244
+ super().__init__()
245
+ self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
246
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
247
+
248
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
249
+ x = self.fc(x)
250
+ x = torch.nn.functional.gelu(x)
251
+ return self.proj(x)
252
+
253
+ class LLaMAMLP(nn.Module):
254
+ def __init__(self, config: DiffusionLlamaConfig) -> None:
255
+ super().__init__()
256
+ self.swiglu = SwiGLU(config.n_embd, config.intermediate_size, bias=False, _pack_weights=False)
257
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
258
+ return self.swiglu(x)
259
+
260
+ class SelfAttention(nn.Module):
261
+ def __init__(self, config: DiffusionLlamaConfig) -> None:
262
+ super().__init__()
263
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
264
+ self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
265
+ self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
266
+ self.config = config
267
+
268
+ def forward(self, x: torch.Tensor, rope: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
269
+ B, T, C = x.size()
270
+ qkv = self.attn(x)
271
+
272
+ q_per_kv = self.config.n_head // self.config.n_query_groups
273
+ total_qkv = q_per_kv + 2
274
+ qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
275
+
276
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2)
277
+ q = q.reshape(B, T, -1, self.config.head_size)
278
+ k = k.reshape(B, T, -1, self.config.head_size)
279
+ v = v.reshape(B, T, -1, self.config.head_size)
280
+
281
+ cos, sin = rope
282
+
283
+ # Apply Rotary
284
+ q = apply_rotary_emb_func(q, cos, sin, False, True)
285
+ k = apply_rotary_emb_func(k, cos, sin, False, True)
286
+
287
+ y = self.scaled_dot_product_attention(q, k, v)
288
+ y = y.reshape(B, T, C)
289
+ y = self.proj(y)
290
+ return y
291
+
292
+ def scaled_dot_product_attention(self, q, k, v):
293
+ scale = 1.0 / math.sqrt(self.config.head_size)
294
+
295
+ # Use Flash Attention 2 if available and on CUDA
296
+ if FlashAttention2Available and q.device.type == "cuda" and q.dtype in (torch.float16, torch.bfloat16):
297
+ from flash_attn import flash_attn_func
298
+ return flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=False)
299
+
300
+ # Fallback to SDPA
301
+ q = q.transpose(1, 2)
302
+ k = k.transpose(1, 2)
303
+ v = v.transpose(1, 2)
304
+
305
+ # Handle GQA/MQA broadcast
306
+ if q.size() != k.size():
307
+ k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1)
308
+ v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1)
309
+
310
+ y = torch.nn.functional.scaled_dot_product_attention(
311
+ q, k, v, attn_mask=None, dropout_p=0.0, scale=scale, is_causal=False
312
+ )
313
+ return y.transpose(1, 2)
314
+
315
+ class Block(nn.Module):
316
+ def __init__(self, config: DiffusionLlamaConfig) -> None:
317
+ super().__init__()
318
+ # Determine classes dynamically based on config strings
319
+ if config.norm_class == "RMSNorm":
320
+ norm_cls = RMSNorm
321
+ elif config.norm_class == "FusedRMSNorm":
322
+ norm_cls = FusedRMSNorm
323
+ else:
324
+ norm_cls = getattr(torch.nn, config.norm_class)
325
+
326
+ mlp_cls = LLaMAMLP if config.mlp_class == "LLaMAMLP" else GptNeoxMLP
327
+
328
+ self.norm_1 = norm_cls(config.n_embd, eps=config.norm_eps)
329
+ self.attn = SelfAttention(config)
330
+
331
+ if not config.shared_attention_norm:
332
+ self.norm_2 = norm_cls(config.n_embd, eps=config.norm_eps)
333
+
334
+ self.mlp = mlp_cls(config)
335
+ self.config = config
336
+
337
+ def forward(self, x: torch.Tensor, rope: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
338
+ n_1 = self.norm_1(x)
339
+ h = self.attn(n_1, rope)
340
+
341
+ if self.config.parallel_residual:
342
+ n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
343
+ x = x + h + self.mlp(n_2)
344
+ else:
345
+ if self.config.shared_attention_norm:
346
+ raise NotImplementedError("Shared attention norm not supported with non-parallel residual")
347
+ x = x + h
348
+ x = x + self.mlp(self.norm_2(x))
349
+ return x
350
+
351
+
352
+ # ===========================================================================
353
+ # PART 4: MAIN MODEL CLASSES
354
+ # ===========================================================================
355
+
356
+ class TransEncoder(nn.Module):
357
+ def __init__(self, config: DiffusionLlamaConfig) -> None:
358
+ super().__init__()
359
+ assert config.padded_vocab_size is not None
360
+ self.config = config
361
+
362
+ if config.norm_class == "RMSNorm":
363
+ norm_cls = RMSNorm
364
+ elif config.norm_class == "FusedRMSNorm":
365
+ norm_cls = FusedRMSNorm
366
+ else:
367
+ norm_cls = getattr(torch.nn, config.norm_class)
368
+
369
+ self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
370
+ self.transformer = nn.ModuleDict(
371
+ dict(
372
+ wte=nn.Embedding(config.padded_vocab_size + 1, config.n_embd),
373
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
374
+ ln_f=norm_cls(config.n_embd, eps=config.norm_eps),
375
+ )
376
+ )
377
+ self.rope_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
378
+
379
+ def forward(self, idx: torch.Tensor) -> torch.Tensor:
380
+ B, T = idx.size()
381
+
382
+ # Build Rope cache if needed
383
+ if self.rope_cache is None:
384
+ self.rope_cache = build_rope_cache(
385
+ seq_len=self.config.block_size,
386
+ n_elem=int(self.config.rotary_percentage * self.config.head_size),
387
+ dtype=torch.bfloat16,
388
+ device=idx.device,
389
+ condense_ratio=self.config.condense_ratio,
390
+ )
391
+
392
+ # Retrieve and slice cache
393
+ cos, sin = self.rope_cache
394
+ cos = cos[:T]
395
+ sin = sin[:T]
396
+
397
+ x = self.transformer.wte(idx)
398
+ for block in self.transformer.h:
399
+ x = block(x, (cos, sin))
400
+
401
+ x = self.transformer.ln_f(x)
402
+ return self.lm_head(x)
403
+
404
+
405
+ class DiffusionLlamaLM(PreTrainedModel):
406
+ config_class = DiffusionLlamaConfig
407
+ base_model_prefix = "model"
408
+
409
+ def __init__(self, config: DiffusionLlamaConfig):
410
+ super().__init__(config)
411
+ self.model = TransEncoder(config)
412
+
413
+ # Initialize weights (Training feature)
414
+ self.post_init()
415
+
416
+ def _init_weights(self, module: nn.Module) -> None:
417
+ """
418
+ Initialization logic for training.
419
+ Adapted from original TransEncoder._init_weights.
420
+ """
421
+ n_layer = self.config.n_layer
422
+
423
+ if isinstance(module, nn.Embedding):
424
+ torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
425
+ elif isinstance(module, nn.Linear):
426
+ torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
427
+ if module.bias is not None:
428
+ torch.nn.init.zeros_(module.bias)
429
+
430
+ # Special initialization for SwiGLU / Projections based on names
431
+ # In HF _init_weights, 'module' is the current leaf. We check specific instances.
432
+ if isinstance(module, LLaMAMLP):
433
+ for name, p in module.named_parameters():
434
+ if "proj.weight" in name:
435
+ nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer)
436
+
437
+ if isinstance(module, SwiGLU):
438
+ for name, p in module.named_parameters():
439
+ if "w3.weight" in name:
440
+ nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer)
441
+
442
+ if isinstance(module, SelfAttention):
443
+ for name, p in module.named_parameters():
444
+ if "proj.weight" in name:
445
+ nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer)
446
+
447
+ def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, **kwargs) -> Union[Tuple, CausalLMOutputWithPast]:
448
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
449
+
450
+ logits = self.model(input_ids)
451
+
452
+ loss = None
453
+ if labels is not None:
454
+ # Shift so that tokens < n predict n
455
+ shift_logits = logits[..., :-1, :].contiguous()
456
+ shift_labels = labels[..., 1:].contiguous()
457
+ loss_fct = nn.CrossEntropyLoss()
458
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
459
+
460
+ if not return_dict:
461
+ return ((loss,) + (logits,)) if loss is not None else (logits,)
462
+
463
+ return CausalLMOutputWithPast(loss=loss, logits=logits)