linoy
inital commit
ebfc6b3
"""Progress tracking for LTX training.
This module provides a unified progress display for training and validation sampling,
encapsulating all Rich progress bar logic in one place.
"""
from rich.progress import (
BarColumn,
Progress,
TaskID,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
class SamplingContext:
"""Context for validation sampling progress tracking.
Provides a unified progress display showing current video and denoising step.
Display format: "Sampling X/Y [β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ] step Z/W"
The progress bar shows the denoising progress for the current video.
"""
def __init__(self, progress: Progress | None, task: TaskID | None, num_prompts: int, num_steps: int):
self._progress = progress
self._task = task
self._num_prompts = num_prompts
self._num_steps = num_steps
def start_video(self, video_idx: int) -> None:
"""Start tracking a new video (resets step progress)."""
if self._progress is None or self._task is None:
return
# Reset task for new video: completed=0, total=num_steps
self._progress.reset(self._task, total=self._num_steps)
self._progress.update(
self._task,
completed=0,
video=f"{video_idx + 1}/{self._num_prompts}",
info=f"step 0/{self._num_steps}",
)
def advance_step(self) -> None:
"""Advance the denoising step by one."""
if self._progress is None or self._task is None:
return
self._progress.advance(self._task)
completed = int(self._progress.tasks[self._task].completed)
self._progress.update(self._task, info=f"step {completed}/{self._num_steps}")
def cleanup(self) -> None:
"""Hide sampling task when done."""
if self._progress is None or self._task is None:
return
self._progress.update(self._task, visible=False)
class StandaloneSamplingProgress:
"""Standalone progress display for inference scripts.
Unlike SamplingContext (which integrates with TrainingProgress), this class
manages its own Rich Progress instance for use in standalone inference scripts.
Usage:
with StandaloneSamplingProgress(num_steps=30) as ctx:
for step in range(30):
# ... denoising step ...
ctx.advance_step()
"""
def __init__(self, num_steps: int, description: str = "Generating"):
"""Initialize standalone sampling progress.
Args:
num_steps: Total number of denoising steps
description: Description to show in progress bar
"""
self._num_steps = num_steps
self._description = description
self._progress: Progress | None = None
self._task: TaskID | None = None
def __enter__(self) -> "StandaloneSamplingProgress":
"""Start the progress display."""
self._progress = Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(bar_width=40, style="blue"),
TextColumn("{task.fields[info]}", style="cyan"),
TimeElapsedColumn(),
TextColumn("ETA:"),
TimeRemainingColumn(compact=True),
)
self._progress.__enter__()
self._task = self._progress.add_task(
self._description,
total=self._num_steps,
info=f"step 0/{self._num_steps}",
)
return self
def __exit__(self, *args) -> None:
"""Stop the progress display."""
if self._progress is not None:
self._progress.__exit__(*args)
def advance_step(self) -> None:
"""Advance the denoising step by one."""
if self._progress is None or self._task is None:
return
self._progress.advance(self._task)
completed = int(self._progress.tasks[self._task].completed)
self._progress.update(self._task, info=f"step {completed}/{self._num_steps}")
class TrainingProgress:
"""Manages Rich progress display for training and validation.
This class encapsulates all progress bar logic, providing a clean interface
for the trainer to update progress without dealing with Rich internals.
Usage:
with TrainingProgress(enabled=True, total_steps=1000) as progress:
for step in range(1000):
# ... training step ...
progress.update_training(loss=0.1, lr=1e-4, step_time=0.5)
if should_validate:
sampling_ctx = progress.start_sampling(num_prompts=3, num_steps=30)
sampler = ValidationSampler(..., sampling_context=sampling_ctx)
for prompt_idx, prompt in enumerate(prompts):
sampling_ctx.start_video(prompt_idx)
sampler.generate(...)
sampling_ctx.cleanup()
"""
def __init__(self, enabled: bool, total_steps: int):
"""Initialize progress tracking.
Args:
enabled: Whether to display progress bars (False for non-main processes)
total_steps: Total number of training steps
"""
self._enabled = enabled
self._total_steps = total_steps
self._train_task: TaskID | None = None
if not enabled:
self._progress = None
return
# Single Progress instance with flexible columns
self._progress = Progress(
TextColumn("[progress.description]{task.description}"),
TextColumn("{task.fields[video]}", style="magenta"),
BarColumn(bar_width=40, style="blue"),
TextColumn("{task.fields[info]}", style="cyan"),
TimeElapsedColumn(),
TextColumn("ETA:"),
TimeRemainingColumn(compact=True),
)
def __enter__(self) -> "TrainingProgress":
"""Enter the progress context, starting the live display."""
if self._progress is not None:
self._progress.__enter__()
self._train_task = self._progress.add_task(
"Training",
total=self._total_steps,
video=f"0/{self._total_steps}",
info="Starting...",
)
return self
def __exit__(self, *args) -> None:
"""Exit the progress context, stopping the live display."""
if self._progress is not None:
self._progress.__exit__(*args)
@property
def enabled(self) -> bool:
"""Whether progress display is enabled."""
return self._enabled
def update_training(
self,
*,
loss: float,
lr: float,
step_time: float,
advance: bool = True,
) -> None:
"""Update the training progress display.
Args:
loss: Current training loss
lr: Current learning rate
step_time: Time taken for this step in seconds
advance: Whether to advance the progress by one step
"""
if self._progress is None or self._train_task is None:
return
info = f"Loss: {loss:.4f} | LR: {lr:.2e} | {step_time:.2f}s/step"
self._progress.update(
self._train_task,
advance=1 if advance else 0,
info=info,
)
# Update step count in video column
completed = int(self._progress.tasks[self._train_task].completed)
self._progress.update(self._train_task, video=f"{completed}/{self._total_steps}")
def start_sampling(self, num_prompts: int, num_steps: int) -> SamplingContext:
"""Start validation sampling progress tracking.
Creates a task that shows current video and denoising step progress.
Format: "Sampling X/Y [β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ] step Z/W"
Args:
num_prompts: Number of validation prompts to sample
num_steps: Number of denoising steps per sample
Returns:
SamplingContext for tracking progress (no-op if progress is disabled)
"""
if self._progress is None:
# Return a no-op context when progress is disabled
return SamplingContext(
progress=None,
task=None,
num_prompts=num_prompts,
num_steps=num_steps,
)
task = self._progress.add_task(
"Sampling",
total=num_steps,
completed=0,
video=f"0/{num_prompts}",
info=f"step 0/{num_steps}",
)
return SamplingContext(
progress=self._progress,
task=task,
num_prompts=num_prompts,
num_steps=num_steps,
)