|
|
from transformers import BatchEncoding, Pipeline |
|
|
import torch |
|
|
from typing import Any, Generator |
|
|
|
|
|
class TextDiffusionPipeline(Pipeline): |
|
|
def _sanitize_parameters( |
|
|
self, |
|
|
num_steps: int = 50, |
|
|
allow_edits: bool = True, |
|
|
use_confidence: bool = False, |
|
|
stop_token: None = None, |
|
|
**kwargs |
|
|
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: |
|
|
|
|
|
|
|
|
forward_kwargs = { |
|
|
"num_steps": num_steps, |
|
|
"allow_edits": allow_edits, |
|
|
"use_confidence": use_confidence, |
|
|
"stop_token": stop_token |
|
|
} |
|
|
|
|
|
preprocess_kwargs = {} |
|
|
if "max_length" in kwargs: |
|
|
preprocess_kwargs["max_length"] = kwargs["max_length"] |
|
|
|
|
|
return preprocess_kwargs, forward_kwargs, {} |
|
|
|
|
|
def preprocess(self, input_text, max_length=None) -> BatchEncoding | Any: |
|
|
if self.tokenizer is None: |
|
|
raise ValueError("Tokenizer was not passed to the pipeline!") |
|
|
|
|
|
if max_length is None: |
|
|
|
|
|
max_length = getattr(self.model.config, "seq_length", 512) |
|
|
|
|
|
if input_text is None: |
|
|
input_text = "" |
|
|
|
|
|
tokenized_text = self.tokenizer.encode(input_text) |
|
|
|
|
|
if len(tokenized_text) < max_length: |
|
|
input_ids = torch.full((1, max_length), self.tokenizer.mask_token_id, dtype=torch.long) |
|
|
input_ids[0, :len(tokenized_text)] = torch.tensor(tokenized_text, dtype=torch.long) |
|
|
|
|
|
return BatchEncoding({ |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": torch.ones_like(input_ids) |
|
|
}) |
|
|
|
|
|
return self.tokenizer( |
|
|
input_text, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
max_length=max_length, |
|
|
truncation=True, |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
def diffusion_generator( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
num_steps: int, |
|
|
allow_edits: bool = True, |
|
|
use_confidence: bool = False |
|
|
) -> Generator[torch.Tensor, None, None]: |
|
|
if self.tokenizer is None: |
|
|
raise ValueError("Tokenizer was not passed to the pipeline!") |
|
|
|
|
|
current_state: torch.Tensor = input_ids.clone() |
|
|
yield current_state.clone() |
|
|
|
|
|
|
|
|
initial_mask = (current_state == self.tokenizer.mask_token_id) | \ |
|
|
(current_state == self.tokenizer.pad_token_id) |
|
|
|
|
|
for step in range(num_steps): |
|
|
t_current = 1 - step / num_steps |
|
|
t_next = 1 - (step + 1) / num_steps |
|
|
|
|
|
|
|
|
output = self.model(input_ids=current_state) |
|
|
logits = output.logits |
|
|
|
|
|
|
|
|
logits[:, :, self.tokenizer.mask_token_id] = torch.finfo(logits.dtype).min |
|
|
|
|
|
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
dist = torch.distributions.Categorical(probs) |
|
|
sampled_ids = dist.sample() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if step < num_steps - 1: |
|
|
unmasking_prob = (t_current - t_next) / t_current |
|
|
else: |
|
|
unmasking_prob = 1.0 |
|
|
|
|
|
remasking_mask: torch.Tensor = (current_state == self.tokenizer.mask_token_id) | \ |
|
|
(current_state == self.tokenizer.pad_token_id) |
|
|
|
|
|
if use_confidence: |
|
|
|
|
|
sample_probs = probs.gather(-1, sampled_ids.unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
|
|
|
if step < num_steps - 1: |
|
|
num_masked = remasking_mask.sum(dim=1, keepdim=True) |
|
|
num_to_unmask = (num_masked.float() * unmasking_prob).ceil().long() |
|
|
else: |
|
|
num_to_unmask = remasking_mask.sum(dim=1, keepdim=True) |
|
|
|
|
|
|
|
|
|
|
|
candidate_confidences = sample_probs.clone() |
|
|
candidate_confidences[~remasking_mask] = -float('inf') |
|
|
|
|
|
unmasking_mask = torch.zeros_like(remasking_mask, dtype=torch.bool) |
|
|
|
|
|
max_k = num_to_unmask.max().item() |
|
|
if max_k > 0: |
|
|
_, top_indices = candidate_confidences.topk(k=max_k, dim=1) |
|
|
range_tensor = torch.arange(max_k, device=current_state.device).unsqueeze(0) |
|
|
mask_k = range_tensor < num_to_unmask |
|
|
unmasking_mask.scatter_(1, top_indices, mask_k) |
|
|
|
|
|
else: |
|
|
|
|
|
unmasking_mask = torch.rand_like(current_state, dtype=torch.float) < unmasking_prob |
|
|
|
|
|
update_mask = unmasking_mask & remasking_mask & initial_mask |
|
|
|
|
|
if allow_edits: |
|
|
alpha_t = 0.1 * (1 - step / num_steps) |
|
|
|
|
|
edit_mask = torch.rand_like(current_state, dtype=torch.float) < alpha_t |
|
|
|
|
|
is_visible = (current_state != self.tokenizer.mask_token_id) & \ |
|
|
(current_state != self.tokenizer.pad_token_id) & \ |
|
|
(current_state != self.tokenizer.eos_token_id) |
|
|
edit_mask = is_visible & edit_mask & initial_mask |
|
|
|
|
|
|
|
|
update_mask = update_mask | edit_mask |
|
|
|
|
|
|
|
|
current_state[update_mask] = sampled_ids[update_mask] |
|
|
|
|
|
yield current_state.clone() |
|
|
|
|
|
@torch.no_grad() |
|
|
def _forward( |
|
|
self, |
|
|
model_inputs: torch.Tensor, |
|
|
num_steps: int = 50, |
|
|
allow_edits: bool = True, |
|
|
use_confidence: bool = False, |
|
|
stop_token: None = None |
|
|
) -> dict[str, Any]: |
|
|
if self.tokenizer is None: |
|
|
raise ValueError("Tokenizer was not passed to the pipeline!") |
|
|
|
|
|
input_ids = model_inputs["input_ids"] |
|
|
all_states = list(self.diffusion_generator(input_ids=input_ids, num_steps=num_steps, allow_edits=allow_edits, use_confidence=use_confidence)) |
|
|
final_state = all_states[-1] |
|
|
|
|
|
return {"final_state": final_state, "history": all_states} |
|
|
|
|
|
@torch.no_grad() |
|
|
def stream_generation( |
|
|
self, |
|
|
input_text: str, |
|
|
num_steps: int = 50, |
|
|
allow_edits: bool = True, |
|
|
use_confidence: bool = False, |
|
|
max_length: int | None = None, |
|
|
stop_token: str | None = None |
|
|
) -> Generator[str, None, None]: |
|
|
""" |
|
|
Public method to stream text generation step-by-step. |
|
|
""" |
|
|
|
|
|
inputs = self.preprocess(input_text, max_length) |
|
|
input_ids = inputs["input_ids"].to(self.model.device) |
|
|
|
|
|
|
|
|
for step_tensor in self.diffusion_generator(input_ids=input_ids, num_steps=num_steps, allow_edits=allow_edits, use_confidence=use_confidence): |
|
|
|
|
|
text = self.tokenizer.decode(step_tensor[0], skip_special_tokens=False) |
|
|
yield text |
|
|
|
|
|
if stop_token is not None and stop_token in text[len(input_text):]: |
|
|
text = input_text + text[len(input_text):].split(stop_token)[0] |
|
|
yield text |
|
|
|
|
|
def postprocess(self, model_outputs) -> list[str] | Any: |
|
|
if self.tokenizer is None: |
|
|
raise ValueError("Tokenizer was not passed to the pipeline!") |
|
|
|
|
|
|
|
|
final_ids = model_outputs["final_state"] |
|
|
return { |
|
|
"decoded_texts": self.tokenizer.batch_decode(final_ids, skip_special_tokens=False), |
|
|
"history": model_outputs["history"], |
|
|
"final_ids": final_ids |
|
|
} |
|
|
|
|
|
@torch.no_grad() |
|
|
def block_diffusion_generator( |
|
|
self, input_ids: torch.Tensor, |
|
|
block_size: int, |
|
|
max_length: int, |
|
|
num_steps: int, |
|
|
allow_edits: bool = True, |
|
|
use_confidence: bool = False, |
|
|
stop_token: str | None = None |
|
|
) -> Generator[torch.Tensor, None, None]: |
|
|
""" |
|
|
Generator that yields the diffusion states block-by-block. |
|
|
Args: |
|
|
input_ids (torch.Tensor): Initial input IDs with context. |
|
|
block_size (int): Number of tokens to generate in each block. |
|
|
max_length (int): Max length of the generated text. |
|
|
num_steps (int): Number of diffusion steps per block. |
|
|
allow_edits (bool): Whether to allow edits to existing tokens. |
|
|
use_confidence (bool): Whether to use confidence-based unmasking. |
|
|
stop_token (str | None): Token at which to stop generation early. |
|
|
Yields: |
|
|
torch.Tensor: The current state of the full sequence after each diffusion step. |
|
|
""" |
|
|
assert num_steps > 0, "num_steps must be greater than 0" |
|
|
if self.tokenizer is None: |
|
|
raise ValueError("Tokenizer was not passed to the pipeline!") |
|
|
|
|
|
max_seq_length = self.model.config.seq_length if hasattr(self.model.config, "seq_length") else 512 |
|
|
stop_token_id = self.tokenizer.convert_tokens_to_ids(stop_token) if stop_token is not None else None |
|
|
|
|
|
assert block_size > 0 and block_size <= max_seq_length, f"block_size must be in (0, {max_seq_length}]" |
|
|
|
|
|
full_sequence = input_ids.clone() |
|
|
current_length = input_ids.shape[1] |
|
|
while current_length < max_length: |
|
|
remaining = max_length - current_length |
|
|
this_block_len = min(block_size, remaining) |
|
|
if this_block_len <= 0: break |
|
|
|
|
|
|
|
|
mask_block = torch.full( |
|
|
(1, this_block_len), |
|
|
self.tokenizer.mask_token_id, |
|
|
dtype=torch.long, |
|
|
device=self.model.device |
|
|
) |
|
|
|
|
|
|
|
|
input_ids = torch.cat([full_sequence[:, -(max_seq_length - this_block_len):], mask_block], dim=1) |
|
|
|
|
|
for step_tensor in self.diffusion_generator( |
|
|
input_ids, |
|
|
num_steps=num_steps, |
|
|
allow_edits=allow_edits, |
|
|
use_confidence=use_confidence |
|
|
): |
|
|
current_generated_tokens = step_tensor[:, -this_block_len:] |
|
|
yield torch.cat([full_sequence, current_generated_tokens], dim=1) |
|
|
|
|
|
|
|
|
if stop_token_id is not None and stop_token_id in current_generated_tokens: |
|
|
|
|
|
eos_index = (current_generated_tokens == stop_token_id).nonzero(as_tuple=True)[1] |
|
|
current_generated_tokens = current_generated_tokens[:, :eos_index[0]] |
|
|
yield torch.cat([full_sequence, current_generated_tokens], dim=1) |
|
|
break |
|
|
|
|
|
|
|
|
full_sequence = torch.cat([full_sequence, current_generated_tokens], dim=1) |
|
|
current_length = full_sequence.shape[1] |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def semi_autoregressive_generate( |
|
|
self, |
|
|
input_text: str, |
|
|
block_size: int = 64, |
|
|
max_length: int = 256, |
|
|
num_steps: int = 50, |
|
|
allow_edits: bool = True, |
|
|
use_confidence: bool = False |
|
|
) -> dict[str, Any]: |
|
|
""" |
|
|
Semi-Autoregressive Generation: |
|
|
Generates text in blocks using the diffusion model. |
|
|
Each block is generated by appending MASK tokens to the current context |
|
|
and running the diffusion process on the combined sequence. |
|
|
Args: |
|
|
input_text (str): The initial prompt text. |
|
|
block_size (int): Number of tokens to generate in each block. |
|
|
max_length (int): Max length of the generated text. |
|
|
num_steps (int): Number of diffusion steps per block. |
|
|
allow_edits (bool): Whether to allow edits to existing tokens. |
|
|
use_confidence (bool): Whether to use confidence-based unmasking. |
|
|
Returns: |
|
|
dict[str, Any]: A dictionary containing the decoded texts, generation history, and final token IDs. |
|
|
""" |
|
|
if self.tokenizer is None: raise ValueError("No tokenizer") |
|
|
|
|
|
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(self.model.device) |
|
|
all_states = list(self.block_diffusion_generator(input_ids, block_size, max_length, num_steps, allow_edits, use_confidence=use_confidence)) |
|
|
final_state = all_states[-1] |
|
|
return { |
|
|
"decoded_texts": self.tokenizer.batch_decode(final_state, skip_special_tokens=False), |
|
|
"history": all_states, |
|
|
"final_ids": final_state |
|
|
} |
|
|
|
|
|
@torch.no_grad() |
|
|
def stream_semi_autoregressive_generate( |
|
|
self, |
|
|
input_text: str, |
|
|
block_size: int = 64, |
|
|
max_length: int = 256, |
|
|
num_steps: int = 50, |
|
|
allow_edits: bool = True, |
|
|
use_confidence: bool = False, |
|
|
stop_token: str | None = None |
|
|
) -> Generator[str, None, None]: |
|
|
""" |
|
|
Streams the generation process block-by-block. |
|
|
Yields the full decoded text at every diffusion step of every block. |
|
|
Args: |
|
|
input_text (str): The initial prompt text. |
|
|
block_size (int): Number of tokens to generate in each block. |
|
|
max_length (int): Max length of the generated text. |
|
|
num_steps (int): Number of diffusion steps per block. |
|
|
allow_edits (bool): Whether to allow edits to existing tokens. |
|
|
use_confidence (bool): Whether to use confidence-based unmasking. |
|
|
stop_token (None): Token at which to stop generation early. |
|
|
Yields: |
|
|
str: The current generated text after each diffusion step. |
|
|
""" |
|
|
if self.tokenizer is None: raise ValueError("No tokenizer") |
|
|
|
|
|
input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(self.model.device) |
|
|
|
|
|
for step_tensor in self.block_diffusion_generator(input_ids, block_size, max_length, num_steps, allow_edits, use_confidence=use_confidence, stop_token=stop_token): |
|
|
|
|
|
yield self.tokenizer.decode(step_tensor[0], skip_special_tokens=False) |