yitongl's picture
Add inference code and attention settings for sfp4 checkpoint-750
697fddf verified
# SPDX-License-Identifier: Apache-2.0
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/serve.py
import argparse
import dataclasses
import os
from typing import cast
from fastvideo import VideoGenerator
from fastvideo.configs.sample.base import SamplingParam
from fastvideo.entrypoints.cli.cli_types import CLISubcommand
from fastvideo.entrypoints.cli.utils import RaiseNotImplementedAction
from fastvideo.fastvideo_args import FastVideoArgs
from fastvideo.logger import init_logger
from fastvideo.utils import FlexibleArgumentParser
logger = init_logger(__name__)
class GenerateSubcommand(CLISubcommand):
"""The `generate` subcommand for the FastVideo CLI"""
def __init__(self) -> None:
self.name = "generate"
super().__init__()
self.init_arg_names = self._get_init_arg_names()
self.generation_arg_names = self._get_generation_arg_names()
def _get_init_arg_names(self) -> list[str]:
"""Get names of arguments for VideoGenerator initialization"""
return ["num_gpus", "tp_size", "sp_size", "model_path"]
def _get_generation_arg_names(self) -> list[str]:
"""Get names of arguments for generate_video method"""
return [field.name for field in dataclasses.fields(SamplingParam)]
def cmd(self, args: argparse.Namespace) -> None:
excluded_args = ['subparser', 'config', 'dispatch_function']
provided_args = {}
for k, v in vars(args).items():
if (k not in excluded_args and v is not None and hasattr(args, '_provided') and k in args._provided):
provided_args[k] = v
if 'model_path' in vars(args) and args.model_path is not None:
provided_args['model_path'] = args.model_path
if 'prompt' in vars(args) and args.prompt is not None:
provided_args['prompt'] = args.prompt
merged_args = {**provided_args}
logger.info('CLI Args: %s', merged_args)
if 'model_path' not in merged_args or not merged_args['model_path']:
raise ValueError("model_path must be provided either in config file or via --model-path")
# Check if either prompt or prompt_txt is provided
has_prompt = 'prompt' in merged_args and merged_args['prompt']
has_prompt_txt = 'prompt_txt' in merged_args and merged_args['prompt_txt']
if not (has_prompt or has_prompt_txt):
raise ValueError("Either prompt or prompt_txt must be provided")
if has_prompt and has_prompt_txt:
raise ValueError("Cannot provide both 'prompt' and 'prompt_txt'. Use only one of them.")
init_args = {k: v for k, v in merged_args.items() if k not in self.generation_arg_names}
generation_args = {k: v for k, v in merged_args.items() if k in self.generation_arg_names}
generation_args.setdefault("return_frames", False)
model_path = init_args.pop('model_path')
prompt = generation_args.pop('prompt', None)
generator = VideoGenerator.from_pretrained(model_path=model_path, **init_args)
# Call generate_video - it handles both single and batch modes
generator.generate_video(prompt=prompt, **generation_args)
def validate(self, args: argparse.Namespace) -> None:
"""Validate the arguments for this command"""
if args.num_gpus is not None and args.num_gpus <= 0:
raise ValueError("Number of gpus must be positive")
if args.config and not os.path.exists(args.config):
raise ValueError(f"Config file not found: {args.config}")
def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
generate_parser = subparsers.add_parser(
"generate",
help="Run inference on a model",
usage="fastvideo generate (--model-path MODEL_PATH_OR_ID --prompt PROMPT) | --config CONFIG_FILE [OPTIONS]")
generate_parser.add_argument(
"--config",
type=str,
default='',
required=False,
help="Read CLI options from a config JSON or YAML file. If provided, --model-path and --prompt are optional."
)
generate_parser = FastVideoArgs.add_cli_args(generate_parser)
generate_parser = SamplingParam.add_cli_args(generate_parser)
generate_parser.add_argument(
"--text-encoder-configs",
action=RaiseNotImplementedAction,
help="JSON array of text encoder configurations (NOT YET IMPLEMENTED)",
)
return cast(FlexibleArgumentParser, generate_parser)
def cmd_init() -> list[CLISubcommand]:
return [GenerateSubcommand()]