| | """ |
| | Terminal visualization for RND1 generation. |
| | |
| | This module provides real-time visualization of the diffusion denoising process, |
| | showing token evolution and generation progress in the terminal using rich |
| | formatting when available. |
| | """ |
| |
|
| | import torch |
| | from typing import Optional |
| | from tqdm import tqdm |
| |
|
| | try: |
| | from rich.console import Console |
| | from rich.live import Live |
| | from rich.text import Text |
| | from rich.panel import Panel |
| | from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn |
| | from rich.layout import Layout |
| | RICH_AVAILABLE = True |
| | except ImportError: |
| | RICH_AVAILABLE = False |
| |
|
| |
|
| | class TerminalVisualizer: |
| | """ |
| | Rich-based visualization for diffusion process with live updates. |
| | |
| | Provides real-time visualization of the token denoising process during |
| | diffusion-based language generation, with colored highlighting of masked |
| | positions and progress tracking. |
| | """ |
| |
|
| | def __init__(self, tokenizer, show_visualization: bool = True): |
| | """ |
| | Initialize the terminal visualizer. |
| | |
| | Args: |
| | tokenizer: The tokenizer for decoding tokens to text |
| | show_visualization: Whether to show visualization (requires rich) |
| | """ |
| | self.tokenizer = tokenizer |
| | self.show_visualization = show_visualization and RICH_AVAILABLE |
| | if not RICH_AVAILABLE and show_visualization: |
| | print("Warning: Install 'rich' for better visualization. Falling back to simple progress bar.") |
| | self.show_visualization = False |
| |
|
| | if self.show_visualization: |
| | self.console = Console() |
| | self.live = None |
| | self.progress = None |
| | self.layout = None |
| | else: |
| | self.pbar = None |
| |
|
| | self.current_tokens = None |
| | self.mask_positions = None |
| | self.total_steps = 0 |
| | self.current_step = 0 |
| |
|
| | def start_visualization(self, initial_tokens: torch.LongTensor, mask_positions: torch.BoolTensor, total_steps: int): |
| | """ |
| | Start the visualization. |
| | |
| | Args: |
| | initial_tokens: Initial token IDs (possibly masked) |
| | mask_positions: Boolean mask indicating which positions are masked |
| | total_steps: Total number of diffusion steps |
| | """ |
| | if not self.show_visualization: |
| | self.pbar = tqdm(total=total_steps, desc="Diffusion") |
| | return |
| |
|
| | self.current_tokens = initial_tokens.clone() |
| | self.mask_positions = mask_positions |
| | self.total_steps = total_steps |
| | self.current_step = 0 |
| |
|
| | self.layout = Layout() |
| | self.layout.split_column( |
| | Layout(name="header", size=3), |
| | Layout(name="text", ratio=1), |
| | Layout(name="progress", size=3) |
| | ) |
| |
|
| | self.progress = Progress( |
| | TextColumn("[bold blue]Diffusion"), |
| | BarColumn(), |
| | MofNCompleteColumn(), |
| | TextColumn("•"), |
| | TextColumn("[cyan]Masks: {task.fields[masks]}"), |
| | TimeRemainingColumn(), |
| | ) |
| | self.progress_task = self.progress.add_task( |
| | "Generating", |
| | total=total_steps, |
| | masks=mask_positions.sum().item() |
| | ) |
| |
|
| | self.live = Live(self.layout, console=self.console, refresh_per_second=4) |
| | self.live.start() |
| | self._update_display() |
| |
|
| | def update_step(self, tokens: torch.LongTensor, maskable: Optional[torch.BoolTensor], step: int, |
| | entropy: Optional[torch.FloatTensor] = None, confidence: Optional[torch.FloatTensor] = None): |
| | """ |
| | Update visualization for current step. |
| | |
| | Args: |
| | tokens: Current token IDs |
| | maskable: Boolean mask of remaining masked positions |
| | step: Current step number |
| | entropy: Optional entropy scores for each position |
| | confidence: Optional confidence scores for each position |
| | """ |
| | if not self.show_visualization: |
| | if self.pbar: |
| | self.pbar.update(1) |
| | masks = maskable.sum().item() if maskable is not None else 0 |
| | self.pbar.set_postfix({'masks': masks}) |
| | return |
| |
|
| | self.current_tokens = tokens.clone() |
| | self.mask_positions = maskable |
| | self.current_step = step |
| |
|
| | masks_remaining = maskable.sum().item() if maskable is not None else 0 |
| | self.progress.update( |
| | self.progress_task, |
| | advance=1, |
| | masks=masks_remaining |
| | ) |
| |
|
| | self._update_display() |
| |
|
| | def _update_display(self): |
| | """Update the live display.""" |
| | if not self.live: |
| | return |
| |
|
| | header = Text("🎭 RND1-Base Generation", style="bold magenta", justify="center") |
| | self.layout["header"].update(Panel(header, border_style="bright_blue")) |
| |
|
| | text_display = self._format_text_with_masks() |
| | self.layout["text"].update( |
| | Panel( |
| | text_display, |
| | title="[bold]Generated Text", |
| | subtitle=f"[dim]Step {self.current_step}/{self.total_steps}[/dim]", |
| | border_style="cyan" |
| | ) |
| | ) |
| |
|
| | self.layout["progress"].update(Panel(self.progress)) |
| |
|
| | def _format_text_with_masks(self) -> Text: |
| | """ |
| | Format text with colored masks. |
| | |
| | Returns: |
| | Rich Text object with formatted tokens |
| | """ |
| | text = Text() |
| |
|
| | if self.current_tokens is None: |
| | return text |
| |
|
| | token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens |
| | mask_flags = self.mask_positions[0] if self.mask_positions is not None and self.mask_positions.dim() > 1 else self.mask_positions |
| |
|
| | for i, token_id in enumerate(token_ids): |
| | if mask_flags is not None and i < len(mask_flags) and mask_flags[i]: |
| | |
| | text.append("[MASK]", style="bold red on yellow" if self.current_step % 2 == 0 else "bold yellow on red") |
| | else: |
| | try: |
| | token_str = self.tokenizer.decode([token_id.item()], skip_special_tokens=False) |
| | |
| | if token_str not in ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<s>", "</s>"]: |
| | |
| | text.append(token_str, style="green" if i < len(token_ids) // 2 else "cyan") |
| | except: |
| | continue |
| |
|
| | return text |
| |
|
| | def stop_visualization(self): |
| | """Stop the visualization and display final result.""" |
| | if not self.show_visualization: |
| | if self.pbar: |
| | self.pbar.close() |
| | print("\n✨ Generation complete!\n") |
| | return |
| |
|
| | if self.live: |
| | self.live.stop() |
| |
|
| | self.console.print("\n[bold green]✨ Generation complete![/bold green]\n") |
| |
|
| | |
| | if self.current_tokens is not None: |
| | try: |
| | token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens |
| | final_text = self.tokenizer.decode(token_ids, skip_special_tokens=True) |
| |
|
| | self.console.print(Panel( |
| | final_text, |
| | title="[bold]Final Generated Text", |
| | border_style="green", |
| | padding=(1, 2) |
| | )) |
| | except: |
| | pass |
| |
|
| |
|
| | class SimpleProgressBar: |
| | """ |
| | Simple progress bar fallback when rich is not available. |
| | |
| | Provides basic progress tracking using tqdm when the rich library |
| | is not installed. |
| | """ |
| |
|
| | def __init__(self, total_steps: int): |
| | """ |
| | Initialize simple progress bar. |
| | |
| | Args: |
| | total_steps: Total number of steps |
| | """ |
| | self.pbar = tqdm(total=total_steps, desc="Diffusion") |
| |
|
| | def update(self, masks_remaining: int = 0): |
| | """ |
| | Update progress bar. |
| | |
| | Args: |
| | masks_remaining: Number of masks still remaining |
| | """ |
| | self.pbar.update(1) |
| | self.pbar.set_postfix({'masks': masks_remaining}) |
| |
|
| | def close(self): |
| | """Close the progress bar.""" |
| | self.pbar.close() |
| | print("\n✨ Generation complete!\n") |