| """ |
| Compute reference videos for IC-LoRA training. |
| This script provides a command-line interface for generating reference videos to be used for IC-LoRA training. |
| Note that it reads and writes to the same file (the output of caption_videos.py), |
| where it adds the "reference_path" field to the JSON. |
| Basic usage: |
| # Compute reference videos for all videos in a directory |
| compute_reference.py videos_dir/ --output videos_dir/captions.json |
| """ |
|
|
| |
| import json |
| from pathlib import Path |
| from typing import Dict |
|
|
| |
| import cv2 |
| import torch |
| import torchvision.transforms.functional as TF |
| import typer |
| from rich.console import Console |
| from rich.progress import ( |
| BarColumn, |
| MofNCompleteColumn, |
| Progress, |
| SpinnerColumn, |
| TextColumn, |
| TimeElapsedColumn, |
| TimeRemainingColumn, |
| ) |
| from transformers.utils.logging import disable_progress_bar |
|
|
| |
| from ltx_trainer.video_utils import read_video, save_video |
|
|
| |
| console = Console() |
| disable_progress_bar() |
|
|
|
|
| def compute_reference( |
| images: torch.Tensor, |
| ) -> torch.Tensor: |
| """Compute Canny edge detection on a batch of images. |
| Args: |
| images: Batch of images tensor of shape [B, C, H, W] |
| Returns: |
| Binary edge masks tensor of shape [B, H, W] |
| """ |
| |
| if images.shape[1] == 3: |
| images = TF.rgb_to_grayscale(images) |
|
|
| |
| if images.max() > 1.0: |
| images = images / 255.0 |
|
|
| |
| edge_masks = [] |
| for image in images: |
| |
| image_np = (image.squeeze().cpu().numpy() * 255).astype("uint8") |
|
|
| |
| edges = cv2.Canny( |
| image_np, |
| threshold1=100, |
| threshold2=200, |
| ) |
|
|
| |
| edge_mask = torch.from_numpy(edges).float() |
| edge_masks.append(edge_mask) |
|
|
| edges = torch.stack(edge_masks) |
| edges = torch.stack([edges] * 3, dim=1) |
| return edges |
|
|
|
|
| def _get_meta_data( |
| output_path: Path, |
| ) -> Dict[str, str]: |
| """Get set of existing reference video paths without loading the actual files. |
| Args: |
| output_path: Path to the reference video paths file |
| Returns: |
| Dictionary mapping media paths to reference video paths |
| """ |
| if not output_path.exists(): |
| return {} |
|
|
| console.print(f"[bold blue]Reading meta data from [cyan]{output_path}[/]...[/]") |
|
|
| try: |
| with output_path.open("r", encoding="utf-8") as f: |
| json_data = json.load(f) |
| return json_data |
|
|
| except Exception as e: |
| console.print(f"[bold yellow]Warning: Could not check meta data: {e}[/]") |
| return {} |
|
|
|
|
| def _save_dataset_json( |
| reference_paths: Dict[str, str], |
| output_path: Path, |
| ) -> None: |
| """Save dataset json with reference video paths. |
| Args: |
| reference_paths: Dictionary mapping media paths to reference video paths |
| output_path: Path to save the output file |
| """ |
|
|
| with output_path.open("r", encoding="utf-8") as f: |
| json_data = json.load(f) |
| new_json_data = json_data.copy() |
| for i, item in enumerate(json_data): |
| media_path = item["media_path"] |
| reference_path = reference_paths[media_path] |
| new_json_data[i]["reference_path"] = reference_path |
|
|
| with output_path.open("w", encoding="utf-8") as f: |
| json.dump(new_json_data, f, indent=2, ensure_ascii=False) |
|
|
| console.print(f"[bold green]✓[/] Reference video paths saved to [cyan]{output_path}[/]") |
| console.print("[bold yellow]Note:[/] Use these files with ImageOrVideoDataset by setting:") |
| console.print(" reference_column='[cyan]reference_path[/]'") |
| console.print(" video_column='[cyan]media_path[/]'") |
|
|
|
|
| def process_media( |
| input_path: Path, |
| output_path: Path, |
| override: bool, |
| batch_size: int = 100, |
| ) -> None: |
| """Process videos and images to compute condition on videos. |
| Args: |
| input_path: Path to input video/image file or directory |
| output_path: Path to output reference video file |
| override: Whether to override existing reference video files |
| """ |
| if not output_path.exists(): |
| raise FileNotFoundError( |
| f"Output file does not exist: {output_path}. This is also the input file for the dataset." |
| ) |
|
|
| |
| meta_data = _get_meta_data(output_path) |
|
|
| base_dir = input_path.resolve() |
| console.print(f"Using [bold blue]{base_dir}[/] as base directory for relative paths") |
|
|
| |
| media_to_process = [] |
| skipped_media = [] |
|
|
| def media_path_to_reference_path(media_file: Path) -> Path: |
| return media_file.parent / (media_file.stem + "_reference" + media_file.suffix) |
|
|
| media_files = [base_dir / Path(sample["media_path"]) for sample in meta_data] |
| for media_file in media_files: |
| reference_path = media_path_to_reference_path(media_file) |
| media_to_process.append(media_file) |
|
|
| console.print(f"Processing [bold]{len(media_to_process)}[/] media.") |
|
|
| |
| progress = Progress( |
| SpinnerColumn(), |
| TextColumn("{task.description}"), |
| BarColumn(bar_width=40), |
| MofNCompleteColumn(), |
| TimeElapsedColumn(), |
| TextColumn("•"), |
| TimeRemainingColumn(), |
| console=console, |
| ) |
|
|
| |
| media_paths = [item["media_path"] for item in meta_data] |
| reference_paths = {rel_path: str(media_path_to_reference_path(Path(rel_path))) for rel_path in media_paths} |
|
|
| with progress: |
| task = progress.add_task("Computing condition on videos", total=len(media_to_process)) |
|
|
| for media_file in media_to_process: |
| progress.update(task, description=f"Processing [bold blue]{media_file.name}[/]") |
|
|
| rel_path = str(media_file.resolve().relative_to(base_dir)) |
| reference_path = media_path_to_reference_path(media_file) |
| reference_paths[rel_path] = str(reference_path.relative_to(base_dir)) |
|
|
| if not reference_path.resolve().exists() or override: |
| try: |
| video, fps = read_video(media_file) |
|
|
| |
| condition_frames = [] |
|
|
| for i in range(0, len(video), batch_size): |
| batch = video[i : i + batch_size] |
| condition_batch = compute_reference(batch) |
| condition_frames.append(condition_batch) |
|
|
| |
| all_condition = torch.cat(condition_frames, dim=0) |
|
|
| |
| save_video(all_condition, reference_path.resolve(), fps=fps) |
|
|
| except Exception as e: |
| console.print(f"[bold red]Error processing [bold blue]{media_file}[/]: {e}[/]") |
| reference_paths.pop(rel_path) |
| else: |
| skipped_media.append(media_file) |
|
|
| progress.advance(task) |
|
|
| |
| _save_dataset_json(reference_paths, output_path) |
|
|
| |
| total_to_process = len(media_files) - len(skipped_media) |
| console.print( |
| f"[bold green]✓[/] Processed [bold]{total_to_process}/{len(media_files)}[/] media successfully.", |
| ) |
|
|
|
|
| app = typer.Typer( |
| pretty_exceptions_enable=False, |
| no_args_is_help=True, |
| help="Compute reference videos for IC-LoRA training.", |
| ) |
|
|
|
|
| @app.command() |
| def main( |
| input_path: Path = typer.Argument( |
| ..., |
| help="Path to input video/image file or directory containing media files", |
| exists=True, |
| ), |
| output: Path | None = typer.Option( |
| None, |
| "--output", |
| "-o", |
| help="Path to json output file for reference video paths. " |
| "This is also the input file for the dataset, the output of compute_captions.py.", |
| ), |
| override: bool = typer.Option( |
| False, |
| "--override", |
| help="Whether to override existing reference video files", |
| ), |
| batch_size: int = typer.Option( |
| 100, |
| "--batch-size", |
| help="Batch size for processing videos", |
| ), |
| ) -> None: |
| """Compute reference videos for IC-LoRA training. |
| This script generates reference videos (e.g., Canny edge maps) for given videos. |
| The paths in the output file will be relative to the output file's directory. |
| Examples: |
| # Process all videos in a directory |
| compute_reference.py videos_dir/ -o videos_dir/captions.json |
| """ |
|
|
| |
| output = Path(output).resolve() |
| console.print(f"Output will be saved to [bold blue]{output}[/]") |
|
|
| |
| if not output.exists(): |
| raise FileNotFoundError(f"Output file does not exist: {output}. This is also the input file for the dataset.") |
|
|
| |
| process_media( |
| input_path=input_path, |
| output_path=output, |
| override=override, |
| batch_size=batch_size, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| app() |
|
|