ltx-2 / packages /ltx-trainer /scripts /compute_reference.py
linoy
inital commit
ebfc6b3
"""
Compute reference videos for IC-LoRA training.
This script provides a command-line interface for generating reference videos to be used for IC-LoRA training.
Note that it reads and writes to the same file (the output of caption_videos.py),
where it adds the "reference_path" field to the JSON.
Basic usage:
# Compute reference videos for all videos in a directory
compute_reference.py videos_dir/ --output videos_dir/captions.json
"""
# Standard library imports
import json
from pathlib import Path
from typing import Dict
# Third-party imports
import cv2
import torch
import torchvision.transforms.functional as TF # noqa: N812
import typer
from rich.console import Console
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from transformers.utils.logging import disable_progress_bar
# Local imports
from ltx_trainer.video_utils import read_video, save_video
# Initialize console and disable progress bars
console = Console()
disable_progress_bar()
def compute_reference(
images: torch.Tensor,
) -> torch.Tensor:
"""Compute Canny edge detection on a batch of images.
Args:
images: Batch of images tensor of shape [B, C, H, W]
Returns:
Binary edge masks tensor of shape [B, H, W]
"""
# Convert to grayscale if needed
if images.shape[1] == 3:
images = TF.rgb_to_grayscale(images)
# Ensure images are in [0, 1] range
if images.max() > 1.0:
images = images / 255.0
# Compute Canny edges
edge_masks = []
for image in images:
# Convert to numpy for OpenCV
image_np = (image.squeeze().cpu().numpy() * 255).astype("uint8")
# Apply Canny edge detection
edges = cv2.Canny(
image_np,
threshold1=100,
threshold2=200,
)
# Convert back to tensor
edge_mask = torch.from_numpy(edges).float()
edge_masks.append(edge_mask)
edges = torch.stack(edge_masks)
edges = torch.stack([edges] * 3, dim=1) # Convert to 3-channel
return edges
def _get_meta_data(
output_path: Path,
) -> Dict[str, str]:
"""Get set of existing reference video paths without loading the actual files.
Args:
output_path: Path to the reference video paths file
Returns:
Dictionary mapping media paths to reference video paths
"""
if not output_path.exists():
return {}
console.print(f"[bold blue]Reading meta data from [cyan]{output_path}[/]...[/]")
try:
with output_path.open("r", encoding="utf-8") as f:
json_data = json.load(f)
return json_data
except Exception as e:
console.print(f"[bold yellow]Warning: Could not check meta data: {e}[/]")
return {}
def _save_dataset_json(
reference_paths: Dict[str, str],
output_path: Path,
) -> None:
"""Save dataset json with reference video paths.
Args:
reference_paths: Dictionary mapping media paths to reference video paths
output_path: Path to save the output file
"""
with output_path.open("r", encoding="utf-8") as f:
json_data = json.load(f)
new_json_data = json_data.copy()
for i, item in enumerate(json_data):
media_path = item["media_path"]
reference_path = reference_paths[media_path]
new_json_data[i]["reference_path"] = reference_path
with output_path.open("w", encoding="utf-8") as f:
json.dump(new_json_data, f, indent=2, ensure_ascii=False)
console.print(f"[bold green]✓[/] Reference video paths saved to [cyan]{output_path}[/]")
console.print("[bold yellow]Note:[/] Use these files with ImageOrVideoDataset by setting:")
console.print(" reference_column='[cyan]reference_path[/]'")
console.print(" video_column='[cyan]media_path[/]'")
def process_media(
input_path: Path,
output_path: Path,
override: bool,
batch_size: int = 100,
) -> None:
"""Process videos and images to compute condition on videos.
Args:
input_path: Path to input video/image file or directory
output_path: Path to output reference video file
override: Whether to override existing reference video files
"""
if not output_path.exists():
raise FileNotFoundError(
f"Output file does not exist: {output_path}. This is also the input file for the dataset."
)
# Check for existing reference video files
meta_data = _get_meta_data(output_path)
base_dir = input_path.resolve()
console.print(f"Using [bold blue]{base_dir}[/] as base directory for relative paths")
# Filter media files
media_to_process = []
skipped_media = []
def media_path_to_reference_path(media_file: Path) -> Path:
return media_file.parent / (media_file.stem + "_reference" + media_file.suffix)
media_files = [base_dir / Path(sample["media_path"]) for sample in meta_data]
for media_file in media_files:
reference_path = media_path_to_reference_path(media_file)
media_to_process.append(media_file)
console.print(f"Processing [bold]{len(media_to_process)}[/] media.")
# Initialize progress tracking
progress = Progress(
SpinnerColumn(),
TextColumn("{task.description}"),
BarColumn(bar_width=40),
MofNCompleteColumn(),
TimeElapsedColumn(),
TextColumn("•"),
TimeRemainingColumn(),
console=console,
)
# Process media files
media_paths = [item["media_path"] for item in meta_data]
reference_paths = {rel_path: str(media_path_to_reference_path(Path(rel_path))) for rel_path in media_paths}
with progress:
task = progress.add_task("Computing condition on videos", total=len(media_to_process))
for media_file in media_to_process:
progress.update(task, description=f"Processing [bold blue]{media_file.name}[/]")
rel_path = str(media_file.resolve().relative_to(base_dir))
reference_path = media_path_to_reference_path(media_file)
reference_paths[rel_path] = str(reference_path.relative_to(base_dir))
if not reference_path.resolve().exists() or override:
try:
video, fps = read_video(media_file)
# Process frames in batches
condition_frames = []
for i in range(0, len(video), batch_size):
batch = video[i : i + batch_size]
condition_batch = compute_reference(batch)
condition_frames.append(condition_batch)
# Concatenate all edge frames
all_condition = torch.cat(condition_frames, dim=0)
# Save the edge video
save_video(all_condition, reference_path.resolve(), fps=fps)
except Exception as e:
console.print(f"[bold red]Error processing [bold blue]{media_file}[/]: {e}[/]")
reference_paths.pop(rel_path)
else:
skipped_media.append(media_file)
progress.advance(task)
# Save results
_save_dataset_json(reference_paths, output_path)
# Print summary
total_to_process = len(media_files) - len(skipped_media)
console.print(
f"[bold green]✓[/] Processed [bold]{total_to_process}/{len(media_files)}[/] media successfully.",
)
app = typer.Typer(
pretty_exceptions_enable=False,
no_args_is_help=True,
help="Compute reference videos for IC-LoRA training.",
)
@app.command()
def main(
input_path: Path = typer.Argument( # noqa: B008
...,
help="Path to input video/image file or directory containing media files",
exists=True,
),
output: Path | None = typer.Option( # noqa: B008
None,
"--output",
"-o",
help="Path to json output file for reference video paths. "
"This is also the input file for the dataset, the output of compute_captions.py.",
),
override: bool = typer.Option(
False,
"--override",
help="Whether to override existing reference video files",
),
batch_size: int = typer.Option(
100,
"--batch-size",
help="Batch size for processing videos",
),
) -> None:
"""Compute reference videos for IC-LoRA training.
This script generates reference videos (e.g., Canny edge maps) for given videos.
The paths in the output file will be relative to the output file's directory.
Examples:
# Process all videos in a directory
compute_reference.py videos_dir/ -o videos_dir/captions.json
"""
# Ensure output path is absolute
output = Path(output).resolve()
console.print(f"Output will be saved to [bold blue]{output}[/]")
# Verify output path exists
if not output.exists():
raise FileNotFoundError(f"Output file does not exist: {output}. This is also the input file for the dataset.")
# Process media files
process_media(
input_path=input_path,
output_path=output,
override=override,
batch_size=batch_size,
)
if __name__ == "__main__":
app()