Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer | |
| from dataclasses import dataclass | |
| import os | |
| import math | |
| # ============== Model Architecture ============== | |
| class RMSNorm(nn.Module): | |
| """Root Mean Square Layer Normalization.""" | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x): | |
| var = x.pow(2).mean(-1, keepdim=True) | |
| x = x * torch.rsqrt(var + self.eps) | |
| return self.weight * x | |
| class RotaryEmbedding(nn.Module): | |
| """Rotary Position Embeddings (RoPE) with NTK extrapolation.""" | |
| def __init__(self, dim, max_position_embeddings=16384, base=100000, scaling_factor=1.0): | |
| super().__init__() | |
| self.scaling_factor = scaling_factor | |
| self.dim = dim | |
| self.base = base | |
| self.max_position_embeddings = max_position_embeddings | |
| self.inv_freq = None | |
| self._cache = {} | |
| def _update_freqs(self, device): | |
| base = self.base * (self.scaling_factor ** (self.dim / (self.dim - 2))) | |
| inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) | |
| self.inv_freq = inv_freq | |
| def forward(self, x, seq_len=None): | |
| if seq_len is None: | |
| seq_len = x.shape[-2] | |
| if self.inv_freq is None or self.inv_freq.device != x.device: | |
| self._update_freqs(x.device) | |
| cache_key = (seq_len, x.device, x.dtype) | |
| if cache_key in self._cache: | |
| return self._cache[cache_key] | |
| t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) | |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| cos = emb.cos()[None, None, :, :] | |
| sin = emb.sin()[None, None, :, :] | |
| self._cache[cache_key] = (cos, sin) | |
| if len(self._cache) > 10: | |
| self._cache.pop(next(iter(self._cache))) | |
| return cos, sin | |
| def apply_rotary_pos_emb(q, k, cos, sin): | |
| """Apply rotary embeddings to Q and K.""" | |
| def rotate_half(x): | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2:] | |
| return torch.cat((-x2, x1), dim=-1) | |
| q_embed = (q * cos) + (rotate_half(q) * sin) | |
| k_embed = (k * cos) + (rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| class DiffusionAttention(nn.Module): | |
| """Multi-head attention with GQA and Flash Attention support.""" | |
| def __init__(self, config): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.num_heads = config.num_attention_heads | |
| self.head_dim = self.hidden_size // self.num_heads | |
| self.num_key_value_heads = config.num_key_value_heads | |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads | |
| self.use_flash_attn = config.use_flash_attn | |
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) | |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) | |
| def forward(self, hidden_states, freqs_cis, attention_mask=None, past_kv=None): | |
| bsz, q_len, _ = hidden_states.size() | |
| q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| k = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| v = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| cos, sin = freqs_cis | |
| cos = cos[:, :, :q_len, :] | |
| sin = sin[:, :, :q_len, :] | |
| q, k = apply_rotary_pos_emb(q, k, cos, sin) | |
| if past_kv is not None: | |
| cache_k, cache_v = past_kv | |
| k = torch.cat([cache_k, k], dim=2) | |
| v = torch.cat([cache_v, v], dim=2) | |
| current_kv = (k, v) | |
| k = k.repeat_interleave(self.num_key_value_groups, dim=1) | |
| v = v.repeat_interleave(self.num_key_value_groups, dim=1) | |
| attn_mask = None | |
| if attention_mask is not None: | |
| attn_mask = attention_mask[:, None, None, :].to(dtype=q.dtype) | |
| attn_mask = (1.0 - attn_mask) * torch.finfo(q.dtype).min | |
| output = F.scaled_dot_product_attention( | |
| q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| output = output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size) | |
| return self.o_proj(output), current_kv | |
| class MLP(nn.Module): | |
| """Gated MLP with SiLU activation.""" | |
| def __init__(self, config): | |
| super().__init__() | |
| self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) | |
| self.act_fn = nn.SiLU() | |
| def forward(self, x): | |
| return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) | |
| class BlockDiffusionBlock(nn.Module): | |
| """Transformer block with pre-norm, attention, and MLP.""" | |
| def __init__(self, config): | |
| super().__init__() | |
| self.self_attn = DiffusionAttention(config) | |
| self.mlp = MLP(config) | |
| self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.use_activation_checkpointing = config.use_activation_checkpointing | |
| def forward(self, hidden_states, freqs_cis, attention_mask, past_kv): | |
| return self._forward(hidden_states, freqs_cis, attention_mask, past_kv) | |
| def _forward(self, hidden_states, freqs_cis, attention_mask, past_kv): | |
| residual = hidden_states | |
| hidden_states = self.input_layernorm(hidden_states) | |
| attn_out, new_kv = self.self_attn(hidden_states, freqs_cis, attention_mask, past_kv) | |
| hidden_states = residual + attn_out | |
| residual = hidden_states | |
| hidden_states = self.post_attention_layernorm(hidden_states) | |
| hidden_states = residual + self.mlp(hidden_states) | |
| return hidden_states, new_kv | |
| class ModelConfig: | |
| """Model architecture configuration.""" | |
| vocab_size: int = 151936 | |
| hidden_size: int = 1024 | |
| intermediate_size: int = 2816 | |
| num_hidden_layers: int = 16 | |
| num_attention_heads: int = 16 | |
| num_key_value_heads: int = 4 | |
| max_position_embeddings: int = 16384 | |
| rms_norm_eps: float = 1e-6 | |
| rope_theta: float = 100000.0 | |
| pad_token_id: int = 0 | |
| mask_token_id: int = 1 | |
| use_flash_attn: bool = True | |
| use_activation_checkpointing: bool = False | |
| attention_dropout: float = 0.0 | |
| hidden_dropout: float = 0.0 | |
| class DiffusionLLM(nn.Module): | |
| """Complete diffusion language model.""" | |
| def __init__(self, config: ModelConfig): | |
| super().__init__() | |
| self.config = config | |
| pad_idx = config.pad_token_id if config.pad_token_id < config.vocab_size else None | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=pad_idx) | |
| self.layers = nn.ModuleList([BlockDiffusionBlock(config) for _ in range(config.num_hidden_layers)]) | |
| self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| self.rotary_emb = RotaryEmbedding( | |
| config.hidden_size // config.num_attention_heads, | |
| config.max_position_embeddings | |
| ) | |
| self.lm_head.weight = self.embed_tokens.weight | |
| def forward(self, input_ids, attention_mask=None, past_key_values=None): | |
| bsz, seqlen = input_ids.shape | |
| hidden_states = self.embed_tokens(input_ids) | |
| freqs_cis = self.rotary_emb(hidden_states, seq_len=seqlen) | |
| if past_key_values is None: | |
| past_key_values = [None] * len(self.layers) | |
| new_kvs = [] | |
| for i, layer in enumerate(self.layers): | |
| hidden_states, kv = layer(hidden_states, freqs_cis, attention_mask, past_key_values[i]) | |
| new_kvs.append(kv) | |
| hidden_states = self.norm(hidden_states) | |
| logits = self.lm_head(hidden_states) | |
| return logits, new_kvs | |
| def get_num_params(self, trainable_only=True): | |
| if trainable_only: | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| else: | |
| return sum(p.numel() for p in self.parameters()) | |
| # ============== Inference Functions ============== | |
| def load_model(model_path: str, device: str = 'cuda'): | |
| """Load a saved model (fp16 or fp32) for inference.""" | |
| print(f"Loading model from {model_path}...") | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=False) | |
| config = checkpoint['config'] | |
| model = DiffusionLLM(config) | |
| state_dict = checkpoint['model_state'] | |
| state_dict = {k: v.float() for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| model = model.to(device) | |
| model.eval() | |
| num_params = model.get_num_params() / 1e6 | |
| file_size = os.path.getsize(model_path) / 1e6 | |
| print(f"✓ Model loaded: {num_params:.1f}M params from {file_size:.1f} MB file") | |
| return model, config | |
| def visualize_diffusion_state(tokenizer, context_ids, mask_blocks, is_masked_list, config, clear=True, block_colors=None): | |
| """Visualize the current state of diffusion generation with multiple blocks. | |
| Args: | |
| mask_blocks: Either a single block tensor (1, block_size) or list of block tensors | |
| is_masked_list: Either a single mask tensor (1, block_size) or list of mask tensors | |
| block_colors: List of ANSI color codes for each block. If None, uses defaults. | |
| """ | |
| import sys | |
| import os | |
| # Default colors for different blocks (green, cyan, yellow, magenta) | |
| DEFAULT_COLORS = ['\033[92m', '\033[96m', '\033[93m', '\033[95m'] | |
| MASK_COLOR = '\033[90m' # Gray for masked tokens | |
| RESET = '\033[0m' | |
| # Normalize inputs to lists | |
| if not isinstance(mask_blocks, list): | |
| mask_blocks = [mask_blocks] | |
| is_masked_list = [is_masked_list] | |
| if block_colors is None: | |
| block_colors = DEFAULT_COLORS | |
| # Decode context (prompt + previously generated blocks) and replace newlines | |
| context_text = tokenizer.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ') | |
| # Build visualization for all blocks | |
| all_blocks_text = [] | |
| for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)): | |
| color = block_colors[block_idx % len(block_colors)] | |
| block_tokens = mask_block[0].tolist() | |
| block_color_tokens = [] | |
| for i, token_id in enumerate(block_tokens): | |
| if is_masked[0, i]: | |
| # Use block-specific color for masked tokens to distinguish blocks | |
| block_color_tokens.append(f'{MASK_COLOR}██{RESET}') | |
| else: | |
| # Decode individual token; use block color for revealed tokens | |
| token_text = tokenizer.decode([token_id], skip_special_tokens=False) | |
| block_color_tokens.append(f'{color}{token_text}{RESET}') | |
| all_blocks_text.append(''.join(block_color_tokens)) | |
| # Join all blocks with a subtle separator | |
| blocks_combined = ''.join(all_blocks_text) | |
| # Clear entire terminal | |
| if clear: | |
| clear_cmd = 'cls' if os.name == 'nt' else 'clear' | |
| try: | |
| os.system(clear_cmd) | |
| except Exception: | |
| sys.stdout.write('\r\033[K') | |
| # Print legend for parallel blocks | |
| if len(mask_blocks) > 1: | |
| legend_parts = [] | |
| for i in range(len(mask_blocks)): | |
| color = block_colors[i % len(block_colors)] | |
| legend_parts.append(f'{color}Block {i+1}{RESET}') | |
| print(f"Generating: {' | '.join(legend_parts)}\n") | |
| # Print the full context with colored blocks | |
| print(f"{context_text}{blocks_combined}", flush=True) | |
| def demo_visualize_truncation(): | |
| """Demo for visualize_diffusion_state without a full model. | |
| Simulates streaming output and verifies there is no line duplication when content exceeds terminal width. | |
| """ | |
| class MockTokenizer: | |
| def __init__(self): | |
| # Map token id to token text (simple ASCII characters and spaces) | |
| self.vocab = {i: chr(65 + (i % 26)) for i in range(256)} | |
| self.vocab[32] = ' ' | |
| self.eos_token = '\n' | |
| self.pad_token = ' ' | |
| def decode(self, ids, skip_special_tokens=True): | |
| # ids can be tensor or list | |
| if isinstance(ids, torch.Tensor): | |
| ids = ids.tolist() | |
| if isinstance(ids, (list, tuple)): | |
| return ''.join(self.vocab.get(int(i) % 256, '?') for i in ids) | |
| return str(ids) | |
| tok = MockTokenizer() | |
| # Create a long context and a block that's also long | |
| # Make context exceed terminal width | |
| term_width = 80 | |
| long_context_ids = torch.tensor([[i % 26 + 65 for i in range(120)]], dtype=torch.long) | |
| block_size = 32 | |
| mask_block = torch.full((1, block_size), 32, dtype=torch.long) # spaces | |
| is_masked = torch.ones(1, block_size, dtype=torch.bool) | |
| for i in range(0, block_size, 3): | |
| is_masked[0, i] = False | |
| mask_block[0, i] = 65 + (i % 26) | |
| print('\nRunning demo: long prompt + block to test truncation\n') | |
| for i in range(8): | |
| visualize_diffusion_state(tok, long_context_ids, [mask_block], [is_masked], ModelConfig(), clear=(i > 0)) | |
| # rotate some tokens to simulate diffusion | |
| mask_block = torch.roll(mask_block, shifts=1, dims=1) | |
| time_delay = 0.08 | |
| try: | |
| import time | |
| time.sleep(time_delay) | |
| except Exception: | |
| pass | |
| print('\n\nDemo completed.') | |
| def generate_block_diffusion( | |
| model, | |
| tokenizer, | |
| prompt: str, | |
| steps: int = 16, | |
| block_size: int = 64, | |
| max_new_tokens: int = 256, | |
| device: str = 'cuda', | |
| temperature: float = 1.0, | |
| top_k: int = 50, | |
| top_p: float = 0.9, | |
| repetition_penalty: float = 1.2, | |
| no_repeat_ngram_size: int = 3, | |
| visualize: bool = False, | |
| parallel_blocks: int = 1, # Number of blocks to generate in parallel | |
| ): | |
| """Generate text using block diffusion with proper sampling and repetition control. | |
| Args: | |
| visualize: If True, stream output in real-time showing the diffusion effect. | |
| parallel_blocks: Number of blocks to generate in parallel (1-4 recommended). | |
| """ | |
| import time | |
| model.eval() | |
| prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| config = model.module.config if hasattr(model, 'module') else model.config | |
| if hasattr(model, '_orig_mod'): | |
| config = model._orig_mod.config | |
| num_blocks = max_new_tokens // block_size | |
| parallel_blocks = min(parallel_blocks, num_blocks) # Can't parallelize more than total blocks | |
| if not visualize: | |
| if parallel_blocks > 1: | |
| print(f"Generating {num_blocks} blocks of {block_size} tokens each ({parallel_blocks} blocks in parallel)...") | |
| else: | |
| print(f"Generating {num_blocks} blocks of {block_size} tokens each...") | |
| else: | |
| print(f"\n\033[94mStarting diffusion generation...\033[0m\n") | |
| print(prompt, end='', flush=True) | |
| context_ids = prompt_ids | |
| all_generated_tokens = set(prompt_ids[0].tolist()) | |
| # Process blocks in batches of parallel_blocks | |
| blocks_generated = 0 | |
| while blocks_generated < num_blocks: | |
| # Determine how many blocks to generate this iteration | |
| current_parallel = min(parallel_blocks, num_blocks - blocks_generated) | |
| if current_parallel > 1: | |
| # Parallel block generation | |
| generated_blocks = _generate_parallel_blocks( | |
| model, tokenizer, context_ids, config, device, | |
| current_parallel, block_size, steps, temperature, | |
| top_k, top_p, repetition_penalty, no_repeat_ngram_size, | |
| all_generated_tokens, visualize | |
| ) | |
| # Concatenate all generated blocks to context | |
| for block in generated_blocks: | |
| context_ids = torch.cat([context_ids, block], dim=1) | |
| all_generated_tokens.update(block[0].tolist()) | |
| if not visualize: | |
| print(f" Blocks {blocks_generated + 1}-{blocks_generated + current_parallel}/{num_blocks} complete") | |
| blocks_generated += current_parallel | |
| else: | |
| # Single block generation (original logic) | |
| mask_block, block_token_history = _generate_single_block( | |
| model, tokenizer, context_ids, config, device, | |
| block_size, steps, temperature, top_k, top_p, | |
| repetition_penalty, no_repeat_ngram_size, | |
| all_generated_tokens, visualize | |
| ) | |
| context_ids = torch.cat([context_ids, mask_block], dim=1) | |
| all_generated_tokens.update(mask_block[0].tolist()) | |
| if not visualize: | |
| print(f" Block {blocks_generated + 1}/{num_blocks} complete") | |
| blocks_generated += 1 | |
| if visualize: | |
| # Final newline after visualization | |
| print("\n") | |
| generated_ids = context_ids[0].tolist() | |
| return tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| def _generate_single_block( | |
| model, tokenizer, context_ids, config, device, | |
| block_size, steps, temperature, top_k, top_p, | |
| repetition_penalty, no_repeat_ngram_size, | |
| all_generated_tokens, visualize | |
| ): | |
| """Generate a single block using diffusion.""" | |
| mask_block = torch.full((1, block_size), config.mask_token_id, device=device) | |
| is_masked = torch.ones(1, block_size, dtype=torch.bool, device=device) | |
| block_token_history = [] | |
| for step_idx in range(steps): | |
| full_input = torch.cat([context_ids, mask_block], dim=1) | |
| attention_mask = torch.ones_like(full_input, dtype=torch.float32) | |
| logits, _ = model(full_input, attention_mask=attention_mask) | |
| block_logits = logits[:, -block_size:, :] | |
| block_logits = _apply_sampling_controls( | |
| block_logits, context_ids, mask_block, is_masked, | |
| repetition_penalty, temperature, top_k, top_p, | |
| no_repeat_ngram_size, block_token_history | |
| ) | |
| probs = F.softmax(block_logits, dim=-1) | |
| probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) | |
| probs = probs.clamp(min=1e-10) | |
| probs = probs / probs.sum(dim=-1, keepdim=True) | |
| sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1) | |
| sampled_tokens = sampled_tokens.view(1, block_size) | |
| confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1) | |
| tokens_to_unmask = max(1, block_size // steps) | |
| if step_idx == steps - 1: | |
| tokens_to_unmask = is_masked.sum().item() | |
| if tokens_to_unmask > 0 and is_masked.sum() > 0: | |
| masked_confidence = confidence.clone() | |
| masked_confidence[~is_masked] = -1.0 | |
| num_to_unmask = min(tokens_to_unmask, is_masked.sum().item()) | |
| _, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask) | |
| for idx in top_indices: | |
| mask_block[0, idx] = sampled_tokens[0, idx] | |
| is_masked[0, idx] = False | |
| block_token_history.append(sampled_tokens[0, idx].item()) | |
| all_generated_tokens.add(sampled_tokens[0, idx].item()) | |
| if visualize: | |
| visualize_diffusion_state(tokenizer, context_ids, [mask_block], [is_masked], config, clear=(step_idx > 0)) | |
| return mask_block, block_token_history | |
| def _generate_parallel_blocks( | |
| model, tokenizer, context_ids, config, device, | |
| num_parallel, block_size, steps, temperature, | |
| top_k, top_p, repetition_penalty, no_repeat_ngram_size, | |
| all_generated_tokens, visualize | |
| ): | |
| """Generate multiple blocks in parallel using batched computation. | |
| Each block sees all previous blocks in the sequence, maintaining proper order: | |
| - Block 0: context + [block0] | |
| - Block 1: context + [block0] + [block1] | |
| - Block 2: context + [block0] + [block1] + [block2] | |
| - etc. | |
| This ensures sequential coherence while still benefiting from batched computation. | |
| """ | |
| batch_size = num_parallel | |
| context_len = context_ids.shape[1] | |
| # Initialize mask blocks for all parallel blocks | |
| # Shape: (num_parallel, block_size) | |
| mask_blocks = torch.full((batch_size, block_size), config.mask_token_id, device=device) | |
| is_masked = torch.ones(batch_size, block_size, dtype=torch.bool, device=device) | |
| block_token_histories = [[] for _ in range(batch_size)] | |
| for step_idx in range(steps): | |
| # Build inputs with proper sequential structure | |
| # Each batch item has context + all blocks up to and including its own position | |
| # Block i sees: context + block_0 + block_1 + ... + block_i | |
| # Create padded inputs - each batch item has different length | |
| # We'll pad to the longest sequence (which is the last block) | |
| max_seq_len = context_len + (num_parallel * block_size) | |
| # Build full input for each batch item | |
| full_inputs = [] | |
| attention_masks = [] | |
| for b in range(batch_size): | |
| # This block sees: context + all previous blocks + its own block | |
| seq_parts = [context_ids[0]] # Start with context | |
| # Add all blocks from 0 to b (inclusive) | |
| for prev_b in range(b + 1): | |
| seq_parts.append(mask_blocks[prev_b]) | |
| # Concatenate to form this batch item's input | |
| batch_input = torch.cat(seq_parts, dim=0) # (seq_len,) | |
| current_len = batch_input.shape[0] | |
| # Pad to max_seq_len | |
| padding_needed = max_seq_len - current_len | |
| if padding_needed > 0: | |
| padding = torch.full((padding_needed,), config.pad_token_id, device=device) | |
| batch_input = torch.cat([batch_input, padding], dim=0) | |
| full_inputs.append(batch_input) | |
| # Create attention mask (1 for real tokens, 0 for padding) | |
| attn_mask = torch.zeros(max_seq_len, device=device) | |
| attn_mask[:current_len] = 1.0 | |
| attention_masks.append(attn_mask) | |
| # Stack into batched tensors | |
| full_input = torch.stack(full_inputs, dim=0) # (batch, max_seq_len) | |
| attention_mask = torch.stack(attention_masks, dim=0) # (batch, max_seq_len) | |
| # Single forward pass for all blocks | |
| logits, _ = model(full_input, attention_mask=attention_mask) | |
| # Extract logits for each block's position | |
| # Block b's logits are at positions [context_len + b*block_size : context_len + (b+1)*block_size] | |
| block_logits_list = [] | |
| for b in range(batch_size): | |
| start_pos = context_len + (b * block_size) | |
| end_pos = start_pos + block_size | |
| block_logits_list.append(logits[b, start_pos:end_pos, :]) | |
| block_logits = torch.stack(block_logits_list, dim=0) # (batch, block_size, vocab) | |
| # Apply sampling controls per batch item | |
| for b in range(batch_size): | |
| # Build context that includes previous blocks for repetition penalty | |
| extended_context = context_ids | |
| if b > 0: | |
| prev_blocks = torch.cat([mask_blocks[pb:pb+1] for pb in range(b)], dim=1) | |
| extended_context = torch.cat([context_ids, prev_blocks], dim=1) | |
| block_logits[b:b+1] = _apply_sampling_controls( | |
| block_logits[b:b+1], | |
| extended_context, | |
| mask_blocks[b:b+1], | |
| is_masked[b:b+1], | |
| repetition_penalty, temperature, top_k, top_p, | |
| no_repeat_ngram_size, block_token_histories[b] | |
| ) | |
| probs = F.softmax(block_logits, dim=-1) | |
| probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) | |
| probs = probs.clamp(min=1e-10) | |
| probs = probs / probs.sum(dim=-1, keepdim=True) | |
| # Sample for all batches | |
| sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1) | |
| sampled_tokens = sampled_tokens.view(batch_size, block_size) | |
| confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1) | |
| tokens_to_unmask = max(1, block_size // steps) | |
| if step_idx == steps - 1: | |
| tokens_to_unmask = block_size # Unmask all remaining | |
| # Unmask for each batch item | |
| for b in range(batch_size): | |
| if is_masked[b].sum() > 0: | |
| masked_confidence = confidence[b].clone() | |
| masked_confidence[~is_masked[b]] = -1.0 | |
| num_to_unmask = min(tokens_to_unmask, is_masked[b].sum().item()) | |
| if num_to_unmask > 0: | |
| _, top_indices = torch.topk(masked_confidence, num_to_unmask) | |
| for idx in top_indices: | |
| mask_blocks[b, idx] = sampled_tokens[b, idx] | |
| is_masked[b, idx] = False | |
| block_token_histories[b].append(sampled_tokens[b, idx].item()) | |
| if visualize: | |
| # Visualize all blocks with different colors | |
| block_list = [mask_blocks[b:b+1] for b in range(batch_size)] | |
| is_masked_list = [is_masked[b:b+1] for b in range(batch_size)] | |
| visualize_diffusion_state( | |
| tokenizer, context_ids, block_list, is_masked_list, | |
| config, clear=(step_idx > 0) | |
| ) | |
| # Return list of generated blocks | |
| return [mask_blocks[b:b+1] for b in range(batch_size)] | |
| def _apply_sampling_controls( | |
| block_logits, context_ids, mask_block, is_masked, | |
| repetition_penalty, temperature, top_k, top_p, | |
| no_repeat_ngram_size, block_token_history | |
| ): | |
| """Apply repetition penalty, temperature, top-k, top-p, and n-gram blocking.""" | |
| if repetition_penalty != 1.0: | |
| seen_tokens = set(context_ids[0].tolist()) | |
| for i in range(mask_block.shape[1]): | |
| if not is_masked[0, i]: | |
| seen_tokens.add(mask_block[0, i].item()) | |
| for token_id in seen_tokens: | |
| if token_id < block_logits.shape[-1]: | |
| if block_logits[0, :, token_id].mean() > 0: | |
| block_logits[:, :, token_id] /= repetition_penalty | |
| else: | |
| block_logits[:, :, token_id] *= repetition_penalty | |
| block_logits = block_logits / temperature | |
| if top_k > 0: | |
| top_k_logits, top_k_indices = torch.topk(block_logits, top_k, dim=-1) | |
| block_logits = torch.full_like(block_logits, float('-inf')) | |
| block_logits.scatter_(-1, top_k_indices, top_k_logits) | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(block_logits, descending=True, dim=-1) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove) | |
| block_logits[indices_to_remove] = float('-inf') | |
| if no_repeat_ngram_size > 0 and len(block_token_history) >= no_repeat_ngram_size - 1: | |
| recent_ngram = tuple(block_token_history[-(no_repeat_ngram_size-1):]) | |
| full_history = context_ids[0].tolist() + block_token_history | |
| for i in range(len(full_history) - no_repeat_ngram_size + 1): | |
| if tuple(full_history[i:i+no_repeat_ngram_size-1]) == recent_ngram: | |
| blocked_token = full_history[i + no_repeat_ngram_size - 1] | |
| if blocked_token < block_logits.shape[-1]: | |
| block_logits[:, :, blocked_token] = float('-inf') | |
| # Safety check: if all logits are -inf, reset to uniform distribution | |
| all_inf_mask = torch.isinf(block_logits).all(dim=-1) | |
| if all_inf_mask.any(): | |
| block_logits[all_inf_mask] = 0.0 | |
| return block_logits | |
| # ============== Main Entry Point ============== | |
| def main(): | |
| """Main inference function.""" | |
| # Configuration | |
| model_path = "../extra-final-boss/checkpoints/model_fp32.pt" | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # Allow a quick demo mode to test visualization without loading the model | |
| import sys | |
| if len(sys.argv) > 1 and sys.argv[1] == 'demo': | |
| demo_visualize_truncation() | |
| return | |
| # Load tokenizer | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load model | |
| model, config = load_model(model_path, device) | |
| # Generate text | |
| print("\n" + "=" * 50) | |
| print("Text Generation") | |
| print("=" * 50) | |
| prompt = "Barrack Obama was born in " | |
| print(f"Prompt: {prompt}\n") | |
| # Set visualize=True to see real-time diffusion effect | |
| visualize = True | |
| parallel_blocks = 4 # Generate 2-4 blocks in parallel for speedup | |
| generated = generate_block_diffusion( | |
| model, | |
| tokenizer, | |
| prompt=prompt, | |
| steps=64, | |
| block_size=64, | |
| max_new_tokens=512, | |
| device=device, | |
| temperature=1, | |
| top_k=40, | |
| top_p=0.9, | |
| repetition_penalty=1.3, | |
| no_repeat_ngram_size=3, | |
| visualize=visualize, | |
| parallel_blocks=parallel_blocks, | |
| ) | |
| print(f"\nGenerated text:\n{generated}") | |
| if __name__ == "__main__": | |
| main() | |