#!/usr/bin/env python3 """ Inference script for DiffusionQwen3 model checkpoint. Usage: # Interactive chat mode python inference.py --checkpoint ./outputs/pretrain/checkpoint-1000 --mode chat # Single prompt completion python inference.py --checkpoint ./outputs/pretrain/checkpoint-1000 --prompt "def fibonacci(n):" # With custom generation parameters python inference.py --checkpoint ./outputs/pretrain/checkpoint-1000 \ --prompt "Write a hello world in Python" \ --steps 128 --temperature 0.0 --max-tokens 256 """ import argparse import sys import os from typing import Optional, Tuple, List import torch import torch.nn.functional as F import torch.distributions as dists from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig # ============================================================================ # Diffusion Sampling Utilities (adapted from CoDALanguageModel/generation_utils.py) # ============================================================================ def top_p_logits(logits: torch.Tensor, top_p: float) -> torch.Tensor: """Apply nucleus (top-p) filtering to logits.""" sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 mask = torch.zeros_like(logits, dtype=torch.bool) mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) return logits def top_k_logits(logits: torch.Tensor, top_k: int) -> torch.Tensor: """Apply top-k filtering to logits.""" top_k = min(top_k, logits.size(-1)) indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) return logits def sample_tokens( logits: torch.Tensor, temperature: float = 0.0, top_p: Optional[float] = None, top_k: Optional[int] = None, neg_entropy: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Sample tokens from logits with optional temperature, top-p, and top-k. Returns: confidence: Confidence scores for sampled tokens x0: Sampled token IDs """ if temperature > 0: logits = logits / temperature if top_p is not None and top_p < 1.0: logits = top_p_logits(logits, top_p) if top_k is not None: logits = top_k_logits(logits, top_k) probs = torch.softmax(logits, dim=-1) if temperature > 0: try: x0 = dists.Categorical(probs=probs).sample() confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) except: confidence, x0 = probs.max(dim=-1) else: confidence, x0 = probs.max(dim=-1) if neg_entropy: # Use negative entropy as confidence (for entropy-based sampling) epsilon = 1e-10 log_probs = torch.log(probs + epsilon) confidence = torch.sum(probs * log_probs, dim=-1) return confidence, x0 # ============================================================================ # Diffusion Generation # ============================================================================ @torch.no_grad() def diffusion_generate( model: PreTrainedModel, input_ids: torch.LongTensor, mask_token_id: int, max_new_tokens: int = 128, steps: int = 128, temperature: float = 0.0, top_p: Optional[float] = None, top_k: Optional[int] = None, alg: str = "entropy", alg_temp: Optional[float] = 0.1, eps: float = 1e-3, verbose: bool = False, ) -> torch.LongTensor: """ Generate text using discrete diffusion. Args: model: The diffusion language model input_ids: Input token IDs (prompt) [batch_size, prompt_len] mask_token_id: Token ID for mask token max_new_tokens: Maximum number of new tokens to generate steps: Number of diffusion steps temperature: Sampling temperature (0 = greedy) top_p: Nucleus sampling threshold top_k: Top-k sampling threshold alg: Sampling algorithm ("origin", "entropy", "maskgit_plus", "topk_margin") alg_temp: Algorithm-specific temperature for confidence weighting eps: Small epsilon for numerical stability verbose: Print progress during generation Returns: Generated token sequence [batch_size, prompt_len + max_new_tokens] """ device = input_ids.device batch_size = input_ids.shape[0] prompt_len = input_ids.shape[1] total_len = prompt_len + max_new_tokens # Initialize sequence: prompt + mask tokens for generation x = F.pad(input_ids, (0, max_new_tokens), value=mask_token_id) # Create timesteps from 1 to eps timesteps = torch.linspace(1, eps, steps + 1, device=device) for i in range(steps): mask_index = (x == mask_token_id) if not mask_index.any(): if verbose: print(f"Step {i}: No more masked tokens, stopping early") break # Forward pass outputs = model(x, return_logits_only=True) if hasattr(outputs, 'logits'): logits = outputs.logits elif isinstance(outputs, tuple): logits = outputs[0] else: logits = outputs # Shift logits for next-token prediction logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) # Get logits only for masked positions mask_logits = logits[mask_index] t = timesteps[i] s = timesteps[i + 1] if alg == "origin": # Original diffusion: random unmasking with probability 1 - s/t p_transfer = 1 - s / t if i < steps - 1 else 1 x0 = torch.zeros_like(x[mask_index], device=device, dtype=torch.long) + mask_token_id transfer_index = torch.rand(*x0.shape, device=device) < p_transfer _, x0[transfer_index] = sample_tokens( mask_logits[transfer_index], temperature=temperature, top_p=top_p, top_k=top_k ) x[mask_index] = x0.clone() else: # Confidence-based unmasking algorithms if alg == "maskgit_plus": confidence, x0 = sample_tokens( mask_logits, temperature=temperature, top_p=top_p, top_k=top_k ) elif alg == "topk_margin": # Margin confidence: difference between top-2 probabilities probs = F.softmax(mask_logits / (temperature if temperature > 0 else 1), dim=-1) sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) confidence = sorted_probs[:, 0] - sorted_probs[:, 1] _, x0 = sample_tokens( mask_logits, temperature=temperature, top_p=top_p, top_k=top_k ) elif alg == "entropy": confidence, x0 = sample_tokens( mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, neg_entropy=True ) else: raise ValueError(f"Unknown algorithm: {alg}") # Determine how many tokens to unmask num_mask_token = mask_index.sum() / batch_size num_transfer = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token) if num_transfer > 0: # Create full confidence tensor full_confidence = torch.full_like(x, -torch.inf, dtype=logits.dtype) full_confidence[mask_index] = confidence # Select top-k most confident positions to unmask if alg_temp is None or alg_temp == 0: _, transfer_index = torch.topk(full_confidence, num_transfer) else: # Stochastic selection with temperature conf_probs = F.softmax(full_confidence / alg_temp, dim=-1) transfer_index = torch.multinomial(conf_probs, num_samples=num_transfer) # Create candidate tensor with predicted tokens x_candidate = torch.zeros_like(x, dtype=torch.long) + mask_token_id x_candidate[mask_index] = x0.clone() # Update only selected positions row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(transfer_index) x[row_indices, transfer_index] = x_candidate[row_indices, transfer_index] if verbose and (i + 1) % max(1, steps // 10) == 0: remaining_masks = (x == mask_token_id).sum().item() print(f"Step {i+1}/{steps}: {remaining_masks} masked tokens remaining") return x # ============================================================================ # Model Loading # ============================================================================ def load_model_and_tokenizer( checkpoint_path: str, device: str = "auto", torch_dtype: str = "bfloat16", ) -> Tuple[PreTrainedModel, AutoTokenizer, dict]: """ Load the diffusion model and tokenizer from checkpoint. Args: checkpoint_path: Path to the checkpoint directory device: Device to load model on ("auto", "cuda", "cpu") torch_dtype: Data type for model weights Returns: model: Loaded model tokenizer: Loaded tokenizer config: Model configuration dict """ import json from transformers import Qwen2ForCausalLM, Qwen2Config # Determine device if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" # Get dtype dtype_map = { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, } dtype = dtype_map.get(torch_dtype, torch.bfloat16) if device == "cpu" and dtype == torch.bfloat16: print("Warning: bfloat16 on CPU may be slow, using float32") dtype = torch.float32 print(f"Loading model from {checkpoint_path}...") print(f" Device: {device}, Dtype: {dtype}") # Load config config_path = os.path.join(checkpoint_path, "config.json") with open(config_path, "r") as f: config_dict = json.load(f) # Import and register the model class sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from models.diffusion_qwen import DiffusionQwen3Model, DiffusionQwen3Config # Create diffusion config diff_config = DiffusionQwen3Config(**config_dict) # Create a Qwen2Config to initialize the base model architecture qwen_config = Qwen2Config( vocab_size=diff_config.vocab_size, hidden_size=diff_config.hidden_size, intermediate_size=diff_config.intermediate_size, num_hidden_layers=diff_config.num_hidden_layers, num_attention_heads=diff_config.num_attention_heads, num_key_value_heads=diff_config.num_key_value_heads, max_position_embeddings=diff_config.max_position_embeddings, rms_norm_eps=diff_config.rms_norm_eps, rope_theta=diff_config.rope_theta, hidden_act=diff_config.hidden_act, attention_dropout=diff_config.attention_dropout, use_sliding_window=False, pad_token_id=diff_config.pad_token_id, bos_token_id=diff_config.bos_token_id, eos_token_id=diff_config.eos_token_id, ) # Create DiffusionQwen3Model with proper architecture model = DiffusionQwen3Model(diff_config) # Initialize the base Qwen2 model architecture print(" Initializing model architecture...") base_model = Qwen2ForCausalLM(qwen_config) model._init_from_qwen(base_model) del base_model # Free memory # Load state dict weights_path = os.path.join(checkpoint_path, "pytorch_model.bin") if not os.path.exists(weights_path): # Try model.safetensors weights_path = os.path.join(checkpoint_path, "model.safetensors") print(f" Loading weights from {weights_path}...") state_dict = torch.load(weights_path, map_location="cpu", weights_only=True) # Handle potential key mismatches missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing: print(f" Warning: Missing keys ({len(missing)}): {missing[:3]}{'...' if len(missing) > 3 else ''}") if unexpected: print(f" Warning: Unexpected keys ({len(unexpected)}): {unexpected[:3]}{'...' if len(unexpected) > 3 else ''}") # Move to device and set eval mode model = model.to(device=device, dtype=dtype) model.eval() # Disable causal attention for bidirectional model._disable_causal_masking() # Load tokenizer print(" Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True) # Ensure mask token is set if tokenizer.mask_token_id is None: tokenizer.mask_token_id = config_dict.get("mask_token_id", 151665) print(f" Model loaded successfully!") print(f" Vocab size: {diff_config.vocab_size}") print(f" Hidden size: {diff_config.hidden_size}") print(f" Num layers: {diff_config.num_hidden_layers}") print(f" Mask token ID: {diff_config.mask_token_id}") return model, tokenizer, config_dict # ============================================================================ # Generation Wrapper # ============================================================================ def generate( model: PreTrainedModel, tokenizer: AutoTokenizer, prompt: str, max_new_tokens: int = 128, steps: int = 128, temperature: float = 0.0, top_p: Optional[float] = None, top_k: Optional[int] = None, alg: str = "entropy", alg_temp: float = 0.1, verbose: bool = False, ) -> str: """ Generate text from a prompt. Args: model: The diffusion language model tokenizer: The tokenizer prompt: Input prompt text max_new_tokens: Maximum tokens to generate steps: Diffusion steps temperature: Sampling temperature top_p: Nucleus sampling threshold top_k: Top-k sampling threshold alg: Sampling algorithm alg_temp: Algorithm temperature verbose: Print progress Returns: Generated text (prompt + completion) """ device = next(model.parameters()).device # Tokenize prompt input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) # Get mask token ID mask_token_id = getattr(model.config, "mask_token_id", tokenizer.mask_token_id) if mask_token_id is None: mask_token_id = 151665 # Default from config # Generate output_ids = diffusion_generate( model=model, input_ids=input_ids, mask_token_id=mask_token_id, max_new_tokens=max_new_tokens, steps=steps, temperature=temperature, top_p=top_p, top_k=top_k, alg=alg, alg_temp=alg_temp, verbose=verbose, ) # Filter out mask and pad tokens output_ids = output_ids[0] # Remove batch dimension pad_token_id = tokenizer.pad_token_id or 151643 output_ids = output_ids[output_ids != mask_token_id] output_ids = output_ids[output_ids != pad_token_id] # Decode generated_text = tokenizer.decode(output_ids, skip_special_tokens=True) return generated_text def chat_generate( model: PreTrainedModel, tokenizer: AutoTokenizer, messages: List[dict], max_new_tokens: int = 256, steps: int = 128, temperature: float = 0.0, top_p: Optional[float] = None, top_k: Optional[int] = None, alg: str = "entropy", alg_temp: float = 0.1, verbose: bool = False, ) -> str: """ Generate chat response from conversation history. Args: model: The diffusion language model tokenizer: The tokenizer messages: List of message dicts with 'role' and 'content' Other args: Same as generate() Returns: Assistant response text """ device = next(model.parameters()).device # Apply chat template prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) # Tokenize input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) prompt_len = input_ids.shape[1] # Get mask token ID mask_token_id = getattr(model.config, "mask_token_id", tokenizer.mask_token_id) if mask_token_id is None: mask_token_id = 151665 # Generate output_ids = diffusion_generate( model=model, input_ids=input_ids, mask_token_id=mask_token_id, max_new_tokens=max_new_tokens, steps=steps, temperature=temperature, top_p=top_p, top_k=top_k, alg=alg, alg_temp=alg_temp, verbose=verbose, ) # Get only the generated tokens (after prompt) generated_ids = output_ids[0, prompt_len:] # Filter out mask and pad tokens pad_token_id = tokenizer.pad_token_id or 151643 generated_ids = generated_ids[generated_ids != mask_token_id] generated_ids = generated_ids[generated_ids != pad_token_id] # Decode response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() return response # ============================================================================ # Interactive Chat # ============================================================================ def interactive_chat( model: PreTrainedModel, tokenizer: AutoTokenizer, system_prompt: str = "You are a helpful assistant.", **gen_kwargs, ): """Run interactive chat session.""" print("\n" + "=" * 60) print("Interactive Chat Mode") print("=" * 60) print("Commands:") print(" /exit or /quit - Exit the chat") print(" /reset - Reset conversation history") print(" /system - Set new system prompt") print("=" * 60 + "\n") messages = [{"role": "system", "content": system_prompt}] while True: try: user_input = input("\033[92mYou: \033[0m").strip() except (EOFError, KeyboardInterrupt): print("\nGoodbye!") break if not user_input: continue # Handle commands if user_input.lower() in ["/exit", "/quit"]: print("Goodbye!") break if user_input.lower() == "/reset": messages = [{"role": "system", "content": system_prompt}] print("\033[90mConversation reset.\033[0m") continue if user_input.lower().startswith("/system "): system_prompt = user_input[8:].strip() messages = [{"role": "system", "content": system_prompt}] print("\033[90mSystem prompt updated.\033[0m") continue # Add user message messages.append({"role": "user", "content": user_input}) # Generate response print("\033[94mAssistant: \033[0m", end="", flush=True) try: response = chat_generate( model=model, tokenizer=tokenizer, messages=messages, **gen_kwargs, ) print(response) messages.append({"role": "assistant", "content": response}) except Exception as e: print(f"\033[91mError: {e}\033[0m") messages.pop() # Remove failed user message # ============================================================================ # Main # ============================================================================ def main(): parser = argparse.ArgumentParser( description="Run inference with DiffusionQwen3 model", formatter_class=argparse.RawDescriptionHelpFormatter, ) # Model arguments parser.add_argument( "--checkpoint", "-c", type=str, default="./outputs/pretrain/checkpoint-1000", help="Path to model checkpoint directory", ) parser.add_argument( "--device", type=str, default="auto", choices=["auto", "cuda", "cpu"], help="Device to run on", ) parser.add_argument( "--dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"], help="Model data type", ) # Generation mode parser.add_argument( "--mode", "-m", type=str, default="prompt", choices=["prompt", "chat"], help="Generation mode: 'prompt' for single completion, 'chat' for interactive", ) parser.add_argument( "--prompt", "-p", type=str, default=None, help="Input prompt for single completion mode", ) parser.add_argument( "--system", type=str, default="You are a helpful assistant.", help="System prompt for chat mode", ) # Generation parameters parser.add_argument("--max-tokens", type=int, default=256, help="Max tokens to generate") parser.add_argument("--steps", type=int, default=128, help="Diffusion steps") parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature") parser.add_argument("--top-p", type=float, default=None, help="Nucleus sampling threshold") parser.add_argument("--top-k", type=int, default=None, help="Top-k sampling") parser.add_argument( "--alg", type=str, default="entropy", choices=["origin", "entropy", "maskgit_plus", "topk_margin"], help="Diffusion sampling algorithm", ) parser.add_argument("--alg-temp", type=float, default=0.1, help="Algorithm temperature") parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") args = parser.parse_args() # Load model model, tokenizer, config = load_model_and_tokenizer( args.checkpoint, device=args.device, torch_dtype=args.dtype, ) # Generation kwargs gen_kwargs = { "max_new_tokens": args.max_tokens, "steps": args.steps, "temperature": args.temperature, "top_p": args.top_p, "top_k": args.top_k, "alg": args.alg, "alg_temp": args.alg_temp, "verbose": args.verbose, } if args.mode == "chat": interactive_chat(model, tokenizer, system_prompt=args.system, **gen_kwargs) else: # Single prompt mode if args.prompt is None: # Default demo prompts prompts = [ "def fibonacci(n):", "Write a Python function to check if a number is prime:", "# Calculate the factorial of a number\ndef factorial(n):", ] print("\nNo prompt provided. Running demo with sample prompts...\n") for prompt in prompts: print("=" * 60) print(f"Prompt: {prompt}") print("-" * 60) result = generate(model, tokenizer, prompt, **gen_kwargs) print(f"Generated:\n{result}") print("=" * 60 + "\n") else: result = generate(model, tokenizer, args.prompt, **gen_kwargs) print("\n" + "=" * 60) print("Generated:") print("=" * 60) print(result) print("=" * 60) if __name__ == "__main__": main()