Bailan-Alex's picture
Upload folder using huggingface_hub
4f2b2f4 verified
import torch
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
import torch.distributed as dist
def add_gumbel_noise(logits, temperature):
"""
The Gumbel max is a method for sampling categorical distributions.
Using float16 for better performance while maintaining reasonable quality.
"""
if temperature == 0.0:
return logits # Skip noise when temperature is 0
# Use float32 instead of float64 for better performance
logits = logits.to(torch.float32)
noise = torch.rand_like(logits, dtype=torch.float32)
gumbel_noise = (-torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise
def get_num_transfer_tokens(mask_index, steps):
"""
Precompute the number of tokens to transition at each step.
Optimized to be more efficient.
"""
mask_num = mask_index.sum(dim=1, keepdim=True)
base = mask_num // steps
remainder = mask_num % steps
# Create tensor once and modify in-place
num_transfer_tokens = base.expand(-1, steps).clone()
# Handle remainder more efficiently
if remainder.sum() > 0:
indices = torch.arange(steps, device=mask_index.device)
mask = indices.unsqueeze(0) < remainder
num_transfer_tokens[mask] += 1
return num_transfer_tokens.to(torch.int64)
@torch.no_grad()
def generate_mdm(
model,
prompt,
tokenizer,
steps,
gen_length,
block_length,
temperature,
cfg_scale,
remasking,
mask_id: int = 126336,
):
"""
Optimized version of the generate function.
"""
# Use mixed precision for faster computation
if prompt is not None:
B, prompt_len = prompt.shape[0], prompt.shape[1]
else:
B = 1
prompt_len = 0
with torch.autocast(device_type="cuda"):
x = torch.full(
(B, prompt_len + gen_length), mask_id, dtype=torch.long, device=model.device
)
if prompt is not None:
x[:, : prompt.shape[1]] = prompt.clone()
prompt_index = x != mask_id
# get rank iff in the distributed setting
rank = dist.get_rank() if dist.is_initialized() else 0
assert gen_length % block_length == 0
num_blocks = gen_length // block_length
steps_per_block = max(1, steps // num_blocks)
for num_block in tqdm(range(num_blocks), disable=(rank != 0)):
start_idx = prompt_len + num_block * block_length
end_idx = prompt_len + (num_block + 1) * block_length
print(start_idx, end_idx)
block_mask_index = x[:, start_idx:end_idx] == mask_id
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)
for i in range(steps_per_block):
mask_index = x == mask_id
# Handle classifier-free guidance more efficiently
if cfg_scale > 0.0:
un_x = x.clone()
un_x[prompt_index] = mask_id
x_ = torch.cat([x, un_x], dim=0)
# Get logits in a single forward pass
logits = model(x_).logits
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
else:
logits = model(x).logits
# Apply Gumbel noise for sampling
logits_with_noise = add_gumbel_noise(logits, temperature)
x0 = torch.argmax(logits_with_noise, dim=-1)
# Handle remasking strategy
if remasking == "low_confidence":
# Use float32 instead of float64 for better performance
p = F.softmax(logits, dim=-1)
x0_p = torch.gather(p, dim=-1, index=x0.unsqueeze(-1)).squeeze(-1)
elif remasking == "random":
x0_p = torch.rand(x0.shape, device=x0.device)
else:
raise NotImplementedError(remasking)
# Ensure we don't process tokens beyond the current block
x0_p[:, end_idx:] = -np.inf
# Update masked tokens
x0 = torch.where(mask_index, x0, x)
confidence = torch.where(mask_index, x0_p, torch.tensor(-np.inf, device=x0.device))
# Select tokens to transfer based on confidence
for j in range(confidence.shape[0]):
num_tokens = num_transfer_tokens[j, i].item()
if num_tokens > 0:
_, select_indices = torch.topk(confidence[j], k=num_tokens)
x[j, select_indices] = x0[j, select_indices]
return x