Instructions to use xixircc/MetaRigCapture with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use xixircc/MetaRigCapture with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("xixircc/MetaRigCapture", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +8 -0
- blaze_face_short_range.tflite +3 -0
- face-parsing/.gitattributes +28 -0
- face-parsing/README.md +165 -0
- face-parsing/config.json +111 -0
- face-parsing/demo.png +3 -0
- face-parsing/model.safetensors +3 -0
- face-parsing/onnx/model.onnx +3 -0
- face-parsing/onnx/model_quantized.onnx +3 -0
- face-parsing/preprocessor_config.json +23 -0
- face-parsing/quantize_config.json +33 -0
- models/unet_3d.py +727 -0
- models/unet_3d_blocks.py +1121 -0
- pretrained_weights/sd-image-variations-diffusers/.gitattributes +32 -0
- pretrained_weights/sd-image-variations-diffusers/README.md +226 -0
- pretrained_weights/sd-image-variations-diffusers/alias-montage.jpg +3 -0
- pretrained_weights/sd-image-variations-diffusers/default-montage.jpg +3 -0
- pretrained_weights/sd-image-variations-diffusers/earring.jpg +3 -0
- pretrained_weights/sd-image-variations-diffusers/feature_extractor/preprocessor_config.json +28 -0
- pretrained_weights/sd-image-variations-diffusers/image_encoder/config.json +23 -0
- pretrained_weights/sd-image-variations-diffusers/image_encoder/pytorch_model.bin +3 -0
- pretrained_weights/sd-image-variations-diffusers/inputs.jpg +0 -0
- pretrained_weights/sd-image-variations-diffusers/model_index.json +29 -0
- pretrained_weights/sd-image-variations-diffusers/safety_checker/config.json +181 -0
- pretrained_weights/sd-image-variations-diffusers/scheduler/scheduler_config.json +13 -0
- pretrained_weights/sd-image-variations-diffusers/unet/config.json +40 -0
- pretrained_weights/sd-image-variations-diffusers/unet/diffusion_pytorch_model.bin +3 -0
- pretrained_weights/sd-image-variations-diffusers/v1-montage.jpg +3 -0
- pretrained_weights/sd-image-variations-diffusers/v2-montage.jpg +3 -0
- pretrained_weights/sd-image-variations-diffusers/vae/config.json +30 -0
- pretrained_weights/sd-image-variations-diffusers/vae/diffusion_pytorch_model.bin +3 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/.gitattributes +36 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/LICENSE.md +58 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/README.md +99 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/comparison.png +3 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/feature_extractor/preprocessor_config.json +28 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/image_encoder/config.json +23 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/model_index.json +25 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/output_tile.gif +3 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/scheduler/scheduler_config.json +20 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/svd_xt.safetensors +3 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/svd_xt_image_decoder.safetensors +3 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/unet/config.json +38 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/unet/diffusion_pytorch_model.fp16.safetensors +3 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/unet/diffusion_pytorch_model.safetensors +3 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/vae/config.json +24 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/vae/diffusion_pytorch_model.fp16.safetensors +3 -0
- pretrained_weights/stable-video-diffusion-img2vid-xt/vae/diffusion_pytorch_model.safetensors +3 -0
- pretrained_weights/xnemo_denoising_unet.pth +3 -0
- pretrained_weights/xnemo_motion_encoder.pth +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
face-parsing/demo.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
pretrained_weights/sd-image-variations-diffusers/alias-montage.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
pretrained_weights/sd-image-variations-diffusers/default-montage.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
pretrained_weights/sd-image-variations-diffusers/earring.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
pretrained_weights/sd-image-variations-diffusers/v1-montage.jpg filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
pretrained_weights/sd-image-variations-diffusers/v2-montage.jpg filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
pretrained_weights/stable-video-diffusion-img2vid-xt/comparison.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
pretrained_weights/stable-video-diffusion-img2vid-xt/output_tile.gif filter=lfs diff=lfs merge=lfs -text
|
blaze_face_short_range.tflite
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b4578f35940bf5a1a655214a1cce5cab13eba73c1297cd78e1a04c2380b0152f
|
| 3 |
+
size 229746
|
face-parsing/.gitattributes
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
face-parsing/README.md
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
library_name: transformers
|
| 4 |
+
tags:
|
| 5 |
+
- vision
|
| 6 |
+
- image-segmentation
|
| 7 |
+
- nvidia/mit-b5
|
| 8 |
+
- transformers.js
|
| 9 |
+
- onnx
|
| 10 |
+
datasets:
|
| 11 |
+
- celebamaskhq
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# Face Parsing
|
| 15 |
+
|
| 16 |
+

|
| 17 |
+
|
| 18 |
+
[Semantic segmentation](https://huggingface.co/docs/transformers/tasks/semantic_segmentation) model fine-tuned from [nvidia/mit-b5](https://huggingface.co/nvidia/mit-b5) with [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ) for face parsing. For additional options, see the Transformers [Segformer docs](https://huggingface.co/docs/transformers/model_doc/segformer).
|
| 19 |
+
|
| 20 |
+
> ONNX model for web inference contributed by [Xenova](https://huggingface.co/Xenova).
|
| 21 |
+
|
| 22 |
+
## Usage in Python
|
| 23 |
+
|
| 24 |
+
Exhaustive list of labels can be extracted from [config.json](https://huggingface.co/jonathandinu/face-parsing/blob/65972ac96180b397f86fda0980bbe68e6ee01b8f/config.json#L30).
|
| 25 |
+
|
| 26 |
+
| id | label | note |
|
| 27 |
+
| :-: | :--------- | :---------------- |
|
| 28 |
+
| 0 | background | |
|
| 29 |
+
| 1 | skin | |
|
| 30 |
+
| 2 | nose | |
|
| 31 |
+
| 3 | eye_g | eyeglasses |
|
| 32 |
+
| 4 | l_eye | left eye |
|
| 33 |
+
| 5 | r_eye | right eye |
|
| 34 |
+
| 6 | l_brow | left eyebrow |
|
| 35 |
+
| 7 | r_brow | right eyebrow |
|
| 36 |
+
| 8 | l_ear | left ear |
|
| 37 |
+
| 9 | r_ear | right ear |
|
| 38 |
+
| 10 | mouth | area between lips |
|
| 39 |
+
| 11 | u_lip | upper lip |
|
| 40 |
+
| 12 | l_lip | lower lip |
|
| 41 |
+
| 13 | hair | |
|
| 42 |
+
| 14 | hat | |
|
| 43 |
+
| 15 | ear_r | earring |
|
| 44 |
+
| 16 | neck_l | necklace |
|
| 45 |
+
| 17 | neck | |
|
| 46 |
+
| 18 | cloth | clothing |
|
| 47 |
+
|
| 48 |
+
```python
|
| 49 |
+
import torch
|
| 50 |
+
from torch import nn
|
| 51 |
+
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
|
| 52 |
+
|
| 53 |
+
from PIL import Image
|
| 54 |
+
import matplotlib.pyplot as plt
|
| 55 |
+
import requests
|
| 56 |
+
|
| 57 |
+
# convenience expression for automatically determining device
|
| 58 |
+
device = (
|
| 59 |
+
"cuda"
|
| 60 |
+
# Device for NVIDIA or AMD GPUs
|
| 61 |
+
if torch.cuda.is_available()
|
| 62 |
+
else "mps"
|
| 63 |
+
# Device for Apple Silicon (Metal Performance Shaders)
|
| 64 |
+
if torch.backends.mps.is_available()
|
| 65 |
+
else "cpu"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# load models
|
| 69 |
+
image_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
|
| 70 |
+
model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
|
| 71 |
+
model.to(device)
|
| 72 |
+
|
| 73 |
+
# expects a PIL.Image or torch.Tensor
|
| 74 |
+
url = "https://images.unsplash.com/photo-1539571696357-5a69c17a67c6"
|
| 75 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
| 76 |
+
|
| 77 |
+
# run inference on image
|
| 78 |
+
inputs = image_processor(images=image, return_tensors="pt").to(device)
|
| 79 |
+
outputs = model(**inputs)
|
| 80 |
+
logits = outputs.logits # shape (batch_size, num_labels, ~height/4, ~width/4)
|
| 81 |
+
|
| 82 |
+
# resize output to match input image dimensions
|
| 83 |
+
upsampled_logits = nn.functional.interpolate(logits,
|
| 84 |
+
size=image.size[::-1], # H x W
|
| 85 |
+
mode='bilinear',
|
| 86 |
+
align_corners=False)
|
| 87 |
+
|
| 88 |
+
# get label masks
|
| 89 |
+
labels = upsampled_logits.argmax(dim=1)[0]
|
| 90 |
+
|
| 91 |
+
# move to CPU to visualize in matplotlib
|
| 92 |
+
labels_viz = labels.cpu().numpy()
|
| 93 |
+
plt.imshow(labels_viz)
|
| 94 |
+
plt.show()
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
## Usage in the browser (Transformers.js)
|
| 98 |
+
|
| 99 |
+
```js
|
| 100 |
+
import {
|
| 101 |
+
pipeline,
|
| 102 |
+
env,
|
| 103 |
+
} from "https://cdn.jsdelivr.net/npm/@xenova/transformers@2.14.0";
|
| 104 |
+
|
| 105 |
+
// important to prevent errors since the model files are likely remote on HF hub
|
| 106 |
+
env.allowLocalModels = false;
|
| 107 |
+
|
| 108 |
+
// instantiate image segmentation pipeline with pretrained face parsing model
|
| 109 |
+
model = await pipeline("image-segmentation", "jonathandinu/face-parsing");
|
| 110 |
+
|
| 111 |
+
// async inference since it could take a few seconds
|
| 112 |
+
const output = await model(url);
|
| 113 |
+
|
| 114 |
+
// each label is a separate mask object
|
| 115 |
+
// [
|
| 116 |
+
// { score: null, label: 'background', mask: transformers.js RawImage { ... }}
|
| 117 |
+
// { score: null, label: 'hair', mask: transformers.js RawImage { ... }}
|
| 118 |
+
// ...
|
| 119 |
+
// ]
|
| 120 |
+
for (const m of output) {
|
| 121 |
+
print(`Found ${m.label}`);
|
| 122 |
+
m.mask.save(`${m.label}.png`);
|
| 123 |
+
}
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
### p5.js
|
| 127 |
+
|
| 128 |
+
Since [p5.js](https://p5js.org/) uses an animation loop abstraction, we need to take care loading the model and making predictions.
|
| 129 |
+
|
| 130 |
+
```js
|
| 131 |
+
// ...
|
| 132 |
+
|
| 133 |
+
// asynchronously load transformers.js and instantiate model
|
| 134 |
+
async function preload() {
|
| 135 |
+
// load transformers.js library with a dynamic import
|
| 136 |
+
const { pipeline, env } = await import(
|
| 137 |
+
"https://cdn.jsdelivr.net/npm/@xenova/transformers@2.14.0"
|
| 138 |
+
);
|
| 139 |
+
|
| 140 |
+
// important to prevent errors since the model files are remote on HF hub
|
| 141 |
+
env.allowLocalModels = false;
|
| 142 |
+
|
| 143 |
+
// instantiate image segmentation pipeline with pretrained face parsing model
|
| 144 |
+
model = await pipeline("image-segmentation", "jonathandinu/face-parsing");
|
| 145 |
+
|
| 146 |
+
print("face-parsing model loaded");
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
// ...
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
[full p5.js example](https://editor.p5js.org/jonathan.ai/sketches/wZn15Dvgh)
|
| 153 |
+
|
| 154 |
+
### Model Description
|
| 155 |
+
|
| 156 |
+
- **Developed by:** [Jonathan Dinu](https://twitter.com/jonathandinu)
|
| 157 |
+
- **Model type:** Transformer-based semantic segmentation image model
|
| 158 |
+
- **License:** non-commercial research and educational purposes
|
| 159 |
+
- **Resources for more information:** Transformers docs on [Segformer](https://huggingface.co/docs/transformers/model_doc/segformer) and/or the [original research paper](https://arxiv.org/abs/2105.15203).
|
| 160 |
+
|
| 161 |
+
## Limitations and Bias
|
| 162 |
+
|
| 163 |
+
### Bias
|
| 164 |
+
|
| 165 |
+
While the capabilities of computer vision models are impressive, they can also reinforce or exacerbate social biases. The [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ) dataset used for fine-tuning is large but not necessarily perfectly diverse or representative. Also, they are images of.... just celebrities.
|
face-parsing/config.json
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "jonathandinu/face-parsing",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"SegformerForSemanticSegmentation"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.0,
|
| 7 |
+
"classifier_dropout_prob": 0.1,
|
| 8 |
+
"decoder_hidden_size": 768,
|
| 9 |
+
"depths": [
|
| 10 |
+
3,
|
| 11 |
+
6,
|
| 12 |
+
40,
|
| 13 |
+
3
|
| 14 |
+
],
|
| 15 |
+
"downsampling_rates": [
|
| 16 |
+
1,
|
| 17 |
+
4,
|
| 18 |
+
8,
|
| 19 |
+
16
|
| 20 |
+
],
|
| 21 |
+
"drop_path_rate": 0.1,
|
| 22 |
+
"hidden_act": "gelu",
|
| 23 |
+
"hidden_dropout_prob": 0.0,
|
| 24 |
+
"hidden_sizes": [
|
| 25 |
+
64,
|
| 26 |
+
128,
|
| 27 |
+
320,
|
| 28 |
+
512
|
| 29 |
+
],
|
| 30 |
+
"id2label": {
|
| 31 |
+
"0": "background",
|
| 32 |
+
"1": "skin",
|
| 33 |
+
"2": "nose",
|
| 34 |
+
"3": "eye_g",
|
| 35 |
+
"4": "l_eye",
|
| 36 |
+
"5": "r_eye",
|
| 37 |
+
"6": "l_brow",
|
| 38 |
+
"7": "r_brow",
|
| 39 |
+
"8": "l_ear",
|
| 40 |
+
"9": "r_ear",
|
| 41 |
+
"10": "mouth",
|
| 42 |
+
"11": "u_lip",
|
| 43 |
+
"12": "l_lip",
|
| 44 |
+
"13": "hair",
|
| 45 |
+
"14": "hat",
|
| 46 |
+
"15": "ear_r",
|
| 47 |
+
"16": "neck_l",
|
| 48 |
+
"17": "neck",
|
| 49 |
+
"18": "cloth"
|
| 50 |
+
},
|
| 51 |
+
"image_size": 224,
|
| 52 |
+
"initializer_range": 0.02,
|
| 53 |
+
"label2id": {
|
| 54 |
+
"background": 0,
|
| 55 |
+
"skin": 1,
|
| 56 |
+
"nose": 2,
|
| 57 |
+
"eye_g": 3,
|
| 58 |
+
"l_eye": 4,
|
| 59 |
+
"r_eye": 5,
|
| 60 |
+
"l_brow": 6,
|
| 61 |
+
"r_brow": 7,
|
| 62 |
+
"l_ear": 8,
|
| 63 |
+
"r_ear": 9,
|
| 64 |
+
"mouth": 10,
|
| 65 |
+
"u_lip": 11,
|
| 66 |
+
"l_lip": 12,
|
| 67 |
+
"hair": 13,
|
| 68 |
+
"hat": 14,
|
| 69 |
+
"ear_r": 15,
|
| 70 |
+
"neck_l": 16,
|
| 71 |
+
"neck": 17,
|
| 72 |
+
"cloth": 18
|
| 73 |
+
},
|
| 74 |
+
"layer_norm_eps": 1e-06,
|
| 75 |
+
"mlp_ratios": [
|
| 76 |
+
4,
|
| 77 |
+
4,
|
| 78 |
+
4,
|
| 79 |
+
4
|
| 80 |
+
],
|
| 81 |
+
"model_type": "segformer",
|
| 82 |
+
"num_attention_heads": [
|
| 83 |
+
1,
|
| 84 |
+
2,
|
| 85 |
+
5,
|
| 86 |
+
8
|
| 87 |
+
],
|
| 88 |
+
"num_channels": 3,
|
| 89 |
+
"num_encoder_blocks": 4,
|
| 90 |
+
"patch_sizes": [
|
| 91 |
+
7,
|
| 92 |
+
3,
|
| 93 |
+
3,
|
| 94 |
+
3
|
| 95 |
+
],
|
| 96 |
+
"reshape_last_stage": true,
|
| 97 |
+
"semantic_loss_ignore_index": 255,
|
| 98 |
+
"sr_ratios": [
|
| 99 |
+
8,
|
| 100 |
+
4,
|
| 101 |
+
2,
|
| 102 |
+
1
|
| 103 |
+
],
|
| 104 |
+
"strides": [
|
| 105 |
+
4,
|
| 106 |
+
2,
|
| 107 |
+
2,
|
| 108 |
+
2
|
| 109 |
+
],
|
| 110 |
+
"transformers_version": "4.37.0.dev0"
|
| 111 |
+
}
|
face-parsing/demo.png
ADDED
|
Git LFS Details
|
face-parsing/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c2bec795a8c243db71bd95be538fd62559003566466c71237e45c99b920f4b62
|
| 3 |
+
size 338580732
|
face-parsing/onnx/model.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d4e67af60ff78184745ebf74cc15163c0adc27d45cdeba31e3a03d1096fb8c3
|
| 3 |
+
size 340316611
|
face-parsing/onnx/model_quantized.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5bab9bfb3cb979f3098ac3b934b1641dbf87f835e0b03c2ca6d88dcf18c83d27
|
| 3 |
+
size 89439678
|
face-parsing/preprocessor_config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_normalize": true,
|
| 3 |
+
"do_reduce_labels": false,
|
| 4 |
+
"do_rescale": true,
|
| 5 |
+
"do_resize": true,
|
| 6 |
+
"image_mean": [
|
| 7 |
+
0.485,
|
| 8 |
+
0.456,
|
| 9 |
+
0.406
|
| 10 |
+
],
|
| 11 |
+
"image_processor_type": "SegformerFeatureExtractor",
|
| 12 |
+
"image_std": [
|
| 13 |
+
0.229,
|
| 14 |
+
0.224,
|
| 15 |
+
0.225
|
| 16 |
+
],
|
| 17 |
+
"resample": 2,
|
| 18 |
+
"rescale_factor": 0.00392156862745098,
|
| 19 |
+
"size": {
|
| 20 |
+
"height": 512,
|
| 21 |
+
"width": 512
|
| 22 |
+
}
|
| 23 |
+
}
|
face-parsing/quantize_config.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"per_channel": true,
|
| 3 |
+
"reduce_range": true,
|
| 4 |
+
"per_model_config": {
|
| 5 |
+
"model": {
|
| 6 |
+
"op_types": [
|
| 7 |
+
"Unsqueeze",
|
| 8 |
+
"Shape",
|
| 9 |
+
"Transpose",
|
| 10 |
+
"Sqrt",
|
| 11 |
+
"Gather",
|
| 12 |
+
"Slice",
|
| 13 |
+
"Erf",
|
| 14 |
+
"Div",
|
| 15 |
+
"Reshape",
|
| 16 |
+
"Add",
|
| 17 |
+
"Cast",
|
| 18 |
+
"Sub",
|
| 19 |
+
"Concat",
|
| 20 |
+
"ReduceMean",
|
| 21 |
+
"Mul",
|
| 22 |
+
"Conv",
|
| 23 |
+
"Constant",
|
| 24 |
+
"Resize",
|
| 25 |
+
"Softmax",
|
| 26 |
+
"Pow",
|
| 27 |
+
"Relu",
|
| 28 |
+
"MatMul"
|
| 29 |
+
],
|
| 30 |
+
"weight_type": "QUInt8"
|
| 31 |
+
}
|
| 32 |
+
}
|
| 33 |
+
}
|
models/unet_3d.py
ADDED
|
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# *************************************************************************
|
| 2 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by ByteDance Ltd. and/or its affiliates.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under Aniportrait, with the full license text
|
| 8 |
+
# available at https://github.com/Zejun-Yang/AniPortrait/blob/main/LICENSE.
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
# *************************************************************************
|
| 12 |
+
from collections import OrderedDict
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
import pdb
|
| 15 |
+
from os import PathLike
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.utils.checkpoint
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from diffusers.models.attention_processor import AttentionProcessor
|
| 25 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
| 26 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 27 |
+
from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
|
| 28 |
+
from safetensors.torch import load_file
|
| 29 |
+
|
| 30 |
+
from .resnet import InflatedConv3d, InflatedGroupNorm
|
| 31 |
+
from .unet_3d_blocks import UNetMidBlock3DCrossAttn, get_down_block, get_up_block
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class UNet3DConditionOutput(BaseOutput):
|
| 38 |
+
sample: torch.FloatTensor
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
| 42 |
+
_supports_gradient_checkpointing = True
|
| 43 |
+
|
| 44 |
+
@register_to_config
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
sample_size: Optional[int] = None,
|
| 48 |
+
in_channels: int = 4,
|
| 49 |
+
out_channels: int = 4,
|
| 50 |
+
center_input_sample: bool = False,
|
| 51 |
+
flip_sin_to_cos: bool = True,
|
| 52 |
+
freq_shift: int = 0,
|
| 53 |
+
down_block_types: Tuple[str] = (
|
| 54 |
+
"CrossAttnDownBlock3D",
|
| 55 |
+
"CrossAttnDownBlock3D",
|
| 56 |
+
"CrossAttnDownBlock3D",
|
| 57 |
+
"DownBlock3D",
|
| 58 |
+
),
|
| 59 |
+
mid_block_type: str = "UNetMidBlock3DCrossAttn",
|
| 60 |
+
up_block_types: Tuple[str] = (
|
| 61 |
+
"UpBlock3D",
|
| 62 |
+
"CrossAttnUpBlock3D",
|
| 63 |
+
"CrossAttnUpBlock3D",
|
| 64 |
+
"CrossAttnUpBlock3D",
|
| 65 |
+
),
|
| 66 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
| 67 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
| 68 |
+
layers_per_block: int = 2,
|
| 69 |
+
downsample_padding: int = 1,
|
| 70 |
+
mid_block_scale_factor: float = 1,
|
| 71 |
+
act_fn: str = "silu",
|
| 72 |
+
norm_num_groups: int = 32,
|
| 73 |
+
norm_eps: float = 1e-5,
|
| 74 |
+
cross_attention_dim: int = 1280,
|
| 75 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
| 76 |
+
dual_cross_attention: bool = False,
|
| 77 |
+
use_linear_projection: bool = False,
|
| 78 |
+
class_embed_type: Optional[str] = None,
|
| 79 |
+
num_class_embeds: Optional[int] = None,
|
| 80 |
+
upcast_attention: bool = False,
|
| 81 |
+
resnet_time_scale_shift: str = "default",
|
| 82 |
+
use_inflated_groupnorm=False,
|
| 83 |
+
# Additional
|
| 84 |
+
use_motion_module=False,
|
| 85 |
+
use_temporal_module=False,
|
| 86 |
+
motion_module_resolutions=(1, 2, 4, 8),
|
| 87 |
+
motion_module_mid_block=False,
|
| 88 |
+
motion_module_decoder_only=False,
|
| 89 |
+
motion_module_type=None,
|
| 90 |
+
temporal_module_type=None,
|
| 91 |
+
motion_module_kwargs={},
|
| 92 |
+
temporal_module_kwargs={},
|
| 93 |
+
unet_use_cross_frame_attention=None,
|
| 94 |
+
unet_use_temporal_attention=None,
|
| 95 |
+
):
|
| 96 |
+
super().__init__()
|
| 97 |
+
|
| 98 |
+
self.sample_size = sample_size
|
| 99 |
+
time_embed_dim = block_out_channels[0] * 4
|
| 100 |
+
|
| 101 |
+
# input
|
| 102 |
+
self.conv_in = InflatedConv3d(
|
| 103 |
+
in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# time
|
| 107 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
| 108 |
+
timestep_input_dim = block_out_channels[0]
|
| 109 |
+
|
| 110 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 111 |
+
|
| 112 |
+
# class embedding
|
| 113 |
+
if class_embed_type is None and num_class_embeds is not None:
|
| 114 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
| 115 |
+
elif class_embed_type == "timestep":
|
| 116 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 117 |
+
elif class_embed_type == "identity":
|
| 118 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
| 119 |
+
else:
|
| 120 |
+
self.class_embedding = None
|
| 121 |
+
|
| 122 |
+
self.down_blocks = nn.ModuleList([])
|
| 123 |
+
self.mid_block = None
|
| 124 |
+
self.up_blocks = nn.ModuleList([])
|
| 125 |
+
|
| 126 |
+
if isinstance(only_cross_attention, bool):
|
| 127 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
| 128 |
+
|
| 129 |
+
if isinstance(attention_head_dim, int):
|
| 130 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
| 131 |
+
|
| 132 |
+
# down
|
| 133 |
+
output_channel = block_out_channels[0]
|
| 134 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 135 |
+
res = 2**i
|
| 136 |
+
input_channel = output_channel
|
| 137 |
+
output_channel = block_out_channels[i]
|
| 138 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 139 |
+
|
| 140 |
+
down_block = get_down_block(
|
| 141 |
+
down_block_type,
|
| 142 |
+
num_layers=layers_per_block,
|
| 143 |
+
in_channels=input_channel,
|
| 144 |
+
out_channels=output_channel,
|
| 145 |
+
temb_channels=time_embed_dim,
|
| 146 |
+
add_downsample=not is_final_block,
|
| 147 |
+
resnet_eps=norm_eps,
|
| 148 |
+
resnet_act_fn=act_fn,
|
| 149 |
+
resnet_groups=norm_num_groups,
|
| 150 |
+
cross_attention_dim=cross_attention_dim,
|
| 151 |
+
attn_num_head_channels=attention_head_dim[i],
|
| 152 |
+
downsample_padding=downsample_padding,
|
| 153 |
+
dual_cross_attention=dual_cross_attention,
|
| 154 |
+
use_linear_projection=use_linear_projection,
|
| 155 |
+
only_cross_attention=only_cross_attention[i],
|
| 156 |
+
upcast_attention=upcast_attention,
|
| 157 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 158 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 159 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 160 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 161 |
+
use_motion_module=use_motion_module
|
| 162 |
+
and (res in motion_module_resolutions)
|
| 163 |
+
and (not motion_module_decoder_only),
|
| 164 |
+
use_temporal_module=use_temporal_module
|
| 165 |
+
and (res in motion_module_resolutions)
|
| 166 |
+
and (not motion_module_decoder_only),
|
| 167 |
+
motion_module_type=motion_module_type,
|
| 168 |
+
temporal_module_type=temporal_module_type,
|
| 169 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 170 |
+
temporal_module_kwargs=temporal_module_kwargs
|
| 171 |
+
)
|
| 172 |
+
self.down_blocks.append(down_block)
|
| 173 |
+
|
| 174 |
+
# mid
|
| 175 |
+
if mid_block_type == "UNetMidBlock3DCrossAttn":
|
| 176 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
| 177 |
+
in_channels=block_out_channels[-1],
|
| 178 |
+
temb_channels=time_embed_dim,
|
| 179 |
+
resnet_eps=norm_eps,
|
| 180 |
+
resnet_act_fn=act_fn,
|
| 181 |
+
output_scale_factor=mid_block_scale_factor,
|
| 182 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 183 |
+
cross_attention_dim=cross_attention_dim,
|
| 184 |
+
attn_num_head_channels=attention_head_dim[-1],
|
| 185 |
+
resnet_groups=norm_num_groups,
|
| 186 |
+
dual_cross_attention=dual_cross_attention,
|
| 187 |
+
use_linear_projection=use_linear_projection,
|
| 188 |
+
upcast_attention=upcast_attention,
|
| 189 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 190 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 191 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 192 |
+
use_motion_module=use_motion_module and motion_module_mid_block,
|
| 193 |
+
use_temporal_module=use_temporal_module and motion_module_mid_block,
|
| 194 |
+
motion_module_type=motion_module_type,
|
| 195 |
+
temporal_module_type=temporal_module_type,
|
| 196 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 197 |
+
temporal_module_kwargs=temporal_module_kwargs,
|
| 198 |
+
)
|
| 199 |
+
else:
|
| 200 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
| 201 |
+
|
| 202 |
+
# count how many layers upsample the videos
|
| 203 |
+
self.num_upsamplers = 0
|
| 204 |
+
|
| 205 |
+
# up
|
| 206 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 207 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
| 208 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
| 209 |
+
output_channel = reversed_block_out_channels[0]
|
| 210 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 211 |
+
res = 2 ** (3 - i)
|
| 212 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 213 |
+
|
| 214 |
+
prev_output_channel = output_channel
|
| 215 |
+
output_channel = reversed_block_out_channels[i]
|
| 216 |
+
input_channel = reversed_block_out_channels[
|
| 217 |
+
min(i + 1, len(block_out_channels) - 1)
|
| 218 |
+
]
|
| 219 |
+
|
| 220 |
+
# add upsample block for all BUT final layer
|
| 221 |
+
if not is_final_block:
|
| 222 |
+
add_upsample = True
|
| 223 |
+
self.num_upsamplers += 1
|
| 224 |
+
else:
|
| 225 |
+
add_upsample = False
|
| 226 |
+
|
| 227 |
+
up_block = get_up_block(
|
| 228 |
+
up_block_type,
|
| 229 |
+
num_layers=layers_per_block + 1,
|
| 230 |
+
in_channels=input_channel,
|
| 231 |
+
out_channels=output_channel,
|
| 232 |
+
prev_output_channel=prev_output_channel,
|
| 233 |
+
temb_channels=time_embed_dim,
|
| 234 |
+
add_upsample=add_upsample,
|
| 235 |
+
resnet_eps=norm_eps,
|
| 236 |
+
resnet_act_fn=act_fn,
|
| 237 |
+
resnet_groups=norm_num_groups,
|
| 238 |
+
cross_attention_dim=cross_attention_dim,
|
| 239 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
| 240 |
+
dual_cross_attention=dual_cross_attention,
|
| 241 |
+
use_linear_projection=use_linear_projection,
|
| 242 |
+
only_cross_attention=only_cross_attention[i],
|
| 243 |
+
upcast_attention=upcast_attention,
|
| 244 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 245 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 246 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 247 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 248 |
+
use_motion_module=use_motion_module
|
| 249 |
+
and (res in motion_module_resolutions),
|
| 250 |
+
use_temporal_module=use_temporal_module
|
| 251 |
+
and (res in motion_module_resolutions),
|
| 252 |
+
motion_module_type=motion_module_type,
|
| 253 |
+
temporal_module_type=temporal_module_type,
|
| 254 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 255 |
+
temporal_module_kwargs=temporal_module_kwargs,
|
| 256 |
+
)
|
| 257 |
+
self.up_blocks.append(up_block)
|
| 258 |
+
prev_output_channel = output_channel
|
| 259 |
+
|
| 260 |
+
# out
|
| 261 |
+
if use_inflated_groupnorm:
|
| 262 |
+
self.conv_norm_out = InflatedGroupNorm(
|
| 263 |
+
num_channels=block_out_channels[0],
|
| 264 |
+
num_groups=norm_num_groups,
|
| 265 |
+
eps=norm_eps,
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 269 |
+
num_channels=block_out_channels[0],
|
| 270 |
+
num_groups=norm_num_groups,
|
| 271 |
+
eps=norm_eps,
|
| 272 |
+
)
|
| 273 |
+
self.conv_act = nn.SiLU()
|
| 274 |
+
self.conv_out = InflatedConv3d(
|
| 275 |
+
block_out_channels[0], out_channels, kernel_size=3, padding=1
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
@property
|
| 279 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 280 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 281 |
+
r"""
|
| 282 |
+
Returns:
|
| 283 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 284 |
+
indexed by its weight name.
|
| 285 |
+
"""
|
| 286 |
+
# set recursively
|
| 287 |
+
processors = {}
|
| 288 |
+
|
| 289 |
+
def fn_recursive_add_processors(
|
| 290 |
+
name: str,
|
| 291 |
+
module: torch.nn.Module,
|
| 292 |
+
processors: Dict[str, AttentionProcessor],
|
| 293 |
+
):
|
| 294 |
+
if hasattr(module, "set_processor"):
|
| 295 |
+
processors[f"{name}.processor"] = module.processor
|
| 296 |
+
|
| 297 |
+
for sub_name, child in module.named_children():
|
| 298 |
+
if "temporal_transformer" not in sub_name:
|
| 299 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 300 |
+
|
| 301 |
+
return processors
|
| 302 |
+
|
| 303 |
+
for name, module in self.named_children():
|
| 304 |
+
if "temporal_transformer" not in name:
|
| 305 |
+
fn_recursive_add_processors(name, module, processors)
|
| 306 |
+
|
| 307 |
+
return processors
|
| 308 |
+
|
| 309 |
+
def set_attention_slice(self, slice_size):
|
| 310 |
+
r"""
|
| 311 |
+
Enable sliced attention computation.
|
| 312 |
+
|
| 313 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
| 314 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
| 318 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
| 319 |
+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
| 320 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
| 321 |
+
must be a multiple of `slice_size`.
|
| 322 |
+
"""
|
| 323 |
+
sliceable_head_dims = []
|
| 324 |
+
|
| 325 |
+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
| 326 |
+
if hasattr(module, "set_attention_slice"):
|
| 327 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
| 328 |
+
|
| 329 |
+
for child in module.children():
|
| 330 |
+
fn_recursive_retrieve_slicable_dims(child)
|
| 331 |
+
|
| 332 |
+
# retrieve number of attention layers
|
| 333 |
+
for module in self.children():
|
| 334 |
+
fn_recursive_retrieve_slicable_dims(module)
|
| 335 |
+
|
| 336 |
+
num_slicable_layers = len(sliceable_head_dims)
|
| 337 |
+
|
| 338 |
+
if slice_size == "auto":
|
| 339 |
+
# half the attention head size is usually a good trade-off between
|
| 340 |
+
# speed and memory
|
| 341 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
| 342 |
+
elif slice_size == "max":
|
| 343 |
+
# make smallest slice possible
|
| 344 |
+
slice_size = num_slicable_layers * [1]
|
| 345 |
+
|
| 346 |
+
slice_size = (
|
| 347 |
+
num_slicable_layers * [slice_size]
|
| 348 |
+
if not isinstance(slice_size, list)
|
| 349 |
+
else slice_size
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
if len(slice_size) != len(sliceable_head_dims):
|
| 353 |
+
raise ValueError(
|
| 354 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
| 355 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
for i in range(len(slice_size)):
|
| 359 |
+
size = slice_size[i]
|
| 360 |
+
dim = sliceable_head_dims[i]
|
| 361 |
+
if size is not None and size > dim:
|
| 362 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
| 363 |
+
|
| 364 |
+
# Recursively walk through all the children.
|
| 365 |
+
# Any children which exposes the set_attention_slice method
|
| 366 |
+
# gets the message
|
| 367 |
+
def fn_recursive_set_attention_slice(
|
| 368 |
+
module: torch.nn.Module, slice_size: List[int]
|
| 369 |
+
):
|
| 370 |
+
if hasattr(module, "set_attention_slice"):
|
| 371 |
+
module.set_attention_slice(slice_size.pop())
|
| 372 |
+
|
| 373 |
+
for child in module.children():
|
| 374 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
| 375 |
+
|
| 376 |
+
reversed_slice_size = list(reversed(slice_size))
|
| 377 |
+
for module in self.children():
|
| 378 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
| 379 |
+
|
| 380 |
+
def set_use_cross_frame_attention(self, value):
|
| 381 |
+
|
| 382 |
+
def fn_recursive_set_use_cf_att(module: torch.nn.Module, value):
|
| 383 |
+
if hasattr(module, "set_use_cross_frame_attention"):
|
| 384 |
+
module.set_use_cross_frame_attention(value)
|
| 385 |
+
|
| 386 |
+
for child in module.children():
|
| 387 |
+
fn_recursive_set_use_cf_att(child, value)
|
| 388 |
+
|
| 389 |
+
for module in self.children():
|
| 390 |
+
fn_recursive_set_use_cf_att(module, value)
|
| 391 |
+
|
| 392 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 393 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 394 |
+
module.gradient_checkpointing = value
|
| 395 |
+
|
| 396 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 397 |
+
def set_attn_processor(
|
| 398 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
|
| 399 |
+
):
|
| 400 |
+
r"""
|
| 401 |
+
Sets the attention processor to use to compute attention.
|
| 402 |
+
|
| 403 |
+
Parameters:
|
| 404 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 405 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 406 |
+
for **all** `Attention` layers.
|
| 407 |
+
|
| 408 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 409 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 410 |
+
|
| 411 |
+
"""
|
| 412 |
+
count = len(self.attn_processors.keys())
|
| 413 |
+
|
| 414 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 415 |
+
raise ValueError(
|
| 416 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 417 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 421 |
+
if hasattr(module, "set_processor"):
|
| 422 |
+
if not isinstance(processor, dict):
|
| 423 |
+
module.set_processor(processor)
|
| 424 |
+
else:
|
| 425 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 426 |
+
|
| 427 |
+
for sub_name, child in module.named_children():
|
| 428 |
+
if "temporal_transformer" not in sub_name:
|
| 429 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 430 |
+
|
| 431 |
+
for name, module in self.named_children():
|
| 432 |
+
if "temporal_transformer" not in name:
|
| 433 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 434 |
+
|
| 435 |
+
def forward(
|
| 436 |
+
self,
|
| 437 |
+
sample: torch.FloatTensor,
|
| 438 |
+
timestep: Union[torch.Tensor, float, int],
|
| 439 |
+
encoder_hidden_states: torch.Tensor,
|
| 440 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 441 |
+
pose_cond_fea = None,
|
| 442 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 443 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
| 444 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
| 445 |
+
return_dict: bool = True,
|
| 446 |
+
skip_mm: bool = False,
|
| 447 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
| 448 |
+
r"""
|
| 449 |
+
Args:
|
| 450 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
| 451 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
| 452 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
| 453 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 454 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
| 455 |
+
|
| 456 |
+
Returns:
|
| 457 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
| 458 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
| 459 |
+
returning a tuple, the first element is the sample tensor.
|
| 460 |
+
"""
|
| 461 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
| 462 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
| 463 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
| 464 |
+
# on the fly if necessary.
|
| 465 |
+
|
| 466 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
| 467 |
+
|
| 468 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 469 |
+
forward_upsample_size = False
|
| 470 |
+
upsample_size = None
|
| 471 |
+
|
| 472 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| 473 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
| 474 |
+
forward_upsample_size = True
|
| 475 |
+
|
| 476 |
+
# prepare attention_mask
|
| 477 |
+
if attention_mask is not None:
|
| 478 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 479 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 480 |
+
|
| 481 |
+
# center input if necessary
|
| 482 |
+
if self.config.center_input_sample:
|
| 483 |
+
sample = 2 * sample - 1.0
|
| 484 |
+
|
| 485 |
+
# time
|
| 486 |
+
timesteps = timestep
|
| 487 |
+
if not torch.is_tensor(timesteps):
|
| 488 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 489 |
+
is_mps = sample.device.type == "mps"
|
| 490 |
+
if isinstance(timestep, float):
|
| 491 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 492 |
+
else:
|
| 493 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 494 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 495 |
+
elif len(timesteps.shape) == 0:
|
| 496 |
+
timesteps = timesteps[None].to(sample.device)
|
| 497 |
+
|
| 498 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 499 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 500 |
+
|
| 501 |
+
t_emb = self.time_proj(timesteps)
|
| 502 |
+
|
| 503 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 504 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 505 |
+
# there might be better ways to encapsulate this.
|
| 506 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
| 507 |
+
emb = self.time_embedding(t_emb)
|
| 508 |
+
if self.class_embedding is not None:
|
| 509 |
+
if class_labels is None:
|
| 510 |
+
raise ValueError(
|
| 511 |
+
"class_labels should be provided when num_class_embeds > 0"
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
if self.config.class_embed_type == "timestep":
|
| 515 |
+
class_labels = self.time_proj(class_labels)
|
| 516 |
+
|
| 517 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
| 518 |
+
emb = emb + class_emb
|
| 519 |
+
|
| 520 |
+
# pre-process
|
| 521 |
+
sample = self.conv_in(sample)
|
| 522 |
+
if pose_cond_fea is not None:
|
| 523 |
+
sample = sample + pose_cond_fea[0]
|
| 524 |
+
|
| 525 |
+
# down
|
| 526 |
+
down_block_res_samples = (sample,)
|
| 527 |
+
block_count = 1
|
| 528 |
+
for downsample_block in self.down_blocks:
|
| 529 |
+
if (
|
| 530 |
+
hasattr(downsample_block, "has_cross_attention")
|
| 531 |
+
and downsample_block.has_cross_attention
|
| 532 |
+
):
|
| 533 |
+
sample, res_samples = downsample_block(
|
| 534 |
+
hidden_states=sample,
|
| 535 |
+
temb=emb,
|
| 536 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 537 |
+
attention_mask=attention_mask,
|
| 538 |
+
skip_mm=skip_mm,
|
| 539 |
+
)
|
| 540 |
+
else:
|
| 541 |
+
sample, res_samples = downsample_block(
|
| 542 |
+
hidden_states=sample,
|
| 543 |
+
temb=emb,
|
| 544 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 545 |
+
skip_mm=skip_mm,
|
| 546 |
+
)
|
| 547 |
+
if pose_cond_fea is not None:
|
| 548 |
+
sample = sample + pose_cond_fea[block_count]
|
| 549 |
+
block_count += 1
|
| 550 |
+
down_block_res_samples += res_samples
|
| 551 |
+
|
| 552 |
+
if down_block_additional_residuals is not None:
|
| 553 |
+
new_down_block_res_samples = ()
|
| 554 |
+
|
| 555 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
| 556 |
+
down_block_res_samples, down_block_additional_residuals
|
| 557 |
+
):
|
| 558 |
+
down_block_res_sample = (
|
| 559 |
+
down_block_res_sample + down_block_additional_residual
|
| 560 |
+
)
|
| 561 |
+
new_down_block_res_samples += (down_block_res_sample,)
|
| 562 |
+
|
| 563 |
+
down_block_res_samples = new_down_block_res_samples
|
| 564 |
+
|
| 565 |
+
# mid
|
| 566 |
+
sample = self.mid_block(
|
| 567 |
+
sample,
|
| 568 |
+
emb,
|
| 569 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 570 |
+
attention_mask=attention_mask,
|
| 571 |
+
skip_mm=skip_mm,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
if mid_block_additional_residual is not None:
|
| 575 |
+
sample = sample + mid_block_additional_residual
|
| 576 |
+
|
| 577 |
+
# up
|
| 578 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 579 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 580 |
+
|
| 581 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 582 |
+
down_block_res_samples = down_block_res_samples[
|
| 583 |
+
: -len(upsample_block.resnets)
|
| 584 |
+
]
|
| 585 |
+
|
| 586 |
+
# if we have not reached the final block and need to forward the
|
| 587 |
+
# upsample size, we do it here
|
| 588 |
+
if not is_final_block and forward_upsample_size:
|
| 589 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 590 |
+
|
| 591 |
+
if (
|
| 592 |
+
hasattr(upsample_block, "has_cross_attention")
|
| 593 |
+
and upsample_block.has_cross_attention
|
| 594 |
+
):
|
| 595 |
+
sample = upsample_block(
|
| 596 |
+
hidden_states=sample,
|
| 597 |
+
temb=emb,
|
| 598 |
+
res_hidden_states_tuple=res_samples,
|
| 599 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 600 |
+
upsample_size=upsample_size,
|
| 601 |
+
attention_mask=attention_mask,
|
| 602 |
+
skip_mm=skip_mm,
|
| 603 |
+
)
|
| 604 |
+
else:
|
| 605 |
+
sample = upsample_block(
|
| 606 |
+
hidden_states=sample,
|
| 607 |
+
temb=emb,
|
| 608 |
+
res_hidden_states_tuple=res_samples,
|
| 609 |
+
upsample_size=upsample_size,
|
| 610 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 611 |
+
skip_mm=skip_mm,
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
# post-process
|
| 615 |
+
sample = self.conv_norm_out(sample)
|
| 616 |
+
sample = self.conv_act(sample)
|
| 617 |
+
sample = self.conv_out(sample)
|
| 618 |
+
|
| 619 |
+
if not return_dict:
|
| 620 |
+
return (sample,)
|
| 621 |
+
|
| 622 |
+
return UNet3DConditionOutput(sample=sample)
|
| 623 |
+
|
| 624 |
+
@classmethod
|
| 625 |
+
def from_pretrained_2d(
|
| 626 |
+
cls,
|
| 627 |
+
pretrained_model_path: PathLike,
|
| 628 |
+
motion_module_path: PathLike,
|
| 629 |
+
subfolder=None,
|
| 630 |
+
unet_additional_kwargs=None,
|
| 631 |
+
mm_zero_proj_out=False,
|
| 632 |
+
):
|
| 633 |
+
pretrained_model_path = Path(pretrained_model_path)
|
| 634 |
+
motion_module_path = Path(motion_module_path)
|
| 635 |
+
if subfolder is not None:
|
| 636 |
+
pretrained_model_path = pretrained_model_path.joinpath(subfolder)
|
| 637 |
+
logger.info(
|
| 638 |
+
f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..."
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
config_file = pretrained_model_path / "config.json"
|
| 642 |
+
if not (config_file.exists() and config_file.is_file()):
|
| 643 |
+
raise RuntimeError(f"{config_file} does not exist or is not a file")
|
| 644 |
+
|
| 645 |
+
unet_config = cls.load_config(config_file)
|
| 646 |
+
unet_config["_class_name"] = cls.__name__
|
| 647 |
+
unet_config["down_block_types"] = [
|
| 648 |
+
"CrossAttnDownBlock3D",
|
| 649 |
+
"CrossAttnDownBlock3D",
|
| 650 |
+
"CrossAttnDownBlock3D",
|
| 651 |
+
"DownBlock3D",
|
| 652 |
+
]
|
| 653 |
+
unet_config["up_block_types"] = [
|
| 654 |
+
"UpBlock3D",
|
| 655 |
+
"CrossAttnUpBlock3D",
|
| 656 |
+
"CrossAttnUpBlock3D",
|
| 657 |
+
"CrossAttnUpBlock3D",
|
| 658 |
+
]
|
| 659 |
+
unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
|
| 660 |
+
|
| 661 |
+
model = cls.from_config(unet_config, **unet_additional_kwargs)
|
| 662 |
+
# load the vanilla weights
|
| 663 |
+
if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists():
|
| 664 |
+
logger.debug(
|
| 665 |
+
f"loading safeTensors weights from {pretrained_model_path} ..."
|
| 666 |
+
)
|
| 667 |
+
state_dict = load_file(
|
| 668 |
+
pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu"
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists():
|
| 672 |
+
logger.debug(f"loading weights from {pretrained_model_path} ...")
|
| 673 |
+
state_dict = torch.load(
|
| 674 |
+
pretrained_model_path.joinpath(WEIGHTS_NAME),
|
| 675 |
+
map_location="cpu",
|
| 676 |
+
weights_only=True,
|
| 677 |
+
)
|
| 678 |
+
else:
|
| 679 |
+
raise FileNotFoundError(f"no weights file found in {pretrained_model_path}")
|
| 680 |
+
|
| 681 |
+
# load the motion module weights
|
| 682 |
+
if motion_module_path.exists() and motion_module_path.is_file():
|
| 683 |
+
if motion_module_path.suffix.lower() in [".pth", ".pt", ".ckpt"]:
|
| 684 |
+
logger.info(f"Load motion module params from {motion_module_path}")
|
| 685 |
+
motion_state_dict = torch.load(
|
| 686 |
+
motion_module_path, map_location="cpu", weights_only=True
|
| 687 |
+
)
|
| 688 |
+
elif motion_module_path.suffix.lower() == ".safetensors":
|
| 689 |
+
motion_state_dict = load_file(motion_module_path, device="cpu")
|
| 690 |
+
else:
|
| 691 |
+
raise RuntimeError(
|
| 692 |
+
f"unknown file format for motion module weights: {motion_module_path.suffix}"
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
motion_state_dict = {
|
| 696 |
+
k.replace('motion_modules.', 'temporal_modules.'): v for k, v in motion_state_dict.items() if not "pos_encoder" in k
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
if mm_zero_proj_out:
|
| 700 |
+
logger.info(f"Zero initialize proj_out layers in motion module...")
|
| 701 |
+
new_motion_state_dict = OrderedDict()
|
| 702 |
+
for k in motion_state_dict:
|
| 703 |
+
if "proj_out" in k:
|
| 704 |
+
continue
|
| 705 |
+
new_motion_state_dict[k] = motion_state_dict[k]
|
| 706 |
+
motion_state_dict = new_motion_state_dict
|
| 707 |
+
|
| 708 |
+
# merge the state dicts
|
| 709 |
+
state_dict.update(motion_state_dict)
|
| 710 |
+
|
| 711 |
+
# load the weights into the model
|
| 712 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
| 713 |
+
logger.debug(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
| 714 |
+
|
| 715 |
+
params = [
|
| 716 |
+
p.numel() if "temporal_modules" in n else 0
|
| 717 |
+
for n, p in model.named_parameters()
|
| 718 |
+
]
|
| 719 |
+
mm_params = [
|
| 720 |
+
p.numel() if "motion_modules" in n else 0
|
| 721 |
+
for n, p in model.named_parameters()
|
| 722 |
+
]
|
| 723 |
+
logger.info(
|
| 724 |
+
f"Loaded {sum(mm_params) / 1e6}M-parameter motion module, Loaded {sum(params) / 1e6}M-parameter temporal module"
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
return model
|
models/unet_3d_blocks.py
ADDED
|
@@ -0,0 +1,1121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# *************************************************************************
|
| 2 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# This file has been modified by ByteDance Ltd. and/or its affiliates.
|
| 6 |
+
#
|
| 7 |
+
# Original file was released under Aniportrait, with the full license text
|
| 8 |
+
# available at https://github.com/Zejun-Yang/AniPortrait/blob/main/LICENSE.
|
| 9 |
+
#
|
| 10 |
+
# This modified file is released under the same license.
|
| 11 |
+
# *************************************************************************
|
| 12 |
+
import pdb
|
| 13 |
+
from typing import Dict, Optional
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn
|
| 16 |
+
|
| 17 |
+
from src.models.motion_module import get_motion_module
|
| 18 |
+
|
| 19 |
+
# from .motion_module import get_motion_module
|
| 20 |
+
from src.models.resnet import Downsample3D, ResnetBlock3D, Upsample3D
|
| 21 |
+
from .transformer_3d import Transformer3DModel
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_down_block(
|
| 25 |
+
down_block_type,
|
| 26 |
+
num_layers,
|
| 27 |
+
in_channels,
|
| 28 |
+
out_channels,
|
| 29 |
+
temb_channels,
|
| 30 |
+
add_downsample,
|
| 31 |
+
resnet_eps,
|
| 32 |
+
resnet_act_fn,
|
| 33 |
+
attn_num_head_channels,
|
| 34 |
+
resnet_groups=None,
|
| 35 |
+
cross_attention_dim=None,
|
| 36 |
+
downsample_padding=None,
|
| 37 |
+
dual_cross_attention=False,
|
| 38 |
+
use_linear_projection=False,
|
| 39 |
+
only_cross_attention=False,
|
| 40 |
+
upcast_attention=False,
|
| 41 |
+
resnet_time_scale_shift="default",
|
| 42 |
+
unet_use_cross_frame_attention=None,
|
| 43 |
+
unet_use_temporal_attention=None,
|
| 44 |
+
use_inflated_groupnorm=None,
|
| 45 |
+
use_motion_module=None,
|
| 46 |
+
motion_module_type=None,
|
| 47 |
+
motion_module_kwargs=None,
|
| 48 |
+
use_temporal_module=None,
|
| 49 |
+
temporal_module_type=None,
|
| 50 |
+
temporal_module_kwargs=None,
|
| 51 |
+
):
|
| 52 |
+
down_block_type = (
|
| 53 |
+
down_block_type[7:]
|
| 54 |
+
if down_block_type.startswith("UNetRes")
|
| 55 |
+
else down_block_type
|
| 56 |
+
)
|
| 57 |
+
if down_block_type == "DownBlock3D":
|
| 58 |
+
return DownBlock3D(
|
| 59 |
+
num_layers=num_layers,
|
| 60 |
+
in_channels=in_channels,
|
| 61 |
+
out_channels=out_channels,
|
| 62 |
+
temb_channels=temb_channels,
|
| 63 |
+
add_downsample=add_downsample,
|
| 64 |
+
resnet_eps=resnet_eps,
|
| 65 |
+
resnet_act_fn=resnet_act_fn,
|
| 66 |
+
resnet_groups=resnet_groups,
|
| 67 |
+
downsample_padding=downsample_padding,
|
| 68 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 69 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 70 |
+
use_motion_module=use_motion_module,
|
| 71 |
+
motion_module_type=motion_module_type,
|
| 72 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 73 |
+
use_temporal_module=use_temporal_module,
|
| 74 |
+
temporal_module_type=temporal_module_type,
|
| 75 |
+
temporal_module_kwargs=temporal_module_kwargs,
|
| 76 |
+
)
|
| 77 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
| 78 |
+
if cross_attention_dim is None:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
"cross_attention_dim must be specified for CrossAttnDownBlock3D"
|
| 81 |
+
)
|
| 82 |
+
return CrossAttnDownBlock3D(
|
| 83 |
+
num_layers=num_layers,
|
| 84 |
+
in_channels=in_channels,
|
| 85 |
+
out_channels=out_channels,
|
| 86 |
+
temb_channels=temb_channels,
|
| 87 |
+
add_downsample=add_downsample,
|
| 88 |
+
resnet_eps=resnet_eps,
|
| 89 |
+
resnet_act_fn=resnet_act_fn,
|
| 90 |
+
resnet_groups=resnet_groups,
|
| 91 |
+
downsample_padding=downsample_padding,
|
| 92 |
+
cross_attention_dim=cross_attention_dim,
|
| 93 |
+
attn_num_head_channels=attn_num_head_channels,
|
| 94 |
+
dual_cross_attention=dual_cross_attention,
|
| 95 |
+
use_linear_projection=use_linear_projection,
|
| 96 |
+
only_cross_attention=only_cross_attention,
|
| 97 |
+
upcast_attention=upcast_attention,
|
| 98 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 99 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 100 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 101 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 102 |
+
use_motion_module=use_motion_module,
|
| 103 |
+
motion_module_type=motion_module_type,
|
| 104 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 105 |
+
use_temporal_module=use_temporal_module,
|
| 106 |
+
temporal_module_type=temporal_module_type,
|
| 107 |
+
temporal_module_kwargs=temporal_module_kwargs,
|
| 108 |
+
)
|
| 109 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_up_block(
|
| 113 |
+
up_block_type,
|
| 114 |
+
num_layers,
|
| 115 |
+
in_channels,
|
| 116 |
+
out_channels,
|
| 117 |
+
prev_output_channel,
|
| 118 |
+
temb_channels,
|
| 119 |
+
add_upsample,
|
| 120 |
+
resnet_eps,
|
| 121 |
+
resnet_act_fn,
|
| 122 |
+
attn_num_head_channels,
|
| 123 |
+
resnet_groups=None,
|
| 124 |
+
cross_attention_dim=None,
|
| 125 |
+
dual_cross_attention=False,
|
| 126 |
+
use_linear_projection=False,
|
| 127 |
+
only_cross_attention=False,
|
| 128 |
+
upcast_attention=False,
|
| 129 |
+
resnet_time_scale_shift="default",
|
| 130 |
+
unet_use_cross_frame_attention=None,
|
| 131 |
+
unet_use_temporal_attention=None,
|
| 132 |
+
use_inflated_groupnorm=None,
|
| 133 |
+
use_motion_module=None,
|
| 134 |
+
motion_module_type=None,
|
| 135 |
+
motion_module_kwargs=None,
|
| 136 |
+
use_temporal_module=None,
|
| 137 |
+
temporal_module_type=None,
|
| 138 |
+
temporal_module_kwargs=None,
|
| 139 |
+
):
|
| 140 |
+
up_block_type = (
|
| 141 |
+
up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
| 142 |
+
)
|
| 143 |
+
if up_block_type == "UpBlock3D":
|
| 144 |
+
return UpBlock3D(
|
| 145 |
+
num_layers=num_layers,
|
| 146 |
+
in_channels=in_channels,
|
| 147 |
+
out_channels=out_channels,
|
| 148 |
+
prev_output_channel=prev_output_channel,
|
| 149 |
+
temb_channels=temb_channels,
|
| 150 |
+
add_upsample=add_upsample,
|
| 151 |
+
resnet_eps=resnet_eps,
|
| 152 |
+
resnet_act_fn=resnet_act_fn,
|
| 153 |
+
resnet_groups=resnet_groups,
|
| 154 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 155 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 156 |
+
use_motion_module=use_motion_module,
|
| 157 |
+
motion_module_type=motion_module_type,
|
| 158 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 159 |
+
use_temporal_module=use_temporal_module,
|
| 160 |
+
temporal_module_type=temporal_module_type,
|
| 161 |
+
temporal_module_kwargs=temporal_module_kwargs,
|
| 162 |
+
)
|
| 163 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
| 164 |
+
if cross_attention_dim is None:
|
| 165 |
+
raise ValueError(
|
| 166 |
+
"cross_attention_dim must be specified for CrossAttnUpBlock3D"
|
| 167 |
+
)
|
| 168 |
+
return CrossAttnUpBlock3D(
|
| 169 |
+
num_layers=num_layers,
|
| 170 |
+
in_channels=in_channels,
|
| 171 |
+
out_channels=out_channels,
|
| 172 |
+
prev_output_channel=prev_output_channel,
|
| 173 |
+
temb_channels=temb_channels,
|
| 174 |
+
add_upsample=add_upsample,
|
| 175 |
+
resnet_eps=resnet_eps,
|
| 176 |
+
resnet_act_fn=resnet_act_fn,
|
| 177 |
+
resnet_groups=resnet_groups,
|
| 178 |
+
cross_attention_dim=cross_attention_dim,
|
| 179 |
+
attn_num_head_channels=attn_num_head_channels,
|
| 180 |
+
dual_cross_attention=dual_cross_attention,
|
| 181 |
+
use_linear_projection=use_linear_projection,
|
| 182 |
+
only_cross_attention=only_cross_attention,
|
| 183 |
+
upcast_attention=upcast_attention,
|
| 184 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 185 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 186 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 187 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 188 |
+
use_motion_module=use_motion_module,
|
| 189 |
+
motion_module_type=motion_module_type,
|
| 190 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 191 |
+
use_temporal_module=use_temporal_module,
|
| 192 |
+
temporal_module_type=temporal_module_type,
|
| 193 |
+
temporal_module_kwargs=temporal_module_kwargs,
|
| 194 |
+
)
|
| 195 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
| 199 |
+
|
| 200 |
+
def __init__(
|
| 201 |
+
self,
|
| 202 |
+
in_channels: int,
|
| 203 |
+
temb_channels: int,
|
| 204 |
+
dropout: float = 0.0,
|
| 205 |
+
num_layers: int = 1,
|
| 206 |
+
resnet_eps: float = 1e-6,
|
| 207 |
+
resnet_time_scale_shift: str = "default",
|
| 208 |
+
resnet_act_fn: str = "swish",
|
| 209 |
+
resnet_groups: int = 32,
|
| 210 |
+
resnet_pre_norm: bool = True,
|
| 211 |
+
attn_num_head_channels=1,
|
| 212 |
+
output_scale_factor=1.0,
|
| 213 |
+
cross_attention_dim=1280,
|
| 214 |
+
dual_cross_attention=False,
|
| 215 |
+
use_linear_projection=False,
|
| 216 |
+
upcast_attention=False,
|
| 217 |
+
unet_use_cross_frame_attention=None,
|
| 218 |
+
unet_use_temporal_attention=None,
|
| 219 |
+
use_inflated_groupnorm=None,
|
| 220 |
+
use_motion_module=None,
|
| 221 |
+
motion_module_type=None,
|
| 222 |
+
motion_module_kwargs=None,
|
| 223 |
+
use_temporal_module=None,
|
| 224 |
+
temporal_module_type=None,
|
| 225 |
+
temporal_module_kwargs=None,
|
| 226 |
+
**transformer_kwargs,
|
| 227 |
+
):
|
| 228 |
+
super().__init__()
|
| 229 |
+
|
| 230 |
+
self.has_cross_attention = True
|
| 231 |
+
self.attn_num_head_channels = attn_num_head_channels
|
| 232 |
+
resnet_groups = (
|
| 233 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# there is always at least one resnet
|
| 237 |
+
resnets = [
|
| 238 |
+
ResnetBlock3D(
|
| 239 |
+
in_channels=in_channels,
|
| 240 |
+
out_channels=in_channels,
|
| 241 |
+
temb_channels=temb_channels,
|
| 242 |
+
eps=resnet_eps,
|
| 243 |
+
groups=resnet_groups,
|
| 244 |
+
dropout=dropout,
|
| 245 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 246 |
+
non_linearity=resnet_act_fn,
|
| 247 |
+
output_scale_factor=output_scale_factor,
|
| 248 |
+
pre_norm=resnet_pre_norm,
|
| 249 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 250 |
+
)
|
| 251 |
+
]
|
| 252 |
+
attentions = []
|
| 253 |
+
motion_modules = []
|
| 254 |
+
|
| 255 |
+
for _ in range(num_layers):
|
| 256 |
+
if dual_cross_attention:
|
| 257 |
+
raise NotImplementedError
|
| 258 |
+
attentions.append(
|
| 259 |
+
Transformer3DModel(
|
| 260 |
+
attn_num_head_channels,
|
| 261 |
+
in_channels // attn_num_head_channels,
|
| 262 |
+
in_channels=in_channels,
|
| 263 |
+
num_layers=1,
|
| 264 |
+
cross_attention_dim=cross_attention_dim,
|
| 265 |
+
norm_num_groups=resnet_groups,
|
| 266 |
+
use_linear_projection=use_linear_projection,
|
| 267 |
+
upcast_attention=upcast_attention,
|
| 268 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 269 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 270 |
+
**transformer_kwargs
|
| 271 |
+
)
|
| 272 |
+
)
|
| 273 |
+
motion_modules.append(
|
| 274 |
+
get_motion_module(
|
| 275 |
+
in_channels=in_channels,
|
| 276 |
+
motion_module_type=motion_module_type,
|
| 277 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 278 |
+
)
|
| 279 |
+
if use_motion_module
|
| 280 |
+
else None
|
| 281 |
+
)
|
| 282 |
+
resnets.append(
|
| 283 |
+
ResnetBlock3D(
|
| 284 |
+
in_channels=in_channels,
|
| 285 |
+
out_channels=in_channels,
|
| 286 |
+
temb_channels=temb_channels,
|
| 287 |
+
eps=resnet_eps,
|
| 288 |
+
groups=resnet_groups,
|
| 289 |
+
dropout=dropout,
|
| 290 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 291 |
+
non_linearity=resnet_act_fn,
|
| 292 |
+
output_scale_factor=output_scale_factor,
|
| 293 |
+
pre_norm=resnet_pre_norm,
|
| 294 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 295 |
+
)
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
self.attentions = nn.ModuleList(attentions)
|
| 299 |
+
self.resnets = nn.ModuleList(resnets)
|
| 300 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
| 301 |
+
self.temporal_modules = nn.ModuleList(
|
| 302 |
+
[
|
| 303 |
+
(
|
| 304 |
+
get_motion_module(
|
| 305 |
+
in_channels=in_channels,
|
| 306 |
+
motion_module_type=temporal_module_type,
|
| 307 |
+
motion_module_kwargs=temporal_module_kwargs,
|
| 308 |
+
)
|
| 309 |
+
if use_temporal_module
|
| 310 |
+
else None
|
| 311 |
+
)
|
| 312 |
+
for _ in range(num_layers)
|
| 313 |
+
]
|
| 314 |
+
)
|
| 315 |
+
self.gradient_checkpointing = False
|
| 316 |
+
|
| 317 |
+
def forward(
|
| 318 |
+
self,
|
| 319 |
+
hidden_states,
|
| 320 |
+
temb=None,
|
| 321 |
+
encoder_hidden_states=None,
|
| 322 |
+
attention_mask=None,
|
| 323 |
+
skip_mm=False,
|
| 324 |
+
):
|
| 325 |
+
if isinstance(encoder_hidden_states, list):
|
| 326 |
+
encoder_hidden_states, motion_hidden_states = encoder_hidden_states
|
| 327 |
+
else:
|
| 328 |
+
motion_hidden_states = encoder_hidden_states
|
| 329 |
+
|
| 330 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
| 331 |
+
for attn, resnet, motion_module, temporal_module in zip(
|
| 332 |
+
self.attentions, self.resnets[1:], self.motion_modules, self.temporal_modules
|
| 333 |
+
):
|
| 334 |
+
if self.training and self.gradient_checkpointing:
|
| 335 |
+
def create_custom_forward(module, return_dict=None):
|
| 336 |
+
def custom_forward(*inputs):
|
| 337 |
+
if return_dict is not None:
|
| 338 |
+
return module(*inputs, return_dict=return_dict)
|
| 339 |
+
else:
|
| 340 |
+
return module(*inputs)
|
| 341 |
+
|
| 342 |
+
return custom_forward
|
| 343 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 344 |
+
create_custom_forward(attn, return_dict=False),
|
| 345 |
+
hidden_states,
|
| 346 |
+
encoder_hidden_states,
|
| 347 |
+
)[0]
|
| 348 |
+
if (motion_module is not None) and not skip_mm:
|
| 349 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 350 |
+
create_custom_forward(motion_module),
|
| 351 |
+
hidden_states,
|
| 352 |
+
temb,
|
| 353 |
+
motion_hidden_states,
|
| 354 |
+
)
|
| 355 |
+
if (temporal_module is not None) and not skip_mm:
|
| 356 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 357 |
+
create_custom_forward(temporal_module),
|
| 358 |
+
hidden_states.requires_grad_(),
|
| 359 |
+
temb,
|
| 360 |
+
None,
|
| 361 |
+
)
|
| 362 |
+
# hidden_states = (
|
| 363 |
+
# temporal_module(hidden_states, temb, encoder_hidden_states=None)
|
| 364 |
+
# if (temporal_module is not None) and not skip_mm
|
| 365 |
+
# else hidden_states
|
| 366 |
+
# )
|
| 367 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 368 |
+
create_custom_forward(resnet), hidden_states, temb
|
| 369 |
+
)
|
| 370 |
+
else:
|
| 371 |
+
hidden_states = attn(
|
| 372 |
+
hidden_states,
|
| 373 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 374 |
+
).sample
|
| 375 |
+
hidden_states = (
|
| 376 |
+
motion_module(
|
| 377 |
+
hidden_states, temb, encoder_hidden_states=motion_hidden_states
|
| 378 |
+
)
|
| 379 |
+
if (motion_module is not None) and not skip_mm
|
| 380 |
+
else hidden_states
|
| 381 |
+
)
|
| 382 |
+
hidden_states = (
|
| 383 |
+
temporal_module(hidden_states, temb, encoder_hidden_states=None, debug=True)
|
| 384 |
+
if (temporal_module is not None) and not skip_mm
|
| 385 |
+
else hidden_states
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
hidden_states = resnet(hidden_states, temb)
|
| 389 |
+
|
| 390 |
+
return hidden_states
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
class CrossAttnDownBlock3D(nn.Module):
|
| 394 |
+
|
| 395 |
+
def __init__(
|
| 396 |
+
self,
|
| 397 |
+
in_channels: int,
|
| 398 |
+
out_channels: int,
|
| 399 |
+
temb_channels: int,
|
| 400 |
+
dropout: float = 0.0,
|
| 401 |
+
num_layers: int = 1,
|
| 402 |
+
resnet_eps: float = 1e-6,
|
| 403 |
+
resnet_time_scale_shift: str = "default",
|
| 404 |
+
resnet_act_fn: str = "swish",
|
| 405 |
+
resnet_groups: int = 32,
|
| 406 |
+
resnet_pre_norm: bool = True,
|
| 407 |
+
attn_num_head_channels=1,
|
| 408 |
+
cross_attention_dim=1280,
|
| 409 |
+
output_scale_factor=1.0,
|
| 410 |
+
downsample_padding=1,
|
| 411 |
+
add_downsample=True,
|
| 412 |
+
dual_cross_attention=False,
|
| 413 |
+
use_linear_projection=False,
|
| 414 |
+
only_cross_attention=False,
|
| 415 |
+
upcast_attention=False,
|
| 416 |
+
unet_use_cross_frame_attention=None,
|
| 417 |
+
unet_use_temporal_attention=None,
|
| 418 |
+
use_inflated_groupnorm=None,
|
| 419 |
+
use_motion_module=None,
|
| 420 |
+
motion_module_type=None,
|
| 421 |
+
motion_module_kwargs=None,
|
| 422 |
+
use_temporal_module=None,
|
| 423 |
+
temporal_module_type=None,
|
| 424 |
+
temporal_module_kwargs=None,
|
| 425 |
+
**transformer_kwargs,
|
| 426 |
+
):
|
| 427 |
+
super().__init__()
|
| 428 |
+
resnets = []
|
| 429 |
+
attentions = []
|
| 430 |
+
motion_modules = []
|
| 431 |
+
|
| 432 |
+
self.has_cross_attention = True
|
| 433 |
+
self.attn_num_head_channels = attn_num_head_channels
|
| 434 |
+
|
| 435 |
+
for i in range(num_layers):
|
| 436 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 437 |
+
resnets.append(
|
| 438 |
+
ResnetBlock3D(
|
| 439 |
+
in_channels=in_channels,
|
| 440 |
+
out_channels=out_channels,
|
| 441 |
+
temb_channels=temb_channels,
|
| 442 |
+
eps=resnet_eps,
|
| 443 |
+
groups=resnet_groups,
|
| 444 |
+
dropout=dropout,
|
| 445 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 446 |
+
non_linearity=resnet_act_fn,
|
| 447 |
+
output_scale_factor=output_scale_factor,
|
| 448 |
+
pre_norm=resnet_pre_norm,
|
| 449 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 450 |
+
)
|
| 451 |
+
)
|
| 452 |
+
if dual_cross_attention:
|
| 453 |
+
raise NotImplementedError
|
| 454 |
+
attentions.append(
|
| 455 |
+
Transformer3DModel(
|
| 456 |
+
attn_num_head_channels,
|
| 457 |
+
out_channels // attn_num_head_channels,
|
| 458 |
+
in_channels=out_channels,
|
| 459 |
+
num_layers=1,
|
| 460 |
+
cross_attention_dim=cross_attention_dim,
|
| 461 |
+
norm_num_groups=resnet_groups,
|
| 462 |
+
use_linear_projection=use_linear_projection,
|
| 463 |
+
only_cross_attention=only_cross_attention,
|
| 464 |
+
upcast_attention=upcast_attention,
|
| 465 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 466 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 467 |
+
**transformer_kwargs,
|
| 468 |
+
)
|
| 469 |
+
)
|
| 470 |
+
motion_modules.append(
|
| 471 |
+
get_motion_module(
|
| 472 |
+
in_channels=out_channels,
|
| 473 |
+
motion_module_type=motion_module_type,
|
| 474 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 475 |
+
)
|
| 476 |
+
if use_motion_module
|
| 477 |
+
else None
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
self.attentions = nn.ModuleList(attentions)
|
| 481 |
+
self.resnets = nn.ModuleList(resnets)
|
| 482 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
| 483 |
+
self.temporal_modules = nn.ModuleList(
|
| 484 |
+
[
|
| 485 |
+
(
|
| 486 |
+
get_motion_module(
|
| 487 |
+
in_channels=out_channels,
|
| 488 |
+
motion_module_type=temporal_module_type,
|
| 489 |
+
motion_module_kwargs=temporal_module_kwargs,
|
| 490 |
+
)
|
| 491 |
+
if use_temporal_module
|
| 492 |
+
else None
|
| 493 |
+
)
|
| 494 |
+
for _ in range(num_layers)
|
| 495 |
+
]
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
if add_downsample:
|
| 499 |
+
self.downsamplers = nn.ModuleList(
|
| 500 |
+
[
|
| 501 |
+
Downsample3D(
|
| 502 |
+
out_channels,
|
| 503 |
+
use_conv=True,
|
| 504 |
+
out_channels=out_channels,
|
| 505 |
+
padding=downsample_padding,
|
| 506 |
+
name="op",
|
| 507 |
+
)
|
| 508 |
+
]
|
| 509 |
+
)
|
| 510 |
+
else:
|
| 511 |
+
self.downsamplers = None
|
| 512 |
+
|
| 513 |
+
self.gradient_checkpointing = False
|
| 514 |
+
|
| 515 |
+
def forward(
|
| 516 |
+
self,
|
| 517 |
+
hidden_states,
|
| 518 |
+
temb=None,
|
| 519 |
+
encoder_hidden_states=None,
|
| 520 |
+
attention_mask=None,
|
| 521 |
+
skip_mm=False
|
| 522 |
+
):
|
| 523 |
+
if isinstance(encoder_hidden_states, list):
|
| 524 |
+
encoder_hidden_states, motion_hidden_states = encoder_hidden_states
|
| 525 |
+
else:
|
| 526 |
+
motion_hidden_states = encoder_hidden_states
|
| 527 |
+
|
| 528 |
+
output_states = ()
|
| 529 |
+
|
| 530 |
+
for i, (resnet, attn, motion_module, temporal_module) in enumerate(
|
| 531 |
+
zip(self.resnets, self.attentions, self.motion_modules, self.temporal_modules)
|
| 532 |
+
):
|
| 533 |
+
|
| 534 |
+
# self.gradient_checkpointing = False
|
| 535 |
+
if self.training and self.gradient_checkpointing:
|
| 536 |
+
|
| 537 |
+
def create_custom_forward(module, return_dict=None):
|
| 538 |
+
def custom_forward(*inputs):
|
| 539 |
+
if return_dict is not None:
|
| 540 |
+
return module(*inputs, return_dict=return_dict)
|
| 541 |
+
else:
|
| 542 |
+
return module(*inputs)
|
| 543 |
+
|
| 544 |
+
return custom_forward
|
| 545 |
+
|
| 546 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 547 |
+
create_custom_forward(resnet), hidden_states, temb
|
| 548 |
+
)
|
| 549 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 550 |
+
create_custom_forward(attn, return_dict=False),
|
| 551 |
+
hidden_states,
|
| 552 |
+
encoder_hidden_states,
|
| 553 |
+
)[0]
|
| 554 |
+
|
| 555 |
+
# add motion module
|
| 556 |
+
if (motion_module is not None) and not skip_mm:
|
| 557 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 558 |
+
create_custom_forward(motion_module),
|
| 559 |
+
hidden_states,
|
| 560 |
+
temb,
|
| 561 |
+
motion_hidden_states,
|
| 562 |
+
)
|
| 563 |
+
if (temporal_module is not None) and not skip_mm:
|
| 564 |
+
# hidden_states = torch.utils.checkpoint.checkpoint(
|
| 565 |
+
# create_custom_forward(temporal_module),
|
| 566 |
+
# hidden_states.requires_grad_(),
|
| 567 |
+
# temb,
|
| 568 |
+
# None,
|
| 569 |
+
# )
|
| 570 |
+
hidden_states = (
|
| 571 |
+
temporal_module(hidden_states, temb, encoder_hidden_states=None)
|
| 572 |
+
if (temporal_module is not None) and not skip_mm
|
| 573 |
+
else hidden_states
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
else:
|
| 577 |
+
hidden_states = resnet(hidden_states, temb)
|
| 578 |
+
hidden_states = attn(
|
| 579 |
+
hidden_states,
|
| 580 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 581 |
+
).sample
|
| 582 |
+
|
| 583 |
+
# add motion module
|
| 584 |
+
hidden_states = (
|
| 585 |
+
motion_module(
|
| 586 |
+
hidden_states, temb, encoder_hidden_states=motion_hidden_states
|
| 587 |
+
)
|
| 588 |
+
if (motion_module is not None) and not skip_mm
|
| 589 |
+
else hidden_states
|
| 590 |
+
)
|
| 591 |
+
hidden_states = (
|
| 592 |
+
temporal_module(hidden_states, temb, encoder_hidden_states=None, debug=True)
|
| 593 |
+
if (temporal_module is not None) and not skip_mm
|
| 594 |
+
else hidden_states
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
output_states += (hidden_states,)
|
| 598 |
+
|
| 599 |
+
if self.downsamplers is not None:
|
| 600 |
+
for downsampler in self.downsamplers:
|
| 601 |
+
hidden_states = downsampler(hidden_states)
|
| 602 |
+
|
| 603 |
+
output_states += (hidden_states,)
|
| 604 |
+
|
| 605 |
+
return hidden_states, output_states
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
class DownBlock3D(nn.Module):
|
| 609 |
+
|
| 610 |
+
def __init__(
|
| 611 |
+
self,
|
| 612 |
+
in_channels: int,
|
| 613 |
+
out_channels: int,
|
| 614 |
+
temb_channels: int,
|
| 615 |
+
dropout: float = 0.0,
|
| 616 |
+
num_layers: int = 1,
|
| 617 |
+
resnet_eps: float = 1e-6,
|
| 618 |
+
resnet_time_scale_shift: str = "default",
|
| 619 |
+
resnet_act_fn: str = "swish",
|
| 620 |
+
resnet_groups: int = 32,
|
| 621 |
+
resnet_pre_norm: bool = True,
|
| 622 |
+
output_scale_factor=1.0,
|
| 623 |
+
add_downsample=True,
|
| 624 |
+
downsample_padding=1,
|
| 625 |
+
use_inflated_groupnorm=None,
|
| 626 |
+
use_motion_module=None,
|
| 627 |
+
motion_module_type=None,
|
| 628 |
+
motion_module_kwargs=None,
|
| 629 |
+
use_temporal_module=None,
|
| 630 |
+
temporal_module_type=None,
|
| 631 |
+
temporal_module_kwargs=None,
|
| 632 |
+
):
|
| 633 |
+
super().__init__()
|
| 634 |
+
resnets = []
|
| 635 |
+
motion_modules = []
|
| 636 |
+
|
| 637 |
+
# use_motion_module = False
|
| 638 |
+
for i in range(num_layers):
|
| 639 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 640 |
+
resnets.append(
|
| 641 |
+
ResnetBlock3D(
|
| 642 |
+
in_channels=in_channels,
|
| 643 |
+
out_channels=out_channels,
|
| 644 |
+
temb_channels=temb_channels,
|
| 645 |
+
eps=resnet_eps,
|
| 646 |
+
groups=resnet_groups,
|
| 647 |
+
dropout=dropout,
|
| 648 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 649 |
+
non_linearity=resnet_act_fn,
|
| 650 |
+
output_scale_factor=output_scale_factor,
|
| 651 |
+
pre_norm=resnet_pre_norm,
|
| 652 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 653 |
+
)
|
| 654 |
+
)
|
| 655 |
+
motion_modules.append(
|
| 656 |
+
get_motion_module(
|
| 657 |
+
in_channels=out_channels,
|
| 658 |
+
motion_module_type=motion_module_type,
|
| 659 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 660 |
+
)
|
| 661 |
+
if use_motion_module
|
| 662 |
+
else None
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
self.resnets = nn.ModuleList(resnets)
|
| 666 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
| 667 |
+
self.temporal_modules = nn.ModuleList(
|
| 668 |
+
[
|
| 669 |
+
(
|
| 670 |
+
get_motion_module(
|
| 671 |
+
in_channels=out_channels,
|
| 672 |
+
motion_module_type=temporal_module_type,
|
| 673 |
+
motion_module_kwargs=temporal_module_kwargs,
|
| 674 |
+
)
|
| 675 |
+
if use_temporal_module
|
| 676 |
+
else None
|
| 677 |
+
)
|
| 678 |
+
for _ in range(num_layers)
|
| 679 |
+
]
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
if add_downsample:
|
| 683 |
+
self.downsamplers = nn.ModuleList(
|
| 684 |
+
[
|
| 685 |
+
Downsample3D(
|
| 686 |
+
out_channels,
|
| 687 |
+
use_conv=True,
|
| 688 |
+
out_channels=out_channels,
|
| 689 |
+
padding=downsample_padding,
|
| 690 |
+
name="op",
|
| 691 |
+
)
|
| 692 |
+
]
|
| 693 |
+
)
|
| 694 |
+
else:
|
| 695 |
+
self.downsamplers = None
|
| 696 |
+
|
| 697 |
+
self.gradient_checkpointing = False
|
| 698 |
+
|
| 699 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, skip_mm=False):
|
| 700 |
+
output_states = ()
|
| 701 |
+
if isinstance(encoder_hidden_states, list):
|
| 702 |
+
encoder_hidden_states, motion_hidden_states = encoder_hidden_states
|
| 703 |
+
else:
|
| 704 |
+
motion_hidden_states = encoder_hidden_states
|
| 705 |
+
for resnet, motion_module, temporal_module in zip(
|
| 706 |
+
self.resnets, self.motion_modules, self.temporal_modules
|
| 707 |
+
):
|
| 708 |
+
# print(f"DownBlock3D {self.gradient_checkpointing = }")
|
| 709 |
+
if self.training and self.gradient_checkpointing:
|
| 710 |
+
|
| 711 |
+
def create_custom_forward(module):
|
| 712 |
+
def custom_forward(*inputs):
|
| 713 |
+
return module(*inputs)
|
| 714 |
+
|
| 715 |
+
return custom_forward
|
| 716 |
+
|
| 717 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 718 |
+
create_custom_forward(resnet), hidden_states, temb
|
| 719 |
+
)
|
| 720 |
+
if (motion_module is not None) and not skip_mm:
|
| 721 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 722 |
+
create_custom_forward(motion_module),
|
| 723 |
+
hidden_states,
|
| 724 |
+
temb,
|
| 725 |
+
motion_hidden_states,
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
if (temporal_module is not None) and not skip_mm:
|
| 729 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 730 |
+
create_custom_forward(temporal_module),
|
| 731 |
+
hidden_states.requires_grad_(),
|
| 732 |
+
temb,
|
| 733 |
+
None,
|
| 734 |
+
)
|
| 735 |
+
else:
|
| 736 |
+
hidden_states = resnet(hidden_states, temb)
|
| 737 |
+
|
| 738 |
+
# add motion module
|
| 739 |
+
hidden_states = (
|
| 740 |
+
motion_module(
|
| 741 |
+
hidden_states, temb, encoder_hidden_states=motion_hidden_states
|
| 742 |
+
)
|
| 743 |
+
if (motion_module is not None) and not skip_mm
|
| 744 |
+
else hidden_states
|
| 745 |
+
)
|
| 746 |
+
hidden_states = (
|
| 747 |
+
temporal_module(
|
| 748 |
+
hidden_states, temb, encoder_hidden_states=None, debug=True
|
| 749 |
+
)
|
| 750 |
+
if (temporal_module is not None) and not skip_mm
|
| 751 |
+
else hidden_states
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
output_states += (hidden_states,)
|
| 755 |
+
|
| 756 |
+
if self.downsamplers is not None:
|
| 757 |
+
for downsampler in self.downsamplers:
|
| 758 |
+
hidden_states = downsampler(hidden_states)
|
| 759 |
+
|
| 760 |
+
output_states += (hidden_states,)
|
| 761 |
+
|
| 762 |
+
return hidden_states, output_states
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
class CrossAttnUpBlock3D(nn.Module):
|
| 766 |
+
|
| 767 |
+
def __init__(
|
| 768 |
+
self,
|
| 769 |
+
in_channels: int,
|
| 770 |
+
out_channels: int,
|
| 771 |
+
prev_output_channel: int,
|
| 772 |
+
temb_channels: int,
|
| 773 |
+
dropout: float = 0.0,
|
| 774 |
+
num_layers: int = 1,
|
| 775 |
+
resnet_eps: float = 1e-6,
|
| 776 |
+
resnet_time_scale_shift: str = "default",
|
| 777 |
+
resnet_act_fn: str = "swish",
|
| 778 |
+
resnet_groups: int = 32,
|
| 779 |
+
resnet_pre_norm: bool = True,
|
| 780 |
+
attn_num_head_channels=1,
|
| 781 |
+
cross_attention_dim=1280,
|
| 782 |
+
output_scale_factor=1.0,
|
| 783 |
+
add_upsample=True,
|
| 784 |
+
dual_cross_attention=False,
|
| 785 |
+
use_linear_projection=False,
|
| 786 |
+
only_cross_attention=False,
|
| 787 |
+
upcast_attention=False,
|
| 788 |
+
unet_use_cross_frame_attention=None,
|
| 789 |
+
unet_use_temporal_attention=None,
|
| 790 |
+
use_motion_module=None,
|
| 791 |
+
use_inflated_groupnorm=None,
|
| 792 |
+
motion_module_type=None,
|
| 793 |
+
motion_module_kwargs=None,
|
| 794 |
+
use_temporal_module=None,
|
| 795 |
+
temporal_module_type=None,
|
| 796 |
+
temporal_module_kwargs=None,
|
| 797 |
+
**transformer_kwargs,
|
| 798 |
+
):
|
| 799 |
+
super().__init__()
|
| 800 |
+
resnets = []
|
| 801 |
+
attentions = []
|
| 802 |
+
motion_modules = []
|
| 803 |
+
|
| 804 |
+
self.has_cross_attention = True
|
| 805 |
+
self.attn_num_head_channels = attn_num_head_channels
|
| 806 |
+
|
| 807 |
+
for i in range(num_layers):
|
| 808 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
| 809 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
| 810 |
+
|
| 811 |
+
resnets.append(
|
| 812 |
+
ResnetBlock3D(
|
| 813 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
| 814 |
+
out_channels=out_channels,
|
| 815 |
+
temb_channels=temb_channels,
|
| 816 |
+
eps=resnet_eps,
|
| 817 |
+
groups=resnet_groups,
|
| 818 |
+
dropout=dropout,
|
| 819 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 820 |
+
non_linearity=resnet_act_fn,
|
| 821 |
+
output_scale_factor=output_scale_factor,
|
| 822 |
+
pre_norm=resnet_pre_norm,
|
| 823 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 824 |
+
)
|
| 825 |
+
)
|
| 826 |
+
if dual_cross_attention:
|
| 827 |
+
raise NotImplementedError
|
| 828 |
+
attentions.append(
|
| 829 |
+
Transformer3DModel(
|
| 830 |
+
attn_num_head_channels,
|
| 831 |
+
out_channels // attn_num_head_channels,
|
| 832 |
+
in_channels=out_channels,
|
| 833 |
+
num_layers=1,
|
| 834 |
+
cross_attention_dim=cross_attention_dim,
|
| 835 |
+
norm_num_groups=resnet_groups,
|
| 836 |
+
use_linear_projection=use_linear_projection,
|
| 837 |
+
only_cross_attention=only_cross_attention,
|
| 838 |
+
upcast_attention=upcast_attention,
|
| 839 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
| 840 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
| 841 |
+
**transformer_kwargs,
|
| 842 |
+
)
|
| 843 |
+
)
|
| 844 |
+
motion_modules.append(
|
| 845 |
+
get_motion_module(
|
| 846 |
+
in_channels=out_channels,
|
| 847 |
+
motion_module_type=motion_module_type,
|
| 848 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 849 |
+
)
|
| 850 |
+
if use_motion_module
|
| 851 |
+
else None
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
self.attentions = nn.ModuleList(attentions)
|
| 855 |
+
self.resnets = nn.ModuleList(resnets)
|
| 856 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
| 857 |
+
self.temporal_modules = nn.ModuleList(
|
| 858 |
+
[
|
| 859 |
+
(
|
| 860 |
+
get_motion_module(
|
| 861 |
+
in_channels=out_channels,
|
| 862 |
+
motion_module_type=temporal_module_type,
|
| 863 |
+
motion_module_kwargs=temporal_module_kwargs,
|
| 864 |
+
)
|
| 865 |
+
if use_temporal_module
|
| 866 |
+
else None
|
| 867 |
+
)
|
| 868 |
+
for _ in range(num_layers)
|
| 869 |
+
]
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
if add_upsample:
|
| 873 |
+
self.upsamplers = nn.ModuleList(
|
| 874 |
+
[Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
|
| 875 |
+
)
|
| 876 |
+
else:
|
| 877 |
+
self.upsamplers = None
|
| 878 |
+
|
| 879 |
+
self.gradient_checkpointing = False
|
| 880 |
+
|
| 881 |
+
def forward(
|
| 882 |
+
self,
|
| 883 |
+
hidden_states,
|
| 884 |
+
res_hidden_states_tuple,
|
| 885 |
+
temb=None,
|
| 886 |
+
encoder_hidden_states=None,
|
| 887 |
+
upsample_size=None,
|
| 888 |
+
attention_mask=None,
|
| 889 |
+
skip_mm=False,
|
| 890 |
+
):
|
| 891 |
+
if isinstance(encoder_hidden_states, list):
|
| 892 |
+
encoder_hidden_states, motion_hidden_states = encoder_hidden_states
|
| 893 |
+
else:
|
| 894 |
+
motion_hidden_states = encoder_hidden_states
|
| 895 |
+
for i, (resnet, attn, motion_module, temporal_module) in enumerate(
|
| 896 |
+
zip(self.resnets, self.attentions, self.motion_modules, self.temporal_modules)
|
| 897 |
+
):
|
| 898 |
+
# pop res hidden states
|
| 899 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 900 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 901 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 902 |
+
|
| 903 |
+
if self.training and self.gradient_checkpointing:
|
| 904 |
+
|
| 905 |
+
def create_custom_forward(module, return_dict=None):
|
| 906 |
+
def custom_forward(*inputs):
|
| 907 |
+
if return_dict is not None:
|
| 908 |
+
return module(*inputs, return_dict=return_dict)
|
| 909 |
+
else:
|
| 910 |
+
return module(*inputs)
|
| 911 |
+
|
| 912 |
+
return custom_forward
|
| 913 |
+
|
| 914 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 915 |
+
create_custom_forward(resnet), hidden_states, temb
|
| 916 |
+
)
|
| 917 |
+
# hidden_states = attn(
|
| 918 |
+
# hidden_states,
|
| 919 |
+
# encoder_hidden_states=encoder_hidden_states,
|
| 920 |
+
# ).sample
|
| 921 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 922 |
+
create_custom_forward(attn, return_dict=False),
|
| 923 |
+
hidden_states,
|
| 924 |
+
encoder_hidden_states,
|
| 925 |
+
)[0]
|
| 926 |
+
if (motion_module is not None) and not skip_mm:
|
| 927 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 928 |
+
create_custom_forward(motion_module),
|
| 929 |
+
hidden_states,
|
| 930 |
+
temb,
|
| 931 |
+
motion_hidden_states,
|
| 932 |
+
)
|
| 933 |
+
if (temporal_module is not None) and not skip_mm:
|
| 934 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 935 |
+
create_custom_forward(temporal_module),
|
| 936 |
+
hidden_states.requires_grad_(),
|
| 937 |
+
temb,
|
| 938 |
+
None,
|
| 939 |
+
)
|
| 940 |
+
|
| 941 |
+
else:
|
| 942 |
+
hidden_states = resnet(hidden_states, temb)
|
| 943 |
+
hidden_states = attn(
|
| 944 |
+
hidden_states,
|
| 945 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 946 |
+
).sample
|
| 947 |
+
|
| 948 |
+
# add motion module
|
| 949 |
+
hidden_states = (
|
| 950 |
+
motion_module(
|
| 951 |
+
hidden_states, temb, encoder_hidden_states=motion_hidden_states
|
| 952 |
+
)
|
| 953 |
+
if (motion_module is not None) and not skip_mm
|
| 954 |
+
else hidden_states
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
# add temporal_module
|
| 958 |
+
hidden_states = (
|
| 959 |
+
temporal_module(hidden_states, temb, encoder_hidden_states=None, debug=True)
|
| 960 |
+
if (temporal_module is not None) and not skip_mm
|
| 961 |
+
else hidden_states
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
if self.upsamplers is not None:
|
| 965 |
+
for upsampler in self.upsamplers:
|
| 966 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 967 |
+
|
| 968 |
+
return hidden_states
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
class UpBlock3D(nn.Module):
|
| 972 |
+
|
| 973 |
+
def __init__(
|
| 974 |
+
self,
|
| 975 |
+
in_channels: int,
|
| 976 |
+
prev_output_channel: int,
|
| 977 |
+
out_channels: int,
|
| 978 |
+
temb_channels: int,
|
| 979 |
+
dropout: float = 0.0,
|
| 980 |
+
num_layers: int = 1,
|
| 981 |
+
resnet_eps: float = 1e-6,
|
| 982 |
+
resnet_time_scale_shift: str = "default",
|
| 983 |
+
resnet_act_fn: str = "swish",
|
| 984 |
+
resnet_groups: int = 32,
|
| 985 |
+
resnet_pre_norm: bool = True,
|
| 986 |
+
output_scale_factor=1.0,
|
| 987 |
+
add_upsample=True,
|
| 988 |
+
use_inflated_groupnorm=None,
|
| 989 |
+
use_motion_module=None,
|
| 990 |
+
motion_module_type=None,
|
| 991 |
+
motion_module_kwargs=None,
|
| 992 |
+
use_temporal_module=None,
|
| 993 |
+
temporal_module_type=None,
|
| 994 |
+
temporal_module_kwargs=None,
|
| 995 |
+
):
|
| 996 |
+
super().__init__()
|
| 997 |
+
resnets = []
|
| 998 |
+
motion_modules = []
|
| 999 |
+
|
| 1000 |
+
# use_motion_module = False
|
| 1001 |
+
for i in range(num_layers):
|
| 1002 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
| 1003 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
| 1004 |
+
|
| 1005 |
+
resnets.append(
|
| 1006 |
+
ResnetBlock3D(
|
| 1007 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
| 1008 |
+
out_channels=out_channels,
|
| 1009 |
+
temb_channels=temb_channels,
|
| 1010 |
+
eps=resnet_eps,
|
| 1011 |
+
groups=resnet_groups,
|
| 1012 |
+
dropout=dropout,
|
| 1013 |
+
time_embedding_norm=resnet_time_scale_shift,
|
| 1014 |
+
non_linearity=resnet_act_fn,
|
| 1015 |
+
output_scale_factor=output_scale_factor,
|
| 1016 |
+
pre_norm=resnet_pre_norm,
|
| 1017 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
| 1018 |
+
)
|
| 1019 |
+
)
|
| 1020 |
+
motion_modules.append(
|
| 1021 |
+
get_motion_module(
|
| 1022 |
+
in_channels=out_channels,
|
| 1023 |
+
motion_module_type=motion_module_type,
|
| 1024 |
+
motion_module_kwargs=motion_module_kwargs,
|
| 1025 |
+
)
|
| 1026 |
+
if use_motion_module
|
| 1027 |
+
else None
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
self.resnets = nn.ModuleList(resnets)
|
| 1031 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
| 1032 |
+
self.temporal_modules = nn.ModuleList(
|
| 1033 |
+
[
|
| 1034 |
+
(
|
| 1035 |
+
get_motion_module(
|
| 1036 |
+
in_channels=out_channels,
|
| 1037 |
+
motion_module_type=temporal_module_type,
|
| 1038 |
+
motion_module_kwargs=temporal_module_kwargs,
|
| 1039 |
+
)
|
| 1040 |
+
if use_temporal_module
|
| 1041 |
+
else None
|
| 1042 |
+
)
|
| 1043 |
+
for _ in range(num_layers)
|
| 1044 |
+
]
|
| 1045 |
+
)
|
| 1046 |
+
|
| 1047 |
+
if add_upsample:
|
| 1048 |
+
self.upsamplers = nn.ModuleList(
|
| 1049 |
+
[Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
|
| 1050 |
+
)
|
| 1051 |
+
else:
|
| 1052 |
+
self.upsamplers = None
|
| 1053 |
+
|
| 1054 |
+
self.gradient_checkpointing = False
|
| 1055 |
+
|
| 1056 |
+
def forward(
|
| 1057 |
+
self,
|
| 1058 |
+
hidden_states,
|
| 1059 |
+
res_hidden_states_tuple,
|
| 1060 |
+
temb=None,
|
| 1061 |
+
upsample_size=None,
|
| 1062 |
+
encoder_hidden_states=None,
|
| 1063 |
+
skip_mm=False,
|
| 1064 |
+
):
|
| 1065 |
+
if isinstance(encoder_hidden_states, list):
|
| 1066 |
+
encoder_hidden_states, motion_hidden_states = encoder_hidden_states
|
| 1067 |
+
else:
|
| 1068 |
+
motion_hidden_states = encoder_hidden_states
|
| 1069 |
+
for resnet, motion_module, temporal_module in zip(self.resnets, self.motion_modules, self.temporal_modules):
|
| 1070 |
+
# pop res hidden states
|
| 1071 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 1072 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 1073 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 1074 |
+
|
| 1075 |
+
# print(f"UpBlock3D {self.gradient_checkpointing = }")
|
| 1076 |
+
if self.training and self.gradient_checkpointing:
|
| 1077 |
+
|
| 1078 |
+
def create_custom_forward(module):
|
| 1079 |
+
def custom_forward(*inputs):
|
| 1080 |
+
return module(*inputs)
|
| 1081 |
+
|
| 1082 |
+
return custom_forward
|
| 1083 |
+
|
| 1084 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 1085 |
+
create_custom_forward(resnet), hidden_states, temb
|
| 1086 |
+
)
|
| 1087 |
+
if (motion_module is not None) and not skip_mm:
|
| 1088 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 1089 |
+
create_custom_forward(motion_module),
|
| 1090 |
+
hidden_states,
|
| 1091 |
+
temb,
|
| 1092 |
+
motion_hidden_states,
|
| 1093 |
+
)
|
| 1094 |
+
if (temporal_module is not None) and not skip_mm:
|
| 1095 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 1096 |
+
create_custom_forward(temporal_module),
|
| 1097 |
+
hidden_states.requires_grad_(),
|
| 1098 |
+
temb,
|
| 1099 |
+
None,
|
| 1100 |
+
)
|
| 1101 |
+
|
| 1102 |
+
else:
|
| 1103 |
+
hidden_states = resnet(hidden_states, temb)
|
| 1104 |
+
hidden_states = (
|
| 1105 |
+
motion_module(
|
| 1106 |
+
hidden_states, temb, encoder_hidden_states=motion_hidden_states
|
| 1107 |
+
)
|
| 1108 |
+
if (motion_module is not None) and not skip_mm
|
| 1109 |
+
else hidden_states
|
| 1110 |
+
)
|
| 1111 |
+
hidden_states = (
|
| 1112 |
+
temporal_module(hidden_states, temb, encoder_hidden_states=None, debug=True)
|
| 1113 |
+
if (temporal_module is not None) and not skip_mm
|
| 1114 |
+
else hidden_states
|
| 1115 |
+
)
|
| 1116 |
+
|
| 1117 |
+
if self.upsamplers is not None:
|
| 1118 |
+
for upsampler in self.upsamplers:
|
| 1119 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 1120 |
+
|
| 1121 |
+
return hidden_states
|
pretrained_weights/sd-image-variations-diffusers/.gitattributes
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
pretrained_weights/sd-image-variations-diffusers/README.md
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
thumbnail: "https://repository-images.githubusercontent.com/523487884/fdb03a69-8353-4387-b5fc-0d85f888a63f"
|
| 3 |
+
datasets:
|
| 4 |
+
- ChristophSchuhmann/improved_aesthetics_6plus
|
| 5 |
+
license: creativeml-openrail-m
|
| 6 |
+
tags:
|
| 7 |
+
- stable-diffusion
|
| 8 |
+
- stable-diffusion-diffusers
|
| 9 |
+
- image-to-image
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Stable Diffusion Image Variations Model Card
|
| 13 |
+
|
| 14 |
+
📣 V2 model released, and blurriness issues fixed! 📣
|
| 15 |
+
|
| 16 |
+
🧨🎉 Image Variations is now natively supported in 🤗 Diffusers! 🎉🧨
|
| 17 |
+
|
| 18 |
+

|
| 19 |
+
|
| 20 |
+
## Version 2
|
| 21 |
+
|
| 22 |
+
This version of Stable Diffusion has been fine tuned from [CompVis/stable-diffusion-v1-4-original](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) to accept CLIP image embedding rather than text embeddings. This allows the creation of "image variations" similar to DALLE-2 using Stable Diffusion. This version of the weights has been ported to huggingface Diffusers, to use this with the Diffusers library requires the [Lambda Diffusers repo](https://github.com/LambdaLabsML/lambda-diffusers).
|
| 23 |
+
|
| 24 |
+
This model was trained in two stages and longer than the original variations model and gives better image quality and better CLIP rated similarity compared to the original version
|
| 25 |
+
|
| 26 |
+
See training details and v1 vs v2 comparison below.
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
## Example
|
| 30 |
+
|
| 31 |
+
Make sure you are using a version of Diffusers >=0.8.0 (for older version see the old instructions at the bottom of this model card)
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
from diffusers import StableDiffusionImageVariationPipeline
|
| 35 |
+
from PIL import Image
|
| 36 |
+
|
| 37 |
+
device = "cuda:0"
|
| 38 |
+
sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained(
|
| 39 |
+
"lambdalabs/sd-image-variations-diffusers",
|
| 40 |
+
revision="v2.0",
|
| 41 |
+
)
|
| 42 |
+
sd_pipe = sd_pipe.to(device)
|
| 43 |
+
|
| 44 |
+
im = Image.open("path/to/image.jpg")
|
| 45 |
+
tform = transforms.Compose([
|
| 46 |
+
transforms.ToTensor(),
|
| 47 |
+
transforms.Resize(
|
| 48 |
+
(224, 224),
|
| 49 |
+
interpolation=transforms.InterpolationMode.BICUBIC,
|
| 50 |
+
antialias=False,
|
| 51 |
+
),
|
| 52 |
+
transforms.Normalize(
|
| 53 |
+
[0.48145466, 0.4578275, 0.40821073],
|
| 54 |
+
[0.26862954, 0.26130258, 0.27577711]),
|
| 55 |
+
])
|
| 56 |
+
inp = tform(im).to(device).unsqueeze(0)
|
| 57 |
+
|
| 58 |
+
out = sd_pipe(inp, guidance_scale=3)
|
| 59 |
+
out["images"][0].save("result.jpg")
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
### The importance of resizing correctly... (or not)
|
| 63 |
+
|
| 64 |
+
Note that due a bit of an oversight during training, the model expects resized images without anti-aliasing. This turns out to make a big difference and is important to do the resizing the same way during inference. When passing a PIL image to the Diffusers pipeline antialiasing will be applied during resize, so it's better to input a tensor which you have prepared manually according to the transfrom in the example above!
|
| 65 |
+
|
| 66 |
+
Here are examples of images generated without (top) and with (bottom) anti-aliasing during resize. (Input is [this image](https://github.com/SHI-Labs/Versatile-Diffusion/blob/master/assets/ghibli.jpg))
|
| 67 |
+
|
| 68 |
+

|
| 69 |
+
|
| 70 |
+

|
| 71 |
+
|
| 72 |
+
### V1 vs V2
|
| 73 |
+
|
| 74 |
+
Here's an example of V1 vs V2, version two was trained more carefully and for longer, see the details below. V2-top vs V1-bottom
|
| 75 |
+
|
| 76 |
+

|
| 77 |
+
|
| 78 |
+

|
| 79 |
+
|
| 80 |
+
Input images:
|
| 81 |
+
|
| 82 |
+

|
| 83 |
+
|
| 84 |
+
One important thing to note is that due to the longer training V2 appears to have memorised some common images from the training data, e.g. now the previous example of the Girl with a Pearl Earring almosts perfectly reproduce the original rather than creating variations. You can always use v1 by specifiying `revision="v1.0"`.
|
| 85 |
+
|
| 86 |
+
v2 output for girl with a pearl earing as input (guidance scale=3)
|
| 87 |
+
|
| 88 |
+

|
| 89 |
+
|
| 90 |
+
# Training
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
**Training Procedure**
|
| 94 |
+
This model is fine tuned from Stable Diffusion v1-3 where the text encoder has been replaced with an image encoder. The training procedure is the same as for Stable Diffusion except for the fact that images are encoded through a ViT-L/14 image-encoder including the final projection layer to the CLIP shared embedding space. The model was trained on LAION improved aesthetics 6plus.
|
| 95 |
+
|
| 96 |
+
- **Hardware:** 8 x A100-40GB GPUs (provided by [Lambda GPU Cloud](https://lambdalabs.com/service/gpu-cloud))
|
| 97 |
+
- **Optimizer:** AdamW
|
| 98 |
+
|
| 99 |
+
- **Stage 1** - Fine tune only CrossAttention layer weights from Stable Diffusion v1.4 model
|
| 100 |
+
- **Steps**: 46,000
|
| 101 |
+
- **Batch:** batch size=4, GPUs=8, Gradient Accumulations=4. Total batch size=128
|
| 102 |
+
- **Learning rate:** warmup to 1e-5 for 10,000 steps and then kept constant
|
| 103 |
+
|
| 104 |
+
- **Stage 2** - Resume from Stage 1 training the whole unet
|
| 105 |
+
- **Steps**: 50,000
|
| 106 |
+
- **Batch:** batch size=4, GPUs=8, Gradient Accumulations=5. Total batch size=160
|
| 107 |
+
- **Learning rate:** warmup to 1e-5 for 5,000 steps and then kept constant
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
Training was done using a [modified version of the original Stable Diffusion training code](https://github.com/justinpinkney/stable-diffusion).
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# Uses
|
| 114 |
+
_The following section is adapted from the [Stable Diffusion model card](https://huggingface.co/CompVis/stable-diffusion-v1-4)_
|
| 115 |
+
|
| 116 |
+
## Direct Use
|
| 117 |
+
The model is intended for research purposes only. Possible research areas and
|
| 118 |
+
tasks include
|
| 119 |
+
|
| 120 |
+
- Safe deployment of models which have the potential to generate harmful content.
|
| 121 |
+
- Probing and understanding the limitations and biases of generative models.
|
| 122 |
+
- Generation of artworks and use in design and other artistic processes.
|
| 123 |
+
- Applications in educational or creative tools.
|
| 124 |
+
- Research on generative models.
|
| 125 |
+
|
| 126 |
+
Excluded uses are described below.
|
| 127 |
+
|
| 128 |
+
### Misuse, Malicious Use, and Out-of-Scope Use
|
| 129 |
+
|
| 130 |
+
The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
|
| 131 |
+
|
| 132 |
+
#### Out-of-Scope Use
|
| 133 |
+
The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
| 134 |
+
|
| 135 |
+
#### Misuse and Malicious Use
|
| 136 |
+
Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
|
| 137 |
+
|
| 138 |
+
- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
|
| 139 |
+
- Intentionally promoting or propagating discriminatory content or harmful stereotypes.
|
| 140 |
+
- Impersonating individuals without their consent.
|
| 141 |
+
- Sexual content without consent of the people who might see it.
|
| 142 |
+
- Mis- and disinformation
|
| 143 |
+
- Representations of egregious violence and gore
|
| 144 |
+
- Sharing of copyrighted or licensed material in violation of its terms of use.
|
| 145 |
+
- Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
|
| 146 |
+
|
| 147 |
+
## Limitations and Bias
|
| 148 |
+
|
| 149 |
+
### Limitations
|
| 150 |
+
|
| 151 |
+
- The model does not achieve perfect photorealism
|
| 152 |
+
- The model cannot render legible text
|
| 153 |
+
- The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
|
| 154 |
+
- Faces and people in general may not be generated properly.
|
| 155 |
+
- The model was trained mainly with English captions and will not work as well in other languages.
|
| 156 |
+
- The autoencoding part of the model is lossy
|
| 157 |
+
- The model was trained on a large-scale dataset
|
| 158 |
+
[LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
|
| 159 |
+
and is not fit for product use without additional safety mechanisms and
|
| 160 |
+
considerations.
|
| 161 |
+
- No additional measures were used to deduplicate the dataset. As a result, we observe some degree of memorization for images that are duplicated in the training data.
|
| 162 |
+
The training data can be searched at [https://rom1504.github.io/clip-retrieval/](https://rom1504.github.io/clip-retrieval/) to possibly assist in the detection of memorized images.
|
| 163 |
+
|
| 164 |
+
### Bias
|
| 165 |
+
|
| 166 |
+
While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
|
| 167 |
+
Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
|
| 168 |
+
which consists of images that are primarily limited to English descriptions.
|
| 169 |
+
Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
|
| 170 |
+
This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
|
| 171 |
+
ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
|
| 172 |
+
|
| 173 |
+
### Safety Module
|
| 174 |
+
|
| 175 |
+
The intended use of this model is with the [Safety Checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) in Diffusers.
|
| 176 |
+
This checker works by checking model outputs against known hard-coded NSFW concepts.
|
| 177 |
+
The concepts are intentionally hidden to reduce the likelihood of reverse-engineering this filter.
|
| 178 |
+
Specifically, the checker compares the class probability of harmful concepts in the embedding space of the `CLIPModel` *after generation* of the images.
|
| 179 |
+
The concepts are passed into the model with the generated image and compared to a hand-engineered weight for each NSFW concept.
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
## Old instructions
|
| 183 |
+
|
| 184 |
+
If you are using a diffusers version <0.8.0 there is no `StableDiffusionImageVariationPipeline`,
|
| 185 |
+
in this case you need to use an older revision (`2ddbd90b14bc5892c19925b15185e561bc8e5d0a`) in conjunction with the lambda-diffusers repo:
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
First clone [Lambda Diffusers](https://github.com/LambdaLabsML/lambda-diffusers) and install any requirements (in a virtual environment in the example below):
|
| 189 |
+
|
| 190 |
+
```bash
|
| 191 |
+
git clone https://github.com/LambdaLabsML/lambda-diffusers.git
|
| 192 |
+
cd lambda-diffusers
|
| 193 |
+
python -m venv .venv
|
| 194 |
+
source .venv/bin/activate
|
| 195 |
+
pip install -r requirements.txt
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
Then run the following python code:
|
| 199 |
+
|
| 200 |
+
```python
|
| 201 |
+
from pathlib import Path
|
| 202 |
+
from lambda_diffusers import StableDiffusionImageEmbedPipeline
|
| 203 |
+
from PIL import Image
|
| 204 |
+
import torch
|
| 205 |
+
|
| 206 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 207 |
+
pipe = StableDiffusionImageEmbedPipeline.from_pretrained(
|
| 208 |
+
"lambdalabs/sd-image-variations-diffusers",
|
| 209 |
+
revision="2ddbd90b14bc5892c19925b15185e561bc8e5d0a",
|
| 210 |
+
)
|
| 211 |
+
pipe = pipe.to(device)
|
| 212 |
+
|
| 213 |
+
im = Image.open("your/input/image/here.jpg")
|
| 214 |
+
num_samples = 4
|
| 215 |
+
image = pipe(num_samples*[im], guidance_scale=3.0)
|
| 216 |
+
image = image["sample"]
|
| 217 |
+
|
| 218 |
+
base_path = Path("outputs/im2im")
|
| 219 |
+
base_path.mkdir(exist_ok=True, parents=True)
|
| 220 |
+
for idx, im in enumerate(image):
|
| 221 |
+
im.save(base_path/f"{idx:06}.jpg")
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
*This model card was written by: Justin Pinkney and is based on the [Stable Diffusion model card](https://huggingface.co/CompVis/stable-diffusion-v1-4).*
|
pretrained_weights/sd-image-variations-diffusers/alias-montage.jpg
ADDED
|
Git LFS Details
|
pretrained_weights/sd-image-variations-diffusers/default-montage.jpg
ADDED
|
Git LFS Details
|
pretrained_weights/sd-image-variations-diffusers/earring.jpg
ADDED
|
Git LFS Details
|
pretrained_weights/sd-image-variations-diffusers/feature_extractor/preprocessor_config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"crop_size": {
|
| 3 |
+
"height": 224,
|
| 4 |
+
"width": 224
|
| 5 |
+
},
|
| 6 |
+
"do_center_crop": true,
|
| 7 |
+
"do_convert_rgb": true,
|
| 8 |
+
"do_normalize": true,
|
| 9 |
+
"do_rescale": true,
|
| 10 |
+
"do_resize": true,
|
| 11 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
| 12 |
+
"image_mean": [
|
| 13 |
+
0.48145466,
|
| 14 |
+
0.4578275,
|
| 15 |
+
0.40821073
|
| 16 |
+
],
|
| 17 |
+
"image_processor_type": "CLIPImageProcessor",
|
| 18 |
+
"image_std": [
|
| 19 |
+
0.26862954,
|
| 20 |
+
0.26130258,
|
| 21 |
+
0.27577711
|
| 22 |
+
],
|
| 23 |
+
"resample": 3,
|
| 24 |
+
"rescale_factor": 0.00392156862745098,
|
| 25 |
+
"size": {
|
| 26 |
+
"shortest_edge": 224
|
| 27 |
+
}
|
| 28 |
+
}
|
pretrained_weights/sd-image-variations-diffusers/image_encoder/config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "/home/jpinkney/.cache/huggingface/diffusers/models--lambdalabs--sd-image-variations-diffusers/snapshots/ca6f97f838ae1b5bf764f31363a21f388f4d8f3e/image_encoder",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"CLIPVisionModelWithProjection"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"dropout": 0.0,
|
| 8 |
+
"hidden_act": "quick_gelu",
|
| 9 |
+
"hidden_size": 1024,
|
| 10 |
+
"image_size": 224,
|
| 11 |
+
"initializer_factor": 1.0,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 4096,
|
| 14 |
+
"layer_norm_eps": 1e-05,
|
| 15 |
+
"model_type": "clip_vision_model",
|
| 16 |
+
"num_attention_heads": 16,
|
| 17 |
+
"num_channels": 3,
|
| 18 |
+
"num_hidden_layers": 24,
|
| 19 |
+
"patch_size": 14,
|
| 20 |
+
"projection_dim": 768,
|
| 21 |
+
"torch_dtype": "float32",
|
| 22 |
+
"transformers_version": "4.25.1"
|
| 23 |
+
}
|
pretrained_weights/sd-image-variations-diffusers/image_encoder/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:89d2aa29b5fdf64f3ad4f45fb4227ea98bc45156bbae673b85be1af7783dbabb
|
| 3 |
+
size 1215993967
|
pretrained_weights/sd-image-variations-diffusers/inputs.jpg
ADDED
|
pretrained_weights/sd-image-variations-diffusers/model_index.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "StableDiffusionImageVariationPipeline",
|
| 3 |
+
"_diffusers_version": "0.9.0",
|
| 4 |
+
"feature_extractor": [
|
| 5 |
+
"transformers",
|
| 6 |
+
"CLIPImageProcessor"
|
| 7 |
+
],
|
| 8 |
+
"image_encoder": [
|
| 9 |
+
"transformers",
|
| 10 |
+
"CLIPVisionModelWithProjection"
|
| 11 |
+
],
|
| 12 |
+
"requires_safety_checker": true,
|
| 13 |
+
"safety_checker": [
|
| 14 |
+
"stable_diffusion",
|
| 15 |
+
"StableDiffusionSafetyChecker"
|
| 16 |
+
],
|
| 17 |
+
"scheduler": [
|
| 18 |
+
"diffusers",
|
| 19 |
+
"PNDMScheduler"
|
| 20 |
+
],
|
| 21 |
+
"unet": [
|
| 22 |
+
"diffusers",
|
| 23 |
+
"UNet2DConditionModel"
|
| 24 |
+
],
|
| 25 |
+
"vae": [
|
| 26 |
+
"diffusers",
|
| 27 |
+
"AutoencoderKL"
|
| 28 |
+
]
|
| 29 |
+
}
|
pretrained_weights/sd-image-variations-diffusers/safety_checker/config.json
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_commit_hash": "ca6f97f838ae1b5bf764f31363a21f388f4d8f3e",
|
| 3 |
+
"_name_or_path": "/home/jpinkney/.cache/huggingface/diffusers/models--lambdalabs--sd-image-variations-diffusers/snapshots/ca6f97f838ae1b5bf764f31363a21f388f4d8f3e/safety_checker",
|
| 4 |
+
"architectures": [
|
| 5 |
+
"StableDiffusionSafetyChecker"
|
| 6 |
+
],
|
| 7 |
+
"initializer_factor": 1.0,
|
| 8 |
+
"logit_scale_init_value": 2.6592,
|
| 9 |
+
"model_type": "clip",
|
| 10 |
+
"projection_dim": 768,
|
| 11 |
+
"text_config": {
|
| 12 |
+
"_name_or_path": "",
|
| 13 |
+
"add_cross_attention": false,
|
| 14 |
+
"architectures": null,
|
| 15 |
+
"attention_dropout": 0.0,
|
| 16 |
+
"bad_words_ids": null,
|
| 17 |
+
"begin_suppress_tokens": null,
|
| 18 |
+
"bos_token_id": 0,
|
| 19 |
+
"chunk_size_feed_forward": 0,
|
| 20 |
+
"cross_attention_hidden_size": null,
|
| 21 |
+
"decoder_start_token_id": null,
|
| 22 |
+
"diversity_penalty": 0.0,
|
| 23 |
+
"do_sample": false,
|
| 24 |
+
"dropout": 0.0,
|
| 25 |
+
"early_stopping": false,
|
| 26 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 27 |
+
"eos_token_id": 2,
|
| 28 |
+
"exponential_decay_length_penalty": null,
|
| 29 |
+
"finetuning_task": null,
|
| 30 |
+
"forced_bos_token_id": null,
|
| 31 |
+
"forced_eos_token_id": null,
|
| 32 |
+
"hidden_act": "quick_gelu",
|
| 33 |
+
"hidden_size": 768,
|
| 34 |
+
"id2label": {
|
| 35 |
+
"0": "LABEL_0",
|
| 36 |
+
"1": "LABEL_1"
|
| 37 |
+
},
|
| 38 |
+
"initializer_factor": 1.0,
|
| 39 |
+
"initializer_range": 0.02,
|
| 40 |
+
"intermediate_size": 3072,
|
| 41 |
+
"is_decoder": false,
|
| 42 |
+
"is_encoder_decoder": false,
|
| 43 |
+
"label2id": {
|
| 44 |
+
"LABEL_0": 0,
|
| 45 |
+
"LABEL_1": 1
|
| 46 |
+
},
|
| 47 |
+
"layer_norm_eps": 1e-05,
|
| 48 |
+
"length_penalty": 1.0,
|
| 49 |
+
"max_length": 20,
|
| 50 |
+
"max_position_embeddings": 77,
|
| 51 |
+
"min_length": 0,
|
| 52 |
+
"model_type": "clip_text_model",
|
| 53 |
+
"no_repeat_ngram_size": 0,
|
| 54 |
+
"num_attention_heads": 12,
|
| 55 |
+
"num_beam_groups": 1,
|
| 56 |
+
"num_beams": 1,
|
| 57 |
+
"num_hidden_layers": 12,
|
| 58 |
+
"num_return_sequences": 1,
|
| 59 |
+
"output_attentions": false,
|
| 60 |
+
"output_hidden_states": false,
|
| 61 |
+
"output_scores": false,
|
| 62 |
+
"pad_token_id": 1,
|
| 63 |
+
"prefix": null,
|
| 64 |
+
"problem_type": null,
|
| 65 |
+
"projection_dim": 512,
|
| 66 |
+
"pruned_heads": {},
|
| 67 |
+
"remove_invalid_values": false,
|
| 68 |
+
"repetition_penalty": 1.0,
|
| 69 |
+
"return_dict": true,
|
| 70 |
+
"return_dict_in_generate": false,
|
| 71 |
+
"sep_token_id": null,
|
| 72 |
+
"suppress_tokens": null,
|
| 73 |
+
"task_specific_params": null,
|
| 74 |
+
"temperature": 1.0,
|
| 75 |
+
"tf_legacy_loss": false,
|
| 76 |
+
"tie_encoder_decoder": false,
|
| 77 |
+
"tie_word_embeddings": true,
|
| 78 |
+
"tokenizer_class": null,
|
| 79 |
+
"top_k": 50,
|
| 80 |
+
"top_p": 1.0,
|
| 81 |
+
"torch_dtype": null,
|
| 82 |
+
"torchscript": false,
|
| 83 |
+
"transformers_version": "4.25.1",
|
| 84 |
+
"typical_p": 1.0,
|
| 85 |
+
"use_bfloat16": false,
|
| 86 |
+
"vocab_size": 49408
|
| 87 |
+
},
|
| 88 |
+
"text_config_dict": {
|
| 89 |
+
"hidden_size": 768,
|
| 90 |
+
"intermediate_size": 3072,
|
| 91 |
+
"num_attention_heads": 12,
|
| 92 |
+
"num_hidden_layers": 12
|
| 93 |
+
},
|
| 94 |
+
"torch_dtype": "float32",
|
| 95 |
+
"transformers_version": null,
|
| 96 |
+
"vision_config": {
|
| 97 |
+
"_name_or_path": "",
|
| 98 |
+
"add_cross_attention": false,
|
| 99 |
+
"architectures": null,
|
| 100 |
+
"attention_dropout": 0.0,
|
| 101 |
+
"bad_words_ids": null,
|
| 102 |
+
"begin_suppress_tokens": null,
|
| 103 |
+
"bos_token_id": null,
|
| 104 |
+
"chunk_size_feed_forward": 0,
|
| 105 |
+
"cross_attention_hidden_size": null,
|
| 106 |
+
"decoder_start_token_id": null,
|
| 107 |
+
"diversity_penalty": 0.0,
|
| 108 |
+
"do_sample": false,
|
| 109 |
+
"dropout": 0.0,
|
| 110 |
+
"early_stopping": false,
|
| 111 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 112 |
+
"eos_token_id": null,
|
| 113 |
+
"exponential_decay_length_penalty": null,
|
| 114 |
+
"finetuning_task": null,
|
| 115 |
+
"forced_bos_token_id": null,
|
| 116 |
+
"forced_eos_token_id": null,
|
| 117 |
+
"hidden_act": "quick_gelu",
|
| 118 |
+
"hidden_size": 1024,
|
| 119 |
+
"id2label": {
|
| 120 |
+
"0": "LABEL_0",
|
| 121 |
+
"1": "LABEL_1"
|
| 122 |
+
},
|
| 123 |
+
"image_size": 224,
|
| 124 |
+
"initializer_factor": 1.0,
|
| 125 |
+
"initializer_range": 0.02,
|
| 126 |
+
"intermediate_size": 4096,
|
| 127 |
+
"is_decoder": false,
|
| 128 |
+
"is_encoder_decoder": false,
|
| 129 |
+
"label2id": {
|
| 130 |
+
"LABEL_0": 0,
|
| 131 |
+
"LABEL_1": 1
|
| 132 |
+
},
|
| 133 |
+
"layer_norm_eps": 1e-05,
|
| 134 |
+
"length_penalty": 1.0,
|
| 135 |
+
"max_length": 20,
|
| 136 |
+
"min_length": 0,
|
| 137 |
+
"model_type": "clip_vision_model",
|
| 138 |
+
"no_repeat_ngram_size": 0,
|
| 139 |
+
"num_attention_heads": 16,
|
| 140 |
+
"num_beam_groups": 1,
|
| 141 |
+
"num_beams": 1,
|
| 142 |
+
"num_channels": 3,
|
| 143 |
+
"num_hidden_layers": 24,
|
| 144 |
+
"num_return_sequences": 1,
|
| 145 |
+
"output_attentions": false,
|
| 146 |
+
"output_hidden_states": false,
|
| 147 |
+
"output_scores": false,
|
| 148 |
+
"pad_token_id": null,
|
| 149 |
+
"patch_size": 14,
|
| 150 |
+
"prefix": null,
|
| 151 |
+
"problem_type": null,
|
| 152 |
+
"projection_dim": 512,
|
| 153 |
+
"pruned_heads": {},
|
| 154 |
+
"remove_invalid_values": false,
|
| 155 |
+
"repetition_penalty": 1.0,
|
| 156 |
+
"return_dict": true,
|
| 157 |
+
"return_dict_in_generate": false,
|
| 158 |
+
"sep_token_id": null,
|
| 159 |
+
"suppress_tokens": null,
|
| 160 |
+
"task_specific_params": null,
|
| 161 |
+
"temperature": 1.0,
|
| 162 |
+
"tf_legacy_loss": false,
|
| 163 |
+
"tie_encoder_decoder": false,
|
| 164 |
+
"tie_word_embeddings": true,
|
| 165 |
+
"tokenizer_class": null,
|
| 166 |
+
"top_k": 50,
|
| 167 |
+
"top_p": 1.0,
|
| 168 |
+
"torch_dtype": null,
|
| 169 |
+
"torchscript": false,
|
| 170 |
+
"transformers_version": "4.25.1",
|
| 171 |
+
"typical_p": 1.0,
|
| 172 |
+
"use_bfloat16": false
|
| 173 |
+
},
|
| 174 |
+
"vision_config_dict": {
|
| 175 |
+
"hidden_size": 1024,
|
| 176 |
+
"intermediate_size": 4096,
|
| 177 |
+
"num_attention_heads": 16,
|
| 178 |
+
"num_hidden_layers": 24,
|
| 179 |
+
"patch_size": 14
|
| 180 |
+
}
|
| 181 |
+
}
|
pretrained_weights/sd-image-variations-diffusers/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "PNDMScheduler",
|
| 3 |
+
"_diffusers_version": "0.9.0",
|
| 4 |
+
"beta_end": 0.012,
|
| 5 |
+
"beta_schedule": "scaled_linear",
|
| 6 |
+
"beta_start": 0.00085,
|
| 7 |
+
"clip_sample": false,
|
| 8 |
+
"num_train_timesteps": 1000,
|
| 9 |
+
"set_alpha_to_one": false,
|
| 10 |
+
"skip_prk_steps": true,
|
| 11 |
+
"steps_offset": 1,
|
| 12 |
+
"trained_betas": null
|
| 13 |
+
}
|
pretrained_weights/sd-image-variations-diffusers/unet/config.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "UNet2DConditionModel",
|
| 3 |
+
"_diffusers_version": "0.9.0",
|
| 4 |
+
"act_fn": "silu",
|
| 5 |
+
"attention_head_dim": 8,
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
320,
|
| 8 |
+
640,
|
| 9 |
+
1280,
|
| 10 |
+
1280
|
| 11 |
+
],
|
| 12 |
+
"center_input_sample": false,
|
| 13 |
+
"cross_attention_dim": 768,
|
| 14 |
+
"down_block_types": [
|
| 15 |
+
"CrossAttnDownBlock2D",
|
| 16 |
+
"CrossAttnDownBlock2D",
|
| 17 |
+
"CrossAttnDownBlock2D",
|
| 18 |
+
"DownBlock2D"
|
| 19 |
+
],
|
| 20 |
+
"downsample_padding": 1,
|
| 21 |
+
"dual_cross_attention": false,
|
| 22 |
+
"flip_sin_to_cos": true,
|
| 23 |
+
"freq_shift": 0,
|
| 24 |
+
"in_channels": 4,
|
| 25 |
+
"layers_per_block": 2,
|
| 26 |
+
"mid_block_scale_factor": 1,
|
| 27 |
+
"norm_eps": 1e-05,
|
| 28 |
+
"norm_num_groups": 32,
|
| 29 |
+
"num_class_embeds": null,
|
| 30 |
+
"only_cross_attention": false,
|
| 31 |
+
"out_channels": 4,
|
| 32 |
+
"sample_size": 64,
|
| 33 |
+
"up_block_types": [
|
| 34 |
+
"UpBlock2D",
|
| 35 |
+
"CrossAttnUpBlock2D",
|
| 36 |
+
"CrossAttnUpBlock2D",
|
| 37 |
+
"CrossAttnUpBlock2D"
|
| 38 |
+
],
|
| 39 |
+
"use_linear_projection": false
|
| 40 |
+
}
|
pretrained_weights/sd-image-variations-diffusers/unet/diffusion_pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ee23e3368e4e7c0e4ef636ed61923609c97fcaa583f8bb416e3e0986d4a0cfc6
|
| 3 |
+
size 3438354725
|
pretrained_weights/sd-image-variations-diffusers/v1-montage.jpg
ADDED
|
Git LFS Details
|
pretrained_weights/sd-image-variations-diffusers/v2-montage.jpg
ADDED
|
Git LFS Details
|
pretrained_weights/sd-image-variations-diffusers/vae/config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.9.0",
|
| 4 |
+
"_name_or_path": "stabilityai/sd-vae-ft-mse",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
128,
|
| 8 |
+
256,
|
| 9 |
+
512,
|
| 10 |
+
512
|
| 11 |
+
],
|
| 12 |
+
"down_block_types": [
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D",
|
| 16 |
+
"DownEncoderBlock2D"
|
| 17 |
+
],
|
| 18 |
+
"in_channels": 3,
|
| 19 |
+
"latent_channels": 4,
|
| 20 |
+
"layers_per_block": 2,
|
| 21 |
+
"norm_num_groups": 32,
|
| 22 |
+
"out_channels": 3,
|
| 23 |
+
"sample_size": 256,
|
| 24 |
+
"up_block_types": [
|
| 25 |
+
"UpDecoderBlock2D",
|
| 26 |
+
"UpDecoderBlock2D",
|
| 27 |
+
"UpDecoderBlock2D",
|
| 28 |
+
"UpDecoderBlock2D"
|
| 29 |
+
]
|
| 30 |
+
}
|
pretrained_weights/sd-image-variations-diffusers/vae/diffusion_pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1b4889b6b1d4ce7ae320a02dedaeff1780ad77d415ea0d744b476155c6377ddc
|
| 3 |
+
size 334707217
|
pretrained_weights/stable-video-diffusion-img2vid-xt/.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
output_tile.gif filter=lfs diff=lfs merge=lfs -text
|
pretrained_weights/stable-video-diffusion-img2vid-xt/LICENSE.md
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
STABILITY AI COMMUNITY LICENSE AGREEMENT
|
| 2 |
+
|
| 3 |
+
Last Updated: July 5, 2024
|
| 4 |
+
|
| 5 |
+
1. INTRODUCTION
|
| 6 |
+
|
| 7 |
+
This Agreement applies to any individual person or entity (“You”, “Your” or “Licensee”) that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
|
| 8 |
+
|
| 9 |
+
This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
|
| 10 |
+
|
| 11 |
+
By clicking “I Accept” or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then “You” includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity’s behalf.
|
| 12 |
+
|
| 13 |
+
2. RESEARCH & NON-COMMERCIAL USE LICENSE
|
| 14 |
+
|
| 15 |
+
Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. “Research Purpose” means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. “Non-Commercial Purpose” means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
|
| 16 |
+
|
| 17 |
+
3. COMMERCIAL USE LICENSE
|
| 18 |
+
|
| 19 |
+
Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. “Commercial Purpose” means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business’s or organization’s internal operations.
|
| 20 |
+
If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
|
| 21 |
+
|
| 22 |
+
4. GENERAL TERMS
|
| 23 |
+
|
| 24 |
+
Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms.
|
| 25 |
+
a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved”, and (iii) prominently display “Powered by Stability AI” on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the “Notice” text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the “Notice” text file that You changed the Stability AI Materials and how it was modified.
|
| 26 |
+
b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI’s AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
|
| 27 |
+
c. Intellectual Property.
|
| 28 |
+
(i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein.
|
| 29 |
+
(ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI’s ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
|
| 30 |
+
(iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
|
| 31 |
+
(iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement.
|
| 32 |
+
(v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI’s existing or prospective technology, products or services (collectively, “Feedback”). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided “AS IS” and You make no warranties whatsoever about any Feedback.
|
| 33 |
+
d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
|
| 34 |
+
e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
| 35 |
+
f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement.
|
| 36 |
+
g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
|
| 37 |
+
|
| 38 |
+
5. DEFINITIONS
|
| 39 |
+
|
| 40 |
+
“Affiliate(s)” means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, “control” means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
|
| 41 |
+
|
| 42 |
+
"Agreement" means this Stability AI Community License Agreement.
|
| 43 |
+
|
| 44 |
+
“AUP” means the Stability AI Acceptable Use Policy available at (https://stability.ai/use-policy), as may be updated from time to time.
|
| 45 |
+
|
| 46 |
+
"Derivative Work(s)” means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output, including “fine tune” and “low-rank adaptation” models derived from a Model or a Model’s output, but do not include the output of any Model.
|
| 47 |
+
|
| 48 |
+
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
|
| 49 |
+
|
| 50 |
+
“Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability’s Core Models Webpage available at (https://stability.ai/core-models), as may be updated from time to time.
|
| 51 |
+
|
| 52 |
+
"Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
|
| 53 |
+
|
| 54 |
+
"Software" means Stability AI’s proprietary software made available under this Agreement now or in the future.
|
| 55 |
+
|
| 56 |
+
“Stability AI Materials” means, collectively, Stability’s proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
|
| 57 |
+
|
| 58 |
+
“Trade Control Laws” means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
|
pretrained_weights/stable-video-diffusion-img2vid-xt/README.md
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
pipeline_tag: image-to-video
|
| 3 |
+
license: other
|
| 4 |
+
license_name: stable-video-diffusion-community
|
| 5 |
+
license_link: LICENSE.md
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
# Stable Video Diffusion Image-to-Video Model Card
|
| 9 |
+
|
| 10 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 11 |
+

|
| 12 |
+
Stable Video Diffusion (SVD) Image-to-Video is a diffusion model that takes in a still image as a conditioning frame, and generates a video from it.
|
| 13 |
+
|
| 14 |
+
Please note: For commercial use, please refer to https://stability.ai/license.
|
| 15 |
+
|
| 16 |
+
## Model Details
|
| 17 |
+
|
| 18 |
+
### Model Description
|
| 19 |
+
|
| 20 |
+
(SVD) Image-to-Video is a latent diffusion model trained to generate short video clips from an image conditioning.
|
| 21 |
+
This model was trained to generate 25 frames at resolution 576x1024 given a context frame of the same size, finetuned from [SVD Image-to-Video [14 frames]](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid).
|
| 22 |
+
We also finetune the widely used [f8-decoder](https://huggingface.co/docs/diffusers/api/models/autoencoderkl#loading-from-the-original-format) for temporal consistency.
|
| 23 |
+
For convenience, we additionally provide the model with the
|
| 24 |
+
standard frame-wise decoder [here](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/blob/main/svd_xt_image_decoder.safetensors).
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
- **Developed by:** Stability AI
|
| 28 |
+
- **Funded by:** Stability AI
|
| 29 |
+
- **Model type:** Generative image-to-video model
|
| 30 |
+
- **Finetuned from model:** SVD Image-to-Video [14 frames]
|
| 31 |
+
|
| 32 |
+
### Model Sources
|
| 33 |
+
|
| 34 |
+
For research purposes, we recommend our `generative-models` Github repository (https://github.com/Stability-AI/generative-models),
|
| 35 |
+
which implements the most popular diffusion frameworks (both training and inference).
|
| 36 |
+
|
| 37 |
+
- **Repository:** https://github.com/Stability-AI/generative-models
|
| 38 |
+
- **Paper:** https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
## Evaluation
|
| 42 |
+

|
| 43 |
+
The chart above evaluates user preference for SVD-Image-to-Video over [GEN-2](https://research.runwayml.com/gen2) and [PikaLabs](https://www.pika.art/).
|
| 44 |
+
SVD-Image-to-Video is preferred by human voters in terms of video quality. For details on the user study, we refer to the [research paper](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets)
|
| 45 |
+
|
| 46 |
+
## Uses
|
| 47 |
+
|
| 48 |
+
### Direct Use
|
| 49 |
+
|
| 50 |
+
The model is intended for both non-commercial and commercial usage. You can use this model for non-commercial or research purposes under this [license](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/blob/main/LICENSE.md). Possible research areas and tasks include
|
| 51 |
+
|
| 52 |
+
- Research on generative models.
|
| 53 |
+
- Safe deployment of models which have the potential to generate harmful content.
|
| 54 |
+
- Probing and understanding the limitations and biases of generative models.
|
| 55 |
+
- Generation of artworks and use in design and other artistic processes.
|
| 56 |
+
- Applications in educational or creative tools.
|
| 57 |
+
|
| 58 |
+
For commercial use, please refer to https://stability.ai/license.
|
| 59 |
+
|
| 60 |
+
Excluded uses are described below.
|
| 61 |
+
|
| 62 |
+
### Out-of-Scope Use
|
| 63 |
+
|
| 64 |
+
The model was not trained to be factual or true representations of people or events,
|
| 65 |
+
and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
| 66 |
+
The model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy).
|
| 67 |
+
|
| 68 |
+
## Limitations and Bias
|
| 69 |
+
|
| 70 |
+
### Limitations
|
| 71 |
+
- The generated videos are rather short (<= 4sec), and the model does not achieve perfect photorealism.
|
| 72 |
+
- The model may generate videos without motion, or very slow camera pans.
|
| 73 |
+
- The model cannot be controlled through text.
|
| 74 |
+
- The model cannot render legible text.
|
| 75 |
+
- Faces and people in general may not be generated properly.
|
| 76 |
+
- The autoencoding part of the model is lossy.
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
### Recommendations
|
| 80 |
+
|
| 81 |
+
The model is intended for both non-commercial and commercial usage.
|
| 82 |
+
|
| 83 |
+
## How to Get Started with the Model
|
| 84 |
+
|
| 85 |
+
Check out https://github.com/Stability-AI/generative-models
|
| 86 |
+
|
| 87 |
+
# Appendix:
|
| 88 |
+
|
| 89 |
+
All considered potential data sources were included for final training, with none held out as the proposed data filtering methods described in the SVD paper handle the quality control/filtering of the dataset. With regards to safety/NSFW filtering, sources considered were either deemed safe or filtered with the in-house NSFW filters.
|
| 90 |
+
No explicit human labor is involved in training data preparation. However, human evaluation for model outputs and quality was extensively used to evaluate model quality and performance. The evaluations were performed with third-party contractor platforms (Amazon Sagemaker, Amazon Mechanical Turk, Prolific) with fluent English-speaking contractors from various countries, primarily from the USA, UK, and Canada. Each worker was paid $12/hr for the time invested in the evaluation.
|
| 91 |
+
No other third party was involved in the development of this model; the model was fully developed in-house at Stability AI.
|
| 92 |
+
Training the SVD checkpoints required a total of approximately 200,000 A100 80GB hours. The majority of the training occurred on 48 * 8 A100s, while some stages took more/less than that. The resulting CO2 emission is ~19,000kg CO2 eq., and energy consumed is ~64000 kWh.
|
| 93 |
+
The released checkpoints (SVD/SVD-XT) are image-to-video models that generate short videos/animations closely following the given input image. Since the model relies on an existing supplied image, the potential risks of disclosing specific material or novel unsafe content are minimal. This was also evaluated by third-party independent red-teaming services, which agree with our conclusion to a high degree of confidence (>90% in various areas of safety red-teaming). The external evaluations were also performed for trustworthiness, leading to >95% confidence in real, trustworthy videos.
|
| 94 |
+
With the default settings at the time of release, SVD takes ~100s for generation, and SVD-XT takes ~180s on an A100 80GB card. Several optimizations to trade off quality / memory / speed can be done to perform faster inference or inference on lower VRAM cards.
|
| 95 |
+
The information related to the model and its development process and usage protocols can be found in the GitHub repo, associated research paper, and HuggingFace model page/cards.
|
| 96 |
+
The released model inference & demo code has image-level watermarking enabled by default, which can be used to detect the outputs. This is done via the imWatermark Python library.
|
| 97 |
+
The model can be used to generate videos from static initial images. However, we prohibit unlawful, obscene, or misleading uses of the model consistent with the terms of our license and Acceptable Use Policy. For the open-weights release, our training data filtering mitigations alleviate this risk to some extent. These restrictions are explicitly enforced on user-facing interfaces at stablevideo.com, where a warning is issued. We do not take any responsibility for third-party interfaces. Submitting initial images that bypass input filters to tease out offensive or inappropriate content listed above is also prohibited. Safety filtering checks at stablevideo.com run on model inputs and outputs independently. More details on our user-facing interfaces can be found here: https://www.stablevideo.com/faq. Beyond the Acceptable Use Policy and other mitigations and conditions described here, the model is not subject to additional model behavior interventions of the type described in the Foundation Model Transparency Index.
|
| 98 |
+
For stablevideo.com, we store preference data in the form of upvotes/downvotes on user-generated videos, and we have a pairwise ranker that runs while a user generates videos. This usage data is solely used for improving Stability AI’s future image/video models and services. No other third-party entities are given access to the usage data beyond Stability AI and maintainers of stablevideo.com.
|
| 99 |
+
For usage statistics of SVD, we refer interested users to HuggingFace model download/usage statistics as a primary indicator. Third-party applications also have reported model usage statistics. We might also consider releasing aggregate usage statistics of stablevideo.com on reaching some milestones.
|
pretrained_weights/stable-video-diffusion-img2vid-xt/comparison.png
ADDED
|
Git LFS Details
|
pretrained_weights/stable-video-diffusion-img2vid-xt/feature_extractor/preprocessor_config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"crop_size": {
|
| 3 |
+
"height": 224,
|
| 4 |
+
"width": 224
|
| 5 |
+
},
|
| 6 |
+
"do_center_crop": true,
|
| 7 |
+
"do_convert_rgb": true,
|
| 8 |
+
"do_normalize": true,
|
| 9 |
+
"do_rescale": true,
|
| 10 |
+
"do_resize": true,
|
| 11 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
| 12 |
+
"image_mean": [
|
| 13 |
+
0.48145466,
|
| 14 |
+
0.4578275,
|
| 15 |
+
0.40821073
|
| 16 |
+
],
|
| 17 |
+
"image_processor_type": "CLIPImageProcessor",
|
| 18 |
+
"image_std": [
|
| 19 |
+
0.26862954,
|
| 20 |
+
0.26130258,
|
| 21 |
+
0.27577711
|
| 22 |
+
],
|
| 23 |
+
"resample": 3,
|
| 24 |
+
"rescale_factor": 0.00392156862745098,
|
| 25 |
+
"size": {
|
| 26 |
+
"shortest_edge": 224
|
| 27 |
+
}
|
| 28 |
+
}
|
pretrained_weights/stable-video-diffusion-img2vid-xt/image_encoder/config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "/home/suraj_huggingface_co/.cache/huggingface/hub/models--diffusers--svd-xt/snapshots/9703ded20c957c340781ee710b75660826deb487/image_encoder",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"CLIPVisionModelWithProjection"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"dropout": 0.0,
|
| 8 |
+
"hidden_act": "gelu",
|
| 9 |
+
"hidden_size": 1280,
|
| 10 |
+
"image_size": 224,
|
| 11 |
+
"initializer_factor": 1.0,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 5120,
|
| 14 |
+
"layer_norm_eps": 1e-05,
|
| 15 |
+
"model_type": "clip_vision_model",
|
| 16 |
+
"num_attention_heads": 16,
|
| 17 |
+
"num_channels": 3,
|
| 18 |
+
"num_hidden_layers": 32,
|
| 19 |
+
"patch_size": 14,
|
| 20 |
+
"projection_dim": 1024,
|
| 21 |
+
"torch_dtype": "float16",
|
| 22 |
+
"transformers_version": "4.34.0.dev0"
|
| 23 |
+
}
|
pretrained_weights/stable-video-diffusion-img2vid-xt/model_index.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "StableVideoDiffusionPipeline",
|
| 3 |
+
"_diffusers_version": "0.24.0.dev0",
|
| 4 |
+
"_name_or_path": "diffusers/svd-xt",
|
| 5 |
+
"feature_extractor": [
|
| 6 |
+
"transformers",
|
| 7 |
+
"CLIPImageProcessor"
|
| 8 |
+
],
|
| 9 |
+
"image_encoder": [
|
| 10 |
+
"transformers",
|
| 11 |
+
"CLIPVisionModelWithProjection"
|
| 12 |
+
],
|
| 13 |
+
"scheduler": [
|
| 14 |
+
"diffusers",
|
| 15 |
+
"EulerDiscreteScheduler"
|
| 16 |
+
],
|
| 17 |
+
"unet": [
|
| 18 |
+
"diffusers",
|
| 19 |
+
"UNetSpatioTemporalConditionModel"
|
| 20 |
+
],
|
| 21 |
+
"vae": [
|
| 22 |
+
"diffusers",
|
| 23 |
+
"AutoencoderKLTemporalDecoder"
|
| 24 |
+
]
|
| 25 |
+
}
|
pretrained_weights/stable-video-diffusion-img2vid-xt/output_tile.gif
ADDED
|
Git LFS Details
|
pretrained_weights/stable-video-diffusion-img2vid-xt/scheduler/scheduler_config.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "EulerDiscreteScheduler",
|
| 3 |
+
"_diffusers_version": "0.24.0.dev0",
|
| 4 |
+
"beta_end": 0.012,
|
| 5 |
+
"beta_schedule": "scaled_linear",
|
| 6 |
+
"beta_start": 0.00085,
|
| 7 |
+
"clip_sample": false,
|
| 8 |
+
"interpolation_type": "linear",
|
| 9 |
+
"num_train_timesteps": 1000,
|
| 10 |
+
"prediction_type": "v_prediction",
|
| 11 |
+
"set_alpha_to_one": false,
|
| 12 |
+
"sigma_max": 700.0,
|
| 13 |
+
"sigma_min": 0.002,
|
| 14 |
+
"skip_prk_steps": true,
|
| 15 |
+
"steps_offset": 1,
|
| 16 |
+
"timestep_spacing": "leading",
|
| 17 |
+
"timestep_type": "continuous",
|
| 18 |
+
"trained_betas": null,
|
| 19 |
+
"use_karras_sigmas": true
|
| 20 |
+
}
|
pretrained_weights/stable-video-diffusion-img2vid-xt/svd_xt.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2652c23d64a1da5f14d55011b9b6dce55f2e72e395719f1cd1f8a079b00a451
|
| 3 |
+
size 9559625980
|
pretrained_weights/stable-video-diffusion-img2vid-xt/svd_xt_image_decoder.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:99aa889bf6d1ca28e026755b83ba37e3072ad79b45dd4c94fae14bee7482263b
|
| 3 |
+
size 9503252964
|
pretrained_weights/stable-video-diffusion-img2vid-xt/unet/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "UNetSpatioTemporalConditionModel",
|
| 3 |
+
"_diffusers_version": "0.24.0.dev0",
|
| 4 |
+
"_name_or_path": "/home/suraj_huggingface_co/.cache/huggingface/hub/models--diffusers--svd-xt/snapshots/9703ded20c957c340781ee710b75660826deb487/unet",
|
| 5 |
+
"addition_time_embed_dim": 256,
|
| 6 |
+
"block_out_channels": [
|
| 7 |
+
320,
|
| 8 |
+
640,
|
| 9 |
+
1280,
|
| 10 |
+
1280
|
| 11 |
+
],
|
| 12 |
+
"cross_attention_dim": 1024,
|
| 13 |
+
"down_block_types": [
|
| 14 |
+
"CrossAttnDownBlockSpatioTemporal",
|
| 15 |
+
"CrossAttnDownBlockSpatioTemporal",
|
| 16 |
+
"CrossAttnDownBlockSpatioTemporal",
|
| 17 |
+
"DownBlockSpatioTemporal"
|
| 18 |
+
],
|
| 19 |
+
"in_channels": 8,
|
| 20 |
+
"layers_per_block": 2,
|
| 21 |
+
"num_attention_heads": [
|
| 22 |
+
5,
|
| 23 |
+
10,
|
| 24 |
+
20,
|
| 25 |
+
20
|
| 26 |
+
],
|
| 27 |
+
"num_frames": 25,
|
| 28 |
+
"out_channels": 4,
|
| 29 |
+
"projection_class_embeddings_input_dim": 768,
|
| 30 |
+
"sample_size": 96,
|
| 31 |
+
"transformer_layers_per_block": 1,
|
| 32 |
+
"up_block_types": [
|
| 33 |
+
"UpBlockSpatioTemporal",
|
| 34 |
+
"CrossAttnUpBlockSpatioTemporal",
|
| 35 |
+
"CrossAttnUpBlockSpatioTemporal",
|
| 36 |
+
"CrossAttnUpBlockSpatioTemporal"
|
| 37 |
+
]
|
| 38 |
+
}
|
pretrained_weights/stable-video-diffusion-img2vid-xt/unet/diffusion_pytorch_model.fp16.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9fbc02e90f37d422f5e3a4aeaee95f6629dc8c45ca211b951626e930daf2bddf
|
| 3 |
+
size 3049435868
|
pretrained_weights/stable-video-diffusion-img2vid-xt/unet/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7783d82729af04f26ded4641a5952617fe331fc46add332fb9e47674fecc6ad7
|
| 3 |
+
size 6098682464
|
pretrained_weights/stable-video-diffusion-img2vid-xt/vae/config.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKLTemporalDecoder",
|
| 3 |
+
"_diffusers_version": "0.24.0.dev0",
|
| 4 |
+
"_name_or_path": "/home/suraj_huggingface_co/.cache/huggingface/hub/models--diffusers--svd-xt/snapshots/9703ded20c957c340781ee710b75660826deb487/vae",
|
| 5 |
+
"block_out_channels": [
|
| 6 |
+
128,
|
| 7 |
+
256,
|
| 8 |
+
512,
|
| 9 |
+
512
|
| 10 |
+
],
|
| 11 |
+
"down_block_types": [
|
| 12 |
+
"DownEncoderBlock2D",
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D"
|
| 16 |
+
],
|
| 17 |
+
"force_upcast": true,
|
| 18 |
+
"in_channels": 3,
|
| 19 |
+
"latent_channels": 4,
|
| 20 |
+
"layers_per_block": 2,
|
| 21 |
+
"out_channels": 3,
|
| 22 |
+
"sample_size": 768,
|
| 23 |
+
"scaling_factor": 0.18215
|
| 24 |
+
}
|
pretrained_weights/stable-video-diffusion-img2vid-xt/vae/diffusion_pytorch_model.fp16.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:af602cd0eb4ad6086ec94fbf1438dfb1be5ec9ac03fd0215640854e90d6463a3
|
| 3 |
+
size 195531910
|
pretrained_weights/stable-video-diffusion-img2vid-xt/vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5d92aa595a53d9da9faf594f09910ee869d5d567c8bb0362d5095673c69997d6
|
| 3 |
+
size 391017740
|
pretrained_weights/xnemo_denoising_unet.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0ff582dff6e19b08278378cfc244cf7203c6f70e3dcaba492ec39f9abb9be3d2
|
| 3 |
+
size 4927016814
|
pretrained_weights/xnemo_motion_encoder.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0230c49cebff21fd81c14fc61fc509ab1120b61415d40571e7dc1b9df1fc6b6f
|
| 3 |
+
size 246869630
|