shouryamaanjain commited on
Commit
cd2f2fc
·
verified ·
1 Parent(s): 2f83783

Upload smj-diffusion checkpoint (step 12000)

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # smj-diffusion
2
+
3
+ A discrete diffusion language model for code generation, based on the CoDA (Coding LM via Diffusion Adaptation) architecture.
4
+
5
+ > ⚠️ **Note:** This is an intermediate checkpoint (step 12,000) from an interrupted training run. The model may not be fully trained.
6
+
7
+ ## Model Details
8
+
9
+ | Property | Value |
10
+ |----------|-------|
11
+ | **Architecture** | DiffusionQwen3 (Bidirectional Transformer) |
12
+ | **Base Model** | Qwen-based architecture |
13
+ | **Hidden Size** | 1536 |
14
+ | **Layers** | 28 |
15
+ | **Attention Heads** | 12 |
16
+ | **KV Heads** | 2 (GQA) |
17
+ | **Intermediate Size** | 8960 |
18
+ | **Max Position Embeddings** | 32,768 |
19
+ | **Vocab Size** | 151,666 |
20
+ | **Training Checkpoint** | 12,000 steps |
21
+
22
+ ## How Diffusion LMs Work
23
+
24
+ Unlike autoregressive models that generate tokens left-to-right, this model uses **discrete diffusion**:
25
+
26
+ 1. Start with all `<mask>` tokens in the generation region
27
+ 2. Iteratively unmask tokens based on model confidence
28
+ 3. Higher-confidence predictions are revealed first
29
+ 4. Process repeats until all tokens are generated
30
+
31
+ This enables **bidirectional context** during generation, potentially improving coherence for code.
32
+
33
+ ## Usage
34
+
35
+ ### Installation
36
+
37
+ ```bash
38
+ pip install torch transformers
39
+ ```
40
+
41
+ ### Inference
42
+
43
+ ```python
44
+ import torch
45
+ from transformers import AutoTokenizer
46
+
47
+ # Load tokenizer
48
+ tokenizer = AutoTokenizer.from_pretrained("YOUR_USERNAME/smj-diffusion", trust_remote_code=True)
49
+
50
+ # Load model (see inference.py for full diffusion generation logic)
51
+ # The model uses custom DiffusionQwen3Model class
52
+ ```
53
+
54
+ For full inference with diffusion sampling, use the included `inference.py` script:
55
+
56
+ ```bash
57
+ # Single prompt
58
+ python inference.py --checkpoint /path/to/model --prompt "def fibonacci(n):"
59
+
60
+ # Interactive chat
61
+ python inference.py --checkpoint /path/to/model --mode chat
62
+
63
+ # With custom parameters
64
+ python inference.py --checkpoint /path/to/model \
65
+ --prompt "Write a function to sort a list" \
66
+ --steps 128 \
67
+ --temperature 0.0 \
68
+ --max-tokens 256 \
69
+ --alg entropy
70
+ ```
71
+
72
+ ### Generation Parameters
73
+
74
+ | Parameter | Default | Description |
75
+ |-----------|---------|-------------|
76
+ | `steps` | 128 | Number of diffusion denoising steps |
77
+ | `temperature` | 0.0 | Sampling temperature (0 = greedy) |
78
+ | `top_p` | None | Nucleus sampling threshold |
79
+ | `top_k` | None | Top-k sampling |
80
+ | `alg` | entropy | Sampling algorithm: `origin`, `entropy`, `maskgit_plus`, `topk_margin` |
81
+ | `alg_temp` | 0.1 | Algorithm-specific confidence temperature |
82
+
83
+ ## Model Architecture
84
+
85
+ The model is a bidirectional transformer (non-causal attention) trained with discrete diffusion objectives:
86
+
87
+ ```
88
+ DiffusionQwen3Model(
89
+ (model): Qwen2Model with bidirectional attention
90
+ (lm_head): Linear(1536, 151666)
91
+ )
92
+ ```
93
+
94
+ ### Training Objective
95
+
96
+ - **Forward process:** Randomly mask tokens with probability `σ ~ U[ε, 1]`
97
+ - **Reverse process:** Predict original tokens from masked input
98
+ - **Loss weighting:** `1/σ` (ELBO-derived)
99
+
100
+ ## Files
101
+
102
+ - `pytorch_model.bin` - Model weights
103
+ - `config.json` - Model configuration
104
+ - `tokenizer.json`, `vocab.json`, `merges.txt` - Tokenizer files
105
+ - `inference.py` - Standalone inference script
106
+ - `modeling_diffusion_qwen3.py` - Model class definition
107
+
108
+ ## Limitations
109
+
110
+ - This is a **checkpoint from interrupted training** - not a fully trained model
111
+ - Performance may be limited compared to fully trained models
112
+ - Primarily designed for code generation tasks
113
+
114
+ ## Citation
115
+
116
+ Based on CoDA by Salesforce AI Research:
117
+
118
+ ```bibtex
119
+ @article{coda2024,
120
+ title={CoDA: Coding LM via Diffusion Adaptation},
121
+ author={Salesforce AI Research},
122
+ journal={arXiv preprint},
123
+ year={2024}
124
+ }
125
+ ```
126
+
127
+ ## License
128
+
129
+ Please refer to the base Qwen model license for usage terms.
130
+
added_tokens.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<tool_call>": 151657,
4
+ "<|box_end|>": 151649,
5
+ "<|box_start|>": 151648,
6
+ "<|endoftext|>": 151643,
7
+ "<|file_sep|>": 151664,
8
+ "<|fim_middle|>": 151660,
9
+ "<|fim_pad|>": 151662,
10
+ "<|fim_prefix|>": 151659,
11
+ "<|fim_suffix|>": 151661,
12
+ "<|im_end|>": 151645,
13
+ "<|im_start|>": 151644,
14
+ "<|image_pad|>": 151655,
15
+ "<|mask|>": 151665,
16
+ "<|object_ref_end|>": 151647,
17
+ "<|object_ref_start|>": 151646,
18
+ "<|quad_end|>": 151651,
19
+ "<|quad_start|>": 151650,
20
+ "<|repo_name|>": 151663,
21
+ "<|video_pad|>": 151656,
22
+ "<|vision_end|>": 151653,
23
+ "<|vision_pad|>": 151654,
24
+ "<|vision_start|>": 151652
25
+ }
chat_template.jinja ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0]['role'] == 'system' %}
4
+ {{- messages[0]['content'] }}
5
+ {%- else %}
6
+ {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}
7
+ {%- endif %}
8
+ {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
9
+ {%- for tool in tools %}
10
+ {{- "\n" }}
11
+ {{- tool | tojson }}
12
+ {%- endfor %}
13
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
14
+ {%- else %}
15
+ {%- if messages[0]['role'] == 'system' %}
16
+ {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
17
+ {%- else %}
18
+ {{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
19
+ {%- endif %}
20
+ {%- endif %}
21
+ {%- for message in messages %}
22
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
23
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
24
+ {%- elif message.role == "assistant" %}
25
+ {{- '<|im_start|>' + message.role }}
26
+ {%- if message.content %}
27
+ {{- '\n' + message.content }}
28
+ {%- endif %}
29
+ {%- for tool_call in message.tool_calls %}
30
+ {%- if tool_call.function is defined %}
31
+ {%- set tool_call = tool_call.function %}
32
+ {%- endif %}
33
+ {{- '\n<tool_call>\n{"name": "' }}
34
+ {{- tool_call.name }}
35
+ {{- '", "arguments": ' }}
36
+ {{- tool_call.arguments | tojson }}
37
+ {{- '}\n</tool_call>' }}
38
+ {%- endfor %}
39
+ {{- '<|im_end|>\n' }}
40
+ {%- elif message.role == "tool" %}
41
+ {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
42
+ {{- '<|im_start|>user' }}
43
+ {%- endif %}
44
+ {{- '\n<tool_response>\n' }}
45
+ {{- message.content }}
46
+ {{- '\n</tool_response>' }}
47
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
48
+ {{- '<|im_end|>\n' }}
49
+ {%- endif %}
50
+ {%- endif %}
51
+ {%- endfor %}
52
+ {%- if add_generation_prompt %}
53
+ {{- '<|im_start|>assistant\n' }}
54
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DiffusionQwen3Model"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "block_masking_probability": 0.01,
8
+ "bos_token_id": null,
9
+ "dtype": "bfloat16",
10
+ "eos_token_id": 151645,
11
+ "head_dim": 128,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 1536,
14
+ "intermediate_size": 8960,
15
+ "mask_block_sizes": [
16
+ 2,
17
+ 4,
18
+ 8
19
+ ],
20
+ "mask_token_id": 151665,
21
+ "max_position_embeddings": 32768,
22
+ "model_type": "diffusion_qwen3",
23
+ "num_attention_heads": 12,
24
+ "num_hidden_layers": 28,
25
+ "num_key_value_heads": 2,
26
+ "pad_token_id": 151643,
27
+ "prefix_probability": 0.01,
28
+ "rms_norm_eps": 1e-06,
29
+ "rope_theta": 1000000.0,
30
+ "sampling_eps": 0.001,
31
+ "transformers_version": "4.57.3",
32
+ "truncate_probability": 0.01,
33
+ "vocab_size": 151666
34
+ }
inference.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Inference script for DiffusionQwen3 model checkpoint.
4
+
5
+ Usage:
6
+ # Interactive chat mode
7
+ python inference.py --checkpoint ./outputs/pretrain/checkpoint-1000 --mode chat
8
+
9
+ # Single prompt completion
10
+ python inference.py --checkpoint ./outputs/pretrain/checkpoint-1000 --prompt "def fibonacci(n):"
11
+
12
+ # With custom generation parameters
13
+ python inference.py --checkpoint ./outputs/pretrain/checkpoint-1000 \
14
+ --prompt "Write a hello world in Python" \
15
+ --steps 128 --temperature 0.0 --max-tokens 256
16
+ """
17
+
18
+ import argparse
19
+ import sys
20
+ import os
21
+ from typing import Optional, Tuple, List
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.distributions as dists
26
+ from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig
27
+
28
+
29
+ # ============================================================================
30
+ # Diffusion Sampling Utilities (adapted from CoDALanguageModel/generation_utils.py)
31
+ # ============================================================================
32
+
33
+ def top_p_logits(logits: torch.Tensor, top_p: float) -> torch.Tensor:
34
+ """Apply nucleus (top-p) filtering to logits."""
35
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
36
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
37
+ sorted_indices_to_remove = cumulative_probs > top_p
38
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
39
+ sorted_indices_to_remove[..., 0] = 0
40
+ mask = torch.zeros_like(logits, dtype=torch.bool)
41
+ mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
42
+ logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
43
+ return logits
44
+
45
+
46
+ def top_k_logits(logits: torch.Tensor, top_k: int) -> torch.Tensor:
47
+ """Apply top-k filtering to logits."""
48
+ top_k = min(top_k, logits.size(-1))
49
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
50
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
51
+ return logits
52
+
53
+
54
+ def sample_tokens(
55
+ logits: torch.Tensor,
56
+ temperature: float = 0.0,
57
+ top_p: Optional[float] = None,
58
+ top_k: Optional[int] = None,
59
+ neg_entropy: bool = False,
60
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ """
62
+ Sample tokens from logits with optional temperature, top-p, and top-k.
63
+
64
+ Returns:
65
+ confidence: Confidence scores for sampled tokens
66
+ x0: Sampled token IDs
67
+ """
68
+ if temperature > 0:
69
+ logits = logits / temperature
70
+ if top_p is not None and top_p < 1.0:
71
+ logits = top_p_logits(logits, top_p)
72
+ if top_k is not None:
73
+ logits = top_k_logits(logits, top_k)
74
+
75
+ probs = torch.softmax(logits, dim=-1)
76
+
77
+ if temperature > 0:
78
+ try:
79
+ x0 = dists.Categorical(probs=probs).sample()
80
+ confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
81
+ except:
82
+ confidence, x0 = probs.max(dim=-1)
83
+ else:
84
+ confidence, x0 = probs.max(dim=-1)
85
+
86
+ if neg_entropy:
87
+ # Use negative entropy as confidence (for entropy-based sampling)
88
+ epsilon = 1e-10
89
+ log_probs = torch.log(probs + epsilon)
90
+ confidence = torch.sum(probs * log_probs, dim=-1)
91
+
92
+ return confidence, x0
93
+
94
+
95
+ # ============================================================================
96
+ # Diffusion Generation
97
+ # ============================================================================
98
+
99
+ @torch.no_grad()
100
+ def diffusion_generate(
101
+ model: PreTrainedModel,
102
+ input_ids: torch.LongTensor,
103
+ mask_token_id: int,
104
+ max_new_tokens: int = 128,
105
+ steps: int = 128,
106
+ temperature: float = 0.0,
107
+ top_p: Optional[float] = None,
108
+ top_k: Optional[int] = None,
109
+ alg: str = "entropy",
110
+ alg_temp: Optional[float] = 0.1,
111
+ eps: float = 1e-3,
112
+ verbose: bool = False,
113
+ ) -> torch.LongTensor:
114
+ """
115
+ Generate text using discrete diffusion.
116
+
117
+ Args:
118
+ model: The diffusion language model
119
+ input_ids: Input token IDs (prompt) [batch_size, prompt_len]
120
+ mask_token_id: Token ID for mask token
121
+ max_new_tokens: Maximum number of new tokens to generate
122
+ steps: Number of diffusion steps
123
+ temperature: Sampling temperature (0 = greedy)
124
+ top_p: Nucleus sampling threshold
125
+ top_k: Top-k sampling threshold
126
+ alg: Sampling algorithm ("origin", "entropy", "maskgit_plus", "topk_margin")
127
+ alg_temp: Algorithm-specific temperature for confidence weighting
128
+ eps: Small epsilon for numerical stability
129
+ verbose: Print progress during generation
130
+
131
+ Returns:
132
+ Generated token sequence [batch_size, prompt_len + max_new_tokens]
133
+ """
134
+ device = input_ids.device
135
+ batch_size = input_ids.shape[0]
136
+ prompt_len = input_ids.shape[1]
137
+ total_len = prompt_len + max_new_tokens
138
+
139
+ # Initialize sequence: prompt + mask tokens for generation
140
+ x = F.pad(input_ids, (0, max_new_tokens), value=mask_token_id)
141
+
142
+ # Create timesteps from 1 to eps
143
+ timesteps = torch.linspace(1, eps, steps + 1, device=device)
144
+
145
+ for i in range(steps):
146
+ mask_index = (x == mask_token_id)
147
+
148
+ if not mask_index.any():
149
+ if verbose:
150
+ print(f"Step {i}: No more masked tokens, stopping early")
151
+ break
152
+
153
+ # Forward pass
154
+ outputs = model(x, return_logits_only=True)
155
+ if hasattr(outputs, 'logits'):
156
+ logits = outputs.logits
157
+ elif isinstance(outputs, tuple):
158
+ logits = outputs[0]
159
+ else:
160
+ logits = outputs
161
+
162
+ # Shift logits for next-token prediction
163
+ logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
164
+
165
+ # Get logits only for masked positions
166
+ mask_logits = logits[mask_index]
167
+
168
+ t = timesteps[i]
169
+ s = timesteps[i + 1]
170
+
171
+ if alg == "origin":
172
+ # Original diffusion: random unmasking with probability 1 - s/t
173
+ p_transfer = 1 - s / t if i < steps - 1 else 1
174
+ x0 = torch.zeros_like(x[mask_index], device=device, dtype=torch.long) + mask_token_id
175
+ transfer_index = torch.rand(*x0.shape, device=device) < p_transfer
176
+ _, x0[transfer_index] = sample_tokens(
177
+ mask_logits[transfer_index],
178
+ temperature=temperature,
179
+ top_p=top_p,
180
+ top_k=top_k
181
+ )
182
+ x[mask_index] = x0.clone()
183
+ else:
184
+ # Confidence-based unmasking algorithms
185
+ if alg == "maskgit_plus":
186
+ confidence, x0 = sample_tokens(
187
+ mask_logits, temperature=temperature, top_p=top_p, top_k=top_k
188
+ )
189
+ elif alg == "topk_margin":
190
+ # Margin confidence: difference between top-2 probabilities
191
+ probs = F.softmax(mask_logits / (temperature if temperature > 0 else 1), dim=-1)
192
+ sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
193
+ confidence = sorted_probs[:, 0] - sorted_probs[:, 1]
194
+ _, x0 = sample_tokens(
195
+ mask_logits, temperature=temperature, top_p=top_p, top_k=top_k
196
+ )
197
+ elif alg == "entropy":
198
+ confidence, x0 = sample_tokens(
199
+ mask_logits, temperature=temperature, top_p=top_p, top_k=top_k,
200
+ neg_entropy=True
201
+ )
202
+ else:
203
+ raise ValueError(f"Unknown algorithm: {alg}")
204
+
205
+ # Determine how many tokens to unmask
206
+ num_mask_token = mask_index.sum() / batch_size
207
+ num_transfer = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
208
+
209
+ if num_transfer > 0:
210
+ # Create full confidence tensor
211
+ full_confidence = torch.full_like(x, -torch.inf, dtype=logits.dtype)
212
+ full_confidence[mask_index] = confidence
213
+
214
+ # Select top-k most confident positions to unmask
215
+ if alg_temp is None or alg_temp == 0:
216
+ _, transfer_index = torch.topk(full_confidence, num_transfer)
217
+ else:
218
+ # Stochastic selection with temperature
219
+ conf_probs = F.softmax(full_confidence / alg_temp, dim=-1)
220
+ transfer_index = torch.multinomial(conf_probs, num_samples=num_transfer)
221
+
222
+ # Create candidate tensor with predicted tokens
223
+ x_candidate = torch.zeros_like(x, dtype=torch.long) + mask_token_id
224
+ x_candidate[mask_index] = x0.clone()
225
+
226
+ # Update only selected positions
227
+ row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(transfer_index)
228
+ x[row_indices, transfer_index] = x_candidate[row_indices, transfer_index]
229
+
230
+ if verbose and (i + 1) % max(1, steps // 10) == 0:
231
+ remaining_masks = (x == mask_token_id).sum().item()
232
+ print(f"Step {i+1}/{steps}: {remaining_masks} masked tokens remaining")
233
+
234
+ return x
235
+
236
+
237
+ # ============================================================================
238
+ # Model Loading
239
+ # ============================================================================
240
+
241
+ def load_model_and_tokenizer(
242
+ checkpoint_path: str,
243
+ device: str = "auto",
244
+ torch_dtype: str = "bfloat16",
245
+ ) -> Tuple[PreTrainedModel, AutoTokenizer, dict]:
246
+ """
247
+ Load the diffusion model and tokenizer from checkpoint.
248
+
249
+ Args:
250
+ checkpoint_path: Path to the checkpoint directory
251
+ device: Device to load model on ("auto", "cuda", "cpu")
252
+ torch_dtype: Data type for model weights
253
+
254
+ Returns:
255
+ model: Loaded model
256
+ tokenizer: Loaded tokenizer
257
+ config: Model configuration dict
258
+ """
259
+ import json
260
+ from transformers import Qwen2ForCausalLM, Qwen2Config
261
+
262
+ # Determine device
263
+ if device == "auto":
264
+ device = "cuda" if torch.cuda.is_available() else "cpu"
265
+
266
+ # Get dtype
267
+ dtype_map = {
268
+ "float32": torch.float32,
269
+ "float16": torch.float16,
270
+ "bfloat16": torch.bfloat16,
271
+ }
272
+ dtype = dtype_map.get(torch_dtype, torch.bfloat16)
273
+ if device == "cpu" and dtype == torch.bfloat16:
274
+ print("Warning: bfloat16 on CPU may be slow, using float32")
275
+ dtype = torch.float32
276
+
277
+ print(f"Loading model from {checkpoint_path}...")
278
+ print(f" Device: {device}, Dtype: {dtype}")
279
+
280
+ # Load config
281
+ config_path = os.path.join(checkpoint_path, "config.json")
282
+ with open(config_path, "r") as f:
283
+ config_dict = json.load(f)
284
+
285
+ # Import and register the model class
286
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
287
+ from models.diffusion_qwen import DiffusionQwen3Model, DiffusionQwen3Config
288
+
289
+ # Create diffusion config
290
+ diff_config = DiffusionQwen3Config(**config_dict)
291
+
292
+ # Create a Qwen2Config to initialize the base model architecture
293
+ qwen_config = Qwen2Config(
294
+ vocab_size=diff_config.vocab_size,
295
+ hidden_size=diff_config.hidden_size,
296
+ intermediate_size=diff_config.intermediate_size,
297
+ num_hidden_layers=diff_config.num_hidden_layers,
298
+ num_attention_heads=diff_config.num_attention_heads,
299
+ num_key_value_heads=diff_config.num_key_value_heads,
300
+ max_position_embeddings=diff_config.max_position_embeddings,
301
+ rms_norm_eps=diff_config.rms_norm_eps,
302
+ rope_theta=diff_config.rope_theta,
303
+ hidden_act=diff_config.hidden_act,
304
+ attention_dropout=diff_config.attention_dropout,
305
+ use_sliding_window=False,
306
+ pad_token_id=diff_config.pad_token_id,
307
+ bos_token_id=diff_config.bos_token_id,
308
+ eos_token_id=diff_config.eos_token_id,
309
+ )
310
+
311
+ # Create DiffusionQwen3Model with proper architecture
312
+ model = DiffusionQwen3Model(diff_config)
313
+
314
+ # Initialize the base Qwen2 model architecture
315
+ print(" Initializing model architecture...")
316
+ base_model = Qwen2ForCausalLM(qwen_config)
317
+ model._init_from_qwen(base_model)
318
+ del base_model # Free memory
319
+
320
+ # Load state dict
321
+ weights_path = os.path.join(checkpoint_path, "pytorch_model.bin")
322
+ if not os.path.exists(weights_path):
323
+ # Try model.safetensors
324
+ weights_path = os.path.join(checkpoint_path, "model.safetensors")
325
+
326
+ print(f" Loading weights from {weights_path}...")
327
+ state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
328
+
329
+ # Handle potential key mismatches
330
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
331
+ if missing:
332
+ print(f" Warning: Missing keys ({len(missing)}): {missing[:3]}{'...' if len(missing) > 3 else ''}")
333
+ if unexpected:
334
+ print(f" Warning: Unexpected keys ({len(unexpected)}): {unexpected[:3]}{'...' if len(unexpected) > 3 else ''}")
335
+
336
+ # Move to device and set eval mode
337
+ model = model.to(device=device, dtype=dtype)
338
+ model.eval()
339
+
340
+ # Disable causal attention for bidirectional
341
+ model._disable_causal_masking()
342
+
343
+ # Load tokenizer
344
+ print(" Loading tokenizer...")
345
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True)
346
+
347
+ # Ensure mask token is set
348
+ if tokenizer.mask_token_id is None:
349
+ tokenizer.mask_token_id = config_dict.get("mask_token_id", 151665)
350
+
351
+ print(f" Model loaded successfully!")
352
+ print(f" Vocab size: {diff_config.vocab_size}")
353
+ print(f" Hidden size: {diff_config.hidden_size}")
354
+ print(f" Num layers: {diff_config.num_hidden_layers}")
355
+ print(f" Mask token ID: {diff_config.mask_token_id}")
356
+
357
+ return model, tokenizer, config_dict
358
+
359
+
360
+ # ============================================================================
361
+ # Generation Wrapper
362
+ # ============================================================================
363
+
364
+ def generate(
365
+ model: PreTrainedModel,
366
+ tokenizer: AutoTokenizer,
367
+ prompt: str,
368
+ max_new_tokens: int = 128,
369
+ steps: int = 128,
370
+ temperature: float = 0.0,
371
+ top_p: Optional[float] = None,
372
+ top_k: Optional[int] = None,
373
+ alg: str = "entropy",
374
+ alg_temp: float = 0.1,
375
+ verbose: bool = False,
376
+ ) -> str:
377
+ """
378
+ Generate text from a prompt.
379
+
380
+ Args:
381
+ model: The diffusion language model
382
+ tokenizer: The tokenizer
383
+ prompt: Input prompt text
384
+ max_new_tokens: Maximum tokens to generate
385
+ steps: Diffusion steps
386
+ temperature: Sampling temperature
387
+ top_p: Nucleus sampling threshold
388
+ top_k: Top-k sampling threshold
389
+ alg: Sampling algorithm
390
+ alg_temp: Algorithm temperature
391
+ verbose: Print progress
392
+
393
+ Returns:
394
+ Generated text (prompt + completion)
395
+ """
396
+ device = next(model.parameters()).device
397
+
398
+ # Tokenize prompt
399
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
400
+
401
+ # Get mask token ID
402
+ mask_token_id = getattr(model.config, "mask_token_id", tokenizer.mask_token_id)
403
+ if mask_token_id is None:
404
+ mask_token_id = 151665 # Default from config
405
+
406
+ # Generate
407
+ output_ids = diffusion_generate(
408
+ model=model,
409
+ input_ids=input_ids,
410
+ mask_token_id=mask_token_id,
411
+ max_new_tokens=max_new_tokens,
412
+ steps=steps,
413
+ temperature=temperature,
414
+ top_p=top_p,
415
+ top_k=top_k,
416
+ alg=alg,
417
+ alg_temp=alg_temp,
418
+ verbose=verbose,
419
+ )
420
+
421
+ # Filter out mask and pad tokens
422
+ output_ids = output_ids[0] # Remove batch dimension
423
+ pad_token_id = tokenizer.pad_token_id or 151643
424
+ output_ids = output_ids[output_ids != mask_token_id]
425
+ output_ids = output_ids[output_ids != pad_token_id]
426
+
427
+ # Decode
428
+ generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)
429
+
430
+ return generated_text
431
+
432
+
433
+ def chat_generate(
434
+ model: PreTrainedModel,
435
+ tokenizer: AutoTokenizer,
436
+ messages: List[dict],
437
+ max_new_tokens: int = 256,
438
+ steps: int = 128,
439
+ temperature: float = 0.0,
440
+ top_p: Optional[float] = None,
441
+ top_k: Optional[int] = None,
442
+ alg: str = "entropy",
443
+ alg_temp: float = 0.1,
444
+ verbose: bool = False,
445
+ ) -> str:
446
+ """
447
+ Generate chat response from conversation history.
448
+
449
+ Args:
450
+ model: The diffusion language model
451
+ tokenizer: The tokenizer
452
+ messages: List of message dicts with 'role' and 'content'
453
+ Other args: Same as generate()
454
+
455
+ Returns:
456
+ Assistant response text
457
+ """
458
+ device = next(model.parameters()).device
459
+
460
+ # Apply chat template
461
+ prompt = tokenizer.apply_chat_template(
462
+ messages,
463
+ tokenize=False,
464
+ add_generation_prompt=True,
465
+ )
466
+
467
+ # Tokenize
468
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
469
+ prompt_len = input_ids.shape[1]
470
+
471
+ # Get mask token ID
472
+ mask_token_id = getattr(model.config, "mask_token_id", tokenizer.mask_token_id)
473
+ if mask_token_id is None:
474
+ mask_token_id = 151665
475
+
476
+ # Generate
477
+ output_ids = diffusion_generate(
478
+ model=model,
479
+ input_ids=input_ids,
480
+ mask_token_id=mask_token_id,
481
+ max_new_tokens=max_new_tokens,
482
+ steps=steps,
483
+ temperature=temperature,
484
+ top_p=top_p,
485
+ top_k=top_k,
486
+ alg=alg,
487
+ alg_temp=alg_temp,
488
+ verbose=verbose,
489
+ )
490
+
491
+ # Get only the generated tokens (after prompt)
492
+ generated_ids = output_ids[0, prompt_len:]
493
+
494
+ # Filter out mask and pad tokens
495
+ pad_token_id = tokenizer.pad_token_id or 151643
496
+ generated_ids = generated_ids[generated_ids != mask_token_id]
497
+ generated_ids = generated_ids[generated_ids != pad_token_id]
498
+
499
+ # Decode
500
+ response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
501
+
502
+ return response
503
+
504
+
505
+ # ============================================================================
506
+ # Interactive Chat
507
+ # ============================================================================
508
+
509
+ def interactive_chat(
510
+ model: PreTrainedModel,
511
+ tokenizer: AutoTokenizer,
512
+ system_prompt: str = "You are a helpful assistant.",
513
+ **gen_kwargs,
514
+ ):
515
+ """Run interactive chat session."""
516
+ print("\n" + "=" * 60)
517
+ print("Interactive Chat Mode")
518
+ print("=" * 60)
519
+ print("Commands:")
520
+ print(" /exit or /quit - Exit the chat")
521
+ print(" /reset - Reset conversation history")
522
+ print(" /system <text> - Set new system prompt")
523
+ print("=" * 60 + "\n")
524
+
525
+ messages = [{"role": "system", "content": system_prompt}]
526
+
527
+ while True:
528
+ try:
529
+ user_input = input("\033[92mYou: \033[0m").strip()
530
+ except (EOFError, KeyboardInterrupt):
531
+ print("\nGoodbye!")
532
+ break
533
+
534
+ if not user_input:
535
+ continue
536
+
537
+ # Handle commands
538
+ if user_input.lower() in ["/exit", "/quit"]:
539
+ print("Goodbye!")
540
+ break
541
+
542
+ if user_input.lower() == "/reset":
543
+ messages = [{"role": "system", "content": system_prompt}]
544
+ print("\033[90mConversation reset.\033[0m")
545
+ continue
546
+
547
+ if user_input.lower().startswith("/system "):
548
+ system_prompt = user_input[8:].strip()
549
+ messages = [{"role": "system", "content": system_prompt}]
550
+ print("\033[90mSystem prompt updated.\033[0m")
551
+ continue
552
+
553
+ # Add user message
554
+ messages.append({"role": "user", "content": user_input})
555
+
556
+ # Generate response
557
+ print("\033[94mAssistant: \033[0m", end="", flush=True)
558
+ try:
559
+ response = chat_generate(
560
+ model=model,
561
+ tokenizer=tokenizer,
562
+ messages=messages,
563
+ **gen_kwargs,
564
+ )
565
+ print(response)
566
+ messages.append({"role": "assistant", "content": response})
567
+ except Exception as e:
568
+ print(f"\033[91mError: {e}\033[0m")
569
+ messages.pop() # Remove failed user message
570
+
571
+
572
+ # ============================================================================
573
+ # Main
574
+ # ============================================================================
575
+
576
+ def main():
577
+ parser = argparse.ArgumentParser(
578
+ description="Run inference with DiffusionQwen3 model",
579
+ formatter_class=argparse.RawDescriptionHelpFormatter,
580
+ )
581
+
582
+ # Model arguments
583
+ parser.add_argument(
584
+ "--checkpoint", "-c",
585
+ type=str,
586
+ default="./outputs/pretrain/checkpoint-1000",
587
+ help="Path to model checkpoint directory",
588
+ )
589
+ parser.add_argument(
590
+ "--device",
591
+ type=str,
592
+ default="auto",
593
+ choices=["auto", "cuda", "cpu"],
594
+ help="Device to run on",
595
+ )
596
+ parser.add_argument(
597
+ "--dtype",
598
+ type=str,
599
+ default="bfloat16",
600
+ choices=["float32", "float16", "bfloat16"],
601
+ help="Model data type",
602
+ )
603
+
604
+ # Generation mode
605
+ parser.add_argument(
606
+ "--mode", "-m",
607
+ type=str,
608
+ default="prompt",
609
+ choices=["prompt", "chat"],
610
+ help="Generation mode: 'prompt' for single completion, 'chat' for interactive",
611
+ )
612
+ parser.add_argument(
613
+ "--prompt", "-p",
614
+ type=str,
615
+ default=None,
616
+ help="Input prompt for single completion mode",
617
+ )
618
+ parser.add_argument(
619
+ "--system",
620
+ type=str,
621
+ default="You are a helpful assistant.",
622
+ help="System prompt for chat mode",
623
+ )
624
+
625
+ # Generation parameters
626
+ parser.add_argument("--max-tokens", type=int, default=256, help="Max tokens to generate")
627
+ parser.add_argument("--steps", type=int, default=128, help="Diffusion steps")
628
+ parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
629
+ parser.add_argument("--top-p", type=float, default=None, help="Nucleus sampling threshold")
630
+ parser.add_argument("--top-k", type=int, default=None, help="Top-k sampling")
631
+ parser.add_argument(
632
+ "--alg",
633
+ type=str,
634
+ default="entropy",
635
+ choices=["origin", "entropy", "maskgit_plus", "topk_margin"],
636
+ help="Diffusion sampling algorithm",
637
+ )
638
+ parser.add_argument("--alg-temp", type=float, default=0.1, help="Algorithm temperature")
639
+ parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
640
+
641
+ args = parser.parse_args()
642
+
643
+ # Load model
644
+ model, tokenizer, config = load_model_and_tokenizer(
645
+ args.checkpoint,
646
+ device=args.device,
647
+ torch_dtype=args.dtype,
648
+ )
649
+
650
+ # Generation kwargs
651
+ gen_kwargs = {
652
+ "max_new_tokens": args.max_tokens,
653
+ "steps": args.steps,
654
+ "temperature": args.temperature,
655
+ "top_p": args.top_p,
656
+ "top_k": args.top_k,
657
+ "alg": args.alg,
658
+ "alg_temp": args.alg_temp,
659
+ "verbose": args.verbose,
660
+ }
661
+
662
+ if args.mode == "chat":
663
+ interactive_chat(model, tokenizer, system_prompt=args.system, **gen_kwargs)
664
+ else:
665
+ # Single prompt mode
666
+ if args.prompt is None:
667
+ # Default demo prompts
668
+ prompts = [
669
+ "def fibonacci(n):",
670
+ "Write a Python function to check if a number is prime:",
671
+ "# Calculate the factorial of a number\ndef factorial(n):",
672
+ ]
673
+ print("\nNo prompt provided. Running demo with sample prompts...\n")
674
+ for prompt in prompts:
675
+ print("=" * 60)
676
+ print(f"Prompt: {prompt}")
677
+ print("-" * 60)
678
+ result = generate(model, tokenizer, prompt, **gen_kwargs)
679
+ print(f"Generated:\n{result}")
680
+ print("=" * 60 + "\n")
681
+ else:
682
+ result = generate(model, tokenizer, args.prompt, **gen_kwargs)
683
+ print("\n" + "=" * 60)
684
+ print("Generated:")
685
+ print("=" * 60)
686
+ print(result)
687
+ print("=" * 60)
688
+
689
+
690
+ if __name__ == "__main__":
691
+ main()
692
+
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
modeling_diffusion_qwen3.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DiffusionQwen3 Model - Converts Qwen3-1.7B AR to Bidirectional Diffusion LLM
3
+
4
+ This module provides:
5
+ 1. DiffusionQwen3Config - Configuration for diffusion-adapted Qwen3
6
+ 2. DiffusionQwen3Model - The main model class with diffusion training/inference
7
+
8
+ Based on CoDA (Coding LM via Diffusion Adaptation) by Salesforce AI Research
9
+ https://arxiv.org/abs/2510.03270
10
+
11
+ CRITICAL: Loss normalization matches CoDA official implementation exactly:
12
+ loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len)
13
+ NOT dividing by num_masked (which causes gradient explosion)
14
+ """
15
+
16
+ import math
17
+ from dataclasses import dataclass
18
+ from typing import Optional, Tuple, Union, List, Dict, Any
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from transformers import PreTrainedModel, PretrainedConfig
24
+ from transformers import Qwen2ForCausalLM, Qwen2Config, AutoTokenizer
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+
27
+
28
+ @dataclass
29
+ class DiffusionQwen3Config(PretrainedConfig):
30
+ """Configuration for Diffusion-adapted Qwen3 model."""
31
+
32
+ model_type = "diffusion_qwen3"
33
+
34
+ def __init__(
35
+ self,
36
+ # Base Qwen3 config
37
+ vocab_size: int = 151936,
38
+ hidden_size: int = 2048,
39
+ intermediate_size: int = 6144,
40
+ num_hidden_layers: int = 28,
41
+ num_attention_heads: int = 16,
42
+ num_key_value_heads: int = 8,
43
+ head_dim: int = 128,
44
+ max_position_embeddings: int = 40960,
45
+ rms_norm_eps: float = 1e-6,
46
+ rope_theta: float = 1000000.0,
47
+ hidden_act: str = "silu",
48
+ attention_dropout: float = 0.0,
49
+ attention_bias: bool = False,
50
+ tie_word_embeddings: bool = True,
51
+
52
+ # Diffusion-specific config
53
+ mask_token_id: int = 151669,
54
+ pad_token_id: int = 151643,
55
+ bos_token_id: int = 151643,
56
+ eos_token_id: int = 151645,
57
+
58
+ # Diffusion training parameters
59
+ sampling_eps: float = 0.001, # CoDA default: creates 1/t in [1, 1000]
60
+ mask_block_sizes: List[int] = None,
61
+ block_masking_probability: float = 0.01,
62
+ prefix_probability: float = 0.01,
63
+ truncate_probability: float = 0.01,
64
+
65
+ **kwargs
66
+ ):
67
+ super().__init__(
68
+ pad_token_id=pad_token_id,
69
+ bos_token_id=bos_token_id,
70
+ eos_token_id=eos_token_id,
71
+ tie_word_embeddings=tie_word_embeddings,
72
+ **kwargs
73
+ )
74
+
75
+ # Base model config
76
+ self.vocab_size = vocab_size
77
+ self.hidden_size = hidden_size
78
+ self.intermediate_size = intermediate_size
79
+ self.num_hidden_layers = num_hidden_layers
80
+ self.num_attention_heads = num_attention_heads
81
+ self.num_key_value_heads = num_key_value_heads
82
+ self.head_dim = head_dim
83
+ self.max_position_embeddings = max_position_embeddings
84
+ self.rms_norm_eps = rms_norm_eps
85
+ self.rope_theta = rope_theta
86
+ self.hidden_act = hidden_act
87
+ self.attention_dropout = attention_dropout
88
+ self.attention_bias = attention_bias
89
+
90
+ # Diffusion config
91
+ self.mask_token_id = mask_token_id
92
+ self.sampling_eps = sampling_eps
93
+ self.mask_block_sizes = mask_block_sizes or [2, 4, 8]
94
+ self.block_masking_probability = block_masking_probability
95
+ self.prefix_probability = prefix_probability
96
+ self.truncate_probability = truncate_probability
97
+
98
+
99
+ class DiffusionQwen3Model(PreTrainedModel):
100
+ """
101
+ Qwen3 model adapted for discrete diffusion language modeling.
102
+
103
+ Key modifications from standard Qwen3:
104
+ 1. Bidirectional attention (is_causal=False)
105
+ 2. Masked diffusion training objective
106
+ 3. Loss weighted by 1/t (inverse noise level)
107
+ 4. Support for progressive masking (S1/S2/S3)
108
+
109
+ CRITICAL: Loss normalization follows CoDA exactly (line 524 of modeling.py):
110
+ loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len)
111
+ """
112
+
113
+ config_class = DiffusionQwen3Config
114
+ base_model_prefix = "model"
115
+ supports_gradient_checkpointing = True
116
+ _no_split_modules = ["Qwen2DecoderLayer"]
117
+ _supports_flash_attn_2 = True
118
+ _supports_sdpa = True
119
+
120
+ def __init__(self, config: DiffusionQwen3Config):
121
+ super().__init__(config)
122
+ self.config = config
123
+
124
+ # Initialize the base Qwen2 model (Qwen3 uses Qwen2 architecture in transformers)
125
+ # We'll load this from pretrained in the from_pretrained method
126
+ self.model = None
127
+ self.lm_head = None
128
+ self.embed_tokens = None
129
+
130
+ # Diffusion parameters
131
+ self.mask_token_id = config.mask_token_id
132
+ self.sampling_eps = config.sampling_eps
133
+
134
+ # Loss function
135
+ self.loss_fn = nn.CrossEntropyLoss(reduction='none')
136
+
137
+ def _init_from_qwen(self, qwen_model: Qwen2ForCausalLM):
138
+ """Initialize from a pretrained Qwen model."""
139
+ # Extract the base model and lm_head
140
+ self.model = qwen_model.model
141
+ self.lm_head = qwen_model.lm_head
142
+ self.embed_tokens = self.model.embed_tokens
143
+
144
+ # Disable causal masking in all attention layers
145
+ self._disable_causal_masking()
146
+
147
+ def _disable_causal_masking(self):
148
+ """Disable causal attention masks for bidirectional attention."""
149
+ for layer in self.model.layers:
150
+ if hasattr(layer.self_attn, 'is_causal'):
151
+ layer.self_attn.is_causal = False
152
+
153
+ def get_input_embeddings(self):
154
+ return self.embed_tokens
155
+
156
+ def set_input_embeddings(self, value):
157
+ self.embed_tokens = value
158
+ self.model.embed_tokens = value
159
+
160
+ def get_output_embeddings(self):
161
+ return self.lm_head
162
+
163
+ def set_output_embeddings(self, new_embeddings):
164
+ self.lm_head = new_embeddings
165
+
166
+ def get_embeds(self, input_ids: torch.LongTensor) -> torch.Tensor:
167
+ """Get token embeddings."""
168
+ return self.embed_tokens(input_ids)
169
+
170
+ def transition(
171
+ self,
172
+ x_0: torch.LongTensor,
173
+ sigma: torch.Tensor,
174
+ maskable_mask: torch.BoolTensor,
175
+ mask_block_size: int = 1,
176
+ ) -> torch.LongTensor:
177
+ """
178
+ Apply noise transition: mask tokens with probability sigma.
179
+
180
+ Args:
181
+ x_0: Original token IDs [batch_size, seq_len]
182
+ sigma: Noise level per sample [batch_size, 1] or [batch_size]
183
+ maskable_mask: Boolean mask of which positions can be masked [batch_size, seq_len]
184
+ mask_block_size: Size of contiguous blocks to mask (1 for individual tokens)
185
+
186
+ Returns:
187
+ x_t: Noisy token IDs with some tokens replaced by mask_token_id
188
+ """
189
+ if sigma.dim() == 1:
190
+ sigma = sigma.unsqueeze(-1)
191
+
192
+ if mask_block_size == 1:
193
+ # Standard per-token masking
194
+ move_indices = (torch.rand_like(x_0, dtype=torch.float) < sigma) & maskable_mask
195
+ x_t = torch.where(move_indices, self.mask_token_id, x_0)
196
+ else:
197
+ # Block masking
198
+ x_t = self._block_masking(x_0, sigma, maskable_mask, mask_block_size)
199
+
200
+ return x_t
201
+
202
+ def _block_masking(
203
+ self,
204
+ x_0: torch.LongTensor,
205
+ sigma: torch.Tensor,
206
+ maskable_mask: torch.BoolTensor,
207
+ mask_block_size: int,
208
+ ) -> torch.LongTensor:
209
+ """Apply block masking for contiguous spans."""
210
+ batch_size, seq_len = x_0.shape
211
+
212
+ if seq_len < mask_block_size:
213
+ return x_0
214
+
215
+ # Calculate number of possible block positions
216
+ num_windows = seq_len - mask_block_size + 1
217
+
218
+ # Create all possible block positions
219
+ window_starts = torch.arange(num_windows, device=x_0.device)
220
+ block_offsets = torch.arange(mask_block_size, device=x_0.device)
221
+ all_positions = window_starts.unsqueeze(1) + block_offsets.unsqueeze(0)
222
+
223
+ # Check which blocks are fully maskable
224
+ maskable_blocks = maskable_mask.unsqueeze(1).expand(-1, num_windows, -1)
225
+ maskable_blocks = maskable_blocks.gather(2, all_positions.unsqueeze(0).expand(batch_size, -1, -1))
226
+ fully_maskable = maskable_blocks.all(dim=2)
227
+
228
+ # Scale sigma for block masking (CoDA line 569)
229
+ effective_sigma = 1 - (1 - sigma) ** (1 / mask_block_size)
230
+
231
+ # Determine which blocks to mask
232
+ should_mask = (torch.rand(batch_size, num_windows, device=x_0.device) < effective_sigma) & fully_maskable
233
+
234
+ # Create final mask
235
+ position_indices = torch.arange(seq_len, device=x_0.device).unsqueeze(0).unsqueeze(0)
236
+ all_positions_expanded = all_positions.unsqueeze(0)
237
+ should_mask_expanded = should_mask.unsqueeze(2)
238
+
239
+ position_matches = (position_indices == all_positions_expanded.unsqueeze(3)).any(dim=2)
240
+ should_mask_positions = should_mask_expanded & position_matches
241
+ final_mask = should_mask_positions.any(dim=1)
242
+
243
+ return torch.where(final_mask, self.mask_token_id, x_0)
244
+
245
+ def forward(
246
+ self,
247
+ input_ids: torch.LongTensor,
248
+ attention_mask: Optional[torch.Tensor] = None,
249
+ labels: Optional[torch.LongTensor] = None,
250
+ src_mask: Optional[torch.BoolTensor] = None,
251
+ training_mode: str = "pretrain",
252
+ masking_schedule: Optional[Dict[str, Any]] = None,
253
+ epoch: Optional[int] = None,
254
+ return_logits_only: bool = False,
255
+ **kwargs,
256
+ ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], CausalLMOutputWithPast]:
257
+ """
258
+ Forward pass with diffusion training.
259
+
260
+ Args:
261
+ input_ids: Input token IDs [batch_size, seq_len]
262
+ attention_mask: Attention mask [batch_size, seq_len]
263
+ labels: Target labels (same as input_ids for diffusion)
264
+ src_mask: Source mask for SFT (True = prompt, False = response)
265
+ training_mode: "pretrain", "midtrain", or "sft"
266
+ masking_schedule: Optional override for masking probabilities
267
+ epoch: Current epoch for progressive masking
268
+ return_logits_only: If True, skip diffusion training logic (used by trainer)
269
+
270
+ Returns:
271
+ logits: Model predictions [batch_size, seq_len, vocab_size]
272
+ loss: Diffusion loss (if training and not return_logits_only)
273
+ """
274
+ if not self.training or return_logits_only:
275
+ # Inference mode OR trainer is handling diffusion logic
276
+ hidden_states = self.model(
277
+ input_ids=input_ids,
278
+ attention_mask=attention_mask,
279
+ ).last_hidden_state
280
+ logits = self.lm_head(hidden_states)
281
+ return CausalLMOutputWithPast(logits=logits, loss=None)
282
+
283
+ # Training mode
284
+ batch_size, seq_len = input_ids.shape
285
+
286
+ # Get masking configuration
287
+ if masking_schedule is not None:
288
+ prefix_prob = masking_schedule.get("prefix_probability", 0)
289
+ truncate_prob = masking_schedule.get("truncate_probability", 0)
290
+ block_prob = masking_schedule.get("block_masking_probability", 0)
291
+ mask_block_sizes = masking_schedule.get("mask_block_sizes", self.config.mask_block_sizes)
292
+ else:
293
+ prefix_prob = self.config.prefix_probability
294
+ truncate_prob = self.config.truncate_probability
295
+ block_prob = self.config.block_masking_probability
296
+ mask_block_sizes = self.config.mask_block_sizes
297
+
298
+ # Create maskable_mask based on training mode
299
+ if src_mask is not None:
300
+ # SFT mode: only mask response tokens
301
+ maskable_mask = ~src_mask
302
+ else:
303
+ # Pre-training/mid-training: all tokens maskable
304
+ maskable_mask = torch.ones_like(input_ids, dtype=torch.bool)
305
+
306
+ # Apply S1: Unmaskable prefix
307
+ if prefix_prob > 0:
308
+ maskable_mask = self._apply_prefix_masking(
309
+ input_ids, maskable_mask, prefix_prob
310
+ )
311
+
312
+ # Apply S2: Truncated suffix
313
+ if truncate_prob > 0:
314
+ input_ids, maskable_mask = self._apply_truncate_masking(
315
+ input_ids, maskable_mask, truncate_prob
316
+ )
317
+
318
+ # Sample timesteps and compute sigma
319
+ # CoDA line 475: sigma = (1 - sampling_eps) * rand + sampling_eps
320
+ sampling_eps = self.config.sampling_eps
321
+ t = (1 - sampling_eps) * torch.rand(batch_size, device=input_ids.device) + sampling_eps
322
+ sigma = t
323
+ # CoDA line 476: dsigma = 1 / sigma (for loss weighting)
324
+ dsigma = torch.reciprocal(t)
325
+
326
+ # Select block masking size
327
+ if block_prob > 0 and mask_block_sizes and torch.rand(1).item() < block_prob:
328
+ mask_block_size = mask_block_sizes[torch.randint(len(mask_block_sizes), (1,)).item()]
329
+ else:
330
+ mask_block_size = 1
331
+
332
+ # Apply noise transition
333
+ noisy_input_ids = self.transition(
334
+ input_ids, sigma, maskable_mask, mask_block_size
335
+ )
336
+
337
+ # Track which positions are masked (for loss computation)
338
+ loss_mask = (noisy_input_ids == self.mask_token_id)
339
+
340
+ # Forward pass through model
341
+ hidden_states = self.model(
342
+ input_ids=noisy_input_ids,
343
+ attention_mask=attention_mask,
344
+ ).last_hidden_state
345
+
346
+ logits = self.lm_head(hidden_states)
347
+ logits = logits.float()
348
+
349
+ # =================================================================
350
+ # LOSS COMPUTATION - MATCHES CODA EXACTLY (modeling.py lines 509-524)
351
+ # =================================================================
352
+ # Shift for next-token prediction
353
+ # logits: [batch, seq_len-1, vocab_size]
354
+ # labels: [batch, seq_len-1]
355
+ shift_logits = logits[..., :-1, :].contiguous()
356
+ shift_labels = input_ids[..., 1:].contiguous()
357
+ shift_loss_mask = loss_mask[..., 1:].contiguous()
358
+
359
+ # Cross-entropy loss per token
360
+ loss = self.loss_fn(
361
+ shift_logits.view(-1, self.config.vocab_size),
362
+ shift_labels.view(-1)
363
+ ).view(batch_size, -1)
364
+
365
+ # Zero out loss for non-masked positions
366
+ loss = loss.masked_fill(~shift_loss_mask, 0)
367
+
368
+ # =================================================================
369
+ # CRITICAL: CoDA normalization (line 524)
370
+ # Divide by (batch_size * seq_len), NOT by num_masked!
371
+ # This gives stable gradients regardless of mask ratio
372
+ # =================================================================
373
+ # loss = (dsigma[:, None] * loss).sum() / (batch_size * seq_len)
374
+ loss = (dsigma.unsqueeze(-1) * loss).sum() / (batch_size * seq_len)
375
+
376
+ return logits, loss
377
+
378
+ def _apply_prefix_masking(
379
+ self,
380
+ input_ids: torch.LongTensor,
381
+ maskable_mask: torch.BoolTensor,
382
+ prefix_prob: float,
383
+ ) -> torch.BoolTensor:
384
+ """Apply S1: Random unmaskable prefix."""
385
+ batch_size, seq_len = input_ids.shape
386
+
387
+ # Randomly decide which samples get prefix
388
+ apply_prefix = torch.rand(batch_size, device=input_ids.device) < prefix_prob
389
+
390
+ # Generate random prefix lengths
391
+ prefix_lengths = torch.randint(1, seq_len, (batch_size,), device=input_ids.device)
392
+
393
+ # Create position indices
394
+ positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
395
+
396
+ # Create prefix mask
397
+ prefix_mask = positions < prefix_lengths.unsqueeze(1)
398
+
399
+ # Apply: set maskable_mask to False for prefix positions
400
+ maskable_mask = maskable_mask & ~(apply_prefix.unsqueeze(1) & prefix_mask)
401
+
402
+ return maskable_mask
403
+
404
+ def _apply_truncate_masking(
405
+ self,
406
+ input_ids: torch.LongTensor,
407
+ maskable_mask: torch.BoolTensor,
408
+ truncate_prob: float,
409
+ ) -> Tuple[torch.LongTensor, torch.BoolTensor]:
410
+ """Apply S2: Random truncated suffix."""
411
+ batch_size, seq_len = input_ids.shape
412
+
413
+ # Randomly decide which samples get truncated
414
+ apply_truncate = torch.rand(batch_size, device=input_ids.device) < truncate_prob
415
+
416
+ # Generate random truncation positions
417
+ truncate_positions = torch.randint(1, seq_len, (batch_size,), device=input_ids.device)
418
+
419
+ # Create position indices
420
+ positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
421
+
422
+ # Create truncate mask
423
+ truncate_mask = positions >= truncate_positions.unsqueeze(1)
424
+
425
+ # Apply: replace with pad token and update maskable_mask
426
+ input_ids = torch.where(
427
+ apply_truncate.unsqueeze(1) & truncate_mask,
428
+ self.config.pad_token_id,
429
+ input_ids
430
+ )
431
+ maskable_mask = maskable_mask & (input_ids != self.config.pad_token_id)
432
+
433
+ return input_ids, maskable_mask
434
+
435
+ @classmethod
436
+ def from_pretrained_qwen(
437
+ cls,
438
+ pretrained_model_name_or_path: str = "Qwen/Qwen3-1.7B",
439
+ config: Optional[DiffusionQwen3Config] = None,
440
+ **kwargs
441
+ ) -> "DiffusionQwen3Model":
442
+ """
443
+ Load from a pretrained Qwen3 model and convert to diffusion.
444
+
445
+ Args:
446
+ pretrained_model_name_or_path: HuggingFace model name or path
447
+ config: Optional DiffusionQwen3Config override
448
+ **kwargs: Additional arguments for from_pretrained
449
+
450
+ Returns:
451
+ DiffusionQwen3Model ready for diffusion training
452
+ """
453
+ # Load the base Qwen model
454
+ print(f"Loading base model from {pretrained_model_name_or_path}...")
455
+
456
+ qwen_model = Qwen2ForCausalLM.from_pretrained(
457
+ pretrained_model_name_or_path,
458
+ torch_dtype=kwargs.pop("torch_dtype", torch.bfloat16),
459
+ attn_implementation=kwargs.pop("attn_implementation", "flash_attention_2"),
460
+ **kwargs
461
+ )
462
+
463
+ # Create diffusion config if not provided
464
+ if config is None:
465
+ qwen_config = qwen_model.config
466
+ config = DiffusionQwen3Config(
467
+ vocab_size=qwen_config.vocab_size,
468
+ hidden_size=qwen_config.hidden_size,
469
+ intermediate_size=qwen_config.intermediate_size,
470
+ num_hidden_layers=qwen_config.num_hidden_layers,
471
+ num_attention_heads=qwen_config.num_attention_heads,
472
+ num_key_value_heads=qwen_config.num_key_value_heads,
473
+ max_position_embeddings=qwen_config.max_position_embeddings,
474
+ rms_norm_eps=qwen_config.rms_norm_eps,
475
+ rope_theta=qwen_config.rope_theta,
476
+ )
477
+
478
+ # Create diffusion model and initialize from Qwen
479
+ model = cls(config)
480
+ model._init_from_qwen(qwen_model)
481
+
482
+ print(f"Converted to DiffusionQwen3Model with bidirectional attention")
483
+ print(f" - Mask token ID: {config.mask_token_id}")
484
+ print(f" - Vocab size: {config.vocab_size}")
485
+ print(f" - Hidden size: {config.hidden_size}")
486
+ print(f" - Num layers: {config.num_hidden_layers}")
487
+
488
+ return model
489
+
490
+
491
+ def prepare_tokenizer(tokenizer_name: str = "Qwen/Qwen3-1.7B") -> AutoTokenizer:
492
+ """
493
+ Prepare tokenizer with mask token for diffusion training.
494
+
495
+ Args:
496
+ tokenizer_name: HuggingFace tokenizer name
497
+
498
+ Returns:
499
+ Tokenizer with mask token added
500
+ """
501
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
502
+
503
+ # Check if mask token already exists
504
+ if tokenizer.mask_token is None:
505
+ # Add mask token (CoDA uses ID 151669)
506
+ tokenizer.add_tokens("<|mask|>", special_tokens=True)
507
+ tokenizer.add_special_tokens(
508
+ {"mask_token": "<|mask|>"},
509
+ replace_additional_special_tokens=False
510
+ )
511
+ print(f"Added mask token: {tokenizer.mask_token} (ID: {tokenizer.mask_token_id})")
512
+ else:
513
+ print(f"Mask token already exists: {tokenizer.mask_token} (ID: {tokenizer.mask_token_id})")
514
+
515
+ return tokenizer
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47e6306d5cb44f8ea9da0ab55d9f13b581cf8306205bb4c9cb71039ce923c4c3
3
+ size 3086713515
special_tokens_map.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "mask_token": {
25
+ "content": "<|mask|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ "pad_token": {
32
+ "content": "<|endoftext|>",
33
+ "lstrip": false,
34
+ "normalized": false,
35
+ "rstrip": false,
36
+ "single_word": false
37
+ }
38
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a59820ad3f728fff77cf7e4188532fc45e5f80cd0299cde28046bd2b51c64bdf
3
+ size 11422081
tokenizer_config.json ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<|mask|>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": true
188
+ }
189
+ },
190
+ "additional_special_tokens": [
191
+ "<|im_start|>",
192
+ "<|im_end|>",
193
+ "<|object_ref_start|>",
194
+ "<|object_ref_end|>",
195
+ "<|box_start|>",
196
+ "<|box_end|>",
197
+ "<|quad_start|>",
198
+ "<|quad_end|>",
199
+ "<|vision_start|>",
200
+ "<|vision_end|>",
201
+ "<|vision_pad|>",
202
+ "<|image_pad|>",
203
+ "<|video_pad|>"
204
+ ],
205
+ "bos_token": null,
206
+ "clean_up_tokenization_spaces": false,
207
+ "eos_token": "<|im_end|>",
208
+ "errors": "replace",
209
+ "extra_special_tokens": {},
210
+ "mask_token": "<|mask|>",
211
+ "model_max_length": 131072,
212
+ "pad_token": "<|endoftext|>",
213
+ "split_special_tokens": false,
214
+ "tokenizer_class": "Qwen2Tokenizer",
215
+ "unk_token": null
216
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff