TraDo-4B-Instruct / configuration_sdar.py
exdysa's picture
Update configuration_sdar.py
6daf6b2 verified
raw
history blame
7.93 kB
# SPDX-License-Identifier: Apache-2.0
# adapted fromhttps://github.com/Gen-Verse/dLLM-RL
# adapted from SADR https://github.com/JetAstra/SDAR/blob/main/generate.py
import torch
from torch.nn import functional as F
from transformers.cache_utils import DynamicCache
def top_k_logits(logits, k):
if k <= 0:
return logits
else:
values, _ = torch.topk(logits, k)
min_values = values[..., -1, None]
return torch.where(logits < min_values, torch.full_like(logits, float("-inf")), logits)
def top_p_logits(logits, p):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_mask = cumulative_probs > p
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
sorted_mask[..., 0] = False
mask_indices = torch.scatter(torch.full_like(logits, False, dtype=torch.bool), -1, sorted_indices, sorted_mask)
logits = logits.masked_fill(mask_indices, float("-inf"))
return logits
def sample_with_temperature_topk_topp(logits, temperature=1.0, top_k=0, top_p=1.0):
orig_shape = logits.shape[:-1] # [batch, block]
vocab_size = logits.shape[-1]
logits = logits.reshape(-1, vocab_size) # [batch*block, vocab]
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
logits = top_k_logits(logits, top_k)
if top_p < 1.0:
logits = top_p_logits(logits, top_p)
probs = F.softmax(logits, dim=-1) # shape: [batch*block, vocab]
assert probs.dim() == 2
token = torch.multinomial(probs, num_samples=1) # [batch*block, 1]
token_prob = torch.gather(probs, -1, token) # [batch*block, 1]
return token.view(*orig_shape), token_prob.view(*orig_shape)
def get_num_transfer_tokens(block_length, steps):
base = block_length // steps
remainder = block_length % steps
num_transfer_tokens = torch.zeros(steps, dtype=torch.int64) + base
num_transfer_tokens[:remainder] += 1
return num_transfer_tokens
@torch.no_grad()
def block_diffusion_generate(
model,
prompt,
mask_id,
gen_length=128,
block_length=8,
denoising_steps=8,
temperature=1.0,
top_k=0,
top_p=1.0,
remasking_strategy="low_confidence_dynamic",
confidence_threshold=0.85,
stopping_criteria_idx=None,
):
model.eval()
input_ids = prompt["input_ids"]
prompt_length = input_ids.shape[1]
past_key_values = DynamicCache()
num_blocks = (prompt_length + gen_length + block_length - 1) // block_length
total_length = num_blocks * block_length
block_mask = torch.tril(torch.ones(num_blocks, num_blocks, device=model.device))
block_diffusion_attention_mask = block_mask.repeat_interleave(block_length, dim=0).repeat_interleave(block_length, dim=1).unsqueeze(0)
position_ids = torch.arange(total_length, device=model.device).unsqueeze(0)
x = torch.full((1, total_length), mask_id, dtype=torch.long, device=model.device)
x[:, :prompt_length] = input_ids
prefill_blocks = prompt_length // block_length
prefill_length = prefill_blocks * block_length
# Prefill stage
if prefill_length > 0:
cur_x = x[:, :prefill_length]
cur_attn_mask = block_diffusion_attention_mask[:, :prefill_length, :prefill_length]
if cur_attn_mask.dim() == 3:
cur_attn_mask = cur_attn_mask[:, None, :, :]
cur_position_ids = position_ids[:, :prefill_length]
model(cur_x, attention_mask=cur_attn_mask, position_ids=cur_position_ids, past_key_values=past_key_values, use_cache=True, store_kv=True)
num_transfer_tokens = get_num_transfer_tokens(block_length, denoising_steps)
# Decode stage
for num_block in range(prefill_blocks, num_blocks):
cur_x = x[:, num_block * block_length : (num_block + 1) * block_length].clone()
cur_attn_mask = block_diffusion_attention_mask[:, num_block * block_length : (num_block + 1) * block_length, : (num_block + 1) * block_length]
if cur_attn_mask.dim() == 3:
cur_attn_mask = cur_attn_mask[:, None, :, :]
cur_position_ids = position_ids[:, num_block * block_length : (num_block + 1) * block_length]
for step in range(denoising_steps + 1):
mask_index = cur_x == mask_id
if mask_index.sum() == 0:
# Store kv cache
model(cur_x, attention_mask=cur_attn_mask, position_ids=cur_position_ids, past_key_values=past_key_values, use_cache=True, store_kv=True)
break
# Denosing
output = model(cur_x, attention_mask=cur_attn_mask, position_ids=cur_position_ids, past_key_values=past_key_values, use_cache=True, store_kv=False)
# Extract logits from the output - handle both CausalLMOutputWithPast and BaseModelOutputWithPast
if hasattr(output, "logits") and output.logits is not None:
logits = output.logits
elif hasattr(output, "last_hidden_state"):
# If logits don't exist but we have hidden states, compute logits from the model's lm_head
# This can happen if the model returns BaseModelOutputWithPast instead of CausalLMOutputWithPast
if hasattr(model, "lm_head"):
hidden_states = output.last_hidden_state
logits = model.lm_head(hidden_states)
else:
raise ValueError("Model output does not contain logits and model does not have lm_head to compute them.")
else:
raise ValueError(f"Unexpected model output type: {type(output)}. Expected CausalLMOutputWithPast or BaseModelOutputWithPast with logits or last_hidden_state.")
# Sampling
x0, x0_p = sample_with_temperature_topk_topp(logits, temperature=temperature, top_k=top_k, top_p=top_p)
# Sampling strategy
if remasking_strategy == "sequential":
transfer_index = torch.zeros_like(x0, dtype=torch.bool)
for j in range(cur_x.shape[0]):
if mask_index[j].any():
first_mask_index = mask_index[j].nonzero(as_tuple=True)[0].min().item()
transfer_index[j, first_mask_index : first_mask_index + num_transfer_tokens[step]] = True
else:
raise ValueError("No mask tokens found in the current block.")
elif remasking_strategy == "low_confidence_static":
confidence = torch.where(mask_index, x0_p, -torch.inf)
transfer_index = torch.zeros_like(x0, dtype=torch.bool)
for j in range(confidence.shape[0]):
_, idx = torch.topk(confidence[j], num_transfer_tokens[step])
transfer_index[j, idx] = True
elif remasking_strategy == "low_confidence_dynamic":
confidence = torch.where(mask_index, x0_p, -torch.inf)
transfer_index = torch.zeros_like(x0, dtype=torch.bool)
for j in range(confidence.shape[0]):
high_conf_mask = confidence[j] > confidence_threshold
num_high_confidence = high_conf_mask.sum()
if num_high_confidence >= num_transfer_tokens[step]:
transfer_index[j] = high_conf_mask
else:
_, idx = torch.topk(confidence[j], num_transfer_tokens[step])
transfer_index[j, idx] = True
else:
raise ValueError(f"Unknown remasking strategy: {remasking_strategy}")
cur_x[transfer_index] = x0[transfer_index]
x[:, num_block * block_length : (num_block + 1) * block_length] = cur_x
if stopping_criteria_idx is not None and any(stop_idx in x[:, prompt_length:] for stop_idx in stopping_criteria_idx):
break
return x