yitongl commited on
Commit
1d0c0cc
·
verified ·
1 Parent(s): 383bb79

Add standalone inference helper for sfp4 checkpoint-700

Browse files
Files changed (34) hide show
  1. standalone_inference/README.md +74 -0
  2. standalone_inference/__pycache__/install_overlay.cpython-313.pyc +0 -0
  3. standalone_inference/__pycache__/run_inference.cpython-313.pyc +0 -0
  4. standalone_inference/install_overlay.py +89 -0
  5. standalone_inference/manifest.sha256 +31 -0
  6. standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py +270 -0
  7. standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py +1155 -0
  8. standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py +250 -0
  9. standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py +80 -0
  10. standalone_inference/overlay_files/fastvideo/api/compat.py +503 -0
  11. standalone_inference/overlay_files/fastvideo/attention/backends/sparse_fp4_ours_p_attn.py +192 -0
  12. standalone_inference/overlay_files/fastvideo/attention/backends/video_sparse_attn.py +262 -0
  13. standalone_inference/overlay_files/fastvideo/configs/models/dits/base.py +79 -0
  14. standalone_inference/overlay_files/fastvideo/configs/pipelines/wan.py +203 -0
  15. standalone_inference/overlay_files/fastvideo/configs/sample/base.py +292 -0
  16. standalone_inference/overlay_files/fastvideo/configs/sample/wan.py +154 -0
  17. standalone_inference/overlay_files/fastvideo/configs/wan_1.3B_t2v_pipeline.json +40 -0
  18. standalone_inference/overlay_files/fastvideo/entrypoints/cli/generate.py +115 -0
  19. standalone_inference/overlay_files/fastvideo/entrypoints/video_generator.py +797 -0
  20. standalone_inference/overlay_files/fastvideo/fastvideo_args.py +1188 -0
  21. standalone_inference/overlay_files/fastvideo/forward_context.py +100 -0
  22. standalone_inference/overlay_files/fastvideo/pipelines/basic/wan/__init__.py +0 -0
  23. standalone_inference/overlay_files/fastvideo/pipelines/basic/wan/wan_pipeline.py +60 -0
  24. standalone_inference/overlay_files/fastvideo/pipelines/composed_pipeline_base.py +474 -0
  25. standalone_inference/overlay_files/fastvideo/pipelines/stages/denoising.py +1184 -0
  26. standalone_inference/overlay_files/fastvideo/platforms/cuda.py +440 -0
  27. standalone_inference/overlay_files/fastvideo/platforms/interface.py +255 -0
  28. standalone_inference/overlay_files/fastvideo/train/models/wan/wan.py +680 -0
  29. standalone_inference/overlay_files/fastvideo/training/training_pipeline.py +1044 -0
  30. standalone_inference/overlay_files/fastvideo/training/wan_training_pipeline.py +74 -0
  31. standalone_inference/requirements.txt +5 -0
  32. standalone_inference/run.sh +22 -0
  33. standalone_inference/run_inference.py +123 -0
  34. standalone_inference/training_attention_settings.json +62 -0
standalone_inference/README.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standalone Inference Helper
2
+
3
+ This folder contains a portable inference helper for:
4
+
5
+ `sfp4_v4_sparse09_hpo_on_ours_p_init2050_1n_interactive/checkpoint-700`
6
+
7
+ It is not a full vendored copy of Wan or FastVideo. It contains the sparse FP4
8
+ backend overlay and a runner that can be applied to a FastVideo checkout or
9
+ installation so the uploaded checkpoint can be used for normal inference.
10
+
11
+ ## Contents
12
+
13
+ - `run_inference.py`: downloads/loads `transformer/diffusion_pytorch_model.safetensors` from `yitongl/sparse_quant_exp` and runs `VideoGenerator`.
14
+ - `run.sh`: convenience wrapper that installs the overlay into `FASTVIDEO_ROOT` and then runs `run_inference.py`.
15
+ - `install_overlay.py`: copies the bundled sparse FP4 backend files into a FastVideo checkout/install.
16
+ - `overlay_files/`: exact runtime source files needed by `SPARSE_FP4_OURS_P_ATTN`.
17
+ - `training_attention_settings.json`: structured settings for the uploaded checkpoint.
18
+
19
+ ## Expected Environment
20
+
21
+ - A working FastVideo Python environment.
22
+ - FastVideo dependencies installed, including PyTorch, Triton, safetensors, and
23
+ Hugging Face Hub.
24
+ - Access to the base model `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`.
25
+ - A CUDA GPU supported by the custom Triton kernels.
26
+
27
+ ## Usage
28
+
29
+ From a machine with this HF repo downloaded:
30
+
31
+ ```bash
32
+ export FASTVIDEO_ROOT=/path/to/FastVideo
33
+ bash standalone_inference/run.sh \
34
+ --output-path outputs/sfp4_checkpoint_700 \
35
+ --seed 1000
36
+ ```
37
+
38
+ The script sets:
39
+
40
+ ```bash
41
+ FASTVIDEO_ATTENTION_BACKEND=SPARSE_FP4_OURS_P_ATTN
42
+ FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O=1
43
+ ```
44
+
45
+ and downloads the uploaded checkpoint-700 transformer weights unless `--weights`
46
+ is provided.
47
+
48
+ To use a local safetensors file:
49
+
50
+ ```bash
51
+ export FASTVIDEO_ROOT=/path/to/FastVideo
52
+ bash standalone_inference/run.sh \
53
+ --weights /path/to/diffusion_pytorch_model.safetensors \
54
+ --prompt "your prompt"
55
+ ```
56
+
57
+ ## Attention Semantics
58
+
59
+ - Self-attention uses `SPARSE_FP4_OURS_P_ATTN`.
60
+ - Q/K/V use FP4 fake quantization with STE.
61
+ - VSA tile size is `4 x 4 x 4 = 64` tokens.
62
+ - Selected sparse tiles use group-local P quantization in the Triton kernel.
63
+ - Dropped tiles use tile mean compensation.
64
+ - Cross-attention falls back to dense SDPA and is not sparse/FP4.
65
+
66
+ ## Checkpoint
67
+
68
+ The current HF `main` transformer file is checkpoint-700:
69
+
70
+ `transformer/diffusion_pytorch_model.safetensors`
71
+
72
+ Local SHA256 used when preparing this helper:
73
+
74
+ `4595ca81ea7085c15ccf14b738aa9c0fdf2d2786641f49b55e0bc0e99bf042d2`
standalone_inference/__pycache__/install_overlay.cpython-313.pyc ADDED
Binary file (4.48 kB). View file
 
standalone_inference/__pycache__/run_inference.cpython-313.pyc ADDED
Binary file (6.22 kB). View file
 
standalone_inference/install_overlay.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Install the sparse FP4 checkpoint-700 inference overlay into FastVideo.
3
+
4
+ The checkpoint depends on local FastVideo attention backend changes that are
5
+ not part of a vanilla install. This helper copies the bundled overlay files
6
+ into a FastVideo source checkout or site-packages installation.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import importlib.util
13
+ import shutil
14
+ import sys
15
+ from pathlib import Path
16
+
17
+
18
+ def _find_fastvideo_root() -> Path:
19
+ spec = importlib.util.find_spec("fastvideo")
20
+ if spec is None or spec.origin is None:
21
+ raise RuntimeError(
22
+ "Could not import fastvideo. Pass --fastvideo-root explicitly or "
23
+ "activate a FastVideo environment first.")
24
+ return Path(spec.origin).resolve().parents[1]
25
+
26
+
27
+ def _iter_overlay_files(overlay_root: Path):
28
+ for path in sorted(overlay_root.rglob("*")):
29
+ if path.is_file() and "__pycache__" not in path.parts:
30
+ yield path
31
+
32
+
33
+ def main() -> int:
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument(
36
+ "--fastvideo-root",
37
+ type=Path,
38
+ default=None,
39
+ help="FastVideo repository/install root. Defaults to import location.",
40
+ )
41
+ parser.add_argument(
42
+ "--backup",
43
+ action="store_true",
44
+ help="Write .sfp4_backup copies before overwriting existing files.",
45
+ )
46
+ parser.add_argument(
47
+ "--dry-run",
48
+ action="store_true",
49
+ help="Print files that would be copied without modifying anything.",
50
+ )
51
+ args = parser.parse_args()
52
+
53
+ bundle_root = Path(__file__).resolve().parent
54
+ overlay_root = bundle_root / "overlay_files"
55
+ if not overlay_root.is_dir():
56
+ raise RuntimeError(f"Missing overlay directory: {overlay_root}")
57
+
58
+ target_root = args.fastvideo_root.resolve() if args.fastvideo_root else _find_fastvideo_root()
59
+ if not (target_root / "fastvideo").exists():
60
+ raise RuntimeError(
61
+ f"{target_root} does not look like a FastVideo root: missing fastvideo/")
62
+
63
+ copied = 0
64
+ for src in _iter_overlay_files(overlay_root):
65
+ rel = src.relative_to(overlay_root)
66
+ dst = target_root / rel
67
+ print(f"{rel}")
68
+ if args.dry_run:
69
+ continue
70
+ dst.parent.mkdir(parents=True, exist_ok=True)
71
+ if args.backup and dst.exists():
72
+ backup = dst.with_suffix(dst.suffix + ".sfp4_backup")
73
+ if not backup.exists():
74
+ shutil.copy2(dst, backup)
75
+ shutil.copy2(src, dst)
76
+ copied += 1
77
+
78
+ if args.dry_run:
79
+ print(f"Dry run complete for target root: {target_root}")
80
+ else:
81
+ print(f"Installed {copied} files into {target_root}")
82
+ print(
83
+ "Use PYTHONPATH='<FastVideo>/fastvideo-kernel/python:"
84
+ "<FastVideo>/fastvideo-kernel:$PYTHONPATH' when running inference.")
85
+ return 0
86
+
87
+
88
+ if __name__ == "__main__":
89
+ raise SystemExit(main())
standalone_inference/manifest.sha256 ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fb13abe775d8acd0aa59ce47ebad40178e4f2604fd191b6b02c1e34dd1e95cc4 ./README.md
2
+ eb151afbefca213bbf1595e94b40547e1e431e850e6fc4cd187e506eb8e25b2d ./install_overlay.py
3
+ 9d1d8dc58aab529270fe31eb1735d6a1382c0c6d36fccca122a8dbffa1b714fd ./overlay_files/fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py
4
+ 211c7f0445fbe9488250f01fa83457c6620e83bd6f3877db791fd155de93c08b ./overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py
5
+ 3f3a407a88612ea17ad65e1b6b9cf6b7b02df56956d8301c4b13bffa92095016 ./overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py
6
+ 56f17c602dede53c7c3677058f81274681530f1b83c086d9d1d44c6b51feefbb ./overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py
7
+ 58f4ac013e6755336212a7a6c9948b19dab0dafc00f4a3298591598df270cb39 ./overlay_files/fastvideo/api/compat.py
8
+ 2b821b0e2e7bdb3581be6312ebbece42380a6ee28a7a982f0cf2dc71fab849c8 ./overlay_files/fastvideo/attention/backends/sparse_fp4_ours_p_attn.py
9
+ a97adcc52d7558c49f418c09395fd1665e988ad290d2276b95f21dfca0f8eb7d ./overlay_files/fastvideo/attention/backends/video_sparse_attn.py
10
+ 79ef6f38ec0f5bfe16b2b98327ad2ccd15f3c863dd87fd03affc5dbdaa0a8224 ./overlay_files/fastvideo/configs/models/dits/base.py
11
+ 4bda44746a3626551ea9a9380d890f036087092fb99fce2d302642cce14a97ed ./overlay_files/fastvideo/configs/pipelines/wan.py
12
+ 5926e29a594db13b116922f131db50631bf8adbf90fe5cec00a5e2f446bfb4ca ./overlay_files/fastvideo/configs/sample/base.py
13
+ d99adcf607d982b38bbb5a70be60bf87f35d0e9f6f50752f3bceb68b34ce46c2 ./overlay_files/fastvideo/configs/sample/wan.py
14
+ 49775ce42fd9643c78d8fad4ab8248c1755c7f1524ad771cbd1863d76c513c38 ./overlay_files/fastvideo/configs/wan_1.3B_t2v_pipeline.json
15
+ ae2d8309472b09927da3e450dea52d9715dcabe5d6722fc2917130ae8d85adb4 ./overlay_files/fastvideo/entrypoints/cli/generate.py
16
+ d0466769626e7fd497376c544904d56ba62847745eb52527896d96b99d76ba03 ./overlay_files/fastvideo/entrypoints/video_generator.py
17
+ 73afe6b2ebe0f8cfe0a8ec762a7126161621ad97a64ebad628995f4a164b8b0e ./overlay_files/fastvideo/fastvideo_args.py
18
+ ddcab6f4fd33c9813840571b6bf83bbbcea164b564166951ed4301297db6cef0 ./overlay_files/fastvideo/forward_context.py
19
+ e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 ./overlay_files/fastvideo/pipelines/basic/wan/__init__.py
20
+ deac1e22530a6a41c501629f5e8fce47a7af4e008f321cc8a4d734c5120ef4fe ./overlay_files/fastvideo/pipelines/basic/wan/wan_pipeline.py
21
+ 8908223b3ff99cdb3206148a68a730c2a13d554a2fb1316db6f2f9672efac9e8 ./overlay_files/fastvideo/pipelines/composed_pipeline_base.py
22
+ 6cfd128e782b7787a27ddd28a5e2d50cb4b0e2e9425d51d9780f14c91e8206f0 ./overlay_files/fastvideo/pipelines/stages/denoising.py
23
+ 489388dbdd9e5e3ad24db3012bd9b108794509a9729891d7dd315a102abba828 ./overlay_files/fastvideo/platforms/cuda.py
24
+ c046b1914041b59254bcdfe577aed20d6f007a72632ea1fe1ae92fa678eca760 ./overlay_files/fastvideo/platforms/interface.py
25
+ 2456d39ca28019e12bb7ab007774e86348f0582a017bf0e6c91e2a01d654a1a0 ./overlay_files/fastvideo/train/models/wan/wan.py
26
+ bc46e84b732567de6c0325223405daecd1226c623e303be33c7be9b5b7fdec08 ./overlay_files/fastvideo/training/training_pipeline.py
27
+ 1d3898fa37e21029df6c37e05dc34ed7805a211c2f87de6642db890e5a8c6f2e ./overlay_files/fastvideo/training/wan_training_pipeline.py
28
+ 1b2addfcb414ab65e20034394ee21a8af9ada58220a680b67d3b4233a0952268 ./requirements.txt
29
+ 5087bb4ffe5721c41a12d92d8dfe439cd86aa1a5d3b3d259e30ad62711d95081 ./run.sh
30
+ b826c8b059a000af6054ec099c36742d01e6a329ee77bc5936ae7562e9428409 ./run_inference.py
31
+ 8ddeea65247d9fa31a4a8a2a5ce5abe068a911ff4d67871453555e1355af8ecf ./training_attention_settings.json
standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ import torch
6
+
7
+
8
+ def _use_high_prec_output_for_backward() -> bool:
9
+ value = os.environ.get("FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O", "1")
10
+ return value.lower() not in ("0", "false", "no", "off")
11
+
12
+
13
+ def _map_to_index(block_map: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
14
+ if block_map.dim() == 3:
15
+ block_map = block_map.unsqueeze(0)
16
+ if block_map.dim() != 4:
17
+ raise ValueError(
18
+ f"block_map must be [B,H,Q,KV] or [H,Q,KV], got {tuple(block_map.shape)}"
19
+ )
20
+ if block_map.dtype != torch.bool:
21
+ block_map = block_map.to(torch.bool)
22
+ if not block_map.is_cuda:
23
+ raise RuntimeError("block_map must be a CUDA tensor.")
24
+
25
+ try:
26
+ from fastvideo_kernel.triton_kernels.index import map_to_index as triton_map_to_index
27
+ except Exception as e:
28
+ raise ImportError("Triton map_to_index is required for ours-P Sparse FP4.") from e
29
+ return triton_map_to_index(block_map)
30
+
31
+
32
+ @torch.library.custom_op(
33
+ "fastvideo_kernel::block_sparse_attn_ours_p_triton",
34
+ mutates_args=(),
35
+ device_types="cuda",
36
+ )
37
+ def block_sparse_attn_ours_p_triton(
38
+ q: torch.Tensor,
39
+ k: torch.Tensor,
40
+ v: torch.Tensor,
41
+ block_map: torch.Tensor,
42
+ variable_block_sizes: torch.Tensor,
43
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
44
+ q = q.contiguous()
45
+ k = k.contiguous()
46
+ v = v.contiguous()
47
+ block_map = block_map.to(torch.bool)
48
+ q2k_idx, q2k_num = _map_to_index(block_map)
49
+
50
+ from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
51
+ triton_block_sparse_attn_forward,
52
+ )
53
+
54
+ return triton_block_sparse_attn_forward(
55
+ q, k, v, q2k_idx, q2k_num, variable_block_sizes, is_qat=True
56
+ )
57
+
58
+
59
+ @torch.library.register_fake("fastvideo_kernel::block_sparse_attn_ours_p_triton")
60
+ def _block_sparse_attn_ours_p_triton_fake(
61
+ q: torch.Tensor,
62
+ k: torch.Tensor,
63
+ v: torch.Tensor,
64
+ block_map: torch.Tensor,
65
+ variable_block_sizes: torch.Tensor,
66
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
67
+ o = torch.empty_like(q)
68
+ high_prec_o = torch.empty_like(q)
69
+ M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
70
+ return o, M, high_prec_o
71
+
72
+
73
+ @torch.library.custom_op(
74
+ "fastvideo_kernel::block_sparse_attn_ours_p_backward_triton",
75
+ mutates_args=(),
76
+ device_types="cuda",
77
+ )
78
+ def block_sparse_attn_ours_p_backward_triton(
79
+ grad_output: torch.Tensor,
80
+ q: torch.Tensor,
81
+ k: torch.Tensor,
82
+ v: torch.Tensor,
83
+ o: torch.Tensor,
84
+ M: torch.Tensor,
85
+ block_map: torch.Tensor,
86
+ variable_block_sizes: torch.Tensor,
87
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
88
+ grad_output = grad_output.contiguous()
89
+ block_map = block_map.to(torch.bool)
90
+ q2k_idx, q2k_num = _map_to_index(block_map)
91
+ k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous())
92
+
93
+ from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
94
+ triton_block_sparse_attn_backward,
95
+ )
96
+
97
+ return triton_block_sparse_attn_backward(
98
+ grad_output,
99
+ q,
100
+ k,
101
+ v,
102
+ o,
103
+ M,
104
+ q2k_idx,
105
+ q2k_num,
106
+ k2q_idx,
107
+ k2q_num,
108
+ variable_block_sizes,
109
+ is_qat=True,
110
+ )
111
+
112
+
113
+ @torch.library.register_fake(
114
+ "fastvideo_kernel::block_sparse_attn_ours_p_backward_triton"
115
+ )
116
+ def _block_sparse_attn_ours_p_backward_triton_fake(
117
+ grad_output: torch.Tensor,
118
+ q: torch.Tensor,
119
+ k: torch.Tensor,
120
+ v: torch.Tensor,
121
+ o: torch.Tensor,
122
+ M: torch.Tensor,
123
+ block_map: torch.Tensor,
124
+ variable_block_sizes: torch.Tensor,
125
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
126
+ return torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
127
+
128
+
129
+ def _backward_triton(ctx, grad_o, grad_M, grad_high_prec_o):
130
+ q, k, v, o_for_bwd, M, block_map, variable_block_sizes = ctx.saved_tensors
131
+ dq, dk, dv = block_sparse_attn_ours_p_backward_triton(
132
+ grad_o, q, k, v, o_for_bwd, M, block_map, variable_block_sizes
133
+ )
134
+ return dq, dk, dv, None, None
135
+
136
+
137
+ def _setup_context_triton(ctx, inputs, output):
138
+ q, k, v, block_map, variable_block_sizes = inputs
139
+ o, M, high_prec_o = output
140
+ o_for_bwd = high_prec_o if _use_high_prec_output_for_backward() else o
141
+ ctx.save_for_backward(q, k, v, o_for_bwd, M, block_map, variable_block_sizes)
142
+
143
+
144
+ block_sparse_attn_ours_p_triton.register_autograd(
145
+ _backward_triton, setup_context=_setup_context_triton
146
+ )
147
+
148
+
149
+ class _BlockSparseAttnOursPTileComp(torch.autograd.Function):
150
+
151
+ @staticmethod
152
+ def forward(ctx, q, k, v, q_mean, k_mean, v_mean, block_map, variable_block_sizes):
153
+ q = q.contiguous()
154
+ k = k.contiguous()
155
+ v = v.contiguous()
156
+ q_mean = q_mean.contiguous()
157
+ k_mean = k_mean.contiguous()
158
+ v_mean = v_mean.contiguous()
159
+ block_map = block_map.to(torch.bool)
160
+ dropped_block_map = torch.logical_not(block_map)
161
+
162
+ q2k_idx, q2k_num = _map_to_index(block_map)
163
+ dropped_q2k_idx, dropped_q2k_num = _map_to_index(dropped_block_map)
164
+
165
+ from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
166
+ triton_block_sparse_attn_forward,
167
+ )
168
+
169
+ o, M, high_prec_o = triton_block_sparse_attn_forward(
170
+ q,
171
+ k,
172
+ v,
173
+ q2k_idx,
174
+ q2k_num,
175
+ variable_block_sizes,
176
+ is_qat=True,
177
+ q_mean=q_mean,
178
+ k_mean=k_mean,
179
+ v_mean=v_mean,
180
+ dropped_q2k_index=dropped_q2k_idx,
181
+ dropped_q2k_num=dropped_q2k_num,
182
+ )
183
+ o_for_bwd = high_prec_o if _use_high_prec_output_for_backward() else o
184
+ ctx.save_for_backward(
185
+ q,
186
+ k,
187
+ v,
188
+ q_mean,
189
+ k_mean,
190
+ v_mean,
191
+ o_for_bwd,
192
+ M,
193
+ block_map,
194
+ dropped_block_map,
195
+ variable_block_sizes,
196
+ )
197
+ return o, M
198
+
199
+ @staticmethod
200
+ def backward(ctx, grad_o, grad_M):
201
+ (
202
+ q,
203
+ k,
204
+ v,
205
+ q_mean,
206
+ k_mean,
207
+ v_mean,
208
+ o_for_bwd,
209
+ M,
210
+ block_map,
211
+ dropped_block_map,
212
+ variable_block_sizes,
213
+ ) = ctx.saved_tensors
214
+
215
+ q2k_idx, q2k_num = _map_to_index(block_map)
216
+ k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous())
217
+ dropped_q2k_idx, dropped_q2k_num = _map_to_index(dropped_block_map)
218
+ dropped_k2q_idx, dropped_k2q_num = _map_to_index(
219
+ dropped_block_map.transpose(-1, -2).contiguous()
220
+ )
221
+
222
+ from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
223
+ triton_block_sparse_attn_backward,
224
+ )
225
+
226
+ dq, dk, dv = triton_block_sparse_attn_backward(
227
+ grad_o.contiguous(),
228
+ q,
229
+ k,
230
+ v,
231
+ o_for_bwd,
232
+ M,
233
+ q2k_idx,
234
+ q2k_num,
235
+ k2q_idx,
236
+ k2q_num,
237
+ variable_block_sizes,
238
+ is_qat=True,
239
+ q_mean=q_mean,
240
+ k_mean=k_mean,
241
+ v_mean=v_mean,
242
+ dropped_q2k_index=dropped_q2k_idx,
243
+ dropped_q2k_num=dropped_q2k_num,
244
+ dropped_k2q_index=dropped_k2q_idx,
245
+ dropped_k2q_num=dropped_k2q_num,
246
+ )
247
+ return dq, dk, dv, None, None, None, None, None
248
+
249
+
250
+ def block_sparse_attn_ours_p(
251
+ q: torch.Tensor,
252
+ k: torch.Tensor,
253
+ v: torch.Tensor,
254
+ block_map: torch.Tensor,
255
+ variable_block_sizes: torch.Tensor,
256
+ q_mean: torch.Tensor | None = None,
257
+ k_mean: torch.Tensor | None = None,
258
+ v_mean: torch.Tensor | None = None,
259
+ ) -> tuple[torch.Tensor, torch.Tensor]:
260
+ if (q_mean is not None) or (k_mean is not None) or (v_mean is not None):
261
+ if q_mean is None or k_mean is None or v_mean is None:
262
+ raise ValueError("q_mean, k_mean, and v_mean must be provided together")
263
+ return _BlockSparseAttnOursPTileComp.apply(
264
+ q, k, v, q_mean, k_mean, v_mean, block_map, variable_block_sizes
265
+ )
266
+
267
+ o, M, _ = block_sparse_attn_ours_p_triton(
268
+ q, k, v, block_map, variable_block_sizes
269
+ )
270
+ return o, M
standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py ADDED
@@ -0,0 +1,1155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fused Attention
3
+ ===============
4
+
5
+ This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
6
+ (https://tridao.me/publications/flash2/flash2.pdf)
7
+
8
+ Credits: OpenAI kernel team
9
+ """
10
+
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+ from .quant_utils import fake_quantize
15
+
16
+ # ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
17
+ import math # small utility needed by the sparse wrapper
18
+ # ──────────────────────────── SPARSE ADDITION END ─────────────────────────────
19
+
20
+ # We don't run auto-tuning every time to keep the tutorial fast. Keeping
21
+ # the code below and commenting out the equivalent parameters is convenient for
22
+ # re-tuning.
23
+ configs = [
24
+ triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
25
+ for BM in [64]\
26
+ for BN in [64]\
27
+ for s in [3, 4, 7]\
28
+ for w in [4, 8]\
29
+ ]
30
+
31
+
32
+ # ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
33
+ @triton.autotune(configs, key=["N_CTX_Q", "HEAD_DIM"])
34
+ @triton.jit
35
+ def _attn_fwd_sparse(
36
+ Q,
37
+ K,
38
+ V,
39
+ QMean,
40
+ KMean,
41
+ VMean,
42
+ sm_scale, #
43
+ q2k_index,
44
+ q2k_num,
45
+ max_kv_blks, #
46
+ dropped_q2k_index,
47
+ dropped_q2k_num,
48
+ max_dropped_kv_blks, #
49
+ variable_block_sizes,
50
+ M,
51
+ Out, #
52
+ HighPrecOut, #
53
+ stride_qz,
54
+ stride_qh,
55
+ stride_qm,
56
+ stride_qk,
57
+ stride_kz,
58
+ stride_kh,
59
+ stride_kn,
60
+ stride_kk,
61
+ stride_vz,
62
+ stride_vh,
63
+ stride_vk,
64
+ stride_vn,
65
+ stride_oz,
66
+ stride_oh,
67
+ stride_om,
68
+ stride_on,
69
+ Z,
70
+ H,
71
+ N_CTX_Q, #
72
+ N_CTX_KV, #
73
+ HEAD_DIM: tl.constexpr, #
74
+ BLOCK_M: tl.constexpr,
75
+ BLOCK_N: tl.constexpr,
76
+ STAGE: tl.constexpr,
77
+ IS_QAT: tl.constexpr = False,
78
+ USE_TILE_COMP: tl.constexpr = False):
79
+ """
80
+ 64x64 block-sparse forward kernel for the independent "ours P quant" path.
81
+
82
+ P quantization is group-local: each selected KV tile quantizes
83
+ exp2(logit - tile_row_max), then applies exp2(tile_row_max - online_max)
84
+ after the FP4 PV GEMM. This intentionally differs from the QAT-style
85
+ backend, which quantizes exp2(logit - online_max) directly.
86
+ """
87
+
88
+ # ----- program-id mapping -----
89
+ q_blk = tl.program_id(0) # Q-tile index
90
+ off_hz = tl.program_id(1) # fused (batch, head)
91
+ b = off_hz // H
92
+ h = off_hz % H
93
+ q_tiles = N_CTX_Q // BLOCK_M
94
+ meta_base = ((b * H + h) * q_tiles + q_blk)
95
+
96
+ kv_blocks = tl.load(q2k_num + meta_base) # int32
97
+ kv_ptr = q2k_index + meta_base * max_kv_blks # ptr to list
98
+ dropped_kv_blocks = tl.load(dropped_q2k_num + meta_base)
99
+ dropped_kv_ptr = dropped_q2k_index + meta_base * max_dropped_kv_blks
100
+
101
+ # ----- base pointers -----
102
+ q_off = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
103
+ k_off = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
104
+ v_off = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
105
+ o_off = (b.to(tl.int64) * stride_oz + h.to(tl.int64) * stride_oh)
106
+
107
+ Q_ptr = tl.make_block_ptr(base=Q + q_off,
108
+ shape=(N_CTX_Q, HEAD_DIM),
109
+ strides=(stride_qm, stride_qk),
110
+ offsets=(q_blk * BLOCK_M, 0),
111
+ block_shape=(BLOCK_M, HEAD_DIM),
112
+ order=(1, 0))
113
+
114
+ K_base = tl.make_block_ptr(base=K + k_off,
115
+ shape=(HEAD_DIM, N_CTX_KV),
116
+ strides=(stride_kk, stride_kn),
117
+ offsets=(0, 0),
118
+ block_shape=(HEAD_DIM, BLOCK_N),
119
+ order=(0, 1))
120
+
121
+ v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1,
122
+ 0)
123
+ V_base = tl.make_block_ptr(base=V + v_off,
124
+ shape=(N_CTX_KV, HEAD_DIM),
125
+ strides=(stride_vk, stride_vn),
126
+ offsets=(0, 0),
127
+ block_shape=(BLOCK_N, HEAD_DIM),
128
+ order=v_order)
129
+
130
+ O_ptr = tl.make_block_ptr(base=Out + o_off,
131
+ shape=(N_CTX_Q, HEAD_DIM),
132
+ strides=(stride_om, stride_on),
133
+ offsets=(q_blk * BLOCK_M, 0),
134
+ block_shape=(BLOCK_M, HEAD_DIM),
135
+ order=(1, 0))
136
+ HPO_ptr = tl.make_block_ptr(base=HighPrecOut + o_off,
137
+ shape=(N_CTX_Q, HEAD_DIM),
138
+ strides=(stride_om, stride_on),
139
+ offsets=(q_blk * BLOCK_M, 0),
140
+ block_shape=(BLOCK_M, HEAD_DIM),
141
+ order=(1, 0))
142
+
143
+ # ----- accumulators -----
144
+ offs_m = q_blk * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ m_i = tl.full([BLOCK_M], -float("inf"), tl.float32)
146
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
147
+ acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
148
+ high_prec_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
149
+ qk_scale = sm_scale * 1.44269504 # 1/ln2
150
+ q = tl.load(Q_ptr)
151
+ offs_d = tl.arange(0, HEAD_DIM)
152
+
153
+ # ----- sparse loop over valid K/V tiles -----
154
+ for i in range(0, kv_blocks):
155
+ kv_idx = tl.load(kv_ptr + i).to(tl.int32)
156
+ block_size = tl.load(variable_block_sizes + kv_idx)
157
+ K_ptr = tl.advance(K_base, (0, kv_idx * BLOCK_N))
158
+ V_ptr = tl.advance(V_base, (kv_idx * BLOCK_N, 0))
159
+
160
+ k = tl.load(K_ptr)
161
+ mask = tl.arange(0, BLOCK_N) < block_size
162
+ qk = tl.dot(q, k) * qk_scale
163
+ # mask out invalid columns
164
+ qk = tl.where(mask[None, :], qk, -float("inf"))
165
+ group_m = tl.max(qk, 1)
166
+ m_ij = tl.maximum(m_i, group_m)
167
+
168
+ p_local = tl.math.exp2(qk - group_m[:, None])
169
+ p_local = tl.where(mask[None, :], p_local, 0.0)
170
+ p_comp = tl.math.exp2(group_m - m_ij)
171
+ p_valid = mask[None, :] & (
172
+ tl.full(shape=p_local.shape, value=1.0,
173
+ dtype=p_local.dtype) == 1.0
174
+ )
175
+ p_quant, high_prec_p = fake_quantize(
176
+ src_tensor=p_local, valid_src_mask=p_valid,
177
+ BLOCK_SIZE_OUT_DIM=BLOCK_M, BLOCK_SIZE_QUANT_DIM=BLOCK_N,
178
+ dst_dtype=tl.bfloat16, use_global_sf=False,
179
+ )
180
+ l_ij = tl.sum(high_prec_p, 1) * p_comp
181
+
182
+ alpha = tl.math.exp2(m_i - m_ij)
183
+ l_i = l_i * alpha + l_ij
184
+ acc = acc * alpha[:, None]
185
+ high_prec_acc = high_prec_acc * alpha[:, None]
186
+
187
+ v = tl.load(V_ptr)
188
+ acc = acc + tl.dot(
189
+ p_quant.to(tl.bfloat16),
190
+ v.to(tl.bfloat16),
191
+ ) * p_comp[:, None]
192
+ high_prec_acc = high_prec_acc + tl.dot(
193
+ high_prec_p.to(tl.bfloat16),
194
+ v.to(tl.bfloat16),
195
+ ) * p_comp[:, None]
196
+ m_i = m_ij
197
+
198
+ if USE_TILE_COMP:
199
+ q_mean_base = (off_hz * q_tiles + q_blk).to(tl.int64) * HEAD_DIM
200
+ q_mean = tl.load(QMean + q_mean_base + offs_d).to(tl.float32)
201
+ kv_tiles = N_CTX_KV // BLOCK_N
202
+
203
+ for i in range(0, dropped_kv_blocks):
204
+ kv_idx = tl.load(dropped_kv_ptr + i).to(tl.int32)
205
+ block_size = tl.load(variable_block_sizes + kv_idx).to(tl.float32)
206
+ kv_mean_base = (off_hz * kv_tiles + kv_idx).to(tl.int64) * HEAD_DIM
207
+ k_mean = tl.load(KMean + kv_mean_base + offs_d).to(tl.float32)
208
+ v_mean = tl.load(VMean + kv_mean_base + offs_d).to(tl.float32)
209
+
210
+ score = tl.sum(q_mean * k_mean, axis=0) * qk_scale
211
+ m_ij = tl.maximum(m_i, score)
212
+ alpha = tl.math.exp2(m_i - m_ij)
213
+ beta = tl.math.exp2(score - m_ij)
214
+
215
+ l_i = l_i * alpha + block_size * beta
216
+ comp = (block_size * beta)[:, None] * v_mean[None, :]
217
+ acc = acc * alpha[:, None] + comp
218
+ high_prec_acc = high_prec_acc * alpha[:, None] + comp
219
+ m_i = m_ij
220
+
221
+ # ----- epilogue -----
222
+ m_i += tl.math.log2(l_i)
223
+ acc = acc / l_i[:, None]
224
+ high_prec_acc = high_prec_acc / l_i[:, None]
225
+ tl.store(M + off_hz * N_CTX_Q + offs_m, m_i)
226
+ tl.store(O_ptr, acc.to(Out.type.element_ty))
227
+ tl.store(HPO_ptr, high_prec_acc.to(HighPrecOut.type.element_ty))
228
+
229
+
230
+ # ──────────────────────────── SPARSE ADDITION END ─────────────────────────────
231
+
232
+
233
+ @triton.jit
234
+ def _attn_bwd_preprocess(
235
+ O,
236
+ DO, #
237
+ Delta, #
238
+ Z,
239
+ H,
240
+ N_CTX, #
241
+ BLOCK_M: tl.constexpr,
242
+ HEAD_DIM: tl.constexpr #
243
+ ):
244
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
245
+ off_hz = tl.program_id(1)
246
+ off_n = tl.arange(0, HEAD_DIM)
247
+ # load
248
+ o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM +
249
+ off_n[None, :])
250
+ do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM +
251
+ off_n[None, :]).to(tl.float32)
252
+ delta = tl.sum(o * do, axis=1)
253
+ # write-back
254
+ tl.store(Delta + off_hz * N_CTX + off_m, delta)
255
+
256
+
257
+ # The main inner-loop logic for computing dK and dV.
258
+ @triton.jit
259
+ def _attn_bwd_dkdv(
260
+ dk,
261
+ dv, #
262
+ Q,
263
+ k,
264
+ v,
265
+ QMean,
266
+ KMean,
267
+ VMean,
268
+ sm_scale, #
269
+ DO, #
270
+ M,
271
+ D, #
272
+ k2q_index,
273
+ k2q_num,
274
+ max_q_blks,
275
+ dropped_k2q_index,
276
+ dropped_k2q_num,
277
+ max_dropped_q_blks,
278
+ variable_block_sizes,
279
+ # shared by Q/K/V/DO.
280
+ stride_tok,
281
+ stride_d, #
282
+ H,
283
+ N_CTX_KV,
284
+ BLOCK_M1: tl.constexpr, #
285
+ BLOCK_N1: tl.constexpr, #
286
+ HEAD_DIM: tl.constexpr, #
287
+ # Filled in by the wrapper.
288
+ start_n,
289
+ start_m,
290
+ num_steps,
291
+ IS_QAT: tl.constexpr = False,
292
+ USE_TILE_COMP: tl.constexpr = False):
293
+ offs_m = start_m + tl.arange(0, BLOCK_M1)
294
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
295
+ offs_k = tl.arange(0, HEAD_DIM)
296
+ qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
297
+ do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
298
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
299
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
300
+ step_m = BLOCK_M1
301
+ kv_blk = tl.program_id(0) # Q-tile index
302
+ off_hz = tl.program_id(2) # fused (batch, head)
303
+ b = off_hz // H
304
+ h = off_hz % H
305
+ kv_tiles = N_CTX_KV // BLOCK_N1
306
+ meta_base = ((b * H + h) * kv_tiles + kv_blk)
307
+
308
+ q_blocks = tl.load(k2q_num + meta_base) # int32
309
+ q_ptr = k2q_index + meta_base * max_q_blks # ptr to list
310
+ dropped_q_blocks = tl.load(dropped_k2q_num + meta_base)
311
+ dropped_q_ptr = dropped_k2q_index + meta_base * max_dropped_q_blks
312
+ block_size = tl.load(variable_block_sizes + kv_blk)
313
+ block_size_f = block_size.to(tl.float32)
314
+
315
+ for blk_idx in range(q_blocks * 2):
316
+ block_sparse_offset = (tl.load(q_ptr + blk_idx // 2).to(tl.int32) * 2 +
317
+ blk_idx % 2) * step_m
318
+ qT = tl.load(qT_ptrs + block_sparse_offset * stride_tok)
319
+ # Load m before computing qk to reduce pipeline stall.
320
+ offs_m = start_m + block_sparse_offset + tl.arange(0, BLOCK_M1)
321
+ m = tl.load(M + offs_m)
322
+ qkT = tl.dot(k.to(tl.bfloat16), qT)
323
+ qkT = qkT * sm_scale * 1.44269504
324
+ mask = tl.arange(0, BLOCK_N1) < block_size
325
+ qkT = tl.where(mask[:, None], qkT, -float("inf"))
326
+ group_m = tl.max(qkT, 0)
327
+ pT = tl.math.exp2(qkT - m[None, :])
328
+ pT = tl.where(mask[:, None], pT, 0.0)
329
+
330
+ do = tl.load(do_ptrs + block_sparse_offset * stride_tok)
331
+ # Compute dV with group-local P quantization:
332
+ # quantize exp2(logit - tile_col_max), then multiply dO by
333
+ # exp2(tile_col_max - final_lse) to recover the final softmax scale.
334
+ p_local_T = tl.math.exp2(qkT - group_m[None, :])
335
+ p_local_T = tl.where(mask[:, None], p_local_T, 0.0)
336
+ p_comp = tl.math.exp2(group_m - m)
337
+ p_for_quant = tl.trans(p_local_T)
338
+ p_valid = mask[None, :] & (
339
+ tl.full(
340
+ shape=p_for_quant.shape,
341
+ value=1.0,
342
+ dtype=p_for_quant.dtype,
343
+ ) == 1.0
344
+ )
345
+ p_quant, _ = fake_quantize(
346
+ src_tensor=p_for_quant, valid_src_mask=p_valid,
347
+ BLOCK_SIZE_OUT_DIM=BLOCK_M1, BLOCK_SIZE_QUANT_DIM=BLOCK_N1,
348
+ dst_dtype=p_for_quant.dtype, use_global_sf=False,
349
+ )
350
+ dv += tl.dot(
351
+ tl.trans(p_quant.to(tl.bfloat16)),
352
+ (do * p_comp[:, None]).to(tl.bfloat16),
353
+ )
354
+ # D (= delta) is pre-divided by ds_scale.
355
+ Di = tl.load(D + offs_m)
356
+ # Compute dP and dS.
357
+ dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
358
+ dsT = pT * (dpT - Di[None, :])
359
+ dsT = dsT.to(tl.bfloat16)
360
+ dk += tl.dot(dsT, tl.trans(qT))
361
+ # Increment pointers.
362
+
363
+ if USE_TILE_COMP:
364
+ k_mean = tl.load(KMean + kv_blk * HEAD_DIM + offs_k).to(tl.float32)
365
+ v_mean = tl.load(VMean + kv_blk * HEAD_DIM + offs_k).to(tl.float32)
366
+ qk_scale = sm_scale * 1.44269504
367
+
368
+ for blk_idx in range(dropped_q_blocks * 2):
369
+ q_blk_idx = tl.load(dropped_q_ptr + blk_idx // 2).to(tl.int32)
370
+ half = (blk_idx % 2).to(tl.int32)
371
+ block_sparse_offset = (q_blk_idx * 2 + half) * step_m
372
+ offs_m = start_m + block_sparse_offset + tl.arange(0, BLOCK_M1)
373
+ q_mean = tl.load(QMean + q_blk_idx * HEAD_DIM +
374
+ offs_k).to(tl.float32)
375
+ m = tl.load(M + offs_m)
376
+ do = tl.load(do_ptrs + block_sparse_offset * stride_tok)
377
+ Di = tl.load(D + offs_m)
378
+ q_block_size = tl.load(variable_block_sizes +
379
+ q_blk_idx).to(tl.float32)
380
+
381
+ score = tl.sum(q_mean * k_mean, axis=0) * qk_scale
382
+ p = tl.math.exp2(score - m)
383
+ dp = tl.sum(do.to(tl.float32) * v_mean[None, :], axis=1)
384
+ ds = block_size_f * p * (dp - Di)
385
+
386
+ dk_mean = tl.sum(ds[:, None] * q_mean[None, :],
387
+ axis=0) / block_size_f
388
+ dv_mean = tl.sum(p[:, None] * do.to(tl.float32), axis=0)
389
+ dk += dk_mean[None, :]
390
+ dv += dv_mean[None, :]
391
+ return dk, dv
392
+
393
+
394
+ # the main inner-loop logic for computing dQ
395
+ @triton.jit
396
+ def _attn_bwd_dq(
397
+ dq,
398
+ q,
399
+ K,
400
+ V, #
401
+ QMean,
402
+ KMean,
403
+ VMean,
404
+ do,
405
+ m,
406
+ m_vec,
407
+ D,
408
+ # shared by Q/K/V/DO.
409
+ q2k_index,
410
+ q2k_num,
411
+ max_kv_blks,
412
+ dropped_q2k_index,
413
+ dropped_q2k_num,
414
+ max_dropped_kv_blks,
415
+ variable_block_sizes,
416
+ stride_tok,
417
+ stride_d, #
418
+ H,
419
+ N_CTX, #
420
+ BLOCK_M2: tl.constexpr, #
421
+ BLOCK_N2: tl.constexpr, #
422
+ HEAD_DIM: tl.constexpr,
423
+ # Filled in by the wrapper.
424
+ start_m,
425
+ start_n,
426
+ num_steps,
427
+ sm_scale=1.0,
428
+ IS_QAT: tl.constexpr = False,
429
+ USE_TILE_COMP: tl.constexpr = False):
430
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
431
+ offs_n = start_n + tl.arange(0, BLOCK_N2)
432
+ offs_k = tl.arange(0, HEAD_DIM)
433
+ kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
434
+ vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
435
+ # D (= delta) is pre-divided by ds_scale.
436
+ Di = tl.load(D + offs_m)
437
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
438
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
439
+ step_n = BLOCK_N2
440
+
441
+ q_blk = tl.program_id(0) # Q-tile index
442
+ off_hz = tl.program_id(2) # fused (batch, head)
443
+ b = off_hz // H
444
+ h = off_hz % H
445
+ q_tiles = N_CTX // BLOCK_M2
446
+ meta_base = ((b * H + h) * q_tiles + q_blk)
447
+
448
+ kv_blocks = tl.load(q2k_num + meta_base) # int32
449
+ kv_ptr = q2k_index + meta_base * max_kv_blks # ptr to list
450
+ dropped_kv_blocks = tl.load(dropped_q2k_num + meta_base)
451
+ dropped_kv_ptr = dropped_q2k_index + meta_base * max_dropped_kv_blks
452
+
453
+ for blk_idx in range(kv_blocks * 2):
454
+ kv_idx = tl.load(kv_ptr + blk_idx // 2).to(tl.int32)
455
+ # variable_block_sizes is defined per KV block (tile). Mask must therefore
456
+ # use kv_idx (not q_blk). Also, because we split each 64-token block into
457
+ # two 32-token halves, the mask must account for the half-block offset.
458
+ block_size = tl.load(variable_block_sizes + kv_idx).to(tl.int32)
459
+ half = (blk_idx % 2).to(tl.int32)
460
+ block_sparse_offset = (kv_idx * 2 + half) * step_n * stride_tok
461
+ kT = tl.load(kT_ptrs + block_sparse_offset)
462
+ vT = tl.load(vT_ptrs + block_sparse_offset)
463
+ qk = tl.dot(q, kT)
464
+ qk = qk * sm_scale * 1.44269504
465
+ p = tl.math.exp2(qk - m)
466
+ offs_in_block = half * step_n + tl.arange(0, BLOCK_N2)
467
+ mask = offs_in_block < block_size
468
+ p = tl.where(mask[None, :], p, 0.0)
469
+ # Compute dP and dS.
470
+ dp = tl.dot(do, vT).to(tl.float32)
471
+ ds = p * (dp - Di[:, None])
472
+ ds = ds.to(tl.bfloat16)
473
+ # Compute dQ.
474
+ # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
475
+ dq += tl.dot(ds, tl.trans(kT))
476
+ # Increment pointers.
477
+
478
+ if USE_TILE_COMP:
479
+ q_mean = tl.load(QMean + q_blk * HEAD_DIM + offs_k).to(tl.float32)
480
+ q_block_size = tl.load(variable_block_sizes + q_blk).to(tl.float32)
481
+ qk_scale = sm_scale * 1.44269504
482
+ dq_mean = tl.zeros([HEAD_DIM], dtype=tl.float32)
483
+
484
+ for blk_idx in range(dropped_kv_blocks):
485
+ kv_idx = tl.load(dropped_kv_ptr + blk_idx).to(tl.int32)
486
+ block_size = tl.load(variable_block_sizes + kv_idx).to(tl.float32)
487
+ k_mean = tl.load(KMean + kv_idx * HEAD_DIM +
488
+ offs_k).to(tl.float32)
489
+ v_mean = tl.load(VMean + kv_idx * HEAD_DIM +
490
+ offs_k).to(tl.float32)
491
+
492
+ score = tl.sum(q_mean * k_mean, axis=0) * qk_scale
493
+ p = tl.math.exp2(score - m_vec)
494
+ dp = tl.sum(do.to(tl.float32) * v_mean[None, :], axis=1)
495
+ ds = block_size * p * (dp - Di)
496
+ dq_mean = dq_mean + tl.sum(ds, axis=0) * k_mean
497
+
498
+ dq += dq_mean[None, :] / q_block_size
499
+ return dq
500
+
501
+
502
+ @triton.jit
503
+ def _attn_bwd(
504
+ Q,
505
+ K,
506
+ V,
507
+ sm_scale, #
508
+ DO, #
509
+ DQ,
510
+ DK,
511
+ DV, #
512
+ M,
513
+ D,
514
+ q2k_index,
515
+ q2k_num,
516
+ max_kv_blks,
517
+ k2q_index,
518
+ k2q_num,
519
+ max_q_blks,
520
+ variable_block_sizes,
521
+ # shared by Q/K/V/DO.
522
+ stride_z,
523
+ stride_h,
524
+ stride_tok,
525
+ stride_d, #
526
+ H,
527
+ N_CTX, #
528
+ BLOCK_M1: tl.constexpr, #
529
+ BLOCK_N1: tl.constexpr, #
530
+ BLOCK_M2: tl.constexpr, #
531
+ BLOCK_N2: tl.constexpr, #
532
+ HEAD_DIM: tl.constexpr,
533
+ IS_QAT: tl.constexpr = False):
534
+ LN2 = 0.6931471824645996 # = ln(2)
535
+
536
+ bhid = tl.program_id(2)
537
+ off_chz = (bhid * N_CTX).to(tl.int64)
538
+ adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
539
+ pid = tl.program_id(0)
540
+
541
+ # offset pointers for batch/head
542
+ Q += adj
543
+ K += adj
544
+ V += adj
545
+ DO += adj
546
+ DQ += adj
547
+ DK += adj
548
+ DV += adj
549
+ M += off_chz
550
+ D += off_chz
551
+
552
+ # load scales
553
+ offs_k = tl.arange(0, HEAD_DIM)
554
+
555
+ start_n = pid * BLOCK_N1
556
+ start_m = 0
557
+
558
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
559
+
560
+ dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
561
+ dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
562
+
563
+ # load K and V: they stay in SRAM throughout the inner loop.
564
+ k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
565
+ v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
566
+
567
+ num_steps = N_CTX // BLOCK_M1
568
+
569
+ dk, dv = _attn_bwd_dkdv( #
570
+ dk,
571
+ dv, #
572
+ Q,
573
+ k,
574
+ v,
575
+ Q,
576
+ K,
577
+ V,
578
+ sm_scale, #
579
+ DO, #
580
+ M,
581
+ D, #
582
+ k2q_index,
583
+ k2q_num,
584
+ max_q_blks,
585
+ k2q_index,
586
+ k2q_num,
587
+ max_q_blks,
588
+ variable_block_sizes,
589
+ stride_tok,
590
+ stride_d, #
591
+ H,
592
+ N_CTX, #
593
+ BLOCK_M1,
594
+ BLOCK_N1,
595
+ HEAD_DIM, #
596
+ start_n,
597
+ start_m,
598
+ num_steps, #
599
+ IS_QAT=IS_QAT,
600
+ USE_TILE_COMP=False,
601
+ )
602
+
603
+ dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
604
+ tl.store(dv_ptrs, dv)
605
+
606
+ # Write back dK.
607
+ dk *= sm_scale
608
+ dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
609
+ tl.store(dk_ptrs, dk)
610
+
611
+ # THIS BLOCK DOES DQ:
612
+ start_m = pid * BLOCK_M2
613
+ end_n = 0
614
+
615
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
616
+
617
+ q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
618
+ dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
619
+ do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
620
+
621
+ m_vec = tl.load(M + offs_m)
622
+ m = m_vec[:, None]
623
+
624
+ num_steps = N_CTX // BLOCK_N2
625
+ dq = _attn_bwd_dq(
626
+ dq,
627
+ q,
628
+ K,
629
+ V, #
630
+ Q,
631
+ K,
632
+ V,
633
+ do,
634
+ m,
635
+ m_vec,
636
+ D, #
637
+ q2k_index,
638
+ q2k_num,
639
+ max_kv_blks,
640
+ q2k_index,
641
+ q2k_num,
642
+ max_kv_blks,
643
+ variable_block_sizes,
644
+ stride_tok,
645
+ stride_d, #
646
+ H,
647
+ N_CTX, #
648
+ BLOCK_M2,
649
+ BLOCK_N2,
650
+ HEAD_DIM, #
651
+ start_m,
652
+ end_n,
653
+ num_steps, #
654
+ sm_scale=sm_scale,
655
+ IS_QAT=IS_QAT,
656
+ USE_TILE_COMP=False,
657
+ )
658
+ # Write back dQ.
659
+ dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
660
+ dq *= sm_scale
661
+ tl.store(dq_ptrs, dq)
662
+
663
+
664
+ @triton.jit
665
+ def _attn_bwd_dkdv_kernel(
666
+ Q,
667
+ K,
668
+ V,
669
+ QMean,
670
+ KMean,
671
+ VMean,
672
+ sm_scale, #
673
+ DO, #
674
+ DK,
675
+ DV, #
676
+ M,
677
+ D,
678
+ k2q_index,
679
+ k2q_num,
680
+ max_q_blks,
681
+ dropped_k2q_index,
682
+ dropped_k2q_num,
683
+ max_dropped_q_blks,
684
+ variable_block_sizes,
685
+ # shared token/dim strides (assumed contiguous along token and dim)
686
+ stride_tok,
687
+ stride_d, #
688
+ # batch/head strides (may differ between Q and KV)
689
+ stride_qz,
690
+ stride_qh,
691
+ stride_kz,
692
+ stride_kh,
693
+ stride_vz,
694
+ stride_vh,
695
+ stride_doz,
696
+ stride_doh,
697
+ stride_dkz,
698
+ stride_dkh,
699
+ stride_dvz,
700
+ stride_dvh,
701
+ H,
702
+ N_CTX_Q,
703
+ N_CTX_KV,
704
+ BLOCK_M1: tl.constexpr, #
705
+ BLOCK_N1: tl.constexpr, #
706
+ HEAD_DIM: tl.constexpr,
707
+ IS_QAT: tl.constexpr = False,
708
+ USE_TILE_COMP: tl.constexpr = False):
709
+ """
710
+ Backward kernel that computes dK and dV for each KV block (64 tokens).
711
+ Grid:
712
+ pid0: kv_blk in [0, N_CTX_KV/BLOCK_N1)
713
+ pid2: fused (batch, head) in [0, B*H)
714
+ """
715
+ bhid = tl.program_id(2)
716
+ b = bhid // H
717
+ h = bhid % H
718
+ kv_blk = tl.program_id(0)
719
+
720
+ q_adj = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
721
+ kv_adj_k = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
722
+ kv_adj_v = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
723
+ do_adj = (b.to(tl.int64) * stride_doz + h.to(tl.int64) * stride_doh)
724
+ dk_adj = (b.to(tl.int64) * stride_dkz + h.to(tl.int64) * stride_dkh)
725
+ dv_adj = (b.to(tl.int64) * stride_dvz + h.to(tl.int64) * stride_dvh)
726
+
727
+ Q = Q + q_adj
728
+ K = K + kv_adj_k
729
+ V = V + kv_adj_v
730
+ DO = DO + do_adj
731
+ DK = DK + dk_adj
732
+ DV = DV + dv_adj
733
+
734
+ q_tiles = N_CTX_Q // BLOCK_M1 // 2
735
+ kv_tiles = N_CTX_KV // BLOCK_N1
736
+ mean_q_adj = (bhid * q_tiles * HEAD_DIM).to(tl.int64)
737
+ mean_kv_adj = (bhid * kv_tiles * HEAD_DIM).to(tl.int64)
738
+ QMean = QMean + mean_q_adj
739
+ KMean = KMean + mean_kv_adj
740
+ VMean = VMean + mean_kv_adj
741
+
742
+ # M and D (delta) are always sized by Q length.
743
+ M = M + (bhid * N_CTX_Q).to(tl.int64)
744
+ D = D + (bhid * N_CTX_Q).to(tl.int64)
745
+
746
+ offs_k = tl.arange(0, HEAD_DIM)
747
+ start_n = kv_blk * BLOCK_N1
748
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
749
+
750
+ # load K and V: they stay in SRAM throughout the inner loop.
751
+ k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
752
+ v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
753
+
754
+ dv_acc = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
755
+ dk_acc = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
756
+
757
+ num_steps = N_CTX_Q // BLOCK_M1
758
+ dk_acc, dv_acc = _attn_bwd_dkdv(
759
+ dk_acc,
760
+ dv_acc,
761
+ Q,
762
+ k,
763
+ v,
764
+ QMean,
765
+ KMean,
766
+ VMean,
767
+ sm_scale,
768
+ DO,
769
+ M,
770
+ D,
771
+ k2q_index,
772
+ k2q_num,
773
+ max_q_blks,
774
+ dropped_k2q_index,
775
+ dropped_k2q_num,
776
+ max_dropped_q_blks,
777
+ variable_block_sizes,
778
+ stride_tok,
779
+ stride_d,
780
+ H,
781
+ N_CTX_KV,
782
+ BLOCK_M1=BLOCK_M1,
783
+ BLOCK_N1=BLOCK_N1,
784
+ HEAD_DIM=HEAD_DIM,
785
+ start_n=start_n,
786
+ start_m=0,
787
+ num_steps=num_steps,
788
+ IS_QAT=IS_QAT,
789
+ USE_TILE_COMP=USE_TILE_COMP,
790
+ )
791
+
792
+ dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
793
+ tl.store(dv_ptrs, dv_acc)
794
+
795
+ dk_acc *= sm_scale
796
+ dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
797
+ tl.store(dk_ptrs, dk_acc)
798
+
799
+
800
+ @triton.jit
801
+ def _attn_bwd_dq_kernel(
802
+ Q,
803
+ K,
804
+ V,
805
+ QMean,
806
+ KMean,
807
+ VMean,
808
+ DO, #
809
+ DQ,
810
+ M,
811
+ D,
812
+ q2k_index,
813
+ q2k_num,
814
+ max_kv_blks,
815
+ dropped_q2k_index,
816
+ dropped_q2k_num,
817
+ max_dropped_kv_blks,
818
+ variable_block_sizes,
819
+ # shared token/dim strides (assumed contiguous along token and dim)
820
+ stride_tok,
821
+ stride_d, #
822
+ # batch/head strides (may differ between Q and KV)
823
+ stride_qz,
824
+ stride_qh,
825
+ stride_kz,
826
+ stride_kh,
827
+ stride_vz,
828
+ stride_vh,
829
+ stride_doz,
830
+ stride_doh,
831
+ stride_dqz,
832
+ stride_dqh,
833
+ H,
834
+ N_CTX_Q,
835
+ sm_scale,
836
+ BLOCK_M2: tl.constexpr, #
837
+ BLOCK_N2: tl.constexpr, #
838
+ HEAD_DIM: tl.constexpr,
839
+ IS_QAT: tl.constexpr = False,
840
+ USE_TILE_COMP: tl.constexpr = False):
841
+ """
842
+ Backward kernel that computes dQ for each Q block (64 tokens).
843
+ Grid:
844
+ pid0: q_blk in [0, N_CTX_Q/BLOCK_M2)
845
+ pid2: fused (batch, head) in [0, B*H)
846
+ """
847
+ LN2 = 0.6931471824645996 # = ln(2)
848
+ bhid = tl.program_id(2)
849
+ b = bhid // H
850
+ h = bhid % H
851
+ q_blk = tl.program_id(0)
852
+
853
+ q_adj = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
854
+ kv_adj_k = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
855
+ kv_adj_v = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
856
+ do_adj = (b.to(tl.int64) * stride_doz + h.to(tl.int64) * stride_doh)
857
+ dq_adj = (b.to(tl.int64) * stride_dqz + h.to(tl.int64) * stride_dqh)
858
+
859
+ Q = Q + q_adj
860
+ K = K + kv_adj_k
861
+ V = V + kv_adj_v
862
+ DO = DO + do_adj
863
+ DQ = DQ + dq_adj
864
+
865
+ q_tiles = N_CTX_Q // BLOCK_M2
866
+ kv_tiles = N_CTX_Q // 64
867
+ mean_q_adj = (bhid * q_tiles * HEAD_DIM).to(tl.int64)
868
+ mean_kv_adj = (bhid * kv_tiles * HEAD_DIM).to(tl.int64)
869
+ QMean = QMean + mean_q_adj
870
+ KMean = KMean + mean_kv_adj
871
+ VMean = VMean + mean_kv_adj
872
+
873
+ M = M + (bhid * N_CTX_Q).to(tl.int64)
874
+ D = D + (bhid * N_CTX_Q).to(tl.int64)
875
+
876
+ offs_k = tl.arange(0, HEAD_DIM)
877
+ start_m = q_blk * BLOCK_M2
878
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
879
+
880
+ q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
881
+ do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
882
+ m_vec = tl.load(M + offs_m)
883
+ m = m_vec[:, None]
884
+
885
+ dq_acc = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
886
+ num_steps = 0 # unused in _attn_bwd_dq
887
+ dq_acc = _attn_bwd_dq(
888
+ dq_acc,
889
+ q,
890
+ K,
891
+ V,
892
+ QMean,
893
+ KMean,
894
+ VMean,
895
+ do,
896
+ m,
897
+ m_vec,
898
+ D,
899
+ q2k_index,
900
+ q2k_num,
901
+ max_kv_blks,
902
+ dropped_q2k_index,
903
+ dropped_q2k_num,
904
+ max_dropped_kv_blks,
905
+ variable_block_sizes,
906
+ stride_tok,
907
+ stride_d,
908
+ H,
909
+ N_CTX_Q,
910
+ BLOCK_M2=BLOCK_M2,
911
+ BLOCK_N2=BLOCK_N2,
912
+ HEAD_DIM=HEAD_DIM,
913
+ start_m=start_m,
914
+ start_n=0,
915
+ num_steps=num_steps,
916
+ sm_scale=sm_scale,
917
+ IS_QAT=IS_QAT,
918
+ USE_TILE_COMP=USE_TILE_COMP,
919
+ )
920
+
921
+ dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
922
+ dq_acc *= sm_scale
923
+ tl.store(dq_ptrs, dq_acc)
924
+
925
+
926
+ # ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
927
+ def triton_block_sparse_attn_forward(q, k, v, q2k_index, q2k_num,
928
+ variable_block_sizes, is_qat=False,
929
+ q_mean=None, k_mean=None, v_mean=None,
930
+ dropped_q2k_index=None,
931
+ dropped_q2k_num=None):
932
+ B, H, Tq, D = q.shape
933
+ Tkv = k.shape[2]
934
+ sm_scale = 1.0 / math.sqrt(D)
935
+ max_kv_blks = q2k_index.shape[-1]
936
+ use_tile_comp = q_mean is not None
937
+ if use_tile_comp:
938
+ assert k_mean is not None and v_mean is not None
939
+ assert dropped_q2k_index is not None and dropped_q2k_num is not None
940
+ q_mean = q_mean.contiguous()
941
+ k_mean = k_mean.contiguous()
942
+ v_mean = v_mean.contiguous()
943
+ max_dropped_kv_blks = dropped_q2k_index.shape[-1]
944
+ else:
945
+ q_mean = q
946
+ k_mean = k
947
+ v_mean = v
948
+ dropped_q2k_index = q2k_index
949
+ dropped_q2k_num = q2k_num
950
+ max_dropped_kv_blks = max_kv_blks
951
+ assert Tq % 64 == 0, f"q length must be a multiple of 64, but got {Tq}"
952
+ assert Tkv % 64 == 0, f"kv length must be a multiple of 64, but got {Tkv}"
953
+ assert q2k_num.shape[
954
+ -1] == Tq // 64, f"shape mismatch, Tq // 64 = {Tq // 64}, q2k_num.shape[-2] = {q2k_num.shape[-2]}"
955
+ assert variable_block_sizes.numel() == Tkv // 64, (
956
+ f"shape mismatch, variable_block_sizes must have length {Tkv // 64}, "
957
+ f"got {variable_block_sizes.numel()}"
958
+ )
959
+ o = torch.empty_like(q)
960
+ high_prec_o = torch.empty_like(q)
961
+ M = torch.empty((B, H, Tq), dtype=torch.float32, device=q.device)
962
+
963
+ grid = lambda _: (triton.cdiv(Tq, 64), B * H, 1)
964
+ _attn_fwd_sparse[grid](q,
965
+ k,
966
+ v,
967
+ q_mean,
968
+ k_mean,
969
+ v_mean,
970
+ sm_scale,
971
+ q2k_index,
972
+ q2k_num,
973
+ max_kv_blks,
974
+ dropped_q2k_index,
975
+ dropped_q2k_num,
976
+ max_dropped_kv_blks,
977
+ variable_block_sizes,
978
+ M,
979
+ o,
980
+ high_prec_o,
981
+ q.stride(0),
982
+ q.stride(1),
983
+ q.stride(2),
984
+ q.stride(3),
985
+ k.stride(0),
986
+ k.stride(1),
987
+ k.stride(2),
988
+ k.stride(3),
989
+ v.stride(0),
990
+ v.stride(1),
991
+ v.stride(2),
992
+ v.stride(3),
993
+ o.stride(0),
994
+ o.stride(1),
995
+ o.stride(2),
996
+ o.stride(3),
997
+ B,
998
+ H,
999
+ Tq,
1000
+ Tkv,
1001
+ HEAD_DIM=D,
1002
+ STAGE=3,
1003
+ IS_QAT=is_qat,
1004
+ USE_TILE_COMP=use_tile_comp)
1005
+
1006
+ return o, M, high_prec_o
1007
+
1008
+
1009
+ def triton_block_sparse_attn_backward(do, q, k, v, o, M, q2k_index, q2k_num,
1010
+ k2q_index, k2q_num, variable_block_sizes,
1011
+ is_qat=False, q_mean=None, k_mean=None,
1012
+ v_mean=None, dropped_q2k_index=None,
1013
+ dropped_q2k_num=None,
1014
+ dropped_k2q_index=None,
1015
+ dropped_k2q_num=None):
1016
+ assert do.is_contiguous()
1017
+
1018
+ B, H, Tq, D = q.shape
1019
+ Tkv = k.shape[2]
1020
+ sm_scale = 1.0 / math.sqrt(D)
1021
+ dq = torch.empty_like(q)
1022
+ dk = torch.empty_like(k)
1023
+ dv = torch.empty_like(v)
1024
+ BATCH, N_HEAD = q.shape[:2]
1025
+ BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
1026
+ RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
1027
+ # Ours-P mode keeps K unscaled and applies sm_scale inside the bwd kernels.
1028
+ arg_k = k
1029
+ PRE_BLOCK = 64
1030
+ assert Tq % PRE_BLOCK == 0
1031
+ pre_grid = (Tq // PRE_BLOCK, BATCH * N_HEAD)
1032
+ delta = torch.empty_like(M)
1033
+ _attn_bwd_preprocess[pre_grid](
1034
+ o,
1035
+ do, #
1036
+ delta, #
1037
+ BATCH,
1038
+ N_HEAD,
1039
+ Tq, #
1040
+ BLOCK_M=PRE_BLOCK,
1041
+ HEAD_DIM=D #
1042
+ )
1043
+
1044
+ max_q_blks = k2q_index.shape[-1]
1045
+ max_kv_blks = q2k_index.shape[-1]
1046
+ use_tile_comp = q_mean is not None
1047
+ if use_tile_comp:
1048
+ assert k_mean is not None and v_mean is not None
1049
+ assert dropped_q2k_index is not None and dropped_q2k_num is not None
1050
+ assert dropped_k2q_index is not None and dropped_k2q_num is not None
1051
+ q_mean = q_mean.contiguous()
1052
+ k_mean = k_mean.contiguous()
1053
+ v_mean = v_mean.contiguous()
1054
+ max_dropped_kv_blks = dropped_q2k_index.shape[-1]
1055
+ max_dropped_q_blks = dropped_k2q_index.shape[-1]
1056
+ else:
1057
+ q_mean = q
1058
+ k_mean = k
1059
+ v_mean = v
1060
+ dropped_q2k_index = q2k_index
1061
+ dropped_q2k_num = q2k_num
1062
+ dropped_k2q_index = k2q_index
1063
+ dropped_k2q_num = k2q_num
1064
+ max_dropped_kv_blks = max_kv_blks
1065
+ max_dropped_q_blks = max_q_blks
1066
+
1067
+ # dK/dV kernel: grid over KV blocks
1068
+ grid_kv = (Tkv // BLOCK_N1, 1, BATCH * N_HEAD)
1069
+ _attn_bwd_dkdv_kernel[grid_kv](
1070
+ q,
1071
+ arg_k,
1072
+ v,
1073
+ q_mean,
1074
+ k_mean,
1075
+ v_mean,
1076
+ sm_scale,
1077
+ do,
1078
+ dk,
1079
+ dv,
1080
+ M,
1081
+ delta,
1082
+ k2q_index,
1083
+ k2q_num,
1084
+ max_q_blks,
1085
+ dropped_k2q_index,
1086
+ dropped_k2q_num,
1087
+ max_dropped_q_blks,
1088
+ variable_block_sizes,
1089
+ q.stride(2),
1090
+ q.stride(3),
1091
+ q.stride(0),
1092
+ q.stride(1),
1093
+ arg_k.stride(0),
1094
+ arg_k.stride(1),
1095
+ v.stride(0),
1096
+ v.stride(1),
1097
+ do.stride(0),
1098
+ do.stride(1),
1099
+ dk.stride(0),
1100
+ dk.stride(1),
1101
+ dv.stride(0),
1102
+ dv.stride(1),
1103
+ N_HEAD,
1104
+ Tq,
1105
+ Tkv,
1106
+ BLOCK_M1=BLOCK_M1,
1107
+ BLOCK_N1=BLOCK_N1,
1108
+ HEAD_DIM=D,
1109
+ IS_QAT=is_qat,
1110
+ USE_TILE_COMP=use_tile_comp,
1111
+ )
1112
+
1113
+ # dQ kernel: grid over Q blocks
1114
+ grid_q = (Tq // BLOCK_M2, 1, BATCH * N_HEAD)
1115
+ _attn_bwd_dq_kernel[grid_q](
1116
+ q,
1117
+ arg_k,
1118
+ v,
1119
+ q_mean,
1120
+ k_mean,
1121
+ v_mean,
1122
+ do,
1123
+ dq,
1124
+ M,
1125
+ delta,
1126
+ q2k_index,
1127
+ q2k_num,
1128
+ max_kv_blks,
1129
+ dropped_q2k_index,
1130
+ dropped_q2k_num,
1131
+ max_dropped_kv_blks,
1132
+ variable_block_sizes,
1133
+ q.stride(2),
1134
+ q.stride(3),
1135
+ q.stride(0),
1136
+ q.stride(1),
1137
+ arg_k.stride(0),
1138
+ arg_k.stride(1),
1139
+ v.stride(0),
1140
+ v.stride(1),
1141
+ do.stride(0),
1142
+ do.stride(1),
1143
+ dq.stride(0),
1144
+ dq.stride(1),
1145
+ N_HEAD,
1146
+ Tq,
1147
+ sm_scale,
1148
+ BLOCK_M2=BLOCK_M2,
1149
+ BLOCK_N2=BLOCK_N2,
1150
+ HEAD_DIM=D,
1151
+ IS_QAT=is_qat,
1152
+ USE_TILE_COMP=use_tile_comp,
1153
+ )
1154
+
1155
+ return dq, dk, dv
standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
3
+ # and https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
4
+
5
+ import triton
6
+ import triton.language as tl
7
+ try:
8
+ from triton.language.target_info import cuda_capability_geq
9
+ _HAS_CAPABILITY_CHECK = True
10
+ except ImportError:
11
+ cuda_capability_geq = None
12
+ _HAS_CAPABILITY_CHECK = False
13
+
14
+ MXFP_BLOCK_SIZE = tl.constexpr(16)
15
+
16
+ @triton.jit
17
+ def _compute_quant_and_scale(
18
+ src_tensor,
19
+ valid_src_mask,
20
+ mx_tensor_dtype: tl.constexpr = tl.uint8,
21
+ use_global_sf=True,
22
+ two_level_quant_P=False,
23
+ IS_BLACKWELL: tl.constexpr = False,
24
+ ):
25
+ BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
26
+ BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
27
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
28
+ is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
29
+
30
+ is_fp8e4: tl.constexpr = mx_tensor_dtype == tl.float8e4nv
31
+ is_fp8e5: tl.constexpr = mx_tensor_dtype == tl.float8e5
32
+ tl.static_assert(
33
+ is_fp4 or (is_fp8e4 or is_fp8e5),
34
+ "mx_tensor_dtype must be uint8, float8e4nv, or float8e5",
35
+ )
36
+
37
+ # Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
38
+ f32_tensor = src_tensor.to(tl.float32)
39
+ abs_tensor = tl.abs(f32_tensor)
40
+ abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
41
+
42
+ if two_level_quant_P:
43
+ # row max from SageAttn3 paper
44
+ global_max_val = tl.max(f32_tensor, axis=1, keep_dims=True) # (BLOCK_SIZE_OUT_DIM, 1)
45
+ global_max_val = tl.maximum(global_max_val, 1e-8)
46
+ s_enc = ((6 * 448) / global_max_val).reshape([BLOCK_SIZE_OUT_DIM, 1, 1])
47
+ s_dec = (1 / s_enc)
48
+
49
+ abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
50
+
51
+ if use_global_sf and not two_level_quant_P:
52
+ global_max_val = tl.max(abs_tensor)
53
+ # Avoid division by zero: if all values are padding (max is 0), use a default scale
54
+ global_max_val = tl.maximum(global_max_val, 1e-8)
55
+ s_enc = (6 * 448) / global_max_val
56
+ s_dec = (1 / s_enc)
57
+ elif not two_level_quant_P and not use_global_sf:
58
+ s_dec = 1.0
59
+ s_enc = 1.0
60
+
61
+ max_val = tl.max(abs_tensor, axis=2, keep_dims=True) # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1) # per block maxima
62
+ s_dec_b = max_val / 6 # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1)
63
+ s_dec_b_e4m3 = (s_dec_b * s_enc).to(tl.float8e4nv) # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1)
64
+ s_enc_b = 1 / (s_dec_b_e4m3.to(tl.float32) * s_dec) # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1)
65
+
66
+ f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
67
+ quant_tensor = f32_tensor * s_enc_b
68
+
69
+ # Reshape the tensors after scaling
70
+ quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
71
+ # Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
72
+ quant_tensor = tl.where(valid_src_mask, quant_tensor, 0.0)
73
+ dequant_scale = s_dec_b_e4m3.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
74
+
75
+ if is_fp4 and IS_BLACKWELL:
76
+ # Convert scaled values to two f32 lanes and use PTX cvt to e2m1x2 with two f32 operands.
77
+ pairs = tl.reshape(quant_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
78
+ lo_f, hi_f = tl.split(pairs)
79
+ lo_f32 = lo_f.to(tl.float32)
80
+ hi_f32 = hi_f.to(tl.float32)
81
+
82
+ # Inline PTX: cvt.rn.satfinite.e2m1x2.f32 takes two f32 sources and produces one .b8 packed e2m1x2.
83
+ out_tensor = tl.inline_asm_elementwise(
84
+ """
85
+ {
86
+ .reg .b8 r;
87
+ cvt.rn.satfinite.e2m1x2.f32 r, $1, $2;
88
+ mov.b32 $0, {r, r, r, r};
89
+ }
90
+ """,
91
+ constraints="=r,f,f",
92
+ args=[hi_f32, lo_f32],
93
+ dtype=tl.uint8,
94
+ is_pure=True,
95
+ pack=1,
96
+ )
97
+ elif is_fp4:
98
+ quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
99
+ signs = quant_tensor & 0x80000000
100
+ exponents = (quant_tensor >> 23) & 0xFF
101
+ mantissas_orig = (quant_tensor & 0x7FFFFF)
102
+
103
+ # For RTNE: 0.25 < x < 0.75 maps to 0.5 (denormal); exactly 0.25 maps to 0.0
104
+ E8_BIAS = 127
105
+ E2_BIAS = 1
106
+ # Move implicit bit 1 at the beginning to mantissa for denormals
107
+ is_subnormal = exponents < E8_BIAS
108
+ adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False)
109
+ mantissas_pre = (0x400000 | (mantissas_orig >> 1))
110
+ mantissas = tl.where(is_subnormal, mantissas_pre >> adjusted_exponents, mantissas_orig)
111
+
112
+ # For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
113
+ exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
114
+
115
+ # Combine sign, exponent, and mantissa, while saturating
116
+ # Round to nearest, ties to even (RTNE): use guard/sticky and LSB to decide increment
117
+ m2bits = mantissas >> 21
118
+ lsb_keep = (m2bits >> 1) & 0x1
119
+ guard = m2bits & 0x1
120
+ IS_SRC_FP32: tl.constexpr = src_tensor.dtype == tl.float32
121
+ if IS_SRC_FP32:
122
+ bit0_dropped = (mantissas_orig & 0x1) != 0
123
+ mask = (1 << tl.minimum(adjusted_exponents, 31)) - 1
124
+ dropped_post = (mantissas_pre & mask) != 0
125
+ sticky = is_subnormal & (bit0_dropped | dropped_post)
126
+ sticky |= ((mantissas & 0x1FFFFF) != 0).to(tl.uint32)
127
+ else:
128
+ sticky = ((mantissas & 0x1FFFFF) != 0).to(tl.uint32)
129
+ round_inc = guard & (sticky | lsb_keep)
130
+ e2m1_tmp = tl.minimum((((exponents << 2) | m2bits) + round_inc) >> 1, 0x7)
131
+ e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
132
+
133
+ e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
134
+ evens, odds = tl.split(e2m1_value)
135
+ out_tensor = evens | (odds << 4)
136
+ else:
137
+ out_tensor = quant_tensor.to(mx_tensor_dtype)
138
+
139
+ return out_tensor, dequant_scale, s_dec
140
+
141
+ @triton.jit
142
+ def _compute_dequant(
143
+ mx_tensor,
144
+ scale,
145
+ s_dec,
146
+ BLOCK_SIZE_OUT_DIM: tl.constexpr,
147
+ BLOCK_SIZE_QUANT_DIM: tl.constexpr,
148
+ dst_dtype: tl.constexpr,
149
+ IS_BLACKWELL: tl.constexpr = False,
150
+ ):
151
+ tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"Block size along quantization block must be a multiple of {MXFP_BLOCK_SIZE=}")
152
+ # uint8 signifies two fp4 e2m1 values packed into a single byte
153
+ mx_tensor_dtype: tl.constexpr = mx_tensor.dtype
154
+ _is_f16: tl.constexpr = dst_dtype == tl.float16
155
+ _is_bf16: tl.constexpr = dst_dtype == tl.bfloat16
156
+ _is_f32: tl.constexpr = dst_dtype == tl.float32
157
+ tl.static_assert(_is_f16 or (_is_bf16 or _is_f32))
158
+ _is_u8: tl.constexpr = mx_tensor_dtype == tl.uint8
159
+ _is_e4: tl.constexpr = mx_tensor_dtype == tl.float8e4nv
160
+ _is_e5: tl.constexpr = mx_tensor_dtype == tl.float8e5
161
+ _is_dst: tl.constexpr = mx_tensor_dtype == dst_dtype
162
+ tl.static_assert(
163
+ _is_u8 or ((_is_e4 or _is_e5) or _is_dst),
164
+ "mx_tensor_ptr must be uint8 or float8 or dst_dtype")
165
+ tl.static_assert(scale.dtype == tl.float8e4nv, "scale must be float8e4nv")
166
+
167
+ # Determine if we are dealing with fp8 types.
168
+ is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
169
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
170
+
171
+ # Upcast the scale to the destination type.
172
+ if dst_dtype == tl.bfloat16:
173
+ dst_scale = scale.to(tl.bfloat16)
174
+ else:
175
+ dst_scale = scale.to(tl.float32)
176
+ if dst_dtype == tl.float16:
177
+ dst_scale = dst_scale.to(tl.float16)
178
+
179
+ # Now upcast the tensor.
180
+ intermediate_dtype: tl.constexpr = tl.bfloat16 if dst_dtype == tl.float32 else dst_dtype
181
+ if IS_BLACKWELL:
182
+ assert is_fp4
183
+ packed_u32 = tl.inline_asm_elementwise(
184
+ asm="""
185
+ {
186
+ .reg .b8 in_8;
187
+ .reg .f16x2 out;
188
+ cvt.u8.u32 in_8, $1;
189
+ cvt.rn.f16x2.e2m1x2 out, in_8;
190
+ mov.b32 $0, out;
191
+ }
192
+ """,
193
+ constraints="=r,r",
194
+ args=[mx_tensor], # tl.uint8 passed in as a 32-bit reg with value in low 8 bits
195
+ dtype=tl.uint32,
196
+ is_pure=True,
197
+ pack=1,
198
+ )
199
+ lo_u16 = (packed_u32 & 0xFFFF).to(tl.uint16)
200
+ hi_u16 = (packed_u32 >> 16).to(tl.uint16)
201
+ lo_f16 = lo_u16.to(tl.float16, bitcast=True)
202
+ hi_f16 = hi_u16.to(tl.float16, bitcast=True)
203
+
204
+ if intermediate_dtype == tl.float16:
205
+ x0, x1 = lo_f16, hi_f16
206
+ else:
207
+ x0 = lo_f16.to(intermediate_dtype)
208
+ x1 = hi_f16.to(intermediate_dtype)
209
+
210
+ dst_tensor = tl.interleave(x0, x1)
211
+
212
+ else:
213
+ assert is_fp4
214
+ dst_bias: tl.constexpr = 127 if intermediate_dtype == tl.bfloat16 else 15 # exponent bias
215
+ dst_0p5: tl.constexpr = 16128 if intermediate_dtype == tl.bfloat16 else 0x3800
216
+ dst_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10 # mantissa bits
217
+ # e2m1
218
+ em0 = mx_tensor & 0x07
219
+ em1 = mx_tensor & 0x70
220
+ x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((mx_tensor & 0x08).to(tl.uint16) << 12)
221
+ x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((mx_tensor & 0x80).to(tl.uint16) << 8)
222
+ # Three cases:
223
+ # 1) x is normal and non-zero: Correct bias
224
+ x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
225
+ x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
226
+ # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
227
+ x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
228
+ x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
229
+ # 3) x is zero, do nothing
230
+ dst_tensor = tl.interleave(x0, x1).to(intermediate_dtype, bitcast=True)
231
+
232
+ dst_tensor = dst_tensor.to(dst_dtype)
233
+
234
+ # Reshape for proper broadcasting: the scale was stored with a 16‐sized “inner” grouping.
235
+ dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
236
+ dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
237
+ scale = scale.reshape(dst_scale.shape)
238
+
239
+ out_tensor = dst_tensor * dst_scale * s_dec # NVFP4 has the additional global scale factor
240
+ if dst_dtype == tl.float32:
241
+ max_fin = 3.4028234663852886e+38
242
+ elif dst_dtype == tl.bfloat16:
243
+ max_fin = 3.3895313892515355e+38
244
+ else:
245
+ tl.static_assert(dst_dtype == tl.float16)
246
+ max_fin = 65504
247
+ out_tensor = tl.clamp(out_tensor, min=-max_fin, max=max_fin)
248
+ out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
249
+ out_tensor = out_tensor.to(dst_dtype)
250
+ return out_tensor
standalone_inference/overlay_files/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+
4
+ from .nvfp4_utils import _compute_quant_and_scale, _compute_dequant
5
+
6
+ @triton.jit
7
+ def fake_quantize(src_tensor, valid_src_mask, BLOCK_SIZE_OUT_DIM: tl.constexpr,
8
+ BLOCK_SIZE_QUANT_DIM: tl.constexpr,
9
+ dst_dtype: tl.constexpr,
10
+ mx_tensor_dtype: tl.constexpr = tl.uint8,
11
+ use_global_sf: tl.constexpr = True,
12
+ two_level_quant_P: tl.constexpr = False):
13
+ high_prec_src_tensor = src_tensor
14
+ src_tensor, src_scale, src_s_dec = _compute_quant_and_scale(src_tensor=src_tensor,
15
+ valid_src_mask=valid_src_mask,
16
+ mx_tensor_dtype=mx_tensor_dtype,
17
+ use_global_sf=use_global_sf,
18
+ two_level_quant_P=two_level_quant_P)
19
+ src_tensor = _compute_dequant(mx_tensor=src_tensor,
20
+ scale=src_scale,
21
+ s_dec=src_s_dec,
22
+ BLOCK_SIZE_OUT_DIM=BLOCK_SIZE_OUT_DIM,
23
+ BLOCK_SIZE_QUANT_DIM=BLOCK_SIZE_QUANT_DIM,
24
+ dst_dtype=dst_dtype)
25
+ return src_tensor, high_prec_src_tensor.to(src_tensor.dtype)
26
+
27
+ @triton.jit
28
+ def fake_quantize_q(Q, fake_Q, stride_z_q, stride_h_q,
29
+ stride_tok_q, stride_d_q,
30
+ fake_stride_z_q, fake_stride_h_q,
31
+ fake_stride_tok_q, fake_stride_d_q,
32
+ H, N_CTX_Q,
33
+ BLOCK_M: tl.constexpr,
34
+ HEAD_DIM: tl.constexpr,
35
+ use_global_sf: tl.constexpr = True):
36
+ bhid = tl.program_id(1)
37
+ adj_q = (stride_h_q * (bhid % H) + stride_z_q * (bhid // H))
38
+ fake_adj_q = (fake_stride_h_q * (bhid % H) + fake_stride_z_q * (bhid // H))
39
+ Q += adj_q
40
+ fake_Q += fake_adj_q
41
+
42
+ pid = tl.program_id(0)
43
+ start_m = pid * BLOCK_M
44
+ offs_m = start_m + tl.arange(0, BLOCK_M)
45
+ offs_k = tl.arange(0, HEAD_DIM)
46
+
47
+ q_valid = offs_m < N_CTX_Q
48
+ q = tl.load(Q + offs_m[:, None] * stride_tok_q + offs_k[None, :] * stride_d_q, mask=q_valid[:, None], other=0.0)
49
+ q, _ = fake_quantize(src_tensor=q, valid_src_mask=q_valid[:, None], BLOCK_SIZE_OUT_DIM=BLOCK_M, BLOCK_SIZE_QUANT_DIM=HEAD_DIM, dst_dtype=q.dtype, use_global_sf=use_global_sf)
50
+ tl.store(fake_Q + offs_m[:, None] * fake_stride_tok_q + offs_k[None, :] * fake_stride_d_q, q, mask=q_valid[:, None])
51
+
52
+ @triton.jit
53
+ def fake_quantize_kv(K, V, fake_K, fake_V, stride_z_kv, stride_h_kv,
54
+ stride_tok_kv, stride_d_kv,
55
+ fake_stride_z_kv, fake_stride_h_kv,
56
+ fake_stride_tok_kv, fake_stride_d_kv,
57
+ H, N_CTX_KV,
58
+ BLOCK_N: tl.constexpr,
59
+ HEAD_DIM: tl.constexpr,
60
+ use_global_sf: tl.constexpr = True):
61
+ bhid = tl.program_id(1)
62
+ adj_kv = (stride_h_kv * (bhid % H) + stride_z_kv * (bhid // H))
63
+ fake_adj_kv = (fake_stride_h_kv * (bhid % H) + fake_stride_z_kv * (bhid // H))
64
+ K += adj_kv
65
+ V += adj_kv
66
+ fake_K += fake_adj_kv
67
+ fake_V += fake_adj_kv
68
+
69
+ pid = tl.program_id(0)
70
+ start_n = pid * BLOCK_N
71
+ offs_n = start_n + tl.arange(0, BLOCK_N)
72
+ offs_k = tl.arange(0, HEAD_DIM)
73
+
74
+ kv_valid = offs_n < N_CTX_KV
75
+ k_block = tl.load(K + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv, mask=kv_valid[:, None], other=0.0)
76
+ v_block = tl.load(V + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv, mask=kv_valid[:, None], other=0.0)
77
+ k, _ = fake_quantize(src_tensor=k_block, valid_src_mask=kv_valid[:, None], BLOCK_SIZE_OUT_DIM=BLOCK_N, BLOCK_SIZE_QUANT_DIM=HEAD_DIM, dst_dtype=k_block.dtype, use_global_sf=use_global_sf)
78
+ v, _ = fake_quantize(src_tensor=v_block, valid_src_mask=kv_valid[:, None], BLOCK_SIZE_OUT_DIM=BLOCK_N, BLOCK_SIZE_QUANT_DIM=HEAD_DIM, dst_dtype=v_block.dtype, use_global_sf=use_global_sf)
79
+ tl.store(fake_K + offs_n[:, None] * fake_stride_tok_kv + offs_k[None, :] * fake_stride_d_kv, k, mask=kv_valid[:, None])
80
+ tl.store(fake_V + offs_n[:, None] * fake_stride_tok_kv + offs_k[None, :] * fake_stride_d_kv, v, mask=kv_valid[:, None])
standalone_inference/overlay_files/fastvideo/api/compat.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from __future__ import annotations
3
+
4
+ from collections.abc import Mapping
5
+ from copy import deepcopy
6
+ from dataclasses import fields, is_dataclass
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from fastvideo.api.overrides import apply_overrides, parse_cli_overrides
11
+ from fastvideo.api.parser import config_to_dict, load_raw_config, parse_config
12
+ from fastvideo.api.schema import (
13
+ GenerationRequest,
14
+ GeneratorConfig,
15
+ InputConfig,
16
+ OutputConfig,
17
+ RequestRuntimeConfig,
18
+ SamplingConfig,
19
+ )
20
+ from fastvideo.configs.sample import SamplingParam
21
+ from fastvideo.fastvideo_args import FastVideoArgs
22
+ from fastvideo.utils import shallow_asdict
23
+
24
+ _EXPLICIT_REQUEST_ATTR = "_fastvideo_explicit_request"
25
+ _INPUT_FIELD_NAMES = {field.name for field in fields(InputConfig)}
26
+ _SAMPLING_FIELD_NAMES = {field.name for field in fields(SamplingConfig)}
27
+ _RUNTIME_FIELD_NAMES = {field.name for field in fields(RequestRuntimeConfig)}
28
+ _OUTPUT_FIELD_NAMES = {field.name for field in fields(OutputConfig)}
29
+ _MISSING = object()
30
+ _LEGACY_REQUEST_ALIASES = {
31
+ "neg_prompt": "negative_prompt",
32
+ }
33
+ _REQUEST_PIPELINE_OVERRIDE_FIELDS = frozenset({
34
+ "embedded_cfg_scale",
35
+ })
36
+
37
+
38
+ def normalize_generator_config(config: GeneratorConfig | Mapping[str, Any], ) -> GeneratorConfig:
39
+ if isinstance(config, GeneratorConfig):
40
+ return config
41
+ return parse_config(GeneratorConfig, config)
42
+
43
+
44
+ def load_generator_config_from_file(
45
+ path: str | Path,
46
+ overrides: list[str] | Mapping[str, Any] | None = None,
47
+ ) -> GeneratorConfig:
48
+ raw = load_raw_config(path)
49
+ normalized_overrides = _normalize_overrides(overrides)
50
+
51
+ if _looks_like_run_or_serve_config(raw):
52
+ if normalized_overrides:
53
+ raw = apply_overrides(raw, normalized_overrides)
54
+ return parse_config(GeneratorConfig, raw["generator"])
55
+
56
+ if normalized_overrides:
57
+ adjusted = normalized_overrides
58
+ if all(key.startswith("generator.") for key in adjusted):
59
+ adjusted = {key[len("generator."):]: value for key, value in adjusted.items()}
60
+ raw = apply_overrides(raw, adjusted)
61
+
62
+ return parse_config(GeneratorConfig, raw)
63
+
64
+
65
+ def legacy_from_pretrained_to_config(
66
+ model_path: str,
67
+ kwargs: Mapping[str, Any],
68
+ ) -> GeneratorConfig:
69
+ raw: dict[str, Any] = {"model_path": model_path}
70
+ engine: dict[str, Any] = {}
71
+ parallelism: dict[str, Any] = {}
72
+ offload: dict[str, Any] = {}
73
+ compile_config: dict[str, Any] = {}
74
+ pipeline: dict[str, Any] = {}
75
+ components: dict[str, Any] = {}
76
+ quantization: dict[str, Any] = {}
77
+ experimental: dict[str, Any] = {}
78
+
79
+ for key, value in kwargs.items():
80
+ if key == "revision":
81
+ raw["revision"] = value
82
+ elif key == "trust_remote_code":
83
+ raw["trust_remote_code"] = value
84
+ elif key == "num_gpus":
85
+ engine["num_gpus"] = value
86
+ elif key == "distributed_executor_backend":
87
+ engine["execution_backend"] = value
88
+ elif key in {"tp_size", "sp_size", "hsdp_replicate_dim", "hsdp_shard_dim", "dist_timeout"}:
89
+ parallelism[key] = value
90
+ elif key == "dit_cpu_offload":
91
+ offload["dit"] = value
92
+ elif key == "dit_layerwise_offload":
93
+ offload["dit_layerwise"] = value
94
+ elif key == "text_encoder_cpu_offload":
95
+ offload["text_encoder"] = value
96
+ elif key == "image_encoder_cpu_offload":
97
+ offload["image_encoder"] = value
98
+ elif key == "vae_cpu_offload":
99
+ offload["vae"] = value
100
+ elif key == "pin_cpu_memory":
101
+ offload["pin_cpu_memory"] = value
102
+ elif key == "enable_torch_compile":
103
+ compile_config["enabled"] = value
104
+ elif key == "torch_compile_kwargs":
105
+ compile_config["kwargs"] = deepcopy(value)
106
+ elif key in {"enable_stage_verification", "use_fsdp_inference", "disable_autocast"}:
107
+ engine[key] = value
108
+ elif key == "override_text_encoder_quant":
109
+ quantization["text_encoder_quant"] = value
110
+ elif key == "transformer_quant":
111
+ quantization["transformer_quant"] = value
112
+ elif key == "workload_type":
113
+ pipeline["workload_type"] = value
114
+ elif key == "lora_path":
115
+ components["lora_path"] = value
116
+ elif key == "override_pipeline_cls_name":
117
+ components["override_pipeline_cls_name"] = value
118
+ elif key == "override_transformer_cls_name":
119
+ components["override_transformer_cls_name"] = value
120
+ elif key == "pipeline_config":
121
+ if isinstance(value, str):
122
+ components["pipeline_config_path"] = value
123
+ else:
124
+ experimental[key] = deepcopy(value)
125
+ elif key == "override_text_encoder_safetensors":
126
+ components["text_encoder_weights"] = value
127
+ elif key == "init_weights_from_safetensors":
128
+ components["transformer_weights"] = value
129
+ elif key == "init_weights_from_safetensors_2":
130
+ components["transformer_2_weights"] = value
131
+ else:
132
+ experimental[key] = deepcopy(value)
133
+
134
+ if parallelism:
135
+ engine["parallelism"] = parallelism
136
+ if offload:
137
+ engine["offload"] = offload
138
+ if compile_config:
139
+ engine["compile"] = compile_config
140
+ if quantization:
141
+ engine["quantization"] = quantization
142
+ if engine:
143
+ raw["engine"] = engine
144
+
145
+ if components:
146
+ pipeline["components"] = components
147
+ if experimental:
148
+ pipeline["experimental"] = experimental
149
+ if pipeline:
150
+ raw["pipeline"] = pipeline
151
+
152
+ return parse_config(GeneratorConfig, raw)
153
+
154
+
155
+ def generator_config_to_fastvideo_args(config: GeneratorConfig | Mapping[str, Any], ) -> FastVideoArgs:
156
+ normalized = normalize_generator_config(config)
157
+ unsupported = []
158
+ if normalized.pipeline.profile is not None:
159
+ unsupported.append("pipeline.profile")
160
+ if normalized.pipeline.profile_version is not None:
161
+ unsupported.append("pipeline.profile_version")
162
+ if normalized.pipeline.components.config_root is not None:
163
+ unsupported.append("pipeline.components.config_root")
164
+ if normalized.pipeline.components.vae_weights is not None:
165
+ unsupported.append("pipeline.components.vae_weights")
166
+ if normalized.pipeline.components.upsampler_weights is not None:
167
+ unsupported.append("pipeline.components.upsampler_weights")
168
+ if unsupported:
169
+ joined = ", ".join(unsupported)
170
+ raise NotImplementedError(f"VideoGenerator compatibility adapter does not support {joined} yet")
171
+
172
+ engine = normalized.engine
173
+ kwargs: dict[str, Any] = {
174
+ "model_path": normalized.model_path,
175
+ "revision": normalized.revision,
176
+ "trust_remote_code": normalized.trust_remote_code,
177
+ "num_gpus": engine.num_gpus,
178
+ "distributed_executor_backend": engine.execution_backend,
179
+ "tp_size": engine.parallelism.tp_size,
180
+ "sp_size": engine.parallelism.sp_size,
181
+ "hsdp_replicate_dim": engine.parallelism.hsdp_replicate_dim,
182
+ "hsdp_shard_dim": engine.parallelism.hsdp_shard_dim,
183
+ "dist_timeout": engine.parallelism.dist_timeout,
184
+ "dit_cpu_offload": engine.offload.dit,
185
+ "dit_layerwise_offload": engine.offload.dit_layerwise,
186
+ "text_encoder_cpu_offload": engine.offload.text_encoder,
187
+ "image_encoder_cpu_offload": engine.offload.image_encoder,
188
+ "vae_cpu_offload": engine.offload.vae,
189
+ "pin_cpu_memory": engine.offload.pin_cpu_memory,
190
+ "enable_torch_compile": engine.compile.enabled,
191
+ "torch_compile_kwargs": deepcopy(engine.compile.kwargs),
192
+ "enable_stage_verification": engine.enable_stage_verification,
193
+ "use_fsdp_inference": engine.use_fsdp_inference,
194
+ "disable_autocast": engine.disable_autocast,
195
+ }
196
+ if normalized.pipeline.workload_type is not None:
197
+ kwargs["workload_type"] = normalized.pipeline.workload_type
198
+
199
+ quantization = engine.quantization
200
+ if quantization is not None and quantization.text_encoder_quant is not None:
201
+ kwargs["override_text_encoder_quant"] = quantization.text_encoder_quant
202
+ if quantization is not None and quantization.transformer_quant is not None:
203
+ kwargs["transformer_quant"] = quantization.transformer_quant
204
+
205
+ components = normalized.pipeline.components
206
+ if components.pipeline_config_path is not None:
207
+ kwargs["pipeline_config"] = components.pipeline_config_path
208
+ if components.lora_path is not None:
209
+ kwargs["lora_path"] = components.lora_path
210
+ if components.override_pipeline_cls_name is not None:
211
+ kwargs["override_pipeline_cls_name"] = components.override_pipeline_cls_name
212
+ if components.override_transformer_cls_name is not None:
213
+ kwargs["override_transformer_cls_name"] = components.override_transformer_cls_name
214
+ if components.text_encoder_weights is not None:
215
+ kwargs["override_text_encoder_safetensors"] = components.text_encoder_weights
216
+ if components.transformer_weights is not None:
217
+ kwargs["init_weights_from_safetensors"] = components.transformer_weights
218
+ if components.transformer_2_weights is not None:
219
+ kwargs["init_weights_from_safetensors_2"] = components.transformer_2_weights
220
+
221
+ kwargs.update(deepcopy(normalized.pipeline.profile_overrides))
222
+ kwargs.update(deepcopy(normalized.pipeline.experimental))
223
+ return FastVideoArgs.from_kwargs(**kwargs)
224
+
225
+
226
+ def normalize_generation_request(request: GenerationRequest | Mapping[str, Any], ) -> GenerationRequest:
227
+ normalized = (request if isinstance(request, GenerationRequest) else parse_config(GenerationRequest, request))
228
+
229
+ if not hasattr(normalized, _EXPLICIT_REQUEST_ATTR):
230
+ setattr(normalized, _EXPLICIT_REQUEST_ATTR, _serialize_generation_request(normalized))
231
+ return normalized
232
+
233
+
234
+ def legacy_generate_call_to_request(
235
+ prompt: str | None,
236
+ sampling_param: SamplingParam | None,
237
+ *,
238
+ mouse_cond: Any | None = None,
239
+ keyboard_cond: Any | None = None,
240
+ grid_sizes: Any | None = None,
241
+ legacy_kwargs: Mapping[str, Any] | None = None,
242
+ ) -> GenerationRequest:
243
+ raw = _sampling_param_to_request_raw(sampling_param)
244
+ if prompt is not None:
245
+ raw["prompt"] = prompt
246
+
247
+ for key, value in (legacy_kwargs or {}).items():
248
+ _apply_request_field(raw, key, value)
249
+
250
+ if mouse_cond is not None:
251
+ raw.setdefault("inputs", {})["mouse_cond"] = mouse_cond
252
+ if keyboard_cond is not None:
253
+ raw.setdefault("inputs", {})["keyboard_cond"] = keyboard_cond
254
+ if grid_sizes is not None:
255
+ raw.setdefault("inputs", {})["grid_sizes"] = grid_sizes
256
+
257
+ normalized = parse_config(GenerationRequest, raw)
258
+ setattr(normalized, _EXPLICIT_REQUEST_ATTR, deepcopy(raw))
259
+ return normalized
260
+
261
+
262
+ def request_to_sampling_param(
263
+ request: GenerationRequest,
264
+ *,
265
+ model_path: str,
266
+ ) -> SamplingParam:
267
+ if request.plan is not None:
268
+ raise NotImplementedError("GenerationRequest.plan is not wired into VideoGenerator yet")
269
+ if request.state is not None:
270
+ raise NotImplementedError("GenerationRequest.state is not wired into VideoGenerator yet")
271
+
272
+ sampling_param = SamplingParam.from_pretrained(model_path)
273
+ updates = _explicit_request_updates(request)
274
+
275
+ for key, value in updates.items():
276
+ if hasattr(sampling_param, key):
277
+ setattr(sampling_param, key, deepcopy(value))
278
+ elif key in _REQUEST_PIPELINE_OVERRIDE_FIELDS or _is_supported_as_default_only(key, value):
279
+ continue
280
+ else:
281
+ raise ValueError(f"Request field {key!r} is not supported by sampling params for {model_path}")
282
+
283
+ sampling_param.__post_init__()
284
+ sampling_param.check_sampling_param()
285
+ return sampling_param
286
+
287
+
288
+ def expand_request_prompt_batch(request: GenerationRequest, ) -> list[GenerationRequest]:
289
+ if not isinstance(request.prompt, list):
290
+ return [request]
291
+
292
+ requests: list[GenerationRequest] = []
293
+ for index, prompt in enumerate(request.prompt):
294
+ single_request = deepcopy(request)
295
+ single_request.prompt = prompt
296
+ _fan_out_batched_input_value(request, single_request, "image_path", index)
297
+ _fan_out_batched_input_value(request, single_request, "video_path", index)
298
+ _fan_out_explicit_request_metadata(request, single_request, index, prompt)
299
+ requests.append(single_request)
300
+ return requests
301
+
302
+
303
+ def _looks_like_run_or_serve_config(raw: Mapping[str, Any]) -> bool:
304
+ return isinstance(raw.get("generator"), Mapping)
305
+
306
+
307
+ def _normalize_overrides(overrides: list[str] | Mapping[str, Any] | None, ) -> dict[str, Any] | None:
308
+ if not overrides:
309
+ return None
310
+ if isinstance(overrides, list):
311
+ return parse_cli_overrides(overrides)
312
+ return dict(overrides)
313
+
314
+
315
+ def _sampling_param_to_request_raw(sampling_param: SamplingParam | None, ) -> dict[str, Any]:
316
+ if sampling_param is None:
317
+ return {}
318
+
319
+ raw: dict[str, Any] = {}
320
+ for key, value in shallow_asdict(sampling_param).items():
321
+ if key == "prompt":
322
+ continue
323
+ _apply_request_field(raw, key, deepcopy(value))
324
+ return raw
325
+
326
+
327
+ def _apply_request_field(
328
+ raw: dict[str, Any],
329
+ key: str,
330
+ value: Any,
331
+ ) -> None:
332
+ key = _LEGACY_REQUEST_ALIASES.get(key, key)
333
+ if key == "negative_prompt":
334
+ raw["negative_prompt"] = value
335
+ return
336
+ if key in _INPUT_FIELD_NAMES:
337
+ raw.setdefault("inputs", {})[key] = value
338
+ return
339
+ if key in _SAMPLING_FIELD_NAMES:
340
+ raw.setdefault("sampling", {})[key] = value
341
+ return
342
+ if key in _RUNTIME_FIELD_NAMES:
343
+ raw.setdefault("runtime", {})[key] = value
344
+ return
345
+ if key in _OUTPUT_FIELD_NAMES:
346
+ raw.setdefault("output", {})[key] = value
347
+ return
348
+ raw.setdefault("extensions", {})[key] = value
349
+
350
+
351
+ def request_to_pipeline_overrides(request: GenerationRequest) -> dict[str, Any]:
352
+ overrides: dict[str, Any] = {}
353
+ for key, value in _explicit_request_updates(request).items():
354
+ if key in _REQUEST_PIPELINE_OVERRIDE_FIELDS:
355
+ overrides[key] = deepcopy(value)
356
+ return overrides
357
+
358
+
359
+ def _explicit_request_updates(request: GenerationRequest) -> dict[str, Any]:
360
+ raw = getattr(request, _EXPLICIT_REQUEST_ATTR, None)
361
+ if raw is None:
362
+ raw = _serialize_generation_request(request)
363
+
364
+ return _extract_request_updates(raw)
365
+
366
+
367
+ def _extract_request_updates(raw: Mapping[str, Any]) -> dict[str, Any]:
368
+ updates: dict[str, Any] = {}
369
+ if "negative_prompt" in raw:
370
+ updates["negative_prompt"] = deepcopy(raw["negative_prompt"])
371
+
372
+ for section_name in ("inputs", "sampling", "runtime", "output"):
373
+ section = raw.get(section_name)
374
+ if not isinstance(section, Mapping):
375
+ continue
376
+ for key, value in section.items():
377
+ updates[key] = deepcopy(value)
378
+
379
+ stage_overrides = raw.get("stage_overrides")
380
+ if stage_overrides:
381
+ updates.update(_flatten_stage_overrides(stage_overrides))
382
+
383
+ extensions = raw.get("extensions")
384
+ if isinstance(extensions, Mapping):
385
+ for key, value in extensions.items():
386
+ updates[key] = deepcopy(value)
387
+
388
+ return updates
389
+
390
+
391
+ def _flatten_stage_overrides(stage_overrides: Any) -> dict[str, Any]:
392
+ if not isinstance(stage_overrides, Mapping):
393
+ raise ValueError("GenerationRequest.stage_overrides must be a mapping")
394
+
395
+ flattened: dict[str, Any] = {}
396
+ for stage_name, overrides in stage_overrides.items():
397
+ if not isinstance(overrides, Mapping):
398
+ raise ValueError(f"GenerationRequest.stage_overrides.{stage_name} must be a mapping")
399
+ for key, value in overrides.items():
400
+ if key in flattened and flattened[key] != value:
401
+ raise ValueError(f"Conflicting stage override for {key!r} across stages")
402
+ flattened[key] = deepcopy(value)
403
+ return flattened
404
+
405
+
406
+ def _serialize_generation_request(request: GenerationRequest) -> dict[str, Any]:
407
+ return deepcopy(config_to_dict(request))
408
+
409
+
410
+ def _fan_out_batched_input_value(
411
+ source_request: GenerationRequest,
412
+ target_request: GenerationRequest,
413
+ field_name: str,
414
+ index: int,
415
+ ) -> None:
416
+ value = getattr(source_request.inputs, field_name)
417
+ if not isinstance(value, list):
418
+ return
419
+ _validate_batched_input_length(source_request.prompt, value, field_name)
420
+ setattr(target_request.inputs, field_name, deepcopy(value[index]))
421
+
422
+
423
+ def _fan_out_explicit_request_metadata(
424
+ source_request: GenerationRequest,
425
+ target_request: GenerationRequest,
426
+ index: int,
427
+ prompt: str,
428
+ ) -> None:
429
+ raw = getattr(source_request, _EXPLICIT_REQUEST_ATTR, None)
430
+ if raw is None:
431
+ return
432
+
433
+ raw = deepcopy(raw)
434
+ raw["prompt"] = prompt
435
+ inputs = raw.get("inputs")
436
+ if isinstance(inputs, dict):
437
+ for field_name in ("image_path", "video_path"):
438
+ value = inputs.get(field_name)
439
+ if isinstance(value, list):
440
+ _validate_batched_input_length(source_request.prompt, value, field_name)
441
+ inputs[field_name] = deepcopy(value[index])
442
+
443
+ setattr(target_request, _EXPLICIT_REQUEST_ATTR, raw)
444
+
445
+
446
+ def _validate_batched_input_length(
447
+ prompts: str | list[str] | None,
448
+ values: list[Any],
449
+ field_name: str,
450
+ ) -> None:
451
+ if not isinstance(prompts, list):
452
+ return
453
+ if len(values) != len(prompts):
454
+ raise ValueError(f"GenerationRequest.inputs.{field_name} must have the same length as request.prompt")
455
+
456
+
457
+ def _is_supported_as_default_only(key: str, value: Any) -> bool:
458
+ default_value = _DEFAULT_REQUEST_UPDATES.get(key, _MISSING)
459
+ return default_value is not _MISSING and _values_equal(value, default_value)
460
+
461
+
462
+ def _collect_non_default_fields(
463
+ value: Any,
464
+ default: Any,
465
+ ) -> dict[str, Any]:
466
+ if not (is_dataclass(value) and is_dataclass(default)):
467
+ return {}
468
+
469
+ result: dict[str, Any] = {}
470
+ for field in fields(value):
471
+ current = getattr(value, field.name)
472
+ default_value = getattr(default, field.name)
473
+ if is_dataclass(current) and is_dataclass(default_value):
474
+ nested = _collect_non_default_fields(current, default_value)
475
+ if nested:
476
+ result[field.name] = nested
477
+ continue
478
+ if not _values_equal(current, default_value):
479
+ result[field.name] = deepcopy(current)
480
+ return result
481
+
482
+
483
+ def _values_equal(left: Any, right: Any) -> bool:
484
+ if left is right:
485
+ return True
486
+ try:
487
+ return bool(left == right)
488
+ except Exception:
489
+ return False
490
+
491
+
492
+ _DEFAULT_REQUEST_UPDATES = _extract_request_updates(config_to_dict(GenerationRequest()))
493
+
494
+ __all__ = [
495
+ "generator_config_to_fastvideo_args",
496
+ "legacy_from_pretrained_to_config",
497
+ "legacy_generate_call_to_request",
498
+ "load_generator_config_from_file",
499
+ "normalize_generation_request",
500
+ "normalize_generator_config",
501
+ "request_to_pipeline_overrides",
502
+ "request_to_sampling_param",
503
+ ]
standalone_inference/overlay_files/fastvideo/attention/backends/sparse_fp4_ours_p_attn.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Sparse FP4 Attention backend with the independent ours-P quant kernel."""
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import triton
9
+
10
+ from fastvideo_kernel.triton_kernels.quant_utils import (
11
+ fake_quantize_q,
12
+ fake_quantize_kv,
13
+ )
14
+ from fastvideo_kernel.block_sparse_attn_ours_p import block_sparse_attn_ours_p
15
+ from fastvideo.forward_context import get_forward_context
16
+
17
+ from fastvideo.attention.backends.abstract import (
18
+ AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataBuilder,
19
+ )
20
+ from fastvideo.attention.backends.video_sparse_attn import (
21
+ VideoSparseAttentionMetadata,
22
+ VideoSparseAttentionMetadataBuilder,
23
+ VSA_TILE_SIZE,
24
+ )
25
+ from fastvideo.distributed import get_sp_group
26
+ from fastvideo.logger import init_logger
27
+
28
+ logger = init_logger(__name__)
29
+
30
+
31
+ def _dense_sdpa_blhd(query, key, value):
32
+ q = query.transpose(1, 2)
33
+ k = key.transpose(1, 2)
34
+ v = value.transpose(1, 2)
35
+ out = F.scaled_dot_product_attention(q, k, v, is_causal=False)
36
+ return out.transpose(1, 2)
37
+
38
+
39
+ def _quantize_qkv_bhld(q, k, v):
40
+ """FP4 fake quantize Q/K/V in BHLD layout, same as attn_qat_train."""
41
+ H = q.shape[1]
42
+ N_Q = q.shape[2]
43
+ N_KV = k.shape[2]
44
+ D = q.shape[3]
45
+ BLOCK = 32
46
+
47
+ fake_q = torch.empty_like(q)
48
+ fake_k = torch.empty_like(k)
49
+ fake_v = torch.empty_like(v)
50
+
51
+ grid_q = (triton.cdiv(N_Q, BLOCK), q.shape[0] * H, 1)
52
+ grid_kv = (triton.cdiv(N_KV, BLOCK), q.shape[0] * H, 1)
53
+
54
+ fake_quantize_q[grid_q](
55
+ q, fake_q,
56
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
57
+ fake_q.stride(0), fake_q.stride(1), fake_q.stride(2), fake_q.stride(3),
58
+ H, N_Q, BLOCK_M=BLOCK, HEAD_DIM=D, use_global_sf=False,
59
+ )
60
+ fake_quantize_kv[grid_kv](
61
+ k, v, fake_k, fake_v,
62
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
63
+ fake_k.stride(0), fake_k.stride(1), fake_k.stride(2), fake_k.stride(3),
64
+ H, N_KV, BLOCK_N=BLOCK, HEAD_DIM=D, use_global_sf=False,
65
+ )
66
+ return fake_q, fake_k, fake_v
67
+
68
+
69
+ class SparseFP4OursPAttentionBackend(AttentionBackend):
70
+ accept_output_buffer: bool = True
71
+
72
+ @staticmethod
73
+ def get_supported_head_sizes() -> list[int]:
74
+ return [64, 96, 128, 160, 192, 224, 256]
75
+
76
+ @staticmethod
77
+ def get_name() -> str:
78
+ return "SPARSE_FP4_OURS_P_ATTN"
79
+
80
+ @staticmethod
81
+ def get_impl_cls() -> type["SparseFP4OursPAttentionImpl"]:
82
+ return SparseFP4OursPAttentionImpl
83
+
84
+ @staticmethod
85
+ def get_metadata_cls() -> type["VideoSparseAttentionMetadata"]:
86
+ return VideoSparseAttentionMetadata
87
+
88
+ @staticmethod
89
+ def get_builder_cls() -> type["VideoSparseAttentionMetadataBuilder"]:
90
+ return VideoSparseAttentionMetadataBuilder
91
+
92
+
93
+ class SparseFP4OursPAttentionImpl(AttentionImpl):
94
+
95
+ def __init__(self, num_heads, head_size, causal, softmax_scale,
96
+ num_kv_heads=None, prefix="", **extra):
97
+ self.prefix = prefix
98
+ self.sp_size = get_sp_group().world_size
99
+
100
+ def tile(self, x, num_tiles, tile_partition_indices, non_pad_index):
101
+ t_p = num_tiles[0] * VSA_TILE_SIZE[0]
102
+ h_p = num_tiles[1] * VSA_TILE_SIZE[1]
103
+ w_p = num_tiles[2] * VSA_TILE_SIZE[2]
104
+ out = torch.zeros(
105
+ (x.shape[0], t_p * h_p * w_p, x.shape[-2], x.shape[-1]),
106
+ device=x.device, dtype=x.dtype,
107
+ )
108
+ out[:, non_pad_index] = x[:, tile_partition_indices]
109
+ return out
110
+
111
+ def untile(self, x, reverse_tile_partition_indices, non_pad_index):
112
+ return x[:, non_pad_index][:, reverse_tile_partition_indices]
113
+
114
+ def _is_force_dense(self) -> bool:
115
+ ctx = get_forward_context()
116
+ return ctx.force_dense
117
+
118
+ def preprocess_qkv(self, qkv, attn_metadata):
119
+ if attn_metadata is None or self._is_force_dense():
120
+ return qkv
121
+ return self.tile(qkv, attn_metadata.num_tiles,
122
+ attn_metadata.tile_partition_indices,
123
+ attn_metadata.non_pad_index)
124
+
125
+ def postprocess_output(self, output, attn_metadata):
126
+ if attn_metadata is None or self._is_force_dense():
127
+ return output
128
+ return self.untile(output,
129
+ attn_metadata.reverse_tile_partition_indices,
130
+ attn_metadata.non_pad_index)
131
+
132
+ def forward(self, query, key, value,
133
+ gate_compress_or_metadata=None, attn_metadata=None):
134
+ # Handle both call conventions
135
+ if attn_metadata is None and isinstance(
136
+ gate_compress_or_metadata, (VideoSparseAttentionMetadata, type(None))):
137
+ attn_metadata = gate_compress_or_metadata
138
+
139
+ # ── force_dense: true dense BF16 SDPA (for teacher in distillation) ──
140
+ ctx = get_forward_context()
141
+ if ctx.force_dense:
142
+ return _dense_sdpa_blhd(query, key, value)
143
+
144
+ is_cross = query.shape[1] != key.shape[1]
145
+
146
+ # ── Cross-attention/no metadata: keep dense. The sparse VSA metadata only
147
+ # applies to tiled video self-attention.
148
+ if attn_metadata is None or is_cross:
149
+ return _dense_sdpa_blhd(query, key, value)
150
+
151
+ # ── Self-attention: FP4 quant Q/K/V + block-sparse attention ──
152
+ # BLHD → BHLD
153
+ q = query.transpose(1, 2).contiguous()
154
+ k = key.transpose(1, 2).contiguous()
155
+ v = value.transpose(1, 2).contiguous()
156
+
157
+ # Step 1: FP4 fake quantize Q/K/V with STE (straight-through estimator)
158
+ with torch.no_grad():
159
+ fq, fk, fv = _quantize_qkv_bhld(q, k, v)
160
+ # STE: forward uses quantized values, backward passes gradient through as-is
161
+ fq = q + (fq - q).detach()
162
+ fk = k + (fk - k).detach()
163
+ fv = v + (fv - v).detach()
164
+
165
+ # Step 2: Build sparse block map
166
+ B, H, S, D = fq.shape
167
+ block_elements = math.prod(VSA_TILE_SIZE)
168
+ num_blocks = S // block_elements
169
+
170
+ VSA_sparsity = attn_metadata.VSA_sparsity
171
+ cur_topk = max(1, math.ceil((1 - VSA_sparsity) * num_blocks))
172
+ logger.info(f"[SFP4] S={S} num_blocks={num_blocks} sparsity={VSA_sparsity} topk={cur_topk}/{num_blocks}")
173
+
174
+ block_sizes = attn_metadata.variable_block_sizes.to(
175
+ device=fq.device, dtype=torch.float32).clamp_min(1)
176
+ block_sizes = block_sizes.view(1, 1, num_blocks, 1)
177
+ q_c = (fq.view(B, H, num_blocks, block_elements, D).float().sum(3) /
178
+ block_sizes).to(fq.dtype)
179
+ k_c = (fk.view(B, H, num_blocks, block_elements, D).float().sum(3) /
180
+ block_sizes).to(fk.dtype)
181
+ v_c = (fv.view(B, H, num_blocks, block_elements, D).float().sum(3) /
182
+ block_sizes).to(fv.dtype)
183
+ scores = torch.matmul(q_c, k_c.transpose(-2, -1)) / (D ** 0.5)
184
+ topk_idx = torch.topk(scores, cur_topk, dim=-1).indices
185
+ block_map = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, topk_idx, True)
186
+
187
+ # Step 3: Block-sparse attention with independent group-local P quant.
188
+ out, _ = block_sparse_attn_ours_p(fq, fk, fv, block_map,
189
+ attn_metadata.variable_block_sizes,
190
+ q_c, k_c, v_c)
191
+
192
+ return out.transpose(1, 2) # BHLD → BLHD
standalone_inference/overlay_files/fastvideo/attention/backends/video_sparse_attn.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import functools
3
+ import math
4
+ from dataclasses import dataclass
5
+
6
+ import torch
7
+
8
+ try:
9
+ from fastvideo_kernel import video_sparse_attn
10
+ except ImportError:
11
+ video_sparse_attn = None
12
+
13
+ from typing import Any
14
+
15
+ from fastvideo.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata,
16
+ AttentionMetadataBuilder)
17
+ from fastvideo.distributed import get_sp_group
18
+ from fastvideo.logger import init_logger
19
+
20
+ logger = init_logger(__name__)
21
+ VSA_TILE_SIZE = (4, 4, 4)
22
+
23
+
24
+ @functools.lru_cache(maxsize=10)
25
+ def get_tile_partition_indices(
26
+ dit_seq_shape: tuple[int, int, int],
27
+ tile_size: tuple[int, int, int],
28
+ device: torch.device,
29
+ ) -> torch.LongTensor:
30
+ T, H, W = dit_seq_shape
31
+ ts, hs, ws = tile_size
32
+ indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)
33
+ ls = []
34
+ for t in range(math.ceil(T / ts)):
35
+ for h in range(math.ceil(H / hs)):
36
+ for w in range(math.ceil(W / ws)):
37
+ ls.append(indices[t * ts:min(t * ts + ts, T), h * hs:min(h * hs + hs, H),
38
+ w * ws:min(w * ws + ws, W)].flatten())
39
+ index = torch.cat(ls, dim=0)
40
+ return index
41
+
42
+
43
+ @functools.lru_cache(maxsize=10)
44
+ def get_reverse_tile_partition_indices(
45
+ dit_seq_shape: tuple[int, int, int],
46
+ tile_size: tuple[int, int, int],
47
+ device: torch.device,
48
+ ) -> torch.LongTensor:
49
+ return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))
50
+
51
+
52
+ @functools.lru_cache(maxsize=10)
53
+ def construct_variable_block_sizes(
54
+ dit_seq_shape: tuple[int, int, int],
55
+ num_tiles: tuple[int, int, int],
56
+ device: torch.device,
57
+ ) -> torch.LongTensor:
58
+ """
59
+ Compute the number of valid (non‑padded) tokens inside every
60
+ (ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order
61
+ (t‑tile, h‑tile, w‑tile) that `rearrange` uses.
62
+
63
+ Returns
64
+ -------
65
+ torch.LongTensor # shape: [∏ full_window_size]
66
+ """
67
+ # unpack
68
+ t, h, w = dit_seq_shape
69
+ ts_t, ts_h, ts_w = VSA_TILE_SIZE
70
+ n_t, n_h, n_w = num_tiles
71
+
72
+ def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
73
+ """Vector with the size of each tile along one dimension."""
74
+ sizes = torch.full((n_tiles, ), tile, dtype=torch.int, device=device)
75
+ # size of last (possibly partial) tile
76
+ remainder = dim_len - (n_tiles - 1) * tile
77
+ sizes[-1] = remainder if remainder > 0 else tile
78
+ return sizes
79
+
80
+ t_sizes = _sizes(t, ts_t, n_t) # [n_t]
81
+ h_sizes = _sizes(h, ts_h, n_h) # [n_h]
82
+ w_sizes = _sizes(w, ts_w, n_w) # [n_w]
83
+
84
+ # broadcast‑multiply to get voxels per tile, then flatten
85
+ block_sizes = (
86
+ t_sizes[:, None, None] # [n_t, 1, 1]
87
+ * h_sizes[None, :, None] # [1, n_h, 1]
88
+ * w_sizes[None, None, :] # [1, 1, n_w]
89
+ ).reshape(-1) # [n_t * n_h * n_w]
90
+
91
+ return block_sizes
92
+
93
+
94
+ @functools.lru_cache(maxsize=10)
95
+ def get_non_pad_index(
96
+ variable_block_sizes: torch.LongTensor,
97
+ max_block_size: int,
98
+ ):
99
+ n_win = variable_block_sizes.shape[0]
100
+ device = variable_block_sizes.device
101
+ starts_pad = torch.arange(n_win, device=device) * max_block_size
102
+ index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :]
103
+ index_mask = torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
104
+ return index_pad[index_mask]
105
+
106
+
107
+ class VideoSparseAttentionBackend(AttentionBackend):
108
+
109
+ accept_output_buffer: bool = True
110
+
111
+ @staticmethod
112
+ def get_supported_head_sizes() -> list[int]:
113
+ return [64, 128]
114
+
115
+ @staticmethod
116
+ def get_name() -> str:
117
+ return "VIDEO_SPARSE_ATTN"
118
+
119
+ @staticmethod
120
+ def get_impl_cls() -> type["VideoSparseAttentionImpl"]:
121
+ return VideoSparseAttentionImpl
122
+
123
+ @staticmethod
124
+ def get_metadata_cls() -> type["VideoSparseAttentionMetadata"]:
125
+ return VideoSparseAttentionMetadata
126
+
127
+ @staticmethod
128
+ def get_builder_cls() -> type["VideoSparseAttentionMetadataBuilder"]:
129
+ return VideoSparseAttentionMetadataBuilder
130
+
131
+
132
+ @dataclass
133
+ class VideoSparseAttentionMetadata(AttentionMetadata):
134
+ current_timestep: int
135
+ dit_seq_shape: list[int]
136
+ num_tiles: list[int]
137
+ total_seq_length: int
138
+ tile_partition_indices: torch.LongTensor
139
+ reverse_tile_partition_indices: torch.LongTensor
140
+ variable_block_sizes: torch.LongTensor
141
+ non_pad_index: torch.LongTensor
142
+
143
+
144
+ class VideoSparseAttentionMetadataBuilder(AttentionMetadataBuilder):
145
+
146
+ def __init__(self) -> None:
147
+ pass
148
+
149
+ def prepare(self) -> None:
150
+ pass
151
+
152
+ def build( # type: ignore
153
+ self,
154
+ current_timestep: int,
155
+ raw_latent_shape: tuple[int, int, int],
156
+ patch_size: tuple[int, int, int],
157
+ VSA_sparsity: float,
158
+ device: torch.device,
159
+ **kwargs: dict[str, Any],
160
+ ) -> VideoSparseAttentionMetadata:
161
+ patch_size = patch_size
162
+ dit_seq_shape = (raw_latent_shape[0] // patch_size[0], raw_latent_shape[1] // patch_size[1],
163
+ raw_latent_shape[2] // patch_size[2])
164
+
165
+ num_tiles = (math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]), math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),
166
+ math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]))
167
+ total_seq_length = math.prod(dit_seq_shape)
168
+
169
+ tile_partition_indices = get_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
170
+ reverse_tile_partition_indices = get_reverse_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
171
+ variable_block_sizes = construct_variable_block_sizes(dit_seq_shape, num_tiles, device)
172
+ non_pad_index = get_non_pad_index(variable_block_sizes, math.prod(VSA_TILE_SIZE))
173
+
174
+ return VideoSparseAttentionMetadata(
175
+ current_timestep=current_timestep,
176
+ dit_seq_shape=dit_seq_shape, # type: ignore
177
+ VSA_sparsity=VSA_sparsity, # type: ignore
178
+ num_tiles=num_tiles, # type: ignore
179
+ total_seq_length=total_seq_length, # type: ignore
180
+ tile_partition_indices=tile_partition_indices, # type: ignore
181
+ reverse_tile_partition_indices=reverse_tile_partition_indices,
182
+ variable_block_sizes=variable_block_sizes,
183
+ non_pad_index=non_pad_index)
184
+
185
+
186
+ class VideoSparseAttentionImpl(AttentionImpl):
187
+
188
+ def __init__(
189
+ self,
190
+ num_heads: int,
191
+ head_size: int,
192
+ causal: bool,
193
+ softmax_scale: float,
194
+ num_kv_heads: int | None = None,
195
+ prefix: str = "",
196
+ **extra_impl_args,
197
+ ) -> None:
198
+ self.prefix = prefix
199
+ sp_group = get_sp_group()
200
+ self.sp_size = sp_group.world_size
201
+
202
+ def tile(self, x: torch.Tensor, num_tiles: list[int], tile_partition_indices: torch.LongTensor,
203
+ non_pad_index: torch.LongTensor) -> torch.Tensor:
204
+ t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
205
+ h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
206
+ w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
207
+
208
+ x_padded = torch.zeros((x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1]),
209
+ device=x.device,
210
+ dtype=x.dtype)
211
+ x_padded[:, non_pad_index] = x[:, tile_partition_indices]
212
+ return x_padded
213
+
214
+ def untile(self, x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor,
215
+ non_pad_index: torch.LongTensor) -> torch.Tensor:
216
+ x = x[:, non_pad_index][:, reverse_tile_partition_indices]
217
+ return x
218
+
219
+ def preprocess_qkv(
220
+ self,
221
+ qkv: torch.Tensor,
222
+ attn_metadata: VideoSparseAttentionMetadata,
223
+ ) -> torch.Tensor:
224
+ return self.tile(qkv, attn_metadata.num_tiles, attn_metadata.tile_partition_indices,
225
+ attn_metadata.non_pad_index)
226
+
227
+ def postprocess_output(
228
+ self,
229
+ output: torch.Tensor,
230
+ attn_metadata: VideoSparseAttentionMetadata,
231
+ ) -> torch.Tensor:
232
+ return self.untile(output, attn_metadata.reverse_tile_partition_indices, attn_metadata.non_pad_index)
233
+
234
+ def forward( # type: ignore[override]
235
+ self,
236
+ query: torch.Tensor,
237
+ key: torch.Tensor,
238
+ value: torch.Tensor,
239
+ gate_compress: torch.Tensor,
240
+ attn_metadata: VideoSparseAttentionMetadata,
241
+ ) -> torch.Tensor:
242
+ query = query.transpose(1, 2).contiguous()
243
+ key = key.transpose(1, 2).contiguous()
244
+ value = value.transpose(1, 2).contiguous()
245
+ gate_compress = gate_compress.transpose(1, 2).contiguous()
246
+
247
+ VSA_sparsity = attn_metadata.VSA_sparsity
248
+
249
+ cur_topk = math.ceil((1 - VSA_sparsity) * (attn_metadata.total_seq_length / math.prod(VSA_TILE_SIZE)))
250
+
251
+ if video_sparse_attn is None:
252
+ raise NotImplementedError("video_sparse_attn is not installed")
253
+ hidden_states = video_sparse_attn(query,
254
+ key,
255
+ value,
256
+ attn_metadata.variable_block_sizes,
257
+ attn_metadata.variable_block_sizes,
258
+ cur_topk,
259
+ block_size=VSA_TILE_SIZE,
260
+ compress_attn_weight=gate_compress).transpose(1, 2)
261
+
262
+ return hidden_states
standalone_inference/overlay_files/fastvideo/configs/models/dits/base.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from dataclasses import dataclass, field
3
+ from typing import Any
4
+
5
+ from fastvideo.configs.models.base import ArchConfig, ModelConfig
6
+ from fastvideo.layers.quantization import QuantizationConfig
7
+ from fastvideo.platforms import AttentionBackendEnum
8
+
9
+
10
+ @dataclass
11
+ class DiTArchConfig(ArchConfig):
12
+ _fsdp_shard_conditions: list = field(default_factory=list)
13
+ _compile_conditions: list = field(default_factory=list)
14
+ param_names_mapping: dict = field(default_factory=dict)
15
+ reverse_param_names_mapping: dict = field(default_factory=dict)
16
+ lora_param_names_mapping: dict = field(default_factory=dict)
17
+ _supported_attention_backends: tuple[AttentionBackendEnum,
18
+ ...] = (AttentionBackendEnum.SAGE_ATTN, AttentionBackendEnum.FLASH_ATTN,
19
+ AttentionBackendEnum.TORCH_SDPA,
20
+ AttentionBackendEnum.VIDEO_SPARSE_ATTN,
21
+ AttentionBackendEnum.VMOBA_ATTN, AttentionBackendEnum.SAGE_ATTN_THREE,
22
+ AttentionBackendEnum.ATTN_QAT_INFER,
23
+ AttentionBackendEnum.ATTN_QAT_TRAIN, AttentionBackendEnum.SLA_ATTN,
24
+ AttentionBackendEnum.SAGE_SLA_ATTN,
25
+ AttentionBackendEnum.SPARSE_FP4_ATTN,
26
+ AttentionBackendEnum.SPARSE_FP4_OURS_P_ATTN)
27
+
28
+ hidden_size: int = 0
29
+ num_attention_heads: int = 0
30
+ num_channels_latents: int = 0
31
+ in_channels: int | None = 0
32
+ out_channels: int | None = 0
33
+ patch_size: int | tuple[int, int, int] | None = None
34
+ expand_timesteps: bool = False
35
+ num_layers: int = 0
36
+ ffn_dim: int = 0
37
+ exclude_lora_layers: list[str] = field(default_factory=list)
38
+ boundary_ratio: float | None = None
39
+
40
+ def __post_init__(self) -> None:
41
+ if not self._compile_conditions:
42
+ self._compile_conditions = self._fsdp_shard_conditions.copy()
43
+
44
+
45
+ @dataclass
46
+ class DiTConfig(ModelConfig):
47
+ arch_config: DiTArchConfig = field(default_factory=DiTArchConfig)
48
+
49
+ # FastVideoDiT-specific parameters
50
+ prefix: str = ""
51
+ quant_config: QuantizationConfig | None = None
52
+ expand_timesteps: bool = False
53
+ boundary_ratio: float | None = None
54
+
55
+ def __post_init__(self) -> None:
56
+ super().__post_init__()
57
+ self.arch_config.expand_timesteps = self.expand_timesteps
58
+ self.arch_config.boundary_ratio = self.boundary_ratio
59
+
60
+ @staticmethod
61
+ def add_cli_args(parser: Any, prefix: str = "dit-config") -> Any:
62
+ """Add CLI arguments for DiTConfig fields"""
63
+ parser.add_argument(
64
+ f"--{prefix}.prefix",
65
+ type=str,
66
+ dest=f"{prefix.replace('-', '_')}.prefix",
67
+ default=DiTConfig.prefix,
68
+ help="Prefix for the DiT model",
69
+ )
70
+
71
+ parser.add_argument(
72
+ f"--{prefix}.quant-config",
73
+ type=str,
74
+ dest=f"{prefix.replace('-', '_')}.quant_config",
75
+ default=None,
76
+ help="Quantization configuration for the DiT model",
77
+ )
78
+
79
+ return parser
standalone_inference/overlay_files/fastvideo/configs/pipelines/wan.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from collections.abc import Callable
3
+ from dataclasses import dataclass, field
4
+
5
+ import torch
6
+
7
+ from fastvideo.configs.models import DiTConfig, EncoderConfig, VAEConfig
8
+ from fastvideo.configs.models.dits import WanVideoConfig
9
+ from fastvideo.configs.models.dits.matrixgame import MatrixGameWanVideoConfig
10
+ from fastvideo.configs.models.encoders import (BaseEncoderOutput, CLIPVisionConfig, T5Config,
11
+ WAN2_1ControlCLIPVisionConfig)
12
+ from fastvideo.configs.models.vaes import WanVAEConfig
13
+ from fastvideo.configs.pipelines.base import PipelineConfig
14
+
15
+
16
+ def t5_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor:
17
+ mask: torch.Tensor = outputs.attention_mask
18
+ hidden_state: torch.Tensor = outputs.last_hidden_state
19
+ seq_lens = mask.gt(0).sum(dim=1).long()
20
+ assert torch.isnan(hidden_state).sum() == 0
21
+ prompt_embeds = [u[:v] for u, v in zip(hidden_state, seq_lens, strict=True)]
22
+ prompt_embeds_tensor: torch.Tensor = torch.stack(
23
+ [torch.cat([u, u.new_zeros(512 - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0)
24
+ return prompt_embeds_tensor
25
+
26
+
27
+ @dataclass
28
+ class WanT2V480PConfig(PipelineConfig):
29
+ """Base configuration for Wan T2V 1.3B pipeline architecture."""
30
+
31
+ # WanConfig-specific parameters with defaults
32
+ # DiT
33
+ dit_config: DiTConfig = field(default_factory=WanVideoConfig)
34
+ # VAE
35
+ vae_config: VAEConfig = field(default_factory=WanVAEConfig)
36
+ vae_tiling: bool = False
37
+ vae_sp: bool = False
38
+
39
+ # Denoising stage
40
+ flow_shift: float | None = 3.0
41
+
42
+ # Text encoding stage
43
+ text_encoder_configs: tuple[EncoderConfig, ...] = field(default_factory=lambda: (T5Config(), ))
44
+ postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor],
45
+ ...] = field(default_factory=lambda: (t5_postprocess_text, ))
46
+
47
+ # Precision for each component
48
+ precision: str = "bf16"
49
+ vae_precision: str = "fp32"
50
+ text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32", ))
51
+
52
+ # self-forcing params
53
+ warp_denoising_step: bool = True
54
+
55
+ # WanConfig-specific added parameters
56
+
57
+ def __post_init__(self):
58
+ self.vae_config.load_encoder = False
59
+ self.vae_config.load_decoder = True
60
+
61
+
62
+ @dataclass
63
+ class WanT2V720PConfig(WanT2V480PConfig):
64
+ """Base configuration for Wan T2V 14B 720P pipeline architecture."""
65
+
66
+ # WanConfig-specific parameters with defaults
67
+
68
+ # Denoising stage
69
+ flow_shift: float | None = 5.0
70
+
71
+
72
+ @dataclass
73
+ class WanI2V480PConfig(WanT2V480PConfig):
74
+ """Base configuration for Wan I2V 14B 480P pipeline architecture."""
75
+
76
+ # WanConfig-specific parameters with defaults
77
+
78
+ # Precision for each component
79
+ image_encoder_config: EncoderConfig = field(default_factory=CLIPVisionConfig)
80
+ image_encoder_precision: str = "fp32"
81
+
82
+ def __post_init__(self) -> None:
83
+ self.vae_config.load_encoder = True
84
+ self.vae_config.load_decoder = True
85
+
86
+
87
+ @dataclass
88
+ class WanI2V720PConfig(WanI2V480PConfig):
89
+ """Base configuration for Wan I2V 14B 720P pipeline architecture."""
90
+
91
+ # WanConfig-specific parameters with defaults
92
+
93
+ # Denoising stage
94
+ flow_shift: float | None = 5.0
95
+
96
+
97
+ @dataclass
98
+ class WANV2VConfig(WanI2V480PConfig):
99
+ """Configuration for WAN2.1 1.3B Control pipeline."""
100
+
101
+ image_encoder_config: EncoderConfig = field(default_factory=WAN2_1ControlCLIPVisionConfig)
102
+ # CLIP encoder precision
103
+ image_encoder_precision: str = 'bf16'
104
+
105
+
106
+ @dataclass
107
+ class FastWan2_1_T2V_480P_Config(WanT2V480PConfig):
108
+ """Base configuration for FastWan T2V 1.3B 480P pipeline architecture with DMD"""
109
+
110
+ # WanConfig-specific parameters with defaults
111
+
112
+ # Denoising stage
113
+ flow_shift: float | None = 8.0
114
+ dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522])
115
+
116
+
117
+ @dataclass
118
+ class Wan2_2_TI2V_5B_Config(WanT2V480PConfig):
119
+ flow_shift: float | None = 5.0
120
+ ti2v_task: bool = True
121
+ expand_timesteps: bool = True
122
+
123
+ def __post_init__(self) -> None:
124
+ self.vae_config.load_encoder = True
125
+ self.vae_config.load_decoder = True
126
+ self.dit_config.expand_timesteps = self.expand_timesteps
127
+
128
+
129
+ @dataclass
130
+ class FastWan2_2_TI2V_5B_Config(Wan2_2_TI2V_5B_Config):
131
+ flow_shift: float | None = 5.0
132
+ dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522])
133
+
134
+
135
+ @dataclass
136
+ class Wan2_2_T2V_A14B_Config(WanT2V480PConfig):
137
+ flow_shift: float | None = 12.0
138
+ boundary_ratio: float | None = 0.875
139
+
140
+ # self-forcing params
141
+ dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 750, 500, 250])
142
+ warp_denoising_step: bool = True
143
+
144
+ def __post_init__(self) -> None:
145
+ self.dit_config.boundary_ratio = self.boundary_ratio
146
+
147
+
148
+ @dataclass
149
+ class Wan2_2_I2V_A14B_Config(WanI2V480PConfig):
150
+ flow_shift: float | None = 5.0
151
+ boundary_ratio: float | None = 0.900
152
+
153
+ def __post_init__(self) -> None:
154
+ super().__post_init__()
155
+ self.dit_config.boundary_ratio = self.boundary_ratio
156
+
157
+
158
+ # =============================================
159
+ # ============= Causal Self-Forcing =============
160
+ # =============================================
161
+ @dataclass
162
+ class SelfForcingWanT2V480PConfig(WanT2V480PConfig):
163
+ is_causal: bool = True
164
+ flow_shift: float | None = 5.0
165
+ dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 750, 500, 250])
166
+ warp_denoising_step: bool = True
167
+
168
+
169
+ @dataclass
170
+ class SelfForcingWan2_2_T2V480PConfig(Wan2_2_T2V_A14B_Config):
171
+ is_causal: bool = True
172
+ flow_shift: float | None = 12.0
173
+ boundary_ratio: float | None = 0.875
174
+ dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 850, 700, 550, 350, 275, 200, 125])
175
+ warp_denoising_step: bool = True
176
+
177
+ def __post_init__(self) -> None:
178
+ self.vae_config.load_encoder = True
179
+ self.vae_config.load_decoder = True
180
+
181
+
182
+ # =============================================
183
+ # ============= Matrix Game ===================
184
+ # =============================================
185
+ @dataclass
186
+ class MatrixGameBaseI2V480PConfig(WanI2V480PConfig):
187
+ dit_config: DiTConfig = field(default_factory=MatrixGameWanVideoConfig)
188
+ flow_shift: float | None = 5.0
189
+
190
+
191
+ @dataclass
192
+ class MatrixGameI2V480PConfig(WanI2V480PConfig):
193
+ dit_config: DiTConfig = field(default_factory=MatrixGameWanVideoConfig)
194
+
195
+ image_encoder_config: EncoderConfig = field(default_factory=WAN2_1ControlCLIPVisionConfig)
196
+
197
+ is_causal: bool = True
198
+ flow_shift: float | None = 5.0
199
+ dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 666, 333])
200
+ warp_denoising_step: bool = True
201
+ context_noise: int = 0
202
+ num_frames_per_block: int = 3
203
+ # sliding_window_num_frames: int = 15
standalone_inference/overlay_files/fastvideo/configs/sample/base.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from dataclasses import dataclass
3
+ from typing import Any
4
+
5
+ from fastvideo.logger import init_logger
6
+ from fastvideo.utils import StoreBoolean
7
+
8
+ logger = init_logger(__name__)
9
+
10
+
11
+ @dataclass
12
+ class SamplingParam:
13
+ """
14
+ Sampling parameters for video generation.
15
+ """
16
+ # All fields below are copied from ForwardBatch
17
+ data_type: str = "video"
18
+
19
+ # Image inputs
20
+ image_path: str | None = None
21
+ pil_image: Any | None = None
22
+
23
+ # Video inputs
24
+ video_path: str | None = None
25
+
26
+ # Action control inputs (Matrix-Game)
27
+ mouse_cond: Any | None = None # Shape: (B, T, 2)
28
+ keyboard_cond: Any | None = None # Shape: (B, T, K)
29
+ grid_sizes: Any | None = None # Shape: (3,) [F,H,W]
30
+
31
+ # Camera control inputs (HYWorld)
32
+ pose: str | None = None # Camera trajectory: pose string (e.g., 'w-31') or JSON file path
33
+
34
+ # Camera control inputs (LingBotWorld)
35
+ c2ws_plucker_emb: Any | None = None # Plucker embedding: [B, C, F_lat, H_lat, W_lat]
36
+
37
+ # Refine inputs (LongCat 480p->720p upscaling)
38
+ # Path-based refine (load stage1 video from disk, e.g. MP4)
39
+ refine_from: str | None = None # Path to stage1 video (480p output from distill)
40
+ t_thresh: float = 0.5 # Threshold for timestep scheduling in refinement
41
+ spatial_refine_only: bool = False # If True, only spatial (no temporal doubling)
42
+ num_cond_frames: int = 0 # Number of conditioning frames
43
+ # In-memory refine input (for two-stage pipeline where stage1 frames are already in memory)
44
+ # This mirrors LongCat's demo where a list of frames (e.g. np.ndarray or PIL.Image)
45
+ # is passed directly to the refinement pipeline instead of reloading from disk.
46
+ stage1_video: Any | None = None
47
+
48
+ # Text inputs
49
+ prompt: str | list[str] | None = None
50
+ negative_prompt: str | None = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
51
+ prompt_path: str | None = None
52
+ output_path: str = "outputs/"
53
+ output_video_name: str | None = None
54
+
55
+ # Batch info
56
+ num_videos_per_prompt: int = 1
57
+ seed: int = 1024
58
+
59
+ # Original dimensions (before VAE scaling)
60
+ num_frames: int = 125
61
+ height: int = 720
62
+ width: int = 1280
63
+ height_sr: int = 1072
64
+ width_sr: int = 1920
65
+ fps: int = 24
66
+
67
+ # Denoising parameters
68
+ num_inference_steps: int = 50
69
+ num_inference_steps_sr: int = 50
70
+ guidance_scale: float = 1.0
71
+ guidance_scale_2: float | None = None
72
+ guidance_rescale: float = 0.0
73
+ boundary_ratio: float | None = None
74
+ sigmas: list[float] | None = None
75
+
76
+ # TeaCache parameters
77
+ enable_teacache: bool = False
78
+
79
+ # GEN3C camera control
80
+ trajectory_type: str | None = None
81
+ movement_distance: float | None = None
82
+ camera_rotation: str | None = None
83
+
84
+ # Misc
85
+ save_video: bool = True
86
+ return_frames: bool = True
87
+ return_trajectory_latents: bool = False # returns all latents for each timestep
88
+ return_trajectory_decoded: bool = False # returns decoded latents for each timestep
89
+
90
+ def __post_init__(self) -> None:
91
+ self.data_type = "video" if self.num_frames > 1 else "image"
92
+
93
+ def __getattr__(self, name: str) -> Any:
94
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
95
+
96
+ def check_sampling_param(self) -> None:
97
+ if self.prompt_path and not self.prompt_path.endswith(".txt"):
98
+ raise ValueError("prompt_path must be a txt file")
99
+
100
+ def update(self, source_dict: dict[str, Any]) -> None:
101
+ for key, value in source_dict.items():
102
+ if hasattr(self, key):
103
+ setattr(self, key, value)
104
+ else:
105
+ logger.exception("%s has no attribute %s", type(self).__name__, key)
106
+
107
+ self.__post_init__()
108
+
109
+ @classmethod
110
+ def from_pretrained(cls, model_path: str) -> "SamplingParam":
111
+ from fastvideo.registry import get_sampling_param_cls_for_name
112
+ sampling_cls = get_sampling_param_cls_for_name(model_path)
113
+ if sampling_cls is not None:
114
+ sampling_param: SamplingParam = sampling_cls()
115
+ else:
116
+ logger.warning("Couldn't find an optimal sampling param for %s. Using the default sampling param.",
117
+ model_path)
118
+ sampling_param = cls()
119
+
120
+ return sampling_param
121
+
122
+ @staticmethod
123
+ def add_cli_args(parser: Any) -> Any:
124
+ """Add CLI arguments for SamplingParam fields"""
125
+ parser.add_argument(
126
+ "--prompt",
127
+ type=str,
128
+ default=SamplingParam.prompt,
129
+ help="Text prompt for video generation",
130
+ )
131
+ parser.add_argument(
132
+ "--negative-prompt",
133
+ type=str,
134
+ default=SamplingParam.negative_prompt,
135
+ help="Negative text prompt for video generation",
136
+ )
137
+ parser.add_argument(
138
+ "--prompt-path",
139
+ type=str,
140
+ default=SamplingParam.prompt_path,
141
+ help="Path to a text file containing the prompt",
142
+ )
143
+ parser.add_argument(
144
+ "--output-path",
145
+ type=str,
146
+ default=SamplingParam.output_path,
147
+ help="Path to save the generated video",
148
+ )
149
+ parser.add_argument(
150
+ "--output-video-name",
151
+ type=str,
152
+ default=SamplingParam.output_video_name,
153
+ help="Name of the output video",
154
+ )
155
+ parser.add_argument(
156
+ "--num-videos-per-prompt",
157
+ type=int,
158
+ default=SamplingParam.num_videos_per_prompt,
159
+ help="Number of videos to generate per prompt",
160
+ )
161
+ parser.add_argument(
162
+ "--seed",
163
+ type=int,
164
+ default=SamplingParam.seed,
165
+ help="Random seed for generation",
166
+ )
167
+ parser.add_argument(
168
+ "--num-frames",
169
+ type=int,
170
+ default=SamplingParam.num_frames,
171
+ help="Number of frames to generate",
172
+ )
173
+ parser.add_argument(
174
+ "--height",
175
+ type=int,
176
+ default=SamplingParam.height,
177
+ help="Height of generated video",
178
+ )
179
+ parser.add_argument(
180
+ "--width",
181
+ type=int,
182
+ default=SamplingParam.width,
183
+ help="Width of generated video",
184
+ )
185
+ parser.add_argument(
186
+ "--fps",
187
+ type=int,
188
+ default=SamplingParam.fps,
189
+ help="Frames per second for saved video",
190
+ )
191
+ parser.add_argument(
192
+ "--num-inference-steps",
193
+ type=int,
194
+ default=SamplingParam.num_inference_steps,
195
+ help="Number of denoising steps",
196
+ )
197
+ parser.add_argument(
198
+ "--guidance-scale",
199
+ type=float,
200
+ default=SamplingParam.guidance_scale,
201
+ help="Classifier-free guidance scale",
202
+ )
203
+ parser.add_argument(
204
+ "--guidance-rescale",
205
+ type=float,
206
+ default=SamplingParam.guidance_rescale,
207
+ help="Guidance rescale factor",
208
+ )
209
+ parser.add_argument(
210
+ "--boundary-ratio",
211
+ type=float,
212
+ default=SamplingParam.boundary_ratio,
213
+ help="Boundary timestep ratio",
214
+ )
215
+ parser.add_argument(
216
+ "--save-video",
217
+ action="store_true",
218
+ default=SamplingParam.save_video,
219
+ help="Whether to save the video to disk",
220
+ )
221
+ parser.add_argument(
222
+ "--no-save-video",
223
+ action="store_false",
224
+ dest="save_video",
225
+ help="Don't save the video to disk",
226
+ )
227
+ parser.add_argument(
228
+ "--return-frames",
229
+ action="store_true",
230
+ default=False,
231
+ help="Whether to return the raw frames",
232
+ )
233
+ parser.add_argument(
234
+ "--image-path",
235
+ type=str,
236
+ default=SamplingParam.image_path,
237
+ help="Path to input image for image-to-video generation",
238
+ )
239
+ parser.add_argument(
240
+ "--video-path",
241
+ type=str,
242
+ default=SamplingParam.video_path,
243
+ help="Path to input video for video-to-video generation",
244
+ )
245
+ parser.add_argument(
246
+ "--refine-from",
247
+ type=str,
248
+ default=SamplingParam.refine_from,
249
+ help="Path to stage1 video for refinement (LongCat 480p->720p)",
250
+ )
251
+ parser.add_argument(
252
+ "--t-thresh",
253
+ type=float,
254
+ default=SamplingParam.t_thresh,
255
+ help="Threshold for timestep scheduling in refinement (default: 0.5)",
256
+ )
257
+ parser.add_argument(
258
+ "--spatial-refine-only",
259
+ action=StoreBoolean,
260
+ default=SamplingParam.spatial_refine_only,
261
+ help="Only perform spatial super-resolution (no temporal doubling)",
262
+ )
263
+ parser.add_argument(
264
+ "--num-cond-frames",
265
+ type=int,
266
+ default=SamplingParam.num_cond_frames,
267
+ help="Number of conditioning frames for refinement",
268
+ )
269
+ parser.add_argument(
270
+ "--moba-config-path",
271
+ type=str,
272
+ default=None,
273
+ help="Path to a JSON file containing V-MoBA specific configurations.",
274
+ )
275
+ parser.add_argument(
276
+ "--return-trajectory-latents",
277
+ action="store_true",
278
+ default=SamplingParam.return_trajectory_latents,
279
+ help="Whether to return the trajectory",
280
+ )
281
+ parser.add_argument(
282
+ "--return-trajectory-decoded",
283
+ action="store_true",
284
+ default=SamplingParam.return_trajectory_decoded,
285
+ help="Whether to return the decoded trajectory",
286
+ )
287
+ return parser
288
+
289
+
290
+ @dataclass
291
+ class CacheParams:
292
+ cache_type: str = "none"
standalone_inference/overlay_files/fastvideo/configs/sample/wan.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from dataclasses import dataclass
3
+
4
+ from fastvideo.configs.sample.base import SamplingParam
5
+
6
+
7
+ @dataclass
8
+ class WanT2V_1_3B_SamplingParam(SamplingParam):
9
+ # Video parameters
10
+ height: int = 480
11
+ width: int = 832
12
+ num_frames: int = 81
13
+ fps: int = 16
14
+
15
+ # Denoising stage
16
+ guidance_scale: float = 3.0
17
+ negative_prompt: str = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
18
+ num_inference_steps: int = 50
19
+
20
+
21
+ @dataclass
22
+ class WanT2V_14B_SamplingParam(SamplingParam):
23
+ # Video parameters
24
+ height: int = 720
25
+ width: int = 1280
26
+ num_frames: int = 81
27
+ fps: int = 16
28
+
29
+ # Denoising stage
30
+ guidance_scale: float = 5.0
31
+ negative_prompt: str = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
32
+ num_inference_steps: int = 50
33
+
34
+
35
+ @dataclass
36
+ class WanI2V_14B_480P_SamplingParam(WanT2V_1_3B_SamplingParam):
37
+ # Denoising stage
38
+ guidance_scale: float = 5.0
39
+ num_inference_steps: int = 40
40
+
41
+
42
+ @dataclass
43
+ class WanI2V_14B_720P_SamplingParam(WanT2V_14B_SamplingParam):
44
+ # Denoising stage
45
+ guidance_scale: float = 5.0
46
+ num_inference_steps: int = 40
47
+
48
+
49
+ @dataclass
50
+ class FastWanT2V480P_SamplingParam(WanT2V_1_3B_SamplingParam):
51
+ # DMD parameters
52
+ # dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522])
53
+ num_inference_steps: int = 3
54
+ num_frames: int = 61
55
+ height: int = 448
56
+ width: int = 832
57
+ fps: int = 16
58
+
59
+
60
+ # =============================================
61
+ # ============= Wan2.1 Fun Models =============
62
+ # =============================================
63
+ @dataclass
64
+ class Wan2_1_Fun_1_3B_InP_SamplingParam(SamplingParam):
65
+ """Sampling parameters for Wan2.1 Fun 1.3B InP model."""
66
+ height: int = 480
67
+ width: int = 832
68
+ num_frames: int = 81
69
+ fps: int = 16
70
+ negative_prompt: str | None = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
71
+ guidance_scale: float = 6.0
72
+ num_inference_steps: int = 50
73
+
74
+
75
+ @dataclass
76
+ class Wan2_1_Fun_1_3B_Control_SamplingParam(SamplingParam):
77
+ fps: int = 16
78
+ num_frames: int = 49
79
+ height: int = 832
80
+ width: int = 480
81
+ guidance_scale: float = 6.0
82
+
83
+
84
+ # =============================================
85
+ # ============= Wan2.2 TI2V Models =============
86
+ # =============================================
87
+ @dataclass
88
+ class Wan2_2_Base_SamplingParam(SamplingParam):
89
+ """Sampling parameters for Wan2.2 TI2V 5B model."""
90
+ negative_prompt: str | None = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
91
+
92
+
93
+ @dataclass
94
+ class Wan2_2_TI2V_5B_SamplingParam(Wan2_2_Base_SamplingParam):
95
+ """Sampling parameters for Wan2.2 TI2V 5B model."""
96
+ height: int = 704
97
+ width: int = 1280
98
+ num_frames: int = 121
99
+ fps: int = 24
100
+ guidance_scale: float = 5.0
101
+ num_inference_steps: int = 50
102
+
103
+
104
+ @dataclass
105
+ class Wan2_2_T2V_A14B_SamplingParam(Wan2_2_Base_SamplingParam):
106
+ guidance_scale: float = 4.0 # high_noise
107
+ guidance_scale_2: float = 3.0 # low_noise
108
+ num_inference_steps: int = 40
109
+ fps: int = 16
110
+ # NOTE(will): default boundary timestep is tracked by PipelineConfig, but
111
+ # can be overridden during sampling
112
+
113
+
114
+ @dataclass
115
+ class Wan2_2_I2V_A14B_SamplingParam(Wan2_2_Base_SamplingParam):
116
+ guidance_scale: float = 3.5 # high_noise
117
+ guidance_scale_2: float = 3.5 # low_noise
118
+ num_inference_steps: int = 40
119
+ fps: int = 16
120
+ # NOTE(will): default boundary timestep is tracked by PipelineConfig, but
121
+ # can be overridden during sampling
122
+
123
+
124
+ @dataclass
125
+ class Wan2_2_Fun_A14B_Control_SamplingParam(Wan2_1_Fun_1_3B_Control_SamplingParam):
126
+ num_frames: int = 81
127
+
128
+
129
+ # =============================================
130
+ # ============= Causal Self-Forcing =============
131
+ # =============================================
132
+ @dataclass
133
+ class SelfForcingWan2_1_T2V_1_3B_480P_SamplingParam(Wan2_1_Fun_1_3B_InP_SamplingParam):
134
+ pass
135
+
136
+
137
+ @dataclass
138
+ class SelfForcingWan2_2_T2V_A14B_480P_SamplingParam(Wan2_2_T2V_A14B_SamplingParam):
139
+ num_inference_steps: int = 8
140
+ num_frames: int = 81
141
+ height: int = 448
142
+ width: int = 832
143
+ fps: int = 16
144
+
145
+
146
+ @dataclass
147
+ class MatrixGame2_SamplingParam(SamplingParam):
148
+ height: int = 352
149
+ width: int = 640
150
+ num_frames: int = 57
151
+ fps: int = 25
152
+ guidance_scale: float = 1.0
153
+ num_inference_steps: int = 3
154
+ negative_prompt: str | None = None
standalone_inference/overlay_files/fastvideo/configs/wan_1.3B_t2v_pipeline.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embedded_cfg_scale": 6.0,
3
+ "flow_shift": 3,
4
+ "dit_cpu_offload": true,
5
+ "disable_autocast": false,
6
+ "precision": "bf16",
7
+ "vae_precision": "fp32",
8
+ "vae_tiling": false,
9
+ "vae_sp": false,
10
+ "vae_config": {
11
+ "load_encoder": false,
12
+ "load_decoder": true,
13
+ "tile_sample_min_height": 256,
14
+ "tile_sample_min_width": 256,
15
+ "tile_sample_min_num_frames": 16,
16
+ "tile_sample_stride_height": 192,
17
+ "tile_sample_stride_width": 192,
18
+ "tile_sample_stride_num_frames": 12,
19
+ "blend_num_frames": 8,
20
+ "use_tiling": false,
21
+ "use_temporal_tiling": false,
22
+ "use_parallel_tiling": false,
23
+ "use_feature_cache": true
24
+ },
25
+ "dit_config": {
26
+ "prefix": "Wan",
27
+ "quant_config": null
28
+ },
29
+ "text_encoder_precisions": [
30
+ "fp32"
31
+ ],
32
+ "text_encoder_configs": [
33
+ {
34
+ "prefix": "t5",
35
+ "quant_config": null,
36
+ "lora_config": null
37
+ }
38
+ ],
39
+ "enable_torch_compile": false
40
+ }
standalone_inference/overlay_files/fastvideo/entrypoints/cli/generate.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/serve.py
3
+
4
+ import argparse
5
+ import dataclasses
6
+ import os
7
+ from typing import cast
8
+
9
+ from fastvideo import VideoGenerator
10
+ from fastvideo.configs.sample.base import SamplingParam
11
+ from fastvideo.entrypoints.cli.cli_types import CLISubcommand
12
+ from fastvideo.entrypoints.cli.utils import RaiseNotImplementedAction
13
+ from fastvideo.fastvideo_args import FastVideoArgs
14
+ from fastvideo.logger import init_logger
15
+ from fastvideo.utils import FlexibleArgumentParser
16
+
17
+ logger = init_logger(__name__)
18
+
19
+
20
+ class GenerateSubcommand(CLISubcommand):
21
+ """The `generate` subcommand for the FastVideo CLI"""
22
+
23
+ def __init__(self) -> None:
24
+ self.name = "generate"
25
+ super().__init__()
26
+ self.init_arg_names = self._get_init_arg_names()
27
+ self.generation_arg_names = self._get_generation_arg_names()
28
+
29
+ def _get_init_arg_names(self) -> list[str]:
30
+ """Get names of arguments for VideoGenerator initialization"""
31
+ return ["num_gpus", "tp_size", "sp_size", "model_path"]
32
+
33
+ def _get_generation_arg_names(self) -> list[str]:
34
+ """Get names of arguments for generate_video method"""
35
+ return [field.name for field in dataclasses.fields(SamplingParam)]
36
+
37
+ def cmd(self, args: argparse.Namespace) -> None:
38
+ excluded_args = ['subparser', 'config', 'dispatch_function']
39
+
40
+ provided_args = {}
41
+ for k, v in vars(args).items():
42
+ if (k not in excluded_args and v is not None and hasattr(args, '_provided') and k in args._provided):
43
+ provided_args[k] = v
44
+
45
+ if 'model_path' in vars(args) and args.model_path is not None:
46
+ provided_args['model_path'] = args.model_path
47
+
48
+ if 'prompt' in vars(args) and args.prompt is not None:
49
+ provided_args['prompt'] = args.prompt
50
+
51
+ merged_args = {**provided_args}
52
+
53
+ logger.info('CLI Args: %s', merged_args)
54
+
55
+ if 'model_path' not in merged_args or not merged_args['model_path']:
56
+ raise ValueError("model_path must be provided either in config file or via --model-path")
57
+
58
+ # Check if either prompt or prompt_txt is provided
59
+ has_prompt = 'prompt' in merged_args and merged_args['prompt']
60
+ has_prompt_txt = 'prompt_txt' in merged_args and merged_args['prompt_txt']
61
+
62
+ if not (has_prompt or has_prompt_txt):
63
+ raise ValueError("Either prompt or prompt_txt must be provided")
64
+
65
+ if has_prompt and has_prompt_txt:
66
+ raise ValueError("Cannot provide both 'prompt' and 'prompt_txt'. Use only one of them.")
67
+
68
+ init_args = {k: v for k, v in merged_args.items() if k not in self.generation_arg_names}
69
+ generation_args = {k: v for k, v in merged_args.items() if k in self.generation_arg_names}
70
+ generation_args.setdefault("return_frames", False)
71
+
72
+ model_path = init_args.pop('model_path')
73
+ prompt = generation_args.pop('prompt', None)
74
+
75
+ generator = VideoGenerator.from_pretrained(model_path=model_path, **init_args)
76
+
77
+ # Call generate_video - it handles both single and batch modes
78
+ generator.generate_video(prompt=prompt, **generation_args)
79
+
80
+ def validate(self, args: argparse.Namespace) -> None:
81
+ """Validate the arguments for this command"""
82
+ if args.num_gpus is not None and args.num_gpus <= 0:
83
+ raise ValueError("Number of gpus must be positive")
84
+
85
+ if args.config and not os.path.exists(args.config):
86
+ raise ValueError(f"Config file not found: {args.config}")
87
+
88
+ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
89
+ generate_parser = subparsers.add_parser(
90
+ "generate",
91
+ help="Run inference on a model",
92
+ usage="fastvideo generate (--model-path MODEL_PATH_OR_ID --prompt PROMPT) | --config CONFIG_FILE [OPTIONS]")
93
+
94
+ generate_parser.add_argument(
95
+ "--config",
96
+ type=str,
97
+ default='',
98
+ required=False,
99
+ help="Read CLI options from a config JSON or YAML file. If provided, --model-path and --prompt are optional."
100
+ )
101
+
102
+ generate_parser = FastVideoArgs.add_cli_args(generate_parser)
103
+ generate_parser = SamplingParam.add_cli_args(generate_parser)
104
+
105
+ generate_parser.add_argument(
106
+ "--text-encoder-configs",
107
+ action=RaiseNotImplementedAction,
108
+ help="JSON array of text encoder configurations (NOT YET IMPLEMENTED)",
109
+ )
110
+
111
+ return cast(FlexibleArgumentParser, generate_parser)
112
+
113
+
114
+ def cmd_init() -> list[CLISubcommand]:
115
+ return [GenerateSubcommand()]
standalone_inference/overlay_files/fastvideo/entrypoints/video_generator.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ VideoGenerator module for FastVideo.
4
+
5
+ This module provides a consolidated interface for generating videos using
6
+ diffusion models.
7
+ """
8
+
9
+ import os
10
+ import re
11
+ import shutil
12
+ import threading
13
+ import time
14
+ import tempfile
15
+ import warnings
16
+ from collections.abc import Mapping
17
+ from copy import deepcopy
18
+ from typing import Any
19
+
20
+ import imageio
21
+ import numpy as np
22
+ import torch
23
+ import torchvision
24
+ from einops import rearrange
25
+
26
+ from fastvideo.api.compat import (
27
+ expand_request_prompt_batch,
28
+ generator_config_to_fastvideo_args,
29
+ legacy_from_pretrained_to_config,
30
+ load_generator_config_from_file,
31
+ normalize_generation_request,
32
+ normalize_generator_config,
33
+ request_to_pipeline_overrides,
34
+ request_to_sampling_param,
35
+ )
36
+ from fastvideo.api.results import GenerationResult
37
+ from fastvideo.api.schema import GenerationRequest, GeneratorConfig
38
+ from fastvideo.configs.sample import SamplingParam
39
+ from fastvideo.fastvideo_args import FastVideoArgs
40
+ from fastvideo.logger import init_logger
41
+ from fastvideo.pipelines import ForwardBatch
42
+ from fastvideo.utils import align_to, shallow_asdict
43
+ from fastvideo.worker.executor import Executor
44
+
45
+ logger = init_logger(__name__)
46
+
47
+ _FROM_PRETRAINED_CONVENIENCE_KWARGS = frozenset({
48
+ "num_gpus",
49
+ "revision",
50
+ "trust_remote_code",
51
+ "distributed_executor_backend",
52
+ "tp_size",
53
+ "sp_size",
54
+ "hsdp_replicate_dim",
55
+ "hsdp_shard_dim",
56
+ "dist_timeout",
57
+ "use_fsdp_inference",
58
+ "disable_autocast",
59
+ "enable_stage_verification",
60
+ "dit_cpu_offload",
61
+ "dit_layerwise_offload",
62
+ "text_encoder_cpu_offload",
63
+ "image_encoder_cpu_offload",
64
+ "vae_cpu_offload",
65
+ "pin_cpu_memory",
66
+ "enable_torch_compile",
67
+ "torch_compile_kwargs",
68
+ "transformer_quant",
69
+ })
70
+
71
+
72
+ def _infer_latent_batch_size(batch: ForwardBatch) -> int:
73
+ if isinstance(batch.prompt, list):
74
+ latent_batch_size = len(batch.prompt)
75
+ elif batch.prompt is not None:
76
+ latent_batch_size = 1
77
+ elif batch.prompt_embeds is not None and len(batch.prompt_embeds) > 0:
78
+ latent_batch_size = batch.prompt_embeds[0].shape[0]
79
+ else:
80
+ raise ValueError("Cannot infer batch size from batch; no prompt or prompt_embeds found")
81
+ latent_batch_size *= batch.num_videos_per_prompt
82
+ return latent_batch_size
83
+
84
+
85
+ class VideoGenerator:
86
+ """
87
+ A unified class for generating videos using diffusion models.
88
+
89
+ This class provides a simple interface for video generation with rich
90
+ customization options, similar to popular frameworks like HF Diffusers.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ fastvideo_args: FastVideoArgs,
96
+ executor_class: type[Executor],
97
+ log_stats: bool,
98
+ *,
99
+ log_queue=None,
100
+ ):
101
+ """
102
+ Initialize the video generator.
103
+
104
+ Args:
105
+ fastvideo_args: The inference arguments
106
+ executor_class: The executor class to use for inference
107
+ log_stats: Whether to log statistics
108
+ log_queue: Optional multiprocessing.Queue to forward worker logs to
109
+ """
110
+ self.config: GeneratorConfig | None = None
111
+ self.fastvideo_args = fastvideo_args
112
+ self.executor = executor_class(fastvideo_args, log_queue=log_queue)
113
+
114
+ @classmethod
115
+ def from_pretrained(
116
+ cls,
117
+ model_path: str | GeneratorConfig | Mapping[str, Any] | None = None,
118
+ **kwargs,
119
+ ) -> "VideoGenerator":
120
+ """
121
+ Create a video generator from a pretrained model.
122
+
123
+ Args:
124
+ model_path: Path or identifier for the pretrained model
125
+ pipeline_config: Pipeline config to use for inference
126
+ **kwargs: Additional arguments to customize model loading, set any FastVideoArgs or PipelineConfig attributes here.
127
+
128
+ Returns:
129
+ The created video generator
130
+
131
+ Priority level: Default pipeline config < User's pipeline config < User's kwargs
132
+
133
+ Stable convenience kwargs remain supported here for common engine and
134
+ offload settings. Advanced model- or pipeline-specific options should
135
+ move to VideoGenerator.from_config(...).
136
+ """
137
+ log_queue = kwargs.pop("log_queue", None)
138
+ typed_config = kwargs.pop("config", None)
139
+ if typed_config is not None:
140
+ if model_path is not None:
141
+ raise TypeError("Pass either model_path or config to from_pretrained, not both")
142
+ if kwargs:
143
+ unexpected = ", ".join(sorted(kwargs))
144
+ raise TypeError(f"Unexpected keyword arguments with config: {unexpected}")
145
+ return cls.from_config(typed_config, log_queue=log_queue)
146
+
147
+ if isinstance(model_path, GeneratorConfig | Mapping):
148
+ if kwargs:
149
+ unexpected = ", ".join(sorted(kwargs))
150
+ raise TypeError(f"Unexpected keyword arguments with typed config: {unexpected}")
151
+ return cls.from_config(model_path, log_queue=log_queue)
152
+
153
+ if model_path is None:
154
+ raise TypeError("model_path or config is required")
155
+
156
+ legacy_only_kwargs = sorted(set(kwargs) - _FROM_PRETRAINED_CONVENIENCE_KWARGS)
157
+ if legacy_only_kwargs:
158
+ warnings.warn(
159
+ "VideoGenerator.from_pretrained(...) received legacy-only kwargs "
160
+ f"({', '.join(legacy_only_kwargs)}); prefer VideoGenerator.from_config(...) "
161
+ "for advanced configuration.",
162
+ DeprecationWarning,
163
+ stacklevel=2,
164
+ )
165
+ return cls.from_config(
166
+ legacy_from_pretrained_to_config(model_path, kwargs),
167
+ log_queue=log_queue,
168
+ )
169
+
170
+ @classmethod
171
+ def from_config(
172
+ cls,
173
+ config: GeneratorConfig | Mapping[str, Any],
174
+ *,
175
+ log_queue=None,
176
+ ) -> "VideoGenerator":
177
+ normalized = normalize_generator_config(config)
178
+ fastvideo_args = generator_config_to_fastvideo_args(normalized)
179
+ generator = cls.from_fastvideo_args(fastvideo_args, log_queue=log_queue)
180
+ generator.config = normalized
181
+ return generator
182
+
183
+ @classmethod
184
+ def from_file(
185
+ cls,
186
+ path: str,
187
+ overrides: list[str] | Mapping[str, Any] | None = None,
188
+ *,
189
+ log_queue=None,
190
+ ) -> "VideoGenerator":
191
+ return cls.from_config(
192
+ load_generator_config_from_file(path, overrides=overrides),
193
+ log_queue=log_queue,
194
+ )
195
+
196
+ @classmethod
197
+ def from_fastvideo_args(
198
+ cls,
199
+ fastvideo_args: FastVideoArgs,
200
+ *,
201
+ log_queue=None,
202
+ ) -> "VideoGenerator":
203
+ """
204
+ Create a video generator with the specified arguments.
205
+
206
+ Args:
207
+ fastvideo_args: The inference arguments
208
+ log_queue: Optional multiprocessing.Queue to forward worker logs to
209
+
210
+ Returns:
211
+ The created video generator
212
+ """
213
+ # Initialize distributed environment if needed
214
+ # initialize_distributed_and_parallelism(fastvideo_args)
215
+
216
+ executor_class = Executor.get_class(fastvideo_args)
217
+ return cls(
218
+ fastvideo_args=fastvideo_args,
219
+ executor_class=executor_class,
220
+ log_stats=False, # TODO: implement
221
+ log_queue=log_queue,
222
+ )
223
+
224
+ def generate(
225
+ self,
226
+ request: GenerationRequest | Mapping[str, Any],
227
+ *,
228
+ log_queue=None,
229
+ ) -> GenerationResult | list[GenerationResult]:
230
+ """
231
+ Generate video or image outputs from a typed inference request.
232
+
233
+ Args:
234
+ request: A `GenerationRequest` instance or a mapping that can be
235
+ parsed into one. This is the primary public inference
236
+ entrypoint for the typed API.
237
+ log_queue: Optional multiprocessing.Queue to forward worker logs to
238
+ during this request.
239
+
240
+ Returns:
241
+ A `GenerationResult` for single-request generation, or a list of
242
+ `GenerationResult` objects when the request expands into multiple
243
+ prompts.
244
+ """
245
+ normalized_request = normalize_generation_request(request)
246
+ if log_queue:
247
+ self.executor.set_log_queue(log_queue)
248
+
249
+ try:
250
+ return self._generate_request_impl(normalized_request)
251
+ finally:
252
+ if log_queue:
253
+ self.executor.clear_log_queue()
254
+
255
+ def generate_video(
256
+ self,
257
+ prompt: str | None = None,
258
+ sampling_param: SamplingParam | None = None,
259
+ # Action control inputs (Matrix-Game)
260
+ mouse_cond: torch.Tensor | None = None,
261
+ keyboard_cond: torch.Tensor | None = None,
262
+ grid_sizes: tuple[int, int, int] | list[int] | torch.Tensor
263
+ | None = None,
264
+ **kwargs,
265
+ ) -> dict[str, Any] | list[dict[str, Any]]:
266
+ """
267
+ Generate a video based on the given prompt.
268
+
269
+ Args:
270
+ prompt: The prompt to use for generation (optional if prompt_txt is provided)
271
+ negative_prompt: The negative prompt to use (overrides the one in fastvideo_args)
272
+ output_path: Path to save the video (overrides the one in fastvideo_args)
273
+ prompt_path: Path to prompt file
274
+ save_video: Whether to save the video to disk
275
+ return_frames: Whether to include raw frames in the result dict
276
+ num_inference_steps: Number of denoising steps (overrides fastvideo_args)
277
+ guidance_scale: Classifier-free guidance scale (overrides fastvideo_args)
278
+ num_frames: Number of frames to generate (overrides fastvideo_args)
279
+ height: Height of generated video (overrides fastvideo_args)
280
+ width: Width of generated video (overrides fastvideo_args)
281
+ fps: Frames per second for saved video (overrides fastvideo_args)
282
+ seed: Random seed for generation (overrides fastvideo_args)
283
+ callback: Callback function called after each step
284
+ callback_steps: Number of steps between each callback
285
+
286
+ Returns:
287
+ A metadata dictionary for single-prompt generation, or a list of
288
+ metadata dictionaries for prompt-file batch generation.
289
+ """
290
+ log_queue = kwargs.pop("log_queue", None)
291
+ warnings.warn(
292
+ "VideoGenerator.generate_video(...) is deprecated; use "
293
+ "VideoGenerator.generate(request=...) instead.",
294
+ DeprecationWarning,
295
+ stacklevel=2,
296
+ )
297
+ if log_queue:
298
+ self.executor.set_log_queue(log_queue)
299
+
300
+ try:
301
+ return self._generate_video_impl(
302
+ prompt=prompt,
303
+ sampling_param=sampling_param,
304
+ mouse_cond=mouse_cond,
305
+ keyboard_cond=keyboard_cond,
306
+ grid_sizes=grid_sizes,
307
+ **kwargs,
308
+ )
309
+ finally:
310
+ if log_queue:
311
+ self.executor.clear_log_queue()
312
+
313
+ def _generate_request_impl(
314
+ self,
315
+ request: GenerationRequest,
316
+ ) -> GenerationResult | list[GenerationResult]:
317
+ if isinstance(request.prompt, list):
318
+ if request.inputs.prompt_path is not None:
319
+ raise ValueError("request.prompt list cannot be combined with request.inputs.prompt_path")
320
+ results: list[GenerationResult] = []
321
+ for index, single_request in enumerate(expand_request_prompt_batch(request)):
322
+ prompt = single_request.prompt
323
+ wrapped = self._generate_single_request(single_request)
324
+ if isinstance(wrapped, list):
325
+ results.extend(wrapped)
326
+ continue
327
+ wrapped.prompt_index = index
328
+ if wrapped.prompt is None and isinstance(prompt, str):
329
+ wrapped.prompt = prompt
330
+ results.append(wrapped)
331
+ return results
332
+
333
+ return self._generate_single_request(request)
334
+
335
+ def _generate_single_request(
336
+ self,
337
+ request: GenerationRequest,
338
+ ) -> GenerationResult | list[GenerationResult]:
339
+ fastvideo_args = self.fastvideo_args
340
+ pipeline_overrides = request_to_pipeline_overrides(request)
341
+ if pipeline_overrides:
342
+ fastvideo_args = deepcopy(self.fastvideo_args)
343
+ for key, value in pipeline_overrides.items():
344
+ if not hasattr(fastvideo_args.pipeline_config, key):
345
+ raise ValueError(f"Request field {key!r} is not supported by pipeline config overrides")
346
+ setattr(fastvideo_args.pipeline_config, key, deepcopy(value))
347
+
348
+ sampling_param = request_to_sampling_param(
349
+ request,
350
+ model_path=self.fastvideo_args.model_path,
351
+ )
352
+ result = self._generate_video_impl(
353
+ prompt=request.prompt,
354
+ sampling_param=sampling_param,
355
+ fastvideo_args=fastvideo_args,
356
+ )
357
+ return self._wrap_legacy_result(result)
358
+
359
+ def _generate_video_impl(
360
+ self,
361
+ prompt: str | list[str] | None = None,
362
+ sampling_param: SamplingParam | None = None,
363
+ mouse_cond: torch.Tensor | None = None,
364
+ keyboard_cond: torch.Tensor | None = None,
365
+ grid_sizes: tuple[int, int, int] | list[int] | torch.Tensor
366
+ | None = None,
367
+ fastvideo_args: FastVideoArgs | None = None,
368
+ **kwargs,
369
+ ) -> dict[str, Any] | list[np.ndarray] | list[dict[str, Any]]:
370
+ """Internal implementation of generate_video."""
371
+ if fastvideo_args is None:
372
+ fastvideo_args = self.fastvideo_args
373
+
374
+ # Handle batch processing from text file
375
+ if sampling_param is None:
376
+ sampling_param = SamplingParam.from_pretrained(fastvideo_args.model_path)
377
+
378
+ # Add action control inputs to kwargs if provided
379
+ if mouse_cond is not None:
380
+ kwargs['mouse_cond'] = mouse_cond
381
+ if keyboard_cond is not None:
382
+ kwargs['keyboard_cond'] = keyboard_cond
383
+ if grid_sizes is not None:
384
+ kwargs['grid_sizes'] = grid_sizes
385
+
386
+ sampling_param.update(kwargs)
387
+
388
+ if fastvideo_args.prompt_txt is not None or sampling_param.prompt_path is not None:
389
+ prompt_txt_path = sampling_param.prompt_path or fastvideo_args.prompt_txt
390
+ if not prompt_txt_path or not os.path.exists(prompt_txt_path):
391
+ raise FileNotFoundError(f"Prompt text file not found: {prompt_txt_path}")
392
+
393
+ # Read prompts from file
394
+ with open(prompt_txt_path, encoding='utf-8') as f:
395
+ prompts = [line.strip() for line in f if line.strip()]
396
+
397
+ if not prompts:
398
+ raise ValueError(f"No prompts found in file: {prompt_txt_path}")
399
+
400
+ logger.info("Found %d prompts in %s", len(prompts), prompt_txt_path)
401
+
402
+ results = []
403
+ for i, batch_prompt in enumerate(prompts):
404
+ logger.info("Processing prompt %d/%d: %s...", i + 1, len(prompts), batch_prompt[:100])
405
+ try:
406
+ # Generate video for this prompt using the same logic below
407
+ output_path = self._prepare_output_path(sampling_param.output_path, batch_prompt)
408
+ kwargs["output_path"] = output_path
409
+ result = self._generate_single_video(
410
+ prompt=batch_prompt,
411
+ sampling_param=sampling_param,
412
+ fastvideo_args=fastvideo_args,
413
+ **kwargs,
414
+ )
415
+
416
+ # Add prompt info to result
417
+ result["prompt_index"] = i
418
+ result["prompt"] = batch_prompt
419
+
420
+ results.append(result)
421
+ logger.info("Successfully generated video for prompt %d", i + 1)
422
+
423
+ except Exception as e:
424
+ logger.error("Failed to generate video for prompt %d: %s", i + 1, e)
425
+ continue
426
+
427
+ logger.info("Completed batch processing. Generated %d videos successfully.", len(results))
428
+ return results
429
+
430
+ # Single prompt generation (original behavior)
431
+ if prompt is None:
432
+ raise ValueError("Either prompt or prompt_txt must be provided")
433
+ if not isinstance(prompt, str):
434
+ raise ValueError("Single-prompt generation expects a string prompt")
435
+ output_path = self._prepare_output_path(sampling_param.output_path, prompt)
436
+ kwargs["output_path"] = output_path
437
+ return self._generate_single_video(
438
+ prompt=prompt,
439
+ sampling_param=sampling_param,
440
+ fastvideo_args=fastvideo_args,
441
+ **kwargs,
442
+ )
443
+
444
+ def _is_image_workload(self) -> bool:
445
+ """Return True when the workload produces a single image (t2i, i2i …)."""
446
+ args = getattr(self, "fastvideo_args", None)
447
+ if args is None:
448
+ return False
449
+ return args.workload_type.value.endswith("2i")
450
+
451
+ def _prepare_output_path(
452
+ self,
453
+ output_path: str,
454
+ prompt: str,
455
+ ) -> str:
456
+ """Build a unique, sanitized output file path.
457
+
458
+ The file extension is chosen automatically based on the workload type:
459
+ ``.png`` for image workloads (``t2i``, ``i2i``, …) and ``.mp4`` for
460
+ video workloads.
461
+
462
+ - If ``output_path`` already carries the correct extension, treat it
463
+ as a file path.
464
+ - Otherwise, treat ``output_path`` as a directory and derive the
465
+ filename from the prompt.
466
+ - Invalid filename characters are removed; if the name changes, a
467
+ warning is logged.
468
+ - If the target path already exists, a numeric suffix is appended.
469
+ """
470
+ target_ext = ".png" if self._is_image_workload() else ".mp4"
471
+
472
+ def _sanitize_filename_component(name: str) -> str:
473
+ # Remove characters invalid on common filesystems, strip spaces/dots
474
+ sanitized = re.sub(r'[\\/:*?"<>|]', '', name)
475
+ sanitized = sanitized.strip().strip('.')
476
+ sanitized = re.sub(r'\s+', ' ', sanitized)
477
+ return sanitized or "output"
478
+
479
+ base_path, extension = os.path.splitext(output_path)
480
+ extension_lower = extension.lower()
481
+
482
+ if extension_lower == target_ext:
483
+ output_dir = os.path.dirname(output_path)
484
+ base_name = os.path.basename(base_path) # filename without extension
485
+ sanitized_base = _sanitize_filename_component(base_name)
486
+ if sanitized_base != base_name:
487
+ logger.warning(
488
+ "The output name '%s' contained invalid characters. "
489
+ "It has been renamed to '%s%s'",
490
+ os.path.basename(output_path),
491
+ sanitized_base,
492
+ target_ext,
493
+ )
494
+ out_name = f"{sanitized_base}{target_ext}"
495
+ else:
496
+ # Treat as directory; inform if an unexpected extension was
497
+ # provided.
498
+ if extension:
499
+ logger.info(
500
+ "Output path '%s' has extension '%s' which does not "
501
+ "match the target '%s'; treating it as a directory",
502
+ output_path,
503
+ extension,
504
+ target_ext,
505
+ )
506
+ output_dir = output_path
507
+ prompt_component = _sanitize_filename_component(prompt[:100])
508
+ out_name = f"{prompt_component}{target_ext}"
509
+
510
+ if output_dir:
511
+ os.makedirs(output_dir, exist_ok=True)
512
+
513
+ new_output_path = os.path.join(output_dir, out_name)
514
+ counter = 1
515
+ while os.path.exists(new_output_path):
516
+ name_part, ext_part = os.path.splitext(out_name)
517
+ new_name = f"{name_part}_{counter}{ext_part}"
518
+ new_output_path = os.path.join(output_dir, new_name)
519
+ counter += 1
520
+ return new_output_path
521
+
522
+ def _generate_single_video(
523
+ self,
524
+ prompt: str,
525
+ sampling_param: SamplingParam | None = None,
526
+ fastvideo_args: FastVideoArgs | None = None,
527
+ **kwargs,
528
+ ) -> dict[str, Any]:
529
+ """Internal method for single video generation"""
530
+ if fastvideo_args is None:
531
+ fastvideo_args = self.fastvideo_args
532
+
533
+ # Validate inputs
534
+ if not isinstance(prompt, str):
535
+ raise TypeError(f"`prompt` must be a string, but got {type(prompt)}")
536
+ prompt = prompt.strip()
537
+ sampling_param = deepcopy(sampling_param)
538
+ output_path = kwargs["output_path"]
539
+ sampling_param.prompt = prompt
540
+ # Process negative prompt
541
+ if sampling_param.negative_prompt is not None:
542
+ sampling_param.negative_prompt = sampling_param.negative_prompt.strip()
543
+
544
+ # Validate dimensions
545
+ if (sampling_param.height <= 0 or sampling_param.width <= 0 or sampling_param.num_frames <= 0):
546
+ raise ValueError(f"Height, width, and num_frames must be positive integers, got "
547
+ f"height={sampling_param.height}, width={sampling_param.width}, "
548
+ f"num_frames={sampling_param.num_frames}")
549
+
550
+ # Calculate sizes
551
+ target_height = align_to(sampling_param.height, 16)
552
+ target_width = align_to(sampling_param.width, 16)
553
+
554
+ # Calculate latent sizes
555
+ latents_size = [(sampling_param.num_frames - 1) // 4 + 1, sampling_param.height // 8, sampling_param.width // 8]
556
+ n_tokens = latents_size[0] * latents_size[1] * latents_size[2]
557
+
558
+ # Log parameters
559
+ debug_str = f"""
560
+ height: {target_height}
561
+ width: {target_width}
562
+ video_length: {sampling_param.num_frames}
563
+ prompt: {sampling_param.prompt}
564
+ image_path: {sampling_param.image_path}
565
+ neg_prompt: {sampling_param.negative_prompt}
566
+ seed: {sampling_param.seed}
567
+ infer_steps: {sampling_param.num_inference_steps}
568
+ num_videos_per_prompt: {sampling_param.num_videos_per_prompt}
569
+ guidance_scale: {sampling_param.guidance_scale}
570
+ n_tokens: {n_tokens}
571
+ flow_shift: {fastvideo_args.pipeline_config.flow_shift}
572
+ embedded_guidance_scale: {fastvideo_args.pipeline_config.embedded_cfg_scale}
573
+ save_video: {sampling_param.save_video}
574
+ output_path: {output_path}
575
+ """ # type: ignore[attr-defined]
576
+ logger.info(debug_str)
577
+
578
+ # Prepare batch
579
+ batch = ForwardBatch(
580
+ **shallow_asdict(sampling_param),
581
+ eta=0.0,
582
+ n_tokens=n_tokens,
583
+ VSA_sparsity=fastvideo_args.VSA_sparsity,
584
+ )
585
+
586
+ # Run inference
587
+ start_time = time.perf_counter()
588
+
589
+ # Execute forward pass in a new thread for non-blocking tensor
590
+ # allocation. Capture thread exceptions so we can surface the true
591
+ # failure in the main thread instead of later hitting None outputs.
592
+ result_container = {"output_batch": ForwardBatch(data_type=batch.data_type)}
593
+ thread_error: dict[str, BaseException | None] = {"error": None}
594
+ thread_error_traceback: dict[str, str] = {"traceback": ""}
595
+
596
+ def execute_forward_thread():
597
+ import traceback
598
+ try:
599
+ result_container["output_batch"] = self.executor.execute_forward(batch, fastvideo_args)
600
+ except BaseException as error: # noqa: BLE001
601
+ thread_error["error"] = error
602
+ thread_error_traceback["traceback"] = traceback.format_exc()
603
+
604
+ thread = threading.Thread(target=execute_forward_thread)
605
+ thread.start()
606
+ latent_batch_size = _infer_latent_batch_size(batch)
607
+ samples = torch.empty(
608
+ (latent_batch_size, 3, sampling_param.num_frames, sampling_param.height, sampling_param.width),
609
+ device='cpu',
610
+ pin_memory=fastvideo_args.pin_cpu_memory)
611
+ thread.join()
612
+
613
+ if thread_error["error"] is not None:
614
+ raise RuntimeError("Forward execution thread failed.\n"
615
+ f"{thread_error_traceback['traceback']}") from thread_error["error"]
616
+
617
+ output_batch = result_container["output_batch"]
618
+ if output_batch.output is None:
619
+ raise RuntimeError("Forward execution returned no output tensor. "
620
+ "This usually means the executor/pipeline failed earlier.")
621
+
622
+ if output_batch.output.shape == samples.shape:
623
+ samples.copy_(output_batch.output)
624
+ else:
625
+ logger.warning("Output shape %s does not match expected shape %s; use slow path", output_batch.output.shape,
626
+ samples.shape)
627
+ samples = output_batch.output.cpu()
628
+ logging_info = output_batch.logging_info
629
+
630
+ gen_time = time.perf_counter() - start_time
631
+ logger.info("Generated successfully in %.2f seconds", gen_time)
632
+
633
+ # Process outputs
634
+ videos = rearrange(samples, "b c t h w -> t b c h w")
635
+ frames = []
636
+ for x in videos:
637
+ x = torchvision.utils.make_grid(x, nrow=6)
638
+ x = x.permute(1, 2, 0).squeeze(-1)
639
+ x = (x * 255).to(torch.uint8)
640
+ frames.append(x.cpu().numpy())
641
+
642
+ # Save output if requested
643
+ if batch.save_video:
644
+ if self._is_image_workload():
645
+ # Image workloads (t2i, i2i, …): save the first frame as PNG.
646
+ imageio.imwrite(output_path, frames[0])
647
+ logger.info("Saved image to %s", output_path)
648
+ else:
649
+ imageio.mimsave(output_path, frames, fps=batch.fps, format="mp4")
650
+ logger.info("Saved video to %s", output_path)
651
+ audio = output_batch.extra.get("audio")
652
+ audio_sample_rate = output_batch.extra.get("audio_sample_rate")
653
+ if (audio is not None and audio_sample_rate is not None
654
+ and not self._mux_audio(output_path, audio, audio_sample_rate)):
655
+ logger.warning("Audio mux failed; saved video without audio.")
656
+
657
+ result: dict[str, Any] = {
658
+ "prompts": prompt,
659
+ "samples": samples if batch.return_frames else None,
660
+ "frames": frames if batch.return_frames else None,
661
+ "audio": output_batch.extra.get("audio") if batch.return_frames else None,
662
+ "size": (target_height, target_width, batch.num_frames),
663
+ "generation_time": gen_time,
664
+ "logging_info": logging_info,
665
+ "trajectory": output_batch.trajectory_latents,
666
+ "trajectory_timesteps": output_batch.trajectory_timesteps,
667
+ "trajectory_decoded": output_batch.trajectory_decoded,
668
+ "video_path": output_path if batch.save_video else None,
669
+ "peak_memory_mb": output_batch.extra.get("peak_memory_mb"),
670
+ }
671
+
672
+ return result
673
+
674
+ @staticmethod
675
+ def _wrap_legacy_result(
676
+ result: dict[str, Any] | list[dict[str, Any]], ) -> GenerationResult | list[GenerationResult]:
677
+ if isinstance(result, list):
678
+ return [GenerationResult.from_legacy_result(item) for item in result]
679
+ return GenerationResult.from_legacy_result(result)
680
+
681
+ @staticmethod
682
+ def _unwrap_typed_result(
683
+ result: GenerationResult | list[GenerationResult], ) -> dict[str, Any] | list[dict[str, Any]]:
684
+ if isinstance(result, list):
685
+ return [item.to_legacy_dict() for item in result]
686
+ return result.to_legacy_dict()
687
+
688
+ @staticmethod
689
+ def _mux_audio(
690
+ video_path: str,
691
+ audio: torch.Tensor | np.ndarray,
692
+ sample_rate: int,
693
+ ) -> bool:
694
+ """Mux audio into video using PyAV."""
695
+ try:
696
+ import av
697
+ except ImportError:
698
+ logger.warning("PyAV not installed; cannot mux audio. "
699
+ "Install with: pip install av")
700
+ return False
701
+
702
+ if torch.is_tensor(audio):
703
+ audio_np = audio.detach().cpu().float().numpy()
704
+ else:
705
+ audio_np = np.asarray(audio, dtype=np.float32)
706
+
707
+ if audio_np.ndim == 1:
708
+ audio_np = audio_np[:, None]
709
+ elif audio_np.ndim == 2:
710
+ if audio_np.shape[0] <= 8 and audio_np.shape[1] > audio_np.shape[0]:
711
+ audio_np = audio_np.T
712
+ else:
713
+ logger.warning("Unexpected audio shape %s; skipping mux.", audio_np.shape)
714
+ return False
715
+
716
+ audio_np = np.clip(audio_np, -1.0, 1.0)
717
+ audio_int16 = (audio_np * 32767.0).astype(np.int16)
718
+ num_channels = audio_int16.shape[1]
719
+ layout = "stereo" if num_channels == 2 else "mono"
720
+
721
+ try:
722
+ import wave
723
+ with tempfile.TemporaryDirectory() as tmpdir:
724
+ out_path = os.path.join(tmpdir, "muxed.mp4")
725
+ wav_path = os.path.join(tmpdir, "audio.wav")
726
+
727
+ # Write audio to WAV file
728
+ with wave.open(wav_path, "wb") as wav_file:
729
+ wav_file.setnchannels(num_channels)
730
+ wav_file.setsampwidth(2)
731
+ wav_file.setframerate(sample_rate)
732
+ wav_file.writeframes(audio_int16.tobytes())
733
+
734
+ # Open input video and audio
735
+ input_video = av.open(video_path)
736
+ input_audio = av.open(wav_path)
737
+
738
+ # Create output with both streams
739
+ output = av.open(out_path, mode="w")
740
+
741
+ # Add video stream (copy codec from input)
742
+ in_video_stream = input_video.streams.video[0]
743
+ out_video_stream = output.add_stream(
744
+ codec_name=in_video_stream.codec_context.name,
745
+ rate=in_video_stream.average_rate,
746
+ )
747
+ out_video_stream.width = in_video_stream.width
748
+ out_video_stream.height = in_video_stream.height
749
+ out_video_stream.pix_fmt = in_video_stream.pix_fmt
750
+
751
+ # Add audio stream (AAC)
752
+ out_audio_stream = output.add_stream("aac", rate=sample_rate)
753
+ out_audio_stream.layout = layout
754
+
755
+ # Remux video (decode and re-encode to be safe)
756
+ for frame in input_video.decode(video=0):
757
+ for packet in out_video_stream.encode(frame):
758
+ output.mux(packet)
759
+ for packet in out_video_stream.encode():
760
+ output.mux(packet)
761
+
762
+ # Encode audio
763
+ for frame in input_audio.decode(audio=0):
764
+ frame.pts = None # Let encoder assign PTS
765
+ for packet in out_audio_stream.encode(frame):
766
+ output.mux(packet)
767
+ for packet in out_audio_stream.encode():
768
+ output.mux(packet)
769
+
770
+ input_video.close()
771
+ input_audio.close()
772
+ output.close()
773
+ shutil.move(out_path, video_path)
774
+ return True
775
+ except Exception as e:
776
+ logger.warning("Audio mux failed: %s", e)
777
+ return False
778
+
779
+ def set_lora_adapter(self, lora_nickname: str, lora_path: str | None = None) -> None:
780
+ self.executor.set_lora_adapter(lora_nickname, lora_path)
781
+
782
+ def unmerge_lora_weights(self) -> None:
783
+ """
784
+ Use unmerged weights for inference to produce videos that align with
785
+ validation videos generated during training.
786
+ """
787
+ self.executor.unmerge_lora_weights()
788
+
789
+ def merge_lora_weights(self) -> None:
790
+ self.executor.merge_lora_weights()
791
+
792
+ def shutdown(self) -> None:
793
+ """
794
+ Shutdown the video generator.
795
+ """
796
+ self.executor.shutdown()
797
+ del self.executor
standalone_inference/overlay_files/fastvideo/fastvideo_args.py ADDED
@@ -0,0 +1,1188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Inspired by SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py
3
+ """The arguments of FastVideo Inference."""
4
+ import argparse
5
+ import dataclasses
6
+ import json
7
+ from contextlib import contextmanager
8
+ from dataclasses import field
9
+ from enum import Enum
10
+ from typing import Any, TYPE_CHECKING
11
+
12
+ from fastvideo.configs.configs import PreprocessConfig
13
+ from fastvideo.configs.pipelines.base import PipelineConfig
14
+ from fastvideo.configs.utils import clean_cli_args
15
+ from fastvideo.layers.quantization import QUANTIZATION_METHODS, QuantizationMethods
16
+ from fastvideo.logger import init_logger
17
+ from fastvideo.utils import FlexibleArgumentParser, StoreBoolean
18
+
19
+ if TYPE_CHECKING:
20
+ from ray.runtime_env import RuntimeEnv
21
+ from ray.util.placement_group import PlacementGroup
22
+ else:
23
+ RuntimeEnv = Any
24
+ PlacementGroup = Any
25
+
26
+ logger = init_logger(__name__)
27
+
28
+
29
+ class ExecutionMode(str, Enum):
30
+ """
31
+ Enumeration for different pipeline modes.
32
+
33
+ Inherits from str to allow string comparison for backward compatibility.
34
+ """
35
+ INFERENCE = "inference"
36
+ PREPROCESS = "preprocess"
37
+ FINETUNING = "finetuning"
38
+ DISTILLATION = "distillation"
39
+
40
+ @classmethod
41
+ def from_string(cls, value: str) -> "ExecutionMode":
42
+ """Convert string to ExecutionMode enum."""
43
+ try:
44
+ return cls(value.lower())
45
+ except ValueError:
46
+ raise ValueError(f"Invalid mode: {value}. Must be one of: {', '.join([m.value for m in cls])}") from None
47
+
48
+ @classmethod
49
+ def choices(cls) -> list[str]:
50
+ """Get all available choices as strings for argparse."""
51
+ return [mode.value for mode in cls]
52
+
53
+
54
+ class WorkloadType(str, Enum):
55
+ """
56
+ Enumeration for different workload types.
57
+
58
+ Inherits from str to allow string comparison for backward compatibility.
59
+ """
60
+ I2V = "i2v" # Image to Video
61
+ T2V = "t2v" # Text to Video
62
+ T2I = "t2i" # Text to Image
63
+ I2I = "i2i" # Image to Image
64
+
65
+ @classmethod
66
+ def from_string(cls, value: str) -> "WorkloadType":
67
+ """Convert string to WorkloadType enum."""
68
+ try:
69
+ return cls(value.lower())
70
+ except ValueError:
71
+ raise ValueError(
72
+ f"Invalid workload type: {value}. Must be one of: {', '.join([m.value for m in cls])}") from None
73
+
74
+ @classmethod
75
+ def choices(cls) -> list[str]:
76
+ """Get all available choices as strings for argparse."""
77
+ return [workload.value for workload in cls]
78
+
79
+
80
+ # args for fastvideo framework
81
+ @dataclasses.dataclass
82
+ class FastVideoArgs:
83
+ # Model and path configuration (for convenience)
84
+ model_path: str
85
+
86
+ # Running mode
87
+ mode: ExecutionMode = ExecutionMode.INFERENCE
88
+
89
+ # Workload type
90
+ workload_type: WorkloadType = WorkloadType.T2V
91
+
92
+ # Distributed executor backend
93
+ distributed_executor_backend: str = "mp"
94
+
95
+ # a few attributes for ray related
96
+ ray_placement_group: PlacementGroup | None = None
97
+ ray_runtime_env: RuntimeEnv | None = None
98
+
99
+ inference_mode: bool = True # if False == training mode
100
+
101
+ # HuggingFace specific parameters
102
+ trust_remote_code: bool = False
103
+ revision: str | None = None
104
+
105
+ # Parallelism
106
+ num_gpus: int = 1
107
+ tp_size: int = -1
108
+ sp_size: int = -1
109
+ hsdp_replicate_dim: int = 1
110
+ hsdp_shard_dim: int = -1
111
+ dist_timeout: int | None = None # timeout for torch.distributed
112
+
113
+ pipeline_config: PipelineConfig = field(default_factory=PipelineConfig)
114
+ preprocess_config: PreprocessConfig | None = None
115
+
116
+ # LoRA parameters
117
+ # (Wenxuan) prefer to keep it here instead of in pipeline config to not make it complicated.
118
+ lora_path: str | None = None
119
+ lora_nickname: str = "default" # for swapping adapters in the pipeline
120
+ # can restrict layers to adapt, e.g. ["q_proj"]
121
+ # Will adapt only q, k, v, o by default.
122
+ lora_target_modules: list[str] | None = None
123
+
124
+ output_type: str = "pil"
125
+
126
+ # CPU offload parameters
127
+ dit_cpu_offload: bool = True
128
+ use_fsdp_inference: bool = False
129
+ dit_layerwise_offload: bool = True
130
+ text_encoder_cpu_offload: bool = True
131
+ image_encoder_cpu_offload: bool = True
132
+ vae_cpu_offload: bool = True
133
+ pin_cpu_memory: bool = True
134
+
135
+ # Compilation
136
+ enable_torch_compile: bool = False
137
+ torch_compile_kwargs: dict[str, Any] = field(default_factory=dict)
138
+
139
+ disable_autocast: bool = False
140
+
141
+ # VSA parameters
142
+ VSA_sparsity: float = 0.0 # inference/validation sparsity
143
+
144
+ # V-MoBA parameters
145
+ moba_config_path: str | None = None
146
+ moba_config: dict[str, Any] = field(default_factory=dict)
147
+
148
+ # Master port for distributed training/inference
149
+ master_port: int | None = None
150
+
151
+ # Stage verification
152
+ enable_stage_verification: bool = True
153
+
154
+ # Prompt text file for batch processing
155
+ prompt_txt: str | None = None
156
+
157
+ # LTX-2 VAE tiling overrides
158
+ ltx2_vae_tiling: bool | None = None
159
+ ltx2_vae_spatial_tile_size_in_pixels: int | None = None
160
+ ltx2_vae_spatial_tile_overlap_in_pixels: int | None = None
161
+ ltx2_vae_temporal_tile_size_in_frames: int | None = None
162
+ ltx2_vae_temporal_tile_overlap_in_frames: int | None = None
163
+ ltx2_initial_latent_path: str | None = None
164
+
165
+ # model paths for correct deallocation
166
+ model_paths: dict[str, str] = field(default_factory=dict)
167
+ model_loaded: dict[str, bool] = field(default_factory=lambda: {
168
+ "transformer": True,
169
+ "vae": True,
170
+ "upsampler": True,
171
+ })
172
+
173
+ override_text_encoder_safetensors: str | None = None # path to safetensors file for text encoder override
174
+ override_text_encoder_quant: QuantizationMethods = None
175
+ transformer_quant: QuantizationMethods = None
176
+
177
+ override_transformer_cls_name: str | None = None
178
+ init_weights_from_safetensors: str = "" # path to safetensors file for initial weight loading
179
+ init_weights_from_safetensors_2: str = "" # path to safetensors file for initial weight loading for transformer_2
180
+
181
+ override_pipeline_cls_name: str | None = None
182
+
183
+ # # DMD parameters
184
+ # dmd_denoising_steps: List[int] | None = field(default=None)
185
+
186
+ # MoE parameters used by Wan2.2
187
+ boundary_ratio: float = 0.875
188
+
189
+ @property
190
+ def training_mode(self) -> bool:
191
+ return not self.inference_mode
192
+
193
+ def __post_init__(self):
194
+ if self.moba_config_path:
195
+ try:
196
+ with open(self.moba_config_path) as f:
197
+ self.moba_config = json.load(f)
198
+ logger.info("Loaded V-MoBA config from %s", self.moba_config_path)
199
+ except (FileNotFoundError, json.JSONDecodeError) as e:
200
+ logger.error("Failed to load V-MoBA config from %s: %s", self.moba_config_path, e)
201
+ raise
202
+ self._apply_ltx2_vae_overrides()
203
+ self.check_fastvideo_args()
204
+
205
+ def __getattr__(self, name: str) -> Any:
206
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
207
+
208
+ def _apply_ltx2_vae_overrides(self) -> None:
209
+ if self.pipeline_config is None:
210
+ return
211
+ vae_config = self.pipeline_config.vae_config
212
+ has_any = any(value is not None for value in (
213
+ self.ltx2_vae_spatial_tile_size_in_pixels,
214
+ self.ltx2_vae_spatial_tile_overlap_in_pixels,
215
+ self.ltx2_vae_temporal_tile_size_in_frames,
216
+ self.ltx2_vae_temporal_tile_overlap_in_frames,
217
+ ))
218
+ if self.ltx2_vae_tiling is not None and hasattr(self.pipeline_config, "vae_tiling"):
219
+ self.pipeline_config.vae_tiling = self.ltx2_vae_tiling
220
+ elif has_any and hasattr(self.pipeline_config, "vae_tiling"):
221
+ self.pipeline_config.vae_tiling = True
222
+
223
+ if hasattr(vae_config,
224
+ "ltx2_spatial_tile_size_in_pixels") and self.ltx2_vae_spatial_tile_size_in_pixels is not None:
225
+ vae_config.ltx2_spatial_tile_size_in_pixels = (self.ltx2_vae_spatial_tile_size_in_pixels)
226
+ if hasattr(vae_config,
227
+ "ltx2_spatial_tile_overlap_in_pixels") and self.ltx2_vae_spatial_tile_overlap_in_pixels is not None:
228
+ vae_config.ltx2_spatial_tile_overlap_in_pixels = (self.ltx2_vae_spatial_tile_overlap_in_pixels)
229
+ if hasattr(vae_config,
230
+ "ltx2_temporal_tile_size_in_frames") and self.ltx2_vae_temporal_tile_size_in_frames is not None:
231
+ vae_config.ltx2_temporal_tile_size_in_frames = (self.ltx2_vae_temporal_tile_size_in_frames)
232
+ if hasattr(
233
+ vae_config,
234
+ "ltx2_temporal_tile_overlap_in_frames") and self.ltx2_vae_temporal_tile_overlap_in_frames is not None:
235
+ vae_config.ltx2_temporal_tile_overlap_in_frames = (self.ltx2_vae_temporal_tile_overlap_in_frames)
236
+
237
+ @staticmethod
238
+ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
239
+ # Model and path configuration
240
+ parser.add_argument(
241
+ "--model-path",
242
+ type=str,
243
+ help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
244
+ )
245
+
246
+ # Running mode
247
+ parser.add_argument(
248
+ "--mode",
249
+ type=str,
250
+ choices=ExecutionMode.choices(),
251
+ default=FastVideoArgs.mode.value,
252
+ help="The mode to run FastVideo",
253
+ )
254
+
255
+ # Workload type
256
+ parser.add_argument(
257
+ "--workload-type",
258
+ type=str,
259
+ choices=WorkloadType.choices(),
260
+ default=FastVideoArgs.workload_type.value,
261
+ help="The workload type",
262
+ )
263
+
264
+ # distributed_executor_backend
265
+ parser.add_argument(
266
+ "--distributed-executor-backend",
267
+ type=str,
268
+ choices=["mp"],
269
+ default=FastVideoArgs.distributed_executor_backend,
270
+ help="The distributed executor backend to use",
271
+ )
272
+
273
+ parser.add_argument(
274
+ "--inference-mode",
275
+ action=StoreBoolean,
276
+ default=FastVideoArgs.inference_mode,
277
+ help="Whether to use inference mode",
278
+ )
279
+
280
+ # HuggingFace specific parameters
281
+ parser.add_argument(
282
+ "--trust-remote-code",
283
+ action=StoreBoolean,
284
+ default=FastVideoArgs.trust_remote_code,
285
+ help="Trust remote code when loading HuggingFace models",
286
+ )
287
+ parser.add_argument(
288
+ "--revision",
289
+ type=str,
290
+ default=FastVideoArgs.revision,
291
+ help="The specific model version to use (can be a branch name, tag name, or commit id)",
292
+ )
293
+
294
+ # Parallelism
295
+ parser.add_argument(
296
+ "--num-gpus",
297
+ type=int,
298
+ default=FastVideoArgs.num_gpus,
299
+ help="The number of GPUs to use.",
300
+ )
301
+ parser.add_argument(
302
+ "--tp-size",
303
+ type=int,
304
+ default=FastVideoArgs.tp_size,
305
+ help="The tensor parallelism size.",
306
+ )
307
+ parser.add_argument(
308
+ "--sp-size",
309
+ type=int,
310
+ default=FastVideoArgs.sp_size,
311
+ help="The sequence parallelism size.",
312
+ )
313
+ parser.add_argument(
314
+ "--hsdp-replicate-dim",
315
+ type=int,
316
+ default=FastVideoArgs.hsdp_replicate_dim,
317
+ help="The data parallelism size.",
318
+ )
319
+ parser.add_argument(
320
+ "--hsdp-shard-dim",
321
+ type=int,
322
+ default=FastVideoArgs.hsdp_shard_dim,
323
+ help="The data parallelism shards.",
324
+ )
325
+ parser.add_argument(
326
+ "--dist-timeout",
327
+ type=int,
328
+ default=FastVideoArgs.dist_timeout,
329
+ help="Set timeout for torch.distributed initialization.",
330
+ )
331
+
332
+ # Output type
333
+ parser.add_argument(
334
+ "--output-type",
335
+ type=str,
336
+ default=FastVideoArgs.output_type,
337
+ choices=["pil"],
338
+ help="Output type for the generated video",
339
+ )
340
+
341
+ # Prompt text file for batch processing
342
+ parser.add_argument(
343
+ "--prompt-txt",
344
+ type=str,
345
+ default=FastVideoArgs.prompt_txt,
346
+ help="Path to a text file containing prompts (one per line) for batch processing",
347
+ )
348
+
349
+ # LTX-2 VAE tiling overrides
350
+ parser.add_argument(
351
+ "--ltx2-vae-tiling",
352
+ action=StoreBoolean,
353
+ default=FastVideoArgs.ltx2_vae_tiling,
354
+ help="Enable LTX-2 VAE tiling overrides.",
355
+ )
356
+ parser.add_argument(
357
+ "--ltx2-vae-spatial-tile-size-in-pixels",
358
+ type=int,
359
+ default=FastVideoArgs.ltx2_vae_spatial_tile_size_in_pixels,
360
+ help="LTX-2 VAE spatial tile size in pixels.",
361
+ )
362
+ parser.add_argument(
363
+ "--ltx2-vae-spatial-tile-overlap-in-pixels",
364
+ type=int,
365
+ default=FastVideoArgs.ltx2_vae_spatial_tile_overlap_in_pixels,
366
+ help="LTX-2 VAE spatial tile overlap in pixels.",
367
+ )
368
+ parser.add_argument(
369
+ "--ltx2-vae-temporal-tile-size-in-frames",
370
+ type=int,
371
+ default=FastVideoArgs.ltx2_vae_temporal_tile_size_in_frames,
372
+ help="LTX-2 VAE temporal tile size in frames.",
373
+ )
374
+ parser.add_argument(
375
+ "--ltx2-vae-temporal-tile-overlap-in-frames",
376
+ type=int,
377
+ default=FastVideoArgs.ltx2_vae_temporal_tile_overlap_in_frames,
378
+ help="LTX-2 VAE temporal tile overlap in frames.",
379
+ )
380
+ parser.add_argument(
381
+ "--ltx2-initial-latent-path",
382
+ type=str,
383
+ default=FastVideoArgs.ltx2_initial_latent_path,
384
+ help="Path to load/save a precomputed LTX-2 initial latent.",
385
+ )
386
+
387
+ # LoRA parameters (inference-time adapter loading)
388
+ parser.add_argument(
389
+ "--lora-path",
390
+ type=str,
391
+ default=FastVideoArgs.lora_path,
392
+ help="Path to a LoRA adapter (directory or HF repo id). If set, LoRA will be applied at inference.",
393
+ )
394
+ parser.add_argument(
395
+ "--lora-nickname",
396
+ type=str,
397
+ default=FastVideoArgs.lora_nickname,
398
+ help="Nickname to refer to the loaded LoRA adapter (useful for swapping).",
399
+ )
400
+ parser.add_argument(
401
+ "--lora-target-modules",
402
+ nargs="+",
403
+ type=str,
404
+ default=FastVideoArgs.lora_target_modules,
405
+ help="Optional list of module name substrings to restrict LoRA injection (e.g. q_proj k_proj v_proj).",
406
+ )
407
+
408
+ # BSA runtime control (LongCat)
409
+ parser.add_argument(
410
+ "--enable-bsa",
411
+ action=StoreBoolean,
412
+ help="Enable Block Sparse Attention (BSA) at runtime (overrides config).",
413
+ )
414
+ parser.add_argument(
415
+ "--bsa-sparsity",
416
+ type=float,
417
+ help="BSA sparsity (e.g., 0.9375).",
418
+ )
419
+ parser.add_argument(
420
+ "--bsa-cdf-threshold",
421
+ type=float,
422
+ help="BSA CDF threshold (optional).",
423
+ )
424
+ parser.add_argument(
425
+ "--bsa-chunk-q",
426
+ nargs=3,
427
+ type=int,
428
+ metavar=("T", "H", "W"),
429
+ help="BSA chunk_3d_shape_q as three ints, e.g., 4 4 4.",
430
+ )
431
+ parser.add_argument(
432
+ "--bsa-chunk-k",
433
+ nargs=3,
434
+ type=int,
435
+ metavar=("T", "H", "W"),
436
+ help="BSA chunk_3d_shape_k as three ints, e.g., 4 4 4.",
437
+ )
438
+
439
+ parser.add_argument(
440
+ "--enable-torch-compile",
441
+ action=StoreBoolean,
442
+ default=FastVideoArgs.enable_torch_compile,
443
+ help="Use torch.compile to speed up DiT inference." +
444
+ "However, will likely cause precision drifts. See (https://github.com/pytorch/pytorch/issues/145213)",
445
+ )
446
+ parser.add_argument(
447
+ "--torch-compile-kwargs",
448
+ type=str,
449
+ default=None,
450
+ help=
451
+ "JSON string of kwargs to pass to torch.compile. Example: '{\"backend\":\"inductor\",\"mode\":\"reduce-overhead\"}'",
452
+ )
453
+
454
+ parser.add_argument(
455
+ "--dit-cpu-offload",
456
+ action=StoreBoolean,
457
+ help="Use CPU offload for DiT inference. Enable if run out of memory with FSDP.",
458
+ )
459
+ parser.add_argument(
460
+ "--dit-layerwise-offload",
461
+ action=StoreBoolean,
462
+ help="Enable layerwise CPU offload with async H2D prefetch overlap.",
463
+ )
464
+ parser.add_argument(
465
+ "--use-fsdp-inference",
466
+ action=StoreBoolean,
467
+ help=
468
+ "Use FSDP for inference by sharding the model weights. Latency is very low due to prefetch--enable if run out of memory.",
469
+ )
470
+ parser.add_argument(
471
+ "--text-encoder-cpu-offload",
472
+ action=StoreBoolean,
473
+ help="Use CPU offload for text encoder. Enable if run out of memory.",
474
+ )
475
+ parser.add_argument(
476
+ "--image-encoder-cpu-offload",
477
+ action=StoreBoolean,
478
+ help="Use CPU offload for image encoder. Enable if run out of memory.",
479
+ )
480
+ parser.add_argument(
481
+ "--vae-cpu-offload",
482
+ action=StoreBoolean,
483
+ help="Use CPU offload for VAE. Enable if run out of memory.",
484
+ )
485
+ parser.add_argument(
486
+ "--pin-cpu-memory",
487
+ action=StoreBoolean,
488
+ help=
489
+ "Pin memory for CPU offload. Only added as a temp workaround if it throws \"CUDA error: invalid argument\". "
490
+ "Should be enabled in almost all cases",
491
+ )
492
+ parser.add_argument(
493
+ "--disable-autocast",
494
+ action=StoreBoolean,
495
+ help="Disable autocast for denoising loop and vae decoding in pipeline sampling",
496
+ )
497
+
498
+ # VSA parameters
499
+ parser.add_argument(
500
+ "--VSA-sparsity",
501
+ type=float,
502
+ default=FastVideoArgs.VSA_sparsity,
503
+ help="Validation sparsity for VSA",
504
+ )
505
+
506
+ # Master port for distributed training/inference
507
+ parser.add_argument(
508
+ "--master-port",
509
+ type=int,
510
+ default=FastVideoArgs.master_port,
511
+ help="Master port for distributed training/inference",
512
+ )
513
+
514
+ # Stage verification
515
+ parser.add_argument(
516
+ "--enable-stage-verification",
517
+ action=StoreBoolean,
518
+ default=FastVideoArgs.enable_stage_verification,
519
+ help="Enable input/output verification for pipeline stages",
520
+ )
521
+ parser.add_argument(
522
+ "--override-text-encoder-safetensors",
523
+ type=str,
524
+ default=FastVideoArgs.override_text_encoder_safetensors,
525
+ help="Path to safetensors file for text encoder override",
526
+ )
527
+ parser.add_argument(
528
+ "--override-text-encoder-quant",
529
+ type=str,
530
+ choices=QUANTIZATION_METHODS,
531
+ default=FastVideoArgs.override_text_encoder_quant,
532
+ help="Quantization method for text encoder override",
533
+ )
534
+ parser.add_argument(
535
+ "--transformer-quant",
536
+ type=str,
537
+ choices=QUANTIZATION_METHODS,
538
+ default=FastVideoArgs.transformer_quant,
539
+ help="Quantization method for transformer loading",
540
+ )
541
+ parser.add_argument(
542
+ "--override-transformer-cls-name",
543
+ type=str,
544
+ default=FastVideoArgs.override_transformer_cls_name,
545
+ help="Override transformer cls name",
546
+ )
547
+ parser.add_argument(
548
+ "--override-pipeline-cls-name",
549
+ type=str,
550
+ default=FastVideoArgs.override_pipeline_cls_name,
551
+ help="Override pipeline cls name",
552
+ )
553
+ parser.add_argument("--init-weights-from-safetensors",
554
+ type=str,
555
+ help="Path to safetensors file for initial weight loading")
556
+ parser.add_argument("--init-weights-from-safetensors-2",
557
+ type=str,
558
+ help="Path to safetensors file for initial weight loading")
559
+
560
+ # Add pipeline configuration arguments
561
+ PipelineConfig.add_cli_args(parser)
562
+
563
+ # Add preprocessing configuration arguments
564
+ PreprocessConfig.add_cli_args(parser)
565
+
566
+ return parser
567
+
568
+ @classmethod
569
+ def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
570
+ provided_args = clean_cli_args(args)
571
+ # Get all fields from the dataclass
572
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
573
+
574
+ # Create a dictionary of attribute values, with defaults for missing attributes
575
+ kwargs: dict[str, Any] = {}
576
+ for attr in attrs:
577
+ if attr == 'pipeline_config':
578
+ pipeline_config = PipelineConfig.from_kwargs(provided_args)
579
+ kwargs['pipeline_config'] = pipeline_config
580
+ elif attr == 'preprocess_config':
581
+ preprocess_config = PreprocessConfig.from_kwargs(provided_args)
582
+ kwargs['preprocess_config'] = preprocess_config
583
+ elif attr == 'mode':
584
+ # Convert string to ExecutionMode enum
585
+ mode_value = getattr(args, attr, FastVideoArgs.mode.value)
586
+ kwargs['mode'] = ExecutionMode.from_string(mode_value) if isinstance(mode_value, str) else mode_value
587
+ elif attr == 'torch_compile_kwargs':
588
+ # Parse JSON string for torch.compile kwargs
589
+ torch_compile_kwargs_str = getattr(args, 'torch_compile_kwargs', None)
590
+ if torch_compile_kwargs_str:
591
+ try:
592
+ import json
593
+ kwargs['torch_compile_kwargs'] = json.loads(torch_compile_kwargs_str)
594
+ except json.JSONDecodeError as e:
595
+ raise ValueError(f"Invalid JSON for torch_compile_kwargs: {e}") from e
596
+ else:
597
+ kwargs['torch_compile_kwargs'] = {}
598
+ elif attr == 'workload_type':
599
+ # Convert string to WorkloadType enum
600
+ workload_type_value = getattr(args, 'workload_type', FastVideoArgs.workload_type.value)
601
+ kwargs['workload_type'] = WorkloadType.from_string(workload_type_value) if isinstance(
602
+ workload_type_value, str) else workload_type_value
603
+ # Use getattr with default value from the dataclass for potentially missing attributes
604
+ else:
605
+ # Get the field to check if it has a default_factory
606
+ field = dataclasses.fields(cls)[next(i for i, f in enumerate(dataclasses.fields(cls))
607
+ if f.name == attr)]
608
+ if field.default_factory is not dataclasses.MISSING:
609
+ # Use the default_factory to create the default value
610
+ default_value = field.default_factory()
611
+ else:
612
+ default_value = getattr(cls, attr, None)
613
+ value = getattr(args, attr, default_value)
614
+ kwargs[attr] = value # type: ignore
615
+
616
+ return cls(**kwargs) # type: ignore
617
+
618
+ @classmethod
619
+ def from_kwargs(cls, **kwargs: Any) -> "FastVideoArgs":
620
+ # Convert mode string to enum if necessary
621
+ if 'mode' in kwargs and isinstance(kwargs['mode'], str):
622
+ kwargs['mode'] = ExecutionMode.from_string(kwargs['mode'])
623
+
624
+ # Convert workload_type string to enum if necessary
625
+ if 'workload_type' in kwargs and isinstance(kwargs['workload_type'], str):
626
+ kwargs['workload_type'] = WorkloadType.from_string(kwargs['workload_type'])
627
+
628
+ kwargs['pipeline_config'] = PipelineConfig.from_kwargs(kwargs)
629
+ kwargs['preprocess_config'] = PreprocessConfig.from_kwargs(kwargs)
630
+ # Filter to only FastVideoArgs dataclass fields — pipeline-specific CLI
631
+ # args (e.g. enable_bsa, bsa_sparsity) live in PipelineConfig and must
632
+ # not be forwarded to the FastVideoArgs constructor.
633
+ valid_fields = {f.name for f in dataclasses.fields(cls)}
634
+ return cls(**{k: v for k, v in kwargs.items() if k in valid_fields})
635
+
636
+ def check_fastvideo_args(self) -> None:
637
+ """Validate inference arguments for consistency"""
638
+ from fastvideo.platforms import current_platform
639
+
640
+ if current_platform.is_mps():
641
+ self.use_fsdp_inference = False
642
+ self.dit_layerwise_offload = False
643
+
644
+ if self.dit_layerwise_offload:
645
+ if self.use_fsdp_inference:
646
+ logger.warning("dit_layerwise_offload is enabled, automatically disabling use_fsdp_inference.")
647
+ self.use_fsdp_inference = False
648
+ if self.dit_cpu_offload:
649
+ logger.warning("dit_layerwise_offload is enabled, automatically disabling dit_cpu_offload.")
650
+ self.dit_cpu_offload = False
651
+
652
+ # Validate mode and inference_mode consistency
653
+ assert isinstance(self.mode, ExecutionMode), f"Mode must be an ExecutionMode enum, got {type(self.mode)}"
654
+ assert self.mode in ExecutionMode.choices(), f"Invalid execution mode: {self.mode}"
655
+
656
+ # Validate workload type
657
+ assert isinstance(self.workload_type,
658
+ WorkloadType), f"Workload type must be a WorkloadType enum, got {type(self.workload_type)}"
659
+ assert self.workload_type in WorkloadType.choices(), f"Invalid workload type: {self.workload_type}"
660
+
661
+ if self.mode in [ExecutionMode.DISTILLATION, ExecutionMode.FINETUNING] and self.inference_mode:
662
+ logger.warning("Mode is 'training' but inference_mode is True. Setting inference_mode to False.")
663
+ self.inference_mode = False
664
+ elif self.mode in [ExecutionMode.INFERENCE, ExecutionMode.PREPROCESS] and not self.inference_mode:
665
+ logger.warning("Mode is '%s' but inference_mode is False. Setting inference_mode to True.", self.mode)
666
+ self.inference_mode = True
667
+
668
+ if not self.inference_mode:
669
+ assert self.hsdp_replicate_dim != -1, "hsdp_replicate_dim must be set for training"
670
+ assert self.hsdp_shard_dim != -1, "hsdp_shard_dim must be set for training"
671
+ assert self.sp_size != -1, "sp_size must be set for training"
672
+
673
+ if self.tp_size == -1:
674
+ self.tp_size = 1
675
+ if self.sp_size == -1:
676
+ self.sp_size = self.num_gpus
677
+ if self.hsdp_shard_dim == -1:
678
+ self.hsdp_shard_dim = self.num_gpus
679
+
680
+ assert self.sp_size <= self.num_gpus and self.num_gpus % self.sp_size == 0, "num_gpus must >= and be divisible by sp_size"
681
+ assert self.hsdp_replicate_dim <= self.num_gpus and self.num_gpus % self.hsdp_replicate_dim == 0, "num_gpus must >= and be divisible by hsdp_replicate_dim"
682
+ assert self.hsdp_shard_dim <= self.num_gpus and self.num_gpus % self.hsdp_shard_dim == 0, "num_gpus must >= and be divisible by hsdp_shard_dim"
683
+
684
+ if self.num_gpus < max(self.tp_size, self.sp_size):
685
+ self.num_gpus = max(self.tp_size, self.sp_size)
686
+
687
+ if self.pipeline_config is None:
688
+ raise ValueError("pipeline_config is not set in FastVideoArgs")
689
+
690
+ self.pipeline_config.check_pipeline_config()
691
+
692
+ # Add preprocessing config validation if needed
693
+ if self.mode == ExecutionMode.PREPROCESS:
694
+ if self.preprocess_config is None:
695
+ raise ValueError("preprocess_config is not set in FastVideoArgs when mode is PREPROCESS")
696
+ if self.preprocess_config.model_path == "":
697
+ self.preprocess_config.model_path = self.model_path
698
+ if not self.pipeline_config.vae_config.load_encoder:
699
+ self.pipeline_config.vae_config.load_encoder = True
700
+ self.preprocess_config.check_preprocess_config()
701
+
702
+
703
+ _current_fastvideo_args = None
704
+
705
+
706
+ def prepare_fastvideo_args(argv: list[str]) -> FastVideoArgs:
707
+ """
708
+ Prepare the inference arguments from the command line arguments.
709
+
710
+ Args:
711
+ argv: The command line arguments. Typically, it should be `sys.argv[1:]`
712
+ to ensure compatibility with `parse_args` when no arguments are passed.
713
+
714
+ Returns:
715
+ The inference arguments.
716
+ """
717
+ parser = FlexibleArgumentParser()
718
+ FastVideoArgs.add_cli_args(parser)
719
+ raw_args = parser.parse_args(argv)
720
+ fastvideo_args = FastVideoArgs.from_cli_args(raw_args)
721
+ global _current_fastvideo_args
722
+ _current_fastvideo_args = fastvideo_args
723
+ return fastvideo_args
724
+
725
+
726
+ @contextmanager
727
+ def set_current_fastvideo_args(fastvideo_args: FastVideoArgs):
728
+ """
729
+ Temporarily set the current fastvideo config.
730
+ Used during model initialization.
731
+ We save the current fastvideo config in a global variable,
732
+ so that all modules can access it, e.g. custom ops
733
+ can access the fastvideo config to determine how to dispatch.
734
+ """
735
+ global _current_fastvideo_args
736
+ old_fastvideo_args = _current_fastvideo_args
737
+ try:
738
+ _current_fastvideo_args = fastvideo_args
739
+ yield
740
+ finally:
741
+ _current_fastvideo_args = old_fastvideo_args
742
+
743
+
744
+ def get_current_fastvideo_args() -> FastVideoArgs:
745
+ if _current_fastvideo_args is None:
746
+ # in ci, usually when we test custom ops/modules directly,
747
+ # we don't set the fastvideo config. In that case, we set a default
748
+ # config.
749
+ # TODO(will): may need to handle this for CI.
750
+ raise ValueError("Current fastvideo args is not set.")
751
+ return _current_fastvideo_args
752
+
753
+
754
+ @dataclasses.dataclass
755
+ class TrainingArgs(FastVideoArgs):
756
+ """
757
+ Training arguments. Inherits from FastVideoArgs and adds training-specific
758
+ arguments. If there are any conflicts, the training arguments will take
759
+ precedence.
760
+ """
761
+ data_path: str = ""
762
+ dataloader_num_workers: int = 0
763
+ num_height: int = 0
764
+ num_width: int = 0
765
+ num_frames: int = 0
766
+
767
+ train_batch_size: int = 0
768
+ num_latent_t: int = 0
769
+ group_frame: bool = False
770
+ group_resolution: bool = False
771
+
772
+ # text encoder & vae & diffusion model
773
+ pretrained_model_name_or_path: str = ""
774
+
775
+ # DMD model paths - separate paths for each network
776
+ real_score_model_path: str = "" # path for real score (teacher) model
777
+ fake_score_model_path: str = "" # path for fake score (critic) model
778
+
779
+ # diffusion setting
780
+ ema_decay: float = 0.0
781
+ ema_start_step: int = 0
782
+ training_cfg_rate: float = 0.0
783
+ precondition_outputs: bool = False
784
+
785
+ # validation & logs
786
+ validation_dataset_file: str = ""
787
+ validation_preprocessed_path: str = ""
788
+ validation_sampling_steps: str = ""
789
+ validation_guidance_scale: str = ""
790
+ validation_steps: float = 0.0
791
+ log_validation: bool = False
792
+ trackers: list[str] = dataclasses.field(default_factory=list)
793
+ tracker_project_name: str = ""
794
+ wandb_run_name: str = ""
795
+ seed: int = 0
796
+ _loading_teacher_critic_model: bool = False
797
+
798
+ # output
799
+ output_dir: str = ""
800
+ checkpoints_total_limit: int = 0
801
+ resume_from_checkpoint: str = "" # specify the checkpoint folder to resume from
802
+
803
+ # optimizer & scheduler
804
+ num_train_epochs: int = 0
805
+ max_train_steps: int = 0
806
+ gradient_accumulation_steps: int = 0
807
+ learning_rate: float = 0.0
808
+ scale_lr: bool = False
809
+ lr_scheduler: str = "constant"
810
+ lr_warmup_steps: int = 0
811
+ max_grad_norm: float = 0.0
812
+ enable_gradient_checkpointing_type: str | None = None
813
+ selective_checkpointing: float = 0.0
814
+ mixed_precision: str = ""
815
+ train_sp_batch_size: int = 0
816
+ fsdp_sharding_startegy: str = ""
817
+
818
+ weighting_scheme: str = ""
819
+ logit_mean: float = 0.0
820
+ logit_std: float = 1.0
821
+ mode_scale: float = 0.0
822
+
823
+ num_euler_timesteps: int = 0
824
+ lr_num_cycles: int = 0
825
+ lr_power: float = 0.0
826
+ min_lr_ratio: float = 0.5 # minimum learning rate ratio for cosine_with_min_lr scheduler
827
+ not_apply_cfg_solver: bool = False
828
+ distill_cfg: float = 0.0
829
+ scheduler_type: str = ""
830
+ linear_quadratic_threshold: float = 0.0
831
+ linear_range: float = 0.0
832
+ weight_decay: float = 0.0
833
+ betas: str = "0.9,0.999" # betas for optimizer, format: "beta1,beta2"
834
+ use_ema: bool = False
835
+ multi_phased_distill_schedule: str = ""
836
+ pred_decay_weight: float = 0.0
837
+ pred_decay_type: str = ""
838
+ hunyuan_teacher_disable_cfg: bool = False
839
+
840
+ # master_weight_type
841
+ master_weight_type: str = ""
842
+
843
+ # VSA training decay parameters
844
+ VSA_decay_rate: float = 0.01 # decay rate -> 0.02
845
+ VSA_decay_interval_steps: int = 1 # decay interval steps -> 50
846
+ VSA_init_sparsity: float = 0.0 # initial sparsity (default 0, ramp from 0)
847
+ VSA_warmup_steps: int = 0 # keep init_sparsity for this many steps before ramping
848
+
849
+ # LoRA training parameters
850
+ lora_rank: int | None = None
851
+ lora_alpha: int | None = None
852
+ lora_training: bool = False
853
+ ltx2_first_frame_conditioning_p: float = 0.1
854
+
855
+ # distillation args
856
+ generator_update_interval: int = 5
857
+ dfake_gen_update_ratio: int = 5 # self-forcing: how often to train generator vs critic
858
+ min_timestep_ratio: float = 0.2
859
+ max_timestep_ratio: float = 0.98
860
+ real_score_guidance_scale: float = 3.5
861
+ fake_score_learning_rate: float = 0.0 # separate learning rate for fake_score_transformer, if 0.0, use learning_rate
862
+ fake_score_lr_scheduler: str = "constant" # separate lr scheduler for fake_score_transformer, if not set, use lr_scheduler
863
+ fake_score_betas: str = "0.9,0.999" # betas for fake score optimizer, format: "beta1,beta2"
864
+ training_state_checkpointing_steps: int = 0 # for resuming training
865
+ weight_only_checkpointing_steps: int = 0 # for inference
866
+ log_visualization: bool = False
867
+ visualization_steps: int = 0
868
+ # simulate generator forward to match inference
869
+ simulate_generator_forward: bool = False
870
+ warp_denoising_step: bool = False
871
+ generator_4bit_attn: bool = False
872
+ generator_4bit_linear: bool = False
873
+
874
+ # Self-forcing specific arguments
875
+ num_frame_per_block: int = 3
876
+ independent_first_frame: bool = False
877
+ enable_gradient_masking: bool = True
878
+ gradient_mask_last_n_frames: int = 21
879
+ same_step_across_blocks: bool = False # Use same exit timestep for all blocks
880
+ last_step_only: bool = False # Only use the last timestep for training
881
+ context_noise: int = 0 # Context noise level for cache updates
882
+
883
+ @classmethod
884
+ def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs":
885
+ provided_args = clean_cli_args(args)
886
+ # Get all fields from the dataclass
887
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
888
+ logger.info(provided_args)
889
+ # Create a dictionary of attribute values, with defaults for missing attributes
890
+ kwargs: dict[str, Any] = {}
891
+ for attr in attrs:
892
+ if attr == 'pipeline_config':
893
+ pipeline_config = PipelineConfig.from_kwargs(provided_args)
894
+ kwargs[attr] = pipeline_config
895
+ elif attr == 'mode':
896
+ # Convert string to ExecutionMode enum
897
+ mode_value = getattr(args, attr, ExecutionMode.FINETUNING.value)
898
+ kwargs[attr] = ExecutionMode.from_string(mode_value) if isinstance(mode_value, str) else mode_value
899
+ elif attr == 'workload_type':
900
+ # Convert string to WorkloadType enum
901
+ workload_type_value = getattr(args, 'workload_type', WorkloadType.T2V.value)
902
+ kwargs[attr] = WorkloadType.from_string(workload_type_value) if isinstance(workload_type_value,
903
+ str) else workload_type_value
904
+ # Use getattr with default value from the dataclass for potentially missing attributes
905
+ else:
906
+ # Get the field to check its default value
907
+ field = dataclasses.fields(cls)[next(i for i, f in enumerate(dataclasses.fields(cls))
908
+ if f.name == attr)]
909
+
910
+ # Check if the attribute is provided in args
911
+ if hasattr(args, attr):
912
+ value = getattr(args, attr)
913
+ else:
914
+ # Use the field's default value
915
+ if field.default_factory is not dataclasses.MISSING:
916
+ value = field.default_factory()
917
+ elif field.default is not dataclasses.MISSING:
918
+ value = field.default
919
+ else:
920
+ # No default value, use None
921
+ value = None
922
+
923
+ kwargs[attr] = value
924
+
925
+ return cls(**kwargs) # type: ignore
926
+
927
+ @staticmethod
928
+ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
929
+ parser.add_argument("--data-path", type=str, required=True, help="Path to parquet files")
930
+ parser.add_argument("--dataloader-num-workers",
931
+ type=int,
932
+ required=True,
933
+ help="Number of workers for dataloader")
934
+ parser.add_argument("--num-height", type=int, required=True, help="Number of heights")
935
+ parser.add_argument("--num-width", type=int, required=True, help="Number of widths")
936
+ parser.add_argument("--num-frames", type=int, required=True, help="Number of frames")
937
+
938
+ # Training batch and model configuration
939
+ parser.add_argument("--train-batch-size", type=int, required=True, help="Training batch size")
940
+ parser.add_argument("--num-latent-t", type=int, required=True, help="Number of latent time steps")
941
+ parser.add_argument("--group-frame", action=StoreBoolean, help="Whether to group frames during training")
942
+ parser.add_argument("--group-resolution",
943
+ action=StoreBoolean,
944
+ help="Whether to group resolutions during training")
945
+
946
+ # Model paths
947
+ parser.add_argument("--pretrained-model-name-or-path",
948
+ type=str,
949
+ required=True,
950
+ help="Path to pretrained model or model name")
951
+ parser.add_argument("--dit-model-name-or-path",
952
+ type=str,
953
+ required=False,
954
+ help="Path to DiT model or model name")
955
+ parser.add_argument("--cache-dir", type=str, help="Directory to cache models")
956
+
957
+ # DMD model paths - separate paths for each network
958
+ parser.add_argument("--generator-model-path",
959
+ type=str,
960
+ help="Path to generator (student) model for DMD distillation")
961
+ parser.add_argument("--real-score-model-path",
962
+ type=str,
963
+ help="Path to real score (teacher) model for DMD distillation")
964
+ parser.add_argument("--fake-score-model-path",
965
+ type=str,
966
+ help="Path to fake score (critic) model for DMD distillation")
967
+
968
+ # Diffusion settings
969
+ parser.add_argument("--ema-decay", type=float, default=0.999, help="EMA decay rate")
970
+ parser.add_argument("--ema-start-step", type=int, default=0, help="Step to start EMA")
971
+ parser.add_argument("--training-cfg-rate", type=float, help="Classifier-free guidance scale")
972
+ parser.add_argument("--precondition-outputs",
973
+ action=StoreBoolean,
974
+ help="Whether to precondition the outputs of the model")
975
+
976
+ # Validation and logging
977
+ parser.add_argument("--validation-dataset-file", type=str, help="Path to unprocessed validation dataset")
978
+ parser.add_argument("--validation-preprocessed-path", type=str, help="Path to processed validation dataset")
979
+ parser.add_argument("--validation-sampling-steps", type=str, help="Validation sampling steps")
980
+ parser.add_argument("--validation-guidance-scale", type=str, help="Validation guidance scale")
981
+ parser.add_argument("--validation-steps", type=float, help="Number of validation steps")
982
+ parser.add_argument("--log-validation", action=StoreBoolean, help="Whether to log validation results")
983
+ parser.add_argument("--visualization-steps", type=int, help="Number of visualization steps")
984
+ parser.add_argument("--tracker-project-name", type=str, help="Project name for tracking")
985
+ parser.add_argument("--wandb-run-name", type=str, help="Run name for wandb")
986
+ parser.add_argument("--seed", type=int, default=42, help="Seed for deterministic training")
987
+
988
+ # Output configuration
989
+ parser.add_argument("--output-dir", type=str, required=True, help="Output directory for checkpoints and logs")
990
+ parser.add_argument("--checkpoints-total-limit", type=int, help="Maximum number of checkpoints to keep")
991
+ parser.add_argument("--training-state-checkpointing-steps",
992
+ type=int,
993
+ help="Steps between training state checkpoints (for resuming training)")
994
+ parser.add_argument("--weight-only-checkpointing-steps",
995
+ type=int,
996
+ help="Steps between weight-only checkpoints (for inference)")
997
+ parser.add_argument("--resume-from-checkpoint", type=str, help="Path to checkpoint to resume from")
998
+ parser.add_argument("--logging-dir", type=str, help="Directory for logging")
999
+
1000
+ # Training configuration
1001
+ parser.add_argument("--num-train-epochs", type=int, help="Number of training epochs")
1002
+ parser.add_argument("--max-train-steps", type=int, help="Maximum number of training steps")
1003
+ parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of steps to accumulate gradients")
1004
+ parser.add_argument("--learning-rate", type=float, required=True, help="Learning rate")
1005
+ parser.add_argument("--scale-lr", action=StoreBoolean, help="Whether to scale learning rate")
1006
+ parser.add_argument("--lr-scheduler", type=str, default="constant", help="Learning rate scheduler type")
1007
+ parser.add_argument("--lr-warmup-steps", type=int, default=10, help="Number of warmup steps for learning rate")
1008
+ parser.add_argument("--max-grad-norm", type=float, help="Maximum gradient norm")
1009
+ parser.add_argument("--enable-gradient-checkpointing-type",
1010
+ type=str,
1011
+ choices=["full", "ops", "block_skip"],
1012
+ default=None,
1013
+ help="Gradient checkpointing type")
1014
+ parser.add_argument("--selective-checkpointing", type=float, help="Selective checkpointing threshold")
1015
+ parser.add_argument("--mixed-precision", type=str, help="Mixed precision training type")
1016
+ parser.add_argument("--train-sp-batch-size", type=int, help="Training spatial parallelism batch size")
1017
+
1018
+ parser.add_argument("--fsdp-sharding-strategy", type=str, help="FSDP sharding strategy")
1019
+
1020
+ parser.add_argument(
1021
+ "--weighting_scheme",
1022
+ type=str,
1023
+ default="uniform",
1024
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"],
1025
+ )
1026
+ parser.add_argument(
1027
+ "--logit_mean",
1028
+ type=float,
1029
+ default=0.0,
1030
+ help="mean to use when using the `'logit_normal'` weighting scheme.",
1031
+ )
1032
+ parser.add_argument(
1033
+ "--logit_std",
1034
+ type=float,
1035
+ default=1.0,
1036
+ help="std to use when using the `'logit_normal'` weighting scheme.",
1037
+ )
1038
+ parser.add_argument(
1039
+ "--mode_scale",
1040
+ type=float,
1041
+ default=1.29,
1042
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
1043
+ )
1044
+
1045
+ # Additional training parameters
1046
+ parser.add_argument("--num-euler-timesteps", type=int, help="Number of Euler timesteps")
1047
+ parser.add_argument("--lr-num-cycles", type=int, help="Number of learning rate cycles")
1048
+ parser.add_argument("--lr-power", type=float, help="Learning rate power")
1049
+ parser.add_argument("--min-lr-ratio",
1050
+ type=float,
1051
+ default=TrainingArgs.min_lr_ratio,
1052
+ help="Minimum learning rate ratio for cosine_with_min_lr scheduler")
1053
+ parser.add_argument("--not-apply-cfg-solver", action=StoreBoolean, help="Whether to not apply CFG solver")
1054
+ parser.add_argument("--distill-cfg", type=float, help="Distillation CFG scale")
1055
+ parser.add_argument("--scheduler-type", type=str, help="Scheduler type")
1056
+ parser.add_argument("--linear-quadratic-threshold", type=float, help="Linear quadratic threshold")
1057
+ parser.add_argument("--linear-range", type=float, help="Linear range")
1058
+ parser.add_argument("--weight-decay", type=float, help="Weight decay")
1059
+ parser.add_argument("--betas",
1060
+ type=str,
1061
+ default=TrainingArgs.betas,
1062
+ help="Betas for optimizer (format: 'beta1,beta2')")
1063
+ parser.add_argument("--use-ema", action=StoreBoolean, help="Whether to use EMA")
1064
+ parser.add_argument("--multi-phased-distill-schedule", type=str, help="Multi-phased distillation schedule")
1065
+ parser.add_argument("--pred-decay-weight", type=float, help="Prediction decay weight")
1066
+ parser.add_argument("--pred-decay-type", type=str, help="Prediction decay type")
1067
+ parser.add_argument("--hunyuan-teacher-disable-cfg",
1068
+ action=StoreBoolean,
1069
+ help="Whether to disable CFG for Hunyuan teacher")
1070
+ parser.add_argument("--master-weight-type", type=str, help="Master weight type")
1071
+
1072
+ # VSA parameters for training with dense to sparse adaption
1073
+ parser.add_argument(
1074
+ "--VSA-decay-rate", # decay rate, how much sparsity you want to decay each step
1075
+ type=float,
1076
+ default=TrainingArgs.VSA_decay_rate,
1077
+ help="VSA decay rate")
1078
+ parser.add_argument(
1079
+ "--VSA-decay-interval-steps", # how many steps for training with current sparsity
1080
+ type=int,
1081
+ default=TrainingArgs.VSA_decay_interval_steps,
1082
+ help="VSA decay interval steps")
1083
+ parser.add_argument(
1084
+ "--VSA-init-sparsity",
1085
+ type=float,
1086
+ default=TrainingArgs.VSA_init_sparsity,
1087
+ help="Initial sparsity to start from (default 0)")
1088
+ parser.add_argument(
1089
+ "--VSA-warmup-steps",
1090
+ type=int,
1091
+ default=TrainingArgs.VSA_warmup_steps,
1092
+ help="Keep init sparsity for N steps before ramping (default 0)")
1093
+ parser.add_argument("--lora-training", action=StoreBoolean, help="Whether to use LoRA training")
1094
+ parser.add_argument("--lora-rank", type=int, help="LoRA rank")
1095
+ parser.add_argument("--lora-alpha", type=int, help="LoRA alpha")
1096
+ parser.add_argument(
1097
+ "--ltx2-first-frame-conditioning-p",
1098
+ type=float,
1099
+ default=TrainingArgs.ltx2_first_frame_conditioning_p,
1100
+ help="Probability of conditioning on the first frame during LTX-2 training",
1101
+ )
1102
+
1103
+ # V-MoBA parameters
1104
+ parser.add_argument(
1105
+ "--moba-config-path",
1106
+ type=str,
1107
+ default=None,
1108
+ help="Path to a JSON file containing V-MoBA specific configurations.",
1109
+ )
1110
+
1111
+ # Distillation arguments
1112
+ parser.add_argument("--generator-update-interval",
1113
+ type=int,
1114
+ default=TrainingArgs.generator_update_interval,
1115
+ help="Ratio of student updates to critic updates.")
1116
+ parser.add_argument(
1117
+ "--dfake-gen-update-ratio",
1118
+ type=int,
1119
+ default=TrainingArgs.dfake_gen_update_ratio,
1120
+ help="Self-forcing: How often to train generator vs critic (train generator every N steps).")
1121
+ parser.add_argument("--min-timestep-ratio",
1122
+ type=float,
1123
+ default=TrainingArgs.min_timestep_ratio,
1124
+ help="Minimum step ratio")
1125
+ parser.add_argument("--max-timestep-ratio",
1126
+ type=float,
1127
+ default=TrainingArgs.max_timestep_ratio,
1128
+ help="Maximum step ratio")
1129
+ parser.add_argument("--real-score-guidance-scale",
1130
+ type=float,
1131
+ default=TrainingArgs.real_score_guidance_scale,
1132
+ help="Teacher guidance scale")
1133
+ parser.add_argument("--fake-score-learning-rate",
1134
+ type=float,
1135
+ default=TrainingArgs.fake_score_learning_rate,
1136
+ help="Learning rate for fake score transformer")
1137
+ parser.add_argument("--fake-score-betas",
1138
+ type=str,
1139
+ default=TrainingArgs.fake_score_betas,
1140
+ help="Betas for fake score optimizer (format: 'beta1,beta2')")
1141
+ parser.add_argument("--fake-score-lr-scheduler",
1142
+ type=str,
1143
+ default=TrainingArgs.fake_score_lr_scheduler,
1144
+ help="Learning rate scheduler for fake score transformer")
1145
+ parser.add_argument("--log-visualization", action=StoreBoolean, help="Whether to log visualization")
1146
+ parser.add_argument("--simulate-generator-forward",
1147
+ action=StoreBoolean,
1148
+ help="Whether to simulate generator forward to match inference")
1149
+ parser.add_argument("--warp-denoising-step",
1150
+ action=StoreBoolean,
1151
+ help="Whether to warp denoising step according to the scheduler time shift")
1152
+
1153
+ # Self-forcing specific arguments
1154
+ parser.add_argument("--num-frame-per-block",
1155
+ type=int,
1156
+ default=TrainingArgs.num_frame_per_block,
1157
+ help="Number of frames per block for causal generation")
1158
+ parser.add_argument("--independent-first-frame",
1159
+ action=StoreBoolean,
1160
+ help="Whether the first frame is independent in causal generation")
1161
+ parser.add_argument("--enable-gradient-masking",
1162
+ action=StoreBoolean,
1163
+ help="Whether to enable frame-level gradient masking")
1164
+ parser.add_argument("--gradient-mask-last-n-frames",
1165
+ type=int,
1166
+ default=TrainingArgs.gradient_mask_last_n_frames,
1167
+ help="Number of last frames to enable gradients for")
1168
+ parser.add_argument("--validate-cache-structure",
1169
+ action=StoreBoolean,
1170
+ help="Whether to validate KV cache structure (debug flag)")
1171
+ parser.add_argument("--same-step-across-blocks",
1172
+ action=StoreBoolean,
1173
+ help="Whether to use the same exit timestep for all blocks")
1174
+ parser.add_argument("--last-step-only",
1175
+ action=StoreBoolean,
1176
+ help="Whether to only use the last timestep for training")
1177
+ parser.add_argument("--context-noise",
1178
+ type=int,
1179
+ default=TrainingArgs.context_noise,
1180
+ help="Context noise level for cache updates")
1181
+
1182
+ return parser
1183
+
1184
+
1185
+ def parse_int_list(value: str) -> list[int]:
1186
+ if not value:
1187
+ return []
1188
+ return [int(x.strip()) for x in value.split(",")]
standalone_inference/overlay_files/fastvideo/forward_context.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/forward_context.py
3
+
4
+ import time
5
+ from collections import defaultdict
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass
8
+ from typing import TYPE_CHECKING, Optional
9
+
10
+ import torch
11
+
12
+ from fastvideo.logger import init_logger
13
+
14
+ if TYPE_CHECKING:
15
+ from fastvideo.attention import AttentionMetadata
16
+ from fastvideo.pipelines import ForwardBatch
17
+
18
+ logger = init_logger(__name__)
19
+
20
+ # TODO(will): check if this is needed
21
+ # track_batchsize: bool = envs.FASTVIDEO_LOG_BATCHSIZE_INTERVAL >= 0
22
+ track_batchsize: bool = False
23
+ last_logging_time: float = 0
24
+ forward_start_time: float = 0
25
+ # batchsize_logging_interval: float = envs.FASTVIDEO_LOG_BATCHSIZE_INTERVAL
26
+ batchsize_logging_interval: float = 1000
27
+ batchsize_forward_time: defaultdict = defaultdict(list)
28
+
29
+
30
+ #
31
+ @dataclass
32
+ class ForwardContext:
33
+ current_timestep: int
34
+ # TODO(will): check this arg
35
+ # copy from vllm_config.compilation_config.static_forward_context
36
+ # attn_layers: Dict[str, Any]
37
+ # TODO: extend to support per-layer dynamic forward context
38
+ attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
39
+ forward_batch: Optional["ForwardBatch"] = None
40
+ force_dense: bool = False
41
+
42
+
43
+ _forward_context: Optional["ForwardContext"] = None
44
+
45
+
46
+ def get_forward_context() -> "ForwardContext":
47
+ """Get the current forward context."""
48
+ assert _forward_context is not None, ("Forward context is not set. "
49
+ "Please use `set_forward_context` to set the forward context.")
50
+ return _forward_context
51
+
52
+
53
+ # TODO(will): finalize the interface
54
+ @contextmanager
55
+ def set_forward_context(current_timestep, attn_metadata, forward_batch: Optional["ForwardBatch"] = None, force_dense: bool = False):
56
+ """A context manager that stores the current forward context,
57
+ can be attention metadata, etc.
58
+ Here we can inject common logic for every model forward pass.
59
+ """
60
+ global forward_start_time
61
+ need_to_track_batchsize = track_batchsize and attn_metadata is not None
62
+ if need_to_track_batchsize:
63
+ forward_start_time = time.perf_counter()
64
+ global _forward_context
65
+ prev_context = _forward_context
66
+ _forward_context = ForwardContext(current_timestep=current_timestep,
67
+ attn_metadata=attn_metadata,
68
+ forward_batch=forward_batch,
69
+ force_dense=force_dense)
70
+
71
+ try:
72
+ yield
73
+ finally:
74
+ global last_logging_time, batchsize_logging_interval
75
+ if need_to_track_batchsize:
76
+ if hasattr(attn_metadata, "num_prefill_tokens"):
77
+ # for v0 attention backends
78
+ batchsize = attn_metadata.num_prefill_tokens + \
79
+ attn_metadata.num_decode_tokens
80
+ else:
81
+ # for v1 attention backends
82
+ batchsize = attn_metadata.num_input_tokens
83
+ now = time.perf_counter()
84
+ # time measurement is in milliseconds
85
+ batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000)
86
+ if now - last_logging_time > batchsize_logging_interval:
87
+ last_logging_time = now
88
+ forward_stats = []
89
+ for bs, times in batchsize_forward_time.items():
90
+ if len(times) <= 1:
91
+ # can be cudagraph / profiling run
92
+ continue
93
+ medium = torch.quantile(torch.tensor(times), q=0.5).item()
94
+ medium = round(medium, 2)
95
+ forward_stats.append((bs, len(times), medium))
96
+ forward_stats.sort(key=lambda x: x[1], reverse=True)
97
+ if forward_stats:
98
+ logger.info(("Batchsize forward time stats "
99
+ "(batchsize, count, median_time(ms)): %s"), forward_stats)
100
+ _forward_context = prev_context
standalone_inference/overlay_files/fastvideo/pipelines/basic/wan/__init__.py ADDED
File without changes
standalone_inference/overlay_files/fastvideo/pipelines/basic/wan/wan_pipeline.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ Wan video diffusion pipeline implementation.
4
+
5
+ This module contains an implementation of the Wan video diffusion pipeline
6
+ using the modular pipeline architecture.
7
+ """
8
+
9
+ from fastvideo.fastvideo_args import FastVideoArgs
10
+ from fastvideo.logger import init_logger
11
+ from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import (FlowUniPCMultistepScheduler)
12
+ from fastvideo.pipelines import ComposedPipelineBase, LoRAPipeline
13
+ from fastvideo.pipelines.stages import (ConditioningStage, DecodingStage, DenoisingStage, InputValidationStage,
14
+ LatentPreparationStage, TextEncodingStage, TimestepPreparationStage)
15
+
16
+ logger = init_logger(__name__)
17
+
18
+
19
+ class WanPipeline(LoRAPipeline, ComposedPipelineBase):
20
+ """
21
+ Wan video diffusion pipeline with LoRA support.
22
+ """
23
+
24
+ _required_config_modules = ["text_encoder", "tokenizer", "vae", "transformer", "scheduler"]
25
+
26
+ def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
27
+ # We use UniPCMScheduler from Wan2.1 official repo, not the one in diffusers.
28
+ self.modules["scheduler"] = FlowUniPCMultistepScheduler(shift=fastvideo_args.pipeline_config.flow_shift)
29
+
30
+ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
31
+ """Set up pipeline stages with proper dependency injection."""
32
+
33
+ self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())
34
+
35
+ self.add_stage(stage_name="prompt_encoding_stage",
36
+ stage=TextEncodingStage(
37
+ text_encoders=[self.get_module("text_encoder")],
38
+ tokenizers=[self.get_module("tokenizer")],
39
+ ))
40
+
41
+ self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())
42
+
43
+ self.add_stage(stage_name="timestep_preparation_stage",
44
+ stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))
45
+
46
+ self.add_stage(stage_name="latent_preparation_stage",
47
+ stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
48
+ transformer=self.get_module("transformer", None)))
49
+
50
+ self.add_stage(stage_name="denoising_stage",
51
+ stage=DenoisingStage(transformer=self.get_module("transformer"),
52
+ transformer_2=self.get_module("transformer_2", None),
53
+ scheduler=self.get_module("scheduler"),
54
+ vae=self.get_module("vae"),
55
+ pipeline=self))
56
+
57
+ self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), pipeline=self))
58
+
59
+
60
+ EntryClass = WanPipeline
standalone_inference/overlay_files/fastvideo/pipelines/composed_pipeline_base.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ Base class for composed pipelines.
4
+
5
+ This module defines the base class for pipelines that are composed of multiple stages.
6
+ """
7
+
8
+ import argparse
9
+ import os
10
+ from abc import ABC, abstractmethod
11
+ from typing import Any, cast
12
+
13
+ import torch
14
+
15
+ from fastvideo.configs.pipelines import PipelineConfig
16
+ from fastvideo.distributed import (maybe_init_distributed_environment_and_model_parallel, get_world_group)
17
+ from fastvideo.distributed.communication_op import (warmup_sequence_parallel_communication)
18
+ from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs
19
+ from fastvideo.logger import init_logger
20
+ from fastvideo.profiler import get_or_create_profiler
21
+ from fastvideo.models.loader.component_loader import PipelineComponentLoader
22
+ from fastvideo.pipelines.pipeline_batch_info import ForwardBatch
23
+ from fastvideo.pipelines.stages import PipelineStage
24
+ import fastvideo.envs as envs
25
+ from fastvideo.utils import (maybe_download_model, verify_model_config_and_directory)
26
+
27
+ logger = init_logger(__name__)
28
+
29
+
30
+ class ComposedPipelineBase(ABC):
31
+ """
32
+ Base class for pipelines composed of multiple stages.
33
+
34
+ This class provides the framework for creating pipelines by composing multiple
35
+ stages together. Each stage is responsible for a specific part of the diffusion
36
+ process, and the pipeline orchestrates the execution of these stages.
37
+ """
38
+
39
+ is_video_pipeline: bool = False # To be overridden by video pipelines
40
+ _required_config_modules: list[str] = []
41
+ _extra_config_module_map: dict[str, str] = {}
42
+ training_args: Any = None
43
+ fastvideo_args: Any = None
44
+ modules: dict[str, Any] = {}
45
+ # do not need to include moe related transformers
46
+ trainable_transformer_names: list[str] = ["transformer"]
47
+ trainable_transformer_modules: dict[str, torch.nn.Module] = {}
48
+ post_init_called: bool = False
49
+
50
+ # TODO(will): args should support both inference args and training args
51
+ def __init__(self,
52
+ model_path: str,
53
+ fastvideo_args: FastVideoArgs | TrainingArgs,
54
+ required_config_modules: list[str] | None = None,
55
+ loaded_modules: dict[str, torch.nn.Module] | None = None):
56
+ """
57
+ Initialize the pipeline. After __init__, the pipeline should be ready to
58
+ use. The pipeline should be stateless and not hold any batch state.
59
+ """
60
+ self.fastvideo_args = fastvideo_args
61
+
62
+ self.model_path: str = model_path
63
+ self._stages: list[PipelineStage] = []
64
+ self._stage_name_mapping: dict[str, PipelineStage] = {}
65
+
66
+ if required_config_modules is not None:
67
+ self._required_config_modules = required_config_modules
68
+
69
+ if self._required_config_modules is None:
70
+ raise NotImplementedError("Subclass must set _required_config_modules")
71
+
72
+ maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)
73
+
74
+ # Torch profiler. Enabled and configured through env vars:
75
+ # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
76
+ trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
77
+ self.profiler_controller = get_or_create_profiler(trace_dir)
78
+ self.profiler = self.profiler_controller.profiler
79
+
80
+ self.local_rank = get_world_group().local_rank
81
+
82
+ # Load modules directly in initialization
83
+ logger.info("Loading pipeline modules...")
84
+ with self.profiler_controller.region("profiler_region_model_loading"):
85
+ self.modules = self.load_modules(fastvideo_args, loaded_modules)
86
+
87
+ def set_trainable(self) -> None:
88
+ # Only train DiT
89
+ if getattr(self.fastvideo_args, "training_mode", False):
90
+ for name, module in self.trainable_transformer_modules.items():
91
+ logger.info("Setting %s to requires_grad=True", name)
92
+ if not isinstance(module, torch.nn.Module):
93
+ logger.info("Skipping %s because it is not a torch.nn.Module", name)
94
+ continue
95
+ module.requires_grad_(True)
96
+ module.train()
97
+
98
+ @staticmethod
99
+ def _compile_with_conditions(
100
+ module: torch.nn.Module,
101
+ compile_kwargs: dict[str, Any],
102
+ ) -> int:
103
+ """Compile submodules that match module._compile_conditions."""
104
+ compile_conditions = getattr(module, "_compile_conditions", None)
105
+ if not compile_conditions:
106
+ return 0
107
+
108
+ compiled_count = 0
109
+ for name, submodule in module.named_modules():
110
+ if not name:
111
+ continue
112
+ if any(cond(name, submodule) for cond in compile_conditions):
113
+ submodule.forward = torch.compile(submodule.forward, **compile_kwargs)
114
+ compiled_count += 1
115
+ return compiled_count
116
+
117
+ def _maybe_compile_pipeline_module(
118
+ self,
119
+ module_name: str,
120
+ fsdp_module_cls: type | None,
121
+ compile_kwargs: dict[str, Any],
122
+ ) -> None:
123
+ if module_name not in self.modules:
124
+ return
125
+
126
+ module = self.modules[module_name]
127
+ if fsdp_module_cls is not None and isinstance(module, fsdp_module_cls):
128
+ logger.info(
129
+ "%s is already FSDP-wrapped; skipping torch.compile in pipeline",
130
+ module_name.capitalize(),
131
+ )
132
+ return
133
+
134
+ compiled_count = self._compile_with_conditions(module, compile_kwargs)
135
+ if compiled_count > 0:
136
+ logger.info(
137
+ "Enabled torch.compile for %d submodules in %s via _compile_conditions with kwargs=%s",
138
+ compiled_count,
139
+ module_name,
140
+ compile_kwargs,
141
+ )
142
+ return
143
+
144
+ # Backward-compatible fallback: compile full module if no condition matched.
145
+ logger.info("Enabling torch.compile for %s with kwargs=%s", module_name, compile_kwargs)
146
+ self.modules[module_name] = torch.compile(module, **compile_kwargs)
147
+
148
+ def post_init(self) -> None:
149
+ assert self.fastvideo_args is not None, "fastvideo_args must be set"
150
+ if self.post_init_called:
151
+ return
152
+ self.post_init_called = True
153
+ if self.fastvideo_args.training_mode:
154
+ assert isinstance(self.fastvideo_args, TrainingArgs)
155
+ self.training_args = self.fastvideo_args
156
+ assert self.training_args is not None
157
+ self.initialize_training_pipeline(self.training_args)
158
+ if self.training_args.log_validation:
159
+ self.initialize_validation_pipeline(self.training_args)
160
+
161
+ self.initialize_pipeline(self.fastvideo_args)
162
+ if self.fastvideo_args.enable_torch_compile:
163
+ if self.fastvideo_args.training_mode:
164
+ logger.info("Torch Compile enabled via FSDP loader for training; skipping additional pipeline compile")
165
+ else:
166
+ fsdp_module_cls = None
167
+ try:
168
+ from torch.distributed.fsdp import FSDPModule # type: ignore
169
+ fsdp_module_cls = FSDPModule
170
+ except Exception: # pragma: no cover - FSDP not always available
171
+ fsdp_module_cls = None
172
+
173
+ compile_kwargs = self.fastvideo_args.torch_compile_kwargs or {}
174
+ self._maybe_compile_pipeline_module(
175
+ module_name="transformer",
176
+ fsdp_module_cls=fsdp_module_cls,
177
+ compile_kwargs=compile_kwargs,
178
+ )
179
+ self._maybe_compile_pipeline_module(
180
+ module_name="transformer_2",
181
+ fsdp_module_cls=fsdp_module_cls,
182
+ compile_kwargs=compile_kwargs,
183
+ )
184
+ logger.info("Torch Compile enabled for DiT")
185
+
186
+ if not self.fastvideo_args.training_mode:
187
+ logger.info("Creating pipeline stages...")
188
+ self.create_pipeline_stages(self.fastvideo_args)
189
+
190
+ # Warmup NCCL communicators for sequence parallelism to avoid
191
+ # slow first forward pass due to lazy initialization
192
+ warmup_sequence_parallel_communication()
193
+
194
+ def initialize_training_pipeline(self, training_args: TrainingArgs):
195
+ raise NotImplementedError("if training_mode is True, the pipeline must implement this method")
196
+
197
+ def initialize_validation_pipeline(self, training_args: TrainingArgs):
198
+ raise NotImplementedError("if log_validation is True, the pipeline must implement this method")
199
+
200
+ @classmethod
201
+ def from_pretrained(cls,
202
+ model_path: str,
203
+ device: str | None = None,
204
+ torch_dtype: torch.dtype | None = None,
205
+ pipeline_config: str | PipelineConfig | None = None,
206
+ args: argparse.Namespace | FastVideoArgs | TrainingArgs | None = None,
207
+ required_config_modules: list[str] | None = None,
208
+ loaded_modules: dict[str, torch.nn.Module]
209
+ | None = None,
210
+ **kwargs) -> "ComposedPipelineBase":
211
+ """
212
+ Load a pipeline from a pretrained model.
213
+ loaded_modules: Optional[Dict[str, torch.nn.Module]] = None,
214
+ If provided, loaded_modules will be used instead of loading from config/pretrained weights.
215
+ """
216
+ if args is None or (isinstance(args, FastVideoArgs) and args.inference_mode):
217
+
218
+ kwargs['model_path'] = model_path
219
+ fastvideo_args = FastVideoArgs.from_kwargs(**kwargs)
220
+ else:
221
+ if isinstance(args, TrainingArgs):
222
+ fastvideo_args = args
223
+ else:
224
+ assert isinstance(args, argparse.Namespace), "training mode expects argparse.Namespace args"
225
+ fastvideo_args = TrainingArgs.from_cli_args(args)
226
+ # TODO(will): fix this so that its not so ugly
227
+ fastvideo_args.model_path = model_path
228
+ for key, value in kwargs.items():
229
+ setattr(fastvideo_args, key, value)
230
+
231
+ fastvideo_args.dit_cpu_offload = False
232
+ # we hijack the precision to be the master weight type so that the
233
+ # model is loaded with the correct precision. Subsequently we will
234
+ # use FSDP2's MixedPrecisionPolicy to set the precision for the
235
+ # fwd, bwd, and other operations' precision.
236
+ assert fastvideo_args.pipeline_config.dit_precision == 'fp32', 'only fp32 is supported for training'
237
+
238
+ logger.info("fastvideo_args in from_pretrained: %s", fastvideo_args)
239
+
240
+ pipe = cls(model_path,
241
+ fastvideo_args,
242
+ required_config_modules=required_config_modules,
243
+ loaded_modules=loaded_modules)
244
+ pipe.post_init()
245
+ return pipe
246
+
247
+ def get_module(self, module_name: str, default_value: Any = None) -> Any:
248
+ if module_name not in self.modules:
249
+ return default_value
250
+ return self.modules[module_name]
251
+
252
+ def add_module(self, module_name: str, module: Any):
253
+ self.modules[module_name] = module
254
+
255
+ def __getattr__(self, name: str) -> Any:
256
+ if "_stage_name_mapping" in self.__dict__ and name in self._stage_name_mapping:
257
+ return self._stage_name_mapping[name]
258
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
259
+
260
+ def _load_config(self, model_path: str) -> dict[str, Any]:
261
+ model_path = maybe_download_model(self.model_path)
262
+ self.model_path = model_path
263
+ # fastvideo_args.downloaded_model_path = model_path
264
+ logger.info("Model path: %s", model_path)
265
+ config = verify_model_config_and_directory(model_path)
266
+ return cast(dict[str, Any], config)
267
+
268
+ @property
269
+ def required_config_modules(self) -> list[str]:
270
+ """
271
+ List of modules that are required by the pipeline. The names should match
272
+ the diffusers directory and model_index.json file. These modules will be
273
+ loaded using the PipelineComponentLoader and made available in the
274
+ modules dictionary. Access these modules using the get_module method.
275
+
276
+ class ConcretePipeline(ComposedPipelineBase):
277
+ _required_config_modules = ["vae", "text_encoder", "transformer", "scheduler", "tokenizer"]
278
+
279
+
280
+ @property
281
+ def required_config_modules(self):
282
+ return self._required_config_modules
283
+ """
284
+ return self._required_config_modules
285
+
286
+ @property
287
+ def stages(self) -> list[PipelineStage]:
288
+ """
289
+ List of stages in the pipeline.
290
+ """
291
+ return self._stages
292
+
293
+ @abstractmethod
294
+ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
295
+ """
296
+ Create the inference pipeline stages.
297
+ """
298
+ raise NotImplementedError
299
+
300
+ def create_training_stages(self, training_args: TrainingArgs):
301
+ """
302
+ Create the training pipeline stages.
303
+ """
304
+ raise NotImplementedError
305
+
306
+ def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
307
+ """
308
+ Initialize the pipeline.
309
+ """
310
+ return
311
+
312
+ def load_modules(self,
313
+ fastvideo_args: FastVideoArgs,
314
+ loaded_modules: dict[str, torch.nn.Module] | None = None) -> dict[str, Any]:
315
+ """
316
+ Load the modules from the config.
317
+ loaded_modules: Optional[Dict[str, torch.nn.Module]] = None,
318
+ If provided, loaded_modules will be used instead of loading from config/pretrained weights.
319
+ """
320
+
321
+ model_index = self._load_config(self.model_path)
322
+ logger.info("Loading pipeline modules from config: %s", model_index)
323
+
324
+ # remove keys that are not pipeline modules
325
+ model_index.pop("_class_name")
326
+ model_index.pop("_diffusers_version")
327
+ model_index.pop("_name_or_path", None)
328
+ model_index.pop("workload_type", None)
329
+ if "boundary_ratio" in model_index and model_index["boundary_ratio"] is not None:
330
+ logger.info("MoE pipeline detected. Adding transformer_2 to self.required_config_modules...")
331
+ self.required_config_modules.append("transformer_2")
332
+ logger.info("MoE pipeline detected. Setting boundary ratio to %s", model_index["boundary_ratio"])
333
+ fastvideo_args.pipeline_config.dit_config.boundary_ratio = model_index["boundary_ratio"]
334
+
335
+ model_index.pop("boundary_ratio", None)
336
+ # used by Wan2.2 ti2v
337
+ model_index.pop("expand_timesteps", None)
338
+
339
+ # some sanity checks
340
+ assert len(model_index) > 1, "model_index.json must contain at least one pipeline module"
341
+
342
+ for module_name in self.required_config_modules:
343
+ if module_name not in model_index and module_name in self._extra_config_module_map:
344
+ extra_module_value = self._extra_config_module_map[module_name]
345
+ logger.warning(
346
+ "model_index.json does not contain a %s module, but found {%s: %s} in _extra_config_module_map, adding to model_index.",
347
+ module_name, module_name, extra_module_value)
348
+ if extra_module_value in model_index:
349
+ logger.info("Using module %s for %s", extra_module_value, module_name)
350
+ model_index[module_name] = model_index[extra_module_value]
351
+ continue
352
+ else:
353
+ raise ValueError(
354
+ f"Required module key: {module_name} value: {model_index.get(module_name)} was not found in loaded modules {model_index.keys()}"
355
+ )
356
+
357
+ # all the component models used by the pipeline
358
+ required_modules = self.required_config_modules
359
+ logger.info("Loading required modules: %s", required_modules)
360
+
361
+ modules = {}
362
+ for module_name, module_spec in model_index.items():
363
+ if not isinstance(module_spec, list | tuple):
364
+ logger.info(
365
+ "Skipping non-module config entry %s=%s",
366
+ module_name,
367
+ module_spec,
368
+ )
369
+ continue
370
+ if len(module_spec) < 1:
371
+ logger.warning(
372
+ "Skipping module %s due to invalid empty spec in model_index.json",
373
+ module_name,
374
+ )
375
+ continue
376
+ transformers_or_diffusers = module_spec[0]
377
+ if transformers_or_diffusers is None:
378
+ logger.warning("Module %s in model_index.json has null value, removing from required_config_modules",
379
+ module_name)
380
+ if module_name in self.required_config_modules:
381
+ self.required_config_modules.remove(module_name)
382
+ continue
383
+ if module_name not in required_modules:
384
+ logger.info("Skipping module %s", module_name)
385
+ continue
386
+ if loaded_modules is not None and module_name in loaded_modules:
387
+ logger.info("Using module %s already provided", module_name)
388
+ modules[module_name] = loaded_modules[module_name]
389
+ continue
390
+
391
+ # we load the module from the extra config module map if it exists
392
+ if module_name in self._extra_config_module_map:
393
+ load_module_name = self._extra_config_module_map[module_name]
394
+ else:
395
+ load_module_name = module_name
396
+
397
+ component_model_path = os.path.join(self.model_path, load_module_name)
398
+ module = PipelineComponentLoader.load_module(
399
+ module_name=load_module_name,
400
+ component_model_path=component_model_path,
401
+ transformers_or_diffusers=transformers_or_diffusers,
402
+ fastvideo_args=fastvideo_args,
403
+ )
404
+ logger.info("Loaded module %s from %s", module_name, component_model_path)
405
+
406
+ if module_name in modules:
407
+ logger.warning("Overwriting module %s", module_name)
408
+ modules[module_name] = module
409
+
410
+ # Check if all required modules were loaded
411
+ for module_name in required_modules:
412
+ if module_name not in modules or modules[module_name] is None:
413
+ raise ValueError(
414
+ f"Required module key: {module_name} value: {modules.get(module_name)} was not found in loaded modules {modules.keys()}"
415
+ )
416
+
417
+ return modules
418
+
419
+ def add_stage(self, stage_name: str, stage: PipelineStage):
420
+ assert self.modules is not None, "No modules are registered"
421
+ self._stages.append(stage)
422
+ self._stage_name_mapping[stage_name] = stage
423
+ setattr(self, stage_name, stage)
424
+
425
+ def profile(self, is_start: bool = True):
426
+ if self.profiler is None:
427
+ raise RuntimeError("Profiler is not enabled.")
428
+ if is_start:
429
+ self.profiler.start()
430
+ else:
431
+ self.profiler.stop()
432
+ # only print profiler results on rank 0
433
+ if self.local_rank == 0:
434
+ print(self.profiler.key_averages().table(sort_by="self_cuda_time_total"))
435
+
436
+ # TODO(will): don't hardcode no_grad
437
+ @torch.no_grad()
438
+ def forward(
439
+ self,
440
+ batch: ForwardBatch,
441
+ fastvideo_args: FastVideoArgs,
442
+ ) -> ForwardBatch:
443
+ """
444
+ Generate a video or image using the pipeline.
445
+
446
+ Args:
447
+ batch: The batch to generate from.
448
+ fastvideo_args: The inference arguments.
449
+ Returns:
450
+ ForwardBatch: The batch with the generated video or image.
451
+ """
452
+ if not self.post_init_called:
453
+ self.post_init()
454
+
455
+ # Execute each stage
456
+ logger.info("Running pipeline stages: %s", self._stage_name_mapping.keys())
457
+ # logger.info("Batch: %s", batch)
458
+ for stage in self.stages:
459
+ batch = stage(batch, fastvideo_args)
460
+
461
+ # Return the output
462
+ return batch
463
+
464
+ def train(self) -> None:
465
+ raise NotImplementedError("if training_mode is True, the pipeline must implement this method")
466
+
467
+ def streaming_reset(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch:
468
+ raise NotImplementedError(f"{type(self).__name__} does not support streaming_reset")
469
+
470
+ def streaming_step(self, *args: Any, **kwargs: Any) -> ForwardBatch:
471
+ raise NotImplementedError(f"{type(self).__name__} does not support streaming_step")
472
+
473
+ def streaming_clear(self) -> None:
474
+ raise NotImplementedError(f"{type(self).__name__} does not support streaming_clear")
standalone_inference/overlay_files/fastvideo/pipelines/stages/denoising.py ADDED
@@ -0,0 +1,1184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ Denoising stage for diffusion pipelines.
4
+ """
5
+
6
+ import inspect
7
+ import weakref
8
+ from collections.abc import Iterable
9
+ from typing import Any
10
+
11
+ import torch
12
+ from tqdm.auto import tqdm
13
+
14
+ from fastvideo.attention import get_attn_backend
15
+ from fastvideo.distributed import (get_local_torch_device, get_world_group)
16
+ from fastvideo.fastvideo_args import FastVideoArgs
17
+ from fastvideo.forward_context import set_forward_context
18
+ from fastvideo.logger import init_logger
19
+ from fastvideo.models.loader.component_loader import TransformerLoader
20
+ from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import (FlowMatchEulerDiscreteScheduler)
21
+ from fastvideo.models.utils import pred_noise_to_pred_video
22
+ from fastvideo.pipelines.pipeline_batch_info import ForwardBatch
23
+ from fastvideo.pipelines.stages.base import PipelineStage
24
+ from fastvideo.pipelines.stages.validators import StageValidators as V
25
+ from fastvideo.pipelines.stages.validators import VerificationResult
26
+ from fastvideo.platforms import AttentionBackendEnum
27
+ from fastvideo.utils import dict_to_3d_list, masks_like
28
+
29
+ try:
30
+ from fastvideo.attention.backends.vmoba import VMOBAAttentionBackend
31
+ from fastvideo.utils import is_vmoba_available
32
+ vmoba_attn_available = is_vmoba_available()
33
+ except ImportError:
34
+ vmoba_attn_available = False
35
+
36
+ try:
37
+ from fastvideo.attention.backends.video_sparse_attn import (VideoSparseAttentionBackend)
38
+ vsa_available = True
39
+ except ImportError:
40
+ vsa_available = False
41
+
42
+ try:
43
+ from fastvideo.attention.backends.sparse_fp4_attn import (SparseFP4AttentionBackend)
44
+ except ImportError:
45
+ SparseFP4AttentionBackend = None # type: ignore[assignment]
46
+
47
+ try:
48
+ from fastvideo.attention.backends.sparse_fp4_ours_p_attn import (SparseFP4OursPAttentionBackend)
49
+ except ImportError:
50
+ SparseFP4OursPAttentionBackend = None # type: ignore[assignment]
51
+
52
+ sparse_fp4_backends = tuple(
53
+ backend for backend in (
54
+ SparseFP4AttentionBackend,
55
+ SparseFP4OursPAttentionBackend,
56
+ ) if backend is not None)
57
+ sparse_fp4_available = bool(sparse_fp4_backends)
58
+
59
+ logger = init_logger(__name__)
60
+
61
+
62
+ class DenoisingStage(PipelineStage):
63
+ """
64
+ Stage for running the denoising loop in diffusion pipelines.
65
+
66
+ This stage handles the iterative denoising process that transforms
67
+ the initial noise into the final output.
68
+ """
69
+
70
+ def __init__(self, transformer, scheduler, pipeline=None, transformer_2=None, vae=None) -> None:
71
+ super().__init__()
72
+ self.transformer = transformer
73
+ self.transformer_2 = transformer_2
74
+ self.scheduler = scheduler
75
+ self.vae = vae
76
+ self.pipeline = weakref.ref(pipeline) if pipeline else None
77
+ attn_head_size = self.transformer.hidden_size // self.transformer.num_attention_heads
78
+ self.attn_backend = get_attn_backend(
79
+ head_size=attn_head_size,
80
+ dtype=torch.float16, # TODO(will): hack
81
+ supported_attention_backends=(
82
+ AttentionBackendEnum.VIDEO_SPARSE_ATTN, AttentionBackendEnum.BSA_ATTN, AttentionBackendEnum.VMOBA_ATTN,
83
+ AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.SAGE_ATTN_THREE,
84
+ AttentionBackendEnum.ATTN_QAT_INFER, AttentionBackendEnum.ATTN_QAT_TRAIN,
85
+ AttentionBackendEnum.SPARSE_FP4_ATTN, AttentionBackendEnum.SPARSE_FP4_OURS_P_ATTN) # hack
86
+ )
87
+
88
+ def forward(
89
+ self,
90
+ batch: ForwardBatch,
91
+ fastvideo_args: FastVideoArgs,
92
+ ) -> ForwardBatch:
93
+ """
94
+ Run the denoising loop.
95
+
96
+ Args:
97
+ batch: The current batch information.
98
+ fastvideo_args: The inference arguments.
99
+
100
+ Returns:
101
+ The batch with denoised latents.
102
+ """
103
+ pipeline = self.pipeline() if self.pipeline else None
104
+ if not fastvideo_args.model_loaded["transformer"]:
105
+ loader = TransformerLoader()
106
+ self.transformer = loader.load(fastvideo_args.model_paths["transformer"], fastvideo_args)
107
+ if pipeline:
108
+ pipeline.add_module("transformer", self.transformer)
109
+ fastvideo_args.model_loaded["transformer"] = True
110
+
111
+ # Prepare extra step kwargs for scheduler
112
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
113
+ self.scheduler.step,
114
+ {
115
+ "generator": batch.generator,
116
+ "eta": batch.eta
117
+ },
118
+ )
119
+
120
+ # Setup precision and autocast settings
121
+ # TODO(will): make the precision configurable for inference
122
+ # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
123
+ target_dtype = torch.bfloat16
124
+ autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
125
+
126
+ # Get timesteps and calculate warmup steps
127
+ timesteps = batch.timesteps
128
+ # TODO(will): remove this once we add input/output validation for stages
129
+ if timesteps is None:
130
+ raise ValueError("Timesteps must be provided")
131
+ num_inference_steps = batch.num_inference_steps
132
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
133
+
134
+ # Prepare image latents and embeddings for I2V generation
135
+ image_embeds = batch.image_embeds
136
+ if len(image_embeds) > 0:
137
+ assert not torch.isnan(image_embeds[0]).any(), "image_embeds contains nan"
138
+ image_embeds = [image_embed.to(target_dtype) for image_embed in image_embeds]
139
+
140
+ image_kwargs = self.prepare_extra_func_kwargs(
141
+ self.transformer.forward,
142
+ {
143
+ "encoder_hidden_states_image": image_embeds,
144
+ "mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24)
145
+ },
146
+ )
147
+
148
+ pos_cond_kwargs = self.prepare_extra_func_kwargs(
149
+ self.transformer.forward,
150
+ {
151
+ "encoder_hidden_states_2": batch.clip_embedding_pos,
152
+ "encoder_attention_mask": batch.prompt_attention_mask,
153
+ },
154
+ )
155
+
156
+ neg_cond_kwargs = self.prepare_extra_func_kwargs(
157
+ self.transformer.forward,
158
+ {
159
+ "encoder_hidden_states_2": batch.clip_embedding_neg,
160
+ "encoder_attention_mask": batch.negative_attention_mask,
161
+ },
162
+ )
163
+
164
+ action_kwargs = self.prepare_extra_func_kwargs(
165
+ self.transformer.forward,
166
+ {
167
+ "mouse_cond": batch.mouse_cond,
168
+ "keyboard_cond": batch.keyboard_cond,
169
+ "c2ws_plucker_emb": batch.c2ws_plucker_emb,
170
+ },
171
+ )
172
+
173
+ camera_kwargs = self.prepare_extra_func_kwargs(
174
+ self.transformer.forward,
175
+ {
176
+ "camera_states": batch.camera_states,
177
+ },
178
+ )
179
+
180
+ # Get latents and embeddings
181
+ latents = batch.latents
182
+ prompt_embeds = batch.prompt_embeds
183
+ assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan"
184
+ if batch.do_classifier_free_guidance:
185
+ neg_prompt_embeds = batch.negative_prompt_embeds
186
+ assert neg_prompt_embeds is not None
187
+ assert not torch.isnan(neg_prompt_embeds[0]).any(), "neg_prompt_embeds contains nan"
188
+
189
+ # (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
190
+ boundary_ratio = fastvideo_args.pipeline_config.dit_config.boundary_ratio
191
+ if batch.boundary_ratio is not None:
192
+ logger.info("Overriding boundary ratio from %s to %s", boundary_ratio, batch.boundary_ratio)
193
+ boundary_ratio = batch.boundary_ratio
194
+
195
+ boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps if boundary_ratio is not None else None
196
+ latent_model_input = latents.to(target_dtype)
197
+ assert latent_model_input.shape[0] == 1, "only support batch size 1"
198
+
199
+ if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
200
+ # TI2V directly replaces the first frame of the latent with
201
+ # the image latent instead of appending along the channel dim
202
+ assert batch.image_latent is None, "TI2V task should not have image latents"
203
+ assert self.vae is not None, "VAE is not provided for TI2V task"
204
+ z = self.vae.encode(batch.pil_image).mean.float()
205
+ if (hasattr(self.vae, "shift_factor") and self.vae.shift_factor is not None):
206
+ if isinstance(self.vae.shift_factor, torch.Tensor):
207
+ z -= self.vae.shift_factor.to(z.device, z.dtype)
208
+ else:
209
+ z -= self.vae.shift_factor
210
+
211
+ if isinstance(self.vae.scaling_factor, torch.Tensor):
212
+ z = z * self.vae.scaling_factor.to(z.device, z.dtype)
213
+ else:
214
+ z = z * self.vae.scaling_factor
215
+
216
+ latent_model_input = latent_model_input.squeeze(0)
217
+ _, mask2 = masks_like([latent_model_input], zero=True)
218
+
219
+ latent_model_input = (1. - mask2[0]) * z + mask2[0] * latent_model_input
220
+ # latent_model_input = latent_model_input.unsqueeze(0)
221
+ latent_model_input = latent_model_input.to(get_local_torch_device())
222
+ latents = latent_model_input
223
+ F = batch.num_frames
224
+ temporal_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_temporal
225
+ spatial_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
226
+ patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
227
+ if not isinstance(patch_size, tuple):
228
+ raise ValueError(f"Expected 3D patch_size tuple for denoising, got {patch_size!r}")
229
+ seq_len = ((F - 1) // temporal_scale + 1) * (batch.height // spatial_scale) * (
230
+ batch.width // spatial_scale) // (patch_size[1] * patch_size[2])
231
+
232
+ # Initialize lists for ODE trajectory
233
+ trajectory_timesteps: list[torch.Tensor] = []
234
+ trajectory_latents: list[torch.Tensor] = []
235
+
236
+ # Run denoising loop
237
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
238
+ for i, t in enumerate(timesteps):
239
+ # Skip if interrupted
240
+ if hasattr(self, 'interrupt') and self.interrupt:
241
+ continue
242
+
243
+ if boundary_timestep is None or t >= boundary_timestep:
244
+ if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
245
+ and self.transformer_2 is not None
246
+ and next(self.transformer_2.parameters()).device.type == 'cuda'):
247
+ self.transformer_2.to('cpu')
248
+ current_model = self.transformer
249
+ if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
250
+ and not fastvideo_args.use_fsdp_inference and current_model is not None):
251
+ transformer_device = next(current_model.parameters()).device.type
252
+ if transformer_device == 'cpu':
253
+ current_model.to(get_local_torch_device())
254
+ current_guidance_scale = batch.guidance_scale
255
+ else:
256
+ # low-noise stage in wan2.2
257
+ if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
258
+ and next(self.transformer.parameters()).device.type == 'cuda'):
259
+ self.transformer.to('cpu')
260
+ current_model = self.transformer_2
261
+ if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
262
+ and not fastvideo_args.use_fsdp_inference and current_model is not None):
263
+ transformer_2_device = next(current_model.parameters()).device.type
264
+ if transformer_2_device == 'cpu':
265
+ current_model.to(get_local_torch_device())
266
+ current_guidance_scale = batch.guidance_scale_2
267
+ assert current_model is not None, "current_model is None"
268
+
269
+ # Expand latents for V2V/I2V
270
+ latent_model_input = latents.to(target_dtype)
271
+ if batch.video_latent is not None:
272
+ latent_model_input = torch.cat([latent_model_input, batch.video_latent,
273
+ torch.zeros_like(latents)],
274
+ dim=1).to(target_dtype)
275
+ elif batch.image_latent is not None:
276
+ assert not fastvideo_args.pipeline_config.ti2v_task, "image latents should not be provided for TI2V task"
277
+ latent_model_input = torch.cat([latent_model_input, batch.image_latent], dim=1).to(target_dtype)
278
+
279
+ assert not torch.isnan(latent_model_input).any(), "latent_model_input contains nan"
280
+ if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
281
+ timestep = torch.stack([t]).to(get_local_torch_device())
282
+ temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
283
+ temp_ts = torch.cat([temp_ts, temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep])
284
+ timestep = temp_ts.unsqueeze(0)
285
+ t_expand = timestep.repeat(latent_model_input.shape[0], 1)
286
+ else:
287
+ t_expand = t.repeat(latent_model_input.shape[0])
288
+ t_expand = t_expand.to(get_local_torch_device())
289
+
290
+ use_meanflow = getattr(self.transformer.config, "use_meanflow", False)
291
+ if use_meanflow:
292
+ if i == len(timesteps) - 1:
293
+ timesteps_r = torch.tensor([0.0], device=get_local_torch_device())
294
+ else:
295
+ timesteps_r = timesteps[i + 1]
296
+ timesteps_r = timesteps_r.repeat(latent_model_input.shape[0])
297
+ else:
298
+ timesteps_r = None
299
+
300
+ timesteps_r_kwarg = self.prepare_extra_func_kwargs(
301
+ self.transformer.forward,
302
+ {
303
+ "timestep_r": timesteps_r,
304
+ },
305
+ )
306
+
307
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
308
+
309
+ # Prepare inputs for transformer
310
+ guidance_expand = (torch.tensor(
311
+ [fastvideo_args.pipeline_config.embedded_cfg_scale] * latent_model_input.shape[0],
312
+ dtype=torch.float32,
313
+ device=get_local_torch_device(),
314
+ ).to(target_dtype) * 1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale is not None else None)
315
+
316
+ # Predict noise residual
317
+ with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
318
+ if (vsa_available and self.attn_backend == VideoSparseAttentionBackend) or \
319
+ (sparse_fp4_available and self.attn_backend in sparse_fp4_backends):
320
+ self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()
321
+
322
+ if self.attn_metadata_builder_cls is not None:
323
+ self.attn_metadata_builder = self.attn_metadata_builder_cls()
324
+ # TODO(will): clean this up
325
+ attn_metadata = self.attn_metadata_builder.build( # type: ignore
326
+ current_timestep=i, # type: ignore
327
+ raw_latent_shape=batch.raw_latent_shape[2:5], # type: ignore
328
+ patch_size=fastvideo_args.pipeline_config. # type: ignore
329
+ dit_config.patch_size, # type: ignore
330
+ VSA_sparsity=fastvideo_args.VSA_sparsity, # type: ignore
331
+ device=get_local_torch_device(),
332
+ )
333
+ assert attn_metadata is not None, "attn_metadata cannot be None"
334
+ else:
335
+ attn_metadata = None
336
+ elif (vmoba_attn_available and self.attn_backend == VMOBAAttentionBackend):
337
+ self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()
338
+ if self.attn_metadata_builder_cls is not None:
339
+ self.attn_metadata_builder = self.attn_metadata_builder_cls()
340
+ # Prepare V-MoBA parameters from config
341
+ moba_params = fastvideo_args.moba_config.copy()
342
+ assert batch.raw_latent_shape is not None, "raw_latent_shape must be set for V-MoBA"
343
+ moba_params.update({
344
+ "current_timestep": i,
345
+ "raw_latent_shape": batch.raw_latent_shape[2:5],
346
+ "patch_size": fastvideo_args.pipeline_config.dit_config.patch_size,
347
+ "device": get_local_torch_device(),
348
+ })
349
+ attn_metadata = self.attn_metadata_builder.build(**moba_params)
350
+ assert attn_metadata is not None, "attn_metadata cannot be None"
351
+ else:
352
+ attn_metadata = None
353
+ else:
354
+ attn_metadata = None
355
+ # TODO(will): finalize the interface. vLLM uses this to
356
+ # support torch dynamo compilation. They pass in
357
+ # attn_metadata, vllm_config, and num_tokens. We can pass in
358
+ # fastvideo_args or training_args, and attn_metadata.
359
+ batch.is_cfg_negative = False
360
+ with set_forward_context(
361
+ current_timestep=i,
362
+ attn_metadata=attn_metadata,
363
+ forward_batch=batch,
364
+ # fastvideo_args=fastvideo_args
365
+ ):
366
+ # Run transformer
367
+ noise_pred = current_model(
368
+ latent_model_input,
369
+ prompt_embeds,
370
+ t_expand,
371
+ guidance=guidance_expand,
372
+ **image_kwargs,
373
+ **pos_cond_kwargs,
374
+ **action_kwargs,
375
+ **camera_kwargs,
376
+ **timesteps_r_kwarg,
377
+ )
378
+
379
+ if batch.do_classifier_free_guidance:
380
+ batch.is_cfg_negative = True
381
+ with set_forward_context(
382
+ current_timestep=i,
383
+ attn_metadata=attn_metadata,
384
+ forward_batch=batch,
385
+ ):
386
+ noise_pred_uncond = current_model(
387
+ latent_model_input,
388
+ neg_prompt_embeds,
389
+ t_expand,
390
+ guidance=guidance_expand,
391
+ **image_kwargs,
392
+ **neg_cond_kwargs,
393
+ **action_kwargs,
394
+ **camera_kwargs,
395
+ **timesteps_r_kwarg,
396
+ )
397
+
398
+ noise_pred_text = noise_pred
399
+ noise_pred = noise_pred_uncond + current_guidance_scale * (noise_pred_text - noise_pred_uncond)
400
+
401
+ # Apply guidance rescale if needed
402
+ if batch.guidance_rescale > 0.0:
403
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
404
+ noise_pred = self.rescale_noise_cfg(
405
+ noise_pred,
406
+ noise_pred_text,
407
+ guidance_rescale=batch.guidance_rescale,
408
+ )
409
+ # Compute the previous noisy sample
410
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
411
+ if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
412
+ latents = latents.squeeze(0)
413
+ latents = (1. - mask2[0]) * z + mask2[0] * latents
414
+ # latents = latents.unsqueeze(0)
415
+
416
+ # save trajectory latents if needed
417
+ if batch.return_trajectory_latents:
418
+ trajectory_timesteps.append(t)
419
+ trajectory_latents.append(latents)
420
+
421
+ # Update progress bar
422
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
423
+ (i + 1) % self.scheduler.order == 0 and progress_bar is not None):
424
+ progress_bar.update()
425
+
426
+ trajectory_tensor: torch.Tensor | None = None
427
+ if trajectory_latents:
428
+ trajectory_tensor = torch.stack(trajectory_latents, dim=1)
429
+ trajectory_timesteps_tensor = torch.stack(trajectory_timesteps, dim=0)
430
+ else:
431
+ trajectory_tensor = None
432
+ trajectory_timesteps_tensor = None
433
+
434
+ if trajectory_tensor is not None and trajectory_timesteps_tensor is not None:
435
+ batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu()
436
+ batch.trajectory_latents = trajectory_tensor.cpu()
437
+
438
+ # Update batch with final latents
439
+ batch.latents = latents
440
+
441
+ if fastvideo_args.dit_layerwise_offload:
442
+ mgr = getattr(self.transformer, "_layerwise_offload_manager", None)
443
+ if mgr is not None and getattr(mgr, "enabled", False):
444
+ mgr.release_all()
445
+ if self.transformer_2 is not None:
446
+ mgr2 = getattr(self.transformer_2, "_layerwise_offload_manager", None)
447
+ if mgr2 is not None and getattr(mgr2, "enabled", False):
448
+ mgr2.release_all()
449
+
450
+ # deallocate transformer if on mps
451
+ if torch.backends.mps.is_available():
452
+ logger.info("Memory before deallocating transformer: %s", torch.mps.current_allocated_memory())
453
+ del self.transformer
454
+ if pipeline is not None and "transformer" in pipeline.modules:
455
+ del pipeline.modules["transformer"]
456
+ fastvideo_args.model_loaded["transformer"] = False
457
+ logger.info("Memory after deallocating transformer: %s", torch.mps.current_allocated_memory())
458
+
459
+ return batch
460
+
461
+ def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]:
462
+ """
463
+ Prepare extra kwargs for the scheduler step / denoise step.
464
+
465
+ Args:
466
+ func: The function to prepare kwargs for.
467
+ kwargs: The kwargs to prepare.
468
+
469
+ Returns:
470
+ The prepared kwargs.
471
+ """
472
+ extra_step_kwargs = {}
473
+ for k, v in kwargs.items():
474
+ accepts = k in set(inspect.signature(func).parameters.keys())
475
+ if accepts:
476
+ extra_step_kwargs[k] = v
477
+ return extra_step_kwargs
478
+
479
+ def progress_bar(self, iterable: Iterable | None = None, total: int | None = None) -> tqdm:
480
+ """
481
+ Create a progress bar for the denoising process.
482
+
483
+ Args:
484
+ iterable: The iterable to iterate over.
485
+ total: The total number of items.
486
+
487
+ Returns:
488
+ A tqdm progress bar.
489
+ """
490
+ local_rank = get_world_group().local_rank
491
+ if local_rank == 0:
492
+ return tqdm(iterable=iterable, total=total)
493
+ else:
494
+ return tqdm(iterable=iterable, total=total, disable=True)
495
+
496
+ def rescale_noise_cfg(self, noise_cfg, noise_pred_text, guidance_rescale=0.0) -> torch.Tensor:
497
+ """
498
+ Rescale noise prediction according to guidance_rescale.
499
+
500
+ Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed"
501
+ (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.
502
+
503
+ Args:
504
+ noise_cfg: The noise prediction with guidance.
505
+ noise_pred_text: The text-conditioned noise prediction.
506
+ guidance_rescale: The guidance rescale factor.
507
+
508
+ Returns:
509
+ The rescaled noise prediction.
510
+ """
511
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
512
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
513
+ # Rescale the results from guidance (fixes overexposure)
514
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
515
+ # Mix with the original results from guidance by factor guidance_rescale
516
+ noise_cfg = (guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg)
517
+ return noise_cfg
518
+
519
+ def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
520
+ """Verify denoising stage inputs."""
521
+ result = VerificationResult()
522
+ result.add_check("timesteps", batch.timesteps, [V.is_tensor, V.min_dims(1)])
523
+ result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
524
+ result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
525
+ result.add_check("image_embeds", batch.image_embeds, V.is_list)
526
+ result.add_check("image_latent", batch.image_latent, V.none_or_tensor_with_dims(5))
527
+ result.add_check("num_inference_steps", batch.num_inference_steps, V.positive_int)
528
+ result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
529
+ result.add_check("eta", batch.eta, V.non_negative_float)
530
+ result.add_check("generator", batch.generator, V.generator_or_list_generators)
531
+ result.add_check("do_classifier_free_guidance", batch.do_classifier_free_guidance, V.bool_value)
532
+ result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
533
+ lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x))
534
+ return result
535
+
536
+ def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
537
+ """Verify denoising stage outputs."""
538
+ result = VerificationResult()
539
+ result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
540
+ return result
541
+
542
+
543
+ class CosmosDenoisingStage(DenoisingStage):
544
+ """
545
+ Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler.
546
+ """
547
+
548
+ def __init__(self, transformer, scheduler, pipeline=None) -> None:
549
+ super().__init__(transformer, scheduler, pipeline)
550
+
551
+ def forward(
552
+ self,
553
+ batch: ForwardBatch,
554
+ fastvideo_args: FastVideoArgs,
555
+ ) -> ForwardBatch:
556
+ pipeline = self.pipeline() if self.pipeline else None
557
+ if not fastvideo_args.model_loaded["transformer"]:
558
+ loader = TransformerLoader()
559
+ self.transformer = loader.load(fastvideo_args.model_paths["transformer"], fastvideo_args)
560
+ if pipeline:
561
+ pipeline.add_module("transformer", self.transformer)
562
+ fastvideo_args.model_loaded["transformer"] = True
563
+
564
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
565
+ self.scheduler.step,
566
+ {
567
+ "generator": batch.generator,
568
+ "eta": batch.eta
569
+ },
570
+ )
571
+
572
+ if hasattr(self.transformer, 'module'):
573
+ transformer_dtype = next(self.transformer.module.parameters()).dtype
574
+ else:
575
+ transformer_dtype = next(self.transformer.parameters()).dtype
576
+ target_dtype = transformer_dtype
577
+ autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
578
+
579
+ latents = batch.latents
580
+ num_inference_steps = batch.num_inference_steps
581
+ guidance_scale = batch.guidance_scale
582
+
583
+ sigma_max = 80.0
584
+ sigma_min = 0.002
585
+ sigma_data = 1.0
586
+ final_sigmas_type = "sigma_min"
587
+
588
+ if self.scheduler is not None:
589
+ self.scheduler.register_to_config(
590
+ sigma_max=sigma_max,
591
+ sigma_min=sigma_min,
592
+ sigma_data=sigma_data,
593
+ final_sigmas_type=final_sigmas_type,
594
+ )
595
+
596
+ self.scheduler.set_timesteps(num_inference_steps, device=latents.device)
597
+ timesteps = self.scheduler.timesteps
598
+
599
+ if (hasattr(self.scheduler.config, 'final_sigmas_type')
600
+ and self.scheduler.config.final_sigmas_type == "sigma_min" and len(self.scheduler.sigmas) > 1):
601
+ self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
602
+
603
+ conditioning_latents = getattr(batch, 'conditioning_latents', None)
604
+ unconditioning_latents = conditioning_latents
605
+
606
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
607
+ for i, t in enumerate(timesteps):
608
+ if hasattr(self, 'interrupt') and self.interrupt:
609
+ continue
610
+
611
+ current_sigma = self.scheduler.sigmas[i]
612
+ current_t = current_sigma / (current_sigma + 1)
613
+ c_in = 1 - current_t
614
+ c_skip = 1 - current_t
615
+ c_out = -current_t
616
+
617
+ timestep = current_t.view(1, 1, 1, 1, 1).expand(latents.size(0), -1, latents.size(2), -1,
618
+ -1) # [B, 1, T, 1, 1]
619
+
620
+ with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
621
+
622
+ cond_latent = latents * c_in
623
+
624
+ if hasattr(
625
+ batch,
626
+ 'cond_indicator') and batch.cond_indicator is not None and conditioning_latents is not None:
627
+ cond_latent = batch.cond_indicator * conditioning_latents + (1 -
628
+ batch.cond_indicator) * cond_latent
629
+ else:
630
+ logger.warning(
631
+ "Step %s: Missing conditioning data - cond_indicator: %s, conditioning_latents: %s", i,
632
+ hasattr(batch, 'cond_indicator'), conditioning_latents is not None)
633
+
634
+ cond_latent = cond_latent.to(target_dtype)
635
+
636
+ cond_timestep = timestep
637
+ if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None:
638
+ sigma_conditioning = 0.0001
639
+ t_conditioning = sigma_conditioning / (sigma_conditioning + 1)
640
+ cond_timestep = batch.cond_indicator * t_conditioning + (1 - batch.cond_indicator) * timestep
641
+ cond_timestep = cond_timestep.to(target_dtype)
642
+
643
+ with set_forward_context(
644
+ current_timestep=i,
645
+ attn_metadata=None,
646
+ forward_batch=batch,
647
+ ):
648
+ # Use conditioning masks from CosmosLatentPreparationStage
649
+ condition_mask = batch.cond_mask.to(target_dtype) if hasattr(batch, 'cond_mask') else None
650
+ padding_mask = torch.zeros(1,
651
+ 1,
652
+ batch.height,
653
+ batch.width,
654
+ device=cond_latent.device,
655
+ dtype=target_dtype)
656
+
657
+ # Fallback if masks not available
658
+ if condition_mask is None:
659
+ batch_size, num_channels, num_frames, height, width = cond_latent.shape
660
+ condition_mask = torch.zeros(batch_size,
661
+ 1,
662
+ num_frames,
663
+ height,
664
+ width,
665
+ device=cond_latent.device,
666
+ dtype=target_dtype)
667
+
668
+ noise_pred = self.transformer(
669
+ hidden_states=cond_latent,
670
+ timestep=cond_timestep.to(target_dtype),
671
+ encoder_hidden_states=batch.prompt_embeds[0].to(target_dtype),
672
+ fps=24, # TODO: get fps from batch or config
673
+ condition_mask=condition_mask,
674
+ padding_mask=padding_mask,
675
+ return_dict=False,
676
+ )[0]
677
+
678
+ cond_pred = (c_skip * latents + c_out * noise_pred.float()).to(target_dtype)
679
+
680
+ if hasattr(
681
+ batch,
682
+ 'cond_indicator') and batch.cond_indicator is not None and conditioning_latents is not None:
683
+ cond_pred = batch.cond_indicator * conditioning_latents + (1 - batch.cond_indicator) * cond_pred
684
+
685
+ if batch.do_classifier_free_guidance and batch.negative_prompt_embeds is not None:
686
+ uncond_latent = latents * c_in
687
+
688
+ if hasattr(batch, 'uncond_indicator'
689
+ ) and batch.uncond_indicator is not None and unconditioning_latents is not None:
690
+ uncond_latent = batch.uncond_indicator * unconditioning_latents + (
691
+ 1 - batch.uncond_indicator) * uncond_latent
692
+
693
+ with set_forward_context(
694
+ current_timestep=i,
695
+ attn_metadata=None,
696
+ forward_batch=batch,
697
+ ):
698
+ uncond_condition_mask = batch.uncond_mask.to(target_dtype) if hasattr(
699
+ batch, 'uncond_mask') and batch.uncond_mask is not None else condition_mask
700
+
701
+ uncond_timestep = timestep
702
+ if hasattr(batch, 'uncond_indicator') and batch.uncond_indicator is not None:
703
+ sigma_conditioning = 0.0001
704
+ t_conditioning = sigma_conditioning / (sigma_conditioning + 1)
705
+ uncond_timestep = batch.uncond_indicator * t_conditioning + (
706
+ 1 - batch.uncond_indicator) * timestep
707
+ uncond_timestep = uncond_timestep.to(target_dtype)
708
+
709
+ noise_pred_uncond = self.transformer(
710
+ hidden_states=uncond_latent.to(target_dtype),
711
+ timestep=uncond_timestep.to(target_dtype),
712
+ encoder_hidden_states=batch.negative_prompt_embeds[0].to(target_dtype),
713
+ fps=24, # TODO: get fps from batch or config
714
+ condition_mask=uncond_condition_mask,
715
+ padding_mask=padding_mask,
716
+ return_dict=False,
717
+ )[0]
718
+
719
+ uncond_pred = (c_skip * latents + c_out * noise_pred_uncond.float()).to(target_dtype)
720
+
721
+ if hasattr(batch, 'uncond_indicator'
722
+ ) and batch.uncond_indicator is not None and unconditioning_latents is not None:
723
+ uncond_pred = batch.uncond_indicator * unconditioning_latents + (
724
+ 1 - batch.uncond_indicator) * uncond_pred
725
+
726
+ guidance_diff = cond_pred - uncond_pred
727
+ final_pred = cond_pred + guidance_scale * guidance_diff
728
+ else:
729
+ final_pred = cond_pred
730
+
731
+ # Convert to noise for scheduler step
732
+ if current_sigma > 1e-8:
733
+ noise_for_scheduler = (latents - final_pred) / current_sigma
734
+ else:
735
+ logger.warning("Step %s: current_sigma too small (%s), using final_pred directly", i, current_sigma)
736
+ noise_for_scheduler = final_pred
737
+
738
+ if torch.isnan(noise_for_scheduler).sum() > 0:
739
+ logger.error("Step %s: NaN detected in noise_for_scheduler, sum: %s", i,
740
+ noise_for_scheduler.float().sum().item())
741
+ logger.error("Step %s: latents sum: %s, final_pred sum: %s, current_sigma: %s", i,
742
+ latents.float().sum().item(),
743
+ final_pred.float().sum().item(), current_sigma)
744
+
745
+ latents = self.scheduler.step(noise_for_scheduler, t, latents, **extra_step_kwargs,
746
+ return_dict=False)[0]
747
+
748
+ progress_bar.update()
749
+
750
+ batch.latents = latents
751
+
752
+ return batch
753
+
754
+ def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
755
+ """Verify Cosmos denoising stage inputs."""
756
+ result = VerificationResult()
757
+ result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
758
+ result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
759
+ result.add_check("num_inference_steps", batch.num_inference_steps, V.positive_int)
760
+ result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
761
+ result.add_check("do_classifier_free_guidance", batch.do_classifier_free_guidance, V.bool_value)
762
+ result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
763
+ lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x))
764
+ return result
765
+
766
+ def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
767
+ """Verify Cosmos denoising stage outputs."""
768
+ result = VerificationResult()
769
+ result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
770
+ return result
771
+
772
+
773
+ class Cosmos25DenoisingStage(CosmosDenoisingStage):
774
+ """Denoising stage for Cosmos 2.5 DiT (expects 1D/2D timestep, not 5D)."""
775
+
776
+ def forward(
777
+ self,
778
+ batch: ForwardBatch,
779
+ fastvideo_args: FastVideoArgs,
780
+ ) -> ForwardBatch:
781
+ pipeline = self.pipeline() if self.pipeline else None
782
+ if not fastvideo_args.model_loaded["transformer"]:
783
+ loader = TransformerLoader()
784
+ self.transformer = loader.load(fastvideo_args.model_paths["transformer"], fastvideo_args)
785
+ if pipeline:
786
+ pipeline.add_module("transformer", self.transformer)
787
+ fastvideo_args.model_loaded["transformer"] = True
788
+
789
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
790
+ self.scheduler.step,
791
+ {
792
+ "generator": batch.generator,
793
+ "eta": batch.eta
794
+ },
795
+ )
796
+
797
+ if hasattr(self.transformer, 'module'):
798
+ transformer_dtype = next(self.transformer.module.parameters()).dtype
799
+ else:
800
+ transformer_dtype = next(self.transformer.parameters()).dtype
801
+ target_dtype = transformer_dtype
802
+ autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
803
+
804
+ latents = batch.latents
805
+ if latents is None:
806
+ raise ValueError("latents must be provided for Cosmos25DenoisingStage")
807
+ guidance_scale = batch.guidance_scale
808
+
809
+ if batch.timesteps is None:
810
+ self.scheduler.set_timesteps(batch.num_inference_steps, device=latents.device)
811
+ timesteps = self.scheduler.timesteps
812
+ else:
813
+ timesteps = batch.timesteps.to(latents.device)
814
+
815
+ cfg = fastvideo_args.pipeline_config
816
+
817
+ if batch.fps is None:
818
+ gen = batch.generator
819
+ if isinstance(gen, list) and len(gen) > 0:
820
+ gen = gen[0]
821
+ fps_tensor = torch.randint(
822
+ 16,
823
+ 32,
824
+ (1, ),
825
+ generator=gen if isinstance(gen, torch.Generator) else None,
826
+ device=latents.device,
827
+ ).float().to(dtype=target_dtype)
828
+ else:
829
+ fps_val = batch.fps
830
+ fps_tensor = torch.tensor(
831
+ [fps_val],
832
+ device=latents.device,
833
+ dtype=target_dtype,
834
+ )
835
+
836
+ latents_4d = latents[0]
837
+
838
+ # Masks are optional for T2W.
839
+ cond_mask = getattr(batch, "cond_mask", None)
840
+ condition_mask = cond_mask.to(target_dtype) if isinstance(cond_mask, torch.Tensor) else None
841
+ pad_mask = getattr(batch, "padding_mask", None)
842
+ padding_mask = pad_mask.to(target_dtype) if isinstance(pad_mask, torch.Tensor) else None
843
+
844
+ # Conditioning fields are attached by latent preparation stage.
845
+ conditioning_latents = getattr(batch, "conditioning_latents", None)
846
+ cond_indicator = getattr(batch, "cond_indicator", None)
847
+ # Infer whether this is a conditioned run (V2W/I2W) purely from the presence
848
+ # of conditioning latents. Avoid carrying explicit mode flags on the batch.
849
+ is_conditioned = (conditioning_latents is not None)
850
+
851
+ init_noise_4d = latents_4d.clone()
852
+ if condition_mask is None:
853
+ _, t, h, w = latents_4d.shape
854
+ condition_mask = torch.zeros(1, 1, t, h, w, device=latents.device, dtype=target_dtype)
855
+ if padding_mask is None:
856
+ _, _, h, w = latents_4d.shape
857
+ padding_default = 0.0 if is_conditioned else 1.0
858
+ padding_mask = torch.full(
859
+ (1, 1, h, w),
860
+ float(padding_default),
861
+ device=latents.device,
862
+ dtype=target_dtype,
863
+ )
864
+
865
+ timestep_scale = 0.001
866
+
867
+ state_dtype = torch.float32
868
+
869
+ conditional_frame_timestep = 0.1
870
+ latents_4d = latents_4d.to(state_dtype)
871
+ init_noise_4d = init_noise_4d.to(state_dtype)
872
+
873
+ clamp_every_step = bool(getattr(cfg, "cosmos25_clamp_every_step", True)) if is_conditioned else False
874
+
875
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
876
+ for i, t in enumerate(timesteps):
877
+ t_val = float(t)
878
+ if is_conditioned:
879
+ t_frames = int(latents_4d.shape[1])
880
+ timestep = torch.full(
881
+ (1, t_frames),
882
+ float(t_val * timestep_scale),
883
+ device=latents.device,
884
+ dtype=torch.float32,
885
+ )
886
+ if cond_indicator is not None and t_frames > 0:
887
+ cond_t = cond_indicator[0, 0, :t_frames, 0, 0]
888
+ cond_mask_t = (cond_t > 0.5)
889
+ if bool(cond_mask_t.any().item()):
890
+ timestep[0, cond_mask_t] = float(conditional_frame_timestep)
891
+ else:
892
+ timestep_val = t_val * timestep_scale
893
+ timestep = torch.tensor(
894
+ [[float(timestep_val)]],
895
+ device=latents.device,
896
+ dtype=target_dtype,
897
+ )
898
+
899
+ # Conditioned runs: replace x_t with GT x0 on the conditioned frames.
900
+ if (is_conditioned and cond_indicator is not None and conditioning_latents is not None
901
+ and (clamp_every_step or i == 0)):
902
+ cond_ind_4d = cond_indicator[0].to(state_dtype)
903
+ gt_x0 = conditioning_latents[0].to(state_dtype)
904
+ latents_4d = gt_x0 * cond_ind_4d + latents_4d * (1 - cond_ind_4d)
905
+
906
+ model_hidden_states = latents_4d.unsqueeze(0)
907
+
908
+ with (
909
+ set_forward_context(current_timestep=int(t_val), attn_metadata=None, forward_batch=batch),
910
+ torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled),
911
+ ):
912
+ cond_v = self.transformer(
913
+ hidden_states=model_hidden_states.to(target_dtype),
914
+ encoder_hidden_states=batch.prompt_embeds[0].to(target_dtype),
915
+ timestep=timestep,
916
+ fps=fps_tensor,
917
+ condition_mask=condition_mask,
918
+ padding_mask=padding_mask,
919
+ return_dict=False,
920
+ )[0]
921
+
922
+ if batch.do_classifier_free_guidance and batch.negative_prompt_embeds:
923
+ uncond_v = self.transformer(
924
+ hidden_states=model_hidden_states.to(target_dtype),
925
+ encoder_hidden_states=batch.negative_prompt_embeds[0].to(target_dtype),
926
+ timestep=timestep,
927
+ fps=fps_tensor,
928
+ condition_mask=condition_mask,
929
+ padding_mask=padding_mask,
930
+ return_dict=False,
931
+ )[0]
932
+ if is_conditioned:
933
+ v = cond_v + guidance_scale * (cond_v - uncond_v)
934
+ else:
935
+ v = uncond_v + guidance_scale * (cond_v - uncond_v)
936
+ else:
937
+ v = cond_v
938
+
939
+ # Conditioned runs: replace velocity on conditioned frames with GT velocity.
940
+ if (is_conditioned and cond_indicator is not None and conditioning_latents is not None):
941
+ cond_ind_4d = cond_indicator[0].to(state_dtype)
942
+ gt_x0 = conditioning_latents[0].to(state_dtype)
943
+ gt_v = init_noise_4d.to(state_dtype) - gt_x0
944
+ v = cond_ind_4d * gt_v + (1 - cond_ind_4d) * v.to(state_dtype)
945
+
946
+ prev = self.scheduler.step(v.unsqueeze(0),
947
+ t,
948
+ latents_4d.unsqueeze(0),
949
+ **extra_step_kwargs,
950
+ return_dict=False)[0]
951
+ latents_4d = prev.squeeze(0)
952
+
953
+ progress_bar.update()
954
+
955
+ batch.latents = latents_4d.to(target_dtype).unsqueeze(0)
956
+ return batch
957
+
958
+
959
+ class Cosmos25T2WDenoisingStage(Cosmos25DenoisingStage):
960
+ """Cosmos 2.5 Text2World denoising stage."""
961
+
962
+ _CONDITIONING_FIELDS = (
963
+ "conditioning_latents",
964
+ "cond_indicator",
965
+ "uncond_indicator",
966
+ )
967
+
968
+ def forward(
969
+ self,
970
+ batch: ForwardBatch,
971
+ fastvideo_args: FastVideoArgs,
972
+ ) -> ForwardBatch:
973
+ for name in self._CONDITIONING_FIELDS:
974
+ if hasattr(batch, name):
975
+ setattr(batch, name, None)
976
+ return super().forward(batch, fastvideo_args)
977
+
978
+
979
+ class Cosmos25V2WDenoisingStage(Cosmos25DenoisingStage):
980
+ """Cosmos 2.5 Video2World denoising stage."""
981
+
982
+ def forward(
983
+ self,
984
+ batch: ForwardBatch,
985
+ fastvideo_args: FastVideoArgs,
986
+ ) -> ForwardBatch:
987
+ return super().forward(batch, fastvideo_args)
988
+
989
+
990
+ class Cosmos25AutoDenoisingStage(PipelineStage):
991
+ """Route Cosmos 2.5 denoising to T2W vs V2W/I2W."""
992
+
993
+ def __init__(self, transformer, scheduler) -> None:
994
+ super().__init__()
995
+ self._t2w = Cosmos25T2WDenoisingStage(transformer=transformer, scheduler=scheduler)
996
+ self._v2w = Cosmos25V2WDenoisingStage(transformer=transformer, scheduler=scheduler)
997
+
998
+ def pipeline(self):
999
+ return self._v2w.pipeline() if self._v2w.pipeline else None
1000
+
1001
+ def forward(
1002
+ self,
1003
+ batch: ForwardBatch,
1004
+ fastvideo_args: FastVideoArgs,
1005
+ ) -> ForwardBatch:
1006
+ conditioning_latents = getattr(batch, "conditioning_latents", None)
1007
+ if conditioning_latents is not None:
1008
+ return self._v2w.forward(batch, fastvideo_args)
1009
+ return self._t2w.forward(batch, fastvideo_args)
1010
+
1011
+ def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
1012
+ conditioning_latents = getattr(batch, "conditioning_latents", None)
1013
+ if conditioning_latents is not None:
1014
+ return self._v2w.verify_input(batch, fastvideo_args)
1015
+ return self._t2w.verify_input(batch, fastvideo_args)
1016
+
1017
+ def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
1018
+ conditioning_latents = getattr(batch, "conditioning_latents", None)
1019
+ if conditioning_latents is not None:
1020
+ return self._v2w.verify_output(batch, fastvideo_args)
1021
+ return self._t2w.verify_output(batch, fastvideo_args)
1022
+
1023
+
1024
+ class DmdDenoisingStage(DenoisingStage):
1025
+ """
1026
+ Denoising stage for DMD.
1027
+ """
1028
+
1029
+ def __init__(self, transformer, scheduler) -> None:
1030
+ super().__init__(transformer, scheduler)
1031
+ self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0)
1032
+
1033
+ def forward(
1034
+ self,
1035
+ batch: ForwardBatch,
1036
+ fastvideo_args: FastVideoArgs,
1037
+ ) -> ForwardBatch:
1038
+ """
1039
+ Run the denoising loop.
1040
+
1041
+ Args:
1042
+ batch: The current batch information.
1043
+ fastvideo_args: The inference arguments.
1044
+
1045
+ Returns:
1046
+ The batch with denoised latents.
1047
+ """
1048
+ # Setup precision and autocast settings
1049
+ # TODO(will): make the precision configurable for inference
1050
+ # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
1051
+ target_dtype = torch.bfloat16
1052
+ autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
1053
+
1054
+ # Get timesteps and calculate warmup steps
1055
+ timesteps = batch.timesteps
1056
+
1057
+ # TODO(will): remove this once we add input/output validation for stages
1058
+ if timesteps is None:
1059
+ raise ValueError("Timesteps must be provided")
1060
+ num_inference_steps = batch.num_inference_steps
1061
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1062
+
1063
+ # Prepare image latents and embeddings for I2V generation
1064
+ image_embeds = batch.image_embeds
1065
+ if len(image_embeds) > 0:
1066
+ assert torch.isnan(image_embeds[0]).sum() == 0
1067
+ image_embeds = [image_embed.to(target_dtype) for image_embed in image_embeds]
1068
+
1069
+ image_kwargs = self.prepare_extra_func_kwargs(
1070
+ self.transformer.forward,
1071
+ {
1072
+ "encoder_hidden_states_image": image_embeds,
1073
+ "mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24)
1074
+ },
1075
+ )
1076
+
1077
+ pos_cond_kwargs = self.prepare_extra_func_kwargs(
1078
+ self.transformer.forward,
1079
+ {
1080
+ "encoder_hidden_states_2": batch.clip_embedding_pos,
1081
+ "encoder_attention_mask": batch.prompt_attention_mask,
1082
+ },
1083
+ )
1084
+
1085
+ # Get latents and embeddings
1086
+ assert batch.latents is not None, "latents must be provided"
1087
+ latents = batch.latents
1088
+
1089
+ video_raw_latent_shape = latents.shape
1090
+ prompt_embeds = batch.prompt_embeds
1091
+ assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan"
1092
+ timesteps = torch.tensor(fastvideo_args.pipeline_config.dmd_denoising_steps,
1093
+ dtype=torch.long,
1094
+ device=get_local_torch_device())
1095
+
1096
+ # Run denoising loop
1097
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
1098
+ for i, t in enumerate(timesteps):
1099
+ # Skip if interrupted
1100
+ if hasattr(self, 'interrupt') and self.interrupt:
1101
+ continue
1102
+ # Expand latents for I2V
1103
+ noise_latents = latents.clone()
1104
+ latent_model_input = latents.to(target_dtype)
1105
+
1106
+ if batch.image_latent is not None:
1107
+ latent_model_input = torch.cat(
1108
+ [latent_model_input, batch.image_latent.permute(0, 2, 1, 3, 4)], dim=2).to(target_dtype)
1109
+ assert not torch.isnan(latent_model_input).any(), "latent_model_input contains nan"
1110
+
1111
+ # Prepare inputs for transformer
1112
+ t_expand = t.repeat(latent_model_input.shape[0])
1113
+ guidance_expand = (torch.tensor(
1114
+ [fastvideo_args.pipeline_config.embedded_cfg_scale] * latent_model_input.shape[0],
1115
+ dtype=torch.float32,
1116
+ device=get_local_torch_device(),
1117
+ ).to(target_dtype) * 1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale is not None else None)
1118
+
1119
+ # Predict noise residual
1120
+ with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
1121
+ if (vsa_available and self.attn_backend == VideoSparseAttentionBackend) or \
1122
+ (sparse_fp4_available and self.attn_backend in sparse_fp4_backends):
1123
+ self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()
1124
+
1125
+ if self.attn_metadata_builder_cls is not None:
1126
+ self.attn_metadata_builder = self.attn_metadata_builder_cls()
1127
+ # TODO(will): clean this up
1128
+ attn_metadata = self.attn_metadata_builder.build( # type: ignore
1129
+ current_timestep=i, # type: ignore
1130
+ raw_latent_shape=batch.raw_latent_shape[2:5], # type: ignore
1131
+ patch_size=fastvideo_args.pipeline_config. # type: ignore
1132
+ dit_config.patch_size, # type: ignore
1133
+ VSA_sparsity=fastvideo_args.VSA_sparsity, # type: ignore
1134
+ device=get_local_torch_device(), # type: ignore
1135
+ ) # type: ignore
1136
+ assert attn_metadata is not None, "attn_metadata cannot be None"
1137
+ else:
1138
+ attn_metadata = None
1139
+ else:
1140
+ attn_metadata = None
1141
+
1142
+ batch.is_cfg_negative = False
1143
+ with set_forward_context(
1144
+ current_timestep=i,
1145
+ attn_metadata=attn_metadata,
1146
+ forward_batch=batch,
1147
+ # fastvideo_args=fastvideo_args
1148
+ ):
1149
+ # Run transformer
1150
+ pred_noise = self.transformer(
1151
+ latent_model_input.permute(0, 2, 1, 3, 4),
1152
+ prompt_embeds,
1153
+ t_expand,
1154
+ guidance=guidance_expand,
1155
+ **image_kwargs,
1156
+ **pos_cond_kwargs,
1157
+ ).permute(0, 2, 1, 3, 4)
1158
+
1159
+ pred_video = pred_noise_to_pred_video(pred_noise=pred_noise.flatten(0, 1),
1160
+ noise_input_latent=noise_latents.flatten(0, 1),
1161
+ timestep=t_expand,
1162
+ scheduler=self.scheduler).unflatten(0, pred_noise.shape[:2])
1163
+
1164
+ if i < len(timesteps) - 1:
1165
+ next_timestep = timesteps[i + 1] * torch.ones([1], dtype=torch.long, device=pred_video.device)
1166
+ noise_generator = batch.generator[0] if isinstance(batch.generator, list) else batch.generator
1167
+ noise = torch.randn(video_raw_latent_shape, dtype=pred_video.dtype,
1168
+ generator=noise_generator).to(self.device)
1169
+ latents = self.scheduler.add_noise(pred_video.flatten(0, 1), noise.flatten(0, 1),
1170
+ next_timestep).unflatten(0, pred_video.shape[:2])
1171
+ else:
1172
+ latents = pred_video
1173
+
1174
+ # Update progress bar
1175
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
1176
+ (i + 1) % self.scheduler.order == 0 and progress_bar is not None):
1177
+ progress_bar.update()
1178
+
1179
+ # Gather results if using sequence parallelism
1180
+ latents = latents.permute(0, 2, 1, 3, 4)
1181
+ # Update batch with final latents
1182
+ batch.latents = latents
1183
+
1184
+ return batch
standalone_inference/overlay_files/fastvideo/platforms/cuda.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/cuda.py
3
+ """Code inside this file can safely assume cuda platform, e.g. importing
4
+ pynvml. However, it should not initialize cuda context.
5
+ """
6
+
7
+ import os
8
+ from collections.abc import Callable
9
+ from functools import lru_cache, wraps
10
+ from typing import TypeVar
11
+
12
+ import torch
13
+ from typing_extensions import ParamSpec
14
+
15
+ import fastvideo.envs as envs
16
+ from fastvideo.logger import init_logger
17
+ from fastvideo.platforms.interface import (AttentionBackendEnum, DeviceCapability, Platform, PlatformEnum)
18
+ from fastvideo.utils import import_pynvml
19
+
20
+ logger = init_logger(__name__)
21
+
22
+ _P = ParamSpec("_P")
23
+ _R = TypeVar("_R")
24
+
25
+ pynvml = import_pynvml() # type: ignore[no-untyped-call]
26
+
27
+ # pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
28
+ # see https://github.com/huggingface/diffusers/issues/9704 for details
29
+ torch.backends.cuda.enable_cudnn_sdp(False)
30
+
31
+
32
+ def device_id_to_physical_device_id(device_id: int) -> int:
33
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
34
+ device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
35
+ if device_ids == [""]:
36
+ msg = ("CUDA_VISIBLE_DEVICES is set to empty string, which means"
37
+ " GPU support is disabled. If you are using ray, please unset"
38
+ " the environment variable `CUDA_VISIBLE_DEVICES` inside the"
39
+ " worker/actor. "
40
+ "Check https://github.com/vllm-project/vllm/issues/8402 for"
41
+ " more information.")
42
+ raise RuntimeError(msg)
43
+ physical_device_id = device_ids[device_id]
44
+ return int(physical_device_id)
45
+ else:
46
+ return device_id
47
+
48
+
49
+ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
50
+
51
+ @wraps(fn)
52
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
53
+ pynvml.nvmlInit()
54
+ try:
55
+ return fn(*args, **kwargs)
56
+ finally:
57
+ pynvml.nvmlShutdown()
58
+
59
+ return wrapper
60
+
61
+
62
+ class CudaPlatformBase(Platform):
63
+ _enum = PlatformEnum.CUDA
64
+ device_name: str = "cuda"
65
+ device_type: str = "cuda"
66
+ dispatch_key: str = "CUDA"
67
+ ray_device_key: str = "GPU"
68
+ device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
69
+
70
+ @classmethod
71
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
72
+ raise NotImplementedError
73
+
74
+ @classmethod
75
+ def get_device_name(cls, device_id: int = 0) -> str:
76
+ raise NotImplementedError
77
+
78
+ @classmethod
79
+ def get_device_total_memory(cls, device_id: int = 0) -> int:
80
+ raise NotImplementedError
81
+
82
+ @classmethod
83
+ def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:
84
+ if enforce_eager:
85
+ logger.warning("To see benefits of async output processing, enable CUDA "
86
+ "graph. Since, enforce-eager is enabled, async output "
87
+ "processor cannot be used")
88
+ return False
89
+ return True
90
+
91
+ @classmethod
92
+ def is_full_nvlink(cls, device_ids: list[int]) -> bool:
93
+ raise NotImplementedError
94
+
95
+ @classmethod
96
+ def log_warnings(cls) -> None:
97
+ pass
98
+
99
+ @classmethod
100
+ def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float:
101
+ torch.cuda.reset_peak_memory_stats(device)
102
+ return float(torch.cuda.max_memory_allocated(device))
103
+
104
+ @classmethod
105
+ def get_torch_device(cls) -> object:
106
+ """
107
+ Return torch.cuda
108
+ """
109
+ return torch.cuda
110
+
111
+ @classmethod
112
+ def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None, head_size: int,
113
+ dtype: torch.dtype) -> str:
114
+ # TODO(will): maybe come up with a more general interface for local attention
115
+ # if distributed is False, we always try to use Flash attn
116
+
117
+ logger.info("Trying FASTVIDEO_ATTENTION_BACKEND=%s", envs.FASTVIDEO_ATTENTION_BACKEND)
118
+ logger.info("Selected backend: %s", selected_backend)
119
+ if selected_backend == AttentionBackendEnum.SAGE_ATTN:
120
+ try:
121
+ from sageattention import sageattn # noqa: F401
122
+
123
+ from fastvideo.attention.backends.sage_attn import ( # noqa: F401
124
+ SageAttentionBackend)
125
+ logger.info("Using Sage Attention backend.")
126
+
127
+ return "fastvideo.attention.backends.sage_attn.SageAttentionBackend"
128
+ except ImportError as e:
129
+ logger.info(e)
130
+ logger.info("Sage Attention backend is not installed. Fall back to Flash Attention.")
131
+ elif selected_backend == AttentionBackendEnum.SAGE_ATTN_THREE:
132
+ try:
133
+ from sageattn3 import sageattn3_blackwell # noqa: F401
134
+
135
+ from fastvideo.attention.backends.sage_attn3 import ( # noqa: F401
136
+ SageAttention3Backend)
137
+ logger.info("Using Sage Attention 3 backend.")
138
+
139
+ return "fastvideo.attention.backends.sage_attn3.SageAttention3Backend"
140
+ except ImportError as e:
141
+ logger.info(e)
142
+ logger.info("Sage Attention 3 backend is not installed. Fall back to Flash Attention.")
143
+ elif selected_backend == AttentionBackendEnum.ATTN_QAT_INFER:
144
+ try:
145
+ from fastvideo.attention.backends.attn_qat_infer import ( # noqa: F401
146
+ AttnQatInferBackend, is_attn_qat_infer_available,
147
+ )
148
+ if not is_attn_qat_infer_available():
149
+ raise ImportError("attn_qat_infer could not be imported.")
150
+ logger.info("Using attn_qat_infer backend.")
151
+
152
+ return "fastvideo.attention.backends.attn_qat_infer.AttnQatInferBackend"
153
+ except ImportError as e:
154
+ logger.info(e)
155
+ logger.info("attn_qat_infer backend is not installed. Fall back to Flash Attention.")
156
+ elif selected_backend == AttentionBackendEnum.ATTN_QAT_TRAIN:
157
+ try:
158
+ from fastvideo_kernel.triton_kernels.attn_qat_train import attention # noqa: F401
159
+
160
+ from fastvideo.attention.backends.attn_qat_train import ( # noqa: F401
161
+ AttnQatTrainBackend)
162
+ logger.info("Using attn_qat_train backend.")
163
+
164
+ return "fastvideo.attention.backends.attn_qat_train.AttnQatTrainBackend"
165
+ except ImportError as e:
166
+ logger.info(e)
167
+ logger.info("attn_qat_train backend is not installed. Fall back to Flash Attention.")
168
+ elif selected_backend == AttentionBackendEnum.VIDEO_SPARSE_ATTN:
169
+ try:
170
+ from fastvideo_kernel import video_sparse_attn # noqa: F401
171
+
172
+ from fastvideo.attention.backends.video_sparse_attn import ( # noqa: F401
173
+ VideoSparseAttentionBackend)
174
+ logger.info("Using Video Sparse Attention backend.")
175
+
176
+ return "fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionBackend"
177
+ except ImportError as e:
178
+ logger.error("Failed to import Video Sparse Attention backend: %s", str(e))
179
+ raise ImportError("The Video Sparse Attention backend is not installed. "
180
+ "To install it, please follow the instructions at: "
181
+ "https://hao-ai-lab.github.io/FastVideo/video_sparse_attention/installation ") from e
182
+ elif selected_backend == AttentionBackendEnum.SPARSE_FP4_ATTN:
183
+ try:
184
+ from fastvideo.attention.backends.sparse_fp4_attn import ( # noqa: F401
185
+ SparseFP4AttentionBackend)
186
+ logger.info("Using Sparse FP4 Attention backend (FP4 quant + VSA).")
187
+ return "fastvideo.attention.backends.sparse_fp4_attn.SparseFP4AttentionBackend"
188
+ except ImportError as e:
189
+ logger.error("Failed to import Sparse FP4 Attention backend: %s", str(e))
190
+ raise ImportError("Sparse FP4 Attention backend is not available.") from e
191
+ elif selected_backend == AttentionBackendEnum.SPARSE_FP4_OURS_P_ATTN:
192
+ try:
193
+ from fastvideo.attention.backends.sparse_fp4_ours_p_attn import ( # noqa: F401
194
+ SparseFP4OursPAttentionBackend)
195
+ logger.info(
196
+ "Using Sparse FP4 Ours-P Attention backend (group-local P quant + VSA)."
197
+ )
198
+ return "fastvideo.attention.backends.sparse_fp4_ours_p_attn.SparseFP4OursPAttentionBackend"
199
+ except ImportError as e:
200
+ logger.error("Failed to import Sparse FP4 Ours-P Attention backend: %s", str(e))
201
+ raise ImportError("Sparse FP4 Ours-P Attention backend is not available.") from e
202
+ elif selected_backend == AttentionBackendEnum.BSA_ATTN:
203
+ try:
204
+ from fastvideo.attention.backends.bsa_attn import ( # noqa: F401
205
+ BSAAttentionBackend)
206
+ logger.info("Using BSA Attention backend.")
207
+
208
+ return "fastvideo.attention.backends.bsa_attn.BSAAttentionBackend"
209
+ except ImportError as e:
210
+ logger.error("Failed to import BSA Attention backend: %s", str(e))
211
+ raise ImportError("The BSA Attention backend failed to import.") from e
212
+ elif selected_backend == AttentionBackendEnum.VMOBA_ATTN:
213
+ try:
214
+ from fastvideo_kernel import moba_attn_varlen # noqa: F401
215
+ from fastvideo.attention.backends.vmoba import ( # noqa: F401
216
+ VMOBAAttentionBackend)
217
+ logger.info("Using Video MOBA Attention backend.")
218
+
219
+ return "fastvideo.attention.backends.vmoba.VMOBAAttentionBackend"
220
+ except ImportError as e:
221
+ logger.error("Failed to import Video MoBA Attention backend: %s", str(e))
222
+ raise ImportError("Video MoBA Attention backend is not installed. ") from e
223
+ elif selected_backend == AttentionBackendEnum.SLA_ATTN:
224
+ try:
225
+ from fastvideo.attention.backends.sla import ( # noqa: F401
226
+ SLAAttentionBackend)
227
+ logger.info("Using SLA (Sparse-Linear Attention) backend.")
228
+
229
+ return "fastvideo.attention.backends.sla.SLAAttentionBackend"
230
+ except ImportError as e:
231
+ logger.error("Failed to import SLA Attention backend: %s", str(e))
232
+ raise ImportError("SLA Attention backend is not available. ") from e
233
+ elif selected_backend == AttentionBackendEnum.SAGE_SLA_ATTN:
234
+ try:
235
+ from fastvideo.attention.backends.sla import ( # noqa: F401
236
+ SageSLAAttentionBackend)
237
+ logger.info("Using SageSLA (Quantized Sparse-Linear Attention) backend.")
238
+
239
+ return "fastvideo.attention.backends.sla.SageSLAAttentionBackend"
240
+ except ImportError as e:
241
+ logger.error("Failed to import SageSLA Attention backend: %s", str(e))
242
+ raise ImportError("SageSLA Attention backend requires spas_sage_attn. "
243
+ "Install with: pip install git+https://github.com/thu-ml/SpargeAttn.git") from e
244
+ elif selected_backend == AttentionBackendEnum.TORCH_SDPA:
245
+ logger.info("Using Torch SDPA backend.")
246
+ return "fastvideo.attention.backends.sdpa.SDPABackend"
247
+ elif selected_backend == AttentionBackendEnum.FLASH_ATTN or selected_backend is None:
248
+ pass
249
+ elif selected_backend:
250
+ raise ValueError(f"Invalid attention backend for {cls.device_name}")
251
+
252
+ target_backend = AttentionBackendEnum.FLASH_ATTN
253
+ if not cls.has_device_capability(80):
254
+ logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
255
+ "GPUs.")
256
+ target_backend = AttentionBackendEnum.TORCH_SDPA
257
+ elif dtype not in (torch.float16, torch.bfloat16):
258
+ logger.info("Cannot use FlashAttention-2 backend for dtype other than "
259
+ "torch.float16 or torch.bfloat16.")
260
+ target_backend = AttentionBackendEnum.TORCH_SDPA
261
+
262
+ # FlashAttn is valid for the model, checking if the package is
263
+ # installed.
264
+ if target_backend == AttentionBackendEnum.FLASH_ATTN:
265
+ try:
266
+ import flash_attn # noqa: F401
267
+
268
+ from fastvideo.attention.backends.flash_attn import ( # noqa: F401
269
+ FlashAttentionBackend)
270
+
271
+ supported_sizes = \
272
+ FlashAttentionBackend.get_supported_head_sizes()
273
+ if head_size not in supported_sizes:
274
+ logger.info("Cannot use FlashAttention-2 backend for head size %d.", head_size)
275
+ target_backend = AttentionBackendEnum.TORCH_SDPA
276
+ except ImportError:
277
+ logger.info("Cannot use FlashAttention-2 backend because the "
278
+ "flash_attn package is not found. "
279
+ "Make sure that flash_attn was built and installed "
280
+ "(on by default).")
281
+ target_backend = AttentionBackendEnum.TORCH_SDPA
282
+
283
+ if target_backend == AttentionBackendEnum.TORCH_SDPA:
284
+ logger.info("Using Torch SDPA backend.")
285
+
286
+ return "fastvideo.attention.backends.sdpa.SDPABackend"
287
+
288
+ logger.info("Using Flash Attention backend.")
289
+
290
+ return "fastvideo.attention.backends.flash_attn.FlashAttentionBackend"
291
+
292
+ @classmethod
293
+ def get_device_communicator_cls(cls) -> str:
294
+ return "fastvideo.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
295
+
296
+
297
+ # NVML utils
298
+ # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
299
+ # all the related functions work on real physical device ids.
300
+ # the major benefit of using NVML is that it will not initialize CUDA
301
+ class NvmlCudaPlatform(CudaPlatformBase):
302
+
303
+ @classmethod
304
+ @lru_cache(maxsize=8)
305
+ @with_nvml_context
306
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
307
+ try:
308
+ physical_device_id = device_id_to_physical_device_id(device_id)
309
+ handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
310
+ major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
311
+ return DeviceCapability(major=major, minor=minor)
312
+ except RuntimeError:
313
+ return None
314
+
315
+ @classmethod
316
+ @lru_cache(maxsize=8)
317
+ @with_nvml_context
318
+ def has_device_capability(
319
+ cls,
320
+ capability: tuple[int, int] | int,
321
+ device_id: int = 0,
322
+ ) -> bool:
323
+ try:
324
+ return bool(super().has_device_capability(capability, device_id))
325
+ except RuntimeError:
326
+ return False
327
+
328
+ @classmethod
329
+ @lru_cache(maxsize=8)
330
+ @with_nvml_context
331
+ def get_device_name(cls, device_id: int = 0) -> str:
332
+ physical_device_id = device_id_to_physical_device_id(device_id)
333
+ return cls._get_physical_device_name(physical_device_id)
334
+
335
+ @classmethod
336
+ @lru_cache(maxsize=8)
337
+ @with_nvml_context
338
+ def get_device_uuid(cls, device_id: int = 0) -> str:
339
+ physical_device_id = device_id_to_physical_device_id(device_id)
340
+ handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
341
+ return str(pynvml.nvmlDeviceGetUUID(handle))
342
+
343
+ @classmethod
344
+ @lru_cache(maxsize=8)
345
+ @with_nvml_context
346
+ def get_device_total_memory(cls, device_id: int = 0) -> int:
347
+ physical_device_id = device_id_to_physical_device_id(device_id)
348
+ handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
349
+ return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
350
+
351
+ @classmethod
352
+ @with_nvml_context
353
+ def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool:
354
+ """
355
+ query if the set of gpus are fully connected by nvlink (1 hop)
356
+ """
357
+ handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
358
+ for i, handle in enumerate(handles):
359
+ for j, peer_handle in enumerate(handles):
360
+ if i < j:
361
+ try:
362
+ p2p_status = pynvml.nvmlDeviceGetP2PStatus(
363
+ handle,
364
+ peer_handle,
365
+ pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
366
+ )
367
+ if p2p_status != pynvml.NVML_P2P_STATUS_OK:
368
+ return False
369
+ except pynvml.NVMLError:
370
+ logger.exception("NVLink detection failed. This is normal if"
371
+ " your machine has no NVLink equipped.")
372
+ return False
373
+ return True
374
+
375
+ @classmethod
376
+ def _get_physical_device_name(cls, device_id: int = 0) -> str:
377
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
378
+ return str(pynvml.nvmlDeviceGetName(handle))
379
+
380
+ @classmethod
381
+ @with_nvml_context
382
+ def log_warnings(cls) -> None:
383
+ device_ids: int = pynvml.nvmlDeviceGetCount()
384
+ if device_ids > 1:
385
+ device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
386
+ if (len(set(device_names)) > 1 and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
387
+ logger.warning(
388
+ "Detected different devices in the system: %s. Please"
389
+ " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
390
+ "avoid unexpected behavior.",
391
+ ", ".join(device_names),
392
+ )
393
+
394
+
395
+ class NonNvmlCudaPlatform(CudaPlatformBase):
396
+
397
+ @classmethod
398
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
399
+ major, minor = torch.cuda.get_device_capability(device_id)
400
+ return DeviceCapability(major=major, minor=minor)
401
+
402
+ @classmethod
403
+ def get_device_name(cls, device_id: int = 0) -> str:
404
+ return str(torch.cuda.get_device_name(device_id))
405
+
406
+ @classmethod
407
+ def get_device_total_memory(cls, device_id: int = 0) -> int:
408
+ device_props = torch.cuda.get_device_properties(device_id)
409
+ return int(device_props.total_memory)
410
+
411
+ @classmethod
412
+ def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool:
413
+ logger.exception("NVLink detection not possible, as context support was"
414
+ " not found. Assuming no NVLink available.")
415
+ return False
416
+
417
+
418
+ # Autodetect either NVML-enabled or non-NVML platform
419
+ # based on whether NVML is available.
420
+ nvml_available = False
421
+ try:
422
+ try:
423
+ pynvml.nvmlInit()
424
+ nvml_available = True
425
+ except Exception:
426
+ # On Jetson, NVML is not supported.
427
+ nvml_available = False
428
+ finally:
429
+ if nvml_available:
430
+ pynvml.nvmlShutdown()
431
+
432
+ CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform
433
+
434
+ try:
435
+ from sphinx.ext.autodoc.mock import _MockModule
436
+
437
+ if not isinstance(pynvml, _MockModule):
438
+ CudaPlatform.log_warnings()
439
+ except ModuleNotFoundError:
440
+ CudaPlatform.log_warnings()
standalone_inference/overlay_files/fastvideo/platforms/interface.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import random
3
+ from typing import Any, NamedTuple
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from fastvideo.logger import init_logger
9
+
10
+ logger = init_logger(__name__)
11
+
12
+
13
+ class AttentionBackendEnum(enum.Enum):
14
+ FLASH_ATTN = enum.auto()
15
+ TORCH_SDPA = enum.auto()
16
+ SAGE_ATTN = enum.auto()
17
+ SAGE_ATTN_THREE = enum.auto()
18
+ ATTN_QAT_INFER = enum.auto()
19
+ ATTN_QAT_TRAIN = enum.auto()
20
+ VIDEO_SPARSE_ATTN = enum.auto()
21
+ BSA_ATTN = enum.auto()
22
+ VMOBA_ATTN = enum.auto()
23
+ SLA_ATTN = enum.auto()
24
+ SAGE_SLA_ATTN = enum.auto()
25
+ SPARSE_FP4_ATTN = enum.auto()
26
+ SPARSE_FP4_OURS_P_ATTN = enum.auto()
27
+ NO_ATTENTION = enum.auto()
28
+
29
+
30
+ class PlatformEnum(enum.Enum):
31
+ CUDA = enum.auto()
32
+ ROCM = enum.auto()
33
+ TPU = enum.auto()
34
+ XPU = enum.auto()
35
+ CPU = enum.auto()
36
+ MPS = enum.auto()
37
+ OOT = enum.auto()
38
+ UNSPECIFIED = enum.auto()
39
+ NPU = enum.auto()
40
+
41
+
42
+ class CpuArchEnum(enum.Enum):
43
+ X86 = enum.auto()
44
+ ARM = enum.auto()
45
+ UNSPECIFIED = enum.auto()
46
+
47
+
48
+ class DeviceCapability(NamedTuple):
49
+ major: int
50
+ minor: int
51
+
52
+ def as_version_str(self) -> str:
53
+ return f"{self.major}.{self.minor}"
54
+
55
+ def to_int(self) -> int:
56
+ """
57
+ Express device capability as an integer ``<major><minor>``.
58
+
59
+ It is assumed that the minor version is always a single digit.
60
+ """
61
+ assert 0 <= self.minor < 10
62
+ return self.major * 10 + self.minor
63
+
64
+
65
+ class Platform:
66
+ _enum: PlatformEnum
67
+ device_name: str
68
+ device_type: str
69
+
70
+ dispatch_key: str = "CPU"
71
+
72
+ # platform-agnostic way to specify the device control environment variable,
73
+ # .e.g. CUDA_VISIBLE_DEVICES for CUDA.
74
+ # hint: search for "get_visible_accelerator_ids_env_var" in
75
+ # https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
76
+ device_control_env_var: str = "FASTVIDEO_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"
77
+
78
+ # available ray device keys:
79
+ # https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
80
+ # empty string means the device does not support ray
81
+ ray_device_key: str = ""
82
+ # The torch.compile backend for compiling simple and
83
+ # standalone functions. The default value is "inductor" to keep
84
+ # the same behavior as PyTorch.
85
+ # NOTE: for the forward part of the model, vLLM has another separate
86
+ # compilation strategy.
87
+ simple_compile_backend: str = "inductor"
88
+
89
+ supported_quantization: list[str] = []
90
+
91
+ additional_env_vars: list[str] = []
92
+
93
+ def is_cuda(self) -> bool:
94
+ return self._enum == PlatformEnum.CUDA
95
+
96
+ def is_rocm(self) -> bool:
97
+ return self._enum == PlatformEnum.ROCM
98
+
99
+ def is_tpu(self) -> bool:
100
+ return self._enum == PlatformEnum.TPU
101
+
102
+ def is_xpu(self) -> bool:
103
+ return self._enum == PlatformEnum.XPU
104
+
105
+ def is_cpu(self) -> bool:
106
+ return self._enum == PlatformEnum.CPU
107
+
108
+ def is_out_of_tree(self) -> bool:
109
+ return self._enum == PlatformEnum.OOT
110
+
111
+ def is_cuda_alike(self) -> bool:
112
+ """Stateless version of :func:`torch.cuda.is_available`."""
113
+ return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
114
+
115
+ def is_mps(self) -> bool:
116
+ return self._enum == PlatformEnum.MPS
117
+
118
+ def is_npu(self) -> bool:
119
+ return self._enum == PlatformEnum.NPU
120
+
121
+ @classmethod
122
+ def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None, head_size: int,
123
+ dtype: torch.dtype) -> str:
124
+ """Get the attention backend class of a device."""
125
+ return ""
126
+
127
+ @classmethod
128
+ def get_device_capability(
129
+ cls,
130
+ device_id: int = 0,
131
+ ) -> DeviceCapability | None:
132
+ """Stateless version of :func:`torch.cuda.get_device_capability`."""
133
+ return None
134
+
135
+ @classmethod
136
+ def has_device_capability(
137
+ cls,
138
+ capability: tuple[int, int] | int,
139
+ device_id: int = 0,
140
+ ) -> bool:
141
+ """
142
+ Test whether this platform is compatible with a device capability.
143
+
144
+ The ``capability`` argument can either be:
145
+
146
+ - A tuple ``(major, minor)``.
147
+ - An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
148
+ """
149
+ current_capability = cls.get_device_capability(device_id=device_id)
150
+ if current_capability is None:
151
+ return False
152
+
153
+ if isinstance(capability, tuple):
154
+ return current_capability >= capability
155
+
156
+ return current_capability.to_int() >= capability
157
+
158
+ @classmethod
159
+ def get_device_name(cls, device_id: int = 0) -> str:
160
+ """Get the name of a device."""
161
+ raise NotImplementedError
162
+
163
+ @classmethod
164
+ def get_device_uuid(cls, device_id: int = 0) -> str:
165
+ """Get the uuid of a device, e.g. the PCI bus ID."""
166
+ raise NotImplementedError
167
+
168
+ @classmethod
169
+ def get_device_total_memory(cls, device_id: int = 0) -> int:
170
+ """Get the total memory of a device in bytes."""
171
+ raise NotImplementedError
172
+
173
+ @classmethod
174
+ def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:
175
+ """
176
+ Check if the current platform supports async output.
177
+ """
178
+ raise NotImplementedError
179
+
180
+ @classmethod
181
+ def get_torch_device(cls) -> Any:
182
+ """
183
+ Check if the current platform supports torch device.
184
+ """
185
+ raise NotImplementedError
186
+
187
+ @classmethod
188
+ def inference_mode(cls):
189
+ """A device-specific wrapper of `torch.inference_mode`.
190
+
191
+ This wrapper is recommended because some hardware backends such as TPU
192
+ do not support `torch.inference_mode`. In such a case, they will fall
193
+ back to `torch.no_grad` by overriding this method.
194
+ """
195
+ return torch.inference_mode(mode=True)
196
+
197
+ @classmethod
198
+ def seed_everything(cls, seed: int | None = None) -> None:
199
+ """
200
+ Set the seed of each random module.
201
+ `torch.manual_seed` will set seed on all devices.
202
+
203
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
204
+ """
205
+ if seed is not None:
206
+ random.seed(seed)
207
+ np.random.seed(seed)
208
+ torch.manual_seed(seed)
209
+ torch.cuda.manual_seed_all(seed)
210
+
211
+ @classmethod
212
+ def verify_model_arch(cls, model_arch: str) -> None:
213
+ """
214
+ Verify whether the current platform supports the specified model
215
+ architecture.
216
+
217
+ - This will raise an Error or Warning based on the model support on
218
+ the current platform.
219
+ - By default all models are considered supported.
220
+ """
221
+ pass
222
+
223
+ @classmethod
224
+ def verify_quantization(cls, quant: str) -> None:
225
+ """
226
+ Verify whether the quantization is supported by the current platform.
227
+ """
228
+ if cls.supported_quantization and \
229
+ quant not in cls.supported_quantization:
230
+ raise ValueError(f"{quant} quantization is currently not supported in "
231
+ f"{cls.device_name}.")
232
+
233
+ @classmethod
234
+ def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float:
235
+ """
236
+ Return the memory usage in bytes.
237
+ """
238
+ raise NotImplementedError
239
+
240
+ @classmethod
241
+ def get_device_communicator_cls(cls) -> str:
242
+ """
243
+ Get device specific communicator class for distributed communication.
244
+ """
245
+ return "fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa
246
+
247
+ @classmethod
248
+ def get_cpu_architecture(cls) -> CpuArchEnum:
249
+ """Get the CPU architecture of the current platform."""
250
+ return CpuArchEnum.UNSPECIFIED
251
+
252
+
253
+ class UnspecifiedPlatform(Platform):
254
+ _enum = PlatformEnum.UNSPECIFIED
255
+ device_type = ""
standalone_inference/overlay_files/fastvideo/train/models/wan/wan.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Wan model plugin (per-role instance)."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import copy
7
+ import gc
8
+ from typing import Any, Literal, TYPE_CHECKING
9
+
10
+ import torch
11
+
12
+ import fastvideo.envs as envs
13
+ from fastvideo.configs.sample import SamplingParam
14
+ from fastvideo.distributed import (
15
+ get_sp_group,
16
+ get_world_group,
17
+ )
18
+ from fastvideo.forward_context import set_forward_context
19
+ from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import (
20
+ FlowMatchEulerDiscreteScheduler, )
21
+ from fastvideo.pipelines import TrainingBatch
22
+ from fastvideo.pipelines.basic.wan.wan_pipeline import (
23
+ WanPipeline, )
24
+ from fastvideo.pipelines.pipeline_batch_info import (
25
+ ForwardBatch, )
26
+ from fastvideo.training.activation_checkpoint import (
27
+ apply_activation_checkpointing, )
28
+ from fastvideo.training.training_utils import (
29
+ compute_density_for_timestep_sampling,
30
+ get_sigmas,
31
+ normalize_dit_input,
32
+ shift_timestep,
33
+ )
34
+ from fastvideo.utils import (
35
+ is_vmoba_available,
36
+ is_vsa_available,
37
+ )
38
+
39
+ from fastvideo.train.models.base import ModelBase
40
+ from fastvideo.train.utils.module_state import (
41
+ apply_trainable, )
42
+ from fastvideo.train.utils.moduleloader import (
43
+ load_module_from_path, )
44
+
45
+ if TYPE_CHECKING:
46
+ from fastvideo.train.utils.training_config import (
47
+ TrainingConfig, )
48
+
49
+ VideoSparseAttentionMetadataBuilder: type[Any] | None
50
+ VideoMobaAttentionMetadataBuilder: type[Any] | None
51
+
52
+ try:
53
+ from fastvideo.attention.backends.video_sparse_attn import (
54
+ VideoSparseAttentionMetadataBuilder as _VideoSparseAttentionMetadataBuilder, )
55
+ from fastvideo.attention.backends.vmoba import (
56
+ VideoMobaAttentionMetadataBuilder as _VideoMobaAttentionMetadataBuilder, )
57
+ VideoSparseAttentionMetadataBuilder = _VideoSparseAttentionMetadataBuilder
58
+ VideoMobaAttentionMetadataBuilder = _VideoMobaAttentionMetadataBuilder
59
+ except Exception:
60
+ VideoSparseAttentionMetadataBuilder = None
61
+ VideoMobaAttentionMetadataBuilder = None
62
+
63
+
64
+ class WanModel(ModelBase):
65
+ """Wan per-role model: owns transformer + noise_scheduler."""
66
+
67
+ _transformer_cls_name: str = "WanTransformer3DModel"
68
+
69
+ def __init__(
70
+ self,
71
+ *,
72
+ init_from: str,
73
+ training_config: TrainingConfig,
74
+ trainable: bool = True,
75
+ disable_custom_init_weights: bool = False,
76
+ flow_shift: float = 3.0,
77
+ enable_gradient_checkpointing_type: str
78
+ | None = None,
79
+ transformer_override_safetensor: str
80
+ | None = None,
81
+ ) -> None:
82
+ self._init_from = str(init_from)
83
+ self._trainable = bool(trainable)
84
+
85
+ self.transformer = self._load_transformer(
86
+ init_from=self._init_from,
87
+ trainable=self._trainable,
88
+ disable_custom_init_weights=(disable_custom_init_weights),
89
+ enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
90
+ training_config=training_config,
91
+ transformer_override_safetensor=(transformer_override_safetensor),
92
+ )
93
+
94
+ self.noise_scheduler = (FlowMatchEulerDiscreteScheduler(shift=float(flow_shift)))
95
+
96
+ # Filled by init_preprocessors (student only).
97
+ self.vae: Any = None
98
+ self.training_config: TrainingConfig = training_config
99
+ self.dataloader: Any = None
100
+ self.validator: Any = None
101
+ self.start_step: int = 0
102
+
103
+ self.world_group: Any = None
104
+ self.sp_group: Any = None
105
+
106
+ self.negative_prompt_embeds: (torch.Tensor | None) = None
107
+ self.negative_prompt_attention_mask: (torch.Tensor | None) = None
108
+
109
+ # Timestep mechanics.
110
+ self.timestep_shift: float = float(flow_shift)
111
+ self.num_train_timestep: int = int(self.noise_scheduler.num_train_timesteps)
112
+ self.min_timestep: int = 0
113
+ self.max_timestep: int = self.num_train_timestep
114
+
115
+ def _load_transformer(
116
+ self,
117
+ *,
118
+ init_from: str,
119
+ trainable: bool,
120
+ disable_custom_init_weights: bool,
121
+ enable_gradient_checkpointing_type: str | None,
122
+ training_config: TrainingConfig,
123
+ transformer_override_safetensor: str | None = None,
124
+ ) -> torch.nn.Module:
125
+ transformer = load_module_from_path(
126
+ model_path=init_from,
127
+ module_type="transformer",
128
+ training_config=training_config,
129
+ disable_custom_init_weights=(disable_custom_init_weights),
130
+ override_transformer_cls_name=(self._transformer_cls_name),
131
+ transformer_override_safetensor=(transformer_override_safetensor),
132
+ )
133
+ transformer = apply_trainable(transformer, trainable=trainable)
134
+ # Fall back to training_config.model if not set on the
135
+ # model YAML section directly.
136
+ ckpt_type = (enable_gradient_checkpointing_type or getattr(
137
+ getattr(training_config, "model", None),
138
+ "enable_gradient_checkpointing_type",
139
+ None,
140
+ ))
141
+ if trainable and ckpt_type:
142
+ transformer = apply_activation_checkpointing(
143
+ transformer,
144
+ checkpointing_type=ckpt_type,
145
+ )
146
+ return transformer
147
+
148
+ # ------------------------------------------------------------------
149
+ # Lifecycle
150
+ # ------------------------------------------------------------------
151
+
152
+ def init_preprocessors(self, training_config: TrainingConfig) -> None:
153
+ self.vae = load_module_from_path(
154
+ model_path=str(training_config.model_path),
155
+ module_type="vae",
156
+ training_config=training_config,
157
+ )
158
+
159
+ self.world_group = get_world_group()
160
+ self.sp_group = get_sp_group()
161
+
162
+ self._init_timestep_mechanics()
163
+
164
+ from fastvideo.dataset.dataloader.schema import (
165
+ pyarrow_schema_t2v, )
166
+ from fastvideo.train.utils.dataloader import (
167
+ build_parquet_t2v_train_dataloader, )
168
+
169
+ text_len = (
170
+ training_config.pipeline_config.text_encoder_configs[ # type: ignore[union-attr]
171
+ 0].arch_config.text_len)
172
+ self.dataloader = build_parquet_t2v_train_dataloader(
173
+ training_config.data,
174
+ text_len=int(text_len),
175
+ parquet_schema=pyarrow_schema_t2v,
176
+ )
177
+ self.start_step = 0
178
+
179
+ @property
180
+ def num_train_timesteps(self) -> int:
181
+ return int(self.num_train_timestep)
182
+
183
+ def shift_and_clamp_timestep(self, timestep: torch.Tensor) -> torch.Tensor:
184
+ timestep = shift_timestep(
185
+ timestep,
186
+ self.timestep_shift,
187
+ self.num_train_timestep,
188
+ )
189
+ return timestep.clamp(self.min_timestep, self.max_timestep)
190
+
191
+ def on_train_start(self) -> None:
192
+ self.ensure_negative_conditioning()
193
+
194
+ # ------------------------------------------------------------------
195
+ # Runtime primitives
196
+ # ------------------------------------------------------------------
197
+
198
+ def prepare_batch(
199
+ self,
200
+ raw_batch: dict[str, Any],
201
+ *,
202
+ generator: torch.Generator,
203
+ latents_source: Literal["data", "zeros"] = "data",
204
+ ) -> TrainingBatch:
205
+ self.ensure_negative_conditioning()
206
+ assert self.training_config is not None
207
+ tc = self.training_config
208
+
209
+ dtype = self._get_training_dtype()
210
+ device = self.device
211
+
212
+ training_batch = TrainingBatch()
213
+ encoder_hidden_states = raw_batch["text_embedding"]
214
+ encoder_attention_mask = raw_batch["text_attention_mask"]
215
+ infos = raw_batch.get("info_list")
216
+
217
+ if latents_source == "zeros":
218
+ batch_size = encoder_hidden_states.shape[0]
219
+ vae_config = (
220
+ tc.pipeline_config.vae_config.arch_config # type: ignore[union-attr]
221
+ )
222
+ num_channels = vae_config.z_dim
223
+ spatial_compression_ratio = (vae_config.spatial_compression_ratio)
224
+ latent_height = (tc.data.num_height // spatial_compression_ratio)
225
+ latent_width = (tc.data.num_width // spatial_compression_ratio)
226
+ latents = torch.zeros(
227
+ batch_size,
228
+ num_channels,
229
+ tc.data.num_latent_t,
230
+ latent_height,
231
+ latent_width,
232
+ device=device,
233
+ dtype=dtype,
234
+ )
235
+ elif latents_source == "data":
236
+ if "vae_latent" not in raw_batch:
237
+ raise ValueError("vae_latent not found in batch "
238
+ "and latents_source='data'")
239
+ latents = raw_batch["vae_latent"]
240
+ latents = latents[:, :, :tc.data.num_latent_t]
241
+ latents = latents.to(device, dtype=dtype)
242
+ else:
243
+ raise ValueError(f"Unknown latents_source: "
244
+ f"{latents_source!r}")
245
+
246
+ training_batch.latents = latents
247
+ training_batch.encoder_hidden_states = (encoder_hidden_states.to(device, dtype=dtype))
248
+ training_batch.encoder_attention_mask = (encoder_attention_mask.to(device, dtype=dtype))
249
+ training_batch.infos = infos
250
+
251
+ training_batch.latents = normalize_dit_input("wan", training_batch.latents, self.vae)
252
+ training_batch = self._prepare_dit_inputs(training_batch, generator)
253
+ training_batch = self._build_attention_metadata(training_batch)
254
+
255
+ training_batch.attn_metadata_vsa = copy.deepcopy(training_batch.attn_metadata)
256
+ if training_batch.attn_metadata is not None:
257
+ training_batch.attn_metadata.VSA_sparsity = 0.0 # type: ignore[attr-defined]
258
+
259
+ return training_batch
260
+
261
+ def add_noise(
262
+ self,
263
+ clean_latents: torch.Tensor,
264
+ noise: torch.Tensor,
265
+ timestep: torch.Tensor,
266
+ ) -> torch.Tensor:
267
+ b, t = clean_latents.shape[:2]
268
+ noisy = self.noise_scheduler.add_noise(
269
+ clean_latents.flatten(0, 1),
270
+ noise.flatten(0, 1),
271
+ timestep,
272
+ ).unflatten(0, (b, t))
273
+ return noisy
274
+
275
+ def predict_noise(
276
+ self,
277
+ noisy_latents: torch.Tensor,
278
+ timestep: torch.Tensor,
279
+ batch: TrainingBatch,
280
+ *,
281
+ conditional: bool,
282
+ cfg_uncond: dict[str, Any] | None = None,
283
+ attn_kind: Literal["dense", "vsa"] = "dense",
284
+ force_dense: bool = False,
285
+ ) -> torch.Tensor:
286
+ device_type = self.device.type
287
+ dtype = noisy_latents.dtype
288
+ if conditional:
289
+ text_dict = batch.conditional_dict
290
+ if text_dict is None:
291
+ raise RuntimeError("Missing conditional_dict in "
292
+ "TrainingBatch")
293
+ else:
294
+ text_dict = self._get_uncond_text_dict(batch, cfg_uncond=cfg_uncond)
295
+
296
+ if attn_kind == "dense":
297
+ attn_metadata = batch.attn_metadata
298
+ elif attn_kind in ("vsa", "sparse_fp4"):
299
+ attn_metadata = batch.attn_metadata_vsa
300
+ else:
301
+ raise ValueError(f"Unknown attn_kind: {attn_kind!r}")
302
+
303
+ with torch.autocast(device_type, dtype=dtype), set_forward_context(
304
+ current_timestep=batch.timesteps,
305
+ attn_metadata=attn_metadata,
306
+ force_dense=force_dense,
307
+ ):
308
+ input_kwargs = (self._build_distill_input_kwargs(noisy_latents, timestep, text_dict))
309
+ transformer = self._get_transformer(timestep)
310
+ pred_noise = transformer(**input_kwargs).permute(0, 2, 1, 3, 4)
311
+ return pred_noise
312
+
313
+ def backward(
314
+ self,
315
+ loss: torch.Tensor,
316
+ ctx: Any,
317
+ *,
318
+ grad_accum_rounds: int,
319
+ ) -> None:
320
+ timesteps, attn_metadata = ctx
321
+ with set_forward_context(
322
+ current_timestep=timesteps,
323
+ attn_metadata=attn_metadata,
324
+ ):
325
+ (loss / max(1, int(grad_accum_rounds))).backward()
326
+
327
+ # ------------------------------------------------------------------
328
+ # Internal helpers
329
+ # ------------------------------------------------------------------
330
+
331
+ def _get_training_dtype(self) -> torch.dtype:
332
+ return torch.bfloat16
333
+
334
+ def _init_timestep_mechanics(self) -> None:
335
+ assert self.training_config is not None
336
+ tc = self.training_config
337
+ flow_shift = tc.pipeline_config.flow_shift
338
+ self.timestep_shift = float(0.0 if flow_shift is None else flow_shift)
339
+ self.num_train_timestep = int(self.noise_scheduler.num_train_timesteps)
340
+ # min/max timestep ratios now come from method_config;
341
+ # default to full range.
342
+ self.min_timestep = 0
343
+ self.max_timestep = self.num_train_timestep
344
+
345
+ def ensure_negative_conditioning(self) -> None:
346
+ if self.negative_prompt_embeds is not None:
347
+ return
348
+
349
+ assert self.training_config is not None
350
+ tc = self.training_config
351
+ world_group = self.world_group
352
+ device = self.device
353
+ dtype = self._get_training_dtype()
354
+
355
+ from fastvideo.train.utils.moduleloader import (
356
+ make_inference_args, )
357
+
358
+ neg_embeds: torch.Tensor | None = None
359
+ neg_mask: torch.Tensor | None = None
360
+
361
+ if world_group.rank_in_group == 0:
362
+ sampling_param = SamplingParam.from_pretrained(tc.model_path)
363
+ negative_prompt = sampling_param.negative_prompt
364
+
365
+ inference_args = make_inference_args(tc, model_path=tc.model_path)
366
+
367
+ prompt_pipeline = WanPipeline.from_pretrained(
368
+ tc.model_path,
369
+ args=inference_args,
370
+ inference_mode=True,
371
+ loaded_modules={"transformer": self.transformer},
372
+ tp_size=tc.distributed.tp_size,
373
+ sp_size=tc.distributed.sp_size,
374
+ num_gpus=tc.distributed.num_gpus,
375
+ pin_cpu_memory=(tc.distributed.pin_cpu_memory),
376
+ dit_cpu_offload=True,
377
+ )
378
+
379
+ batch_negative = ForwardBatch(
380
+ data_type="video",
381
+ prompt=negative_prompt,
382
+ prompt_embeds=[],
383
+ prompt_attention_mask=[],
384
+ )
385
+ result_batch = prompt_pipeline.prompt_encoding_stage( # type: ignore[attr-defined]
386
+ batch_negative,
387
+ inference_args,
388
+ )
389
+
390
+ neg_embeds = result_batch.prompt_embeds[0].to(device=device, dtype=dtype)
391
+ neg_mask = (result_batch.prompt_attention_mask[0].to(device=device, dtype=dtype))
392
+
393
+ del prompt_pipeline
394
+ gc.collect()
395
+ if torch.cuda.is_available():
396
+ torch.cuda.empty_cache()
397
+
398
+ meta = torch.zeros((2, ), device=device, dtype=torch.int64)
399
+ if world_group.rank_in_group == 0:
400
+ assert neg_embeds is not None
401
+ assert neg_mask is not None
402
+ meta[0] = neg_embeds.ndim
403
+ meta[1] = neg_mask.ndim
404
+ world_group.broadcast(meta, src=0)
405
+ embed_ndim, mask_ndim = (
406
+ int(meta[0].item()),
407
+ int(meta[1].item()),
408
+ )
409
+
410
+ max_ndim = 8
411
+ embed_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64)
412
+ mask_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64)
413
+ if world_group.rank_in_group == 0:
414
+ assert neg_embeds is not None
415
+ assert neg_mask is not None
416
+ embed_shape[:embed_ndim] = torch.tensor(
417
+ list(neg_embeds.shape),
418
+ device=device,
419
+ dtype=torch.int64,
420
+ )
421
+ mask_shape[:mask_ndim] = torch.tensor(
422
+ list(neg_mask.shape),
423
+ device=device,
424
+ dtype=torch.int64,
425
+ )
426
+ world_group.broadcast(embed_shape, src=0)
427
+ world_group.broadcast(mask_shape, src=0)
428
+
429
+ embed_sizes = tuple(int(x) for x in embed_shape[:embed_ndim].tolist())
430
+ mask_sizes = tuple(int(x) for x in mask_shape[:mask_ndim].tolist())
431
+
432
+ if world_group.rank_in_group != 0:
433
+ neg_embeds = torch.empty(embed_sizes, device=device, dtype=dtype)
434
+ neg_mask = torch.empty(mask_sizes, device=device, dtype=dtype)
435
+ assert neg_embeds is not None
436
+ assert neg_mask is not None
437
+
438
+ world_group.broadcast(neg_embeds, src=0)
439
+ world_group.broadcast(neg_mask, src=0)
440
+
441
+ self.negative_prompt_embeds = neg_embeds
442
+ self.negative_prompt_attention_mask = neg_mask
443
+
444
+ def _sample_timesteps(
445
+ self,
446
+ batch_size: int,
447
+ device: torch.device,
448
+ generator: torch.Generator,
449
+ ) -> torch.Tensor:
450
+ assert self.training_config is not None
451
+ tc = self.training_config
452
+
453
+ u = compute_density_for_timestep_sampling(
454
+ weighting_scheme=tc.model.weighting_scheme,
455
+ batch_size=batch_size,
456
+ generator=generator,
457
+ device=device,
458
+ logit_mean=tc.model.logit_mean,
459
+ logit_std=tc.model.logit_std,
460
+ mode_scale=tc.model.mode_scale,
461
+ )
462
+ indices = (u * self.noise_scheduler.config.num_train_timesteps).long()
463
+ return self.noise_scheduler.timesteps[indices.cpu()].to(device=device)
464
+
465
+ def _build_attention_metadata(self, training_batch: TrainingBatch) -> TrainingBatch:
466
+ assert self.training_config is not None
467
+ tc = self.training_config
468
+ latents_shape = training_batch.raw_latent_shape
469
+ patch_size = (
470
+ tc.pipeline_config.dit_config.patch_size # type: ignore[union-attr]
471
+ )
472
+ assert latents_shape is not None
473
+ assert training_batch.timesteps is not None
474
+
475
+ if envs.FASTVIDEO_ATTENTION_BACKEND in (
476
+ "VIDEO_SPARSE_ATTN", "SPARSE_FP4_ATTN", "SPARSE_FP4_OURS_P_ATTN",
477
+ ):
478
+ if (not is_vsa_available() or VideoSparseAttentionMetadataBuilder is None):
479
+ raise ImportError(
480
+ f"FASTVIDEO_ATTENTION_BACKEND is "
481
+ f"{envs.FASTVIDEO_ATTENTION_BACKEND}, but "
482
+ f"fastvideo_kernel is not correctly "
483
+ f"installed or detected.")
484
+ training_batch.attn_metadata = VideoSparseAttentionMetadataBuilder().build( # type: ignore[misc]
485
+ raw_latent_shape=latents_shape[2:5],
486
+ current_timestep=(training_batch.timesteps),
487
+ patch_size=patch_size,
488
+ VSA_sparsity=tc.vsa_sparsity,
489
+ device=self.device,
490
+ )
491
+ elif (envs.FASTVIDEO_ATTENTION_BACKEND == "VMOBA_ATTN"):
492
+ if (not is_vmoba_available() or VideoMobaAttentionMetadataBuilder is None):
493
+ raise ImportError("FASTVIDEO_ATTENTION_BACKEND is "
494
+ "VMOBA_ATTN, but fastvideo_kernel "
495
+ "(or flash_attn>=2.7.4) is not "
496
+ "correctly installed.")
497
+ moba_params = tc.model.moba_config.copy()
498
+ assert training_batch.raw_latent_shape is not None
499
+ moba_params.update({
500
+ "current_timestep": (training_batch.timesteps),
501
+ "raw_latent_shape": (training_batch.raw_latent_shape[2:5]),
502
+ "patch_size": patch_size,
503
+ "device": self.device,
504
+ })
505
+ training_batch.attn_metadata = VideoMobaAttentionMetadataBuilder().build(**
506
+ moba_params) # type: ignore[misc]
507
+ else:
508
+ training_batch.attn_metadata = None
509
+
510
+ return training_batch
511
+
512
+ def _prepare_dit_inputs(
513
+ self,
514
+ training_batch: TrainingBatch,
515
+ generator: torch.Generator,
516
+ ) -> TrainingBatch:
517
+ assert self.training_config is not None
518
+ tc = self.training_config
519
+ latents = training_batch.latents
520
+ assert isinstance(latents, torch.Tensor)
521
+ batch_size = latents.shape[0]
522
+
523
+ noise = torch.randn(
524
+ latents.shape,
525
+ generator=generator,
526
+ device=latents.device,
527
+ dtype=latents.dtype,
528
+ )
529
+ timesteps = self._sample_timesteps(
530
+ batch_size,
531
+ latents.device,
532
+ generator,
533
+ )
534
+ if int(tc.distributed.sp_size or 1) > 1:
535
+ self.sp_group.broadcast(timesteps, src=0)
536
+
537
+ sigmas = get_sigmas(
538
+ self.noise_scheduler,
539
+ latents.device,
540
+ timesteps,
541
+ n_dim=latents.ndim,
542
+ dtype=latents.dtype,
543
+ )
544
+ noisy_model_input = ((1.0 - sigmas) * latents + sigmas * noise)
545
+
546
+ training_batch.noisy_model_input = (noisy_model_input)
547
+ training_batch.timesteps = timesteps
548
+ training_batch.sigmas = sigmas
549
+ training_batch.noise = noise
550
+ training_batch.raw_latent_shape = latents.shape
551
+
552
+ training_batch.conditional_dict = {
553
+ "encoder_hidden_states": (training_batch.encoder_hidden_states),
554
+ "encoder_attention_mask": (training_batch.encoder_attention_mask),
555
+ }
556
+
557
+ if (self.negative_prompt_embeds is not None and self.negative_prompt_attention_mask is not None):
558
+ neg_embeds = self.negative_prompt_embeds
559
+ neg_mask = (self.negative_prompt_attention_mask)
560
+ if (neg_embeds.shape[0] == 1 and batch_size > 1):
561
+ neg_embeds = neg_embeds.expand(batch_size, *neg_embeds.shape[1:]).contiguous()
562
+ if (neg_mask.shape[0] == 1 and batch_size > 1):
563
+ neg_mask = neg_mask.expand(batch_size, *neg_mask.shape[1:]).contiguous()
564
+ training_batch.unconditional_dict = {
565
+ "encoder_hidden_states": neg_embeds,
566
+ "encoder_attention_mask": neg_mask,
567
+ }
568
+
569
+ training_batch.latents = (training_batch.latents.permute(0, 2, 1, 3, 4))
570
+ return training_batch
571
+
572
+ def _build_distill_input_kwargs(
573
+ self,
574
+ noise_input: torch.Tensor,
575
+ timestep: torch.Tensor,
576
+ text_dict: dict[str, torch.Tensor] | None,
577
+ ) -> dict[str, Any]:
578
+ if text_dict is None:
579
+ raise ValueError("text_dict cannot be None for "
580
+ "Wan distillation")
581
+ return {
582
+ "hidden_states": noise_input.permute(0, 2, 1, 3, 4),
583
+ "encoder_hidden_states": text_dict["encoder_hidden_states"],
584
+ "encoder_attention_mask": text_dict["encoder_attention_mask"],
585
+ "timestep": timestep,
586
+ "return_dict": False,
587
+ }
588
+
589
+ def _get_transformer(self, timestep: torch.Tensor) -> torch.nn.Module:
590
+ return self.transformer
591
+
592
+ def _get_uncond_text_dict(
593
+ self,
594
+ batch: TrainingBatch,
595
+ *,
596
+ cfg_uncond: dict[str, Any] | None,
597
+ ) -> dict[str, torch.Tensor]:
598
+ if cfg_uncond is None:
599
+ text_dict = getattr(batch, "unconditional_dict", None)
600
+ if text_dict is None:
601
+ raise RuntimeError("Missing unconditional_dict; "
602
+ "ensure_negative_conditioning() "
603
+ "may have failed")
604
+ return text_dict
605
+
606
+ on_missing_raw = cfg_uncond.get("on_missing", "error")
607
+ if not isinstance(on_missing_raw, str):
608
+ raise ValueError("method_config.cfg_uncond.on_missing "
609
+ "must be a string, got "
610
+ f"{type(on_missing_raw).__name__}")
611
+ on_missing = on_missing_raw.strip().lower()
612
+ if on_missing not in {"error", "ignore"}:
613
+ raise ValueError("method_config.cfg_uncond.on_missing "
614
+ "must be one of {error, ignore}, got "
615
+ f"{on_missing_raw!r}")
616
+
617
+ for channel, policy_raw in cfg_uncond.items():
618
+ if channel in {"on_missing", "text"}:
619
+ continue
620
+ if policy_raw is None:
621
+ continue
622
+ if not isinstance(policy_raw, str):
623
+ raise ValueError("method_config.cfg_uncond values "
624
+ "must be strings, got "
625
+ f"{channel}="
626
+ f"{type(policy_raw).__name__}")
627
+ policy = policy_raw.strip().lower()
628
+ if policy == "keep":
629
+ continue
630
+ if on_missing == "ignore":
631
+ continue
632
+ raise ValueError("WanModel does not support "
633
+ "cfg_uncond channel "
634
+ f"{channel!r} (policy={policy!r}). "
635
+ "Set cfg_uncond.on_missing=ignore or "
636
+ "remove the channel.")
637
+
638
+ text_policy_raw = cfg_uncond.get("text", None)
639
+ if text_policy_raw is None:
640
+ text_policy = "negative_prompt"
641
+ elif not isinstance(text_policy_raw, str):
642
+ raise ValueError("method_config.cfg_uncond.text must be "
643
+ "a string, got "
644
+ f"{type(text_policy_raw).__name__}")
645
+ else:
646
+ text_policy = (text_policy_raw.strip().lower())
647
+
648
+ if text_policy in {"negative_prompt"}:
649
+ text_dict = getattr(batch, "unconditional_dict", None)
650
+ if text_dict is None:
651
+ raise RuntimeError("Missing unconditional_dict; "
652
+ "ensure_negative_conditioning() "
653
+ "may have failed")
654
+ return text_dict
655
+ if text_policy == "keep":
656
+ if batch.conditional_dict is None:
657
+ raise RuntimeError("Missing conditional_dict in "
658
+ "TrainingBatch")
659
+ return batch.conditional_dict
660
+ if text_policy == "zero":
661
+ if batch.conditional_dict is None:
662
+ raise RuntimeError("Missing conditional_dict in "
663
+ "TrainingBatch")
664
+ cond = batch.conditional_dict
665
+ enc = cond["encoder_hidden_states"]
666
+ mask = cond["encoder_attention_mask"]
667
+ if not torch.is_tensor(enc) or not torch.is_tensor(mask):
668
+ raise TypeError("conditional_dict must contain "
669
+ "tensor text inputs")
670
+ return {
671
+ "encoder_hidden_states": (torch.zeros_like(enc)),
672
+ "encoder_attention_mask": (torch.zeros_like(mask)),
673
+ }
674
+ if text_policy == "drop":
675
+ raise ValueError("cfg_uncond.text=drop is not supported "
676
+ "for Wan. Use "
677
+ "{negative_prompt, keep, zero}.")
678
+ raise ValueError("cfg_uncond.text must be one of "
679
+ "{negative_prompt, keep, zero, drop}, got "
680
+ f"{text_policy_raw!r}")
standalone_inference/overlay_files/fastvideo/training/training_pipeline.py ADDED
@@ -0,0 +1,1044 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from dataclasses import asdict
3
+ from contextlib import AbstractContextManager, nullcontext
4
+ import math
5
+ import os
6
+ import shutil
7
+ import tempfile
8
+ import time
9
+ from abc import ABC, abstractmethod
10
+ from collections import deque
11
+ from collections.abc import Iterator
12
+ from typing import Any
13
+ from fastvideo.profiler import profile_region
14
+ import imageio
15
+ import numpy as np
16
+ import torch
17
+ import torch.distributed as dist
18
+ import torchvision
19
+ from einops import rearrange
20
+ from torch.utils.data import DataLoader
21
+ from torchdata.stateful_dataloader import StatefulDataLoader
22
+ from tqdm.auto import tqdm
23
+ from diffusers import FlowMatchEulerDiscreteScheduler
24
+
25
+ import fastvideo.envs as envs
26
+ try:
27
+ from fastvideo.attention.backends.video_sparse_attn import (VideoSparseAttentionMetadataBuilder)
28
+ from fastvideo.attention.backends.vmoba import VideoMobaAttentionMetadataBuilder
29
+ except Exception:
30
+ pass
31
+ from fastvideo.configs.sample import SamplingParam
32
+ from fastvideo.dataset import build_parquet_map_style_dataloader
33
+ from fastvideo.dataset.dataloader.schema import pyarrow_schema_t2v
34
+ from fastvideo.dataset.validation_dataset import ValidationDataset
35
+ from fastvideo.distributed import (cleanup_dist_env_and_memory, get_local_torch_device, get_sp_group, get_world_group)
36
+ from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs
37
+ from fastvideo.forward_context import set_forward_context
38
+ from fastvideo.logger import init_logger
39
+ from fastvideo.attention.selector import global_force_attn_backend_context_manager
40
+ from fastvideo.pipelines import (ComposedPipelineBase, ForwardBatch, LoRAPipeline, TrainingBatch)
41
+ from fastvideo.platforms import AttentionBackendEnum, current_platform
42
+ from fastvideo.training.activation_checkpoint import (apply_activation_checkpointing)
43
+ from fastvideo.training.trackers import (DummyTracker, TrackerType, initialize_trackers, Trackers)
44
+ from fastvideo.training.training_utils import (clip_grad_norm_while_handling_failing_dtensor_cases,
45
+ compute_density_for_timestep_sampling, count_trainable, get_scheduler,
46
+ get_sigmas, load_checkpoint, normalize_dit_input, save_checkpoint,
47
+ swap_fp4_linear, traverse_swap_module)
48
+ from fastvideo.utils import (is_vmoba_available, is_vsa_available, set_random_seed, shallow_asdict)
49
+
50
+ try:
51
+ vsa_available = is_vsa_available()
52
+ vmoba_available = is_vmoba_available()
53
+ except Exception:
54
+ vsa_available = False
55
+ vmoba_available = False
56
+
57
+ logger = init_logger(__name__)
58
+
59
+
60
+ class TrainingPipeline(LoRAPipeline, ABC):
61
+ """
62
+ A pipeline for training a model. All training pipelines should inherit from this class.
63
+ All reusable components and code should be implemented in this class.
64
+ """
65
+ _required_config_modules = ["scheduler", "transformer"]
66
+ validation_pipeline: ComposedPipelineBase
67
+ train_dataloader: StatefulDataLoader
68
+ train_loader_iter: Iterator[dict[str, Any]]
69
+ current_epoch: int = 0
70
+ train_transformer_2: bool = False
71
+ tracker: TrackerType
72
+
73
+ def __init__(self,
74
+ model_path: str,
75
+ fastvideo_args: TrainingArgs,
76
+ required_config_modules: list[str] | None = None,
77
+ loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
78
+ fastvideo_args.inference_mode = False
79
+ self.lora_training = fastvideo_args.lora_training
80
+ if self.lora_training and fastvideo_args.lora_rank is None:
81
+ raise ValueError("lora rank must be set when using lora training")
82
+
83
+ set_random_seed(fastvideo_args.seed) # for lora param init
84
+ super().__init__(model_path, fastvideo_args, required_config_modules, loaded_modules) # type: ignore
85
+ self.tracker = DummyTracker()
86
+
87
+ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
88
+ raise RuntimeError("create_pipeline_stages should not be called for training pipeline")
89
+
90
+ @staticmethod
91
+ def _should_force_generator_attn_qat_train(fastvideo_args: FastVideoArgs) -> bool:
92
+ if not isinstance(fastvideo_args, TrainingArgs):
93
+ return False
94
+ return (fastvideo_args.generator_4bit_attn or envs.FASTVIDEO_ATTENTION_BACKEND == "ATTN_QAT_TRAIN")
95
+
96
+ def load_modules(self,
97
+ fastvideo_args: FastVideoArgs,
98
+ loaded_modules: dict[str, torch.nn.Module] | None = None) -> dict[str, Any]:
99
+ force_generator_qat = self._should_force_generator_attn_qat_train(fastvideo_args)
100
+ load_context: AbstractContextManager[None] = nullcontext()
101
+ if force_generator_qat:
102
+ logger.info("Forcing generator attention backend to ATTN_QAT_TRAIN during module loading")
103
+ load_context = global_force_attn_backend_context_manager(AttentionBackendEnum.ATTN_QAT_TRAIN)
104
+
105
+ with load_context:
106
+ return super().load_modules(fastvideo_args, loaded_modules)
107
+
108
+ def set_schemas(self) -> None:
109
+ self.train_dataset_schema = pyarrow_schema_t2v
110
+
111
+ def initialize_training_pipeline(self, training_args: TrainingArgs):
112
+ logger.info("Initializing training pipeline...")
113
+ self.device = get_local_torch_device()
114
+ self.training_args = training_args
115
+ world_group = get_world_group()
116
+ self.world_size = world_group.world_size
117
+ self.global_rank = world_group.rank
118
+ self.sp_group = get_sp_group()
119
+ self.rank_in_sp_group = self.sp_group.rank_in_group
120
+ self.sp_world_size = self.sp_group.world_size
121
+ self.local_rank = world_group.local_rank
122
+ self.transformer = self.get_module("transformer")
123
+ self.transformer_2 = self.get_module("transformer_2", None)
124
+ self.seed = training_args.seed
125
+ self.set_schemas()
126
+
127
+ # Set random seeds for deterministic training
128
+ assert self.seed is not None, "seed must be set"
129
+ set_random_seed(self.seed + self.global_rank)
130
+ self.transformer.train()
131
+ if training_args.enable_gradient_checkpointing_type is not None:
132
+ self.transformer = apply_activation_checkpointing(
133
+ self.transformer, checkpointing_type=training_args.enable_gradient_checkpointing_type)
134
+ if self.transformer_2 is not None:
135
+ self.transformer_2 = apply_activation_checkpointing(
136
+ self.transformer_2, checkpointing_type=training_args.enable_gradient_checkpointing_type)
137
+
138
+ if training_args.generator_4bit_linear:
139
+ num_swaps = traverse_swap_module(self.transformer, swap_fn=swap_fp4_linear)
140
+ logger.info("Swapped %s linear layers to the FP4 forward path in self.transformer", num_swaps)
141
+ noise_scheduler = self.modules["scheduler"]
142
+ self.set_trainable()
143
+ params_to_optimize = self.transformer.parameters()
144
+ params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))
145
+ # Parse betas from string format "beta1,beta2"
146
+ betas_str = training_args.betas
147
+ betas = tuple(float(x.strip()) for x in betas_str.split(","))
148
+
149
+ self.optimizer = torch.optim.AdamW(
150
+ params_to_optimize,
151
+ lr=training_args.learning_rate,
152
+ betas=betas,
153
+ weight_decay=training_args.weight_decay,
154
+ eps=1e-8,
155
+ )
156
+
157
+ self.init_steps = 0
158
+ logger.info("optimizer: %s", self.optimizer)
159
+
160
+ self.lr_scheduler = get_scheduler(
161
+ training_args.lr_scheduler,
162
+ optimizer=self.optimizer,
163
+ num_warmup_steps=training_args.lr_warmup_steps,
164
+ num_training_steps=training_args.max_train_steps,
165
+ num_cycles=training_args.lr_num_cycles,
166
+ power=training_args.lr_power,
167
+ min_lr_ratio=training_args.min_lr_ratio,
168
+ last_epoch=self.init_steps - 1,
169
+ )
170
+ if self.transformer_2 is not None:
171
+ # Ensure transformer_2 has trainable parameters before creating optimizer
172
+ params_to_optimize_2 = self.transformer_2.parameters()
173
+ params_to_optimize_2 = list(filter(lambda p: p.requires_grad, params_to_optimize_2))
174
+ self.optimizer_2 = torch.optim.AdamW(
175
+ params_to_optimize_2,
176
+ lr=training_args.learning_rate,
177
+ betas=(0.9, 0.999),
178
+ weight_decay=training_args.weight_decay,
179
+ eps=1e-8,
180
+ )
181
+ self.lr_scheduler_2 = get_scheduler(
182
+ training_args.lr_scheduler,
183
+ optimizer=self.optimizer_2,
184
+ num_warmup_steps=training_args.lr_warmup_steps,
185
+ num_training_steps=training_args.max_train_steps,
186
+ num_cycles=training_args.lr_num_cycles,
187
+ power=training_args.lr_power,
188
+ min_lr_ratio=training_args.min_lr_ratio,
189
+ last_epoch=self.init_steps - 1,
190
+ )
191
+
192
+ self.train_dataset, self.train_dataloader = build_parquet_map_style_dataloader(
193
+ training_args.data_path,
194
+ training_args.train_batch_size,
195
+ parquet_schema=self.train_dataset_schema,
196
+ num_data_workers=training_args.dataloader_num_workers,
197
+ cfg_rate=training_args.training_cfg_rate,
198
+ drop_last=True,
199
+ text_padding_length=training_args.pipeline_config.text_encoder_configs[0].arch_config.
200
+ text_len, # type: ignore[attr-defined]
201
+ seed=self.seed)
202
+
203
+ self.noise_scheduler = noise_scheduler
204
+ if self.training_args.boundary_ratio is not None:
205
+ self.boundary_timestep = self.training_args.boundary_ratio * self.noise_scheduler.num_train_timesteps
206
+ else:
207
+ self.boundary_timestep = None
208
+
209
+ logger.info("train_dataloader length: %s", len(self.train_dataloader))
210
+ logger.info("train_sp_batch_size: %s", training_args.train_sp_batch_size)
211
+ logger.info("gradient_accumulation_steps: %s", training_args.gradient_accumulation_steps)
212
+ logger.info("sp_size: %s", training_args.sp_size)
213
+
214
+ self.num_update_steps_per_epoch = math.ceil(
215
+ len(self.train_dataloader) / training_args.gradient_accumulation_steps * training_args.sp_size /
216
+ training_args.train_sp_batch_size)
217
+ self.num_train_epochs = math.ceil(training_args.max_train_steps / self.num_update_steps_per_epoch)
218
+
219
+ # TODO(will): is there a cleaner way to track epochs?
220
+ self.current_epoch = 0
221
+
222
+ trackers = list(training_args.trackers)
223
+ if not trackers and training_args.tracker_project_name:
224
+ trackers.append(Trackers.WANDB.value)
225
+ if self.global_rank != 0:
226
+ trackers = []
227
+
228
+ tracker_log_dir = training_args.output_dir or os.getcwd()
229
+ if trackers:
230
+ tracker_log_dir = os.path.join(tracker_log_dir, "tracker")
231
+
232
+ tracker_config = asdict(training_args) if trackers else None
233
+ tracker_run_name = training_args.wandb_run_name or None
234
+ project = training_args.tracker_project_name or "fastvideo"
235
+ self.tracker = initialize_trackers(
236
+ trackers,
237
+ experiment_name=project,
238
+ config=tracker_config,
239
+ log_dir=tracker_log_dir,
240
+ run_name=tracker_run_name,
241
+ )
242
+
243
+ @abstractmethod
244
+ def initialize_validation_pipeline(self, training_args: TrainingArgs):
245
+ raise NotImplementedError("Training pipelines must implement this method")
246
+
247
+ def _prepare_training(self, training_batch: TrainingBatch) -> TrainingBatch:
248
+ self.optimizer.zero_grad()
249
+ if self.transformer_2 is not None:
250
+ self.optimizer_2.zero_grad()
251
+ training_batch.total_loss = 0.0
252
+ return training_batch
253
+
254
+ def _get_next_batch(self, training_batch: TrainingBatch) -> TrainingBatch:
255
+ with self.tracker.timed("timing/get_next_batch"):
256
+ batch = next(self.train_loader_iter, None) # type: ignore
257
+ if batch is None:
258
+ self.current_epoch += 1
259
+ logger.info("Starting epoch %s", self.current_epoch)
260
+ # Reset iterator for next epoch
261
+ self.train_loader_iter = iter(self.train_dataloader)
262
+ # Get first batch of new epoch
263
+ batch = next(self.train_loader_iter)
264
+
265
+ latents = batch['vae_latent']
266
+ latents = latents[:, :, :self.training_args.num_latent_t]
267
+ encoder_hidden_states = batch['text_embedding']
268
+ encoder_attention_mask = batch['text_attention_mask']
269
+ infos = batch['info_list']
270
+
271
+ training_batch.latents = latents.to(
272
+ get_local_torch_device(),
273
+ dtype=torch.bfloat16,
274
+ non_blocking=True,
275
+ )
276
+ training_batch.encoder_hidden_states = (encoder_hidden_states.to(
277
+ get_local_torch_device(),
278
+ dtype=torch.bfloat16,
279
+ non_blocking=True,
280
+ ))
281
+ training_batch.encoder_attention_mask = (encoder_attention_mask.to(
282
+ get_local_torch_device(),
283
+ dtype=torch.bfloat16,
284
+ non_blocking=True,
285
+ ))
286
+ training_batch.infos = infos
287
+
288
+ return training_batch
289
+
290
+ def _normalize_dit_input(self, training_batch: TrainingBatch) -> TrainingBatch:
291
+ # TODO(will): support other models
292
+ with self.tracker.timed("timing/normalize_input"):
293
+ training_batch.latents = normalize_dit_input(
294
+ 'wan',
295
+ training_batch.latents,
296
+ self.get_module("vae"),
297
+ )
298
+ return training_batch
299
+
300
+ def _prepare_dit_inputs(self, training_batch: TrainingBatch) -> TrainingBatch:
301
+ assert self.training_args is not None, "training_args must be set"
302
+ with self.tracker.timed("timing/prepare_dit_inputs"):
303
+ latents = training_batch.latents
304
+ batch_size = latents.shape[0]
305
+ noise = torch.randn(latents.shape,
306
+ generator=self.noise_gen_cuda,
307
+ device=latents.device,
308
+ dtype=latents.dtype)
309
+ timesteps = self._sample_timesteps(batch_size, latents.device)
310
+
311
+ if self.training_args.sp_size > 1:
312
+ # Make sure that the timesteps are the same across all sp processes.
313
+ sp_group = get_sp_group()
314
+ sp_group.broadcast(timesteps, src=0)
315
+ sp_group.broadcast(noise, src=0)
316
+ sigmas = get_sigmas(
317
+ self.noise_scheduler,
318
+ latents.device,
319
+ timesteps,
320
+ n_dim=latents.ndim,
321
+ dtype=latents.dtype,
322
+ )
323
+ noisy_model_input = (1.0 - sigmas) * training_batch.latents + sigmas * noise
324
+
325
+ training_batch.noisy_model_input = noisy_model_input
326
+ training_batch.timesteps = timesteps
327
+ training_batch.sigmas = sigmas
328
+ training_batch.noise = noise
329
+ training_batch.raw_latent_shape = training_batch.latents.shape
330
+
331
+ return training_batch
332
+
333
+ def _sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor:
334
+ # Determine which model to train based on the boundary timestep
335
+ if (self.transformer_2 is not None and self.boundary_timestep is not None
336
+ and torch.rand(1, generator=self.noise_random_generator).item() <= self.training_args.boundary_ratio):
337
+ self.train_transformer_2 = True
338
+ else:
339
+ self.train_transformer_2 = False
340
+
341
+ # Broadcast the decision to all processes
342
+ decision = torch.tensor(1.0 if self.train_transformer_2 else 0.0, device=self.device)
343
+ dist.broadcast(decision, src=0)
344
+ self.train_transformer_2 = decision.item() == 1.0
345
+
346
+ # Sample u from the appropriate range
347
+ u = compute_density_for_timestep_sampling(
348
+ weighting_scheme=self.training_args.weighting_scheme,
349
+ batch_size=batch_size,
350
+ generator=self.noise_random_generator,
351
+ logit_mean=self.training_args.logit_mean,
352
+ logit_std=self.training_args.logit_std,
353
+ mode_scale=self.training_args.mode_scale,
354
+ )
355
+
356
+ boundary_ratio = self.training_args.boundary_ratio
357
+ if self.train_transformer_2:
358
+ u = (1 - boundary_ratio) + u * boundary_ratio # min: 1 - boundary_ratio, max: 1
359
+ # elif self.transformer_2 is not None:
360
+ # u = u * (1 - boundary_ratio) # min: 0, max: 1 - boundary_ratio
361
+ # else: # patch for now to align with non-MoE timestep logic
362
+ # pass
363
+
364
+ indices = (u * self.noise_scheduler.config.num_train_timesteps).long()
365
+ return self.noise_scheduler.timesteps[indices].to(device=device)
366
+
367
+ def _build_attention_metadata(self, training_batch: TrainingBatch) -> TrainingBatch:
368
+ latents_shape = training_batch.raw_latent_shape
369
+ patch_size = self.training_args.pipeline_config.dit_config.patch_size
370
+ current_vsa_sparsity = training_batch.current_vsa_sparsity
371
+ assert latents_shape is not None
372
+ assert isinstance(patch_size, tuple), f"Expected tuple patch_size, got {patch_size!r}"
373
+ assert training_batch.timesteps is not None
374
+ if envs.FASTVIDEO_ATTENTION_BACKEND in (
375
+ "VIDEO_SPARSE_ATTN",
376
+ "SPARSE_FP4_ATTN",
377
+ "SPARSE_FP4_OURS_P_ATTN",
378
+ ):
379
+ if not vsa_available:
380
+ raise ImportError("FASTVIDEO_ATTENTION_BACKEND is set to VIDEO_SPARSE_ATTN, "
381
+ "but fastvideo_kernel is not correctly installed or detected. "
382
+ "Please ensure fastvideo-kernel is installed.")
383
+ training_batch.attn_metadata = VideoSparseAttentionMetadataBuilder( # type: ignore
384
+ ).build( # type: ignore
385
+ raw_latent_shape=latents_shape[2:5],
386
+ current_timestep=training_batch.timesteps,
387
+ patch_size=patch_size,
388
+ VSA_sparsity=current_vsa_sparsity,
389
+ device=get_local_torch_device())
390
+ elif envs.FASTVIDEO_ATTENTION_BACKEND == "VMOBA_ATTN":
391
+ if not vmoba_available:
392
+ raise ImportError("FASTVIDEO_ATTENTION_BACKEND is set to VMOBA_ATTN, "
393
+ "but fastvideo_kernel (or flash_attn>=2.7.4) is not correctly installed.")
394
+ moba_params = self.training_args.moba_config.copy()
395
+ moba_params.update({
396
+ "current_timestep": training_batch.timesteps,
397
+ "raw_latent_shape": latents_shape[2:5],
398
+ "patch_size": self.training_args.pipeline_config.dit_config.patch_size,
399
+ "device": get_local_torch_device(),
400
+ })
401
+ training_batch.attn_metadata = VideoMobaAttentionMetadataBuilder().build(**moba_params)
402
+ else:
403
+ training_batch.attn_metadata = None
404
+
405
+ return training_batch
406
+
407
+ def _build_input_kwargs(self, training_batch: TrainingBatch) -> TrainingBatch:
408
+ training_batch.input_kwargs = {
409
+ "hidden_states": training_batch.noisy_model_input,
410
+ "encoder_hidden_states": training_batch.encoder_hidden_states,
411
+ "timestep": training_batch.timesteps.to(get_local_torch_device(), dtype=torch.bfloat16),
412
+ "encoder_attention_mask": training_batch.encoder_attention_mask,
413
+ "return_dict": False,
414
+ }
415
+ return training_batch
416
+
417
+ def _transformer_forward_and_compute_loss(self, training_batch: TrainingBatch) -> TrainingBatch:
418
+ if vsa_available and envs.FASTVIDEO_ATTENTION_BACKEND in (
419
+ "VIDEO_SPARSE_ATTN",
420
+ "SPARSE_FP4_ATTN",
421
+ "SPARSE_FP4_OURS_P_ATTN",
422
+ ) or vmoba_available and envs.FASTVIDEO_ATTENTION_BACKEND == "VMOBA_ATTN":
423
+ assert training_batch.attn_metadata is not None
424
+ else:
425
+ assert training_batch.attn_metadata is None
426
+ input_kwargs = training_batch.input_kwargs
427
+
428
+ # if 'hunyuan' in self.training_args.model_type:
429
+ # input_kwargs["guidance"] = torch.tensor(
430
+ # [1000.0],
431
+ # device=training_batch.noisy_model_input.device,
432
+ # dtype=torch.bfloat16)
433
+ current_model = self.transformer_2 if self.train_transformer_2 else self.transformer
434
+
435
+ with self.tracker.timed("timing/forward_backward"), set_forward_context(
436
+ current_timestep=training_batch.current_timestep, attn_metadata=training_batch.attn_metadata):
437
+ model_pred = current_model(**input_kwargs)
438
+ if self.training_args.precondition_outputs:
439
+ assert training_batch.sigmas is not None
440
+ model_pred = training_batch.noisy_model_input - model_pred * training_batch.sigmas
441
+ assert training_batch.latents is not None
442
+ assert training_batch.noise is not None
443
+ target = training_batch.latents if self.training_args.precondition_outputs else training_batch.noise - training_batch.latents
444
+
445
+ # make sure no implicit broadcasting happens
446
+ assert model_pred.shape == target.shape, f"model_pred.shape: {model_pred.shape}, target.shape: {target.shape}"
447
+
448
+ loss = (torch.mean(
449
+ (model_pred.float() - target.float())**2) / self.training_args.gradient_accumulation_steps)
450
+
451
+ loss.backward()
452
+
453
+ avg_loss = loss.detach().clone()
454
+
455
+ # Reduce across ranks without forcing a CPU sync
456
+ with self.tracker.timed("timing/reduce_loss"):
457
+ world_group = get_world_group()
458
+ avg_loss = world_group.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
459
+ # Accumulate on GPU; materialize to CPU only once after
460
+ # all gradient-accumulation iterations (see train_one_step).
461
+ training_batch.total_loss += avg_loss
462
+
463
+ return training_batch
464
+
465
+ def _clip_grad_norm(self, training_batch: TrainingBatch) -> TrainingBatch:
466
+ max_grad_norm = self.training_args.max_grad_norm
467
+
468
+ # TODO(will): perhaps move this into transformer api so that we can do
469
+ # the following:
470
+ # grad_norm = transformer.clip_grad_norm_(max_grad_norm)
471
+ if max_grad_norm is not None:
472
+ with self.tracker.timed("timing/clip_grad_norm"):
473
+ # Only clip gradients for the model that is currently training
474
+ if self.train_transformer_2 and self.transformer_2 is not None:
475
+ model_parts = [self.transformer_2]
476
+ else:
477
+ model_parts = [self.transformer]
478
+
479
+ grad_norm = clip_grad_norm_while_handling_failing_dtensor_cases(
480
+ [p for m in model_parts for p in m.parameters()],
481
+ max_grad_norm,
482
+ foreach=None,
483
+ )
484
+ assert grad_norm is not float('nan') or grad_norm is not float('inf')
485
+ grad_norm = grad_norm.item() if grad_norm is not None else 0.0
486
+ else:
487
+ grad_norm = 0.0
488
+ training_batch.grad_norm = grad_norm
489
+ return training_batch
490
+
491
+ @profile_region("profiler_region_training_train_one_step")
492
+ def train_one_step(self, training_batch: TrainingBatch) -> TrainingBatch:
493
+ training_batch = self._prepare_training(training_batch)
494
+
495
+ for _ in range(self.training_args.gradient_accumulation_steps):
496
+ training_batch = self._get_next_batch(training_batch)
497
+
498
+ # Normalize DIT input
499
+ training_batch = self._normalize_dit_input(training_batch)
500
+ # Create noisy model input
501
+ training_batch = self._prepare_dit_inputs(training_batch)
502
+ assert training_batch.latents is not None
503
+ assert training_batch.noisy_model_input is not None
504
+ assert training_batch.noise is not None
505
+
506
+ # old sharding code, need to shard latents and noise but not input
507
+ # Shard latents across sp groups
508
+ training_batch.latents = training_batch.latents[:, :, :self.training_args.num_latent_t]
509
+ # shard noisy_model_input to match
510
+ training_batch.noisy_model_input = training_batch.noisy_model_input[:, :, :self.training_args.num_latent_t]
511
+ # shard noise to match latents
512
+ training_batch.noise = training_batch.noise[:, :, :self.training_args.num_latent_t]
513
+
514
+ training_batch = self._build_attention_metadata(training_batch)
515
+ training_batch = self._build_input_kwargs(training_batch)
516
+
517
+ training_batch = self._transformer_forward_and_compute_loss(training_batch)
518
+
519
+ training_batch = self._clip_grad_norm(training_batch)
520
+
521
+ # Only step the optimizer and scheduler for the model that is currently training
522
+ with self.tracker.timed("timing/optimizer_step"):
523
+ if self.train_transformer_2 and self.transformer_2 is not None:
524
+ self.optimizer_2.step()
525
+ self.lr_scheduler_2.step()
526
+ else:
527
+ self.optimizer.step()
528
+ self.lr_scheduler.step()
529
+
530
+ return training_batch
531
+
532
+ def _compute_current_sparsity(self, step: int) -> float:
533
+ """Compute the VSA sparsity for a given step using the decay schedule."""
534
+ vsa_sparsity = self.training_args.VSA_sparsity
535
+ vsa_decay_rate = self.training_args.VSA_decay_rate
536
+ vsa_decay_interval = self.training_args.VSA_decay_interval_steps
537
+ vsa_init = getattr(self.training_args, 'VSA_init_sparsity', 0.0)
538
+ vsa_warmup = getattr(self.training_args, 'VSA_warmup_steps', 0)
539
+ if step <= vsa_warmup:
540
+ return vsa_init
541
+ ramp_step = step - vsa_warmup
542
+ max_times = int((vsa_sparsity - vsa_init) / vsa_decay_rate) if vsa_decay_rate > 0 else 0
543
+ times = min(ramp_step // vsa_decay_interval, max_times)
544
+ return vsa_init + times * vsa_decay_rate
545
+
546
+ def _resolve_checkpoint_path(self, path: str) -> str | None:
547
+ """Resolve 'latest' to the most recent checkpoint in output_dir."""
548
+ import glob
549
+ if path == "latest":
550
+ output_dir = self.training_args.output_dir
551
+ ckpt_dirs = sorted(
552
+ glob.glob(os.path.join(output_dir, "checkpoint-*")),
553
+ key=lambda d: int(d.split("-")[-1]) if d.split("-")[-1].isdigit() else 0,
554
+ )
555
+ if ckpt_dirs:
556
+ latest = ckpt_dirs[-1]
557
+ logger.info("Auto-resolved 'latest' to %s", latest)
558
+ return latest
559
+ logger.info("No checkpoints found in %s, starting from scratch", output_dir)
560
+ return None
561
+ return path
562
+
563
+ def _resume_from_checkpoint(self) -> None:
564
+ ckpt_path = self._resolve_checkpoint_path(self.training_args.resume_from_checkpoint)
565
+ if ckpt_path is None:
566
+ logger.info("No checkpoint to resume from, starting from step 0")
567
+ return
568
+
569
+ safetensors_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model.safetensors")
570
+ step = int(os.path.basename(os.path.normpath(ckpt_path)).split('-')[-1])
571
+
572
+ resumed_step = load_checkpoint(self.transformer, self.global_rank, ckpt_path,
573
+ self.optimizer, self.train_dataloader,
574
+ self.lr_scheduler, self.noise_random_generator)
575
+ if resumed_step > 0 or step == 0:
576
+ self.init_steps = resumed_step
577
+ logger.info("Successfully resumed full training state from step %s", resumed_step)
578
+ return
579
+
580
+ if os.path.exists(safetensors_path):
581
+ self.init_steps = step
582
+ logger.warning("Distributed checkpoint resume failed; falling back to safetensors weights at step %s",
583
+ step)
584
+ return
585
+
586
+ logger.warning("No usable checkpoint state found at %s; starting from step 0", ckpt_path)
587
+ self.init_steps = 0
588
+
589
+ @profile_region("profiler_region_training_train")
590
+ def train(self) -> None:
591
+ assert self.seed is not None, "seed must be set"
592
+ assert self.training_args is not None, "training_args must be set"
593
+ set_random_seed(self.seed + self.global_rank)
594
+ logger.info('rank: %s: start training', self.global_rank, local_main_process_only=False)
595
+ if not self.post_init_called:
596
+ self.post_init()
597
+ num_trainable_params = count_trainable(self.transformer)
598
+ logger.info("Starting training with %s B trainable parameters", round(num_trainable_params / 1e9, 3))
599
+
600
+ if getattr(self, "transformer_2", None) is not None:
601
+ num_trainable_params = count_trainable(self.transformer_2)
602
+ logger.info("Transformer 2: Starting training with %s B trainable parameters",
603
+ round(num_trainable_params / 1e9, 3))
604
+
605
+ # Set random seeds for deterministic training
606
+ self.noise_random_generator = torch.Generator(device="cpu").manual_seed(self.seed + self.global_rank)
607
+ self.noise_gen_cuda = torch.Generator(device=current_platform.device_name).manual_seed(self.seed +
608
+ self.global_rank)
609
+ self.validation_random_generator = torch.Generator(device="cpu").manual_seed(self.seed + self.global_rank)
610
+ logger.info("Initialized random seeds with seed: %s", self.seed + self.global_rank)
611
+ self.noise_scheduler = FlowMatchEulerDiscreteScheduler()
612
+
613
+ if self.training_args.resume_from_checkpoint:
614
+ self._resume_from_checkpoint()
615
+
616
+ self.train_loader_iter = iter(self.train_dataloader)
617
+
618
+ step_times: deque[float] = deque(maxlen=100)
619
+
620
+ self._log_training_info()
621
+
622
+ # Validation at init uses the sparsity corresponding to init_steps
623
+ saved_sparsity = self.training_args.VSA_sparsity
624
+ self.training_args.VSA_sparsity = self._compute_current_sparsity(self.init_steps)
625
+ self._log_validation(self.transformer, self.training_args, self.init_steps)
626
+ self.training_args.VSA_sparsity = saved_sparsity
627
+
628
+ # Train!
629
+ progress_bar = tqdm(
630
+ range(0, self.training_args.max_train_steps),
631
+ initial=self.init_steps,
632
+ desc="Steps",
633
+ # Only show the progress bar once on each machine.
634
+ disable=self.local_rank > 0,
635
+ )
636
+ for step in range(self.init_steps + 1, self.training_args.max_train_steps + 1):
637
+ start_time = time.perf_counter()
638
+ if vsa_available:
639
+ vsa_sparsity = self.training_args.VSA_sparsity
640
+ vsa_decay_rate = self.training_args.VSA_decay_rate
641
+ vsa_decay_interval_steps = self.training_args.VSA_decay_interval_steps
642
+ vsa_init_sparsity = getattr(self.training_args, 'VSA_init_sparsity', 0.0)
643
+ vsa_warmup_steps = getattr(self.training_args, 'VSA_warmup_steps', 0)
644
+ if step <= vsa_warmup_steps:
645
+ current_vsa_sparsity = vsa_init_sparsity
646
+ else:
647
+ ramp_step = step - vsa_warmup_steps
648
+ max_decay_times = int((vsa_sparsity - vsa_init_sparsity) / vsa_decay_rate)
649
+ current_decay_times = min(ramp_step // vsa_decay_interval_steps, max_decay_times)
650
+ current_vsa_sparsity = vsa_init_sparsity + current_decay_times * vsa_decay_rate
651
+ elif vmoba_available:
652
+ #TODO: add vmoba sparsity scheduling here
653
+ current_vsa_sparsity = 0.0
654
+ else:
655
+ current_vsa_sparsity = 0.0
656
+
657
+ training_batch = TrainingBatch()
658
+ training_batch.current_timestep = step
659
+ training_batch.current_vsa_sparsity = current_vsa_sparsity
660
+ training_batch = self.train_one_step(training_batch)
661
+
662
+ loss = float(training_batch.total_loss)
663
+ grad_norm = training_batch.grad_norm
664
+
665
+ step_time = time.perf_counter() - start_time
666
+ step_times.append(step_time)
667
+ avg_step_time = sum(step_times) / len(step_times)
668
+
669
+ progress_bar.set_postfix({
670
+ "loss": f"{loss:.4f}",
671
+ "step_time": f"{step_time:.2f}s",
672
+ "grad_norm": grad_norm,
673
+ })
674
+ progress_bar.update(1)
675
+ if self.global_rank == 0:
676
+ metrics = {
677
+ "train_loss": loss,
678
+ "learning_rate": self.lr_scheduler.get_last_lr()[0],
679
+ "step_time": step_time,
680
+ "avg_step_time": avg_step_time,
681
+ "grad_norm": grad_norm,
682
+ "vsa_sparsity": current_vsa_sparsity,
683
+ }
684
+ try:
685
+ assert training_batch.raw_latent_shape is not None
686
+ metrics["batch_size"] = int(training_batch.raw_latent_shape[0])
687
+
688
+ patch_size = self.training_args.pipeline_config.dit_config.patch_size
689
+ assert isinstance(patch_size, tuple), f"Expected tuple patch_size, got {patch_size!r}"
690
+ patch_t, patch_h, patch_w = patch_size
691
+ seq_len = (training_batch.raw_latent_shape[2] // patch_t) * (
692
+ training_batch.raw_latent_shape[3] // patch_h) * (training_batch.raw_latent_shape[4] // patch_w)
693
+ if training_batch.encoder_hidden_states is not None:
694
+ context_len = int(training_batch.encoder_hidden_states.shape[1])
695
+ else:
696
+ context_len = 0
697
+
698
+ metrics["dit_seq_len"] = int(seq_len)
699
+ metrics["context_len"] = context_len
700
+
701
+ arch_config = self.training_args.pipeline_config.dit_config.arch_config
702
+
703
+ metrics["hidden_dim"] = arch_config.hidden_size
704
+ metrics["num_layers"] = arch_config.num_layers
705
+ metrics["ffn_dim"] = arch_config.ffn_dim
706
+ except Exception:
707
+ pass
708
+
709
+ self.tracker.log(metrics, step)
710
+ if step % self.training_args.training_state_checkpointing_steps == 0:
711
+ with self.profiler_controller.region("profiler_region_training_save_checkpoint"):
712
+ save_checkpoint(self.transformer, self.global_rank, self.training_args.output_dir, step,
713
+ self.optimizer, self.train_dataloader, self.lr_scheduler,
714
+ self.noise_random_generator,
715
+ self.training_args.checkpoints_total_limit)
716
+ self.transformer.train()
717
+ self.sp_group.barrier()
718
+
719
+ if self.training_args.log_visualization and step % self.training_args.visualization_steps == 0:
720
+ self.visualize_intermediate_latents(training_batch, self.training_args, step)
721
+
722
+ if self.training_args.log_validation and step % self.training_args.validation_steps == 0:
723
+ with self.profiler_controller.region("profiler_region_training_validation"):
724
+ saved_sparsity = self.training_args.VSA_sparsity
725
+ self.training_args.VSA_sparsity = current_vsa_sparsity
726
+ self._log_validation(self.transformer, self.training_args, step)
727
+ self.training_args.VSA_sparsity = saved_sparsity
728
+ gpu_memory_usage = current_platform.get_torch_device().memory_allocated() / 1024**2
729
+ trainable_params = round(count_trainable(self.transformer) / 1e9, 3)
730
+ logger.info("GPU memory usage after validation: %s MB, trainable params: %sB", gpu_memory_usage,
731
+ trainable_params)
732
+
733
+ self.tracker.finish()
734
+ save_checkpoint(self.transformer, self.global_rank, self.training_args.output_dir,
735
+ self.training_args.max_train_steps, self.optimizer, self.train_dataloader, self.lr_scheduler,
736
+ self.noise_random_generator, self.training_args.checkpoints_total_limit)
737
+
738
+ if envs.FASTVIDEO_TORCH_PROFILER_DIR:
739
+ logger.info("Stopping profiler...")
740
+ self.profiler_controller.stop()
741
+ logger.info("Profiler stopped.")
742
+
743
+ if get_sp_group():
744
+ cleanup_dist_env_and_memory()
745
+
746
+ def _log_training_info(self) -> None:
747
+ assert self.training_args is not None, "training_args must be set"
748
+ total_batch_size = (self.world_size * self.training_args.gradient_accumulation_steps /
749
+ self.training_args.sp_size * self.training_args.train_sp_batch_size)
750
+ logger.info("***** Running training *****")
751
+ logger.info(" Num examples = %s", len(self.train_dataset))
752
+ logger.info(" Dataloader size = %s", len(self.train_dataloader))
753
+ logger.info(" Num Epochs = %s", self.num_train_epochs)
754
+ logger.info(" Resume training from step %s", self.init_steps) # type: ignore
755
+ logger.info(" Instantaneous batch size per device = %s", self.training_args.train_batch_size)
756
+ logger.info(" Total train batch size (w. data & sequence parallel, accumulation) = %s", total_batch_size)
757
+ logger.info(" Gradient Accumulation steps = %s", self.training_args.gradient_accumulation_steps)
758
+ logger.info(" Total optimization steps = %s", self.training_args.max_train_steps)
759
+ logger.info(" Total training parameters per FSDP shard = %s B",
760
+ round(count_trainable(self.transformer) / 1e9, 3))
761
+ # print dtype
762
+ logger.info(" Master weight dtype: %s", self.transformer.parameters().__next__().dtype)
763
+
764
+ gpu_memory_usage = current_platform.get_torch_device().memory_allocated() / 1024**2
765
+ logger.info("GPU memory usage before train_one_step: %s MB", gpu_memory_usage)
766
+ logger.info("VSA validation sparsity: %s", self.training_args.VSA_sparsity)
767
+
768
+ def _prepare_validation_batch(self, sampling_param: SamplingParam, training_args: TrainingArgs,
769
+ validation_batch: dict[str, Any], num_inference_steps: int) -> ForwardBatch:
770
+ sampling_param.prompt = validation_batch['prompt']
771
+ sampling_param.height = training_args.num_height
772
+ sampling_param.width = training_args.num_width
773
+ sampling_param.num_inference_steps = num_inference_steps
774
+ sampling_param.data_type = "video"
775
+ if training_args.validation_guidance_scale:
776
+ sampling_param.guidance_scale = float(training_args.validation_guidance_scale)
777
+ assert self.seed is not None
778
+ sampling_param.seed = self.seed
779
+
780
+ latents_size = [(sampling_param.num_frames - 1) // 4 + 1, sampling_param.height // 8, sampling_param.width // 8]
781
+ n_tokens = latents_size[0] * latents_size[1] * latents_size[2]
782
+ temporal_compression_factor = training_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio
783
+ num_frames = (training_args.num_latent_t - 1) * temporal_compression_factor + 1
784
+ sampling_param.num_frames = num_frames
785
+ batch = ForwardBatch(
786
+ **shallow_asdict(sampling_param),
787
+ latents=None,
788
+ generator=self.validation_random_generator,
789
+ n_tokens=n_tokens,
790
+ eta=0.0,
791
+ VSA_sparsity=training_args.VSA_sparsity,
792
+ )
793
+
794
+ return batch
795
+
796
+ @torch.no_grad()
797
+ def _log_validation(self, transformer, training_args, global_step) -> None:
798
+ """
799
+ Generate a validation video and log it to the configured tracker to check the quality during training.
800
+ """
801
+ training_args.inference_mode = True
802
+ training_args.dit_cpu_offload = False
803
+ if not training_args.log_validation:
804
+ return
805
+ if self.validation_pipeline is None:
806
+ raise ValueError("Validation pipeline is not set")
807
+
808
+ logger.info("Starting validation")
809
+
810
+ # Create sampling parameters if not provided
811
+ sampling_param = SamplingParam.from_pretrained(training_args.model_path)
812
+
813
+ # Prepare validation prompts
814
+ logger.info('rank: %s: fastvideo_args.validation_dataset_file: %s',
815
+ self.global_rank,
816
+ training_args.validation_dataset_file,
817
+ local_main_process_only=False)
818
+ validation_dataset = ValidationDataset(training_args.validation_dataset_file)
819
+ validation_dataloader = DataLoader(validation_dataset, batch_size=None, num_workers=0)
820
+
821
+ self.transformer.eval()
822
+ if getattr(self, "transformer_2", None) is not None:
823
+ self.transformer_2.eval()
824
+
825
+ validation_steps = training_args.validation_sampling_steps.split(",")
826
+ validation_steps = [int(step) for step in validation_steps]
827
+ validation_steps = [step for step in validation_steps if step > 0]
828
+ # Log validation results for this step
829
+ world_group = get_world_group()
830
+ num_sp_groups = world_group.world_size // self.sp_group.world_size
831
+ one_prompt_per_rank = os.environ.get(
832
+ "FASTVIDEO_VALIDATION_ONE_PROMPT_PER_RANK",
833
+ "",
834
+ ).lower() in {"1", "true", "yes", "on"}
835
+
836
+ # Process each validation prompt for each validation step
837
+ for num_inference_steps in validation_steps:
838
+ logger.info("rank: %s: num_inference_steps: %s",
839
+ self.global_rank,
840
+ num_inference_steps,
841
+ local_main_process_only=False)
842
+ step_videos: list[np.ndarray] = []
843
+ step_captions: list[str] = []
844
+
845
+ step_audio: list[np.ndarray | None] = []
846
+ step_sample_rates: list[int | None] = []
847
+
848
+ for prompt_idx, validation_batch in enumerate(validation_dataloader):
849
+ if one_prompt_per_rank and prompt_idx > 0:
850
+ continue
851
+
852
+ batch = self._prepare_validation_batch(sampling_param, training_args, validation_batch,
853
+ num_inference_steps)
854
+ logger.info("rank: %s: rank_in_sp_group: %s, batch.prompt: %s",
855
+ self.global_rank,
856
+ self.rank_in_sp_group,
857
+ batch.prompt,
858
+ local_main_process_only=False)
859
+
860
+ assert batch.prompt is not None and isinstance(batch.prompt, str)
861
+ step_captions.append(batch.prompt)
862
+
863
+ # Run validation inference
864
+ output_batch = self.validation_pipeline.forward(batch, training_args)
865
+ samples = output_batch.output.cpu()
866
+
867
+ # Capture audio if available
868
+ audio = output_batch.extra.get("audio")
869
+ sample_rate = output_batch.extra.get("audio_sample_rate")
870
+
871
+ if audio is not None and torch.is_tensor(audio):
872
+ audio = audio.detach().cpu().float().numpy()
873
+
874
+ step_audio.append(audio)
875
+ step_sample_rates.append(sample_rate)
876
+
877
+ if self.rank_in_sp_group != 0:
878
+ continue
879
+
880
+ # Process outputs
881
+ video = rearrange(samples, "b c t h w -> t b c h w")
882
+ frames = []
883
+ for x in video:
884
+ x = torchvision.utils.make_grid(x, nrow=6)
885
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
886
+ frames.append((x * 255).numpy().astype(np.uint8))
887
+ step_videos.append(frames)
888
+
889
+ # Only sp_group leaders (rank_in_sp_group == 0) need to send their
890
+ # results to global rank 0
891
+ if self.rank_in_sp_group == 0 and self.global_rank == 0:
892
+ # Global rank 0 collects results from all sp_group leaders
893
+ all_videos = step_videos # Start with own results
894
+ all_captions = step_captions
895
+ all_audios = step_audio
896
+ all_sample_rates = step_sample_rates
897
+
898
+ # Receive from other sp_group leaders
899
+ for sp_group_idx in range(1, num_sp_groups):
900
+ src_rank = sp_group_idx * self.sp_world_size # Global rank of other sp_group leaders
901
+ recv_videos = world_group.recv_object(src=src_rank)
902
+ recv_captions = world_group.recv_object(src=src_rank)
903
+ recv_audios = world_group.recv_object(src=src_rank)
904
+ recv_sample_rates = world_group.recv_object(src=src_rank)
905
+
906
+ all_videos.extend(recv_videos)
907
+ all_captions.extend(recv_captions)
908
+ all_audios.extend(recv_audios)
909
+ all_sample_rates.extend(recv_sample_rates)
910
+
911
+ video_filenames = []
912
+ for i, (video, caption, audio, sample_rate) in enumerate(
913
+ zip(all_videos, all_captions, all_audios, all_sample_rates, strict=True)):
914
+ os.makedirs(training_args.output_dir, exist_ok=True)
915
+ filename = os.path.join(
916
+ training_args.output_dir,
917
+ f"validation_step_{global_step}_inference_steps_{num_inference_steps}_video_{i}.mp4")
918
+ imageio.mimsave(filename, video, fps=sampling_param.fps)
919
+ # Mux audio if available
920
+ if (audio is not None and sample_rate is not None and not self._mux_audio(
921
+ filename,
922
+ audio,
923
+ sample_rate,
924
+ )):
925
+ logger.warning("Audio mux failed for validation video %s; saved video without audio.", filename)
926
+ video_filenames.append(filename)
927
+
928
+ artifacts = []
929
+ for filename, caption in zip(video_filenames, all_captions, strict=True):
930
+ video_artifact = self.tracker.video(filename, caption=caption)
931
+ if video_artifact is not None:
932
+ artifacts.append(video_artifact)
933
+ if artifacts:
934
+ logs = {f"validation_videos_{num_inference_steps}_steps": artifacts}
935
+ self.tracker.log_artifacts(logs, global_step)
936
+ elif self.rank_in_sp_group == 0:
937
+ # Other sp_group leaders send their results to global rank 0
938
+ world_group.send_object(step_videos, dst=0)
939
+ world_group.send_object(step_captions, dst=0)
940
+ world_group.send_object(step_audio, dst=0)
941
+ world_group.send_object(step_sample_rates, dst=0)
942
+
943
+ world_group.barrier()
944
+
945
+ # Re-enable gradients for training
946
+ training_args.inference_mode = False
947
+ self.transformer.train()
948
+ if getattr(self, "transformer_2", None) is not None:
949
+ self.transformer_2.train()
950
+
951
+ @staticmethod
952
+ def _mux_audio(
953
+ video_path: str,
954
+ audio: torch.Tensor | np.ndarray,
955
+ sample_rate: int,
956
+ ) -> bool:
957
+ """Mux audio into video using PyAV."""
958
+ try:
959
+ import av
960
+ except ImportError:
961
+ logger.warning("PyAV not installed; cannot mux audio. "
962
+ "Install with: pip install av")
963
+ return False
964
+
965
+ if torch.is_tensor(audio):
966
+ audio_np = audio.detach().cpu().float().numpy()
967
+ else:
968
+ audio_np = np.asarray(audio, dtype=np.float32)
969
+
970
+ if audio_np.ndim == 1:
971
+ audio_np = audio_np[:, None]
972
+ elif audio_np.ndim == 2:
973
+ if audio_np.shape[0] <= 8 and audio_np.shape[1] > audio_np.shape[0]:
974
+ audio_np = audio_np.T
975
+ else:
976
+ logger.warning("Unexpected audio shape %s; skipping mux.", audio_np.shape)
977
+ return False
978
+
979
+ audio_np = np.clip(audio_np, -1.0, 1.0)
980
+ audio_int16 = (audio_np * 32767.0).astype(np.int16)
981
+ num_channels = audio_int16.shape[1]
982
+ layout = "stereo" if num_channels == 2 else "mono"
983
+
984
+ try:
985
+ import wave
986
+ with tempfile.TemporaryDirectory() as tmpdir:
987
+ out_path = os.path.join(tmpdir, "muxed.mp4")
988
+ wav_path = os.path.join(tmpdir, "audio.wav")
989
+
990
+ # Write audio to WAV file
991
+ with wave.open(wav_path, "wb") as wav_file:
992
+ wav_file.setnchannels(num_channels)
993
+ wav_file.setsampwidth(2)
994
+ wav_file.setframerate(sample_rate)
995
+ wav_file.writeframes(audio_int16.tobytes())
996
+
997
+ # Open input video and audio
998
+ input_video = av.open(video_path)
999
+ input_audio = av.open(wav_path)
1000
+
1001
+ # Create output with both streams
1002
+ output = av.open(out_path, mode="w")
1003
+
1004
+ # Add video stream (copy codec from input)
1005
+ in_video_stream = input_video.streams.video[0]
1006
+ out_video_stream = output.add_stream(
1007
+ codec_name=in_video_stream.codec_context.name,
1008
+ rate=in_video_stream.average_rate,
1009
+ )
1010
+ out_video_stream.width = in_video_stream.width
1011
+ out_video_stream.height = in_video_stream.height
1012
+ out_video_stream.pix_fmt = in_video_stream.pix_fmt
1013
+
1014
+ # Add audio stream (AAC)
1015
+ out_audio_stream = output.add_stream("aac", rate=sample_rate)
1016
+ out_audio_stream.layout = layout
1017
+
1018
+ # Remux video (decode and re-encode to be safe)
1019
+ for frame in input_video.decode(video=0):
1020
+ for packet in out_video_stream.encode(frame):
1021
+ output.mux(packet)
1022
+ for packet in out_video_stream.encode():
1023
+ output.mux(packet)
1024
+
1025
+ # Encode audio
1026
+ for frame in input_audio.decode(audio=0):
1027
+ frame.pts = None # Let encoder assign PTS
1028
+ for packet in out_audio_stream.encode(frame):
1029
+ output.mux(packet)
1030
+ for packet in out_audio_stream.encode():
1031
+ output.mux(packet)
1032
+
1033
+ input_video.close()
1034
+ input_audio.close()
1035
+ output.close()
1036
+ shutil.move(out_path, video_path)
1037
+ return True
1038
+ except Exception as e:
1039
+ logger.warning("Audio mux failed: %s", e)
1040
+ return False
1041
+
1042
+ def visualize_intermediate_latents(self, training_batch: TrainingBatch, training_args: TrainingArgs, step: int):
1043
+ """Add visualization data to tracker logging and save frames to disk."""
1044
+ raise NotImplementedError("Visualize intermediate latents is not implemented for training pipeline")
standalone_inference/overlay_files/fastvideo/training/wan_training_pipeline.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import sys
3
+ from copy import deepcopy
4
+
5
+ from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs
6
+ from fastvideo.logger import init_logger
7
+ from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import (FlowUniPCMultistepScheduler)
8
+ from fastvideo.pipelines.basic.wan.wan_pipeline import WanPipeline
9
+ from fastvideo.training.training_pipeline import TrainingPipeline
10
+ from fastvideo.utils import is_vsa_available
11
+
12
+ try:
13
+ vsa_available = is_vsa_available()
14
+ except Exception:
15
+ vsa_available = False
16
+
17
+ logger = init_logger(__name__)
18
+
19
+
20
+ class WanTrainingPipeline(TrainingPipeline):
21
+ """
22
+ A training pipeline for Wan.
23
+ """
24
+ _required_config_modules = ["scheduler", "transformer", "vae"]
25
+
26
+ def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
27
+ self.modules["scheduler"] = FlowUniPCMultistepScheduler(shift=fastvideo_args.pipeline_config.flow_shift)
28
+
29
+ def create_training_stages(self, training_args: TrainingArgs):
30
+ """
31
+ May be used in future refactors.
32
+ """
33
+ pass
34
+
35
+ def initialize_validation_pipeline(self, training_args: TrainingArgs):
36
+ logger.info("Initializing validation pipeline...")
37
+ args_copy = deepcopy(training_args)
38
+
39
+ args_copy.inference_mode = True
40
+ validation_pipeline = WanPipeline.from_pretrained(
41
+ training_args.model_path,
42
+ args=args_copy, # type: ignore
43
+ inference_mode=True,
44
+ loaded_modules={
45
+ "transformer": self.get_module("transformer"),
46
+ },
47
+ tp_size=training_args.tp_size,
48
+ sp_size=training_args.sp_size,
49
+ num_gpus=training_args.num_gpus,
50
+ pin_cpu_memory=training_args.pin_cpu_memory,
51
+ dit_cpu_offload=True)
52
+
53
+ self.validation_pipeline = validation_pipeline
54
+
55
+
56
+ def main(args) -> None:
57
+ logger.info("Starting training pipeline...")
58
+
59
+ pipeline = WanTrainingPipeline.from_pretrained(args.pretrained_model_name_or_path, args=args)
60
+ args = pipeline.training_args
61
+ pipeline.train()
62
+ logger.info("Training pipeline done")
63
+
64
+
65
+ if __name__ == "__main__":
66
+ argv = sys.argv
67
+ from fastvideo.fastvideo_args import TrainingArgs
68
+ from fastvideo.utils import FlexibleArgumentParser
69
+ parser = FlexibleArgumentParser()
70
+ parser = TrainingArgs.add_cli_args(parser)
71
+ parser = FastVideoArgs.add_cli_args(parser)
72
+ args = parser.parse_args()
73
+ args.dit_cpu_offload = False
74
+ main(args)
standalone_inference/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Install FastVideo itself from the upstream project or from your local checkout.
2
+ # This file only lists the extra Python packages directly used by the helper.
3
+ huggingface_hub
4
+ safetensors
5
+ triton
standalone_inference/run.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -euo pipefail
4
+
5
+ BUNDLE_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
6
+ FASTVIDEO_ROOT="${FASTVIDEO_ROOT:-}"
7
+
8
+ if [[ -z "${FASTVIDEO_ROOT}" ]]; then
9
+ echo "FASTVIDEO_ROOT is not set."
10
+ echo "Set it to a FastVideo source checkout or installed package root, for example:"
11
+ echo " FASTVIDEO_ROOT=/path/to/FastVideo bash standalone_inference/run.sh"
12
+ exit 1
13
+ fi
14
+
15
+ python "${BUNDLE_ROOT}/install_overlay.py" --fastvideo-root "${FASTVIDEO_ROOT}"
16
+
17
+ export PYTHONPATH="${FASTVIDEO_ROOT}/fastvideo-kernel/python:${FASTVIDEO_ROOT}/fastvideo-kernel:${PYTHONPATH:-}"
18
+ export FASTVIDEO_ATTENTION_BACKEND=SPARSE_FP4_OURS_P_ATTN
19
+ export FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O=1
20
+
21
+ cd "${FASTVIDEO_ROOT}"
22
+ python "${BUNDLE_ROOT}/run_inference.py" "$@"
standalone_inference/run_inference.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Run Wan T2V inference with the sparse FP4 checkpoint-700 transformer."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import os
8
+ from pathlib import Path
9
+
10
+
11
+ DEFAULT_PROMPT = (
12
+ "In the video, a woman is elegantly showcasing her earrings, bringing "
13
+ "attention to their intricate design with a gentle touch of her fingers. "
14
+ "She is bathed in ambient purple and pink lighting, which casts a soft "
15
+ "glow on her delicate features and enhances the vivid tones of her lipstick "
16
+ "and eye makeup. Her hair is styled to frame her face smoothly, emphasizing "
17
+ "the contours of her jawline and cheekbones. The background features a "
18
+ "blurred neon light, adding an artistic and modern touch to the overall "
19
+ "aesthetic."
20
+ )
21
+
22
+ DEFAULT_NEGATIVE_PROMPT = (
23
+ "Bright tones, overexposed, static, blurred details, subtitles, style, "
24
+ "works, paintings, images, static, overall gray, worst quality, low quality, "
25
+ "JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn "
26
+ "hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused "
27
+ "fingers, still picture, messy background, three legs, many people in the "
28
+ "background, walking backwards"
29
+ )
30
+
31
+
32
+ def _resolve_weights(repo_id: str, weights: str | None, local_dir: str) -> str:
33
+ if weights:
34
+ path = Path(weights).expanduser()
35
+ if path.exists():
36
+ return str(path.resolve())
37
+ raise FileNotFoundError(f"--weights does not exist: {path}")
38
+
39
+ from huggingface_hub import hf_hub_download
40
+
41
+ path = hf_hub_download(
42
+ repo_id=repo_id,
43
+ filename="transformer/diffusion_pytorch_model.safetensors",
44
+ local_dir=local_dir,
45
+ repo_type="model",
46
+ )
47
+ return str(Path(path).resolve())
48
+
49
+
50
+ def main() -> int:
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument("--repo-id", default="yitongl/sparse_quant_exp")
53
+ parser.add_argument(
54
+ "--model-path",
55
+ default="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
56
+ help="Base Wan Diffusers model repo/path.",
57
+ )
58
+ parser.add_argument("--weights", default=None)
59
+ parser.add_argument(
60
+ "--local-dir",
61
+ default="checkpoints/hf_download/sparse_quant_exp",
62
+ help="Local Hugging Face download directory for the uploaded weights.",
63
+ )
64
+ parser.add_argument("--prompt", default=DEFAULT_PROMPT)
65
+ parser.add_argument("--negative-prompt", default=DEFAULT_NEGATIVE_PROMPT)
66
+ parser.add_argument("--output-path", default="outputs/sfp4_checkpoint_700")
67
+ parser.add_argument("--height", type=int, default=448)
68
+ parser.add_argument("--width", type=int, default=832)
69
+ parser.add_argument("--num-frames", type=int, default=77)
70
+ parser.add_argument("--num-inference-steps", type=int, default=50)
71
+ parser.add_argument("--fps", type=int, default=16)
72
+ parser.add_argument("--guidance-scale", type=float, default=5.0)
73
+ parser.add_argument("--flow-shift", type=float, default=1.0)
74
+ parser.add_argument("--seed", type=int, default=1000)
75
+ parser.add_argument("--vsa-sparsity", type=float, default=0.9)
76
+ parser.add_argument("--num-gpus", type=int, default=1)
77
+ parser.add_argument("--sp-size", type=int, default=1)
78
+ parser.add_argument("--tp-size", type=int, default=1)
79
+ parser.add_argument("--text-encoder-cpu-offload", action="store_true", default=True)
80
+ parser.add_argument("--pin-cpu-memory", action="store_true", default=False)
81
+ args = parser.parse_args()
82
+
83
+ os.environ.setdefault("FASTVIDEO_ATTENTION_BACKEND", "SPARSE_FP4_OURS_P_ATTN")
84
+ os.environ.setdefault("FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O", "1")
85
+
86
+ weights_path = _resolve_weights(args.repo_id, args.weights, args.local_dir)
87
+
88
+ from fastvideo import VideoGenerator
89
+
90
+ generator = VideoGenerator.from_pretrained(
91
+ model_path=args.model_path,
92
+ num_gpus=args.num_gpus,
93
+ sp_size=args.sp_size,
94
+ tp_size=args.tp_size,
95
+ init_weights_from_safetensors=weights_path,
96
+ dit_cpu_offload=False,
97
+ vae_cpu_offload=False,
98
+ text_encoder_cpu_offload=args.text_encoder_cpu_offload,
99
+ pin_cpu_memory=args.pin_cpu_memory,
100
+ flow_shift=args.flow_shift,
101
+ VSA_sparsity=args.vsa_sparsity,
102
+ )
103
+
104
+ result = generator.generate_video(
105
+ prompt=args.prompt,
106
+ negative_prompt=args.negative_prompt,
107
+ output_path=args.output_path,
108
+ save_video=True,
109
+ return_frames=False,
110
+ height=args.height,
111
+ width=args.width,
112
+ num_frames=args.num_frames,
113
+ num_inference_steps=args.num_inference_steps,
114
+ fps=args.fps,
115
+ guidance_scale=args.guidance_scale,
116
+ seed=args.seed,
117
+ )
118
+ print(result)
119
+ return 0
120
+
121
+
122
+ if __name__ == "__main__":
123
+ raise SystemExit(main())
standalone_inference/training_attention_settings.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "run_name": "sfp4_v4_sparse09_hpo_on_ours_p_init2050_1n_interactive",
3
+ "checkpoint": "checkpoint-700",
4
+ "training_method": "legacy_sft_wan_training_pipeline",
5
+ "model_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
6
+ "init_weights_from_safetensors": "checkpoints/init/sfp4_v4_sparse06_hpo_on_ours_p_1n_interactive_v2_ckpt2050/transformer/diffusion_pytorch_model.safetensors",
7
+ "environment": {
8
+ "FASTVIDEO_ATTENTION_BACKEND": "SPARSE_FP4_OURS_P_ATTN",
9
+ "FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O": "1",
10
+ "FASTVIDEO_VALIDATION_ONE_PROMPT_PER_RANK": "1",
11
+ "WANDB_MODE": "online",
12
+ "WANDB_RESUME": "allow"
13
+ },
14
+ "vsa_schedule": {
15
+ "VSA_SPARSITY": 0.9,
16
+ "VSA_INIT_SPARSITY": 0.9,
17
+ "VSA_WARMUP_STEPS": 0,
18
+ "VSA_DECAY_RATE": 0.03,
19
+ "VSA_DECAY_INTERVAL_STEPS": 50,
20
+ "effective_sparsity_from_step_0": 0.9
21
+ },
22
+ "attention_semantics": {
23
+ "selected_backend": "SPARSE_FP4_OURS_P_ATTN",
24
+ "self_attention": {
25
+ "backend_path": "fastvideo/attention/backends/sparse_fp4_ours_p_attn.py",
26
+ "kernel_path": "fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py",
27
+ "tile_size_video": [4, 4, 4],
28
+ "tile_tokens": 64,
29
+ "qkv_quantization": "FP4 fake quantization with STE, no q/k mean subtraction in quantization",
30
+ "block_selection": "top-k blocks from q_c @ k_c tile-mean scores",
31
+ "p_quantization": "group-local exp2(qk - group_max) FP4 fake quantization; compensation multiplies exp2(group_max - running_row_m)",
32
+ "dropped_tile_handling": "tile-level q_mean/k_mean score and mean_v compensation"
33
+ },
34
+ "cross_attention": {
35
+ "backend": "dense_sdpa",
36
+ "reason": "sparse_fp4_ours_p_attn.py treats query_length != key_length as cross attention and returns _dense_sdpa_blhd",
37
+ "quantized": false,
38
+ "sparse": false
39
+ },
40
+ "force_dense": {
41
+ "backend": "dense_sdpa",
42
+ "used_for": "teacher or explicitly forced dense paths, not the normal SFT student self-attention path"
43
+ }
44
+ },
45
+ "validation_and_checkpointing": {
46
+ "save_steps": 50,
47
+ "eval_steps": 50,
48
+ "validation_sampling_steps": 50,
49
+ "validation_guidance_scale": 5.0,
50
+ "checkpoints_total_limit": 5,
51
+ "flow_shift": 1.0
52
+ },
53
+ "training_shape": {
54
+ "num_latent_t": 20,
55
+ "num_frames": 77,
56
+ "height": 448,
57
+ "width": 832,
58
+ "batch_size_per_gpu": 1,
59
+ "sp_size": 1,
60
+ "tp_size": 1
61
+ }
62
+ }