File size: 15,666 Bytes
8c2cc2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
from transformers import BatchEncoding, Pipeline
import torch
from typing import Any, Generator

class TextDiffusionPipeline(Pipeline):
    def _sanitize_parameters(
        self,
        num_steps: int = 50,
        allow_edits: bool = True,
        use_confidence: bool = False,
        stop_token: None = None,
        **kwargs
    ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
        # Allow user to control the number of steps (e.g., diffusion steps)
        # default to 10 steps
        forward_kwargs = {
            "num_steps": num_steps,
            "allow_edits": allow_edits,
            "use_confidence": use_confidence,
            "stop_token": stop_token
        }
        
        preprocess_kwargs = {}
        if "max_length" in kwargs:
            preprocess_kwargs["max_length"] = kwargs["max_length"]

        return preprocess_kwargs, forward_kwargs, {}
    
    def preprocess(self, input_text, max_length=None) -> BatchEncoding | Any:
        if self.tokenizer is None:
            raise ValueError("Tokenizer was not passed to the pipeline!")
        # Standard tokenization
        if max_length is None:
            # Safely access config if it exists, default to 512
            max_length = getattr(self.model.config, "seq_length", 512)
            
        if input_text is None:
            input_text = ""
            
        tokenized_text = self.tokenizer.encode(input_text)
        
        if len(tokenized_text) < max_length:
            input_ids = torch.full((1, max_length), self.tokenizer.mask_token_id, dtype=torch.long) # type: ignore
            input_ids[0, :len(tokenized_text)] = torch.tensor(tokenized_text, dtype=torch.long)

            return BatchEncoding({
                "input_ids": input_ids,
                "attention_mask": torch.ones_like(input_ids)
            })

        return self.tokenizer(
            input_text,
            return_tensors="pt",
            padding="max_length",
            max_length=max_length,
            truncation=True,
        )
        
    @torch.no_grad()
    def diffusion_generator(
        self,
        input_ids: torch.Tensor,
        num_steps: int,
        allow_edits: bool = True,
        use_confidence: bool = False
    ) -> Generator[torch.Tensor, None, None]:
        if self.tokenizer is None:
            raise ValueError("Tokenizer was not passed to the pipeline!")
        
        current_state: torch.Tensor = input_ids.clone()
        yield current_state.clone() # Yield Step 0
        
        # Determine which tokens can be re-masked (i.e., mask and pad tokens)
        initial_mask = (current_state == self.tokenizer.mask_token_id) | \
                       (current_state == self.tokenizer.pad_token_id)
        
        for step in range(num_steps):
            t_current = 1 - step / num_steps
            t_next = 1 - (step + 1) / num_steps
            
            # Predict full text with model
            output = self.model(input_ids=current_state)
            logits = output.logits
            
            # Set logit that corresponds to the mask token to -inf
            logits[:, :, self.tokenizer.mask_token_id] = torch.finfo(logits.dtype).min
            
            # Ancestral sampling logic
            probs = torch.softmax(logits, dim=-1)
            dist = torch.distributions.Categorical(probs)
            sampled_ids = dist.sample()
            
            # Calculate Unmasking Probability (Equation 7 https://arxiv.org/pdf/2406.07524)
            # P(unmask | masked) = (alpha_s - alpha_t) / (1 - alpha_t)
            # mapping: alpha_t = (1 - t_current), alpha_s = (1 - t_next)
            # resulting simplified formula: (t_current - t_next) / t_current
            if step < num_steps - 1:
                unmasking_prob = (t_current - t_next) / t_current
            else:
                unmasking_prob = 1.0 # Force unmask at the end
            
            remasking_mask: torch.Tensor = (current_state == self.tokenizer.mask_token_id) | \
                             (current_state == self.tokenizer.pad_token_id) # type: ignore
            
            if use_confidence:
                # Get the confidence (probability) of the tokens we just sampled
                sample_probs = probs.gather(-1, sampled_ids.unsqueeze(-1)).squeeze(-1)
                
                # Determine how many tokens to unmask this step
                if step < num_steps - 1:
                    num_masked = remasking_mask.sum(dim=1, keepdim=True)
                    num_to_unmask = (num_masked.float() * unmasking_prob).ceil().long()
                else:
                    num_to_unmask = remasking_mask.sum(dim=1, keepdim=True)
                
                # Select Top-K most confident tokens
                # Set confidence of already visible tokens to -inf so they aren't picked
                candidate_confidences = sample_probs.clone()
                candidate_confidences[~remasking_mask] = -float('inf')
                
                unmasking_mask = torch.zeros_like(remasking_mask, dtype=torch.bool)
                
                max_k = num_to_unmask.max().item()
                if max_k > 0:
                    _, top_indices = candidate_confidences.topk(k=max_k, dim=1)
                    range_tensor = torch.arange(max_k, device=current_state.device).unsqueeze(0)
                    mask_k = range_tensor < num_to_unmask
                    unmasking_mask.scatter_(1, top_indices, mask_k)

            else:
                # Random Unmasking
                unmasking_mask = torch.rand_like(current_state, dtype=torch.float) < unmasking_prob
                        
            update_mask = unmasking_mask & remasking_mask & initial_mask
            
            if allow_edits: # Apply Seed Diffusion Editing Logic (Section 3.1 in https://arxiv.org/pdf/2508.02193)
                alpha_t = 0.1 * (1 - step / num_steps)  # alpha_t decreases from 0.1 to 0 (Seed Diffusion)
                
                edit_mask = torch.rand_like(current_state, dtype=torch.float) < alpha_t
                
                is_visible = (current_state != self.tokenizer.mask_token_id) & \
                             (current_state != self.tokenizer.pad_token_id) & \
                             (current_state != self.tokenizer.eos_token_id)
                edit_mask = is_visible & edit_mask & initial_mask # Use initial_mask to avoid editing original prompt
                
                # Combine both masks
                update_mask = update_mask | edit_mask

            # Update current state
            current_state[update_mask] = sampled_ids[update_mask]
            
            yield current_state.clone() # Yield after each step
    
    @torch.no_grad()
    def _forward(
        self,
        model_inputs: torch.Tensor,
        num_steps: int = 50,
        allow_edits: bool = True,
        use_confidence: bool = False,
        stop_token: None = None
    ) -> dict[str, Any]:
        if self.tokenizer is None:
            raise ValueError("Tokenizer was not passed to the pipeline!")
        
        input_ids = model_inputs["input_ids"]
        all_states = list(self.diffusion_generator(input_ids=input_ids, num_steps=num_steps, allow_edits=allow_edits, use_confidence=use_confidence))
        final_state = all_states[-1]
    
        return {"final_state": final_state, "history": all_states}
    
    @torch.no_grad()
    def stream_generation(
        self,
        input_text: str,
        num_steps: int = 50,
        allow_edits: bool = True,
        use_confidence: bool = False,
        max_length: int | None = None, 
        stop_token: str | None = None
    ) -> Generator[str, None, None]:
        """
        Public method to stream text generation step-by-step.
        """
        # 1. Preprocess
        inputs = self.preprocess(input_text, max_length)
        input_ids = inputs["input_ids"].to(self.model.device) # type: ignore
        
        # 2. Iterate over generator
        for step_tensor in self.diffusion_generator(input_ids=input_ids, num_steps=num_steps, allow_edits=allow_edits, use_confidence=use_confidence):
            # Decode current state
            text = self.tokenizer.decode(step_tensor[0], skip_special_tokens=False) # type: ignore
            yield text
            
        if stop_token is not None and stop_token in text[len(input_text):]:
            text = input_text + text[len(input_text):].split(stop_token)[0]
            yield text
        
    def postprocess(self, model_outputs) -> list[str] | Any:
        if self.tokenizer is None:
            raise ValueError("Tokenizer was not passed to the pipeline!")
        
        # Convert final tensor to image/text
        final_ids = model_outputs["final_state"]
        return {
            "decoded_texts": self.tokenizer.batch_decode(final_ids, skip_special_tokens=False),
            "history": model_outputs["history"],
            "final_ids": final_ids
        }
        
    @torch.no_grad()
    def block_diffusion_generator(
        self, input_ids: torch.Tensor,
        block_size: int,
        max_length: int,
        num_steps: int,
        allow_edits: bool = True,
        use_confidence: bool = False,
        stop_token: str | None = None
    ) -> Generator[torch.Tensor, None, None]:
        """
        Generator that yields the diffusion states block-by-block.
        Args:
            input_ids (torch.Tensor): Initial input IDs with context.
            block_size (int): Number of tokens to generate in each block.
            max_length (int): Max length of the generated text.
            num_steps (int): Number of diffusion steps per block.
            allow_edits (bool): Whether to allow edits to existing tokens.
            use_confidence (bool): Whether to use confidence-based unmasking.
            stop_token (str | None): Token at which to stop generation early.
        Yields:
            torch.Tensor: The current state of the full sequence after each diffusion step.
        """
        assert num_steps > 0, "num_steps must be greater than 0"
        if self.tokenizer is None:
            raise ValueError("Tokenizer was not passed to the pipeline!")
        
        max_seq_length = self.model.config.seq_length if hasattr(self.model.config, "seq_length") else 512
        stop_token_id = self.tokenizer.convert_tokens_to_ids(stop_token) if stop_token is not None else None
        
        assert block_size > 0 and block_size <= max_seq_length, f"block_size must be in (0, {max_seq_length}]"
        
        full_sequence = input_ids.clone()
        current_length = input_ids.shape[1]
        while current_length < max_length:
            remaining = max_length - current_length
            this_block_len = min(block_size, remaining)
            if this_block_len <= 0: break
            
            # Append MASK tokens for the new block
            mask_block = torch.full(
                (1, this_block_len), 
                self.tokenizer.mask_token_id, # type: ignore
                dtype=torch.long, 
                device=self.model.device
            )
            
            # Combine Context + New Masks
            input_ids = torch.cat([full_sequence[:, -(max_seq_length - this_block_len):], mask_block], dim=1)
            
            for step_tensor in self.diffusion_generator(
                input_ids, 
                num_steps=num_steps, 
                allow_edits=allow_edits,
                use_confidence=use_confidence
            ):
                current_generated_tokens = step_tensor[:, -this_block_len:]
                yield torch.cat([full_sequence, current_generated_tokens], dim=1)
                
            
            if stop_token_id is not None and stop_token_id in current_generated_tokens:
                # Stop if EOS is generated
                eos_index = (current_generated_tokens == stop_token_id).nonzero(as_tuple=True)[1] # type: ignore
                current_generated_tokens = current_generated_tokens[:, :eos_index[0]]
                yield torch.cat([full_sequence, current_generated_tokens], dim=1)
                break

            # Update full sequence and current length
            full_sequence = torch.cat([full_sequence, current_generated_tokens], dim=1)
            current_length = full_sequence.shape[1]
        
    
    @torch.no_grad()
    def semi_autoregressive_generate(
        self, 
        input_text: str, 
        block_size: int = 64, 
        max_length: int = 256, 
        num_steps: int = 50,
        allow_edits: bool = True,
        use_confidence: bool = False
    ) -> dict[str, Any]:
        """
        Semi-Autoregressive Generation:
        Generates text in blocks using the diffusion model.
        Each block is generated by appending MASK tokens to the current context
        and running the diffusion process on the combined sequence.
        Args:
            input_text (str): The initial prompt text.
            block_size (int): Number of tokens to generate in each block.
            max_length (int): Max length of the generated text.
            num_steps (int): Number of diffusion steps per block.
            allow_edits (bool): Whether to allow edits to existing tokens.
            use_confidence (bool): Whether to use confidence-based unmasking.
        Returns:
            dict[str, Any]: A dictionary containing the decoded texts, generation history, and final token IDs.
        """
        if self.tokenizer is None: raise ValueError("No tokenizer")
        
        input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(self.model.device) # type: ignore
        all_states = list(self.block_diffusion_generator(input_ids, block_size, max_length, num_steps, allow_edits, use_confidence=use_confidence))
        final_state = all_states[-1]
        return {
            "decoded_texts": self.tokenizer.batch_decode(final_state, skip_special_tokens=False),
            "history": all_states,
            "final_ids": final_state
        }
    
    @torch.no_grad()
    def stream_semi_autoregressive_generate(
        self, 
        input_text: str, 
        block_size: int = 64, 
        max_length: int = 256, 
        num_steps: int = 50,
        allow_edits: bool = True,
        use_confidence: bool = False,
        stop_token: str | None = None
    ) -> Generator[str, None, None]:
        """
        Streams the generation process block-by-block.
        Yields the full decoded text at every diffusion step of every block.
        Args:
            input_text (str): The initial prompt text.
            block_size (int): Number of tokens to generate in each block.
            max_length (int): Max length of the generated text.
            num_steps (int): Number of diffusion steps per block.
            allow_edits (bool): Whether to allow edits to existing tokens.
            use_confidence (bool): Whether to use confidence-based unmasking.
            stop_token (None): Token at which to stop generation early.
        Yields:
            str: The current generated text after each diffusion step.
        """
        if self.tokenizer is None: raise ValueError("No tokenizer")
        
        input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(self.model.device) # type: ignore
        
        for step_tensor in self.block_diffusion_generator(input_ids, block_size, max_length, num_steps, allow_edits, use_confidence=use_confidence, stop_token=stop_token):
            # Decode current state
            yield self.tokenizer.decode(step_tensor[0], skip_special_tokens=False) # type: ignore