Add standalone inference helper for sfp4 checkpoint-700
Browse files- standalone_inference/README.md +74 -0
- standalone_inference/__pycache__/install_overlay.cpython-313.pyc +0 -0
- standalone_inference/__pycache__/run_inference.cpython-313.pyc +0 -0
- standalone_inference/install_overlay.py +89 -0
- standalone_inference/manifest.sha256 +31 -0
- standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py +270 -0
- standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py +1155 -0
- standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py +250 -0
- standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py +80 -0
- standalone_inference/overlay_files/fastvideo/api/compat.py +503 -0
- standalone_inference/overlay_files/fastvideo/attention/backends/sparse_fp4_ours_p_attn.py +192 -0
- standalone_inference/overlay_files/fastvideo/attention/backends/video_sparse_attn.py +262 -0
- standalone_inference/overlay_files/fastvideo/configs/models/dits/base.py +79 -0
- standalone_inference/overlay_files/fastvideo/configs/pipelines/wan.py +203 -0
- standalone_inference/overlay_files/fastvideo/configs/sample/base.py +292 -0
- standalone_inference/overlay_files/fastvideo/configs/sample/wan.py +154 -0
- standalone_inference/overlay_files/fastvideo/configs/wan_1.3B_t2v_pipeline.json +40 -0
- standalone_inference/overlay_files/fastvideo/entrypoints/cli/generate.py +115 -0
- standalone_inference/overlay_files/fastvideo/entrypoints/video_generator.py +797 -0
- standalone_inference/overlay_files/fastvideo/fastvideo_args.py +1188 -0
- standalone_inference/overlay_files/fastvideo/forward_context.py +100 -0
- standalone_inference/overlay_files/fastvideo/pipelines/basic/wan/__init__.py +0 -0
- standalone_inference/overlay_files/fastvideo/pipelines/basic/wan/wan_pipeline.py +60 -0
- standalone_inference/overlay_files/fastvideo/pipelines/composed_pipeline_base.py +474 -0
- standalone_inference/overlay_files/fastvideo/pipelines/stages/denoising.py +1184 -0
- standalone_inference/overlay_files/fastvideo/platforms/cuda.py +440 -0
- standalone_inference/overlay_files/fastvideo/platforms/interface.py +255 -0
- standalone_inference/overlay_files/fastvideo/train/models/wan/wan.py +680 -0
- standalone_inference/overlay_files/fastvideo/training/training_pipeline.py +1044 -0
- standalone_inference/overlay_files/fastvideo/training/wan_training_pipeline.py +74 -0
- standalone_inference/requirements.txt +5 -0
- standalone_inference/run.sh +22 -0
- standalone_inference/run_inference.py +123 -0
- standalone_inference/training_attention_settings.json +62 -0
standalone_inference/README.md
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Standalone Inference Helper
|
| 2 |
+
|
| 3 |
+
This folder contains a portable inference helper for:
|
| 4 |
+
|
| 5 |
+
`sfp4_v4_sparse09_hpo_on_ours_p_init2050_1n_interactive/checkpoint-700`
|
| 6 |
+
|
| 7 |
+
It is not a full vendored copy of Wan or FastVideo. It contains the sparse FP4
|
| 8 |
+
backend overlay and a runner that can be applied to a FastVideo checkout or
|
| 9 |
+
installation so the uploaded checkpoint can be used for normal inference.
|
| 10 |
+
|
| 11 |
+
## Contents
|
| 12 |
+
|
| 13 |
+
- `run_inference.py`: downloads/loads `transformer/diffusion_pytorch_model.safetensors` from `yitongl/sparse_quant_exp` and runs `VideoGenerator`.
|
| 14 |
+
- `run.sh`: convenience wrapper that installs the overlay into `FASTVIDEO_ROOT` and then runs `run_inference.py`.
|
| 15 |
+
- `install_overlay.py`: copies the bundled sparse FP4 backend files into a FastVideo checkout/install.
|
| 16 |
+
- `overlay_files/`: exact runtime source files needed by `SPARSE_FP4_OURS_P_ATTN`.
|
| 17 |
+
- `training_attention_settings.json`: structured settings for the uploaded checkpoint.
|
| 18 |
+
|
| 19 |
+
## Expected Environment
|
| 20 |
+
|
| 21 |
+
- A working FastVideo Python environment.
|
| 22 |
+
- FastVideo dependencies installed, including PyTorch, Triton, safetensors, and
|
| 23 |
+
Hugging Face Hub.
|
| 24 |
+
- Access to the base model `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`.
|
| 25 |
+
- A CUDA GPU supported by the custom Triton kernels.
|
| 26 |
+
|
| 27 |
+
## Usage
|
| 28 |
+
|
| 29 |
+
From a machine with this HF repo downloaded:
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
export FASTVIDEO_ROOT=/path/to/FastVideo
|
| 33 |
+
bash standalone_inference/run.sh \
|
| 34 |
+
--output-path outputs/sfp4_checkpoint_700 \
|
| 35 |
+
--seed 1000
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
The script sets:
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
FASTVIDEO_ATTENTION_BACKEND=SPARSE_FP4_OURS_P_ATTN
|
| 42 |
+
FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O=1
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
and downloads the uploaded checkpoint-700 transformer weights unless `--weights`
|
| 46 |
+
is provided.
|
| 47 |
+
|
| 48 |
+
To use a local safetensors file:
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
export FASTVIDEO_ROOT=/path/to/FastVideo
|
| 52 |
+
bash standalone_inference/run.sh \
|
| 53 |
+
--weights /path/to/diffusion_pytorch_model.safetensors \
|
| 54 |
+
--prompt "your prompt"
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Attention Semantics
|
| 58 |
+
|
| 59 |
+
- Self-attention uses `SPARSE_FP4_OURS_P_ATTN`.
|
| 60 |
+
- Q/K/V use FP4 fake quantization with STE.
|
| 61 |
+
- VSA tile size is `4 x 4 x 4 = 64` tokens.
|
| 62 |
+
- Selected sparse tiles use group-local P quantization in the Triton kernel.
|
| 63 |
+
- Dropped tiles use tile mean compensation.
|
| 64 |
+
- Cross-attention falls back to dense SDPA and is not sparse/FP4.
|
| 65 |
+
|
| 66 |
+
## Checkpoint
|
| 67 |
+
|
| 68 |
+
The current HF `main` transformer file is checkpoint-700:
|
| 69 |
+
|
| 70 |
+
`transformer/diffusion_pytorch_model.safetensors`
|
| 71 |
+
|
| 72 |
+
Local SHA256 used when preparing this helper:
|
| 73 |
+
|
| 74 |
+
`4595ca81ea7085c15ccf14b738aa9c0fdf2d2786641f49b55e0bc0e99bf042d2`
|
standalone_inference/__pycache__/install_overlay.cpython-313.pyc
ADDED
|
Binary file (4.48 kB). View file
|
|
|
standalone_inference/__pycache__/run_inference.cpython-313.pyc
ADDED
|
Binary file (6.22 kB). View file
|
|
|
standalone_inference/install_overlay.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Install the sparse FP4 checkpoint-700 inference overlay into FastVideo.
|
| 3 |
+
|
| 4 |
+
The checkpoint depends on local FastVideo attention backend changes that are
|
| 5 |
+
not part of a vanilla install. This helper copies the bundled overlay files
|
| 6 |
+
into a FastVideo source checkout or site-packages installation.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import importlib.util
|
| 13 |
+
import shutil
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _find_fastvideo_root() -> Path:
|
| 19 |
+
spec = importlib.util.find_spec("fastvideo")
|
| 20 |
+
if spec is None or spec.origin is None:
|
| 21 |
+
raise RuntimeError(
|
| 22 |
+
"Could not import fastvideo. Pass --fastvideo-root explicitly or "
|
| 23 |
+
"activate a FastVideo environment first.")
|
| 24 |
+
return Path(spec.origin).resolve().parents[1]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _iter_overlay_files(overlay_root: Path):
|
| 28 |
+
for path in sorted(overlay_root.rglob("*")):
|
| 29 |
+
if path.is_file() and "__pycache__" not in path.parts:
|
| 30 |
+
yield path
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def main() -> int:
|
| 34 |
+
parser = argparse.ArgumentParser()
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--fastvideo-root",
|
| 37 |
+
type=Path,
|
| 38 |
+
default=None,
|
| 39 |
+
help="FastVideo repository/install root. Defaults to import location.",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--backup",
|
| 43 |
+
action="store_true",
|
| 44 |
+
help="Write .sfp4_backup copies before overwriting existing files.",
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--dry-run",
|
| 48 |
+
action="store_true",
|
| 49 |
+
help="Print files that would be copied without modifying anything.",
|
| 50 |
+
)
|
| 51 |
+
args = parser.parse_args()
|
| 52 |
+
|
| 53 |
+
bundle_root = Path(__file__).resolve().parent
|
| 54 |
+
overlay_root = bundle_root / "overlay_files"
|
| 55 |
+
if not overlay_root.is_dir():
|
| 56 |
+
raise RuntimeError(f"Missing overlay directory: {overlay_root}")
|
| 57 |
+
|
| 58 |
+
target_root = args.fastvideo_root.resolve() if args.fastvideo_root else _find_fastvideo_root()
|
| 59 |
+
if not (target_root / "fastvideo").exists():
|
| 60 |
+
raise RuntimeError(
|
| 61 |
+
f"{target_root} does not look like a FastVideo root: missing fastvideo/")
|
| 62 |
+
|
| 63 |
+
copied = 0
|
| 64 |
+
for src in _iter_overlay_files(overlay_root):
|
| 65 |
+
rel = src.relative_to(overlay_root)
|
| 66 |
+
dst = target_root / rel
|
| 67 |
+
print(f"{rel}")
|
| 68 |
+
if args.dry_run:
|
| 69 |
+
continue
|
| 70 |
+
dst.parent.mkdir(parents=True, exist_ok=True)
|
| 71 |
+
if args.backup and dst.exists():
|
| 72 |
+
backup = dst.with_suffix(dst.suffix + ".sfp4_backup")
|
| 73 |
+
if not backup.exists():
|
| 74 |
+
shutil.copy2(dst, backup)
|
| 75 |
+
shutil.copy2(src, dst)
|
| 76 |
+
copied += 1
|
| 77 |
+
|
| 78 |
+
if args.dry_run:
|
| 79 |
+
print(f"Dry run complete for target root: {target_root}")
|
| 80 |
+
else:
|
| 81 |
+
print(f"Installed {copied} files into {target_root}")
|
| 82 |
+
print(
|
| 83 |
+
"Use PYTHONPATH='<FastVideo>/fastvideo-kernel/python:"
|
| 84 |
+
"<FastVideo>/fastvideo-kernel:$PYTHONPATH' when running inference.")
|
| 85 |
+
return 0
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
raise SystemExit(main())
|
standalone_inference/manifest.sha256
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fb13abe775d8acd0aa59ce47ebad40178e4f2604fd191b6b02c1e34dd1e95cc4 ./README.md
|
| 2 |
+
eb151afbefca213bbf1595e94b40547e1e431e850e6fc4cd187e506eb8e25b2d ./install_overlay.py
|
| 3 |
+
9d1d8dc58aab529270fe31eb1735d6a1382c0c6d36fccca122a8dbffa1b714fd ./overlay_files/fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py
|
| 4 |
+
211c7f0445fbe9488250f01fa83457c6620e83bd6f3877db791fd155de93c08b ./overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py
|
| 5 |
+
3f3a407a88612ea17ad65e1b6b9cf6b7b02df56956d8301c4b13bffa92095016 ./overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py
|
| 6 |
+
56f17c602dede53c7c3677058f81274681530f1b83c086d9d1d44c6b51feefbb ./overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py
|
| 7 |
+
58f4ac013e6755336212a7a6c9948b19dab0dafc00f4a3298591598df270cb39 ./overlay_files/fastvideo/api/compat.py
|
| 8 |
+
2b821b0e2e7bdb3581be6312ebbece42380a6ee28a7a982f0cf2dc71fab849c8 ./overlay_files/fastvideo/attention/backends/sparse_fp4_ours_p_attn.py
|
| 9 |
+
a97adcc52d7558c49f418c09395fd1665e988ad290d2276b95f21dfca0f8eb7d ./overlay_files/fastvideo/attention/backends/video_sparse_attn.py
|
| 10 |
+
79ef6f38ec0f5bfe16b2b98327ad2ccd15f3c863dd87fd03affc5dbdaa0a8224 ./overlay_files/fastvideo/configs/models/dits/base.py
|
| 11 |
+
4bda44746a3626551ea9a9380d890f036087092fb99fce2d302642cce14a97ed ./overlay_files/fastvideo/configs/pipelines/wan.py
|
| 12 |
+
5926e29a594db13b116922f131db50631bf8adbf90fe5cec00a5e2f446bfb4ca ./overlay_files/fastvideo/configs/sample/base.py
|
| 13 |
+
d99adcf607d982b38bbb5a70be60bf87f35d0e9f6f50752f3bceb68b34ce46c2 ./overlay_files/fastvideo/configs/sample/wan.py
|
| 14 |
+
49775ce42fd9643c78d8fad4ab8248c1755c7f1524ad771cbd1863d76c513c38 ./overlay_files/fastvideo/configs/wan_1.3B_t2v_pipeline.json
|
| 15 |
+
ae2d8309472b09927da3e450dea52d9715dcabe5d6722fc2917130ae8d85adb4 ./overlay_files/fastvideo/entrypoints/cli/generate.py
|
| 16 |
+
d0466769626e7fd497376c544904d56ba62847745eb52527896d96b99d76ba03 ./overlay_files/fastvideo/entrypoints/video_generator.py
|
| 17 |
+
73afe6b2ebe0f8cfe0a8ec762a7126161621ad97a64ebad628995f4a164b8b0e ./overlay_files/fastvideo/fastvideo_args.py
|
| 18 |
+
ddcab6f4fd33c9813840571b6bf83bbbcea164b564166951ed4301297db6cef0 ./overlay_files/fastvideo/forward_context.py
|
| 19 |
+
e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 ./overlay_files/fastvideo/pipelines/basic/wan/__init__.py
|
| 20 |
+
deac1e22530a6a41c501629f5e8fce47a7af4e008f321cc8a4d734c5120ef4fe ./overlay_files/fastvideo/pipelines/basic/wan/wan_pipeline.py
|
| 21 |
+
8908223b3ff99cdb3206148a68a730c2a13d554a2fb1316db6f2f9672efac9e8 ./overlay_files/fastvideo/pipelines/composed_pipeline_base.py
|
| 22 |
+
6cfd128e782b7787a27ddd28a5e2d50cb4b0e2e9425d51d9780f14c91e8206f0 ./overlay_files/fastvideo/pipelines/stages/denoising.py
|
| 23 |
+
489388dbdd9e5e3ad24db3012bd9b108794509a9729891d7dd315a102abba828 ./overlay_files/fastvideo/platforms/cuda.py
|
| 24 |
+
c046b1914041b59254bcdfe577aed20d6f007a72632ea1fe1ae92fa678eca760 ./overlay_files/fastvideo/platforms/interface.py
|
| 25 |
+
2456d39ca28019e12bb7ab007774e86348f0582a017bf0e6c91e2a01d654a1a0 ./overlay_files/fastvideo/train/models/wan/wan.py
|
| 26 |
+
bc46e84b732567de6c0325223405daecd1226c623e303be33c7be9b5b7fdec08 ./overlay_files/fastvideo/training/training_pipeline.py
|
| 27 |
+
1d3898fa37e21029df6c37e05dc34ed7805a211c2f87de6642db890e5a8c6f2e ./overlay_files/fastvideo/training/wan_training_pipeline.py
|
| 28 |
+
1b2addfcb414ab65e20034394ee21a8af9ada58220a680b67d3b4233a0952268 ./requirements.txt
|
| 29 |
+
5087bb4ffe5721c41a12d92d8dfe439cd86aa1a5d3b3d259e30ad62711d95081 ./run.sh
|
| 30 |
+
b826c8b059a000af6054ec099c36742d01e6a329ee77bc5936ae7562e9428409 ./run_inference.py
|
| 31 |
+
8ddeea65247d9fa31a4a8a2a5ce5abe068a911ff4d67871453555e1355af8ecf ./training_attention_settings.json
|
standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _use_high_prec_output_for_backward() -> bool:
|
| 9 |
+
value = os.environ.get("FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O", "1")
|
| 10 |
+
return value.lower() not in ("0", "false", "no", "off")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _map_to_index(block_map: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 14 |
+
if block_map.dim() == 3:
|
| 15 |
+
block_map = block_map.unsqueeze(0)
|
| 16 |
+
if block_map.dim() != 4:
|
| 17 |
+
raise ValueError(
|
| 18 |
+
f"block_map must be [B,H,Q,KV] or [H,Q,KV], got {tuple(block_map.shape)}"
|
| 19 |
+
)
|
| 20 |
+
if block_map.dtype != torch.bool:
|
| 21 |
+
block_map = block_map.to(torch.bool)
|
| 22 |
+
if not block_map.is_cuda:
|
| 23 |
+
raise RuntimeError("block_map must be a CUDA tensor.")
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from fastvideo_kernel.triton_kernels.index import map_to_index as triton_map_to_index
|
| 27 |
+
except Exception as e:
|
| 28 |
+
raise ImportError("Triton map_to_index is required for ours-P Sparse FP4.") from e
|
| 29 |
+
return triton_map_to_index(block_map)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@torch.library.custom_op(
|
| 33 |
+
"fastvideo_kernel::block_sparse_attn_ours_p_triton",
|
| 34 |
+
mutates_args=(),
|
| 35 |
+
device_types="cuda",
|
| 36 |
+
)
|
| 37 |
+
def block_sparse_attn_ours_p_triton(
|
| 38 |
+
q: torch.Tensor,
|
| 39 |
+
k: torch.Tensor,
|
| 40 |
+
v: torch.Tensor,
|
| 41 |
+
block_map: torch.Tensor,
|
| 42 |
+
variable_block_sizes: torch.Tensor,
|
| 43 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 44 |
+
q = q.contiguous()
|
| 45 |
+
k = k.contiguous()
|
| 46 |
+
v = v.contiguous()
|
| 47 |
+
block_map = block_map.to(torch.bool)
|
| 48 |
+
q2k_idx, q2k_num = _map_to_index(block_map)
|
| 49 |
+
|
| 50 |
+
from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
|
| 51 |
+
triton_block_sparse_attn_forward,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return triton_block_sparse_attn_forward(
|
| 55 |
+
q, k, v, q2k_idx, q2k_num, variable_block_sizes, is_qat=True
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@torch.library.register_fake("fastvideo_kernel::block_sparse_attn_ours_p_triton")
|
| 60 |
+
def _block_sparse_attn_ours_p_triton_fake(
|
| 61 |
+
q: torch.Tensor,
|
| 62 |
+
k: torch.Tensor,
|
| 63 |
+
v: torch.Tensor,
|
| 64 |
+
block_map: torch.Tensor,
|
| 65 |
+
variable_block_sizes: torch.Tensor,
|
| 66 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 67 |
+
o = torch.empty_like(q)
|
| 68 |
+
high_prec_o = torch.empty_like(q)
|
| 69 |
+
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
| 70 |
+
return o, M, high_prec_o
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@torch.library.custom_op(
|
| 74 |
+
"fastvideo_kernel::block_sparse_attn_ours_p_backward_triton",
|
| 75 |
+
mutates_args=(),
|
| 76 |
+
device_types="cuda",
|
| 77 |
+
)
|
| 78 |
+
def block_sparse_attn_ours_p_backward_triton(
|
| 79 |
+
grad_output: torch.Tensor,
|
| 80 |
+
q: torch.Tensor,
|
| 81 |
+
k: torch.Tensor,
|
| 82 |
+
v: torch.Tensor,
|
| 83 |
+
o: torch.Tensor,
|
| 84 |
+
M: torch.Tensor,
|
| 85 |
+
block_map: torch.Tensor,
|
| 86 |
+
variable_block_sizes: torch.Tensor,
|
| 87 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 88 |
+
grad_output = grad_output.contiguous()
|
| 89 |
+
block_map = block_map.to(torch.bool)
|
| 90 |
+
q2k_idx, q2k_num = _map_to_index(block_map)
|
| 91 |
+
k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous())
|
| 92 |
+
|
| 93 |
+
from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
|
| 94 |
+
triton_block_sparse_attn_backward,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return triton_block_sparse_attn_backward(
|
| 98 |
+
grad_output,
|
| 99 |
+
q,
|
| 100 |
+
k,
|
| 101 |
+
v,
|
| 102 |
+
o,
|
| 103 |
+
M,
|
| 104 |
+
q2k_idx,
|
| 105 |
+
q2k_num,
|
| 106 |
+
k2q_idx,
|
| 107 |
+
k2q_num,
|
| 108 |
+
variable_block_sizes,
|
| 109 |
+
is_qat=True,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@torch.library.register_fake(
|
| 114 |
+
"fastvideo_kernel::block_sparse_attn_ours_p_backward_triton"
|
| 115 |
+
)
|
| 116 |
+
def _block_sparse_attn_ours_p_backward_triton_fake(
|
| 117 |
+
grad_output: torch.Tensor,
|
| 118 |
+
q: torch.Tensor,
|
| 119 |
+
k: torch.Tensor,
|
| 120 |
+
v: torch.Tensor,
|
| 121 |
+
o: torch.Tensor,
|
| 122 |
+
M: torch.Tensor,
|
| 123 |
+
block_map: torch.Tensor,
|
| 124 |
+
variable_block_sizes: torch.Tensor,
|
| 125 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 126 |
+
return torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _backward_triton(ctx, grad_o, grad_M, grad_high_prec_o):
|
| 130 |
+
q, k, v, o_for_bwd, M, block_map, variable_block_sizes = ctx.saved_tensors
|
| 131 |
+
dq, dk, dv = block_sparse_attn_ours_p_backward_triton(
|
| 132 |
+
grad_o, q, k, v, o_for_bwd, M, block_map, variable_block_sizes
|
| 133 |
+
)
|
| 134 |
+
return dq, dk, dv, None, None
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _setup_context_triton(ctx, inputs, output):
|
| 138 |
+
q, k, v, block_map, variable_block_sizes = inputs
|
| 139 |
+
o, M, high_prec_o = output
|
| 140 |
+
o_for_bwd = high_prec_o if _use_high_prec_output_for_backward() else o
|
| 141 |
+
ctx.save_for_backward(q, k, v, o_for_bwd, M, block_map, variable_block_sizes)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
block_sparse_attn_ours_p_triton.register_autograd(
|
| 145 |
+
_backward_triton, setup_context=_setup_context_triton
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class _BlockSparseAttnOursPTileComp(torch.autograd.Function):
|
| 150 |
+
|
| 151 |
+
@staticmethod
|
| 152 |
+
def forward(ctx, q, k, v, q_mean, k_mean, v_mean, block_map, variable_block_sizes):
|
| 153 |
+
q = q.contiguous()
|
| 154 |
+
k = k.contiguous()
|
| 155 |
+
v = v.contiguous()
|
| 156 |
+
q_mean = q_mean.contiguous()
|
| 157 |
+
k_mean = k_mean.contiguous()
|
| 158 |
+
v_mean = v_mean.contiguous()
|
| 159 |
+
block_map = block_map.to(torch.bool)
|
| 160 |
+
dropped_block_map = torch.logical_not(block_map)
|
| 161 |
+
|
| 162 |
+
q2k_idx, q2k_num = _map_to_index(block_map)
|
| 163 |
+
dropped_q2k_idx, dropped_q2k_num = _map_to_index(dropped_block_map)
|
| 164 |
+
|
| 165 |
+
from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
|
| 166 |
+
triton_block_sparse_attn_forward,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
o, M, high_prec_o = triton_block_sparse_attn_forward(
|
| 170 |
+
q,
|
| 171 |
+
k,
|
| 172 |
+
v,
|
| 173 |
+
q2k_idx,
|
| 174 |
+
q2k_num,
|
| 175 |
+
variable_block_sizes,
|
| 176 |
+
is_qat=True,
|
| 177 |
+
q_mean=q_mean,
|
| 178 |
+
k_mean=k_mean,
|
| 179 |
+
v_mean=v_mean,
|
| 180 |
+
dropped_q2k_index=dropped_q2k_idx,
|
| 181 |
+
dropped_q2k_num=dropped_q2k_num,
|
| 182 |
+
)
|
| 183 |
+
o_for_bwd = high_prec_o if _use_high_prec_output_for_backward() else o
|
| 184 |
+
ctx.save_for_backward(
|
| 185 |
+
q,
|
| 186 |
+
k,
|
| 187 |
+
v,
|
| 188 |
+
q_mean,
|
| 189 |
+
k_mean,
|
| 190 |
+
v_mean,
|
| 191 |
+
o_for_bwd,
|
| 192 |
+
M,
|
| 193 |
+
block_map,
|
| 194 |
+
dropped_block_map,
|
| 195 |
+
variable_block_sizes,
|
| 196 |
+
)
|
| 197 |
+
return o, M
|
| 198 |
+
|
| 199 |
+
@staticmethod
|
| 200 |
+
def backward(ctx, grad_o, grad_M):
|
| 201 |
+
(
|
| 202 |
+
q,
|
| 203 |
+
k,
|
| 204 |
+
v,
|
| 205 |
+
q_mean,
|
| 206 |
+
k_mean,
|
| 207 |
+
v_mean,
|
| 208 |
+
o_for_bwd,
|
| 209 |
+
M,
|
| 210 |
+
block_map,
|
| 211 |
+
dropped_block_map,
|
| 212 |
+
variable_block_sizes,
|
| 213 |
+
) = ctx.saved_tensors
|
| 214 |
+
|
| 215 |
+
q2k_idx, q2k_num = _map_to_index(block_map)
|
| 216 |
+
k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous())
|
| 217 |
+
dropped_q2k_idx, dropped_q2k_num = _map_to_index(dropped_block_map)
|
| 218 |
+
dropped_k2q_idx, dropped_k2q_num = _map_to_index(
|
| 219 |
+
dropped_block_map.transpose(-1, -2).contiguous()
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
|
| 223 |
+
triton_block_sparse_attn_backward,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
dq, dk, dv = triton_block_sparse_attn_backward(
|
| 227 |
+
grad_o.contiguous(),
|
| 228 |
+
q,
|
| 229 |
+
k,
|
| 230 |
+
v,
|
| 231 |
+
o_for_bwd,
|
| 232 |
+
M,
|
| 233 |
+
q2k_idx,
|
| 234 |
+
q2k_num,
|
| 235 |
+
k2q_idx,
|
| 236 |
+
k2q_num,
|
| 237 |
+
variable_block_sizes,
|
| 238 |
+
is_qat=True,
|
| 239 |
+
q_mean=q_mean,
|
| 240 |
+
k_mean=k_mean,
|
| 241 |
+
v_mean=v_mean,
|
| 242 |
+
dropped_q2k_index=dropped_q2k_idx,
|
| 243 |
+
dropped_q2k_num=dropped_q2k_num,
|
| 244 |
+
dropped_k2q_index=dropped_k2q_idx,
|
| 245 |
+
dropped_k2q_num=dropped_k2q_num,
|
| 246 |
+
)
|
| 247 |
+
return dq, dk, dv, None, None, None, None, None
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def block_sparse_attn_ours_p(
|
| 251 |
+
q: torch.Tensor,
|
| 252 |
+
k: torch.Tensor,
|
| 253 |
+
v: torch.Tensor,
|
| 254 |
+
block_map: torch.Tensor,
|
| 255 |
+
variable_block_sizes: torch.Tensor,
|
| 256 |
+
q_mean: torch.Tensor | None = None,
|
| 257 |
+
k_mean: torch.Tensor | None = None,
|
| 258 |
+
v_mean: torch.Tensor | None = None,
|
| 259 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 260 |
+
if (q_mean is not None) or (k_mean is not None) or (v_mean is not None):
|
| 261 |
+
if q_mean is None or k_mean is None or v_mean is None:
|
| 262 |
+
raise ValueError("q_mean, k_mean, and v_mean must be provided together")
|
| 263 |
+
return _BlockSparseAttnOursPTileComp.apply(
|
| 264 |
+
q, k, v, q_mean, k_mean, v_mean, block_map, variable_block_sizes
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
o, M, _ = block_sparse_attn_ours_p_triton(
|
| 268 |
+
q, k, v, block_map, variable_block_sizes
|
| 269 |
+
)
|
| 270 |
+
return o, M
|
standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py
ADDED
|
@@ -0,0 +1,1155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fused Attention
|
| 3 |
+
===============
|
| 4 |
+
|
| 5 |
+
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
|
| 6 |
+
(https://tridao.me/publications/flash2/flash2.pdf)
|
| 7 |
+
|
| 8 |
+
Credits: OpenAI kernel team
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import triton
|
| 13 |
+
import triton.language as tl
|
| 14 |
+
from .quant_utils import fake_quantize
|
| 15 |
+
|
| 16 |
+
# ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
|
| 17 |
+
import math # small utility needed by the sparse wrapper
|
| 18 |
+
# ──────────────────────────── SPARSE ADDITION END ─────────────────────────────
|
| 19 |
+
|
| 20 |
+
# We don't run auto-tuning every time to keep the tutorial fast. Keeping
|
| 21 |
+
# the code below and commenting out the equivalent parameters is convenient for
|
| 22 |
+
# re-tuning.
|
| 23 |
+
configs = [
|
| 24 |
+
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
|
| 25 |
+
for BM in [64]\
|
| 26 |
+
for BN in [64]\
|
| 27 |
+
for s in [3, 4, 7]\
|
| 28 |
+
for w in [4, 8]\
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
|
| 33 |
+
@triton.autotune(configs, key=["N_CTX_Q", "HEAD_DIM"])
|
| 34 |
+
@triton.jit
|
| 35 |
+
def _attn_fwd_sparse(
|
| 36 |
+
Q,
|
| 37 |
+
K,
|
| 38 |
+
V,
|
| 39 |
+
QMean,
|
| 40 |
+
KMean,
|
| 41 |
+
VMean,
|
| 42 |
+
sm_scale, #
|
| 43 |
+
q2k_index,
|
| 44 |
+
q2k_num,
|
| 45 |
+
max_kv_blks, #
|
| 46 |
+
dropped_q2k_index,
|
| 47 |
+
dropped_q2k_num,
|
| 48 |
+
max_dropped_kv_blks, #
|
| 49 |
+
variable_block_sizes,
|
| 50 |
+
M,
|
| 51 |
+
Out, #
|
| 52 |
+
HighPrecOut, #
|
| 53 |
+
stride_qz,
|
| 54 |
+
stride_qh,
|
| 55 |
+
stride_qm,
|
| 56 |
+
stride_qk,
|
| 57 |
+
stride_kz,
|
| 58 |
+
stride_kh,
|
| 59 |
+
stride_kn,
|
| 60 |
+
stride_kk,
|
| 61 |
+
stride_vz,
|
| 62 |
+
stride_vh,
|
| 63 |
+
stride_vk,
|
| 64 |
+
stride_vn,
|
| 65 |
+
stride_oz,
|
| 66 |
+
stride_oh,
|
| 67 |
+
stride_om,
|
| 68 |
+
stride_on,
|
| 69 |
+
Z,
|
| 70 |
+
H,
|
| 71 |
+
N_CTX_Q, #
|
| 72 |
+
N_CTX_KV, #
|
| 73 |
+
HEAD_DIM: tl.constexpr, #
|
| 74 |
+
BLOCK_M: tl.constexpr,
|
| 75 |
+
BLOCK_N: tl.constexpr,
|
| 76 |
+
STAGE: tl.constexpr,
|
| 77 |
+
IS_QAT: tl.constexpr = False,
|
| 78 |
+
USE_TILE_COMP: tl.constexpr = False):
|
| 79 |
+
"""
|
| 80 |
+
64x64 block-sparse forward kernel for the independent "ours P quant" path.
|
| 81 |
+
|
| 82 |
+
P quantization is group-local: each selected KV tile quantizes
|
| 83 |
+
exp2(logit - tile_row_max), then applies exp2(tile_row_max - online_max)
|
| 84 |
+
after the FP4 PV GEMM. This intentionally differs from the QAT-style
|
| 85 |
+
backend, which quantizes exp2(logit - online_max) directly.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
# ----- program-id mapping -----
|
| 89 |
+
q_blk = tl.program_id(0) # Q-tile index
|
| 90 |
+
off_hz = tl.program_id(1) # fused (batch, head)
|
| 91 |
+
b = off_hz // H
|
| 92 |
+
h = off_hz % H
|
| 93 |
+
q_tiles = N_CTX_Q // BLOCK_M
|
| 94 |
+
meta_base = ((b * H + h) * q_tiles + q_blk)
|
| 95 |
+
|
| 96 |
+
kv_blocks = tl.load(q2k_num + meta_base) # int32
|
| 97 |
+
kv_ptr = q2k_index + meta_base * max_kv_blks # ptr to list
|
| 98 |
+
dropped_kv_blocks = tl.load(dropped_q2k_num + meta_base)
|
| 99 |
+
dropped_kv_ptr = dropped_q2k_index + meta_base * max_dropped_kv_blks
|
| 100 |
+
|
| 101 |
+
# ----- base pointers -----
|
| 102 |
+
q_off = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
|
| 103 |
+
k_off = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
|
| 104 |
+
v_off = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
|
| 105 |
+
o_off = (b.to(tl.int64) * stride_oz + h.to(tl.int64) * stride_oh)
|
| 106 |
+
|
| 107 |
+
Q_ptr = tl.make_block_ptr(base=Q + q_off,
|
| 108 |
+
shape=(N_CTX_Q, HEAD_DIM),
|
| 109 |
+
strides=(stride_qm, stride_qk),
|
| 110 |
+
offsets=(q_blk * BLOCK_M, 0),
|
| 111 |
+
block_shape=(BLOCK_M, HEAD_DIM),
|
| 112 |
+
order=(1, 0))
|
| 113 |
+
|
| 114 |
+
K_base = tl.make_block_ptr(base=K + k_off,
|
| 115 |
+
shape=(HEAD_DIM, N_CTX_KV),
|
| 116 |
+
strides=(stride_kk, stride_kn),
|
| 117 |
+
offsets=(0, 0),
|
| 118 |
+
block_shape=(HEAD_DIM, BLOCK_N),
|
| 119 |
+
order=(0, 1))
|
| 120 |
+
|
| 121 |
+
v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1,
|
| 122 |
+
0)
|
| 123 |
+
V_base = tl.make_block_ptr(base=V + v_off,
|
| 124 |
+
shape=(N_CTX_KV, HEAD_DIM),
|
| 125 |
+
strides=(stride_vk, stride_vn),
|
| 126 |
+
offsets=(0, 0),
|
| 127 |
+
block_shape=(BLOCK_N, HEAD_DIM),
|
| 128 |
+
order=v_order)
|
| 129 |
+
|
| 130 |
+
O_ptr = tl.make_block_ptr(base=Out + o_off,
|
| 131 |
+
shape=(N_CTX_Q, HEAD_DIM),
|
| 132 |
+
strides=(stride_om, stride_on),
|
| 133 |
+
offsets=(q_blk * BLOCK_M, 0),
|
| 134 |
+
block_shape=(BLOCK_M, HEAD_DIM),
|
| 135 |
+
order=(1, 0))
|
| 136 |
+
HPO_ptr = tl.make_block_ptr(base=HighPrecOut + o_off,
|
| 137 |
+
shape=(N_CTX_Q, HEAD_DIM),
|
| 138 |
+
strides=(stride_om, stride_on),
|
| 139 |
+
offsets=(q_blk * BLOCK_M, 0),
|
| 140 |
+
block_shape=(BLOCK_M, HEAD_DIM),
|
| 141 |
+
order=(1, 0))
|
| 142 |
+
|
| 143 |
+
# ----- accumulators -----
|
| 144 |
+
offs_m = q_blk * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 145 |
+
m_i = tl.full([BLOCK_M], -float("inf"), tl.float32)
|
| 146 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
| 147 |
+
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
| 148 |
+
high_prec_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
| 149 |
+
qk_scale = sm_scale * 1.44269504 # 1/ln2
|
| 150 |
+
q = tl.load(Q_ptr)
|
| 151 |
+
offs_d = tl.arange(0, HEAD_DIM)
|
| 152 |
+
|
| 153 |
+
# ----- sparse loop over valid K/V tiles -----
|
| 154 |
+
for i in range(0, kv_blocks):
|
| 155 |
+
kv_idx = tl.load(kv_ptr + i).to(tl.int32)
|
| 156 |
+
block_size = tl.load(variable_block_sizes + kv_idx)
|
| 157 |
+
K_ptr = tl.advance(K_base, (0, kv_idx * BLOCK_N))
|
| 158 |
+
V_ptr = tl.advance(V_base, (kv_idx * BLOCK_N, 0))
|
| 159 |
+
|
| 160 |
+
k = tl.load(K_ptr)
|
| 161 |
+
mask = tl.arange(0, BLOCK_N) < block_size
|
| 162 |
+
qk = tl.dot(q, k) * qk_scale
|
| 163 |
+
# mask out invalid columns
|
| 164 |
+
qk = tl.where(mask[None, :], qk, -float("inf"))
|
| 165 |
+
group_m = tl.max(qk, 1)
|
| 166 |
+
m_ij = tl.maximum(m_i, group_m)
|
| 167 |
+
|
| 168 |
+
p_local = tl.math.exp2(qk - group_m[:, None])
|
| 169 |
+
p_local = tl.where(mask[None, :], p_local, 0.0)
|
| 170 |
+
p_comp = tl.math.exp2(group_m - m_ij)
|
| 171 |
+
p_valid = mask[None, :] & (
|
| 172 |
+
tl.full(shape=p_local.shape, value=1.0,
|
| 173 |
+
dtype=p_local.dtype) == 1.0
|
| 174 |
+
)
|
| 175 |
+
p_quant, high_prec_p = fake_quantize(
|
| 176 |
+
src_tensor=p_local, valid_src_mask=p_valid,
|
| 177 |
+
BLOCK_SIZE_OUT_DIM=BLOCK_M, BLOCK_SIZE_QUANT_DIM=BLOCK_N,
|
| 178 |
+
dst_dtype=tl.bfloat16, use_global_sf=False,
|
| 179 |
+
)
|
| 180 |
+
l_ij = tl.sum(high_prec_p, 1) * p_comp
|
| 181 |
+
|
| 182 |
+
alpha = tl.math.exp2(m_i - m_ij)
|
| 183 |
+
l_i = l_i * alpha + l_ij
|
| 184 |
+
acc = acc * alpha[:, None]
|
| 185 |
+
high_prec_acc = high_prec_acc * alpha[:, None]
|
| 186 |
+
|
| 187 |
+
v = tl.load(V_ptr)
|
| 188 |
+
acc = acc + tl.dot(
|
| 189 |
+
p_quant.to(tl.bfloat16),
|
| 190 |
+
v.to(tl.bfloat16),
|
| 191 |
+
) * p_comp[:, None]
|
| 192 |
+
high_prec_acc = high_prec_acc + tl.dot(
|
| 193 |
+
high_prec_p.to(tl.bfloat16),
|
| 194 |
+
v.to(tl.bfloat16),
|
| 195 |
+
) * p_comp[:, None]
|
| 196 |
+
m_i = m_ij
|
| 197 |
+
|
| 198 |
+
if USE_TILE_COMP:
|
| 199 |
+
q_mean_base = (off_hz * q_tiles + q_blk).to(tl.int64) * HEAD_DIM
|
| 200 |
+
q_mean = tl.load(QMean + q_mean_base + offs_d).to(tl.float32)
|
| 201 |
+
kv_tiles = N_CTX_KV // BLOCK_N
|
| 202 |
+
|
| 203 |
+
for i in range(0, dropped_kv_blocks):
|
| 204 |
+
kv_idx = tl.load(dropped_kv_ptr + i).to(tl.int32)
|
| 205 |
+
block_size = tl.load(variable_block_sizes + kv_idx).to(tl.float32)
|
| 206 |
+
kv_mean_base = (off_hz * kv_tiles + kv_idx).to(tl.int64) * HEAD_DIM
|
| 207 |
+
k_mean = tl.load(KMean + kv_mean_base + offs_d).to(tl.float32)
|
| 208 |
+
v_mean = tl.load(VMean + kv_mean_base + offs_d).to(tl.float32)
|
| 209 |
+
|
| 210 |
+
score = tl.sum(q_mean * k_mean, axis=0) * qk_scale
|
| 211 |
+
m_ij = tl.maximum(m_i, score)
|
| 212 |
+
alpha = tl.math.exp2(m_i - m_ij)
|
| 213 |
+
beta = tl.math.exp2(score - m_ij)
|
| 214 |
+
|
| 215 |
+
l_i = l_i * alpha + block_size * beta
|
| 216 |
+
comp = (block_size * beta)[:, None] * v_mean[None, :]
|
| 217 |
+
acc = acc * alpha[:, None] + comp
|
| 218 |
+
high_prec_acc = high_prec_acc * alpha[:, None] + comp
|
| 219 |
+
m_i = m_ij
|
| 220 |
+
|
| 221 |
+
# ----- epilogue -----
|
| 222 |
+
m_i += tl.math.log2(l_i)
|
| 223 |
+
acc = acc / l_i[:, None]
|
| 224 |
+
high_prec_acc = high_prec_acc / l_i[:, None]
|
| 225 |
+
tl.store(M + off_hz * N_CTX_Q + offs_m, m_i)
|
| 226 |
+
tl.store(O_ptr, acc.to(Out.type.element_ty))
|
| 227 |
+
tl.store(HPO_ptr, high_prec_acc.to(HighPrecOut.type.element_ty))
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# ──────────────────────────── SPARSE ADDITION END ─────────────────────────────
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@triton.jit
|
| 234 |
+
def _attn_bwd_preprocess(
|
| 235 |
+
O,
|
| 236 |
+
DO, #
|
| 237 |
+
Delta, #
|
| 238 |
+
Z,
|
| 239 |
+
H,
|
| 240 |
+
N_CTX, #
|
| 241 |
+
BLOCK_M: tl.constexpr,
|
| 242 |
+
HEAD_DIM: tl.constexpr #
|
| 243 |
+
):
|
| 244 |
+
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 245 |
+
off_hz = tl.program_id(1)
|
| 246 |
+
off_n = tl.arange(0, HEAD_DIM)
|
| 247 |
+
# load
|
| 248 |
+
o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM +
|
| 249 |
+
off_n[None, :])
|
| 250 |
+
do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM +
|
| 251 |
+
off_n[None, :]).to(tl.float32)
|
| 252 |
+
delta = tl.sum(o * do, axis=1)
|
| 253 |
+
# write-back
|
| 254 |
+
tl.store(Delta + off_hz * N_CTX + off_m, delta)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# The main inner-loop logic for computing dK and dV.
|
| 258 |
+
@triton.jit
|
| 259 |
+
def _attn_bwd_dkdv(
|
| 260 |
+
dk,
|
| 261 |
+
dv, #
|
| 262 |
+
Q,
|
| 263 |
+
k,
|
| 264 |
+
v,
|
| 265 |
+
QMean,
|
| 266 |
+
KMean,
|
| 267 |
+
VMean,
|
| 268 |
+
sm_scale, #
|
| 269 |
+
DO, #
|
| 270 |
+
M,
|
| 271 |
+
D, #
|
| 272 |
+
k2q_index,
|
| 273 |
+
k2q_num,
|
| 274 |
+
max_q_blks,
|
| 275 |
+
dropped_k2q_index,
|
| 276 |
+
dropped_k2q_num,
|
| 277 |
+
max_dropped_q_blks,
|
| 278 |
+
variable_block_sizes,
|
| 279 |
+
# shared by Q/K/V/DO.
|
| 280 |
+
stride_tok,
|
| 281 |
+
stride_d, #
|
| 282 |
+
H,
|
| 283 |
+
N_CTX_KV,
|
| 284 |
+
BLOCK_M1: tl.constexpr, #
|
| 285 |
+
BLOCK_N1: tl.constexpr, #
|
| 286 |
+
HEAD_DIM: tl.constexpr, #
|
| 287 |
+
# Filled in by the wrapper.
|
| 288 |
+
start_n,
|
| 289 |
+
start_m,
|
| 290 |
+
num_steps,
|
| 291 |
+
IS_QAT: tl.constexpr = False,
|
| 292 |
+
USE_TILE_COMP: tl.constexpr = False):
|
| 293 |
+
offs_m = start_m + tl.arange(0, BLOCK_M1)
|
| 294 |
+
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
| 295 |
+
offs_k = tl.arange(0, HEAD_DIM)
|
| 296 |
+
qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
|
| 297 |
+
do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
| 298 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 299 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 300 |
+
step_m = BLOCK_M1
|
| 301 |
+
kv_blk = tl.program_id(0) # Q-tile index
|
| 302 |
+
off_hz = tl.program_id(2) # fused (batch, head)
|
| 303 |
+
b = off_hz // H
|
| 304 |
+
h = off_hz % H
|
| 305 |
+
kv_tiles = N_CTX_KV // BLOCK_N1
|
| 306 |
+
meta_base = ((b * H + h) * kv_tiles + kv_blk)
|
| 307 |
+
|
| 308 |
+
q_blocks = tl.load(k2q_num + meta_base) # int32
|
| 309 |
+
q_ptr = k2q_index + meta_base * max_q_blks # ptr to list
|
| 310 |
+
dropped_q_blocks = tl.load(dropped_k2q_num + meta_base)
|
| 311 |
+
dropped_q_ptr = dropped_k2q_index + meta_base * max_dropped_q_blks
|
| 312 |
+
block_size = tl.load(variable_block_sizes + kv_blk)
|
| 313 |
+
block_size_f = block_size.to(tl.float32)
|
| 314 |
+
|
| 315 |
+
for blk_idx in range(q_blocks * 2):
|
| 316 |
+
block_sparse_offset = (tl.load(q_ptr + blk_idx // 2).to(tl.int32) * 2 +
|
| 317 |
+
blk_idx % 2) * step_m
|
| 318 |
+
qT = tl.load(qT_ptrs + block_sparse_offset * stride_tok)
|
| 319 |
+
# Load m before computing qk to reduce pipeline stall.
|
| 320 |
+
offs_m = start_m + block_sparse_offset + tl.arange(0, BLOCK_M1)
|
| 321 |
+
m = tl.load(M + offs_m)
|
| 322 |
+
qkT = tl.dot(k.to(tl.bfloat16), qT)
|
| 323 |
+
qkT = qkT * sm_scale * 1.44269504
|
| 324 |
+
mask = tl.arange(0, BLOCK_N1) < block_size
|
| 325 |
+
qkT = tl.where(mask[:, None], qkT, -float("inf"))
|
| 326 |
+
group_m = tl.max(qkT, 0)
|
| 327 |
+
pT = tl.math.exp2(qkT - m[None, :])
|
| 328 |
+
pT = tl.where(mask[:, None], pT, 0.0)
|
| 329 |
+
|
| 330 |
+
do = tl.load(do_ptrs + block_sparse_offset * stride_tok)
|
| 331 |
+
# Compute dV with group-local P quantization:
|
| 332 |
+
# quantize exp2(logit - tile_col_max), then multiply dO by
|
| 333 |
+
# exp2(tile_col_max - final_lse) to recover the final softmax scale.
|
| 334 |
+
p_local_T = tl.math.exp2(qkT - group_m[None, :])
|
| 335 |
+
p_local_T = tl.where(mask[:, None], p_local_T, 0.0)
|
| 336 |
+
p_comp = tl.math.exp2(group_m - m)
|
| 337 |
+
p_for_quant = tl.trans(p_local_T)
|
| 338 |
+
p_valid = mask[None, :] & (
|
| 339 |
+
tl.full(
|
| 340 |
+
shape=p_for_quant.shape,
|
| 341 |
+
value=1.0,
|
| 342 |
+
dtype=p_for_quant.dtype,
|
| 343 |
+
) == 1.0
|
| 344 |
+
)
|
| 345 |
+
p_quant, _ = fake_quantize(
|
| 346 |
+
src_tensor=p_for_quant, valid_src_mask=p_valid,
|
| 347 |
+
BLOCK_SIZE_OUT_DIM=BLOCK_M1, BLOCK_SIZE_QUANT_DIM=BLOCK_N1,
|
| 348 |
+
dst_dtype=p_for_quant.dtype, use_global_sf=False,
|
| 349 |
+
)
|
| 350 |
+
dv += tl.dot(
|
| 351 |
+
tl.trans(p_quant.to(tl.bfloat16)),
|
| 352 |
+
(do * p_comp[:, None]).to(tl.bfloat16),
|
| 353 |
+
)
|
| 354 |
+
# D (= delta) is pre-divided by ds_scale.
|
| 355 |
+
Di = tl.load(D + offs_m)
|
| 356 |
+
# Compute dP and dS.
|
| 357 |
+
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
|
| 358 |
+
dsT = pT * (dpT - Di[None, :])
|
| 359 |
+
dsT = dsT.to(tl.bfloat16)
|
| 360 |
+
dk += tl.dot(dsT, tl.trans(qT))
|
| 361 |
+
# Increment pointers.
|
| 362 |
+
|
| 363 |
+
if USE_TILE_COMP:
|
| 364 |
+
k_mean = tl.load(KMean + kv_blk * HEAD_DIM + offs_k).to(tl.float32)
|
| 365 |
+
v_mean = tl.load(VMean + kv_blk * HEAD_DIM + offs_k).to(tl.float32)
|
| 366 |
+
qk_scale = sm_scale * 1.44269504
|
| 367 |
+
|
| 368 |
+
for blk_idx in range(dropped_q_blocks * 2):
|
| 369 |
+
q_blk_idx = tl.load(dropped_q_ptr + blk_idx // 2).to(tl.int32)
|
| 370 |
+
half = (blk_idx % 2).to(tl.int32)
|
| 371 |
+
block_sparse_offset = (q_blk_idx * 2 + half) * step_m
|
| 372 |
+
offs_m = start_m + block_sparse_offset + tl.arange(0, BLOCK_M1)
|
| 373 |
+
q_mean = tl.load(QMean + q_blk_idx * HEAD_DIM +
|
| 374 |
+
offs_k).to(tl.float32)
|
| 375 |
+
m = tl.load(M + offs_m)
|
| 376 |
+
do = tl.load(do_ptrs + block_sparse_offset * stride_tok)
|
| 377 |
+
Di = tl.load(D + offs_m)
|
| 378 |
+
q_block_size = tl.load(variable_block_sizes +
|
| 379 |
+
q_blk_idx).to(tl.float32)
|
| 380 |
+
|
| 381 |
+
score = tl.sum(q_mean * k_mean, axis=0) * qk_scale
|
| 382 |
+
p = tl.math.exp2(score - m)
|
| 383 |
+
dp = tl.sum(do.to(tl.float32) * v_mean[None, :], axis=1)
|
| 384 |
+
ds = block_size_f * p * (dp - Di)
|
| 385 |
+
|
| 386 |
+
dk_mean = tl.sum(ds[:, None] * q_mean[None, :],
|
| 387 |
+
axis=0) / block_size_f
|
| 388 |
+
dv_mean = tl.sum(p[:, None] * do.to(tl.float32), axis=0)
|
| 389 |
+
dk += dk_mean[None, :]
|
| 390 |
+
dv += dv_mean[None, :]
|
| 391 |
+
return dk, dv
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# the main inner-loop logic for computing dQ
|
| 395 |
+
@triton.jit
|
| 396 |
+
def _attn_bwd_dq(
|
| 397 |
+
dq,
|
| 398 |
+
q,
|
| 399 |
+
K,
|
| 400 |
+
V, #
|
| 401 |
+
QMean,
|
| 402 |
+
KMean,
|
| 403 |
+
VMean,
|
| 404 |
+
do,
|
| 405 |
+
m,
|
| 406 |
+
m_vec,
|
| 407 |
+
D,
|
| 408 |
+
# shared by Q/K/V/DO.
|
| 409 |
+
q2k_index,
|
| 410 |
+
q2k_num,
|
| 411 |
+
max_kv_blks,
|
| 412 |
+
dropped_q2k_index,
|
| 413 |
+
dropped_q2k_num,
|
| 414 |
+
max_dropped_kv_blks,
|
| 415 |
+
variable_block_sizes,
|
| 416 |
+
stride_tok,
|
| 417 |
+
stride_d, #
|
| 418 |
+
H,
|
| 419 |
+
N_CTX, #
|
| 420 |
+
BLOCK_M2: tl.constexpr, #
|
| 421 |
+
BLOCK_N2: tl.constexpr, #
|
| 422 |
+
HEAD_DIM: tl.constexpr,
|
| 423 |
+
# Filled in by the wrapper.
|
| 424 |
+
start_m,
|
| 425 |
+
start_n,
|
| 426 |
+
num_steps,
|
| 427 |
+
sm_scale=1.0,
|
| 428 |
+
IS_QAT: tl.constexpr = False,
|
| 429 |
+
USE_TILE_COMP: tl.constexpr = False):
|
| 430 |
+
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
| 431 |
+
offs_n = start_n + tl.arange(0, BLOCK_N2)
|
| 432 |
+
offs_k = tl.arange(0, HEAD_DIM)
|
| 433 |
+
kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
|
| 434 |
+
vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
|
| 435 |
+
# D (= delta) is pre-divided by ds_scale.
|
| 436 |
+
Di = tl.load(D + offs_m)
|
| 437 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 438 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 439 |
+
step_n = BLOCK_N2
|
| 440 |
+
|
| 441 |
+
q_blk = tl.program_id(0) # Q-tile index
|
| 442 |
+
off_hz = tl.program_id(2) # fused (batch, head)
|
| 443 |
+
b = off_hz // H
|
| 444 |
+
h = off_hz % H
|
| 445 |
+
q_tiles = N_CTX // BLOCK_M2
|
| 446 |
+
meta_base = ((b * H + h) * q_tiles + q_blk)
|
| 447 |
+
|
| 448 |
+
kv_blocks = tl.load(q2k_num + meta_base) # int32
|
| 449 |
+
kv_ptr = q2k_index + meta_base * max_kv_blks # ptr to list
|
| 450 |
+
dropped_kv_blocks = tl.load(dropped_q2k_num + meta_base)
|
| 451 |
+
dropped_kv_ptr = dropped_q2k_index + meta_base * max_dropped_kv_blks
|
| 452 |
+
|
| 453 |
+
for blk_idx in range(kv_blocks * 2):
|
| 454 |
+
kv_idx = tl.load(kv_ptr + blk_idx // 2).to(tl.int32)
|
| 455 |
+
# variable_block_sizes is defined per KV block (tile). Mask must therefore
|
| 456 |
+
# use kv_idx (not q_blk). Also, because we split each 64-token block into
|
| 457 |
+
# two 32-token halves, the mask must account for the half-block offset.
|
| 458 |
+
block_size = tl.load(variable_block_sizes + kv_idx).to(tl.int32)
|
| 459 |
+
half = (blk_idx % 2).to(tl.int32)
|
| 460 |
+
block_sparse_offset = (kv_idx * 2 + half) * step_n * stride_tok
|
| 461 |
+
kT = tl.load(kT_ptrs + block_sparse_offset)
|
| 462 |
+
vT = tl.load(vT_ptrs + block_sparse_offset)
|
| 463 |
+
qk = tl.dot(q, kT)
|
| 464 |
+
qk = qk * sm_scale * 1.44269504
|
| 465 |
+
p = tl.math.exp2(qk - m)
|
| 466 |
+
offs_in_block = half * step_n + tl.arange(0, BLOCK_N2)
|
| 467 |
+
mask = offs_in_block < block_size
|
| 468 |
+
p = tl.where(mask[None, :], p, 0.0)
|
| 469 |
+
# Compute dP and dS.
|
| 470 |
+
dp = tl.dot(do, vT).to(tl.float32)
|
| 471 |
+
ds = p * (dp - Di[:, None])
|
| 472 |
+
ds = ds.to(tl.bfloat16)
|
| 473 |
+
# Compute dQ.
|
| 474 |
+
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
|
| 475 |
+
dq += tl.dot(ds, tl.trans(kT))
|
| 476 |
+
# Increment pointers.
|
| 477 |
+
|
| 478 |
+
if USE_TILE_COMP:
|
| 479 |
+
q_mean = tl.load(QMean + q_blk * HEAD_DIM + offs_k).to(tl.float32)
|
| 480 |
+
q_block_size = tl.load(variable_block_sizes + q_blk).to(tl.float32)
|
| 481 |
+
qk_scale = sm_scale * 1.44269504
|
| 482 |
+
dq_mean = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
| 483 |
+
|
| 484 |
+
for blk_idx in range(dropped_kv_blocks):
|
| 485 |
+
kv_idx = tl.load(dropped_kv_ptr + blk_idx).to(tl.int32)
|
| 486 |
+
block_size = tl.load(variable_block_sizes + kv_idx).to(tl.float32)
|
| 487 |
+
k_mean = tl.load(KMean + kv_idx * HEAD_DIM +
|
| 488 |
+
offs_k).to(tl.float32)
|
| 489 |
+
v_mean = tl.load(VMean + kv_idx * HEAD_DIM +
|
| 490 |
+
offs_k).to(tl.float32)
|
| 491 |
+
|
| 492 |
+
score = tl.sum(q_mean * k_mean, axis=0) * qk_scale
|
| 493 |
+
p = tl.math.exp2(score - m_vec)
|
| 494 |
+
dp = tl.sum(do.to(tl.float32) * v_mean[None, :], axis=1)
|
| 495 |
+
ds = block_size * p * (dp - Di)
|
| 496 |
+
dq_mean = dq_mean + tl.sum(ds, axis=0) * k_mean
|
| 497 |
+
|
| 498 |
+
dq += dq_mean[None, :] / q_block_size
|
| 499 |
+
return dq
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
@triton.jit
|
| 503 |
+
def _attn_bwd(
|
| 504 |
+
Q,
|
| 505 |
+
K,
|
| 506 |
+
V,
|
| 507 |
+
sm_scale, #
|
| 508 |
+
DO, #
|
| 509 |
+
DQ,
|
| 510 |
+
DK,
|
| 511 |
+
DV, #
|
| 512 |
+
M,
|
| 513 |
+
D,
|
| 514 |
+
q2k_index,
|
| 515 |
+
q2k_num,
|
| 516 |
+
max_kv_blks,
|
| 517 |
+
k2q_index,
|
| 518 |
+
k2q_num,
|
| 519 |
+
max_q_blks,
|
| 520 |
+
variable_block_sizes,
|
| 521 |
+
# shared by Q/K/V/DO.
|
| 522 |
+
stride_z,
|
| 523 |
+
stride_h,
|
| 524 |
+
stride_tok,
|
| 525 |
+
stride_d, #
|
| 526 |
+
H,
|
| 527 |
+
N_CTX, #
|
| 528 |
+
BLOCK_M1: tl.constexpr, #
|
| 529 |
+
BLOCK_N1: tl.constexpr, #
|
| 530 |
+
BLOCK_M2: tl.constexpr, #
|
| 531 |
+
BLOCK_N2: tl.constexpr, #
|
| 532 |
+
HEAD_DIM: tl.constexpr,
|
| 533 |
+
IS_QAT: tl.constexpr = False):
|
| 534 |
+
LN2 = 0.6931471824645996 # = ln(2)
|
| 535 |
+
|
| 536 |
+
bhid = tl.program_id(2)
|
| 537 |
+
off_chz = (bhid * N_CTX).to(tl.int64)
|
| 538 |
+
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
|
| 539 |
+
pid = tl.program_id(0)
|
| 540 |
+
|
| 541 |
+
# offset pointers for batch/head
|
| 542 |
+
Q += adj
|
| 543 |
+
K += adj
|
| 544 |
+
V += adj
|
| 545 |
+
DO += adj
|
| 546 |
+
DQ += adj
|
| 547 |
+
DK += adj
|
| 548 |
+
DV += adj
|
| 549 |
+
M += off_chz
|
| 550 |
+
D += off_chz
|
| 551 |
+
|
| 552 |
+
# load scales
|
| 553 |
+
offs_k = tl.arange(0, HEAD_DIM)
|
| 554 |
+
|
| 555 |
+
start_n = pid * BLOCK_N1
|
| 556 |
+
start_m = 0
|
| 557 |
+
|
| 558 |
+
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
| 559 |
+
|
| 560 |
+
dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
| 561 |
+
dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
| 562 |
+
|
| 563 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 564 |
+
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 565 |
+
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 566 |
+
|
| 567 |
+
num_steps = N_CTX // BLOCK_M1
|
| 568 |
+
|
| 569 |
+
dk, dv = _attn_bwd_dkdv( #
|
| 570 |
+
dk,
|
| 571 |
+
dv, #
|
| 572 |
+
Q,
|
| 573 |
+
k,
|
| 574 |
+
v,
|
| 575 |
+
Q,
|
| 576 |
+
K,
|
| 577 |
+
V,
|
| 578 |
+
sm_scale, #
|
| 579 |
+
DO, #
|
| 580 |
+
M,
|
| 581 |
+
D, #
|
| 582 |
+
k2q_index,
|
| 583 |
+
k2q_num,
|
| 584 |
+
max_q_blks,
|
| 585 |
+
k2q_index,
|
| 586 |
+
k2q_num,
|
| 587 |
+
max_q_blks,
|
| 588 |
+
variable_block_sizes,
|
| 589 |
+
stride_tok,
|
| 590 |
+
stride_d, #
|
| 591 |
+
H,
|
| 592 |
+
N_CTX, #
|
| 593 |
+
BLOCK_M1,
|
| 594 |
+
BLOCK_N1,
|
| 595 |
+
HEAD_DIM, #
|
| 596 |
+
start_n,
|
| 597 |
+
start_m,
|
| 598 |
+
num_steps, #
|
| 599 |
+
IS_QAT=IS_QAT,
|
| 600 |
+
USE_TILE_COMP=False,
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
| 604 |
+
tl.store(dv_ptrs, dv)
|
| 605 |
+
|
| 606 |
+
# Write back dK.
|
| 607 |
+
dk *= sm_scale
|
| 608 |
+
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
| 609 |
+
tl.store(dk_ptrs, dk)
|
| 610 |
+
|
| 611 |
+
# THIS BLOCK DOES DQ:
|
| 612 |
+
start_m = pid * BLOCK_M2
|
| 613 |
+
end_n = 0
|
| 614 |
+
|
| 615 |
+
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
| 616 |
+
|
| 617 |
+
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 618 |
+
dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
|
| 619 |
+
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 620 |
+
|
| 621 |
+
m_vec = tl.load(M + offs_m)
|
| 622 |
+
m = m_vec[:, None]
|
| 623 |
+
|
| 624 |
+
num_steps = N_CTX // BLOCK_N2
|
| 625 |
+
dq = _attn_bwd_dq(
|
| 626 |
+
dq,
|
| 627 |
+
q,
|
| 628 |
+
K,
|
| 629 |
+
V, #
|
| 630 |
+
Q,
|
| 631 |
+
K,
|
| 632 |
+
V,
|
| 633 |
+
do,
|
| 634 |
+
m,
|
| 635 |
+
m_vec,
|
| 636 |
+
D, #
|
| 637 |
+
q2k_index,
|
| 638 |
+
q2k_num,
|
| 639 |
+
max_kv_blks,
|
| 640 |
+
q2k_index,
|
| 641 |
+
q2k_num,
|
| 642 |
+
max_kv_blks,
|
| 643 |
+
variable_block_sizes,
|
| 644 |
+
stride_tok,
|
| 645 |
+
stride_d, #
|
| 646 |
+
H,
|
| 647 |
+
N_CTX, #
|
| 648 |
+
BLOCK_M2,
|
| 649 |
+
BLOCK_N2,
|
| 650 |
+
HEAD_DIM, #
|
| 651 |
+
start_m,
|
| 652 |
+
end_n,
|
| 653 |
+
num_steps, #
|
| 654 |
+
sm_scale=sm_scale,
|
| 655 |
+
IS_QAT=IS_QAT,
|
| 656 |
+
USE_TILE_COMP=False,
|
| 657 |
+
)
|
| 658 |
+
# Write back dQ.
|
| 659 |
+
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
| 660 |
+
dq *= sm_scale
|
| 661 |
+
tl.store(dq_ptrs, dq)
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
@triton.jit
|
| 665 |
+
def _attn_bwd_dkdv_kernel(
|
| 666 |
+
Q,
|
| 667 |
+
K,
|
| 668 |
+
V,
|
| 669 |
+
QMean,
|
| 670 |
+
KMean,
|
| 671 |
+
VMean,
|
| 672 |
+
sm_scale, #
|
| 673 |
+
DO, #
|
| 674 |
+
DK,
|
| 675 |
+
DV, #
|
| 676 |
+
M,
|
| 677 |
+
D,
|
| 678 |
+
k2q_index,
|
| 679 |
+
k2q_num,
|
| 680 |
+
max_q_blks,
|
| 681 |
+
dropped_k2q_index,
|
| 682 |
+
dropped_k2q_num,
|
| 683 |
+
max_dropped_q_blks,
|
| 684 |
+
variable_block_sizes,
|
| 685 |
+
# shared token/dim strides (assumed contiguous along token and dim)
|
| 686 |
+
stride_tok,
|
| 687 |
+
stride_d, #
|
| 688 |
+
# batch/head strides (may differ between Q and KV)
|
| 689 |
+
stride_qz,
|
| 690 |
+
stride_qh,
|
| 691 |
+
stride_kz,
|
| 692 |
+
stride_kh,
|
| 693 |
+
stride_vz,
|
| 694 |
+
stride_vh,
|
| 695 |
+
stride_doz,
|
| 696 |
+
stride_doh,
|
| 697 |
+
stride_dkz,
|
| 698 |
+
stride_dkh,
|
| 699 |
+
stride_dvz,
|
| 700 |
+
stride_dvh,
|
| 701 |
+
H,
|
| 702 |
+
N_CTX_Q,
|
| 703 |
+
N_CTX_KV,
|
| 704 |
+
BLOCK_M1: tl.constexpr, #
|
| 705 |
+
BLOCK_N1: tl.constexpr, #
|
| 706 |
+
HEAD_DIM: tl.constexpr,
|
| 707 |
+
IS_QAT: tl.constexpr = False,
|
| 708 |
+
USE_TILE_COMP: tl.constexpr = False):
|
| 709 |
+
"""
|
| 710 |
+
Backward kernel that computes dK and dV for each KV block (64 tokens).
|
| 711 |
+
Grid:
|
| 712 |
+
pid0: kv_blk in [0, N_CTX_KV/BLOCK_N1)
|
| 713 |
+
pid2: fused (batch, head) in [0, B*H)
|
| 714 |
+
"""
|
| 715 |
+
bhid = tl.program_id(2)
|
| 716 |
+
b = bhid // H
|
| 717 |
+
h = bhid % H
|
| 718 |
+
kv_blk = tl.program_id(0)
|
| 719 |
+
|
| 720 |
+
q_adj = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
|
| 721 |
+
kv_adj_k = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
|
| 722 |
+
kv_adj_v = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
|
| 723 |
+
do_adj = (b.to(tl.int64) * stride_doz + h.to(tl.int64) * stride_doh)
|
| 724 |
+
dk_adj = (b.to(tl.int64) * stride_dkz + h.to(tl.int64) * stride_dkh)
|
| 725 |
+
dv_adj = (b.to(tl.int64) * stride_dvz + h.to(tl.int64) * stride_dvh)
|
| 726 |
+
|
| 727 |
+
Q = Q + q_adj
|
| 728 |
+
K = K + kv_adj_k
|
| 729 |
+
V = V + kv_adj_v
|
| 730 |
+
DO = DO + do_adj
|
| 731 |
+
DK = DK + dk_adj
|
| 732 |
+
DV = DV + dv_adj
|
| 733 |
+
|
| 734 |
+
q_tiles = N_CTX_Q // BLOCK_M1 // 2
|
| 735 |
+
kv_tiles = N_CTX_KV // BLOCK_N1
|
| 736 |
+
mean_q_adj = (bhid * q_tiles * HEAD_DIM).to(tl.int64)
|
| 737 |
+
mean_kv_adj = (bhid * kv_tiles * HEAD_DIM).to(tl.int64)
|
| 738 |
+
QMean = QMean + mean_q_adj
|
| 739 |
+
KMean = KMean + mean_kv_adj
|
| 740 |
+
VMean = VMean + mean_kv_adj
|
| 741 |
+
|
| 742 |
+
# M and D (delta) are always sized by Q length.
|
| 743 |
+
M = M + (bhid * N_CTX_Q).to(tl.int64)
|
| 744 |
+
D = D + (bhid * N_CTX_Q).to(tl.int64)
|
| 745 |
+
|
| 746 |
+
offs_k = tl.arange(0, HEAD_DIM)
|
| 747 |
+
start_n = kv_blk * BLOCK_N1
|
| 748 |
+
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
| 749 |
+
|
| 750 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 751 |
+
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 752 |
+
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 753 |
+
|
| 754 |
+
dv_acc = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
| 755 |
+
dk_acc = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
| 756 |
+
|
| 757 |
+
num_steps = N_CTX_Q // BLOCK_M1
|
| 758 |
+
dk_acc, dv_acc = _attn_bwd_dkdv(
|
| 759 |
+
dk_acc,
|
| 760 |
+
dv_acc,
|
| 761 |
+
Q,
|
| 762 |
+
k,
|
| 763 |
+
v,
|
| 764 |
+
QMean,
|
| 765 |
+
KMean,
|
| 766 |
+
VMean,
|
| 767 |
+
sm_scale,
|
| 768 |
+
DO,
|
| 769 |
+
M,
|
| 770 |
+
D,
|
| 771 |
+
k2q_index,
|
| 772 |
+
k2q_num,
|
| 773 |
+
max_q_blks,
|
| 774 |
+
dropped_k2q_index,
|
| 775 |
+
dropped_k2q_num,
|
| 776 |
+
max_dropped_q_blks,
|
| 777 |
+
variable_block_sizes,
|
| 778 |
+
stride_tok,
|
| 779 |
+
stride_d,
|
| 780 |
+
H,
|
| 781 |
+
N_CTX_KV,
|
| 782 |
+
BLOCK_M1=BLOCK_M1,
|
| 783 |
+
BLOCK_N1=BLOCK_N1,
|
| 784 |
+
HEAD_DIM=HEAD_DIM,
|
| 785 |
+
start_n=start_n,
|
| 786 |
+
start_m=0,
|
| 787 |
+
num_steps=num_steps,
|
| 788 |
+
IS_QAT=IS_QAT,
|
| 789 |
+
USE_TILE_COMP=USE_TILE_COMP,
|
| 790 |
+
)
|
| 791 |
+
|
| 792 |
+
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
| 793 |
+
tl.store(dv_ptrs, dv_acc)
|
| 794 |
+
|
| 795 |
+
dk_acc *= sm_scale
|
| 796 |
+
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
| 797 |
+
tl.store(dk_ptrs, dk_acc)
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
@triton.jit
|
| 801 |
+
def _attn_bwd_dq_kernel(
|
| 802 |
+
Q,
|
| 803 |
+
K,
|
| 804 |
+
V,
|
| 805 |
+
QMean,
|
| 806 |
+
KMean,
|
| 807 |
+
VMean,
|
| 808 |
+
DO, #
|
| 809 |
+
DQ,
|
| 810 |
+
M,
|
| 811 |
+
D,
|
| 812 |
+
q2k_index,
|
| 813 |
+
q2k_num,
|
| 814 |
+
max_kv_blks,
|
| 815 |
+
dropped_q2k_index,
|
| 816 |
+
dropped_q2k_num,
|
| 817 |
+
max_dropped_kv_blks,
|
| 818 |
+
variable_block_sizes,
|
| 819 |
+
# shared token/dim strides (assumed contiguous along token and dim)
|
| 820 |
+
stride_tok,
|
| 821 |
+
stride_d, #
|
| 822 |
+
# batch/head strides (may differ between Q and KV)
|
| 823 |
+
stride_qz,
|
| 824 |
+
stride_qh,
|
| 825 |
+
stride_kz,
|
| 826 |
+
stride_kh,
|
| 827 |
+
stride_vz,
|
| 828 |
+
stride_vh,
|
| 829 |
+
stride_doz,
|
| 830 |
+
stride_doh,
|
| 831 |
+
stride_dqz,
|
| 832 |
+
stride_dqh,
|
| 833 |
+
H,
|
| 834 |
+
N_CTX_Q,
|
| 835 |
+
sm_scale,
|
| 836 |
+
BLOCK_M2: tl.constexpr, #
|
| 837 |
+
BLOCK_N2: tl.constexpr, #
|
| 838 |
+
HEAD_DIM: tl.constexpr,
|
| 839 |
+
IS_QAT: tl.constexpr = False,
|
| 840 |
+
USE_TILE_COMP: tl.constexpr = False):
|
| 841 |
+
"""
|
| 842 |
+
Backward kernel that computes dQ for each Q block (64 tokens).
|
| 843 |
+
Grid:
|
| 844 |
+
pid0: q_blk in [0, N_CTX_Q/BLOCK_M2)
|
| 845 |
+
pid2: fused (batch, head) in [0, B*H)
|
| 846 |
+
"""
|
| 847 |
+
LN2 = 0.6931471824645996 # = ln(2)
|
| 848 |
+
bhid = tl.program_id(2)
|
| 849 |
+
b = bhid // H
|
| 850 |
+
h = bhid % H
|
| 851 |
+
q_blk = tl.program_id(0)
|
| 852 |
+
|
| 853 |
+
q_adj = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
|
| 854 |
+
kv_adj_k = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
|
| 855 |
+
kv_adj_v = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
|
| 856 |
+
do_adj = (b.to(tl.int64) * stride_doz + h.to(tl.int64) * stride_doh)
|
| 857 |
+
dq_adj = (b.to(tl.int64) * stride_dqz + h.to(tl.int64) * stride_dqh)
|
| 858 |
+
|
| 859 |
+
Q = Q + q_adj
|
| 860 |
+
K = K + kv_adj_k
|
| 861 |
+
V = V + kv_adj_v
|
| 862 |
+
DO = DO + do_adj
|
| 863 |
+
DQ = DQ + dq_adj
|
| 864 |
+
|
| 865 |
+
q_tiles = N_CTX_Q // BLOCK_M2
|
| 866 |
+
kv_tiles = N_CTX_Q // 64
|
| 867 |
+
mean_q_adj = (bhid * q_tiles * HEAD_DIM).to(tl.int64)
|
| 868 |
+
mean_kv_adj = (bhid * kv_tiles * HEAD_DIM).to(tl.int64)
|
| 869 |
+
QMean = QMean + mean_q_adj
|
| 870 |
+
KMean = KMean + mean_kv_adj
|
| 871 |
+
VMean = VMean + mean_kv_adj
|
| 872 |
+
|
| 873 |
+
M = M + (bhid * N_CTX_Q).to(tl.int64)
|
| 874 |
+
D = D + (bhid * N_CTX_Q).to(tl.int64)
|
| 875 |
+
|
| 876 |
+
offs_k = tl.arange(0, HEAD_DIM)
|
| 877 |
+
start_m = q_blk * BLOCK_M2
|
| 878 |
+
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
| 879 |
+
|
| 880 |
+
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 881 |
+
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 882 |
+
m_vec = tl.load(M + offs_m)
|
| 883 |
+
m = m_vec[:, None]
|
| 884 |
+
|
| 885 |
+
dq_acc = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
|
| 886 |
+
num_steps = 0 # unused in _attn_bwd_dq
|
| 887 |
+
dq_acc = _attn_bwd_dq(
|
| 888 |
+
dq_acc,
|
| 889 |
+
q,
|
| 890 |
+
K,
|
| 891 |
+
V,
|
| 892 |
+
QMean,
|
| 893 |
+
KMean,
|
| 894 |
+
VMean,
|
| 895 |
+
do,
|
| 896 |
+
m,
|
| 897 |
+
m_vec,
|
| 898 |
+
D,
|
| 899 |
+
q2k_index,
|
| 900 |
+
q2k_num,
|
| 901 |
+
max_kv_blks,
|
| 902 |
+
dropped_q2k_index,
|
| 903 |
+
dropped_q2k_num,
|
| 904 |
+
max_dropped_kv_blks,
|
| 905 |
+
variable_block_sizes,
|
| 906 |
+
stride_tok,
|
| 907 |
+
stride_d,
|
| 908 |
+
H,
|
| 909 |
+
N_CTX_Q,
|
| 910 |
+
BLOCK_M2=BLOCK_M2,
|
| 911 |
+
BLOCK_N2=BLOCK_N2,
|
| 912 |
+
HEAD_DIM=HEAD_DIM,
|
| 913 |
+
start_m=start_m,
|
| 914 |
+
start_n=0,
|
| 915 |
+
num_steps=num_steps,
|
| 916 |
+
sm_scale=sm_scale,
|
| 917 |
+
IS_QAT=IS_QAT,
|
| 918 |
+
USE_TILE_COMP=USE_TILE_COMP,
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
| 922 |
+
dq_acc *= sm_scale
|
| 923 |
+
tl.store(dq_ptrs, dq_acc)
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
# ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
|
| 927 |
+
def triton_block_sparse_attn_forward(q, k, v, q2k_index, q2k_num,
|
| 928 |
+
variable_block_sizes, is_qat=False,
|
| 929 |
+
q_mean=None, k_mean=None, v_mean=None,
|
| 930 |
+
dropped_q2k_index=None,
|
| 931 |
+
dropped_q2k_num=None):
|
| 932 |
+
B, H, Tq, D = q.shape
|
| 933 |
+
Tkv = k.shape[2]
|
| 934 |
+
sm_scale = 1.0 / math.sqrt(D)
|
| 935 |
+
max_kv_blks = q2k_index.shape[-1]
|
| 936 |
+
use_tile_comp = q_mean is not None
|
| 937 |
+
if use_tile_comp:
|
| 938 |
+
assert k_mean is not None and v_mean is not None
|
| 939 |
+
assert dropped_q2k_index is not None and dropped_q2k_num is not None
|
| 940 |
+
q_mean = q_mean.contiguous()
|
| 941 |
+
k_mean = k_mean.contiguous()
|
| 942 |
+
v_mean = v_mean.contiguous()
|
| 943 |
+
max_dropped_kv_blks = dropped_q2k_index.shape[-1]
|
| 944 |
+
else:
|
| 945 |
+
q_mean = q
|
| 946 |
+
k_mean = k
|
| 947 |
+
v_mean = v
|
| 948 |
+
dropped_q2k_index = q2k_index
|
| 949 |
+
dropped_q2k_num = q2k_num
|
| 950 |
+
max_dropped_kv_blks = max_kv_blks
|
| 951 |
+
assert Tq % 64 == 0, f"q length must be a multiple of 64, but got {Tq}"
|
| 952 |
+
assert Tkv % 64 == 0, f"kv length must be a multiple of 64, but got {Tkv}"
|
| 953 |
+
assert q2k_num.shape[
|
| 954 |
+
-1] == Tq // 64, f"shape mismatch, Tq // 64 = {Tq // 64}, q2k_num.shape[-2] = {q2k_num.shape[-2]}"
|
| 955 |
+
assert variable_block_sizes.numel() == Tkv // 64, (
|
| 956 |
+
f"shape mismatch, variable_block_sizes must have length {Tkv // 64}, "
|
| 957 |
+
f"got {variable_block_sizes.numel()}"
|
| 958 |
+
)
|
| 959 |
+
o = torch.empty_like(q)
|
| 960 |
+
high_prec_o = torch.empty_like(q)
|
| 961 |
+
M = torch.empty((B, H, Tq), dtype=torch.float32, device=q.device)
|
| 962 |
+
|
| 963 |
+
grid = lambda _: (triton.cdiv(Tq, 64), B * H, 1)
|
| 964 |
+
_attn_fwd_sparse[grid](q,
|
| 965 |
+
k,
|
| 966 |
+
v,
|
| 967 |
+
q_mean,
|
| 968 |
+
k_mean,
|
| 969 |
+
v_mean,
|
| 970 |
+
sm_scale,
|
| 971 |
+
q2k_index,
|
| 972 |
+
q2k_num,
|
| 973 |
+
max_kv_blks,
|
| 974 |
+
dropped_q2k_index,
|
| 975 |
+
dropped_q2k_num,
|
| 976 |
+
max_dropped_kv_blks,
|
| 977 |
+
variable_block_sizes,
|
| 978 |
+
M,
|
| 979 |
+
o,
|
| 980 |
+
high_prec_o,
|
| 981 |
+
q.stride(0),
|
| 982 |
+
q.stride(1),
|
| 983 |
+
q.stride(2),
|
| 984 |
+
q.stride(3),
|
| 985 |
+
k.stride(0),
|
| 986 |
+
k.stride(1),
|
| 987 |
+
k.stride(2),
|
| 988 |
+
k.stride(3),
|
| 989 |
+
v.stride(0),
|
| 990 |
+
v.stride(1),
|
| 991 |
+
v.stride(2),
|
| 992 |
+
v.stride(3),
|
| 993 |
+
o.stride(0),
|
| 994 |
+
o.stride(1),
|
| 995 |
+
o.stride(2),
|
| 996 |
+
o.stride(3),
|
| 997 |
+
B,
|
| 998 |
+
H,
|
| 999 |
+
Tq,
|
| 1000 |
+
Tkv,
|
| 1001 |
+
HEAD_DIM=D,
|
| 1002 |
+
STAGE=3,
|
| 1003 |
+
IS_QAT=is_qat,
|
| 1004 |
+
USE_TILE_COMP=use_tile_comp)
|
| 1005 |
+
|
| 1006 |
+
return o, M, high_prec_o
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
def triton_block_sparse_attn_backward(do, q, k, v, o, M, q2k_index, q2k_num,
|
| 1010 |
+
k2q_index, k2q_num, variable_block_sizes,
|
| 1011 |
+
is_qat=False, q_mean=None, k_mean=None,
|
| 1012 |
+
v_mean=None, dropped_q2k_index=None,
|
| 1013 |
+
dropped_q2k_num=None,
|
| 1014 |
+
dropped_k2q_index=None,
|
| 1015 |
+
dropped_k2q_num=None):
|
| 1016 |
+
assert do.is_contiguous()
|
| 1017 |
+
|
| 1018 |
+
B, H, Tq, D = q.shape
|
| 1019 |
+
Tkv = k.shape[2]
|
| 1020 |
+
sm_scale = 1.0 / math.sqrt(D)
|
| 1021 |
+
dq = torch.empty_like(q)
|
| 1022 |
+
dk = torch.empty_like(k)
|
| 1023 |
+
dv = torch.empty_like(v)
|
| 1024 |
+
BATCH, N_HEAD = q.shape[:2]
|
| 1025 |
+
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
|
| 1026 |
+
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
|
| 1027 |
+
# Ours-P mode keeps K unscaled and applies sm_scale inside the bwd kernels.
|
| 1028 |
+
arg_k = k
|
| 1029 |
+
PRE_BLOCK = 64
|
| 1030 |
+
assert Tq % PRE_BLOCK == 0
|
| 1031 |
+
pre_grid = (Tq // PRE_BLOCK, BATCH * N_HEAD)
|
| 1032 |
+
delta = torch.empty_like(M)
|
| 1033 |
+
_attn_bwd_preprocess[pre_grid](
|
| 1034 |
+
o,
|
| 1035 |
+
do, #
|
| 1036 |
+
delta, #
|
| 1037 |
+
BATCH,
|
| 1038 |
+
N_HEAD,
|
| 1039 |
+
Tq, #
|
| 1040 |
+
BLOCK_M=PRE_BLOCK,
|
| 1041 |
+
HEAD_DIM=D #
|
| 1042 |
+
)
|
| 1043 |
+
|
| 1044 |
+
max_q_blks = k2q_index.shape[-1]
|
| 1045 |
+
max_kv_blks = q2k_index.shape[-1]
|
| 1046 |
+
use_tile_comp = q_mean is not None
|
| 1047 |
+
if use_tile_comp:
|
| 1048 |
+
assert k_mean is not None and v_mean is not None
|
| 1049 |
+
assert dropped_q2k_index is not None and dropped_q2k_num is not None
|
| 1050 |
+
assert dropped_k2q_index is not None and dropped_k2q_num is not None
|
| 1051 |
+
q_mean = q_mean.contiguous()
|
| 1052 |
+
k_mean = k_mean.contiguous()
|
| 1053 |
+
v_mean = v_mean.contiguous()
|
| 1054 |
+
max_dropped_kv_blks = dropped_q2k_index.shape[-1]
|
| 1055 |
+
max_dropped_q_blks = dropped_k2q_index.shape[-1]
|
| 1056 |
+
else:
|
| 1057 |
+
q_mean = q
|
| 1058 |
+
k_mean = k
|
| 1059 |
+
v_mean = v
|
| 1060 |
+
dropped_q2k_index = q2k_index
|
| 1061 |
+
dropped_q2k_num = q2k_num
|
| 1062 |
+
dropped_k2q_index = k2q_index
|
| 1063 |
+
dropped_k2q_num = k2q_num
|
| 1064 |
+
max_dropped_kv_blks = max_kv_blks
|
| 1065 |
+
max_dropped_q_blks = max_q_blks
|
| 1066 |
+
|
| 1067 |
+
# dK/dV kernel: grid over KV blocks
|
| 1068 |
+
grid_kv = (Tkv // BLOCK_N1, 1, BATCH * N_HEAD)
|
| 1069 |
+
_attn_bwd_dkdv_kernel[grid_kv](
|
| 1070 |
+
q,
|
| 1071 |
+
arg_k,
|
| 1072 |
+
v,
|
| 1073 |
+
q_mean,
|
| 1074 |
+
k_mean,
|
| 1075 |
+
v_mean,
|
| 1076 |
+
sm_scale,
|
| 1077 |
+
do,
|
| 1078 |
+
dk,
|
| 1079 |
+
dv,
|
| 1080 |
+
M,
|
| 1081 |
+
delta,
|
| 1082 |
+
k2q_index,
|
| 1083 |
+
k2q_num,
|
| 1084 |
+
max_q_blks,
|
| 1085 |
+
dropped_k2q_index,
|
| 1086 |
+
dropped_k2q_num,
|
| 1087 |
+
max_dropped_q_blks,
|
| 1088 |
+
variable_block_sizes,
|
| 1089 |
+
q.stride(2),
|
| 1090 |
+
q.stride(3),
|
| 1091 |
+
q.stride(0),
|
| 1092 |
+
q.stride(1),
|
| 1093 |
+
arg_k.stride(0),
|
| 1094 |
+
arg_k.stride(1),
|
| 1095 |
+
v.stride(0),
|
| 1096 |
+
v.stride(1),
|
| 1097 |
+
do.stride(0),
|
| 1098 |
+
do.stride(1),
|
| 1099 |
+
dk.stride(0),
|
| 1100 |
+
dk.stride(1),
|
| 1101 |
+
dv.stride(0),
|
| 1102 |
+
dv.stride(1),
|
| 1103 |
+
N_HEAD,
|
| 1104 |
+
Tq,
|
| 1105 |
+
Tkv,
|
| 1106 |
+
BLOCK_M1=BLOCK_M1,
|
| 1107 |
+
BLOCK_N1=BLOCK_N1,
|
| 1108 |
+
HEAD_DIM=D,
|
| 1109 |
+
IS_QAT=is_qat,
|
| 1110 |
+
USE_TILE_COMP=use_tile_comp,
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
# dQ kernel: grid over Q blocks
|
| 1114 |
+
grid_q = (Tq // BLOCK_M2, 1, BATCH * N_HEAD)
|
| 1115 |
+
_attn_bwd_dq_kernel[grid_q](
|
| 1116 |
+
q,
|
| 1117 |
+
arg_k,
|
| 1118 |
+
v,
|
| 1119 |
+
q_mean,
|
| 1120 |
+
k_mean,
|
| 1121 |
+
v_mean,
|
| 1122 |
+
do,
|
| 1123 |
+
dq,
|
| 1124 |
+
M,
|
| 1125 |
+
delta,
|
| 1126 |
+
q2k_index,
|
| 1127 |
+
q2k_num,
|
| 1128 |
+
max_kv_blks,
|
| 1129 |
+
dropped_q2k_index,
|
| 1130 |
+
dropped_q2k_num,
|
| 1131 |
+
max_dropped_kv_blks,
|
| 1132 |
+
variable_block_sizes,
|
| 1133 |
+
q.stride(2),
|
| 1134 |
+
q.stride(3),
|
| 1135 |
+
q.stride(0),
|
| 1136 |
+
q.stride(1),
|
| 1137 |
+
arg_k.stride(0),
|
| 1138 |
+
arg_k.stride(1),
|
| 1139 |
+
v.stride(0),
|
| 1140 |
+
v.stride(1),
|
| 1141 |
+
do.stride(0),
|
| 1142 |
+
do.stride(1),
|
| 1143 |
+
dq.stride(0),
|
| 1144 |
+
dq.stride(1),
|
| 1145 |
+
N_HEAD,
|
| 1146 |
+
Tq,
|
| 1147 |
+
sm_scale,
|
| 1148 |
+
BLOCK_M2=BLOCK_M2,
|
| 1149 |
+
BLOCK_N2=BLOCK_N2,
|
| 1150 |
+
HEAD_DIM=D,
|
| 1151 |
+
IS_QAT=is_qat,
|
| 1152 |
+
USE_TILE_COMP=use_tile_comp,
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
return dq, dk, dv
|
standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# Adapted from https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
|
| 3 |
+
# and https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
|
| 4 |
+
|
| 5 |
+
import triton
|
| 6 |
+
import triton.language as tl
|
| 7 |
+
try:
|
| 8 |
+
from triton.language.target_info import cuda_capability_geq
|
| 9 |
+
_HAS_CAPABILITY_CHECK = True
|
| 10 |
+
except ImportError:
|
| 11 |
+
cuda_capability_geq = None
|
| 12 |
+
_HAS_CAPABILITY_CHECK = False
|
| 13 |
+
|
| 14 |
+
MXFP_BLOCK_SIZE = tl.constexpr(16)
|
| 15 |
+
|
| 16 |
+
@triton.jit
|
| 17 |
+
def _compute_quant_and_scale(
|
| 18 |
+
src_tensor,
|
| 19 |
+
valid_src_mask,
|
| 20 |
+
mx_tensor_dtype: tl.constexpr = tl.uint8,
|
| 21 |
+
use_global_sf=True,
|
| 22 |
+
two_level_quant_P=False,
|
| 23 |
+
IS_BLACKWELL: tl.constexpr = False,
|
| 24 |
+
):
|
| 25 |
+
BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
|
| 26 |
+
BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
|
| 27 |
+
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
|
| 28 |
+
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
|
| 29 |
+
|
| 30 |
+
is_fp8e4: tl.constexpr = mx_tensor_dtype == tl.float8e4nv
|
| 31 |
+
is_fp8e5: tl.constexpr = mx_tensor_dtype == tl.float8e5
|
| 32 |
+
tl.static_assert(
|
| 33 |
+
is_fp4 or (is_fp8e4 or is_fp8e5),
|
| 34 |
+
"mx_tensor_dtype must be uint8, float8e4nv, or float8e5",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
|
| 38 |
+
f32_tensor = src_tensor.to(tl.float32)
|
| 39 |
+
abs_tensor = tl.abs(f32_tensor)
|
| 40 |
+
abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
|
| 41 |
+
|
| 42 |
+
if two_level_quant_P:
|
| 43 |
+
# row max from SageAttn3 paper
|
| 44 |
+
global_max_val = tl.max(f32_tensor, axis=1, keep_dims=True) # (BLOCK_SIZE_OUT_DIM, 1)
|
| 45 |
+
global_max_val = tl.maximum(global_max_val, 1e-8)
|
| 46 |
+
s_enc = ((6 * 448) / global_max_val).reshape([BLOCK_SIZE_OUT_DIM, 1, 1])
|
| 47 |
+
s_dec = (1 / s_enc)
|
| 48 |
+
|
| 49 |
+
abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
| 50 |
+
|
| 51 |
+
if use_global_sf and not two_level_quant_P:
|
| 52 |
+
global_max_val = tl.max(abs_tensor)
|
| 53 |
+
# Avoid division by zero: if all values are padding (max is 0), use a default scale
|
| 54 |
+
global_max_val = tl.maximum(global_max_val, 1e-8)
|
| 55 |
+
s_enc = (6 * 448) / global_max_val
|
| 56 |
+
s_dec = (1 / s_enc)
|
| 57 |
+
elif not two_level_quant_P and not use_global_sf:
|
| 58 |
+
s_dec = 1.0
|
| 59 |
+
s_enc = 1.0
|
| 60 |
+
|
| 61 |
+
max_val = tl.max(abs_tensor, axis=2, keep_dims=True) # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1) # per block maxima
|
| 62 |
+
s_dec_b = max_val / 6 # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1)
|
| 63 |
+
s_dec_b_e4m3 = (s_dec_b * s_enc).to(tl.float8e4nv) # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1)
|
| 64 |
+
s_enc_b = 1 / (s_dec_b_e4m3.to(tl.float32) * s_dec) # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1)
|
| 65 |
+
|
| 66 |
+
f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
| 67 |
+
quant_tensor = f32_tensor * s_enc_b
|
| 68 |
+
|
| 69 |
+
# Reshape the tensors after scaling
|
| 70 |
+
quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
|
| 71 |
+
# Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
|
| 72 |
+
quant_tensor = tl.where(valid_src_mask, quant_tensor, 0.0)
|
| 73 |
+
dequant_scale = s_dec_b_e4m3.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
|
| 74 |
+
|
| 75 |
+
if is_fp4 and IS_BLACKWELL:
|
| 76 |
+
# Convert scaled values to two f32 lanes and use PTX cvt to e2m1x2 with two f32 operands.
|
| 77 |
+
pairs = tl.reshape(quant_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
|
| 78 |
+
lo_f, hi_f = tl.split(pairs)
|
| 79 |
+
lo_f32 = lo_f.to(tl.float32)
|
| 80 |
+
hi_f32 = hi_f.to(tl.float32)
|
| 81 |
+
|
| 82 |
+
# Inline PTX: cvt.rn.satfinite.e2m1x2.f32 takes two f32 sources and produces one .b8 packed e2m1x2.
|
| 83 |
+
out_tensor = tl.inline_asm_elementwise(
|
| 84 |
+
"""
|
| 85 |
+
{
|
| 86 |
+
.reg .b8 r;
|
| 87 |
+
cvt.rn.satfinite.e2m1x2.f32 r, $1, $2;
|
| 88 |
+
mov.b32 $0, {r, r, r, r};
|
| 89 |
+
}
|
| 90 |
+
""",
|
| 91 |
+
constraints="=r,f,f",
|
| 92 |
+
args=[hi_f32, lo_f32],
|
| 93 |
+
dtype=tl.uint8,
|
| 94 |
+
is_pure=True,
|
| 95 |
+
pack=1,
|
| 96 |
+
)
|
| 97 |
+
elif is_fp4:
|
| 98 |
+
quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
|
| 99 |
+
signs = quant_tensor & 0x80000000
|
| 100 |
+
exponents = (quant_tensor >> 23) & 0xFF
|
| 101 |
+
mantissas_orig = (quant_tensor & 0x7FFFFF)
|
| 102 |
+
|
| 103 |
+
# For RTNE: 0.25 < x < 0.75 maps to 0.5 (denormal); exactly 0.25 maps to 0.0
|
| 104 |
+
E8_BIAS = 127
|
| 105 |
+
E2_BIAS = 1
|
| 106 |
+
# Move implicit bit 1 at the beginning to mantissa for denormals
|
| 107 |
+
is_subnormal = exponents < E8_BIAS
|
| 108 |
+
adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False)
|
| 109 |
+
mantissas_pre = (0x400000 | (mantissas_orig >> 1))
|
| 110 |
+
mantissas = tl.where(is_subnormal, mantissas_pre >> adjusted_exponents, mantissas_orig)
|
| 111 |
+
|
| 112 |
+
# For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
|
| 113 |
+
exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
|
| 114 |
+
|
| 115 |
+
# Combine sign, exponent, and mantissa, while saturating
|
| 116 |
+
# Round to nearest, ties to even (RTNE): use guard/sticky and LSB to decide increment
|
| 117 |
+
m2bits = mantissas >> 21
|
| 118 |
+
lsb_keep = (m2bits >> 1) & 0x1
|
| 119 |
+
guard = m2bits & 0x1
|
| 120 |
+
IS_SRC_FP32: tl.constexpr = src_tensor.dtype == tl.float32
|
| 121 |
+
if IS_SRC_FP32:
|
| 122 |
+
bit0_dropped = (mantissas_orig & 0x1) != 0
|
| 123 |
+
mask = (1 << tl.minimum(adjusted_exponents, 31)) - 1
|
| 124 |
+
dropped_post = (mantissas_pre & mask) != 0
|
| 125 |
+
sticky = is_subnormal & (bit0_dropped | dropped_post)
|
| 126 |
+
sticky |= ((mantissas & 0x1FFFFF) != 0).to(tl.uint32)
|
| 127 |
+
else:
|
| 128 |
+
sticky = ((mantissas & 0x1FFFFF) != 0).to(tl.uint32)
|
| 129 |
+
round_inc = guard & (sticky | lsb_keep)
|
| 130 |
+
e2m1_tmp = tl.minimum((((exponents << 2) | m2bits) + round_inc) >> 1, 0x7)
|
| 131 |
+
e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
|
| 132 |
+
|
| 133 |
+
e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
|
| 134 |
+
evens, odds = tl.split(e2m1_value)
|
| 135 |
+
out_tensor = evens | (odds << 4)
|
| 136 |
+
else:
|
| 137 |
+
out_tensor = quant_tensor.to(mx_tensor_dtype)
|
| 138 |
+
|
| 139 |
+
return out_tensor, dequant_scale, s_dec
|
| 140 |
+
|
| 141 |
+
@triton.jit
|
| 142 |
+
def _compute_dequant(
|
| 143 |
+
mx_tensor,
|
| 144 |
+
scale,
|
| 145 |
+
s_dec,
|
| 146 |
+
BLOCK_SIZE_OUT_DIM: tl.constexpr,
|
| 147 |
+
BLOCK_SIZE_QUANT_DIM: tl.constexpr,
|
| 148 |
+
dst_dtype: tl.constexpr,
|
| 149 |
+
IS_BLACKWELL: tl.constexpr = False,
|
| 150 |
+
):
|
| 151 |
+
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"Block size along quantization block must be a multiple of {MXFP_BLOCK_SIZE=}")
|
| 152 |
+
# uint8 signifies two fp4 e2m1 values packed into a single byte
|
| 153 |
+
mx_tensor_dtype: tl.constexpr = mx_tensor.dtype
|
| 154 |
+
_is_f16: tl.constexpr = dst_dtype == tl.float16
|
| 155 |
+
_is_bf16: tl.constexpr = dst_dtype == tl.bfloat16
|
| 156 |
+
_is_f32: tl.constexpr = dst_dtype == tl.float32
|
| 157 |
+
tl.static_assert(_is_f16 or (_is_bf16 or _is_f32))
|
| 158 |
+
_is_u8: tl.constexpr = mx_tensor_dtype == tl.uint8
|
| 159 |
+
_is_e4: tl.constexpr = mx_tensor_dtype == tl.float8e4nv
|
| 160 |
+
_is_e5: tl.constexpr = mx_tensor_dtype == tl.float8e5
|
| 161 |
+
_is_dst: tl.constexpr = mx_tensor_dtype == dst_dtype
|
| 162 |
+
tl.static_assert(
|
| 163 |
+
_is_u8 or ((_is_e4 or _is_e5) or _is_dst),
|
| 164 |
+
"mx_tensor_ptr must be uint8 or float8 or dst_dtype")
|
| 165 |
+
tl.static_assert(scale.dtype == tl.float8e4nv, "scale must be float8e4nv")
|
| 166 |
+
|
| 167 |
+
# Determine if we are dealing with fp8 types.
|
| 168 |
+
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
|
| 169 |
+
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
|
| 170 |
+
|
| 171 |
+
# Upcast the scale to the destination type.
|
| 172 |
+
if dst_dtype == tl.bfloat16:
|
| 173 |
+
dst_scale = scale.to(tl.bfloat16)
|
| 174 |
+
else:
|
| 175 |
+
dst_scale = scale.to(tl.float32)
|
| 176 |
+
if dst_dtype == tl.float16:
|
| 177 |
+
dst_scale = dst_scale.to(tl.float16)
|
| 178 |
+
|
| 179 |
+
# Now upcast the tensor.
|
| 180 |
+
intermediate_dtype: tl.constexpr = tl.bfloat16 if dst_dtype == tl.float32 else dst_dtype
|
| 181 |
+
if IS_BLACKWELL:
|
| 182 |
+
assert is_fp4
|
| 183 |
+
packed_u32 = tl.inline_asm_elementwise(
|
| 184 |
+
asm="""
|
| 185 |
+
{
|
| 186 |
+
.reg .b8 in_8;
|
| 187 |
+
.reg .f16x2 out;
|
| 188 |
+
cvt.u8.u32 in_8, $1;
|
| 189 |
+
cvt.rn.f16x2.e2m1x2 out, in_8;
|
| 190 |
+
mov.b32 $0, out;
|
| 191 |
+
}
|
| 192 |
+
""",
|
| 193 |
+
constraints="=r,r",
|
| 194 |
+
args=[mx_tensor], # tl.uint8 passed in as a 32-bit reg with value in low 8 bits
|
| 195 |
+
dtype=tl.uint32,
|
| 196 |
+
is_pure=True,
|
| 197 |
+
pack=1,
|
| 198 |
+
)
|
| 199 |
+
lo_u16 = (packed_u32 & 0xFFFF).to(tl.uint16)
|
| 200 |
+
hi_u16 = (packed_u32 >> 16).to(tl.uint16)
|
| 201 |
+
lo_f16 = lo_u16.to(tl.float16, bitcast=True)
|
| 202 |
+
hi_f16 = hi_u16.to(tl.float16, bitcast=True)
|
| 203 |
+
|
| 204 |
+
if intermediate_dtype == tl.float16:
|
| 205 |
+
x0, x1 = lo_f16, hi_f16
|
| 206 |
+
else:
|
| 207 |
+
x0 = lo_f16.to(intermediate_dtype)
|
| 208 |
+
x1 = hi_f16.to(intermediate_dtype)
|
| 209 |
+
|
| 210 |
+
dst_tensor = tl.interleave(x0, x1)
|
| 211 |
+
|
| 212 |
+
else:
|
| 213 |
+
assert is_fp4
|
| 214 |
+
dst_bias: tl.constexpr = 127 if intermediate_dtype == tl.bfloat16 else 15 # exponent bias
|
| 215 |
+
dst_0p5: tl.constexpr = 16128 if intermediate_dtype == tl.bfloat16 else 0x3800
|
| 216 |
+
dst_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10 # mantissa bits
|
| 217 |
+
# e2m1
|
| 218 |
+
em0 = mx_tensor & 0x07
|
| 219 |
+
em1 = mx_tensor & 0x70
|
| 220 |
+
x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((mx_tensor & 0x08).to(tl.uint16) << 12)
|
| 221 |
+
x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((mx_tensor & 0x80).to(tl.uint16) << 8)
|
| 222 |
+
# Three cases:
|
| 223 |
+
# 1) x is normal and non-zero: Correct bias
|
| 224 |
+
x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
|
| 225 |
+
x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
|
| 226 |
+
# 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
| 227 |
+
x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
|
| 228 |
+
x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
|
| 229 |
+
# 3) x is zero, do nothing
|
| 230 |
+
dst_tensor = tl.interleave(x0, x1).to(intermediate_dtype, bitcast=True)
|
| 231 |
+
|
| 232 |
+
dst_tensor = dst_tensor.to(dst_dtype)
|
| 233 |
+
|
| 234 |
+
# Reshape for proper broadcasting: the scale was stored with a 16‐sized “inner” grouping.
|
| 235 |
+
dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
| 236 |
+
dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
|
| 237 |
+
scale = scale.reshape(dst_scale.shape)
|
| 238 |
+
|
| 239 |
+
out_tensor = dst_tensor * dst_scale * s_dec # NVFP4 has the additional global scale factor
|
| 240 |
+
if dst_dtype == tl.float32:
|
| 241 |
+
max_fin = 3.4028234663852886e+38
|
| 242 |
+
elif dst_dtype == tl.bfloat16:
|
| 243 |
+
max_fin = 3.3895313892515355e+38
|
| 244 |
+
else:
|
| 245 |
+
tl.static_assert(dst_dtype == tl.float16)
|
| 246 |
+
max_fin = 65504
|
| 247 |
+
out_tensor = tl.clamp(out_tensor, min=-max_fin, max=max_fin)
|
| 248 |
+
out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
|
| 249 |
+
out_tensor = out_tensor.to(dst_dtype)
|
| 250 |
+
return out_tensor
|
standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import triton
|
| 2 |
+
import triton.language as tl
|
| 3 |
+
|
| 4 |
+
from .nvfp4_utils import _compute_quant_and_scale, _compute_dequant
|
| 5 |
+
|
| 6 |
+
@triton.jit
|
| 7 |
+
def fake_quantize(src_tensor, valid_src_mask, BLOCK_SIZE_OUT_DIM: tl.constexpr,
|
| 8 |
+
BLOCK_SIZE_QUANT_DIM: tl.constexpr,
|
| 9 |
+
dst_dtype: tl.constexpr,
|
| 10 |
+
mx_tensor_dtype: tl.constexpr = tl.uint8,
|
| 11 |
+
use_global_sf: tl.constexpr = True,
|
| 12 |
+
two_level_quant_P: tl.constexpr = False):
|
| 13 |
+
high_prec_src_tensor = src_tensor
|
| 14 |
+
src_tensor, src_scale, src_s_dec = _compute_quant_and_scale(src_tensor=src_tensor,
|
| 15 |
+
valid_src_mask=valid_src_mask,
|
| 16 |
+
mx_tensor_dtype=mx_tensor_dtype,
|
| 17 |
+
use_global_sf=use_global_sf,
|
| 18 |
+
two_level_quant_P=two_level_quant_P)
|
| 19 |
+
src_tensor = _compute_dequant(mx_tensor=src_tensor,
|
| 20 |
+
scale=src_scale,
|
| 21 |
+
s_dec=src_s_dec,
|
| 22 |
+
BLOCK_SIZE_OUT_DIM=BLOCK_SIZE_OUT_DIM,
|
| 23 |
+
BLOCK_SIZE_QUANT_DIM=BLOCK_SIZE_QUANT_DIM,
|
| 24 |
+
dst_dtype=dst_dtype)
|
| 25 |
+
return src_tensor, high_prec_src_tensor.to(src_tensor.dtype)
|
| 26 |
+
|
| 27 |
+
@triton.jit
|
| 28 |
+
def fake_quantize_q(Q, fake_Q, stride_z_q, stride_h_q,
|
| 29 |
+
stride_tok_q, stride_d_q,
|
| 30 |
+
fake_stride_z_q, fake_stride_h_q,
|
| 31 |
+
fake_stride_tok_q, fake_stride_d_q,
|
| 32 |
+
H, N_CTX_Q,
|
| 33 |
+
BLOCK_M: tl.constexpr,
|
| 34 |
+
HEAD_DIM: tl.constexpr,
|
| 35 |
+
use_global_sf: tl.constexpr = True):
|
| 36 |
+
bhid = tl.program_id(1)
|
| 37 |
+
adj_q = (stride_h_q * (bhid % H) + stride_z_q * (bhid // H))
|
| 38 |
+
fake_adj_q = (fake_stride_h_q * (bhid % H) + fake_stride_z_q * (bhid // H))
|
| 39 |
+
Q += adj_q
|
| 40 |
+
fake_Q += fake_adj_q
|
| 41 |
+
|
| 42 |
+
pid = tl.program_id(0)
|
| 43 |
+
start_m = pid * BLOCK_M
|
| 44 |
+
offs_m = start_m + tl.arange(0, BLOCK_M)
|
| 45 |
+
offs_k = tl.arange(0, HEAD_DIM)
|
| 46 |
+
|
| 47 |
+
q_valid = offs_m < N_CTX_Q
|
| 48 |
+
q = tl.load(Q + offs_m[:, None] * stride_tok_q + offs_k[None, :] * stride_d_q, mask=q_valid[:, None], other=0.0)
|
| 49 |
+
q, _ = fake_quantize(src_tensor=q, valid_src_mask=q_valid[:, None], BLOCK_SIZE_OUT_DIM=BLOCK_M, BLOCK_SIZE_QUANT_DIM=HEAD_DIM, dst_dtype=q.dtype, use_global_sf=use_global_sf)
|
| 50 |
+
tl.store(fake_Q + offs_m[:, None] * fake_stride_tok_q + offs_k[None, :] * fake_stride_d_q, q, mask=q_valid[:, None])
|
| 51 |
+
|
| 52 |
+
@triton.jit
|
| 53 |
+
def fake_quantize_kv(K, V, fake_K, fake_V, stride_z_kv, stride_h_kv,
|
| 54 |
+
stride_tok_kv, stride_d_kv,
|
| 55 |
+
fake_stride_z_kv, fake_stride_h_kv,
|
| 56 |
+
fake_stride_tok_kv, fake_stride_d_kv,
|
| 57 |
+
H, N_CTX_KV,
|
| 58 |
+
BLOCK_N: tl.constexpr,
|
| 59 |
+
HEAD_DIM: tl.constexpr,
|
| 60 |
+
use_global_sf: tl.constexpr = True):
|
| 61 |
+
bhid = tl.program_id(1)
|
| 62 |
+
adj_kv = (stride_h_kv * (bhid % H) + stride_z_kv * (bhid // H))
|
| 63 |
+
fake_adj_kv = (fake_stride_h_kv * (bhid % H) + fake_stride_z_kv * (bhid // H))
|
| 64 |
+
K += adj_kv
|
| 65 |
+
V += adj_kv
|
| 66 |
+
fake_K += fake_adj_kv
|
| 67 |
+
fake_V += fake_adj_kv
|
| 68 |
+
|
| 69 |
+
pid = tl.program_id(0)
|
| 70 |
+
start_n = pid * BLOCK_N
|
| 71 |
+
offs_n = start_n + tl.arange(0, BLOCK_N)
|
| 72 |
+
offs_k = tl.arange(0, HEAD_DIM)
|
| 73 |
+
|
| 74 |
+
kv_valid = offs_n < N_CTX_KV
|
| 75 |
+
k_block = tl.load(K + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv, mask=kv_valid[:, None], other=0.0)
|
| 76 |
+
v_block = tl.load(V + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv, mask=kv_valid[:, None], other=0.0)
|
| 77 |
+
k, _ = fake_quantize(src_tensor=k_block, valid_src_mask=kv_valid[:, None], BLOCK_SIZE_OUT_DIM=BLOCK_N, BLOCK_SIZE_QUANT_DIM=HEAD_DIM, dst_dtype=k_block.dtype, use_global_sf=use_global_sf)
|
| 78 |
+
v, _ = fake_quantize(src_tensor=v_block, valid_src_mask=kv_valid[:, None], BLOCK_SIZE_OUT_DIM=BLOCK_N, BLOCK_SIZE_QUANT_DIM=HEAD_DIM, dst_dtype=v_block.dtype, use_global_sf=use_global_sf)
|
| 79 |
+
tl.store(fake_K + offs_n[:, None] * fake_stride_tok_kv + offs_k[None, :] * fake_stride_d_kv, k, mask=kv_valid[:, None])
|
| 80 |
+
tl.store(fake_V + offs_n[:, None] * fake_stride_tok_kv + offs_k[None, :] * fake_stride_d_kv, v, mask=kv_valid[:, None])
|
standalone_inference/overlay_files/fastvideo/api/compat.py
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from collections.abc import Mapping
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
from dataclasses import fields, is_dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from fastvideo.api.overrides import apply_overrides, parse_cli_overrides
|
| 11 |
+
from fastvideo.api.parser import config_to_dict, load_raw_config, parse_config
|
| 12 |
+
from fastvideo.api.schema import (
|
| 13 |
+
GenerationRequest,
|
| 14 |
+
GeneratorConfig,
|
| 15 |
+
InputConfig,
|
| 16 |
+
OutputConfig,
|
| 17 |
+
RequestRuntimeConfig,
|
| 18 |
+
SamplingConfig,
|
| 19 |
+
)
|
| 20 |
+
from fastvideo.configs.sample import SamplingParam
|
| 21 |
+
from fastvideo.fastvideo_args import FastVideoArgs
|
| 22 |
+
from fastvideo.utils import shallow_asdict
|
| 23 |
+
|
| 24 |
+
_EXPLICIT_REQUEST_ATTR = "_fastvideo_explicit_request"
|
| 25 |
+
_INPUT_FIELD_NAMES = {field.name for field in fields(InputConfig)}
|
| 26 |
+
_SAMPLING_FIELD_NAMES = {field.name for field in fields(SamplingConfig)}
|
| 27 |
+
_RUNTIME_FIELD_NAMES = {field.name for field in fields(RequestRuntimeConfig)}
|
| 28 |
+
_OUTPUT_FIELD_NAMES = {field.name for field in fields(OutputConfig)}
|
| 29 |
+
_MISSING = object()
|
| 30 |
+
_LEGACY_REQUEST_ALIASES = {
|
| 31 |
+
"neg_prompt": "negative_prompt",
|
| 32 |
+
}
|
| 33 |
+
_REQUEST_PIPELINE_OVERRIDE_FIELDS = frozenset({
|
| 34 |
+
"embedded_cfg_scale",
|
| 35 |
+
})
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def normalize_generator_config(config: GeneratorConfig | Mapping[str, Any], ) -> GeneratorConfig:
|
| 39 |
+
if isinstance(config, GeneratorConfig):
|
| 40 |
+
return config
|
| 41 |
+
return parse_config(GeneratorConfig, config)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_generator_config_from_file(
|
| 45 |
+
path: str | Path,
|
| 46 |
+
overrides: list[str] | Mapping[str, Any] | None = None,
|
| 47 |
+
) -> GeneratorConfig:
|
| 48 |
+
raw = load_raw_config(path)
|
| 49 |
+
normalized_overrides = _normalize_overrides(overrides)
|
| 50 |
+
|
| 51 |
+
if _looks_like_run_or_serve_config(raw):
|
| 52 |
+
if normalized_overrides:
|
| 53 |
+
raw = apply_overrides(raw, normalized_overrides)
|
| 54 |
+
return parse_config(GeneratorConfig, raw["generator"])
|
| 55 |
+
|
| 56 |
+
if normalized_overrides:
|
| 57 |
+
adjusted = normalized_overrides
|
| 58 |
+
if all(key.startswith("generator.") for key in adjusted):
|
| 59 |
+
adjusted = {key[len("generator."):]: value for key, value in adjusted.items()}
|
| 60 |
+
raw = apply_overrides(raw, adjusted)
|
| 61 |
+
|
| 62 |
+
return parse_config(GeneratorConfig, raw)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def legacy_from_pretrained_to_config(
|
| 66 |
+
model_path: str,
|
| 67 |
+
kwargs: Mapping[str, Any],
|
| 68 |
+
) -> GeneratorConfig:
|
| 69 |
+
raw: dict[str, Any] = {"model_path": model_path}
|
| 70 |
+
engine: dict[str, Any] = {}
|
| 71 |
+
parallelism: dict[str, Any] = {}
|
| 72 |
+
offload: dict[str, Any] = {}
|
| 73 |
+
compile_config: dict[str, Any] = {}
|
| 74 |
+
pipeline: dict[str, Any] = {}
|
| 75 |
+
components: dict[str, Any] = {}
|
| 76 |
+
quantization: dict[str, Any] = {}
|
| 77 |
+
experimental: dict[str, Any] = {}
|
| 78 |
+
|
| 79 |
+
for key, value in kwargs.items():
|
| 80 |
+
if key == "revision":
|
| 81 |
+
raw["revision"] = value
|
| 82 |
+
elif key == "trust_remote_code":
|
| 83 |
+
raw["trust_remote_code"] = value
|
| 84 |
+
elif key == "num_gpus":
|
| 85 |
+
engine["num_gpus"] = value
|
| 86 |
+
elif key == "distributed_executor_backend":
|
| 87 |
+
engine["execution_backend"] = value
|
| 88 |
+
elif key in {"tp_size", "sp_size", "hsdp_replicate_dim", "hsdp_shard_dim", "dist_timeout"}:
|
| 89 |
+
parallelism[key] = value
|
| 90 |
+
elif key == "dit_cpu_offload":
|
| 91 |
+
offload["dit"] = value
|
| 92 |
+
elif key == "dit_layerwise_offload":
|
| 93 |
+
offload["dit_layerwise"] = value
|
| 94 |
+
elif key == "text_encoder_cpu_offload":
|
| 95 |
+
offload["text_encoder"] = value
|
| 96 |
+
elif key == "image_encoder_cpu_offload":
|
| 97 |
+
offload["image_encoder"] = value
|
| 98 |
+
elif key == "vae_cpu_offload":
|
| 99 |
+
offload["vae"] = value
|
| 100 |
+
elif key == "pin_cpu_memory":
|
| 101 |
+
offload["pin_cpu_memory"] = value
|
| 102 |
+
elif key == "enable_torch_compile":
|
| 103 |
+
compile_config["enabled"] = value
|
| 104 |
+
elif key == "torch_compile_kwargs":
|
| 105 |
+
compile_config["kwargs"] = deepcopy(value)
|
| 106 |
+
elif key in {"enable_stage_verification", "use_fsdp_inference", "disable_autocast"}:
|
| 107 |
+
engine[key] = value
|
| 108 |
+
elif key == "override_text_encoder_quant":
|
| 109 |
+
quantization["text_encoder_quant"] = value
|
| 110 |
+
elif key == "transformer_quant":
|
| 111 |
+
quantization["transformer_quant"] = value
|
| 112 |
+
elif key == "workload_type":
|
| 113 |
+
pipeline["workload_type"] = value
|
| 114 |
+
elif key == "lora_path":
|
| 115 |
+
components["lora_path"] = value
|
| 116 |
+
elif key == "override_pipeline_cls_name":
|
| 117 |
+
components["override_pipeline_cls_name"] = value
|
| 118 |
+
elif key == "override_transformer_cls_name":
|
| 119 |
+
components["override_transformer_cls_name"] = value
|
| 120 |
+
elif key == "pipeline_config":
|
| 121 |
+
if isinstance(value, str):
|
| 122 |
+
components["pipeline_config_path"] = value
|
| 123 |
+
else:
|
| 124 |
+
experimental[key] = deepcopy(value)
|
| 125 |
+
elif key == "override_text_encoder_safetensors":
|
| 126 |
+
components["text_encoder_weights"] = value
|
| 127 |
+
elif key == "init_weights_from_safetensors":
|
| 128 |
+
components["transformer_weights"] = value
|
| 129 |
+
elif key == "init_weights_from_safetensors_2":
|
| 130 |
+
components["transformer_2_weights"] = value
|
| 131 |
+
else:
|
| 132 |
+
experimental[key] = deepcopy(value)
|
| 133 |
+
|
| 134 |
+
if parallelism:
|
| 135 |
+
engine["parallelism"] = parallelism
|
| 136 |
+
if offload:
|
| 137 |
+
engine["offload"] = offload
|
| 138 |
+
if compile_config:
|
| 139 |
+
engine["compile"] = compile_config
|
| 140 |
+
if quantization:
|
| 141 |
+
engine["quantization"] = quantization
|
| 142 |
+
if engine:
|
| 143 |
+
raw["engine"] = engine
|
| 144 |
+
|
| 145 |
+
if components:
|
| 146 |
+
pipeline["components"] = components
|
| 147 |
+
if experimental:
|
| 148 |
+
pipeline["experimental"] = experimental
|
| 149 |
+
if pipeline:
|
| 150 |
+
raw["pipeline"] = pipeline
|
| 151 |
+
|
| 152 |
+
return parse_config(GeneratorConfig, raw)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def generator_config_to_fastvideo_args(config: GeneratorConfig | Mapping[str, Any], ) -> FastVideoArgs:
|
| 156 |
+
normalized = normalize_generator_config(config)
|
| 157 |
+
unsupported = []
|
| 158 |
+
if normalized.pipeline.profile is not None:
|
| 159 |
+
unsupported.append("pipeline.profile")
|
| 160 |
+
if normalized.pipeline.profile_version is not None:
|
| 161 |
+
unsupported.append("pipeline.profile_version")
|
| 162 |
+
if normalized.pipeline.components.config_root is not None:
|
| 163 |
+
unsupported.append("pipeline.components.config_root")
|
| 164 |
+
if normalized.pipeline.components.vae_weights is not None:
|
| 165 |
+
unsupported.append("pipeline.components.vae_weights")
|
| 166 |
+
if normalized.pipeline.components.upsampler_weights is not None:
|
| 167 |
+
unsupported.append("pipeline.components.upsampler_weights")
|
| 168 |
+
if unsupported:
|
| 169 |
+
joined = ", ".join(unsupported)
|
| 170 |
+
raise NotImplementedError(f"VideoGenerator compatibility adapter does not support {joined} yet")
|
| 171 |
+
|
| 172 |
+
engine = normalized.engine
|
| 173 |
+
kwargs: dict[str, Any] = {
|
| 174 |
+
"model_path": normalized.model_path,
|
| 175 |
+
"revision": normalized.revision,
|
| 176 |
+
"trust_remote_code": normalized.trust_remote_code,
|
| 177 |
+
"num_gpus": engine.num_gpus,
|
| 178 |
+
"distributed_executor_backend": engine.execution_backend,
|
| 179 |
+
"tp_size": engine.parallelism.tp_size,
|
| 180 |
+
"sp_size": engine.parallelism.sp_size,
|
| 181 |
+
"hsdp_replicate_dim": engine.parallelism.hsdp_replicate_dim,
|
| 182 |
+
"hsdp_shard_dim": engine.parallelism.hsdp_shard_dim,
|
| 183 |
+
"dist_timeout": engine.parallelism.dist_timeout,
|
| 184 |
+
"dit_cpu_offload": engine.offload.dit,
|
| 185 |
+
"dit_layerwise_offload": engine.offload.dit_layerwise,
|
| 186 |
+
"text_encoder_cpu_offload": engine.offload.text_encoder,
|
| 187 |
+
"image_encoder_cpu_offload": engine.offload.image_encoder,
|
| 188 |
+
"vae_cpu_offload": engine.offload.vae,
|
| 189 |
+
"pin_cpu_memory": engine.offload.pin_cpu_memory,
|
| 190 |
+
"enable_torch_compile": engine.compile.enabled,
|
| 191 |
+
"torch_compile_kwargs": deepcopy(engine.compile.kwargs),
|
| 192 |
+
"enable_stage_verification": engine.enable_stage_verification,
|
| 193 |
+
"use_fsdp_inference": engine.use_fsdp_inference,
|
| 194 |
+
"disable_autocast": engine.disable_autocast,
|
| 195 |
+
}
|
| 196 |
+
if normalized.pipeline.workload_type is not None:
|
| 197 |
+
kwargs["workload_type"] = normalized.pipeline.workload_type
|
| 198 |
+
|
| 199 |
+
quantization = engine.quantization
|
| 200 |
+
if quantization is not None and quantization.text_encoder_quant is not None:
|
| 201 |
+
kwargs["override_text_encoder_quant"] = quantization.text_encoder_quant
|
| 202 |
+
if quantization is not None and quantization.transformer_quant is not None:
|
| 203 |
+
kwargs["transformer_quant"] = quantization.transformer_quant
|
| 204 |
+
|
| 205 |
+
components = normalized.pipeline.components
|
| 206 |
+
if components.pipeline_config_path is not None:
|
| 207 |
+
kwargs["pipeline_config"] = components.pipeline_config_path
|
| 208 |
+
if components.lora_path is not None:
|
| 209 |
+
kwargs["lora_path"] = components.lora_path
|
| 210 |
+
if components.override_pipeline_cls_name is not None:
|
| 211 |
+
kwargs["override_pipeline_cls_name"] = components.override_pipeline_cls_name
|
| 212 |
+
if components.override_transformer_cls_name is not None:
|
| 213 |
+
kwargs["override_transformer_cls_name"] = components.override_transformer_cls_name
|
| 214 |
+
if components.text_encoder_weights is not None:
|
| 215 |
+
kwargs["override_text_encoder_safetensors"] = components.text_encoder_weights
|
| 216 |
+
if components.transformer_weights is not None:
|
| 217 |
+
kwargs["init_weights_from_safetensors"] = components.transformer_weights
|
| 218 |
+
if components.transformer_2_weights is not None:
|
| 219 |
+
kwargs["init_weights_from_safetensors_2"] = components.transformer_2_weights
|
| 220 |
+
|
| 221 |
+
kwargs.update(deepcopy(normalized.pipeline.profile_overrides))
|
| 222 |
+
kwargs.update(deepcopy(normalized.pipeline.experimental))
|
| 223 |
+
return FastVideoArgs.from_kwargs(**kwargs)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def normalize_generation_request(request: GenerationRequest | Mapping[str, Any], ) -> GenerationRequest:
|
| 227 |
+
normalized = (request if isinstance(request, GenerationRequest) else parse_config(GenerationRequest, request))
|
| 228 |
+
|
| 229 |
+
if not hasattr(normalized, _EXPLICIT_REQUEST_ATTR):
|
| 230 |
+
setattr(normalized, _EXPLICIT_REQUEST_ATTR, _serialize_generation_request(normalized))
|
| 231 |
+
return normalized
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def legacy_generate_call_to_request(
|
| 235 |
+
prompt: str | None,
|
| 236 |
+
sampling_param: SamplingParam | None,
|
| 237 |
+
*,
|
| 238 |
+
mouse_cond: Any | None = None,
|
| 239 |
+
keyboard_cond: Any | None = None,
|
| 240 |
+
grid_sizes: Any | None = None,
|
| 241 |
+
legacy_kwargs: Mapping[str, Any] | None = None,
|
| 242 |
+
) -> GenerationRequest:
|
| 243 |
+
raw = _sampling_param_to_request_raw(sampling_param)
|
| 244 |
+
if prompt is not None:
|
| 245 |
+
raw["prompt"] = prompt
|
| 246 |
+
|
| 247 |
+
for key, value in (legacy_kwargs or {}).items():
|
| 248 |
+
_apply_request_field(raw, key, value)
|
| 249 |
+
|
| 250 |
+
if mouse_cond is not None:
|
| 251 |
+
raw.setdefault("inputs", {})["mouse_cond"] = mouse_cond
|
| 252 |
+
if keyboard_cond is not None:
|
| 253 |
+
raw.setdefault("inputs", {})["keyboard_cond"] = keyboard_cond
|
| 254 |
+
if grid_sizes is not None:
|
| 255 |
+
raw.setdefault("inputs", {})["grid_sizes"] = grid_sizes
|
| 256 |
+
|
| 257 |
+
normalized = parse_config(GenerationRequest, raw)
|
| 258 |
+
setattr(normalized, _EXPLICIT_REQUEST_ATTR, deepcopy(raw))
|
| 259 |
+
return normalized
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def request_to_sampling_param(
|
| 263 |
+
request: GenerationRequest,
|
| 264 |
+
*,
|
| 265 |
+
model_path: str,
|
| 266 |
+
) -> SamplingParam:
|
| 267 |
+
if request.plan is not None:
|
| 268 |
+
raise NotImplementedError("GenerationRequest.plan is not wired into VideoGenerator yet")
|
| 269 |
+
if request.state is not None:
|
| 270 |
+
raise NotImplementedError("GenerationRequest.state is not wired into VideoGenerator yet")
|
| 271 |
+
|
| 272 |
+
sampling_param = SamplingParam.from_pretrained(model_path)
|
| 273 |
+
updates = _explicit_request_updates(request)
|
| 274 |
+
|
| 275 |
+
for key, value in updates.items():
|
| 276 |
+
if hasattr(sampling_param, key):
|
| 277 |
+
setattr(sampling_param, key, deepcopy(value))
|
| 278 |
+
elif key in _REQUEST_PIPELINE_OVERRIDE_FIELDS or _is_supported_as_default_only(key, value):
|
| 279 |
+
continue
|
| 280 |
+
else:
|
| 281 |
+
raise ValueError(f"Request field {key!r} is not supported by sampling params for {model_path}")
|
| 282 |
+
|
| 283 |
+
sampling_param.__post_init__()
|
| 284 |
+
sampling_param.check_sampling_param()
|
| 285 |
+
return sampling_param
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def expand_request_prompt_batch(request: GenerationRequest, ) -> list[GenerationRequest]:
|
| 289 |
+
if not isinstance(request.prompt, list):
|
| 290 |
+
return [request]
|
| 291 |
+
|
| 292 |
+
requests: list[GenerationRequest] = []
|
| 293 |
+
for index, prompt in enumerate(request.prompt):
|
| 294 |
+
single_request = deepcopy(request)
|
| 295 |
+
single_request.prompt = prompt
|
| 296 |
+
_fan_out_batched_input_value(request, single_request, "image_path", index)
|
| 297 |
+
_fan_out_batched_input_value(request, single_request, "video_path", index)
|
| 298 |
+
_fan_out_explicit_request_metadata(request, single_request, index, prompt)
|
| 299 |
+
requests.append(single_request)
|
| 300 |
+
return requests
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def _looks_like_run_or_serve_config(raw: Mapping[str, Any]) -> bool:
|
| 304 |
+
return isinstance(raw.get("generator"), Mapping)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _normalize_overrides(overrides: list[str] | Mapping[str, Any] | None, ) -> dict[str, Any] | None:
|
| 308 |
+
if not overrides:
|
| 309 |
+
return None
|
| 310 |
+
if isinstance(overrides, list):
|
| 311 |
+
return parse_cli_overrides(overrides)
|
| 312 |
+
return dict(overrides)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def _sampling_param_to_request_raw(sampling_param: SamplingParam | None, ) -> dict[str, Any]:
|
| 316 |
+
if sampling_param is None:
|
| 317 |
+
return {}
|
| 318 |
+
|
| 319 |
+
raw: dict[str, Any] = {}
|
| 320 |
+
for key, value in shallow_asdict(sampling_param).items():
|
| 321 |
+
if key == "prompt":
|
| 322 |
+
continue
|
| 323 |
+
_apply_request_field(raw, key, deepcopy(value))
|
| 324 |
+
return raw
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def _apply_request_field(
|
| 328 |
+
raw: dict[str, Any],
|
| 329 |
+
key: str,
|
| 330 |
+
value: Any,
|
| 331 |
+
) -> None:
|
| 332 |
+
key = _LEGACY_REQUEST_ALIASES.get(key, key)
|
| 333 |
+
if key == "negative_prompt":
|
| 334 |
+
raw["negative_prompt"] = value
|
| 335 |
+
return
|
| 336 |
+
if key in _INPUT_FIELD_NAMES:
|
| 337 |
+
raw.setdefault("inputs", {})[key] = value
|
| 338 |
+
return
|
| 339 |
+
if key in _SAMPLING_FIELD_NAMES:
|
| 340 |
+
raw.setdefault("sampling", {})[key] = value
|
| 341 |
+
return
|
| 342 |
+
if key in _RUNTIME_FIELD_NAMES:
|
| 343 |
+
raw.setdefault("runtime", {})[key] = value
|
| 344 |
+
return
|
| 345 |
+
if key in _OUTPUT_FIELD_NAMES:
|
| 346 |
+
raw.setdefault("output", {})[key] = value
|
| 347 |
+
return
|
| 348 |
+
raw.setdefault("extensions", {})[key] = value
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def request_to_pipeline_overrides(request: GenerationRequest) -> dict[str, Any]:
|
| 352 |
+
overrides: dict[str, Any] = {}
|
| 353 |
+
for key, value in _explicit_request_updates(request).items():
|
| 354 |
+
if key in _REQUEST_PIPELINE_OVERRIDE_FIELDS:
|
| 355 |
+
overrides[key] = deepcopy(value)
|
| 356 |
+
return overrides
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def _explicit_request_updates(request: GenerationRequest) -> dict[str, Any]:
|
| 360 |
+
raw = getattr(request, _EXPLICIT_REQUEST_ATTR, None)
|
| 361 |
+
if raw is None:
|
| 362 |
+
raw = _serialize_generation_request(request)
|
| 363 |
+
|
| 364 |
+
return _extract_request_updates(raw)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def _extract_request_updates(raw: Mapping[str, Any]) -> dict[str, Any]:
|
| 368 |
+
updates: dict[str, Any] = {}
|
| 369 |
+
if "negative_prompt" in raw:
|
| 370 |
+
updates["negative_prompt"] = deepcopy(raw["negative_prompt"])
|
| 371 |
+
|
| 372 |
+
for section_name in ("inputs", "sampling", "runtime", "output"):
|
| 373 |
+
section = raw.get(section_name)
|
| 374 |
+
if not isinstance(section, Mapping):
|
| 375 |
+
continue
|
| 376 |
+
for key, value in section.items():
|
| 377 |
+
updates[key] = deepcopy(value)
|
| 378 |
+
|
| 379 |
+
stage_overrides = raw.get("stage_overrides")
|
| 380 |
+
if stage_overrides:
|
| 381 |
+
updates.update(_flatten_stage_overrides(stage_overrides))
|
| 382 |
+
|
| 383 |
+
extensions = raw.get("extensions")
|
| 384 |
+
if isinstance(extensions, Mapping):
|
| 385 |
+
for key, value in extensions.items():
|
| 386 |
+
updates[key] = deepcopy(value)
|
| 387 |
+
|
| 388 |
+
return updates
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def _flatten_stage_overrides(stage_overrides: Any) -> dict[str, Any]:
|
| 392 |
+
if not isinstance(stage_overrides, Mapping):
|
| 393 |
+
raise ValueError("GenerationRequest.stage_overrides must be a mapping")
|
| 394 |
+
|
| 395 |
+
flattened: dict[str, Any] = {}
|
| 396 |
+
for stage_name, overrides in stage_overrides.items():
|
| 397 |
+
if not isinstance(overrides, Mapping):
|
| 398 |
+
raise ValueError(f"GenerationRequest.stage_overrides.{stage_name} must be a mapping")
|
| 399 |
+
for key, value in overrides.items():
|
| 400 |
+
if key in flattened and flattened[key] != value:
|
| 401 |
+
raise ValueError(f"Conflicting stage override for {key!r} across stages")
|
| 402 |
+
flattened[key] = deepcopy(value)
|
| 403 |
+
return flattened
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def _serialize_generation_request(request: GenerationRequest) -> dict[str, Any]:
|
| 407 |
+
return deepcopy(config_to_dict(request))
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def _fan_out_batched_input_value(
|
| 411 |
+
source_request: GenerationRequest,
|
| 412 |
+
target_request: GenerationRequest,
|
| 413 |
+
field_name: str,
|
| 414 |
+
index: int,
|
| 415 |
+
) -> None:
|
| 416 |
+
value = getattr(source_request.inputs, field_name)
|
| 417 |
+
if not isinstance(value, list):
|
| 418 |
+
return
|
| 419 |
+
_validate_batched_input_length(source_request.prompt, value, field_name)
|
| 420 |
+
setattr(target_request.inputs, field_name, deepcopy(value[index]))
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def _fan_out_explicit_request_metadata(
|
| 424 |
+
source_request: GenerationRequest,
|
| 425 |
+
target_request: GenerationRequest,
|
| 426 |
+
index: int,
|
| 427 |
+
prompt: str,
|
| 428 |
+
) -> None:
|
| 429 |
+
raw = getattr(source_request, _EXPLICIT_REQUEST_ATTR, None)
|
| 430 |
+
if raw is None:
|
| 431 |
+
return
|
| 432 |
+
|
| 433 |
+
raw = deepcopy(raw)
|
| 434 |
+
raw["prompt"] = prompt
|
| 435 |
+
inputs = raw.get("inputs")
|
| 436 |
+
if isinstance(inputs, dict):
|
| 437 |
+
for field_name in ("image_path", "video_path"):
|
| 438 |
+
value = inputs.get(field_name)
|
| 439 |
+
if isinstance(value, list):
|
| 440 |
+
_validate_batched_input_length(source_request.prompt, value, field_name)
|
| 441 |
+
inputs[field_name] = deepcopy(value[index])
|
| 442 |
+
|
| 443 |
+
setattr(target_request, _EXPLICIT_REQUEST_ATTR, raw)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def _validate_batched_input_length(
|
| 447 |
+
prompts: str | list[str] | None,
|
| 448 |
+
values: list[Any],
|
| 449 |
+
field_name: str,
|
| 450 |
+
) -> None:
|
| 451 |
+
if not isinstance(prompts, list):
|
| 452 |
+
return
|
| 453 |
+
if len(values) != len(prompts):
|
| 454 |
+
raise ValueError(f"GenerationRequest.inputs.{field_name} must have the same length as request.prompt")
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def _is_supported_as_default_only(key: str, value: Any) -> bool:
|
| 458 |
+
default_value = _DEFAULT_REQUEST_UPDATES.get(key, _MISSING)
|
| 459 |
+
return default_value is not _MISSING and _values_equal(value, default_value)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def _collect_non_default_fields(
|
| 463 |
+
value: Any,
|
| 464 |
+
default: Any,
|
| 465 |
+
) -> dict[str, Any]:
|
| 466 |
+
if not (is_dataclass(value) and is_dataclass(default)):
|
| 467 |
+
return {}
|
| 468 |
+
|
| 469 |
+
result: dict[str, Any] = {}
|
| 470 |
+
for field in fields(value):
|
| 471 |
+
current = getattr(value, field.name)
|
| 472 |
+
default_value = getattr(default, field.name)
|
| 473 |
+
if is_dataclass(current) and is_dataclass(default_value):
|
| 474 |
+
nested = _collect_non_default_fields(current, default_value)
|
| 475 |
+
if nested:
|
| 476 |
+
result[field.name] = nested
|
| 477 |
+
continue
|
| 478 |
+
if not _values_equal(current, default_value):
|
| 479 |
+
result[field.name] = deepcopy(current)
|
| 480 |
+
return result
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def _values_equal(left: Any, right: Any) -> bool:
|
| 484 |
+
if left is right:
|
| 485 |
+
return True
|
| 486 |
+
try:
|
| 487 |
+
return bool(left == right)
|
| 488 |
+
except Exception:
|
| 489 |
+
return False
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
_DEFAULT_REQUEST_UPDATES = _extract_request_updates(config_to_dict(GenerationRequest()))
|
| 493 |
+
|
| 494 |
+
__all__ = [
|
| 495 |
+
"generator_config_to_fastvideo_args",
|
| 496 |
+
"legacy_from_pretrained_to_config",
|
| 497 |
+
"legacy_generate_call_to_request",
|
| 498 |
+
"load_generator_config_from_file",
|
| 499 |
+
"normalize_generation_request",
|
| 500 |
+
"normalize_generator_config",
|
| 501 |
+
"request_to_pipeline_overrides",
|
| 502 |
+
"request_to_sampling_param",
|
| 503 |
+
]
|
standalone_inference/overlay_files/fastvideo/attention/backends/sparse_fp4_ours_p_attn.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Sparse FP4 Attention backend with the independent ours-P quant kernel."""
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import triton
|
| 9 |
+
|
| 10 |
+
from fastvideo_kernel.triton_kernels.quant_utils import (
|
| 11 |
+
fake_quantize_q,
|
| 12 |
+
fake_quantize_kv,
|
| 13 |
+
)
|
| 14 |
+
from fastvideo_kernel.block_sparse_attn_ours_p import block_sparse_attn_ours_p
|
| 15 |
+
from fastvideo.forward_context import get_forward_context
|
| 16 |
+
|
| 17 |
+
from fastvideo.attention.backends.abstract import (
|
| 18 |
+
AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataBuilder,
|
| 19 |
+
)
|
| 20 |
+
from fastvideo.attention.backends.video_sparse_attn import (
|
| 21 |
+
VideoSparseAttentionMetadata,
|
| 22 |
+
VideoSparseAttentionMetadataBuilder,
|
| 23 |
+
VSA_TILE_SIZE,
|
| 24 |
+
)
|
| 25 |
+
from fastvideo.distributed import get_sp_group
|
| 26 |
+
from fastvideo.logger import init_logger
|
| 27 |
+
|
| 28 |
+
logger = init_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _dense_sdpa_blhd(query, key, value):
|
| 32 |
+
q = query.transpose(1, 2)
|
| 33 |
+
k = key.transpose(1, 2)
|
| 34 |
+
v = value.transpose(1, 2)
|
| 35 |
+
out = F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
| 36 |
+
return out.transpose(1, 2)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _quantize_qkv_bhld(q, k, v):
|
| 40 |
+
"""FP4 fake quantize Q/K/V in BHLD layout, same as attn_qat_train."""
|
| 41 |
+
H = q.shape[1]
|
| 42 |
+
N_Q = q.shape[2]
|
| 43 |
+
N_KV = k.shape[2]
|
| 44 |
+
D = q.shape[3]
|
| 45 |
+
BLOCK = 32
|
| 46 |
+
|
| 47 |
+
fake_q = torch.empty_like(q)
|
| 48 |
+
fake_k = torch.empty_like(k)
|
| 49 |
+
fake_v = torch.empty_like(v)
|
| 50 |
+
|
| 51 |
+
grid_q = (triton.cdiv(N_Q, BLOCK), q.shape[0] * H, 1)
|
| 52 |
+
grid_kv = (triton.cdiv(N_KV, BLOCK), q.shape[0] * H, 1)
|
| 53 |
+
|
| 54 |
+
fake_quantize_q[grid_q](
|
| 55 |
+
q, fake_q,
|
| 56 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
| 57 |
+
fake_q.stride(0), fake_q.stride(1), fake_q.stride(2), fake_q.stride(3),
|
| 58 |
+
H, N_Q, BLOCK_M=BLOCK, HEAD_DIM=D, use_global_sf=False,
|
| 59 |
+
)
|
| 60 |
+
fake_quantize_kv[grid_kv](
|
| 61 |
+
k, v, fake_k, fake_v,
|
| 62 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
| 63 |
+
fake_k.stride(0), fake_k.stride(1), fake_k.stride(2), fake_k.stride(3),
|
| 64 |
+
H, N_KV, BLOCK_N=BLOCK, HEAD_DIM=D, use_global_sf=False,
|
| 65 |
+
)
|
| 66 |
+
return fake_q, fake_k, fake_v
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class SparseFP4OursPAttentionBackend(AttentionBackend):
|
| 70 |
+
accept_output_buffer: bool = True
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def get_supported_head_sizes() -> list[int]:
|
| 74 |
+
return [64, 96, 128, 160, 192, 224, 256]
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def get_name() -> str:
|
| 78 |
+
return "SPARSE_FP4_OURS_P_ATTN"
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
def get_impl_cls() -> type["SparseFP4OursPAttentionImpl"]:
|
| 82 |
+
return SparseFP4OursPAttentionImpl
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def get_metadata_cls() -> type["VideoSparseAttentionMetadata"]:
|
| 86 |
+
return VideoSparseAttentionMetadata
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def get_builder_cls() -> type["VideoSparseAttentionMetadataBuilder"]:
|
| 90 |
+
return VideoSparseAttentionMetadataBuilder
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class SparseFP4OursPAttentionImpl(AttentionImpl):
|
| 94 |
+
|
| 95 |
+
def __init__(self, num_heads, head_size, causal, softmax_scale,
|
| 96 |
+
num_kv_heads=None, prefix="", **extra):
|
| 97 |
+
self.prefix = prefix
|
| 98 |
+
self.sp_size = get_sp_group().world_size
|
| 99 |
+
|
| 100 |
+
def tile(self, x, num_tiles, tile_partition_indices, non_pad_index):
|
| 101 |
+
t_p = num_tiles[0] * VSA_TILE_SIZE[0]
|
| 102 |
+
h_p = num_tiles[1] * VSA_TILE_SIZE[1]
|
| 103 |
+
w_p = num_tiles[2] * VSA_TILE_SIZE[2]
|
| 104 |
+
out = torch.zeros(
|
| 105 |
+
(x.shape[0], t_p * h_p * w_p, x.shape[-2], x.shape[-1]),
|
| 106 |
+
device=x.device, dtype=x.dtype,
|
| 107 |
+
)
|
| 108 |
+
out[:, non_pad_index] = x[:, tile_partition_indices]
|
| 109 |
+
return out
|
| 110 |
+
|
| 111 |
+
def untile(self, x, reverse_tile_partition_indices, non_pad_index):
|
| 112 |
+
return x[:, non_pad_index][:, reverse_tile_partition_indices]
|
| 113 |
+
|
| 114 |
+
def _is_force_dense(self) -> bool:
|
| 115 |
+
ctx = get_forward_context()
|
| 116 |
+
return ctx.force_dense
|
| 117 |
+
|
| 118 |
+
def preprocess_qkv(self, qkv, attn_metadata):
|
| 119 |
+
if attn_metadata is None or self._is_force_dense():
|
| 120 |
+
return qkv
|
| 121 |
+
return self.tile(qkv, attn_metadata.num_tiles,
|
| 122 |
+
attn_metadata.tile_partition_indices,
|
| 123 |
+
attn_metadata.non_pad_index)
|
| 124 |
+
|
| 125 |
+
def postprocess_output(self, output, attn_metadata):
|
| 126 |
+
if attn_metadata is None or self._is_force_dense():
|
| 127 |
+
return output
|
| 128 |
+
return self.untile(output,
|
| 129 |
+
attn_metadata.reverse_tile_partition_indices,
|
| 130 |
+
attn_metadata.non_pad_index)
|
| 131 |
+
|
| 132 |
+
def forward(self, query, key, value,
|
| 133 |
+
gate_compress_or_metadata=None, attn_metadata=None):
|
| 134 |
+
# Handle both call conventions
|
| 135 |
+
if attn_metadata is None and isinstance(
|
| 136 |
+
gate_compress_or_metadata, (VideoSparseAttentionMetadata, type(None))):
|
| 137 |
+
attn_metadata = gate_compress_or_metadata
|
| 138 |
+
|
| 139 |
+
# ── force_dense: true dense BF16 SDPA (for teacher in distillation) ──
|
| 140 |
+
ctx = get_forward_context()
|
| 141 |
+
if ctx.force_dense:
|
| 142 |
+
return _dense_sdpa_blhd(query, key, value)
|
| 143 |
+
|
| 144 |
+
is_cross = query.shape[1] != key.shape[1]
|
| 145 |
+
|
| 146 |
+
# ── Cross-attention/no metadata: keep dense. The sparse VSA metadata only
|
| 147 |
+
# applies to tiled video self-attention.
|
| 148 |
+
if attn_metadata is None or is_cross:
|
| 149 |
+
return _dense_sdpa_blhd(query, key, value)
|
| 150 |
+
|
| 151 |
+
# ── Self-attention: FP4 quant Q/K/V + block-sparse attention ──
|
| 152 |
+
# BLHD → BHLD
|
| 153 |
+
q = query.transpose(1, 2).contiguous()
|
| 154 |
+
k = key.transpose(1, 2).contiguous()
|
| 155 |
+
v = value.transpose(1, 2).contiguous()
|
| 156 |
+
|
| 157 |
+
# Step 1: FP4 fake quantize Q/K/V with STE (straight-through estimator)
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
fq, fk, fv = _quantize_qkv_bhld(q, k, v)
|
| 160 |
+
# STE: forward uses quantized values, backward passes gradient through as-is
|
| 161 |
+
fq = q + (fq - q).detach()
|
| 162 |
+
fk = k + (fk - k).detach()
|
| 163 |
+
fv = v + (fv - v).detach()
|
| 164 |
+
|
| 165 |
+
# Step 2: Build sparse block map
|
| 166 |
+
B, H, S, D = fq.shape
|
| 167 |
+
block_elements = math.prod(VSA_TILE_SIZE)
|
| 168 |
+
num_blocks = S // block_elements
|
| 169 |
+
|
| 170 |
+
VSA_sparsity = attn_metadata.VSA_sparsity
|
| 171 |
+
cur_topk = max(1, math.ceil((1 - VSA_sparsity) * num_blocks))
|
| 172 |
+
logger.info(f"[SFP4] S={S} num_blocks={num_blocks} sparsity={VSA_sparsity} topk={cur_topk}/{num_blocks}")
|
| 173 |
+
|
| 174 |
+
block_sizes = attn_metadata.variable_block_sizes.to(
|
| 175 |
+
device=fq.device, dtype=torch.float32).clamp_min(1)
|
| 176 |
+
block_sizes = block_sizes.view(1, 1, num_blocks, 1)
|
| 177 |
+
q_c = (fq.view(B, H, num_blocks, block_elements, D).float().sum(3) /
|
| 178 |
+
block_sizes).to(fq.dtype)
|
| 179 |
+
k_c = (fk.view(B, H, num_blocks, block_elements, D).float().sum(3) /
|
| 180 |
+
block_sizes).to(fk.dtype)
|
| 181 |
+
v_c = (fv.view(B, H, num_blocks, block_elements, D).float().sum(3) /
|
| 182 |
+
block_sizes).to(fv.dtype)
|
| 183 |
+
scores = torch.matmul(q_c, k_c.transpose(-2, -1)) / (D ** 0.5)
|
| 184 |
+
topk_idx = torch.topk(scores, cur_topk, dim=-1).indices
|
| 185 |
+
block_map = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, topk_idx, True)
|
| 186 |
+
|
| 187 |
+
# Step 3: Block-sparse attention with independent group-local P quant.
|
| 188 |
+
out, _ = block_sparse_attn_ours_p(fq, fk, fv, block_map,
|
| 189 |
+
attn_metadata.variable_block_sizes,
|
| 190 |
+
q_c, k_c, v_c)
|
| 191 |
+
|
| 192 |
+
return out.transpose(1, 2) # BHLD → BLHD
|
standalone_inference/overlay_files/fastvideo/attention/backends/video_sparse_attn.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
import functools
|
| 3 |
+
import math
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from fastvideo_kernel import video_sparse_attn
|
| 10 |
+
except ImportError:
|
| 11 |
+
video_sparse_attn = None
|
| 12 |
+
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
from fastvideo.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata,
|
| 16 |
+
AttentionMetadataBuilder)
|
| 17 |
+
from fastvideo.distributed import get_sp_group
|
| 18 |
+
from fastvideo.logger import init_logger
|
| 19 |
+
|
| 20 |
+
logger = init_logger(__name__)
|
| 21 |
+
VSA_TILE_SIZE = (4, 4, 4)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@functools.lru_cache(maxsize=10)
|
| 25 |
+
def get_tile_partition_indices(
|
| 26 |
+
dit_seq_shape: tuple[int, int, int],
|
| 27 |
+
tile_size: tuple[int, int, int],
|
| 28 |
+
device: torch.device,
|
| 29 |
+
) -> torch.LongTensor:
|
| 30 |
+
T, H, W = dit_seq_shape
|
| 31 |
+
ts, hs, ws = tile_size
|
| 32 |
+
indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)
|
| 33 |
+
ls = []
|
| 34 |
+
for t in range(math.ceil(T / ts)):
|
| 35 |
+
for h in range(math.ceil(H / hs)):
|
| 36 |
+
for w in range(math.ceil(W / ws)):
|
| 37 |
+
ls.append(indices[t * ts:min(t * ts + ts, T), h * hs:min(h * hs + hs, H),
|
| 38 |
+
w * ws:min(w * ws + ws, W)].flatten())
|
| 39 |
+
index = torch.cat(ls, dim=0)
|
| 40 |
+
return index
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@functools.lru_cache(maxsize=10)
|
| 44 |
+
def get_reverse_tile_partition_indices(
|
| 45 |
+
dit_seq_shape: tuple[int, int, int],
|
| 46 |
+
tile_size: tuple[int, int, int],
|
| 47 |
+
device: torch.device,
|
| 48 |
+
) -> torch.LongTensor:
|
| 49 |
+
return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@functools.lru_cache(maxsize=10)
|
| 53 |
+
def construct_variable_block_sizes(
|
| 54 |
+
dit_seq_shape: tuple[int, int, int],
|
| 55 |
+
num_tiles: tuple[int, int, int],
|
| 56 |
+
device: torch.device,
|
| 57 |
+
) -> torch.LongTensor:
|
| 58 |
+
"""
|
| 59 |
+
Compute the number of valid (non‑padded) tokens inside every
|
| 60 |
+
(ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order
|
| 61 |
+
(t‑tile, h‑tile, w‑tile) that `rearrange` uses.
|
| 62 |
+
|
| 63 |
+
Returns
|
| 64 |
+
-------
|
| 65 |
+
torch.LongTensor # shape: [∏ full_window_size]
|
| 66 |
+
"""
|
| 67 |
+
# unpack
|
| 68 |
+
t, h, w = dit_seq_shape
|
| 69 |
+
ts_t, ts_h, ts_w = VSA_TILE_SIZE
|
| 70 |
+
n_t, n_h, n_w = num_tiles
|
| 71 |
+
|
| 72 |
+
def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
|
| 73 |
+
"""Vector with the size of each tile along one dimension."""
|
| 74 |
+
sizes = torch.full((n_tiles, ), tile, dtype=torch.int, device=device)
|
| 75 |
+
# size of last (possibly partial) tile
|
| 76 |
+
remainder = dim_len - (n_tiles - 1) * tile
|
| 77 |
+
sizes[-1] = remainder if remainder > 0 else tile
|
| 78 |
+
return sizes
|
| 79 |
+
|
| 80 |
+
t_sizes = _sizes(t, ts_t, n_t) # [n_t]
|
| 81 |
+
h_sizes = _sizes(h, ts_h, n_h) # [n_h]
|
| 82 |
+
w_sizes = _sizes(w, ts_w, n_w) # [n_w]
|
| 83 |
+
|
| 84 |
+
# broadcast‑multiply to get voxels per tile, then flatten
|
| 85 |
+
block_sizes = (
|
| 86 |
+
t_sizes[:, None, None] # [n_t, 1, 1]
|
| 87 |
+
* h_sizes[None, :, None] # [1, n_h, 1]
|
| 88 |
+
* w_sizes[None, None, :] # [1, 1, n_w]
|
| 89 |
+
).reshape(-1) # [n_t * n_h * n_w]
|
| 90 |
+
|
| 91 |
+
return block_sizes
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@functools.lru_cache(maxsize=10)
|
| 95 |
+
def get_non_pad_index(
|
| 96 |
+
variable_block_sizes: torch.LongTensor,
|
| 97 |
+
max_block_size: int,
|
| 98 |
+
):
|
| 99 |
+
n_win = variable_block_sizes.shape[0]
|
| 100 |
+
device = variable_block_sizes.device
|
| 101 |
+
starts_pad = torch.arange(n_win, device=device) * max_block_size
|
| 102 |
+
index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :]
|
| 103 |
+
index_mask = torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
|
| 104 |
+
return index_pad[index_mask]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class VideoSparseAttentionBackend(AttentionBackend):
|
| 108 |
+
|
| 109 |
+
accept_output_buffer: bool = True
|
| 110 |
+
|
| 111 |
+
@staticmethod
|
| 112 |
+
def get_supported_head_sizes() -> list[int]:
|
| 113 |
+
return [64, 128]
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def get_name() -> str:
|
| 117 |
+
return "VIDEO_SPARSE_ATTN"
|
| 118 |
+
|
| 119 |
+
@staticmethod
|
| 120 |
+
def get_impl_cls() -> type["VideoSparseAttentionImpl"]:
|
| 121 |
+
return VideoSparseAttentionImpl
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def get_metadata_cls() -> type["VideoSparseAttentionMetadata"]:
|
| 125 |
+
return VideoSparseAttentionMetadata
|
| 126 |
+
|
| 127 |
+
@staticmethod
|
| 128 |
+
def get_builder_cls() -> type["VideoSparseAttentionMetadataBuilder"]:
|
| 129 |
+
return VideoSparseAttentionMetadataBuilder
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@dataclass
|
| 133 |
+
class VideoSparseAttentionMetadata(AttentionMetadata):
|
| 134 |
+
current_timestep: int
|
| 135 |
+
dit_seq_shape: list[int]
|
| 136 |
+
num_tiles: list[int]
|
| 137 |
+
total_seq_length: int
|
| 138 |
+
tile_partition_indices: torch.LongTensor
|
| 139 |
+
reverse_tile_partition_indices: torch.LongTensor
|
| 140 |
+
variable_block_sizes: torch.LongTensor
|
| 141 |
+
non_pad_index: torch.LongTensor
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class VideoSparseAttentionMetadataBuilder(AttentionMetadataBuilder):
|
| 145 |
+
|
| 146 |
+
def __init__(self) -> None:
|
| 147 |
+
pass
|
| 148 |
+
|
| 149 |
+
def prepare(self) -> None:
|
| 150 |
+
pass
|
| 151 |
+
|
| 152 |
+
def build( # type: ignore
|
| 153 |
+
self,
|
| 154 |
+
current_timestep: int,
|
| 155 |
+
raw_latent_shape: tuple[int, int, int],
|
| 156 |
+
patch_size: tuple[int, int, int],
|
| 157 |
+
VSA_sparsity: float,
|
| 158 |
+
device: torch.device,
|
| 159 |
+
**kwargs: dict[str, Any],
|
| 160 |
+
) -> VideoSparseAttentionMetadata:
|
| 161 |
+
patch_size = patch_size
|
| 162 |
+
dit_seq_shape = (raw_latent_shape[0] // patch_size[0], raw_latent_shape[1] // patch_size[1],
|
| 163 |
+
raw_latent_shape[2] // patch_size[2])
|
| 164 |
+
|
| 165 |
+
num_tiles = (math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]), math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),
|
| 166 |
+
math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]))
|
| 167 |
+
total_seq_length = math.prod(dit_seq_shape)
|
| 168 |
+
|
| 169 |
+
tile_partition_indices = get_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
|
| 170 |
+
reverse_tile_partition_indices = get_reverse_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
|
| 171 |
+
variable_block_sizes = construct_variable_block_sizes(dit_seq_shape, num_tiles, device)
|
| 172 |
+
non_pad_index = get_non_pad_index(variable_block_sizes, math.prod(VSA_TILE_SIZE))
|
| 173 |
+
|
| 174 |
+
return VideoSparseAttentionMetadata(
|
| 175 |
+
current_timestep=current_timestep,
|
| 176 |
+
dit_seq_shape=dit_seq_shape, # type: ignore
|
| 177 |
+
VSA_sparsity=VSA_sparsity, # type: ignore
|
| 178 |
+
num_tiles=num_tiles, # type: ignore
|
| 179 |
+
total_seq_length=total_seq_length, # type: ignore
|
| 180 |
+
tile_partition_indices=tile_partition_indices, # type: ignore
|
| 181 |
+
reverse_tile_partition_indices=reverse_tile_partition_indices,
|
| 182 |
+
variable_block_sizes=variable_block_sizes,
|
| 183 |
+
non_pad_index=non_pad_index)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class VideoSparseAttentionImpl(AttentionImpl):
|
| 187 |
+
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
num_heads: int,
|
| 191 |
+
head_size: int,
|
| 192 |
+
causal: bool,
|
| 193 |
+
softmax_scale: float,
|
| 194 |
+
num_kv_heads: int | None = None,
|
| 195 |
+
prefix: str = "",
|
| 196 |
+
**extra_impl_args,
|
| 197 |
+
) -> None:
|
| 198 |
+
self.prefix = prefix
|
| 199 |
+
sp_group = get_sp_group()
|
| 200 |
+
self.sp_size = sp_group.world_size
|
| 201 |
+
|
| 202 |
+
def tile(self, x: torch.Tensor, num_tiles: list[int], tile_partition_indices: torch.LongTensor,
|
| 203 |
+
non_pad_index: torch.LongTensor) -> torch.Tensor:
|
| 204 |
+
t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
|
| 205 |
+
h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
|
| 206 |
+
w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
|
| 207 |
+
|
| 208 |
+
x_padded = torch.zeros((x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1]),
|
| 209 |
+
device=x.device,
|
| 210 |
+
dtype=x.dtype)
|
| 211 |
+
x_padded[:, non_pad_index] = x[:, tile_partition_indices]
|
| 212 |
+
return x_padded
|
| 213 |
+
|
| 214 |
+
def untile(self, x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor,
|
| 215 |
+
non_pad_index: torch.LongTensor) -> torch.Tensor:
|
| 216 |
+
x = x[:, non_pad_index][:, reverse_tile_partition_indices]
|
| 217 |
+
return x
|
| 218 |
+
|
| 219 |
+
def preprocess_qkv(
|
| 220 |
+
self,
|
| 221 |
+
qkv: torch.Tensor,
|
| 222 |
+
attn_metadata: VideoSparseAttentionMetadata,
|
| 223 |
+
) -> torch.Tensor:
|
| 224 |
+
return self.tile(qkv, attn_metadata.num_tiles, attn_metadata.tile_partition_indices,
|
| 225 |
+
attn_metadata.non_pad_index)
|
| 226 |
+
|
| 227 |
+
def postprocess_output(
|
| 228 |
+
self,
|
| 229 |
+
output: torch.Tensor,
|
| 230 |
+
attn_metadata: VideoSparseAttentionMetadata,
|
| 231 |
+
) -> torch.Tensor:
|
| 232 |
+
return self.untile(output, attn_metadata.reverse_tile_partition_indices, attn_metadata.non_pad_index)
|
| 233 |
+
|
| 234 |
+
def forward( # type: ignore[override]
|
| 235 |
+
self,
|
| 236 |
+
query: torch.Tensor,
|
| 237 |
+
key: torch.Tensor,
|
| 238 |
+
value: torch.Tensor,
|
| 239 |
+
gate_compress: torch.Tensor,
|
| 240 |
+
attn_metadata: VideoSparseAttentionMetadata,
|
| 241 |
+
) -> torch.Tensor:
|
| 242 |
+
query = query.transpose(1, 2).contiguous()
|
| 243 |
+
key = key.transpose(1, 2).contiguous()
|
| 244 |
+
value = value.transpose(1, 2).contiguous()
|
| 245 |
+
gate_compress = gate_compress.transpose(1, 2).contiguous()
|
| 246 |
+
|
| 247 |
+
VSA_sparsity = attn_metadata.VSA_sparsity
|
| 248 |
+
|
| 249 |
+
cur_topk = math.ceil((1 - VSA_sparsity) * (attn_metadata.total_seq_length / math.prod(VSA_TILE_SIZE)))
|
| 250 |
+
|
| 251 |
+
if video_sparse_attn is None:
|
| 252 |
+
raise NotImplementedError("video_sparse_attn is not installed")
|
| 253 |
+
hidden_states = video_sparse_attn(query,
|
| 254 |
+
key,
|
| 255 |
+
value,
|
| 256 |
+
attn_metadata.variable_block_sizes,
|
| 257 |
+
attn_metadata.variable_block_sizes,
|
| 258 |
+
cur_topk,
|
| 259 |
+
block_size=VSA_TILE_SIZE,
|
| 260 |
+
compress_attn_weight=gate_compress).transpose(1, 2)
|
| 261 |
+
|
| 262 |
+
return hidden_states
|
standalone_inference/overlay_files/fastvideo/configs/models/dits/base.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from fastvideo.configs.models.base import ArchConfig, ModelConfig
|
| 6 |
+
from fastvideo.layers.quantization import QuantizationConfig
|
| 7 |
+
from fastvideo.platforms import AttentionBackendEnum
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class DiTArchConfig(ArchConfig):
|
| 12 |
+
_fsdp_shard_conditions: list = field(default_factory=list)
|
| 13 |
+
_compile_conditions: list = field(default_factory=list)
|
| 14 |
+
param_names_mapping: dict = field(default_factory=dict)
|
| 15 |
+
reverse_param_names_mapping: dict = field(default_factory=dict)
|
| 16 |
+
lora_param_names_mapping: dict = field(default_factory=dict)
|
| 17 |
+
_supported_attention_backends: tuple[AttentionBackendEnum,
|
| 18 |
+
...] = (AttentionBackendEnum.SAGE_ATTN, AttentionBackendEnum.FLASH_ATTN,
|
| 19 |
+
AttentionBackendEnum.TORCH_SDPA,
|
| 20 |
+
AttentionBackendEnum.VIDEO_SPARSE_ATTN,
|
| 21 |
+
AttentionBackendEnum.VMOBA_ATTN, AttentionBackendEnum.SAGE_ATTN_THREE,
|
| 22 |
+
AttentionBackendEnum.ATTN_QAT_INFER,
|
| 23 |
+
AttentionBackendEnum.ATTN_QAT_TRAIN, AttentionBackendEnum.SLA_ATTN,
|
| 24 |
+
AttentionBackendEnum.SAGE_SLA_ATTN,
|
| 25 |
+
AttentionBackendEnum.SPARSE_FP4_ATTN,
|
| 26 |
+
AttentionBackendEnum.SPARSE_FP4_OURS_P_ATTN)
|
| 27 |
+
|
| 28 |
+
hidden_size: int = 0
|
| 29 |
+
num_attention_heads: int = 0
|
| 30 |
+
num_channels_latents: int = 0
|
| 31 |
+
in_channels: int | None = 0
|
| 32 |
+
out_channels: int | None = 0
|
| 33 |
+
patch_size: int | tuple[int, int, int] | None = None
|
| 34 |
+
expand_timesteps: bool = False
|
| 35 |
+
num_layers: int = 0
|
| 36 |
+
ffn_dim: int = 0
|
| 37 |
+
exclude_lora_layers: list[str] = field(default_factory=list)
|
| 38 |
+
boundary_ratio: float | None = None
|
| 39 |
+
|
| 40 |
+
def __post_init__(self) -> None:
|
| 41 |
+
if not self._compile_conditions:
|
| 42 |
+
self._compile_conditions = self._fsdp_shard_conditions.copy()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class DiTConfig(ModelConfig):
|
| 47 |
+
arch_config: DiTArchConfig = field(default_factory=DiTArchConfig)
|
| 48 |
+
|
| 49 |
+
# FastVideoDiT-specific parameters
|
| 50 |
+
prefix: str = ""
|
| 51 |
+
quant_config: QuantizationConfig | None = None
|
| 52 |
+
expand_timesteps: bool = False
|
| 53 |
+
boundary_ratio: float | None = None
|
| 54 |
+
|
| 55 |
+
def __post_init__(self) -> None:
|
| 56 |
+
super().__post_init__()
|
| 57 |
+
self.arch_config.expand_timesteps = self.expand_timesteps
|
| 58 |
+
self.arch_config.boundary_ratio = self.boundary_ratio
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def add_cli_args(parser: Any, prefix: str = "dit-config") -> Any:
|
| 62 |
+
"""Add CLI arguments for DiTConfig fields"""
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
f"--{prefix}.prefix",
|
| 65 |
+
type=str,
|
| 66 |
+
dest=f"{prefix.replace('-', '_')}.prefix",
|
| 67 |
+
default=DiTConfig.prefix,
|
| 68 |
+
help="Prefix for the DiT model",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
f"--{prefix}.quant-config",
|
| 73 |
+
type=str,
|
| 74 |
+
dest=f"{prefix.replace('-', '_')}.quant_config",
|
| 75 |
+
default=None,
|
| 76 |
+
help="Quantization configuration for the DiT model",
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return parser
|
standalone_inference/overlay_files/fastvideo/configs/pipelines/wan.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
from collections.abc import Callable
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from fastvideo.configs.models import DiTConfig, EncoderConfig, VAEConfig
|
| 8 |
+
from fastvideo.configs.models.dits import WanVideoConfig
|
| 9 |
+
from fastvideo.configs.models.dits.matrixgame import MatrixGameWanVideoConfig
|
| 10 |
+
from fastvideo.configs.models.encoders import (BaseEncoderOutput, CLIPVisionConfig, T5Config,
|
| 11 |
+
WAN2_1ControlCLIPVisionConfig)
|
| 12 |
+
from fastvideo.configs.models.vaes import WanVAEConfig
|
| 13 |
+
from fastvideo.configs.pipelines.base import PipelineConfig
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def t5_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor:
|
| 17 |
+
mask: torch.Tensor = outputs.attention_mask
|
| 18 |
+
hidden_state: torch.Tensor = outputs.last_hidden_state
|
| 19 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
| 20 |
+
assert torch.isnan(hidden_state).sum() == 0
|
| 21 |
+
prompt_embeds = [u[:v] for u, v in zip(hidden_state, seq_lens, strict=True)]
|
| 22 |
+
prompt_embeds_tensor: torch.Tensor = torch.stack(
|
| 23 |
+
[torch.cat([u, u.new_zeros(512 - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0)
|
| 24 |
+
return prompt_embeds_tensor
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class WanT2V480PConfig(PipelineConfig):
|
| 29 |
+
"""Base configuration for Wan T2V 1.3B pipeline architecture."""
|
| 30 |
+
|
| 31 |
+
# WanConfig-specific parameters with defaults
|
| 32 |
+
# DiT
|
| 33 |
+
dit_config: DiTConfig = field(default_factory=WanVideoConfig)
|
| 34 |
+
# VAE
|
| 35 |
+
vae_config: VAEConfig = field(default_factory=WanVAEConfig)
|
| 36 |
+
vae_tiling: bool = False
|
| 37 |
+
vae_sp: bool = False
|
| 38 |
+
|
| 39 |
+
# Denoising stage
|
| 40 |
+
flow_shift: float | None = 3.0
|
| 41 |
+
|
| 42 |
+
# Text encoding stage
|
| 43 |
+
text_encoder_configs: tuple[EncoderConfig, ...] = field(default_factory=lambda: (T5Config(), ))
|
| 44 |
+
postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor],
|
| 45 |
+
...] = field(default_factory=lambda: (t5_postprocess_text, ))
|
| 46 |
+
|
| 47 |
+
# Precision for each component
|
| 48 |
+
precision: str = "bf16"
|
| 49 |
+
vae_precision: str = "fp32"
|
| 50 |
+
text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32", ))
|
| 51 |
+
|
| 52 |
+
# self-forcing params
|
| 53 |
+
warp_denoising_step: bool = True
|
| 54 |
+
|
| 55 |
+
# WanConfig-specific added parameters
|
| 56 |
+
|
| 57 |
+
def __post_init__(self):
|
| 58 |
+
self.vae_config.load_encoder = False
|
| 59 |
+
self.vae_config.load_decoder = True
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@dataclass
|
| 63 |
+
class WanT2V720PConfig(WanT2V480PConfig):
|
| 64 |
+
"""Base configuration for Wan T2V 14B 720P pipeline architecture."""
|
| 65 |
+
|
| 66 |
+
# WanConfig-specific parameters with defaults
|
| 67 |
+
|
| 68 |
+
# Denoising stage
|
| 69 |
+
flow_shift: float | None = 5.0
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@dataclass
|
| 73 |
+
class WanI2V480PConfig(WanT2V480PConfig):
|
| 74 |
+
"""Base configuration for Wan I2V 14B 480P pipeline architecture."""
|
| 75 |
+
|
| 76 |
+
# WanConfig-specific parameters with defaults
|
| 77 |
+
|
| 78 |
+
# Precision for each component
|
| 79 |
+
image_encoder_config: EncoderConfig = field(default_factory=CLIPVisionConfig)
|
| 80 |
+
image_encoder_precision: str = "fp32"
|
| 81 |
+
|
| 82 |
+
def __post_init__(self) -> None:
|
| 83 |
+
self.vae_config.load_encoder = True
|
| 84 |
+
self.vae_config.load_decoder = True
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@dataclass
|
| 88 |
+
class WanI2V720PConfig(WanI2V480PConfig):
|
| 89 |
+
"""Base configuration for Wan I2V 14B 720P pipeline architecture."""
|
| 90 |
+
|
| 91 |
+
# WanConfig-specific parameters with defaults
|
| 92 |
+
|
| 93 |
+
# Denoising stage
|
| 94 |
+
flow_shift: float | None = 5.0
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@dataclass
|
| 98 |
+
class WANV2VConfig(WanI2V480PConfig):
|
| 99 |
+
"""Configuration for WAN2.1 1.3B Control pipeline."""
|
| 100 |
+
|
| 101 |
+
image_encoder_config: EncoderConfig = field(default_factory=WAN2_1ControlCLIPVisionConfig)
|
| 102 |
+
# CLIP encoder precision
|
| 103 |
+
image_encoder_precision: str = 'bf16'
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@dataclass
|
| 107 |
+
class FastWan2_1_T2V_480P_Config(WanT2V480PConfig):
|
| 108 |
+
"""Base configuration for FastWan T2V 1.3B 480P pipeline architecture with DMD"""
|
| 109 |
+
|
| 110 |
+
# WanConfig-specific parameters with defaults
|
| 111 |
+
|
| 112 |
+
# Denoising stage
|
| 113 |
+
flow_shift: float | None = 8.0
|
| 114 |
+
dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522])
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@dataclass
|
| 118 |
+
class Wan2_2_TI2V_5B_Config(WanT2V480PConfig):
|
| 119 |
+
flow_shift: float | None = 5.0
|
| 120 |
+
ti2v_task: bool = True
|
| 121 |
+
expand_timesteps: bool = True
|
| 122 |
+
|
| 123 |
+
def __post_init__(self) -> None:
|
| 124 |
+
self.vae_config.load_encoder = True
|
| 125 |
+
self.vae_config.load_decoder = True
|
| 126 |
+
self.dit_config.expand_timesteps = self.expand_timesteps
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@dataclass
|
| 130 |
+
class FastWan2_2_TI2V_5B_Config(Wan2_2_TI2V_5B_Config):
|
| 131 |
+
flow_shift: float | None = 5.0
|
| 132 |
+
dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522])
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@dataclass
|
| 136 |
+
class Wan2_2_T2V_A14B_Config(WanT2V480PConfig):
|
| 137 |
+
flow_shift: float | None = 12.0
|
| 138 |
+
boundary_ratio: float | None = 0.875
|
| 139 |
+
|
| 140 |
+
# self-forcing params
|
| 141 |
+
dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 750, 500, 250])
|
| 142 |
+
warp_denoising_step: bool = True
|
| 143 |
+
|
| 144 |
+
def __post_init__(self) -> None:
|
| 145 |
+
self.dit_config.boundary_ratio = self.boundary_ratio
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@dataclass
|
| 149 |
+
class Wan2_2_I2V_A14B_Config(WanI2V480PConfig):
|
| 150 |
+
flow_shift: float | None = 5.0
|
| 151 |
+
boundary_ratio: float | None = 0.900
|
| 152 |
+
|
| 153 |
+
def __post_init__(self) -> None:
|
| 154 |
+
super().__post_init__()
|
| 155 |
+
self.dit_config.boundary_ratio = self.boundary_ratio
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# =============================================
|
| 159 |
+
# ============= Causal Self-Forcing =============
|
| 160 |
+
# =============================================
|
| 161 |
+
@dataclass
|
| 162 |
+
class SelfForcingWanT2V480PConfig(WanT2V480PConfig):
|
| 163 |
+
is_causal: bool = True
|
| 164 |
+
flow_shift: float | None = 5.0
|
| 165 |
+
dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 750, 500, 250])
|
| 166 |
+
warp_denoising_step: bool = True
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@dataclass
|
| 170 |
+
class SelfForcingWan2_2_T2V480PConfig(Wan2_2_T2V_A14B_Config):
|
| 171 |
+
is_causal: bool = True
|
| 172 |
+
flow_shift: float | None = 12.0
|
| 173 |
+
boundary_ratio: float | None = 0.875
|
| 174 |
+
dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 850, 700, 550, 350, 275, 200, 125])
|
| 175 |
+
warp_denoising_step: bool = True
|
| 176 |
+
|
| 177 |
+
def __post_init__(self) -> None:
|
| 178 |
+
self.vae_config.load_encoder = True
|
| 179 |
+
self.vae_config.load_decoder = True
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# =============================================
|
| 183 |
+
# ============= Matrix Game ===================
|
| 184 |
+
# =============================================
|
| 185 |
+
@dataclass
|
| 186 |
+
class MatrixGameBaseI2V480PConfig(WanI2V480PConfig):
|
| 187 |
+
dit_config: DiTConfig = field(default_factory=MatrixGameWanVideoConfig)
|
| 188 |
+
flow_shift: float | None = 5.0
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@dataclass
|
| 192 |
+
class MatrixGameI2V480PConfig(WanI2V480PConfig):
|
| 193 |
+
dit_config: DiTConfig = field(default_factory=MatrixGameWanVideoConfig)
|
| 194 |
+
|
| 195 |
+
image_encoder_config: EncoderConfig = field(default_factory=WAN2_1ControlCLIPVisionConfig)
|
| 196 |
+
|
| 197 |
+
is_causal: bool = True
|
| 198 |
+
flow_shift: float | None = 5.0
|
| 199 |
+
dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 666, 333])
|
| 200 |
+
warp_denoising_step: bool = True
|
| 201 |
+
context_noise: int = 0
|
| 202 |
+
num_frames_per_block: int = 3
|
| 203 |
+
# sliding_window_num_frames: int = 15
|
standalone_inference/overlay_files/fastvideo/configs/sample/base.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from fastvideo.logger import init_logger
|
| 6 |
+
from fastvideo.utils import StoreBoolean
|
| 7 |
+
|
| 8 |
+
logger = init_logger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class SamplingParam:
|
| 13 |
+
"""
|
| 14 |
+
Sampling parameters for video generation.
|
| 15 |
+
"""
|
| 16 |
+
# All fields below are copied from ForwardBatch
|
| 17 |
+
data_type: str = "video"
|
| 18 |
+
|
| 19 |
+
# Image inputs
|
| 20 |
+
image_path: str | None = None
|
| 21 |
+
pil_image: Any | None = None
|
| 22 |
+
|
| 23 |
+
# Video inputs
|
| 24 |
+
video_path: str | None = None
|
| 25 |
+
|
| 26 |
+
# Action control inputs (Matrix-Game)
|
| 27 |
+
mouse_cond: Any | None = None # Shape: (B, T, 2)
|
| 28 |
+
keyboard_cond: Any | None = None # Shape: (B, T, K)
|
| 29 |
+
grid_sizes: Any | None = None # Shape: (3,) [F,H,W]
|
| 30 |
+
|
| 31 |
+
# Camera control inputs (HYWorld)
|
| 32 |
+
pose: str | None = None # Camera trajectory: pose string (e.g., 'w-31') or JSON file path
|
| 33 |
+
|
| 34 |
+
# Camera control inputs (LingBotWorld)
|
| 35 |
+
c2ws_plucker_emb: Any | None = None # Plucker embedding: [B, C, F_lat, H_lat, W_lat]
|
| 36 |
+
|
| 37 |
+
# Refine inputs (LongCat 480p->720p upscaling)
|
| 38 |
+
# Path-based refine (load stage1 video from disk, e.g. MP4)
|
| 39 |
+
refine_from: str | None = None # Path to stage1 video (480p output from distill)
|
| 40 |
+
t_thresh: float = 0.5 # Threshold for timestep scheduling in refinement
|
| 41 |
+
spatial_refine_only: bool = False # If True, only spatial (no temporal doubling)
|
| 42 |
+
num_cond_frames: int = 0 # Number of conditioning frames
|
| 43 |
+
# In-memory refine input (for two-stage pipeline where stage1 frames are already in memory)
|
| 44 |
+
# This mirrors LongCat's demo where a list of frames (e.g. np.ndarray or PIL.Image)
|
| 45 |
+
# is passed directly to the refinement pipeline instead of reloading from disk.
|
| 46 |
+
stage1_video: Any | None = None
|
| 47 |
+
|
| 48 |
+
# Text inputs
|
| 49 |
+
prompt: str | list[str] | None = None
|
| 50 |
+
negative_prompt: str | None = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
| 51 |
+
prompt_path: str | None = None
|
| 52 |
+
output_path: str = "outputs/"
|
| 53 |
+
output_video_name: str | None = None
|
| 54 |
+
|
| 55 |
+
# Batch info
|
| 56 |
+
num_videos_per_prompt: int = 1
|
| 57 |
+
seed: int = 1024
|
| 58 |
+
|
| 59 |
+
# Original dimensions (before VAE scaling)
|
| 60 |
+
num_frames: int = 125
|
| 61 |
+
height: int = 720
|
| 62 |
+
width: int = 1280
|
| 63 |
+
height_sr: int = 1072
|
| 64 |
+
width_sr: int = 1920
|
| 65 |
+
fps: int = 24
|
| 66 |
+
|
| 67 |
+
# Denoising parameters
|
| 68 |
+
num_inference_steps: int = 50
|
| 69 |
+
num_inference_steps_sr: int = 50
|
| 70 |
+
guidance_scale: float = 1.0
|
| 71 |
+
guidance_scale_2: float | None = None
|
| 72 |
+
guidance_rescale: float = 0.0
|
| 73 |
+
boundary_ratio: float | None = None
|
| 74 |
+
sigmas: list[float] | None = None
|
| 75 |
+
|
| 76 |
+
# TeaCache parameters
|
| 77 |
+
enable_teacache: bool = False
|
| 78 |
+
|
| 79 |
+
# GEN3C camera control
|
| 80 |
+
trajectory_type: str | None = None
|
| 81 |
+
movement_distance: float | None = None
|
| 82 |
+
camera_rotation: str | None = None
|
| 83 |
+
|
| 84 |
+
# Misc
|
| 85 |
+
save_video: bool = True
|
| 86 |
+
return_frames: bool = True
|
| 87 |
+
return_trajectory_latents: bool = False # returns all latents for each timestep
|
| 88 |
+
return_trajectory_decoded: bool = False # returns decoded latents for each timestep
|
| 89 |
+
|
| 90 |
+
def __post_init__(self) -> None:
|
| 91 |
+
self.data_type = "video" if self.num_frames > 1 else "image"
|
| 92 |
+
|
| 93 |
+
def __getattr__(self, name: str) -> Any:
|
| 94 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
| 95 |
+
|
| 96 |
+
def check_sampling_param(self) -> None:
|
| 97 |
+
if self.prompt_path and not self.prompt_path.endswith(".txt"):
|
| 98 |
+
raise ValueError("prompt_path must be a txt file")
|
| 99 |
+
|
| 100 |
+
def update(self, source_dict: dict[str, Any]) -> None:
|
| 101 |
+
for key, value in source_dict.items():
|
| 102 |
+
if hasattr(self, key):
|
| 103 |
+
setattr(self, key, value)
|
| 104 |
+
else:
|
| 105 |
+
logger.exception("%s has no attribute %s", type(self).__name__, key)
|
| 106 |
+
|
| 107 |
+
self.__post_init__()
|
| 108 |
+
|
| 109 |
+
@classmethod
|
| 110 |
+
def from_pretrained(cls, model_path: str) -> "SamplingParam":
|
| 111 |
+
from fastvideo.registry import get_sampling_param_cls_for_name
|
| 112 |
+
sampling_cls = get_sampling_param_cls_for_name(model_path)
|
| 113 |
+
if sampling_cls is not None:
|
| 114 |
+
sampling_param: SamplingParam = sampling_cls()
|
| 115 |
+
else:
|
| 116 |
+
logger.warning("Couldn't find an optimal sampling param for %s. Using the default sampling param.",
|
| 117 |
+
model_path)
|
| 118 |
+
sampling_param = cls()
|
| 119 |
+
|
| 120 |
+
return sampling_param
|
| 121 |
+
|
| 122 |
+
@staticmethod
|
| 123 |
+
def add_cli_args(parser: Any) -> Any:
|
| 124 |
+
"""Add CLI arguments for SamplingParam fields"""
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--prompt",
|
| 127 |
+
type=str,
|
| 128 |
+
default=SamplingParam.prompt,
|
| 129 |
+
help="Text prompt for video generation",
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--negative-prompt",
|
| 133 |
+
type=str,
|
| 134 |
+
default=SamplingParam.negative_prompt,
|
| 135 |
+
help="Negative text prompt for video generation",
|
| 136 |
+
)
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--prompt-path",
|
| 139 |
+
type=str,
|
| 140 |
+
default=SamplingParam.prompt_path,
|
| 141 |
+
help="Path to a text file containing the prompt",
|
| 142 |
+
)
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"--output-path",
|
| 145 |
+
type=str,
|
| 146 |
+
default=SamplingParam.output_path,
|
| 147 |
+
help="Path to save the generated video",
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--output-video-name",
|
| 151 |
+
type=str,
|
| 152 |
+
default=SamplingParam.output_video_name,
|
| 153 |
+
help="Name of the output video",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--num-videos-per-prompt",
|
| 157 |
+
type=int,
|
| 158 |
+
default=SamplingParam.num_videos_per_prompt,
|
| 159 |
+
help="Number of videos to generate per prompt",
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--seed",
|
| 163 |
+
type=int,
|
| 164 |
+
default=SamplingParam.seed,
|
| 165 |
+
help="Random seed for generation",
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--num-frames",
|
| 169 |
+
type=int,
|
| 170 |
+
default=SamplingParam.num_frames,
|
| 171 |
+
help="Number of frames to generate",
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--height",
|
| 175 |
+
type=int,
|
| 176 |
+
default=SamplingParam.height,
|
| 177 |
+
help="Height of generated video",
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--width",
|
| 181 |
+
type=int,
|
| 182 |
+
default=SamplingParam.width,
|
| 183 |
+
help="Width of generated video",
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--fps",
|
| 187 |
+
type=int,
|
| 188 |
+
default=SamplingParam.fps,
|
| 189 |
+
help="Frames per second for saved video",
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--num-inference-steps",
|
| 193 |
+
type=int,
|
| 194 |
+
default=SamplingParam.num_inference_steps,
|
| 195 |
+
help="Number of denoising steps",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--guidance-scale",
|
| 199 |
+
type=float,
|
| 200 |
+
default=SamplingParam.guidance_scale,
|
| 201 |
+
help="Classifier-free guidance scale",
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--guidance-rescale",
|
| 205 |
+
type=float,
|
| 206 |
+
default=SamplingParam.guidance_rescale,
|
| 207 |
+
help="Guidance rescale factor",
|
| 208 |
+
)
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--boundary-ratio",
|
| 211 |
+
type=float,
|
| 212 |
+
default=SamplingParam.boundary_ratio,
|
| 213 |
+
help="Boundary timestep ratio",
|
| 214 |
+
)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--save-video",
|
| 217 |
+
action="store_true",
|
| 218 |
+
default=SamplingParam.save_video,
|
| 219 |
+
help="Whether to save the video to disk",
|
| 220 |
+
)
|
| 221 |
+
parser.add_argument(
|
| 222 |
+
"--no-save-video",
|
| 223 |
+
action="store_false",
|
| 224 |
+
dest="save_video",
|
| 225 |
+
help="Don't save the video to disk",
|
| 226 |
+
)
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
"--return-frames",
|
| 229 |
+
action="store_true",
|
| 230 |
+
default=False,
|
| 231 |
+
help="Whether to return the raw frames",
|
| 232 |
+
)
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--image-path",
|
| 235 |
+
type=str,
|
| 236 |
+
default=SamplingParam.image_path,
|
| 237 |
+
help="Path to input image for image-to-video generation",
|
| 238 |
+
)
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
"--video-path",
|
| 241 |
+
type=str,
|
| 242 |
+
default=SamplingParam.video_path,
|
| 243 |
+
help="Path to input video for video-to-video generation",
|
| 244 |
+
)
|
| 245 |
+
parser.add_argument(
|
| 246 |
+
"--refine-from",
|
| 247 |
+
type=str,
|
| 248 |
+
default=SamplingParam.refine_from,
|
| 249 |
+
help="Path to stage1 video for refinement (LongCat 480p->720p)",
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--t-thresh",
|
| 253 |
+
type=float,
|
| 254 |
+
default=SamplingParam.t_thresh,
|
| 255 |
+
help="Threshold for timestep scheduling in refinement (default: 0.5)",
|
| 256 |
+
)
|
| 257 |
+
parser.add_argument(
|
| 258 |
+
"--spatial-refine-only",
|
| 259 |
+
action=StoreBoolean,
|
| 260 |
+
default=SamplingParam.spatial_refine_only,
|
| 261 |
+
help="Only perform spatial super-resolution (no temporal doubling)",
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--num-cond-frames",
|
| 265 |
+
type=int,
|
| 266 |
+
default=SamplingParam.num_cond_frames,
|
| 267 |
+
help="Number of conditioning frames for refinement",
|
| 268 |
+
)
|
| 269 |
+
parser.add_argument(
|
| 270 |
+
"--moba-config-path",
|
| 271 |
+
type=str,
|
| 272 |
+
default=None,
|
| 273 |
+
help="Path to a JSON file containing V-MoBA specific configurations.",
|
| 274 |
+
)
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
"--return-trajectory-latents",
|
| 277 |
+
action="store_true",
|
| 278 |
+
default=SamplingParam.return_trajectory_latents,
|
| 279 |
+
help="Whether to return the trajectory",
|
| 280 |
+
)
|
| 281 |
+
parser.add_argument(
|
| 282 |
+
"--return-trajectory-decoded",
|
| 283 |
+
action="store_true",
|
| 284 |
+
default=SamplingParam.return_trajectory_decoded,
|
| 285 |
+
help="Whether to return the decoded trajectory",
|
| 286 |
+
)
|
| 287 |
+
return parser
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
@dataclass
|
| 291 |
+
class CacheParams:
|
| 292 |
+
cache_type: str = "none"
|
standalone_inference/overlay_files/fastvideo/configs/sample/wan.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
from fastvideo.configs.sample.base import SamplingParam
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class WanT2V_1_3B_SamplingParam(SamplingParam):
|
| 9 |
+
# Video parameters
|
| 10 |
+
height: int = 480
|
| 11 |
+
width: int = 832
|
| 12 |
+
num_frames: int = 81
|
| 13 |
+
fps: int = 16
|
| 14 |
+
|
| 15 |
+
# Denoising stage
|
| 16 |
+
guidance_scale: float = 3.0
|
| 17 |
+
negative_prompt: str = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
| 18 |
+
num_inference_steps: int = 50
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class WanT2V_14B_SamplingParam(SamplingParam):
|
| 23 |
+
# Video parameters
|
| 24 |
+
height: int = 720
|
| 25 |
+
width: int = 1280
|
| 26 |
+
num_frames: int = 81
|
| 27 |
+
fps: int = 16
|
| 28 |
+
|
| 29 |
+
# Denoising stage
|
| 30 |
+
guidance_scale: float = 5.0
|
| 31 |
+
negative_prompt: str = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
| 32 |
+
num_inference_steps: int = 50
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class WanI2V_14B_480P_SamplingParam(WanT2V_1_3B_SamplingParam):
|
| 37 |
+
# Denoising stage
|
| 38 |
+
guidance_scale: float = 5.0
|
| 39 |
+
num_inference_steps: int = 40
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class WanI2V_14B_720P_SamplingParam(WanT2V_14B_SamplingParam):
|
| 44 |
+
# Denoising stage
|
| 45 |
+
guidance_scale: float = 5.0
|
| 46 |
+
num_inference_steps: int = 40
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class FastWanT2V480P_SamplingParam(WanT2V_1_3B_SamplingParam):
|
| 51 |
+
# DMD parameters
|
| 52 |
+
# dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522])
|
| 53 |
+
num_inference_steps: int = 3
|
| 54 |
+
num_frames: int = 61
|
| 55 |
+
height: int = 448
|
| 56 |
+
width: int = 832
|
| 57 |
+
fps: int = 16
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# =============================================
|
| 61 |
+
# ============= Wan2.1 Fun Models =============
|
| 62 |
+
# =============================================
|
| 63 |
+
@dataclass
|
| 64 |
+
class Wan2_1_Fun_1_3B_InP_SamplingParam(SamplingParam):
|
| 65 |
+
"""Sampling parameters for Wan2.1 Fun 1.3B InP model."""
|
| 66 |
+
height: int = 480
|
| 67 |
+
width: int = 832
|
| 68 |
+
num_frames: int = 81
|
| 69 |
+
fps: int = 16
|
| 70 |
+
negative_prompt: str | None = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
| 71 |
+
guidance_scale: float = 6.0
|
| 72 |
+
num_inference_steps: int = 50
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class Wan2_1_Fun_1_3B_Control_SamplingParam(SamplingParam):
|
| 77 |
+
fps: int = 16
|
| 78 |
+
num_frames: int = 49
|
| 79 |
+
height: int = 832
|
| 80 |
+
width: int = 480
|
| 81 |
+
guidance_scale: float = 6.0
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# =============================================
|
| 85 |
+
# ============= Wan2.2 TI2V Models =============
|
| 86 |
+
# =============================================
|
| 87 |
+
@dataclass
|
| 88 |
+
class Wan2_2_Base_SamplingParam(SamplingParam):
|
| 89 |
+
"""Sampling parameters for Wan2.2 TI2V 5B model."""
|
| 90 |
+
negative_prompt: str | None = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@dataclass
|
| 94 |
+
class Wan2_2_TI2V_5B_SamplingParam(Wan2_2_Base_SamplingParam):
|
| 95 |
+
"""Sampling parameters for Wan2.2 TI2V 5B model."""
|
| 96 |
+
height: int = 704
|
| 97 |
+
width: int = 1280
|
| 98 |
+
num_frames: int = 121
|
| 99 |
+
fps: int = 24
|
| 100 |
+
guidance_scale: float = 5.0
|
| 101 |
+
num_inference_steps: int = 50
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@dataclass
|
| 105 |
+
class Wan2_2_T2V_A14B_SamplingParam(Wan2_2_Base_SamplingParam):
|
| 106 |
+
guidance_scale: float = 4.0 # high_noise
|
| 107 |
+
guidance_scale_2: float = 3.0 # low_noise
|
| 108 |
+
num_inference_steps: int = 40
|
| 109 |
+
fps: int = 16
|
| 110 |
+
# NOTE(will): default boundary timestep is tracked by PipelineConfig, but
|
| 111 |
+
# can be overridden during sampling
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@dataclass
|
| 115 |
+
class Wan2_2_I2V_A14B_SamplingParam(Wan2_2_Base_SamplingParam):
|
| 116 |
+
guidance_scale: float = 3.5 # high_noise
|
| 117 |
+
guidance_scale_2: float = 3.5 # low_noise
|
| 118 |
+
num_inference_steps: int = 40
|
| 119 |
+
fps: int = 16
|
| 120 |
+
# NOTE(will): default boundary timestep is tracked by PipelineConfig, but
|
| 121 |
+
# can be overridden during sampling
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@dataclass
|
| 125 |
+
class Wan2_2_Fun_A14B_Control_SamplingParam(Wan2_1_Fun_1_3B_Control_SamplingParam):
|
| 126 |
+
num_frames: int = 81
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# =============================================
|
| 130 |
+
# ============= Causal Self-Forcing =============
|
| 131 |
+
# =============================================
|
| 132 |
+
@dataclass
|
| 133 |
+
class SelfForcingWan2_1_T2V_1_3B_480P_SamplingParam(Wan2_1_Fun_1_3B_InP_SamplingParam):
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@dataclass
|
| 138 |
+
class SelfForcingWan2_2_T2V_A14B_480P_SamplingParam(Wan2_2_T2V_A14B_SamplingParam):
|
| 139 |
+
num_inference_steps: int = 8
|
| 140 |
+
num_frames: int = 81
|
| 141 |
+
height: int = 448
|
| 142 |
+
width: int = 832
|
| 143 |
+
fps: int = 16
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@dataclass
|
| 147 |
+
class MatrixGame2_SamplingParam(SamplingParam):
|
| 148 |
+
height: int = 352
|
| 149 |
+
width: int = 640
|
| 150 |
+
num_frames: int = 57
|
| 151 |
+
fps: int = 25
|
| 152 |
+
guidance_scale: float = 1.0
|
| 153 |
+
num_inference_steps: int = 3
|
| 154 |
+
negative_prompt: str | None = None
|
standalone_inference/overlay_files/fastvideo/configs/wan_1.3B_t2v_pipeline.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"embedded_cfg_scale": 6.0,
|
| 3 |
+
"flow_shift": 3,
|
| 4 |
+
"dit_cpu_offload": true,
|
| 5 |
+
"disable_autocast": false,
|
| 6 |
+
"precision": "bf16",
|
| 7 |
+
"vae_precision": "fp32",
|
| 8 |
+
"vae_tiling": false,
|
| 9 |
+
"vae_sp": false,
|
| 10 |
+
"vae_config": {
|
| 11 |
+
"load_encoder": false,
|
| 12 |
+
"load_decoder": true,
|
| 13 |
+
"tile_sample_min_height": 256,
|
| 14 |
+
"tile_sample_min_width": 256,
|
| 15 |
+
"tile_sample_min_num_frames": 16,
|
| 16 |
+
"tile_sample_stride_height": 192,
|
| 17 |
+
"tile_sample_stride_width": 192,
|
| 18 |
+
"tile_sample_stride_num_frames": 12,
|
| 19 |
+
"blend_num_frames": 8,
|
| 20 |
+
"use_tiling": false,
|
| 21 |
+
"use_temporal_tiling": false,
|
| 22 |
+
"use_parallel_tiling": false,
|
| 23 |
+
"use_feature_cache": true
|
| 24 |
+
},
|
| 25 |
+
"dit_config": {
|
| 26 |
+
"prefix": "Wan",
|
| 27 |
+
"quant_config": null
|
| 28 |
+
},
|
| 29 |
+
"text_encoder_precisions": [
|
| 30 |
+
"fp32"
|
| 31 |
+
],
|
| 32 |
+
"text_encoder_configs": [
|
| 33 |
+
{
|
| 34 |
+
"prefix": "t5",
|
| 35 |
+
"quant_config": null,
|
| 36 |
+
"lora_config": null
|
| 37 |
+
}
|
| 38 |
+
],
|
| 39 |
+
"enable_torch_compile": false
|
| 40 |
+
}
|
standalone_inference/overlay_files/fastvideo/entrypoints/cli/generate.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/serve.py
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import dataclasses
|
| 6 |
+
import os
|
| 7 |
+
from typing import cast
|
| 8 |
+
|
| 9 |
+
from fastvideo import VideoGenerator
|
| 10 |
+
from fastvideo.configs.sample.base import SamplingParam
|
| 11 |
+
from fastvideo.entrypoints.cli.cli_types import CLISubcommand
|
| 12 |
+
from fastvideo.entrypoints.cli.utils import RaiseNotImplementedAction
|
| 13 |
+
from fastvideo.fastvideo_args import FastVideoArgs
|
| 14 |
+
from fastvideo.logger import init_logger
|
| 15 |
+
from fastvideo.utils import FlexibleArgumentParser
|
| 16 |
+
|
| 17 |
+
logger = init_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class GenerateSubcommand(CLISubcommand):
|
| 21 |
+
"""The `generate` subcommand for the FastVideo CLI"""
|
| 22 |
+
|
| 23 |
+
def __init__(self) -> None:
|
| 24 |
+
self.name = "generate"
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.init_arg_names = self._get_init_arg_names()
|
| 27 |
+
self.generation_arg_names = self._get_generation_arg_names()
|
| 28 |
+
|
| 29 |
+
def _get_init_arg_names(self) -> list[str]:
|
| 30 |
+
"""Get names of arguments for VideoGenerator initialization"""
|
| 31 |
+
return ["num_gpus", "tp_size", "sp_size", "model_path"]
|
| 32 |
+
|
| 33 |
+
def _get_generation_arg_names(self) -> list[str]:
|
| 34 |
+
"""Get names of arguments for generate_video method"""
|
| 35 |
+
return [field.name for field in dataclasses.fields(SamplingParam)]
|
| 36 |
+
|
| 37 |
+
def cmd(self, args: argparse.Namespace) -> None:
|
| 38 |
+
excluded_args = ['subparser', 'config', 'dispatch_function']
|
| 39 |
+
|
| 40 |
+
provided_args = {}
|
| 41 |
+
for k, v in vars(args).items():
|
| 42 |
+
if (k not in excluded_args and v is not None and hasattr(args, '_provided') and k in args._provided):
|
| 43 |
+
provided_args[k] = v
|
| 44 |
+
|
| 45 |
+
if 'model_path' in vars(args) and args.model_path is not None:
|
| 46 |
+
provided_args['model_path'] = args.model_path
|
| 47 |
+
|
| 48 |
+
if 'prompt' in vars(args) and args.prompt is not None:
|
| 49 |
+
provided_args['prompt'] = args.prompt
|
| 50 |
+
|
| 51 |
+
merged_args = {**provided_args}
|
| 52 |
+
|
| 53 |
+
logger.info('CLI Args: %s', merged_args)
|
| 54 |
+
|
| 55 |
+
if 'model_path' not in merged_args or not merged_args['model_path']:
|
| 56 |
+
raise ValueError("model_path must be provided either in config file or via --model-path")
|
| 57 |
+
|
| 58 |
+
# Check if either prompt or prompt_txt is provided
|
| 59 |
+
has_prompt = 'prompt' in merged_args and merged_args['prompt']
|
| 60 |
+
has_prompt_txt = 'prompt_txt' in merged_args and merged_args['prompt_txt']
|
| 61 |
+
|
| 62 |
+
if not (has_prompt or has_prompt_txt):
|
| 63 |
+
raise ValueError("Either prompt or prompt_txt must be provided")
|
| 64 |
+
|
| 65 |
+
if has_prompt and has_prompt_txt:
|
| 66 |
+
raise ValueError("Cannot provide both 'prompt' and 'prompt_txt'. Use only one of them.")
|
| 67 |
+
|
| 68 |
+
init_args = {k: v for k, v in merged_args.items() if k not in self.generation_arg_names}
|
| 69 |
+
generation_args = {k: v for k, v in merged_args.items() if k in self.generation_arg_names}
|
| 70 |
+
generation_args.setdefault("return_frames", False)
|
| 71 |
+
|
| 72 |
+
model_path = init_args.pop('model_path')
|
| 73 |
+
prompt = generation_args.pop('prompt', None)
|
| 74 |
+
|
| 75 |
+
generator = VideoGenerator.from_pretrained(model_path=model_path, **init_args)
|
| 76 |
+
|
| 77 |
+
# Call generate_video - it handles both single and batch modes
|
| 78 |
+
generator.generate_video(prompt=prompt, **generation_args)
|
| 79 |
+
|
| 80 |
+
def validate(self, args: argparse.Namespace) -> None:
|
| 81 |
+
"""Validate the arguments for this command"""
|
| 82 |
+
if args.num_gpus is not None and args.num_gpus <= 0:
|
| 83 |
+
raise ValueError("Number of gpus must be positive")
|
| 84 |
+
|
| 85 |
+
if args.config and not os.path.exists(args.config):
|
| 86 |
+
raise ValueError(f"Config file not found: {args.config}")
|
| 87 |
+
|
| 88 |
+
def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
| 89 |
+
generate_parser = subparsers.add_parser(
|
| 90 |
+
"generate",
|
| 91 |
+
help="Run inference on a model",
|
| 92 |
+
usage="fastvideo generate (--model-path MODEL_PATH_OR_ID --prompt PROMPT) | --config CONFIG_FILE [OPTIONS]")
|
| 93 |
+
|
| 94 |
+
generate_parser.add_argument(
|
| 95 |
+
"--config",
|
| 96 |
+
type=str,
|
| 97 |
+
default='',
|
| 98 |
+
required=False,
|
| 99 |
+
help="Read CLI options from a config JSON or YAML file. If provided, --model-path and --prompt are optional."
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
generate_parser = FastVideoArgs.add_cli_args(generate_parser)
|
| 103 |
+
generate_parser = SamplingParam.add_cli_args(generate_parser)
|
| 104 |
+
|
| 105 |
+
generate_parser.add_argument(
|
| 106 |
+
"--text-encoder-configs",
|
| 107 |
+
action=RaiseNotImplementedAction,
|
| 108 |
+
help="JSON array of text encoder configurations (NOT YET IMPLEMENTED)",
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
return cast(FlexibleArgumentParser, generate_parser)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def cmd_init() -> list[CLISubcommand]:
|
| 115 |
+
return [GenerateSubcommand()]
|
standalone_inference/overlay_files/fastvideo/entrypoints/video_generator.py
ADDED
|
@@ -0,0 +1,797 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
VideoGenerator module for FastVideo.
|
| 4 |
+
|
| 5 |
+
This module provides a consolidated interface for generating videos using
|
| 6 |
+
diffusion models.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
import shutil
|
| 12 |
+
import threading
|
| 13 |
+
import time
|
| 14 |
+
import tempfile
|
| 15 |
+
import warnings
|
| 16 |
+
from collections.abc import Mapping
|
| 17 |
+
from copy import deepcopy
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
import imageio
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torchvision
|
| 24 |
+
from einops import rearrange
|
| 25 |
+
|
| 26 |
+
from fastvideo.api.compat import (
|
| 27 |
+
expand_request_prompt_batch,
|
| 28 |
+
generator_config_to_fastvideo_args,
|
| 29 |
+
legacy_from_pretrained_to_config,
|
| 30 |
+
load_generator_config_from_file,
|
| 31 |
+
normalize_generation_request,
|
| 32 |
+
normalize_generator_config,
|
| 33 |
+
request_to_pipeline_overrides,
|
| 34 |
+
request_to_sampling_param,
|
| 35 |
+
)
|
| 36 |
+
from fastvideo.api.results import GenerationResult
|
| 37 |
+
from fastvideo.api.schema import GenerationRequest, GeneratorConfig
|
| 38 |
+
from fastvideo.configs.sample import SamplingParam
|
| 39 |
+
from fastvideo.fastvideo_args import FastVideoArgs
|
| 40 |
+
from fastvideo.logger import init_logger
|
| 41 |
+
from fastvideo.pipelines import ForwardBatch
|
| 42 |
+
from fastvideo.utils import align_to, shallow_asdict
|
| 43 |
+
from fastvideo.worker.executor import Executor
|
| 44 |
+
|
| 45 |
+
logger = init_logger(__name__)
|
| 46 |
+
|
| 47 |
+
_FROM_PRETRAINED_CONVENIENCE_KWARGS = frozenset({
|
| 48 |
+
"num_gpus",
|
| 49 |
+
"revision",
|
| 50 |
+
"trust_remote_code",
|
| 51 |
+
"distributed_executor_backend",
|
| 52 |
+
"tp_size",
|
| 53 |
+
"sp_size",
|
| 54 |
+
"hsdp_replicate_dim",
|
| 55 |
+
"hsdp_shard_dim",
|
| 56 |
+
"dist_timeout",
|
| 57 |
+
"use_fsdp_inference",
|
| 58 |
+
"disable_autocast",
|
| 59 |
+
"enable_stage_verification",
|
| 60 |
+
"dit_cpu_offload",
|
| 61 |
+
"dit_layerwise_offload",
|
| 62 |
+
"text_encoder_cpu_offload",
|
| 63 |
+
"image_encoder_cpu_offload",
|
| 64 |
+
"vae_cpu_offload",
|
| 65 |
+
"pin_cpu_memory",
|
| 66 |
+
"enable_torch_compile",
|
| 67 |
+
"torch_compile_kwargs",
|
| 68 |
+
"transformer_quant",
|
| 69 |
+
})
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _infer_latent_batch_size(batch: ForwardBatch) -> int:
|
| 73 |
+
if isinstance(batch.prompt, list):
|
| 74 |
+
latent_batch_size = len(batch.prompt)
|
| 75 |
+
elif batch.prompt is not None:
|
| 76 |
+
latent_batch_size = 1
|
| 77 |
+
elif batch.prompt_embeds is not None and len(batch.prompt_embeds) > 0:
|
| 78 |
+
latent_batch_size = batch.prompt_embeds[0].shape[0]
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError("Cannot infer batch size from batch; no prompt or prompt_embeds found")
|
| 81 |
+
latent_batch_size *= batch.num_videos_per_prompt
|
| 82 |
+
return latent_batch_size
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class VideoGenerator:
|
| 86 |
+
"""
|
| 87 |
+
A unified class for generating videos using diffusion models.
|
| 88 |
+
|
| 89 |
+
This class provides a simple interface for video generation with rich
|
| 90 |
+
customization options, similar to popular frameworks like HF Diffusers.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
fastvideo_args: FastVideoArgs,
|
| 96 |
+
executor_class: type[Executor],
|
| 97 |
+
log_stats: bool,
|
| 98 |
+
*,
|
| 99 |
+
log_queue=None,
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
Initialize the video generator.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
fastvideo_args: The inference arguments
|
| 106 |
+
executor_class: The executor class to use for inference
|
| 107 |
+
log_stats: Whether to log statistics
|
| 108 |
+
log_queue: Optional multiprocessing.Queue to forward worker logs to
|
| 109 |
+
"""
|
| 110 |
+
self.config: GeneratorConfig | None = None
|
| 111 |
+
self.fastvideo_args = fastvideo_args
|
| 112 |
+
self.executor = executor_class(fastvideo_args, log_queue=log_queue)
|
| 113 |
+
|
| 114 |
+
@classmethod
|
| 115 |
+
def from_pretrained(
|
| 116 |
+
cls,
|
| 117 |
+
model_path: str | GeneratorConfig | Mapping[str, Any] | None = None,
|
| 118 |
+
**kwargs,
|
| 119 |
+
) -> "VideoGenerator":
|
| 120 |
+
"""
|
| 121 |
+
Create a video generator from a pretrained model.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
model_path: Path or identifier for the pretrained model
|
| 125 |
+
pipeline_config: Pipeline config to use for inference
|
| 126 |
+
**kwargs: Additional arguments to customize model loading, set any FastVideoArgs or PipelineConfig attributes here.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
The created video generator
|
| 130 |
+
|
| 131 |
+
Priority level: Default pipeline config < User's pipeline config < User's kwargs
|
| 132 |
+
|
| 133 |
+
Stable convenience kwargs remain supported here for common engine and
|
| 134 |
+
offload settings. Advanced model- or pipeline-specific options should
|
| 135 |
+
move to VideoGenerator.from_config(...).
|
| 136 |
+
"""
|
| 137 |
+
log_queue = kwargs.pop("log_queue", None)
|
| 138 |
+
typed_config = kwargs.pop("config", None)
|
| 139 |
+
if typed_config is not None:
|
| 140 |
+
if model_path is not None:
|
| 141 |
+
raise TypeError("Pass either model_path or config to from_pretrained, not both")
|
| 142 |
+
if kwargs:
|
| 143 |
+
unexpected = ", ".join(sorted(kwargs))
|
| 144 |
+
raise TypeError(f"Unexpected keyword arguments with config: {unexpected}")
|
| 145 |
+
return cls.from_config(typed_config, log_queue=log_queue)
|
| 146 |
+
|
| 147 |
+
if isinstance(model_path, GeneratorConfig | Mapping):
|
| 148 |
+
if kwargs:
|
| 149 |
+
unexpected = ", ".join(sorted(kwargs))
|
| 150 |
+
raise TypeError(f"Unexpected keyword arguments with typed config: {unexpected}")
|
| 151 |
+
return cls.from_config(model_path, log_queue=log_queue)
|
| 152 |
+
|
| 153 |
+
if model_path is None:
|
| 154 |
+
raise TypeError("model_path or config is required")
|
| 155 |
+
|
| 156 |
+
legacy_only_kwargs = sorted(set(kwargs) - _FROM_PRETRAINED_CONVENIENCE_KWARGS)
|
| 157 |
+
if legacy_only_kwargs:
|
| 158 |
+
warnings.warn(
|
| 159 |
+
"VideoGenerator.from_pretrained(...) received legacy-only kwargs "
|
| 160 |
+
f"({', '.join(legacy_only_kwargs)}); prefer VideoGenerator.from_config(...) "
|
| 161 |
+
"for advanced configuration.",
|
| 162 |
+
DeprecationWarning,
|
| 163 |
+
stacklevel=2,
|
| 164 |
+
)
|
| 165 |
+
return cls.from_config(
|
| 166 |
+
legacy_from_pretrained_to_config(model_path, kwargs),
|
| 167 |
+
log_queue=log_queue,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
@classmethod
|
| 171 |
+
def from_config(
|
| 172 |
+
cls,
|
| 173 |
+
config: GeneratorConfig | Mapping[str, Any],
|
| 174 |
+
*,
|
| 175 |
+
log_queue=None,
|
| 176 |
+
) -> "VideoGenerator":
|
| 177 |
+
normalized = normalize_generator_config(config)
|
| 178 |
+
fastvideo_args = generator_config_to_fastvideo_args(normalized)
|
| 179 |
+
generator = cls.from_fastvideo_args(fastvideo_args, log_queue=log_queue)
|
| 180 |
+
generator.config = normalized
|
| 181 |
+
return generator
|
| 182 |
+
|
| 183 |
+
@classmethod
|
| 184 |
+
def from_file(
|
| 185 |
+
cls,
|
| 186 |
+
path: str,
|
| 187 |
+
overrides: list[str] | Mapping[str, Any] | None = None,
|
| 188 |
+
*,
|
| 189 |
+
log_queue=None,
|
| 190 |
+
) -> "VideoGenerator":
|
| 191 |
+
return cls.from_config(
|
| 192 |
+
load_generator_config_from_file(path, overrides=overrides),
|
| 193 |
+
log_queue=log_queue,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
@classmethod
|
| 197 |
+
def from_fastvideo_args(
|
| 198 |
+
cls,
|
| 199 |
+
fastvideo_args: FastVideoArgs,
|
| 200 |
+
*,
|
| 201 |
+
log_queue=None,
|
| 202 |
+
) -> "VideoGenerator":
|
| 203 |
+
"""
|
| 204 |
+
Create a video generator with the specified arguments.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
fastvideo_args: The inference arguments
|
| 208 |
+
log_queue: Optional multiprocessing.Queue to forward worker logs to
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
The created video generator
|
| 212 |
+
"""
|
| 213 |
+
# Initialize distributed environment if needed
|
| 214 |
+
# initialize_distributed_and_parallelism(fastvideo_args)
|
| 215 |
+
|
| 216 |
+
executor_class = Executor.get_class(fastvideo_args)
|
| 217 |
+
return cls(
|
| 218 |
+
fastvideo_args=fastvideo_args,
|
| 219 |
+
executor_class=executor_class,
|
| 220 |
+
log_stats=False, # TODO: implement
|
| 221 |
+
log_queue=log_queue,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
def generate(
|
| 225 |
+
self,
|
| 226 |
+
request: GenerationRequest | Mapping[str, Any],
|
| 227 |
+
*,
|
| 228 |
+
log_queue=None,
|
| 229 |
+
) -> GenerationResult | list[GenerationResult]:
|
| 230 |
+
"""
|
| 231 |
+
Generate video or image outputs from a typed inference request.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
request: A `GenerationRequest` instance or a mapping that can be
|
| 235 |
+
parsed into one. This is the primary public inference
|
| 236 |
+
entrypoint for the typed API.
|
| 237 |
+
log_queue: Optional multiprocessing.Queue to forward worker logs to
|
| 238 |
+
during this request.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
A `GenerationResult` for single-request generation, or a list of
|
| 242 |
+
`GenerationResult` objects when the request expands into multiple
|
| 243 |
+
prompts.
|
| 244 |
+
"""
|
| 245 |
+
normalized_request = normalize_generation_request(request)
|
| 246 |
+
if log_queue:
|
| 247 |
+
self.executor.set_log_queue(log_queue)
|
| 248 |
+
|
| 249 |
+
try:
|
| 250 |
+
return self._generate_request_impl(normalized_request)
|
| 251 |
+
finally:
|
| 252 |
+
if log_queue:
|
| 253 |
+
self.executor.clear_log_queue()
|
| 254 |
+
|
| 255 |
+
def generate_video(
|
| 256 |
+
self,
|
| 257 |
+
prompt: str | None = None,
|
| 258 |
+
sampling_param: SamplingParam | None = None,
|
| 259 |
+
# Action control inputs (Matrix-Game)
|
| 260 |
+
mouse_cond: torch.Tensor | None = None,
|
| 261 |
+
keyboard_cond: torch.Tensor | None = None,
|
| 262 |
+
grid_sizes: tuple[int, int, int] | list[int] | torch.Tensor
|
| 263 |
+
| None = None,
|
| 264 |
+
**kwargs,
|
| 265 |
+
) -> dict[str, Any] | list[dict[str, Any]]:
|
| 266 |
+
"""
|
| 267 |
+
Generate a video based on the given prompt.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
prompt: The prompt to use for generation (optional if prompt_txt is provided)
|
| 271 |
+
negative_prompt: The negative prompt to use (overrides the one in fastvideo_args)
|
| 272 |
+
output_path: Path to save the video (overrides the one in fastvideo_args)
|
| 273 |
+
prompt_path: Path to prompt file
|
| 274 |
+
save_video: Whether to save the video to disk
|
| 275 |
+
return_frames: Whether to include raw frames in the result dict
|
| 276 |
+
num_inference_steps: Number of denoising steps (overrides fastvideo_args)
|
| 277 |
+
guidance_scale: Classifier-free guidance scale (overrides fastvideo_args)
|
| 278 |
+
num_frames: Number of frames to generate (overrides fastvideo_args)
|
| 279 |
+
height: Height of generated video (overrides fastvideo_args)
|
| 280 |
+
width: Width of generated video (overrides fastvideo_args)
|
| 281 |
+
fps: Frames per second for saved video (overrides fastvideo_args)
|
| 282 |
+
seed: Random seed for generation (overrides fastvideo_args)
|
| 283 |
+
callback: Callback function called after each step
|
| 284 |
+
callback_steps: Number of steps between each callback
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
A metadata dictionary for single-prompt generation, or a list of
|
| 288 |
+
metadata dictionaries for prompt-file batch generation.
|
| 289 |
+
"""
|
| 290 |
+
log_queue = kwargs.pop("log_queue", None)
|
| 291 |
+
warnings.warn(
|
| 292 |
+
"VideoGenerator.generate_video(...) is deprecated; use "
|
| 293 |
+
"VideoGenerator.generate(request=...) instead.",
|
| 294 |
+
DeprecationWarning,
|
| 295 |
+
stacklevel=2,
|
| 296 |
+
)
|
| 297 |
+
if log_queue:
|
| 298 |
+
self.executor.set_log_queue(log_queue)
|
| 299 |
+
|
| 300 |
+
try:
|
| 301 |
+
return self._generate_video_impl(
|
| 302 |
+
prompt=prompt,
|
| 303 |
+
sampling_param=sampling_param,
|
| 304 |
+
mouse_cond=mouse_cond,
|
| 305 |
+
keyboard_cond=keyboard_cond,
|
| 306 |
+
grid_sizes=grid_sizes,
|
| 307 |
+
**kwargs,
|
| 308 |
+
)
|
| 309 |
+
finally:
|
| 310 |
+
if log_queue:
|
| 311 |
+
self.executor.clear_log_queue()
|
| 312 |
+
|
| 313 |
+
def _generate_request_impl(
|
| 314 |
+
self,
|
| 315 |
+
request: GenerationRequest,
|
| 316 |
+
) -> GenerationResult | list[GenerationResult]:
|
| 317 |
+
if isinstance(request.prompt, list):
|
| 318 |
+
if request.inputs.prompt_path is not None:
|
| 319 |
+
raise ValueError("request.prompt list cannot be combined with request.inputs.prompt_path")
|
| 320 |
+
results: list[GenerationResult] = []
|
| 321 |
+
for index, single_request in enumerate(expand_request_prompt_batch(request)):
|
| 322 |
+
prompt = single_request.prompt
|
| 323 |
+
wrapped = self._generate_single_request(single_request)
|
| 324 |
+
if isinstance(wrapped, list):
|
| 325 |
+
results.extend(wrapped)
|
| 326 |
+
continue
|
| 327 |
+
wrapped.prompt_index = index
|
| 328 |
+
if wrapped.prompt is None and isinstance(prompt, str):
|
| 329 |
+
wrapped.prompt = prompt
|
| 330 |
+
results.append(wrapped)
|
| 331 |
+
return results
|
| 332 |
+
|
| 333 |
+
return self._generate_single_request(request)
|
| 334 |
+
|
| 335 |
+
def _generate_single_request(
|
| 336 |
+
self,
|
| 337 |
+
request: GenerationRequest,
|
| 338 |
+
) -> GenerationResult | list[GenerationResult]:
|
| 339 |
+
fastvideo_args = self.fastvideo_args
|
| 340 |
+
pipeline_overrides = request_to_pipeline_overrides(request)
|
| 341 |
+
if pipeline_overrides:
|
| 342 |
+
fastvideo_args = deepcopy(self.fastvideo_args)
|
| 343 |
+
for key, value in pipeline_overrides.items():
|
| 344 |
+
if not hasattr(fastvideo_args.pipeline_config, key):
|
| 345 |
+
raise ValueError(f"Request field {key!r} is not supported by pipeline config overrides")
|
| 346 |
+
setattr(fastvideo_args.pipeline_config, key, deepcopy(value))
|
| 347 |
+
|
| 348 |
+
sampling_param = request_to_sampling_param(
|
| 349 |
+
request,
|
| 350 |
+
model_path=self.fastvideo_args.model_path,
|
| 351 |
+
)
|
| 352 |
+
result = self._generate_video_impl(
|
| 353 |
+
prompt=request.prompt,
|
| 354 |
+
sampling_param=sampling_param,
|
| 355 |
+
fastvideo_args=fastvideo_args,
|
| 356 |
+
)
|
| 357 |
+
return self._wrap_legacy_result(result)
|
| 358 |
+
|
| 359 |
+
def _generate_video_impl(
|
| 360 |
+
self,
|
| 361 |
+
prompt: str | list[str] | None = None,
|
| 362 |
+
sampling_param: SamplingParam | None = None,
|
| 363 |
+
mouse_cond: torch.Tensor | None = None,
|
| 364 |
+
keyboard_cond: torch.Tensor | None = None,
|
| 365 |
+
grid_sizes: tuple[int, int, int] | list[int] | torch.Tensor
|
| 366 |
+
| None = None,
|
| 367 |
+
fastvideo_args: FastVideoArgs | None = None,
|
| 368 |
+
**kwargs,
|
| 369 |
+
) -> dict[str, Any] | list[np.ndarray] | list[dict[str, Any]]:
|
| 370 |
+
"""Internal implementation of generate_video."""
|
| 371 |
+
if fastvideo_args is None:
|
| 372 |
+
fastvideo_args = self.fastvideo_args
|
| 373 |
+
|
| 374 |
+
# Handle batch processing from text file
|
| 375 |
+
if sampling_param is None:
|
| 376 |
+
sampling_param = SamplingParam.from_pretrained(fastvideo_args.model_path)
|
| 377 |
+
|
| 378 |
+
# Add action control inputs to kwargs if provided
|
| 379 |
+
if mouse_cond is not None:
|
| 380 |
+
kwargs['mouse_cond'] = mouse_cond
|
| 381 |
+
if keyboard_cond is not None:
|
| 382 |
+
kwargs['keyboard_cond'] = keyboard_cond
|
| 383 |
+
if grid_sizes is not None:
|
| 384 |
+
kwargs['grid_sizes'] = grid_sizes
|
| 385 |
+
|
| 386 |
+
sampling_param.update(kwargs)
|
| 387 |
+
|
| 388 |
+
if fastvideo_args.prompt_txt is not None or sampling_param.prompt_path is not None:
|
| 389 |
+
prompt_txt_path = sampling_param.prompt_path or fastvideo_args.prompt_txt
|
| 390 |
+
if not prompt_txt_path or not os.path.exists(prompt_txt_path):
|
| 391 |
+
raise FileNotFoundError(f"Prompt text file not found: {prompt_txt_path}")
|
| 392 |
+
|
| 393 |
+
# Read prompts from file
|
| 394 |
+
with open(prompt_txt_path, encoding='utf-8') as f:
|
| 395 |
+
prompts = [line.strip() for line in f if line.strip()]
|
| 396 |
+
|
| 397 |
+
if not prompts:
|
| 398 |
+
raise ValueError(f"No prompts found in file: {prompt_txt_path}")
|
| 399 |
+
|
| 400 |
+
logger.info("Found %d prompts in %s", len(prompts), prompt_txt_path)
|
| 401 |
+
|
| 402 |
+
results = []
|
| 403 |
+
for i, batch_prompt in enumerate(prompts):
|
| 404 |
+
logger.info("Processing prompt %d/%d: %s...", i + 1, len(prompts), batch_prompt[:100])
|
| 405 |
+
try:
|
| 406 |
+
# Generate video for this prompt using the same logic below
|
| 407 |
+
output_path = self._prepare_output_path(sampling_param.output_path, batch_prompt)
|
| 408 |
+
kwargs["output_path"] = output_path
|
| 409 |
+
result = self._generate_single_video(
|
| 410 |
+
prompt=batch_prompt,
|
| 411 |
+
sampling_param=sampling_param,
|
| 412 |
+
fastvideo_args=fastvideo_args,
|
| 413 |
+
**kwargs,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Add prompt info to result
|
| 417 |
+
result["prompt_index"] = i
|
| 418 |
+
result["prompt"] = batch_prompt
|
| 419 |
+
|
| 420 |
+
results.append(result)
|
| 421 |
+
logger.info("Successfully generated video for prompt %d", i + 1)
|
| 422 |
+
|
| 423 |
+
except Exception as e:
|
| 424 |
+
logger.error("Failed to generate video for prompt %d: %s", i + 1, e)
|
| 425 |
+
continue
|
| 426 |
+
|
| 427 |
+
logger.info("Completed batch processing. Generated %d videos successfully.", len(results))
|
| 428 |
+
return results
|
| 429 |
+
|
| 430 |
+
# Single prompt generation (original behavior)
|
| 431 |
+
if prompt is None:
|
| 432 |
+
raise ValueError("Either prompt or prompt_txt must be provided")
|
| 433 |
+
if not isinstance(prompt, str):
|
| 434 |
+
raise ValueError("Single-prompt generation expects a string prompt")
|
| 435 |
+
output_path = self._prepare_output_path(sampling_param.output_path, prompt)
|
| 436 |
+
kwargs["output_path"] = output_path
|
| 437 |
+
return self._generate_single_video(
|
| 438 |
+
prompt=prompt,
|
| 439 |
+
sampling_param=sampling_param,
|
| 440 |
+
fastvideo_args=fastvideo_args,
|
| 441 |
+
**kwargs,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
def _is_image_workload(self) -> bool:
|
| 445 |
+
"""Return True when the workload produces a single image (t2i, i2i …)."""
|
| 446 |
+
args = getattr(self, "fastvideo_args", None)
|
| 447 |
+
if args is None:
|
| 448 |
+
return False
|
| 449 |
+
return args.workload_type.value.endswith("2i")
|
| 450 |
+
|
| 451 |
+
def _prepare_output_path(
|
| 452 |
+
self,
|
| 453 |
+
output_path: str,
|
| 454 |
+
prompt: str,
|
| 455 |
+
) -> str:
|
| 456 |
+
"""Build a unique, sanitized output file path.
|
| 457 |
+
|
| 458 |
+
The file extension is chosen automatically based on the workload type:
|
| 459 |
+
``.png`` for image workloads (``t2i``, ``i2i``, …) and ``.mp4`` for
|
| 460 |
+
video workloads.
|
| 461 |
+
|
| 462 |
+
- If ``output_path`` already carries the correct extension, treat it
|
| 463 |
+
as a file path.
|
| 464 |
+
- Otherwise, treat ``output_path`` as a directory and derive the
|
| 465 |
+
filename from the prompt.
|
| 466 |
+
- Invalid filename characters are removed; if the name changes, a
|
| 467 |
+
warning is logged.
|
| 468 |
+
- If the target path already exists, a numeric suffix is appended.
|
| 469 |
+
"""
|
| 470 |
+
target_ext = ".png" if self._is_image_workload() else ".mp4"
|
| 471 |
+
|
| 472 |
+
def _sanitize_filename_component(name: str) -> str:
|
| 473 |
+
# Remove characters invalid on common filesystems, strip spaces/dots
|
| 474 |
+
sanitized = re.sub(r'[\\/:*?"<>|]', '', name)
|
| 475 |
+
sanitized = sanitized.strip().strip('.')
|
| 476 |
+
sanitized = re.sub(r'\s+', ' ', sanitized)
|
| 477 |
+
return sanitized or "output"
|
| 478 |
+
|
| 479 |
+
base_path, extension = os.path.splitext(output_path)
|
| 480 |
+
extension_lower = extension.lower()
|
| 481 |
+
|
| 482 |
+
if extension_lower == target_ext:
|
| 483 |
+
output_dir = os.path.dirname(output_path)
|
| 484 |
+
base_name = os.path.basename(base_path) # filename without extension
|
| 485 |
+
sanitized_base = _sanitize_filename_component(base_name)
|
| 486 |
+
if sanitized_base != base_name:
|
| 487 |
+
logger.warning(
|
| 488 |
+
"The output name '%s' contained invalid characters. "
|
| 489 |
+
"It has been renamed to '%s%s'",
|
| 490 |
+
os.path.basename(output_path),
|
| 491 |
+
sanitized_base,
|
| 492 |
+
target_ext,
|
| 493 |
+
)
|
| 494 |
+
out_name = f"{sanitized_base}{target_ext}"
|
| 495 |
+
else:
|
| 496 |
+
# Treat as directory; inform if an unexpected extension was
|
| 497 |
+
# provided.
|
| 498 |
+
if extension:
|
| 499 |
+
logger.info(
|
| 500 |
+
"Output path '%s' has extension '%s' which does not "
|
| 501 |
+
"match the target '%s'; treating it as a directory",
|
| 502 |
+
output_path,
|
| 503 |
+
extension,
|
| 504 |
+
target_ext,
|
| 505 |
+
)
|
| 506 |
+
output_dir = output_path
|
| 507 |
+
prompt_component = _sanitize_filename_component(prompt[:100])
|
| 508 |
+
out_name = f"{prompt_component}{target_ext}"
|
| 509 |
+
|
| 510 |
+
if output_dir:
|
| 511 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 512 |
+
|
| 513 |
+
new_output_path = os.path.join(output_dir, out_name)
|
| 514 |
+
counter = 1
|
| 515 |
+
while os.path.exists(new_output_path):
|
| 516 |
+
name_part, ext_part = os.path.splitext(out_name)
|
| 517 |
+
new_name = f"{name_part}_{counter}{ext_part}"
|
| 518 |
+
new_output_path = os.path.join(output_dir, new_name)
|
| 519 |
+
counter += 1
|
| 520 |
+
return new_output_path
|
| 521 |
+
|
| 522 |
+
def _generate_single_video(
|
| 523 |
+
self,
|
| 524 |
+
prompt: str,
|
| 525 |
+
sampling_param: SamplingParam | None = None,
|
| 526 |
+
fastvideo_args: FastVideoArgs | None = None,
|
| 527 |
+
**kwargs,
|
| 528 |
+
) -> dict[str, Any]:
|
| 529 |
+
"""Internal method for single video generation"""
|
| 530 |
+
if fastvideo_args is None:
|
| 531 |
+
fastvideo_args = self.fastvideo_args
|
| 532 |
+
|
| 533 |
+
# Validate inputs
|
| 534 |
+
if not isinstance(prompt, str):
|
| 535 |
+
raise TypeError(f"`prompt` must be a string, but got {type(prompt)}")
|
| 536 |
+
prompt = prompt.strip()
|
| 537 |
+
sampling_param = deepcopy(sampling_param)
|
| 538 |
+
output_path = kwargs["output_path"]
|
| 539 |
+
sampling_param.prompt = prompt
|
| 540 |
+
# Process negative prompt
|
| 541 |
+
if sampling_param.negative_prompt is not None:
|
| 542 |
+
sampling_param.negative_prompt = sampling_param.negative_prompt.strip()
|
| 543 |
+
|
| 544 |
+
# Validate dimensions
|
| 545 |
+
if (sampling_param.height <= 0 or sampling_param.width <= 0 or sampling_param.num_frames <= 0):
|
| 546 |
+
raise ValueError(f"Height, width, and num_frames must be positive integers, got "
|
| 547 |
+
f"height={sampling_param.height}, width={sampling_param.width}, "
|
| 548 |
+
f"num_frames={sampling_param.num_frames}")
|
| 549 |
+
|
| 550 |
+
# Calculate sizes
|
| 551 |
+
target_height = align_to(sampling_param.height, 16)
|
| 552 |
+
target_width = align_to(sampling_param.width, 16)
|
| 553 |
+
|
| 554 |
+
# Calculate latent sizes
|
| 555 |
+
latents_size = [(sampling_param.num_frames - 1) // 4 + 1, sampling_param.height // 8, sampling_param.width // 8]
|
| 556 |
+
n_tokens = latents_size[0] * latents_size[1] * latents_size[2]
|
| 557 |
+
|
| 558 |
+
# Log parameters
|
| 559 |
+
debug_str = f"""
|
| 560 |
+
height: {target_height}
|
| 561 |
+
width: {target_width}
|
| 562 |
+
video_length: {sampling_param.num_frames}
|
| 563 |
+
prompt: {sampling_param.prompt}
|
| 564 |
+
image_path: {sampling_param.image_path}
|
| 565 |
+
neg_prompt: {sampling_param.negative_prompt}
|
| 566 |
+
seed: {sampling_param.seed}
|
| 567 |
+
infer_steps: {sampling_param.num_inference_steps}
|
| 568 |
+
num_videos_per_prompt: {sampling_param.num_videos_per_prompt}
|
| 569 |
+
guidance_scale: {sampling_param.guidance_scale}
|
| 570 |
+
n_tokens: {n_tokens}
|
| 571 |
+
flow_shift: {fastvideo_args.pipeline_config.flow_shift}
|
| 572 |
+
embedded_guidance_scale: {fastvideo_args.pipeline_config.embedded_cfg_scale}
|
| 573 |
+
save_video: {sampling_param.save_video}
|
| 574 |
+
output_path: {output_path}
|
| 575 |
+
""" # type: ignore[attr-defined]
|
| 576 |
+
logger.info(debug_str)
|
| 577 |
+
|
| 578 |
+
# Prepare batch
|
| 579 |
+
batch = ForwardBatch(
|
| 580 |
+
**shallow_asdict(sampling_param),
|
| 581 |
+
eta=0.0,
|
| 582 |
+
n_tokens=n_tokens,
|
| 583 |
+
VSA_sparsity=fastvideo_args.VSA_sparsity,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
# Run inference
|
| 587 |
+
start_time = time.perf_counter()
|
| 588 |
+
|
| 589 |
+
# Execute forward pass in a new thread for non-blocking tensor
|
| 590 |
+
# allocation. Capture thread exceptions so we can surface the true
|
| 591 |
+
# failure in the main thread instead of later hitting None outputs.
|
| 592 |
+
result_container = {"output_batch": ForwardBatch(data_type=batch.data_type)}
|
| 593 |
+
thread_error: dict[str, BaseException | None] = {"error": None}
|
| 594 |
+
thread_error_traceback: dict[str, str] = {"traceback": ""}
|
| 595 |
+
|
| 596 |
+
def execute_forward_thread():
|
| 597 |
+
import traceback
|
| 598 |
+
try:
|
| 599 |
+
result_container["output_batch"] = self.executor.execute_forward(batch, fastvideo_args)
|
| 600 |
+
except BaseException as error: # noqa: BLE001
|
| 601 |
+
thread_error["error"] = error
|
| 602 |
+
thread_error_traceback["traceback"] = traceback.format_exc()
|
| 603 |
+
|
| 604 |
+
thread = threading.Thread(target=execute_forward_thread)
|
| 605 |
+
thread.start()
|
| 606 |
+
latent_batch_size = _infer_latent_batch_size(batch)
|
| 607 |
+
samples = torch.empty(
|
| 608 |
+
(latent_batch_size, 3, sampling_param.num_frames, sampling_param.height, sampling_param.width),
|
| 609 |
+
device='cpu',
|
| 610 |
+
pin_memory=fastvideo_args.pin_cpu_memory)
|
| 611 |
+
thread.join()
|
| 612 |
+
|
| 613 |
+
if thread_error["error"] is not None:
|
| 614 |
+
raise RuntimeError("Forward execution thread failed.\n"
|
| 615 |
+
f"{thread_error_traceback['traceback']}") from thread_error["error"]
|
| 616 |
+
|
| 617 |
+
output_batch = result_container["output_batch"]
|
| 618 |
+
if output_batch.output is None:
|
| 619 |
+
raise RuntimeError("Forward execution returned no output tensor. "
|
| 620 |
+
"This usually means the executor/pipeline failed earlier.")
|
| 621 |
+
|
| 622 |
+
if output_batch.output.shape == samples.shape:
|
| 623 |
+
samples.copy_(output_batch.output)
|
| 624 |
+
else:
|
| 625 |
+
logger.warning("Output shape %s does not match expected shape %s; use slow path", output_batch.output.shape,
|
| 626 |
+
samples.shape)
|
| 627 |
+
samples = output_batch.output.cpu()
|
| 628 |
+
logging_info = output_batch.logging_info
|
| 629 |
+
|
| 630 |
+
gen_time = time.perf_counter() - start_time
|
| 631 |
+
logger.info("Generated successfully in %.2f seconds", gen_time)
|
| 632 |
+
|
| 633 |
+
# Process outputs
|
| 634 |
+
videos = rearrange(samples, "b c t h w -> t b c h w")
|
| 635 |
+
frames = []
|
| 636 |
+
for x in videos:
|
| 637 |
+
x = torchvision.utils.make_grid(x, nrow=6)
|
| 638 |
+
x = x.permute(1, 2, 0).squeeze(-1)
|
| 639 |
+
x = (x * 255).to(torch.uint8)
|
| 640 |
+
frames.append(x.cpu().numpy())
|
| 641 |
+
|
| 642 |
+
# Save output if requested
|
| 643 |
+
if batch.save_video:
|
| 644 |
+
if self._is_image_workload():
|
| 645 |
+
# Image workloads (t2i, i2i, …): save the first frame as PNG.
|
| 646 |
+
imageio.imwrite(output_path, frames[0])
|
| 647 |
+
logger.info("Saved image to %s", output_path)
|
| 648 |
+
else:
|
| 649 |
+
imageio.mimsave(output_path, frames, fps=batch.fps, format="mp4")
|
| 650 |
+
logger.info("Saved video to %s", output_path)
|
| 651 |
+
audio = output_batch.extra.get("audio")
|
| 652 |
+
audio_sample_rate = output_batch.extra.get("audio_sample_rate")
|
| 653 |
+
if (audio is not None and audio_sample_rate is not None
|
| 654 |
+
and not self._mux_audio(output_path, audio, audio_sample_rate)):
|
| 655 |
+
logger.warning("Audio mux failed; saved video without audio.")
|
| 656 |
+
|
| 657 |
+
result: dict[str, Any] = {
|
| 658 |
+
"prompts": prompt,
|
| 659 |
+
"samples": samples if batch.return_frames else None,
|
| 660 |
+
"frames": frames if batch.return_frames else None,
|
| 661 |
+
"audio": output_batch.extra.get("audio") if batch.return_frames else None,
|
| 662 |
+
"size": (target_height, target_width, batch.num_frames),
|
| 663 |
+
"generation_time": gen_time,
|
| 664 |
+
"logging_info": logging_info,
|
| 665 |
+
"trajectory": output_batch.trajectory_latents,
|
| 666 |
+
"trajectory_timesteps": output_batch.trajectory_timesteps,
|
| 667 |
+
"trajectory_decoded": output_batch.trajectory_decoded,
|
| 668 |
+
"video_path": output_path if batch.save_video else None,
|
| 669 |
+
"peak_memory_mb": output_batch.extra.get("peak_memory_mb"),
|
| 670 |
+
}
|
| 671 |
+
|
| 672 |
+
return result
|
| 673 |
+
|
| 674 |
+
@staticmethod
|
| 675 |
+
def _wrap_legacy_result(
|
| 676 |
+
result: dict[str, Any] | list[dict[str, Any]], ) -> GenerationResult | list[GenerationResult]:
|
| 677 |
+
if isinstance(result, list):
|
| 678 |
+
return [GenerationResult.from_legacy_result(item) for item in result]
|
| 679 |
+
return GenerationResult.from_legacy_result(result)
|
| 680 |
+
|
| 681 |
+
@staticmethod
|
| 682 |
+
def _unwrap_typed_result(
|
| 683 |
+
result: GenerationResult | list[GenerationResult], ) -> dict[str, Any] | list[dict[str, Any]]:
|
| 684 |
+
if isinstance(result, list):
|
| 685 |
+
return [item.to_legacy_dict() for item in result]
|
| 686 |
+
return result.to_legacy_dict()
|
| 687 |
+
|
| 688 |
+
@staticmethod
|
| 689 |
+
def _mux_audio(
|
| 690 |
+
video_path: str,
|
| 691 |
+
audio: torch.Tensor | np.ndarray,
|
| 692 |
+
sample_rate: int,
|
| 693 |
+
) -> bool:
|
| 694 |
+
"""Mux audio into video using PyAV."""
|
| 695 |
+
try:
|
| 696 |
+
import av
|
| 697 |
+
except ImportError:
|
| 698 |
+
logger.warning("PyAV not installed; cannot mux audio. "
|
| 699 |
+
"Install with: pip install av")
|
| 700 |
+
return False
|
| 701 |
+
|
| 702 |
+
if torch.is_tensor(audio):
|
| 703 |
+
audio_np = audio.detach().cpu().float().numpy()
|
| 704 |
+
else:
|
| 705 |
+
audio_np = np.asarray(audio, dtype=np.float32)
|
| 706 |
+
|
| 707 |
+
if audio_np.ndim == 1:
|
| 708 |
+
audio_np = audio_np[:, None]
|
| 709 |
+
elif audio_np.ndim == 2:
|
| 710 |
+
if audio_np.shape[0] <= 8 and audio_np.shape[1] > audio_np.shape[0]:
|
| 711 |
+
audio_np = audio_np.T
|
| 712 |
+
else:
|
| 713 |
+
logger.warning("Unexpected audio shape %s; skipping mux.", audio_np.shape)
|
| 714 |
+
return False
|
| 715 |
+
|
| 716 |
+
audio_np = np.clip(audio_np, -1.0, 1.0)
|
| 717 |
+
audio_int16 = (audio_np * 32767.0).astype(np.int16)
|
| 718 |
+
num_channels = audio_int16.shape[1]
|
| 719 |
+
layout = "stereo" if num_channels == 2 else "mono"
|
| 720 |
+
|
| 721 |
+
try:
|
| 722 |
+
import wave
|
| 723 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 724 |
+
out_path = os.path.join(tmpdir, "muxed.mp4")
|
| 725 |
+
wav_path = os.path.join(tmpdir, "audio.wav")
|
| 726 |
+
|
| 727 |
+
# Write audio to WAV file
|
| 728 |
+
with wave.open(wav_path, "wb") as wav_file:
|
| 729 |
+
wav_file.setnchannels(num_channels)
|
| 730 |
+
wav_file.setsampwidth(2)
|
| 731 |
+
wav_file.setframerate(sample_rate)
|
| 732 |
+
wav_file.writeframes(audio_int16.tobytes())
|
| 733 |
+
|
| 734 |
+
# Open input video and audio
|
| 735 |
+
input_video = av.open(video_path)
|
| 736 |
+
input_audio = av.open(wav_path)
|
| 737 |
+
|
| 738 |
+
# Create output with both streams
|
| 739 |
+
output = av.open(out_path, mode="w")
|
| 740 |
+
|
| 741 |
+
# Add video stream (copy codec from input)
|
| 742 |
+
in_video_stream = input_video.streams.video[0]
|
| 743 |
+
out_video_stream = output.add_stream(
|
| 744 |
+
codec_name=in_video_stream.codec_context.name,
|
| 745 |
+
rate=in_video_stream.average_rate,
|
| 746 |
+
)
|
| 747 |
+
out_video_stream.width = in_video_stream.width
|
| 748 |
+
out_video_stream.height = in_video_stream.height
|
| 749 |
+
out_video_stream.pix_fmt = in_video_stream.pix_fmt
|
| 750 |
+
|
| 751 |
+
# Add audio stream (AAC)
|
| 752 |
+
out_audio_stream = output.add_stream("aac", rate=sample_rate)
|
| 753 |
+
out_audio_stream.layout = layout
|
| 754 |
+
|
| 755 |
+
# Remux video (decode and re-encode to be safe)
|
| 756 |
+
for frame in input_video.decode(video=0):
|
| 757 |
+
for packet in out_video_stream.encode(frame):
|
| 758 |
+
output.mux(packet)
|
| 759 |
+
for packet in out_video_stream.encode():
|
| 760 |
+
output.mux(packet)
|
| 761 |
+
|
| 762 |
+
# Encode audio
|
| 763 |
+
for frame in input_audio.decode(audio=0):
|
| 764 |
+
frame.pts = None # Let encoder assign PTS
|
| 765 |
+
for packet in out_audio_stream.encode(frame):
|
| 766 |
+
output.mux(packet)
|
| 767 |
+
for packet in out_audio_stream.encode():
|
| 768 |
+
output.mux(packet)
|
| 769 |
+
|
| 770 |
+
input_video.close()
|
| 771 |
+
input_audio.close()
|
| 772 |
+
output.close()
|
| 773 |
+
shutil.move(out_path, video_path)
|
| 774 |
+
return True
|
| 775 |
+
except Exception as e:
|
| 776 |
+
logger.warning("Audio mux failed: %s", e)
|
| 777 |
+
return False
|
| 778 |
+
|
| 779 |
+
def set_lora_adapter(self, lora_nickname: str, lora_path: str | None = None) -> None:
|
| 780 |
+
self.executor.set_lora_adapter(lora_nickname, lora_path)
|
| 781 |
+
|
| 782 |
+
def unmerge_lora_weights(self) -> None:
|
| 783 |
+
"""
|
| 784 |
+
Use unmerged weights for inference to produce videos that align with
|
| 785 |
+
validation videos generated during training.
|
| 786 |
+
"""
|
| 787 |
+
self.executor.unmerge_lora_weights()
|
| 788 |
+
|
| 789 |
+
def merge_lora_weights(self) -> None:
|
| 790 |
+
self.executor.merge_lora_weights()
|
| 791 |
+
|
| 792 |
+
def shutdown(self) -> None:
|
| 793 |
+
"""
|
| 794 |
+
Shutdown the video generator.
|
| 795 |
+
"""
|
| 796 |
+
self.executor.shutdown()
|
| 797 |
+
del self.executor
|
standalone_inference/overlay_files/fastvideo/fastvideo_args.py
ADDED
|
@@ -0,0 +1,1188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# Inspired by SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py
|
| 3 |
+
"""The arguments of FastVideo Inference."""
|
| 4 |
+
import argparse
|
| 5 |
+
import dataclasses
|
| 6 |
+
import json
|
| 7 |
+
from contextlib import contextmanager
|
| 8 |
+
from dataclasses import field
|
| 9 |
+
from enum import Enum
|
| 10 |
+
from typing import Any, TYPE_CHECKING
|
| 11 |
+
|
| 12 |
+
from fastvideo.configs.configs import PreprocessConfig
|
| 13 |
+
from fastvideo.configs.pipelines.base import PipelineConfig
|
| 14 |
+
from fastvideo.configs.utils import clean_cli_args
|
| 15 |
+
from fastvideo.layers.quantization import QUANTIZATION_METHODS, QuantizationMethods
|
| 16 |
+
from fastvideo.logger import init_logger
|
| 17 |
+
from fastvideo.utils import FlexibleArgumentParser, StoreBoolean
|
| 18 |
+
|
| 19 |
+
if TYPE_CHECKING:
|
| 20 |
+
from ray.runtime_env import RuntimeEnv
|
| 21 |
+
from ray.util.placement_group import PlacementGroup
|
| 22 |
+
else:
|
| 23 |
+
RuntimeEnv = Any
|
| 24 |
+
PlacementGroup = Any
|
| 25 |
+
|
| 26 |
+
logger = init_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ExecutionMode(str, Enum):
|
| 30 |
+
"""
|
| 31 |
+
Enumeration for different pipeline modes.
|
| 32 |
+
|
| 33 |
+
Inherits from str to allow string comparison for backward compatibility.
|
| 34 |
+
"""
|
| 35 |
+
INFERENCE = "inference"
|
| 36 |
+
PREPROCESS = "preprocess"
|
| 37 |
+
FINETUNING = "finetuning"
|
| 38 |
+
DISTILLATION = "distillation"
|
| 39 |
+
|
| 40 |
+
@classmethod
|
| 41 |
+
def from_string(cls, value: str) -> "ExecutionMode":
|
| 42 |
+
"""Convert string to ExecutionMode enum."""
|
| 43 |
+
try:
|
| 44 |
+
return cls(value.lower())
|
| 45 |
+
except ValueError:
|
| 46 |
+
raise ValueError(f"Invalid mode: {value}. Must be one of: {', '.join([m.value for m in cls])}") from None
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def choices(cls) -> list[str]:
|
| 50 |
+
"""Get all available choices as strings for argparse."""
|
| 51 |
+
return [mode.value for mode in cls]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class WorkloadType(str, Enum):
|
| 55 |
+
"""
|
| 56 |
+
Enumeration for different workload types.
|
| 57 |
+
|
| 58 |
+
Inherits from str to allow string comparison for backward compatibility.
|
| 59 |
+
"""
|
| 60 |
+
I2V = "i2v" # Image to Video
|
| 61 |
+
T2V = "t2v" # Text to Video
|
| 62 |
+
T2I = "t2i" # Text to Image
|
| 63 |
+
I2I = "i2i" # Image to Image
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def from_string(cls, value: str) -> "WorkloadType":
|
| 67 |
+
"""Convert string to WorkloadType enum."""
|
| 68 |
+
try:
|
| 69 |
+
return cls(value.lower())
|
| 70 |
+
except ValueError:
|
| 71 |
+
raise ValueError(
|
| 72 |
+
f"Invalid workload type: {value}. Must be one of: {', '.join([m.value for m in cls])}") from None
|
| 73 |
+
|
| 74 |
+
@classmethod
|
| 75 |
+
def choices(cls) -> list[str]:
|
| 76 |
+
"""Get all available choices as strings for argparse."""
|
| 77 |
+
return [workload.value for workload in cls]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# args for fastvideo framework
|
| 81 |
+
@dataclasses.dataclass
|
| 82 |
+
class FastVideoArgs:
|
| 83 |
+
# Model and path configuration (for convenience)
|
| 84 |
+
model_path: str
|
| 85 |
+
|
| 86 |
+
# Running mode
|
| 87 |
+
mode: ExecutionMode = ExecutionMode.INFERENCE
|
| 88 |
+
|
| 89 |
+
# Workload type
|
| 90 |
+
workload_type: WorkloadType = WorkloadType.T2V
|
| 91 |
+
|
| 92 |
+
# Distributed executor backend
|
| 93 |
+
distributed_executor_backend: str = "mp"
|
| 94 |
+
|
| 95 |
+
# a few attributes for ray related
|
| 96 |
+
ray_placement_group: PlacementGroup | None = None
|
| 97 |
+
ray_runtime_env: RuntimeEnv | None = None
|
| 98 |
+
|
| 99 |
+
inference_mode: bool = True # if False == training mode
|
| 100 |
+
|
| 101 |
+
# HuggingFace specific parameters
|
| 102 |
+
trust_remote_code: bool = False
|
| 103 |
+
revision: str | None = None
|
| 104 |
+
|
| 105 |
+
# Parallelism
|
| 106 |
+
num_gpus: int = 1
|
| 107 |
+
tp_size: int = -1
|
| 108 |
+
sp_size: int = -1
|
| 109 |
+
hsdp_replicate_dim: int = 1
|
| 110 |
+
hsdp_shard_dim: int = -1
|
| 111 |
+
dist_timeout: int | None = None # timeout for torch.distributed
|
| 112 |
+
|
| 113 |
+
pipeline_config: PipelineConfig = field(default_factory=PipelineConfig)
|
| 114 |
+
preprocess_config: PreprocessConfig | None = None
|
| 115 |
+
|
| 116 |
+
# LoRA parameters
|
| 117 |
+
# (Wenxuan) prefer to keep it here instead of in pipeline config to not make it complicated.
|
| 118 |
+
lora_path: str | None = None
|
| 119 |
+
lora_nickname: str = "default" # for swapping adapters in the pipeline
|
| 120 |
+
# can restrict layers to adapt, e.g. ["q_proj"]
|
| 121 |
+
# Will adapt only q, k, v, o by default.
|
| 122 |
+
lora_target_modules: list[str] | None = None
|
| 123 |
+
|
| 124 |
+
output_type: str = "pil"
|
| 125 |
+
|
| 126 |
+
# CPU offload parameters
|
| 127 |
+
dit_cpu_offload: bool = True
|
| 128 |
+
use_fsdp_inference: bool = False
|
| 129 |
+
dit_layerwise_offload: bool = True
|
| 130 |
+
text_encoder_cpu_offload: bool = True
|
| 131 |
+
image_encoder_cpu_offload: bool = True
|
| 132 |
+
vae_cpu_offload: bool = True
|
| 133 |
+
pin_cpu_memory: bool = True
|
| 134 |
+
|
| 135 |
+
# Compilation
|
| 136 |
+
enable_torch_compile: bool = False
|
| 137 |
+
torch_compile_kwargs: dict[str, Any] = field(default_factory=dict)
|
| 138 |
+
|
| 139 |
+
disable_autocast: bool = False
|
| 140 |
+
|
| 141 |
+
# VSA parameters
|
| 142 |
+
VSA_sparsity: float = 0.0 # inference/validation sparsity
|
| 143 |
+
|
| 144 |
+
# V-MoBA parameters
|
| 145 |
+
moba_config_path: str | None = None
|
| 146 |
+
moba_config: dict[str, Any] = field(default_factory=dict)
|
| 147 |
+
|
| 148 |
+
# Master port for distributed training/inference
|
| 149 |
+
master_port: int | None = None
|
| 150 |
+
|
| 151 |
+
# Stage verification
|
| 152 |
+
enable_stage_verification: bool = True
|
| 153 |
+
|
| 154 |
+
# Prompt text file for batch processing
|
| 155 |
+
prompt_txt: str | None = None
|
| 156 |
+
|
| 157 |
+
# LTX-2 VAE tiling overrides
|
| 158 |
+
ltx2_vae_tiling: bool | None = None
|
| 159 |
+
ltx2_vae_spatial_tile_size_in_pixels: int | None = None
|
| 160 |
+
ltx2_vae_spatial_tile_overlap_in_pixels: int | None = None
|
| 161 |
+
ltx2_vae_temporal_tile_size_in_frames: int | None = None
|
| 162 |
+
ltx2_vae_temporal_tile_overlap_in_frames: int | None = None
|
| 163 |
+
ltx2_initial_latent_path: str | None = None
|
| 164 |
+
|
| 165 |
+
# model paths for correct deallocation
|
| 166 |
+
model_paths: dict[str, str] = field(default_factory=dict)
|
| 167 |
+
model_loaded: dict[str, bool] = field(default_factory=lambda: {
|
| 168 |
+
"transformer": True,
|
| 169 |
+
"vae": True,
|
| 170 |
+
"upsampler": True,
|
| 171 |
+
})
|
| 172 |
+
|
| 173 |
+
override_text_encoder_safetensors: str | None = None # path to safetensors file for text encoder override
|
| 174 |
+
override_text_encoder_quant: QuantizationMethods = None
|
| 175 |
+
transformer_quant: QuantizationMethods = None
|
| 176 |
+
|
| 177 |
+
override_transformer_cls_name: str | None = None
|
| 178 |
+
init_weights_from_safetensors: str = "" # path to safetensors file for initial weight loading
|
| 179 |
+
init_weights_from_safetensors_2: str = "" # path to safetensors file for initial weight loading for transformer_2
|
| 180 |
+
|
| 181 |
+
override_pipeline_cls_name: str | None = None
|
| 182 |
+
|
| 183 |
+
# # DMD parameters
|
| 184 |
+
# dmd_denoising_steps: List[int] | None = field(default=None)
|
| 185 |
+
|
| 186 |
+
# MoE parameters used by Wan2.2
|
| 187 |
+
boundary_ratio: float = 0.875
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def training_mode(self) -> bool:
|
| 191 |
+
return not self.inference_mode
|
| 192 |
+
|
| 193 |
+
def __post_init__(self):
|
| 194 |
+
if self.moba_config_path:
|
| 195 |
+
try:
|
| 196 |
+
with open(self.moba_config_path) as f:
|
| 197 |
+
self.moba_config = json.load(f)
|
| 198 |
+
logger.info("Loaded V-MoBA config from %s", self.moba_config_path)
|
| 199 |
+
except (FileNotFoundError, json.JSONDecodeError) as e:
|
| 200 |
+
logger.error("Failed to load V-MoBA config from %s: %s", self.moba_config_path, e)
|
| 201 |
+
raise
|
| 202 |
+
self._apply_ltx2_vae_overrides()
|
| 203 |
+
self.check_fastvideo_args()
|
| 204 |
+
|
| 205 |
+
def __getattr__(self, name: str) -> Any:
|
| 206 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
| 207 |
+
|
| 208 |
+
def _apply_ltx2_vae_overrides(self) -> None:
|
| 209 |
+
if self.pipeline_config is None:
|
| 210 |
+
return
|
| 211 |
+
vae_config = self.pipeline_config.vae_config
|
| 212 |
+
has_any = any(value is not None for value in (
|
| 213 |
+
self.ltx2_vae_spatial_tile_size_in_pixels,
|
| 214 |
+
self.ltx2_vae_spatial_tile_overlap_in_pixels,
|
| 215 |
+
self.ltx2_vae_temporal_tile_size_in_frames,
|
| 216 |
+
self.ltx2_vae_temporal_tile_overlap_in_frames,
|
| 217 |
+
))
|
| 218 |
+
if self.ltx2_vae_tiling is not None and hasattr(self.pipeline_config, "vae_tiling"):
|
| 219 |
+
self.pipeline_config.vae_tiling = self.ltx2_vae_tiling
|
| 220 |
+
elif has_any and hasattr(self.pipeline_config, "vae_tiling"):
|
| 221 |
+
self.pipeline_config.vae_tiling = True
|
| 222 |
+
|
| 223 |
+
if hasattr(vae_config,
|
| 224 |
+
"ltx2_spatial_tile_size_in_pixels") and self.ltx2_vae_spatial_tile_size_in_pixels is not None:
|
| 225 |
+
vae_config.ltx2_spatial_tile_size_in_pixels = (self.ltx2_vae_spatial_tile_size_in_pixels)
|
| 226 |
+
if hasattr(vae_config,
|
| 227 |
+
"ltx2_spatial_tile_overlap_in_pixels") and self.ltx2_vae_spatial_tile_overlap_in_pixels is not None:
|
| 228 |
+
vae_config.ltx2_spatial_tile_overlap_in_pixels = (self.ltx2_vae_spatial_tile_overlap_in_pixels)
|
| 229 |
+
if hasattr(vae_config,
|
| 230 |
+
"ltx2_temporal_tile_size_in_frames") and self.ltx2_vae_temporal_tile_size_in_frames is not None:
|
| 231 |
+
vae_config.ltx2_temporal_tile_size_in_frames = (self.ltx2_vae_temporal_tile_size_in_frames)
|
| 232 |
+
if hasattr(
|
| 233 |
+
vae_config,
|
| 234 |
+
"ltx2_temporal_tile_overlap_in_frames") and self.ltx2_vae_temporal_tile_overlap_in_frames is not None:
|
| 235 |
+
vae_config.ltx2_temporal_tile_overlap_in_frames = (self.ltx2_vae_temporal_tile_overlap_in_frames)
|
| 236 |
+
|
| 237 |
+
@staticmethod
|
| 238 |
+
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
| 239 |
+
# Model and path configuration
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
"--model-path",
|
| 242 |
+
type=str,
|
| 243 |
+
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Running mode
|
| 247 |
+
parser.add_argument(
|
| 248 |
+
"--mode",
|
| 249 |
+
type=str,
|
| 250 |
+
choices=ExecutionMode.choices(),
|
| 251 |
+
default=FastVideoArgs.mode.value,
|
| 252 |
+
help="The mode to run FastVideo",
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Workload type
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--workload-type",
|
| 258 |
+
type=str,
|
| 259 |
+
choices=WorkloadType.choices(),
|
| 260 |
+
default=FastVideoArgs.workload_type.value,
|
| 261 |
+
help="The workload type",
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# distributed_executor_backend
|
| 265 |
+
parser.add_argument(
|
| 266 |
+
"--distributed-executor-backend",
|
| 267 |
+
type=str,
|
| 268 |
+
choices=["mp"],
|
| 269 |
+
default=FastVideoArgs.distributed_executor_backend,
|
| 270 |
+
help="The distributed executor backend to use",
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
"--inference-mode",
|
| 275 |
+
action=StoreBoolean,
|
| 276 |
+
default=FastVideoArgs.inference_mode,
|
| 277 |
+
help="Whether to use inference mode",
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# HuggingFace specific parameters
|
| 281 |
+
parser.add_argument(
|
| 282 |
+
"--trust-remote-code",
|
| 283 |
+
action=StoreBoolean,
|
| 284 |
+
default=FastVideoArgs.trust_remote_code,
|
| 285 |
+
help="Trust remote code when loading HuggingFace models",
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--revision",
|
| 289 |
+
type=str,
|
| 290 |
+
default=FastVideoArgs.revision,
|
| 291 |
+
help="The specific model version to use (can be a branch name, tag name, or commit id)",
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# Parallelism
|
| 295 |
+
parser.add_argument(
|
| 296 |
+
"--num-gpus",
|
| 297 |
+
type=int,
|
| 298 |
+
default=FastVideoArgs.num_gpus,
|
| 299 |
+
help="The number of GPUs to use.",
|
| 300 |
+
)
|
| 301 |
+
parser.add_argument(
|
| 302 |
+
"--tp-size",
|
| 303 |
+
type=int,
|
| 304 |
+
default=FastVideoArgs.tp_size,
|
| 305 |
+
help="The tensor parallelism size.",
|
| 306 |
+
)
|
| 307 |
+
parser.add_argument(
|
| 308 |
+
"--sp-size",
|
| 309 |
+
type=int,
|
| 310 |
+
default=FastVideoArgs.sp_size,
|
| 311 |
+
help="The sequence parallelism size.",
|
| 312 |
+
)
|
| 313 |
+
parser.add_argument(
|
| 314 |
+
"--hsdp-replicate-dim",
|
| 315 |
+
type=int,
|
| 316 |
+
default=FastVideoArgs.hsdp_replicate_dim,
|
| 317 |
+
help="The data parallelism size.",
|
| 318 |
+
)
|
| 319 |
+
parser.add_argument(
|
| 320 |
+
"--hsdp-shard-dim",
|
| 321 |
+
type=int,
|
| 322 |
+
default=FastVideoArgs.hsdp_shard_dim,
|
| 323 |
+
help="The data parallelism shards.",
|
| 324 |
+
)
|
| 325 |
+
parser.add_argument(
|
| 326 |
+
"--dist-timeout",
|
| 327 |
+
type=int,
|
| 328 |
+
default=FastVideoArgs.dist_timeout,
|
| 329 |
+
help="Set timeout for torch.distributed initialization.",
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# Output type
|
| 333 |
+
parser.add_argument(
|
| 334 |
+
"--output-type",
|
| 335 |
+
type=str,
|
| 336 |
+
default=FastVideoArgs.output_type,
|
| 337 |
+
choices=["pil"],
|
| 338 |
+
help="Output type for the generated video",
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Prompt text file for batch processing
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
"--prompt-txt",
|
| 344 |
+
type=str,
|
| 345 |
+
default=FastVideoArgs.prompt_txt,
|
| 346 |
+
help="Path to a text file containing prompts (one per line) for batch processing",
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# LTX-2 VAE tiling overrides
|
| 350 |
+
parser.add_argument(
|
| 351 |
+
"--ltx2-vae-tiling",
|
| 352 |
+
action=StoreBoolean,
|
| 353 |
+
default=FastVideoArgs.ltx2_vae_tiling,
|
| 354 |
+
help="Enable LTX-2 VAE tiling overrides.",
|
| 355 |
+
)
|
| 356 |
+
parser.add_argument(
|
| 357 |
+
"--ltx2-vae-spatial-tile-size-in-pixels",
|
| 358 |
+
type=int,
|
| 359 |
+
default=FastVideoArgs.ltx2_vae_spatial_tile_size_in_pixels,
|
| 360 |
+
help="LTX-2 VAE spatial tile size in pixels.",
|
| 361 |
+
)
|
| 362 |
+
parser.add_argument(
|
| 363 |
+
"--ltx2-vae-spatial-tile-overlap-in-pixels",
|
| 364 |
+
type=int,
|
| 365 |
+
default=FastVideoArgs.ltx2_vae_spatial_tile_overlap_in_pixels,
|
| 366 |
+
help="LTX-2 VAE spatial tile overlap in pixels.",
|
| 367 |
+
)
|
| 368 |
+
parser.add_argument(
|
| 369 |
+
"--ltx2-vae-temporal-tile-size-in-frames",
|
| 370 |
+
type=int,
|
| 371 |
+
default=FastVideoArgs.ltx2_vae_temporal_tile_size_in_frames,
|
| 372 |
+
help="LTX-2 VAE temporal tile size in frames.",
|
| 373 |
+
)
|
| 374 |
+
parser.add_argument(
|
| 375 |
+
"--ltx2-vae-temporal-tile-overlap-in-frames",
|
| 376 |
+
type=int,
|
| 377 |
+
default=FastVideoArgs.ltx2_vae_temporal_tile_overlap_in_frames,
|
| 378 |
+
help="LTX-2 VAE temporal tile overlap in frames.",
|
| 379 |
+
)
|
| 380 |
+
parser.add_argument(
|
| 381 |
+
"--ltx2-initial-latent-path",
|
| 382 |
+
type=str,
|
| 383 |
+
default=FastVideoArgs.ltx2_initial_latent_path,
|
| 384 |
+
help="Path to load/save a precomputed LTX-2 initial latent.",
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# LoRA parameters (inference-time adapter loading)
|
| 388 |
+
parser.add_argument(
|
| 389 |
+
"--lora-path",
|
| 390 |
+
type=str,
|
| 391 |
+
default=FastVideoArgs.lora_path,
|
| 392 |
+
help="Path to a LoRA adapter (directory or HF repo id). If set, LoRA will be applied at inference.",
|
| 393 |
+
)
|
| 394 |
+
parser.add_argument(
|
| 395 |
+
"--lora-nickname",
|
| 396 |
+
type=str,
|
| 397 |
+
default=FastVideoArgs.lora_nickname,
|
| 398 |
+
help="Nickname to refer to the loaded LoRA adapter (useful for swapping).",
|
| 399 |
+
)
|
| 400 |
+
parser.add_argument(
|
| 401 |
+
"--lora-target-modules",
|
| 402 |
+
nargs="+",
|
| 403 |
+
type=str,
|
| 404 |
+
default=FastVideoArgs.lora_target_modules,
|
| 405 |
+
help="Optional list of module name substrings to restrict LoRA injection (e.g. q_proj k_proj v_proj).",
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# BSA runtime control (LongCat)
|
| 409 |
+
parser.add_argument(
|
| 410 |
+
"--enable-bsa",
|
| 411 |
+
action=StoreBoolean,
|
| 412 |
+
help="Enable Block Sparse Attention (BSA) at runtime (overrides config).",
|
| 413 |
+
)
|
| 414 |
+
parser.add_argument(
|
| 415 |
+
"--bsa-sparsity",
|
| 416 |
+
type=float,
|
| 417 |
+
help="BSA sparsity (e.g., 0.9375).",
|
| 418 |
+
)
|
| 419 |
+
parser.add_argument(
|
| 420 |
+
"--bsa-cdf-threshold",
|
| 421 |
+
type=float,
|
| 422 |
+
help="BSA CDF threshold (optional).",
|
| 423 |
+
)
|
| 424 |
+
parser.add_argument(
|
| 425 |
+
"--bsa-chunk-q",
|
| 426 |
+
nargs=3,
|
| 427 |
+
type=int,
|
| 428 |
+
metavar=("T", "H", "W"),
|
| 429 |
+
help="BSA chunk_3d_shape_q as three ints, e.g., 4 4 4.",
|
| 430 |
+
)
|
| 431 |
+
parser.add_argument(
|
| 432 |
+
"--bsa-chunk-k",
|
| 433 |
+
nargs=3,
|
| 434 |
+
type=int,
|
| 435 |
+
metavar=("T", "H", "W"),
|
| 436 |
+
help="BSA chunk_3d_shape_k as three ints, e.g., 4 4 4.",
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
parser.add_argument(
|
| 440 |
+
"--enable-torch-compile",
|
| 441 |
+
action=StoreBoolean,
|
| 442 |
+
default=FastVideoArgs.enable_torch_compile,
|
| 443 |
+
help="Use torch.compile to speed up DiT inference." +
|
| 444 |
+
"However, will likely cause precision drifts. See (https://github.com/pytorch/pytorch/issues/145213)",
|
| 445 |
+
)
|
| 446 |
+
parser.add_argument(
|
| 447 |
+
"--torch-compile-kwargs",
|
| 448 |
+
type=str,
|
| 449 |
+
default=None,
|
| 450 |
+
help=
|
| 451 |
+
"JSON string of kwargs to pass to torch.compile. Example: '{\"backend\":\"inductor\",\"mode\":\"reduce-overhead\"}'",
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
parser.add_argument(
|
| 455 |
+
"--dit-cpu-offload",
|
| 456 |
+
action=StoreBoolean,
|
| 457 |
+
help="Use CPU offload for DiT inference. Enable if run out of memory with FSDP.",
|
| 458 |
+
)
|
| 459 |
+
parser.add_argument(
|
| 460 |
+
"--dit-layerwise-offload",
|
| 461 |
+
action=StoreBoolean,
|
| 462 |
+
help="Enable layerwise CPU offload with async H2D prefetch overlap.",
|
| 463 |
+
)
|
| 464 |
+
parser.add_argument(
|
| 465 |
+
"--use-fsdp-inference",
|
| 466 |
+
action=StoreBoolean,
|
| 467 |
+
help=
|
| 468 |
+
"Use FSDP for inference by sharding the model weights. Latency is very low due to prefetch--enable if run out of memory.",
|
| 469 |
+
)
|
| 470 |
+
parser.add_argument(
|
| 471 |
+
"--text-encoder-cpu-offload",
|
| 472 |
+
action=StoreBoolean,
|
| 473 |
+
help="Use CPU offload for text encoder. Enable if run out of memory.",
|
| 474 |
+
)
|
| 475 |
+
parser.add_argument(
|
| 476 |
+
"--image-encoder-cpu-offload",
|
| 477 |
+
action=StoreBoolean,
|
| 478 |
+
help="Use CPU offload for image encoder. Enable if run out of memory.",
|
| 479 |
+
)
|
| 480 |
+
parser.add_argument(
|
| 481 |
+
"--vae-cpu-offload",
|
| 482 |
+
action=StoreBoolean,
|
| 483 |
+
help="Use CPU offload for VAE. Enable if run out of memory.",
|
| 484 |
+
)
|
| 485 |
+
parser.add_argument(
|
| 486 |
+
"--pin-cpu-memory",
|
| 487 |
+
action=StoreBoolean,
|
| 488 |
+
help=
|
| 489 |
+
"Pin memory for CPU offload. Only added as a temp workaround if it throws \"CUDA error: invalid argument\". "
|
| 490 |
+
"Should be enabled in almost all cases",
|
| 491 |
+
)
|
| 492 |
+
parser.add_argument(
|
| 493 |
+
"--disable-autocast",
|
| 494 |
+
action=StoreBoolean,
|
| 495 |
+
help="Disable autocast for denoising loop and vae decoding in pipeline sampling",
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
# VSA parameters
|
| 499 |
+
parser.add_argument(
|
| 500 |
+
"--VSA-sparsity",
|
| 501 |
+
type=float,
|
| 502 |
+
default=FastVideoArgs.VSA_sparsity,
|
| 503 |
+
help="Validation sparsity for VSA",
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
# Master port for distributed training/inference
|
| 507 |
+
parser.add_argument(
|
| 508 |
+
"--master-port",
|
| 509 |
+
type=int,
|
| 510 |
+
default=FastVideoArgs.master_port,
|
| 511 |
+
help="Master port for distributed training/inference",
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# Stage verification
|
| 515 |
+
parser.add_argument(
|
| 516 |
+
"--enable-stage-verification",
|
| 517 |
+
action=StoreBoolean,
|
| 518 |
+
default=FastVideoArgs.enable_stage_verification,
|
| 519 |
+
help="Enable input/output verification for pipeline stages",
|
| 520 |
+
)
|
| 521 |
+
parser.add_argument(
|
| 522 |
+
"--override-text-encoder-safetensors",
|
| 523 |
+
type=str,
|
| 524 |
+
default=FastVideoArgs.override_text_encoder_safetensors,
|
| 525 |
+
help="Path to safetensors file for text encoder override",
|
| 526 |
+
)
|
| 527 |
+
parser.add_argument(
|
| 528 |
+
"--override-text-encoder-quant",
|
| 529 |
+
type=str,
|
| 530 |
+
choices=QUANTIZATION_METHODS,
|
| 531 |
+
default=FastVideoArgs.override_text_encoder_quant,
|
| 532 |
+
help="Quantization method for text encoder override",
|
| 533 |
+
)
|
| 534 |
+
parser.add_argument(
|
| 535 |
+
"--transformer-quant",
|
| 536 |
+
type=str,
|
| 537 |
+
choices=QUANTIZATION_METHODS,
|
| 538 |
+
default=FastVideoArgs.transformer_quant,
|
| 539 |
+
help="Quantization method for transformer loading",
|
| 540 |
+
)
|
| 541 |
+
parser.add_argument(
|
| 542 |
+
"--override-transformer-cls-name",
|
| 543 |
+
type=str,
|
| 544 |
+
default=FastVideoArgs.override_transformer_cls_name,
|
| 545 |
+
help="Override transformer cls name",
|
| 546 |
+
)
|
| 547 |
+
parser.add_argument(
|
| 548 |
+
"--override-pipeline-cls-name",
|
| 549 |
+
type=str,
|
| 550 |
+
default=FastVideoArgs.override_pipeline_cls_name,
|
| 551 |
+
help="Override pipeline cls name",
|
| 552 |
+
)
|
| 553 |
+
parser.add_argument("--init-weights-from-safetensors",
|
| 554 |
+
type=str,
|
| 555 |
+
help="Path to safetensors file for initial weight loading")
|
| 556 |
+
parser.add_argument("--init-weights-from-safetensors-2",
|
| 557 |
+
type=str,
|
| 558 |
+
help="Path to safetensors file for initial weight loading")
|
| 559 |
+
|
| 560 |
+
# Add pipeline configuration arguments
|
| 561 |
+
PipelineConfig.add_cli_args(parser)
|
| 562 |
+
|
| 563 |
+
# Add preprocessing configuration arguments
|
| 564 |
+
PreprocessConfig.add_cli_args(parser)
|
| 565 |
+
|
| 566 |
+
return parser
|
| 567 |
+
|
| 568 |
+
@classmethod
|
| 569 |
+
def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
|
| 570 |
+
provided_args = clean_cli_args(args)
|
| 571 |
+
# Get all fields from the dataclass
|
| 572 |
+
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
| 573 |
+
|
| 574 |
+
# Create a dictionary of attribute values, with defaults for missing attributes
|
| 575 |
+
kwargs: dict[str, Any] = {}
|
| 576 |
+
for attr in attrs:
|
| 577 |
+
if attr == 'pipeline_config':
|
| 578 |
+
pipeline_config = PipelineConfig.from_kwargs(provided_args)
|
| 579 |
+
kwargs['pipeline_config'] = pipeline_config
|
| 580 |
+
elif attr == 'preprocess_config':
|
| 581 |
+
preprocess_config = PreprocessConfig.from_kwargs(provided_args)
|
| 582 |
+
kwargs['preprocess_config'] = preprocess_config
|
| 583 |
+
elif attr == 'mode':
|
| 584 |
+
# Convert string to ExecutionMode enum
|
| 585 |
+
mode_value = getattr(args, attr, FastVideoArgs.mode.value)
|
| 586 |
+
kwargs['mode'] = ExecutionMode.from_string(mode_value) if isinstance(mode_value, str) else mode_value
|
| 587 |
+
elif attr == 'torch_compile_kwargs':
|
| 588 |
+
# Parse JSON string for torch.compile kwargs
|
| 589 |
+
torch_compile_kwargs_str = getattr(args, 'torch_compile_kwargs', None)
|
| 590 |
+
if torch_compile_kwargs_str:
|
| 591 |
+
try:
|
| 592 |
+
import json
|
| 593 |
+
kwargs['torch_compile_kwargs'] = json.loads(torch_compile_kwargs_str)
|
| 594 |
+
except json.JSONDecodeError as e:
|
| 595 |
+
raise ValueError(f"Invalid JSON for torch_compile_kwargs: {e}") from e
|
| 596 |
+
else:
|
| 597 |
+
kwargs['torch_compile_kwargs'] = {}
|
| 598 |
+
elif attr == 'workload_type':
|
| 599 |
+
# Convert string to WorkloadType enum
|
| 600 |
+
workload_type_value = getattr(args, 'workload_type', FastVideoArgs.workload_type.value)
|
| 601 |
+
kwargs['workload_type'] = WorkloadType.from_string(workload_type_value) if isinstance(
|
| 602 |
+
workload_type_value, str) else workload_type_value
|
| 603 |
+
# Use getattr with default value from the dataclass for potentially missing attributes
|
| 604 |
+
else:
|
| 605 |
+
# Get the field to check if it has a default_factory
|
| 606 |
+
field = dataclasses.fields(cls)[next(i for i, f in enumerate(dataclasses.fields(cls))
|
| 607 |
+
if f.name == attr)]
|
| 608 |
+
if field.default_factory is not dataclasses.MISSING:
|
| 609 |
+
# Use the default_factory to create the default value
|
| 610 |
+
default_value = field.default_factory()
|
| 611 |
+
else:
|
| 612 |
+
default_value = getattr(cls, attr, None)
|
| 613 |
+
value = getattr(args, attr, default_value)
|
| 614 |
+
kwargs[attr] = value # type: ignore
|
| 615 |
+
|
| 616 |
+
return cls(**kwargs) # type: ignore
|
| 617 |
+
|
| 618 |
+
@classmethod
|
| 619 |
+
def from_kwargs(cls, **kwargs: Any) -> "FastVideoArgs":
|
| 620 |
+
# Convert mode string to enum if necessary
|
| 621 |
+
if 'mode' in kwargs and isinstance(kwargs['mode'], str):
|
| 622 |
+
kwargs['mode'] = ExecutionMode.from_string(kwargs['mode'])
|
| 623 |
+
|
| 624 |
+
# Convert workload_type string to enum if necessary
|
| 625 |
+
if 'workload_type' in kwargs and isinstance(kwargs['workload_type'], str):
|
| 626 |
+
kwargs['workload_type'] = WorkloadType.from_string(kwargs['workload_type'])
|
| 627 |
+
|
| 628 |
+
kwargs['pipeline_config'] = PipelineConfig.from_kwargs(kwargs)
|
| 629 |
+
kwargs['preprocess_config'] = PreprocessConfig.from_kwargs(kwargs)
|
| 630 |
+
# Filter to only FastVideoArgs dataclass fields — pipeline-specific CLI
|
| 631 |
+
# args (e.g. enable_bsa, bsa_sparsity) live in PipelineConfig and must
|
| 632 |
+
# not be forwarded to the FastVideoArgs constructor.
|
| 633 |
+
valid_fields = {f.name for f in dataclasses.fields(cls)}
|
| 634 |
+
return cls(**{k: v for k, v in kwargs.items() if k in valid_fields})
|
| 635 |
+
|
| 636 |
+
def check_fastvideo_args(self) -> None:
|
| 637 |
+
"""Validate inference arguments for consistency"""
|
| 638 |
+
from fastvideo.platforms import current_platform
|
| 639 |
+
|
| 640 |
+
if current_platform.is_mps():
|
| 641 |
+
self.use_fsdp_inference = False
|
| 642 |
+
self.dit_layerwise_offload = False
|
| 643 |
+
|
| 644 |
+
if self.dit_layerwise_offload:
|
| 645 |
+
if self.use_fsdp_inference:
|
| 646 |
+
logger.warning("dit_layerwise_offload is enabled, automatically disabling use_fsdp_inference.")
|
| 647 |
+
self.use_fsdp_inference = False
|
| 648 |
+
if self.dit_cpu_offload:
|
| 649 |
+
logger.warning("dit_layerwise_offload is enabled, automatically disabling dit_cpu_offload.")
|
| 650 |
+
self.dit_cpu_offload = False
|
| 651 |
+
|
| 652 |
+
# Validate mode and inference_mode consistency
|
| 653 |
+
assert isinstance(self.mode, ExecutionMode), f"Mode must be an ExecutionMode enum, got {type(self.mode)}"
|
| 654 |
+
assert self.mode in ExecutionMode.choices(), f"Invalid execution mode: {self.mode}"
|
| 655 |
+
|
| 656 |
+
# Validate workload type
|
| 657 |
+
assert isinstance(self.workload_type,
|
| 658 |
+
WorkloadType), f"Workload type must be a WorkloadType enum, got {type(self.workload_type)}"
|
| 659 |
+
assert self.workload_type in WorkloadType.choices(), f"Invalid workload type: {self.workload_type}"
|
| 660 |
+
|
| 661 |
+
if self.mode in [ExecutionMode.DISTILLATION, ExecutionMode.FINETUNING] and self.inference_mode:
|
| 662 |
+
logger.warning("Mode is 'training' but inference_mode is True. Setting inference_mode to False.")
|
| 663 |
+
self.inference_mode = False
|
| 664 |
+
elif self.mode in [ExecutionMode.INFERENCE, ExecutionMode.PREPROCESS] and not self.inference_mode:
|
| 665 |
+
logger.warning("Mode is '%s' but inference_mode is False. Setting inference_mode to True.", self.mode)
|
| 666 |
+
self.inference_mode = True
|
| 667 |
+
|
| 668 |
+
if not self.inference_mode:
|
| 669 |
+
assert self.hsdp_replicate_dim != -1, "hsdp_replicate_dim must be set for training"
|
| 670 |
+
assert self.hsdp_shard_dim != -1, "hsdp_shard_dim must be set for training"
|
| 671 |
+
assert self.sp_size != -1, "sp_size must be set for training"
|
| 672 |
+
|
| 673 |
+
if self.tp_size == -1:
|
| 674 |
+
self.tp_size = 1
|
| 675 |
+
if self.sp_size == -1:
|
| 676 |
+
self.sp_size = self.num_gpus
|
| 677 |
+
if self.hsdp_shard_dim == -1:
|
| 678 |
+
self.hsdp_shard_dim = self.num_gpus
|
| 679 |
+
|
| 680 |
+
assert self.sp_size <= self.num_gpus and self.num_gpus % self.sp_size == 0, "num_gpus must >= and be divisible by sp_size"
|
| 681 |
+
assert self.hsdp_replicate_dim <= self.num_gpus and self.num_gpus % self.hsdp_replicate_dim == 0, "num_gpus must >= and be divisible by hsdp_replicate_dim"
|
| 682 |
+
assert self.hsdp_shard_dim <= self.num_gpus and self.num_gpus % self.hsdp_shard_dim == 0, "num_gpus must >= and be divisible by hsdp_shard_dim"
|
| 683 |
+
|
| 684 |
+
if self.num_gpus < max(self.tp_size, self.sp_size):
|
| 685 |
+
self.num_gpus = max(self.tp_size, self.sp_size)
|
| 686 |
+
|
| 687 |
+
if self.pipeline_config is None:
|
| 688 |
+
raise ValueError("pipeline_config is not set in FastVideoArgs")
|
| 689 |
+
|
| 690 |
+
self.pipeline_config.check_pipeline_config()
|
| 691 |
+
|
| 692 |
+
# Add preprocessing config validation if needed
|
| 693 |
+
if self.mode == ExecutionMode.PREPROCESS:
|
| 694 |
+
if self.preprocess_config is None:
|
| 695 |
+
raise ValueError("preprocess_config is not set in FastVideoArgs when mode is PREPROCESS")
|
| 696 |
+
if self.preprocess_config.model_path == "":
|
| 697 |
+
self.preprocess_config.model_path = self.model_path
|
| 698 |
+
if not self.pipeline_config.vae_config.load_encoder:
|
| 699 |
+
self.pipeline_config.vae_config.load_encoder = True
|
| 700 |
+
self.preprocess_config.check_preprocess_config()
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
_current_fastvideo_args = None
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def prepare_fastvideo_args(argv: list[str]) -> FastVideoArgs:
|
| 707 |
+
"""
|
| 708 |
+
Prepare the inference arguments from the command line arguments.
|
| 709 |
+
|
| 710 |
+
Args:
|
| 711 |
+
argv: The command line arguments. Typically, it should be `sys.argv[1:]`
|
| 712 |
+
to ensure compatibility with `parse_args` when no arguments are passed.
|
| 713 |
+
|
| 714 |
+
Returns:
|
| 715 |
+
The inference arguments.
|
| 716 |
+
"""
|
| 717 |
+
parser = FlexibleArgumentParser()
|
| 718 |
+
FastVideoArgs.add_cli_args(parser)
|
| 719 |
+
raw_args = parser.parse_args(argv)
|
| 720 |
+
fastvideo_args = FastVideoArgs.from_cli_args(raw_args)
|
| 721 |
+
global _current_fastvideo_args
|
| 722 |
+
_current_fastvideo_args = fastvideo_args
|
| 723 |
+
return fastvideo_args
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
@contextmanager
|
| 727 |
+
def set_current_fastvideo_args(fastvideo_args: FastVideoArgs):
|
| 728 |
+
"""
|
| 729 |
+
Temporarily set the current fastvideo config.
|
| 730 |
+
Used during model initialization.
|
| 731 |
+
We save the current fastvideo config in a global variable,
|
| 732 |
+
so that all modules can access it, e.g. custom ops
|
| 733 |
+
can access the fastvideo config to determine how to dispatch.
|
| 734 |
+
"""
|
| 735 |
+
global _current_fastvideo_args
|
| 736 |
+
old_fastvideo_args = _current_fastvideo_args
|
| 737 |
+
try:
|
| 738 |
+
_current_fastvideo_args = fastvideo_args
|
| 739 |
+
yield
|
| 740 |
+
finally:
|
| 741 |
+
_current_fastvideo_args = old_fastvideo_args
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
def get_current_fastvideo_args() -> FastVideoArgs:
|
| 745 |
+
if _current_fastvideo_args is None:
|
| 746 |
+
# in ci, usually when we test custom ops/modules directly,
|
| 747 |
+
# we don't set the fastvideo config. In that case, we set a default
|
| 748 |
+
# config.
|
| 749 |
+
# TODO(will): may need to handle this for CI.
|
| 750 |
+
raise ValueError("Current fastvideo args is not set.")
|
| 751 |
+
return _current_fastvideo_args
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
@dataclasses.dataclass
|
| 755 |
+
class TrainingArgs(FastVideoArgs):
|
| 756 |
+
"""
|
| 757 |
+
Training arguments. Inherits from FastVideoArgs and adds training-specific
|
| 758 |
+
arguments. If there are any conflicts, the training arguments will take
|
| 759 |
+
precedence.
|
| 760 |
+
"""
|
| 761 |
+
data_path: str = ""
|
| 762 |
+
dataloader_num_workers: int = 0
|
| 763 |
+
num_height: int = 0
|
| 764 |
+
num_width: int = 0
|
| 765 |
+
num_frames: int = 0
|
| 766 |
+
|
| 767 |
+
train_batch_size: int = 0
|
| 768 |
+
num_latent_t: int = 0
|
| 769 |
+
group_frame: bool = False
|
| 770 |
+
group_resolution: bool = False
|
| 771 |
+
|
| 772 |
+
# text encoder & vae & diffusion model
|
| 773 |
+
pretrained_model_name_or_path: str = ""
|
| 774 |
+
|
| 775 |
+
# DMD model paths - separate paths for each network
|
| 776 |
+
real_score_model_path: str = "" # path for real score (teacher) model
|
| 777 |
+
fake_score_model_path: str = "" # path for fake score (critic) model
|
| 778 |
+
|
| 779 |
+
# diffusion setting
|
| 780 |
+
ema_decay: float = 0.0
|
| 781 |
+
ema_start_step: int = 0
|
| 782 |
+
training_cfg_rate: float = 0.0
|
| 783 |
+
precondition_outputs: bool = False
|
| 784 |
+
|
| 785 |
+
# validation & logs
|
| 786 |
+
validation_dataset_file: str = ""
|
| 787 |
+
validation_preprocessed_path: str = ""
|
| 788 |
+
validation_sampling_steps: str = ""
|
| 789 |
+
validation_guidance_scale: str = ""
|
| 790 |
+
validation_steps: float = 0.0
|
| 791 |
+
log_validation: bool = False
|
| 792 |
+
trackers: list[str] = dataclasses.field(default_factory=list)
|
| 793 |
+
tracker_project_name: str = ""
|
| 794 |
+
wandb_run_name: str = ""
|
| 795 |
+
seed: int = 0
|
| 796 |
+
_loading_teacher_critic_model: bool = False
|
| 797 |
+
|
| 798 |
+
# output
|
| 799 |
+
output_dir: str = ""
|
| 800 |
+
checkpoints_total_limit: int = 0
|
| 801 |
+
resume_from_checkpoint: str = "" # specify the checkpoint folder to resume from
|
| 802 |
+
|
| 803 |
+
# optimizer & scheduler
|
| 804 |
+
num_train_epochs: int = 0
|
| 805 |
+
max_train_steps: int = 0
|
| 806 |
+
gradient_accumulation_steps: int = 0
|
| 807 |
+
learning_rate: float = 0.0
|
| 808 |
+
scale_lr: bool = False
|
| 809 |
+
lr_scheduler: str = "constant"
|
| 810 |
+
lr_warmup_steps: int = 0
|
| 811 |
+
max_grad_norm: float = 0.0
|
| 812 |
+
enable_gradient_checkpointing_type: str | None = None
|
| 813 |
+
selective_checkpointing: float = 0.0
|
| 814 |
+
mixed_precision: str = ""
|
| 815 |
+
train_sp_batch_size: int = 0
|
| 816 |
+
fsdp_sharding_startegy: str = ""
|
| 817 |
+
|
| 818 |
+
weighting_scheme: str = ""
|
| 819 |
+
logit_mean: float = 0.0
|
| 820 |
+
logit_std: float = 1.0
|
| 821 |
+
mode_scale: float = 0.0
|
| 822 |
+
|
| 823 |
+
num_euler_timesteps: int = 0
|
| 824 |
+
lr_num_cycles: int = 0
|
| 825 |
+
lr_power: float = 0.0
|
| 826 |
+
min_lr_ratio: float = 0.5 # minimum learning rate ratio for cosine_with_min_lr scheduler
|
| 827 |
+
not_apply_cfg_solver: bool = False
|
| 828 |
+
distill_cfg: float = 0.0
|
| 829 |
+
scheduler_type: str = ""
|
| 830 |
+
linear_quadratic_threshold: float = 0.0
|
| 831 |
+
linear_range: float = 0.0
|
| 832 |
+
weight_decay: float = 0.0
|
| 833 |
+
betas: str = "0.9,0.999" # betas for optimizer, format: "beta1,beta2"
|
| 834 |
+
use_ema: bool = False
|
| 835 |
+
multi_phased_distill_schedule: str = ""
|
| 836 |
+
pred_decay_weight: float = 0.0
|
| 837 |
+
pred_decay_type: str = ""
|
| 838 |
+
hunyuan_teacher_disable_cfg: bool = False
|
| 839 |
+
|
| 840 |
+
# master_weight_type
|
| 841 |
+
master_weight_type: str = ""
|
| 842 |
+
|
| 843 |
+
# VSA training decay parameters
|
| 844 |
+
VSA_decay_rate: float = 0.01 # decay rate -> 0.02
|
| 845 |
+
VSA_decay_interval_steps: int = 1 # decay interval steps -> 50
|
| 846 |
+
VSA_init_sparsity: float = 0.0 # initial sparsity (default 0, ramp from 0)
|
| 847 |
+
VSA_warmup_steps: int = 0 # keep init_sparsity for this many steps before ramping
|
| 848 |
+
|
| 849 |
+
# LoRA training parameters
|
| 850 |
+
lora_rank: int | None = None
|
| 851 |
+
lora_alpha: int | None = None
|
| 852 |
+
lora_training: bool = False
|
| 853 |
+
ltx2_first_frame_conditioning_p: float = 0.1
|
| 854 |
+
|
| 855 |
+
# distillation args
|
| 856 |
+
generator_update_interval: int = 5
|
| 857 |
+
dfake_gen_update_ratio: int = 5 # self-forcing: how often to train generator vs critic
|
| 858 |
+
min_timestep_ratio: float = 0.2
|
| 859 |
+
max_timestep_ratio: float = 0.98
|
| 860 |
+
real_score_guidance_scale: float = 3.5
|
| 861 |
+
fake_score_learning_rate: float = 0.0 # separate learning rate for fake_score_transformer, if 0.0, use learning_rate
|
| 862 |
+
fake_score_lr_scheduler: str = "constant" # separate lr scheduler for fake_score_transformer, if not set, use lr_scheduler
|
| 863 |
+
fake_score_betas: str = "0.9,0.999" # betas for fake score optimizer, format: "beta1,beta2"
|
| 864 |
+
training_state_checkpointing_steps: int = 0 # for resuming training
|
| 865 |
+
weight_only_checkpointing_steps: int = 0 # for inference
|
| 866 |
+
log_visualization: bool = False
|
| 867 |
+
visualization_steps: int = 0
|
| 868 |
+
# simulate generator forward to match inference
|
| 869 |
+
simulate_generator_forward: bool = False
|
| 870 |
+
warp_denoising_step: bool = False
|
| 871 |
+
generator_4bit_attn: bool = False
|
| 872 |
+
generator_4bit_linear: bool = False
|
| 873 |
+
|
| 874 |
+
# Self-forcing specific arguments
|
| 875 |
+
num_frame_per_block: int = 3
|
| 876 |
+
independent_first_frame: bool = False
|
| 877 |
+
enable_gradient_masking: bool = True
|
| 878 |
+
gradient_mask_last_n_frames: int = 21
|
| 879 |
+
same_step_across_blocks: bool = False # Use same exit timestep for all blocks
|
| 880 |
+
last_step_only: bool = False # Only use the last timestep for training
|
| 881 |
+
context_noise: int = 0 # Context noise level for cache updates
|
| 882 |
+
|
| 883 |
+
@classmethod
|
| 884 |
+
def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs":
|
| 885 |
+
provided_args = clean_cli_args(args)
|
| 886 |
+
# Get all fields from the dataclass
|
| 887 |
+
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
| 888 |
+
logger.info(provided_args)
|
| 889 |
+
# Create a dictionary of attribute values, with defaults for missing attributes
|
| 890 |
+
kwargs: dict[str, Any] = {}
|
| 891 |
+
for attr in attrs:
|
| 892 |
+
if attr == 'pipeline_config':
|
| 893 |
+
pipeline_config = PipelineConfig.from_kwargs(provided_args)
|
| 894 |
+
kwargs[attr] = pipeline_config
|
| 895 |
+
elif attr == 'mode':
|
| 896 |
+
# Convert string to ExecutionMode enum
|
| 897 |
+
mode_value = getattr(args, attr, ExecutionMode.FINETUNING.value)
|
| 898 |
+
kwargs[attr] = ExecutionMode.from_string(mode_value) if isinstance(mode_value, str) else mode_value
|
| 899 |
+
elif attr == 'workload_type':
|
| 900 |
+
# Convert string to WorkloadType enum
|
| 901 |
+
workload_type_value = getattr(args, 'workload_type', WorkloadType.T2V.value)
|
| 902 |
+
kwargs[attr] = WorkloadType.from_string(workload_type_value) if isinstance(workload_type_value,
|
| 903 |
+
str) else workload_type_value
|
| 904 |
+
# Use getattr with default value from the dataclass for potentially missing attributes
|
| 905 |
+
else:
|
| 906 |
+
# Get the field to check its default value
|
| 907 |
+
field = dataclasses.fields(cls)[next(i for i, f in enumerate(dataclasses.fields(cls))
|
| 908 |
+
if f.name == attr)]
|
| 909 |
+
|
| 910 |
+
# Check if the attribute is provided in args
|
| 911 |
+
if hasattr(args, attr):
|
| 912 |
+
value = getattr(args, attr)
|
| 913 |
+
else:
|
| 914 |
+
# Use the field's default value
|
| 915 |
+
if field.default_factory is not dataclasses.MISSING:
|
| 916 |
+
value = field.default_factory()
|
| 917 |
+
elif field.default is not dataclasses.MISSING:
|
| 918 |
+
value = field.default
|
| 919 |
+
else:
|
| 920 |
+
# No default value, use None
|
| 921 |
+
value = None
|
| 922 |
+
|
| 923 |
+
kwargs[attr] = value
|
| 924 |
+
|
| 925 |
+
return cls(**kwargs) # type: ignore
|
| 926 |
+
|
| 927 |
+
@staticmethod
|
| 928 |
+
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
| 929 |
+
parser.add_argument("--data-path", type=str, required=True, help="Path to parquet files")
|
| 930 |
+
parser.add_argument("--dataloader-num-workers",
|
| 931 |
+
type=int,
|
| 932 |
+
required=True,
|
| 933 |
+
help="Number of workers for dataloader")
|
| 934 |
+
parser.add_argument("--num-height", type=int, required=True, help="Number of heights")
|
| 935 |
+
parser.add_argument("--num-width", type=int, required=True, help="Number of widths")
|
| 936 |
+
parser.add_argument("--num-frames", type=int, required=True, help="Number of frames")
|
| 937 |
+
|
| 938 |
+
# Training batch and model configuration
|
| 939 |
+
parser.add_argument("--train-batch-size", type=int, required=True, help="Training batch size")
|
| 940 |
+
parser.add_argument("--num-latent-t", type=int, required=True, help="Number of latent time steps")
|
| 941 |
+
parser.add_argument("--group-frame", action=StoreBoolean, help="Whether to group frames during training")
|
| 942 |
+
parser.add_argument("--group-resolution",
|
| 943 |
+
action=StoreBoolean,
|
| 944 |
+
help="Whether to group resolutions during training")
|
| 945 |
+
|
| 946 |
+
# Model paths
|
| 947 |
+
parser.add_argument("--pretrained-model-name-or-path",
|
| 948 |
+
type=str,
|
| 949 |
+
required=True,
|
| 950 |
+
help="Path to pretrained model or model name")
|
| 951 |
+
parser.add_argument("--dit-model-name-or-path",
|
| 952 |
+
type=str,
|
| 953 |
+
required=False,
|
| 954 |
+
help="Path to DiT model or model name")
|
| 955 |
+
parser.add_argument("--cache-dir", type=str, help="Directory to cache models")
|
| 956 |
+
|
| 957 |
+
# DMD model paths - separate paths for each network
|
| 958 |
+
parser.add_argument("--generator-model-path",
|
| 959 |
+
type=str,
|
| 960 |
+
help="Path to generator (student) model for DMD distillation")
|
| 961 |
+
parser.add_argument("--real-score-model-path",
|
| 962 |
+
type=str,
|
| 963 |
+
help="Path to real score (teacher) model for DMD distillation")
|
| 964 |
+
parser.add_argument("--fake-score-model-path",
|
| 965 |
+
type=str,
|
| 966 |
+
help="Path to fake score (critic) model for DMD distillation")
|
| 967 |
+
|
| 968 |
+
# Diffusion settings
|
| 969 |
+
parser.add_argument("--ema-decay", type=float, default=0.999, help="EMA decay rate")
|
| 970 |
+
parser.add_argument("--ema-start-step", type=int, default=0, help="Step to start EMA")
|
| 971 |
+
parser.add_argument("--training-cfg-rate", type=float, help="Classifier-free guidance scale")
|
| 972 |
+
parser.add_argument("--precondition-outputs",
|
| 973 |
+
action=StoreBoolean,
|
| 974 |
+
help="Whether to precondition the outputs of the model")
|
| 975 |
+
|
| 976 |
+
# Validation and logging
|
| 977 |
+
parser.add_argument("--validation-dataset-file", type=str, help="Path to unprocessed validation dataset")
|
| 978 |
+
parser.add_argument("--validation-preprocessed-path", type=str, help="Path to processed validation dataset")
|
| 979 |
+
parser.add_argument("--validation-sampling-steps", type=str, help="Validation sampling steps")
|
| 980 |
+
parser.add_argument("--validation-guidance-scale", type=str, help="Validation guidance scale")
|
| 981 |
+
parser.add_argument("--validation-steps", type=float, help="Number of validation steps")
|
| 982 |
+
parser.add_argument("--log-validation", action=StoreBoolean, help="Whether to log validation results")
|
| 983 |
+
parser.add_argument("--visualization-steps", type=int, help="Number of visualization steps")
|
| 984 |
+
parser.add_argument("--tracker-project-name", type=str, help="Project name for tracking")
|
| 985 |
+
parser.add_argument("--wandb-run-name", type=str, help="Run name for wandb")
|
| 986 |
+
parser.add_argument("--seed", type=int, default=42, help="Seed for deterministic training")
|
| 987 |
+
|
| 988 |
+
# Output configuration
|
| 989 |
+
parser.add_argument("--output-dir", type=str, required=True, help="Output directory for checkpoints and logs")
|
| 990 |
+
parser.add_argument("--checkpoints-total-limit", type=int, help="Maximum number of checkpoints to keep")
|
| 991 |
+
parser.add_argument("--training-state-checkpointing-steps",
|
| 992 |
+
type=int,
|
| 993 |
+
help="Steps between training state checkpoints (for resuming training)")
|
| 994 |
+
parser.add_argument("--weight-only-checkpointing-steps",
|
| 995 |
+
type=int,
|
| 996 |
+
help="Steps between weight-only checkpoints (for inference)")
|
| 997 |
+
parser.add_argument("--resume-from-checkpoint", type=str, help="Path to checkpoint to resume from")
|
| 998 |
+
parser.add_argument("--logging-dir", type=str, help="Directory for logging")
|
| 999 |
+
|
| 1000 |
+
# Training configuration
|
| 1001 |
+
parser.add_argument("--num-train-epochs", type=int, help="Number of training epochs")
|
| 1002 |
+
parser.add_argument("--max-train-steps", type=int, help="Maximum number of training steps")
|
| 1003 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of steps to accumulate gradients")
|
| 1004 |
+
parser.add_argument("--learning-rate", type=float, required=True, help="Learning rate")
|
| 1005 |
+
parser.add_argument("--scale-lr", action=StoreBoolean, help="Whether to scale learning rate")
|
| 1006 |
+
parser.add_argument("--lr-scheduler", type=str, default="constant", help="Learning rate scheduler type")
|
| 1007 |
+
parser.add_argument("--lr-warmup-steps", type=int, default=10, help="Number of warmup steps for learning rate")
|
| 1008 |
+
parser.add_argument("--max-grad-norm", type=float, help="Maximum gradient norm")
|
| 1009 |
+
parser.add_argument("--enable-gradient-checkpointing-type",
|
| 1010 |
+
type=str,
|
| 1011 |
+
choices=["full", "ops", "block_skip"],
|
| 1012 |
+
default=None,
|
| 1013 |
+
help="Gradient checkpointing type")
|
| 1014 |
+
parser.add_argument("--selective-checkpointing", type=float, help="Selective checkpointing threshold")
|
| 1015 |
+
parser.add_argument("--mixed-precision", type=str, help="Mixed precision training type")
|
| 1016 |
+
parser.add_argument("--train-sp-batch-size", type=int, help="Training spatial parallelism batch size")
|
| 1017 |
+
|
| 1018 |
+
parser.add_argument("--fsdp-sharding-strategy", type=str, help="FSDP sharding strategy")
|
| 1019 |
+
|
| 1020 |
+
parser.add_argument(
|
| 1021 |
+
"--weighting_scheme",
|
| 1022 |
+
type=str,
|
| 1023 |
+
default="uniform",
|
| 1024 |
+
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"],
|
| 1025 |
+
)
|
| 1026 |
+
parser.add_argument(
|
| 1027 |
+
"--logit_mean",
|
| 1028 |
+
type=float,
|
| 1029 |
+
default=0.0,
|
| 1030 |
+
help="mean to use when using the `'logit_normal'` weighting scheme.",
|
| 1031 |
+
)
|
| 1032 |
+
parser.add_argument(
|
| 1033 |
+
"--logit_std",
|
| 1034 |
+
type=float,
|
| 1035 |
+
default=1.0,
|
| 1036 |
+
help="std to use when using the `'logit_normal'` weighting scheme.",
|
| 1037 |
+
)
|
| 1038 |
+
parser.add_argument(
|
| 1039 |
+
"--mode_scale",
|
| 1040 |
+
type=float,
|
| 1041 |
+
default=1.29,
|
| 1042 |
+
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
| 1043 |
+
)
|
| 1044 |
+
|
| 1045 |
+
# Additional training parameters
|
| 1046 |
+
parser.add_argument("--num-euler-timesteps", type=int, help="Number of Euler timesteps")
|
| 1047 |
+
parser.add_argument("--lr-num-cycles", type=int, help="Number of learning rate cycles")
|
| 1048 |
+
parser.add_argument("--lr-power", type=float, help="Learning rate power")
|
| 1049 |
+
parser.add_argument("--min-lr-ratio",
|
| 1050 |
+
type=float,
|
| 1051 |
+
default=TrainingArgs.min_lr_ratio,
|
| 1052 |
+
help="Minimum learning rate ratio for cosine_with_min_lr scheduler")
|
| 1053 |
+
parser.add_argument("--not-apply-cfg-solver", action=StoreBoolean, help="Whether to not apply CFG solver")
|
| 1054 |
+
parser.add_argument("--distill-cfg", type=float, help="Distillation CFG scale")
|
| 1055 |
+
parser.add_argument("--scheduler-type", type=str, help="Scheduler type")
|
| 1056 |
+
parser.add_argument("--linear-quadratic-threshold", type=float, help="Linear quadratic threshold")
|
| 1057 |
+
parser.add_argument("--linear-range", type=float, help="Linear range")
|
| 1058 |
+
parser.add_argument("--weight-decay", type=float, help="Weight decay")
|
| 1059 |
+
parser.add_argument("--betas",
|
| 1060 |
+
type=str,
|
| 1061 |
+
default=TrainingArgs.betas,
|
| 1062 |
+
help="Betas for optimizer (format: 'beta1,beta2')")
|
| 1063 |
+
parser.add_argument("--use-ema", action=StoreBoolean, help="Whether to use EMA")
|
| 1064 |
+
parser.add_argument("--multi-phased-distill-schedule", type=str, help="Multi-phased distillation schedule")
|
| 1065 |
+
parser.add_argument("--pred-decay-weight", type=float, help="Prediction decay weight")
|
| 1066 |
+
parser.add_argument("--pred-decay-type", type=str, help="Prediction decay type")
|
| 1067 |
+
parser.add_argument("--hunyuan-teacher-disable-cfg",
|
| 1068 |
+
action=StoreBoolean,
|
| 1069 |
+
help="Whether to disable CFG for Hunyuan teacher")
|
| 1070 |
+
parser.add_argument("--master-weight-type", type=str, help="Master weight type")
|
| 1071 |
+
|
| 1072 |
+
# VSA parameters for training with dense to sparse adaption
|
| 1073 |
+
parser.add_argument(
|
| 1074 |
+
"--VSA-decay-rate", # decay rate, how much sparsity you want to decay each step
|
| 1075 |
+
type=float,
|
| 1076 |
+
default=TrainingArgs.VSA_decay_rate,
|
| 1077 |
+
help="VSA decay rate")
|
| 1078 |
+
parser.add_argument(
|
| 1079 |
+
"--VSA-decay-interval-steps", # how many steps for training with current sparsity
|
| 1080 |
+
type=int,
|
| 1081 |
+
default=TrainingArgs.VSA_decay_interval_steps,
|
| 1082 |
+
help="VSA decay interval steps")
|
| 1083 |
+
parser.add_argument(
|
| 1084 |
+
"--VSA-init-sparsity",
|
| 1085 |
+
type=float,
|
| 1086 |
+
default=TrainingArgs.VSA_init_sparsity,
|
| 1087 |
+
help="Initial sparsity to start from (default 0)")
|
| 1088 |
+
parser.add_argument(
|
| 1089 |
+
"--VSA-warmup-steps",
|
| 1090 |
+
type=int,
|
| 1091 |
+
default=TrainingArgs.VSA_warmup_steps,
|
| 1092 |
+
help="Keep init sparsity for N steps before ramping (default 0)")
|
| 1093 |
+
parser.add_argument("--lora-training", action=StoreBoolean, help="Whether to use LoRA training")
|
| 1094 |
+
parser.add_argument("--lora-rank", type=int, help="LoRA rank")
|
| 1095 |
+
parser.add_argument("--lora-alpha", type=int, help="LoRA alpha")
|
| 1096 |
+
parser.add_argument(
|
| 1097 |
+
"--ltx2-first-frame-conditioning-p",
|
| 1098 |
+
type=float,
|
| 1099 |
+
default=TrainingArgs.ltx2_first_frame_conditioning_p,
|
| 1100 |
+
help="Probability of conditioning on the first frame during LTX-2 training",
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
# V-MoBA parameters
|
| 1104 |
+
parser.add_argument(
|
| 1105 |
+
"--moba-config-path",
|
| 1106 |
+
type=str,
|
| 1107 |
+
default=None,
|
| 1108 |
+
help="Path to a JSON file containing V-MoBA specific configurations.",
|
| 1109 |
+
)
|
| 1110 |
+
|
| 1111 |
+
# Distillation arguments
|
| 1112 |
+
parser.add_argument("--generator-update-interval",
|
| 1113 |
+
type=int,
|
| 1114 |
+
default=TrainingArgs.generator_update_interval,
|
| 1115 |
+
help="Ratio of student updates to critic updates.")
|
| 1116 |
+
parser.add_argument(
|
| 1117 |
+
"--dfake-gen-update-ratio",
|
| 1118 |
+
type=int,
|
| 1119 |
+
default=TrainingArgs.dfake_gen_update_ratio,
|
| 1120 |
+
help="Self-forcing: How often to train generator vs critic (train generator every N steps).")
|
| 1121 |
+
parser.add_argument("--min-timestep-ratio",
|
| 1122 |
+
type=float,
|
| 1123 |
+
default=TrainingArgs.min_timestep_ratio,
|
| 1124 |
+
help="Minimum step ratio")
|
| 1125 |
+
parser.add_argument("--max-timestep-ratio",
|
| 1126 |
+
type=float,
|
| 1127 |
+
default=TrainingArgs.max_timestep_ratio,
|
| 1128 |
+
help="Maximum step ratio")
|
| 1129 |
+
parser.add_argument("--real-score-guidance-scale",
|
| 1130 |
+
type=float,
|
| 1131 |
+
default=TrainingArgs.real_score_guidance_scale,
|
| 1132 |
+
help="Teacher guidance scale")
|
| 1133 |
+
parser.add_argument("--fake-score-learning-rate",
|
| 1134 |
+
type=float,
|
| 1135 |
+
default=TrainingArgs.fake_score_learning_rate,
|
| 1136 |
+
help="Learning rate for fake score transformer")
|
| 1137 |
+
parser.add_argument("--fake-score-betas",
|
| 1138 |
+
type=str,
|
| 1139 |
+
default=TrainingArgs.fake_score_betas,
|
| 1140 |
+
help="Betas for fake score optimizer (format: 'beta1,beta2')")
|
| 1141 |
+
parser.add_argument("--fake-score-lr-scheduler",
|
| 1142 |
+
type=str,
|
| 1143 |
+
default=TrainingArgs.fake_score_lr_scheduler,
|
| 1144 |
+
help="Learning rate scheduler for fake score transformer")
|
| 1145 |
+
parser.add_argument("--log-visualization", action=StoreBoolean, help="Whether to log visualization")
|
| 1146 |
+
parser.add_argument("--simulate-generator-forward",
|
| 1147 |
+
action=StoreBoolean,
|
| 1148 |
+
help="Whether to simulate generator forward to match inference")
|
| 1149 |
+
parser.add_argument("--warp-denoising-step",
|
| 1150 |
+
action=StoreBoolean,
|
| 1151 |
+
help="Whether to warp denoising step according to the scheduler time shift")
|
| 1152 |
+
|
| 1153 |
+
# Self-forcing specific arguments
|
| 1154 |
+
parser.add_argument("--num-frame-per-block",
|
| 1155 |
+
type=int,
|
| 1156 |
+
default=TrainingArgs.num_frame_per_block,
|
| 1157 |
+
help="Number of frames per block for causal generation")
|
| 1158 |
+
parser.add_argument("--independent-first-frame",
|
| 1159 |
+
action=StoreBoolean,
|
| 1160 |
+
help="Whether the first frame is independent in causal generation")
|
| 1161 |
+
parser.add_argument("--enable-gradient-masking",
|
| 1162 |
+
action=StoreBoolean,
|
| 1163 |
+
help="Whether to enable frame-level gradient masking")
|
| 1164 |
+
parser.add_argument("--gradient-mask-last-n-frames",
|
| 1165 |
+
type=int,
|
| 1166 |
+
default=TrainingArgs.gradient_mask_last_n_frames,
|
| 1167 |
+
help="Number of last frames to enable gradients for")
|
| 1168 |
+
parser.add_argument("--validate-cache-structure",
|
| 1169 |
+
action=StoreBoolean,
|
| 1170 |
+
help="Whether to validate KV cache structure (debug flag)")
|
| 1171 |
+
parser.add_argument("--same-step-across-blocks",
|
| 1172 |
+
action=StoreBoolean,
|
| 1173 |
+
help="Whether to use the same exit timestep for all blocks")
|
| 1174 |
+
parser.add_argument("--last-step-only",
|
| 1175 |
+
action=StoreBoolean,
|
| 1176 |
+
help="Whether to only use the last timestep for training")
|
| 1177 |
+
parser.add_argument("--context-noise",
|
| 1178 |
+
type=int,
|
| 1179 |
+
default=TrainingArgs.context_noise,
|
| 1180 |
+
help="Context noise level for cache updates")
|
| 1181 |
+
|
| 1182 |
+
return parser
|
| 1183 |
+
|
| 1184 |
+
|
| 1185 |
+
def parse_int_list(value: str) -> list[int]:
|
| 1186 |
+
if not value:
|
| 1187 |
+
return []
|
| 1188 |
+
return [int(x.strip()) for x in value.split(",")]
|
standalone_inference/overlay_files/fastvideo/forward_context.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/forward_context.py
|
| 3 |
+
|
| 4 |
+
import time
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import TYPE_CHECKING, Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from fastvideo.logger import init_logger
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from fastvideo.attention import AttentionMetadata
|
| 16 |
+
from fastvideo.pipelines import ForwardBatch
|
| 17 |
+
|
| 18 |
+
logger = init_logger(__name__)
|
| 19 |
+
|
| 20 |
+
# TODO(will): check if this is needed
|
| 21 |
+
# track_batchsize: bool = envs.FASTVIDEO_LOG_BATCHSIZE_INTERVAL >= 0
|
| 22 |
+
track_batchsize: bool = False
|
| 23 |
+
last_logging_time: float = 0
|
| 24 |
+
forward_start_time: float = 0
|
| 25 |
+
# batchsize_logging_interval: float = envs.FASTVIDEO_LOG_BATCHSIZE_INTERVAL
|
| 26 |
+
batchsize_logging_interval: float = 1000
|
| 27 |
+
batchsize_forward_time: defaultdict = defaultdict(list)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
#
|
| 31 |
+
@dataclass
|
| 32 |
+
class ForwardContext:
|
| 33 |
+
current_timestep: int
|
| 34 |
+
# TODO(will): check this arg
|
| 35 |
+
# copy from vllm_config.compilation_config.static_forward_context
|
| 36 |
+
# attn_layers: Dict[str, Any]
|
| 37 |
+
# TODO: extend to support per-layer dynamic forward context
|
| 38 |
+
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
|
| 39 |
+
forward_batch: Optional["ForwardBatch"] = None
|
| 40 |
+
force_dense: bool = False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
_forward_context: Optional["ForwardContext"] = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_forward_context() -> "ForwardContext":
|
| 47 |
+
"""Get the current forward context."""
|
| 48 |
+
assert _forward_context is not None, ("Forward context is not set. "
|
| 49 |
+
"Please use `set_forward_context` to set the forward context.")
|
| 50 |
+
return _forward_context
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# TODO(will): finalize the interface
|
| 54 |
+
@contextmanager
|
| 55 |
+
def set_forward_context(current_timestep, attn_metadata, forward_batch: Optional["ForwardBatch"] = None, force_dense: bool = False):
|
| 56 |
+
"""A context manager that stores the current forward context,
|
| 57 |
+
can be attention metadata, etc.
|
| 58 |
+
Here we can inject common logic for every model forward pass.
|
| 59 |
+
"""
|
| 60 |
+
global forward_start_time
|
| 61 |
+
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
| 62 |
+
if need_to_track_batchsize:
|
| 63 |
+
forward_start_time = time.perf_counter()
|
| 64 |
+
global _forward_context
|
| 65 |
+
prev_context = _forward_context
|
| 66 |
+
_forward_context = ForwardContext(current_timestep=current_timestep,
|
| 67 |
+
attn_metadata=attn_metadata,
|
| 68 |
+
forward_batch=forward_batch,
|
| 69 |
+
force_dense=force_dense)
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
yield
|
| 73 |
+
finally:
|
| 74 |
+
global last_logging_time, batchsize_logging_interval
|
| 75 |
+
if need_to_track_batchsize:
|
| 76 |
+
if hasattr(attn_metadata, "num_prefill_tokens"):
|
| 77 |
+
# for v0 attention backends
|
| 78 |
+
batchsize = attn_metadata.num_prefill_tokens + \
|
| 79 |
+
attn_metadata.num_decode_tokens
|
| 80 |
+
else:
|
| 81 |
+
# for v1 attention backends
|
| 82 |
+
batchsize = attn_metadata.num_input_tokens
|
| 83 |
+
now = time.perf_counter()
|
| 84 |
+
# time measurement is in milliseconds
|
| 85 |
+
batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000)
|
| 86 |
+
if now - last_logging_time > batchsize_logging_interval:
|
| 87 |
+
last_logging_time = now
|
| 88 |
+
forward_stats = []
|
| 89 |
+
for bs, times in batchsize_forward_time.items():
|
| 90 |
+
if len(times) <= 1:
|
| 91 |
+
# can be cudagraph / profiling run
|
| 92 |
+
continue
|
| 93 |
+
medium = torch.quantile(torch.tensor(times), q=0.5).item()
|
| 94 |
+
medium = round(medium, 2)
|
| 95 |
+
forward_stats.append((bs, len(times), medium))
|
| 96 |
+
forward_stats.sort(key=lambda x: x[1], reverse=True)
|
| 97 |
+
if forward_stats:
|
| 98 |
+
logger.info(("Batchsize forward time stats "
|
| 99 |
+
"(batchsize, count, median_time(ms)): %s"), forward_stats)
|
| 100 |
+
_forward_context = prev_context
|
standalone_inference/overlay_files/fastvideo/pipelines/basic/wan/__init__.py
ADDED
|
File without changes
|
standalone_inference/overlay_files/fastvideo/pipelines/basic/wan/wan_pipeline.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
Wan video diffusion pipeline implementation.
|
| 4 |
+
|
| 5 |
+
This module contains an implementation of the Wan video diffusion pipeline
|
| 6 |
+
using the modular pipeline architecture.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from fastvideo.fastvideo_args import FastVideoArgs
|
| 10 |
+
from fastvideo.logger import init_logger
|
| 11 |
+
from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import (FlowUniPCMultistepScheduler)
|
| 12 |
+
from fastvideo.pipelines import ComposedPipelineBase, LoRAPipeline
|
| 13 |
+
from fastvideo.pipelines.stages import (ConditioningStage, DecodingStage, DenoisingStage, InputValidationStage,
|
| 14 |
+
LatentPreparationStage, TextEncodingStage, TimestepPreparationStage)
|
| 15 |
+
|
| 16 |
+
logger = init_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class WanPipeline(LoRAPipeline, ComposedPipelineBase):
|
| 20 |
+
"""
|
| 21 |
+
Wan video diffusion pipeline with LoRA support.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
_required_config_modules = ["text_encoder", "tokenizer", "vae", "transformer", "scheduler"]
|
| 25 |
+
|
| 26 |
+
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
|
| 27 |
+
# We use UniPCMScheduler from Wan2.1 official repo, not the one in diffusers.
|
| 28 |
+
self.modules["scheduler"] = FlowUniPCMultistepScheduler(shift=fastvideo_args.pipeline_config.flow_shift)
|
| 29 |
+
|
| 30 |
+
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
|
| 31 |
+
"""Set up pipeline stages with proper dependency injection."""
|
| 32 |
+
|
| 33 |
+
self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())
|
| 34 |
+
|
| 35 |
+
self.add_stage(stage_name="prompt_encoding_stage",
|
| 36 |
+
stage=TextEncodingStage(
|
| 37 |
+
text_encoders=[self.get_module("text_encoder")],
|
| 38 |
+
tokenizers=[self.get_module("tokenizer")],
|
| 39 |
+
))
|
| 40 |
+
|
| 41 |
+
self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())
|
| 42 |
+
|
| 43 |
+
self.add_stage(stage_name="timestep_preparation_stage",
|
| 44 |
+
stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))
|
| 45 |
+
|
| 46 |
+
self.add_stage(stage_name="latent_preparation_stage",
|
| 47 |
+
stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
|
| 48 |
+
transformer=self.get_module("transformer", None)))
|
| 49 |
+
|
| 50 |
+
self.add_stage(stage_name="denoising_stage",
|
| 51 |
+
stage=DenoisingStage(transformer=self.get_module("transformer"),
|
| 52 |
+
transformer_2=self.get_module("transformer_2", None),
|
| 53 |
+
scheduler=self.get_module("scheduler"),
|
| 54 |
+
vae=self.get_module("vae"),
|
| 55 |
+
pipeline=self))
|
| 56 |
+
|
| 57 |
+
self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), pipeline=self))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
EntryClass = WanPipeline
|
standalone_inference/overlay_files/fastvideo/pipelines/composed_pipeline_base.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
Base class for composed pipelines.
|
| 4 |
+
|
| 5 |
+
This module defines the base class for pipelines that are composed of multiple stages.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import os
|
| 10 |
+
from abc import ABC, abstractmethod
|
| 11 |
+
from typing import Any, cast
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from fastvideo.configs.pipelines import PipelineConfig
|
| 16 |
+
from fastvideo.distributed import (maybe_init_distributed_environment_and_model_parallel, get_world_group)
|
| 17 |
+
from fastvideo.distributed.communication_op import (warmup_sequence_parallel_communication)
|
| 18 |
+
from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs
|
| 19 |
+
from fastvideo.logger import init_logger
|
| 20 |
+
from fastvideo.profiler import get_or_create_profiler
|
| 21 |
+
from fastvideo.models.loader.component_loader import PipelineComponentLoader
|
| 22 |
+
from fastvideo.pipelines.pipeline_batch_info import ForwardBatch
|
| 23 |
+
from fastvideo.pipelines.stages import PipelineStage
|
| 24 |
+
import fastvideo.envs as envs
|
| 25 |
+
from fastvideo.utils import (maybe_download_model, verify_model_config_and_directory)
|
| 26 |
+
|
| 27 |
+
logger = init_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ComposedPipelineBase(ABC):
|
| 31 |
+
"""
|
| 32 |
+
Base class for pipelines composed of multiple stages.
|
| 33 |
+
|
| 34 |
+
This class provides the framework for creating pipelines by composing multiple
|
| 35 |
+
stages together. Each stage is responsible for a specific part of the diffusion
|
| 36 |
+
process, and the pipeline orchestrates the execution of these stages.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
is_video_pipeline: bool = False # To be overridden by video pipelines
|
| 40 |
+
_required_config_modules: list[str] = []
|
| 41 |
+
_extra_config_module_map: dict[str, str] = {}
|
| 42 |
+
training_args: Any = None
|
| 43 |
+
fastvideo_args: Any = None
|
| 44 |
+
modules: dict[str, Any] = {}
|
| 45 |
+
# do not need to include moe related transformers
|
| 46 |
+
trainable_transformer_names: list[str] = ["transformer"]
|
| 47 |
+
trainable_transformer_modules: dict[str, torch.nn.Module] = {}
|
| 48 |
+
post_init_called: bool = False
|
| 49 |
+
|
| 50 |
+
# TODO(will): args should support both inference args and training args
|
| 51 |
+
def __init__(self,
|
| 52 |
+
model_path: str,
|
| 53 |
+
fastvideo_args: FastVideoArgs | TrainingArgs,
|
| 54 |
+
required_config_modules: list[str] | None = None,
|
| 55 |
+
loaded_modules: dict[str, torch.nn.Module] | None = None):
|
| 56 |
+
"""
|
| 57 |
+
Initialize the pipeline. After __init__, the pipeline should be ready to
|
| 58 |
+
use. The pipeline should be stateless and not hold any batch state.
|
| 59 |
+
"""
|
| 60 |
+
self.fastvideo_args = fastvideo_args
|
| 61 |
+
|
| 62 |
+
self.model_path: str = model_path
|
| 63 |
+
self._stages: list[PipelineStage] = []
|
| 64 |
+
self._stage_name_mapping: dict[str, PipelineStage] = {}
|
| 65 |
+
|
| 66 |
+
if required_config_modules is not None:
|
| 67 |
+
self._required_config_modules = required_config_modules
|
| 68 |
+
|
| 69 |
+
if self._required_config_modules is None:
|
| 70 |
+
raise NotImplementedError("Subclass must set _required_config_modules")
|
| 71 |
+
|
| 72 |
+
maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)
|
| 73 |
+
|
| 74 |
+
# Torch profiler. Enabled and configured through env vars:
|
| 75 |
+
# FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
|
| 76 |
+
trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
|
| 77 |
+
self.profiler_controller = get_or_create_profiler(trace_dir)
|
| 78 |
+
self.profiler = self.profiler_controller.profiler
|
| 79 |
+
|
| 80 |
+
self.local_rank = get_world_group().local_rank
|
| 81 |
+
|
| 82 |
+
# Load modules directly in initialization
|
| 83 |
+
logger.info("Loading pipeline modules...")
|
| 84 |
+
with self.profiler_controller.region("profiler_region_model_loading"):
|
| 85 |
+
self.modules = self.load_modules(fastvideo_args, loaded_modules)
|
| 86 |
+
|
| 87 |
+
def set_trainable(self) -> None:
|
| 88 |
+
# Only train DiT
|
| 89 |
+
if getattr(self.fastvideo_args, "training_mode", False):
|
| 90 |
+
for name, module in self.trainable_transformer_modules.items():
|
| 91 |
+
logger.info("Setting %s to requires_grad=True", name)
|
| 92 |
+
if not isinstance(module, torch.nn.Module):
|
| 93 |
+
logger.info("Skipping %s because it is not a torch.nn.Module", name)
|
| 94 |
+
continue
|
| 95 |
+
module.requires_grad_(True)
|
| 96 |
+
module.train()
|
| 97 |
+
|
| 98 |
+
@staticmethod
|
| 99 |
+
def _compile_with_conditions(
|
| 100 |
+
module: torch.nn.Module,
|
| 101 |
+
compile_kwargs: dict[str, Any],
|
| 102 |
+
) -> int:
|
| 103 |
+
"""Compile submodules that match module._compile_conditions."""
|
| 104 |
+
compile_conditions = getattr(module, "_compile_conditions", None)
|
| 105 |
+
if not compile_conditions:
|
| 106 |
+
return 0
|
| 107 |
+
|
| 108 |
+
compiled_count = 0
|
| 109 |
+
for name, submodule in module.named_modules():
|
| 110 |
+
if not name:
|
| 111 |
+
continue
|
| 112 |
+
if any(cond(name, submodule) for cond in compile_conditions):
|
| 113 |
+
submodule.forward = torch.compile(submodule.forward, **compile_kwargs)
|
| 114 |
+
compiled_count += 1
|
| 115 |
+
return compiled_count
|
| 116 |
+
|
| 117 |
+
def _maybe_compile_pipeline_module(
|
| 118 |
+
self,
|
| 119 |
+
module_name: str,
|
| 120 |
+
fsdp_module_cls: type | None,
|
| 121 |
+
compile_kwargs: dict[str, Any],
|
| 122 |
+
) -> None:
|
| 123 |
+
if module_name not in self.modules:
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
module = self.modules[module_name]
|
| 127 |
+
if fsdp_module_cls is not None and isinstance(module, fsdp_module_cls):
|
| 128 |
+
logger.info(
|
| 129 |
+
"%s is already FSDP-wrapped; skipping torch.compile in pipeline",
|
| 130 |
+
module_name.capitalize(),
|
| 131 |
+
)
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
compiled_count = self._compile_with_conditions(module, compile_kwargs)
|
| 135 |
+
if compiled_count > 0:
|
| 136 |
+
logger.info(
|
| 137 |
+
"Enabled torch.compile for %d submodules in %s via _compile_conditions with kwargs=%s",
|
| 138 |
+
compiled_count,
|
| 139 |
+
module_name,
|
| 140 |
+
compile_kwargs,
|
| 141 |
+
)
|
| 142 |
+
return
|
| 143 |
+
|
| 144 |
+
# Backward-compatible fallback: compile full module if no condition matched.
|
| 145 |
+
logger.info("Enabling torch.compile for %s with kwargs=%s", module_name, compile_kwargs)
|
| 146 |
+
self.modules[module_name] = torch.compile(module, **compile_kwargs)
|
| 147 |
+
|
| 148 |
+
def post_init(self) -> None:
|
| 149 |
+
assert self.fastvideo_args is not None, "fastvideo_args must be set"
|
| 150 |
+
if self.post_init_called:
|
| 151 |
+
return
|
| 152 |
+
self.post_init_called = True
|
| 153 |
+
if self.fastvideo_args.training_mode:
|
| 154 |
+
assert isinstance(self.fastvideo_args, TrainingArgs)
|
| 155 |
+
self.training_args = self.fastvideo_args
|
| 156 |
+
assert self.training_args is not None
|
| 157 |
+
self.initialize_training_pipeline(self.training_args)
|
| 158 |
+
if self.training_args.log_validation:
|
| 159 |
+
self.initialize_validation_pipeline(self.training_args)
|
| 160 |
+
|
| 161 |
+
self.initialize_pipeline(self.fastvideo_args)
|
| 162 |
+
if self.fastvideo_args.enable_torch_compile:
|
| 163 |
+
if self.fastvideo_args.training_mode:
|
| 164 |
+
logger.info("Torch Compile enabled via FSDP loader for training; skipping additional pipeline compile")
|
| 165 |
+
else:
|
| 166 |
+
fsdp_module_cls = None
|
| 167 |
+
try:
|
| 168 |
+
from torch.distributed.fsdp import FSDPModule # type: ignore
|
| 169 |
+
fsdp_module_cls = FSDPModule
|
| 170 |
+
except Exception: # pragma: no cover - FSDP not always available
|
| 171 |
+
fsdp_module_cls = None
|
| 172 |
+
|
| 173 |
+
compile_kwargs = self.fastvideo_args.torch_compile_kwargs or {}
|
| 174 |
+
self._maybe_compile_pipeline_module(
|
| 175 |
+
module_name="transformer",
|
| 176 |
+
fsdp_module_cls=fsdp_module_cls,
|
| 177 |
+
compile_kwargs=compile_kwargs,
|
| 178 |
+
)
|
| 179 |
+
self._maybe_compile_pipeline_module(
|
| 180 |
+
module_name="transformer_2",
|
| 181 |
+
fsdp_module_cls=fsdp_module_cls,
|
| 182 |
+
compile_kwargs=compile_kwargs,
|
| 183 |
+
)
|
| 184 |
+
logger.info("Torch Compile enabled for DiT")
|
| 185 |
+
|
| 186 |
+
if not self.fastvideo_args.training_mode:
|
| 187 |
+
logger.info("Creating pipeline stages...")
|
| 188 |
+
self.create_pipeline_stages(self.fastvideo_args)
|
| 189 |
+
|
| 190 |
+
# Warmup NCCL communicators for sequence parallelism to avoid
|
| 191 |
+
# slow first forward pass due to lazy initialization
|
| 192 |
+
warmup_sequence_parallel_communication()
|
| 193 |
+
|
| 194 |
+
def initialize_training_pipeline(self, training_args: TrainingArgs):
|
| 195 |
+
raise NotImplementedError("if training_mode is True, the pipeline must implement this method")
|
| 196 |
+
|
| 197 |
+
def initialize_validation_pipeline(self, training_args: TrainingArgs):
|
| 198 |
+
raise NotImplementedError("if log_validation is True, the pipeline must implement this method")
|
| 199 |
+
|
| 200 |
+
@classmethod
|
| 201 |
+
def from_pretrained(cls,
|
| 202 |
+
model_path: str,
|
| 203 |
+
device: str | None = None,
|
| 204 |
+
torch_dtype: torch.dtype | None = None,
|
| 205 |
+
pipeline_config: str | PipelineConfig | None = None,
|
| 206 |
+
args: argparse.Namespace | FastVideoArgs | TrainingArgs | None = None,
|
| 207 |
+
required_config_modules: list[str] | None = None,
|
| 208 |
+
loaded_modules: dict[str, torch.nn.Module]
|
| 209 |
+
| None = None,
|
| 210 |
+
**kwargs) -> "ComposedPipelineBase":
|
| 211 |
+
"""
|
| 212 |
+
Load a pipeline from a pretrained model.
|
| 213 |
+
loaded_modules: Optional[Dict[str, torch.nn.Module]] = None,
|
| 214 |
+
If provided, loaded_modules will be used instead of loading from config/pretrained weights.
|
| 215 |
+
"""
|
| 216 |
+
if args is None or (isinstance(args, FastVideoArgs) and args.inference_mode):
|
| 217 |
+
|
| 218 |
+
kwargs['model_path'] = model_path
|
| 219 |
+
fastvideo_args = FastVideoArgs.from_kwargs(**kwargs)
|
| 220 |
+
else:
|
| 221 |
+
if isinstance(args, TrainingArgs):
|
| 222 |
+
fastvideo_args = args
|
| 223 |
+
else:
|
| 224 |
+
assert isinstance(args, argparse.Namespace), "training mode expects argparse.Namespace args"
|
| 225 |
+
fastvideo_args = TrainingArgs.from_cli_args(args)
|
| 226 |
+
# TODO(will): fix this so that its not so ugly
|
| 227 |
+
fastvideo_args.model_path = model_path
|
| 228 |
+
for key, value in kwargs.items():
|
| 229 |
+
setattr(fastvideo_args, key, value)
|
| 230 |
+
|
| 231 |
+
fastvideo_args.dit_cpu_offload = False
|
| 232 |
+
# we hijack the precision to be the master weight type so that the
|
| 233 |
+
# model is loaded with the correct precision. Subsequently we will
|
| 234 |
+
# use FSDP2's MixedPrecisionPolicy to set the precision for the
|
| 235 |
+
# fwd, bwd, and other operations' precision.
|
| 236 |
+
assert fastvideo_args.pipeline_config.dit_precision == 'fp32', 'only fp32 is supported for training'
|
| 237 |
+
|
| 238 |
+
logger.info("fastvideo_args in from_pretrained: %s", fastvideo_args)
|
| 239 |
+
|
| 240 |
+
pipe = cls(model_path,
|
| 241 |
+
fastvideo_args,
|
| 242 |
+
required_config_modules=required_config_modules,
|
| 243 |
+
loaded_modules=loaded_modules)
|
| 244 |
+
pipe.post_init()
|
| 245 |
+
return pipe
|
| 246 |
+
|
| 247 |
+
def get_module(self, module_name: str, default_value: Any = None) -> Any:
|
| 248 |
+
if module_name not in self.modules:
|
| 249 |
+
return default_value
|
| 250 |
+
return self.modules[module_name]
|
| 251 |
+
|
| 252 |
+
def add_module(self, module_name: str, module: Any):
|
| 253 |
+
self.modules[module_name] = module
|
| 254 |
+
|
| 255 |
+
def __getattr__(self, name: str) -> Any:
|
| 256 |
+
if "_stage_name_mapping" in self.__dict__ and name in self._stage_name_mapping:
|
| 257 |
+
return self._stage_name_mapping[name]
|
| 258 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
| 259 |
+
|
| 260 |
+
def _load_config(self, model_path: str) -> dict[str, Any]:
|
| 261 |
+
model_path = maybe_download_model(self.model_path)
|
| 262 |
+
self.model_path = model_path
|
| 263 |
+
# fastvideo_args.downloaded_model_path = model_path
|
| 264 |
+
logger.info("Model path: %s", model_path)
|
| 265 |
+
config = verify_model_config_and_directory(model_path)
|
| 266 |
+
return cast(dict[str, Any], config)
|
| 267 |
+
|
| 268 |
+
@property
|
| 269 |
+
def required_config_modules(self) -> list[str]:
|
| 270 |
+
"""
|
| 271 |
+
List of modules that are required by the pipeline. The names should match
|
| 272 |
+
the diffusers directory and model_index.json file. These modules will be
|
| 273 |
+
loaded using the PipelineComponentLoader and made available in the
|
| 274 |
+
modules dictionary. Access these modules using the get_module method.
|
| 275 |
+
|
| 276 |
+
class ConcretePipeline(ComposedPipelineBase):
|
| 277 |
+
_required_config_modules = ["vae", "text_encoder", "transformer", "scheduler", "tokenizer"]
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
@property
|
| 281 |
+
def required_config_modules(self):
|
| 282 |
+
return self._required_config_modules
|
| 283 |
+
"""
|
| 284 |
+
return self._required_config_modules
|
| 285 |
+
|
| 286 |
+
@property
|
| 287 |
+
def stages(self) -> list[PipelineStage]:
|
| 288 |
+
"""
|
| 289 |
+
List of stages in the pipeline.
|
| 290 |
+
"""
|
| 291 |
+
return self._stages
|
| 292 |
+
|
| 293 |
+
@abstractmethod
|
| 294 |
+
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
|
| 295 |
+
"""
|
| 296 |
+
Create the inference pipeline stages.
|
| 297 |
+
"""
|
| 298 |
+
raise NotImplementedError
|
| 299 |
+
|
| 300 |
+
def create_training_stages(self, training_args: TrainingArgs):
|
| 301 |
+
"""
|
| 302 |
+
Create the training pipeline stages.
|
| 303 |
+
"""
|
| 304 |
+
raise NotImplementedError
|
| 305 |
+
|
| 306 |
+
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
|
| 307 |
+
"""
|
| 308 |
+
Initialize the pipeline.
|
| 309 |
+
"""
|
| 310 |
+
return
|
| 311 |
+
|
| 312 |
+
def load_modules(self,
|
| 313 |
+
fastvideo_args: FastVideoArgs,
|
| 314 |
+
loaded_modules: dict[str, torch.nn.Module] | None = None) -> dict[str, Any]:
|
| 315 |
+
"""
|
| 316 |
+
Load the modules from the config.
|
| 317 |
+
loaded_modules: Optional[Dict[str, torch.nn.Module]] = None,
|
| 318 |
+
If provided, loaded_modules will be used instead of loading from config/pretrained weights.
|
| 319 |
+
"""
|
| 320 |
+
|
| 321 |
+
model_index = self._load_config(self.model_path)
|
| 322 |
+
logger.info("Loading pipeline modules from config: %s", model_index)
|
| 323 |
+
|
| 324 |
+
# remove keys that are not pipeline modules
|
| 325 |
+
model_index.pop("_class_name")
|
| 326 |
+
model_index.pop("_diffusers_version")
|
| 327 |
+
model_index.pop("_name_or_path", None)
|
| 328 |
+
model_index.pop("workload_type", None)
|
| 329 |
+
if "boundary_ratio" in model_index and model_index["boundary_ratio"] is not None:
|
| 330 |
+
logger.info("MoE pipeline detected. Adding transformer_2 to self.required_config_modules...")
|
| 331 |
+
self.required_config_modules.append("transformer_2")
|
| 332 |
+
logger.info("MoE pipeline detected. Setting boundary ratio to %s", model_index["boundary_ratio"])
|
| 333 |
+
fastvideo_args.pipeline_config.dit_config.boundary_ratio = model_index["boundary_ratio"]
|
| 334 |
+
|
| 335 |
+
model_index.pop("boundary_ratio", None)
|
| 336 |
+
# used by Wan2.2 ti2v
|
| 337 |
+
model_index.pop("expand_timesteps", None)
|
| 338 |
+
|
| 339 |
+
# some sanity checks
|
| 340 |
+
assert len(model_index) > 1, "model_index.json must contain at least one pipeline module"
|
| 341 |
+
|
| 342 |
+
for module_name in self.required_config_modules:
|
| 343 |
+
if module_name not in model_index and module_name in self._extra_config_module_map:
|
| 344 |
+
extra_module_value = self._extra_config_module_map[module_name]
|
| 345 |
+
logger.warning(
|
| 346 |
+
"model_index.json does not contain a %s module, but found {%s: %s} in _extra_config_module_map, adding to model_index.",
|
| 347 |
+
module_name, module_name, extra_module_value)
|
| 348 |
+
if extra_module_value in model_index:
|
| 349 |
+
logger.info("Using module %s for %s", extra_module_value, module_name)
|
| 350 |
+
model_index[module_name] = model_index[extra_module_value]
|
| 351 |
+
continue
|
| 352 |
+
else:
|
| 353 |
+
raise ValueError(
|
| 354 |
+
f"Required module key: {module_name} value: {model_index.get(module_name)} was not found in loaded modules {model_index.keys()}"
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# all the component models used by the pipeline
|
| 358 |
+
required_modules = self.required_config_modules
|
| 359 |
+
logger.info("Loading required modules: %s", required_modules)
|
| 360 |
+
|
| 361 |
+
modules = {}
|
| 362 |
+
for module_name, module_spec in model_index.items():
|
| 363 |
+
if not isinstance(module_spec, list | tuple):
|
| 364 |
+
logger.info(
|
| 365 |
+
"Skipping non-module config entry %s=%s",
|
| 366 |
+
module_name,
|
| 367 |
+
module_spec,
|
| 368 |
+
)
|
| 369 |
+
continue
|
| 370 |
+
if len(module_spec) < 1:
|
| 371 |
+
logger.warning(
|
| 372 |
+
"Skipping module %s due to invalid empty spec in model_index.json",
|
| 373 |
+
module_name,
|
| 374 |
+
)
|
| 375 |
+
continue
|
| 376 |
+
transformers_or_diffusers = module_spec[0]
|
| 377 |
+
if transformers_or_diffusers is None:
|
| 378 |
+
logger.warning("Module %s in model_index.json has null value, removing from required_config_modules",
|
| 379 |
+
module_name)
|
| 380 |
+
if module_name in self.required_config_modules:
|
| 381 |
+
self.required_config_modules.remove(module_name)
|
| 382 |
+
continue
|
| 383 |
+
if module_name not in required_modules:
|
| 384 |
+
logger.info("Skipping module %s", module_name)
|
| 385 |
+
continue
|
| 386 |
+
if loaded_modules is not None and module_name in loaded_modules:
|
| 387 |
+
logger.info("Using module %s already provided", module_name)
|
| 388 |
+
modules[module_name] = loaded_modules[module_name]
|
| 389 |
+
continue
|
| 390 |
+
|
| 391 |
+
# we load the module from the extra config module map if it exists
|
| 392 |
+
if module_name in self._extra_config_module_map:
|
| 393 |
+
load_module_name = self._extra_config_module_map[module_name]
|
| 394 |
+
else:
|
| 395 |
+
load_module_name = module_name
|
| 396 |
+
|
| 397 |
+
component_model_path = os.path.join(self.model_path, load_module_name)
|
| 398 |
+
module = PipelineComponentLoader.load_module(
|
| 399 |
+
module_name=load_module_name,
|
| 400 |
+
component_model_path=component_model_path,
|
| 401 |
+
transformers_or_diffusers=transformers_or_diffusers,
|
| 402 |
+
fastvideo_args=fastvideo_args,
|
| 403 |
+
)
|
| 404 |
+
logger.info("Loaded module %s from %s", module_name, component_model_path)
|
| 405 |
+
|
| 406 |
+
if module_name in modules:
|
| 407 |
+
logger.warning("Overwriting module %s", module_name)
|
| 408 |
+
modules[module_name] = module
|
| 409 |
+
|
| 410 |
+
# Check if all required modules were loaded
|
| 411 |
+
for module_name in required_modules:
|
| 412 |
+
if module_name not in modules or modules[module_name] is None:
|
| 413 |
+
raise ValueError(
|
| 414 |
+
f"Required module key: {module_name} value: {modules.get(module_name)} was not found in loaded modules {modules.keys()}"
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
return modules
|
| 418 |
+
|
| 419 |
+
def add_stage(self, stage_name: str, stage: PipelineStage):
|
| 420 |
+
assert self.modules is not None, "No modules are registered"
|
| 421 |
+
self._stages.append(stage)
|
| 422 |
+
self._stage_name_mapping[stage_name] = stage
|
| 423 |
+
setattr(self, stage_name, stage)
|
| 424 |
+
|
| 425 |
+
def profile(self, is_start: bool = True):
|
| 426 |
+
if self.profiler is None:
|
| 427 |
+
raise RuntimeError("Profiler is not enabled.")
|
| 428 |
+
if is_start:
|
| 429 |
+
self.profiler.start()
|
| 430 |
+
else:
|
| 431 |
+
self.profiler.stop()
|
| 432 |
+
# only print profiler results on rank 0
|
| 433 |
+
if self.local_rank == 0:
|
| 434 |
+
print(self.profiler.key_averages().table(sort_by="self_cuda_time_total"))
|
| 435 |
+
|
| 436 |
+
# TODO(will): don't hardcode no_grad
|
| 437 |
+
@torch.no_grad()
|
| 438 |
+
def forward(
|
| 439 |
+
self,
|
| 440 |
+
batch: ForwardBatch,
|
| 441 |
+
fastvideo_args: FastVideoArgs,
|
| 442 |
+
) -> ForwardBatch:
|
| 443 |
+
"""
|
| 444 |
+
Generate a video or image using the pipeline.
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
batch: The batch to generate from.
|
| 448 |
+
fastvideo_args: The inference arguments.
|
| 449 |
+
Returns:
|
| 450 |
+
ForwardBatch: The batch with the generated video or image.
|
| 451 |
+
"""
|
| 452 |
+
if not self.post_init_called:
|
| 453 |
+
self.post_init()
|
| 454 |
+
|
| 455 |
+
# Execute each stage
|
| 456 |
+
logger.info("Running pipeline stages: %s", self._stage_name_mapping.keys())
|
| 457 |
+
# logger.info("Batch: %s", batch)
|
| 458 |
+
for stage in self.stages:
|
| 459 |
+
batch = stage(batch, fastvideo_args)
|
| 460 |
+
|
| 461 |
+
# Return the output
|
| 462 |
+
return batch
|
| 463 |
+
|
| 464 |
+
def train(self) -> None:
|
| 465 |
+
raise NotImplementedError("if training_mode is True, the pipeline must implement this method")
|
| 466 |
+
|
| 467 |
+
def streaming_reset(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch:
|
| 468 |
+
raise NotImplementedError(f"{type(self).__name__} does not support streaming_reset")
|
| 469 |
+
|
| 470 |
+
def streaming_step(self, *args: Any, **kwargs: Any) -> ForwardBatch:
|
| 471 |
+
raise NotImplementedError(f"{type(self).__name__} does not support streaming_step")
|
| 472 |
+
|
| 473 |
+
def streaming_clear(self) -> None:
|
| 474 |
+
raise NotImplementedError(f"{type(self).__name__} does not support streaming_clear")
|
standalone_inference/overlay_files/fastvideo/pipelines/stages/denoising.py
ADDED
|
@@ -0,0 +1,1184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
Denoising stage for diffusion pipelines.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import inspect
|
| 7 |
+
import weakref
|
| 8 |
+
from collections.abc import Iterable
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from tqdm.auto import tqdm
|
| 13 |
+
|
| 14 |
+
from fastvideo.attention import get_attn_backend
|
| 15 |
+
from fastvideo.distributed import (get_local_torch_device, get_world_group)
|
| 16 |
+
from fastvideo.fastvideo_args import FastVideoArgs
|
| 17 |
+
from fastvideo.forward_context import set_forward_context
|
| 18 |
+
from fastvideo.logger import init_logger
|
| 19 |
+
from fastvideo.models.loader.component_loader import TransformerLoader
|
| 20 |
+
from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import (FlowMatchEulerDiscreteScheduler)
|
| 21 |
+
from fastvideo.models.utils import pred_noise_to_pred_video
|
| 22 |
+
from fastvideo.pipelines.pipeline_batch_info import ForwardBatch
|
| 23 |
+
from fastvideo.pipelines.stages.base import PipelineStage
|
| 24 |
+
from fastvideo.pipelines.stages.validators import StageValidators as V
|
| 25 |
+
from fastvideo.pipelines.stages.validators import VerificationResult
|
| 26 |
+
from fastvideo.platforms import AttentionBackendEnum
|
| 27 |
+
from fastvideo.utils import dict_to_3d_list, masks_like
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from fastvideo.attention.backends.vmoba import VMOBAAttentionBackend
|
| 31 |
+
from fastvideo.utils import is_vmoba_available
|
| 32 |
+
vmoba_attn_available = is_vmoba_available()
|
| 33 |
+
except ImportError:
|
| 34 |
+
vmoba_attn_available = False
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
from fastvideo.attention.backends.video_sparse_attn import (VideoSparseAttentionBackend)
|
| 38 |
+
vsa_available = True
|
| 39 |
+
except ImportError:
|
| 40 |
+
vsa_available = False
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
from fastvideo.attention.backends.sparse_fp4_attn import (SparseFP4AttentionBackend)
|
| 44 |
+
except ImportError:
|
| 45 |
+
SparseFP4AttentionBackend = None # type: ignore[assignment]
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
from fastvideo.attention.backends.sparse_fp4_ours_p_attn import (SparseFP4OursPAttentionBackend)
|
| 49 |
+
except ImportError:
|
| 50 |
+
SparseFP4OursPAttentionBackend = None # type: ignore[assignment]
|
| 51 |
+
|
| 52 |
+
sparse_fp4_backends = tuple(
|
| 53 |
+
backend for backend in (
|
| 54 |
+
SparseFP4AttentionBackend,
|
| 55 |
+
SparseFP4OursPAttentionBackend,
|
| 56 |
+
) if backend is not None)
|
| 57 |
+
sparse_fp4_available = bool(sparse_fp4_backends)
|
| 58 |
+
|
| 59 |
+
logger = init_logger(__name__)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class DenoisingStage(PipelineStage):
|
| 63 |
+
"""
|
| 64 |
+
Stage for running the denoising loop in diffusion pipelines.
|
| 65 |
+
|
| 66 |
+
This stage handles the iterative denoising process that transforms
|
| 67 |
+
the initial noise into the final output.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, transformer, scheduler, pipeline=None, transformer_2=None, vae=None) -> None:
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.transformer = transformer
|
| 73 |
+
self.transformer_2 = transformer_2
|
| 74 |
+
self.scheduler = scheduler
|
| 75 |
+
self.vae = vae
|
| 76 |
+
self.pipeline = weakref.ref(pipeline) if pipeline else None
|
| 77 |
+
attn_head_size = self.transformer.hidden_size // self.transformer.num_attention_heads
|
| 78 |
+
self.attn_backend = get_attn_backend(
|
| 79 |
+
head_size=attn_head_size,
|
| 80 |
+
dtype=torch.float16, # TODO(will): hack
|
| 81 |
+
supported_attention_backends=(
|
| 82 |
+
AttentionBackendEnum.VIDEO_SPARSE_ATTN, AttentionBackendEnum.BSA_ATTN, AttentionBackendEnum.VMOBA_ATTN,
|
| 83 |
+
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.SAGE_ATTN_THREE,
|
| 84 |
+
AttentionBackendEnum.ATTN_QAT_INFER, AttentionBackendEnum.ATTN_QAT_TRAIN,
|
| 85 |
+
AttentionBackendEnum.SPARSE_FP4_ATTN, AttentionBackendEnum.SPARSE_FP4_OURS_P_ATTN) # hack
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def forward(
|
| 89 |
+
self,
|
| 90 |
+
batch: ForwardBatch,
|
| 91 |
+
fastvideo_args: FastVideoArgs,
|
| 92 |
+
) -> ForwardBatch:
|
| 93 |
+
"""
|
| 94 |
+
Run the denoising loop.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
batch: The current batch information.
|
| 98 |
+
fastvideo_args: The inference arguments.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
The batch with denoised latents.
|
| 102 |
+
"""
|
| 103 |
+
pipeline = self.pipeline() if self.pipeline else None
|
| 104 |
+
if not fastvideo_args.model_loaded["transformer"]:
|
| 105 |
+
loader = TransformerLoader()
|
| 106 |
+
self.transformer = loader.load(fastvideo_args.model_paths["transformer"], fastvideo_args)
|
| 107 |
+
if pipeline:
|
| 108 |
+
pipeline.add_module("transformer", self.transformer)
|
| 109 |
+
fastvideo_args.model_loaded["transformer"] = True
|
| 110 |
+
|
| 111 |
+
# Prepare extra step kwargs for scheduler
|
| 112 |
+
extra_step_kwargs = self.prepare_extra_func_kwargs(
|
| 113 |
+
self.scheduler.step,
|
| 114 |
+
{
|
| 115 |
+
"generator": batch.generator,
|
| 116 |
+
"eta": batch.eta
|
| 117 |
+
},
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Setup precision and autocast settings
|
| 121 |
+
# TODO(will): make the precision configurable for inference
|
| 122 |
+
# target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
|
| 123 |
+
target_dtype = torch.bfloat16
|
| 124 |
+
autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
|
| 125 |
+
|
| 126 |
+
# Get timesteps and calculate warmup steps
|
| 127 |
+
timesteps = batch.timesteps
|
| 128 |
+
# TODO(will): remove this once we add input/output validation for stages
|
| 129 |
+
if timesteps is None:
|
| 130 |
+
raise ValueError("Timesteps must be provided")
|
| 131 |
+
num_inference_steps = batch.num_inference_steps
|
| 132 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 133 |
+
|
| 134 |
+
# Prepare image latents and embeddings for I2V generation
|
| 135 |
+
image_embeds = batch.image_embeds
|
| 136 |
+
if len(image_embeds) > 0:
|
| 137 |
+
assert not torch.isnan(image_embeds[0]).any(), "image_embeds contains nan"
|
| 138 |
+
image_embeds = [image_embed.to(target_dtype) for image_embed in image_embeds]
|
| 139 |
+
|
| 140 |
+
image_kwargs = self.prepare_extra_func_kwargs(
|
| 141 |
+
self.transformer.forward,
|
| 142 |
+
{
|
| 143 |
+
"encoder_hidden_states_image": image_embeds,
|
| 144 |
+
"mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24)
|
| 145 |
+
},
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
pos_cond_kwargs = self.prepare_extra_func_kwargs(
|
| 149 |
+
self.transformer.forward,
|
| 150 |
+
{
|
| 151 |
+
"encoder_hidden_states_2": batch.clip_embedding_pos,
|
| 152 |
+
"encoder_attention_mask": batch.prompt_attention_mask,
|
| 153 |
+
},
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
neg_cond_kwargs = self.prepare_extra_func_kwargs(
|
| 157 |
+
self.transformer.forward,
|
| 158 |
+
{
|
| 159 |
+
"encoder_hidden_states_2": batch.clip_embedding_neg,
|
| 160 |
+
"encoder_attention_mask": batch.negative_attention_mask,
|
| 161 |
+
},
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
action_kwargs = self.prepare_extra_func_kwargs(
|
| 165 |
+
self.transformer.forward,
|
| 166 |
+
{
|
| 167 |
+
"mouse_cond": batch.mouse_cond,
|
| 168 |
+
"keyboard_cond": batch.keyboard_cond,
|
| 169 |
+
"c2ws_plucker_emb": batch.c2ws_plucker_emb,
|
| 170 |
+
},
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
camera_kwargs = self.prepare_extra_func_kwargs(
|
| 174 |
+
self.transformer.forward,
|
| 175 |
+
{
|
| 176 |
+
"camera_states": batch.camera_states,
|
| 177 |
+
},
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Get latents and embeddings
|
| 181 |
+
latents = batch.latents
|
| 182 |
+
prompt_embeds = batch.prompt_embeds
|
| 183 |
+
assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan"
|
| 184 |
+
if batch.do_classifier_free_guidance:
|
| 185 |
+
neg_prompt_embeds = batch.negative_prompt_embeds
|
| 186 |
+
assert neg_prompt_embeds is not None
|
| 187 |
+
assert not torch.isnan(neg_prompt_embeds[0]).any(), "neg_prompt_embeds contains nan"
|
| 188 |
+
|
| 189 |
+
# (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
|
| 190 |
+
boundary_ratio = fastvideo_args.pipeline_config.dit_config.boundary_ratio
|
| 191 |
+
if batch.boundary_ratio is not None:
|
| 192 |
+
logger.info("Overriding boundary ratio from %s to %s", boundary_ratio, batch.boundary_ratio)
|
| 193 |
+
boundary_ratio = batch.boundary_ratio
|
| 194 |
+
|
| 195 |
+
boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps if boundary_ratio is not None else None
|
| 196 |
+
latent_model_input = latents.to(target_dtype)
|
| 197 |
+
assert latent_model_input.shape[0] == 1, "only support batch size 1"
|
| 198 |
+
|
| 199 |
+
if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
|
| 200 |
+
# TI2V directly replaces the first frame of the latent with
|
| 201 |
+
# the image latent instead of appending along the channel dim
|
| 202 |
+
assert batch.image_latent is None, "TI2V task should not have image latents"
|
| 203 |
+
assert self.vae is not None, "VAE is not provided for TI2V task"
|
| 204 |
+
z = self.vae.encode(batch.pil_image).mean.float()
|
| 205 |
+
if (hasattr(self.vae, "shift_factor") and self.vae.shift_factor is not None):
|
| 206 |
+
if isinstance(self.vae.shift_factor, torch.Tensor):
|
| 207 |
+
z -= self.vae.shift_factor.to(z.device, z.dtype)
|
| 208 |
+
else:
|
| 209 |
+
z -= self.vae.shift_factor
|
| 210 |
+
|
| 211 |
+
if isinstance(self.vae.scaling_factor, torch.Tensor):
|
| 212 |
+
z = z * self.vae.scaling_factor.to(z.device, z.dtype)
|
| 213 |
+
else:
|
| 214 |
+
z = z * self.vae.scaling_factor
|
| 215 |
+
|
| 216 |
+
latent_model_input = latent_model_input.squeeze(0)
|
| 217 |
+
_, mask2 = masks_like([latent_model_input], zero=True)
|
| 218 |
+
|
| 219 |
+
latent_model_input = (1. - mask2[0]) * z + mask2[0] * latent_model_input
|
| 220 |
+
# latent_model_input = latent_model_input.unsqueeze(0)
|
| 221 |
+
latent_model_input = latent_model_input.to(get_local_torch_device())
|
| 222 |
+
latents = latent_model_input
|
| 223 |
+
F = batch.num_frames
|
| 224 |
+
temporal_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_temporal
|
| 225 |
+
spatial_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
|
| 226 |
+
patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
|
| 227 |
+
if not isinstance(patch_size, tuple):
|
| 228 |
+
raise ValueError(f"Expected 3D patch_size tuple for denoising, got {patch_size!r}")
|
| 229 |
+
seq_len = ((F - 1) // temporal_scale + 1) * (batch.height // spatial_scale) * (
|
| 230 |
+
batch.width // spatial_scale) // (patch_size[1] * patch_size[2])
|
| 231 |
+
|
| 232 |
+
# Initialize lists for ODE trajectory
|
| 233 |
+
trajectory_timesteps: list[torch.Tensor] = []
|
| 234 |
+
trajectory_latents: list[torch.Tensor] = []
|
| 235 |
+
|
| 236 |
+
# Run denoising loop
|
| 237 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 238 |
+
for i, t in enumerate(timesteps):
|
| 239 |
+
# Skip if interrupted
|
| 240 |
+
if hasattr(self, 'interrupt') and self.interrupt:
|
| 241 |
+
continue
|
| 242 |
+
|
| 243 |
+
if boundary_timestep is None or t >= boundary_timestep:
|
| 244 |
+
if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
|
| 245 |
+
and self.transformer_2 is not None
|
| 246 |
+
and next(self.transformer_2.parameters()).device.type == 'cuda'):
|
| 247 |
+
self.transformer_2.to('cpu')
|
| 248 |
+
current_model = self.transformer
|
| 249 |
+
if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
|
| 250 |
+
and not fastvideo_args.use_fsdp_inference and current_model is not None):
|
| 251 |
+
transformer_device = next(current_model.parameters()).device.type
|
| 252 |
+
if transformer_device == 'cpu':
|
| 253 |
+
current_model.to(get_local_torch_device())
|
| 254 |
+
current_guidance_scale = batch.guidance_scale
|
| 255 |
+
else:
|
| 256 |
+
# low-noise stage in wan2.2
|
| 257 |
+
if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
|
| 258 |
+
and next(self.transformer.parameters()).device.type == 'cuda'):
|
| 259 |
+
self.transformer.to('cpu')
|
| 260 |
+
current_model = self.transformer_2
|
| 261 |
+
if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
|
| 262 |
+
and not fastvideo_args.use_fsdp_inference and current_model is not None):
|
| 263 |
+
transformer_2_device = next(current_model.parameters()).device.type
|
| 264 |
+
if transformer_2_device == 'cpu':
|
| 265 |
+
current_model.to(get_local_torch_device())
|
| 266 |
+
current_guidance_scale = batch.guidance_scale_2
|
| 267 |
+
assert current_model is not None, "current_model is None"
|
| 268 |
+
|
| 269 |
+
# Expand latents for V2V/I2V
|
| 270 |
+
latent_model_input = latents.to(target_dtype)
|
| 271 |
+
if batch.video_latent is not None:
|
| 272 |
+
latent_model_input = torch.cat([latent_model_input, batch.video_latent,
|
| 273 |
+
torch.zeros_like(latents)],
|
| 274 |
+
dim=1).to(target_dtype)
|
| 275 |
+
elif batch.image_latent is not None:
|
| 276 |
+
assert not fastvideo_args.pipeline_config.ti2v_task, "image latents should not be provided for TI2V task"
|
| 277 |
+
latent_model_input = torch.cat([latent_model_input, batch.image_latent], dim=1).to(target_dtype)
|
| 278 |
+
|
| 279 |
+
assert not torch.isnan(latent_model_input).any(), "latent_model_input contains nan"
|
| 280 |
+
if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
|
| 281 |
+
timestep = torch.stack([t]).to(get_local_torch_device())
|
| 282 |
+
temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
|
| 283 |
+
temp_ts = torch.cat([temp_ts, temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep])
|
| 284 |
+
timestep = temp_ts.unsqueeze(0)
|
| 285 |
+
t_expand = timestep.repeat(latent_model_input.shape[0], 1)
|
| 286 |
+
else:
|
| 287 |
+
t_expand = t.repeat(latent_model_input.shape[0])
|
| 288 |
+
t_expand = t_expand.to(get_local_torch_device())
|
| 289 |
+
|
| 290 |
+
use_meanflow = getattr(self.transformer.config, "use_meanflow", False)
|
| 291 |
+
if use_meanflow:
|
| 292 |
+
if i == len(timesteps) - 1:
|
| 293 |
+
timesteps_r = torch.tensor([0.0], device=get_local_torch_device())
|
| 294 |
+
else:
|
| 295 |
+
timesteps_r = timesteps[i + 1]
|
| 296 |
+
timesteps_r = timesteps_r.repeat(latent_model_input.shape[0])
|
| 297 |
+
else:
|
| 298 |
+
timesteps_r = None
|
| 299 |
+
|
| 300 |
+
timesteps_r_kwarg = self.prepare_extra_func_kwargs(
|
| 301 |
+
self.transformer.forward,
|
| 302 |
+
{
|
| 303 |
+
"timestep_r": timesteps_r,
|
| 304 |
+
},
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 308 |
+
|
| 309 |
+
# Prepare inputs for transformer
|
| 310 |
+
guidance_expand = (torch.tensor(
|
| 311 |
+
[fastvideo_args.pipeline_config.embedded_cfg_scale] * latent_model_input.shape[0],
|
| 312 |
+
dtype=torch.float32,
|
| 313 |
+
device=get_local_torch_device(),
|
| 314 |
+
).to(target_dtype) * 1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale is not None else None)
|
| 315 |
+
|
| 316 |
+
# Predict noise residual
|
| 317 |
+
with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
|
| 318 |
+
if (vsa_available and self.attn_backend == VideoSparseAttentionBackend) or \
|
| 319 |
+
(sparse_fp4_available and self.attn_backend in sparse_fp4_backends):
|
| 320 |
+
self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()
|
| 321 |
+
|
| 322 |
+
if self.attn_metadata_builder_cls is not None:
|
| 323 |
+
self.attn_metadata_builder = self.attn_metadata_builder_cls()
|
| 324 |
+
# TODO(will): clean this up
|
| 325 |
+
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
| 326 |
+
current_timestep=i, # type: ignore
|
| 327 |
+
raw_latent_shape=batch.raw_latent_shape[2:5], # type: ignore
|
| 328 |
+
patch_size=fastvideo_args.pipeline_config. # type: ignore
|
| 329 |
+
dit_config.patch_size, # type: ignore
|
| 330 |
+
VSA_sparsity=fastvideo_args.VSA_sparsity, # type: ignore
|
| 331 |
+
device=get_local_torch_device(),
|
| 332 |
+
)
|
| 333 |
+
assert attn_metadata is not None, "attn_metadata cannot be None"
|
| 334 |
+
else:
|
| 335 |
+
attn_metadata = None
|
| 336 |
+
elif (vmoba_attn_available and self.attn_backend == VMOBAAttentionBackend):
|
| 337 |
+
self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()
|
| 338 |
+
if self.attn_metadata_builder_cls is not None:
|
| 339 |
+
self.attn_metadata_builder = self.attn_metadata_builder_cls()
|
| 340 |
+
# Prepare V-MoBA parameters from config
|
| 341 |
+
moba_params = fastvideo_args.moba_config.copy()
|
| 342 |
+
assert batch.raw_latent_shape is not None, "raw_latent_shape must be set for V-MoBA"
|
| 343 |
+
moba_params.update({
|
| 344 |
+
"current_timestep": i,
|
| 345 |
+
"raw_latent_shape": batch.raw_latent_shape[2:5],
|
| 346 |
+
"patch_size": fastvideo_args.pipeline_config.dit_config.patch_size,
|
| 347 |
+
"device": get_local_torch_device(),
|
| 348 |
+
})
|
| 349 |
+
attn_metadata = self.attn_metadata_builder.build(**moba_params)
|
| 350 |
+
assert attn_metadata is not None, "attn_metadata cannot be None"
|
| 351 |
+
else:
|
| 352 |
+
attn_metadata = None
|
| 353 |
+
else:
|
| 354 |
+
attn_metadata = None
|
| 355 |
+
# TODO(will): finalize the interface. vLLM uses this to
|
| 356 |
+
# support torch dynamo compilation. They pass in
|
| 357 |
+
# attn_metadata, vllm_config, and num_tokens. We can pass in
|
| 358 |
+
# fastvideo_args or training_args, and attn_metadata.
|
| 359 |
+
batch.is_cfg_negative = False
|
| 360 |
+
with set_forward_context(
|
| 361 |
+
current_timestep=i,
|
| 362 |
+
attn_metadata=attn_metadata,
|
| 363 |
+
forward_batch=batch,
|
| 364 |
+
# fastvideo_args=fastvideo_args
|
| 365 |
+
):
|
| 366 |
+
# Run transformer
|
| 367 |
+
noise_pred = current_model(
|
| 368 |
+
latent_model_input,
|
| 369 |
+
prompt_embeds,
|
| 370 |
+
t_expand,
|
| 371 |
+
guidance=guidance_expand,
|
| 372 |
+
**image_kwargs,
|
| 373 |
+
**pos_cond_kwargs,
|
| 374 |
+
**action_kwargs,
|
| 375 |
+
**camera_kwargs,
|
| 376 |
+
**timesteps_r_kwarg,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
if batch.do_classifier_free_guidance:
|
| 380 |
+
batch.is_cfg_negative = True
|
| 381 |
+
with set_forward_context(
|
| 382 |
+
current_timestep=i,
|
| 383 |
+
attn_metadata=attn_metadata,
|
| 384 |
+
forward_batch=batch,
|
| 385 |
+
):
|
| 386 |
+
noise_pred_uncond = current_model(
|
| 387 |
+
latent_model_input,
|
| 388 |
+
neg_prompt_embeds,
|
| 389 |
+
t_expand,
|
| 390 |
+
guidance=guidance_expand,
|
| 391 |
+
**image_kwargs,
|
| 392 |
+
**neg_cond_kwargs,
|
| 393 |
+
**action_kwargs,
|
| 394 |
+
**camera_kwargs,
|
| 395 |
+
**timesteps_r_kwarg,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
noise_pred_text = noise_pred
|
| 399 |
+
noise_pred = noise_pred_uncond + current_guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 400 |
+
|
| 401 |
+
# Apply guidance rescale if needed
|
| 402 |
+
if batch.guidance_rescale > 0.0:
|
| 403 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
| 404 |
+
noise_pred = self.rescale_noise_cfg(
|
| 405 |
+
noise_pred,
|
| 406 |
+
noise_pred_text,
|
| 407 |
+
guidance_rescale=batch.guidance_rescale,
|
| 408 |
+
)
|
| 409 |
+
# Compute the previous noisy sample
|
| 410 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 411 |
+
if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
|
| 412 |
+
latents = latents.squeeze(0)
|
| 413 |
+
latents = (1. - mask2[0]) * z + mask2[0] * latents
|
| 414 |
+
# latents = latents.unsqueeze(0)
|
| 415 |
+
|
| 416 |
+
# save trajectory latents if needed
|
| 417 |
+
if batch.return_trajectory_latents:
|
| 418 |
+
trajectory_timesteps.append(t)
|
| 419 |
+
trajectory_latents.append(latents)
|
| 420 |
+
|
| 421 |
+
# Update progress bar
|
| 422 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
|
| 423 |
+
(i + 1) % self.scheduler.order == 0 and progress_bar is not None):
|
| 424 |
+
progress_bar.update()
|
| 425 |
+
|
| 426 |
+
trajectory_tensor: torch.Tensor | None = None
|
| 427 |
+
if trajectory_latents:
|
| 428 |
+
trajectory_tensor = torch.stack(trajectory_latents, dim=1)
|
| 429 |
+
trajectory_timesteps_tensor = torch.stack(trajectory_timesteps, dim=0)
|
| 430 |
+
else:
|
| 431 |
+
trajectory_tensor = None
|
| 432 |
+
trajectory_timesteps_tensor = None
|
| 433 |
+
|
| 434 |
+
if trajectory_tensor is not None and trajectory_timesteps_tensor is not None:
|
| 435 |
+
batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu()
|
| 436 |
+
batch.trajectory_latents = trajectory_tensor.cpu()
|
| 437 |
+
|
| 438 |
+
# Update batch with final latents
|
| 439 |
+
batch.latents = latents
|
| 440 |
+
|
| 441 |
+
if fastvideo_args.dit_layerwise_offload:
|
| 442 |
+
mgr = getattr(self.transformer, "_layerwise_offload_manager", None)
|
| 443 |
+
if mgr is not None and getattr(mgr, "enabled", False):
|
| 444 |
+
mgr.release_all()
|
| 445 |
+
if self.transformer_2 is not None:
|
| 446 |
+
mgr2 = getattr(self.transformer_2, "_layerwise_offload_manager", None)
|
| 447 |
+
if mgr2 is not None and getattr(mgr2, "enabled", False):
|
| 448 |
+
mgr2.release_all()
|
| 449 |
+
|
| 450 |
+
# deallocate transformer if on mps
|
| 451 |
+
if torch.backends.mps.is_available():
|
| 452 |
+
logger.info("Memory before deallocating transformer: %s", torch.mps.current_allocated_memory())
|
| 453 |
+
del self.transformer
|
| 454 |
+
if pipeline is not None and "transformer" in pipeline.modules:
|
| 455 |
+
del pipeline.modules["transformer"]
|
| 456 |
+
fastvideo_args.model_loaded["transformer"] = False
|
| 457 |
+
logger.info("Memory after deallocating transformer: %s", torch.mps.current_allocated_memory())
|
| 458 |
+
|
| 459 |
+
return batch
|
| 460 |
+
|
| 461 |
+
def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]:
|
| 462 |
+
"""
|
| 463 |
+
Prepare extra kwargs for the scheduler step / denoise step.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
func: The function to prepare kwargs for.
|
| 467 |
+
kwargs: The kwargs to prepare.
|
| 468 |
+
|
| 469 |
+
Returns:
|
| 470 |
+
The prepared kwargs.
|
| 471 |
+
"""
|
| 472 |
+
extra_step_kwargs = {}
|
| 473 |
+
for k, v in kwargs.items():
|
| 474 |
+
accepts = k in set(inspect.signature(func).parameters.keys())
|
| 475 |
+
if accepts:
|
| 476 |
+
extra_step_kwargs[k] = v
|
| 477 |
+
return extra_step_kwargs
|
| 478 |
+
|
| 479 |
+
def progress_bar(self, iterable: Iterable | None = None, total: int | None = None) -> tqdm:
|
| 480 |
+
"""
|
| 481 |
+
Create a progress bar for the denoising process.
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
iterable: The iterable to iterate over.
|
| 485 |
+
total: The total number of items.
|
| 486 |
+
|
| 487 |
+
Returns:
|
| 488 |
+
A tqdm progress bar.
|
| 489 |
+
"""
|
| 490 |
+
local_rank = get_world_group().local_rank
|
| 491 |
+
if local_rank == 0:
|
| 492 |
+
return tqdm(iterable=iterable, total=total)
|
| 493 |
+
else:
|
| 494 |
+
return tqdm(iterable=iterable, total=total, disable=True)
|
| 495 |
+
|
| 496 |
+
def rescale_noise_cfg(self, noise_cfg, noise_pred_text, guidance_rescale=0.0) -> torch.Tensor:
|
| 497 |
+
"""
|
| 498 |
+
Rescale noise prediction according to guidance_rescale.
|
| 499 |
+
|
| 500 |
+
Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed"
|
| 501 |
+
(https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
noise_cfg: The noise prediction with guidance.
|
| 505 |
+
noise_pred_text: The text-conditioned noise prediction.
|
| 506 |
+
guidance_rescale: The guidance rescale factor.
|
| 507 |
+
|
| 508 |
+
Returns:
|
| 509 |
+
The rescaled noise prediction.
|
| 510 |
+
"""
|
| 511 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
| 512 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 513 |
+
# Rescale the results from guidance (fixes overexposure)
|
| 514 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
| 515 |
+
# Mix with the original results from guidance by factor guidance_rescale
|
| 516 |
+
noise_cfg = (guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg)
|
| 517 |
+
return noise_cfg
|
| 518 |
+
|
| 519 |
+
def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
|
| 520 |
+
"""Verify denoising stage inputs."""
|
| 521 |
+
result = VerificationResult()
|
| 522 |
+
result.add_check("timesteps", batch.timesteps, [V.is_tensor, V.min_dims(1)])
|
| 523 |
+
result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
|
| 524 |
+
result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
|
| 525 |
+
result.add_check("image_embeds", batch.image_embeds, V.is_list)
|
| 526 |
+
result.add_check("image_latent", batch.image_latent, V.none_or_tensor_with_dims(5))
|
| 527 |
+
result.add_check("num_inference_steps", batch.num_inference_steps, V.positive_int)
|
| 528 |
+
result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
|
| 529 |
+
result.add_check("eta", batch.eta, V.non_negative_float)
|
| 530 |
+
result.add_check("generator", batch.generator, V.generator_or_list_generators)
|
| 531 |
+
result.add_check("do_classifier_free_guidance", batch.do_classifier_free_guidance, V.bool_value)
|
| 532 |
+
result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
|
| 533 |
+
lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x))
|
| 534 |
+
return result
|
| 535 |
+
|
| 536 |
+
def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
|
| 537 |
+
"""Verify denoising stage outputs."""
|
| 538 |
+
result = VerificationResult()
|
| 539 |
+
result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
|
| 540 |
+
return result
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class CosmosDenoisingStage(DenoisingStage):
|
| 544 |
+
"""
|
| 545 |
+
Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler.
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
def __init__(self, transformer, scheduler, pipeline=None) -> None:
|
| 549 |
+
super().__init__(transformer, scheduler, pipeline)
|
| 550 |
+
|
| 551 |
+
def forward(
|
| 552 |
+
self,
|
| 553 |
+
batch: ForwardBatch,
|
| 554 |
+
fastvideo_args: FastVideoArgs,
|
| 555 |
+
) -> ForwardBatch:
|
| 556 |
+
pipeline = self.pipeline() if self.pipeline else None
|
| 557 |
+
if not fastvideo_args.model_loaded["transformer"]:
|
| 558 |
+
loader = TransformerLoader()
|
| 559 |
+
self.transformer = loader.load(fastvideo_args.model_paths["transformer"], fastvideo_args)
|
| 560 |
+
if pipeline:
|
| 561 |
+
pipeline.add_module("transformer", self.transformer)
|
| 562 |
+
fastvideo_args.model_loaded["transformer"] = True
|
| 563 |
+
|
| 564 |
+
extra_step_kwargs = self.prepare_extra_func_kwargs(
|
| 565 |
+
self.scheduler.step,
|
| 566 |
+
{
|
| 567 |
+
"generator": batch.generator,
|
| 568 |
+
"eta": batch.eta
|
| 569 |
+
},
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
if hasattr(self.transformer, 'module'):
|
| 573 |
+
transformer_dtype = next(self.transformer.module.parameters()).dtype
|
| 574 |
+
else:
|
| 575 |
+
transformer_dtype = next(self.transformer.parameters()).dtype
|
| 576 |
+
target_dtype = transformer_dtype
|
| 577 |
+
autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
|
| 578 |
+
|
| 579 |
+
latents = batch.latents
|
| 580 |
+
num_inference_steps = batch.num_inference_steps
|
| 581 |
+
guidance_scale = batch.guidance_scale
|
| 582 |
+
|
| 583 |
+
sigma_max = 80.0
|
| 584 |
+
sigma_min = 0.002
|
| 585 |
+
sigma_data = 1.0
|
| 586 |
+
final_sigmas_type = "sigma_min"
|
| 587 |
+
|
| 588 |
+
if self.scheduler is not None:
|
| 589 |
+
self.scheduler.register_to_config(
|
| 590 |
+
sigma_max=sigma_max,
|
| 591 |
+
sigma_min=sigma_min,
|
| 592 |
+
sigma_data=sigma_data,
|
| 593 |
+
final_sigmas_type=final_sigmas_type,
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
self.scheduler.set_timesteps(num_inference_steps, device=latents.device)
|
| 597 |
+
timesteps = self.scheduler.timesteps
|
| 598 |
+
|
| 599 |
+
if (hasattr(self.scheduler.config, 'final_sigmas_type')
|
| 600 |
+
and self.scheduler.config.final_sigmas_type == "sigma_min" and len(self.scheduler.sigmas) > 1):
|
| 601 |
+
self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
|
| 602 |
+
|
| 603 |
+
conditioning_latents = getattr(batch, 'conditioning_latents', None)
|
| 604 |
+
unconditioning_latents = conditioning_latents
|
| 605 |
+
|
| 606 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 607 |
+
for i, t in enumerate(timesteps):
|
| 608 |
+
if hasattr(self, 'interrupt') and self.interrupt:
|
| 609 |
+
continue
|
| 610 |
+
|
| 611 |
+
current_sigma = self.scheduler.sigmas[i]
|
| 612 |
+
current_t = current_sigma / (current_sigma + 1)
|
| 613 |
+
c_in = 1 - current_t
|
| 614 |
+
c_skip = 1 - current_t
|
| 615 |
+
c_out = -current_t
|
| 616 |
+
|
| 617 |
+
timestep = current_t.view(1, 1, 1, 1, 1).expand(latents.size(0), -1, latents.size(2), -1,
|
| 618 |
+
-1) # [B, 1, T, 1, 1]
|
| 619 |
+
|
| 620 |
+
with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
|
| 621 |
+
|
| 622 |
+
cond_latent = latents * c_in
|
| 623 |
+
|
| 624 |
+
if hasattr(
|
| 625 |
+
batch,
|
| 626 |
+
'cond_indicator') and batch.cond_indicator is not None and conditioning_latents is not None:
|
| 627 |
+
cond_latent = batch.cond_indicator * conditioning_latents + (1 -
|
| 628 |
+
batch.cond_indicator) * cond_latent
|
| 629 |
+
else:
|
| 630 |
+
logger.warning(
|
| 631 |
+
"Step %s: Missing conditioning data - cond_indicator: %s, conditioning_latents: %s", i,
|
| 632 |
+
hasattr(batch, 'cond_indicator'), conditioning_latents is not None)
|
| 633 |
+
|
| 634 |
+
cond_latent = cond_latent.to(target_dtype)
|
| 635 |
+
|
| 636 |
+
cond_timestep = timestep
|
| 637 |
+
if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None:
|
| 638 |
+
sigma_conditioning = 0.0001
|
| 639 |
+
t_conditioning = sigma_conditioning / (sigma_conditioning + 1)
|
| 640 |
+
cond_timestep = batch.cond_indicator * t_conditioning + (1 - batch.cond_indicator) * timestep
|
| 641 |
+
cond_timestep = cond_timestep.to(target_dtype)
|
| 642 |
+
|
| 643 |
+
with set_forward_context(
|
| 644 |
+
current_timestep=i,
|
| 645 |
+
attn_metadata=None,
|
| 646 |
+
forward_batch=batch,
|
| 647 |
+
):
|
| 648 |
+
# Use conditioning masks from CosmosLatentPreparationStage
|
| 649 |
+
condition_mask = batch.cond_mask.to(target_dtype) if hasattr(batch, 'cond_mask') else None
|
| 650 |
+
padding_mask = torch.zeros(1,
|
| 651 |
+
1,
|
| 652 |
+
batch.height,
|
| 653 |
+
batch.width,
|
| 654 |
+
device=cond_latent.device,
|
| 655 |
+
dtype=target_dtype)
|
| 656 |
+
|
| 657 |
+
# Fallback if masks not available
|
| 658 |
+
if condition_mask is None:
|
| 659 |
+
batch_size, num_channels, num_frames, height, width = cond_latent.shape
|
| 660 |
+
condition_mask = torch.zeros(batch_size,
|
| 661 |
+
1,
|
| 662 |
+
num_frames,
|
| 663 |
+
height,
|
| 664 |
+
width,
|
| 665 |
+
device=cond_latent.device,
|
| 666 |
+
dtype=target_dtype)
|
| 667 |
+
|
| 668 |
+
noise_pred = self.transformer(
|
| 669 |
+
hidden_states=cond_latent,
|
| 670 |
+
timestep=cond_timestep.to(target_dtype),
|
| 671 |
+
encoder_hidden_states=batch.prompt_embeds[0].to(target_dtype),
|
| 672 |
+
fps=24, # TODO: get fps from batch or config
|
| 673 |
+
condition_mask=condition_mask,
|
| 674 |
+
padding_mask=padding_mask,
|
| 675 |
+
return_dict=False,
|
| 676 |
+
)[0]
|
| 677 |
+
|
| 678 |
+
cond_pred = (c_skip * latents + c_out * noise_pred.float()).to(target_dtype)
|
| 679 |
+
|
| 680 |
+
if hasattr(
|
| 681 |
+
batch,
|
| 682 |
+
'cond_indicator') and batch.cond_indicator is not None and conditioning_latents is not None:
|
| 683 |
+
cond_pred = batch.cond_indicator * conditioning_latents + (1 - batch.cond_indicator) * cond_pred
|
| 684 |
+
|
| 685 |
+
if batch.do_classifier_free_guidance and batch.negative_prompt_embeds is not None:
|
| 686 |
+
uncond_latent = latents * c_in
|
| 687 |
+
|
| 688 |
+
if hasattr(batch, 'uncond_indicator'
|
| 689 |
+
) and batch.uncond_indicator is not None and unconditioning_latents is not None:
|
| 690 |
+
uncond_latent = batch.uncond_indicator * unconditioning_latents + (
|
| 691 |
+
1 - batch.uncond_indicator) * uncond_latent
|
| 692 |
+
|
| 693 |
+
with set_forward_context(
|
| 694 |
+
current_timestep=i,
|
| 695 |
+
attn_metadata=None,
|
| 696 |
+
forward_batch=batch,
|
| 697 |
+
):
|
| 698 |
+
uncond_condition_mask = batch.uncond_mask.to(target_dtype) if hasattr(
|
| 699 |
+
batch, 'uncond_mask') and batch.uncond_mask is not None else condition_mask
|
| 700 |
+
|
| 701 |
+
uncond_timestep = timestep
|
| 702 |
+
if hasattr(batch, 'uncond_indicator') and batch.uncond_indicator is not None:
|
| 703 |
+
sigma_conditioning = 0.0001
|
| 704 |
+
t_conditioning = sigma_conditioning / (sigma_conditioning + 1)
|
| 705 |
+
uncond_timestep = batch.uncond_indicator * t_conditioning + (
|
| 706 |
+
1 - batch.uncond_indicator) * timestep
|
| 707 |
+
uncond_timestep = uncond_timestep.to(target_dtype)
|
| 708 |
+
|
| 709 |
+
noise_pred_uncond = self.transformer(
|
| 710 |
+
hidden_states=uncond_latent.to(target_dtype),
|
| 711 |
+
timestep=uncond_timestep.to(target_dtype),
|
| 712 |
+
encoder_hidden_states=batch.negative_prompt_embeds[0].to(target_dtype),
|
| 713 |
+
fps=24, # TODO: get fps from batch or config
|
| 714 |
+
condition_mask=uncond_condition_mask,
|
| 715 |
+
padding_mask=padding_mask,
|
| 716 |
+
return_dict=False,
|
| 717 |
+
)[0]
|
| 718 |
+
|
| 719 |
+
uncond_pred = (c_skip * latents + c_out * noise_pred_uncond.float()).to(target_dtype)
|
| 720 |
+
|
| 721 |
+
if hasattr(batch, 'uncond_indicator'
|
| 722 |
+
) and batch.uncond_indicator is not None and unconditioning_latents is not None:
|
| 723 |
+
uncond_pred = batch.uncond_indicator * unconditioning_latents + (
|
| 724 |
+
1 - batch.uncond_indicator) * uncond_pred
|
| 725 |
+
|
| 726 |
+
guidance_diff = cond_pred - uncond_pred
|
| 727 |
+
final_pred = cond_pred + guidance_scale * guidance_diff
|
| 728 |
+
else:
|
| 729 |
+
final_pred = cond_pred
|
| 730 |
+
|
| 731 |
+
# Convert to noise for scheduler step
|
| 732 |
+
if current_sigma > 1e-8:
|
| 733 |
+
noise_for_scheduler = (latents - final_pred) / current_sigma
|
| 734 |
+
else:
|
| 735 |
+
logger.warning("Step %s: current_sigma too small (%s), using final_pred directly", i, current_sigma)
|
| 736 |
+
noise_for_scheduler = final_pred
|
| 737 |
+
|
| 738 |
+
if torch.isnan(noise_for_scheduler).sum() > 0:
|
| 739 |
+
logger.error("Step %s: NaN detected in noise_for_scheduler, sum: %s", i,
|
| 740 |
+
noise_for_scheduler.float().sum().item())
|
| 741 |
+
logger.error("Step %s: latents sum: %s, final_pred sum: %s, current_sigma: %s", i,
|
| 742 |
+
latents.float().sum().item(),
|
| 743 |
+
final_pred.float().sum().item(), current_sigma)
|
| 744 |
+
|
| 745 |
+
latents = self.scheduler.step(noise_for_scheduler, t, latents, **extra_step_kwargs,
|
| 746 |
+
return_dict=False)[0]
|
| 747 |
+
|
| 748 |
+
progress_bar.update()
|
| 749 |
+
|
| 750 |
+
batch.latents = latents
|
| 751 |
+
|
| 752 |
+
return batch
|
| 753 |
+
|
| 754 |
+
def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
|
| 755 |
+
"""Verify Cosmos denoising stage inputs."""
|
| 756 |
+
result = VerificationResult()
|
| 757 |
+
result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
|
| 758 |
+
result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
|
| 759 |
+
result.add_check("num_inference_steps", batch.num_inference_steps, V.positive_int)
|
| 760 |
+
result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
|
| 761 |
+
result.add_check("do_classifier_free_guidance", batch.do_classifier_free_guidance, V.bool_value)
|
| 762 |
+
result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
|
| 763 |
+
lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x))
|
| 764 |
+
return result
|
| 765 |
+
|
| 766 |
+
def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
|
| 767 |
+
"""Verify Cosmos denoising stage outputs."""
|
| 768 |
+
result = VerificationResult()
|
| 769 |
+
result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
|
| 770 |
+
return result
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
class Cosmos25DenoisingStage(CosmosDenoisingStage):
|
| 774 |
+
"""Denoising stage for Cosmos 2.5 DiT (expects 1D/2D timestep, not 5D)."""
|
| 775 |
+
|
| 776 |
+
def forward(
|
| 777 |
+
self,
|
| 778 |
+
batch: ForwardBatch,
|
| 779 |
+
fastvideo_args: FastVideoArgs,
|
| 780 |
+
) -> ForwardBatch:
|
| 781 |
+
pipeline = self.pipeline() if self.pipeline else None
|
| 782 |
+
if not fastvideo_args.model_loaded["transformer"]:
|
| 783 |
+
loader = TransformerLoader()
|
| 784 |
+
self.transformer = loader.load(fastvideo_args.model_paths["transformer"], fastvideo_args)
|
| 785 |
+
if pipeline:
|
| 786 |
+
pipeline.add_module("transformer", self.transformer)
|
| 787 |
+
fastvideo_args.model_loaded["transformer"] = True
|
| 788 |
+
|
| 789 |
+
extra_step_kwargs = self.prepare_extra_func_kwargs(
|
| 790 |
+
self.scheduler.step,
|
| 791 |
+
{
|
| 792 |
+
"generator": batch.generator,
|
| 793 |
+
"eta": batch.eta
|
| 794 |
+
},
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
if hasattr(self.transformer, 'module'):
|
| 798 |
+
transformer_dtype = next(self.transformer.module.parameters()).dtype
|
| 799 |
+
else:
|
| 800 |
+
transformer_dtype = next(self.transformer.parameters()).dtype
|
| 801 |
+
target_dtype = transformer_dtype
|
| 802 |
+
autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
|
| 803 |
+
|
| 804 |
+
latents = batch.latents
|
| 805 |
+
if latents is None:
|
| 806 |
+
raise ValueError("latents must be provided for Cosmos25DenoisingStage")
|
| 807 |
+
guidance_scale = batch.guidance_scale
|
| 808 |
+
|
| 809 |
+
if batch.timesteps is None:
|
| 810 |
+
self.scheduler.set_timesteps(batch.num_inference_steps, device=latents.device)
|
| 811 |
+
timesteps = self.scheduler.timesteps
|
| 812 |
+
else:
|
| 813 |
+
timesteps = batch.timesteps.to(latents.device)
|
| 814 |
+
|
| 815 |
+
cfg = fastvideo_args.pipeline_config
|
| 816 |
+
|
| 817 |
+
if batch.fps is None:
|
| 818 |
+
gen = batch.generator
|
| 819 |
+
if isinstance(gen, list) and len(gen) > 0:
|
| 820 |
+
gen = gen[0]
|
| 821 |
+
fps_tensor = torch.randint(
|
| 822 |
+
16,
|
| 823 |
+
32,
|
| 824 |
+
(1, ),
|
| 825 |
+
generator=gen if isinstance(gen, torch.Generator) else None,
|
| 826 |
+
device=latents.device,
|
| 827 |
+
).float().to(dtype=target_dtype)
|
| 828 |
+
else:
|
| 829 |
+
fps_val = batch.fps
|
| 830 |
+
fps_tensor = torch.tensor(
|
| 831 |
+
[fps_val],
|
| 832 |
+
device=latents.device,
|
| 833 |
+
dtype=target_dtype,
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
latents_4d = latents[0]
|
| 837 |
+
|
| 838 |
+
# Masks are optional for T2W.
|
| 839 |
+
cond_mask = getattr(batch, "cond_mask", None)
|
| 840 |
+
condition_mask = cond_mask.to(target_dtype) if isinstance(cond_mask, torch.Tensor) else None
|
| 841 |
+
pad_mask = getattr(batch, "padding_mask", None)
|
| 842 |
+
padding_mask = pad_mask.to(target_dtype) if isinstance(pad_mask, torch.Tensor) else None
|
| 843 |
+
|
| 844 |
+
# Conditioning fields are attached by latent preparation stage.
|
| 845 |
+
conditioning_latents = getattr(batch, "conditioning_latents", None)
|
| 846 |
+
cond_indicator = getattr(batch, "cond_indicator", None)
|
| 847 |
+
# Infer whether this is a conditioned run (V2W/I2W) purely from the presence
|
| 848 |
+
# of conditioning latents. Avoid carrying explicit mode flags on the batch.
|
| 849 |
+
is_conditioned = (conditioning_latents is not None)
|
| 850 |
+
|
| 851 |
+
init_noise_4d = latents_4d.clone()
|
| 852 |
+
if condition_mask is None:
|
| 853 |
+
_, t, h, w = latents_4d.shape
|
| 854 |
+
condition_mask = torch.zeros(1, 1, t, h, w, device=latents.device, dtype=target_dtype)
|
| 855 |
+
if padding_mask is None:
|
| 856 |
+
_, _, h, w = latents_4d.shape
|
| 857 |
+
padding_default = 0.0 if is_conditioned else 1.0
|
| 858 |
+
padding_mask = torch.full(
|
| 859 |
+
(1, 1, h, w),
|
| 860 |
+
float(padding_default),
|
| 861 |
+
device=latents.device,
|
| 862 |
+
dtype=target_dtype,
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
timestep_scale = 0.001
|
| 866 |
+
|
| 867 |
+
state_dtype = torch.float32
|
| 868 |
+
|
| 869 |
+
conditional_frame_timestep = 0.1
|
| 870 |
+
latents_4d = latents_4d.to(state_dtype)
|
| 871 |
+
init_noise_4d = init_noise_4d.to(state_dtype)
|
| 872 |
+
|
| 873 |
+
clamp_every_step = bool(getattr(cfg, "cosmos25_clamp_every_step", True)) if is_conditioned else False
|
| 874 |
+
|
| 875 |
+
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
| 876 |
+
for i, t in enumerate(timesteps):
|
| 877 |
+
t_val = float(t)
|
| 878 |
+
if is_conditioned:
|
| 879 |
+
t_frames = int(latents_4d.shape[1])
|
| 880 |
+
timestep = torch.full(
|
| 881 |
+
(1, t_frames),
|
| 882 |
+
float(t_val * timestep_scale),
|
| 883 |
+
device=latents.device,
|
| 884 |
+
dtype=torch.float32,
|
| 885 |
+
)
|
| 886 |
+
if cond_indicator is not None and t_frames > 0:
|
| 887 |
+
cond_t = cond_indicator[0, 0, :t_frames, 0, 0]
|
| 888 |
+
cond_mask_t = (cond_t > 0.5)
|
| 889 |
+
if bool(cond_mask_t.any().item()):
|
| 890 |
+
timestep[0, cond_mask_t] = float(conditional_frame_timestep)
|
| 891 |
+
else:
|
| 892 |
+
timestep_val = t_val * timestep_scale
|
| 893 |
+
timestep = torch.tensor(
|
| 894 |
+
[[float(timestep_val)]],
|
| 895 |
+
device=latents.device,
|
| 896 |
+
dtype=target_dtype,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
# Conditioned runs: replace x_t with GT x0 on the conditioned frames.
|
| 900 |
+
if (is_conditioned and cond_indicator is not None and conditioning_latents is not None
|
| 901 |
+
and (clamp_every_step or i == 0)):
|
| 902 |
+
cond_ind_4d = cond_indicator[0].to(state_dtype)
|
| 903 |
+
gt_x0 = conditioning_latents[0].to(state_dtype)
|
| 904 |
+
latents_4d = gt_x0 * cond_ind_4d + latents_4d * (1 - cond_ind_4d)
|
| 905 |
+
|
| 906 |
+
model_hidden_states = latents_4d.unsqueeze(0)
|
| 907 |
+
|
| 908 |
+
with (
|
| 909 |
+
set_forward_context(current_timestep=int(t_val), attn_metadata=None, forward_batch=batch),
|
| 910 |
+
torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled),
|
| 911 |
+
):
|
| 912 |
+
cond_v = self.transformer(
|
| 913 |
+
hidden_states=model_hidden_states.to(target_dtype),
|
| 914 |
+
encoder_hidden_states=batch.prompt_embeds[0].to(target_dtype),
|
| 915 |
+
timestep=timestep,
|
| 916 |
+
fps=fps_tensor,
|
| 917 |
+
condition_mask=condition_mask,
|
| 918 |
+
padding_mask=padding_mask,
|
| 919 |
+
return_dict=False,
|
| 920 |
+
)[0]
|
| 921 |
+
|
| 922 |
+
if batch.do_classifier_free_guidance and batch.negative_prompt_embeds:
|
| 923 |
+
uncond_v = self.transformer(
|
| 924 |
+
hidden_states=model_hidden_states.to(target_dtype),
|
| 925 |
+
encoder_hidden_states=batch.negative_prompt_embeds[0].to(target_dtype),
|
| 926 |
+
timestep=timestep,
|
| 927 |
+
fps=fps_tensor,
|
| 928 |
+
condition_mask=condition_mask,
|
| 929 |
+
padding_mask=padding_mask,
|
| 930 |
+
return_dict=False,
|
| 931 |
+
)[0]
|
| 932 |
+
if is_conditioned:
|
| 933 |
+
v = cond_v + guidance_scale * (cond_v - uncond_v)
|
| 934 |
+
else:
|
| 935 |
+
v = uncond_v + guidance_scale * (cond_v - uncond_v)
|
| 936 |
+
else:
|
| 937 |
+
v = cond_v
|
| 938 |
+
|
| 939 |
+
# Conditioned runs: replace velocity on conditioned frames with GT velocity.
|
| 940 |
+
if (is_conditioned and cond_indicator is not None and conditioning_latents is not None):
|
| 941 |
+
cond_ind_4d = cond_indicator[0].to(state_dtype)
|
| 942 |
+
gt_x0 = conditioning_latents[0].to(state_dtype)
|
| 943 |
+
gt_v = init_noise_4d.to(state_dtype) - gt_x0
|
| 944 |
+
v = cond_ind_4d * gt_v + (1 - cond_ind_4d) * v.to(state_dtype)
|
| 945 |
+
|
| 946 |
+
prev = self.scheduler.step(v.unsqueeze(0),
|
| 947 |
+
t,
|
| 948 |
+
latents_4d.unsqueeze(0),
|
| 949 |
+
**extra_step_kwargs,
|
| 950 |
+
return_dict=False)[0]
|
| 951 |
+
latents_4d = prev.squeeze(0)
|
| 952 |
+
|
| 953 |
+
progress_bar.update()
|
| 954 |
+
|
| 955 |
+
batch.latents = latents_4d.to(target_dtype).unsqueeze(0)
|
| 956 |
+
return batch
|
| 957 |
+
|
| 958 |
+
|
| 959 |
+
class Cosmos25T2WDenoisingStage(Cosmos25DenoisingStage):
|
| 960 |
+
"""Cosmos 2.5 Text2World denoising stage."""
|
| 961 |
+
|
| 962 |
+
_CONDITIONING_FIELDS = (
|
| 963 |
+
"conditioning_latents",
|
| 964 |
+
"cond_indicator",
|
| 965 |
+
"uncond_indicator",
|
| 966 |
+
)
|
| 967 |
+
|
| 968 |
+
def forward(
|
| 969 |
+
self,
|
| 970 |
+
batch: ForwardBatch,
|
| 971 |
+
fastvideo_args: FastVideoArgs,
|
| 972 |
+
) -> ForwardBatch:
|
| 973 |
+
for name in self._CONDITIONING_FIELDS:
|
| 974 |
+
if hasattr(batch, name):
|
| 975 |
+
setattr(batch, name, None)
|
| 976 |
+
return super().forward(batch, fastvideo_args)
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
class Cosmos25V2WDenoisingStage(Cosmos25DenoisingStage):
|
| 980 |
+
"""Cosmos 2.5 Video2World denoising stage."""
|
| 981 |
+
|
| 982 |
+
def forward(
|
| 983 |
+
self,
|
| 984 |
+
batch: ForwardBatch,
|
| 985 |
+
fastvideo_args: FastVideoArgs,
|
| 986 |
+
) -> ForwardBatch:
|
| 987 |
+
return super().forward(batch, fastvideo_args)
|
| 988 |
+
|
| 989 |
+
|
| 990 |
+
class Cosmos25AutoDenoisingStage(PipelineStage):
|
| 991 |
+
"""Route Cosmos 2.5 denoising to T2W vs V2W/I2W."""
|
| 992 |
+
|
| 993 |
+
def __init__(self, transformer, scheduler) -> None:
|
| 994 |
+
super().__init__()
|
| 995 |
+
self._t2w = Cosmos25T2WDenoisingStage(transformer=transformer, scheduler=scheduler)
|
| 996 |
+
self._v2w = Cosmos25V2WDenoisingStage(transformer=transformer, scheduler=scheduler)
|
| 997 |
+
|
| 998 |
+
def pipeline(self):
|
| 999 |
+
return self._v2w.pipeline() if self._v2w.pipeline else None
|
| 1000 |
+
|
| 1001 |
+
def forward(
|
| 1002 |
+
self,
|
| 1003 |
+
batch: ForwardBatch,
|
| 1004 |
+
fastvideo_args: FastVideoArgs,
|
| 1005 |
+
) -> ForwardBatch:
|
| 1006 |
+
conditioning_latents = getattr(batch, "conditioning_latents", None)
|
| 1007 |
+
if conditioning_latents is not None:
|
| 1008 |
+
return self._v2w.forward(batch, fastvideo_args)
|
| 1009 |
+
return self._t2w.forward(batch, fastvideo_args)
|
| 1010 |
+
|
| 1011 |
+
def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
|
| 1012 |
+
conditioning_latents = getattr(batch, "conditioning_latents", None)
|
| 1013 |
+
if conditioning_latents is not None:
|
| 1014 |
+
return self._v2w.verify_input(batch, fastvideo_args)
|
| 1015 |
+
return self._t2w.verify_input(batch, fastvideo_args)
|
| 1016 |
+
|
| 1017 |
+
def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
|
| 1018 |
+
conditioning_latents = getattr(batch, "conditioning_latents", None)
|
| 1019 |
+
if conditioning_latents is not None:
|
| 1020 |
+
return self._v2w.verify_output(batch, fastvideo_args)
|
| 1021 |
+
return self._t2w.verify_output(batch, fastvideo_args)
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
class DmdDenoisingStage(DenoisingStage):
|
| 1025 |
+
"""
|
| 1026 |
+
Denoising stage for DMD.
|
| 1027 |
+
"""
|
| 1028 |
+
|
| 1029 |
+
def __init__(self, transformer, scheduler) -> None:
|
| 1030 |
+
super().__init__(transformer, scheduler)
|
| 1031 |
+
self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0)
|
| 1032 |
+
|
| 1033 |
+
def forward(
|
| 1034 |
+
self,
|
| 1035 |
+
batch: ForwardBatch,
|
| 1036 |
+
fastvideo_args: FastVideoArgs,
|
| 1037 |
+
) -> ForwardBatch:
|
| 1038 |
+
"""
|
| 1039 |
+
Run the denoising loop.
|
| 1040 |
+
|
| 1041 |
+
Args:
|
| 1042 |
+
batch: The current batch information.
|
| 1043 |
+
fastvideo_args: The inference arguments.
|
| 1044 |
+
|
| 1045 |
+
Returns:
|
| 1046 |
+
The batch with denoised latents.
|
| 1047 |
+
"""
|
| 1048 |
+
# Setup precision and autocast settings
|
| 1049 |
+
# TODO(will): make the precision configurable for inference
|
| 1050 |
+
# target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
|
| 1051 |
+
target_dtype = torch.bfloat16
|
| 1052 |
+
autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
|
| 1053 |
+
|
| 1054 |
+
# Get timesteps and calculate warmup steps
|
| 1055 |
+
timesteps = batch.timesteps
|
| 1056 |
+
|
| 1057 |
+
# TODO(will): remove this once we add input/output validation for stages
|
| 1058 |
+
if timesteps is None:
|
| 1059 |
+
raise ValueError("Timesteps must be provided")
|
| 1060 |
+
num_inference_steps = batch.num_inference_steps
|
| 1061 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 1062 |
+
|
| 1063 |
+
# Prepare image latents and embeddings for I2V generation
|
| 1064 |
+
image_embeds = batch.image_embeds
|
| 1065 |
+
if len(image_embeds) > 0:
|
| 1066 |
+
assert torch.isnan(image_embeds[0]).sum() == 0
|
| 1067 |
+
image_embeds = [image_embed.to(target_dtype) for image_embed in image_embeds]
|
| 1068 |
+
|
| 1069 |
+
image_kwargs = self.prepare_extra_func_kwargs(
|
| 1070 |
+
self.transformer.forward,
|
| 1071 |
+
{
|
| 1072 |
+
"encoder_hidden_states_image": image_embeds,
|
| 1073 |
+
"mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24)
|
| 1074 |
+
},
|
| 1075 |
+
)
|
| 1076 |
+
|
| 1077 |
+
pos_cond_kwargs = self.prepare_extra_func_kwargs(
|
| 1078 |
+
self.transformer.forward,
|
| 1079 |
+
{
|
| 1080 |
+
"encoder_hidden_states_2": batch.clip_embedding_pos,
|
| 1081 |
+
"encoder_attention_mask": batch.prompt_attention_mask,
|
| 1082 |
+
},
|
| 1083 |
+
)
|
| 1084 |
+
|
| 1085 |
+
# Get latents and embeddings
|
| 1086 |
+
assert batch.latents is not None, "latents must be provided"
|
| 1087 |
+
latents = batch.latents
|
| 1088 |
+
|
| 1089 |
+
video_raw_latent_shape = latents.shape
|
| 1090 |
+
prompt_embeds = batch.prompt_embeds
|
| 1091 |
+
assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan"
|
| 1092 |
+
timesteps = torch.tensor(fastvideo_args.pipeline_config.dmd_denoising_steps,
|
| 1093 |
+
dtype=torch.long,
|
| 1094 |
+
device=get_local_torch_device())
|
| 1095 |
+
|
| 1096 |
+
# Run denoising loop
|
| 1097 |
+
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
| 1098 |
+
for i, t in enumerate(timesteps):
|
| 1099 |
+
# Skip if interrupted
|
| 1100 |
+
if hasattr(self, 'interrupt') and self.interrupt:
|
| 1101 |
+
continue
|
| 1102 |
+
# Expand latents for I2V
|
| 1103 |
+
noise_latents = latents.clone()
|
| 1104 |
+
latent_model_input = latents.to(target_dtype)
|
| 1105 |
+
|
| 1106 |
+
if batch.image_latent is not None:
|
| 1107 |
+
latent_model_input = torch.cat(
|
| 1108 |
+
[latent_model_input, batch.image_latent.permute(0, 2, 1, 3, 4)], dim=2).to(target_dtype)
|
| 1109 |
+
assert not torch.isnan(latent_model_input).any(), "latent_model_input contains nan"
|
| 1110 |
+
|
| 1111 |
+
# Prepare inputs for transformer
|
| 1112 |
+
t_expand = t.repeat(latent_model_input.shape[0])
|
| 1113 |
+
guidance_expand = (torch.tensor(
|
| 1114 |
+
[fastvideo_args.pipeline_config.embedded_cfg_scale] * latent_model_input.shape[0],
|
| 1115 |
+
dtype=torch.float32,
|
| 1116 |
+
device=get_local_torch_device(),
|
| 1117 |
+
).to(target_dtype) * 1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale is not None else None)
|
| 1118 |
+
|
| 1119 |
+
# Predict noise residual
|
| 1120 |
+
with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
|
| 1121 |
+
if (vsa_available and self.attn_backend == VideoSparseAttentionBackend) or \
|
| 1122 |
+
(sparse_fp4_available and self.attn_backend in sparse_fp4_backends):
|
| 1123 |
+
self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()
|
| 1124 |
+
|
| 1125 |
+
if self.attn_metadata_builder_cls is not None:
|
| 1126 |
+
self.attn_metadata_builder = self.attn_metadata_builder_cls()
|
| 1127 |
+
# TODO(will): clean this up
|
| 1128 |
+
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
| 1129 |
+
current_timestep=i, # type: ignore
|
| 1130 |
+
raw_latent_shape=batch.raw_latent_shape[2:5], # type: ignore
|
| 1131 |
+
patch_size=fastvideo_args.pipeline_config. # type: ignore
|
| 1132 |
+
dit_config.patch_size, # type: ignore
|
| 1133 |
+
VSA_sparsity=fastvideo_args.VSA_sparsity, # type: ignore
|
| 1134 |
+
device=get_local_torch_device(), # type: ignore
|
| 1135 |
+
) # type: ignore
|
| 1136 |
+
assert attn_metadata is not None, "attn_metadata cannot be None"
|
| 1137 |
+
else:
|
| 1138 |
+
attn_metadata = None
|
| 1139 |
+
else:
|
| 1140 |
+
attn_metadata = None
|
| 1141 |
+
|
| 1142 |
+
batch.is_cfg_negative = False
|
| 1143 |
+
with set_forward_context(
|
| 1144 |
+
current_timestep=i,
|
| 1145 |
+
attn_metadata=attn_metadata,
|
| 1146 |
+
forward_batch=batch,
|
| 1147 |
+
# fastvideo_args=fastvideo_args
|
| 1148 |
+
):
|
| 1149 |
+
# Run transformer
|
| 1150 |
+
pred_noise = self.transformer(
|
| 1151 |
+
latent_model_input.permute(0, 2, 1, 3, 4),
|
| 1152 |
+
prompt_embeds,
|
| 1153 |
+
t_expand,
|
| 1154 |
+
guidance=guidance_expand,
|
| 1155 |
+
**image_kwargs,
|
| 1156 |
+
**pos_cond_kwargs,
|
| 1157 |
+
).permute(0, 2, 1, 3, 4)
|
| 1158 |
+
|
| 1159 |
+
pred_video = pred_noise_to_pred_video(pred_noise=pred_noise.flatten(0, 1),
|
| 1160 |
+
noise_input_latent=noise_latents.flatten(0, 1),
|
| 1161 |
+
timestep=t_expand,
|
| 1162 |
+
scheduler=self.scheduler).unflatten(0, pred_noise.shape[:2])
|
| 1163 |
+
|
| 1164 |
+
if i < len(timesteps) - 1:
|
| 1165 |
+
next_timestep = timesteps[i + 1] * torch.ones([1], dtype=torch.long, device=pred_video.device)
|
| 1166 |
+
noise_generator = batch.generator[0] if isinstance(batch.generator, list) else batch.generator
|
| 1167 |
+
noise = torch.randn(video_raw_latent_shape, dtype=pred_video.dtype,
|
| 1168 |
+
generator=noise_generator).to(self.device)
|
| 1169 |
+
latents = self.scheduler.add_noise(pred_video.flatten(0, 1), noise.flatten(0, 1),
|
| 1170 |
+
next_timestep).unflatten(0, pred_video.shape[:2])
|
| 1171 |
+
else:
|
| 1172 |
+
latents = pred_video
|
| 1173 |
+
|
| 1174 |
+
# Update progress bar
|
| 1175 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
|
| 1176 |
+
(i + 1) % self.scheduler.order == 0 and progress_bar is not None):
|
| 1177 |
+
progress_bar.update()
|
| 1178 |
+
|
| 1179 |
+
# Gather results if using sequence parallelism
|
| 1180 |
+
latents = latents.permute(0, 2, 1, 3, 4)
|
| 1181 |
+
# Update batch with final latents
|
| 1182 |
+
batch.latents = latents
|
| 1183 |
+
|
| 1184 |
+
return batch
|
standalone_inference/overlay_files/fastvideo/platforms/cuda.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/cuda.py
|
| 3 |
+
"""Code inside this file can safely assume cuda platform, e.g. importing
|
| 4 |
+
pynvml. However, it should not initialize cuda context.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from collections.abc import Callable
|
| 9 |
+
from functools import lru_cache, wraps
|
| 10 |
+
from typing import TypeVar
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from typing_extensions import ParamSpec
|
| 14 |
+
|
| 15 |
+
import fastvideo.envs as envs
|
| 16 |
+
from fastvideo.logger import init_logger
|
| 17 |
+
from fastvideo.platforms.interface import (AttentionBackendEnum, DeviceCapability, Platform, PlatformEnum)
|
| 18 |
+
from fastvideo.utils import import_pynvml
|
| 19 |
+
|
| 20 |
+
logger = init_logger(__name__)
|
| 21 |
+
|
| 22 |
+
_P = ParamSpec("_P")
|
| 23 |
+
_R = TypeVar("_R")
|
| 24 |
+
|
| 25 |
+
pynvml = import_pynvml() # type: ignore[no-untyped-call]
|
| 26 |
+
|
| 27 |
+
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
|
| 28 |
+
# see https://github.com/huggingface/diffusers/issues/9704 for details
|
| 29 |
+
torch.backends.cuda.enable_cudnn_sdp(False)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def device_id_to_physical_device_id(device_id: int) -> int:
|
| 33 |
+
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
| 34 |
+
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
| 35 |
+
if device_ids == [""]:
|
| 36 |
+
msg = ("CUDA_VISIBLE_DEVICES is set to empty string, which means"
|
| 37 |
+
" GPU support is disabled. If you are using ray, please unset"
|
| 38 |
+
" the environment variable `CUDA_VISIBLE_DEVICES` inside the"
|
| 39 |
+
" worker/actor. "
|
| 40 |
+
"Check https://github.com/vllm-project/vllm/issues/8402 for"
|
| 41 |
+
" more information.")
|
| 42 |
+
raise RuntimeError(msg)
|
| 43 |
+
physical_device_id = device_ids[device_id]
|
| 44 |
+
return int(physical_device_id)
|
| 45 |
+
else:
|
| 46 |
+
return device_id
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
| 50 |
+
|
| 51 |
+
@wraps(fn)
|
| 52 |
+
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
| 53 |
+
pynvml.nvmlInit()
|
| 54 |
+
try:
|
| 55 |
+
return fn(*args, **kwargs)
|
| 56 |
+
finally:
|
| 57 |
+
pynvml.nvmlShutdown()
|
| 58 |
+
|
| 59 |
+
return wrapper
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class CudaPlatformBase(Platform):
|
| 63 |
+
_enum = PlatformEnum.CUDA
|
| 64 |
+
device_name: str = "cuda"
|
| 65 |
+
device_type: str = "cuda"
|
| 66 |
+
dispatch_key: str = "CUDA"
|
| 67 |
+
ray_device_key: str = "GPU"
|
| 68 |
+
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
|
| 72 |
+
raise NotImplementedError
|
| 73 |
+
|
| 74 |
+
@classmethod
|
| 75 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 76 |
+
raise NotImplementedError
|
| 77 |
+
|
| 78 |
+
@classmethod
|
| 79 |
+
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
| 80 |
+
raise NotImplementedError
|
| 81 |
+
|
| 82 |
+
@classmethod
|
| 83 |
+
def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:
|
| 84 |
+
if enforce_eager:
|
| 85 |
+
logger.warning("To see benefits of async output processing, enable CUDA "
|
| 86 |
+
"graph. Since, enforce-eager is enabled, async output "
|
| 87 |
+
"processor cannot be used")
|
| 88 |
+
return False
|
| 89 |
+
return True
|
| 90 |
+
|
| 91 |
+
@classmethod
|
| 92 |
+
def is_full_nvlink(cls, device_ids: list[int]) -> bool:
|
| 93 |
+
raise NotImplementedError
|
| 94 |
+
|
| 95 |
+
@classmethod
|
| 96 |
+
def log_warnings(cls) -> None:
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
@classmethod
|
| 100 |
+
def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float:
|
| 101 |
+
torch.cuda.reset_peak_memory_stats(device)
|
| 102 |
+
return float(torch.cuda.max_memory_allocated(device))
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def get_torch_device(cls) -> object:
|
| 106 |
+
"""
|
| 107 |
+
Return torch.cuda
|
| 108 |
+
"""
|
| 109 |
+
return torch.cuda
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None, head_size: int,
|
| 113 |
+
dtype: torch.dtype) -> str:
|
| 114 |
+
# TODO(will): maybe come up with a more general interface for local attention
|
| 115 |
+
# if distributed is False, we always try to use Flash attn
|
| 116 |
+
|
| 117 |
+
logger.info("Trying FASTVIDEO_ATTENTION_BACKEND=%s", envs.FASTVIDEO_ATTENTION_BACKEND)
|
| 118 |
+
logger.info("Selected backend: %s", selected_backend)
|
| 119 |
+
if selected_backend == AttentionBackendEnum.SAGE_ATTN:
|
| 120 |
+
try:
|
| 121 |
+
from sageattention import sageattn # noqa: F401
|
| 122 |
+
|
| 123 |
+
from fastvideo.attention.backends.sage_attn import ( # noqa: F401
|
| 124 |
+
SageAttentionBackend)
|
| 125 |
+
logger.info("Using Sage Attention backend.")
|
| 126 |
+
|
| 127 |
+
return "fastvideo.attention.backends.sage_attn.SageAttentionBackend"
|
| 128 |
+
except ImportError as e:
|
| 129 |
+
logger.info(e)
|
| 130 |
+
logger.info("Sage Attention backend is not installed. Fall back to Flash Attention.")
|
| 131 |
+
elif selected_backend == AttentionBackendEnum.SAGE_ATTN_THREE:
|
| 132 |
+
try:
|
| 133 |
+
from sageattn3 import sageattn3_blackwell # noqa: F401
|
| 134 |
+
|
| 135 |
+
from fastvideo.attention.backends.sage_attn3 import ( # noqa: F401
|
| 136 |
+
SageAttention3Backend)
|
| 137 |
+
logger.info("Using Sage Attention 3 backend.")
|
| 138 |
+
|
| 139 |
+
return "fastvideo.attention.backends.sage_attn3.SageAttention3Backend"
|
| 140 |
+
except ImportError as e:
|
| 141 |
+
logger.info(e)
|
| 142 |
+
logger.info("Sage Attention 3 backend is not installed. Fall back to Flash Attention.")
|
| 143 |
+
elif selected_backend == AttentionBackendEnum.ATTN_QAT_INFER:
|
| 144 |
+
try:
|
| 145 |
+
from fastvideo.attention.backends.attn_qat_infer import ( # noqa: F401
|
| 146 |
+
AttnQatInferBackend, is_attn_qat_infer_available,
|
| 147 |
+
)
|
| 148 |
+
if not is_attn_qat_infer_available():
|
| 149 |
+
raise ImportError("attn_qat_infer could not be imported.")
|
| 150 |
+
logger.info("Using attn_qat_infer backend.")
|
| 151 |
+
|
| 152 |
+
return "fastvideo.attention.backends.attn_qat_infer.AttnQatInferBackend"
|
| 153 |
+
except ImportError as e:
|
| 154 |
+
logger.info(e)
|
| 155 |
+
logger.info("attn_qat_infer backend is not installed. Fall back to Flash Attention.")
|
| 156 |
+
elif selected_backend == AttentionBackendEnum.ATTN_QAT_TRAIN:
|
| 157 |
+
try:
|
| 158 |
+
from fastvideo_kernel.triton_kernels.attn_qat_train import attention # noqa: F401
|
| 159 |
+
|
| 160 |
+
from fastvideo.attention.backends.attn_qat_train import ( # noqa: F401
|
| 161 |
+
AttnQatTrainBackend)
|
| 162 |
+
logger.info("Using attn_qat_train backend.")
|
| 163 |
+
|
| 164 |
+
return "fastvideo.attention.backends.attn_qat_train.AttnQatTrainBackend"
|
| 165 |
+
except ImportError as e:
|
| 166 |
+
logger.info(e)
|
| 167 |
+
logger.info("attn_qat_train backend is not installed. Fall back to Flash Attention.")
|
| 168 |
+
elif selected_backend == AttentionBackendEnum.VIDEO_SPARSE_ATTN:
|
| 169 |
+
try:
|
| 170 |
+
from fastvideo_kernel import video_sparse_attn # noqa: F401
|
| 171 |
+
|
| 172 |
+
from fastvideo.attention.backends.video_sparse_attn import ( # noqa: F401
|
| 173 |
+
VideoSparseAttentionBackend)
|
| 174 |
+
logger.info("Using Video Sparse Attention backend.")
|
| 175 |
+
|
| 176 |
+
return "fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionBackend"
|
| 177 |
+
except ImportError as e:
|
| 178 |
+
logger.error("Failed to import Video Sparse Attention backend: %s", str(e))
|
| 179 |
+
raise ImportError("The Video Sparse Attention backend is not installed. "
|
| 180 |
+
"To install it, please follow the instructions at: "
|
| 181 |
+
"https://hao-ai-lab.github.io/FastVideo/video_sparse_attention/installation ") from e
|
| 182 |
+
elif selected_backend == AttentionBackendEnum.SPARSE_FP4_ATTN:
|
| 183 |
+
try:
|
| 184 |
+
from fastvideo.attention.backends.sparse_fp4_attn import ( # noqa: F401
|
| 185 |
+
SparseFP4AttentionBackend)
|
| 186 |
+
logger.info("Using Sparse FP4 Attention backend (FP4 quant + VSA).")
|
| 187 |
+
return "fastvideo.attention.backends.sparse_fp4_attn.SparseFP4AttentionBackend"
|
| 188 |
+
except ImportError as e:
|
| 189 |
+
logger.error("Failed to import Sparse FP4 Attention backend: %s", str(e))
|
| 190 |
+
raise ImportError("Sparse FP4 Attention backend is not available.") from e
|
| 191 |
+
elif selected_backend == AttentionBackendEnum.SPARSE_FP4_OURS_P_ATTN:
|
| 192 |
+
try:
|
| 193 |
+
from fastvideo.attention.backends.sparse_fp4_ours_p_attn import ( # noqa: F401
|
| 194 |
+
SparseFP4OursPAttentionBackend)
|
| 195 |
+
logger.info(
|
| 196 |
+
"Using Sparse FP4 Ours-P Attention backend (group-local P quant + VSA)."
|
| 197 |
+
)
|
| 198 |
+
return "fastvideo.attention.backends.sparse_fp4_ours_p_attn.SparseFP4OursPAttentionBackend"
|
| 199 |
+
except ImportError as e:
|
| 200 |
+
logger.error("Failed to import Sparse FP4 Ours-P Attention backend: %s", str(e))
|
| 201 |
+
raise ImportError("Sparse FP4 Ours-P Attention backend is not available.") from e
|
| 202 |
+
elif selected_backend == AttentionBackendEnum.BSA_ATTN:
|
| 203 |
+
try:
|
| 204 |
+
from fastvideo.attention.backends.bsa_attn import ( # noqa: F401
|
| 205 |
+
BSAAttentionBackend)
|
| 206 |
+
logger.info("Using BSA Attention backend.")
|
| 207 |
+
|
| 208 |
+
return "fastvideo.attention.backends.bsa_attn.BSAAttentionBackend"
|
| 209 |
+
except ImportError as e:
|
| 210 |
+
logger.error("Failed to import BSA Attention backend: %s", str(e))
|
| 211 |
+
raise ImportError("The BSA Attention backend failed to import.") from e
|
| 212 |
+
elif selected_backend == AttentionBackendEnum.VMOBA_ATTN:
|
| 213 |
+
try:
|
| 214 |
+
from fastvideo_kernel import moba_attn_varlen # noqa: F401
|
| 215 |
+
from fastvideo.attention.backends.vmoba import ( # noqa: F401
|
| 216 |
+
VMOBAAttentionBackend)
|
| 217 |
+
logger.info("Using Video MOBA Attention backend.")
|
| 218 |
+
|
| 219 |
+
return "fastvideo.attention.backends.vmoba.VMOBAAttentionBackend"
|
| 220 |
+
except ImportError as e:
|
| 221 |
+
logger.error("Failed to import Video MoBA Attention backend: %s", str(e))
|
| 222 |
+
raise ImportError("Video MoBA Attention backend is not installed. ") from e
|
| 223 |
+
elif selected_backend == AttentionBackendEnum.SLA_ATTN:
|
| 224 |
+
try:
|
| 225 |
+
from fastvideo.attention.backends.sla import ( # noqa: F401
|
| 226 |
+
SLAAttentionBackend)
|
| 227 |
+
logger.info("Using SLA (Sparse-Linear Attention) backend.")
|
| 228 |
+
|
| 229 |
+
return "fastvideo.attention.backends.sla.SLAAttentionBackend"
|
| 230 |
+
except ImportError as e:
|
| 231 |
+
logger.error("Failed to import SLA Attention backend: %s", str(e))
|
| 232 |
+
raise ImportError("SLA Attention backend is not available. ") from e
|
| 233 |
+
elif selected_backend == AttentionBackendEnum.SAGE_SLA_ATTN:
|
| 234 |
+
try:
|
| 235 |
+
from fastvideo.attention.backends.sla import ( # noqa: F401
|
| 236 |
+
SageSLAAttentionBackend)
|
| 237 |
+
logger.info("Using SageSLA (Quantized Sparse-Linear Attention) backend.")
|
| 238 |
+
|
| 239 |
+
return "fastvideo.attention.backends.sla.SageSLAAttentionBackend"
|
| 240 |
+
except ImportError as e:
|
| 241 |
+
logger.error("Failed to import SageSLA Attention backend: %s", str(e))
|
| 242 |
+
raise ImportError("SageSLA Attention backend requires spas_sage_attn. "
|
| 243 |
+
"Install with: pip install git+https://github.com/thu-ml/SpargeAttn.git") from e
|
| 244 |
+
elif selected_backend == AttentionBackendEnum.TORCH_SDPA:
|
| 245 |
+
logger.info("Using Torch SDPA backend.")
|
| 246 |
+
return "fastvideo.attention.backends.sdpa.SDPABackend"
|
| 247 |
+
elif selected_backend == AttentionBackendEnum.FLASH_ATTN or selected_backend is None:
|
| 248 |
+
pass
|
| 249 |
+
elif selected_backend:
|
| 250 |
+
raise ValueError(f"Invalid attention backend for {cls.device_name}")
|
| 251 |
+
|
| 252 |
+
target_backend = AttentionBackendEnum.FLASH_ATTN
|
| 253 |
+
if not cls.has_device_capability(80):
|
| 254 |
+
logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
|
| 255 |
+
"GPUs.")
|
| 256 |
+
target_backend = AttentionBackendEnum.TORCH_SDPA
|
| 257 |
+
elif dtype not in (torch.float16, torch.bfloat16):
|
| 258 |
+
logger.info("Cannot use FlashAttention-2 backend for dtype other than "
|
| 259 |
+
"torch.float16 or torch.bfloat16.")
|
| 260 |
+
target_backend = AttentionBackendEnum.TORCH_SDPA
|
| 261 |
+
|
| 262 |
+
# FlashAttn is valid for the model, checking if the package is
|
| 263 |
+
# installed.
|
| 264 |
+
if target_backend == AttentionBackendEnum.FLASH_ATTN:
|
| 265 |
+
try:
|
| 266 |
+
import flash_attn # noqa: F401
|
| 267 |
+
|
| 268 |
+
from fastvideo.attention.backends.flash_attn import ( # noqa: F401
|
| 269 |
+
FlashAttentionBackend)
|
| 270 |
+
|
| 271 |
+
supported_sizes = \
|
| 272 |
+
FlashAttentionBackend.get_supported_head_sizes()
|
| 273 |
+
if head_size not in supported_sizes:
|
| 274 |
+
logger.info("Cannot use FlashAttention-2 backend for head size %d.", head_size)
|
| 275 |
+
target_backend = AttentionBackendEnum.TORCH_SDPA
|
| 276 |
+
except ImportError:
|
| 277 |
+
logger.info("Cannot use FlashAttention-2 backend because the "
|
| 278 |
+
"flash_attn package is not found. "
|
| 279 |
+
"Make sure that flash_attn was built and installed "
|
| 280 |
+
"(on by default).")
|
| 281 |
+
target_backend = AttentionBackendEnum.TORCH_SDPA
|
| 282 |
+
|
| 283 |
+
if target_backend == AttentionBackendEnum.TORCH_SDPA:
|
| 284 |
+
logger.info("Using Torch SDPA backend.")
|
| 285 |
+
|
| 286 |
+
return "fastvideo.attention.backends.sdpa.SDPABackend"
|
| 287 |
+
|
| 288 |
+
logger.info("Using Flash Attention backend.")
|
| 289 |
+
|
| 290 |
+
return "fastvideo.attention.backends.flash_attn.FlashAttentionBackend"
|
| 291 |
+
|
| 292 |
+
@classmethod
|
| 293 |
+
def get_device_communicator_cls(cls) -> str:
|
| 294 |
+
return "fastvideo.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# NVML utils
|
| 298 |
+
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
| 299 |
+
# all the related functions work on real physical device ids.
|
| 300 |
+
# the major benefit of using NVML is that it will not initialize CUDA
|
| 301 |
+
class NvmlCudaPlatform(CudaPlatformBase):
|
| 302 |
+
|
| 303 |
+
@classmethod
|
| 304 |
+
@lru_cache(maxsize=8)
|
| 305 |
+
@with_nvml_context
|
| 306 |
+
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
|
| 307 |
+
try:
|
| 308 |
+
physical_device_id = device_id_to_physical_device_id(device_id)
|
| 309 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
| 310 |
+
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
| 311 |
+
return DeviceCapability(major=major, minor=minor)
|
| 312 |
+
except RuntimeError:
|
| 313 |
+
return None
|
| 314 |
+
|
| 315 |
+
@classmethod
|
| 316 |
+
@lru_cache(maxsize=8)
|
| 317 |
+
@with_nvml_context
|
| 318 |
+
def has_device_capability(
|
| 319 |
+
cls,
|
| 320 |
+
capability: tuple[int, int] | int,
|
| 321 |
+
device_id: int = 0,
|
| 322 |
+
) -> bool:
|
| 323 |
+
try:
|
| 324 |
+
return bool(super().has_device_capability(capability, device_id))
|
| 325 |
+
except RuntimeError:
|
| 326 |
+
return False
|
| 327 |
+
|
| 328 |
+
@classmethod
|
| 329 |
+
@lru_cache(maxsize=8)
|
| 330 |
+
@with_nvml_context
|
| 331 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 332 |
+
physical_device_id = device_id_to_physical_device_id(device_id)
|
| 333 |
+
return cls._get_physical_device_name(physical_device_id)
|
| 334 |
+
|
| 335 |
+
@classmethod
|
| 336 |
+
@lru_cache(maxsize=8)
|
| 337 |
+
@with_nvml_context
|
| 338 |
+
def get_device_uuid(cls, device_id: int = 0) -> str:
|
| 339 |
+
physical_device_id = device_id_to_physical_device_id(device_id)
|
| 340 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
| 341 |
+
return str(pynvml.nvmlDeviceGetUUID(handle))
|
| 342 |
+
|
| 343 |
+
@classmethod
|
| 344 |
+
@lru_cache(maxsize=8)
|
| 345 |
+
@with_nvml_context
|
| 346 |
+
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
| 347 |
+
physical_device_id = device_id_to_physical_device_id(device_id)
|
| 348 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
| 349 |
+
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
|
| 350 |
+
|
| 351 |
+
@classmethod
|
| 352 |
+
@with_nvml_context
|
| 353 |
+
def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool:
|
| 354 |
+
"""
|
| 355 |
+
query if the set of gpus are fully connected by nvlink (1 hop)
|
| 356 |
+
"""
|
| 357 |
+
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
| 358 |
+
for i, handle in enumerate(handles):
|
| 359 |
+
for j, peer_handle in enumerate(handles):
|
| 360 |
+
if i < j:
|
| 361 |
+
try:
|
| 362 |
+
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
| 363 |
+
handle,
|
| 364 |
+
peer_handle,
|
| 365 |
+
pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
|
| 366 |
+
)
|
| 367 |
+
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
| 368 |
+
return False
|
| 369 |
+
except pynvml.NVMLError:
|
| 370 |
+
logger.exception("NVLink detection failed. This is normal if"
|
| 371 |
+
" your machine has no NVLink equipped.")
|
| 372 |
+
return False
|
| 373 |
+
return True
|
| 374 |
+
|
| 375 |
+
@classmethod
|
| 376 |
+
def _get_physical_device_name(cls, device_id: int = 0) -> str:
|
| 377 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
| 378 |
+
return str(pynvml.nvmlDeviceGetName(handle))
|
| 379 |
+
|
| 380 |
+
@classmethod
|
| 381 |
+
@with_nvml_context
|
| 382 |
+
def log_warnings(cls) -> None:
|
| 383 |
+
device_ids: int = pynvml.nvmlDeviceGetCount()
|
| 384 |
+
if device_ids > 1:
|
| 385 |
+
device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
|
| 386 |
+
if (len(set(device_names)) > 1 and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
|
| 387 |
+
logger.warning(
|
| 388 |
+
"Detected different devices in the system: %s. Please"
|
| 389 |
+
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
|
| 390 |
+
"avoid unexpected behavior.",
|
| 391 |
+
", ".join(device_names),
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class NonNvmlCudaPlatform(CudaPlatformBase):
|
| 396 |
+
|
| 397 |
+
@classmethod
|
| 398 |
+
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
| 399 |
+
major, minor = torch.cuda.get_device_capability(device_id)
|
| 400 |
+
return DeviceCapability(major=major, minor=minor)
|
| 401 |
+
|
| 402 |
+
@classmethod
|
| 403 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 404 |
+
return str(torch.cuda.get_device_name(device_id))
|
| 405 |
+
|
| 406 |
+
@classmethod
|
| 407 |
+
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
| 408 |
+
device_props = torch.cuda.get_device_properties(device_id)
|
| 409 |
+
return int(device_props.total_memory)
|
| 410 |
+
|
| 411 |
+
@classmethod
|
| 412 |
+
def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool:
|
| 413 |
+
logger.exception("NVLink detection not possible, as context support was"
|
| 414 |
+
" not found. Assuming no NVLink available.")
|
| 415 |
+
return False
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
# Autodetect either NVML-enabled or non-NVML platform
|
| 419 |
+
# based on whether NVML is available.
|
| 420 |
+
nvml_available = False
|
| 421 |
+
try:
|
| 422 |
+
try:
|
| 423 |
+
pynvml.nvmlInit()
|
| 424 |
+
nvml_available = True
|
| 425 |
+
except Exception:
|
| 426 |
+
# On Jetson, NVML is not supported.
|
| 427 |
+
nvml_available = False
|
| 428 |
+
finally:
|
| 429 |
+
if nvml_available:
|
| 430 |
+
pynvml.nvmlShutdown()
|
| 431 |
+
|
| 432 |
+
CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform
|
| 433 |
+
|
| 434 |
+
try:
|
| 435 |
+
from sphinx.ext.autodoc.mock import _MockModule
|
| 436 |
+
|
| 437 |
+
if not isinstance(pynvml, _MockModule):
|
| 438 |
+
CudaPlatform.log_warnings()
|
| 439 |
+
except ModuleNotFoundError:
|
| 440 |
+
CudaPlatform.log_warnings()
|
standalone_inference/overlay_files/fastvideo/platforms/interface.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import enum
|
| 2 |
+
import random
|
| 3 |
+
from typing import Any, NamedTuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from fastvideo.logger import init_logger
|
| 9 |
+
|
| 10 |
+
logger = init_logger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AttentionBackendEnum(enum.Enum):
|
| 14 |
+
FLASH_ATTN = enum.auto()
|
| 15 |
+
TORCH_SDPA = enum.auto()
|
| 16 |
+
SAGE_ATTN = enum.auto()
|
| 17 |
+
SAGE_ATTN_THREE = enum.auto()
|
| 18 |
+
ATTN_QAT_INFER = enum.auto()
|
| 19 |
+
ATTN_QAT_TRAIN = enum.auto()
|
| 20 |
+
VIDEO_SPARSE_ATTN = enum.auto()
|
| 21 |
+
BSA_ATTN = enum.auto()
|
| 22 |
+
VMOBA_ATTN = enum.auto()
|
| 23 |
+
SLA_ATTN = enum.auto()
|
| 24 |
+
SAGE_SLA_ATTN = enum.auto()
|
| 25 |
+
SPARSE_FP4_ATTN = enum.auto()
|
| 26 |
+
SPARSE_FP4_OURS_P_ATTN = enum.auto()
|
| 27 |
+
NO_ATTENTION = enum.auto()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class PlatformEnum(enum.Enum):
|
| 31 |
+
CUDA = enum.auto()
|
| 32 |
+
ROCM = enum.auto()
|
| 33 |
+
TPU = enum.auto()
|
| 34 |
+
XPU = enum.auto()
|
| 35 |
+
CPU = enum.auto()
|
| 36 |
+
MPS = enum.auto()
|
| 37 |
+
OOT = enum.auto()
|
| 38 |
+
UNSPECIFIED = enum.auto()
|
| 39 |
+
NPU = enum.auto()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class CpuArchEnum(enum.Enum):
|
| 43 |
+
X86 = enum.auto()
|
| 44 |
+
ARM = enum.auto()
|
| 45 |
+
UNSPECIFIED = enum.auto()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class DeviceCapability(NamedTuple):
|
| 49 |
+
major: int
|
| 50 |
+
minor: int
|
| 51 |
+
|
| 52 |
+
def as_version_str(self) -> str:
|
| 53 |
+
return f"{self.major}.{self.minor}"
|
| 54 |
+
|
| 55 |
+
def to_int(self) -> int:
|
| 56 |
+
"""
|
| 57 |
+
Express device capability as an integer ``<major><minor>``.
|
| 58 |
+
|
| 59 |
+
It is assumed that the minor version is always a single digit.
|
| 60 |
+
"""
|
| 61 |
+
assert 0 <= self.minor < 10
|
| 62 |
+
return self.major * 10 + self.minor
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Platform:
|
| 66 |
+
_enum: PlatformEnum
|
| 67 |
+
device_name: str
|
| 68 |
+
device_type: str
|
| 69 |
+
|
| 70 |
+
dispatch_key: str = "CPU"
|
| 71 |
+
|
| 72 |
+
# platform-agnostic way to specify the device control environment variable,
|
| 73 |
+
# .e.g. CUDA_VISIBLE_DEVICES for CUDA.
|
| 74 |
+
# hint: search for "get_visible_accelerator_ids_env_var" in
|
| 75 |
+
# https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
|
| 76 |
+
device_control_env_var: str = "FASTVIDEO_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"
|
| 77 |
+
|
| 78 |
+
# available ray device keys:
|
| 79 |
+
# https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
|
| 80 |
+
# empty string means the device does not support ray
|
| 81 |
+
ray_device_key: str = ""
|
| 82 |
+
# The torch.compile backend for compiling simple and
|
| 83 |
+
# standalone functions. The default value is "inductor" to keep
|
| 84 |
+
# the same behavior as PyTorch.
|
| 85 |
+
# NOTE: for the forward part of the model, vLLM has another separate
|
| 86 |
+
# compilation strategy.
|
| 87 |
+
simple_compile_backend: str = "inductor"
|
| 88 |
+
|
| 89 |
+
supported_quantization: list[str] = []
|
| 90 |
+
|
| 91 |
+
additional_env_vars: list[str] = []
|
| 92 |
+
|
| 93 |
+
def is_cuda(self) -> bool:
|
| 94 |
+
return self._enum == PlatformEnum.CUDA
|
| 95 |
+
|
| 96 |
+
def is_rocm(self) -> bool:
|
| 97 |
+
return self._enum == PlatformEnum.ROCM
|
| 98 |
+
|
| 99 |
+
def is_tpu(self) -> bool:
|
| 100 |
+
return self._enum == PlatformEnum.TPU
|
| 101 |
+
|
| 102 |
+
def is_xpu(self) -> bool:
|
| 103 |
+
return self._enum == PlatformEnum.XPU
|
| 104 |
+
|
| 105 |
+
def is_cpu(self) -> bool:
|
| 106 |
+
return self._enum == PlatformEnum.CPU
|
| 107 |
+
|
| 108 |
+
def is_out_of_tree(self) -> bool:
|
| 109 |
+
return self._enum == PlatformEnum.OOT
|
| 110 |
+
|
| 111 |
+
def is_cuda_alike(self) -> bool:
|
| 112 |
+
"""Stateless version of :func:`torch.cuda.is_available`."""
|
| 113 |
+
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
| 114 |
+
|
| 115 |
+
def is_mps(self) -> bool:
|
| 116 |
+
return self._enum == PlatformEnum.MPS
|
| 117 |
+
|
| 118 |
+
def is_npu(self) -> bool:
|
| 119 |
+
return self._enum == PlatformEnum.NPU
|
| 120 |
+
|
| 121 |
+
@classmethod
|
| 122 |
+
def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None, head_size: int,
|
| 123 |
+
dtype: torch.dtype) -> str:
|
| 124 |
+
"""Get the attention backend class of a device."""
|
| 125 |
+
return ""
|
| 126 |
+
|
| 127 |
+
@classmethod
|
| 128 |
+
def get_device_capability(
|
| 129 |
+
cls,
|
| 130 |
+
device_id: int = 0,
|
| 131 |
+
) -> DeviceCapability | None:
|
| 132 |
+
"""Stateless version of :func:`torch.cuda.get_device_capability`."""
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
@classmethod
|
| 136 |
+
def has_device_capability(
|
| 137 |
+
cls,
|
| 138 |
+
capability: tuple[int, int] | int,
|
| 139 |
+
device_id: int = 0,
|
| 140 |
+
) -> bool:
|
| 141 |
+
"""
|
| 142 |
+
Test whether this platform is compatible with a device capability.
|
| 143 |
+
|
| 144 |
+
The ``capability`` argument can either be:
|
| 145 |
+
|
| 146 |
+
- A tuple ``(major, minor)``.
|
| 147 |
+
- An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
|
| 148 |
+
"""
|
| 149 |
+
current_capability = cls.get_device_capability(device_id=device_id)
|
| 150 |
+
if current_capability is None:
|
| 151 |
+
return False
|
| 152 |
+
|
| 153 |
+
if isinstance(capability, tuple):
|
| 154 |
+
return current_capability >= capability
|
| 155 |
+
|
| 156 |
+
return current_capability.to_int() >= capability
|
| 157 |
+
|
| 158 |
+
@classmethod
|
| 159 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 160 |
+
"""Get the name of a device."""
|
| 161 |
+
raise NotImplementedError
|
| 162 |
+
|
| 163 |
+
@classmethod
|
| 164 |
+
def get_device_uuid(cls, device_id: int = 0) -> str:
|
| 165 |
+
"""Get the uuid of a device, e.g. the PCI bus ID."""
|
| 166 |
+
raise NotImplementedError
|
| 167 |
+
|
| 168 |
+
@classmethod
|
| 169 |
+
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
| 170 |
+
"""Get the total memory of a device in bytes."""
|
| 171 |
+
raise NotImplementedError
|
| 172 |
+
|
| 173 |
+
@classmethod
|
| 174 |
+
def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:
|
| 175 |
+
"""
|
| 176 |
+
Check if the current platform supports async output.
|
| 177 |
+
"""
|
| 178 |
+
raise NotImplementedError
|
| 179 |
+
|
| 180 |
+
@classmethod
|
| 181 |
+
def get_torch_device(cls) -> Any:
|
| 182 |
+
"""
|
| 183 |
+
Check if the current platform supports torch device.
|
| 184 |
+
"""
|
| 185 |
+
raise NotImplementedError
|
| 186 |
+
|
| 187 |
+
@classmethod
|
| 188 |
+
def inference_mode(cls):
|
| 189 |
+
"""A device-specific wrapper of `torch.inference_mode`.
|
| 190 |
+
|
| 191 |
+
This wrapper is recommended because some hardware backends such as TPU
|
| 192 |
+
do not support `torch.inference_mode`. In such a case, they will fall
|
| 193 |
+
back to `torch.no_grad` by overriding this method.
|
| 194 |
+
"""
|
| 195 |
+
return torch.inference_mode(mode=True)
|
| 196 |
+
|
| 197 |
+
@classmethod
|
| 198 |
+
def seed_everything(cls, seed: int | None = None) -> None:
|
| 199 |
+
"""
|
| 200 |
+
Set the seed of each random module.
|
| 201 |
+
`torch.manual_seed` will set seed on all devices.
|
| 202 |
+
|
| 203 |
+
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
|
| 204 |
+
"""
|
| 205 |
+
if seed is not None:
|
| 206 |
+
random.seed(seed)
|
| 207 |
+
np.random.seed(seed)
|
| 208 |
+
torch.manual_seed(seed)
|
| 209 |
+
torch.cuda.manual_seed_all(seed)
|
| 210 |
+
|
| 211 |
+
@classmethod
|
| 212 |
+
def verify_model_arch(cls, model_arch: str) -> None:
|
| 213 |
+
"""
|
| 214 |
+
Verify whether the current platform supports the specified model
|
| 215 |
+
architecture.
|
| 216 |
+
|
| 217 |
+
- This will raise an Error or Warning based on the model support on
|
| 218 |
+
the current platform.
|
| 219 |
+
- By default all models are considered supported.
|
| 220 |
+
"""
|
| 221 |
+
pass
|
| 222 |
+
|
| 223 |
+
@classmethod
|
| 224 |
+
def verify_quantization(cls, quant: str) -> None:
|
| 225 |
+
"""
|
| 226 |
+
Verify whether the quantization is supported by the current platform.
|
| 227 |
+
"""
|
| 228 |
+
if cls.supported_quantization and \
|
| 229 |
+
quant not in cls.supported_quantization:
|
| 230 |
+
raise ValueError(f"{quant} quantization is currently not supported in "
|
| 231 |
+
f"{cls.device_name}.")
|
| 232 |
+
|
| 233 |
+
@classmethod
|
| 234 |
+
def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float:
|
| 235 |
+
"""
|
| 236 |
+
Return the memory usage in bytes.
|
| 237 |
+
"""
|
| 238 |
+
raise NotImplementedError
|
| 239 |
+
|
| 240 |
+
@classmethod
|
| 241 |
+
def get_device_communicator_cls(cls) -> str:
|
| 242 |
+
"""
|
| 243 |
+
Get device specific communicator class for distributed communication.
|
| 244 |
+
"""
|
| 245 |
+
return "fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa
|
| 246 |
+
|
| 247 |
+
@classmethod
|
| 248 |
+
def get_cpu_architecture(cls) -> CpuArchEnum:
|
| 249 |
+
"""Get the CPU architecture of the current platform."""
|
| 250 |
+
return CpuArchEnum.UNSPECIFIED
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class UnspecifiedPlatform(Platform):
|
| 254 |
+
_enum = PlatformEnum.UNSPECIFIED
|
| 255 |
+
device_type = ""
|
standalone_inference/overlay_files/fastvideo/train/models/wan/wan.py
ADDED
|
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Wan model plugin (per-role instance)."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
import gc
|
| 8 |
+
from typing import Any, Literal, TYPE_CHECKING
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
import fastvideo.envs as envs
|
| 13 |
+
from fastvideo.configs.sample import SamplingParam
|
| 14 |
+
from fastvideo.distributed import (
|
| 15 |
+
get_sp_group,
|
| 16 |
+
get_world_group,
|
| 17 |
+
)
|
| 18 |
+
from fastvideo.forward_context import set_forward_context
|
| 19 |
+
from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import (
|
| 20 |
+
FlowMatchEulerDiscreteScheduler, )
|
| 21 |
+
from fastvideo.pipelines import TrainingBatch
|
| 22 |
+
from fastvideo.pipelines.basic.wan.wan_pipeline import (
|
| 23 |
+
WanPipeline, )
|
| 24 |
+
from fastvideo.pipelines.pipeline_batch_info import (
|
| 25 |
+
ForwardBatch, )
|
| 26 |
+
from fastvideo.training.activation_checkpoint import (
|
| 27 |
+
apply_activation_checkpointing, )
|
| 28 |
+
from fastvideo.training.training_utils import (
|
| 29 |
+
compute_density_for_timestep_sampling,
|
| 30 |
+
get_sigmas,
|
| 31 |
+
normalize_dit_input,
|
| 32 |
+
shift_timestep,
|
| 33 |
+
)
|
| 34 |
+
from fastvideo.utils import (
|
| 35 |
+
is_vmoba_available,
|
| 36 |
+
is_vsa_available,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
from fastvideo.train.models.base import ModelBase
|
| 40 |
+
from fastvideo.train.utils.module_state import (
|
| 41 |
+
apply_trainable, )
|
| 42 |
+
from fastvideo.train.utils.moduleloader import (
|
| 43 |
+
load_module_from_path, )
|
| 44 |
+
|
| 45 |
+
if TYPE_CHECKING:
|
| 46 |
+
from fastvideo.train.utils.training_config import (
|
| 47 |
+
TrainingConfig, )
|
| 48 |
+
|
| 49 |
+
VideoSparseAttentionMetadataBuilder: type[Any] | None
|
| 50 |
+
VideoMobaAttentionMetadataBuilder: type[Any] | None
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
from fastvideo.attention.backends.video_sparse_attn import (
|
| 54 |
+
VideoSparseAttentionMetadataBuilder as _VideoSparseAttentionMetadataBuilder, )
|
| 55 |
+
from fastvideo.attention.backends.vmoba import (
|
| 56 |
+
VideoMobaAttentionMetadataBuilder as _VideoMobaAttentionMetadataBuilder, )
|
| 57 |
+
VideoSparseAttentionMetadataBuilder = _VideoSparseAttentionMetadataBuilder
|
| 58 |
+
VideoMobaAttentionMetadataBuilder = _VideoMobaAttentionMetadataBuilder
|
| 59 |
+
except Exception:
|
| 60 |
+
VideoSparseAttentionMetadataBuilder = None
|
| 61 |
+
VideoMobaAttentionMetadataBuilder = None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class WanModel(ModelBase):
|
| 65 |
+
"""Wan per-role model: owns transformer + noise_scheduler."""
|
| 66 |
+
|
| 67 |
+
_transformer_cls_name: str = "WanTransformer3DModel"
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
*,
|
| 72 |
+
init_from: str,
|
| 73 |
+
training_config: TrainingConfig,
|
| 74 |
+
trainable: bool = True,
|
| 75 |
+
disable_custom_init_weights: bool = False,
|
| 76 |
+
flow_shift: float = 3.0,
|
| 77 |
+
enable_gradient_checkpointing_type: str
|
| 78 |
+
| None = None,
|
| 79 |
+
transformer_override_safetensor: str
|
| 80 |
+
| None = None,
|
| 81 |
+
) -> None:
|
| 82 |
+
self._init_from = str(init_from)
|
| 83 |
+
self._trainable = bool(trainable)
|
| 84 |
+
|
| 85 |
+
self.transformer = self._load_transformer(
|
| 86 |
+
init_from=self._init_from,
|
| 87 |
+
trainable=self._trainable,
|
| 88 |
+
disable_custom_init_weights=(disable_custom_init_weights),
|
| 89 |
+
enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
|
| 90 |
+
training_config=training_config,
|
| 91 |
+
transformer_override_safetensor=(transformer_override_safetensor),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
self.noise_scheduler = (FlowMatchEulerDiscreteScheduler(shift=float(flow_shift)))
|
| 95 |
+
|
| 96 |
+
# Filled by init_preprocessors (student only).
|
| 97 |
+
self.vae: Any = None
|
| 98 |
+
self.training_config: TrainingConfig = training_config
|
| 99 |
+
self.dataloader: Any = None
|
| 100 |
+
self.validator: Any = None
|
| 101 |
+
self.start_step: int = 0
|
| 102 |
+
|
| 103 |
+
self.world_group: Any = None
|
| 104 |
+
self.sp_group: Any = None
|
| 105 |
+
|
| 106 |
+
self.negative_prompt_embeds: (torch.Tensor | None) = None
|
| 107 |
+
self.negative_prompt_attention_mask: (torch.Tensor | None) = None
|
| 108 |
+
|
| 109 |
+
# Timestep mechanics.
|
| 110 |
+
self.timestep_shift: float = float(flow_shift)
|
| 111 |
+
self.num_train_timestep: int = int(self.noise_scheduler.num_train_timesteps)
|
| 112 |
+
self.min_timestep: int = 0
|
| 113 |
+
self.max_timestep: int = self.num_train_timestep
|
| 114 |
+
|
| 115 |
+
def _load_transformer(
|
| 116 |
+
self,
|
| 117 |
+
*,
|
| 118 |
+
init_from: str,
|
| 119 |
+
trainable: bool,
|
| 120 |
+
disable_custom_init_weights: bool,
|
| 121 |
+
enable_gradient_checkpointing_type: str | None,
|
| 122 |
+
training_config: TrainingConfig,
|
| 123 |
+
transformer_override_safetensor: str | None = None,
|
| 124 |
+
) -> torch.nn.Module:
|
| 125 |
+
transformer = load_module_from_path(
|
| 126 |
+
model_path=init_from,
|
| 127 |
+
module_type="transformer",
|
| 128 |
+
training_config=training_config,
|
| 129 |
+
disable_custom_init_weights=(disable_custom_init_weights),
|
| 130 |
+
override_transformer_cls_name=(self._transformer_cls_name),
|
| 131 |
+
transformer_override_safetensor=(transformer_override_safetensor),
|
| 132 |
+
)
|
| 133 |
+
transformer = apply_trainable(transformer, trainable=trainable)
|
| 134 |
+
# Fall back to training_config.model if not set on the
|
| 135 |
+
# model YAML section directly.
|
| 136 |
+
ckpt_type = (enable_gradient_checkpointing_type or getattr(
|
| 137 |
+
getattr(training_config, "model", None),
|
| 138 |
+
"enable_gradient_checkpointing_type",
|
| 139 |
+
None,
|
| 140 |
+
))
|
| 141 |
+
if trainable and ckpt_type:
|
| 142 |
+
transformer = apply_activation_checkpointing(
|
| 143 |
+
transformer,
|
| 144 |
+
checkpointing_type=ckpt_type,
|
| 145 |
+
)
|
| 146 |
+
return transformer
|
| 147 |
+
|
| 148 |
+
# ------------------------------------------------------------------
|
| 149 |
+
# Lifecycle
|
| 150 |
+
# ------------------------------------------------------------------
|
| 151 |
+
|
| 152 |
+
def init_preprocessors(self, training_config: TrainingConfig) -> None:
|
| 153 |
+
self.vae = load_module_from_path(
|
| 154 |
+
model_path=str(training_config.model_path),
|
| 155 |
+
module_type="vae",
|
| 156 |
+
training_config=training_config,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
self.world_group = get_world_group()
|
| 160 |
+
self.sp_group = get_sp_group()
|
| 161 |
+
|
| 162 |
+
self._init_timestep_mechanics()
|
| 163 |
+
|
| 164 |
+
from fastvideo.dataset.dataloader.schema import (
|
| 165 |
+
pyarrow_schema_t2v, )
|
| 166 |
+
from fastvideo.train.utils.dataloader import (
|
| 167 |
+
build_parquet_t2v_train_dataloader, )
|
| 168 |
+
|
| 169 |
+
text_len = (
|
| 170 |
+
training_config.pipeline_config.text_encoder_configs[ # type: ignore[union-attr]
|
| 171 |
+
0].arch_config.text_len)
|
| 172 |
+
self.dataloader = build_parquet_t2v_train_dataloader(
|
| 173 |
+
training_config.data,
|
| 174 |
+
text_len=int(text_len),
|
| 175 |
+
parquet_schema=pyarrow_schema_t2v,
|
| 176 |
+
)
|
| 177 |
+
self.start_step = 0
|
| 178 |
+
|
| 179 |
+
@property
|
| 180 |
+
def num_train_timesteps(self) -> int:
|
| 181 |
+
return int(self.num_train_timestep)
|
| 182 |
+
|
| 183 |
+
def shift_and_clamp_timestep(self, timestep: torch.Tensor) -> torch.Tensor:
|
| 184 |
+
timestep = shift_timestep(
|
| 185 |
+
timestep,
|
| 186 |
+
self.timestep_shift,
|
| 187 |
+
self.num_train_timestep,
|
| 188 |
+
)
|
| 189 |
+
return timestep.clamp(self.min_timestep, self.max_timestep)
|
| 190 |
+
|
| 191 |
+
def on_train_start(self) -> None:
|
| 192 |
+
self.ensure_negative_conditioning()
|
| 193 |
+
|
| 194 |
+
# ------------------------------------------------------------------
|
| 195 |
+
# Runtime primitives
|
| 196 |
+
# ------------------------------------------------------------------
|
| 197 |
+
|
| 198 |
+
def prepare_batch(
|
| 199 |
+
self,
|
| 200 |
+
raw_batch: dict[str, Any],
|
| 201 |
+
*,
|
| 202 |
+
generator: torch.Generator,
|
| 203 |
+
latents_source: Literal["data", "zeros"] = "data",
|
| 204 |
+
) -> TrainingBatch:
|
| 205 |
+
self.ensure_negative_conditioning()
|
| 206 |
+
assert self.training_config is not None
|
| 207 |
+
tc = self.training_config
|
| 208 |
+
|
| 209 |
+
dtype = self._get_training_dtype()
|
| 210 |
+
device = self.device
|
| 211 |
+
|
| 212 |
+
training_batch = TrainingBatch()
|
| 213 |
+
encoder_hidden_states = raw_batch["text_embedding"]
|
| 214 |
+
encoder_attention_mask = raw_batch["text_attention_mask"]
|
| 215 |
+
infos = raw_batch.get("info_list")
|
| 216 |
+
|
| 217 |
+
if latents_source == "zeros":
|
| 218 |
+
batch_size = encoder_hidden_states.shape[0]
|
| 219 |
+
vae_config = (
|
| 220 |
+
tc.pipeline_config.vae_config.arch_config # type: ignore[union-attr]
|
| 221 |
+
)
|
| 222 |
+
num_channels = vae_config.z_dim
|
| 223 |
+
spatial_compression_ratio = (vae_config.spatial_compression_ratio)
|
| 224 |
+
latent_height = (tc.data.num_height // spatial_compression_ratio)
|
| 225 |
+
latent_width = (tc.data.num_width // spatial_compression_ratio)
|
| 226 |
+
latents = torch.zeros(
|
| 227 |
+
batch_size,
|
| 228 |
+
num_channels,
|
| 229 |
+
tc.data.num_latent_t,
|
| 230 |
+
latent_height,
|
| 231 |
+
latent_width,
|
| 232 |
+
device=device,
|
| 233 |
+
dtype=dtype,
|
| 234 |
+
)
|
| 235 |
+
elif latents_source == "data":
|
| 236 |
+
if "vae_latent" not in raw_batch:
|
| 237 |
+
raise ValueError("vae_latent not found in batch "
|
| 238 |
+
"and latents_source='data'")
|
| 239 |
+
latents = raw_batch["vae_latent"]
|
| 240 |
+
latents = latents[:, :, :tc.data.num_latent_t]
|
| 241 |
+
latents = latents.to(device, dtype=dtype)
|
| 242 |
+
else:
|
| 243 |
+
raise ValueError(f"Unknown latents_source: "
|
| 244 |
+
f"{latents_source!r}")
|
| 245 |
+
|
| 246 |
+
training_batch.latents = latents
|
| 247 |
+
training_batch.encoder_hidden_states = (encoder_hidden_states.to(device, dtype=dtype))
|
| 248 |
+
training_batch.encoder_attention_mask = (encoder_attention_mask.to(device, dtype=dtype))
|
| 249 |
+
training_batch.infos = infos
|
| 250 |
+
|
| 251 |
+
training_batch.latents = normalize_dit_input("wan", training_batch.latents, self.vae)
|
| 252 |
+
training_batch = self._prepare_dit_inputs(training_batch, generator)
|
| 253 |
+
training_batch = self._build_attention_metadata(training_batch)
|
| 254 |
+
|
| 255 |
+
training_batch.attn_metadata_vsa = copy.deepcopy(training_batch.attn_metadata)
|
| 256 |
+
if training_batch.attn_metadata is not None:
|
| 257 |
+
training_batch.attn_metadata.VSA_sparsity = 0.0 # type: ignore[attr-defined]
|
| 258 |
+
|
| 259 |
+
return training_batch
|
| 260 |
+
|
| 261 |
+
def add_noise(
|
| 262 |
+
self,
|
| 263 |
+
clean_latents: torch.Tensor,
|
| 264 |
+
noise: torch.Tensor,
|
| 265 |
+
timestep: torch.Tensor,
|
| 266 |
+
) -> torch.Tensor:
|
| 267 |
+
b, t = clean_latents.shape[:2]
|
| 268 |
+
noisy = self.noise_scheduler.add_noise(
|
| 269 |
+
clean_latents.flatten(0, 1),
|
| 270 |
+
noise.flatten(0, 1),
|
| 271 |
+
timestep,
|
| 272 |
+
).unflatten(0, (b, t))
|
| 273 |
+
return noisy
|
| 274 |
+
|
| 275 |
+
def predict_noise(
|
| 276 |
+
self,
|
| 277 |
+
noisy_latents: torch.Tensor,
|
| 278 |
+
timestep: torch.Tensor,
|
| 279 |
+
batch: TrainingBatch,
|
| 280 |
+
*,
|
| 281 |
+
conditional: bool,
|
| 282 |
+
cfg_uncond: dict[str, Any] | None = None,
|
| 283 |
+
attn_kind: Literal["dense", "vsa"] = "dense",
|
| 284 |
+
force_dense: bool = False,
|
| 285 |
+
) -> torch.Tensor:
|
| 286 |
+
device_type = self.device.type
|
| 287 |
+
dtype = noisy_latents.dtype
|
| 288 |
+
if conditional:
|
| 289 |
+
text_dict = batch.conditional_dict
|
| 290 |
+
if text_dict is None:
|
| 291 |
+
raise RuntimeError("Missing conditional_dict in "
|
| 292 |
+
"TrainingBatch")
|
| 293 |
+
else:
|
| 294 |
+
text_dict = self._get_uncond_text_dict(batch, cfg_uncond=cfg_uncond)
|
| 295 |
+
|
| 296 |
+
if attn_kind == "dense":
|
| 297 |
+
attn_metadata = batch.attn_metadata
|
| 298 |
+
elif attn_kind in ("vsa", "sparse_fp4"):
|
| 299 |
+
attn_metadata = batch.attn_metadata_vsa
|
| 300 |
+
else:
|
| 301 |
+
raise ValueError(f"Unknown attn_kind: {attn_kind!r}")
|
| 302 |
+
|
| 303 |
+
with torch.autocast(device_type, dtype=dtype), set_forward_context(
|
| 304 |
+
current_timestep=batch.timesteps,
|
| 305 |
+
attn_metadata=attn_metadata,
|
| 306 |
+
force_dense=force_dense,
|
| 307 |
+
):
|
| 308 |
+
input_kwargs = (self._build_distill_input_kwargs(noisy_latents, timestep, text_dict))
|
| 309 |
+
transformer = self._get_transformer(timestep)
|
| 310 |
+
pred_noise = transformer(**input_kwargs).permute(0, 2, 1, 3, 4)
|
| 311 |
+
return pred_noise
|
| 312 |
+
|
| 313 |
+
def backward(
|
| 314 |
+
self,
|
| 315 |
+
loss: torch.Tensor,
|
| 316 |
+
ctx: Any,
|
| 317 |
+
*,
|
| 318 |
+
grad_accum_rounds: int,
|
| 319 |
+
) -> None:
|
| 320 |
+
timesteps, attn_metadata = ctx
|
| 321 |
+
with set_forward_context(
|
| 322 |
+
current_timestep=timesteps,
|
| 323 |
+
attn_metadata=attn_metadata,
|
| 324 |
+
):
|
| 325 |
+
(loss / max(1, int(grad_accum_rounds))).backward()
|
| 326 |
+
|
| 327 |
+
# ------------------------------------------------------------------
|
| 328 |
+
# Internal helpers
|
| 329 |
+
# ------------------------------------------------------------------
|
| 330 |
+
|
| 331 |
+
def _get_training_dtype(self) -> torch.dtype:
|
| 332 |
+
return torch.bfloat16
|
| 333 |
+
|
| 334 |
+
def _init_timestep_mechanics(self) -> None:
|
| 335 |
+
assert self.training_config is not None
|
| 336 |
+
tc = self.training_config
|
| 337 |
+
flow_shift = tc.pipeline_config.flow_shift
|
| 338 |
+
self.timestep_shift = float(0.0 if flow_shift is None else flow_shift)
|
| 339 |
+
self.num_train_timestep = int(self.noise_scheduler.num_train_timesteps)
|
| 340 |
+
# min/max timestep ratios now come from method_config;
|
| 341 |
+
# default to full range.
|
| 342 |
+
self.min_timestep = 0
|
| 343 |
+
self.max_timestep = self.num_train_timestep
|
| 344 |
+
|
| 345 |
+
def ensure_negative_conditioning(self) -> None:
|
| 346 |
+
if self.negative_prompt_embeds is not None:
|
| 347 |
+
return
|
| 348 |
+
|
| 349 |
+
assert self.training_config is not None
|
| 350 |
+
tc = self.training_config
|
| 351 |
+
world_group = self.world_group
|
| 352 |
+
device = self.device
|
| 353 |
+
dtype = self._get_training_dtype()
|
| 354 |
+
|
| 355 |
+
from fastvideo.train.utils.moduleloader import (
|
| 356 |
+
make_inference_args, )
|
| 357 |
+
|
| 358 |
+
neg_embeds: torch.Tensor | None = None
|
| 359 |
+
neg_mask: torch.Tensor | None = None
|
| 360 |
+
|
| 361 |
+
if world_group.rank_in_group == 0:
|
| 362 |
+
sampling_param = SamplingParam.from_pretrained(tc.model_path)
|
| 363 |
+
negative_prompt = sampling_param.negative_prompt
|
| 364 |
+
|
| 365 |
+
inference_args = make_inference_args(tc, model_path=tc.model_path)
|
| 366 |
+
|
| 367 |
+
prompt_pipeline = WanPipeline.from_pretrained(
|
| 368 |
+
tc.model_path,
|
| 369 |
+
args=inference_args,
|
| 370 |
+
inference_mode=True,
|
| 371 |
+
loaded_modules={"transformer": self.transformer},
|
| 372 |
+
tp_size=tc.distributed.tp_size,
|
| 373 |
+
sp_size=tc.distributed.sp_size,
|
| 374 |
+
num_gpus=tc.distributed.num_gpus,
|
| 375 |
+
pin_cpu_memory=(tc.distributed.pin_cpu_memory),
|
| 376 |
+
dit_cpu_offload=True,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
batch_negative = ForwardBatch(
|
| 380 |
+
data_type="video",
|
| 381 |
+
prompt=negative_prompt,
|
| 382 |
+
prompt_embeds=[],
|
| 383 |
+
prompt_attention_mask=[],
|
| 384 |
+
)
|
| 385 |
+
result_batch = prompt_pipeline.prompt_encoding_stage( # type: ignore[attr-defined]
|
| 386 |
+
batch_negative,
|
| 387 |
+
inference_args,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
neg_embeds = result_batch.prompt_embeds[0].to(device=device, dtype=dtype)
|
| 391 |
+
neg_mask = (result_batch.prompt_attention_mask[0].to(device=device, dtype=dtype))
|
| 392 |
+
|
| 393 |
+
del prompt_pipeline
|
| 394 |
+
gc.collect()
|
| 395 |
+
if torch.cuda.is_available():
|
| 396 |
+
torch.cuda.empty_cache()
|
| 397 |
+
|
| 398 |
+
meta = torch.zeros((2, ), device=device, dtype=torch.int64)
|
| 399 |
+
if world_group.rank_in_group == 0:
|
| 400 |
+
assert neg_embeds is not None
|
| 401 |
+
assert neg_mask is not None
|
| 402 |
+
meta[0] = neg_embeds.ndim
|
| 403 |
+
meta[1] = neg_mask.ndim
|
| 404 |
+
world_group.broadcast(meta, src=0)
|
| 405 |
+
embed_ndim, mask_ndim = (
|
| 406 |
+
int(meta[0].item()),
|
| 407 |
+
int(meta[1].item()),
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
max_ndim = 8
|
| 411 |
+
embed_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64)
|
| 412 |
+
mask_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64)
|
| 413 |
+
if world_group.rank_in_group == 0:
|
| 414 |
+
assert neg_embeds is not None
|
| 415 |
+
assert neg_mask is not None
|
| 416 |
+
embed_shape[:embed_ndim] = torch.tensor(
|
| 417 |
+
list(neg_embeds.shape),
|
| 418 |
+
device=device,
|
| 419 |
+
dtype=torch.int64,
|
| 420 |
+
)
|
| 421 |
+
mask_shape[:mask_ndim] = torch.tensor(
|
| 422 |
+
list(neg_mask.shape),
|
| 423 |
+
device=device,
|
| 424 |
+
dtype=torch.int64,
|
| 425 |
+
)
|
| 426 |
+
world_group.broadcast(embed_shape, src=0)
|
| 427 |
+
world_group.broadcast(mask_shape, src=0)
|
| 428 |
+
|
| 429 |
+
embed_sizes = tuple(int(x) for x in embed_shape[:embed_ndim].tolist())
|
| 430 |
+
mask_sizes = tuple(int(x) for x in mask_shape[:mask_ndim].tolist())
|
| 431 |
+
|
| 432 |
+
if world_group.rank_in_group != 0:
|
| 433 |
+
neg_embeds = torch.empty(embed_sizes, device=device, dtype=dtype)
|
| 434 |
+
neg_mask = torch.empty(mask_sizes, device=device, dtype=dtype)
|
| 435 |
+
assert neg_embeds is not None
|
| 436 |
+
assert neg_mask is not None
|
| 437 |
+
|
| 438 |
+
world_group.broadcast(neg_embeds, src=0)
|
| 439 |
+
world_group.broadcast(neg_mask, src=0)
|
| 440 |
+
|
| 441 |
+
self.negative_prompt_embeds = neg_embeds
|
| 442 |
+
self.negative_prompt_attention_mask = neg_mask
|
| 443 |
+
|
| 444 |
+
def _sample_timesteps(
|
| 445 |
+
self,
|
| 446 |
+
batch_size: int,
|
| 447 |
+
device: torch.device,
|
| 448 |
+
generator: torch.Generator,
|
| 449 |
+
) -> torch.Tensor:
|
| 450 |
+
assert self.training_config is not None
|
| 451 |
+
tc = self.training_config
|
| 452 |
+
|
| 453 |
+
u = compute_density_for_timestep_sampling(
|
| 454 |
+
weighting_scheme=tc.model.weighting_scheme,
|
| 455 |
+
batch_size=batch_size,
|
| 456 |
+
generator=generator,
|
| 457 |
+
device=device,
|
| 458 |
+
logit_mean=tc.model.logit_mean,
|
| 459 |
+
logit_std=tc.model.logit_std,
|
| 460 |
+
mode_scale=tc.model.mode_scale,
|
| 461 |
+
)
|
| 462 |
+
indices = (u * self.noise_scheduler.config.num_train_timesteps).long()
|
| 463 |
+
return self.noise_scheduler.timesteps[indices.cpu()].to(device=device)
|
| 464 |
+
|
| 465 |
+
def _build_attention_metadata(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 466 |
+
assert self.training_config is not None
|
| 467 |
+
tc = self.training_config
|
| 468 |
+
latents_shape = training_batch.raw_latent_shape
|
| 469 |
+
patch_size = (
|
| 470 |
+
tc.pipeline_config.dit_config.patch_size # type: ignore[union-attr]
|
| 471 |
+
)
|
| 472 |
+
assert latents_shape is not None
|
| 473 |
+
assert training_batch.timesteps is not None
|
| 474 |
+
|
| 475 |
+
if envs.FASTVIDEO_ATTENTION_BACKEND in (
|
| 476 |
+
"VIDEO_SPARSE_ATTN", "SPARSE_FP4_ATTN", "SPARSE_FP4_OURS_P_ATTN",
|
| 477 |
+
):
|
| 478 |
+
if (not is_vsa_available() or VideoSparseAttentionMetadataBuilder is None):
|
| 479 |
+
raise ImportError(
|
| 480 |
+
f"FASTVIDEO_ATTENTION_BACKEND is "
|
| 481 |
+
f"{envs.FASTVIDEO_ATTENTION_BACKEND}, but "
|
| 482 |
+
f"fastvideo_kernel is not correctly "
|
| 483 |
+
f"installed or detected.")
|
| 484 |
+
training_batch.attn_metadata = VideoSparseAttentionMetadataBuilder().build( # type: ignore[misc]
|
| 485 |
+
raw_latent_shape=latents_shape[2:5],
|
| 486 |
+
current_timestep=(training_batch.timesteps),
|
| 487 |
+
patch_size=patch_size,
|
| 488 |
+
VSA_sparsity=tc.vsa_sparsity,
|
| 489 |
+
device=self.device,
|
| 490 |
+
)
|
| 491 |
+
elif (envs.FASTVIDEO_ATTENTION_BACKEND == "VMOBA_ATTN"):
|
| 492 |
+
if (not is_vmoba_available() or VideoMobaAttentionMetadataBuilder is None):
|
| 493 |
+
raise ImportError("FASTVIDEO_ATTENTION_BACKEND is "
|
| 494 |
+
"VMOBA_ATTN, but fastvideo_kernel "
|
| 495 |
+
"(or flash_attn>=2.7.4) is not "
|
| 496 |
+
"correctly installed.")
|
| 497 |
+
moba_params = tc.model.moba_config.copy()
|
| 498 |
+
assert training_batch.raw_latent_shape is not None
|
| 499 |
+
moba_params.update({
|
| 500 |
+
"current_timestep": (training_batch.timesteps),
|
| 501 |
+
"raw_latent_shape": (training_batch.raw_latent_shape[2:5]),
|
| 502 |
+
"patch_size": patch_size,
|
| 503 |
+
"device": self.device,
|
| 504 |
+
})
|
| 505 |
+
training_batch.attn_metadata = VideoMobaAttentionMetadataBuilder().build(**
|
| 506 |
+
moba_params) # type: ignore[misc]
|
| 507 |
+
else:
|
| 508 |
+
training_batch.attn_metadata = None
|
| 509 |
+
|
| 510 |
+
return training_batch
|
| 511 |
+
|
| 512 |
+
def _prepare_dit_inputs(
|
| 513 |
+
self,
|
| 514 |
+
training_batch: TrainingBatch,
|
| 515 |
+
generator: torch.Generator,
|
| 516 |
+
) -> TrainingBatch:
|
| 517 |
+
assert self.training_config is not None
|
| 518 |
+
tc = self.training_config
|
| 519 |
+
latents = training_batch.latents
|
| 520 |
+
assert isinstance(latents, torch.Tensor)
|
| 521 |
+
batch_size = latents.shape[0]
|
| 522 |
+
|
| 523 |
+
noise = torch.randn(
|
| 524 |
+
latents.shape,
|
| 525 |
+
generator=generator,
|
| 526 |
+
device=latents.device,
|
| 527 |
+
dtype=latents.dtype,
|
| 528 |
+
)
|
| 529 |
+
timesteps = self._sample_timesteps(
|
| 530 |
+
batch_size,
|
| 531 |
+
latents.device,
|
| 532 |
+
generator,
|
| 533 |
+
)
|
| 534 |
+
if int(tc.distributed.sp_size or 1) > 1:
|
| 535 |
+
self.sp_group.broadcast(timesteps, src=0)
|
| 536 |
+
|
| 537 |
+
sigmas = get_sigmas(
|
| 538 |
+
self.noise_scheduler,
|
| 539 |
+
latents.device,
|
| 540 |
+
timesteps,
|
| 541 |
+
n_dim=latents.ndim,
|
| 542 |
+
dtype=latents.dtype,
|
| 543 |
+
)
|
| 544 |
+
noisy_model_input = ((1.0 - sigmas) * latents + sigmas * noise)
|
| 545 |
+
|
| 546 |
+
training_batch.noisy_model_input = (noisy_model_input)
|
| 547 |
+
training_batch.timesteps = timesteps
|
| 548 |
+
training_batch.sigmas = sigmas
|
| 549 |
+
training_batch.noise = noise
|
| 550 |
+
training_batch.raw_latent_shape = latents.shape
|
| 551 |
+
|
| 552 |
+
training_batch.conditional_dict = {
|
| 553 |
+
"encoder_hidden_states": (training_batch.encoder_hidden_states),
|
| 554 |
+
"encoder_attention_mask": (training_batch.encoder_attention_mask),
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
if (self.negative_prompt_embeds is not None and self.negative_prompt_attention_mask is not None):
|
| 558 |
+
neg_embeds = self.negative_prompt_embeds
|
| 559 |
+
neg_mask = (self.negative_prompt_attention_mask)
|
| 560 |
+
if (neg_embeds.shape[0] == 1 and batch_size > 1):
|
| 561 |
+
neg_embeds = neg_embeds.expand(batch_size, *neg_embeds.shape[1:]).contiguous()
|
| 562 |
+
if (neg_mask.shape[0] == 1 and batch_size > 1):
|
| 563 |
+
neg_mask = neg_mask.expand(batch_size, *neg_mask.shape[1:]).contiguous()
|
| 564 |
+
training_batch.unconditional_dict = {
|
| 565 |
+
"encoder_hidden_states": neg_embeds,
|
| 566 |
+
"encoder_attention_mask": neg_mask,
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
training_batch.latents = (training_batch.latents.permute(0, 2, 1, 3, 4))
|
| 570 |
+
return training_batch
|
| 571 |
+
|
| 572 |
+
def _build_distill_input_kwargs(
|
| 573 |
+
self,
|
| 574 |
+
noise_input: torch.Tensor,
|
| 575 |
+
timestep: torch.Tensor,
|
| 576 |
+
text_dict: dict[str, torch.Tensor] | None,
|
| 577 |
+
) -> dict[str, Any]:
|
| 578 |
+
if text_dict is None:
|
| 579 |
+
raise ValueError("text_dict cannot be None for "
|
| 580 |
+
"Wan distillation")
|
| 581 |
+
return {
|
| 582 |
+
"hidden_states": noise_input.permute(0, 2, 1, 3, 4),
|
| 583 |
+
"encoder_hidden_states": text_dict["encoder_hidden_states"],
|
| 584 |
+
"encoder_attention_mask": text_dict["encoder_attention_mask"],
|
| 585 |
+
"timestep": timestep,
|
| 586 |
+
"return_dict": False,
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
def _get_transformer(self, timestep: torch.Tensor) -> torch.nn.Module:
|
| 590 |
+
return self.transformer
|
| 591 |
+
|
| 592 |
+
def _get_uncond_text_dict(
|
| 593 |
+
self,
|
| 594 |
+
batch: TrainingBatch,
|
| 595 |
+
*,
|
| 596 |
+
cfg_uncond: dict[str, Any] | None,
|
| 597 |
+
) -> dict[str, torch.Tensor]:
|
| 598 |
+
if cfg_uncond is None:
|
| 599 |
+
text_dict = getattr(batch, "unconditional_dict", None)
|
| 600 |
+
if text_dict is None:
|
| 601 |
+
raise RuntimeError("Missing unconditional_dict; "
|
| 602 |
+
"ensure_negative_conditioning() "
|
| 603 |
+
"may have failed")
|
| 604 |
+
return text_dict
|
| 605 |
+
|
| 606 |
+
on_missing_raw = cfg_uncond.get("on_missing", "error")
|
| 607 |
+
if not isinstance(on_missing_raw, str):
|
| 608 |
+
raise ValueError("method_config.cfg_uncond.on_missing "
|
| 609 |
+
"must be a string, got "
|
| 610 |
+
f"{type(on_missing_raw).__name__}")
|
| 611 |
+
on_missing = on_missing_raw.strip().lower()
|
| 612 |
+
if on_missing not in {"error", "ignore"}:
|
| 613 |
+
raise ValueError("method_config.cfg_uncond.on_missing "
|
| 614 |
+
"must be one of {error, ignore}, got "
|
| 615 |
+
f"{on_missing_raw!r}")
|
| 616 |
+
|
| 617 |
+
for channel, policy_raw in cfg_uncond.items():
|
| 618 |
+
if channel in {"on_missing", "text"}:
|
| 619 |
+
continue
|
| 620 |
+
if policy_raw is None:
|
| 621 |
+
continue
|
| 622 |
+
if not isinstance(policy_raw, str):
|
| 623 |
+
raise ValueError("method_config.cfg_uncond values "
|
| 624 |
+
"must be strings, got "
|
| 625 |
+
f"{channel}="
|
| 626 |
+
f"{type(policy_raw).__name__}")
|
| 627 |
+
policy = policy_raw.strip().lower()
|
| 628 |
+
if policy == "keep":
|
| 629 |
+
continue
|
| 630 |
+
if on_missing == "ignore":
|
| 631 |
+
continue
|
| 632 |
+
raise ValueError("WanModel does not support "
|
| 633 |
+
"cfg_uncond channel "
|
| 634 |
+
f"{channel!r} (policy={policy!r}). "
|
| 635 |
+
"Set cfg_uncond.on_missing=ignore or "
|
| 636 |
+
"remove the channel.")
|
| 637 |
+
|
| 638 |
+
text_policy_raw = cfg_uncond.get("text", None)
|
| 639 |
+
if text_policy_raw is None:
|
| 640 |
+
text_policy = "negative_prompt"
|
| 641 |
+
elif not isinstance(text_policy_raw, str):
|
| 642 |
+
raise ValueError("method_config.cfg_uncond.text must be "
|
| 643 |
+
"a string, got "
|
| 644 |
+
f"{type(text_policy_raw).__name__}")
|
| 645 |
+
else:
|
| 646 |
+
text_policy = (text_policy_raw.strip().lower())
|
| 647 |
+
|
| 648 |
+
if text_policy in {"negative_prompt"}:
|
| 649 |
+
text_dict = getattr(batch, "unconditional_dict", None)
|
| 650 |
+
if text_dict is None:
|
| 651 |
+
raise RuntimeError("Missing unconditional_dict; "
|
| 652 |
+
"ensure_negative_conditioning() "
|
| 653 |
+
"may have failed")
|
| 654 |
+
return text_dict
|
| 655 |
+
if text_policy == "keep":
|
| 656 |
+
if batch.conditional_dict is None:
|
| 657 |
+
raise RuntimeError("Missing conditional_dict in "
|
| 658 |
+
"TrainingBatch")
|
| 659 |
+
return batch.conditional_dict
|
| 660 |
+
if text_policy == "zero":
|
| 661 |
+
if batch.conditional_dict is None:
|
| 662 |
+
raise RuntimeError("Missing conditional_dict in "
|
| 663 |
+
"TrainingBatch")
|
| 664 |
+
cond = batch.conditional_dict
|
| 665 |
+
enc = cond["encoder_hidden_states"]
|
| 666 |
+
mask = cond["encoder_attention_mask"]
|
| 667 |
+
if not torch.is_tensor(enc) or not torch.is_tensor(mask):
|
| 668 |
+
raise TypeError("conditional_dict must contain "
|
| 669 |
+
"tensor text inputs")
|
| 670 |
+
return {
|
| 671 |
+
"encoder_hidden_states": (torch.zeros_like(enc)),
|
| 672 |
+
"encoder_attention_mask": (torch.zeros_like(mask)),
|
| 673 |
+
}
|
| 674 |
+
if text_policy == "drop":
|
| 675 |
+
raise ValueError("cfg_uncond.text=drop is not supported "
|
| 676 |
+
"for Wan. Use "
|
| 677 |
+
"{negative_prompt, keep, zero}.")
|
| 678 |
+
raise ValueError("cfg_uncond.text must be one of "
|
| 679 |
+
"{negative_prompt, keep, zero, drop}, got "
|
| 680 |
+
f"{text_policy_raw!r}")
|
standalone_inference/overlay_files/fastvideo/training/training_pipeline.py
ADDED
|
@@ -0,0 +1,1044 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
from dataclasses import asdict
|
| 3 |
+
from contextlib import AbstractContextManager, nullcontext
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
import tempfile
|
| 8 |
+
import time
|
| 9 |
+
from abc import ABC, abstractmethod
|
| 10 |
+
from collections import deque
|
| 11 |
+
from collections.abc import Iterator
|
| 12 |
+
from typing import Any
|
| 13 |
+
from fastvideo.profiler import profile_region
|
| 14 |
+
import imageio
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
import torchvision
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from torch.utils.data import DataLoader
|
| 21 |
+
from torchdata.stateful_dataloader import StatefulDataLoader
|
| 22 |
+
from tqdm.auto import tqdm
|
| 23 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 24 |
+
|
| 25 |
+
import fastvideo.envs as envs
|
| 26 |
+
try:
|
| 27 |
+
from fastvideo.attention.backends.video_sparse_attn import (VideoSparseAttentionMetadataBuilder)
|
| 28 |
+
from fastvideo.attention.backends.vmoba import VideoMobaAttentionMetadataBuilder
|
| 29 |
+
except Exception:
|
| 30 |
+
pass
|
| 31 |
+
from fastvideo.configs.sample import SamplingParam
|
| 32 |
+
from fastvideo.dataset import build_parquet_map_style_dataloader
|
| 33 |
+
from fastvideo.dataset.dataloader.schema import pyarrow_schema_t2v
|
| 34 |
+
from fastvideo.dataset.validation_dataset import ValidationDataset
|
| 35 |
+
from fastvideo.distributed import (cleanup_dist_env_and_memory, get_local_torch_device, get_sp_group, get_world_group)
|
| 36 |
+
from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs
|
| 37 |
+
from fastvideo.forward_context import set_forward_context
|
| 38 |
+
from fastvideo.logger import init_logger
|
| 39 |
+
from fastvideo.attention.selector import global_force_attn_backend_context_manager
|
| 40 |
+
from fastvideo.pipelines import (ComposedPipelineBase, ForwardBatch, LoRAPipeline, TrainingBatch)
|
| 41 |
+
from fastvideo.platforms import AttentionBackendEnum, current_platform
|
| 42 |
+
from fastvideo.training.activation_checkpoint import (apply_activation_checkpointing)
|
| 43 |
+
from fastvideo.training.trackers import (DummyTracker, TrackerType, initialize_trackers, Trackers)
|
| 44 |
+
from fastvideo.training.training_utils import (clip_grad_norm_while_handling_failing_dtensor_cases,
|
| 45 |
+
compute_density_for_timestep_sampling, count_trainable, get_scheduler,
|
| 46 |
+
get_sigmas, load_checkpoint, normalize_dit_input, save_checkpoint,
|
| 47 |
+
swap_fp4_linear, traverse_swap_module)
|
| 48 |
+
from fastvideo.utils import (is_vmoba_available, is_vsa_available, set_random_seed, shallow_asdict)
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
vsa_available = is_vsa_available()
|
| 52 |
+
vmoba_available = is_vmoba_available()
|
| 53 |
+
except Exception:
|
| 54 |
+
vsa_available = False
|
| 55 |
+
vmoba_available = False
|
| 56 |
+
|
| 57 |
+
logger = init_logger(__name__)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class TrainingPipeline(LoRAPipeline, ABC):
|
| 61 |
+
"""
|
| 62 |
+
A pipeline for training a model. All training pipelines should inherit from this class.
|
| 63 |
+
All reusable components and code should be implemented in this class.
|
| 64 |
+
"""
|
| 65 |
+
_required_config_modules = ["scheduler", "transformer"]
|
| 66 |
+
validation_pipeline: ComposedPipelineBase
|
| 67 |
+
train_dataloader: StatefulDataLoader
|
| 68 |
+
train_loader_iter: Iterator[dict[str, Any]]
|
| 69 |
+
current_epoch: int = 0
|
| 70 |
+
train_transformer_2: bool = False
|
| 71 |
+
tracker: TrackerType
|
| 72 |
+
|
| 73 |
+
def __init__(self,
|
| 74 |
+
model_path: str,
|
| 75 |
+
fastvideo_args: TrainingArgs,
|
| 76 |
+
required_config_modules: list[str] | None = None,
|
| 77 |
+
loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
|
| 78 |
+
fastvideo_args.inference_mode = False
|
| 79 |
+
self.lora_training = fastvideo_args.lora_training
|
| 80 |
+
if self.lora_training and fastvideo_args.lora_rank is None:
|
| 81 |
+
raise ValueError("lora rank must be set when using lora training")
|
| 82 |
+
|
| 83 |
+
set_random_seed(fastvideo_args.seed) # for lora param init
|
| 84 |
+
super().__init__(model_path, fastvideo_args, required_config_modules, loaded_modules) # type: ignore
|
| 85 |
+
self.tracker = DummyTracker()
|
| 86 |
+
|
| 87 |
+
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
|
| 88 |
+
raise RuntimeError("create_pipeline_stages should not be called for training pipeline")
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def _should_force_generator_attn_qat_train(fastvideo_args: FastVideoArgs) -> bool:
|
| 92 |
+
if not isinstance(fastvideo_args, TrainingArgs):
|
| 93 |
+
return False
|
| 94 |
+
return (fastvideo_args.generator_4bit_attn or envs.FASTVIDEO_ATTENTION_BACKEND == "ATTN_QAT_TRAIN")
|
| 95 |
+
|
| 96 |
+
def load_modules(self,
|
| 97 |
+
fastvideo_args: FastVideoArgs,
|
| 98 |
+
loaded_modules: dict[str, torch.nn.Module] | None = None) -> dict[str, Any]:
|
| 99 |
+
force_generator_qat = self._should_force_generator_attn_qat_train(fastvideo_args)
|
| 100 |
+
load_context: AbstractContextManager[None] = nullcontext()
|
| 101 |
+
if force_generator_qat:
|
| 102 |
+
logger.info("Forcing generator attention backend to ATTN_QAT_TRAIN during module loading")
|
| 103 |
+
load_context = global_force_attn_backend_context_manager(AttentionBackendEnum.ATTN_QAT_TRAIN)
|
| 104 |
+
|
| 105 |
+
with load_context:
|
| 106 |
+
return super().load_modules(fastvideo_args, loaded_modules)
|
| 107 |
+
|
| 108 |
+
def set_schemas(self) -> None:
|
| 109 |
+
self.train_dataset_schema = pyarrow_schema_t2v
|
| 110 |
+
|
| 111 |
+
def initialize_training_pipeline(self, training_args: TrainingArgs):
|
| 112 |
+
logger.info("Initializing training pipeline...")
|
| 113 |
+
self.device = get_local_torch_device()
|
| 114 |
+
self.training_args = training_args
|
| 115 |
+
world_group = get_world_group()
|
| 116 |
+
self.world_size = world_group.world_size
|
| 117 |
+
self.global_rank = world_group.rank
|
| 118 |
+
self.sp_group = get_sp_group()
|
| 119 |
+
self.rank_in_sp_group = self.sp_group.rank_in_group
|
| 120 |
+
self.sp_world_size = self.sp_group.world_size
|
| 121 |
+
self.local_rank = world_group.local_rank
|
| 122 |
+
self.transformer = self.get_module("transformer")
|
| 123 |
+
self.transformer_2 = self.get_module("transformer_2", None)
|
| 124 |
+
self.seed = training_args.seed
|
| 125 |
+
self.set_schemas()
|
| 126 |
+
|
| 127 |
+
# Set random seeds for deterministic training
|
| 128 |
+
assert self.seed is not None, "seed must be set"
|
| 129 |
+
set_random_seed(self.seed + self.global_rank)
|
| 130 |
+
self.transformer.train()
|
| 131 |
+
if training_args.enable_gradient_checkpointing_type is not None:
|
| 132 |
+
self.transformer = apply_activation_checkpointing(
|
| 133 |
+
self.transformer, checkpointing_type=training_args.enable_gradient_checkpointing_type)
|
| 134 |
+
if self.transformer_2 is not None:
|
| 135 |
+
self.transformer_2 = apply_activation_checkpointing(
|
| 136 |
+
self.transformer_2, checkpointing_type=training_args.enable_gradient_checkpointing_type)
|
| 137 |
+
|
| 138 |
+
if training_args.generator_4bit_linear:
|
| 139 |
+
num_swaps = traverse_swap_module(self.transformer, swap_fn=swap_fp4_linear)
|
| 140 |
+
logger.info("Swapped %s linear layers to the FP4 forward path in self.transformer", num_swaps)
|
| 141 |
+
noise_scheduler = self.modules["scheduler"]
|
| 142 |
+
self.set_trainable()
|
| 143 |
+
params_to_optimize = self.transformer.parameters()
|
| 144 |
+
params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))
|
| 145 |
+
# Parse betas from string format "beta1,beta2"
|
| 146 |
+
betas_str = training_args.betas
|
| 147 |
+
betas = tuple(float(x.strip()) for x in betas_str.split(","))
|
| 148 |
+
|
| 149 |
+
self.optimizer = torch.optim.AdamW(
|
| 150 |
+
params_to_optimize,
|
| 151 |
+
lr=training_args.learning_rate,
|
| 152 |
+
betas=betas,
|
| 153 |
+
weight_decay=training_args.weight_decay,
|
| 154 |
+
eps=1e-8,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
self.init_steps = 0
|
| 158 |
+
logger.info("optimizer: %s", self.optimizer)
|
| 159 |
+
|
| 160 |
+
self.lr_scheduler = get_scheduler(
|
| 161 |
+
training_args.lr_scheduler,
|
| 162 |
+
optimizer=self.optimizer,
|
| 163 |
+
num_warmup_steps=training_args.lr_warmup_steps,
|
| 164 |
+
num_training_steps=training_args.max_train_steps,
|
| 165 |
+
num_cycles=training_args.lr_num_cycles,
|
| 166 |
+
power=training_args.lr_power,
|
| 167 |
+
min_lr_ratio=training_args.min_lr_ratio,
|
| 168 |
+
last_epoch=self.init_steps - 1,
|
| 169 |
+
)
|
| 170 |
+
if self.transformer_2 is not None:
|
| 171 |
+
# Ensure transformer_2 has trainable parameters before creating optimizer
|
| 172 |
+
params_to_optimize_2 = self.transformer_2.parameters()
|
| 173 |
+
params_to_optimize_2 = list(filter(lambda p: p.requires_grad, params_to_optimize_2))
|
| 174 |
+
self.optimizer_2 = torch.optim.AdamW(
|
| 175 |
+
params_to_optimize_2,
|
| 176 |
+
lr=training_args.learning_rate,
|
| 177 |
+
betas=(0.9, 0.999),
|
| 178 |
+
weight_decay=training_args.weight_decay,
|
| 179 |
+
eps=1e-8,
|
| 180 |
+
)
|
| 181 |
+
self.lr_scheduler_2 = get_scheduler(
|
| 182 |
+
training_args.lr_scheduler,
|
| 183 |
+
optimizer=self.optimizer_2,
|
| 184 |
+
num_warmup_steps=training_args.lr_warmup_steps,
|
| 185 |
+
num_training_steps=training_args.max_train_steps,
|
| 186 |
+
num_cycles=training_args.lr_num_cycles,
|
| 187 |
+
power=training_args.lr_power,
|
| 188 |
+
min_lr_ratio=training_args.min_lr_ratio,
|
| 189 |
+
last_epoch=self.init_steps - 1,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
self.train_dataset, self.train_dataloader = build_parquet_map_style_dataloader(
|
| 193 |
+
training_args.data_path,
|
| 194 |
+
training_args.train_batch_size,
|
| 195 |
+
parquet_schema=self.train_dataset_schema,
|
| 196 |
+
num_data_workers=training_args.dataloader_num_workers,
|
| 197 |
+
cfg_rate=training_args.training_cfg_rate,
|
| 198 |
+
drop_last=True,
|
| 199 |
+
text_padding_length=training_args.pipeline_config.text_encoder_configs[0].arch_config.
|
| 200 |
+
text_len, # type: ignore[attr-defined]
|
| 201 |
+
seed=self.seed)
|
| 202 |
+
|
| 203 |
+
self.noise_scheduler = noise_scheduler
|
| 204 |
+
if self.training_args.boundary_ratio is not None:
|
| 205 |
+
self.boundary_timestep = self.training_args.boundary_ratio * self.noise_scheduler.num_train_timesteps
|
| 206 |
+
else:
|
| 207 |
+
self.boundary_timestep = None
|
| 208 |
+
|
| 209 |
+
logger.info("train_dataloader length: %s", len(self.train_dataloader))
|
| 210 |
+
logger.info("train_sp_batch_size: %s", training_args.train_sp_batch_size)
|
| 211 |
+
logger.info("gradient_accumulation_steps: %s", training_args.gradient_accumulation_steps)
|
| 212 |
+
logger.info("sp_size: %s", training_args.sp_size)
|
| 213 |
+
|
| 214 |
+
self.num_update_steps_per_epoch = math.ceil(
|
| 215 |
+
len(self.train_dataloader) / training_args.gradient_accumulation_steps * training_args.sp_size /
|
| 216 |
+
training_args.train_sp_batch_size)
|
| 217 |
+
self.num_train_epochs = math.ceil(training_args.max_train_steps / self.num_update_steps_per_epoch)
|
| 218 |
+
|
| 219 |
+
# TODO(will): is there a cleaner way to track epochs?
|
| 220 |
+
self.current_epoch = 0
|
| 221 |
+
|
| 222 |
+
trackers = list(training_args.trackers)
|
| 223 |
+
if not trackers and training_args.tracker_project_name:
|
| 224 |
+
trackers.append(Trackers.WANDB.value)
|
| 225 |
+
if self.global_rank != 0:
|
| 226 |
+
trackers = []
|
| 227 |
+
|
| 228 |
+
tracker_log_dir = training_args.output_dir or os.getcwd()
|
| 229 |
+
if trackers:
|
| 230 |
+
tracker_log_dir = os.path.join(tracker_log_dir, "tracker")
|
| 231 |
+
|
| 232 |
+
tracker_config = asdict(training_args) if trackers else None
|
| 233 |
+
tracker_run_name = training_args.wandb_run_name or None
|
| 234 |
+
project = training_args.tracker_project_name or "fastvideo"
|
| 235 |
+
self.tracker = initialize_trackers(
|
| 236 |
+
trackers,
|
| 237 |
+
experiment_name=project,
|
| 238 |
+
config=tracker_config,
|
| 239 |
+
log_dir=tracker_log_dir,
|
| 240 |
+
run_name=tracker_run_name,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
@abstractmethod
|
| 244 |
+
def initialize_validation_pipeline(self, training_args: TrainingArgs):
|
| 245 |
+
raise NotImplementedError("Training pipelines must implement this method")
|
| 246 |
+
|
| 247 |
+
def _prepare_training(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 248 |
+
self.optimizer.zero_grad()
|
| 249 |
+
if self.transformer_2 is not None:
|
| 250 |
+
self.optimizer_2.zero_grad()
|
| 251 |
+
training_batch.total_loss = 0.0
|
| 252 |
+
return training_batch
|
| 253 |
+
|
| 254 |
+
def _get_next_batch(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 255 |
+
with self.tracker.timed("timing/get_next_batch"):
|
| 256 |
+
batch = next(self.train_loader_iter, None) # type: ignore
|
| 257 |
+
if batch is None:
|
| 258 |
+
self.current_epoch += 1
|
| 259 |
+
logger.info("Starting epoch %s", self.current_epoch)
|
| 260 |
+
# Reset iterator for next epoch
|
| 261 |
+
self.train_loader_iter = iter(self.train_dataloader)
|
| 262 |
+
# Get first batch of new epoch
|
| 263 |
+
batch = next(self.train_loader_iter)
|
| 264 |
+
|
| 265 |
+
latents = batch['vae_latent']
|
| 266 |
+
latents = latents[:, :, :self.training_args.num_latent_t]
|
| 267 |
+
encoder_hidden_states = batch['text_embedding']
|
| 268 |
+
encoder_attention_mask = batch['text_attention_mask']
|
| 269 |
+
infos = batch['info_list']
|
| 270 |
+
|
| 271 |
+
training_batch.latents = latents.to(
|
| 272 |
+
get_local_torch_device(),
|
| 273 |
+
dtype=torch.bfloat16,
|
| 274 |
+
non_blocking=True,
|
| 275 |
+
)
|
| 276 |
+
training_batch.encoder_hidden_states = (encoder_hidden_states.to(
|
| 277 |
+
get_local_torch_device(),
|
| 278 |
+
dtype=torch.bfloat16,
|
| 279 |
+
non_blocking=True,
|
| 280 |
+
))
|
| 281 |
+
training_batch.encoder_attention_mask = (encoder_attention_mask.to(
|
| 282 |
+
get_local_torch_device(),
|
| 283 |
+
dtype=torch.bfloat16,
|
| 284 |
+
non_blocking=True,
|
| 285 |
+
))
|
| 286 |
+
training_batch.infos = infos
|
| 287 |
+
|
| 288 |
+
return training_batch
|
| 289 |
+
|
| 290 |
+
def _normalize_dit_input(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 291 |
+
# TODO(will): support other models
|
| 292 |
+
with self.tracker.timed("timing/normalize_input"):
|
| 293 |
+
training_batch.latents = normalize_dit_input(
|
| 294 |
+
'wan',
|
| 295 |
+
training_batch.latents,
|
| 296 |
+
self.get_module("vae"),
|
| 297 |
+
)
|
| 298 |
+
return training_batch
|
| 299 |
+
|
| 300 |
+
def _prepare_dit_inputs(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 301 |
+
assert self.training_args is not None, "training_args must be set"
|
| 302 |
+
with self.tracker.timed("timing/prepare_dit_inputs"):
|
| 303 |
+
latents = training_batch.latents
|
| 304 |
+
batch_size = latents.shape[0]
|
| 305 |
+
noise = torch.randn(latents.shape,
|
| 306 |
+
generator=self.noise_gen_cuda,
|
| 307 |
+
device=latents.device,
|
| 308 |
+
dtype=latents.dtype)
|
| 309 |
+
timesteps = self._sample_timesteps(batch_size, latents.device)
|
| 310 |
+
|
| 311 |
+
if self.training_args.sp_size > 1:
|
| 312 |
+
# Make sure that the timesteps are the same across all sp processes.
|
| 313 |
+
sp_group = get_sp_group()
|
| 314 |
+
sp_group.broadcast(timesteps, src=0)
|
| 315 |
+
sp_group.broadcast(noise, src=0)
|
| 316 |
+
sigmas = get_sigmas(
|
| 317 |
+
self.noise_scheduler,
|
| 318 |
+
latents.device,
|
| 319 |
+
timesteps,
|
| 320 |
+
n_dim=latents.ndim,
|
| 321 |
+
dtype=latents.dtype,
|
| 322 |
+
)
|
| 323 |
+
noisy_model_input = (1.0 - sigmas) * training_batch.latents + sigmas * noise
|
| 324 |
+
|
| 325 |
+
training_batch.noisy_model_input = noisy_model_input
|
| 326 |
+
training_batch.timesteps = timesteps
|
| 327 |
+
training_batch.sigmas = sigmas
|
| 328 |
+
training_batch.noise = noise
|
| 329 |
+
training_batch.raw_latent_shape = training_batch.latents.shape
|
| 330 |
+
|
| 331 |
+
return training_batch
|
| 332 |
+
|
| 333 |
+
def _sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor:
|
| 334 |
+
# Determine which model to train based on the boundary timestep
|
| 335 |
+
if (self.transformer_2 is not None and self.boundary_timestep is not None
|
| 336 |
+
and torch.rand(1, generator=self.noise_random_generator).item() <= self.training_args.boundary_ratio):
|
| 337 |
+
self.train_transformer_2 = True
|
| 338 |
+
else:
|
| 339 |
+
self.train_transformer_2 = False
|
| 340 |
+
|
| 341 |
+
# Broadcast the decision to all processes
|
| 342 |
+
decision = torch.tensor(1.0 if self.train_transformer_2 else 0.0, device=self.device)
|
| 343 |
+
dist.broadcast(decision, src=0)
|
| 344 |
+
self.train_transformer_2 = decision.item() == 1.0
|
| 345 |
+
|
| 346 |
+
# Sample u from the appropriate range
|
| 347 |
+
u = compute_density_for_timestep_sampling(
|
| 348 |
+
weighting_scheme=self.training_args.weighting_scheme,
|
| 349 |
+
batch_size=batch_size,
|
| 350 |
+
generator=self.noise_random_generator,
|
| 351 |
+
logit_mean=self.training_args.logit_mean,
|
| 352 |
+
logit_std=self.training_args.logit_std,
|
| 353 |
+
mode_scale=self.training_args.mode_scale,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
boundary_ratio = self.training_args.boundary_ratio
|
| 357 |
+
if self.train_transformer_2:
|
| 358 |
+
u = (1 - boundary_ratio) + u * boundary_ratio # min: 1 - boundary_ratio, max: 1
|
| 359 |
+
# elif self.transformer_2 is not None:
|
| 360 |
+
# u = u * (1 - boundary_ratio) # min: 0, max: 1 - boundary_ratio
|
| 361 |
+
# else: # patch for now to align with non-MoE timestep logic
|
| 362 |
+
# pass
|
| 363 |
+
|
| 364 |
+
indices = (u * self.noise_scheduler.config.num_train_timesteps).long()
|
| 365 |
+
return self.noise_scheduler.timesteps[indices].to(device=device)
|
| 366 |
+
|
| 367 |
+
def _build_attention_metadata(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 368 |
+
latents_shape = training_batch.raw_latent_shape
|
| 369 |
+
patch_size = self.training_args.pipeline_config.dit_config.patch_size
|
| 370 |
+
current_vsa_sparsity = training_batch.current_vsa_sparsity
|
| 371 |
+
assert latents_shape is not None
|
| 372 |
+
assert isinstance(patch_size, tuple), f"Expected tuple patch_size, got {patch_size!r}"
|
| 373 |
+
assert training_batch.timesteps is not None
|
| 374 |
+
if envs.FASTVIDEO_ATTENTION_BACKEND in (
|
| 375 |
+
"VIDEO_SPARSE_ATTN",
|
| 376 |
+
"SPARSE_FP4_ATTN",
|
| 377 |
+
"SPARSE_FP4_OURS_P_ATTN",
|
| 378 |
+
):
|
| 379 |
+
if not vsa_available:
|
| 380 |
+
raise ImportError("FASTVIDEO_ATTENTION_BACKEND is set to VIDEO_SPARSE_ATTN, "
|
| 381 |
+
"but fastvideo_kernel is not correctly installed or detected. "
|
| 382 |
+
"Please ensure fastvideo-kernel is installed.")
|
| 383 |
+
training_batch.attn_metadata = VideoSparseAttentionMetadataBuilder( # type: ignore
|
| 384 |
+
).build( # type: ignore
|
| 385 |
+
raw_latent_shape=latents_shape[2:5],
|
| 386 |
+
current_timestep=training_batch.timesteps,
|
| 387 |
+
patch_size=patch_size,
|
| 388 |
+
VSA_sparsity=current_vsa_sparsity,
|
| 389 |
+
device=get_local_torch_device())
|
| 390 |
+
elif envs.FASTVIDEO_ATTENTION_BACKEND == "VMOBA_ATTN":
|
| 391 |
+
if not vmoba_available:
|
| 392 |
+
raise ImportError("FASTVIDEO_ATTENTION_BACKEND is set to VMOBA_ATTN, "
|
| 393 |
+
"but fastvideo_kernel (or flash_attn>=2.7.4) is not correctly installed.")
|
| 394 |
+
moba_params = self.training_args.moba_config.copy()
|
| 395 |
+
moba_params.update({
|
| 396 |
+
"current_timestep": training_batch.timesteps,
|
| 397 |
+
"raw_latent_shape": latents_shape[2:5],
|
| 398 |
+
"patch_size": self.training_args.pipeline_config.dit_config.patch_size,
|
| 399 |
+
"device": get_local_torch_device(),
|
| 400 |
+
})
|
| 401 |
+
training_batch.attn_metadata = VideoMobaAttentionMetadataBuilder().build(**moba_params)
|
| 402 |
+
else:
|
| 403 |
+
training_batch.attn_metadata = None
|
| 404 |
+
|
| 405 |
+
return training_batch
|
| 406 |
+
|
| 407 |
+
def _build_input_kwargs(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 408 |
+
training_batch.input_kwargs = {
|
| 409 |
+
"hidden_states": training_batch.noisy_model_input,
|
| 410 |
+
"encoder_hidden_states": training_batch.encoder_hidden_states,
|
| 411 |
+
"timestep": training_batch.timesteps.to(get_local_torch_device(), dtype=torch.bfloat16),
|
| 412 |
+
"encoder_attention_mask": training_batch.encoder_attention_mask,
|
| 413 |
+
"return_dict": False,
|
| 414 |
+
}
|
| 415 |
+
return training_batch
|
| 416 |
+
|
| 417 |
+
def _transformer_forward_and_compute_loss(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 418 |
+
if vsa_available and envs.FASTVIDEO_ATTENTION_BACKEND in (
|
| 419 |
+
"VIDEO_SPARSE_ATTN",
|
| 420 |
+
"SPARSE_FP4_ATTN",
|
| 421 |
+
"SPARSE_FP4_OURS_P_ATTN",
|
| 422 |
+
) or vmoba_available and envs.FASTVIDEO_ATTENTION_BACKEND == "VMOBA_ATTN":
|
| 423 |
+
assert training_batch.attn_metadata is not None
|
| 424 |
+
else:
|
| 425 |
+
assert training_batch.attn_metadata is None
|
| 426 |
+
input_kwargs = training_batch.input_kwargs
|
| 427 |
+
|
| 428 |
+
# if 'hunyuan' in self.training_args.model_type:
|
| 429 |
+
# input_kwargs["guidance"] = torch.tensor(
|
| 430 |
+
# [1000.0],
|
| 431 |
+
# device=training_batch.noisy_model_input.device,
|
| 432 |
+
# dtype=torch.bfloat16)
|
| 433 |
+
current_model = self.transformer_2 if self.train_transformer_2 else self.transformer
|
| 434 |
+
|
| 435 |
+
with self.tracker.timed("timing/forward_backward"), set_forward_context(
|
| 436 |
+
current_timestep=training_batch.current_timestep, attn_metadata=training_batch.attn_metadata):
|
| 437 |
+
model_pred = current_model(**input_kwargs)
|
| 438 |
+
if self.training_args.precondition_outputs:
|
| 439 |
+
assert training_batch.sigmas is not None
|
| 440 |
+
model_pred = training_batch.noisy_model_input - model_pred * training_batch.sigmas
|
| 441 |
+
assert training_batch.latents is not None
|
| 442 |
+
assert training_batch.noise is not None
|
| 443 |
+
target = training_batch.latents if self.training_args.precondition_outputs else training_batch.noise - training_batch.latents
|
| 444 |
+
|
| 445 |
+
# make sure no implicit broadcasting happens
|
| 446 |
+
assert model_pred.shape == target.shape, f"model_pred.shape: {model_pred.shape}, target.shape: {target.shape}"
|
| 447 |
+
|
| 448 |
+
loss = (torch.mean(
|
| 449 |
+
(model_pred.float() - target.float())**2) / self.training_args.gradient_accumulation_steps)
|
| 450 |
+
|
| 451 |
+
loss.backward()
|
| 452 |
+
|
| 453 |
+
avg_loss = loss.detach().clone()
|
| 454 |
+
|
| 455 |
+
# Reduce across ranks without forcing a CPU sync
|
| 456 |
+
with self.tracker.timed("timing/reduce_loss"):
|
| 457 |
+
world_group = get_world_group()
|
| 458 |
+
avg_loss = world_group.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
|
| 459 |
+
# Accumulate on GPU; materialize to CPU only once after
|
| 460 |
+
# all gradient-accumulation iterations (see train_one_step).
|
| 461 |
+
training_batch.total_loss += avg_loss
|
| 462 |
+
|
| 463 |
+
return training_batch
|
| 464 |
+
|
| 465 |
+
def _clip_grad_norm(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 466 |
+
max_grad_norm = self.training_args.max_grad_norm
|
| 467 |
+
|
| 468 |
+
# TODO(will): perhaps move this into transformer api so that we can do
|
| 469 |
+
# the following:
|
| 470 |
+
# grad_norm = transformer.clip_grad_norm_(max_grad_norm)
|
| 471 |
+
if max_grad_norm is not None:
|
| 472 |
+
with self.tracker.timed("timing/clip_grad_norm"):
|
| 473 |
+
# Only clip gradients for the model that is currently training
|
| 474 |
+
if self.train_transformer_2 and self.transformer_2 is not None:
|
| 475 |
+
model_parts = [self.transformer_2]
|
| 476 |
+
else:
|
| 477 |
+
model_parts = [self.transformer]
|
| 478 |
+
|
| 479 |
+
grad_norm = clip_grad_norm_while_handling_failing_dtensor_cases(
|
| 480 |
+
[p for m in model_parts for p in m.parameters()],
|
| 481 |
+
max_grad_norm,
|
| 482 |
+
foreach=None,
|
| 483 |
+
)
|
| 484 |
+
assert grad_norm is not float('nan') or grad_norm is not float('inf')
|
| 485 |
+
grad_norm = grad_norm.item() if grad_norm is not None else 0.0
|
| 486 |
+
else:
|
| 487 |
+
grad_norm = 0.0
|
| 488 |
+
training_batch.grad_norm = grad_norm
|
| 489 |
+
return training_batch
|
| 490 |
+
|
| 491 |
+
@profile_region("profiler_region_training_train_one_step")
|
| 492 |
+
def train_one_step(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 493 |
+
training_batch = self._prepare_training(training_batch)
|
| 494 |
+
|
| 495 |
+
for _ in range(self.training_args.gradient_accumulation_steps):
|
| 496 |
+
training_batch = self._get_next_batch(training_batch)
|
| 497 |
+
|
| 498 |
+
# Normalize DIT input
|
| 499 |
+
training_batch = self._normalize_dit_input(training_batch)
|
| 500 |
+
# Create noisy model input
|
| 501 |
+
training_batch = self._prepare_dit_inputs(training_batch)
|
| 502 |
+
assert training_batch.latents is not None
|
| 503 |
+
assert training_batch.noisy_model_input is not None
|
| 504 |
+
assert training_batch.noise is not None
|
| 505 |
+
|
| 506 |
+
# old sharding code, need to shard latents and noise but not input
|
| 507 |
+
# Shard latents across sp groups
|
| 508 |
+
training_batch.latents = training_batch.latents[:, :, :self.training_args.num_latent_t]
|
| 509 |
+
# shard noisy_model_input to match
|
| 510 |
+
training_batch.noisy_model_input = training_batch.noisy_model_input[:, :, :self.training_args.num_latent_t]
|
| 511 |
+
# shard noise to match latents
|
| 512 |
+
training_batch.noise = training_batch.noise[:, :, :self.training_args.num_latent_t]
|
| 513 |
+
|
| 514 |
+
training_batch = self._build_attention_metadata(training_batch)
|
| 515 |
+
training_batch = self._build_input_kwargs(training_batch)
|
| 516 |
+
|
| 517 |
+
training_batch = self._transformer_forward_and_compute_loss(training_batch)
|
| 518 |
+
|
| 519 |
+
training_batch = self._clip_grad_norm(training_batch)
|
| 520 |
+
|
| 521 |
+
# Only step the optimizer and scheduler for the model that is currently training
|
| 522 |
+
with self.tracker.timed("timing/optimizer_step"):
|
| 523 |
+
if self.train_transformer_2 and self.transformer_2 is not None:
|
| 524 |
+
self.optimizer_2.step()
|
| 525 |
+
self.lr_scheduler_2.step()
|
| 526 |
+
else:
|
| 527 |
+
self.optimizer.step()
|
| 528 |
+
self.lr_scheduler.step()
|
| 529 |
+
|
| 530 |
+
return training_batch
|
| 531 |
+
|
| 532 |
+
def _compute_current_sparsity(self, step: int) -> float:
|
| 533 |
+
"""Compute the VSA sparsity for a given step using the decay schedule."""
|
| 534 |
+
vsa_sparsity = self.training_args.VSA_sparsity
|
| 535 |
+
vsa_decay_rate = self.training_args.VSA_decay_rate
|
| 536 |
+
vsa_decay_interval = self.training_args.VSA_decay_interval_steps
|
| 537 |
+
vsa_init = getattr(self.training_args, 'VSA_init_sparsity', 0.0)
|
| 538 |
+
vsa_warmup = getattr(self.training_args, 'VSA_warmup_steps', 0)
|
| 539 |
+
if step <= vsa_warmup:
|
| 540 |
+
return vsa_init
|
| 541 |
+
ramp_step = step - vsa_warmup
|
| 542 |
+
max_times = int((vsa_sparsity - vsa_init) / vsa_decay_rate) if vsa_decay_rate > 0 else 0
|
| 543 |
+
times = min(ramp_step // vsa_decay_interval, max_times)
|
| 544 |
+
return vsa_init + times * vsa_decay_rate
|
| 545 |
+
|
| 546 |
+
def _resolve_checkpoint_path(self, path: str) -> str | None:
|
| 547 |
+
"""Resolve 'latest' to the most recent checkpoint in output_dir."""
|
| 548 |
+
import glob
|
| 549 |
+
if path == "latest":
|
| 550 |
+
output_dir = self.training_args.output_dir
|
| 551 |
+
ckpt_dirs = sorted(
|
| 552 |
+
glob.glob(os.path.join(output_dir, "checkpoint-*")),
|
| 553 |
+
key=lambda d: int(d.split("-")[-1]) if d.split("-")[-1].isdigit() else 0,
|
| 554 |
+
)
|
| 555 |
+
if ckpt_dirs:
|
| 556 |
+
latest = ckpt_dirs[-1]
|
| 557 |
+
logger.info("Auto-resolved 'latest' to %s", latest)
|
| 558 |
+
return latest
|
| 559 |
+
logger.info("No checkpoints found in %s, starting from scratch", output_dir)
|
| 560 |
+
return None
|
| 561 |
+
return path
|
| 562 |
+
|
| 563 |
+
def _resume_from_checkpoint(self) -> None:
|
| 564 |
+
ckpt_path = self._resolve_checkpoint_path(self.training_args.resume_from_checkpoint)
|
| 565 |
+
if ckpt_path is None:
|
| 566 |
+
logger.info("No checkpoint to resume from, starting from step 0")
|
| 567 |
+
return
|
| 568 |
+
|
| 569 |
+
safetensors_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model.safetensors")
|
| 570 |
+
step = int(os.path.basename(os.path.normpath(ckpt_path)).split('-')[-1])
|
| 571 |
+
|
| 572 |
+
resumed_step = load_checkpoint(self.transformer, self.global_rank, ckpt_path,
|
| 573 |
+
self.optimizer, self.train_dataloader,
|
| 574 |
+
self.lr_scheduler, self.noise_random_generator)
|
| 575 |
+
if resumed_step > 0 or step == 0:
|
| 576 |
+
self.init_steps = resumed_step
|
| 577 |
+
logger.info("Successfully resumed full training state from step %s", resumed_step)
|
| 578 |
+
return
|
| 579 |
+
|
| 580 |
+
if os.path.exists(safetensors_path):
|
| 581 |
+
self.init_steps = step
|
| 582 |
+
logger.warning("Distributed checkpoint resume failed; falling back to safetensors weights at step %s",
|
| 583 |
+
step)
|
| 584 |
+
return
|
| 585 |
+
|
| 586 |
+
logger.warning("No usable checkpoint state found at %s; starting from step 0", ckpt_path)
|
| 587 |
+
self.init_steps = 0
|
| 588 |
+
|
| 589 |
+
@profile_region("profiler_region_training_train")
|
| 590 |
+
def train(self) -> None:
|
| 591 |
+
assert self.seed is not None, "seed must be set"
|
| 592 |
+
assert self.training_args is not None, "training_args must be set"
|
| 593 |
+
set_random_seed(self.seed + self.global_rank)
|
| 594 |
+
logger.info('rank: %s: start training', self.global_rank, local_main_process_only=False)
|
| 595 |
+
if not self.post_init_called:
|
| 596 |
+
self.post_init()
|
| 597 |
+
num_trainable_params = count_trainable(self.transformer)
|
| 598 |
+
logger.info("Starting training with %s B trainable parameters", round(num_trainable_params / 1e9, 3))
|
| 599 |
+
|
| 600 |
+
if getattr(self, "transformer_2", None) is not None:
|
| 601 |
+
num_trainable_params = count_trainable(self.transformer_2)
|
| 602 |
+
logger.info("Transformer 2: Starting training with %s B trainable parameters",
|
| 603 |
+
round(num_trainable_params / 1e9, 3))
|
| 604 |
+
|
| 605 |
+
# Set random seeds for deterministic training
|
| 606 |
+
self.noise_random_generator = torch.Generator(device="cpu").manual_seed(self.seed + self.global_rank)
|
| 607 |
+
self.noise_gen_cuda = torch.Generator(device=current_platform.device_name).manual_seed(self.seed +
|
| 608 |
+
self.global_rank)
|
| 609 |
+
self.validation_random_generator = torch.Generator(device="cpu").manual_seed(self.seed + self.global_rank)
|
| 610 |
+
logger.info("Initialized random seeds with seed: %s", self.seed + self.global_rank)
|
| 611 |
+
self.noise_scheduler = FlowMatchEulerDiscreteScheduler()
|
| 612 |
+
|
| 613 |
+
if self.training_args.resume_from_checkpoint:
|
| 614 |
+
self._resume_from_checkpoint()
|
| 615 |
+
|
| 616 |
+
self.train_loader_iter = iter(self.train_dataloader)
|
| 617 |
+
|
| 618 |
+
step_times: deque[float] = deque(maxlen=100)
|
| 619 |
+
|
| 620 |
+
self._log_training_info()
|
| 621 |
+
|
| 622 |
+
# Validation at init uses the sparsity corresponding to init_steps
|
| 623 |
+
saved_sparsity = self.training_args.VSA_sparsity
|
| 624 |
+
self.training_args.VSA_sparsity = self._compute_current_sparsity(self.init_steps)
|
| 625 |
+
self._log_validation(self.transformer, self.training_args, self.init_steps)
|
| 626 |
+
self.training_args.VSA_sparsity = saved_sparsity
|
| 627 |
+
|
| 628 |
+
# Train!
|
| 629 |
+
progress_bar = tqdm(
|
| 630 |
+
range(0, self.training_args.max_train_steps),
|
| 631 |
+
initial=self.init_steps,
|
| 632 |
+
desc="Steps",
|
| 633 |
+
# Only show the progress bar once on each machine.
|
| 634 |
+
disable=self.local_rank > 0,
|
| 635 |
+
)
|
| 636 |
+
for step in range(self.init_steps + 1, self.training_args.max_train_steps + 1):
|
| 637 |
+
start_time = time.perf_counter()
|
| 638 |
+
if vsa_available:
|
| 639 |
+
vsa_sparsity = self.training_args.VSA_sparsity
|
| 640 |
+
vsa_decay_rate = self.training_args.VSA_decay_rate
|
| 641 |
+
vsa_decay_interval_steps = self.training_args.VSA_decay_interval_steps
|
| 642 |
+
vsa_init_sparsity = getattr(self.training_args, 'VSA_init_sparsity', 0.0)
|
| 643 |
+
vsa_warmup_steps = getattr(self.training_args, 'VSA_warmup_steps', 0)
|
| 644 |
+
if step <= vsa_warmup_steps:
|
| 645 |
+
current_vsa_sparsity = vsa_init_sparsity
|
| 646 |
+
else:
|
| 647 |
+
ramp_step = step - vsa_warmup_steps
|
| 648 |
+
max_decay_times = int((vsa_sparsity - vsa_init_sparsity) / vsa_decay_rate)
|
| 649 |
+
current_decay_times = min(ramp_step // vsa_decay_interval_steps, max_decay_times)
|
| 650 |
+
current_vsa_sparsity = vsa_init_sparsity + current_decay_times * vsa_decay_rate
|
| 651 |
+
elif vmoba_available:
|
| 652 |
+
#TODO: add vmoba sparsity scheduling here
|
| 653 |
+
current_vsa_sparsity = 0.0
|
| 654 |
+
else:
|
| 655 |
+
current_vsa_sparsity = 0.0
|
| 656 |
+
|
| 657 |
+
training_batch = TrainingBatch()
|
| 658 |
+
training_batch.current_timestep = step
|
| 659 |
+
training_batch.current_vsa_sparsity = current_vsa_sparsity
|
| 660 |
+
training_batch = self.train_one_step(training_batch)
|
| 661 |
+
|
| 662 |
+
loss = float(training_batch.total_loss)
|
| 663 |
+
grad_norm = training_batch.grad_norm
|
| 664 |
+
|
| 665 |
+
step_time = time.perf_counter() - start_time
|
| 666 |
+
step_times.append(step_time)
|
| 667 |
+
avg_step_time = sum(step_times) / len(step_times)
|
| 668 |
+
|
| 669 |
+
progress_bar.set_postfix({
|
| 670 |
+
"loss": f"{loss:.4f}",
|
| 671 |
+
"step_time": f"{step_time:.2f}s",
|
| 672 |
+
"grad_norm": grad_norm,
|
| 673 |
+
})
|
| 674 |
+
progress_bar.update(1)
|
| 675 |
+
if self.global_rank == 0:
|
| 676 |
+
metrics = {
|
| 677 |
+
"train_loss": loss,
|
| 678 |
+
"learning_rate": self.lr_scheduler.get_last_lr()[0],
|
| 679 |
+
"step_time": step_time,
|
| 680 |
+
"avg_step_time": avg_step_time,
|
| 681 |
+
"grad_norm": grad_norm,
|
| 682 |
+
"vsa_sparsity": current_vsa_sparsity,
|
| 683 |
+
}
|
| 684 |
+
try:
|
| 685 |
+
assert training_batch.raw_latent_shape is not None
|
| 686 |
+
metrics["batch_size"] = int(training_batch.raw_latent_shape[0])
|
| 687 |
+
|
| 688 |
+
patch_size = self.training_args.pipeline_config.dit_config.patch_size
|
| 689 |
+
assert isinstance(patch_size, tuple), f"Expected tuple patch_size, got {patch_size!r}"
|
| 690 |
+
patch_t, patch_h, patch_w = patch_size
|
| 691 |
+
seq_len = (training_batch.raw_latent_shape[2] // patch_t) * (
|
| 692 |
+
training_batch.raw_latent_shape[3] // patch_h) * (training_batch.raw_latent_shape[4] // patch_w)
|
| 693 |
+
if training_batch.encoder_hidden_states is not None:
|
| 694 |
+
context_len = int(training_batch.encoder_hidden_states.shape[1])
|
| 695 |
+
else:
|
| 696 |
+
context_len = 0
|
| 697 |
+
|
| 698 |
+
metrics["dit_seq_len"] = int(seq_len)
|
| 699 |
+
metrics["context_len"] = context_len
|
| 700 |
+
|
| 701 |
+
arch_config = self.training_args.pipeline_config.dit_config.arch_config
|
| 702 |
+
|
| 703 |
+
metrics["hidden_dim"] = arch_config.hidden_size
|
| 704 |
+
metrics["num_layers"] = arch_config.num_layers
|
| 705 |
+
metrics["ffn_dim"] = arch_config.ffn_dim
|
| 706 |
+
except Exception:
|
| 707 |
+
pass
|
| 708 |
+
|
| 709 |
+
self.tracker.log(metrics, step)
|
| 710 |
+
if step % self.training_args.training_state_checkpointing_steps == 0:
|
| 711 |
+
with self.profiler_controller.region("profiler_region_training_save_checkpoint"):
|
| 712 |
+
save_checkpoint(self.transformer, self.global_rank, self.training_args.output_dir, step,
|
| 713 |
+
self.optimizer, self.train_dataloader, self.lr_scheduler,
|
| 714 |
+
self.noise_random_generator,
|
| 715 |
+
self.training_args.checkpoints_total_limit)
|
| 716 |
+
self.transformer.train()
|
| 717 |
+
self.sp_group.barrier()
|
| 718 |
+
|
| 719 |
+
if self.training_args.log_visualization and step % self.training_args.visualization_steps == 0:
|
| 720 |
+
self.visualize_intermediate_latents(training_batch, self.training_args, step)
|
| 721 |
+
|
| 722 |
+
if self.training_args.log_validation and step % self.training_args.validation_steps == 0:
|
| 723 |
+
with self.profiler_controller.region("profiler_region_training_validation"):
|
| 724 |
+
saved_sparsity = self.training_args.VSA_sparsity
|
| 725 |
+
self.training_args.VSA_sparsity = current_vsa_sparsity
|
| 726 |
+
self._log_validation(self.transformer, self.training_args, step)
|
| 727 |
+
self.training_args.VSA_sparsity = saved_sparsity
|
| 728 |
+
gpu_memory_usage = current_platform.get_torch_device().memory_allocated() / 1024**2
|
| 729 |
+
trainable_params = round(count_trainable(self.transformer) / 1e9, 3)
|
| 730 |
+
logger.info("GPU memory usage after validation: %s MB, trainable params: %sB", gpu_memory_usage,
|
| 731 |
+
trainable_params)
|
| 732 |
+
|
| 733 |
+
self.tracker.finish()
|
| 734 |
+
save_checkpoint(self.transformer, self.global_rank, self.training_args.output_dir,
|
| 735 |
+
self.training_args.max_train_steps, self.optimizer, self.train_dataloader, self.lr_scheduler,
|
| 736 |
+
self.noise_random_generator, self.training_args.checkpoints_total_limit)
|
| 737 |
+
|
| 738 |
+
if envs.FASTVIDEO_TORCH_PROFILER_DIR:
|
| 739 |
+
logger.info("Stopping profiler...")
|
| 740 |
+
self.profiler_controller.stop()
|
| 741 |
+
logger.info("Profiler stopped.")
|
| 742 |
+
|
| 743 |
+
if get_sp_group():
|
| 744 |
+
cleanup_dist_env_and_memory()
|
| 745 |
+
|
| 746 |
+
def _log_training_info(self) -> None:
|
| 747 |
+
assert self.training_args is not None, "training_args must be set"
|
| 748 |
+
total_batch_size = (self.world_size * self.training_args.gradient_accumulation_steps /
|
| 749 |
+
self.training_args.sp_size * self.training_args.train_sp_batch_size)
|
| 750 |
+
logger.info("***** Running training *****")
|
| 751 |
+
logger.info(" Num examples = %s", len(self.train_dataset))
|
| 752 |
+
logger.info(" Dataloader size = %s", len(self.train_dataloader))
|
| 753 |
+
logger.info(" Num Epochs = %s", self.num_train_epochs)
|
| 754 |
+
logger.info(" Resume training from step %s", self.init_steps) # type: ignore
|
| 755 |
+
logger.info(" Instantaneous batch size per device = %s", self.training_args.train_batch_size)
|
| 756 |
+
logger.info(" Total train batch size (w. data & sequence parallel, accumulation) = %s", total_batch_size)
|
| 757 |
+
logger.info(" Gradient Accumulation steps = %s", self.training_args.gradient_accumulation_steps)
|
| 758 |
+
logger.info(" Total optimization steps = %s", self.training_args.max_train_steps)
|
| 759 |
+
logger.info(" Total training parameters per FSDP shard = %s B",
|
| 760 |
+
round(count_trainable(self.transformer) / 1e9, 3))
|
| 761 |
+
# print dtype
|
| 762 |
+
logger.info(" Master weight dtype: %s", self.transformer.parameters().__next__().dtype)
|
| 763 |
+
|
| 764 |
+
gpu_memory_usage = current_platform.get_torch_device().memory_allocated() / 1024**2
|
| 765 |
+
logger.info("GPU memory usage before train_one_step: %s MB", gpu_memory_usage)
|
| 766 |
+
logger.info("VSA validation sparsity: %s", self.training_args.VSA_sparsity)
|
| 767 |
+
|
| 768 |
+
def _prepare_validation_batch(self, sampling_param: SamplingParam, training_args: TrainingArgs,
|
| 769 |
+
validation_batch: dict[str, Any], num_inference_steps: int) -> ForwardBatch:
|
| 770 |
+
sampling_param.prompt = validation_batch['prompt']
|
| 771 |
+
sampling_param.height = training_args.num_height
|
| 772 |
+
sampling_param.width = training_args.num_width
|
| 773 |
+
sampling_param.num_inference_steps = num_inference_steps
|
| 774 |
+
sampling_param.data_type = "video"
|
| 775 |
+
if training_args.validation_guidance_scale:
|
| 776 |
+
sampling_param.guidance_scale = float(training_args.validation_guidance_scale)
|
| 777 |
+
assert self.seed is not None
|
| 778 |
+
sampling_param.seed = self.seed
|
| 779 |
+
|
| 780 |
+
latents_size = [(sampling_param.num_frames - 1) // 4 + 1, sampling_param.height // 8, sampling_param.width // 8]
|
| 781 |
+
n_tokens = latents_size[0] * latents_size[1] * latents_size[2]
|
| 782 |
+
temporal_compression_factor = training_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio
|
| 783 |
+
num_frames = (training_args.num_latent_t - 1) * temporal_compression_factor + 1
|
| 784 |
+
sampling_param.num_frames = num_frames
|
| 785 |
+
batch = ForwardBatch(
|
| 786 |
+
**shallow_asdict(sampling_param),
|
| 787 |
+
latents=None,
|
| 788 |
+
generator=self.validation_random_generator,
|
| 789 |
+
n_tokens=n_tokens,
|
| 790 |
+
eta=0.0,
|
| 791 |
+
VSA_sparsity=training_args.VSA_sparsity,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
return batch
|
| 795 |
+
|
| 796 |
+
@torch.no_grad()
|
| 797 |
+
def _log_validation(self, transformer, training_args, global_step) -> None:
|
| 798 |
+
"""
|
| 799 |
+
Generate a validation video and log it to the configured tracker to check the quality during training.
|
| 800 |
+
"""
|
| 801 |
+
training_args.inference_mode = True
|
| 802 |
+
training_args.dit_cpu_offload = False
|
| 803 |
+
if not training_args.log_validation:
|
| 804 |
+
return
|
| 805 |
+
if self.validation_pipeline is None:
|
| 806 |
+
raise ValueError("Validation pipeline is not set")
|
| 807 |
+
|
| 808 |
+
logger.info("Starting validation")
|
| 809 |
+
|
| 810 |
+
# Create sampling parameters if not provided
|
| 811 |
+
sampling_param = SamplingParam.from_pretrained(training_args.model_path)
|
| 812 |
+
|
| 813 |
+
# Prepare validation prompts
|
| 814 |
+
logger.info('rank: %s: fastvideo_args.validation_dataset_file: %s',
|
| 815 |
+
self.global_rank,
|
| 816 |
+
training_args.validation_dataset_file,
|
| 817 |
+
local_main_process_only=False)
|
| 818 |
+
validation_dataset = ValidationDataset(training_args.validation_dataset_file)
|
| 819 |
+
validation_dataloader = DataLoader(validation_dataset, batch_size=None, num_workers=0)
|
| 820 |
+
|
| 821 |
+
self.transformer.eval()
|
| 822 |
+
if getattr(self, "transformer_2", None) is not None:
|
| 823 |
+
self.transformer_2.eval()
|
| 824 |
+
|
| 825 |
+
validation_steps = training_args.validation_sampling_steps.split(",")
|
| 826 |
+
validation_steps = [int(step) for step in validation_steps]
|
| 827 |
+
validation_steps = [step for step in validation_steps if step > 0]
|
| 828 |
+
# Log validation results for this step
|
| 829 |
+
world_group = get_world_group()
|
| 830 |
+
num_sp_groups = world_group.world_size // self.sp_group.world_size
|
| 831 |
+
one_prompt_per_rank = os.environ.get(
|
| 832 |
+
"FASTVIDEO_VALIDATION_ONE_PROMPT_PER_RANK",
|
| 833 |
+
"",
|
| 834 |
+
).lower() in {"1", "true", "yes", "on"}
|
| 835 |
+
|
| 836 |
+
# Process each validation prompt for each validation step
|
| 837 |
+
for num_inference_steps in validation_steps:
|
| 838 |
+
logger.info("rank: %s: num_inference_steps: %s",
|
| 839 |
+
self.global_rank,
|
| 840 |
+
num_inference_steps,
|
| 841 |
+
local_main_process_only=False)
|
| 842 |
+
step_videos: list[np.ndarray] = []
|
| 843 |
+
step_captions: list[str] = []
|
| 844 |
+
|
| 845 |
+
step_audio: list[np.ndarray | None] = []
|
| 846 |
+
step_sample_rates: list[int | None] = []
|
| 847 |
+
|
| 848 |
+
for prompt_idx, validation_batch in enumerate(validation_dataloader):
|
| 849 |
+
if one_prompt_per_rank and prompt_idx > 0:
|
| 850 |
+
continue
|
| 851 |
+
|
| 852 |
+
batch = self._prepare_validation_batch(sampling_param, training_args, validation_batch,
|
| 853 |
+
num_inference_steps)
|
| 854 |
+
logger.info("rank: %s: rank_in_sp_group: %s, batch.prompt: %s",
|
| 855 |
+
self.global_rank,
|
| 856 |
+
self.rank_in_sp_group,
|
| 857 |
+
batch.prompt,
|
| 858 |
+
local_main_process_only=False)
|
| 859 |
+
|
| 860 |
+
assert batch.prompt is not None and isinstance(batch.prompt, str)
|
| 861 |
+
step_captions.append(batch.prompt)
|
| 862 |
+
|
| 863 |
+
# Run validation inference
|
| 864 |
+
output_batch = self.validation_pipeline.forward(batch, training_args)
|
| 865 |
+
samples = output_batch.output.cpu()
|
| 866 |
+
|
| 867 |
+
# Capture audio if available
|
| 868 |
+
audio = output_batch.extra.get("audio")
|
| 869 |
+
sample_rate = output_batch.extra.get("audio_sample_rate")
|
| 870 |
+
|
| 871 |
+
if audio is not None and torch.is_tensor(audio):
|
| 872 |
+
audio = audio.detach().cpu().float().numpy()
|
| 873 |
+
|
| 874 |
+
step_audio.append(audio)
|
| 875 |
+
step_sample_rates.append(sample_rate)
|
| 876 |
+
|
| 877 |
+
if self.rank_in_sp_group != 0:
|
| 878 |
+
continue
|
| 879 |
+
|
| 880 |
+
# Process outputs
|
| 881 |
+
video = rearrange(samples, "b c t h w -> t b c h w")
|
| 882 |
+
frames = []
|
| 883 |
+
for x in video:
|
| 884 |
+
x = torchvision.utils.make_grid(x, nrow=6)
|
| 885 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
| 886 |
+
frames.append((x * 255).numpy().astype(np.uint8))
|
| 887 |
+
step_videos.append(frames)
|
| 888 |
+
|
| 889 |
+
# Only sp_group leaders (rank_in_sp_group == 0) need to send their
|
| 890 |
+
# results to global rank 0
|
| 891 |
+
if self.rank_in_sp_group == 0 and self.global_rank == 0:
|
| 892 |
+
# Global rank 0 collects results from all sp_group leaders
|
| 893 |
+
all_videos = step_videos # Start with own results
|
| 894 |
+
all_captions = step_captions
|
| 895 |
+
all_audios = step_audio
|
| 896 |
+
all_sample_rates = step_sample_rates
|
| 897 |
+
|
| 898 |
+
# Receive from other sp_group leaders
|
| 899 |
+
for sp_group_idx in range(1, num_sp_groups):
|
| 900 |
+
src_rank = sp_group_idx * self.sp_world_size # Global rank of other sp_group leaders
|
| 901 |
+
recv_videos = world_group.recv_object(src=src_rank)
|
| 902 |
+
recv_captions = world_group.recv_object(src=src_rank)
|
| 903 |
+
recv_audios = world_group.recv_object(src=src_rank)
|
| 904 |
+
recv_sample_rates = world_group.recv_object(src=src_rank)
|
| 905 |
+
|
| 906 |
+
all_videos.extend(recv_videos)
|
| 907 |
+
all_captions.extend(recv_captions)
|
| 908 |
+
all_audios.extend(recv_audios)
|
| 909 |
+
all_sample_rates.extend(recv_sample_rates)
|
| 910 |
+
|
| 911 |
+
video_filenames = []
|
| 912 |
+
for i, (video, caption, audio, sample_rate) in enumerate(
|
| 913 |
+
zip(all_videos, all_captions, all_audios, all_sample_rates, strict=True)):
|
| 914 |
+
os.makedirs(training_args.output_dir, exist_ok=True)
|
| 915 |
+
filename = os.path.join(
|
| 916 |
+
training_args.output_dir,
|
| 917 |
+
f"validation_step_{global_step}_inference_steps_{num_inference_steps}_video_{i}.mp4")
|
| 918 |
+
imageio.mimsave(filename, video, fps=sampling_param.fps)
|
| 919 |
+
# Mux audio if available
|
| 920 |
+
if (audio is not None and sample_rate is not None and not self._mux_audio(
|
| 921 |
+
filename,
|
| 922 |
+
audio,
|
| 923 |
+
sample_rate,
|
| 924 |
+
)):
|
| 925 |
+
logger.warning("Audio mux failed for validation video %s; saved video without audio.", filename)
|
| 926 |
+
video_filenames.append(filename)
|
| 927 |
+
|
| 928 |
+
artifacts = []
|
| 929 |
+
for filename, caption in zip(video_filenames, all_captions, strict=True):
|
| 930 |
+
video_artifact = self.tracker.video(filename, caption=caption)
|
| 931 |
+
if video_artifact is not None:
|
| 932 |
+
artifacts.append(video_artifact)
|
| 933 |
+
if artifacts:
|
| 934 |
+
logs = {f"validation_videos_{num_inference_steps}_steps": artifacts}
|
| 935 |
+
self.tracker.log_artifacts(logs, global_step)
|
| 936 |
+
elif self.rank_in_sp_group == 0:
|
| 937 |
+
# Other sp_group leaders send their results to global rank 0
|
| 938 |
+
world_group.send_object(step_videos, dst=0)
|
| 939 |
+
world_group.send_object(step_captions, dst=0)
|
| 940 |
+
world_group.send_object(step_audio, dst=0)
|
| 941 |
+
world_group.send_object(step_sample_rates, dst=0)
|
| 942 |
+
|
| 943 |
+
world_group.barrier()
|
| 944 |
+
|
| 945 |
+
# Re-enable gradients for training
|
| 946 |
+
training_args.inference_mode = False
|
| 947 |
+
self.transformer.train()
|
| 948 |
+
if getattr(self, "transformer_2", None) is not None:
|
| 949 |
+
self.transformer_2.train()
|
| 950 |
+
|
| 951 |
+
@staticmethod
|
| 952 |
+
def _mux_audio(
|
| 953 |
+
video_path: str,
|
| 954 |
+
audio: torch.Tensor | np.ndarray,
|
| 955 |
+
sample_rate: int,
|
| 956 |
+
) -> bool:
|
| 957 |
+
"""Mux audio into video using PyAV."""
|
| 958 |
+
try:
|
| 959 |
+
import av
|
| 960 |
+
except ImportError:
|
| 961 |
+
logger.warning("PyAV not installed; cannot mux audio. "
|
| 962 |
+
"Install with: pip install av")
|
| 963 |
+
return False
|
| 964 |
+
|
| 965 |
+
if torch.is_tensor(audio):
|
| 966 |
+
audio_np = audio.detach().cpu().float().numpy()
|
| 967 |
+
else:
|
| 968 |
+
audio_np = np.asarray(audio, dtype=np.float32)
|
| 969 |
+
|
| 970 |
+
if audio_np.ndim == 1:
|
| 971 |
+
audio_np = audio_np[:, None]
|
| 972 |
+
elif audio_np.ndim == 2:
|
| 973 |
+
if audio_np.shape[0] <= 8 and audio_np.shape[1] > audio_np.shape[0]:
|
| 974 |
+
audio_np = audio_np.T
|
| 975 |
+
else:
|
| 976 |
+
logger.warning("Unexpected audio shape %s; skipping mux.", audio_np.shape)
|
| 977 |
+
return False
|
| 978 |
+
|
| 979 |
+
audio_np = np.clip(audio_np, -1.0, 1.0)
|
| 980 |
+
audio_int16 = (audio_np * 32767.0).astype(np.int16)
|
| 981 |
+
num_channels = audio_int16.shape[1]
|
| 982 |
+
layout = "stereo" if num_channels == 2 else "mono"
|
| 983 |
+
|
| 984 |
+
try:
|
| 985 |
+
import wave
|
| 986 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 987 |
+
out_path = os.path.join(tmpdir, "muxed.mp4")
|
| 988 |
+
wav_path = os.path.join(tmpdir, "audio.wav")
|
| 989 |
+
|
| 990 |
+
# Write audio to WAV file
|
| 991 |
+
with wave.open(wav_path, "wb") as wav_file:
|
| 992 |
+
wav_file.setnchannels(num_channels)
|
| 993 |
+
wav_file.setsampwidth(2)
|
| 994 |
+
wav_file.setframerate(sample_rate)
|
| 995 |
+
wav_file.writeframes(audio_int16.tobytes())
|
| 996 |
+
|
| 997 |
+
# Open input video and audio
|
| 998 |
+
input_video = av.open(video_path)
|
| 999 |
+
input_audio = av.open(wav_path)
|
| 1000 |
+
|
| 1001 |
+
# Create output with both streams
|
| 1002 |
+
output = av.open(out_path, mode="w")
|
| 1003 |
+
|
| 1004 |
+
# Add video stream (copy codec from input)
|
| 1005 |
+
in_video_stream = input_video.streams.video[0]
|
| 1006 |
+
out_video_stream = output.add_stream(
|
| 1007 |
+
codec_name=in_video_stream.codec_context.name,
|
| 1008 |
+
rate=in_video_stream.average_rate,
|
| 1009 |
+
)
|
| 1010 |
+
out_video_stream.width = in_video_stream.width
|
| 1011 |
+
out_video_stream.height = in_video_stream.height
|
| 1012 |
+
out_video_stream.pix_fmt = in_video_stream.pix_fmt
|
| 1013 |
+
|
| 1014 |
+
# Add audio stream (AAC)
|
| 1015 |
+
out_audio_stream = output.add_stream("aac", rate=sample_rate)
|
| 1016 |
+
out_audio_stream.layout = layout
|
| 1017 |
+
|
| 1018 |
+
# Remux video (decode and re-encode to be safe)
|
| 1019 |
+
for frame in input_video.decode(video=0):
|
| 1020 |
+
for packet in out_video_stream.encode(frame):
|
| 1021 |
+
output.mux(packet)
|
| 1022 |
+
for packet in out_video_stream.encode():
|
| 1023 |
+
output.mux(packet)
|
| 1024 |
+
|
| 1025 |
+
# Encode audio
|
| 1026 |
+
for frame in input_audio.decode(audio=0):
|
| 1027 |
+
frame.pts = None # Let encoder assign PTS
|
| 1028 |
+
for packet in out_audio_stream.encode(frame):
|
| 1029 |
+
output.mux(packet)
|
| 1030 |
+
for packet in out_audio_stream.encode():
|
| 1031 |
+
output.mux(packet)
|
| 1032 |
+
|
| 1033 |
+
input_video.close()
|
| 1034 |
+
input_audio.close()
|
| 1035 |
+
output.close()
|
| 1036 |
+
shutil.move(out_path, video_path)
|
| 1037 |
+
return True
|
| 1038 |
+
except Exception as e:
|
| 1039 |
+
logger.warning("Audio mux failed: %s", e)
|
| 1040 |
+
return False
|
| 1041 |
+
|
| 1042 |
+
def visualize_intermediate_latents(self, training_batch: TrainingBatch, training_args: TrainingArgs, step: int):
|
| 1043 |
+
"""Add visualization data to tracker logging and save frames to disk."""
|
| 1044 |
+
raise NotImplementedError("Visualize intermediate latents is not implemented for training pipeline")
|
standalone_inference/overlay_files/fastvideo/training/wan_training_pipeline.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
import sys
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
|
| 5 |
+
from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs
|
| 6 |
+
from fastvideo.logger import init_logger
|
| 7 |
+
from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import (FlowUniPCMultistepScheduler)
|
| 8 |
+
from fastvideo.pipelines.basic.wan.wan_pipeline import WanPipeline
|
| 9 |
+
from fastvideo.training.training_pipeline import TrainingPipeline
|
| 10 |
+
from fastvideo.utils import is_vsa_available
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
vsa_available = is_vsa_available()
|
| 14 |
+
except Exception:
|
| 15 |
+
vsa_available = False
|
| 16 |
+
|
| 17 |
+
logger = init_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class WanTrainingPipeline(TrainingPipeline):
|
| 21 |
+
"""
|
| 22 |
+
A training pipeline for Wan.
|
| 23 |
+
"""
|
| 24 |
+
_required_config_modules = ["scheduler", "transformer", "vae"]
|
| 25 |
+
|
| 26 |
+
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
|
| 27 |
+
self.modules["scheduler"] = FlowUniPCMultistepScheduler(shift=fastvideo_args.pipeline_config.flow_shift)
|
| 28 |
+
|
| 29 |
+
def create_training_stages(self, training_args: TrainingArgs):
|
| 30 |
+
"""
|
| 31 |
+
May be used in future refactors.
|
| 32 |
+
"""
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
def initialize_validation_pipeline(self, training_args: TrainingArgs):
|
| 36 |
+
logger.info("Initializing validation pipeline...")
|
| 37 |
+
args_copy = deepcopy(training_args)
|
| 38 |
+
|
| 39 |
+
args_copy.inference_mode = True
|
| 40 |
+
validation_pipeline = WanPipeline.from_pretrained(
|
| 41 |
+
training_args.model_path,
|
| 42 |
+
args=args_copy, # type: ignore
|
| 43 |
+
inference_mode=True,
|
| 44 |
+
loaded_modules={
|
| 45 |
+
"transformer": self.get_module("transformer"),
|
| 46 |
+
},
|
| 47 |
+
tp_size=training_args.tp_size,
|
| 48 |
+
sp_size=training_args.sp_size,
|
| 49 |
+
num_gpus=training_args.num_gpus,
|
| 50 |
+
pin_cpu_memory=training_args.pin_cpu_memory,
|
| 51 |
+
dit_cpu_offload=True)
|
| 52 |
+
|
| 53 |
+
self.validation_pipeline = validation_pipeline
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def main(args) -> None:
|
| 57 |
+
logger.info("Starting training pipeline...")
|
| 58 |
+
|
| 59 |
+
pipeline = WanTrainingPipeline.from_pretrained(args.pretrained_model_name_or_path, args=args)
|
| 60 |
+
args = pipeline.training_args
|
| 61 |
+
pipeline.train()
|
| 62 |
+
logger.info("Training pipeline done")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
argv = sys.argv
|
| 67 |
+
from fastvideo.fastvideo_args import TrainingArgs
|
| 68 |
+
from fastvideo.utils import FlexibleArgumentParser
|
| 69 |
+
parser = FlexibleArgumentParser()
|
| 70 |
+
parser = TrainingArgs.add_cli_args(parser)
|
| 71 |
+
parser = FastVideoArgs.add_cli_args(parser)
|
| 72 |
+
args = parser.parse_args()
|
| 73 |
+
args.dit_cpu_offload = False
|
| 74 |
+
main(args)
|
standalone_inference/requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Install FastVideo itself from the upstream project or from your local checkout.
|
| 2 |
+
# This file only lists the extra Python packages directly used by the helper.
|
| 3 |
+
huggingface_hub
|
| 4 |
+
safetensors
|
| 5 |
+
triton
|
standalone_inference/run.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
set -euo pipefail
|
| 4 |
+
|
| 5 |
+
BUNDLE_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 6 |
+
FASTVIDEO_ROOT="${FASTVIDEO_ROOT:-}"
|
| 7 |
+
|
| 8 |
+
if [[ -z "${FASTVIDEO_ROOT}" ]]; then
|
| 9 |
+
echo "FASTVIDEO_ROOT is not set."
|
| 10 |
+
echo "Set it to a FastVideo source checkout or installed package root, for example:"
|
| 11 |
+
echo " FASTVIDEO_ROOT=/path/to/FastVideo bash standalone_inference/run.sh"
|
| 12 |
+
exit 1
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
python "${BUNDLE_ROOT}/install_overlay.py" --fastvideo-root "${FASTVIDEO_ROOT}"
|
| 16 |
+
|
| 17 |
+
export PYTHONPATH="${FASTVIDEO_ROOT}/fastvideo-kernel/python:${FASTVIDEO_ROOT}/fastvideo-kernel:${PYTHONPATH:-}"
|
| 18 |
+
export FASTVIDEO_ATTENTION_BACKEND=SPARSE_FP4_OURS_P_ATTN
|
| 19 |
+
export FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O=1
|
| 20 |
+
|
| 21 |
+
cd "${FASTVIDEO_ROOT}"
|
| 22 |
+
python "${BUNDLE_ROOT}/run_inference.py" "$@"
|
standalone_inference/run_inference.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Run Wan T2V inference with the sparse FP4 checkpoint-700 transformer."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
DEFAULT_PROMPT = (
|
| 12 |
+
"In the video, a woman is elegantly showcasing her earrings, bringing "
|
| 13 |
+
"attention to their intricate design with a gentle touch of her fingers. "
|
| 14 |
+
"She is bathed in ambient purple and pink lighting, which casts a soft "
|
| 15 |
+
"glow on her delicate features and enhances the vivid tones of her lipstick "
|
| 16 |
+
"and eye makeup. Her hair is styled to frame her face smoothly, emphasizing "
|
| 17 |
+
"the contours of her jawline and cheekbones. The background features a "
|
| 18 |
+
"blurred neon light, adding an artistic and modern touch to the overall "
|
| 19 |
+
"aesthetic."
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
DEFAULT_NEGATIVE_PROMPT = (
|
| 23 |
+
"Bright tones, overexposed, static, blurred details, subtitles, style, "
|
| 24 |
+
"works, paintings, images, static, overall gray, worst quality, low quality, "
|
| 25 |
+
"JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn "
|
| 26 |
+
"hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused "
|
| 27 |
+
"fingers, still picture, messy background, three legs, many people in the "
|
| 28 |
+
"background, walking backwards"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _resolve_weights(repo_id: str, weights: str | None, local_dir: str) -> str:
|
| 33 |
+
if weights:
|
| 34 |
+
path = Path(weights).expanduser()
|
| 35 |
+
if path.exists():
|
| 36 |
+
return str(path.resolve())
|
| 37 |
+
raise FileNotFoundError(f"--weights does not exist: {path}")
|
| 38 |
+
|
| 39 |
+
from huggingface_hub import hf_hub_download
|
| 40 |
+
|
| 41 |
+
path = hf_hub_download(
|
| 42 |
+
repo_id=repo_id,
|
| 43 |
+
filename="transformer/diffusion_pytorch_model.safetensors",
|
| 44 |
+
local_dir=local_dir,
|
| 45 |
+
repo_type="model",
|
| 46 |
+
)
|
| 47 |
+
return str(Path(path).resolve())
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def main() -> int:
|
| 51 |
+
parser = argparse.ArgumentParser()
|
| 52 |
+
parser.add_argument("--repo-id", default="yitongl/sparse_quant_exp")
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--model-path",
|
| 55 |
+
default="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
|
| 56 |
+
help="Base Wan Diffusers model repo/path.",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument("--weights", default=None)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--local-dir",
|
| 61 |
+
default="checkpoints/hf_download/sparse_quant_exp",
|
| 62 |
+
help="Local Hugging Face download directory for the uploaded weights.",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument("--prompt", default=DEFAULT_PROMPT)
|
| 65 |
+
parser.add_argument("--negative-prompt", default=DEFAULT_NEGATIVE_PROMPT)
|
| 66 |
+
parser.add_argument("--output-path", default="outputs/sfp4_checkpoint_700")
|
| 67 |
+
parser.add_argument("--height", type=int, default=448)
|
| 68 |
+
parser.add_argument("--width", type=int, default=832)
|
| 69 |
+
parser.add_argument("--num-frames", type=int, default=77)
|
| 70 |
+
parser.add_argument("--num-inference-steps", type=int, default=50)
|
| 71 |
+
parser.add_argument("--fps", type=int, default=16)
|
| 72 |
+
parser.add_argument("--guidance-scale", type=float, default=5.0)
|
| 73 |
+
parser.add_argument("--flow-shift", type=float, default=1.0)
|
| 74 |
+
parser.add_argument("--seed", type=int, default=1000)
|
| 75 |
+
parser.add_argument("--vsa-sparsity", type=float, default=0.9)
|
| 76 |
+
parser.add_argument("--num-gpus", type=int, default=1)
|
| 77 |
+
parser.add_argument("--sp-size", type=int, default=1)
|
| 78 |
+
parser.add_argument("--tp-size", type=int, default=1)
|
| 79 |
+
parser.add_argument("--text-encoder-cpu-offload", action="store_true", default=True)
|
| 80 |
+
parser.add_argument("--pin-cpu-memory", action="store_true", default=False)
|
| 81 |
+
args = parser.parse_args()
|
| 82 |
+
|
| 83 |
+
os.environ.setdefault("FASTVIDEO_ATTENTION_BACKEND", "SPARSE_FP4_OURS_P_ATTN")
|
| 84 |
+
os.environ.setdefault("FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O", "1")
|
| 85 |
+
|
| 86 |
+
weights_path = _resolve_weights(args.repo_id, args.weights, args.local_dir)
|
| 87 |
+
|
| 88 |
+
from fastvideo import VideoGenerator
|
| 89 |
+
|
| 90 |
+
generator = VideoGenerator.from_pretrained(
|
| 91 |
+
model_path=args.model_path,
|
| 92 |
+
num_gpus=args.num_gpus,
|
| 93 |
+
sp_size=args.sp_size,
|
| 94 |
+
tp_size=args.tp_size,
|
| 95 |
+
init_weights_from_safetensors=weights_path,
|
| 96 |
+
dit_cpu_offload=False,
|
| 97 |
+
vae_cpu_offload=False,
|
| 98 |
+
text_encoder_cpu_offload=args.text_encoder_cpu_offload,
|
| 99 |
+
pin_cpu_memory=args.pin_cpu_memory,
|
| 100 |
+
flow_shift=args.flow_shift,
|
| 101 |
+
VSA_sparsity=args.vsa_sparsity,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
result = generator.generate_video(
|
| 105 |
+
prompt=args.prompt,
|
| 106 |
+
negative_prompt=args.negative_prompt,
|
| 107 |
+
output_path=args.output_path,
|
| 108 |
+
save_video=True,
|
| 109 |
+
return_frames=False,
|
| 110 |
+
height=args.height,
|
| 111 |
+
width=args.width,
|
| 112 |
+
num_frames=args.num_frames,
|
| 113 |
+
num_inference_steps=args.num_inference_steps,
|
| 114 |
+
fps=args.fps,
|
| 115 |
+
guidance_scale=args.guidance_scale,
|
| 116 |
+
seed=args.seed,
|
| 117 |
+
)
|
| 118 |
+
print(result)
|
| 119 |
+
return 0
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
raise SystemExit(main())
|
standalone_inference/training_attention_settings.json
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"run_name": "sfp4_v4_sparse09_hpo_on_ours_p_init2050_1n_interactive",
|
| 3 |
+
"checkpoint": "checkpoint-700",
|
| 4 |
+
"training_method": "legacy_sft_wan_training_pipeline",
|
| 5 |
+
"model_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
|
| 6 |
+
"init_weights_from_safetensors": "checkpoints/init/sfp4_v4_sparse06_hpo_on_ours_p_1n_interactive_v2_ckpt2050/transformer/diffusion_pytorch_model.safetensors",
|
| 7 |
+
"environment": {
|
| 8 |
+
"FASTVIDEO_ATTENTION_BACKEND": "SPARSE_FP4_OURS_P_ATTN",
|
| 9 |
+
"FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O": "1",
|
| 10 |
+
"FASTVIDEO_VALIDATION_ONE_PROMPT_PER_RANK": "1",
|
| 11 |
+
"WANDB_MODE": "online",
|
| 12 |
+
"WANDB_RESUME": "allow"
|
| 13 |
+
},
|
| 14 |
+
"vsa_schedule": {
|
| 15 |
+
"VSA_SPARSITY": 0.9,
|
| 16 |
+
"VSA_INIT_SPARSITY": 0.9,
|
| 17 |
+
"VSA_WARMUP_STEPS": 0,
|
| 18 |
+
"VSA_DECAY_RATE": 0.03,
|
| 19 |
+
"VSA_DECAY_INTERVAL_STEPS": 50,
|
| 20 |
+
"effective_sparsity_from_step_0": 0.9
|
| 21 |
+
},
|
| 22 |
+
"attention_semantics": {
|
| 23 |
+
"selected_backend": "SPARSE_FP4_OURS_P_ATTN",
|
| 24 |
+
"self_attention": {
|
| 25 |
+
"backend_path": "fastvideo/attention/backends/sparse_fp4_ours_p_attn.py",
|
| 26 |
+
"kernel_path": "fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py",
|
| 27 |
+
"tile_size_video": [4, 4, 4],
|
| 28 |
+
"tile_tokens": 64,
|
| 29 |
+
"qkv_quantization": "FP4 fake quantization with STE, no q/k mean subtraction in quantization",
|
| 30 |
+
"block_selection": "top-k blocks from q_c @ k_c tile-mean scores",
|
| 31 |
+
"p_quantization": "group-local exp2(qk - group_max) FP4 fake quantization; compensation multiplies exp2(group_max - running_row_m)",
|
| 32 |
+
"dropped_tile_handling": "tile-level q_mean/k_mean score and mean_v compensation"
|
| 33 |
+
},
|
| 34 |
+
"cross_attention": {
|
| 35 |
+
"backend": "dense_sdpa",
|
| 36 |
+
"reason": "sparse_fp4_ours_p_attn.py treats query_length != key_length as cross attention and returns _dense_sdpa_blhd",
|
| 37 |
+
"quantized": false,
|
| 38 |
+
"sparse": false
|
| 39 |
+
},
|
| 40 |
+
"force_dense": {
|
| 41 |
+
"backend": "dense_sdpa",
|
| 42 |
+
"used_for": "teacher or explicitly forced dense paths, not the normal SFT student self-attention path"
|
| 43 |
+
}
|
| 44 |
+
},
|
| 45 |
+
"validation_and_checkpointing": {
|
| 46 |
+
"save_steps": 50,
|
| 47 |
+
"eval_steps": 50,
|
| 48 |
+
"validation_sampling_steps": 50,
|
| 49 |
+
"validation_guidance_scale": 5.0,
|
| 50 |
+
"checkpoints_total_limit": 5,
|
| 51 |
+
"flow_shift": 1.0
|
| 52 |
+
},
|
| 53 |
+
"training_shape": {
|
| 54 |
+
"num_latent_t": 20,
|
| 55 |
+
"num_frames": 77,
|
| 56 |
+
"height": 448,
|
| 57 |
+
"width": 832,
|
| 58 |
+
"batch_size_per_gpu": 1,
|
| 59 |
+
"sp_size": 1,
|
| 60 |
+
"tp_size": 1
|
| 61 |
+
}
|
| 62 |
+
}
|