""" Compute reference videos for IC-LoRA training. This script provides a command-line interface for generating reference videos to be used for IC-LoRA training. Note that it reads and writes to the same file (the output of caption_videos.py), where it adds the "reference_path" field to the JSON. Basic usage: # Compute reference videos for all videos in a directory compute_reference.py videos_dir/ --output videos_dir/captions.json """ # Standard library imports import json from pathlib import Path from typing import Dict # Third-party imports import cv2 import torch import torchvision.transforms.functional as TF # noqa: N812 import typer from rich.console import Console from rich.progress import ( BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn, ) from transformers.utils.logging import disable_progress_bar # Local imports from ltx_trainer.video_utils import read_video, save_video # Initialize console and disable progress bars console = Console() disable_progress_bar() def compute_reference( images: torch.Tensor, ) -> torch.Tensor: """Compute Canny edge detection on a batch of images. Args: images: Batch of images tensor of shape [B, C, H, W] Returns: Binary edge masks tensor of shape [B, H, W] """ # Convert to grayscale if needed if images.shape[1] == 3: images = TF.rgb_to_grayscale(images) # Ensure images are in [0, 1] range if images.max() > 1.0: images = images / 255.0 # Compute Canny edges edge_masks = [] for image in images: # Convert to numpy for OpenCV image_np = (image.squeeze().cpu().numpy() * 255).astype("uint8") # Apply Canny edge detection edges = cv2.Canny( image_np, threshold1=100, threshold2=200, ) # Convert back to tensor edge_mask = torch.from_numpy(edges).float() edge_masks.append(edge_mask) edges = torch.stack(edge_masks) edges = torch.stack([edges] * 3, dim=1) # Convert to 3-channel return edges def _get_meta_data( output_path: Path, ) -> Dict[str, str]: """Get set of existing reference video paths without loading the actual files. Args: output_path: Path to the reference video paths file Returns: Dictionary mapping media paths to reference video paths """ if not output_path.exists(): return {} console.print(f"[bold blue]Reading meta data from [cyan]{output_path}[/]...[/]") try: with output_path.open("r", encoding="utf-8") as f: json_data = json.load(f) return json_data except Exception as e: console.print(f"[bold yellow]Warning: Could not check meta data: {e}[/]") return {} def _save_dataset_json( reference_paths: Dict[str, str], output_path: Path, ) -> None: """Save dataset json with reference video paths. Args: reference_paths: Dictionary mapping media paths to reference video paths output_path: Path to save the output file """ with output_path.open("r", encoding="utf-8") as f: json_data = json.load(f) new_json_data = json_data.copy() for i, item in enumerate(json_data): media_path = item["media_path"] reference_path = reference_paths[media_path] new_json_data[i]["reference_path"] = reference_path with output_path.open("w", encoding="utf-8") as f: json.dump(new_json_data, f, indent=2, ensure_ascii=False) console.print(f"[bold green]✓[/] Reference video paths saved to [cyan]{output_path}[/]") console.print("[bold yellow]Note:[/] Use these files with ImageOrVideoDataset by setting:") console.print(" reference_column='[cyan]reference_path[/]'") console.print(" video_column='[cyan]media_path[/]'") def process_media( input_path: Path, output_path: Path, override: bool, batch_size: int = 100, ) -> None: """Process videos and images to compute condition on videos. Args: input_path: Path to input video/image file or directory output_path: Path to output reference video file override: Whether to override existing reference video files """ if not output_path.exists(): raise FileNotFoundError( f"Output file does not exist: {output_path}. This is also the input file for the dataset." ) # Check for existing reference video files meta_data = _get_meta_data(output_path) base_dir = input_path.resolve() console.print(f"Using [bold blue]{base_dir}[/] as base directory for relative paths") # Filter media files media_to_process = [] skipped_media = [] def media_path_to_reference_path(media_file: Path) -> Path: return media_file.parent / (media_file.stem + "_reference" + media_file.suffix) media_files = [base_dir / Path(sample["media_path"]) for sample in meta_data] for media_file in media_files: reference_path = media_path_to_reference_path(media_file) media_to_process.append(media_file) console.print(f"Processing [bold]{len(media_to_process)}[/] media.") # Initialize progress tracking progress = Progress( SpinnerColumn(), TextColumn("{task.description}"), BarColumn(bar_width=40), MofNCompleteColumn(), TimeElapsedColumn(), TextColumn("•"), TimeRemainingColumn(), console=console, ) # Process media files media_paths = [item["media_path"] for item in meta_data] reference_paths = {rel_path: str(media_path_to_reference_path(Path(rel_path))) for rel_path in media_paths} with progress: task = progress.add_task("Computing condition on videos", total=len(media_to_process)) for media_file in media_to_process: progress.update(task, description=f"Processing [bold blue]{media_file.name}[/]") rel_path = str(media_file.resolve().relative_to(base_dir)) reference_path = media_path_to_reference_path(media_file) reference_paths[rel_path] = str(reference_path.relative_to(base_dir)) if not reference_path.resolve().exists() or override: try: video, fps = read_video(media_file) # Process frames in batches condition_frames = [] for i in range(0, len(video), batch_size): batch = video[i : i + batch_size] condition_batch = compute_reference(batch) condition_frames.append(condition_batch) # Concatenate all edge frames all_condition = torch.cat(condition_frames, dim=0) # Save the edge video save_video(all_condition, reference_path.resolve(), fps=fps) except Exception as e: console.print(f"[bold red]Error processing [bold blue]{media_file}[/]: {e}[/]") reference_paths.pop(rel_path) else: skipped_media.append(media_file) progress.advance(task) # Save results _save_dataset_json(reference_paths, output_path) # Print summary total_to_process = len(media_files) - len(skipped_media) console.print( f"[bold green]✓[/] Processed [bold]{total_to_process}/{len(media_files)}[/] media successfully.", ) app = typer.Typer( pretty_exceptions_enable=False, no_args_is_help=True, help="Compute reference videos for IC-LoRA training.", ) @app.command() def main( input_path: Path = typer.Argument( # noqa: B008 ..., help="Path to input video/image file or directory containing media files", exists=True, ), output: Path | None = typer.Option( # noqa: B008 None, "--output", "-o", help="Path to json output file for reference video paths. " "This is also the input file for the dataset, the output of compute_captions.py.", ), override: bool = typer.Option( False, "--override", help="Whether to override existing reference video files", ), batch_size: int = typer.Option( 100, "--batch-size", help="Batch size for processing videos", ), ) -> None: """Compute reference videos for IC-LoRA training. This script generates reference videos (e.g., Canny edge maps) for given videos. The paths in the output file will be relative to the output file's directory. Examples: # Process all videos in a directory compute_reference.py videos_dir/ -o videos_dir/captions.json """ # Ensure output path is absolute output = Path(output).resolve() console.print(f"Output will be saved to [bold blue]{output}[/]") # Verify output path exists if not output.exists(): raise FileNotFoundError(f"Output file does not exist: {output}. This is also the input file for the dataset.") # Process media files process_media( input_path=input_path, output_path=output, override=override, batch_size=batch_size, ) if __name__ == "__main__": app()