Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import time | |
| import argparse | |
| import importlib.util | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer | |
| # Tracks how many lines the last visualization printed so we can overwrite it | |
| _visualize_last_lines = 0 | |
| def try_import_infer_base(base_path: str): | |
| """Dynamically import `infer-base.py` as a module and return it, or None on failure.""" | |
| if not os.path.exists(base_path): | |
| return None | |
| try: | |
| spec = importlib.util.spec_from_file_location("infer_base", base_path) | |
| module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(module) | |
| return module | |
| except Exception as e: | |
| print(f"Warning: failed to import {base_path}: {e}") | |
| return None | |
| def load_finetuned_model(model_path: str, device: str = 'cuda'): | |
| """Load a saved fine-tuned model for inference.""" | |
| print(f"Loading model from {model_path}...") | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=False) | |
| config = checkpoint['config'] | |
| # Create model | |
| model = DiffusionLLM(config) | |
| # Load weights | |
| state_dict = checkpoint['model_state'] | |
| state_dict = {k: v.float() for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| model = model.to(device) | |
| model.eval() | |
| num_params = sum(p.numel() for p in model.parameters()) / 1e6 | |
| print(f"✓ Loaded model: {num_params:.1f}M parameters") | |
| # Print training info if available | |
| if 'step' in checkpoint: | |
| print(f" Trained for {checkpoint['step']} steps") | |
| if 'best_val_loss' in checkpoint: | |
| print(f" Best validation loss: {checkpoint['best_val_loss']:.4f}") | |
| return model, config | |
| def generate_block_diffusion( | |
| model, | |
| tokenizer, | |
| prompt: str, | |
| steps: int = 32, | |
| block_size: int = 32, | |
| max_new_tokens: int = 128, | |
| device: str = 'cuda', | |
| temperature: float = 0.8, | |
| top_k: int = 50, | |
| top_p: float = 0.9, | |
| repetition_penalty: float = 1.2, | |
| no_repeat_ngram_size: int = 3, | |
| verbose: bool = True, | |
| visualize_fn=None, | |
| parallel_blocks: int = 1, | |
| ): | |
| """ | |
| Generate text using block diffusion with sampling controls. | |
| If `visualize_fn` is provided it will be called as: | |
| visualize_fn(tokenizer, context_ids, mask_block, is_masked, config, clear=True) | |
| Returns the decoded generated string (including prompt). | |
| """ | |
| model.eval() | |
| # Encode prompt | |
| prompt_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| # Get model config | |
| config = model.module.config if hasattr(model, 'module') else getattr(model, 'config', None) | |
| if hasattr(model, '_orig_mod'): | |
| config = model._orig_mod.config | |
| if config is None: | |
| raise RuntimeError("Could not determine model config") | |
| num_blocks = max_new_tokens // block_size | |
| parallel_blocks = min(parallel_blocks, num_blocks) | |
| if verbose: | |
| print(f"Generating {num_blocks} blocks of {block_size} tokens ({max_new_tokens} max_new_tokens)\n") | |
| context_ids = prompt_ids | |
| all_generated_tokens = set(prompt_ids[0].tolist()) | |
| blocks_generated = 0 | |
| while blocks_generated < num_blocks: | |
| current_parallel = min(parallel_blocks, num_blocks - blocks_generated) | |
| if current_parallel > 1: | |
| new_blocks = _generate_parallel_blocks( | |
| model, tokenizer, context_ids, config, device, | |
| current_parallel, block_size, steps, temperature, | |
| top_k, top_p, repetition_penalty, no_repeat_ngram_size, | |
| all_generated_tokens, visualize_fn | |
| ) | |
| for block in new_blocks: | |
| context_ids = torch.cat([context_ids, block], dim=1) | |
| blocks_generated += 1 | |
| else: | |
| mask_block, block_token_history = _generate_single_block( | |
| model, tokenizer, context_ids, config, device, | |
| block_size, steps, temperature, top_k, top_p, | |
| repetition_penalty, no_repeat_ngram_size, | |
| all_generated_tokens, visualize_fn | |
| ) | |
| context_ids = torch.cat([context_ids, mask_block], dim=1) | |
| blocks_generated += 1 | |
| generated_ids = context_ids[0].tolist() | |
| return tokenizer.decode(generated_ids, skip_special_tokens=False) | |
| def _apply_sampling_controls( | |
| block_logits, context_ids, mask_block, is_masked, | |
| repetition_penalty, temperature, top_k, top_p, | |
| no_repeat_ngram_size, block_token_history | |
| ): | |
| """Apply repetition penalty, temperature, top-k, top-p, and n-gram blocking.""" | |
| if repetition_penalty != 1.0: | |
| seen_tokens = set(context_ids[0].tolist()) | |
| for i in range(mask_block.shape[1]): | |
| if not is_masked[0, i]: | |
| seen_tokens.add(mask_block[0, i].item()) | |
| for token_id in seen_tokens: | |
| if token_id < block_logits.shape[-1]: | |
| avg = block_logits[0, :, token_id].mean() | |
| if avg > 0: | |
| block_logits[:, :, token_id] /= repetition_penalty | |
| else: | |
| block_logits[:, :, token_id] *= repetition_penalty | |
| block_logits = block_logits / temperature | |
| if top_k > 0: | |
| k = min(top_k, block_logits.size(-1)) | |
| top_k_logits, top_k_indices = torch.topk(block_logits, k, dim=-1) | |
| filtered = torch.full_like(block_logits, float('-inf')) | |
| filtered.scatter_(-1, top_k_indices, top_k_logits) | |
| block_logits = filtered | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(block_logits, descending=True, dim=-1) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove) | |
| block_logits[indices_to_remove] = float('-inf') | |
| if no_repeat_ngram_size > 0 and len(block_token_history) >= no_repeat_ngram_size - 1: | |
| recent_ngram = tuple(block_token_history[-(no_repeat_ngram_size - 1):]) | |
| full_history = context_ids[0].tolist() + block_token_history | |
| for i in range(len(full_history) - no_repeat_ngram_size + 1): | |
| if tuple(full_history[i:i + no_repeat_ngram_size - 1]) == recent_ngram: | |
| blocked_token = full_history[i + no_repeat_ngram_size - 1] | |
| if blocked_token < block_logits.shape[-1]: | |
| block_logits[:, :, blocked_token] = float('-inf') | |
| # Safety: reset if all logits are -inf | |
| all_inf_mask = torch.isinf(block_logits).all(dim=-1) | |
| if all_inf_mask.any(): | |
| block_logits[all_inf_mask] = 0.0 | |
| return block_logits | |
| def _generate_single_block( | |
| model, tokenizer, context_ids, config, device, | |
| block_size, steps, temperature, top_k, top_p, | |
| repetition_penalty, no_repeat_ngram_size, | |
| all_generated_tokens, visualize_fn=None | |
| ): | |
| """Generate a single block using diffusion.""" | |
| mask_block = torch.full((1, block_size), config.mask_token_id, device=device) | |
| is_masked = torch.ones(1, block_size, dtype=torch.bool, device=device) | |
| block_token_history = [] | |
| for step_idx in range(steps): | |
| full_input = torch.cat([context_ids, mask_block], dim=1) | |
| attention_mask = torch.ones_like(full_input, dtype=torch.float32) | |
| logits, _ = model(full_input, attention_mask=attention_mask) | |
| block_logits = logits[:, -block_size:, :] | |
| block_logits = _apply_sampling_controls( | |
| block_logits, context_ids, mask_block, is_masked, | |
| repetition_penalty, temperature, top_k, top_p, | |
| no_repeat_ngram_size, block_token_history | |
| ) | |
| probs = F.softmax(block_logits, dim=-1) | |
| probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) | |
| probs = probs.clamp(min=1e-10) | |
| probs = probs / probs.sum(dim=-1, keepdim=True) | |
| sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1) | |
| sampled_tokens = sampled_tokens.view(1, block_size) | |
| confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1) | |
| tokens_to_unmask = max(1, block_size // steps) | |
| if step_idx == steps - 1: | |
| tokens_to_unmask = int(is_masked.sum().item()) | |
| if tokens_to_unmask > 0 and is_masked.sum() > 0: | |
| masked_confidence = confidence.clone() | |
| masked_confidence[~is_masked] = -1.0 | |
| num_to_unmask = min(int(tokens_to_unmask), int(is_masked.sum().item())) | |
| _, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask) | |
| for idx in top_indices: | |
| idx = int(idx.item()) | |
| mask_block[0, idx] = sampled_tokens[0, idx] | |
| is_masked[0, idx] = False | |
| block_token_history.append(sampled_tokens[0, idx].item()) | |
| all_generated_tokens.add(sampled_tokens[0, idx].item()) | |
| if callable(visualize_fn): | |
| try: | |
| visualize_fn(tokenizer, context_ids, mask_block, is_masked, config, clear=(step_idx > 0)) | |
| except Exception: | |
| pass | |
| elif visualize_fn: | |
| visualize_diffusion_state_local(tokenizer, context_ids, mask_block, is_masked, config, clear=(step_idx > 0)) | |
| return mask_block, block_token_history | |
| def _generate_parallel_blocks( | |
| model, tokenizer, context_ids, config, device, | |
| num_parallel, block_size, steps, temperature, | |
| top_k, top_p, repetition_penalty, no_repeat_ngram_size, | |
| all_generated_tokens, visualize_fn=None | |
| ): | |
| """Generate multiple blocks in parallel using batched computation. | |
| Each block sees all previous blocks in the sequence, maintaining proper order: | |
| - Block 0: context + [block0] | |
| - Block 1: context + [block0] + [block1] | |
| - Block 2: context + [block0] + [block1] + [block2] | |
| - etc. | |
| This ensures sequential coherence while still benefiting from batched computation. | |
| """ | |
| batch_size = num_parallel | |
| context_len = context_ids.shape[1] | |
| # Initialize mask blocks for all parallel blocks | |
| # Shape: (num_parallel, block_size) | |
| mask_blocks = torch.full((batch_size, block_size), config.mask_token_id, device=device) | |
| is_masked = torch.ones(batch_size, block_size, dtype=torch.bool, device=device) | |
| block_token_histories = [[] for _ in range(batch_size)] | |
| for step_idx in range(steps): | |
| # Build inputs with proper sequential structure | |
| # Each batch item has context + all previous blocks + its own block | |
| # Block i sees: context + block_0 + block_1 + ... + block_i | |
| # Create padded inputs - each batch item has different length | |
| # We'll pad to the longest sequence (which is the last block) | |
| max_seq_len = context_len + (num_parallel * block_size) | |
| # Build full input for each batch item | |
| full_inputs = [] | |
| attention_masks = [] | |
| for b in range(batch_size): | |
| # This block sees: context + all previous blocks + its own block | |
| seq_parts = [context_ids[0]] # Start with context | |
| # Add all blocks from 0 to b (inclusive) | |
| for prev_b in range(b + 1): | |
| seq_parts.append(mask_blocks[prev_b]) | |
| # Concatenate to form this batch item's input | |
| batch_input = torch.cat(seq_parts, dim=0) # (seq_len,) | |
| current_len = batch_input.shape[0] | |
| # Pad to max_seq_len | |
| padding_needed = max_seq_len - current_len | |
| if padding_needed > 0: | |
| pad_token = config.pad_token_id if config.pad_token_id is not None else 0 | |
| padding = torch.full((padding_needed,), pad_token, device=device) | |
| batch_input = torch.cat([batch_input, padding], dim=0) | |
| full_inputs.append(batch_input) | |
| # Create attention mask (1 for real tokens, 0 for padding) | |
| attn_mask = torch.zeros(max_seq_len, device=device) | |
| attn_mask[:current_len] = 1.0 | |
| attention_masks.append(attn_mask) | |
| # Stack into batched tensors | |
| full_input = torch.stack(full_inputs, dim=0) # (batch, max_seq_len) | |
| attention_mask = torch.stack(attention_masks, dim=0) # (batch, max_seq_len) | |
| # Single forward pass for all blocks | |
| logits, _ = model(full_input, attention_mask=attention_mask) | |
| # Extract logits for each block's position | |
| # Block b's logits are at positions [context_len + b*block_size : context_len + (b+1)*block_size] | |
| block_logits_list = [] | |
| for b in range(batch_size): | |
| start_pos = context_len + (b * block_size) | |
| end_pos = start_pos + block_size | |
| block_logits_list.append(logits[b, start_pos:end_pos, :]) | |
| block_logits = torch.stack(block_logits_list, dim=0) # (batch, block_size, vocab) | |
| # Apply sampling controls per batch item | |
| for b in range(batch_size): | |
| # Build context that includes previous blocks for repetition penalty | |
| extended_context = context_ids | |
| if b > 0: | |
| prev_blocks = mask_blocks[:b] | |
| extended_context = torch.cat([context_ids] + [prev_blocks.view(1, -1)], dim=1) | |
| block_logits[b:b+1] = _apply_sampling_controls( | |
| block_logits[b:b+1], | |
| extended_context, | |
| mask_blocks[b:b+1], | |
| is_masked[b:b+1], | |
| repetition_penalty, temperature, top_k, top_p, | |
| no_repeat_ngram_size, block_token_histories[b] | |
| ) | |
| probs = F.softmax(block_logits, dim=-1) | |
| probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) | |
| probs = probs.clamp(min=1e-10) | |
| probs = probs / probs.sum(dim=-1, keepdim=True) | |
| # Sample for all batches | |
| sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1) | |
| sampled_tokens = sampled_tokens.view(batch_size, block_size) | |
| confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1) | |
| tokens_to_unmask = max(1, block_size // steps) | |
| if step_idx == steps - 1: | |
| tokens_to_unmask = block_size # Unmask all remaining | |
| # Unmask for each batch item | |
| for b in range(batch_size): | |
| if is_masked[b].sum() > 0: | |
| masked_confidence = confidence[b] | |
| masked_confidence = masked_confidence.clone() | |
| masked_confidence[~is_masked[b]] = -1.0 | |
| num_to_unmask = min(int(tokens_to_unmask), int(is_masked[b].sum().item())) | |
| _, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask) | |
| for idx in top_indices: | |
| idx = int(idx.item()) | |
| mask_blocks[b, idx] = sampled_tokens[b, idx] | |
| is_masked[b, idx] = False | |
| block_token_histories[b].append(sampled_tokens[b, idx].item()) | |
| all_generated_tokens.add(sampled_tokens[b, idx].item()) | |
| if callable(visualize_fn): | |
| try: | |
| block_list = [mask_blocks[b:b+1] for b in range(batch_size)] | |
| is_masked_list = [is_masked[b:b+1] for b in range(batch_size)] | |
| visualize_fn(tokenizer, context_ids, block_list, is_masked_list, config, clear=(step_idx > 0)) | |
| except Exception: | |
| pass | |
| elif visualize_fn: | |
| block_list = [mask_blocks[b:b+1] for b in range(batch_size)] | |
| is_masked_list = [is_masked[b:b+1] for b in range(batch_size)] | |
| visualize_diffusion_state_local(tokenizer, context_ids, block_list, is_masked_list, config, clear=(step_idx > 0)) | |
| # Return list of generated blocks | |
| return [mask_blocks[b:b+1] for b in range(batch_size)] | |
| def chat(model, tokenizer, instruction: str, parallel_blocks: int = 1, **kwargs): | |
| """Simple chat interface.""" | |
| device = next(model.parameters()).device | |
| prompt = format_instruct_prompt(instruction) | |
| generated = generate_block_diffusion( | |
| model, | |
| tokenizer, | |
| prompt=prompt, | |
| device=device, | |
| parallel_blocks=parallel_blocks, | |
| **kwargs | |
| ) | |
| # Extract all assistant responses using ChatML tags | |
| start_tag = "<|im_start|>assistant" | |
| end_tag = "<|im_end|>" | |
| resp_parts = [] | |
| pos = 0 | |
| while True: | |
| start_idx = generated.find(start_tag, pos) | |
| if start_idx == -1: | |
| break | |
| start_idx += len(start_tag) | |
| end_idx = generated.find(end_tag, start_idx) | |
| if end_idx == -1: | |
| resp_parts.append(generated[start_idx:].strip()) | |
| break | |
| resp_parts.append(generated[start_idx:end_idx].strip()) | |
| pos = end_idx + len(end_tag) | |
| if resp_parts: | |
| resp = "\n\n".join(p for p in resp_parts if p) | |
| else: | |
| # Fallback if no assistant tags found | |
| resp = generated.replace("<|im_start|>assistant", "").replace("<|im_end|>", "").strip() | |
| return generated, resp | |
| def format_instruct_prompt(instruction: str) -> str: | |
| """Format instruction using a simple ChatML-like template.""" | |
| return ( | |
| "<|im_start|>system\n" | |
| "Answer this question truthfully<|im_end|>\n" | |
| "<|im_start|>user\n" | |
| f"{instruction}\n" | |
| "<|im_end|>\n" | |
| "<|im_start|>assistant\n" | |
| ) | |
| def visualize_diffusion_state_local(tokenizer, context_ids, mask_blocks, is_masked_list, config, clear=True, block_colors=None): | |
| """Local visualization copied from infer-base.py to ensure consistent terminal output.""" | |
| import sys | |
| import os | |
| # Default colors for different blocks (green, cyan, yellow, magenta) | |
| DEFAULT_COLORS = ['\033[92m', '\033[96m', '\033[93m', '\033[95m'] | |
| MASK_COLOR = '\033[90m' # Gray for masked tokens | |
| RESET = '\033[0m' | |
| # Normalize inputs to lists | |
| if not isinstance(mask_blocks, list): | |
| mask_blocks = [mask_blocks] | |
| is_masked_list = [is_masked_list] | |
| if block_colors is None: | |
| block_colors = DEFAULT_COLORS | |
| # Decode context (prompt + previously generated blocks) and replace newlines | |
| try: | |
| context_text = tokenizer.decode(context_ids[0], skip_special_tokens=True).replace('\n', ' ') | |
| except Exception: | |
| # Fallback to str | |
| context_text = str(context_ids[0].tolist()) | |
| # Build visualization for all blocks | |
| all_blocks_text = [] | |
| for block_idx, (mask_block, is_masked) in enumerate(zip(mask_blocks, is_masked_list)): | |
| color = block_colors[block_idx % len(block_colors)] | |
| block_tokens = mask_block[0].tolist() | |
| block_color_tokens = [] | |
| for i, token_id in enumerate(block_tokens): | |
| if is_masked[0, i]: | |
| # Use block-specific color for masked tokens to distinguish blocks | |
| block_color_tokens.append(f'{MASK_COLOR}██{RESET}') | |
| else: | |
| # Decode individual token; use block color for revealed tokens | |
| try: | |
| token_text = tokenizer.decode([token_id], skip_special_tokens=False) | |
| except Exception: | |
| token_text = str(int(token_id)) | |
| block_color_tokens.append(f'{color}{token_text}{RESET}') | |
| all_blocks_text.append(''.join(block_color_tokens)) | |
| # Join all blocks with a subtle separator | |
| blocks_combined = ''.join(all_blocks_text) | |
| # Overwrite previous visualization area (if any) by moving cursor up and clearing lines. | |
| # This prevents accumulation of repeated frames in terminals like VSCode integrated terminal. | |
| global _visualize_last_lines | |
| if clear and _visualize_last_lines > 0: | |
| try: | |
| # Move cursor up to the start of the previous block | |
| sys.stdout.write(f'\x1b[{_visualize_last_lines}A') | |
| # Clear each line that was previously printed | |
| for _ in range(_visualize_last_lines): | |
| sys.stdout.write('\x1b[2K') # Erase entire line | |
| sys.stdout.write('\x1b[1B') # Move cursor down one line | |
| # Move cursor back to the top of cleared region | |
| sys.stdout.write(f'\x1b[{_visualize_last_lines}A') | |
| sys.stdout.flush() | |
| except Exception: | |
| # Fallback to whole-screen clear | |
| try: | |
| sys.stdout.write('\x1b[2J\x1b[H') | |
| sys.stdout.flush() | |
| except Exception: | |
| try: | |
| clear_cmd = 'cls' if os.name == 'nt' else 'clear' | |
| os.system(clear_cmd) | |
| except Exception: | |
| sys.stdout.write('\r\033[K') | |
| sys.stdout.flush() | |
| elif clear: | |
| # No previous region to overwrite; do a simple ANSI clear to start fresh | |
| try: | |
| sys.stdout.write('\x1b[2J\x1b[H') | |
| sys.stdout.flush() | |
| except Exception: | |
| try: | |
| clear_cmd = 'cls' if os.name == 'nt' else 'clear' | |
| os.system(clear_cmd) | |
| except Exception: | |
| sys.stdout.write('\r\033[K') | |
| sys.stdout.flush() | |
| # Print legend for parallel blocks | |
| if len(mask_blocks) > 1: | |
| legend_parts = [] | |
| for i in range(len(mask_blocks)): | |
| color = block_colors[i % len(block_colors)] | |
| legend_parts.append(f'{color}Block {i+1}{RESET}') | |
| print(f"Generating: {' | '.join(legend_parts)}\n") | |
| # Print the full context with colored blocks | |
| # Ensure trailing newline so subsequent clears have predictable behavior | |
| out_text = f"{context_text}{blocks_combined}\n" | |
| try: | |
| sys.stdout.write(out_text) | |
| sys.stdout.flush() | |
| except Exception: | |
| print(out_text, flush=True) | |
| # Update last-lines counter so next frame can overwrite this one | |
| try: | |
| _visualize_last_lines = out_text.count('\n') + (1 if len(mask_blocks) > 1 else 0) + 1 | |
| except Exception: | |
| _visualize_last_lines = out_text.count('\n') | |
| def main(): | |
| base_path = os.path.join(os.path.dirname(__file__), "infer-base.py") | |
| base_mod = try_import_infer_base(base_path) | |
| if base_mod is None or not hasattr(base_mod, 'DiffusionLLM'): | |
| raise RuntimeError("DiffusionLLM not found in infer-base.py") | |
| DiffusionLLM = base_mod.DiffusionLLM | |
| # Workaround for torch.load pickling | |
| try: | |
| main_mod = sys.modules.get('__main__') | |
| if main_mod is not None: | |
| if hasattr(base_mod, 'ModelConfig'): | |
| setattr(main_mod, 'ModelConfig', base_mod.ModelConfig) | |
| setattr(main_mod, 'DiffusionLLM', DiffusionLLM) | |
| except Exception: | |
| pass | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model", type=str, default="./checkpoints/model_fp32.pt", help="Path to model checkpoint") | |
| parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-0.5B", help="Tokenizer model id or path") | |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") | |
| parser.add_argument("--visualize", action="store_true", default=False, help="Enable visualization during generation") | |
| parser.add_argument("--steps", type=int, default=64) | |
| parser.add_argument("--block_size", type=int, default=128) | |
| parser.add_argument("--max_new_tokens", type=int, default=128) | |
| parser.add_argument("--parallel_blocks", type=int, default=1, help="Number of blocks to generate in parallel") | |
| args = parser.parse_args() | |
| device = torch.device(args.device) | |
| print(f"Using device: {device}") | |
| # Load tokenizer | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) | |
| if tokenizer.pad_token is None: | |
| # set pad token if not present | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load model | |
| best_model_path = "checkpoints/best_model.pt" | |
| if os.path.exists(best_model_path): | |
| print("Loading best model...") | |
| model, config = load_finetuned_model(best_model_path, device) | |
| else: | |
| model, config = load_finetuned_model(args.model, device) | |
| # Use the local visualization implementation for consistency | |
| visualize_fn = None | |
| if args.visualize: | |
| visualize_fn = visualize_diffusion_state_local | |
| print("Ready. Type a message and press Enter (empty line to quit).\n") | |
| while True: | |
| try: | |
| user_input = input("User: ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print("\nExiting.") | |
| break | |
| if user_input == "": | |
| print("Goodbye.") | |
| break | |
| raw_output, response = chat( | |
| model, | |
| tokenizer, | |
| user_input, | |
| steps=args.steps, | |
| block_size=args.block_size, | |
| max_new_tokens=args.max_new_tokens, | |
| temperature=0.8, | |
| top_k=50, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| no_repeat_ngram_size=3, | |
| verbose=False, | |
| visualize_fn=visualize_fn, | |
| parallel_blocks=args.parallel_blocks, | |
| ) | |
| print("\nRaw Output:\n") | |
| print(raw_output) | |
| print("\nAssistant:\n") | |
| print(response) | |
| print("\n" + ("=" * 60) + "\n") | |
| if __name__ == "__main__": | |
| main() | |