yitongl's picture
Add inference code and attention settings for sfp4 checkpoint-750
697fddf verified
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Any
from fastvideo.logger import init_logger
from fastvideo.utils import StoreBoolean
logger = init_logger(__name__)
@dataclass
class SamplingParam:
"""
Sampling parameters for video generation.
"""
# All fields below are copied from ForwardBatch
data_type: str = "video"
# Image inputs
image_path: str | None = None
pil_image: Any | None = None
# Video inputs
video_path: str | None = None
# Action control inputs (Matrix-Game)
mouse_cond: Any | None = None # Shape: (B, T, 2)
keyboard_cond: Any | None = None # Shape: (B, T, K)
grid_sizes: Any | None = None # Shape: (3,) [F,H,W]
# Camera control inputs (HYWorld)
pose: str | None = None # Camera trajectory: pose string (e.g., 'w-31') or JSON file path
# Camera control inputs (LingBotWorld)
c2ws_plucker_emb: Any | None = None # Plucker embedding: [B, C, F_lat, H_lat, W_lat]
# Refine inputs (LongCat 480p->720p upscaling)
# Path-based refine (load stage1 video from disk, e.g. MP4)
refine_from: str | None = None # Path to stage1 video (480p output from distill)
t_thresh: float = 0.5 # Threshold for timestep scheduling in refinement
spatial_refine_only: bool = False # If True, only spatial (no temporal doubling)
num_cond_frames: int = 0 # Number of conditioning frames
# In-memory refine input (for two-stage pipeline where stage1 frames are already in memory)
# This mirrors LongCat's demo where a list of frames (e.g. np.ndarray or PIL.Image)
# is passed directly to the refinement pipeline instead of reloading from disk.
stage1_video: Any | None = None
# Text inputs
prompt: str | list[str] | None = None
negative_prompt: str | None = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
prompt_path: str | None = None
output_path: str = "outputs/"
output_video_name: str | None = None
# Batch info
num_videos_per_prompt: int = 1
seed: int = 1024
# Original dimensions (before VAE scaling)
num_frames: int = 125
height: int = 720
width: int = 1280
height_sr: int = 1072
width_sr: int = 1920
fps: int = 24
# Denoising parameters
num_inference_steps: int = 50
num_inference_steps_sr: int = 50
guidance_scale: float = 1.0
guidance_scale_2: float | None = None
guidance_rescale: float = 0.0
boundary_ratio: float | None = None
sigmas: list[float] | None = None
# TeaCache parameters
enable_teacache: bool = False
# GEN3C camera control
trajectory_type: str | None = None
movement_distance: float | None = None
camera_rotation: str | None = None
# Misc
save_video: bool = True
return_frames: bool = True
return_trajectory_latents: bool = False # returns all latents for each timestep
return_trajectory_decoded: bool = False # returns decoded latents for each timestep
def __post_init__(self) -> None:
self.data_type = "video" if self.num_frames > 1 else "image"
def __getattr__(self, name: str) -> Any:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
def check_sampling_param(self) -> None:
if self.prompt_path and not self.prompt_path.endswith(".txt"):
raise ValueError("prompt_path must be a txt file")
def update(self, source_dict: dict[str, Any]) -> None:
for key, value in source_dict.items():
if hasattr(self, key):
setattr(self, key, value)
else:
logger.exception("%s has no attribute %s", type(self).__name__, key)
self.__post_init__()
@classmethod
def from_pretrained(cls, model_path: str) -> "SamplingParam":
from fastvideo.registry import get_sampling_param_cls_for_name
sampling_cls = get_sampling_param_cls_for_name(model_path)
if sampling_cls is not None:
sampling_param: SamplingParam = sampling_cls()
else:
logger.warning("Couldn't find an optimal sampling param for %s. Using the default sampling param.",
model_path)
sampling_param = cls()
return sampling_param
@staticmethod
def add_cli_args(parser: Any) -> Any:
"""Add CLI arguments for SamplingParam fields"""
parser.add_argument(
"--prompt",
type=str,
default=SamplingParam.prompt,
help="Text prompt for video generation",
)
parser.add_argument(
"--negative-prompt",
type=str,
default=SamplingParam.negative_prompt,
help="Negative text prompt for video generation",
)
parser.add_argument(
"--prompt-path",
type=str,
default=SamplingParam.prompt_path,
help="Path to a text file containing the prompt",
)
parser.add_argument(
"--output-path",
type=str,
default=SamplingParam.output_path,
help="Path to save the generated video",
)
parser.add_argument(
"--output-video-name",
type=str,
default=SamplingParam.output_video_name,
help="Name of the output video",
)
parser.add_argument(
"--num-videos-per-prompt",
type=int,
default=SamplingParam.num_videos_per_prompt,
help="Number of videos to generate per prompt",
)
parser.add_argument(
"--seed",
type=int,
default=SamplingParam.seed,
help="Random seed for generation",
)
parser.add_argument(
"--num-frames",
type=int,
default=SamplingParam.num_frames,
help="Number of frames to generate",
)
parser.add_argument(
"--height",
type=int,
default=SamplingParam.height,
help="Height of generated video",
)
parser.add_argument(
"--width",
type=int,
default=SamplingParam.width,
help="Width of generated video",
)
parser.add_argument(
"--fps",
type=int,
default=SamplingParam.fps,
help="Frames per second for saved video",
)
parser.add_argument(
"--num-inference-steps",
type=int,
default=SamplingParam.num_inference_steps,
help="Number of denoising steps",
)
parser.add_argument(
"--guidance-scale",
type=float,
default=SamplingParam.guidance_scale,
help="Classifier-free guidance scale",
)
parser.add_argument(
"--guidance-rescale",
type=float,
default=SamplingParam.guidance_rescale,
help="Guidance rescale factor",
)
parser.add_argument(
"--boundary-ratio",
type=float,
default=SamplingParam.boundary_ratio,
help="Boundary timestep ratio",
)
parser.add_argument(
"--save-video",
action="store_true",
default=SamplingParam.save_video,
help="Whether to save the video to disk",
)
parser.add_argument(
"--no-save-video",
action="store_false",
dest="save_video",
help="Don't save the video to disk",
)
parser.add_argument(
"--return-frames",
action="store_true",
default=False,
help="Whether to return the raw frames",
)
parser.add_argument(
"--image-path",
type=str,
default=SamplingParam.image_path,
help="Path to input image for image-to-video generation",
)
parser.add_argument(
"--video-path",
type=str,
default=SamplingParam.video_path,
help="Path to input video for video-to-video generation",
)
parser.add_argument(
"--refine-from",
type=str,
default=SamplingParam.refine_from,
help="Path to stage1 video for refinement (LongCat 480p->720p)",
)
parser.add_argument(
"--t-thresh",
type=float,
default=SamplingParam.t_thresh,
help="Threshold for timestep scheduling in refinement (default: 0.5)",
)
parser.add_argument(
"--spatial-refine-only",
action=StoreBoolean,
default=SamplingParam.spatial_refine_only,
help="Only perform spatial super-resolution (no temporal doubling)",
)
parser.add_argument(
"--num-cond-frames",
type=int,
default=SamplingParam.num_cond_frames,
help="Number of conditioning frames for refinement",
)
parser.add_argument(
"--moba-config-path",
type=str,
default=None,
help="Path to a JSON file containing V-MoBA specific configurations.",
)
parser.add_argument(
"--return-trajectory-latents",
action="store_true",
default=SamplingParam.return_trajectory_latents,
help="Whether to return the trajectory",
)
parser.add_argument(
"--return-trajectory-decoded",
action="store_true",
default=SamplingParam.return_trajectory_decoded,
help="Whether to return the decoded trajectory",
)
return parser
@dataclass
class CacheParams:
cache_type: str = "none"