Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Compute reference videos for IC-LoRA training. | |
| This script provides a command-line interface for generating reference videos to be used for IC-LoRA training. | |
| Note that it reads and writes to the same file (the output of caption_videos.py), | |
| where it adds the "reference_path" field to the JSON. | |
| Basic usage: | |
| # Compute reference videos for all videos in a directory | |
| compute_reference.py videos_dir/ --output videos_dir/captions.json | |
| """ | |
| # Standard library imports | |
| import json | |
| from pathlib import Path | |
| from typing import Dict | |
| # Third-party imports | |
| import cv2 | |
| import torch | |
| import torchvision.transforms.functional as TF # noqa: N812 | |
| import typer | |
| from rich.console import Console | |
| from rich.progress import ( | |
| BarColumn, | |
| MofNCompleteColumn, | |
| Progress, | |
| SpinnerColumn, | |
| TextColumn, | |
| TimeElapsedColumn, | |
| TimeRemainingColumn, | |
| ) | |
| from transformers.utils.logging import disable_progress_bar | |
| # Local imports | |
| from ltx_trainer.video_utils import read_video, save_video | |
| # Initialize console and disable progress bars | |
| console = Console() | |
| disable_progress_bar() | |
| def compute_reference( | |
| images: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Compute Canny edge detection on a batch of images. | |
| Args: | |
| images: Batch of images tensor of shape [B, C, H, W] | |
| Returns: | |
| Binary edge masks tensor of shape [B, H, W] | |
| """ | |
| # Convert to grayscale if needed | |
| if images.shape[1] == 3: | |
| images = TF.rgb_to_grayscale(images) | |
| # Ensure images are in [0, 1] range | |
| if images.max() > 1.0: | |
| images = images / 255.0 | |
| # Compute Canny edges | |
| edge_masks = [] | |
| for image in images: | |
| # Convert to numpy for OpenCV | |
| image_np = (image.squeeze().cpu().numpy() * 255).astype("uint8") | |
| # Apply Canny edge detection | |
| edges = cv2.Canny( | |
| image_np, | |
| threshold1=100, | |
| threshold2=200, | |
| ) | |
| # Convert back to tensor | |
| edge_mask = torch.from_numpy(edges).float() | |
| edge_masks.append(edge_mask) | |
| edges = torch.stack(edge_masks) | |
| edges = torch.stack([edges] * 3, dim=1) # Convert to 3-channel | |
| return edges | |
| def _get_meta_data( | |
| output_path: Path, | |
| ) -> Dict[str, str]: | |
| """Get set of existing reference video paths without loading the actual files. | |
| Args: | |
| output_path: Path to the reference video paths file | |
| Returns: | |
| Dictionary mapping media paths to reference video paths | |
| """ | |
| if not output_path.exists(): | |
| return {} | |
| console.print(f"[bold blue]Reading meta data from [cyan]{output_path}[/]...[/]") | |
| try: | |
| with output_path.open("r", encoding="utf-8") as f: | |
| json_data = json.load(f) | |
| return json_data | |
| except Exception as e: | |
| console.print(f"[bold yellow]Warning: Could not check meta data: {e}[/]") | |
| return {} | |
| def _save_dataset_json( | |
| reference_paths: Dict[str, str], | |
| output_path: Path, | |
| ) -> None: | |
| """Save dataset json with reference video paths. | |
| Args: | |
| reference_paths: Dictionary mapping media paths to reference video paths | |
| output_path: Path to save the output file | |
| """ | |
| with output_path.open("r", encoding="utf-8") as f: | |
| json_data = json.load(f) | |
| new_json_data = json_data.copy() | |
| for i, item in enumerate(json_data): | |
| media_path = item["media_path"] | |
| reference_path = reference_paths[media_path] | |
| new_json_data[i]["reference_path"] = reference_path | |
| with output_path.open("w", encoding="utf-8") as f: | |
| json.dump(new_json_data, f, indent=2, ensure_ascii=False) | |
| console.print(f"[bold green]✓[/] Reference video paths saved to [cyan]{output_path}[/]") | |
| console.print("[bold yellow]Note:[/] Use these files with ImageOrVideoDataset by setting:") | |
| console.print(" reference_column='[cyan]reference_path[/]'") | |
| console.print(" video_column='[cyan]media_path[/]'") | |
| def process_media( | |
| input_path: Path, | |
| output_path: Path, | |
| override: bool, | |
| batch_size: int = 100, | |
| ) -> None: | |
| """Process videos and images to compute condition on videos. | |
| Args: | |
| input_path: Path to input video/image file or directory | |
| output_path: Path to output reference video file | |
| override: Whether to override existing reference video files | |
| """ | |
| if not output_path.exists(): | |
| raise FileNotFoundError( | |
| f"Output file does not exist: {output_path}. This is also the input file for the dataset." | |
| ) | |
| # Check for existing reference video files | |
| meta_data = _get_meta_data(output_path) | |
| base_dir = input_path.resolve() | |
| console.print(f"Using [bold blue]{base_dir}[/] as base directory for relative paths") | |
| # Filter media files | |
| media_to_process = [] | |
| skipped_media = [] | |
| def media_path_to_reference_path(media_file: Path) -> Path: | |
| return media_file.parent / (media_file.stem + "_reference" + media_file.suffix) | |
| media_files = [base_dir / Path(sample["media_path"]) for sample in meta_data] | |
| for media_file in media_files: | |
| reference_path = media_path_to_reference_path(media_file) | |
| media_to_process.append(media_file) | |
| console.print(f"Processing [bold]{len(media_to_process)}[/] media.") | |
| # Initialize progress tracking | |
| progress = Progress( | |
| SpinnerColumn(), | |
| TextColumn("{task.description}"), | |
| BarColumn(bar_width=40), | |
| MofNCompleteColumn(), | |
| TimeElapsedColumn(), | |
| TextColumn("•"), | |
| TimeRemainingColumn(), | |
| console=console, | |
| ) | |
| # Process media files | |
| media_paths = [item["media_path"] for item in meta_data] | |
| reference_paths = {rel_path: str(media_path_to_reference_path(Path(rel_path))) for rel_path in media_paths} | |
| with progress: | |
| task = progress.add_task("Computing condition on videos", total=len(media_to_process)) | |
| for media_file in media_to_process: | |
| progress.update(task, description=f"Processing [bold blue]{media_file.name}[/]") | |
| rel_path = str(media_file.resolve().relative_to(base_dir)) | |
| reference_path = media_path_to_reference_path(media_file) | |
| reference_paths[rel_path] = str(reference_path.relative_to(base_dir)) | |
| if not reference_path.resolve().exists() or override: | |
| try: | |
| video, fps = read_video(media_file) | |
| # Process frames in batches | |
| condition_frames = [] | |
| for i in range(0, len(video), batch_size): | |
| batch = video[i : i + batch_size] | |
| condition_batch = compute_reference(batch) | |
| condition_frames.append(condition_batch) | |
| # Concatenate all edge frames | |
| all_condition = torch.cat(condition_frames, dim=0) | |
| # Save the edge video | |
| save_video(all_condition, reference_path.resolve(), fps=fps) | |
| except Exception as e: | |
| console.print(f"[bold red]Error processing [bold blue]{media_file}[/]: {e}[/]") | |
| reference_paths.pop(rel_path) | |
| else: | |
| skipped_media.append(media_file) | |
| progress.advance(task) | |
| # Save results | |
| _save_dataset_json(reference_paths, output_path) | |
| # Print summary | |
| total_to_process = len(media_files) - len(skipped_media) | |
| console.print( | |
| f"[bold green]✓[/] Processed [bold]{total_to_process}/{len(media_files)}[/] media successfully.", | |
| ) | |
| app = typer.Typer( | |
| pretty_exceptions_enable=False, | |
| no_args_is_help=True, | |
| help="Compute reference videos for IC-LoRA training.", | |
| ) | |
| def main( | |
| input_path: Path = typer.Argument( # noqa: B008 | |
| ..., | |
| help="Path to input video/image file or directory containing media files", | |
| exists=True, | |
| ), | |
| output: Path | None = typer.Option( # noqa: B008 | |
| None, | |
| "--output", | |
| "-o", | |
| help="Path to json output file for reference video paths. " | |
| "This is also the input file for the dataset, the output of compute_captions.py.", | |
| ), | |
| override: bool = typer.Option( | |
| False, | |
| "--override", | |
| help="Whether to override existing reference video files", | |
| ), | |
| batch_size: int = typer.Option( | |
| 100, | |
| "--batch-size", | |
| help="Batch size for processing videos", | |
| ), | |
| ) -> None: | |
| """Compute reference videos for IC-LoRA training. | |
| This script generates reference videos (e.g., Canny edge maps) for given videos. | |
| The paths in the output file will be relative to the output file's directory. | |
| Examples: | |
| # Process all videos in a directory | |
| compute_reference.py videos_dir/ -o videos_dir/captions.json | |
| """ | |
| # Ensure output path is absolute | |
| output = Path(output).resolve() | |
| console.print(f"Output will be saved to [bold blue]{output}[/]") | |
| # Verify output path exists | |
| if not output.exists(): | |
| raise FileNotFoundError(f"Output file does not exist: {output}. This is also the input file for the dataset.") | |
| # Process media files | |
| process_media( | |
| input_path=input_path, | |
| output_path=output, | |
| override=override, | |
| batch_size=batch_size, | |
| ) | |
| if __name__ == "__main__": | |
| app() | |