diffusionGPT / pipeline.py
JorgeVanco's picture
Upload folder using huggingface_hub
8c2cc2d verified
from transformers import BatchEncoding, Pipeline
import torch
from typing import Any, Generator
class TextDiffusionPipeline(Pipeline):
def _sanitize_parameters(
self,
num_steps: int = 50,
allow_edits: bool = True,
use_confidence: bool = False,
stop_token: None = None,
**kwargs
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
# Allow user to control the number of steps (e.g., diffusion steps)
# default to 10 steps
forward_kwargs = {
"num_steps": num_steps,
"allow_edits": allow_edits,
"use_confidence": use_confidence,
"stop_token": stop_token
}
preprocess_kwargs = {}
if "max_length" in kwargs:
preprocess_kwargs["max_length"] = kwargs["max_length"]
return preprocess_kwargs, forward_kwargs, {}
def preprocess(self, input_text, max_length=None) -> BatchEncoding | Any:
if self.tokenizer is None:
raise ValueError("Tokenizer was not passed to the pipeline!")
# Standard tokenization
if max_length is None:
# Safely access config if it exists, default to 512
max_length = getattr(self.model.config, "seq_length", 512)
if input_text is None:
input_text = ""
tokenized_text = self.tokenizer.encode(input_text)
if len(tokenized_text) < max_length:
input_ids = torch.full((1, max_length), self.tokenizer.mask_token_id, dtype=torch.long) # type: ignore
input_ids[0, :len(tokenized_text)] = torch.tensor(tokenized_text, dtype=torch.long)
return BatchEncoding({
"input_ids": input_ids,
"attention_mask": torch.ones_like(input_ids)
})
return self.tokenizer(
input_text,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True,
)
@torch.no_grad()
def diffusion_generator(
self,
input_ids: torch.Tensor,
num_steps: int,
allow_edits: bool = True,
use_confidence: bool = False
) -> Generator[torch.Tensor, None, None]:
if self.tokenizer is None:
raise ValueError("Tokenizer was not passed to the pipeline!")
current_state: torch.Tensor = input_ids.clone()
yield current_state.clone() # Yield Step 0
# Determine which tokens can be re-masked (i.e., mask and pad tokens)
initial_mask = (current_state == self.tokenizer.mask_token_id) | \
(current_state == self.tokenizer.pad_token_id)
for step in range(num_steps):
t_current = 1 - step / num_steps
t_next = 1 - (step + 1) / num_steps
# Predict full text with model
output = self.model(input_ids=current_state)
logits = output.logits
# Set logit that corresponds to the mask token to -inf
logits[:, :, self.tokenizer.mask_token_id] = torch.finfo(logits.dtype).min
# Ancestral sampling logic
probs = torch.softmax(logits, dim=-1)
dist = torch.distributions.Categorical(probs)
sampled_ids = dist.sample()
# Calculate Unmasking Probability (Equation 7 https://arxiv.org/pdf/2406.07524)
# P(unmask | masked) = (alpha_s - alpha_t) / (1 - alpha_t)
# mapping: alpha_t = (1 - t_current), alpha_s = (1 - t_next)
# resulting simplified formula: (t_current - t_next) / t_current
if step < num_steps - 1:
unmasking_prob = (t_current - t_next) / t_current
else:
unmasking_prob = 1.0 # Force unmask at the end
remasking_mask: torch.Tensor = (current_state == self.tokenizer.mask_token_id) | \
(current_state == self.tokenizer.pad_token_id) # type: ignore
if use_confidence:
# Get the confidence (probability) of the tokens we just sampled
sample_probs = probs.gather(-1, sampled_ids.unsqueeze(-1)).squeeze(-1)
# Determine how many tokens to unmask this step
if step < num_steps - 1:
num_masked = remasking_mask.sum(dim=1, keepdim=True)
num_to_unmask = (num_masked.float() * unmasking_prob).ceil().long()
else:
num_to_unmask = remasking_mask.sum(dim=1, keepdim=True)
# Select Top-K most confident tokens
# Set confidence of already visible tokens to -inf so they aren't picked
candidate_confidences = sample_probs.clone()
candidate_confidences[~remasking_mask] = -float('inf')
unmasking_mask = torch.zeros_like(remasking_mask, dtype=torch.bool)
max_k = num_to_unmask.max().item()
if max_k > 0:
_, top_indices = candidate_confidences.topk(k=max_k, dim=1)
range_tensor = torch.arange(max_k, device=current_state.device).unsqueeze(0)
mask_k = range_tensor < num_to_unmask
unmasking_mask.scatter_(1, top_indices, mask_k)
else:
# Random Unmasking
unmasking_mask = torch.rand_like(current_state, dtype=torch.float) < unmasking_prob
update_mask = unmasking_mask & remasking_mask & initial_mask
if allow_edits: # Apply Seed Diffusion Editing Logic (Section 3.1 in https://arxiv.org/pdf/2508.02193)
alpha_t = 0.1 * (1 - step / num_steps) # alpha_t decreases from 0.1 to 0 (Seed Diffusion)
edit_mask = torch.rand_like(current_state, dtype=torch.float) < alpha_t
is_visible = (current_state != self.tokenizer.mask_token_id) & \
(current_state != self.tokenizer.pad_token_id) & \
(current_state != self.tokenizer.eos_token_id)
edit_mask = is_visible & edit_mask & initial_mask # Use initial_mask to avoid editing original prompt
# Combine both masks
update_mask = update_mask | edit_mask
# Update current state
current_state[update_mask] = sampled_ids[update_mask]
yield current_state.clone() # Yield after each step
@torch.no_grad()
def _forward(
self,
model_inputs: torch.Tensor,
num_steps: int = 50,
allow_edits: bool = True,
use_confidence: bool = False,
stop_token: None = None
) -> dict[str, Any]:
if self.tokenizer is None:
raise ValueError("Tokenizer was not passed to the pipeline!")
input_ids = model_inputs["input_ids"]
all_states = list(self.diffusion_generator(input_ids=input_ids, num_steps=num_steps, allow_edits=allow_edits, use_confidence=use_confidence))
final_state = all_states[-1]
return {"final_state": final_state, "history": all_states}
@torch.no_grad()
def stream_generation(
self,
input_text: str,
num_steps: int = 50,
allow_edits: bool = True,
use_confidence: bool = False,
max_length: int | None = None,
stop_token: str | None = None
) -> Generator[str, None, None]:
"""
Public method to stream text generation step-by-step.
"""
# 1. Preprocess
inputs = self.preprocess(input_text, max_length)
input_ids = inputs["input_ids"].to(self.model.device) # type: ignore
# 2. Iterate over generator
for step_tensor in self.diffusion_generator(input_ids=input_ids, num_steps=num_steps, allow_edits=allow_edits, use_confidence=use_confidence):
# Decode current state
text = self.tokenizer.decode(step_tensor[0], skip_special_tokens=False) # type: ignore
yield text
if stop_token is not None and stop_token in text[len(input_text):]:
text = input_text + text[len(input_text):].split(stop_token)[0]
yield text
def postprocess(self, model_outputs) -> list[str] | Any:
if self.tokenizer is None:
raise ValueError("Tokenizer was not passed to the pipeline!")
# Convert final tensor to image/text
final_ids = model_outputs["final_state"]
return {
"decoded_texts": self.tokenizer.batch_decode(final_ids, skip_special_tokens=False),
"history": model_outputs["history"],
"final_ids": final_ids
}
@torch.no_grad()
def block_diffusion_generator(
self, input_ids: torch.Tensor,
block_size: int,
max_length: int,
num_steps: int,
allow_edits: bool = True,
use_confidence: bool = False,
stop_token: str | None = None
) -> Generator[torch.Tensor, None, None]:
"""
Generator that yields the diffusion states block-by-block.
Args:
input_ids (torch.Tensor): Initial input IDs with context.
block_size (int): Number of tokens to generate in each block.
max_length (int): Max length of the generated text.
num_steps (int): Number of diffusion steps per block.
allow_edits (bool): Whether to allow edits to existing tokens.
use_confidence (bool): Whether to use confidence-based unmasking.
stop_token (str | None): Token at which to stop generation early.
Yields:
torch.Tensor: The current state of the full sequence after each diffusion step.
"""
assert num_steps > 0, "num_steps must be greater than 0"
if self.tokenizer is None:
raise ValueError("Tokenizer was not passed to the pipeline!")
max_seq_length = self.model.config.seq_length if hasattr(self.model.config, "seq_length") else 512
stop_token_id = self.tokenizer.convert_tokens_to_ids(stop_token) if stop_token is not None else None
assert block_size > 0 and block_size <= max_seq_length, f"block_size must be in (0, {max_seq_length}]"
full_sequence = input_ids.clone()
current_length = input_ids.shape[1]
while current_length < max_length:
remaining = max_length - current_length
this_block_len = min(block_size, remaining)
if this_block_len <= 0: break
# Append MASK tokens for the new block
mask_block = torch.full(
(1, this_block_len),
self.tokenizer.mask_token_id, # type: ignore
dtype=torch.long,
device=self.model.device
)
# Combine Context + New Masks
input_ids = torch.cat([full_sequence[:, -(max_seq_length - this_block_len):], mask_block], dim=1)
for step_tensor in self.diffusion_generator(
input_ids,
num_steps=num_steps,
allow_edits=allow_edits,
use_confidence=use_confidence
):
current_generated_tokens = step_tensor[:, -this_block_len:]
yield torch.cat([full_sequence, current_generated_tokens], dim=1)
if stop_token_id is not None and stop_token_id in current_generated_tokens:
# Stop if EOS is generated
eos_index = (current_generated_tokens == stop_token_id).nonzero(as_tuple=True)[1] # type: ignore
current_generated_tokens = current_generated_tokens[:, :eos_index[0]]
yield torch.cat([full_sequence, current_generated_tokens], dim=1)
break
# Update full sequence and current length
full_sequence = torch.cat([full_sequence, current_generated_tokens], dim=1)
current_length = full_sequence.shape[1]
@torch.no_grad()
def semi_autoregressive_generate(
self,
input_text: str,
block_size: int = 64,
max_length: int = 256,
num_steps: int = 50,
allow_edits: bool = True,
use_confidence: bool = False
) -> dict[str, Any]:
"""
Semi-Autoregressive Generation:
Generates text in blocks using the diffusion model.
Each block is generated by appending MASK tokens to the current context
and running the diffusion process on the combined sequence.
Args:
input_text (str): The initial prompt text.
block_size (int): Number of tokens to generate in each block.
max_length (int): Max length of the generated text.
num_steps (int): Number of diffusion steps per block.
allow_edits (bool): Whether to allow edits to existing tokens.
use_confidence (bool): Whether to use confidence-based unmasking.
Returns:
dict[str, Any]: A dictionary containing the decoded texts, generation history, and final token IDs.
"""
if self.tokenizer is None: raise ValueError("No tokenizer")
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(self.model.device) # type: ignore
all_states = list(self.block_diffusion_generator(input_ids, block_size, max_length, num_steps, allow_edits, use_confidence=use_confidence))
final_state = all_states[-1]
return {
"decoded_texts": self.tokenizer.batch_decode(final_state, skip_special_tokens=False),
"history": all_states,
"final_ids": final_state
}
@torch.no_grad()
def stream_semi_autoregressive_generate(
self,
input_text: str,
block_size: int = 64,
max_length: int = 256,
num_steps: int = 50,
allow_edits: bool = True,
use_confidence: bool = False,
stop_token: str | None = None
) -> Generator[str, None, None]:
"""
Streams the generation process block-by-block.
Yields the full decoded text at every diffusion step of every block.
Args:
input_text (str): The initial prompt text.
block_size (int): Number of tokens to generate in each block.
max_length (int): Max length of the generated text.
num_steps (int): Number of diffusion steps per block.
allow_edits (bool): Whether to allow edits to existing tokens.
use_confidence (bool): Whether to use confidence-based unmasking.
stop_token (None): Token at which to stop generation early.
Yields:
str: The current generated text after each diffusion step.
"""
if self.tokenizer is None: raise ValueError("No tokenizer")
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(self.model.device) # type: ignore
for step_tensor in self.block_diffusion_generator(input_ids, block_size, max_length, num_steps, allow_edits, use_confidence=use_confidence, stop_token=stop_token):
# Decode current state
yield self.tokenizer.decode(step_tensor[0], skip_special_tokens=False) # type: ignore