diffusionGPT / infer-chat.py
thejagstudio's picture
Upload 10 files
486838c verified
raw
history blame
26.1 kB
import os
import sys
import time
import argparse
import importlib.util
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
# Tracks how many lines the last visualization printed so we can overwrite it
_visualize_last_lines = 0
def try_import_infer_base(base_path: str):
"""Dynamically import `infer-base.py` as a module and return it, or None on failure."""
if not os.path.exists(base_path):
return None
try:
spec = importlib.util.spec_from_file_location("infer_base", base_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
except Exception as e:
print(f"Warning: failed to import {base_path}: {e}")
return None
def load_finetuned_model(model_path: str, device: str = 'cuda'):
"""Load a saved fine-tuned model for inference."""
print(f"Loading model from {model_path}...")
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
config = checkpoint['config']
# Create model
model = DiffusionLLM(config)
# Load weights
state_dict = checkpoint['model_state']
state_dict = {k: v.float() for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
num_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"✓ Loaded model: {num_params:.1f}M parameters")
# Print training info if available
if 'step' in checkpoint:
print(f" Trained for {checkpoint['step']} steps")
if 'best_val_loss' in checkpoint:
print(f" Best validation loss: {checkpoint['best_val_loss']:.4f}")
return model, config
@torch.no_grad()
def generate_block_diffusion(
model,
tokenizer,
prompt: str,
steps: int = 32,
block_size: int = 32,
max_new_tokens: int = 128,
device: str = 'cuda',
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.9,
repetition_penalty: float = 1.2,
no_repeat_ngram_size: int = 3,
verbose: bool = True,
visualize_fn=None,
parallel_blocks: int = 1,
):
"""
Generate text using block diffusion with sampling controls.
If `visualize_fn` is provided it will be called as:
visualize_fn(tokenizer, context_ids, mask_block, is_masked, config, clear=True)
Returns the decoded generated string (including prompt).
"""
model.eval()
# Encode prompt
prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Get model config
config = model.module.config if hasattr(model, 'module') else getattr(model, 'config', None)
if hasattr(model, '_orig_mod'):
config = model._orig_mod.config
if config is None:
raise RuntimeError("Could not determine model config")
num_blocks = max_new_tokens // block_size
parallel_blocks = min(parallel_blocks, num_blocks)
if verbose:
print(f"Generating {num_blocks} blocks of {block_size} tokens ({max_new_tokens} max_new_tokens)\n")
context_ids = prompt_ids
all_generated_tokens = set(prompt_ids[0].tolist())
blocks_generated = 0
while blocks_generated < num_blocks:
current_parallel = min(parallel_blocks, num_blocks - blocks_generated)
if current_parallel > 1:
new_blocks = _generate_parallel_blocks(
model, tokenizer, context_ids, config, device,
current_parallel, block_size, steps, temperature,
top_k, top_p, repetition_penalty, no_repeat_ngram_size,
all_generated_tokens, visualize_fn
)
for block in new_blocks:
context_ids = torch.cat([context_ids, block], dim=1)
blocks_generated += 1
else:
mask_block, block_token_history = _generate_single_block(
model, tokenizer, context_ids, config, device,
block_size, steps, temperature, top_k, top_p,
repetition_penalty, no_repeat_ngram_size,
all_generated_tokens, visualize_fn
)
context_ids = torch.cat([context_ids, mask_block], dim=1)
blocks_generated += 1
generated_ids = context_ids[0].tolist()
return tokenizer.decode(generated_ids, skip_special_tokens=False)
def _apply_sampling_controls(
block_logits, context_ids, mask_block, is_masked,
repetition_penalty, temperature, top_k, top_p,
no_repeat_ngram_size, block_token_history
):
"""Apply repetition penalty, temperature, top-k, top-p, and n-gram blocking."""
if repetition_penalty != 1.0:
seen_tokens = set(context_ids[0].tolist())
for i in range(mask_block.shape[1]):
if not is_masked[0, i]:
seen_tokens.add(mask_block[0, i].item())
for token_id in seen_tokens:
if token_id < block_logits.shape[-1]:
avg = block_logits[0, :, token_id].mean()
if avg > 0:
block_logits[:, :, token_id] /= repetition_penalty
else:
block_logits[:, :, token_id] *= repetition_penalty
block_logits = block_logits / temperature
if top_k > 0:
k = min(top_k, block_logits.size(-1))
top_k_logits, top_k_indices = torch.topk(block_logits, k, dim=-1)
filtered = torch.full_like(block_logits, float('-inf'))
filtered.scatter_(-1, top_k_indices, top_k_logits)
block_logits = filtered
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(block_logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
block_logits[indices_to_remove] = float('-inf')
if no_repeat_ngram_size > 0 and len(block_token_history) >= no_repeat_ngram_size - 1:
recent_ngram = tuple(block_token_history[-(no_repeat_ngram_size - 1):])
full_history = context_ids[0].tolist() + block_token_history
for i in range(len(full_history) - no_repeat_ngram_size + 1):
if tuple(full_history[i:i + no_repeat_ngram_size - 1]) == recent_ngram:
blocked_token = full_history[i + no_repeat_ngram_size - 1]
if blocked_token < block_logits.shape[-1]:
block_logits[:, :, blocked_token] = float('-inf')
# Safety: reset if all logits are -inf
all_inf_mask = torch.isinf(block_logits).all(dim=-1)
if all_inf_mask.any():
block_logits[all_inf_mask] = 0.0
return block_logits
def _generate_single_block(
model, tokenizer, context_ids, config, device,
block_size, steps, temperature, top_k, top_p,
repetition_penalty, no_repeat_ngram_size,
all_generated_tokens, visualize_fn=None
):
"""Generate a single block using diffusion."""
mask_block = torch.full((1, block_size), config.mask_token_id, device=device)
is_masked = torch.ones(1, block_size, dtype=torch.bool, device=device)
block_token_history = []
for step_idx in range(steps):
full_input = torch.cat([context_ids, mask_block], dim=1)
attention_mask = torch.ones_like(full_input, dtype=torch.float32)
logits, _ = model(full_input, attention_mask=attention_mask)
block_logits = logits[:, -block_size:, :]
block_logits = _apply_sampling_controls(
block_logits, context_ids, mask_block, is_masked,
repetition_penalty, temperature, top_k, top_p,
no_repeat_ngram_size, block_token_history
)
probs = F.softmax(block_logits, dim=-1)
probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
probs = probs.clamp(min=1e-10)
probs = probs / probs.sum(dim=-1, keepdim=True)
sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
sampled_tokens = sampled_tokens.view(1, block_size)
confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
tokens_to_unmask = max(1, block_size // steps)
if step_idx == steps - 1:
tokens_to_unmask = int(is_masked.sum().item())
if tokens_to_unmask > 0 and is_masked.sum() > 0:
masked_confidence = confidence.clone()
masked_confidence[~is_masked] = -1.0
num_to_unmask = min(int(tokens_to_unmask), int(is_masked.sum().item()))
_, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask)
for idx in top_indices:
idx = int(idx.item())
mask_block[0, idx] = sampled_tokens[0, idx]
is_masked[0, idx] = False
block_token_history.append(sampled_tokens[0, idx].item())
all_generated_tokens.add(sampled_tokens[0, idx].item())
if callable(visualize_fn):
try:
visualize_fn(tokenizer, context_ids, mask_block, is_masked, config, clear=(step_idx > 0))
except Exception:
pass
elif visualize_fn:
visualize_diffusion_state_local(tokenizer, context_ids, mask_block, is_masked, config, clear=(step_idx > 0))
return mask_block, block_token_history
def _generate_parallel_blocks(
model, tokenizer, context_ids, config, device,
num_parallel, block_size, steps, temperature,
top_k, top_p, repetition_penalty, no_repeat_ngram_size,
all_generated_tokens, visualize_fn=None
):
"""Generate multiple blocks in parallel using batched computation.
Each block sees all previous blocks in the sequence, maintaining proper order:
- Block 0: context + [block0]
- Block 1: context + [block0] + [block1]
- Block 2: context + [block0] + [block1] + [block2]
- etc.
This ensures sequential coherence while still benefiting from batched computation.
"""
batch_size = num_parallel
context_len = context_ids.shape[1]
# Initialize mask blocks for all parallel blocks
# Shape: (num_parallel, block_size)
mask_blocks = torch.full((batch_size, block_size), config.mask_token_id, device=device)
is_masked = torch.ones(batch_size, block_size, dtype=torch.bool, device=device)
block_token_histories = [[] for _ in range(batch_size)]
for step_idx in range(steps):
# Build inputs with proper sequential structure
# Each batch item has context + all previous blocks + its own block
# Block i sees: context + block_0 + block_1 + ... + block_i
# Create padded inputs - each batch item has different length
# We'll pad to the longest sequence (which is the last block)
max_seq_len = context_len + (num_parallel * block_size)
# Build full input for each batch item
full_inputs = []
attention_masks = []
for b in range(batch_size):
# This block sees: context + all previous blocks + its own block
seq_parts = [context_ids[0]] # Start with context
# Add all blocks from 0 to b (inclusive)
for prev_b in range(b + 1):
seq_parts.append(mask_blocks[prev_b])
# Concatenate to form this batch item's input
batch_input = torch.cat(seq_parts, dim=0) # (seq_len,)
current_len = batch_input.shape[0]
# Pad to max_seq_len
padding_needed = max_seq_len - current_len
if padding_needed > 0:
pad_token = config.pad_token_id if config.pad_token_id is not None else 0
padding = torch.full((padding_needed,), pad_token, device=device)
batch_input = torch.cat([batch_input, padding], dim=0)
full_inputs.append(batch_input)
# Create attention mask (1 for real tokens, 0 for padding)
attn_mask = torch.zeros(max_seq_len, device=device)
attn_mask[:current_len] = 1.0
attention_masks.append(attn_mask)
# Stack into batched tensors
full_input = torch.stack(full_inputs, dim=0) # (batch, max_seq_len)
attention_mask = torch.stack(attention_masks, dim=0) # (batch, max_seq_len)
# Single forward pass for all blocks
logits, _ = model(full_input, attention_mask=attention_mask)
# Extract logits for each block's position
# Block b's logits are at positions [context_len + b*block_size : context_len + (b+1)*block_size]
block_logits_list = []
for b in range(batch_size):
start_pos = context_len + (b * block_size)
end_pos = start_pos + block_size
block_logits_list.append(logits[b, start_pos:end_pos, :])
block_logits = torch.stack(block_logits_list, dim=0) # (batch, block_size, vocab)
# Apply sampling controls per batch item
for b in range(batch_size):
# Build context that includes previous blocks for repetition penalty
extended_context = context_ids
if b > 0:
prev_blocks = mask_blocks[:b]
extended_context = torch.cat([context_ids] + [prev_blocks.view(1, -1)], dim=1)
block_logits[b:b+1] = _apply_sampling_controls(
block_logits[b:b+1],
extended_context,
mask_blocks[b:b+1],
is_masked[b:b+1],
repetition_penalty, temperature, top_k, top_p,
no_repeat_ngram_size, block_token_histories[b]
)
probs = F.softmax(block_logits, dim=-1)
probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
probs = probs.clamp(min=1e-10)
probs = probs / probs.sum(dim=-1, keepdim=True)
# Sample for all batches
sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
sampled_tokens = sampled_tokens.view(batch_size, block_size)
confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
tokens_to_unmask = max(1, block_size // steps)
if step_idx == steps - 1:
tokens_to_unmask = block_size # Unmask all remaining
# Unmask for each batch item
for b in range(batch_size):
if is_masked[b].sum() > 0:
masked_confidence = confidence[b]
masked_confidence = masked_confidence.clone()
masked_confidence[~is_masked[b]] = -1.0
num_to_unmask = min(int(tokens_to_unmask), int(is_masked[b].sum().item()))
_, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask)
for idx in top_indices:
idx = int(idx.item())
mask_blocks[b, idx] = sampled_tokens[b, idx]
is_masked[b, idx] = False
block_token_histories[b].append(sampled_tokens[b, idx].item())
all_generated_tokens.add(sampled_tokens[b, idx].item())
if callable(visualize_fn):
try:
block_list = [mask_blocks[b:b+1] for b in range(batch_size)]
is_masked_list = [is_masked[b:b+1] for b in range(batch_size)]
visualize_fn(tokenizer, context_ids, block_list, is_masked_list, config, clear=(step_idx > 0))
except Exception:
pass
elif visualize_fn:
block_list = [mask_blocks[b:b+1] for b in range(batch_size)]
is_masked_list = [is_masked[b:b+1] for b in range(batch_size)]
visualize_diffusion_state_local(tokenizer, context_ids, block_list, is_masked_list, config, clear=(step_idx > 0))
# Return list of generated blocks
return [mask_blocks[b:b+1] for b in range(batch_size)]
def chat(model, tokenizer, instruction: str, parallel_blocks: int = 1, **kwargs):
"""Simple chat interface."""
device = next(model.parameters()).device
prompt = format_instruct_prompt(instruction)
generated = generate_block_diffusion(
model,
tokenizer,
prompt=prompt,
device=device,
parallel_blocks=parallel_blocks,
**kwargs
)
# Extract all assistant responses using ChatML tags
start_tag = "<|im_start|>assistant"
end_tag = "<|im_end|>"
resp_parts = []
pos = 0
while True:
start_idx = generated.find(start_tag, pos)
if start_idx == -1:
break
start_idx += len(start_tag)
end_idx = generated.find(end_tag, start_idx)
if end_idx == -1:
resp_parts.append(generated[start_idx:].strip())
break
resp_parts.append(generated[start_idx:end_idx].strip())
pos = end_idx + len(end_tag)
if resp_parts:
resp = "\n\n".join(p for p in resp_parts if p)
else:
# Fallback if no assistant tags found
resp = generated.replace("<|im_start|>assistant", "").replace("<|im_end|>", "").strip()
return generated, resp
def format_instruct_prompt(instruction: str) -> str:
"""Format instruction using a simple ChatML-like template."""
return (
"<|im_start|>system\n"
"Answer this question truthfully<|im_end|>\n"
"<|im_start|>user\n"
f"{instruction}\n"
"<|im_end|>\n"
"<|im_start|>assistant\n"
)
def visualize_diffusion_state_local(tokenizer, context_ids, mask_blocks, is_masked_list, config, clear=True, block_colors=None):
"""Local visualization copied from infer-base.py to ensure consistent terminal output."""
import sys
import os
# Default colors for different blocks (green, cyan, yellow, magenta)
DEFAULT_COLORS = ['\033[92m', '\033[96m', '\033[93m', '\033[95m']
MASK_COLOR = '\033[90m' # Gray for masked tokens
RESET = '\033[0m'
# Normalize inputs to lists
if not isinstance(mask_blocks, list):
mask_blocks = [mask_blocks]
is_masked_list = [is_masked_list]
if block_colors is None:
block_colors = DEFAULT_COLORS
# Decode context (prompt + previously generated blocks) and replace newlines
try:
context_text = tokenizer.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ')
except Exception:
# Fallback to str
context_text = str(context_ids[0].tolist())
# Build visualization for all blocks
all_blocks_text = []
for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)):
color = block_colors[block_idx % len(block_colors)]
block_tokens = mask_block[0].tolist()
block_color_tokens = []
for i, token_id in enumerate(block_tokens):
if is_masked[0, i]:
# Use block-specific color for masked tokens to distinguish blocks
block_color_tokens.append(f'{MASK_COLOR}██{RESET}')
else:
# Decode individual token; use block color for revealed tokens
try:
token_text = tokenizer.decode([token_id], skip_special_tokens=False)
except Exception:
token_text = str(int(token_id))
block_color_tokens.append(f'{color}{token_text}{RESET}')
all_blocks_text.append(''.join(block_color_tokens))
# Join all blocks with a subtle separator
blocks_combined = ''.join(all_blocks_text)
# Overwrite previous visualization area (if any) by moving cursor up and clearing lines.
# This prevents accumulation of repeated frames in terminals like VSCode integrated terminal.
global _visualize_last_lines
if clear and _visualize_last_lines > 0:
try:
# Move cursor up to the start of the previous block
sys.stdout.write(f'\x1b[{_visualize_last_lines}A')
# Clear each line that was previously printed
for _ in range(_visualize_last_lines):
sys.stdout.write('\x1b[2K') # Erase entire line
sys.stdout.write('\x1b[1B') # Move cursor down one line
# Move cursor back to the top of cleared region
sys.stdout.write(f'\x1b[{_visualize_last_lines}A')
sys.stdout.flush()
except Exception:
# Fallback to whole-screen clear
try:
sys.stdout.write('\x1b[2J\x1b[H')
sys.stdout.flush()
except Exception:
try:
clear_cmd = 'cls' if os.name == 'nt' else 'clear'
os.system(clear_cmd)
except Exception:
sys.stdout.write('\r\033[K')
sys.stdout.flush()
elif clear:
# No previous region to overwrite; do a simple ANSI clear to start fresh
try:
sys.stdout.write('\x1b[2J\x1b[H')
sys.stdout.flush()
except Exception:
try:
clear_cmd = 'cls' if os.name == 'nt' else 'clear'
os.system(clear_cmd)
except Exception:
sys.stdout.write('\r\033[K')
sys.stdout.flush()
# Print legend for parallel blocks
if len(mask_blocks) > 1:
legend_parts = []
for i in range(len(mask_blocks)):
color = block_colors[i % len(block_colors)]
legend_parts.append(f'{color}Block {i+1}{RESET}')
print(f"Generating: {' | '.join(legend_parts)}\n")
# Print the full context with colored blocks
# Ensure trailing newline so subsequent clears have predictable behavior
out_text = f"{context_text}{blocks_combined}\n"
try:
sys.stdout.write(out_text)
sys.stdout.flush()
except Exception:
print(out_text, flush=True)
# Update last-lines counter so next frame can overwrite this one
try:
_visualize_last_lines = out_text.count('\n') + (1 if len(mask_blocks) > 1 else 0) + 1
except Exception:
_visualize_last_lines = out_text.count('\n')
def main():
base_path = os.path.join(os.path.dirname(__file__), "infer-base.py")
base_mod = try_import_infer_base(base_path)
if base_mod is None or not hasattr(base_mod, 'DiffusionLLM'):
raise RuntimeError("DiffusionLLM not found in infer-base.py")
DiffusionLLM = base_mod.DiffusionLLM
# Workaround for torch.load pickling
try:
main_mod = sys.modules.get('__main__')
if main_mod is not None:
if hasattr(base_mod, 'ModelConfig'):
setattr(main_mod, 'ModelConfig', base_mod.ModelConfig)
setattr(main_mod, 'DiffusionLLM', DiffusionLLM)
except Exception:
pass
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="./checkpoints/model_fp32.pt", help="Path to model checkpoint")
parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-0.5B", help="Tokenizer model id or path")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--visualize", action="store_true", default=False, help="Enable visualization during generation")
parser.add_argument("--steps", type=int, default=64)
parser.add_argument("--block_size", type=int, default=128)
parser.add_argument("--max_new_tokens", type=int, default=128)
parser.add_argument("--parallel_blocks", type=int, default=1, help="Number of blocks to generate in parallel")
args = parser.parse_args()
device = torch.device(args.device)
print(f"Using device: {device}")
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
if tokenizer.pad_token is None:
# set pad token if not present
tokenizer.pad_token = tokenizer.eos_token
# Load model
best_model_path = "checkpoints/best_model.pt"
if os.path.exists(best_model_path):
print("Loading best model...")
model, config = load_finetuned_model(best_model_path, device)
else:
model, config = load_finetuned_model(args.model, device)
# Use the local visualization implementation for consistency
visualize_fn = None
if args.visualize:
visualize_fn = visualize_diffusion_state_local
print("Ready. Type a message and press Enter (empty line to quit).\n")
while True:
try:
user_input = input("User: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nExiting.")
break
if user_input == "":
print("Goodbye.")
break
raw_output, response = chat(
model,
tokenizer,
user_input,
steps=args.steps,
block_size=args.block_size,
max_new_tokens=args.max_new_tokens,
temperature=0.8,
top_k=50,
top_p=0.9,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
verbose=False,
visualize_fn=visualize_fn,
parallel_blocks=args.parallel_blocks,
)
print("\nRaw Output:\n")
print(raw_output)
print("\nAssistant:\n")
print(response)
print("\n" + ("=" * 60) + "\n")
if __name__ == "__main__":
main()