Spaces:
Running
on
Zero
Running
on
Zero
| """Progress tracking for LTX training. | |
| This module provides a unified progress display for training and validation sampling, | |
| encapsulating all Rich progress bar logic in one place. | |
| """ | |
| from rich.progress import ( | |
| BarColumn, | |
| Progress, | |
| TaskID, | |
| TextColumn, | |
| TimeElapsedColumn, | |
| TimeRemainingColumn, | |
| ) | |
| class SamplingContext: | |
| """Context for validation sampling progress tracking. | |
| Provides a unified progress display showing current video and denoising step. | |
| Display format: "Sampling X/Y [ββββββββββββ] step Z/W" | |
| The progress bar shows the denoising progress for the current video. | |
| """ | |
| def __init__(self, progress: Progress | None, task: TaskID | None, num_prompts: int, num_steps: int): | |
| self._progress = progress | |
| self._task = task | |
| self._num_prompts = num_prompts | |
| self._num_steps = num_steps | |
| def start_video(self, video_idx: int) -> None: | |
| """Start tracking a new video (resets step progress).""" | |
| if self._progress is None or self._task is None: | |
| return | |
| # Reset task for new video: completed=0, total=num_steps | |
| self._progress.reset(self._task, total=self._num_steps) | |
| self._progress.update( | |
| self._task, | |
| completed=0, | |
| video=f"{video_idx + 1}/{self._num_prompts}", | |
| info=f"step 0/{self._num_steps}", | |
| ) | |
| def advance_step(self) -> None: | |
| """Advance the denoising step by one.""" | |
| if self._progress is None or self._task is None: | |
| return | |
| self._progress.advance(self._task) | |
| completed = int(self._progress.tasks[self._task].completed) | |
| self._progress.update(self._task, info=f"step {completed}/{self._num_steps}") | |
| def cleanup(self) -> None: | |
| """Hide sampling task when done.""" | |
| if self._progress is None or self._task is None: | |
| return | |
| self._progress.update(self._task, visible=False) | |
| class StandaloneSamplingProgress: | |
| """Standalone progress display for inference scripts. | |
| Unlike SamplingContext (which integrates with TrainingProgress), this class | |
| manages its own Rich Progress instance for use in standalone inference scripts. | |
| Usage: | |
| with StandaloneSamplingProgress(num_steps=30) as ctx: | |
| for step in range(30): | |
| # ... denoising step ... | |
| ctx.advance_step() | |
| """ | |
| def __init__(self, num_steps: int, description: str = "Generating"): | |
| """Initialize standalone sampling progress. | |
| Args: | |
| num_steps: Total number of denoising steps | |
| description: Description to show in progress bar | |
| """ | |
| self._num_steps = num_steps | |
| self._description = description | |
| self._progress: Progress | None = None | |
| self._task: TaskID | None = None | |
| def __enter__(self) -> "StandaloneSamplingProgress": | |
| """Start the progress display.""" | |
| self._progress = Progress( | |
| TextColumn("[progress.description]{task.description}"), | |
| BarColumn(bar_width=40, style="blue"), | |
| TextColumn("{task.fields[info]}", style="cyan"), | |
| TimeElapsedColumn(), | |
| TextColumn("ETA:"), | |
| TimeRemainingColumn(compact=True), | |
| ) | |
| self._progress.__enter__() | |
| self._task = self._progress.add_task( | |
| self._description, | |
| total=self._num_steps, | |
| info=f"step 0/{self._num_steps}", | |
| ) | |
| return self | |
| def __exit__(self, *args) -> None: | |
| """Stop the progress display.""" | |
| if self._progress is not None: | |
| self._progress.__exit__(*args) | |
| def advance_step(self) -> None: | |
| """Advance the denoising step by one.""" | |
| if self._progress is None or self._task is None: | |
| return | |
| self._progress.advance(self._task) | |
| completed = int(self._progress.tasks[self._task].completed) | |
| self._progress.update(self._task, info=f"step {completed}/{self._num_steps}") | |
| class TrainingProgress: | |
| """Manages Rich progress display for training and validation. | |
| This class encapsulates all progress bar logic, providing a clean interface | |
| for the trainer to update progress without dealing with Rich internals. | |
| Usage: | |
| with TrainingProgress(enabled=True, total_steps=1000) as progress: | |
| for step in range(1000): | |
| # ... training step ... | |
| progress.update_training(loss=0.1, lr=1e-4, step_time=0.5) | |
| if should_validate: | |
| sampling_ctx = progress.start_sampling(num_prompts=3, num_steps=30) | |
| sampler = ValidationSampler(..., sampling_context=sampling_ctx) | |
| for prompt_idx, prompt in enumerate(prompts): | |
| sampling_ctx.start_video(prompt_idx) | |
| sampler.generate(...) | |
| sampling_ctx.cleanup() | |
| """ | |
| def __init__(self, enabled: bool, total_steps: int): | |
| """Initialize progress tracking. | |
| Args: | |
| enabled: Whether to display progress bars (False for non-main processes) | |
| total_steps: Total number of training steps | |
| """ | |
| self._enabled = enabled | |
| self._total_steps = total_steps | |
| self._train_task: TaskID | None = None | |
| if not enabled: | |
| self._progress = None | |
| return | |
| # Single Progress instance with flexible columns | |
| self._progress = Progress( | |
| TextColumn("[progress.description]{task.description}"), | |
| TextColumn("{task.fields[video]}", style="magenta"), | |
| BarColumn(bar_width=40, style="blue"), | |
| TextColumn("{task.fields[info]}", style="cyan"), | |
| TimeElapsedColumn(), | |
| TextColumn("ETA:"), | |
| TimeRemainingColumn(compact=True), | |
| ) | |
| def __enter__(self) -> "TrainingProgress": | |
| """Enter the progress context, starting the live display.""" | |
| if self._progress is not None: | |
| self._progress.__enter__() | |
| self._train_task = self._progress.add_task( | |
| "Training", | |
| total=self._total_steps, | |
| video=f"0/{self._total_steps}", | |
| info="Starting...", | |
| ) | |
| return self | |
| def __exit__(self, *args) -> None: | |
| """Exit the progress context, stopping the live display.""" | |
| if self._progress is not None: | |
| self._progress.__exit__(*args) | |
| def enabled(self) -> bool: | |
| """Whether progress display is enabled.""" | |
| return self._enabled | |
| def update_training( | |
| self, | |
| *, | |
| loss: float, | |
| lr: float, | |
| step_time: float, | |
| advance: bool = True, | |
| ) -> None: | |
| """Update the training progress display. | |
| Args: | |
| loss: Current training loss | |
| lr: Current learning rate | |
| step_time: Time taken for this step in seconds | |
| advance: Whether to advance the progress by one step | |
| """ | |
| if self._progress is None or self._train_task is None: | |
| return | |
| info = f"Loss: {loss:.4f} | LR: {lr:.2e} | {step_time:.2f}s/step" | |
| self._progress.update( | |
| self._train_task, | |
| advance=1 if advance else 0, | |
| info=info, | |
| ) | |
| # Update step count in video column | |
| completed = int(self._progress.tasks[self._train_task].completed) | |
| self._progress.update(self._train_task, video=f"{completed}/{self._total_steps}") | |
| def start_sampling(self, num_prompts: int, num_steps: int) -> SamplingContext: | |
| """Start validation sampling progress tracking. | |
| Creates a task that shows current video and denoising step progress. | |
| Format: "Sampling X/Y [ββββββββββββ] step Z/W" | |
| Args: | |
| num_prompts: Number of validation prompts to sample | |
| num_steps: Number of denoising steps per sample | |
| Returns: | |
| SamplingContext for tracking progress (no-op if progress is disabled) | |
| """ | |
| if self._progress is None: | |
| # Return a no-op context when progress is disabled | |
| return SamplingContext( | |
| progress=None, | |
| task=None, | |
| num_prompts=num_prompts, | |
| num_steps=num_steps, | |
| ) | |
| task = self._progress.add_task( | |
| "Sampling", | |
| total=num_steps, | |
| completed=0, | |
| video=f"0/{num_prompts}", | |
| info=f"step 0/{num_steps}", | |
| ) | |
| return SamplingContext( | |
| progress=self._progress, | |
| task=task, | |
| num_prompts=num_prompts, | |
| num_steps=num_steps, | |
| ) | |