yitongl's picture
Add inference code and attention settings for sfp4 checkpoint-750
697fddf verified
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from collections.abc import Mapping
from copy import deepcopy
from dataclasses import fields, is_dataclass
from pathlib import Path
from typing import Any
from fastvideo.api.overrides import apply_overrides, parse_cli_overrides
from fastvideo.api.parser import config_to_dict, load_raw_config, parse_config
from fastvideo.api.schema import (
GenerationRequest,
GeneratorConfig,
InputConfig,
OutputConfig,
RequestRuntimeConfig,
SamplingConfig,
)
from fastvideo.configs.sample import SamplingParam
from fastvideo.fastvideo_args import FastVideoArgs
from fastvideo.utils import shallow_asdict
_EXPLICIT_REQUEST_ATTR = "_fastvideo_explicit_request"
_INPUT_FIELD_NAMES = {field.name for field in fields(InputConfig)}
_SAMPLING_FIELD_NAMES = {field.name for field in fields(SamplingConfig)}
_RUNTIME_FIELD_NAMES = {field.name for field in fields(RequestRuntimeConfig)}
_OUTPUT_FIELD_NAMES = {field.name for field in fields(OutputConfig)}
_MISSING = object()
_LEGACY_REQUEST_ALIASES = {
"neg_prompt": "negative_prompt",
}
_REQUEST_PIPELINE_OVERRIDE_FIELDS = frozenset({
"embedded_cfg_scale",
})
def normalize_generator_config(config: GeneratorConfig | Mapping[str, Any], ) -> GeneratorConfig:
if isinstance(config, GeneratorConfig):
return config
return parse_config(GeneratorConfig, config)
def load_generator_config_from_file(
path: str | Path,
overrides: list[str] | Mapping[str, Any] | None = None,
) -> GeneratorConfig:
raw = load_raw_config(path)
normalized_overrides = _normalize_overrides(overrides)
if _looks_like_run_or_serve_config(raw):
if normalized_overrides:
raw = apply_overrides(raw, normalized_overrides)
return parse_config(GeneratorConfig, raw["generator"])
if normalized_overrides:
adjusted = normalized_overrides
if all(key.startswith("generator.") for key in adjusted):
adjusted = {key[len("generator."):]: value for key, value in adjusted.items()}
raw = apply_overrides(raw, adjusted)
return parse_config(GeneratorConfig, raw)
def legacy_from_pretrained_to_config(
model_path: str,
kwargs: Mapping[str, Any],
) -> GeneratorConfig:
raw: dict[str, Any] = {"model_path": model_path}
engine: dict[str, Any] = {}
parallelism: dict[str, Any] = {}
offload: dict[str, Any] = {}
compile_config: dict[str, Any] = {}
pipeline: dict[str, Any] = {}
components: dict[str, Any] = {}
quantization: dict[str, Any] = {}
experimental: dict[str, Any] = {}
for key, value in kwargs.items():
if key == "revision":
raw["revision"] = value
elif key == "trust_remote_code":
raw["trust_remote_code"] = value
elif key == "num_gpus":
engine["num_gpus"] = value
elif key == "distributed_executor_backend":
engine["execution_backend"] = value
elif key in {"tp_size", "sp_size", "hsdp_replicate_dim", "hsdp_shard_dim", "dist_timeout"}:
parallelism[key] = value
elif key == "dit_cpu_offload":
offload["dit"] = value
elif key == "dit_layerwise_offload":
offload["dit_layerwise"] = value
elif key == "text_encoder_cpu_offload":
offload["text_encoder"] = value
elif key == "image_encoder_cpu_offload":
offload["image_encoder"] = value
elif key == "vae_cpu_offload":
offload["vae"] = value
elif key == "pin_cpu_memory":
offload["pin_cpu_memory"] = value
elif key == "enable_torch_compile":
compile_config["enabled"] = value
elif key == "torch_compile_kwargs":
compile_config["kwargs"] = deepcopy(value)
elif key in {"enable_stage_verification", "use_fsdp_inference", "disable_autocast"}:
engine[key] = value
elif key == "override_text_encoder_quant":
quantization["text_encoder_quant"] = value
elif key == "transformer_quant":
quantization["transformer_quant"] = value
elif key == "workload_type":
pipeline["workload_type"] = value
elif key == "lora_path":
components["lora_path"] = value
elif key == "override_pipeline_cls_name":
components["override_pipeline_cls_name"] = value
elif key == "override_transformer_cls_name":
components["override_transformer_cls_name"] = value
elif key == "pipeline_config":
if isinstance(value, str):
components["pipeline_config_path"] = value
else:
experimental[key] = deepcopy(value)
elif key == "override_text_encoder_safetensors":
components["text_encoder_weights"] = value
elif key == "init_weights_from_safetensors":
components["transformer_weights"] = value
elif key == "init_weights_from_safetensors_2":
components["transformer_2_weights"] = value
else:
experimental[key] = deepcopy(value)
if parallelism:
engine["parallelism"] = parallelism
if offload:
engine["offload"] = offload
if compile_config:
engine["compile"] = compile_config
if quantization:
engine["quantization"] = quantization
if engine:
raw["engine"] = engine
if components:
pipeline["components"] = components
if experimental:
pipeline["experimental"] = experimental
if pipeline:
raw["pipeline"] = pipeline
return parse_config(GeneratorConfig, raw)
def generator_config_to_fastvideo_args(config: GeneratorConfig | Mapping[str, Any], ) -> FastVideoArgs:
normalized = normalize_generator_config(config)
unsupported = []
if normalized.pipeline.profile is not None:
unsupported.append("pipeline.profile")
if normalized.pipeline.profile_version is not None:
unsupported.append("pipeline.profile_version")
if normalized.pipeline.components.config_root is not None:
unsupported.append("pipeline.components.config_root")
if normalized.pipeline.components.vae_weights is not None:
unsupported.append("pipeline.components.vae_weights")
if normalized.pipeline.components.upsampler_weights is not None:
unsupported.append("pipeline.components.upsampler_weights")
if unsupported:
joined = ", ".join(unsupported)
raise NotImplementedError(f"VideoGenerator compatibility adapter does not support {joined} yet")
engine = normalized.engine
kwargs: dict[str, Any] = {
"model_path": normalized.model_path,
"revision": normalized.revision,
"trust_remote_code": normalized.trust_remote_code,
"num_gpus": engine.num_gpus,
"distributed_executor_backend": engine.execution_backend,
"tp_size": engine.parallelism.tp_size,
"sp_size": engine.parallelism.sp_size,
"hsdp_replicate_dim": engine.parallelism.hsdp_replicate_dim,
"hsdp_shard_dim": engine.parallelism.hsdp_shard_dim,
"dist_timeout": engine.parallelism.dist_timeout,
"dit_cpu_offload": engine.offload.dit,
"dit_layerwise_offload": engine.offload.dit_layerwise,
"text_encoder_cpu_offload": engine.offload.text_encoder,
"image_encoder_cpu_offload": engine.offload.image_encoder,
"vae_cpu_offload": engine.offload.vae,
"pin_cpu_memory": engine.offload.pin_cpu_memory,
"enable_torch_compile": engine.compile.enabled,
"torch_compile_kwargs": deepcopy(engine.compile.kwargs),
"enable_stage_verification": engine.enable_stage_verification,
"use_fsdp_inference": engine.use_fsdp_inference,
"disable_autocast": engine.disable_autocast,
}
if normalized.pipeline.workload_type is not None:
kwargs["workload_type"] = normalized.pipeline.workload_type
quantization = engine.quantization
if quantization is not None and quantization.text_encoder_quant is not None:
kwargs["override_text_encoder_quant"] = quantization.text_encoder_quant
if quantization is not None and quantization.transformer_quant is not None:
kwargs["transformer_quant"] = quantization.transformer_quant
components = normalized.pipeline.components
if components.pipeline_config_path is not None:
kwargs["pipeline_config"] = components.pipeline_config_path
if components.lora_path is not None:
kwargs["lora_path"] = components.lora_path
if components.override_pipeline_cls_name is not None:
kwargs["override_pipeline_cls_name"] = components.override_pipeline_cls_name
if components.override_transformer_cls_name is not None:
kwargs["override_transformer_cls_name"] = components.override_transformer_cls_name
if components.text_encoder_weights is not None:
kwargs["override_text_encoder_safetensors"] = components.text_encoder_weights
if components.transformer_weights is not None:
kwargs["init_weights_from_safetensors"] = components.transformer_weights
if components.transformer_2_weights is not None:
kwargs["init_weights_from_safetensors_2"] = components.transformer_2_weights
kwargs.update(deepcopy(normalized.pipeline.profile_overrides))
kwargs.update(deepcopy(normalized.pipeline.experimental))
return FastVideoArgs.from_kwargs(**kwargs)
def normalize_generation_request(request: GenerationRequest | Mapping[str, Any], ) -> GenerationRequest:
normalized = (request if isinstance(request, GenerationRequest) else parse_config(GenerationRequest, request))
if not hasattr(normalized, _EXPLICIT_REQUEST_ATTR):
setattr(normalized, _EXPLICIT_REQUEST_ATTR, _serialize_generation_request(normalized))
return normalized
def legacy_generate_call_to_request(
prompt: str | None,
sampling_param: SamplingParam | None,
*,
mouse_cond: Any | None = None,
keyboard_cond: Any | None = None,
grid_sizes: Any | None = None,
legacy_kwargs: Mapping[str, Any] | None = None,
) -> GenerationRequest:
raw = _sampling_param_to_request_raw(sampling_param)
if prompt is not None:
raw["prompt"] = prompt
for key, value in (legacy_kwargs or {}).items():
_apply_request_field(raw, key, value)
if mouse_cond is not None:
raw.setdefault("inputs", {})["mouse_cond"] = mouse_cond
if keyboard_cond is not None:
raw.setdefault("inputs", {})["keyboard_cond"] = keyboard_cond
if grid_sizes is not None:
raw.setdefault("inputs", {})["grid_sizes"] = grid_sizes
normalized = parse_config(GenerationRequest, raw)
setattr(normalized, _EXPLICIT_REQUEST_ATTR, deepcopy(raw))
return normalized
def request_to_sampling_param(
request: GenerationRequest,
*,
model_path: str,
) -> SamplingParam:
if request.plan is not None:
raise NotImplementedError("GenerationRequest.plan is not wired into VideoGenerator yet")
if request.state is not None:
raise NotImplementedError("GenerationRequest.state is not wired into VideoGenerator yet")
sampling_param = SamplingParam.from_pretrained(model_path)
updates = _explicit_request_updates(request)
for key, value in updates.items():
if hasattr(sampling_param, key):
setattr(sampling_param, key, deepcopy(value))
elif key in _REQUEST_PIPELINE_OVERRIDE_FIELDS or _is_supported_as_default_only(key, value):
continue
else:
raise ValueError(f"Request field {key!r} is not supported by sampling params for {model_path}")
sampling_param.__post_init__()
sampling_param.check_sampling_param()
return sampling_param
def expand_request_prompt_batch(request: GenerationRequest, ) -> list[GenerationRequest]:
if not isinstance(request.prompt, list):
return [request]
requests: list[GenerationRequest] = []
for index, prompt in enumerate(request.prompt):
single_request = deepcopy(request)
single_request.prompt = prompt
_fan_out_batched_input_value(request, single_request, "image_path", index)
_fan_out_batched_input_value(request, single_request, "video_path", index)
_fan_out_explicit_request_metadata(request, single_request, index, prompt)
requests.append(single_request)
return requests
def _looks_like_run_or_serve_config(raw: Mapping[str, Any]) -> bool:
return isinstance(raw.get("generator"), Mapping)
def _normalize_overrides(overrides: list[str] | Mapping[str, Any] | None, ) -> dict[str, Any] | None:
if not overrides:
return None
if isinstance(overrides, list):
return parse_cli_overrides(overrides)
return dict(overrides)
def _sampling_param_to_request_raw(sampling_param: SamplingParam | None, ) -> dict[str, Any]:
if sampling_param is None:
return {}
raw: dict[str, Any] = {}
for key, value in shallow_asdict(sampling_param).items():
if key == "prompt":
continue
_apply_request_field(raw, key, deepcopy(value))
return raw
def _apply_request_field(
raw: dict[str, Any],
key: str,
value: Any,
) -> None:
key = _LEGACY_REQUEST_ALIASES.get(key, key)
if key == "negative_prompt":
raw["negative_prompt"] = value
return
if key in _INPUT_FIELD_NAMES:
raw.setdefault("inputs", {})[key] = value
return
if key in _SAMPLING_FIELD_NAMES:
raw.setdefault("sampling", {})[key] = value
return
if key in _RUNTIME_FIELD_NAMES:
raw.setdefault("runtime", {})[key] = value
return
if key in _OUTPUT_FIELD_NAMES:
raw.setdefault("output", {})[key] = value
return
raw.setdefault("extensions", {})[key] = value
def request_to_pipeline_overrides(request: GenerationRequest) -> dict[str, Any]:
overrides: dict[str, Any] = {}
for key, value in _explicit_request_updates(request).items():
if key in _REQUEST_PIPELINE_OVERRIDE_FIELDS:
overrides[key] = deepcopy(value)
return overrides
def _explicit_request_updates(request: GenerationRequest) -> dict[str, Any]:
raw = getattr(request, _EXPLICIT_REQUEST_ATTR, None)
if raw is None:
raw = _serialize_generation_request(request)
return _extract_request_updates(raw)
def _extract_request_updates(raw: Mapping[str, Any]) -> dict[str, Any]:
updates: dict[str, Any] = {}
if "negative_prompt" in raw:
updates["negative_prompt"] = deepcopy(raw["negative_prompt"])
for section_name in ("inputs", "sampling", "runtime", "output"):
section = raw.get(section_name)
if not isinstance(section, Mapping):
continue
for key, value in section.items():
updates[key] = deepcopy(value)
stage_overrides = raw.get("stage_overrides")
if stage_overrides:
updates.update(_flatten_stage_overrides(stage_overrides))
extensions = raw.get("extensions")
if isinstance(extensions, Mapping):
for key, value in extensions.items():
updates[key] = deepcopy(value)
return updates
def _flatten_stage_overrides(stage_overrides: Any) -> dict[str, Any]:
if not isinstance(stage_overrides, Mapping):
raise ValueError("GenerationRequest.stage_overrides must be a mapping")
flattened: dict[str, Any] = {}
for stage_name, overrides in stage_overrides.items():
if not isinstance(overrides, Mapping):
raise ValueError(f"GenerationRequest.stage_overrides.{stage_name} must be a mapping")
for key, value in overrides.items():
if key in flattened and flattened[key] != value:
raise ValueError(f"Conflicting stage override for {key!r} across stages")
flattened[key] = deepcopy(value)
return flattened
def _serialize_generation_request(request: GenerationRequest) -> dict[str, Any]:
return deepcopy(config_to_dict(request))
def _fan_out_batched_input_value(
source_request: GenerationRequest,
target_request: GenerationRequest,
field_name: str,
index: int,
) -> None:
value = getattr(source_request.inputs, field_name)
if not isinstance(value, list):
return
_validate_batched_input_length(source_request.prompt, value, field_name)
setattr(target_request.inputs, field_name, deepcopy(value[index]))
def _fan_out_explicit_request_metadata(
source_request: GenerationRequest,
target_request: GenerationRequest,
index: int,
prompt: str,
) -> None:
raw = getattr(source_request, _EXPLICIT_REQUEST_ATTR, None)
if raw is None:
return
raw = deepcopy(raw)
raw["prompt"] = prompt
inputs = raw.get("inputs")
if isinstance(inputs, dict):
for field_name in ("image_path", "video_path"):
value = inputs.get(field_name)
if isinstance(value, list):
_validate_batched_input_length(source_request.prompt, value, field_name)
inputs[field_name] = deepcopy(value[index])
setattr(target_request, _EXPLICIT_REQUEST_ATTR, raw)
def _validate_batched_input_length(
prompts: str | list[str] | None,
values: list[Any],
field_name: str,
) -> None:
if not isinstance(prompts, list):
return
if len(values) != len(prompts):
raise ValueError(f"GenerationRequest.inputs.{field_name} must have the same length as request.prompt")
def _is_supported_as_default_only(key: str, value: Any) -> bool:
default_value = _DEFAULT_REQUEST_UPDATES.get(key, _MISSING)
return default_value is not _MISSING and _values_equal(value, default_value)
def _collect_non_default_fields(
value: Any,
default: Any,
) -> dict[str, Any]:
if not (is_dataclass(value) and is_dataclass(default)):
return {}
result: dict[str, Any] = {}
for field in fields(value):
current = getattr(value, field.name)
default_value = getattr(default, field.name)
if is_dataclass(current) and is_dataclass(default_value):
nested = _collect_non_default_fields(current, default_value)
if nested:
result[field.name] = nested
continue
if not _values_equal(current, default_value):
result[field.name] = deepcopy(current)
return result
def _values_equal(left: Any, right: Any) -> bool:
if left is right:
return True
try:
return bool(left == right)
except Exception:
return False
_DEFAULT_REQUEST_UPDATES = _extract_request_updates(config_to_dict(GenerationRequest()))
__all__ = [
"generator_config_to_fastvideo_args",
"legacy_from_pretrained_to_config",
"legacy_generate_call_to_request",
"load_generator_config_from_file",
"normalize_generation_request",
"normalize_generator_config",
"request_to_pipeline_overrides",
"request_to_sampling_param",
]