File size: 4,699 Bytes
697fddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# SPDX-License-Identifier: Apache-2.0
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/serve.py

import argparse
import dataclasses
import os
from typing import cast

from fastvideo import VideoGenerator
from fastvideo.configs.sample.base import SamplingParam
from fastvideo.entrypoints.cli.cli_types import CLISubcommand
from fastvideo.entrypoints.cli.utils import RaiseNotImplementedAction
from fastvideo.fastvideo_args import FastVideoArgs
from fastvideo.logger import init_logger
from fastvideo.utils import FlexibleArgumentParser

logger = init_logger(__name__)


class GenerateSubcommand(CLISubcommand):
    """The `generate` subcommand for the FastVideo CLI"""

    def __init__(self) -> None:
        self.name = "generate"
        super().__init__()
        self.init_arg_names = self._get_init_arg_names()
        self.generation_arg_names = self._get_generation_arg_names()

    def _get_init_arg_names(self) -> list[str]:
        """Get names of arguments for VideoGenerator initialization"""
        return ["num_gpus", "tp_size", "sp_size", "model_path"]

    def _get_generation_arg_names(self) -> list[str]:
        """Get names of arguments for generate_video method"""
        return [field.name for field in dataclasses.fields(SamplingParam)]

    def cmd(self, args: argparse.Namespace) -> None:
        excluded_args = ['subparser', 'config', 'dispatch_function']

        provided_args = {}
        for k, v in vars(args).items():
            if (k not in excluded_args and v is not None and hasattr(args, '_provided') and k in args._provided):
                provided_args[k] = v

        if 'model_path' in vars(args) and args.model_path is not None:
            provided_args['model_path'] = args.model_path

        if 'prompt' in vars(args) and args.prompt is not None:
            provided_args['prompt'] = args.prompt

        merged_args = {**provided_args}

        logger.info('CLI Args: %s', merged_args)

        if 'model_path' not in merged_args or not merged_args['model_path']:
            raise ValueError("model_path must be provided either in config file or via --model-path")

        # Check if either prompt or prompt_txt is provided
        has_prompt = 'prompt' in merged_args and merged_args['prompt']
        has_prompt_txt = 'prompt_txt' in merged_args and merged_args['prompt_txt']

        if not (has_prompt or has_prompt_txt):
            raise ValueError("Either prompt or prompt_txt must be provided")

        if has_prompt and has_prompt_txt:
            raise ValueError("Cannot provide both 'prompt' and 'prompt_txt'. Use only one of them.")

        init_args = {k: v for k, v in merged_args.items() if k not in self.generation_arg_names}
        generation_args = {k: v for k, v in merged_args.items() if k in self.generation_arg_names}
        generation_args.setdefault("return_frames", False)

        model_path = init_args.pop('model_path')
        prompt = generation_args.pop('prompt', None)

        generator = VideoGenerator.from_pretrained(model_path=model_path, **init_args)

        # Call generate_video - it handles both single and batch modes
        generator.generate_video(prompt=prompt, **generation_args)

    def validate(self, args: argparse.Namespace) -> None:
        """Validate the arguments for this command"""
        if args.num_gpus is not None and args.num_gpus <= 0:
            raise ValueError("Number of gpus must be positive")

        if args.config and not os.path.exists(args.config):
            raise ValueError(f"Config file not found: {args.config}")

    def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
        generate_parser = subparsers.add_parser(
            "generate",
            help="Run inference on a model",
            usage="fastvideo generate (--model-path MODEL_PATH_OR_ID --prompt PROMPT) | --config CONFIG_FILE [OPTIONS]")

        generate_parser.add_argument(
            "--config",
            type=str,
            default='',
            required=False,
            help="Read CLI options from a config JSON or YAML file. If provided, --model-path and --prompt are optional."
        )

        generate_parser = FastVideoArgs.add_cli_args(generate_parser)
        generate_parser = SamplingParam.add_cli_args(generate_parser)

        generate_parser.add_argument(
            "--text-encoder-configs",
            action=RaiseNotImplementedAction,
            help="JSON array of text encoder configurations (NOT YET IMPLEMENTED)",
        )

        return cast(FlexibleArgumentParser, generate_parser)


def cmd_init() -> list[CLISubcommand]:
    return [GenerateSubcommand()]