import torch import numpy as np import torch.nn.functional as F import time import re from collections import Counter from transformers import AutoTokenizer, AutoModel def add_gumbel_noise(logits, temperature): if temperature == 0: return logits logits = logits.to(torch.float64) noise = torch.rand_like(logits, dtype=torch.float64) gumbel_noise = (- torch.log(noise)) ** temperature return logits.exp() / gumbel_noise def get_num_transfer_tokens(block_mask_index: torch.Tensor, steps: int) -> torch.Tensor: device = block_mask_index.device dtype = torch.long total = block_mask_index.sum(dim=1) base = torch.div(total, steps, rounding_mode='floor') rem = total - base * steps num_transfer_tokens = base.unsqueeze(1).expand(-1, steps).to(dtype) cols = torch.arange(steps, device=device).unsqueeze(0) add_mask = cols < rem.unsqueeze(1) num_transfer_tokens = num_transfer_tokens + add_mask.to(dtype) return num_transfer_tokens # ================================================================= # [수정됨] top_prob_margin 지원 추가 # ================================================================= def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None): # 1) Sample proposal x0 logits_with_noise = add_gumbel_noise(logits, temperature=temperature) x0 = torch.argmax(logits_with_noise, dim=-1) # 2) Confidence for chosen tokens if remasking == "low_confidence": p = F.softmax(logits.to(torch.float64), dim=-1) x0_p = torch.gather(p, dim=-1, index=x0.unsqueeze(-1)).squeeze(-1) # [여기 추가됨!] top_prob_margin 로직 복원 elif remasking == "top_prob_margin": p = F.softmax(logits.to(torch.float64), dim=-1) top2_probs, _ = torch.topk(p, k=2, dim=-1) x0_p = top2_probs[..., 0] - top2_probs[..., 1] elif remasking == "random": x0_p = torch.rand(x0.shape, device=x0.device, dtype=torch.float64) else: raise NotImplementedError(remasking) # Only modify masked spots x0 = torch.where(mask_index, x0, x) neg_inf = torch.tensor(torch.finfo(x0_p.dtype).min, device=x0_p.device, dtype=x0_p.dtype) confidence = torch.where(mask_index, x0_p, neg_inf) # 3) Pick positions to transfer if threshold is not None: transfer_index = mask_index & (confidence >= threshold) max_conf_indices = torch.argmax(confidence, dim=1, keepdim=True) force_mask = torch.zeros_like(transfer_index).scatter_(1, max_conf_indices, True) transfer_index = transfer_index | force_mask transfer_index = transfer_index & mask_index return x0, transfer_index if num_transfer_tokens is None: raise ValueError("num_transfer_tokens must be a tensor when threshold is None.") if num_transfer_tokens.dim() == 2 and num_transfer_tokens.size(1) == 1: num_transfer_tokens = num_transfer_tokens.squeeze(1) num_transfer_tokens = num_transfer_tokens.to(dtype=torch.long, device=confidence.device) num_transfer_tokens = torch.clamp(num_transfer_tokens, min=0) values, idx = torch.sort(confidence, dim=1, descending=True) B, L = confidence.shape cols = torch.arange(L, device=confidence.device).unsqueeze(0).expand(B, L) k_expanded = num_transfer_tokens.unsqueeze(1).expand(B, L) select_sorted = cols < k_expanded transfer_int = torch.zeros(B, L, device=confidence.device, dtype=torch.int8) transfer_int = transfer_int.scatter(1, idx, select_sorted.to(torch.int8)) transfer_index = transfer_int.bool() & mask_index return x0, transfer_index # ================================================================= # [수정됨] top_prob_margin 지원 추가 (Dynamic 버전) # ================================================================= def get_transfer_index_dynamic(logits, temperature, remasking, mask_index, x, num_transfer_tokens, factor=1): logits_with_noise = add_gumbel_noise(logits, temperature=temperature) x0 = torch.argmax(logits_with_noise, dim=-1) if remasking == 'low_confidence': p = F.softmax(logits.to(torch.float64), dim=-1) x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # [여기 추가됨!] top_prob_margin 로직 복원 elif remasking == 'top_prob_margin': p = F.softmax(logits.to(torch.float64), dim=-1) top2_probs, _ = torch.topk(p, k=2, dim=-1) x0_p = top2_probs[..., 0] - top2_probs[..., 1] elif remasking == 'random': x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) else: raise NotImplementedError(remasking) x0 = torch.where(mask_index, x0, x) confidence = torch.where(mask_index, x0_p, -np.inf) transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) num_transfer_tokens = mask_index.sum(dim=1, keepdim=True) for j in range(confidence.shape[0]): num_tokens = int(num_transfer_tokens[j].item()) if num_tokens == 0: continue ns = list(range(1, num_transfer_tokens[j] + 1)) es = [factor / (n + 1) for n in ns] threshs = [1 - e for e in es] threshs[0] = -1 sorted_confidence = torch.sort(confidence[j][mask_index[j]], dim=-1, descending=True)[0] top_i = len(threshs) for i in range(len(threshs)): if sorted_confidence[i] < threshs[i]: top_i = i break if top_i == 0: top_i = 1 _, select_index = torch.topk(confidence[j], k=top_i) transfer_index[j, select_index] = True return x0, transfer_index # ================================================================= # generate_standard (기존 함수) # ================================================================= @ torch.no_grad() def generate_standard(model, prompt, attention_mask=None, steps=128, gen_length=128, block_length=128, temperature=0., cfg_scale=0., remasking='low_confidence', mask_id=126336, logits_eos_inf=False, confidence_eos_eot_inf=False): x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) x[:, :prompt.shape[1]] = prompt.clone() if attention_mask is not None: attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1) prompt_index = (x != mask_id) assert gen_length % block_length == 0 num_blocks = gen_length // block_length assert steps % num_blocks == 0 steps = steps // num_blocks for num_block in range(num_blocks): block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length] == mask_id) num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) for i in range(steps): mask_index = (x == mask_id) if cfg_scale > 0.: un_x = x.clone() un_x[prompt_index] = mask_id x_ = torch.cat([x, un_x], dim=0) if attention_mask is not None: attention_mask_ = torch.cat([attention_mask, attention_mask], dim=0) logits = model(x_, attention_mask=attention_mask_).logits logits, un_logits = torch.chunk(logits, 2, dim=0) logits = un_logits + (cfg_scale + 1) * (logits - un_logits) else: logits = model(x, attention_mask=attention_mask).logits if logits_eos_inf: logits[:, :, 126081] = -torch.inf logits_with_noise = add_gumbel_noise(logits, temperature=temperature) x0 = torch.argmax(logits_with_noise, dim=-1) if confidence_eos_eot_inf: logits_with_noise[:, :, 126081] = logits[:, :, 126348] = -torch.inf if remasking == 'low_confidence': p = F.softmax(logits, dim=-1) x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) elif remasking == 'top_prob_margin': p = F.softmax(logits, dim=-1) top2_probs, _ = torch.topk(p, k=2, dim=-1) x0_p = top2_probs[:, :, 0] - top2_probs[:, :, 1] elif remasking == 'random': x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) else: raise NotImplementedError(remasking) x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf x0 = torch.where(mask_index, x0, x) confidence = torch.where(mask_index, x0_p, -np.inf) transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) for j in range(confidence.shape[0]): _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) transfer_index[j, select_index] = True x[transfer_index] = x0[transfer_index] return x # ================================================================= # generate_with_dual_cache (최적화 함수) # ================================================================= @torch.no_grad() def generate_with_dual_cache( model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., remasking="low_confidence", mask_id=126336, threshold=None, factor=None, cfg_scale=0., logits_eos_inf=False, confidence_eos_eot_inf=False, attention_mask=None ): if cfg_scale > 0: print("⚠️ Warning: cfg_scale > 0 is not supported in Dual Cache mode. Falling back to standard generate.") return generate_standard(model, prompt, attention_mask, steps, gen_length, block_length, temperature, cfg_scale, remasking, mask_id, logits_eos_inf, confidence_eos_eot_inf) B = prompt.shape[0] Lp = int(prompt.shape[1]) assert gen_length % block_length == 0 num_blocks = gen_length // block_length assert steps % num_blocks == 0 steps_per_block = steps // num_blocks x = torch.full((B, Lp + gen_length), mask_id, dtype=torch.long, device=model.device) x[:, :Lp] = prompt nfe = 0 for nb in range(num_blocks): s = Lp + nb * block_length e = s + block_length block_mask_index = (x[:, s:e] == mask_id) num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block) # 1) Warm KV-cache out_full = model(x, use_cache=True) past_key_values = out_full.past_key_values nfe += 1 replace_position = torch.zeros_like(x, dtype=torch.bool) replace_position[:, s:e] = True global_mask_index = (x == mask_id) global_mask_index[:, e:] = False if factor is None: quota0 = None if threshold is not None else num_transfer_tokens[:, 0] # 여기 remasking 인자가 'top_prob_margin'이어도 이제 작동함 x0, transfer_index = get_transfer_index( out_full.logits, temperature, remasking, global_mask_index, x, quota0, threshold ) else: x0, transfer_index = get_transfer_index_dynamic( out_full.logits, temperature, remasking, global_mask_index, x, None, factor ) x = torch.where(transfer_index, x0, x) for i in range(1, steps_per_block): if (x[:, s:e] == mask_id).sum() == 0: break try: logits_blk = model( x[:, s:e], past_key_values=past_key_values, use_cache=True, replace_position=replace_position ).logits except TypeError: logits_blk = model( x[:, s:e], past_key_values=past_key_values, use_cache=True ).logits mask_blk = (x[:, s:e] == mask_id) if factor is None: quota_i = None if threshold is not None else num_transfer_tokens[:, i] x0_blk, transfer_idx_blk = get_transfer_index( logits_blk, temperature, remasking, mask_blk, x[:, s:e], quota_i, threshold ) else: x0_blk, transfer_idx_blk = get_transfer_index_dynamic( logits_blk, temperature, remasking, mask_blk, x[:, s:e], None, factor ) blk_old = x[:, s:e] blk_new = torch.where(transfer_idx_blk, x0_blk, blk_old) x = torch.cat([x[:, :s], blk_new, x[:, e:]], dim=1) nfe += 1 return x # Alias generate = generate_standard