|
|
"""Progress tracking for LTX training. |
|
|
|
|
|
This module provides a unified progress display for training and validation sampling, |
|
|
encapsulating all Rich progress bar logic in one place. |
|
|
""" |
|
|
|
|
|
from rich.progress import ( |
|
|
BarColumn, |
|
|
Progress, |
|
|
TaskID, |
|
|
TextColumn, |
|
|
TimeElapsedColumn, |
|
|
TimeRemainingColumn, |
|
|
) |
|
|
|
|
|
|
|
|
class SamplingContext: |
|
|
"""Context for validation sampling progress tracking. |
|
|
|
|
|
Provides a unified progress display showing current video and denoising step. |
|
|
Display format: "Sampling X/Y [ββββββββββββ] step Z/W" |
|
|
The progress bar shows the denoising progress for the current video. |
|
|
""" |
|
|
|
|
|
def __init__(self, progress: Progress | None, task: TaskID | None, num_prompts: int, num_steps: int): |
|
|
self._progress = progress |
|
|
self._task = task |
|
|
self._num_prompts = num_prompts |
|
|
self._num_steps = num_steps |
|
|
|
|
|
def start_video(self, video_idx: int) -> None: |
|
|
"""Start tracking a new video (resets step progress).""" |
|
|
if self._progress is None or self._task is None: |
|
|
return |
|
|
|
|
|
self._progress.reset(self._task, total=self._num_steps) |
|
|
self._progress.update( |
|
|
self._task, |
|
|
completed=0, |
|
|
video=f"{video_idx + 1}/{self._num_prompts}", |
|
|
info=f"step 0/{self._num_steps}", |
|
|
) |
|
|
|
|
|
def advance_step(self) -> None: |
|
|
"""Advance the denoising step by one.""" |
|
|
if self._progress is None or self._task is None: |
|
|
return |
|
|
self._progress.advance(self._task) |
|
|
completed = int(self._progress.tasks[self._task].completed) |
|
|
self._progress.update(self._task, info=f"step {completed}/{self._num_steps}") |
|
|
|
|
|
def cleanup(self) -> None: |
|
|
"""Hide sampling task when done.""" |
|
|
if self._progress is None or self._task is None: |
|
|
return |
|
|
self._progress.update(self._task, visible=False) |
|
|
|
|
|
|
|
|
class StandaloneSamplingProgress: |
|
|
"""Standalone progress display for inference scripts. |
|
|
|
|
|
Unlike SamplingContext (which integrates with TrainingProgress), this class |
|
|
manages its own Rich Progress instance for use in standalone inference scripts. |
|
|
|
|
|
Usage: |
|
|
with StandaloneSamplingProgress(num_steps=30) as ctx: |
|
|
for step in range(30): |
|
|
# ... denoising step ... |
|
|
ctx.advance_step() |
|
|
""" |
|
|
|
|
|
def __init__(self, num_steps: int, description: str = "Generating"): |
|
|
"""Initialize standalone sampling progress. |
|
|
|
|
|
Args: |
|
|
num_steps: Total number of denoising steps |
|
|
description: Description to show in progress bar |
|
|
""" |
|
|
self._num_steps = num_steps |
|
|
self._description = description |
|
|
self._progress: Progress | None = None |
|
|
self._task: TaskID | None = None |
|
|
|
|
|
def __enter__(self) -> "StandaloneSamplingProgress": |
|
|
"""Start the progress display.""" |
|
|
self._progress = Progress( |
|
|
TextColumn("[progress.description]{task.description}"), |
|
|
BarColumn(bar_width=40, style="blue"), |
|
|
TextColumn("{task.fields[info]}", style="cyan"), |
|
|
TimeElapsedColumn(), |
|
|
TextColumn("ETA:"), |
|
|
TimeRemainingColumn(compact=True), |
|
|
) |
|
|
self._progress.__enter__() |
|
|
self._task = self._progress.add_task( |
|
|
self._description, |
|
|
total=self._num_steps, |
|
|
info=f"step 0/{self._num_steps}", |
|
|
) |
|
|
return self |
|
|
|
|
|
def __exit__(self, *args) -> None: |
|
|
"""Stop the progress display.""" |
|
|
if self._progress is not None: |
|
|
self._progress.__exit__(*args) |
|
|
|
|
|
def advance_step(self) -> None: |
|
|
"""Advance the denoising step by one.""" |
|
|
if self._progress is None or self._task is None: |
|
|
return |
|
|
self._progress.advance(self._task) |
|
|
completed = int(self._progress.tasks[self._task].completed) |
|
|
self._progress.update(self._task, info=f"step {completed}/{self._num_steps}") |
|
|
|
|
|
|
|
|
class TrainingProgress: |
|
|
"""Manages Rich progress display for training and validation. |
|
|
|
|
|
This class encapsulates all progress bar logic, providing a clean interface |
|
|
for the trainer to update progress without dealing with Rich internals. |
|
|
|
|
|
Usage: |
|
|
with TrainingProgress(enabled=True, total_steps=1000) as progress: |
|
|
for step in range(1000): |
|
|
# ... training step ... |
|
|
progress.update_training(loss=0.1, lr=1e-4, step_time=0.5) |
|
|
|
|
|
if should_validate: |
|
|
sampling_ctx = progress.start_sampling(num_prompts=3, num_steps=30) |
|
|
sampler = ValidationSampler(..., sampling_context=sampling_ctx) |
|
|
for prompt_idx, prompt in enumerate(prompts): |
|
|
sampling_ctx.start_video(prompt_idx) |
|
|
sampler.generate(...) |
|
|
sampling_ctx.cleanup() |
|
|
""" |
|
|
|
|
|
def __init__(self, enabled: bool, total_steps: int): |
|
|
"""Initialize progress tracking. |
|
|
|
|
|
Args: |
|
|
enabled: Whether to display progress bars (False for non-main processes) |
|
|
total_steps: Total number of training steps |
|
|
""" |
|
|
self._enabled = enabled |
|
|
self._total_steps = total_steps |
|
|
self._train_task: TaskID | None = None |
|
|
|
|
|
if not enabled: |
|
|
self._progress = None |
|
|
return |
|
|
|
|
|
|
|
|
self._progress = Progress( |
|
|
TextColumn("[progress.description]{task.description}"), |
|
|
TextColumn("{task.fields[video]}", style="magenta"), |
|
|
BarColumn(bar_width=40, style="blue"), |
|
|
TextColumn("{task.fields[info]}", style="cyan"), |
|
|
TimeElapsedColumn(), |
|
|
TextColumn("ETA:"), |
|
|
TimeRemainingColumn(compact=True), |
|
|
) |
|
|
|
|
|
def __enter__(self) -> "TrainingProgress": |
|
|
"""Enter the progress context, starting the live display.""" |
|
|
if self._progress is not None: |
|
|
self._progress.__enter__() |
|
|
self._train_task = self._progress.add_task( |
|
|
"Training", |
|
|
total=self._total_steps, |
|
|
video=f"0/{self._total_steps}", |
|
|
info="Starting...", |
|
|
) |
|
|
return self |
|
|
|
|
|
def __exit__(self, *args) -> None: |
|
|
"""Exit the progress context, stopping the live display.""" |
|
|
if self._progress is not None: |
|
|
self._progress.__exit__(*args) |
|
|
|
|
|
@property |
|
|
def enabled(self) -> bool: |
|
|
"""Whether progress display is enabled.""" |
|
|
return self._enabled |
|
|
|
|
|
def update_training( |
|
|
self, |
|
|
*, |
|
|
loss: float, |
|
|
lr: float, |
|
|
step_time: float, |
|
|
advance: bool = True, |
|
|
) -> None: |
|
|
"""Update the training progress display. |
|
|
|
|
|
Args: |
|
|
loss: Current training loss |
|
|
lr: Current learning rate |
|
|
step_time: Time taken for this step in seconds |
|
|
advance: Whether to advance the progress by one step |
|
|
""" |
|
|
if self._progress is None or self._train_task is None: |
|
|
return |
|
|
|
|
|
info = f"Loss: {loss:.4f} | LR: {lr:.2e} | {step_time:.2f}s/step" |
|
|
self._progress.update( |
|
|
self._train_task, |
|
|
advance=1 if advance else 0, |
|
|
info=info, |
|
|
) |
|
|
|
|
|
completed = int(self._progress.tasks[self._train_task].completed) |
|
|
self._progress.update(self._train_task, video=f"{completed}/{self._total_steps}") |
|
|
|
|
|
def start_sampling(self, num_prompts: int, num_steps: int) -> SamplingContext: |
|
|
"""Start validation sampling progress tracking. |
|
|
|
|
|
Creates a task that shows current video and denoising step progress. |
|
|
Format: "Sampling X/Y [ββββββββββββ] step Z/W" |
|
|
|
|
|
Args: |
|
|
num_prompts: Number of validation prompts to sample |
|
|
num_steps: Number of denoising steps per sample |
|
|
|
|
|
Returns: |
|
|
SamplingContext for tracking progress (no-op if progress is disabled) |
|
|
""" |
|
|
if self._progress is None: |
|
|
|
|
|
return SamplingContext( |
|
|
progress=None, |
|
|
task=None, |
|
|
num_prompts=num_prompts, |
|
|
num_steps=num_steps, |
|
|
) |
|
|
|
|
|
task = self._progress.add_task( |
|
|
"Sampling", |
|
|
total=num_steps, |
|
|
completed=0, |
|
|
video=f"0/{num_prompts}", |
|
|
info=f"step 0/{num_steps}", |
|
|
) |
|
|
|
|
|
return SamplingContext( |
|
|
progress=self._progress, |
|
|
task=task, |
|
|
num_prompts=num_prompts, |
|
|
num_steps=num_steps, |
|
|
) |
|
|
|