Stable-DiffCoder-8B-Instruct / modeling_stable_diffcoder.py
Seas0's picture
Update to support transformers v5.3.0
658a484 verified
raw
history blame
11.3 kB
# Copyright (c) 2026 ByteDance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, DynamicCache
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.generation.utils import GenerationConfig
class StableDiffcoderForCausalLM(LlamaForCausalLM):
def _get_num_transfer_tokens(self, mask_map, steps):
# Only bs == 1 is supported for now
mask_num = mask_map.sum().long().item()
base = mask_num // steps
remainder = mask_num % steps
num_transfer_tokens = torch.full(
(steps,), fill_value=base, device=mask_map.device, dtype=torch.long
)
num_transfer_tokens[:remainder] += 1
return num_transfer_tokens
def _make_block_causal_mask(
self, seq_len, block_size=2, device=None, dtype=torch.bfloat16
):
# ceil(seq_len / block_size)
num_blocks = (seq_len + block_size - 1) // block_size
# create a block-wise causal mask using Kronecker product
# global_mask = block_wise_mask ⊗ per_block_local_mask
block_mask = torch.tril(
torch.ones((num_blocks, num_blocks), dtype=torch.bool, device=device)
)
local_block = torch.ones(
(block_size, block_size), dtype=torch.bool, device=device
)
mask = block_mask.kron(local_block)[:seq_len, :seq_len]
# [x] [ ] [ ] [ )
# [x] [x] [ ] [ )
# [x] [x] [x] [ )
# [x] [x] [x] [x)
# TODO: remove this itchy -inf masking method.
attention_mask = mask.float()
attention_mask.masked_fill_(~mask, -torch.inf)
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).to(dtype)
return attention_mask
def _get_transfer_index(
self,
logits,
temperature,
remasking,
mask_index,
x,
num_transfer_token,
threshold=None,
shift=False,
):
def add_gumbel_noise(logits, temperature):
if temperature == 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
gumbel_noise = (-torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
if shift == True:
x0 = torch.cat([x[:, :1], x0[:, :-1]], dim=-1)
pad = torch.zeros_like(logits[:, :1])
logits = torch.cat([pad, logits[:, :-1]], dim=1)
if remasking == "low_confidence":
p = F.softmax(logits.to(torch.float64), dim=-1)
x0_p = torch.squeeze(
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1
) # b, l
elif remasking == "random":
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
else:
raise NotImplementedError(remasking)
x0 = torch.where(mask_index, x0, x)
confidence = torch.where(mask_index, x0_p, -np.inf)
transfer_map = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
if threshold is not None:
num_transfer_token = mask_index.sum(dim=1, keepdim=True)
_, select_index = torch.topk(confidence[0], k=num_transfer_token)
transfer_map[0, select_index] = True
if threshold is not None:
for k in range(1, num_transfer_token):
if confidence[0, select_index[k]] < threshold:
transfer_map[0, select_index[k]] = False
return x0, transfer_map
@torch.no_grad()
def generate_block(
self,
input_ids: torch.LongTensor,
steps=128,
gen_length=128,
block_length=4,
temperature=0.0,
remasking="low_confidence",
tokenizer=None,
mask_id=5,
threshold=0.95,
shift=False,
eos_id=None,
):
# initialize x with mask_id and copy prompt to the beginning
# x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(
# self.device
# )
# x[:, : prompt.shape[1]] = prompt.clone()
x = torch.cat(
[
input_ids,
torch.full(
(input_ids.shape[0], gen_length),
mask_id,
dtype=torch.long,
device=input_ids.device,
),
],
dim=1,
)
# check the validity of block count
assert gen_length % block_length == 0, (
"gen_length must be divisible by block_length"
)
gen_blocks = gen_length // block_length
# check the validity of sampling steps
assert steps % gen_blocks == 0, (
"steps must be divisible by the number of generation blocks"
)
steps = steps // gen_blocks
# check bs == 1
assert x.shape[0] == 1, (
"Only batch size of 1 is supported for block-wise generation currently."
)
# construct block lengths
prompt_length = input_ids.shape[1]
gen_block_list = [block_length for _ in range(gen_blocks)]
# if the prompt is not aligned with block boundary
# adjust the first block and the last block accordingly
res_block = block_length - (prompt_length % block_length)
if res_block > 0:
gen_block_list = [res_block] + gen_block_list
gen_block_list[-1] = block_length - res_block
gen_blocks += 1
# cumulative block lengths (pfxSum for attn mask construction)
cum_block = [sum(gen_block_list[: i + 1]) for i in range(len(gen_block_list))]
# make block-wise causal diffusion attention mask
block_diffusion_attention_mask = self._make_block_causal_mask(
prompt_length + gen_length,
block_length,
self.device,
dtype=torch.bfloat16,
)
# TODO: better cache initialization method
past_key_values = DynamicCache()
# prefill the kv cache with prompt as input
nfe = 0
final_flag = False
# align prompt_length to block_length boundary
prefill_length = prompt_length // block_length * block_length
if prefill_length > 0:
cur_attn_mask = block_diffusion_attention_mask[
:, :, :prefill_length, :prefill_length
]
self(
x[:, :prefill_length],
past_key_values=past_key_values,
attention_mask=cur_attn_mask,
use_cache=True,
).past_key_values
# iterative block-wise generation
for block_id, block_size in enumerate(gen_block_list):
# print(
# f"Generating block {block_id + 1}/{gen_blocks} with {steps} steps..."
# )
block_start = (
prompt_length + cum_block[block_id - 1]
if block_id > 0
else prefill_length
)
block_end = prompt_length + cum_block[block_id]
# print(f"Current block range: [{block_start}, {block_end})")
block_mask_map = x[:, block_start:block_end] == mask_id
# sampling noise schedule
num_transfer_tokens = self._get_num_transfer_tokens(block_mask_map, steps)
# print(f"DEBUG: {num_transfer_tokens=}")
replace_position = torch.zeros_like(x, dtype=torch.bool)
replace_position[:, block_start:block_end] = True
for token_count in num_transfer_tokens:
if token_count:
# print(f"Transferring {token_count} tokens in block {block_id + 1}/{gen_blocks}...")
nfe += 1
mask_map = x[:, block_start:block_end] == mask_id
attention_mask = block_diffusion_attention_mask[
..., block_start:block_end, :block_end
]
output = self(
x[:, block_start:block_end],
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
cache_position=replace_position.nonzero(as_tuple=True)[1],
)
logits = output.logits
# crop the kv cache as we didn't finish the cur. blk
# IMPORTANT: check the correctness
past_key_values.crop(block_start)
# unmask based on policy of logits
x0, transfer_map = self._get_transfer_index(
logits,
temperature,
remasking,
mask_map,
x[:, block_start:block_end],
token_count if threshold is None else None,
threshold,
shift=False,
)
x[:, block_start:block_end][transfer_map] = x0[transfer_map]
if (x[:, block_start:block_end] == mask_id).sum() == 0:
# check if all sequences in the batch have produced eos
# if eos_id is not None and (x[:, current_block_start:current_block_end] == eos_id).sum() > 0:
if (
eos_id is not None
and (x[:, block_start:block_end] == eos_id).sum() > 0
):
final_flag = True
x = x[:, :block_end]
# fill the rest of the sequence with eos_id if eos_id is specified
eos_pos = (x == eos_id).nonzero(as_tuple=True)[1][0].item()
x[0, eos_pos + 1:] = eos_id
break
nfe += 1
# update the kv cache
self(
x[:, block_start:block_end],
attention_mask=block_diffusion_attention_mask[
..., block_start:block_end, :block_end
],
past_key_values=past_key_values,
use_cache=True,
cache_position=replace_position.nonzero(as_tuple=True)[1],
)
break
if final_flag:
break
return x, nfe
@torch.no_grad()
def generate(
self,
input_ids=None,
generation_config: GenerationConfig = None,
**kwargs,
):
if input_ids is None:
raise ValueError("input_ids must be provided")
if generation_config is None:
generation_config = self.generation_config
output_ids, nfe = self.generate_block(
input_ids=input_ids,
**kwargs,
)
return output_ids