# SPDX-License-Identifier: Apache-2.0 from __future__ import annotations from collections.abc import Mapping from copy import deepcopy from dataclasses import fields, is_dataclass from pathlib import Path from typing import Any from fastvideo.api.overrides import apply_overrides, parse_cli_overrides from fastvideo.api.parser import config_to_dict, load_raw_config, parse_config from fastvideo.api.schema import ( GenerationRequest, GeneratorConfig, InputConfig, OutputConfig, RequestRuntimeConfig, SamplingConfig, ) from fastvideo.configs.sample import SamplingParam from fastvideo.fastvideo_args import FastVideoArgs from fastvideo.utils import shallow_asdict _EXPLICIT_REQUEST_ATTR = "_fastvideo_explicit_request" _INPUT_FIELD_NAMES = {field.name for field in fields(InputConfig)} _SAMPLING_FIELD_NAMES = {field.name for field in fields(SamplingConfig)} _RUNTIME_FIELD_NAMES = {field.name for field in fields(RequestRuntimeConfig)} _OUTPUT_FIELD_NAMES = {field.name for field in fields(OutputConfig)} _MISSING = object() _LEGACY_REQUEST_ALIASES = { "neg_prompt": "negative_prompt", } _REQUEST_PIPELINE_OVERRIDE_FIELDS = frozenset({ "embedded_cfg_scale", }) def normalize_generator_config(config: GeneratorConfig | Mapping[str, Any], ) -> GeneratorConfig: if isinstance(config, GeneratorConfig): return config return parse_config(GeneratorConfig, config) def load_generator_config_from_file( path: str | Path, overrides: list[str] | Mapping[str, Any] | None = None, ) -> GeneratorConfig: raw = load_raw_config(path) normalized_overrides = _normalize_overrides(overrides) if _looks_like_run_or_serve_config(raw): if normalized_overrides: raw = apply_overrides(raw, normalized_overrides) return parse_config(GeneratorConfig, raw["generator"]) if normalized_overrides: adjusted = normalized_overrides if all(key.startswith("generator.") for key in adjusted): adjusted = {key[len("generator."):]: value for key, value in adjusted.items()} raw = apply_overrides(raw, adjusted) return parse_config(GeneratorConfig, raw) def legacy_from_pretrained_to_config( model_path: str, kwargs: Mapping[str, Any], ) -> GeneratorConfig: raw: dict[str, Any] = {"model_path": model_path} engine: dict[str, Any] = {} parallelism: dict[str, Any] = {} offload: dict[str, Any] = {} compile_config: dict[str, Any] = {} pipeline: dict[str, Any] = {} components: dict[str, Any] = {} quantization: dict[str, Any] = {} experimental: dict[str, Any] = {} for key, value in kwargs.items(): if key == "revision": raw["revision"] = value elif key == "trust_remote_code": raw["trust_remote_code"] = value elif key == "num_gpus": engine["num_gpus"] = value elif key == "distributed_executor_backend": engine["execution_backend"] = value elif key in {"tp_size", "sp_size", "hsdp_replicate_dim", "hsdp_shard_dim", "dist_timeout"}: parallelism[key] = value elif key == "dit_cpu_offload": offload["dit"] = value elif key == "dit_layerwise_offload": offload["dit_layerwise"] = value elif key == "text_encoder_cpu_offload": offload["text_encoder"] = value elif key == "image_encoder_cpu_offload": offload["image_encoder"] = value elif key == "vae_cpu_offload": offload["vae"] = value elif key == "pin_cpu_memory": offload["pin_cpu_memory"] = value elif key == "enable_torch_compile": compile_config["enabled"] = value elif key == "torch_compile_kwargs": compile_config["kwargs"] = deepcopy(value) elif key in {"enable_stage_verification", "use_fsdp_inference", "disable_autocast"}: engine[key] = value elif key == "override_text_encoder_quant": quantization["text_encoder_quant"] = value elif key == "transformer_quant": quantization["transformer_quant"] = value elif key == "workload_type": pipeline["workload_type"] = value elif key == "lora_path": components["lora_path"] = value elif key == "override_pipeline_cls_name": components["override_pipeline_cls_name"] = value elif key == "override_transformer_cls_name": components["override_transformer_cls_name"] = value elif key == "pipeline_config": if isinstance(value, str): components["pipeline_config_path"] = value else: experimental[key] = deepcopy(value) elif key == "override_text_encoder_safetensors": components["text_encoder_weights"] = value elif key == "init_weights_from_safetensors": components["transformer_weights"] = value elif key == "init_weights_from_safetensors_2": components["transformer_2_weights"] = value else: experimental[key] = deepcopy(value) if parallelism: engine["parallelism"] = parallelism if offload: engine["offload"] = offload if compile_config: engine["compile"] = compile_config if quantization: engine["quantization"] = quantization if engine: raw["engine"] = engine if components: pipeline["components"] = components if experimental: pipeline["experimental"] = experimental if pipeline: raw["pipeline"] = pipeline return parse_config(GeneratorConfig, raw) def generator_config_to_fastvideo_args(config: GeneratorConfig | Mapping[str, Any], ) -> FastVideoArgs: normalized = normalize_generator_config(config) unsupported = [] if normalized.pipeline.profile is not None: unsupported.append("pipeline.profile") if normalized.pipeline.profile_version is not None: unsupported.append("pipeline.profile_version") if normalized.pipeline.components.config_root is not None: unsupported.append("pipeline.components.config_root") if normalized.pipeline.components.vae_weights is not None: unsupported.append("pipeline.components.vae_weights") if normalized.pipeline.components.upsampler_weights is not None: unsupported.append("pipeline.components.upsampler_weights") if unsupported: joined = ", ".join(unsupported) raise NotImplementedError(f"VideoGenerator compatibility adapter does not support {joined} yet") engine = normalized.engine kwargs: dict[str, Any] = { "model_path": normalized.model_path, "revision": normalized.revision, "trust_remote_code": normalized.trust_remote_code, "num_gpus": engine.num_gpus, "distributed_executor_backend": engine.execution_backend, "tp_size": engine.parallelism.tp_size, "sp_size": engine.parallelism.sp_size, "hsdp_replicate_dim": engine.parallelism.hsdp_replicate_dim, "hsdp_shard_dim": engine.parallelism.hsdp_shard_dim, "dist_timeout": engine.parallelism.dist_timeout, "dit_cpu_offload": engine.offload.dit, "dit_layerwise_offload": engine.offload.dit_layerwise, "text_encoder_cpu_offload": engine.offload.text_encoder, "image_encoder_cpu_offload": engine.offload.image_encoder, "vae_cpu_offload": engine.offload.vae, "pin_cpu_memory": engine.offload.pin_cpu_memory, "enable_torch_compile": engine.compile.enabled, "torch_compile_kwargs": deepcopy(engine.compile.kwargs), "enable_stage_verification": engine.enable_stage_verification, "use_fsdp_inference": engine.use_fsdp_inference, "disable_autocast": engine.disable_autocast, } if normalized.pipeline.workload_type is not None: kwargs["workload_type"] = normalized.pipeline.workload_type quantization = engine.quantization if quantization is not None and quantization.text_encoder_quant is not None: kwargs["override_text_encoder_quant"] = quantization.text_encoder_quant if quantization is not None and quantization.transformer_quant is not None: kwargs["transformer_quant"] = quantization.transformer_quant components = normalized.pipeline.components if components.pipeline_config_path is not None: kwargs["pipeline_config"] = components.pipeline_config_path if components.lora_path is not None: kwargs["lora_path"] = components.lora_path if components.override_pipeline_cls_name is not None: kwargs["override_pipeline_cls_name"] = components.override_pipeline_cls_name if components.override_transformer_cls_name is not None: kwargs["override_transformer_cls_name"] = components.override_transformer_cls_name if components.text_encoder_weights is not None: kwargs["override_text_encoder_safetensors"] = components.text_encoder_weights if components.transformer_weights is not None: kwargs["init_weights_from_safetensors"] = components.transformer_weights if components.transformer_2_weights is not None: kwargs["init_weights_from_safetensors_2"] = components.transformer_2_weights kwargs.update(deepcopy(normalized.pipeline.profile_overrides)) kwargs.update(deepcopy(normalized.pipeline.experimental)) return FastVideoArgs.from_kwargs(**kwargs) def normalize_generation_request(request: GenerationRequest | Mapping[str, Any], ) -> GenerationRequest: normalized = (request if isinstance(request, GenerationRequest) else parse_config(GenerationRequest, request)) if not hasattr(normalized, _EXPLICIT_REQUEST_ATTR): setattr(normalized, _EXPLICIT_REQUEST_ATTR, _serialize_generation_request(normalized)) return normalized def legacy_generate_call_to_request( prompt: str | None, sampling_param: SamplingParam | None, *, mouse_cond: Any | None = None, keyboard_cond: Any | None = None, grid_sizes: Any | None = None, legacy_kwargs: Mapping[str, Any] | None = None, ) -> GenerationRequest: raw = _sampling_param_to_request_raw(sampling_param) if prompt is not None: raw["prompt"] = prompt for key, value in (legacy_kwargs or {}).items(): _apply_request_field(raw, key, value) if mouse_cond is not None: raw.setdefault("inputs", {})["mouse_cond"] = mouse_cond if keyboard_cond is not None: raw.setdefault("inputs", {})["keyboard_cond"] = keyboard_cond if grid_sizes is not None: raw.setdefault("inputs", {})["grid_sizes"] = grid_sizes normalized = parse_config(GenerationRequest, raw) setattr(normalized, _EXPLICIT_REQUEST_ATTR, deepcopy(raw)) return normalized def request_to_sampling_param( request: GenerationRequest, *, model_path: str, ) -> SamplingParam: if request.plan is not None: raise NotImplementedError("GenerationRequest.plan is not wired into VideoGenerator yet") if request.state is not None: raise NotImplementedError("GenerationRequest.state is not wired into VideoGenerator yet") sampling_param = SamplingParam.from_pretrained(model_path) updates = _explicit_request_updates(request) for key, value in updates.items(): if hasattr(sampling_param, key): setattr(sampling_param, key, deepcopy(value)) elif key in _REQUEST_PIPELINE_OVERRIDE_FIELDS or _is_supported_as_default_only(key, value): continue else: raise ValueError(f"Request field {key!r} is not supported by sampling params for {model_path}") sampling_param.__post_init__() sampling_param.check_sampling_param() return sampling_param def expand_request_prompt_batch(request: GenerationRequest, ) -> list[GenerationRequest]: if not isinstance(request.prompt, list): return [request] requests: list[GenerationRequest] = [] for index, prompt in enumerate(request.prompt): single_request = deepcopy(request) single_request.prompt = prompt _fan_out_batched_input_value(request, single_request, "image_path", index) _fan_out_batched_input_value(request, single_request, "video_path", index) _fan_out_explicit_request_metadata(request, single_request, index, prompt) requests.append(single_request) return requests def _looks_like_run_or_serve_config(raw: Mapping[str, Any]) -> bool: return isinstance(raw.get("generator"), Mapping) def _normalize_overrides(overrides: list[str] | Mapping[str, Any] | None, ) -> dict[str, Any] | None: if not overrides: return None if isinstance(overrides, list): return parse_cli_overrides(overrides) return dict(overrides) def _sampling_param_to_request_raw(sampling_param: SamplingParam | None, ) -> dict[str, Any]: if sampling_param is None: return {} raw: dict[str, Any] = {} for key, value in shallow_asdict(sampling_param).items(): if key == "prompt": continue _apply_request_field(raw, key, deepcopy(value)) return raw def _apply_request_field( raw: dict[str, Any], key: str, value: Any, ) -> None: key = _LEGACY_REQUEST_ALIASES.get(key, key) if key == "negative_prompt": raw["negative_prompt"] = value return if key in _INPUT_FIELD_NAMES: raw.setdefault("inputs", {})[key] = value return if key in _SAMPLING_FIELD_NAMES: raw.setdefault("sampling", {})[key] = value return if key in _RUNTIME_FIELD_NAMES: raw.setdefault("runtime", {})[key] = value return if key in _OUTPUT_FIELD_NAMES: raw.setdefault("output", {})[key] = value return raw.setdefault("extensions", {})[key] = value def request_to_pipeline_overrides(request: GenerationRequest) -> dict[str, Any]: overrides: dict[str, Any] = {} for key, value in _explicit_request_updates(request).items(): if key in _REQUEST_PIPELINE_OVERRIDE_FIELDS: overrides[key] = deepcopy(value) return overrides def _explicit_request_updates(request: GenerationRequest) -> dict[str, Any]: raw = getattr(request, _EXPLICIT_REQUEST_ATTR, None) if raw is None: raw = _serialize_generation_request(request) return _extract_request_updates(raw) def _extract_request_updates(raw: Mapping[str, Any]) -> dict[str, Any]: updates: dict[str, Any] = {} if "negative_prompt" in raw: updates["negative_prompt"] = deepcopy(raw["negative_prompt"]) for section_name in ("inputs", "sampling", "runtime", "output"): section = raw.get(section_name) if not isinstance(section, Mapping): continue for key, value in section.items(): updates[key] = deepcopy(value) stage_overrides = raw.get("stage_overrides") if stage_overrides: updates.update(_flatten_stage_overrides(stage_overrides)) extensions = raw.get("extensions") if isinstance(extensions, Mapping): for key, value in extensions.items(): updates[key] = deepcopy(value) return updates def _flatten_stage_overrides(stage_overrides: Any) -> dict[str, Any]: if not isinstance(stage_overrides, Mapping): raise ValueError("GenerationRequest.stage_overrides must be a mapping") flattened: dict[str, Any] = {} for stage_name, overrides in stage_overrides.items(): if not isinstance(overrides, Mapping): raise ValueError(f"GenerationRequest.stage_overrides.{stage_name} must be a mapping") for key, value in overrides.items(): if key in flattened and flattened[key] != value: raise ValueError(f"Conflicting stage override for {key!r} across stages") flattened[key] = deepcopy(value) return flattened def _serialize_generation_request(request: GenerationRequest) -> dict[str, Any]: return deepcopy(config_to_dict(request)) def _fan_out_batched_input_value( source_request: GenerationRequest, target_request: GenerationRequest, field_name: str, index: int, ) -> None: value = getattr(source_request.inputs, field_name) if not isinstance(value, list): return _validate_batched_input_length(source_request.prompt, value, field_name) setattr(target_request.inputs, field_name, deepcopy(value[index])) def _fan_out_explicit_request_metadata( source_request: GenerationRequest, target_request: GenerationRequest, index: int, prompt: str, ) -> None: raw = getattr(source_request, _EXPLICIT_REQUEST_ATTR, None) if raw is None: return raw = deepcopy(raw) raw["prompt"] = prompt inputs = raw.get("inputs") if isinstance(inputs, dict): for field_name in ("image_path", "video_path"): value = inputs.get(field_name) if isinstance(value, list): _validate_batched_input_length(source_request.prompt, value, field_name) inputs[field_name] = deepcopy(value[index]) setattr(target_request, _EXPLICIT_REQUEST_ATTR, raw) def _validate_batched_input_length( prompts: str | list[str] | None, values: list[Any], field_name: str, ) -> None: if not isinstance(prompts, list): return if len(values) != len(prompts): raise ValueError(f"GenerationRequest.inputs.{field_name} must have the same length as request.prompt") def _is_supported_as_default_only(key: str, value: Any) -> bool: default_value = _DEFAULT_REQUEST_UPDATES.get(key, _MISSING) return default_value is not _MISSING and _values_equal(value, default_value) def _collect_non_default_fields( value: Any, default: Any, ) -> dict[str, Any]: if not (is_dataclass(value) and is_dataclass(default)): return {} result: dict[str, Any] = {} for field in fields(value): current = getattr(value, field.name) default_value = getattr(default, field.name) if is_dataclass(current) and is_dataclass(default_value): nested = _collect_non_default_fields(current, default_value) if nested: result[field.name] = nested continue if not _values_equal(current, default_value): result[field.name] = deepcopy(current) return result def _values_equal(left: Any, right: Any) -> bool: if left is right: return True try: return bool(left == right) except Exception: return False _DEFAULT_REQUEST_UPDATES = _extract_request_updates(config_to_dict(GenerationRequest())) __all__ = [ "generator_config_to_fastvideo_args", "legacy_from_pretrained_to_config", "legacy_generate_call_to_request", "load_generator_config_from_file", "normalize_generation_request", "normalize_generator_config", "request_to_pipeline_overrides", "request_to_sampling_param", ]