Add text-to-video pipeline tag and project links
#1
by
nielsr
HF Staff
- opened
README.md
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
base_model:
|
| 4 |
- Wan-AI/Wan2.1-T2V-1.3B
|
|
|
|
|
|
|
| 5 |
---
|
|
|
|
| 6 |
<div align="center">
|
| 7 |
|
| 8 |
# Causal Forcing
|
|
@@ -25,17 +27,17 @@ base_model:
|
|
| 25 |
|
| 26 |
</div>
|
| 27 |
</p>
|
| 28 |
-
<h3 align="center"><a href="https://arxiv.org/abs/2602.02214">Paper</a> | <a href="https://thu-ml.github.io/CausalForcing.github.io">Website</a> | <a href="https://huggingface.co/zhuhz22/Causal-Forcing/tree/main">Models</a> | <a href="assets/wechat.jpg">WeChat</a></h3>
|
| 29 |
</p>
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-----
|
|
|
|
| 34 |
Causal Forcing significantly outperforms Self Forcing in **both visual quality and motion dynamics**, while keeping **the same training budget and inference efficiency**—enabling real-time, streaming video generation on a single RTX 4090.
|
| 35 |
|
| 36 |
-----
|
| 37 |
|
| 38 |
-
|
|
|
|
| 39 |
|
| 40 |
## Quick Start
|
| 41 |
|
|
@@ -82,8 +84,6 @@ python inference.py \
|
|
| 82 |
--data_path prompts/demos.txt
|
| 83 |
```
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
## Training
|
| 88 |
|
| 89 |
<details>
|
|
@@ -94,8 +94,6 @@ First download the dataset (we provide a 6K toy dataset here):
|
|
| 94 |
hf download zhuhz22/Causal-Forcing-data --local-dir dataset
|
| 95 |
python utils/merge_and_get_clean.py
|
| 96 |
```
|
| 97 |
-
> If the download gets stuck, Ctrl^C and then resume it.
|
| 98 |
-
|
| 99 |
|
| 100 |
Then train the AR-diffusion model:
|
| 101 |
- Framewise:
|
|
@@ -118,12 +116,8 @@ Then train the AR-diffusion model:
|
|
| 118 |
--logdir logs/ar_diffusion_chunkwise
|
| 119 |
```
|
| 120 |
|
| 121 |
-
> We recommend training no less than 2K steps, and more steps (e.g., 5~10K) will lead to better performance.
|
| 122 |
-
|
| 123 |
-
|
| 124 |
</details>
|
| 125 |
|
| 126 |
-
|
| 127 |
<details>
|
| 128 |
<summary> Stage 2: Causal ODE Initialization (Can skip by using our pretrained checkpoints. Click to expand.)</summary>
|
| 129 |
|
|
@@ -133,40 +127,7 @@ hf download zhuhz22/Causal-Forcing framewise/ar_diffusion.pt --local-dir checkpo
|
|
| 133 |
hf download zhuhz22/Causal-Forcing chunkwise/ar_diffusion.pt --local-dir checkpoints
|
| 134 |
```
|
| 135 |
|
| 136 |
-
In this stage,
|
| 137 |
-
```bash
|
| 138 |
-
# for the frame-wise model
|
| 139 |
-
torchrun --nproc_per_node=8 \
|
| 140 |
-
get_causal_ode_data_framewise.py \
|
| 141 |
-
--generator_ckpt checkpoints/framewise/ar_diffusion.pt \
|
| 142 |
-
--rawdata_path dataset/clean_data \
|
| 143 |
-
--output_folder dataset/ODE6KCausal_framewise_latents
|
| 144 |
-
|
| 145 |
-
python utils/create_lmdb_iterative.py \
|
| 146 |
-
--data_path dataset/ODE6KCausal_framewise_latents \
|
| 147 |
-
--lmdb_path dataset/ODE6KCausal_framewise
|
| 148 |
-
|
| 149 |
-
# for the chunk-wise model
|
| 150 |
-
torchrun --nproc_per_node=8 \
|
| 151 |
-
get_causal_ode_data_chunkwise.py \
|
| 152 |
-
--generator_ckpt checkpoints/chunkwise/ar_diffusion.pt \
|
| 153 |
-
--rawdata_path dataset/clean_data \
|
| 154 |
-
--output_folder dataset/ODE6KCausal_chunkwise_latents
|
| 155 |
-
|
| 156 |
-
python utils/create_lmdb_iterative.py \
|
| 157 |
-
--data_path dataset/ODE6KCausal_chunkwise_latents \
|
| 158 |
-
--lmdb_path dataset/ODE6KCausal_chunkwise
|
| 159 |
-
```
|
| 160 |
-
|
| 161 |
-
Or you can also directly download our prepared dataset (~300G):
|
| 162 |
-
```bash
|
| 163 |
-
hf download zhuhz22/Causal-Forcing-data --local-dir dataset
|
| 164 |
-
python utils/merge_lmdb.py
|
| 165 |
-
```
|
| 166 |
-
> If the download gets stuck, Ctrl^C and then resume it.
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
And then train ODE initialization models:
|
| 170 |
- Frame-wise:
|
| 171 |
```bash
|
| 172 |
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
|
|
@@ -176,37 +137,12 @@ And then train ODE initialization models:
|
|
| 176 |
--config_path configs/causal_ode_framewise.yaml \
|
| 177 |
--logdir logs/causal_ode_framewise
|
| 178 |
```
|
| 179 |
-
- Chunk-wise:
|
| 180 |
-
```bash
|
| 181 |
-
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
|
| 182 |
-
--rdzv_backend=c10d \
|
| 183 |
-
--rdzv_endpoint $MASTER_ADDR \
|
| 184 |
-
train.py \
|
| 185 |
-
--config_path configs/causal_ode_chunkwise.yaml \
|
| 186 |
-
--logdir logs/causal_ode_chunkwise
|
| 187 |
-
```
|
| 188 |
-
|
| 189 |
-
> We recommend training no less than 1K steps, and more steps (e.g., 5~10K) will lead to better performance.
|
| 190 |
-
|
| 191 |
</details>
|
| 192 |
|
| 193 |
### Stage 3: DMD
|
| 194 |
|
| 195 |
> This stage is compatible with Self Forcing training, so you can migrate seamlessly by using our configs and checkpoints.
|
| 196 |
|
| 197 |
-
|
| 198 |
-
First download the dataset:
|
| 199 |
-
```bash
|
| 200 |
-
hf download gdhe17/Self-Forcing vidprom_filtered_extended.txt --local-dir prompts
|
| 201 |
-
```
|
| 202 |
-
If you have skipped Stage 2, you need to download the pretrained checkpoints:
|
| 203 |
-
```bash
|
| 204 |
-
hf download zhuhz22/Causal-Forcing framewise/causal_ode.pt --local-dir checkpoints
|
| 205 |
-
hf download zhuhz22/Causal-Forcing chunkwise/causal_ode.pt --local-dir checkpoints
|
| 206 |
-
```
|
| 207 |
-
|
| 208 |
-
And then train DMD models:
|
| 209 |
-
|
| 210 |
- Frame-wise model:
|
| 211 |
```bash
|
| 212 |
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
|
|
@@ -216,28 +152,12 @@ And then train DMD models:
|
|
| 216 |
--config_path configs/causal_forcing_dmd_framewise.yaml \
|
| 217 |
--logdir logs/causal_forcing_dmd_framewise
|
| 218 |
```
|
| 219 |
-
> We recommend training 500 steps. More than 1K steps will reduce dynamic degree.
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
- Chunk-wise model:
|
| 223 |
-
```bash
|
| 224 |
-
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
|
| 225 |
-
--rdzv_backend=c10d \
|
| 226 |
-
--rdzv_endpoint $MASTER_ADDR \
|
| 227 |
-
train.py \
|
| 228 |
-
--config_path configs/causal_forcing_dmd_chunkwise.yaml \
|
| 229 |
-
--logdir logs/causal_forcing_dmd_chunkwise
|
| 230 |
-
```
|
| 231 |
-
> We recommend training 100~200 steps. More than 1K steps will reduce dynamic degree.
|
| 232 |
-
|
| 233 |
-
Such models are the final models used to generate videos.
|
| 234 |
|
| 235 |
## Acknowledgements
|
| 236 |
This codebase is built on top of the open-source implementation of [CausVid](https://github.com/tianweiy/CausVid), [Self Forcing](https://github.com/guandeh17/Self-Forcing) and the [Wan2.1](https://github.com/Wan-Video/Wan2.1) repo.
|
| 237 |
|
| 238 |
## References
|
| 239 |
-
|
| 240 |
-
```
|
| 241 |
@misc{zhu2026causalforcingautoregressivediffusion,
|
| 242 |
title={Causal Forcing: Autoregressive Diffusion Distillation Done Right for High-Quality Real-Time Interactive Video Generation},
|
| 243 |
author={Hongzhou Zhu and Min Zhao and Guande He and Hang Su and Chongxuan Li and Jun Zhu},
|
|
|
|
| 1 |
---
|
|
|
|
| 2 |
base_model:
|
| 3 |
- Wan-AI/Wan2.1-T2V-1.3B
|
| 4 |
+
license: apache-2.0
|
| 5 |
+
pipeline_tag: text-to-video
|
| 6 |
---
|
| 7 |
+
|
| 8 |
<div align="center">
|
| 9 |
|
| 10 |
# Causal Forcing
|
|
|
|
| 27 |
|
| 28 |
</div>
|
| 29 |
</p>
|
| 30 |
+
<h3 align="center"><a href="https://arxiv.org/abs/2602.02214">Paper</a> | <a href="https://thu-ml.github.io/CausalForcing.github.io">Website</a> | <a href="https://github.com/thu-ml/Causal-Forcing">Code</a> | <a href="https://huggingface.co/zhuhz22/Causal-Forcing/tree/main">Models</a> | <a href="assets/wechat.jpg">WeChat</a></h3>
|
| 31 |
</p>
|
| 32 |
|
|
|
|
|
|
|
| 33 |
-----
|
| 34 |
+
|
| 35 |
Causal Forcing significantly outperforms Self Forcing in **both visual quality and motion dynamics**, while keeping **the same training budget and inference efficiency**—enabling real-time, streaming video generation on a single RTX 4090.
|
| 36 |
|
| 37 |
-----
|
| 38 |
|
| 39 |
+
## Abstract
|
| 40 |
+
To achieve real-time interactive video generation, current methods distill pretrained bidirectional video diffusion models into few-step autoregressive (AR) models, facing an architectural gap when full attention is replaced by causal attention. We propose **Causal Forcing** that uses an AR teacher for ODE initialization, thereby bridging the architectural gap. Empirical results show that our method outperforms all baselines across all metrics, surpassing the SOTA Self Forcing by 19.3% in Dynamic Degree, 8.7% in VisionReward, and 16.7% in Instruction Following.
|
| 41 |
|
| 42 |
## Quick Start
|
| 43 |
|
|
|
|
| 84 |
--data_path prompts/demos.txt
|
| 85 |
```
|
| 86 |
|
|
|
|
|
|
|
| 87 |
## Training
|
| 88 |
|
| 89 |
<details>
|
|
|
|
| 94 |
hf download zhuhz22/Causal-Forcing-data --local-dir dataset
|
| 95 |
python utils/merge_and_get_clean.py
|
| 96 |
```
|
|
|
|
|
|
|
| 97 |
|
| 98 |
Then train the AR-diffusion model:
|
| 99 |
- Framewise:
|
|
|
|
| 116 |
--logdir logs/ar_diffusion_chunkwise
|
| 117 |
```
|
| 118 |
|
|
|
|
|
|
|
|
|
|
| 119 |
</details>
|
| 120 |
|
|
|
|
| 121 |
<details>
|
| 122 |
<summary> Stage 2: Causal ODE Initialization (Can skip by using our pretrained checkpoints. Click to expand.)</summary>
|
| 123 |
|
|
|
|
| 127 |
hf download zhuhz22/Causal-Forcing chunkwise/ar_diffusion.pt --local-dir checkpoints
|
| 128 |
```
|
| 129 |
|
| 130 |
+
In this stage, train ODE initialization models:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
- Frame-wise:
|
| 132 |
```bash
|
| 133 |
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
|
|
|
|
| 137 |
--config_path configs/causal_ode_framewise.yaml \
|
| 138 |
--logdir logs/causal_ode_framewise
|
| 139 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
</details>
|
| 141 |
|
| 142 |
### Stage 3: DMD
|
| 143 |
|
| 144 |
> This stage is compatible with Self Forcing training, so you can migrate seamlessly by using our configs and checkpoints.
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
- Frame-wise model:
|
| 147 |
```bash
|
| 148 |
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
|
|
|
|
| 152 |
--config_path configs/causal_forcing_dmd_framewise.yaml \
|
| 153 |
--logdir logs/causal_forcing_dmd_framewise
|
| 154 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
## Acknowledgements
|
| 157 |
This codebase is built on top of the open-source implementation of [CausVid](https://github.com/tianweiy/CausVid), [Self Forcing](https://github.com/guandeh17/Self-Forcing) and the [Wan2.1](https://github.com/Wan-Video/Wan2.1) repo.
|
| 158 |
|
| 159 |
## References
|
| 160 |
+
```bibtex
|
|
|
|
| 161 |
@misc{zhu2026causalforcingautoregressivediffusion,
|
| 162 |
title={Causal Forcing: Autoregressive Diffusion Distillation Done Right for High-Quality Real-Time Interactive Video Generation},
|
| 163 |
author={Hongzhou Zhu and Min Zhao and Guande He and Hang Su and Chongxuan Li and Jun Zhu},
|