Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """A CLI to run CausalVideoTokenizer on plain videos based on torch.jit. | |
| Usage: | |
| python3 -m cosmos_predict1.tokenizer.inference.video_cli \ | |
| --video_pattern 'path/to/video/samples/*.mp4' \ | |
| --output_dir ./reconstructions \ | |
| --checkpoint_enc ./checkpoints/<model-name>/encoder.jit \ | |
| --checkpoint_dec ./checkpoints/<model-name>/decoder.jit | |
| Optionally, you can run the model in pure PyTorch mode: | |
| python3 -m cosmos_predict1.tokenizer.inference.video_cli \ | |
| --video_pattern 'path/to/video/samples/*.mp4' \ | |
| --mode=torch \ | |
| --tokenizer_type=CV \ | |
| --temporal_compression=4 \ | |
| --spatial_compression=8 \ | |
| --checkpoint_enc ./checkpoints/<model-name>/encoder.jit \ | |
| --checkpoint_dec ./checkpoints/<model-name>/decoder.jit | |
| """ | |
| import os | |
| import sys | |
| from argparse import ArgumentParser, Namespace | |
| from typing import Any | |
| import numpy as np | |
| from loguru import logger as logging | |
| from cosmos_predict1.tokenizer.inference.utils import ( | |
| get_filepaths, | |
| get_output_filepath, | |
| read_video, | |
| resize_video, | |
| write_video, | |
| ) | |
| from cosmos_predict1.tokenizer.inference.video_lib import CausalVideoTokenizer | |
| from cosmos_predict1.tokenizer.networks import TokenizerConfigs | |
| def _parse_args() -> tuple[Namespace, dict[str, Any]]: | |
| parser = ArgumentParser(description="A CLI for CausalVideoTokenizer.") | |
| parser.add_argument( | |
| "--video_pattern", | |
| type=str, | |
| default="path/to/videos/*.mp4", | |
| help="Glob pattern.", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| default=None, | |
| help="JIT full Autoencoder model filepath.", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint_enc", | |
| type=str, | |
| default=None, | |
| help="JIT Encoder model filepath.", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint_dec", | |
| type=str, | |
| default=None, | |
| help="JIT Decoder model filepath.", | |
| ) | |
| parser.add_argument( | |
| "--tokenizer_type", | |
| type=str, | |
| default=None, | |
| choices=[ | |
| "CV8x8x8-720p", | |
| "DV8x16x16-720p", | |
| "CV4x8x8-360p", | |
| "DV4x8x8-360p", | |
| ], | |
| help="Specifies the tokenizer type.", | |
| ) | |
| parser.add_argument( | |
| "--mode", | |
| type=str, | |
| choices=["torch", "jit"], | |
| default="jit", | |
| help="Specify the backend: native 'torch' or 'jit' (default: 'jit')", | |
| ) | |
| parser.add_argument( | |
| "--short_size", | |
| type=int, | |
| default=None, | |
| help="The size to resample inputs. None, by default.", | |
| ) | |
| parser.add_argument( | |
| "--temporal_window", | |
| type=int, | |
| default=17, | |
| help="The temporal window to operate at a time.", | |
| ) | |
| parser.add_argument( | |
| "--dtype", | |
| type=str, | |
| default="bfloat16", | |
| help="Sets the precision, default bfloat16.", | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default="cuda", | |
| help="Device for invoking the model.", | |
| ) | |
| parser.add_argument("--output_dir", type=str, default=None, help="Output directory.") | |
| parser.add_argument( | |
| "--output_fps", | |
| type=float, | |
| default=24.0, | |
| help="Output frames-per-second (FPS).", | |
| ) | |
| parser.add_argument( | |
| "--save_input", | |
| action="store_true", | |
| help="If on, the input video will be be outputted too.", | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| logging.info("Initializes args ...") | |
| args = _parse_args() | |
| if args.mode == "torch" and args.tokenizer_type is None: | |
| logging.error("`torch` backend requires `--tokenizer_type` to be specified.") | |
| sys.exit(1) | |
| def _run_eval() -> None: | |
| """Invokes JIT-compiled CausalVideoTokenizer on an input video.""" | |
| if args.checkpoint_enc is None and args.checkpoint_dec is None and args.checkpoint is None: | |
| logging.warning("Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model.") | |
| return | |
| if args.mode == "torch": | |
| _type = args.tokenizer_type.replace("-", "_") | |
| _config = TokenizerConfigs[_type].value | |
| else: | |
| _config = None | |
| logging.info( | |
| f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..." | |
| ) | |
| autoencoder = CausalVideoTokenizer( | |
| checkpoint=args.checkpoint, | |
| checkpoint_enc=args.checkpoint_enc, | |
| checkpoint_dec=args.checkpoint_dec, | |
| tokenizer_config=_config, | |
| device=args.device, | |
| dtype=args.dtype, | |
| ) | |
| logging.info(f"Looking for files matching video_pattern={args.video_pattern} ...") | |
| filepaths = get_filepaths(args.video_pattern) | |
| logging.info(f"Found {len(filepaths)} videos from {args.video_pattern}.") | |
| for filepath in filepaths: | |
| logging.info(f"Reading video {filepath} ...") | |
| video = read_video(filepath) | |
| video = resize_video(video, short_size=args.short_size) | |
| logging.info("Invoking the autoencoder model in ... ") | |
| batch_video = video[np.newaxis, ...] | |
| output_video = autoencoder(batch_video, temporal_window=args.temporal_window)[0] | |
| logging.info("Constructing output filepath ...") | |
| output_filepath = get_output_filepath(filepath, output_dir=args.output_dir) | |
| logging.info(f"Outputing {output_filepath} ...") | |
| write_video(output_filepath, output_video, fps=args.output_fps) | |
| if args.save_input: | |
| ext = os.path.splitext(output_filepath)[-1] | |
| input_filepath = output_filepath.replace(ext, "_input" + ext) | |
| write_video(input_filepath, video, fps=args.output_fps) | |
| def main() -> None: | |
| _run_eval() | |
| if __name__ == "__main__": | |
| main() | |