| """ |
| Terminal visualization for RND1 generation. |
| |
| This module provides real-time visualization of the diffusion denoising process, |
| showing token evolution and generation progress in the terminal using rich |
| formatting when available. |
| """ |
|
|
| import torch |
| from typing import Optional |
| from tqdm import tqdm |
|
|
| try: |
| from rich.console import Console |
| from rich.live import Live |
| from rich.text import Text |
| from rich.panel import Panel |
| from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn |
| from rich.layout import Layout |
| RICH_AVAILABLE = True |
| except ImportError: |
| RICH_AVAILABLE = False |
|
|
|
|
| class TerminalVisualizer: |
| """ |
| Rich-based visualization for diffusion process with live updates. |
| |
| Provides real-time visualization of the token denoising process during |
| diffusion-based language generation, with colored highlighting of masked |
| positions and progress tracking. |
| """ |
|
|
| def __init__(self, tokenizer, show_visualization: bool = True): |
| """ |
| Initialize the terminal visualizer. |
| |
| Args: |
| tokenizer: The tokenizer for decoding tokens to text |
| show_visualization: Whether to show visualization (requires rich) |
| """ |
| self.tokenizer = tokenizer |
| self.show_visualization = show_visualization and RICH_AVAILABLE |
| if not RICH_AVAILABLE and show_visualization: |
| print("Warning: Install 'rich' for better visualization. Falling back to simple progress bar.") |
| self.show_visualization = False |
|
|
| if self.show_visualization: |
| self.console = Console() |
| self.live = None |
| self.progress = None |
| self.layout = None |
| else: |
| self.pbar = None |
|
|
| self.current_tokens = None |
| self.mask_positions = None |
| self.total_steps = 0 |
| self.current_step = 0 |
|
|
| def start_visualization(self, initial_tokens: torch.LongTensor, mask_positions: torch.BoolTensor, total_steps: int): |
| """ |
| Start the visualization. |
| |
| Args: |
| initial_tokens: Initial token IDs (possibly masked) |
| mask_positions: Boolean mask indicating which positions are masked |
| total_steps: Total number of diffusion steps |
| """ |
| if not self.show_visualization: |
| self.pbar = tqdm(total=total_steps, desc="Diffusion") |
| return |
|
|
| self.current_tokens = initial_tokens.clone() |
| self.mask_positions = mask_positions |
| self.total_steps = total_steps |
| self.current_step = 0 |
|
|
| self.layout = Layout() |
| self.layout.split_column( |
| Layout(name="header", size=3), |
| Layout(name="text", ratio=1), |
| Layout(name="progress", size=3) |
| ) |
|
|
| self.progress = Progress( |
| TextColumn("[bold blue]Diffusion"), |
| BarColumn(), |
| MofNCompleteColumn(), |
| TextColumn("•"), |
| TextColumn("[cyan]Masks: {task.fields[masks]}"), |
| TimeRemainingColumn(), |
| ) |
| self.progress_task = self.progress.add_task( |
| "Generating", |
| total=total_steps, |
| masks=mask_positions.sum().item() |
| ) |
|
|
| self.live = Live(self.layout, console=self.console, refresh_per_second=4) |
| self.live.start() |
| self._update_display() |
|
|
| def update_step(self, tokens: torch.LongTensor, maskable: Optional[torch.BoolTensor], step: int, |
| entropy: Optional[torch.FloatTensor] = None, confidence: Optional[torch.FloatTensor] = None): |
| """ |
| Update visualization for current step. |
| |
| Args: |
| tokens: Current token IDs |
| maskable: Boolean mask of remaining masked positions |
| step: Current step number |
| entropy: Optional entropy scores for each position |
| confidence: Optional confidence scores for each position |
| """ |
| if not self.show_visualization: |
| if self.pbar: |
| self.pbar.update(1) |
| masks = maskable.sum().item() if maskable is not None else 0 |
| self.pbar.set_postfix({'masks': masks}) |
| return |
|
|
| self.current_tokens = tokens.clone() |
| self.mask_positions = maskable |
| self.current_step = step |
|
|
| masks_remaining = maskable.sum().item() if maskable is not None else 0 |
| self.progress.update( |
| self.progress_task, |
| advance=1, |
| masks=masks_remaining |
| ) |
|
|
| self._update_display() |
|
|
| def _update_display(self): |
| """Update the live display.""" |
| if not self.live: |
| return |
|
|
| header = Text("🎭 RND1-Base Generation", style="bold magenta", justify="center") |
| self.layout["header"].update(Panel(header, border_style="bright_blue")) |
|
|
| text_display = self._format_text_with_masks() |
| self.layout["text"].update( |
| Panel( |
| text_display, |
| title="[bold]Generated Text", |
| subtitle=f"[dim]Step {self.current_step}/{self.total_steps}[/dim]", |
| border_style="cyan" |
| ) |
| ) |
|
|
| self.layout["progress"].update(Panel(self.progress)) |
|
|
| def _format_text_with_masks(self) -> Text: |
| """ |
| Format text with colored masks. |
| |
| Returns: |
| Rich Text object with formatted tokens |
| """ |
| text = Text() |
|
|
| if self.current_tokens is None: |
| return text |
|
|
| token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens |
| mask_flags = self.mask_positions[0] if self.mask_positions is not None and self.mask_positions.dim() > 1 else self.mask_positions |
|
|
| for i, token_id in enumerate(token_ids): |
| if mask_flags is not None and i < len(mask_flags) and mask_flags[i]: |
| |
| text.append("[MASK]", style="bold red on yellow" if self.current_step % 2 == 0 else "bold yellow on red") |
| else: |
| try: |
| token_str = self.tokenizer.decode([token_id.item()], skip_special_tokens=False) |
| |
| if token_str not in ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<s>", "</s>"]: |
| |
| text.append(token_str, style="green" if i < len(token_ids) // 2 else "cyan") |
| except: |
| continue |
|
|
| return text |
|
|
| def stop_visualization(self): |
| """Stop the visualization and display final result.""" |
| if not self.show_visualization: |
| if self.pbar: |
| self.pbar.close() |
| print("\n✨ Generation complete!\n") |
| return |
|
|
| if self.live: |
| self.live.stop() |
|
|
| self.console.print("\n[bold green]✨ Generation complete![/bold green]\n") |
|
|
| |
| if self.current_tokens is not None: |
| try: |
| token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens |
| final_text = self.tokenizer.decode(token_ids, skip_special_tokens=True) |
|
|
| self.console.print(Panel( |
| final_text, |
| title="[bold]Final Generated Text", |
| border_style="green", |
| padding=(1, 2) |
| )) |
| except: |
| pass |
|
|
|
|
| class SimpleProgressBar: |
| """ |
| Simple progress bar fallback when rich is not available. |
| |
| Provides basic progress tracking using tqdm when the rich library |
| is not installed. |
| """ |
|
|
| def __init__(self, total_steps: int): |
| """ |
| Initialize simple progress bar. |
| |
| Args: |
| total_steps: Total number of steps |
| """ |
| self.pbar = tqdm(total=total_steps, desc="Diffusion") |
|
|
| def update(self, masks_remaining: int = 0): |
| """ |
| Update progress bar. |
| |
| Args: |
| masks_remaining: Number of masks still remaining |
| """ |
| self.pbar.update(1) |
| self.pbar.set_postfix({'masks': masks_remaining}) |
|
|
| def close(self): |
| """Close the progress bar.""" |
| self.pbar.close() |
| print("\n✨ Generation complete!\n") |