File size: 9,350 Bytes
ebfc6b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
"""
Compute reference videos for IC-LoRA training.

This script provides a command-line interface for generating reference videos to be used for IC-LoRA training.
Note that it reads and writes to the same file (the output of caption_videos.py),
where it adds the "reference_path" field to the JSON.

Basic usage:
    # Compute reference videos for all videos in a directory
    compute_reference.py videos_dir/ --output videos_dir/captions.json
"""

# Standard library imports
import json
from pathlib import Path
from typing import Dict

# Third-party imports
import cv2
import torch
import torchvision.transforms.functional as TF  # noqa: N812
import typer
from rich.console import Console
from rich.progress import (
    BarColumn,
    MofNCompleteColumn,
    Progress,
    SpinnerColumn,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
)
from transformers.utils.logging import disable_progress_bar

# Local imports
from ltx_trainer.video_utils import read_video, save_video

# Initialize console and disable progress bars
console = Console()
disable_progress_bar()


def compute_reference(
    images: torch.Tensor,
) -> torch.Tensor:
    """Compute Canny edge detection on a batch of images.

    Args:
        images: Batch of images tensor of shape [B, C, H, W]

    Returns:
        Binary edge masks tensor of shape [B, H, W]
    """
    # Convert to grayscale if needed
    if images.shape[1] == 3:
        images = TF.rgb_to_grayscale(images)

    # Ensure images are in [0, 1] range
    if images.max() > 1.0:
        images = images / 255.0

    # Compute Canny edges
    edge_masks = []
    for image in images:
        # Convert to numpy for OpenCV
        image_np = (image.squeeze().cpu().numpy() * 255).astype("uint8")

        # Apply Canny edge detection
        edges = cv2.Canny(
            image_np,
            threshold1=100,
            threshold2=200,
        )

        # Convert back to tensor
        edge_mask = torch.from_numpy(edges).float()
        edge_masks.append(edge_mask)

    edges = torch.stack(edge_masks)
    edges = torch.stack([edges] * 3, dim=1)  # Convert to 3-channel
    return edges


def _get_meta_data(
    output_path: Path,
) -> Dict[str, str]:
    """Get set of existing reference video paths without loading the actual files.

    Args:
        output_path: Path to the reference video paths file

    Returns:
        Dictionary mapping media paths to reference video paths
    """
    if not output_path.exists():
        return {}

    console.print(f"[bold blue]Reading meta data from [cyan]{output_path}[/]...[/]")

    try:
        with output_path.open("r", encoding="utf-8") as f:
            json_data = json.load(f)
        return json_data

    except Exception as e:
        console.print(f"[bold yellow]Warning: Could not check meta data: {e}[/]")
        return {}


def _save_dataset_json(
    reference_paths: Dict[str, str],
    output_path: Path,
) -> None:
    """Save dataset json with reference video paths.

    Args:
        reference_paths: Dictionary mapping media paths to reference video paths
        output_path: Path to save the output file
    """

    with output_path.open("r", encoding="utf-8") as f:
        json_data = json.load(f)
        new_json_data = json_data.copy()
        for i, item in enumerate(json_data):
            media_path = item["media_path"]
            reference_path = reference_paths[media_path]
            new_json_data[i]["reference_path"] = reference_path

    with output_path.open("w", encoding="utf-8") as f:
        json.dump(new_json_data, f, indent=2, ensure_ascii=False)

    console.print(f"[bold green]✓[/] Reference video paths saved to [cyan]{output_path}[/]")
    console.print("[bold yellow]Note:[/] Use these files with ImageOrVideoDataset by setting:")
    console.print("  reference_column='[cyan]reference_path[/]'")
    console.print("  video_column='[cyan]media_path[/]'")


def process_media(
    input_path: Path,
    output_path: Path,
    override: bool,
    batch_size: int = 100,
) -> None:
    """Process videos and images to compute condition on videos.

    Args:
        input_path: Path to input video/image file or directory
        output_path: Path to output reference video file
        override: Whether to override existing reference video files
    """
    if not output_path.exists():
        raise FileNotFoundError(
            f"Output file does not exist: {output_path}. This is also the input file for the dataset."
        )

    # Check for existing reference video files
    meta_data = _get_meta_data(output_path)

    base_dir = input_path.resolve()
    console.print(f"Using [bold blue]{base_dir}[/] as base directory for relative paths")

    # Filter media files
    media_to_process = []
    skipped_media = []

    def media_path_to_reference_path(media_file: Path) -> Path:
        return media_file.parent / (media_file.stem + "_reference" + media_file.suffix)

    media_files = [base_dir / Path(sample["media_path"]) for sample in meta_data]
    for media_file in media_files:
        reference_path = media_path_to_reference_path(media_file)
        media_to_process.append(media_file)

    console.print(f"Processing [bold]{len(media_to_process)}[/] media.")

    # Initialize progress tracking
    progress = Progress(
        SpinnerColumn(),
        TextColumn("{task.description}"),
        BarColumn(bar_width=40),
        MofNCompleteColumn(),
        TimeElapsedColumn(),
        TextColumn("•"),
        TimeRemainingColumn(),
        console=console,
    )

    # Process media files
    media_paths = [item["media_path"] for item in meta_data]
    reference_paths = {rel_path: str(media_path_to_reference_path(Path(rel_path))) for rel_path in media_paths}

    with progress:
        task = progress.add_task("Computing condition on videos", total=len(media_to_process))

        for media_file in media_to_process:
            progress.update(task, description=f"Processing [bold blue]{media_file.name}[/]")

            rel_path = str(media_file.resolve().relative_to(base_dir))
            reference_path = media_path_to_reference_path(media_file)
            reference_paths[rel_path] = str(reference_path.relative_to(base_dir))

            if not reference_path.resolve().exists() or override:
                try:
                    video, fps = read_video(media_file)

                    # Process frames in batches
                    condition_frames = []

                    for i in range(0, len(video), batch_size):
                        batch = video[i : i + batch_size]
                        condition_batch = compute_reference(batch)
                        condition_frames.append(condition_batch)

                    # Concatenate all edge frames
                    all_condition = torch.cat(condition_frames, dim=0)

                    # Save the edge video
                    save_video(all_condition, reference_path.resolve(), fps=fps)

                except Exception as e:
                    console.print(f"[bold red]Error processing [bold blue]{media_file}[/]: {e}[/]")
                    reference_paths.pop(rel_path)
            else:
                skipped_media.append(media_file)

            progress.advance(task)

    # Save results
    _save_dataset_json(reference_paths, output_path)

    # Print summary
    total_to_process = len(media_files) - len(skipped_media)
    console.print(
        f"[bold green]✓[/] Processed [bold]{total_to_process}/{len(media_files)}[/] media successfully.",
    )


app = typer.Typer(
    pretty_exceptions_enable=False,
    no_args_is_help=True,
    help="Compute reference videos for IC-LoRA training.",
)


@app.command()
def main(
    input_path: Path = typer.Argument(  # noqa: B008
        ...,
        help="Path to input video/image file or directory containing media files",
        exists=True,
    ),
    output: Path | None = typer.Option(  # noqa: B008
        None,
        "--output",
        "-o",
        help="Path to json output file for reference video paths. "
        "This is also the input file for the dataset, the output of compute_captions.py.",
    ),
    override: bool = typer.Option(
        False,
        "--override",
        help="Whether to override existing reference video files",
    ),
    batch_size: int = typer.Option(
        100,
        "--batch-size",
        help="Batch size for processing videos",
    ),
) -> None:
    """Compute reference videos for IC-LoRA training.

    This script generates reference videos (e.g., Canny edge maps) for given videos.
    The paths in the output file will be relative to the output file's directory.

    Examples:
        # Process all videos in a directory
        compute_reference.py videos_dir/ -o videos_dir/captions.json
    """

    # Ensure output path is absolute
    output = Path(output).resolve()
    console.print(f"Output will be saved to [bold blue]{output}[/]")

    # Verify output path exists
    if not output.exists():
        raise FileNotFoundError(f"Output file does not exist: {output}. This is also the input file for the dataset.")

    # Process media files
    process_media(
        input_path=input_path,
        output_path=output,
        override=override,
        batch_size=batch_size,
    )


if __name__ == "__main__":
    app()