ucpe_checkpoints / README.md
wlyu's picture
upload README
9475b90 verified
---
license: mit
tags:
- video-generation
- diffusion
- causal-video
- camera-control
- wan2.2
---
# UCPE Causal-Forcing Checkpoints
Wan2.2-TI2V-5B + UCPE camera-control checkpoints for the causal video-generation pipeline at
[github.com/weijielyu/RayStream_CF](https://github.com/weijielyu/RayStream_CF) (`cf_ucpe` repo).
All checkpoints are at **704×1280 (720p), 81 frames @ 16 fps, TI2V** with UCPE
camera conditioning (`relray_absmap`, `attn_compress=8`, parallel `cam_self_attn`
branches at every DiT block).
## Repository layout
```
.
├── README.md # this file
├── wan22_bidirectional_ucpe/ # Wan2.2 bidirectional teacher (DeepSpeed ckpt, ~24 GB)
│ ├── checkpoint/
│ │ ├── mp_rank_00_model_states.pt # ← actual weights (21 GB)
│ │ └── bf16_zero_pp_rank_*.pt # optimizer shards (8 × 213 MB)
│ ├── latest
│ └── zero_to_fp32.py
├── ode_regression_wan21_sf/ # Stage-1: causal student after DF-style ODE regression
│ ├── checkpoint_model_000400/model.pt # 400 steps (~20 GB)
│ └── checkpoint_model_001000/model.pt # 1000 steps (~20 GB)
├── dmd_unfreeze_cam_wan21_sf/ # Stage-2 variant A: DMD with camera branch trainable (lr_cam=10x)
│ ├── checkpoint_model_000500/model.pt # 500 steps (~135 GB, full-resume bundle)
│ └── checkpoint_model_001000/model.pt # 1000 steps (~135 GB)
└── dmd_freeze_cam_wan21_sf/ # Stage-2 variant B: DMD with camera branch frozen
├── checkpoint_model_000500/model.pt # 500 steps (~132 GB)
└── checkpoint_model_001000/model.pt # 1000 steps (~141 GB)
```
`cf_ucpe` ckpt format:
| dir | top-level keys |
|---|---|
| `ode_regression_wan21_sf/*/model.pt` | `generator` |
| `dmd_*_wan21_sf/*/model.pt` | `generator`, `generator_ema`, `fake_score`, `generator_optimizer`, `critic_optimizer`, `step` |
For inference you only need `generator_ema` (DMD) or `generator` (ODE) — see
`scripts/extract_ema_ckpt.py` in the code repo to slim them down.
The Wan2.2 bidirectional ckpt is in **DeepSpeed Zero-3** layout. Code that loads
it (e.g. `UCPE/scripts/predict_one_sample.py`) reads
`checkpoint/mp_rank_00_model_states.pt` directly.
---
## Quick start: download
```bash
huggingface-cli download wlyu/ucpe_checkpoints --local-dir ./ucpe_checkpoints
```
Or pull a specific subfolder:
```bash
huggingface-cli download wlyu/ucpe_checkpoints \
--include 'dmd_unfreeze_cam_wan21_sf/checkpoint_model_001000/*' \
--local-dir ./ucpe_checkpoints
```
---
## Training (in `cf_ucpe`)
The full pipeline is **Wan2.2 bidirectional → ODE regression (causal student) → DMD distillation**.
### Stage 1 — ODE regression (DF mode, matches upstream Self-Forcing)
Single causal forward, per-block random timesteps, no `clean_x`. Configured via
`use_df: true` → dispatches to `model.ode_regression_df.ODERegressionDF`.
Run on each of 4 nodes (set `NODE_RANK=0..3`):
```bash
LOG_DIR=output/ucpe_training_720_v2/ode_regression_wan21_sf \
CONFIG=configs/ucpe_ode_regression_720_wan21_sf.yaml \
NODE_RANK=0 MASTER_PORT=36903 MASTER_ADDR=<node0-ip> \
bash scripts/run_ode_regression_720_multinode.sh
```
Saves at every 200 steps; ~1000 steps total is enough.
### Stage 2 — DMD distillation
Distills the causal student against the bidirectional teacher (Wan2.2 + UCPE).
Two variants:
- **`ucpe_causal_forcing_dmd_720_wan21_sf.yaml`** — camera branch trainable,
with `lr_cam_multiplier=10` (default in `trainer/distillation.py`).
- **`ucpe_causal_forcing_dmd_720_wan21_sf_freeze.yaml`** — `freeze_camera_branch:
true`, camera branch participates in the forward pass but receives no gradient.
```bash
LOG_DIR=output/ucpe_training_720_v2/dmd_wan21_sf \
CONFIG=configs/ucpe_causal_forcing_dmd_720_wan21_sf.yaml \
NODE_RANK=0 MASTER_PORT=34576 MASTER_ADDR=<node0-ip> \
bash scripts/run_dmd_720_multinode.sh
```
Each step takes ~17 s on 4×8 H100. ~1000 steps recommended.
---
## Inference
### DMD causal student (few-step, fast)
```bash
python scripts/test_ucpe_dmd.py \
--config_path configs/ucpe_causal_forcing_dmd_720_wan21_sf.yaml \
--checkpoint_path /path/to/dmd_unfreeze_cam_wan21_sf/checkpoint_model_001000/model.pt \
--output_folder ./output/test \
--use_ema \
--num_samples 8
```
`--use_ema` is **required** for DMD checkpoints (loads `generator_ema`). Skip it
for ODE-stage checkpoints (only `generator`).
### Wan2.2 bidirectional teacher (50-step, source-of-truth)
The bidirectional ckpt was trained against UCPE's diffsynth-based pipeline.
Run via UCPE's `scripts/predict_one_sample.py`:
```bash
cd /path/to/UCPE # the UCPE repo, NOT cf_ucpe
HF_HUB_OFFLINE=1 python scripts/predict_one_sample.py \
--video_id <panshot_video_id> \
--ckpt_path /path/to/wan22_bidirectional_ucpe \
--output_path ./bidir.mp4 \
--num_inference_steps 50
```
Pick by `--video_id` (recommended) or `--sample_idx` for the test split.
---
## Visualization (4-panel comparison)
Generates **GT / camera-trajectory / Wan2.2 bidirectional / DMD** as a 2×2 grid
mp4 for one PanShot test sample:
```bash
# 1. Run all four sources for one sample (writes to output/comparison/<sample_dir>/)
python scripts/compare_inference.py \
--config_path configs/ucpe_causal_forcing_dmd_720_wan21_sf.yaml \
--dmd_ckpt /path/to/dmd_unfreeze_cam_wan21_sf/checkpoint_model_001000/model.pt \
--use_ema \
--sample_idx 0 \
--output_root output/comparison
# 2. Compose the 2x2 grid (renders camera trajectory + ffmpeg stack)
python scripts/compare_grid.py --input_dir output/comparison/0000_<video_id>/
```
Output: `output/comparison/0000_<video_id>/grid.mp4`.
For a batch over 8 GPUs (samples 0..31, ~5 min):
```bash
START=0 END=31 bash scripts/compare_batch_8gpu.sh
```
The trajectory is rendered as a 3D camera frustum gizmo over the actual world-space
camera path. Frustum size auto-scales to the trajectory bbox; tweak with
`--frustum_scale_ratio` (default `1/12`) on `compare_grid.py` without redoing inference.
---
## Citation / contact
Code: <https://github.com/weijielyu/RayStream_CF>
Author: Weijie Lyu (`weijielyu1@gmail.com`)