SceneWorks commited on
Commit
e19a025
Β·
verified Β·
1 Parent(s): f147ef2

sc-5445: turnkey MLX model card (converted vae/t5/clip safetensors)

Browse files
Files changed (1) hide show
  1. README.md +22 -12
README.md CHANGED
@@ -20,18 +20,28 @@ base_model: zai-org/SCAIL-2
20
  > Capabilities (from upstream): character animation from a reference image + driving video, **cross-identity character replacement**, zero-shot animal-driving, end-to-end *and* pose-rendered driving, and (experimental) multi-reference. Image output is `num_frames == 1`.
21
 
22
  ## What changed vs. upstream
23
- - The main DiT (`model/1/fsdp2_rank_0000_checkpoint.pt`, an FSDP2/SAT checkpoint) was key-remapped to the `SCAIL2Model` parameter naming using the upstream `convert.py` contract (fused `query_key_value`β†’`q`/`k`/`v`, `key_value`β†’`k`/`v`, `clip_feature_key_value_list`β†’`k_img`/`v_img`) and cast **fp32 β†’ bf16**, then saved as a single safetensors. The remap is bit-faithful (validated: 987 source keys β†’ 1307 model keys, an exact key+shape match against `SCAIL2Model.from_config(config-14b.json)`).
24
- - The text encoder (UMT5-XXL), VAE (Wan2.1), and image encoder (open-CLIP XLM-RoBERTa ViT-H/14) are the **stock upstream files**, included here so the snapshot is self-contained.
25
-
26
- ## Contents
27
- | file | source | notes |
28
- |---|---|---|
29
- | `dit.safetensors` | converted | SCAIL-2 14B DiT, **bf16**, 1307 tensors (~31 GB) |
30
- | `config.json` | upstream `configs/config-14b.json` | `model_type: i2v`, `dim 5120`, `ffn 13824`, `40` layers/heads, `in_dim 20`, `mask_dim 28`, `out_dim 16` |
31
- | `Wan2.1_VAE.pth` | upstream, stock | z16 VAE, stride (4,8,8) |
32
- | `umt5-xxl/` | upstream, stock | UMT5-XXL encoder (bf16) + tokenizer |
33
- | `models_clip_open-clip-xlm-roberta-large-vit-huge-14-onlyvisual.pth` | upstream, stock | reference-image visual encoder (1280-dim) |
34
- | `bias-aware-dpo-lora.pt` | upstream, stock | optional Bias-Aware DPO refinement LoRA |
 
 
 
 
 
 
 
 
 
 
35
 
36
  ## Architecture (summary)
37
  Wan2.1-14B **I2V** dense DiT. Conditioning is a **token-axis packed** stream β€” reference + video + pose patch-embedded (three Conv3d stems) with additive 28-channel color-coded mask embeddings, concatenated into one self-attention sequence β€” plus a **per-source RoPE** with integer T/H/W shifts (the `replace_flag` flips the reference H-shift, toggling animation vs. replacement). The reference image is encoded by the CLIP visual tower and injected via Wan-I2V image cross-attention. Sampling is plain CFG (guide 5.0), flow-matching UniPC/DPM++.
 
20
  > Capabilities (from upstream): character animation from a reference image + driving video, **cross-identity character replacement**, zero-shot animal-driving, end-to-end *and* pose-rendered driving, and (experimental) multi-reference. Image output is `num_frames == 1`.
21
 
22
  ## What changed vs. upstream
23
+ Every component is repackaged to the safetensors layout the SceneWorks Rust/MLX loaders consume β€” no PyTorch at runtime:
24
+ - **DiT** (`model/1/fsdp2_rank_0000_checkpoint.pt`, an FSDP2/SAT checkpoint) was key-remapped to the `SCAIL2Model` parameter naming using the upstream `convert.py` contract (fused `query_key_value`β†’`q`/`k`/`v`, `key_value`β†’`k`/`v`, `clip_feature_key_value_list`β†’`k_img`/`v_img`) and cast **fp32 β†’ bf16** β†’ `dit.safetensors`. Bit-faithful (987 source keys β†’ 1307 model keys; exact key+shape match against `SCAIL2Model.from_config(config-14b.json)`).
25
+ - **VAE** (`Wan2.1_VAE.pth`, the stock Wan2.1 z16 VAE) β†’ `vae.safetensors` (**f32**, channels-last conv transpose, keys unchanged β€” the `sanitize_wan_vae_weights` contract shared with Bernini/wan). Loaded by `mlx_gen_wan::WanVae`.
26
+ - **Text encoder** (`umt5-xxl/models_t5_umt5-xxl-enc-bf16.pth`, stock UMT5-XXL) β†’ `t5_encoder.safetensors` (**bf16**, sole rename `.ffn.gate.0.`β†’`.ffn.gate_proj.`). Loaded by `mlx_gen_wan::Umt5Encoder` with `tokenizer.json`.
27
+ - **Image encoder** (`models_clip_...onlyvisual.pth`, open-CLIP XLM-RoBERTa ViT-H/14) β†’ `clip.safetensors` (**f32**, de-prefixed `visual.*` keys). Loaded by `mlx_gen_scail2::ScailClip` (32-layer visual tower, `use_31_block` penultimate features).
28
+
29
+ The converted VAE/UMT5 are byte-size-identical (modulo safetensors header) to Bernini/wan's already-validated Wan2.1 VAE + umt5-xxl safetensors β€” confirming SCAIL-2 ships the stock components.
30
+
31
+ ## Contents (turnkey MLX snapshot)
32
+ | file | source | loader | notes |
33
+ |---|---|---|---|
34
+ | `dit.safetensors` | converted | `Scail2Dit` | SCAIL-2 14B DiT, **bf16**, 1307 tensors (~31 GB) |
35
+ | `vae.safetensors` | converted | `WanVae` | Wan2.1 z16 VAE, **f32**, stride (4,8,8) (~0.5 GB) |
36
+ | `t5_encoder.safetensors` | converted | `Umt5Encoder` | UMT5-XXL encoder, **bf16** (~11 GB) |
37
+ | `clip.safetensors` | converted | `ScailClip` | open-CLIP ViT-H/14 visual tower, **f32**, 1280-dim (~2.5 GB) |
38
+ | `tokenizer.json` | upstream, stock | `load_tokenizer` | UMT5-XXL HF tokenizer (root copy) |
39
+ | `config.json` | upstream `configs/config-14b.json` | `Scail2Config` | `model_type: i2v`, `dim 5120`, `ffn 13824`, `40` layers/heads, `in_dim 20`, `mask_dim 28`, `out_dim 16` |
40
+ | `bias-aware-dpo-lora.pt` | upstream, stock | (sc-5451) | optional Bias-Aware DPO refinement LoRA |
41
+
42
+ Quantization (Q4/Q8) is applied at **load time** by the SceneWorks worker (`mlx_gen` `.quantize()`); this snapshot ships the dense bf16/f32 weights.
43
+
44
+ The raw upstream pickles (`Wan2.1_VAE.pth`, `umt5-xxl/models_t5_umt5-xxl-enc-bf16.pth`, `models_clip_...onlyvisual.pth`) remain in this repo for provenance; the Rust loaders use only the converted safetensors above, so a lean SceneWorks pull can skip them.
45
 
46
  ## Architecture (summary)
47
  Wan2.1-14B **I2V** dense DiT. Conditioning is a **token-axis packed** stream β€” reference + video + pose patch-embedded (three Conv3d stems) with additive 28-channel color-coded mask embeddings, concatenated into one self-attention sequence β€” plus a **per-source RoPE** with integer T/H/W shifts (the `replace_flag` flips the reference H-shift, toggling animation vs. replacement). The reference image is encoded by the CLIP visual tower and injected via Wan-I2V image cross-attention. Sampling is plain CFG (guide 5.0), flow-matching UniPC/DPM++.