|
|
|
|
|
""" |
|
|
Inference script for DiffusionQwen3 model checkpoint. |
|
|
|
|
|
Usage: |
|
|
# Interactive chat mode |
|
|
python inference.py --checkpoint ./outputs/pretrain/checkpoint-1000 --mode chat |
|
|
|
|
|
# Single prompt completion |
|
|
python inference.py --checkpoint ./outputs/pretrain/checkpoint-1000 --prompt "def fibonacci(n):" |
|
|
|
|
|
# With custom generation parameters |
|
|
python inference.py --checkpoint ./outputs/pretrain/checkpoint-1000 \ |
|
|
--prompt "Write a hello world in Python" \ |
|
|
--steps 128 --temperature 0.0 --max-tokens 256 |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import sys |
|
|
import os |
|
|
from typing import Optional, Tuple, List |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.distributions as dists |
|
|
from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def top_p_logits(logits: torch.Tensor, top_p: float) -> torch.Tensor: |
|
|
"""Apply nucleus (top-p) filtering to logits.""" |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
mask = torch.zeros_like(logits, dtype=torch.bool) |
|
|
mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove) |
|
|
logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min) |
|
|
return logits |
|
|
|
|
|
|
|
|
def top_k_logits(logits: torch.Tensor, top_k: int) -> torch.Tensor: |
|
|
"""Apply top-k filtering to logits.""" |
|
|
top_k = min(top_k, logits.size(-1)) |
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
|
logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min) |
|
|
return logits |
|
|
|
|
|
|
|
|
def sample_tokens( |
|
|
logits: torch.Tensor, |
|
|
temperature: float = 0.0, |
|
|
top_p: Optional[float] = None, |
|
|
top_k: Optional[int] = None, |
|
|
neg_entropy: bool = False, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Sample tokens from logits with optional temperature, top-p, and top-k. |
|
|
|
|
|
Returns: |
|
|
confidence: Confidence scores for sampled tokens |
|
|
x0: Sampled token IDs |
|
|
""" |
|
|
if temperature > 0: |
|
|
logits = logits / temperature |
|
|
if top_p is not None and top_p < 1.0: |
|
|
logits = top_p_logits(logits, top_p) |
|
|
if top_k is not None: |
|
|
logits = top_k_logits(logits, top_k) |
|
|
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
|
|
if temperature > 0: |
|
|
try: |
|
|
x0 = dists.Categorical(probs=probs).sample() |
|
|
confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1) |
|
|
except: |
|
|
confidence, x0 = probs.max(dim=-1) |
|
|
else: |
|
|
confidence, x0 = probs.max(dim=-1) |
|
|
|
|
|
if neg_entropy: |
|
|
|
|
|
epsilon = 1e-10 |
|
|
log_probs = torch.log(probs + epsilon) |
|
|
confidence = torch.sum(probs * log_probs, dim=-1) |
|
|
|
|
|
return confidence, x0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def diffusion_generate( |
|
|
model: PreTrainedModel, |
|
|
input_ids: torch.LongTensor, |
|
|
mask_token_id: int, |
|
|
max_new_tokens: int = 128, |
|
|
steps: int = 128, |
|
|
temperature: float = 0.0, |
|
|
top_p: Optional[float] = None, |
|
|
top_k: Optional[int] = None, |
|
|
alg: str = "entropy", |
|
|
alg_temp: Optional[float] = 0.1, |
|
|
eps: float = 1e-3, |
|
|
verbose: bool = False, |
|
|
) -> torch.LongTensor: |
|
|
""" |
|
|
Generate text using discrete diffusion. |
|
|
|
|
|
Args: |
|
|
model: The diffusion language model |
|
|
input_ids: Input token IDs (prompt) [batch_size, prompt_len] |
|
|
mask_token_id: Token ID for mask token |
|
|
max_new_tokens: Maximum number of new tokens to generate |
|
|
steps: Number of diffusion steps |
|
|
temperature: Sampling temperature (0 = greedy) |
|
|
top_p: Nucleus sampling threshold |
|
|
top_k: Top-k sampling threshold |
|
|
alg: Sampling algorithm ("origin", "entropy", "maskgit_plus", "topk_margin") |
|
|
alg_temp: Algorithm-specific temperature for confidence weighting |
|
|
eps: Small epsilon for numerical stability |
|
|
verbose: Print progress during generation |
|
|
|
|
|
Returns: |
|
|
Generated token sequence [batch_size, prompt_len + max_new_tokens] |
|
|
""" |
|
|
device = input_ids.device |
|
|
batch_size = input_ids.shape[0] |
|
|
prompt_len = input_ids.shape[1] |
|
|
total_len = prompt_len + max_new_tokens |
|
|
|
|
|
|
|
|
x = F.pad(input_ids, (0, max_new_tokens), value=mask_token_id) |
|
|
|
|
|
|
|
|
timesteps = torch.linspace(1, eps, steps + 1, device=device) |
|
|
|
|
|
for i in range(steps): |
|
|
mask_index = (x == mask_token_id) |
|
|
|
|
|
if not mask_index.any(): |
|
|
if verbose: |
|
|
print(f"Step {i}: No more masked tokens, stopping early") |
|
|
break |
|
|
|
|
|
|
|
|
outputs = model(x, return_logits_only=True) |
|
|
if hasattr(outputs, 'logits'): |
|
|
logits = outputs.logits |
|
|
elif isinstance(outputs, tuple): |
|
|
logits = outputs[0] |
|
|
else: |
|
|
logits = outputs |
|
|
|
|
|
|
|
|
logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1) |
|
|
|
|
|
|
|
|
mask_logits = logits[mask_index] |
|
|
|
|
|
t = timesteps[i] |
|
|
s = timesteps[i + 1] |
|
|
|
|
|
if alg == "origin": |
|
|
|
|
|
p_transfer = 1 - s / t if i < steps - 1 else 1 |
|
|
x0 = torch.zeros_like(x[mask_index], device=device, dtype=torch.long) + mask_token_id |
|
|
transfer_index = torch.rand(*x0.shape, device=device) < p_transfer |
|
|
_, x0[transfer_index] = sample_tokens( |
|
|
mask_logits[transfer_index], |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k |
|
|
) |
|
|
x[mask_index] = x0.clone() |
|
|
else: |
|
|
|
|
|
if alg == "maskgit_plus": |
|
|
confidence, x0 = sample_tokens( |
|
|
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k |
|
|
) |
|
|
elif alg == "topk_margin": |
|
|
|
|
|
probs = F.softmax(mask_logits / (temperature if temperature > 0 else 1), dim=-1) |
|
|
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) |
|
|
confidence = sorted_probs[:, 0] - sorted_probs[:, 1] |
|
|
_, x0 = sample_tokens( |
|
|
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k |
|
|
) |
|
|
elif alg == "entropy": |
|
|
confidence, x0 = sample_tokens( |
|
|
mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, |
|
|
neg_entropy=True |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unknown algorithm: {alg}") |
|
|
|
|
|
|
|
|
num_mask_token = mask_index.sum() / batch_size |
|
|
num_transfer = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token) |
|
|
|
|
|
if num_transfer > 0: |
|
|
|
|
|
full_confidence = torch.full_like(x, -torch.inf, dtype=logits.dtype) |
|
|
full_confidence[mask_index] = confidence |
|
|
|
|
|
|
|
|
if alg_temp is None or alg_temp == 0: |
|
|
_, transfer_index = torch.topk(full_confidence, num_transfer) |
|
|
else: |
|
|
|
|
|
conf_probs = F.softmax(full_confidence / alg_temp, dim=-1) |
|
|
transfer_index = torch.multinomial(conf_probs, num_samples=num_transfer) |
|
|
|
|
|
|
|
|
x_candidate = torch.zeros_like(x, dtype=torch.long) + mask_token_id |
|
|
x_candidate[mask_index] = x0.clone() |
|
|
|
|
|
|
|
|
row_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(transfer_index) |
|
|
x[row_indices, transfer_index] = x_candidate[row_indices, transfer_index] |
|
|
|
|
|
if verbose and (i + 1) % max(1, steps // 10) == 0: |
|
|
remaining_masks = (x == mask_token_id).sum().item() |
|
|
print(f"Step {i+1}/{steps}: {remaining_masks} masked tokens remaining") |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model_and_tokenizer( |
|
|
checkpoint_path: str, |
|
|
device: str = "auto", |
|
|
torch_dtype: str = "bfloat16", |
|
|
) -> Tuple[PreTrainedModel, AutoTokenizer, dict]: |
|
|
""" |
|
|
Load the diffusion model and tokenizer from checkpoint. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the checkpoint directory |
|
|
device: Device to load model on ("auto", "cuda", "cpu") |
|
|
torch_dtype: Data type for model weights |
|
|
|
|
|
Returns: |
|
|
model: Loaded model |
|
|
tokenizer: Loaded tokenizer |
|
|
config: Model configuration dict |
|
|
""" |
|
|
import json |
|
|
from transformers import Qwen2ForCausalLM, Qwen2Config |
|
|
|
|
|
|
|
|
if device == "auto": |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
dtype_map = { |
|
|
"float32": torch.float32, |
|
|
"float16": torch.float16, |
|
|
"bfloat16": torch.bfloat16, |
|
|
} |
|
|
dtype = dtype_map.get(torch_dtype, torch.bfloat16) |
|
|
if device == "cpu" and dtype == torch.bfloat16: |
|
|
print("Warning: bfloat16 on CPU may be slow, using float32") |
|
|
dtype = torch.float32 |
|
|
|
|
|
print(f"Loading model from {checkpoint_path}...") |
|
|
print(f" Device: {device}, Dtype: {dtype}") |
|
|
|
|
|
|
|
|
config_path = os.path.join(checkpoint_path, "config.json") |
|
|
with open(config_path, "r") as f: |
|
|
config_dict = json.load(f) |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
from models.diffusion_qwen import DiffusionQwen3Model, DiffusionQwen3Config |
|
|
|
|
|
|
|
|
diff_config = DiffusionQwen3Config(**config_dict) |
|
|
|
|
|
|
|
|
qwen_config = Qwen2Config( |
|
|
vocab_size=diff_config.vocab_size, |
|
|
hidden_size=diff_config.hidden_size, |
|
|
intermediate_size=diff_config.intermediate_size, |
|
|
num_hidden_layers=diff_config.num_hidden_layers, |
|
|
num_attention_heads=diff_config.num_attention_heads, |
|
|
num_key_value_heads=diff_config.num_key_value_heads, |
|
|
max_position_embeddings=diff_config.max_position_embeddings, |
|
|
rms_norm_eps=diff_config.rms_norm_eps, |
|
|
rope_theta=diff_config.rope_theta, |
|
|
hidden_act=diff_config.hidden_act, |
|
|
attention_dropout=diff_config.attention_dropout, |
|
|
use_sliding_window=False, |
|
|
pad_token_id=diff_config.pad_token_id, |
|
|
bos_token_id=diff_config.bos_token_id, |
|
|
eos_token_id=diff_config.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
model = DiffusionQwen3Model(diff_config) |
|
|
|
|
|
|
|
|
print(" Initializing model architecture...") |
|
|
base_model = Qwen2ForCausalLM(qwen_config) |
|
|
model._init_from_qwen(base_model) |
|
|
del base_model |
|
|
|
|
|
|
|
|
weights_path = os.path.join(checkpoint_path, "pytorch_model.bin") |
|
|
if not os.path.exists(weights_path): |
|
|
|
|
|
weights_path = os.path.join(checkpoint_path, "model.safetensors") |
|
|
|
|
|
print(f" Loading weights from {weights_path}...") |
|
|
state_dict = torch.load(weights_path, map_location="cpu", weights_only=True) |
|
|
|
|
|
|
|
|
missing, unexpected = model.load_state_dict(state_dict, strict=False) |
|
|
if missing: |
|
|
print(f" Warning: Missing keys ({len(missing)}): {missing[:3]}{'...' if len(missing) > 3 else ''}") |
|
|
if unexpected: |
|
|
print(f" Warning: Unexpected keys ({len(unexpected)}): {unexpected[:3]}{'...' if len(unexpected) > 3 else ''}") |
|
|
|
|
|
|
|
|
model = model.to(device=device, dtype=dtype) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
model._disable_causal_masking() |
|
|
|
|
|
|
|
|
print(" Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True) |
|
|
|
|
|
|
|
|
if tokenizer.mask_token_id is None: |
|
|
tokenizer.mask_token_id = config_dict.get("mask_token_id", 151665) |
|
|
|
|
|
print(f" Model loaded successfully!") |
|
|
print(f" Vocab size: {diff_config.vocab_size}") |
|
|
print(f" Hidden size: {diff_config.hidden_size}") |
|
|
print(f" Num layers: {diff_config.num_hidden_layers}") |
|
|
print(f" Mask token ID: {diff_config.mask_token_id}") |
|
|
|
|
|
return model, tokenizer, config_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate( |
|
|
model: PreTrainedModel, |
|
|
tokenizer: AutoTokenizer, |
|
|
prompt: str, |
|
|
max_new_tokens: int = 128, |
|
|
steps: int = 128, |
|
|
temperature: float = 0.0, |
|
|
top_p: Optional[float] = None, |
|
|
top_k: Optional[int] = None, |
|
|
alg: str = "entropy", |
|
|
alg_temp: float = 0.1, |
|
|
verbose: bool = False, |
|
|
) -> str: |
|
|
""" |
|
|
Generate text from a prompt. |
|
|
|
|
|
Args: |
|
|
model: The diffusion language model |
|
|
tokenizer: The tokenizer |
|
|
prompt: Input prompt text |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
steps: Diffusion steps |
|
|
temperature: Sampling temperature |
|
|
top_p: Nucleus sampling threshold |
|
|
top_k: Top-k sampling threshold |
|
|
alg: Sampling algorithm |
|
|
alg_temp: Algorithm temperature |
|
|
verbose: Print progress |
|
|
|
|
|
Returns: |
|
|
Generated text (prompt + completion) |
|
|
""" |
|
|
device = next(model.parameters()).device |
|
|
|
|
|
|
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
|
|
|
mask_token_id = getattr(model.config, "mask_token_id", tokenizer.mask_token_id) |
|
|
if mask_token_id is None: |
|
|
mask_token_id = 151665 |
|
|
|
|
|
|
|
|
output_ids = diffusion_generate( |
|
|
model=model, |
|
|
input_ids=input_ids, |
|
|
mask_token_id=mask_token_id, |
|
|
max_new_tokens=max_new_tokens, |
|
|
steps=steps, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
alg=alg, |
|
|
alg_temp=alg_temp, |
|
|
verbose=verbose, |
|
|
) |
|
|
|
|
|
|
|
|
output_ids = output_ids[0] |
|
|
pad_token_id = tokenizer.pad_token_id or 151643 |
|
|
output_ids = output_ids[output_ids != mask_token_id] |
|
|
output_ids = output_ids[output_ids != pad_token_id] |
|
|
|
|
|
|
|
|
generated_text = tokenizer.decode(output_ids, skip_special_tokens=True) |
|
|
|
|
|
return generated_text |
|
|
|
|
|
|
|
|
def chat_generate( |
|
|
model: PreTrainedModel, |
|
|
tokenizer: AutoTokenizer, |
|
|
messages: List[dict], |
|
|
max_new_tokens: int = 256, |
|
|
steps: int = 128, |
|
|
temperature: float = 0.0, |
|
|
top_p: Optional[float] = None, |
|
|
top_k: Optional[int] = None, |
|
|
alg: str = "entropy", |
|
|
alg_temp: float = 0.1, |
|
|
verbose: bool = False, |
|
|
) -> str: |
|
|
""" |
|
|
Generate chat response from conversation history. |
|
|
|
|
|
Args: |
|
|
model: The diffusion language model |
|
|
tokenizer: The tokenizer |
|
|
messages: List of message dicts with 'role' and 'content' |
|
|
Other args: Same as generate() |
|
|
|
|
|
Returns: |
|
|
Assistant response text |
|
|
""" |
|
|
device = next(model.parameters()).device |
|
|
|
|
|
|
|
|
prompt = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
|
|
|
|
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
|
|
prompt_len = input_ids.shape[1] |
|
|
|
|
|
|
|
|
mask_token_id = getattr(model.config, "mask_token_id", tokenizer.mask_token_id) |
|
|
if mask_token_id is None: |
|
|
mask_token_id = 151665 |
|
|
|
|
|
|
|
|
output_ids = diffusion_generate( |
|
|
model=model, |
|
|
input_ids=input_ids, |
|
|
mask_token_id=mask_token_id, |
|
|
max_new_tokens=max_new_tokens, |
|
|
steps=steps, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
alg=alg, |
|
|
alg_temp=alg_temp, |
|
|
verbose=verbose, |
|
|
) |
|
|
|
|
|
|
|
|
generated_ids = output_ids[0, prompt_len:] |
|
|
|
|
|
|
|
|
pad_token_id = tokenizer.pad_token_id or 151643 |
|
|
generated_ids = generated_ids[generated_ids != mask_token_id] |
|
|
generated_ids = generated_ids[generated_ids != pad_token_id] |
|
|
|
|
|
|
|
|
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def interactive_chat( |
|
|
model: PreTrainedModel, |
|
|
tokenizer: AutoTokenizer, |
|
|
system_prompt: str = "You are a helpful assistant.", |
|
|
**gen_kwargs, |
|
|
): |
|
|
"""Run interactive chat session.""" |
|
|
print("\n" + "=" * 60) |
|
|
print("Interactive Chat Mode") |
|
|
print("=" * 60) |
|
|
print("Commands:") |
|
|
print(" /exit or /quit - Exit the chat") |
|
|
print(" /reset - Reset conversation history") |
|
|
print(" /system <text> - Set new system prompt") |
|
|
print("=" * 60 + "\n") |
|
|
|
|
|
messages = [{"role": "system", "content": system_prompt}] |
|
|
|
|
|
while True: |
|
|
try: |
|
|
user_input = input("\033[92mYou: \033[0m").strip() |
|
|
except (EOFError, KeyboardInterrupt): |
|
|
print("\nGoodbye!") |
|
|
break |
|
|
|
|
|
if not user_input: |
|
|
continue |
|
|
|
|
|
|
|
|
if user_input.lower() in ["/exit", "/quit"]: |
|
|
print("Goodbye!") |
|
|
break |
|
|
|
|
|
if user_input.lower() == "/reset": |
|
|
messages = [{"role": "system", "content": system_prompt}] |
|
|
print("\033[90mConversation reset.\033[0m") |
|
|
continue |
|
|
|
|
|
if user_input.lower().startswith("/system "): |
|
|
system_prompt = user_input[8:].strip() |
|
|
messages = [{"role": "system", "content": system_prompt}] |
|
|
print("\033[90mSystem prompt updated.\033[0m") |
|
|
continue |
|
|
|
|
|
|
|
|
messages.append({"role": "user", "content": user_input}) |
|
|
|
|
|
|
|
|
print("\033[94mAssistant: \033[0m", end="", flush=True) |
|
|
try: |
|
|
response = chat_generate( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
messages=messages, |
|
|
**gen_kwargs, |
|
|
) |
|
|
print(response) |
|
|
messages.append({"role": "assistant", "content": response}) |
|
|
except Exception as e: |
|
|
print(f"\033[91mError: {e}\033[0m") |
|
|
messages.pop() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Run inference with DiffusionQwen3 model", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--checkpoint", "-c", |
|
|
type=str, |
|
|
default="./outputs/pretrain/checkpoint-1000", |
|
|
help="Path to model checkpoint directory", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="auto", |
|
|
choices=["auto", "cuda", "cpu"], |
|
|
help="Device to run on", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dtype", |
|
|
type=str, |
|
|
default="bfloat16", |
|
|
choices=["float32", "float16", "bfloat16"], |
|
|
help="Model data type", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--mode", "-m", |
|
|
type=str, |
|
|
default="prompt", |
|
|
choices=["prompt", "chat"], |
|
|
help="Generation mode: 'prompt' for single completion, 'chat' for interactive", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--prompt", "-p", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Input prompt for single completion mode", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--system", |
|
|
type=str, |
|
|
default="You are a helpful assistant.", |
|
|
help="System prompt for chat mode", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument("--max-tokens", type=int, default=256, help="Max tokens to generate") |
|
|
parser.add_argument("--steps", type=int, default=128, help="Diffusion steps") |
|
|
parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature") |
|
|
parser.add_argument("--top-p", type=float, default=None, help="Nucleus sampling threshold") |
|
|
parser.add_argument("--top-k", type=int, default=None, help="Top-k sampling") |
|
|
parser.add_argument( |
|
|
"--alg", |
|
|
type=str, |
|
|
default="entropy", |
|
|
choices=["origin", "entropy", "maskgit_plus", "topk_margin"], |
|
|
help="Diffusion sampling algorithm", |
|
|
) |
|
|
parser.add_argument("--alg-temp", type=float, default=0.1, help="Algorithm temperature") |
|
|
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
model, tokenizer, config = load_model_and_tokenizer( |
|
|
args.checkpoint, |
|
|
device=args.device, |
|
|
torch_dtype=args.dtype, |
|
|
) |
|
|
|
|
|
|
|
|
gen_kwargs = { |
|
|
"max_new_tokens": args.max_tokens, |
|
|
"steps": args.steps, |
|
|
"temperature": args.temperature, |
|
|
"top_p": args.top_p, |
|
|
"top_k": args.top_k, |
|
|
"alg": args.alg, |
|
|
"alg_temp": args.alg_temp, |
|
|
"verbose": args.verbose, |
|
|
} |
|
|
|
|
|
if args.mode == "chat": |
|
|
interactive_chat(model, tokenizer, system_prompt=args.system, **gen_kwargs) |
|
|
else: |
|
|
|
|
|
if args.prompt is None: |
|
|
|
|
|
prompts = [ |
|
|
"def fibonacci(n):", |
|
|
"Write a Python function to check if a number is prime:", |
|
|
"# Calculate the factorial of a number\ndef factorial(n):", |
|
|
] |
|
|
print("\nNo prompt provided. Running demo with sample prompts...\n") |
|
|
for prompt in prompts: |
|
|
print("=" * 60) |
|
|
print(f"Prompt: {prompt}") |
|
|
print("-" * 60) |
|
|
result = generate(model, tokenizer, prompt, **gen_kwargs) |
|
|
print(f"Generated:\n{result}") |
|
|
print("=" * 60 + "\n") |
|
|
else: |
|
|
result = generate(model, tokenizer, args.prompt, **gen_kwargs) |
|
|
print("\n" + "=" * 60) |
|
|
print("Generated:") |
|
|
print("=" * 60) |
|
|
print(result) |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|