codelion commited on
Commit
56202a2
·
verified ·
1 Parent(s): 7d2bb1c

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +5 -5
  2. config.json +7 -2
  3. modeling_dhara.py +760 -0
README.md CHANGED
@@ -162,7 +162,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
162
 
163
  # Load model and tokenizer
164
  tokenizer = AutoTokenizer.from_pretrained("codelion/dhara-70m")
165
- model = AutoModelForCausalLM.from_pretrained("codelion/dhara-70m")
166
 
167
  # Generate text
168
  inputs = tokenizer("The future of AI is", return_tensors="pt")
@@ -235,17 +235,17 @@ for i, output in enumerate(outputs):
235
  ## Citation
236
 
237
  ```bibtex
238
- @article{sharma2025dhara,
239
- title={Dhara: Optimal Architecture for Efficient Diffusion Language Models},
240
  author={Sharma, Asankhaya},
241
  year={2025},
242
- url={https://huggingface.co/codelion/dhara-70m}
243
  }
244
  ```
245
 
246
  ## Related Work
247
 
248
- - [Width vs Depth: The Optimal Architecture for Small Language Models](https://huggingface.co/blog/codelion/optimal-architecture) - Blog post describing this work
249
  - [The 1 Billion Token Challenge: Optimal Dataset Mixing](https://huggingface.co/blog/codelion/optimal-dataset-mixing) - Our previous work on optimal pretraining data
250
  - [GPT-2-70M](https://huggingface.co/codelion/gpt-2-70m) - Our previous model from optimal pretraining experiments
251
 
 
162
 
163
  # Load model and tokenizer
164
  tokenizer = AutoTokenizer.from_pretrained("codelion/dhara-70m")
165
+ model = AutoModelForCausalLM.from_pretrained("codelion/dhara-70m", trust_remote_code=True)
166
 
167
  # Generate text
168
  inputs = tokenizer("The future of AI is", return_tensors="pt")
 
235
  ## Citation
236
 
237
  ```bibtex
238
+ @article{sharma2025optimal,
239
+ title={The Optimal Architecture for Small Language Models},
240
  author={Sharma, Asankhaya},
241
  year={2025},
242
+ url={https://huggingface.co/blog/codelion/optimal-model-architecture}
243
  }
244
  ```
245
 
246
  ## Related Work
247
 
248
+ - [The Optimal Architecture for Small Language Models](https://huggingface.co/blog/codelion/optimal-model-architecture) - Blog post describing this work
249
  - [The 1 Billion Token Challenge: Optimal Dataset Mixing](https://huggingface.co/blog/codelion/optimal-dataset-mixing) - Our previous work on optimal pretraining data
250
  - [GPT-2-70M](https://huggingface.co/codelion/gpt-2-70m) - Our previous model from optimal pretraining experiments
251
 
config.json CHANGED
@@ -1,7 +1,12 @@
1
  {
2
  "architectures": [
3
- "DharaCanonForMaskedDiffusion"
4
  ],
 
 
 
 
 
5
  "attention_dropout": 0.0,
6
  "bos_token_id": 1,
7
  "canon_activation": false,
@@ -18,7 +23,7 @@
18
  "mask_epsilon": 0.001,
19
  "mask_token_id": 50256,
20
  "max_position_embeddings": 2048,
21
- "model_type": "dhara_canon",
22
  "num_attention_heads": 6,
23
  "num_diffusion_steps": 1000,
24
  "num_hidden_layers": 32,
 
1
  {
2
  "architectures": [
3
+ "DharaForMaskedDiffusion"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling_dhara.DharaConfig",
7
+ "AutoModel": "modeling_dhara.DharaForMaskedDiffusion",
8
+ "AutoModelForCausalLM": "modeling_dhara.DharaForMaskedDiffusion"
9
+ },
10
  "attention_dropout": 0.0,
11
  "bos_token_id": 1,
12
  "canon_activation": false,
 
23
  "mask_epsilon": 0.001,
24
  "mask_token_id": 50256,
25
  "max_position_embeddings": 2048,
26
+ "model_type": "dhara",
27
  "num_attention_heads": 6,
28
  "num_diffusion_steps": 1000,
29
  "num_hidden_layers": 32,
modeling_dhara.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Dhara: Diffusion Language Model
4
+
5
+ A diffusion-based language model that combines:
6
+ 1. Masked diffusion training (MDM) with bidirectional attention
7
+ 2. Canon layers for local context mixing via causal depthwise convolutions
8
+ 3. High-throughput parallel token generation
9
+
10
+ Usage:
11
+ from transformers import AutoModel, AutoTokenizer
12
+ model = AutoModel.from_pretrained("codelion/dhara-70m", trust_remote_code=True)
13
+ tokenizer = AutoTokenizer.from_pretrained("codelion/dhara-70m")
14
+ """
15
+
16
+ import math
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from transformers import PreTrainedModel
24
+ from transformers.generation import GenerationMixin
25
+ from transformers.modeling_outputs import BaseModelOutputWithPast, MaskedLMOutput
26
+ from transformers.utils import logging
27
+ from transformers.cache_utils import Cache, DynamicCache
28
+ from transformers import PretrainedConfig
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+ # Optional performance imports
33
+ try:
34
+ from flash_attn import flash_attn_func
35
+ FLASH_ATTN_AVAILABLE = True
36
+ except ImportError:
37
+ FLASH_ATTN_AVAILABLE = False
38
+
39
+
40
+ class DharaConfig(PretrainedConfig):
41
+ """
42
+ Configuration for Dhara model.
43
+
44
+ Dhara is a diffusion language model with Canon layers for local context mixing.
45
+ """
46
+
47
+ model_type = "dhara"
48
+
49
+ def __init__(
50
+ self,
51
+ # Core architecture
52
+ vocab_size: int = 50257,
53
+ hidden_size: int = 384,
54
+ num_hidden_layers: int = 32,
55
+ num_attention_heads: int = 6,
56
+ num_key_value_heads: int = 6,
57
+ intermediate_size: int = 1024,
58
+ head_dim: int = None,
59
+ max_position_embeddings: int = 2048,
60
+
61
+ # Model specifics
62
+ hidden_act: str = "silu",
63
+ rms_norm_eps: float = 1e-5,
64
+ rope_theta: float = 10000.0,
65
+ initializer_range: float = 0.02,
66
+ tie_word_embeddings: bool = True,
67
+ attention_dropout: float = 0.0,
68
+
69
+ # Canon layer parameters
70
+ canon_set: str = "AC",
71
+ canon_kernel: int = 4,
72
+ canon_residual: bool = True,
73
+ canon_activation: bool = False,
74
+ canon_bias: bool = False,
75
+
76
+ # Diffusion specific
77
+ mask_token_id: int = 50256,
78
+ mask_epsilon: float = 0.001,
79
+ num_diffusion_steps: int = 1000,
80
+
81
+ # Special tokens
82
+ bos_token_id: int = 1,
83
+ eos_token_id: int = 2,
84
+ pad_token_id: int = 0,
85
+
86
+ # Performance flags
87
+ use_cache: bool = False,
88
+ use_flash_attention: bool = False,
89
+ use_xformers: bool = False,
90
+
91
+ **kwargs
92
+ ):
93
+ super().__init__(
94
+ bos_token_id=bos_token_id,
95
+ eos_token_id=eos_token_id,
96
+ pad_token_id=pad_token_id,
97
+ tie_word_embeddings=tie_word_embeddings,
98
+ **kwargs
99
+ )
100
+
101
+ self.vocab_size = vocab_size
102
+ self.hidden_size = hidden_size
103
+ self.num_hidden_layers = num_hidden_layers
104
+ self.num_attention_heads = num_attention_heads
105
+ self.num_key_value_heads = num_key_value_heads
106
+ self.intermediate_size = intermediate_size
107
+ self.head_dim = head_dim or (hidden_size // num_attention_heads)
108
+ self.max_position_embeddings = max_position_embeddings
109
+
110
+ self.hidden_act = hidden_act
111
+ self.rms_norm_eps = rms_norm_eps
112
+ self.rope_theta = rope_theta
113
+ self.initializer_range = initializer_range
114
+ self.attention_dropout = attention_dropout
115
+
116
+ self.canon_set = canon_set
117
+ self.canon_kernel = canon_kernel
118
+ self.canon_residual = canon_residual
119
+ self.canon_activation = canon_activation
120
+ self.canon_bias = canon_bias
121
+
122
+ self.mask_token_id = mask_token_id
123
+ self.mask_epsilon = mask_epsilon
124
+ self.num_diffusion_steps = num_diffusion_steps
125
+
126
+ self.use_cache = use_cache
127
+ self.use_flash_attention = use_flash_attention
128
+ self.use_xformers = use_xformers
129
+
130
+
131
+ class CanonLayer(nn.Module):
132
+ """Causal 1D depthwise convolution for local context mixing."""
133
+
134
+ def __init__(
135
+ self,
136
+ hidden_size: int,
137
+ kernel_size: int = 4,
138
+ use_residual: bool = True,
139
+ use_activation: bool = False,
140
+ use_bias: bool = False,
141
+ ):
142
+ super().__init__()
143
+ self.hidden_size = hidden_size
144
+ self.kernel_size = kernel_size
145
+ self.use_residual = use_residual
146
+ self.use_activation = use_activation
147
+
148
+ self.conv = nn.Conv1d(
149
+ in_channels=hidden_size,
150
+ out_channels=hidden_size,
151
+ kernel_size=kernel_size,
152
+ padding=kernel_size - 1,
153
+ groups=hidden_size,
154
+ bias=use_bias,
155
+ )
156
+
157
+ nn.init.normal_(self.conv.weight, mean=0.0, std=0.02)
158
+ if use_bias:
159
+ nn.init.zeros_(self.conv.bias)
160
+
161
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
162
+ batch_size, seq_len, hidden_size = hidden_states.shape
163
+ x = hidden_states.transpose(1, 2)
164
+ out = self.conv(x)
165
+ out = out[:, :, :seq_len]
166
+ if self.use_activation:
167
+ out = F.silu(out)
168
+ out = out.transpose(1, 2)
169
+ if self.use_residual:
170
+ out = hidden_states + out
171
+ return out
172
+
173
+
174
+ class RMSNorm(nn.Module):
175
+ """Root Mean Square Layer Normalization"""
176
+
177
+ def __init__(self, hidden_size, eps=1e-6):
178
+ super().__init__()
179
+ self.weight = nn.Parameter(torch.ones(hidden_size))
180
+ self.variance_epsilon = eps
181
+
182
+ def forward(self, hidden_states):
183
+ input_dtype = hidden_states.dtype
184
+ hidden_states = hidden_states.to(torch.float32)
185
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
186
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
187
+ return self.weight * hidden_states.to(input_dtype)
188
+
189
+
190
+ class RotaryEmbedding(nn.Module):
191
+ """Rotary Position Embeddings (RoPE)"""
192
+
193
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
194
+ super().__init__()
195
+ self.dim = dim
196
+ self.max_position_embeddings = max_position_embeddings
197
+ self.base = base
198
+
199
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
200
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
201
+
202
+ self._set_cos_sin_cache(
203
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
204
+ )
205
+
206
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
207
+ self.max_seq_len_cached = seq_len
208
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
209
+ freqs = torch.outer(t, self.inv_freq)
210
+ emb = torch.cat((freqs, freqs), dim=-1)
211
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
212
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
213
+
214
+ def forward(self, x, seq_len=None):
215
+ if seq_len > self.max_seq_len_cached:
216
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
217
+ return (
218
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
219
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
220
+ )
221
+
222
+
223
+ def rotate_half(x):
224
+ x1 = x[..., : x.shape[-1] // 2]
225
+ x2 = x[..., x.shape[-1] // 2 :]
226
+ return torch.cat((-x2, x1), dim=-1)
227
+
228
+
229
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
230
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
231
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
232
+ cos = cos.to(q.dtype)
233
+ sin = sin.to(q.dtype)
234
+ q_embed = (q * cos) + (rotate_half(q) * sin)
235
+ k_embed = (k * cos) + (rotate_half(k) * sin)
236
+ return q_embed, k_embed
237
+
238
+
239
+ class DharaMLP(nn.Module):
240
+ """MLP with SwiGLU activation"""
241
+
242
+ def __init__(self, config):
243
+ super().__init__()
244
+ self.hidden_size = config.hidden_size
245
+ self.intermediate_size = config.intermediate_size
246
+
247
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
248
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
249
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
250
+ self.act_fn = nn.SiLU()
251
+
252
+ def forward(self, x):
253
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
254
+
255
+
256
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
257
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
258
+ if n_rep == 1:
259
+ return hidden_states
260
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
261
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
262
+
263
+
264
+ class DharaAttention(nn.Module):
265
+ """Multi-Head Bidirectional Attention with GQA support"""
266
+
267
+ def __init__(self, config: DharaConfig, layer_idx: Optional[int] = None):
268
+ super().__init__()
269
+ self.config = config
270
+ self.layer_idx = layer_idx
271
+
272
+ self.attention_dropout = config.attention_dropout
273
+ self.hidden_size = config.hidden_size
274
+ self.num_heads = config.num_attention_heads
275
+ self.head_dim = config.head_dim
276
+ self.num_key_value_heads = config.num_key_value_heads
277
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
278
+ self.max_position_embeddings = config.max_position_embeddings
279
+ self.rope_theta = config.rope_theta
280
+ self.is_causal = False # Bidirectional for diffusion
281
+
282
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
283
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
284
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
285
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
286
+
287
+ self.rotary_emb = RotaryEmbedding(
288
+ self.head_dim,
289
+ max_position_embeddings=self.max_position_embeddings,
290
+ base=self.rope_theta,
291
+ )
292
+
293
+ def forward(
294
+ self,
295
+ hidden_states: torch.Tensor,
296
+ attention_mask: Optional[torch.Tensor] = None,
297
+ position_ids: Optional[torch.LongTensor] = None,
298
+ past_key_value=None,
299
+ output_attentions: bool = False,
300
+ use_cache: bool = False,
301
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
302
+ bsz, q_len, _ = hidden_states.size()
303
+
304
+ query_states = self.q_proj(hidden_states)
305
+ key_states = self.k_proj(hidden_states)
306
+ value_states = self.v_proj(hidden_states)
307
+
308
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
309
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
310
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
311
+
312
+ kv_seq_len = key_states.shape[-2]
313
+ if past_key_value is not None:
314
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
315
+
316
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
317
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
318
+
319
+ if past_key_value is not None:
320
+ cache_kwargs = {"sin": sin, "cos": cos}
321
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
322
+
323
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
324
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
325
+
326
+ if FLASH_ATTN_AVAILABLE and self.config.use_flash_attention:
327
+ query_states = query_states.transpose(1, 2).contiguous()
328
+ key_states = key_states.transpose(1, 2).contiguous()
329
+ value_states = value_states.transpose(1, 2).contiguous()
330
+
331
+ if query_states.dtype not in [torch.float16, torch.bfloat16]:
332
+ query_states = query_states.to(torch.bfloat16)
333
+ key_states = key_states.to(torch.bfloat16)
334
+ value_states = value_states.to(torch.bfloat16)
335
+
336
+ attn_output = flash_attn_func(
337
+ query_states, key_states, value_states,
338
+ dropout_p=0.0, causal=False,
339
+ )
340
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
341
+ else:
342
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
343
+ if attention_mask is not None:
344
+ attn_weights = attn_weights + attention_mask
345
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
346
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
347
+ attn_output = torch.matmul(attn_weights, value_states)
348
+ attn_output = attn_output.transpose(1, 2).contiguous()
349
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
350
+
351
+ attn_output = self.o_proj(attn_output)
352
+
353
+ if not output_attentions:
354
+ attn_weights = None
355
+
356
+ return attn_output, attn_weights, past_key_value
357
+
358
+
359
+ class DharaDecoderLayer(nn.Module):
360
+ """Dhara decoder layer with Canon layers"""
361
+
362
+ def __init__(self, config: DharaConfig, layer_idx: int):
363
+ super().__init__()
364
+ self.hidden_size = config.hidden_size
365
+ self.config = config
366
+
367
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
368
+
369
+ self.canon_a = None
370
+ if "A" in config.canon_set:
371
+ self.canon_a = CanonLayer(
372
+ hidden_size=config.hidden_size,
373
+ kernel_size=config.canon_kernel,
374
+ use_residual=config.canon_residual,
375
+ use_activation=config.canon_activation,
376
+ use_bias=config.canon_bias,
377
+ )
378
+
379
+ self.self_attn = DharaAttention(config=config, layer_idx=layer_idx)
380
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
381
+
382
+ self.canon_c = None
383
+ if "C" in config.canon_set:
384
+ self.canon_c = CanonLayer(
385
+ hidden_size=config.hidden_size,
386
+ kernel_size=config.canon_kernel,
387
+ use_residual=config.canon_residual,
388
+ use_activation=config.canon_activation,
389
+ use_bias=config.canon_bias,
390
+ )
391
+
392
+ self.mlp = DharaMLP(config)
393
+
394
+ def forward(
395
+ self,
396
+ hidden_states: torch.Tensor,
397
+ attention_mask: Optional[torch.Tensor] = None,
398
+ position_ids: Optional[torch.LongTensor] = None,
399
+ past_key_value=None,
400
+ output_attentions: Optional[bool] = False,
401
+ use_cache: Optional[bool] = False,
402
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
403
+ residual = hidden_states
404
+ hidden_states = self.input_layernorm(hidden_states)
405
+
406
+ if self.canon_a is not None:
407
+ hidden_states = self.canon_a(hidden_states)
408
+
409
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
410
+ hidden_states=hidden_states,
411
+ attention_mask=attention_mask,
412
+ position_ids=position_ids,
413
+ past_key_value=past_key_value,
414
+ output_attentions=output_attentions,
415
+ use_cache=use_cache,
416
+ )
417
+ hidden_states = residual + hidden_states
418
+
419
+ residual = hidden_states
420
+ hidden_states = self.post_attention_layernorm(hidden_states)
421
+
422
+ if self.canon_c is not None:
423
+ hidden_states = self.canon_c(hidden_states)
424
+
425
+ hidden_states = self.mlp(hidden_states)
426
+ hidden_states = residual + hidden_states
427
+
428
+ outputs = (hidden_states,)
429
+ if output_attentions:
430
+ outputs += (self_attn_weights,)
431
+ if use_cache:
432
+ outputs += (present_key_value,)
433
+
434
+ return outputs
435
+
436
+
437
+ class DharaPreTrainedModel(PreTrainedModel):
438
+ config_class = DharaConfig
439
+ base_model_prefix = "model"
440
+ supports_gradient_checkpointing = True
441
+ _no_split_modules = ["DharaDecoderLayer"]
442
+ _skip_keys_device_placement = "past_key_values"
443
+ _supports_flash_attn_2 = True
444
+ _supports_cache_class = True
445
+
446
+ def _init_weights(self, module):
447
+ std = self.config.initializer_range
448
+ if isinstance(module, nn.Linear):
449
+ module.weight.data.normal_(mean=0.0, std=std)
450
+ if module.bias is not None:
451
+ module.bias.data.zero_()
452
+ elif isinstance(module, nn.Embedding):
453
+ module.weight.data.normal_(mean=0.0, std=std)
454
+ if module.padding_idx is not None:
455
+ module.weight.data[module.padding_idx].zero_()
456
+
457
+
458
+ class DharaModel(DharaPreTrainedModel):
459
+ """Dhara base model with bidirectional attention and Canon layers."""
460
+
461
+ def __init__(self, config: DharaConfig):
462
+ super().__init__(config)
463
+ self.padding_idx = config.pad_token_id
464
+ self.vocab_size = config.vocab_size
465
+
466
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
467
+ self.layers = nn.ModuleList(
468
+ [DharaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
469
+ )
470
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
471
+ self.gradient_checkpointing = False
472
+
473
+ self.config = config
474
+ self.mask_token_id = config.mask_token_id
475
+ self._use_flash_attention_2 = config.use_flash_attention and FLASH_ATTN_AVAILABLE
476
+
477
+ self.post_init()
478
+
479
+ def get_input_embeddings(self):
480
+ return self.embed_tokens
481
+
482
+ def set_input_embeddings(self, value):
483
+ self.embed_tokens = value
484
+
485
+ def forward(
486
+ self,
487
+ input_ids: torch.LongTensor = None,
488
+ attention_mask: Optional[torch.Tensor] = None,
489
+ position_ids: Optional[torch.LongTensor] = None,
490
+ past_key_values=None,
491
+ inputs_embeds: Optional[torch.FloatTensor] = None,
492
+ use_cache: Optional[bool] = None,
493
+ output_attentions: Optional[bool] = None,
494
+ output_hidden_states: Optional[bool] = None,
495
+ return_dict: Optional[bool] = None,
496
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
497
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
498
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
499
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
500
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
501
+
502
+ if input_ids is not None and inputs_embeds is not None:
503
+ raise ValueError("You cannot specify both input_ids and inputs_embeds")
504
+ elif input_ids is not None:
505
+ batch_size, seq_length = input_ids.shape[:2]
506
+ elif inputs_embeds is not None:
507
+ batch_size, seq_length = inputs_embeds.shape[:2]
508
+ else:
509
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
510
+
511
+ if self.gradient_checkpointing and self.training and use_cache:
512
+ use_cache = False
513
+
514
+ past_key_values_length = 0
515
+ if use_cache:
516
+ use_legacy_cache = not isinstance(past_key_values, Cache)
517
+ if use_legacy_cache:
518
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
519
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
520
+
521
+ if position_ids is None:
522
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
523
+ position_ids = torch.arange(
524
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
525
+ )
526
+ position_ids = position_ids.unsqueeze(0)
527
+
528
+ if inputs_embeds is None:
529
+ inputs_embeds = self.embed_tokens(input_ids)
530
+
531
+ if self._use_flash_attention_2:
532
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
533
+ else:
534
+ if attention_mask is not None:
535
+ if attention_mask.dim() == 2:
536
+ attention_mask_4d = attention_mask[:, None, None, :].expand(
537
+ batch_size, 1, seq_length, seq_length
538
+ ).to(dtype=inputs_embeds.dtype)
539
+ attention_mask = torch.where(
540
+ attention_mask_4d == 0,
541
+ torch.tensor(float('-inf'), dtype=inputs_embeds.dtype, device=attention_mask_4d.device),
542
+ torch.tensor(0.0, dtype=inputs_embeds.dtype, device=attention_mask_4d.device)
543
+ )
544
+
545
+ hidden_states = inputs_embeds
546
+ all_hidden_states = () if output_hidden_states else None
547
+ all_self_attns = () if output_attentions else None
548
+ next_decoder_cache = None
549
+
550
+ for decoder_layer in self.layers:
551
+ if output_hidden_states:
552
+ all_hidden_states += (hidden_states,)
553
+
554
+ if self.gradient_checkpointing and self.training:
555
+ layer_outputs = self._gradient_checkpointing_func(
556
+ decoder_layer.__call__,
557
+ hidden_states, attention_mask, position_ids,
558
+ past_key_values, output_attentions, use_cache,
559
+ )
560
+ else:
561
+ layer_outputs = decoder_layer(
562
+ hidden_states,
563
+ attention_mask=attention_mask,
564
+ position_ids=position_ids,
565
+ past_key_value=past_key_values,
566
+ output_attentions=output_attentions,
567
+ use_cache=use_cache,
568
+ )
569
+
570
+ hidden_states = layer_outputs[0]
571
+
572
+ if use_cache:
573
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
574
+ if output_attentions:
575
+ all_self_attns += (layer_outputs[1],)
576
+
577
+ hidden_states = self.norm(hidden_states)
578
+
579
+ if output_hidden_states:
580
+ all_hidden_states += (hidden_states,)
581
+
582
+ next_cache = None
583
+ if use_cache:
584
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
585
+
586
+ if not return_dict:
587
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
588
+
589
+ return BaseModelOutputWithPast(
590
+ last_hidden_state=hidden_states,
591
+ past_key_values=next_cache,
592
+ hidden_states=all_hidden_states,
593
+ attentions=all_self_attns,
594
+ )
595
+
596
+ def add_noise_to_tokens(self, input_ids: torch.LongTensor, t: torch.FloatTensor, eps: float = None):
597
+ """MDM-style masking: Replace tokens with [MASK] based on noise level t."""
598
+ batch_size, seq_len = input_ids.shape
599
+ device = input_ids.device
600
+
601
+ if eps is None:
602
+ eps = getattr(self.config, 'mask_epsilon', 0.001)
603
+ p_mask = (1 - eps) * t + eps
604
+ p_mask = p_mask.unsqueeze(-1).expand(batch_size, seq_len)
605
+
606
+ corruption_mask = torch.rand(batch_size, seq_len, device=device) < p_mask
607
+ noisy_input_ids = torch.where(corruption_mask, self.mask_token_id, input_ids)
608
+
609
+ return noisy_input_ids, corruption_mask, p_mask
610
+
611
+
612
+ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
613
+ """Dhara Model with Masked Diffusion head for training and inference"""
614
+ _tied_weights_keys = ["lm_head.weight"]
615
+
616
+ def __init__(self, config):
617
+ super().__init__(config)
618
+ self.model = DharaModel(config)
619
+ self.vocab_size = config.vocab_size
620
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
621
+
622
+ self.config = config
623
+ self.mask_token_id = config.mask_token_id
624
+
625
+ self.post_init()
626
+
627
+ def get_input_embeddings(self):
628
+ return self.model.embed_tokens
629
+
630
+ def set_input_embeddings(self, value):
631
+ self.model.embed_tokens = value
632
+
633
+ def get_output_embeddings(self):
634
+ return self.lm_head
635
+
636
+ def set_output_embeddings(self, new_embeddings):
637
+ self.lm_head = new_embeddings
638
+
639
+ def get_decoder(self):
640
+ return self.model
641
+
642
+ def forward(
643
+ self,
644
+ input_ids: torch.LongTensor = None,
645
+ attention_mask: Optional[torch.Tensor] = None,
646
+ position_ids: Optional[torch.LongTensor] = None,
647
+ past_key_values=None,
648
+ inputs_embeds: Optional[torch.FloatTensor] = None,
649
+ labels: Optional[torch.LongTensor] = None,
650
+ use_cache: Optional[bool] = None,
651
+ output_attentions: Optional[bool] = None,
652
+ output_hidden_states: Optional[bool] = None,
653
+ return_dict: Optional[bool] = None,
654
+ corruption_mask: Optional[torch.BoolTensor] = None,
655
+ p_mask: Optional[torch.Tensor] = None,
656
+ ) -> Union[Tuple, MaskedLMOutput]:
657
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
658
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
659
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
660
+
661
+ outputs = self.model(
662
+ input_ids=input_ids,
663
+ attention_mask=attention_mask,
664
+ position_ids=position_ids,
665
+ past_key_values=past_key_values,
666
+ inputs_embeds=inputs_embeds,
667
+ use_cache=use_cache,
668
+ output_attentions=output_attentions,
669
+ output_hidden_states=output_hidden_states,
670
+ return_dict=return_dict,
671
+ )
672
+
673
+ hidden_states = outputs[0]
674
+ if self.config.tie_word_embeddings:
675
+ logits = F.linear(hidden_states, self.model.embed_tokens.weight)
676
+ else:
677
+ logits = self.lm_head(hidden_states)
678
+ logits = logits.float()
679
+
680
+ loss = None
681
+ if labels is not None:
682
+ loss = self.compute_diffusion_loss(logits, labels, corruption_mask, p_mask)
683
+
684
+ if not return_dict:
685
+ output = (logits,) + outputs[1:]
686
+ return (loss,) + output if loss is not None else output
687
+
688
+ return MaskedLMOutput(
689
+ loss=loss,
690
+ logits=logits,
691
+ hidden_states=outputs.hidden_states,
692
+ attentions=outputs.attentions,
693
+ )
694
+
695
+ def compute_diffusion_loss(self, logits, labels, corruption_mask=None, p_mask=None):
696
+ """MDM loss with p_mask importance weighting."""
697
+ if corruption_mask is None or p_mask is None:
698
+ raise ValueError("MDM requires both corruption_mask and p_mask for loss computation.")
699
+
700
+ loss = F.cross_entropy(
701
+ logits.view(-1, self.config.vocab_size),
702
+ labels.view(-1),
703
+ reduction='none'
704
+ )
705
+ loss = loss.view(labels.shape)
706
+
707
+ masked_losses = loss[corruption_mask]
708
+ masked_p_mask = p_mask[corruption_mask]
709
+ weighted_losses = masked_losses / masked_p_mask
710
+
711
+ total_positions = labels.shape[0] * labels.shape[1]
712
+ return weighted_losses.sum() / total_positions
713
+
714
+ def add_noise_to_tokens(self, input_ids: torch.LongTensor, t: torch.FloatTensor, eps: float = None):
715
+ """Delegate to the base model"""
716
+ return self.model.add_noise_to_tokens(input_ids, t, eps)
717
+
718
+ def prepare_inputs_for_generation(
719
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
720
+ ):
721
+ if past_key_values is not None:
722
+ if isinstance(past_key_values, Cache):
723
+ cache_length = past_key_values.get_seq_length()
724
+ past_length = past_key_values.seen_tokens
725
+ max_cache_length = past_key_values.get_max_length()
726
+ else:
727
+ cache_length = past_length = past_key_values[0][0].shape[2]
728
+ max_cache_length = None
729
+
730
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
731
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
732
+ elif past_length < input_ids.shape[1]:
733
+ input_ids = input_ids[:, past_length:]
734
+
735
+ if max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length:
736
+ attention_mask = attention_mask[:, -max_cache_length:]
737
+
738
+ position_ids = kwargs.get("position_ids", None)
739
+ if attention_mask is not None and position_ids is None:
740
+ position_ids = attention_mask.long().cumsum(-1) - 1
741
+ position_ids.masked_fill_(attention_mask == 0, 1)
742
+ if past_key_values:
743
+ position_ids = position_ids[:, -input_ids.shape[1]:]
744
+
745
+ if inputs_embeds is not None and past_key_values is None:
746
+ model_inputs = {"inputs_embeds": inputs_embeds}
747
+ else:
748
+ model_inputs = {"input_ids": input_ids}
749
+
750
+ model_inputs.update({
751
+ "position_ids": position_ids,
752
+ "past_key_values": past_key_values,
753
+ "use_cache": kwargs.get("use_cache"),
754
+ "attention_mask": attention_mask,
755
+ })
756
+ return model_inputs
757
+
758
+ def save_pretrained(self, save_directory, **kwargs):
759
+ kwargs['safe_serialization'] = kwargs.get('safe_serialization', True)
760
+ return super().save_pretrained(save_directory, **kwargs)