File size: 8,797 Bytes
ebfc6b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""Progress tracking for LTX training.

This module provides a unified progress display for training and validation sampling,
encapsulating all Rich progress bar logic in one place.
"""

from rich.progress import (
    BarColumn,
    Progress,
    TaskID,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
)


class SamplingContext:
    """Context for validation sampling progress tracking.

    Provides a unified progress display showing current video and denoising step.
    Display format: "Sampling X/Y [β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ] step Z/W"
    The progress bar shows the denoising progress for the current video.
    """

    def __init__(self, progress: Progress | None, task: TaskID | None, num_prompts: int, num_steps: int):
        self._progress = progress
        self._task = task
        self._num_prompts = num_prompts
        self._num_steps = num_steps

    def start_video(self, video_idx: int) -> None:
        """Start tracking a new video (resets step progress)."""
        if self._progress is None or self._task is None:
            return
        # Reset task for new video: completed=0, total=num_steps
        self._progress.reset(self._task, total=self._num_steps)
        self._progress.update(
            self._task,
            completed=0,
            video=f"{video_idx + 1}/{self._num_prompts}",
            info=f"step 0/{self._num_steps}",
        )

    def advance_step(self) -> None:
        """Advance the denoising step by one."""
        if self._progress is None or self._task is None:
            return
        self._progress.advance(self._task)
        completed = int(self._progress.tasks[self._task].completed)
        self._progress.update(self._task, info=f"step {completed}/{self._num_steps}")

    def cleanup(self) -> None:
        """Hide sampling task when done."""
        if self._progress is None or self._task is None:
            return
        self._progress.update(self._task, visible=False)


class StandaloneSamplingProgress:
    """Standalone progress display for inference scripts.

    Unlike SamplingContext (which integrates with TrainingProgress), this class
    manages its own Rich Progress instance for use in standalone inference scripts.

    Usage:
        with StandaloneSamplingProgress(num_steps=30) as ctx:
            for step in range(30):
                # ... denoising step ...
                ctx.advance_step()
    """

    def __init__(self, num_steps: int, description: str = "Generating"):
        """Initialize standalone sampling progress.

        Args:
            num_steps: Total number of denoising steps
            description: Description to show in progress bar
        """
        self._num_steps = num_steps
        self._description = description
        self._progress: Progress | None = None
        self._task: TaskID | None = None

    def __enter__(self) -> "StandaloneSamplingProgress":
        """Start the progress display."""
        self._progress = Progress(
            TextColumn("[progress.description]{task.description}"),
            BarColumn(bar_width=40, style="blue"),
            TextColumn("{task.fields[info]}", style="cyan"),
            TimeElapsedColumn(),
            TextColumn("ETA:"),
            TimeRemainingColumn(compact=True),
        )
        self._progress.__enter__()
        self._task = self._progress.add_task(
            self._description,
            total=self._num_steps,
            info=f"step 0/{self._num_steps}",
        )
        return self

    def __exit__(self, *args) -> None:
        """Stop the progress display."""
        if self._progress is not None:
            self._progress.__exit__(*args)

    def advance_step(self) -> None:
        """Advance the denoising step by one."""
        if self._progress is None or self._task is None:
            return
        self._progress.advance(self._task)
        completed = int(self._progress.tasks[self._task].completed)
        self._progress.update(self._task, info=f"step {completed}/{self._num_steps}")


class TrainingProgress:
    """Manages Rich progress display for training and validation.

    This class encapsulates all progress bar logic, providing a clean interface
    for the trainer to update progress without dealing with Rich internals.

    Usage:
        with TrainingProgress(enabled=True, total_steps=1000) as progress:
            for step in range(1000):
                # ... training step ...
                progress.update_training(loss=0.1, lr=1e-4, step_time=0.5)

                if should_validate:
                    sampling_ctx = progress.start_sampling(num_prompts=3, num_steps=30)
                    sampler = ValidationSampler(..., sampling_context=sampling_ctx)
                    for prompt_idx, prompt in enumerate(prompts):
                        sampling_ctx.start_video(prompt_idx)
                        sampler.generate(...)
                    sampling_ctx.cleanup()
    """

    def __init__(self, enabled: bool, total_steps: int):
        """Initialize progress tracking.

        Args:
            enabled: Whether to display progress bars (False for non-main processes)
            total_steps: Total number of training steps
        """
        self._enabled = enabled
        self._total_steps = total_steps
        self._train_task: TaskID | None = None

        if not enabled:
            self._progress = None
            return

        # Single Progress instance with flexible columns
        self._progress = Progress(
            TextColumn("[progress.description]{task.description}"),
            TextColumn("{task.fields[video]}", style="magenta"),
            BarColumn(bar_width=40, style="blue"),
            TextColumn("{task.fields[info]}", style="cyan"),
            TimeElapsedColumn(),
            TextColumn("ETA:"),
            TimeRemainingColumn(compact=True),
        )

    def __enter__(self) -> "TrainingProgress":
        """Enter the progress context, starting the live display."""
        if self._progress is not None:
            self._progress.__enter__()
            self._train_task = self._progress.add_task(
                "Training",
                total=self._total_steps,
                video=f"0/{self._total_steps}",
                info="Starting...",
            )
        return self

    def __exit__(self, *args) -> None:
        """Exit the progress context, stopping the live display."""
        if self._progress is not None:
            self._progress.__exit__(*args)

    @property
    def enabled(self) -> bool:
        """Whether progress display is enabled."""
        return self._enabled

    def update_training(
        self,
        *,
        loss: float,
        lr: float,
        step_time: float,
        advance: bool = True,
    ) -> None:
        """Update the training progress display.

        Args:
            loss: Current training loss
            lr: Current learning rate
            step_time: Time taken for this step in seconds
            advance: Whether to advance the progress by one step
        """
        if self._progress is None or self._train_task is None:
            return

        info = f"Loss: {loss:.4f} | LR: {lr:.2e} | {step_time:.2f}s/step"
        self._progress.update(
            self._train_task,
            advance=1 if advance else 0,
            info=info,
        )
        # Update step count in video column
        completed = int(self._progress.tasks[self._train_task].completed)
        self._progress.update(self._train_task, video=f"{completed}/{self._total_steps}")

    def start_sampling(self, num_prompts: int, num_steps: int) -> SamplingContext:
        """Start validation sampling progress tracking.

        Creates a task that shows current video and denoising step progress.
        Format: "Sampling X/Y [β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ] step Z/W"

        Args:
            num_prompts: Number of validation prompts to sample
            num_steps: Number of denoising steps per sample

        Returns:
            SamplingContext for tracking progress (no-op if progress is disabled)
        """
        if self._progress is None:
            # Return a no-op context when progress is disabled
            return SamplingContext(
                progress=None,
                task=None,
                num_prompts=num_prompts,
                num_steps=num_steps,
            )

        task = self._progress.add_task(
            "Sampling",
            total=num_steps,
            completed=0,
            video=f"0/{num_prompts}",
            info=f"step 0/{num_steps}",
        )

        return SamplingContext(
            progress=self._progress,
            task=task,
            num_prompts=num_prompts,
            num_steps=num_steps,
        )