Causal-Forcing / README.md
zhuhz22's picture
Update README.md
145fd52 verified
metadata
base_model:
  - Wan-AI/Wan2.1-T2V-1.3B
license: apache-2.0
pipeline_tag: text-to-video

Causal Forcing

Autoregressive Diffusion Distillation Done Right for High-Quality Real-Time Interactive Video Generation

Tsinghua University & Shengshu & UT Austin

Paper | Website | Code | Models


Causal Forcing significantly outperforms Self Forcing in both visual quality and motion dynamics, while keeping the same training budget and inference efficiency—enabling real-time, streaming video generation on a single RTX 4090.


Abstract

To achieve real-time interactive video generation, current methods distill pretrained bidirectional video diffusion models into few-step autoregressive (AR) models, facing an architectural gap when full attention is replaced by causal attention. We propose Causal Forcing that uses an AR teacher for ODE initialization, thereby bridging the architectural gap. Empirical results show that our method outperforms all baselines across all metrics, surpassing the SOTA Self Forcing by 19.3% in Dynamic Degree, 8.7% in VisionReward, and 16.7% in Instruction Following.

Quick Start

The inference environment is identical to Self Forcing, so you can migrate directly using our configs and model.

Installation

conda create -n causal_forcing python=3.10 -y
conda activate causal_forcing
pip install -r requirements.txt
pip install git+https://github.com/openai/CLIP.git
pip install flash-attn --no-build-isolation
python setup.py develop

Download checkpoints

hf download Wan-AI/Wan2.1-T2V-1.3B  --local-dir wan_models/Wan2.1-T2V-1.3B
hf download Wan-AI/Wan2.1-T2V-14B  --local-dir wan_models/Wan2.1-T2V-14B
hf download zhuhz22/Causal-Forcing chunkwise/causal_forcing.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing framewise/causal_forcing.pt --local-dir checkpoints

CLI Inference

We open-source both the frame-wise and chunk-wise models; the former is a setting that Self Forcing has chosen not to release.

Frame-wise model (higher dynamic degree and more expressive):

python inference.py \
  --config_path configs/causal_forcing_dmd_framewise.yaml \
  --output_folder output/framewise \
  --checkpoint_path  checkpoints/framewise/causal_forcing.pt \
  --data_path prompts/demos.txt \
  --use_ema
    # Note: this frame-wise config not in Self Forcing; if using its framework, migrate this config too.

Chunk-wise model (more stable):

python inference.py \
  --config_path configs/causal_forcing_dmd_chunkwise.yaml \
  --output_folder output/chunkwise \
  --checkpoint_path checkpoints/chunkwise/causal_forcing.pt \
  --data_path prompts/demos.txt

Training

Stage 1: Autoregressive Diffusion Training (Can skip by using our pretrained checkpoints. Click to expand.)

First download the dataset (we provide a 6K toy dataset here):

hf download zhuhz22/Causal-Forcing-data  --local-dir dataset
python utils/merge_and_get_clean.py

Then train the AR-diffusion model:

  • Framewise:

      torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
      --rdzv_backend=c10d \
      --rdzv_endpoint $MASTER_ADDR \
      train.py \
      --config_path configs/ar_diffusion_tf_framewise.yaml \
      --logdir logs/ar_diffusion_framewise
    
  • Chunkwise:

      torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
      --rdzv_backend=c10d \
      --rdzv_endpoint $MASTER_ADDR \
      train.py \
      --config_path configs/ar_diffusion_tf_chunkwise.yaml \
      --logdir logs/ar_diffusion_chunkwise
    
Stage 2: Causal ODE Initialization (Can skip by using our pretrained checkpoints. Click to expand.)

If you have skipped Stage 1, you need to download the pretrained models:

hf download zhuhz22/Causal-Forcing framewise/ar_diffusion.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing chunkwise/ar_diffusion.pt --local-dir checkpoints

In this stage, train ODE initialization models:

  • Frame-wise:
    torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
      --rdzv_backend=c10d \
      --rdzv_endpoint $MASTER_ADDR \
      train.py \
      --config_path configs/causal_ode_framewise.yaml \
      --logdir logs/causal_ode_framewise
    

Stage 3: DMD

This stage is compatible with Self Forcing training, so you can migrate seamlessly by using our configs and checkpoints.

  • Frame-wise model:
    torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
      --rdzv_backend=c10d \
      --rdzv_endpoint $MASTER_ADDR \
      train.py \
      --config_path configs/causal_forcing_dmd_framewise.yaml \
      --logdir logs/causal_forcing_dmd_framewise
    

Acknowledgements

This codebase is built on top of the open-source implementation of CausVid, Self Forcing and the Wan2.1 repo.

References

@misc{zhu2026causalforcingautoregressivediffusion,
      title={Causal Forcing: Autoregressive Diffusion Distillation Done Right for High-Quality Real-Time Interactive Video Generation}, 
      author={Hongzhou Zhu and Min Zhao and Guande He and Hang Su and Chongxuan Li and Jun Zhu},
      year={2026},
      eprint={2602.02214},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2602.02214}, 
}