adhd-diffusion / inference.py
shouryamaanjain's picture
Upload smj-diffusion checkpoint (step 12000)
cd2f2fc verified
#!/usr/bin/env python3
"""
Inference script for DiffusionQwen3 model checkpoint.
Usage:
# Interactive chat mode
python inference.py --checkpoint ./outputs/pretrain/checkpoint-1000 --mode chat
# Single prompt completion
python inference.py --checkpoint ./outputs/pretrain/checkpoint-1000 --prompt "def fibonacci(n):"
# With custom generation parameters
python inference.py --checkpoint ./outputs/pretrain/checkpoint-1000 \
--prompt "Write a hello world in Python" \
--steps 128 --temperature 0.0 --max-tokens 256
"""
import argparse
import sys
import os
from typing import Optional, Tuple, List
import torch
import torch.nn.functional as F
import torch.distributions as dists
from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig
# ============================================================================
# Diffusion Sampling Utilities (adapted from CoDALanguageModel/generation_utils.py)
# ============================================================================
def top_p_logits(logits: torch.Tensor, top_p: float) -> torch.Tensor:
"""Apply nucleus (top-p) filtering to logits."""
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
mask = torch.zeros_like(logits, dtype=torch.bool)
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
return logits
def top_k_logits(logits: torch.Tensor, top_k: int) -> torch.Tensor:
"""Apply top-k filtering to logits."""
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
return logits
def sample_tokens(
logits: torch.Tensor,
temperature: float = 0.0,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
neg_entropy: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample tokens from logits with optional temperature, top-p, and top-k.
Returns:
confidence: Confidence scores for sampled tokens
x0: Sampled token IDs
"""
if temperature > 0:
logits = logits / temperature
if top_p is not None and top_p < 1.0:
logits = top_p_logits(logits, top_p)
if top_k is not None:
logits = top_k_logits(logits, top_k)
probs = torch.softmax(logits, dim=-1)
if temperature > 0:
try:
x0 = dists.Categorical(probs=probs).sample()
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
except:
confidence, x0 = probs.max(dim=-1)
else:
confidence, x0 = probs.max(dim=-1)
if neg_entropy:
# Use negative entropy as confidence (for entropy-based sampling)
epsilon = 1e-10
log_probs = torch.log(probs + epsilon)
confidence = torch.sum(probs * log_probs, dim=-1)
return confidence, x0
# ============================================================================
# Diffusion Generation
# ============================================================================
@torch.no_grad()
def diffusion_generate(
model: PreTrainedModel,
input_ids: torch.LongTensor,
mask_token_id: int,
max_new_tokens: int = 128,
steps: int = 128,
temperature: float = 0.0,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
alg: str = "entropy",
alg_temp: Optional[float] = 0.1,
eps: float = 1e-3,
verbose: bool = False,
) -> torch.LongTensor:
"""
Generate text using discrete diffusion.
Args:
model: The diffusion language model
input_ids: Input token IDs (prompt) [batch_size, prompt_len]
mask_token_id: Token ID for mask token
max_new_tokens: Maximum number of new tokens to generate
steps: Number of diffusion steps
temperature: Sampling temperature (0 = greedy)
top_p: Nucleus sampling threshold
top_k: Top-k sampling threshold
alg: Sampling algorithm ("origin", "entropy", "maskgit_plus", "topk_margin")
alg_temp: Algorithm-specific temperature for confidence weighting
eps: Small epsilon for numerical stability
verbose: Print progress during generation
Returns:
Generated token sequence [batch_size, prompt_len + max_new_tokens]
"""
device = input_ids.device
batch_size = input_ids.shape[0]
prompt_len = input_ids.shape[1]
total_len = prompt_len + max_new_tokens
# Initialize sequence: prompt + mask tokens for generation
x = F.pad(input_ids, (0, max_new_tokens), value=mask_token_id)
# Create timesteps from 1 to eps
timesteps = torch.linspace(1, eps, steps + 1, device=device)
for i in range(steps):
mask_index = (x == mask_token_id)
if not mask_index.any():
if verbose:
print(f"Step {i}: No more masked tokens, stopping early")
break
# Forward pass
outputs = model(x, return_logits_only=True)
if hasattr(outputs, 'logits'):
logits = outputs.logits
elif isinstance(outputs, tuple):
logits = outputs[0]
else:
logits = outputs
# Shift logits for next-token prediction
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
# Get logits only for masked positions
mask_logits = logits[mask_index]
t = timesteps[i]
s = timesteps[i + 1]
if alg == "origin":
# Original diffusion: random unmasking with probability 1 - s/t
p_transfer = 1 - s / t if i < steps - 1 else 1
x0 = torch.zeros_like(x[mask_index], device=device, dtype=torch.long) + mask_token_id
transfer_index = torch.rand(*x0.shape, device=device) < p_transfer
_, x0[transfer_index] = sample_tokens(
mask_logits[transfer_index],
temperature=temperature,
top_p=top_p,
top_k=top_k
)
x[mask_index] = x0.clone()
else:
# Confidence-based unmasking algorithms
if alg == "maskgit_plus":
confidence, x0 = sample_tokens(
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k
)
elif alg == "topk_margin":
# Margin confidence: difference between top-2 probabilities
probs = F.softmax(mask_logits / (temperature if temperature > 0 else 1), dim=-1)
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
confidence = sorted_probs[:, 0] - sorted_probs[:, 1]
_, x0 = sample_tokens(
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k
)
elif alg == "entropy":
confidence, x0 = sample_tokens(
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k,
neg_entropy=True
)
else:
raise ValueError(f"Unknown algorithm: {alg}")
# Determine how many tokens to unmask
num_mask_token = mask_index.sum() / batch_size
num_transfer = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
if num_transfer > 0:
# Create full confidence tensor
full_confidence = torch.full_like(x, -torch.inf, dtype=logits.dtype)
full_confidence[mask_index] = confidence
# Select top-k most confident positions to unmask
if alg_temp is None or alg_temp == 0:
_, transfer_index = torch.topk(full_confidence, num_transfer)
else:
# Stochastic selection with temperature
conf_probs = F.softmax(full_confidence / alg_temp, dim=-1)
transfer_index = torch.multinomial(conf_probs, num_samples=num_transfer)
# Create candidate tensor with predicted tokens
x_candidate = torch.zeros_like(x, dtype=torch.long) + mask_token_id
x_candidate[mask_index] = x0.clone()
# Update only selected positions
row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(transfer_index)
x[row_indices, transfer_index] = x_candidate[row_indices, transfer_index]
if verbose and (i + 1) % max(1, steps // 10) == 0:
remaining_masks = (x == mask_token_id).sum().item()
print(f"Step {i+1}/{steps}: {remaining_masks} masked tokens remaining")
return x
# ============================================================================
# Model Loading
# ============================================================================
def load_model_and_tokenizer(
checkpoint_path: str,
device: str = "auto",
torch_dtype: str = "bfloat16",
) -> Tuple[PreTrainedModel, AutoTokenizer, dict]:
"""
Load the diffusion model and tokenizer from checkpoint.
Args:
checkpoint_path: Path to the checkpoint directory
device: Device to load model on ("auto", "cuda", "cpu")
torch_dtype: Data type for model weights
Returns:
model: Loaded model
tokenizer: Loaded tokenizer
config: Model configuration dict
"""
import json
from transformers import Qwen2ForCausalLM, Qwen2Config
# Determine device
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
# Get dtype
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
dtype = dtype_map.get(torch_dtype, torch.bfloat16)
if device == "cpu" and dtype == torch.bfloat16:
print("Warning: bfloat16 on CPU may be slow, using float32")
dtype = torch.float32
print(f"Loading model from {checkpoint_path}...")
print(f" Device: {device}, Dtype: {dtype}")
# Load config
config_path = os.path.join(checkpoint_path, "config.json")
with open(config_path, "r") as f:
config_dict = json.load(f)
# Import and register the model class
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from models.diffusion_qwen import DiffusionQwen3Model, DiffusionQwen3Config
# Create diffusion config
diff_config = DiffusionQwen3Config(**config_dict)
# Create a Qwen2Config to initialize the base model architecture
qwen_config = Qwen2Config(
vocab_size=diff_config.vocab_size,
hidden_size=diff_config.hidden_size,
intermediate_size=diff_config.intermediate_size,
num_hidden_layers=diff_config.num_hidden_layers,
num_attention_heads=diff_config.num_attention_heads,
num_key_value_heads=diff_config.num_key_value_heads,
max_position_embeddings=diff_config.max_position_embeddings,
rms_norm_eps=diff_config.rms_norm_eps,
rope_theta=diff_config.rope_theta,
hidden_act=diff_config.hidden_act,
attention_dropout=diff_config.attention_dropout,
use_sliding_window=False,
pad_token_id=diff_config.pad_token_id,
bos_token_id=diff_config.bos_token_id,
eos_token_id=diff_config.eos_token_id,
)
# Create DiffusionQwen3Model with proper architecture
model = DiffusionQwen3Model(diff_config)
# Initialize the base Qwen2 model architecture
print(" Initializing model architecture...")
base_model = Qwen2ForCausalLM(qwen_config)
model._init_from_qwen(base_model)
del base_model # Free memory
# Load state dict
weights_path = os.path.join(checkpoint_path, "pytorch_model.bin")
if not os.path.exists(weights_path):
# Try model.safetensors
weights_path = os.path.join(checkpoint_path, "model.safetensors")
print(f" Loading weights from {weights_path}...")
state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
# Handle potential key mismatches
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f" Warning: Missing keys ({len(missing)}): {missing[:3]}{'...' if len(missing) > 3 else ''}")
if unexpected:
print(f" Warning: Unexpected keys ({len(unexpected)}): {unexpected[:3]}{'...' if len(unexpected) > 3 else ''}")
# Move to device and set eval mode
model = model.to(device=device, dtype=dtype)
model.eval()
# Disable causal attention for bidirectional
model._disable_causal_masking()
# Load tokenizer
print(" Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True)
# Ensure mask token is set
if tokenizer.mask_token_id is None:
tokenizer.mask_token_id = config_dict.get("mask_token_id", 151665)
print(f" Model loaded successfully!")
print(f" Vocab size: {diff_config.vocab_size}")
print(f" Hidden size: {diff_config.hidden_size}")
print(f" Num layers: {diff_config.num_hidden_layers}")
print(f" Mask token ID: {diff_config.mask_token_id}")
return model, tokenizer, config_dict
# ============================================================================
# Generation Wrapper
# ============================================================================
def generate(
model: PreTrainedModel,
tokenizer: AutoTokenizer,
prompt: str,
max_new_tokens: int = 128,
steps: int = 128,
temperature: float = 0.0,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
alg: str = "entropy",
alg_temp: float = 0.1,
verbose: bool = False,
) -> str:
"""
Generate text from a prompt.
Args:
model: The diffusion language model
tokenizer: The tokenizer
prompt: Input prompt text
max_new_tokens: Maximum tokens to generate
steps: Diffusion steps
temperature: Sampling temperature
top_p: Nucleus sampling threshold
top_k: Top-k sampling threshold
alg: Sampling algorithm
alg_temp: Algorithm temperature
verbose: Print progress
Returns:
Generated text (prompt + completion)
"""
device = next(model.parameters()).device
# Tokenize prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Get mask token ID
mask_token_id = getattr(model.config, "mask_token_id", tokenizer.mask_token_id)
if mask_token_id is None:
mask_token_id = 151665 # Default from config
# Generate
output_ids = diffusion_generate(
model=model,
input_ids=input_ids,
mask_token_id=mask_token_id,
max_new_tokens=max_new_tokens,
steps=steps,
temperature=temperature,
top_p=top_p,
top_k=top_k,
alg=alg,
alg_temp=alg_temp,
verbose=verbose,
)
# Filter out mask and pad tokens
output_ids = output_ids[0] # Remove batch dimension
pad_token_id = tokenizer.pad_token_id or 151643
output_ids = output_ids[output_ids != mask_token_id]
output_ids = output_ids[output_ids != pad_token_id]
# Decode
generated_text = tokenizer.decode(output_ids, skip_special_tokens=True)
return generated_text
def chat_generate(
model: PreTrainedModel,
tokenizer: AutoTokenizer,
messages: List[dict],
max_new_tokens: int = 256,
steps: int = 128,
temperature: float = 0.0,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
alg: str = "entropy",
alg_temp: float = 0.1,
verbose: bool = False,
) -> str:
"""
Generate chat response from conversation history.
Args:
model: The diffusion language model
tokenizer: The tokenizer
messages: List of message dicts with 'role' and 'content'
Other args: Same as generate()
Returns:
Assistant response text
"""
device = next(model.parameters()).device
# Apply chat template
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Tokenize
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
prompt_len = input_ids.shape[1]
# Get mask token ID
mask_token_id = getattr(model.config, "mask_token_id", tokenizer.mask_token_id)
if mask_token_id is None:
mask_token_id = 151665
# Generate
output_ids = diffusion_generate(
model=model,
input_ids=input_ids,
mask_token_id=mask_token_id,
max_new_tokens=max_new_tokens,
steps=steps,
temperature=temperature,
top_p=top_p,
top_k=top_k,
alg=alg,
alg_temp=alg_temp,
verbose=verbose,
)
# Get only the generated tokens (after prompt)
generated_ids = output_ids[0, prompt_len:]
# Filter out mask and pad tokens
pad_token_id = tokenizer.pad_token_id or 151643
generated_ids = generated_ids[generated_ids != mask_token_id]
generated_ids = generated_ids[generated_ids != pad_token_id]
# Decode
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return response
# ============================================================================
# Interactive Chat
# ============================================================================
def interactive_chat(
model: PreTrainedModel,
tokenizer: AutoTokenizer,
system_prompt: str = "You are a helpful assistant.",
**gen_kwargs,
):
"""Run interactive chat session."""
print("\n" + "=" * 60)
print("Interactive Chat Mode")
print("=" * 60)
print("Commands:")
print(" /exit or /quit - Exit the chat")
print(" /reset - Reset conversation history")
print(" /system <text> - Set new system prompt")
print("=" * 60 + "\n")
messages = [{"role": "system", "content": system_prompt}]
while True:
try:
user_input = input("\033[92mYou: \033[0m").strip()
except (EOFError, KeyboardInterrupt):
print("\nGoodbye!")
break
if not user_input:
continue
# Handle commands
if user_input.lower() in ["/exit", "/quit"]:
print("Goodbye!")
break
if user_input.lower() == "/reset":
messages = [{"role": "system", "content": system_prompt}]
print("\033[90mConversation reset.\033[0m")
continue
if user_input.lower().startswith("/system "):
system_prompt = user_input[8:].strip()
messages = [{"role": "system", "content": system_prompt}]
print("\033[90mSystem prompt updated.\033[0m")
continue
# Add user message
messages.append({"role": "user", "content": user_input})
# Generate response
print("\033[94mAssistant: \033[0m", end="", flush=True)
try:
response = chat_generate(
model=model,
tokenizer=tokenizer,
messages=messages,
**gen_kwargs,
)
print(response)
messages.append({"role": "assistant", "content": response})
except Exception as e:
print(f"\033[91mError: {e}\033[0m")
messages.pop() # Remove failed user message
# ============================================================================
# Main
# ============================================================================
def main():
parser = argparse.ArgumentParser(
description="Run inference with DiffusionQwen3 model",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Model arguments
parser.add_argument(
"--checkpoint", "-c",
type=str,
default="./outputs/pretrain/checkpoint-1000",
help="Path to model checkpoint directory",
)
parser.add_argument(
"--device",
type=str,
default="auto",
choices=["auto", "cuda", "cpu"],
help="Device to run on",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["float32", "float16", "bfloat16"],
help="Model data type",
)
# Generation mode
parser.add_argument(
"--mode", "-m",
type=str,
default="prompt",
choices=["prompt", "chat"],
help="Generation mode: 'prompt' for single completion, 'chat' for interactive",
)
parser.add_argument(
"--prompt", "-p",
type=str,
default=None,
help="Input prompt for single completion mode",
)
parser.add_argument(
"--system",
type=str,
default="You are a helpful assistant.",
help="System prompt for chat mode",
)
# Generation parameters
parser.add_argument("--max-tokens", type=int, default=256, help="Max tokens to generate")
parser.add_argument("--steps", type=int, default=128, help="Diffusion steps")
parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
parser.add_argument("--top-p", type=float, default=None, help="Nucleus sampling threshold")
parser.add_argument("--top-k", type=int, default=None, help="Top-k sampling")
parser.add_argument(
"--alg",
type=str,
default="entropy",
choices=["origin", "entropy", "maskgit_plus", "topk_margin"],
help="Diffusion sampling algorithm",
)
parser.add_argument("--alg-temp", type=float, default=0.1, help="Algorithm temperature")
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
args = parser.parse_args()
# Load model
model, tokenizer, config = load_model_and_tokenizer(
args.checkpoint,
device=args.device,
torch_dtype=args.dtype,
)
# Generation kwargs
gen_kwargs = {
"max_new_tokens": args.max_tokens,
"steps": args.steps,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": args.top_k,
"alg": args.alg,
"alg_temp": args.alg_temp,
"verbose": args.verbose,
}
if args.mode == "chat":
interactive_chat(model, tokenizer, system_prompt=args.system, **gen_kwargs)
else:
# Single prompt mode
if args.prompt is None:
# Default demo prompts
prompts = [
"def fibonacci(n):",
"Write a Python function to check if a number is prime:",
"# Calculate the factorial of a number\ndef factorial(n):",
]
print("\nNo prompt provided. Running demo with sample prompts...\n")
for prompt in prompts:
print("=" * 60)
print(f"Prompt: {prompt}")
print("-" * 60)
result = generate(model, tokenizer, prompt, **gen_kwargs)
print(f"Generated:\n{result}")
print("=" * 60 + "\n")
else:
result = generate(model, tokenizer, args.prompt, **gen_kwargs)
print("\n" + "=" * 60)
print("Generated:")
print("=" * 60)
print(result)
print("=" * 60)
if __name__ == "__main__":
main()