| """Progress tracking for LTX training. |
| This module provides a unified progress display for training and validation sampling, |
| encapsulating all Rich progress bar logic in one place. |
| """ |
|
|
| from rich.progress import ( |
| BarColumn, |
| Progress, |
| TaskID, |
| TextColumn, |
| TimeElapsedColumn, |
| TimeRemainingColumn, |
| ) |
|
|
|
|
| class SamplingContext: |
| """Context for validation sampling progress tracking. |
| Provides a unified progress display showing current video and denoising step. |
| Display format: "Sampling X/Y [ββββββββββββ] step Z/W" |
| The progress bar shows the denoising progress for the current video. |
| """ |
|
|
| def __init__(self, progress: Progress | None, task: TaskID | None, num_prompts: int, num_steps: int): |
| self._progress = progress |
| self._task = task |
| self._num_prompts = num_prompts |
| self._num_steps = num_steps |
|
|
| def start_video(self, video_idx: int) -> None: |
| """Start tracking a new video (resets step progress).""" |
| if self._progress is None or self._task is None: |
| return |
| |
| self._progress.reset(self._task, total=self._num_steps) |
| self._progress.update( |
| self._task, |
| completed=0, |
| video=f"{video_idx + 1}/{self._num_prompts}", |
| info=f"step 0/{self._num_steps}", |
| ) |
|
|
| def advance_step(self) -> None: |
| """Advance the denoising step by one.""" |
| if self._progress is None or self._task is None: |
| return |
| self._progress.advance(self._task) |
| completed = int(self._progress.tasks[self._task].completed) |
| self._progress.update(self._task, info=f"step {completed}/{self._num_steps}") |
|
|
| def cleanup(self) -> None: |
| """Hide sampling task when done.""" |
| if self._progress is None or self._task is None: |
| return |
| self._progress.update(self._task, visible=False) |
|
|
|
|
| class StandaloneSamplingProgress: |
| """Standalone progress display for inference scripts. |
| Unlike SamplingContext (which integrates with TrainingProgress), this class |
| manages its own Rich Progress instance for use in standalone inference scripts. |
| Usage: |
| with StandaloneSamplingProgress(num_steps=30) as ctx: |
| for step in range(30): |
| # ... denoising step ... |
| ctx.advance_step() |
| """ |
|
|
| def __init__(self, num_steps: int, description: str = "Generating"): |
| """Initialize standalone sampling progress. |
| Args: |
| num_steps: Total number of denoising steps |
| description: Description to show in progress bar |
| """ |
| self._num_steps = num_steps |
| self._description = description |
| self._progress: Progress | None = None |
| self._task: TaskID | None = None |
|
|
| def __enter__(self) -> "StandaloneSamplingProgress": |
| """Start the progress display.""" |
| self._progress = Progress( |
| TextColumn("[progress.description]{task.description}"), |
| BarColumn(bar_width=40, style="blue"), |
| TextColumn("{task.fields[info]}", style="cyan"), |
| TimeElapsedColumn(), |
| TextColumn("ETA:"), |
| TimeRemainingColumn(compact=True), |
| ) |
| self._progress.__enter__() |
| self._task = self._progress.add_task( |
| self._description, |
| total=self._num_steps, |
| info=f"step 0/{self._num_steps}", |
| ) |
| return self |
|
|
| def __exit__(self, *args) -> None: |
| """Stop the progress display.""" |
| if self._progress is not None: |
| self._progress.__exit__(*args) |
|
|
| def advance_step(self) -> None: |
| """Advance the denoising step by one.""" |
| if self._progress is None or self._task is None: |
| return |
| self._progress.advance(self._task) |
| completed = int(self._progress.tasks[self._task].completed) |
| self._progress.update(self._task, info=f"step {completed}/{self._num_steps}") |
|
|
|
|
| class TrainingProgress: |
| """Manages Rich progress display for training and validation. |
| This class encapsulates all progress bar logic, providing a clean interface |
| for the trainer to update progress without dealing with Rich internals. |
| Usage: |
| with TrainingProgress(enabled=True, total_steps=1000) as progress: |
| for step in range(1000): |
| # ... training step ... |
| progress.update_training(loss=0.1, lr=1e-4, step_time=0.5) |
| if should_validate: |
| sampling_ctx = progress.start_sampling(num_prompts=3, num_steps=30) |
| sampler = ValidationSampler(..., sampling_context=sampling_ctx) |
| for prompt_idx, prompt in enumerate(prompts): |
| sampling_ctx.start_video(prompt_idx) |
| sampler.generate(...) |
| sampling_ctx.cleanup() |
| """ |
|
|
| def __init__(self, enabled: bool, total_steps: int): |
| """Initialize progress tracking. |
| Args: |
| enabled: Whether to display progress bars (False for non-main processes) |
| total_steps: Total number of training steps |
| """ |
| self._enabled = enabled |
| self._total_steps = total_steps |
| self._train_task: TaskID | None = None |
|
|
| if not enabled: |
| self._progress = None |
| return |
|
|
| |
| self._progress = Progress( |
| TextColumn("[progress.description]{task.description}"), |
| TextColumn("{task.fields[video]}", style="magenta"), |
| BarColumn(bar_width=40, style="blue"), |
| TextColumn("{task.fields[info]}", style="cyan"), |
| TimeElapsedColumn(), |
| TextColumn("ETA:"), |
| TimeRemainingColumn(compact=True), |
| ) |
|
|
| def __enter__(self) -> "TrainingProgress": |
| """Enter the progress context, starting the live display.""" |
| if self._progress is not None: |
| self._progress.__enter__() |
| self._train_task = self._progress.add_task( |
| "Training", |
| total=self._total_steps, |
| video=f"0/{self._total_steps}", |
| info="Starting...", |
| ) |
| return self |
|
|
| def __exit__(self, *args) -> None: |
| """Exit the progress context, stopping the live display.""" |
| if self._progress is not None: |
| self._progress.__exit__(*args) |
|
|
| @property |
| def enabled(self) -> bool: |
| """Whether progress display is enabled.""" |
| return self._enabled |
|
|
| def update_training( |
| self, |
| *, |
| loss: float, |
| lr: float, |
| step_time: float, |
| advance: bool = True, |
| ) -> None: |
| """Update the training progress display. |
| Args: |
| loss: Current training loss |
| lr: Current learning rate |
| step_time: Time taken for this step in seconds |
| advance: Whether to advance the progress by one step |
| """ |
| if self._progress is None or self._train_task is None: |
| return |
|
|
| info = f"Loss: {loss:.4f} | LR: {lr:.2e} | {step_time:.2f}s/step" |
| self._progress.update( |
| self._train_task, |
| advance=1 if advance else 0, |
| info=info, |
| ) |
| |
| completed = int(self._progress.tasks[self._train_task].completed) |
| self._progress.update(self._train_task, video=f"{completed}/{self._total_steps}") |
|
|
| def start_sampling(self, num_prompts: int, num_steps: int) -> SamplingContext: |
| """Start validation sampling progress tracking. |
| Creates a task that shows current video and denoising step progress. |
| Format: "Sampling X/Y [ββββββββββββ] step Z/W" |
| Args: |
| num_prompts: Number of validation prompts to sample |
| num_steps: Number of denoising steps per sample |
| Returns: |
| SamplingContext for tracking progress (no-op if progress is disabled) |
| """ |
| if self._progress is None: |
| |
| return SamplingContext( |
| progress=None, |
| task=None, |
| num_prompts=num_prompts, |
| num_steps=num_steps, |
| ) |
|
|
| task = self._progress.add_task( |
| "Sampling", |
| total=num_steps, |
| completed=0, |
| video=f"0/{num_prompts}", |
| info=f"step 0/{num_steps}", |
| ) |
|
|
| return SamplingContext( |
| progress=self._progress, |
| task=task, |
| num_prompts=num_prompts, |
| num_steps=num_steps, |
| ) |
|
|