xixircc commited on
Commit
d9fac04
·
verified ·
1 Parent(s): bd29c5e

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. blaze_face_short_range.tflite +3 -0
  3. face-parsing/.gitattributes +28 -0
  4. face-parsing/README.md +165 -0
  5. face-parsing/config.json +111 -0
  6. face-parsing/demo.png +3 -0
  7. face-parsing/model.safetensors +3 -0
  8. face-parsing/onnx/model.onnx +3 -0
  9. face-parsing/onnx/model_quantized.onnx +3 -0
  10. face-parsing/preprocessor_config.json +23 -0
  11. face-parsing/quantize_config.json +33 -0
  12. models/unet_3d.py +727 -0
  13. models/unet_3d_blocks.py +1121 -0
  14. pretrained_weights/sd-image-variations-diffusers/.gitattributes +32 -0
  15. pretrained_weights/sd-image-variations-diffusers/README.md +226 -0
  16. pretrained_weights/sd-image-variations-diffusers/alias-montage.jpg +3 -0
  17. pretrained_weights/sd-image-variations-diffusers/default-montage.jpg +3 -0
  18. pretrained_weights/sd-image-variations-diffusers/earring.jpg +3 -0
  19. pretrained_weights/sd-image-variations-diffusers/feature_extractor/preprocessor_config.json +28 -0
  20. pretrained_weights/sd-image-variations-diffusers/image_encoder/config.json +23 -0
  21. pretrained_weights/sd-image-variations-diffusers/image_encoder/pytorch_model.bin +3 -0
  22. pretrained_weights/sd-image-variations-diffusers/inputs.jpg +0 -0
  23. pretrained_weights/sd-image-variations-diffusers/model_index.json +29 -0
  24. pretrained_weights/sd-image-variations-diffusers/safety_checker/config.json +181 -0
  25. pretrained_weights/sd-image-variations-diffusers/scheduler/scheduler_config.json +13 -0
  26. pretrained_weights/sd-image-variations-diffusers/unet/config.json +40 -0
  27. pretrained_weights/sd-image-variations-diffusers/unet/diffusion_pytorch_model.bin +3 -0
  28. pretrained_weights/sd-image-variations-diffusers/v1-montage.jpg +3 -0
  29. pretrained_weights/sd-image-variations-diffusers/v2-montage.jpg +3 -0
  30. pretrained_weights/sd-image-variations-diffusers/vae/config.json +30 -0
  31. pretrained_weights/sd-image-variations-diffusers/vae/diffusion_pytorch_model.bin +3 -0
  32. pretrained_weights/stable-video-diffusion-img2vid-xt/.gitattributes +36 -0
  33. pretrained_weights/stable-video-diffusion-img2vid-xt/LICENSE.md +58 -0
  34. pretrained_weights/stable-video-diffusion-img2vid-xt/README.md +99 -0
  35. pretrained_weights/stable-video-diffusion-img2vid-xt/comparison.png +3 -0
  36. pretrained_weights/stable-video-diffusion-img2vid-xt/feature_extractor/preprocessor_config.json +28 -0
  37. pretrained_weights/stable-video-diffusion-img2vid-xt/image_encoder/config.json +23 -0
  38. pretrained_weights/stable-video-diffusion-img2vid-xt/model_index.json +25 -0
  39. pretrained_weights/stable-video-diffusion-img2vid-xt/output_tile.gif +3 -0
  40. pretrained_weights/stable-video-diffusion-img2vid-xt/scheduler/scheduler_config.json +20 -0
  41. pretrained_weights/stable-video-diffusion-img2vid-xt/svd_xt.safetensors +3 -0
  42. pretrained_weights/stable-video-diffusion-img2vid-xt/svd_xt_image_decoder.safetensors +3 -0
  43. pretrained_weights/stable-video-diffusion-img2vid-xt/unet/config.json +38 -0
  44. pretrained_weights/stable-video-diffusion-img2vid-xt/unet/diffusion_pytorch_model.fp16.safetensors +3 -0
  45. pretrained_weights/stable-video-diffusion-img2vid-xt/unet/diffusion_pytorch_model.safetensors +3 -0
  46. pretrained_weights/stable-video-diffusion-img2vid-xt/vae/config.json +24 -0
  47. pretrained_weights/stable-video-diffusion-img2vid-xt/vae/diffusion_pytorch_model.fp16.safetensors +3 -0
  48. pretrained_weights/stable-video-diffusion-img2vid-xt/vae/diffusion_pytorch_model.safetensors +3 -0
  49. pretrained_weights/xnemo_denoising_unet.pth +3 -0
  50. pretrained_weights/xnemo_motion_encoder.pth +3 -0
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ face-parsing/demo.png filter=lfs diff=lfs merge=lfs -text
37
+ pretrained_weights/sd-image-variations-diffusers/alias-montage.jpg filter=lfs diff=lfs merge=lfs -text
38
+ pretrained_weights/sd-image-variations-diffusers/default-montage.jpg filter=lfs diff=lfs merge=lfs -text
39
+ pretrained_weights/sd-image-variations-diffusers/earring.jpg filter=lfs diff=lfs merge=lfs -text
40
+ pretrained_weights/sd-image-variations-diffusers/v1-montage.jpg filter=lfs diff=lfs merge=lfs -text
41
+ pretrained_weights/sd-image-variations-diffusers/v2-montage.jpg filter=lfs diff=lfs merge=lfs -text
42
+ pretrained_weights/stable-video-diffusion-img2vid-xt/comparison.png filter=lfs diff=lfs merge=lfs -text
43
+ pretrained_weights/stable-video-diffusion-img2vid-xt/output_tile.gif filter=lfs diff=lfs merge=lfs -text
blaze_face_short_range.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4578f35940bf5a1a655214a1cce5cab13eba73c1297cd78e1a04c2380b0152f
3
+ size 229746
face-parsing/.gitattributes ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.onnx filter=lfs diff=lfs merge=lfs -text
13
+ *.ot filter=lfs diff=lfs merge=lfs -text
14
+ *.parquet filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pt filter=lfs diff=lfs merge=lfs -text
17
+ *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *.rar filter=lfs diff=lfs merge=lfs -text
19
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
21
+ *.tflite filter=lfs diff=lfs merge=lfs -text
22
+ *.tgz filter=lfs diff=lfs merge=lfs -text
23
+ *.wasm filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
face-parsing/README.md ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ library_name: transformers
4
+ tags:
5
+ - vision
6
+ - image-segmentation
7
+ - nvidia/mit-b5
8
+ - transformers.js
9
+ - onnx
10
+ datasets:
11
+ - celebamaskhq
12
+ ---
13
+
14
+ # Face Parsing
15
+
16
+ ![example image and output](demo.png)
17
+
18
+ [Semantic segmentation](https://huggingface.co/docs/transformers/tasks/semantic_segmentation) model fine-tuned from [nvidia/mit-b5](https://huggingface.co/nvidia/mit-b5) with [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ) for face parsing. For additional options, see the Transformers [Segformer docs](https://huggingface.co/docs/transformers/model_doc/segformer).
19
+
20
+ > ONNX model for web inference contributed by [Xenova](https://huggingface.co/Xenova).
21
+
22
+ ## Usage in Python
23
+
24
+ Exhaustive list of labels can be extracted from [config.json](https://huggingface.co/jonathandinu/face-parsing/blob/65972ac96180b397f86fda0980bbe68e6ee01b8f/config.json#L30).
25
+
26
+ | id | label | note |
27
+ | :-: | :--------- | :---------------- |
28
+ | 0 | background | |
29
+ | 1 | skin | |
30
+ | 2 | nose | |
31
+ | 3 | eye_g | eyeglasses |
32
+ | 4 | l_eye | left eye |
33
+ | 5 | r_eye | right eye |
34
+ | 6 | l_brow | left eyebrow |
35
+ | 7 | r_brow | right eyebrow |
36
+ | 8 | l_ear | left ear |
37
+ | 9 | r_ear | right ear |
38
+ | 10 | mouth | area between lips |
39
+ | 11 | u_lip | upper lip |
40
+ | 12 | l_lip | lower lip |
41
+ | 13 | hair | |
42
+ | 14 | hat | |
43
+ | 15 | ear_r | earring |
44
+ | 16 | neck_l | necklace |
45
+ | 17 | neck | |
46
+ | 18 | cloth | clothing |
47
+
48
+ ```python
49
+ import torch
50
+ from torch import nn
51
+ from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
52
+
53
+ from PIL import Image
54
+ import matplotlib.pyplot as plt
55
+ import requests
56
+
57
+ # convenience expression for automatically determining device
58
+ device = (
59
+ "cuda"
60
+ # Device for NVIDIA or AMD GPUs
61
+ if torch.cuda.is_available()
62
+ else "mps"
63
+ # Device for Apple Silicon (Metal Performance Shaders)
64
+ if torch.backends.mps.is_available()
65
+ else "cpu"
66
+ )
67
+
68
+ # load models
69
+ image_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
70
+ model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
71
+ model.to(device)
72
+
73
+ # expects a PIL.Image or torch.Tensor
74
+ url = "https://images.unsplash.com/photo-1539571696357-5a69c17a67c6"
75
+ image = Image.open(requests.get(url, stream=True).raw)
76
+
77
+ # run inference on image
78
+ inputs = image_processor(images=image, return_tensors="pt").to(device)
79
+ outputs = model(**inputs)
80
+ logits = outputs.logits # shape (batch_size, num_labels, ~height/4, ~width/4)
81
+
82
+ # resize output to match input image dimensions
83
+ upsampled_logits = nn.functional.interpolate(logits,
84
+ size=image.size[::-1], # H x W
85
+ mode='bilinear',
86
+ align_corners=False)
87
+
88
+ # get label masks
89
+ labels = upsampled_logits.argmax(dim=1)[0]
90
+
91
+ # move to CPU to visualize in matplotlib
92
+ labels_viz = labels.cpu().numpy()
93
+ plt.imshow(labels_viz)
94
+ plt.show()
95
+ ```
96
+
97
+ ## Usage in the browser (Transformers.js)
98
+
99
+ ```js
100
+ import {
101
+ pipeline,
102
+ env,
103
+ } from "https://cdn.jsdelivr.net/npm/@xenova/transformers@2.14.0";
104
+
105
+ // important to prevent errors since the model files are likely remote on HF hub
106
+ env.allowLocalModels = false;
107
+
108
+ // instantiate image segmentation pipeline with pretrained face parsing model
109
+ model = await pipeline("image-segmentation", "jonathandinu/face-parsing");
110
+
111
+ // async inference since it could take a few seconds
112
+ const output = await model(url);
113
+
114
+ // each label is a separate mask object
115
+ // [
116
+ // { score: null, label: 'background', mask: transformers.js RawImage { ... }}
117
+ // { score: null, label: 'hair', mask: transformers.js RawImage { ... }}
118
+ // ...
119
+ // ]
120
+ for (const m of output) {
121
+ print(`Found ${m.label}`);
122
+ m.mask.save(`${m.label}.png`);
123
+ }
124
+ ```
125
+
126
+ ### p5.js
127
+
128
+ Since [p5.js](https://p5js.org/) uses an animation loop abstraction, we need to take care loading the model and making predictions.
129
+
130
+ ```js
131
+ // ...
132
+
133
+ // asynchronously load transformers.js and instantiate model
134
+ async function preload() {
135
+ // load transformers.js library with a dynamic import
136
+ const { pipeline, env } = await import(
137
+ "https://cdn.jsdelivr.net/npm/@xenova/transformers@2.14.0"
138
+ );
139
+
140
+ // important to prevent errors since the model files are remote on HF hub
141
+ env.allowLocalModels = false;
142
+
143
+ // instantiate image segmentation pipeline with pretrained face parsing model
144
+ model = await pipeline("image-segmentation", "jonathandinu/face-parsing");
145
+
146
+ print("face-parsing model loaded");
147
+ }
148
+
149
+ // ...
150
+ ```
151
+
152
+ [full p5.js example](https://editor.p5js.org/jonathan.ai/sketches/wZn15Dvgh)
153
+
154
+ ### Model Description
155
+
156
+ - **Developed by:** [Jonathan Dinu](https://twitter.com/jonathandinu)
157
+ - **Model type:** Transformer-based semantic segmentation image model
158
+ - **License:** non-commercial research and educational purposes
159
+ - **Resources for more information:** Transformers docs on [Segformer](https://huggingface.co/docs/transformers/model_doc/segformer) and/or the [original research paper](https://arxiv.org/abs/2105.15203).
160
+
161
+ ## Limitations and Bias
162
+
163
+ ### Bias
164
+
165
+ While the capabilities of computer vision models are impressive, they can also reinforce or exacerbate social biases. The [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ) dataset used for fine-tuning is large but not necessarily perfectly diverse or representative. Also, they are images of.... just celebrities.
face-parsing/config.json ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "jonathandinu/face-parsing",
3
+ "architectures": [
4
+ "SegformerForSemanticSegmentation"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "classifier_dropout_prob": 0.1,
8
+ "decoder_hidden_size": 768,
9
+ "depths": [
10
+ 3,
11
+ 6,
12
+ 40,
13
+ 3
14
+ ],
15
+ "downsampling_rates": [
16
+ 1,
17
+ 4,
18
+ 8,
19
+ 16
20
+ ],
21
+ "drop_path_rate": 0.1,
22
+ "hidden_act": "gelu",
23
+ "hidden_dropout_prob": 0.0,
24
+ "hidden_sizes": [
25
+ 64,
26
+ 128,
27
+ 320,
28
+ 512
29
+ ],
30
+ "id2label": {
31
+ "0": "background",
32
+ "1": "skin",
33
+ "2": "nose",
34
+ "3": "eye_g",
35
+ "4": "l_eye",
36
+ "5": "r_eye",
37
+ "6": "l_brow",
38
+ "7": "r_brow",
39
+ "8": "l_ear",
40
+ "9": "r_ear",
41
+ "10": "mouth",
42
+ "11": "u_lip",
43
+ "12": "l_lip",
44
+ "13": "hair",
45
+ "14": "hat",
46
+ "15": "ear_r",
47
+ "16": "neck_l",
48
+ "17": "neck",
49
+ "18": "cloth"
50
+ },
51
+ "image_size": 224,
52
+ "initializer_range": 0.02,
53
+ "label2id": {
54
+ "background": 0,
55
+ "skin": 1,
56
+ "nose": 2,
57
+ "eye_g": 3,
58
+ "l_eye": 4,
59
+ "r_eye": 5,
60
+ "l_brow": 6,
61
+ "r_brow": 7,
62
+ "l_ear": 8,
63
+ "r_ear": 9,
64
+ "mouth": 10,
65
+ "u_lip": 11,
66
+ "l_lip": 12,
67
+ "hair": 13,
68
+ "hat": 14,
69
+ "ear_r": 15,
70
+ "neck_l": 16,
71
+ "neck": 17,
72
+ "cloth": 18
73
+ },
74
+ "layer_norm_eps": 1e-06,
75
+ "mlp_ratios": [
76
+ 4,
77
+ 4,
78
+ 4,
79
+ 4
80
+ ],
81
+ "model_type": "segformer",
82
+ "num_attention_heads": [
83
+ 1,
84
+ 2,
85
+ 5,
86
+ 8
87
+ ],
88
+ "num_channels": 3,
89
+ "num_encoder_blocks": 4,
90
+ "patch_sizes": [
91
+ 7,
92
+ 3,
93
+ 3,
94
+ 3
95
+ ],
96
+ "reshape_last_stage": true,
97
+ "semantic_loss_ignore_index": 255,
98
+ "sr_ratios": [
99
+ 8,
100
+ 4,
101
+ 2,
102
+ 1
103
+ ],
104
+ "strides": [
105
+ 4,
106
+ 2,
107
+ 2,
108
+ 2
109
+ ],
110
+ "transformers_version": "4.37.0.dev0"
111
+ }
face-parsing/demo.png ADDED

Git LFS Details

  • SHA256: 31c74d29ab9e45f3401f404f7bfc09e2cf9f5825611f07dc20b25d00eb1cac8a
  • Pointer size: 131 Bytes
  • Size of remote file: 645 kB
face-parsing/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2bec795a8c243db71bd95be538fd62559003566466c71237e45c99b920f4b62
3
+ size 338580732
face-parsing/onnx/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d4e67af60ff78184745ebf74cc15163c0adc27d45cdeba31e3a03d1096fb8c3
3
+ size 340316611
face-parsing/onnx/model_quantized.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bab9bfb3cb979f3098ac3b934b1641dbf87f835e0b03c2ca6d88dcf18c83d27
3
+ size 89439678
face-parsing/preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "do_reduce_labels": false,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.485,
8
+ 0.456,
9
+ 0.406
10
+ ],
11
+ "image_processor_type": "SegformerFeatureExtractor",
12
+ "image_std": [
13
+ 0.229,
14
+ 0.224,
15
+ 0.225
16
+ ],
17
+ "resample": 2,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "height": 512,
21
+ "width": 512
22
+ }
23
+ }
face-parsing/quantize_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "per_channel": true,
3
+ "reduce_range": true,
4
+ "per_model_config": {
5
+ "model": {
6
+ "op_types": [
7
+ "Unsqueeze",
8
+ "Shape",
9
+ "Transpose",
10
+ "Sqrt",
11
+ "Gather",
12
+ "Slice",
13
+ "Erf",
14
+ "Div",
15
+ "Reshape",
16
+ "Add",
17
+ "Cast",
18
+ "Sub",
19
+ "Concat",
20
+ "ReduceMean",
21
+ "Mul",
22
+ "Conv",
23
+ "Constant",
24
+ "Resize",
25
+ "Softmax",
26
+ "Pow",
27
+ "Relu",
28
+ "MatMul"
29
+ ],
30
+ "weight_type": "QUInt8"
31
+ }
32
+ }
33
+ }
models/unet_3d.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # This file has been modified by ByteDance Ltd. and/or its affiliates.
6
+ #
7
+ # Original file was released under Aniportrait, with the full license text
8
+ # available at https://github.com/Zejun-Yang/AniPortrait/blob/main/LICENSE.
9
+ #
10
+ # This modified file is released under the same license.
11
+ # *************************************************************************
12
+ from collections import OrderedDict
13
+ from dataclasses import dataclass
14
+ import pdb
15
+ from os import PathLike
16
+ from pathlib import Path
17
+ from typing import Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.utils.checkpoint
22
+ import torch.nn.functional as F
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.models.attention_processor import AttentionProcessor
25
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
28
+ from safetensors.torch import load_file
29
+
30
+ from .resnet import InflatedConv3d, InflatedGroupNorm
31
+ from .unet_3d_blocks import UNetMidBlock3DCrossAttn, get_down_block, get_up_block
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ @dataclass
37
+ class UNet3DConditionOutput(BaseOutput):
38
+ sample: torch.FloatTensor
39
+
40
+
41
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
42
+ _supports_gradient_checkpointing = True
43
+
44
+ @register_to_config
45
+ def __init__(
46
+ self,
47
+ sample_size: Optional[int] = None,
48
+ in_channels: int = 4,
49
+ out_channels: int = 4,
50
+ center_input_sample: bool = False,
51
+ flip_sin_to_cos: bool = True,
52
+ freq_shift: int = 0,
53
+ down_block_types: Tuple[str] = (
54
+ "CrossAttnDownBlock3D",
55
+ "CrossAttnDownBlock3D",
56
+ "CrossAttnDownBlock3D",
57
+ "DownBlock3D",
58
+ ),
59
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
60
+ up_block_types: Tuple[str] = (
61
+ "UpBlock3D",
62
+ "CrossAttnUpBlock3D",
63
+ "CrossAttnUpBlock3D",
64
+ "CrossAttnUpBlock3D",
65
+ ),
66
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
67
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
68
+ layers_per_block: int = 2,
69
+ downsample_padding: int = 1,
70
+ mid_block_scale_factor: float = 1,
71
+ act_fn: str = "silu",
72
+ norm_num_groups: int = 32,
73
+ norm_eps: float = 1e-5,
74
+ cross_attention_dim: int = 1280,
75
+ attention_head_dim: Union[int, Tuple[int]] = 8,
76
+ dual_cross_attention: bool = False,
77
+ use_linear_projection: bool = False,
78
+ class_embed_type: Optional[str] = None,
79
+ num_class_embeds: Optional[int] = None,
80
+ upcast_attention: bool = False,
81
+ resnet_time_scale_shift: str = "default",
82
+ use_inflated_groupnorm=False,
83
+ # Additional
84
+ use_motion_module=False,
85
+ use_temporal_module=False,
86
+ motion_module_resolutions=(1, 2, 4, 8),
87
+ motion_module_mid_block=False,
88
+ motion_module_decoder_only=False,
89
+ motion_module_type=None,
90
+ temporal_module_type=None,
91
+ motion_module_kwargs={},
92
+ temporal_module_kwargs={},
93
+ unet_use_cross_frame_attention=None,
94
+ unet_use_temporal_attention=None,
95
+ ):
96
+ super().__init__()
97
+
98
+ self.sample_size = sample_size
99
+ time_embed_dim = block_out_channels[0] * 4
100
+
101
+ # input
102
+ self.conv_in = InflatedConv3d(
103
+ in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
104
+ )
105
+
106
+ # time
107
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
108
+ timestep_input_dim = block_out_channels[0]
109
+
110
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
111
+
112
+ # class embedding
113
+ if class_embed_type is None and num_class_embeds is not None:
114
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
115
+ elif class_embed_type == "timestep":
116
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
117
+ elif class_embed_type == "identity":
118
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
119
+ else:
120
+ self.class_embedding = None
121
+
122
+ self.down_blocks = nn.ModuleList([])
123
+ self.mid_block = None
124
+ self.up_blocks = nn.ModuleList([])
125
+
126
+ if isinstance(only_cross_attention, bool):
127
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
128
+
129
+ if isinstance(attention_head_dim, int):
130
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
131
+
132
+ # down
133
+ output_channel = block_out_channels[0]
134
+ for i, down_block_type in enumerate(down_block_types):
135
+ res = 2**i
136
+ input_channel = output_channel
137
+ output_channel = block_out_channels[i]
138
+ is_final_block = i == len(block_out_channels) - 1
139
+
140
+ down_block = get_down_block(
141
+ down_block_type,
142
+ num_layers=layers_per_block,
143
+ in_channels=input_channel,
144
+ out_channels=output_channel,
145
+ temb_channels=time_embed_dim,
146
+ add_downsample=not is_final_block,
147
+ resnet_eps=norm_eps,
148
+ resnet_act_fn=act_fn,
149
+ resnet_groups=norm_num_groups,
150
+ cross_attention_dim=cross_attention_dim,
151
+ attn_num_head_channels=attention_head_dim[i],
152
+ downsample_padding=downsample_padding,
153
+ dual_cross_attention=dual_cross_attention,
154
+ use_linear_projection=use_linear_projection,
155
+ only_cross_attention=only_cross_attention[i],
156
+ upcast_attention=upcast_attention,
157
+ resnet_time_scale_shift=resnet_time_scale_shift,
158
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
159
+ unet_use_temporal_attention=unet_use_temporal_attention,
160
+ use_inflated_groupnorm=use_inflated_groupnorm,
161
+ use_motion_module=use_motion_module
162
+ and (res in motion_module_resolutions)
163
+ and (not motion_module_decoder_only),
164
+ use_temporal_module=use_temporal_module
165
+ and (res in motion_module_resolutions)
166
+ and (not motion_module_decoder_only),
167
+ motion_module_type=motion_module_type,
168
+ temporal_module_type=temporal_module_type,
169
+ motion_module_kwargs=motion_module_kwargs,
170
+ temporal_module_kwargs=temporal_module_kwargs
171
+ )
172
+ self.down_blocks.append(down_block)
173
+
174
+ # mid
175
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
176
+ self.mid_block = UNetMidBlock3DCrossAttn(
177
+ in_channels=block_out_channels[-1],
178
+ temb_channels=time_embed_dim,
179
+ resnet_eps=norm_eps,
180
+ resnet_act_fn=act_fn,
181
+ output_scale_factor=mid_block_scale_factor,
182
+ resnet_time_scale_shift=resnet_time_scale_shift,
183
+ cross_attention_dim=cross_attention_dim,
184
+ attn_num_head_channels=attention_head_dim[-1],
185
+ resnet_groups=norm_num_groups,
186
+ dual_cross_attention=dual_cross_attention,
187
+ use_linear_projection=use_linear_projection,
188
+ upcast_attention=upcast_attention,
189
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
190
+ unet_use_temporal_attention=unet_use_temporal_attention,
191
+ use_inflated_groupnorm=use_inflated_groupnorm,
192
+ use_motion_module=use_motion_module and motion_module_mid_block,
193
+ use_temporal_module=use_temporal_module and motion_module_mid_block,
194
+ motion_module_type=motion_module_type,
195
+ temporal_module_type=temporal_module_type,
196
+ motion_module_kwargs=motion_module_kwargs,
197
+ temporal_module_kwargs=temporal_module_kwargs,
198
+ )
199
+ else:
200
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
201
+
202
+ # count how many layers upsample the videos
203
+ self.num_upsamplers = 0
204
+
205
+ # up
206
+ reversed_block_out_channels = list(reversed(block_out_channels))
207
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
208
+ only_cross_attention = list(reversed(only_cross_attention))
209
+ output_channel = reversed_block_out_channels[0]
210
+ for i, up_block_type in enumerate(up_block_types):
211
+ res = 2 ** (3 - i)
212
+ is_final_block = i == len(block_out_channels) - 1
213
+
214
+ prev_output_channel = output_channel
215
+ output_channel = reversed_block_out_channels[i]
216
+ input_channel = reversed_block_out_channels[
217
+ min(i + 1, len(block_out_channels) - 1)
218
+ ]
219
+
220
+ # add upsample block for all BUT final layer
221
+ if not is_final_block:
222
+ add_upsample = True
223
+ self.num_upsamplers += 1
224
+ else:
225
+ add_upsample = False
226
+
227
+ up_block = get_up_block(
228
+ up_block_type,
229
+ num_layers=layers_per_block + 1,
230
+ in_channels=input_channel,
231
+ out_channels=output_channel,
232
+ prev_output_channel=prev_output_channel,
233
+ temb_channels=time_embed_dim,
234
+ add_upsample=add_upsample,
235
+ resnet_eps=norm_eps,
236
+ resnet_act_fn=act_fn,
237
+ resnet_groups=norm_num_groups,
238
+ cross_attention_dim=cross_attention_dim,
239
+ attn_num_head_channels=reversed_attention_head_dim[i],
240
+ dual_cross_attention=dual_cross_attention,
241
+ use_linear_projection=use_linear_projection,
242
+ only_cross_attention=only_cross_attention[i],
243
+ upcast_attention=upcast_attention,
244
+ resnet_time_scale_shift=resnet_time_scale_shift,
245
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
246
+ unet_use_temporal_attention=unet_use_temporal_attention,
247
+ use_inflated_groupnorm=use_inflated_groupnorm,
248
+ use_motion_module=use_motion_module
249
+ and (res in motion_module_resolutions),
250
+ use_temporal_module=use_temporal_module
251
+ and (res in motion_module_resolutions),
252
+ motion_module_type=motion_module_type,
253
+ temporal_module_type=temporal_module_type,
254
+ motion_module_kwargs=motion_module_kwargs,
255
+ temporal_module_kwargs=temporal_module_kwargs,
256
+ )
257
+ self.up_blocks.append(up_block)
258
+ prev_output_channel = output_channel
259
+
260
+ # out
261
+ if use_inflated_groupnorm:
262
+ self.conv_norm_out = InflatedGroupNorm(
263
+ num_channels=block_out_channels[0],
264
+ num_groups=norm_num_groups,
265
+ eps=norm_eps,
266
+ )
267
+ else:
268
+ self.conv_norm_out = nn.GroupNorm(
269
+ num_channels=block_out_channels[0],
270
+ num_groups=norm_num_groups,
271
+ eps=norm_eps,
272
+ )
273
+ self.conv_act = nn.SiLU()
274
+ self.conv_out = InflatedConv3d(
275
+ block_out_channels[0], out_channels, kernel_size=3, padding=1
276
+ )
277
+
278
+ @property
279
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
280
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
281
+ r"""
282
+ Returns:
283
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
284
+ indexed by its weight name.
285
+ """
286
+ # set recursively
287
+ processors = {}
288
+
289
+ def fn_recursive_add_processors(
290
+ name: str,
291
+ module: torch.nn.Module,
292
+ processors: Dict[str, AttentionProcessor],
293
+ ):
294
+ if hasattr(module, "set_processor"):
295
+ processors[f"{name}.processor"] = module.processor
296
+
297
+ for sub_name, child in module.named_children():
298
+ if "temporal_transformer" not in sub_name:
299
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
300
+
301
+ return processors
302
+
303
+ for name, module in self.named_children():
304
+ if "temporal_transformer" not in name:
305
+ fn_recursive_add_processors(name, module, processors)
306
+
307
+ return processors
308
+
309
+ def set_attention_slice(self, slice_size):
310
+ r"""
311
+ Enable sliced attention computation.
312
+
313
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
314
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
315
+
316
+ Args:
317
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
318
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
319
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
320
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
321
+ must be a multiple of `slice_size`.
322
+ """
323
+ sliceable_head_dims = []
324
+
325
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
326
+ if hasattr(module, "set_attention_slice"):
327
+ sliceable_head_dims.append(module.sliceable_head_dim)
328
+
329
+ for child in module.children():
330
+ fn_recursive_retrieve_slicable_dims(child)
331
+
332
+ # retrieve number of attention layers
333
+ for module in self.children():
334
+ fn_recursive_retrieve_slicable_dims(module)
335
+
336
+ num_slicable_layers = len(sliceable_head_dims)
337
+
338
+ if slice_size == "auto":
339
+ # half the attention head size is usually a good trade-off between
340
+ # speed and memory
341
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
342
+ elif slice_size == "max":
343
+ # make smallest slice possible
344
+ slice_size = num_slicable_layers * [1]
345
+
346
+ slice_size = (
347
+ num_slicable_layers * [slice_size]
348
+ if not isinstance(slice_size, list)
349
+ else slice_size
350
+ )
351
+
352
+ if len(slice_size) != len(sliceable_head_dims):
353
+ raise ValueError(
354
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
355
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
356
+ )
357
+
358
+ for i in range(len(slice_size)):
359
+ size = slice_size[i]
360
+ dim = sliceable_head_dims[i]
361
+ if size is not None and size > dim:
362
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
363
+
364
+ # Recursively walk through all the children.
365
+ # Any children which exposes the set_attention_slice method
366
+ # gets the message
367
+ def fn_recursive_set_attention_slice(
368
+ module: torch.nn.Module, slice_size: List[int]
369
+ ):
370
+ if hasattr(module, "set_attention_slice"):
371
+ module.set_attention_slice(slice_size.pop())
372
+
373
+ for child in module.children():
374
+ fn_recursive_set_attention_slice(child, slice_size)
375
+
376
+ reversed_slice_size = list(reversed(slice_size))
377
+ for module in self.children():
378
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
379
+
380
+ def set_use_cross_frame_attention(self, value):
381
+
382
+ def fn_recursive_set_use_cf_att(module: torch.nn.Module, value):
383
+ if hasattr(module, "set_use_cross_frame_attention"):
384
+ module.set_use_cross_frame_attention(value)
385
+
386
+ for child in module.children():
387
+ fn_recursive_set_use_cf_att(child, value)
388
+
389
+ for module in self.children():
390
+ fn_recursive_set_use_cf_att(module, value)
391
+
392
+ def _set_gradient_checkpointing(self, module, value=False):
393
+ if hasattr(module, "gradient_checkpointing"):
394
+ module.gradient_checkpointing = value
395
+
396
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
397
+ def set_attn_processor(
398
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
399
+ ):
400
+ r"""
401
+ Sets the attention processor to use to compute attention.
402
+
403
+ Parameters:
404
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
405
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
406
+ for **all** `Attention` layers.
407
+
408
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
409
+ processor. This is strongly recommended when setting trainable attention processors.
410
+
411
+ """
412
+ count = len(self.attn_processors.keys())
413
+
414
+ if isinstance(processor, dict) and len(processor) != count:
415
+ raise ValueError(
416
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
417
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
418
+ )
419
+
420
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
421
+ if hasattr(module, "set_processor"):
422
+ if not isinstance(processor, dict):
423
+ module.set_processor(processor)
424
+ else:
425
+ module.set_processor(processor.pop(f"{name}.processor"))
426
+
427
+ for sub_name, child in module.named_children():
428
+ if "temporal_transformer" not in sub_name:
429
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
430
+
431
+ for name, module in self.named_children():
432
+ if "temporal_transformer" not in name:
433
+ fn_recursive_attn_processor(name, module, processor)
434
+
435
+ def forward(
436
+ self,
437
+ sample: torch.FloatTensor,
438
+ timestep: Union[torch.Tensor, float, int],
439
+ encoder_hidden_states: torch.Tensor,
440
+ class_labels: Optional[torch.Tensor] = None,
441
+ pose_cond_fea = None,
442
+ attention_mask: Optional[torch.Tensor] = None,
443
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
444
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
445
+ return_dict: bool = True,
446
+ skip_mm: bool = False,
447
+ ) -> Union[UNet3DConditionOutput, Tuple]:
448
+ r"""
449
+ Args:
450
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
451
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
452
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
453
+ return_dict (`bool`, *optional*, defaults to `True`):
454
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
455
+
456
+ Returns:
457
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
458
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
459
+ returning a tuple, the first element is the sample tensor.
460
+ """
461
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
462
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
463
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
464
+ # on the fly if necessary.
465
+
466
+ default_overall_up_factor = 2**self.num_upsamplers
467
+
468
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
469
+ forward_upsample_size = False
470
+ upsample_size = None
471
+
472
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
473
+ logger.info("Forward upsample size to force interpolation output size.")
474
+ forward_upsample_size = True
475
+
476
+ # prepare attention_mask
477
+ if attention_mask is not None:
478
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
479
+ attention_mask = attention_mask.unsqueeze(1)
480
+
481
+ # center input if necessary
482
+ if self.config.center_input_sample:
483
+ sample = 2 * sample - 1.0
484
+
485
+ # time
486
+ timesteps = timestep
487
+ if not torch.is_tensor(timesteps):
488
+ # This would be a good case for the `match` statement (Python 3.10+)
489
+ is_mps = sample.device.type == "mps"
490
+ if isinstance(timestep, float):
491
+ dtype = torch.float32 if is_mps else torch.float64
492
+ else:
493
+ dtype = torch.int32 if is_mps else torch.int64
494
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
495
+ elif len(timesteps.shape) == 0:
496
+ timesteps = timesteps[None].to(sample.device)
497
+
498
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
499
+ timesteps = timesteps.expand(sample.shape[0])
500
+
501
+ t_emb = self.time_proj(timesteps)
502
+
503
+ # timesteps does not contain any weights and will always return f32 tensors
504
+ # but time_embedding might actually be running in fp16. so we need to cast here.
505
+ # there might be better ways to encapsulate this.
506
+ t_emb = t_emb.to(dtype=self.dtype)
507
+ emb = self.time_embedding(t_emb)
508
+ if self.class_embedding is not None:
509
+ if class_labels is None:
510
+ raise ValueError(
511
+ "class_labels should be provided when num_class_embeds > 0"
512
+ )
513
+
514
+ if self.config.class_embed_type == "timestep":
515
+ class_labels = self.time_proj(class_labels)
516
+
517
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
518
+ emb = emb + class_emb
519
+
520
+ # pre-process
521
+ sample = self.conv_in(sample)
522
+ if pose_cond_fea is not None:
523
+ sample = sample + pose_cond_fea[0]
524
+
525
+ # down
526
+ down_block_res_samples = (sample,)
527
+ block_count = 1
528
+ for downsample_block in self.down_blocks:
529
+ if (
530
+ hasattr(downsample_block, "has_cross_attention")
531
+ and downsample_block.has_cross_attention
532
+ ):
533
+ sample, res_samples = downsample_block(
534
+ hidden_states=sample,
535
+ temb=emb,
536
+ encoder_hidden_states=encoder_hidden_states,
537
+ attention_mask=attention_mask,
538
+ skip_mm=skip_mm,
539
+ )
540
+ else:
541
+ sample, res_samples = downsample_block(
542
+ hidden_states=sample,
543
+ temb=emb,
544
+ encoder_hidden_states=encoder_hidden_states,
545
+ skip_mm=skip_mm,
546
+ )
547
+ if pose_cond_fea is not None:
548
+ sample = sample + pose_cond_fea[block_count]
549
+ block_count += 1
550
+ down_block_res_samples += res_samples
551
+
552
+ if down_block_additional_residuals is not None:
553
+ new_down_block_res_samples = ()
554
+
555
+ for down_block_res_sample, down_block_additional_residual in zip(
556
+ down_block_res_samples, down_block_additional_residuals
557
+ ):
558
+ down_block_res_sample = (
559
+ down_block_res_sample + down_block_additional_residual
560
+ )
561
+ new_down_block_res_samples += (down_block_res_sample,)
562
+
563
+ down_block_res_samples = new_down_block_res_samples
564
+
565
+ # mid
566
+ sample = self.mid_block(
567
+ sample,
568
+ emb,
569
+ encoder_hidden_states=encoder_hidden_states,
570
+ attention_mask=attention_mask,
571
+ skip_mm=skip_mm,
572
+ )
573
+
574
+ if mid_block_additional_residual is not None:
575
+ sample = sample + mid_block_additional_residual
576
+
577
+ # up
578
+ for i, upsample_block in enumerate(self.up_blocks):
579
+ is_final_block = i == len(self.up_blocks) - 1
580
+
581
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
582
+ down_block_res_samples = down_block_res_samples[
583
+ : -len(upsample_block.resnets)
584
+ ]
585
+
586
+ # if we have not reached the final block and need to forward the
587
+ # upsample size, we do it here
588
+ if not is_final_block and forward_upsample_size:
589
+ upsample_size = down_block_res_samples[-1].shape[2:]
590
+
591
+ if (
592
+ hasattr(upsample_block, "has_cross_attention")
593
+ and upsample_block.has_cross_attention
594
+ ):
595
+ sample = upsample_block(
596
+ hidden_states=sample,
597
+ temb=emb,
598
+ res_hidden_states_tuple=res_samples,
599
+ encoder_hidden_states=encoder_hidden_states,
600
+ upsample_size=upsample_size,
601
+ attention_mask=attention_mask,
602
+ skip_mm=skip_mm,
603
+ )
604
+ else:
605
+ sample = upsample_block(
606
+ hidden_states=sample,
607
+ temb=emb,
608
+ res_hidden_states_tuple=res_samples,
609
+ upsample_size=upsample_size,
610
+ encoder_hidden_states=encoder_hidden_states,
611
+ skip_mm=skip_mm,
612
+ )
613
+
614
+ # post-process
615
+ sample = self.conv_norm_out(sample)
616
+ sample = self.conv_act(sample)
617
+ sample = self.conv_out(sample)
618
+
619
+ if not return_dict:
620
+ return (sample,)
621
+
622
+ return UNet3DConditionOutput(sample=sample)
623
+
624
+ @classmethod
625
+ def from_pretrained_2d(
626
+ cls,
627
+ pretrained_model_path: PathLike,
628
+ motion_module_path: PathLike,
629
+ subfolder=None,
630
+ unet_additional_kwargs=None,
631
+ mm_zero_proj_out=False,
632
+ ):
633
+ pretrained_model_path = Path(pretrained_model_path)
634
+ motion_module_path = Path(motion_module_path)
635
+ if subfolder is not None:
636
+ pretrained_model_path = pretrained_model_path.joinpath(subfolder)
637
+ logger.info(
638
+ f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..."
639
+ )
640
+
641
+ config_file = pretrained_model_path / "config.json"
642
+ if not (config_file.exists() and config_file.is_file()):
643
+ raise RuntimeError(f"{config_file} does not exist or is not a file")
644
+
645
+ unet_config = cls.load_config(config_file)
646
+ unet_config["_class_name"] = cls.__name__
647
+ unet_config["down_block_types"] = [
648
+ "CrossAttnDownBlock3D",
649
+ "CrossAttnDownBlock3D",
650
+ "CrossAttnDownBlock3D",
651
+ "DownBlock3D",
652
+ ]
653
+ unet_config["up_block_types"] = [
654
+ "UpBlock3D",
655
+ "CrossAttnUpBlock3D",
656
+ "CrossAttnUpBlock3D",
657
+ "CrossAttnUpBlock3D",
658
+ ]
659
+ unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
660
+
661
+ model = cls.from_config(unet_config, **unet_additional_kwargs)
662
+ # load the vanilla weights
663
+ if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists():
664
+ logger.debug(
665
+ f"loading safeTensors weights from {pretrained_model_path} ..."
666
+ )
667
+ state_dict = load_file(
668
+ pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu"
669
+ )
670
+
671
+ elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists():
672
+ logger.debug(f"loading weights from {pretrained_model_path} ...")
673
+ state_dict = torch.load(
674
+ pretrained_model_path.joinpath(WEIGHTS_NAME),
675
+ map_location="cpu",
676
+ weights_only=True,
677
+ )
678
+ else:
679
+ raise FileNotFoundError(f"no weights file found in {pretrained_model_path}")
680
+
681
+ # load the motion module weights
682
+ if motion_module_path.exists() and motion_module_path.is_file():
683
+ if motion_module_path.suffix.lower() in [".pth", ".pt", ".ckpt"]:
684
+ logger.info(f"Load motion module params from {motion_module_path}")
685
+ motion_state_dict = torch.load(
686
+ motion_module_path, map_location="cpu", weights_only=True
687
+ )
688
+ elif motion_module_path.suffix.lower() == ".safetensors":
689
+ motion_state_dict = load_file(motion_module_path, device="cpu")
690
+ else:
691
+ raise RuntimeError(
692
+ f"unknown file format for motion module weights: {motion_module_path.suffix}"
693
+ )
694
+
695
+ motion_state_dict = {
696
+ k.replace('motion_modules.', 'temporal_modules.'): v for k, v in motion_state_dict.items() if not "pos_encoder" in k
697
+ }
698
+
699
+ if mm_zero_proj_out:
700
+ logger.info(f"Zero initialize proj_out layers in motion module...")
701
+ new_motion_state_dict = OrderedDict()
702
+ for k in motion_state_dict:
703
+ if "proj_out" in k:
704
+ continue
705
+ new_motion_state_dict[k] = motion_state_dict[k]
706
+ motion_state_dict = new_motion_state_dict
707
+
708
+ # merge the state dicts
709
+ state_dict.update(motion_state_dict)
710
+
711
+ # load the weights into the model
712
+ m, u = model.load_state_dict(state_dict, strict=False)
713
+ logger.debug(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
714
+
715
+ params = [
716
+ p.numel() if "temporal_modules" in n else 0
717
+ for n, p in model.named_parameters()
718
+ ]
719
+ mm_params = [
720
+ p.numel() if "motion_modules" in n else 0
721
+ for n, p in model.named_parameters()
722
+ ]
723
+ logger.info(
724
+ f"Loaded {sum(mm_params) / 1e6}M-parameter motion module, Loaded {sum(params) / 1e6}M-parameter temporal module"
725
+ )
726
+
727
+ return model
models/unet_3d_blocks.py ADDED
@@ -0,0 +1,1121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *************************************************************************
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # This file has been modified by ByteDance Ltd. and/or its affiliates.
6
+ #
7
+ # Original file was released under Aniportrait, with the full license text
8
+ # available at https://github.com/Zejun-Yang/AniPortrait/blob/main/LICENSE.
9
+ #
10
+ # This modified file is released under the same license.
11
+ # *************************************************************************
12
+ import pdb
13
+ from typing import Dict, Optional
14
+ import torch
15
+ from torch import nn
16
+
17
+ from src.models.motion_module import get_motion_module
18
+
19
+ # from .motion_module import get_motion_module
20
+ from src.models.resnet import Downsample3D, ResnetBlock3D, Upsample3D
21
+ from .transformer_3d import Transformer3DModel
22
+
23
+
24
+ def get_down_block(
25
+ down_block_type,
26
+ num_layers,
27
+ in_channels,
28
+ out_channels,
29
+ temb_channels,
30
+ add_downsample,
31
+ resnet_eps,
32
+ resnet_act_fn,
33
+ attn_num_head_channels,
34
+ resnet_groups=None,
35
+ cross_attention_dim=None,
36
+ downsample_padding=None,
37
+ dual_cross_attention=False,
38
+ use_linear_projection=False,
39
+ only_cross_attention=False,
40
+ upcast_attention=False,
41
+ resnet_time_scale_shift="default",
42
+ unet_use_cross_frame_attention=None,
43
+ unet_use_temporal_attention=None,
44
+ use_inflated_groupnorm=None,
45
+ use_motion_module=None,
46
+ motion_module_type=None,
47
+ motion_module_kwargs=None,
48
+ use_temporal_module=None,
49
+ temporal_module_type=None,
50
+ temporal_module_kwargs=None,
51
+ ):
52
+ down_block_type = (
53
+ down_block_type[7:]
54
+ if down_block_type.startswith("UNetRes")
55
+ else down_block_type
56
+ )
57
+ if down_block_type == "DownBlock3D":
58
+ return DownBlock3D(
59
+ num_layers=num_layers,
60
+ in_channels=in_channels,
61
+ out_channels=out_channels,
62
+ temb_channels=temb_channels,
63
+ add_downsample=add_downsample,
64
+ resnet_eps=resnet_eps,
65
+ resnet_act_fn=resnet_act_fn,
66
+ resnet_groups=resnet_groups,
67
+ downsample_padding=downsample_padding,
68
+ resnet_time_scale_shift=resnet_time_scale_shift,
69
+ use_inflated_groupnorm=use_inflated_groupnorm,
70
+ use_motion_module=use_motion_module,
71
+ motion_module_type=motion_module_type,
72
+ motion_module_kwargs=motion_module_kwargs,
73
+ use_temporal_module=use_temporal_module,
74
+ temporal_module_type=temporal_module_type,
75
+ temporal_module_kwargs=temporal_module_kwargs,
76
+ )
77
+ elif down_block_type == "CrossAttnDownBlock3D":
78
+ if cross_attention_dim is None:
79
+ raise ValueError(
80
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D"
81
+ )
82
+ return CrossAttnDownBlock3D(
83
+ num_layers=num_layers,
84
+ in_channels=in_channels,
85
+ out_channels=out_channels,
86
+ temb_channels=temb_channels,
87
+ add_downsample=add_downsample,
88
+ resnet_eps=resnet_eps,
89
+ resnet_act_fn=resnet_act_fn,
90
+ resnet_groups=resnet_groups,
91
+ downsample_padding=downsample_padding,
92
+ cross_attention_dim=cross_attention_dim,
93
+ attn_num_head_channels=attn_num_head_channels,
94
+ dual_cross_attention=dual_cross_attention,
95
+ use_linear_projection=use_linear_projection,
96
+ only_cross_attention=only_cross_attention,
97
+ upcast_attention=upcast_attention,
98
+ resnet_time_scale_shift=resnet_time_scale_shift,
99
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
100
+ unet_use_temporal_attention=unet_use_temporal_attention,
101
+ use_inflated_groupnorm=use_inflated_groupnorm,
102
+ use_motion_module=use_motion_module,
103
+ motion_module_type=motion_module_type,
104
+ motion_module_kwargs=motion_module_kwargs,
105
+ use_temporal_module=use_temporal_module,
106
+ temporal_module_type=temporal_module_type,
107
+ temporal_module_kwargs=temporal_module_kwargs,
108
+ )
109
+ raise ValueError(f"{down_block_type} does not exist.")
110
+
111
+
112
+ def get_up_block(
113
+ up_block_type,
114
+ num_layers,
115
+ in_channels,
116
+ out_channels,
117
+ prev_output_channel,
118
+ temb_channels,
119
+ add_upsample,
120
+ resnet_eps,
121
+ resnet_act_fn,
122
+ attn_num_head_channels,
123
+ resnet_groups=None,
124
+ cross_attention_dim=None,
125
+ dual_cross_attention=False,
126
+ use_linear_projection=False,
127
+ only_cross_attention=False,
128
+ upcast_attention=False,
129
+ resnet_time_scale_shift="default",
130
+ unet_use_cross_frame_attention=None,
131
+ unet_use_temporal_attention=None,
132
+ use_inflated_groupnorm=None,
133
+ use_motion_module=None,
134
+ motion_module_type=None,
135
+ motion_module_kwargs=None,
136
+ use_temporal_module=None,
137
+ temporal_module_type=None,
138
+ temporal_module_kwargs=None,
139
+ ):
140
+ up_block_type = (
141
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
142
+ )
143
+ if up_block_type == "UpBlock3D":
144
+ return UpBlock3D(
145
+ num_layers=num_layers,
146
+ in_channels=in_channels,
147
+ out_channels=out_channels,
148
+ prev_output_channel=prev_output_channel,
149
+ temb_channels=temb_channels,
150
+ add_upsample=add_upsample,
151
+ resnet_eps=resnet_eps,
152
+ resnet_act_fn=resnet_act_fn,
153
+ resnet_groups=resnet_groups,
154
+ resnet_time_scale_shift=resnet_time_scale_shift,
155
+ use_inflated_groupnorm=use_inflated_groupnorm,
156
+ use_motion_module=use_motion_module,
157
+ motion_module_type=motion_module_type,
158
+ motion_module_kwargs=motion_module_kwargs,
159
+ use_temporal_module=use_temporal_module,
160
+ temporal_module_type=temporal_module_type,
161
+ temporal_module_kwargs=temporal_module_kwargs,
162
+ )
163
+ elif up_block_type == "CrossAttnUpBlock3D":
164
+ if cross_attention_dim is None:
165
+ raise ValueError(
166
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D"
167
+ )
168
+ return CrossAttnUpBlock3D(
169
+ num_layers=num_layers,
170
+ in_channels=in_channels,
171
+ out_channels=out_channels,
172
+ prev_output_channel=prev_output_channel,
173
+ temb_channels=temb_channels,
174
+ add_upsample=add_upsample,
175
+ resnet_eps=resnet_eps,
176
+ resnet_act_fn=resnet_act_fn,
177
+ resnet_groups=resnet_groups,
178
+ cross_attention_dim=cross_attention_dim,
179
+ attn_num_head_channels=attn_num_head_channels,
180
+ dual_cross_attention=dual_cross_attention,
181
+ use_linear_projection=use_linear_projection,
182
+ only_cross_attention=only_cross_attention,
183
+ upcast_attention=upcast_attention,
184
+ resnet_time_scale_shift=resnet_time_scale_shift,
185
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
186
+ unet_use_temporal_attention=unet_use_temporal_attention,
187
+ use_inflated_groupnorm=use_inflated_groupnorm,
188
+ use_motion_module=use_motion_module,
189
+ motion_module_type=motion_module_type,
190
+ motion_module_kwargs=motion_module_kwargs,
191
+ use_temporal_module=use_temporal_module,
192
+ temporal_module_type=temporal_module_type,
193
+ temporal_module_kwargs=temporal_module_kwargs,
194
+ )
195
+ raise ValueError(f"{up_block_type} does not exist.")
196
+
197
+
198
+ class UNetMidBlock3DCrossAttn(nn.Module):
199
+
200
+ def __init__(
201
+ self,
202
+ in_channels: int,
203
+ temb_channels: int,
204
+ dropout: float = 0.0,
205
+ num_layers: int = 1,
206
+ resnet_eps: float = 1e-6,
207
+ resnet_time_scale_shift: str = "default",
208
+ resnet_act_fn: str = "swish",
209
+ resnet_groups: int = 32,
210
+ resnet_pre_norm: bool = True,
211
+ attn_num_head_channels=1,
212
+ output_scale_factor=1.0,
213
+ cross_attention_dim=1280,
214
+ dual_cross_attention=False,
215
+ use_linear_projection=False,
216
+ upcast_attention=False,
217
+ unet_use_cross_frame_attention=None,
218
+ unet_use_temporal_attention=None,
219
+ use_inflated_groupnorm=None,
220
+ use_motion_module=None,
221
+ motion_module_type=None,
222
+ motion_module_kwargs=None,
223
+ use_temporal_module=None,
224
+ temporal_module_type=None,
225
+ temporal_module_kwargs=None,
226
+ **transformer_kwargs,
227
+ ):
228
+ super().__init__()
229
+
230
+ self.has_cross_attention = True
231
+ self.attn_num_head_channels = attn_num_head_channels
232
+ resnet_groups = (
233
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
234
+ )
235
+
236
+ # there is always at least one resnet
237
+ resnets = [
238
+ ResnetBlock3D(
239
+ in_channels=in_channels,
240
+ out_channels=in_channels,
241
+ temb_channels=temb_channels,
242
+ eps=resnet_eps,
243
+ groups=resnet_groups,
244
+ dropout=dropout,
245
+ time_embedding_norm=resnet_time_scale_shift,
246
+ non_linearity=resnet_act_fn,
247
+ output_scale_factor=output_scale_factor,
248
+ pre_norm=resnet_pre_norm,
249
+ use_inflated_groupnorm=use_inflated_groupnorm,
250
+ )
251
+ ]
252
+ attentions = []
253
+ motion_modules = []
254
+
255
+ for _ in range(num_layers):
256
+ if dual_cross_attention:
257
+ raise NotImplementedError
258
+ attentions.append(
259
+ Transformer3DModel(
260
+ attn_num_head_channels,
261
+ in_channels // attn_num_head_channels,
262
+ in_channels=in_channels,
263
+ num_layers=1,
264
+ cross_attention_dim=cross_attention_dim,
265
+ norm_num_groups=resnet_groups,
266
+ use_linear_projection=use_linear_projection,
267
+ upcast_attention=upcast_attention,
268
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
269
+ unet_use_temporal_attention=unet_use_temporal_attention,
270
+ **transformer_kwargs
271
+ )
272
+ )
273
+ motion_modules.append(
274
+ get_motion_module(
275
+ in_channels=in_channels,
276
+ motion_module_type=motion_module_type,
277
+ motion_module_kwargs=motion_module_kwargs,
278
+ )
279
+ if use_motion_module
280
+ else None
281
+ )
282
+ resnets.append(
283
+ ResnetBlock3D(
284
+ in_channels=in_channels,
285
+ out_channels=in_channels,
286
+ temb_channels=temb_channels,
287
+ eps=resnet_eps,
288
+ groups=resnet_groups,
289
+ dropout=dropout,
290
+ time_embedding_norm=resnet_time_scale_shift,
291
+ non_linearity=resnet_act_fn,
292
+ output_scale_factor=output_scale_factor,
293
+ pre_norm=resnet_pre_norm,
294
+ use_inflated_groupnorm=use_inflated_groupnorm,
295
+ )
296
+ )
297
+
298
+ self.attentions = nn.ModuleList(attentions)
299
+ self.resnets = nn.ModuleList(resnets)
300
+ self.motion_modules = nn.ModuleList(motion_modules)
301
+ self.temporal_modules = nn.ModuleList(
302
+ [
303
+ (
304
+ get_motion_module(
305
+ in_channels=in_channels,
306
+ motion_module_type=temporal_module_type,
307
+ motion_module_kwargs=temporal_module_kwargs,
308
+ )
309
+ if use_temporal_module
310
+ else None
311
+ )
312
+ for _ in range(num_layers)
313
+ ]
314
+ )
315
+ self.gradient_checkpointing = False
316
+
317
+ def forward(
318
+ self,
319
+ hidden_states,
320
+ temb=None,
321
+ encoder_hidden_states=None,
322
+ attention_mask=None,
323
+ skip_mm=False,
324
+ ):
325
+ if isinstance(encoder_hidden_states, list):
326
+ encoder_hidden_states, motion_hidden_states = encoder_hidden_states
327
+ else:
328
+ motion_hidden_states = encoder_hidden_states
329
+
330
+ hidden_states = self.resnets[0](hidden_states, temb)
331
+ for attn, resnet, motion_module, temporal_module in zip(
332
+ self.attentions, self.resnets[1:], self.motion_modules, self.temporal_modules
333
+ ):
334
+ if self.training and self.gradient_checkpointing:
335
+ def create_custom_forward(module, return_dict=None):
336
+ def custom_forward(*inputs):
337
+ if return_dict is not None:
338
+ return module(*inputs, return_dict=return_dict)
339
+ else:
340
+ return module(*inputs)
341
+
342
+ return custom_forward
343
+ hidden_states = torch.utils.checkpoint.checkpoint(
344
+ create_custom_forward(attn, return_dict=False),
345
+ hidden_states,
346
+ encoder_hidden_states,
347
+ )[0]
348
+ if (motion_module is not None) and not skip_mm:
349
+ hidden_states = torch.utils.checkpoint.checkpoint(
350
+ create_custom_forward(motion_module),
351
+ hidden_states,
352
+ temb,
353
+ motion_hidden_states,
354
+ )
355
+ if (temporal_module is not None) and not skip_mm:
356
+ hidden_states = torch.utils.checkpoint.checkpoint(
357
+ create_custom_forward(temporal_module),
358
+ hidden_states.requires_grad_(),
359
+ temb,
360
+ None,
361
+ )
362
+ # hidden_states = (
363
+ # temporal_module(hidden_states, temb, encoder_hidden_states=None)
364
+ # if (temporal_module is not None) and not skip_mm
365
+ # else hidden_states
366
+ # )
367
+ hidden_states = torch.utils.checkpoint.checkpoint(
368
+ create_custom_forward(resnet), hidden_states, temb
369
+ )
370
+ else:
371
+ hidden_states = attn(
372
+ hidden_states,
373
+ encoder_hidden_states=encoder_hidden_states,
374
+ ).sample
375
+ hidden_states = (
376
+ motion_module(
377
+ hidden_states, temb, encoder_hidden_states=motion_hidden_states
378
+ )
379
+ if (motion_module is not None) and not skip_mm
380
+ else hidden_states
381
+ )
382
+ hidden_states = (
383
+ temporal_module(hidden_states, temb, encoder_hidden_states=None, debug=True)
384
+ if (temporal_module is not None) and not skip_mm
385
+ else hidden_states
386
+ )
387
+
388
+ hidden_states = resnet(hidden_states, temb)
389
+
390
+ return hidden_states
391
+
392
+
393
+ class CrossAttnDownBlock3D(nn.Module):
394
+
395
+ def __init__(
396
+ self,
397
+ in_channels: int,
398
+ out_channels: int,
399
+ temb_channels: int,
400
+ dropout: float = 0.0,
401
+ num_layers: int = 1,
402
+ resnet_eps: float = 1e-6,
403
+ resnet_time_scale_shift: str = "default",
404
+ resnet_act_fn: str = "swish",
405
+ resnet_groups: int = 32,
406
+ resnet_pre_norm: bool = True,
407
+ attn_num_head_channels=1,
408
+ cross_attention_dim=1280,
409
+ output_scale_factor=1.0,
410
+ downsample_padding=1,
411
+ add_downsample=True,
412
+ dual_cross_attention=False,
413
+ use_linear_projection=False,
414
+ only_cross_attention=False,
415
+ upcast_attention=False,
416
+ unet_use_cross_frame_attention=None,
417
+ unet_use_temporal_attention=None,
418
+ use_inflated_groupnorm=None,
419
+ use_motion_module=None,
420
+ motion_module_type=None,
421
+ motion_module_kwargs=None,
422
+ use_temporal_module=None,
423
+ temporal_module_type=None,
424
+ temporal_module_kwargs=None,
425
+ **transformer_kwargs,
426
+ ):
427
+ super().__init__()
428
+ resnets = []
429
+ attentions = []
430
+ motion_modules = []
431
+
432
+ self.has_cross_attention = True
433
+ self.attn_num_head_channels = attn_num_head_channels
434
+
435
+ for i in range(num_layers):
436
+ in_channels = in_channels if i == 0 else out_channels
437
+ resnets.append(
438
+ ResnetBlock3D(
439
+ in_channels=in_channels,
440
+ out_channels=out_channels,
441
+ temb_channels=temb_channels,
442
+ eps=resnet_eps,
443
+ groups=resnet_groups,
444
+ dropout=dropout,
445
+ time_embedding_norm=resnet_time_scale_shift,
446
+ non_linearity=resnet_act_fn,
447
+ output_scale_factor=output_scale_factor,
448
+ pre_norm=resnet_pre_norm,
449
+ use_inflated_groupnorm=use_inflated_groupnorm,
450
+ )
451
+ )
452
+ if dual_cross_attention:
453
+ raise NotImplementedError
454
+ attentions.append(
455
+ Transformer3DModel(
456
+ attn_num_head_channels,
457
+ out_channels // attn_num_head_channels,
458
+ in_channels=out_channels,
459
+ num_layers=1,
460
+ cross_attention_dim=cross_attention_dim,
461
+ norm_num_groups=resnet_groups,
462
+ use_linear_projection=use_linear_projection,
463
+ only_cross_attention=only_cross_attention,
464
+ upcast_attention=upcast_attention,
465
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
466
+ unet_use_temporal_attention=unet_use_temporal_attention,
467
+ **transformer_kwargs,
468
+ )
469
+ )
470
+ motion_modules.append(
471
+ get_motion_module(
472
+ in_channels=out_channels,
473
+ motion_module_type=motion_module_type,
474
+ motion_module_kwargs=motion_module_kwargs,
475
+ )
476
+ if use_motion_module
477
+ else None
478
+ )
479
+
480
+ self.attentions = nn.ModuleList(attentions)
481
+ self.resnets = nn.ModuleList(resnets)
482
+ self.motion_modules = nn.ModuleList(motion_modules)
483
+ self.temporal_modules = nn.ModuleList(
484
+ [
485
+ (
486
+ get_motion_module(
487
+ in_channels=out_channels,
488
+ motion_module_type=temporal_module_type,
489
+ motion_module_kwargs=temporal_module_kwargs,
490
+ )
491
+ if use_temporal_module
492
+ else None
493
+ )
494
+ for _ in range(num_layers)
495
+ ]
496
+ )
497
+
498
+ if add_downsample:
499
+ self.downsamplers = nn.ModuleList(
500
+ [
501
+ Downsample3D(
502
+ out_channels,
503
+ use_conv=True,
504
+ out_channels=out_channels,
505
+ padding=downsample_padding,
506
+ name="op",
507
+ )
508
+ ]
509
+ )
510
+ else:
511
+ self.downsamplers = None
512
+
513
+ self.gradient_checkpointing = False
514
+
515
+ def forward(
516
+ self,
517
+ hidden_states,
518
+ temb=None,
519
+ encoder_hidden_states=None,
520
+ attention_mask=None,
521
+ skip_mm=False
522
+ ):
523
+ if isinstance(encoder_hidden_states, list):
524
+ encoder_hidden_states, motion_hidden_states = encoder_hidden_states
525
+ else:
526
+ motion_hidden_states = encoder_hidden_states
527
+
528
+ output_states = ()
529
+
530
+ for i, (resnet, attn, motion_module, temporal_module) in enumerate(
531
+ zip(self.resnets, self.attentions, self.motion_modules, self.temporal_modules)
532
+ ):
533
+
534
+ # self.gradient_checkpointing = False
535
+ if self.training and self.gradient_checkpointing:
536
+
537
+ def create_custom_forward(module, return_dict=None):
538
+ def custom_forward(*inputs):
539
+ if return_dict is not None:
540
+ return module(*inputs, return_dict=return_dict)
541
+ else:
542
+ return module(*inputs)
543
+
544
+ return custom_forward
545
+
546
+ hidden_states = torch.utils.checkpoint.checkpoint(
547
+ create_custom_forward(resnet), hidden_states, temb
548
+ )
549
+ hidden_states = torch.utils.checkpoint.checkpoint(
550
+ create_custom_forward(attn, return_dict=False),
551
+ hidden_states,
552
+ encoder_hidden_states,
553
+ )[0]
554
+
555
+ # add motion module
556
+ if (motion_module is not None) and not skip_mm:
557
+ hidden_states = torch.utils.checkpoint.checkpoint(
558
+ create_custom_forward(motion_module),
559
+ hidden_states,
560
+ temb,
561
+ motion_hidden_states,
562
+ )
563
+ if (temporal_module is not None) and not skip_mm:
564
+ # hidden_states = torch.utils.checkpoint.checkpoint(
565
+ # create_custom_forward(temporal_module),
566
+ # hidden_states.requires_grad_(),
567
+ # temb,
568
+ # None,
569
+ # )
570
+ hidden_states = (
571
+ temporal_module(hidden_states, temb, encoder_hidden_states=None)
572
+ if (temporal_module is not None) and not skip_mm
573
+ else hidden_states
574
+ )
575
+
576
+ else:
577
+ hidden_states = resnet(hidden_states, temb)
578
+ hidden_states = attn(
579
+ hidden_states,
580
+ encoder_hidden_states=encoder_hidden_states,
581
+ ).sample
582
+
583
+ # add motion module
584
+ hidden_states = (
585
+ motion_module(
586
+ hidden_states, temb, encoder_hidden_states=motion_hidden_states
587
+ )
588
+ if (motion_module is not None) and not skip_mm
589
+ else hidden_states
590
+ )
591
+ hidden_states = (
592
+ temporal_module(hidden_states, temb, encoder_hidden_states=None, debug=True)
593
+ if (temporal_module is not None) and not skip_mm
594
+ else hidden_states
595
+ )
596
+
597
+ output_states += (hidden_states,)
598
+
599
+ if self.downsamplers is not None:
600
+ for downsampler in self.downsamplers:
601
+ hidden_states = downsampler(hidden_states)
602
+
603
+ output_states += (hidden_states,)
604
+
605
+ return hidden_states, output_states
606
+
607
+
608
+ class DownBlock3D(nn.Module):
609
+
610
+ def __init__(
611
+ self,
612
+ in_channels: int,
613
+ out_channels: int,
614
+ temb_channels: int,
615
+ dropout: float = 0.0,
616
+ num_layers: int = 1,
617
+ resnet_eps: float = 1e-6,
618
+ resnet_time_scale_shift: str = "default",
619
+ resnet_act_fn: str = "swish",
620
+ resnet_groups: int = 32,
621
+ resnet_pre_norm: bool = True,
622
+ output_scale_factor=1.0,
623
+ add_downsample=True,
624
+ downsample_padding=1,
625
+ use_inflated_groupnorm=None,
626
+ use_motion_module=None,
627
+ motion_module_type=None,
628
+ motion_module_kwargs=None,
629
+ use_temporal_module=None,
630
+ temporal_module_type=None,
631
+ temporal_module_kwargs=None,
632
+ ):
633
+ super().__init__()
634
+ resnets = []
635
+ motion_modules = []
636
+
637
+ # use_motion_module = False
638
+ for i in range(num_layers):
639
+ in_channels = in_channels if i == 0 else out_channels
640
+ resnets.append(
641
+ ResnetBlock3D(
642
+ in_channels=in_channels,
643
+ out_channels=out_channels,
644
+ temb_channels=temb_channels,
645
+ eps=resnet_eps,
646
+ groups=resnet_groups,
647
+ dropout=dropout,
648
+ time_embedding_norm=resnet_time_scale_shift,
649
+ non_linearity=resnet_act_fn,
650
+ output_scale_factor=output_scale_factor,
651
+ pre_norm=resnet_pre_norm,
652
+ use_inflated_groupnorm=use_inflated_groupnorm,
653
+ )
654
+ )
655
+ motion_modules.append(
656
+ get_motion_module(
657
+ in_channels=out_channels,
658
+ motion_module_type=motion_module_type,
659
+ motion_module_kwargs=motion_module_kwargs,
660
+ )
661
+ if use_motion_module
662
+ else None
663
+ )
664
+
665
+ self.resnets = nn.ModuleList(resnets)
666
+ self.motion_modules = nn.ModuleList(motion_modules)
667
+ self.temporal_modules = nn.ModuleList(
668
+ [
669
+ (
670
+ get_motion_module(
671
+ in_channels=out_channels,
672
+ motion_module_type=temporal_module_type,
673
+ motion_module_kwargs=temporal_module_kwargs,
674
+ )
675
+ if use_temporal_module
676
+ else None
677
+ )
678
+ for _ in range(num_layers)
679
+ ]
680
+ )
681
+
682
+ if add_downsample:
683
+ self.downsamplers = nn.ModuleList(
684
+ [
685
+ Downsample3D(
686
+ out_channels,
687
+ use_conv=True,
688
+ out_channels=out_channels,
689
+ padding=downsample_padding,
690
+ name="op",
691
+ )
692
+ ]
693
+ )
694
+ else:
695
+ self.downsamplers = None
696
+
697
+ self.gradient_checkpointing = False
698
+
699
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, skip_mm=False):
700
+ output_states = ()
701
+ if isinstance(encoder_hidden_states, list):
702
+ encoder_hidden_states, motion_hidden_states = encoder_hidden_states
703
+ else:
704
+ motion_hidden_states = encoder_hidden_states
705
+ for resnet, motion_module, temporal_module in zip(
706
+ self.resnets, self.motion_modules, self.temporal_modules
707
+ ):
708
+ # print(f"DownBlock3D {self.gradient_checkpointing = }")
709
+ if self.training and self.gradient_checkpointing:
710
+
711
+ def create_custom_forward(module):
712
+ def custom_forward(*inputs):
713
+ return module(*inputs)
714
+
715
+ return custom_forward
716
+
717
+ hidden_states = torch.utils.checkpoint.checkpoint(
718
+ create_custom_forward(resnet), hidden_states, temb
719
+ )
720
+ if (motion_module is not None) and not skip_mm:
721
+ hidden_states = torch.utils.checkpoint.checkpoint(
722
+ create_custom_forward(motion_module),
723
+ hidden_states,
724
+ temb,
725
+ motion_hidden_states,
726
+ )
727
+
728
+ if (temporal_module is not None) and not skip_mm:
729
+ hidden_states = torch.utils.checkpoint.checkpoint(
730
+ create_custom_forward(temporal_module),
731
+ hidden_states.requires_grad_(),
732
+ temb,
733
+ None,
734
+ )
735
+ else:
736
+ hidden_states = resnet(hidden_states, temb)
737
+
738
+ # add motion module
739
+ hidden_states = (
740
+ motion_module(
741
+ hidden_states, temb, encoder_hidden_states=motion_hidden_states
742
+ )
743
+ if (motion_module is not None) and not skip_mm
744
+ else hidden_states
745
+ )
746
+ hidden_states = (
747
+ temporal_module(
748
+ hidden_states, temb, encoder_hidden_states=None, debug=True
749
+ )
750
+ if (temporal_module is not None) and not skip_mm
751
+ else hidden_states
752
+ )
753
+
754
+ output_states += (hidden_states,)
755
+
756
+ if self.downsamplers is not None:
757
+ for downsampler in self.downsamplers:
758
+ hidden_states = downsampler(hidden_states)
759
+
760
+ output_states += (hidden_states,)
761
+
762
+ return hidden_states, output_states
763
+
764
+
765
+ class CrossAttnUpBlock3D(nn.Module):
766
+
767
+ def __init__(
768
+ self,
769
+ in_channels: int,
770
+ out_channels: int,
771
+ prev_output_channel: int,
772
+ temb_channels: int,
773
+ dropout: float = 0.0,
774
+ num_layers: int = 1,
775
+ resnet_eps: float = 1e-6,
776
+ resnet_time_scale_shift: str = "default",
777
+ resnet_act_fn: str = "swish",
778
+ resnet_groups: int = 32,
779
+ resnet_pre_norm: bool = True,
780
+ attn_num_head_channels=1,
781
+ cross_attention_dim=1280,
782
+ output_scale_factor=1.0,
783
+ add_upsample=True,
784
+ dual_cross_attention=False,
785
+ use_linear_projection=False,
786
+ only_cross_attention=False,
787
+ upcast_attention=False,
788
+ unet_use_cross_frame_attention=None,
789
+ unet_use_temporal_attention=None,
790
+ use_motion_module=None,
791
+ use_inflated_groupnorm=None,
792
+ motion_module_type=None,
793
+ motion_module_kwargs=None,
794
+ use_temporal_module=None,
795
+ temporal_module_type=None,
796
+ temporal_module_kwargs=None,
797
+ **transformer_kwargs,
798
+ ):
799
+ super().__init__()
800
+ resnets = []
801
+ attentions = []
802
+ motion_modules = []
803
+
804
+ self.has_cross_attention = True
805
+ self.attn_num_head_channels = attn_num_head_channels
806
+
807
+ for i in range(num_layers):
808
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
809
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
810
+
811
+ resnets.append(
812
+ ResnetBlock3D(
813
+ in_channels=resnet_in_channels + res_skip_channels,
814
+ out_channels=out_channels,
815
+ temb_channels=temb_channels,
816
+ eps=resnet_eps,
817
+ groups=resnet_groups,
818
+ dropout=dropout,
819
+ time_embedding_norm=resnet_time_scale_shift,
820
+ non_linearity=resnet_act_fn,
821
+ output_scale_factor=output_scale_factor,
822
+ pre_norm=resnet_pre_norm,
823
+ use_inflated_groupnorm=use_inflated_groupnorm,
824
+ )
825
+ )
826
+ if dual_cross_attention:
827
+ raise NotImplementedError
828
+ attentions.append(
829
+ Transformer3DModel(
830
+ attn_num_head_channels,
831
+ out_channels // attn_num_head_channels,
832
+ in_channels=out_channels,
833
+ num_layers=1,
834
+ cross_attention_dim=cross_attention_dim,
835
+ norm_num_groups=resnet_groups,
836
+ use_linear_projection=use_linear_projection,
837
+ only_cross_attention=only_cross_attention,
838
+ upcast_attention=upcast_attention,
839
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
840
+ unet_use_temporal_attention=unet_use_temporal_attention,
841
+ **transformer_kwargs,
842
+ )
843
+ )
844
+ motion_modules.append(
845
+ get_motion_module(
846
+ in_channels=out_channels,
847
+ motion_module_type=motion_module_type,
848
+ motion_module_kwargs=motion_module_kwargs,
849
+ )
850
+ if use_motion_module
851
+ else None
852
+ )
853
+
854
+ self.attentions = nn.ModuleList(attentions)
855
+ self.resnets = nn.ModuleList(resnets)
856
+ self.motion_modules = nn.ModuleList(motion_modules)
857
+ self.temporal_modules = nn.ModuleList(
858
+ [
859
+ (
860
+ get_motion_module(
861
+ in_channels=out_channels,
862
+ motion_module_type=temporal_module_type,
863
+ motion_module_kwargs=temporal_module_kwargs,
864
+ )
865
+ if use_temporal_module
866
+ else None
867
+ )
868
+ for _ in range(num_layers)
869
+ ]
870
+ )
871
+
872
+ if add_upsample:
873
+ self.upsamplers = nn.ModuleList(
874
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
875
+ )
876
+ else:
877
+ self.upsamplers = None
878
+
879
+ self.gradient_checkpointing = False
880
+
881
+ def forward(
882
+ self,
883
+ hidden_states,
884
+ res_hidden_states_tuple,
885
+ temb=None,
886
+ encoder_hidden_states=None,
887
+ upsample_size=None,
888
+ attention_mask=None,
889
+ skip_mm=False,
890
+ ):
891
+ if isinstance(encoder_hidden_states, list):
892
+ encoder_hidden_states, motion_hidden_states = encoder_hidden_states
893
+ else:
894
+ motion_hidden_states = encoder_hidden_states
895
+ for i, (resnet, attn, motion_module, temporal_module) in enumerate(
896
+ zip(self.resnets, self.attentions, self.motion_modules, self.temporal_modules)
897
+ ):
898
+ # pop res hidden states
899
+ res_hidden_states = res_hidden_states_tuple[-1]
900
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
901
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
902
+
903
+ if self.training and self.gradient_checkpointing:
904
+
905
+ def create_custom_forward(module, return_dict=None):
906
+ def custom_forward(*inputs):
907
+ if return_dict is not None:
908
+ return module(*inputs, return_dict=return_dict)
909
+ else:
910
+ return module(*inputs)
911
+
912
+ return custom_forward
913
+
914
+ hidden_states = torch.utils.checkpoint.checkpoint(
915
+ create_custom_forward(resnet), hidden_states, temb
916
+ )
917
+ # hidden_states = attn(
918
+ # hidden_states,
919
+ # encoder_hidden_states=encoder_hidden_states,
920
+ # ).sample
921
+ hidden_states = torch.utils.checkpoint.checkpoint(
922
+ create_custom_forward(attn, return_dict=False),
923
+ hidden_states,
924
+ encoder_hidden_states,
925
+ )[0]
926
+ if (motion_module is not None) and not skip_mm:
927
+ hidden_states = torch.utils.checkpoint.checkpoint(
928
+ create_custom_forward(motion_module),
929
+ hidden_states,
930
+ temb,
931
+ motion_hidden_states,
932
+ )
933
+ if (temporal_module is not None) and not skip_mm:
934
+ hidden_states = torch.utils.checkpoint.checkpoint(
935
+ create_custom_forward(temporal_module),
936
+ hidden_states.requires_grad_(),
937
+ temb,
938
+ None,
939
+ )
940
+
941
+ else:
942
+ hidden_states = resnet(hidden_states, temb)
943
+ hidden_states = attn(
944
+ hidden_states,
945
+ encoder_hidden_states=encoder_hidden_states,
946
+ ).sample
947
+
948
+ # add motion module
949
+ hidden_states = (
950
+ motion_module(
951
+ hidden_states, temb, encoder_hidden_states=motion_hidden_states
952
+ )
953
+ if (motion_module is not None) and not skip_mm
954
+ else hidden_states
955
+ )
956
+
957
+ # add temporal_module
958
+ hidden_states = (
959
+ temporal_module(hidden_states, temb, encoder_hidden_states=None, debug=True)
960
+ if (temporal_module is not None) and not skip_mm
961
+ else hidden_states
962
+ )
963
+
964
+ if self.upsamplers is not None:
965
+ for upsampler in self.upsamplers:
966
+ hidden_states = upsampler(hidden_states, upsample_size)
967
+
968
+ return hidden_states
969
+
970
+
971
+ class UpBlock3D(nn.Module):
972
+
973
+ def __init__(
974
+ self,
975
+ in_channels: int,
976
+ prev_output_channel: int,
977
+ out_channels: int,
978
+ temb_channels: int,
979
+ dropout: float = 0.0,
980
+ num_layers: int = 1,
981
+ resnet_eps: float = 1e-6,
982
+ resnet_time_scale_shift: str = "default",
983
+ resnet_act_fn: str = "swish",
984
+ resnet_groups: int = 32,
985
+ resnet_pre_norm: bool = True,
986
+ output_scale_factor=1.0,
987
+ add_upsample=True,
988
+ use_inflated_groupnorm=None,
989
+ use_motion_module=None,
990
+ motion_module_type=None,
991
+ motion_module_kwargs=None,
992
+ use_temporal_module=None,
993
+ temporal_module_type=None,
994
+ temporal_module_kwargs=None,
995
+ ):
996
+ super().__init__()
997
+ resnets = []
998
+ motion_modules = []
999
+
1000
+ # use_motion_module = False
1001
+ for i in range(num_layers):
1002
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1003
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1004
+
1005
+ resnets.append(
1006
+ ResnetBlock3D(
1007
+ in_channels=resnet_in_channels + res_skip_channels,
1008
+ out_channels=out_channels,
1009
+ temb_channels=temb_channels,
1010
+ eps=resnet_eps,
1011
+ groups=resnet_groups,
1012
+ dropout=dropout,
1013
+ time_embedding_norm=resnet_time_scale_shift,
1014
+ non_linearity=resnet_act_fn,
1015
+ output_scale_factor=output_scale_factor,
1016
+ pre_norm=resnet_pre_norm,
1017
+ use_inflated_groupnorm=use_inflated_groupnorm,
1018
+ )
1019
+ )
1020
+ motion_modules.append(
1021
+ get_motion_module(
1022
+ in_channels=out_channels,
1023
+ motion_module_type=motion_module_type,
1024
+ motion_module_kwargs=motion_module_kwargs,
1025
+ )
1026
+ if use_motion_module
1027
+ else None
1028
+ )
1029
+
1030
+ self.resnets = nn.ModuleList(resnets)
1031
+ self.motion_modules = nn.ModuleList(motion_modules)
1032
+ self.temporal_modules = nn.ModuleList(
1033
+ [
1034
+ (
1035
+ get_motion_module(
1036
+ in_channels=out_channels,
1037
+ motion_module_type=temporal_module_type,
1038
+ motion_module_kwargs=temporal_module_kwargs,
1039
+ )
1040
+ if use_temporal_module
1041
+ else None
1042
+ )
1043
+ for _ in range(num_layers)
1044
+ ]
1045
+ )
1046
+
1047
+ if add_upsample:
1048
+ self.upsamplers = nn.ModuleList(
1049
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
1050
+ )
1051
+ else:
1052
+ self.upsamplers = None
1053
+
1054
+ self.gradient_checkpointing = False
1055
+
1056
+ def forward(
1057
+ self,
1058
+ hidden_states,
1059
+ res_hidden_states_tuple,
1060
+ temb=None,
1061
+ upsample_size=None,
1062
+ encoder_hidden_states=None,
1063
+ skip_mm=False,
1064
+ ):
1065
+ if isinstance(encoder_hidden_states, list):
1066
+ encoder_hidden_states, motion_hidden_states = encoder_hidden_states
1067
+ else:
1068
+ motion_hidden_states = encoder_hidden_states
1069
+ for resnet, motion_module, temporal_module in zip(self.resnets, self.motion_modules, self.temporal_modules):
1070
+ # pop res hidden states
1071
+ res_hidden_states = res_hidden_states_tuple[-1]
1072
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1073
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1074
+
1075
+ # print(f"UpBlock3D {self.gradient_checkpointing = }")
1076
+ if self.training and self.gradient_checkpointing:
1077
+
1078
+ def create_custom_forward(module):
1079
+ def custom_forward(*inputs):
1080
+ return module(*inputs)
1081
+
1082
+ return custom_forward
1083
+
1084
+ hidden_states = torch.utils.checkpoint.checkpoint(
1085
+ create_custom_forward(resnet), hidden_states, temb
1086
+ )
1087
+ if (motion_module is not None) and not skip_mm:
1088
+ hidden_states = torch.utils.checkpoint.checkpoint(
1089
+ create_custom_forward(motion_module),
1090
+ hidden_states,
1091
+ temb,
1092
+ motion_hidden_states,
1093
+ )
1094
+ if (temporal_module is not None) and not skip_mm:
1095
+ hidden_states = torch.utils.checkpoint.checkpoint(
1096
+ create_custom_forward(temporal_module),
1097
+ hidden_states.requires_grad_(),
1098
+ temb,
1099
+ None,
1100
+ )
1101
+
1102
+ else:
1103
+ hidden_states = resnet(hidden_states, temb)
1104
+ hidden_states = (
1105
+ motion_module(
1106
+ hidden_states, temb, encoder_hidden_states=motion_hidden_states
1107
+ )
1108
+ if (motion_module is not None) and not skip_mm
1109
+ else hidden_states
1110
+ )
1111
+ hidden_states = (
1112
+ temporal_module(hidden_states, temb, encoder_hidden_states=None, debug=True)
1113
+ if (temporal_module is not None) and not skip_mm
1114
+ else hidden_states
1115
+ )
1116
+
1117
+ if self.upsamplers is not None:
1118
+ for upsampler in self.upsamplers:
1119
+ hidden_states = upsampler(hidden_states, upsample_size)
1120
+
1121
+ return hidden_states
pretrained_weights/sd-image-variations-diffusers/.gitattributes ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.npy filter=lfs diff=lfs merge=lfs -text
14
+ *.npz filter=lfs diff=lfs merge=lfs -text
15
+ *.onnx filter=lfs diff=lfs merge=lfs -text
16
+ *.ot filter=lfs diff=lfs merge=lfs -text
17
+ *.parquet filter=lfs diff=lfs merge=lfs -text
18
+ *.pb filter=lfs diff=lfs merge=lfs -text
19
+ *.pickle filter=lfs diff=lfs merge=lfs -text
20
+ *.pkl filter=lfs diff=lfs merge=lfs -text
21
+ *.pt filter=lfs diff=lfs merge=lfs -text
22
+ *.pth filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
26
+ *.tflite filter=lfs diff=lfs merge=lfs -text
27
+ *.tgz filter=lfs diff=lfs merge=lfs -text
28
+ *.wasm filter=lfs diff=lfs merge=lfs -text
29
+ *.xz filter=lfs diff=lfs merge=lfs -text
30
+ *.zip filter=lfs diff=lfs merge=lfs -text
31
+ *.zst filter=lfs diff=lfs merge=lfs -text
32
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
pretrained_weights/sd-image-variations-diffusers/README.md ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ thumbnail: "https://repository-images.githubusercontent.com/523487884/fdb03a69-8353-4387-b5fc-0d85f888a63f"
3
+ datasets:
4
+ - ChristophSchuhmann/improved_aesthetics_6plus
5
+ license: creativeml-openrail-m
6
+ tags:
7
+ - stable-diffusion
8
+ - stable-diffusion-diffusers
9
+ - image-to-image
10
+ ---
11
+
12
+ # Stable Diffusion Image Variations Model Card
13
+
14
+ 📣 V2 model released, and blurriness issues fixed! 📣
15
+
16
+ 🧨🎉 Image Variations is now natively supported in 🤗 Diffusers! 🎉🧨
17
+
18
+ ![](https://raw.githubusercontent.com/justinpinkney/stable-diffusion/main/assets/im-vars-thin.jpg)
19
+
20
+ ## Version 2
21
+
22
+ This version of Stable Diffusion has been fine tuned from [CompVis/stable-diffusion-v1-4-original](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original) to accept CLIP image embedding rather than text embeddings. This allows the creation of "image variations" similar to DALLE-2 using Stable Diffusion. This version of the weights has been ported to huggingface Diffusers, to use this with the Diffusers library requires the [Lambda Diffusers repo](https://github.com/LambdaLabsML/lambda-diffusers).
23
+
24
+ This model was trained in two stages and longer than the original variations model and gives better image quality and better CLIP rated similarity compared to the original version
25
+
26
+ See training details and v1 vs v2 comparison below.
27
+
28
+
29
+ ## Example
30
+
31
+ Make sure you are using a version of Diffusers >=0.8.0 (for older version see the old instructions at the bottom of this model card)
32
+
33
+ ```python
34
+ from diffusers import StableDiffusionImageVariationPipeline
35
+ from PIL import Image
36
+
37
+ device = "cuda:0"
38
+ sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained(
39
+ "lambdalabs/sd-image-variations-diffusers",
40
+ revision="v2.0",
41
+ )
42
+ sd_pipe = sd_pipe.to(device)
43
+
44
+ im = Image.open("path/to/image.jpg")
45
+ tform = transforms.Compose([
46
+ transforms.ToTensor(),
47
+ transforms.Resize(
48
+ (224, 224),
49
+ interpolation=transforms.InterpolationMode.BICUBIC,
50
+ antialias=False,
51
+ ),
52
+ transforms.Normalize(
53
+ [0.48145466, 0.4578275, 0.40821073],
54
+ [0.26862954, 0.26130258, 0.27577711]),
55
+ ])
56
+ inp = tform(im).to(device).unsqueeze(0)
57
+
58
+ out = sd_pipe(inp, guidance_scale=3)
59
+ out["images"][0].save("result.jpg")
60
+ ```
61
+
62
+ ### The importance of resizing correctly... (or not)
63
+
64
+ Note that due a bit of an oversight during training, the model expects resized images without anti-aliasing. This turns out to make a big difference and is important to do the resizing the same way during inference. When passing a PIL image to the Diffusers pipeline antialiasing will be applied during resize, so it's better to input a tensor which you have prepared manually according to the transfrom in the example above!
65
+
66
+ Here are examples of images generated without (top) and with (bottom) anti-aliasing during resize. (Input is [this image](https://github.com/SHI-Labs/Versatile-Diffusion/blob/master/assets/ghibli.jpg))
67
+
68
+ ![](alias-montage.jpg)
69
+
70
+ ![](default-montage.jpg)
71
+
72
+ ### V1 vs V2
73
+
74
+ Here's an example of V1 vs V2, version two was trained more carefully and for longer, see the details below. V2-top vs V1-bottom
75
+
76
+ ![](v2-montage.jpg)
77
+
78
+ ![](v1-montage.jpg)
79
+
80
+ Input images:
81
+
82
+ ![](inputs.jpg)
83
+
84
+ One important thing to note is that due to the longer training V2 appears to have memorised some common images from the training data, e.g. now the previous example of the Girl with a Pearl Earring almosts perfectly reproduce the original rather than creating variations. You can always use v1 by specifiying `revision="v1.0"`.
85
+
86
+ v2 output for girl with a pearl earing as input (guidance scale=3)
87
+
88
+ ![](earring.jpg)
89
+
90
+ # Training
91
+
92
+
93
+ **Training Procedure**
94
+ This model is fine tuned from Stable Diffusion v1-3 where the text encoder has been replaced with an image encoder. The training procedure is the same as for Stable Diffusion except for the fact that images are encoded through a ViT-L/14 image-encoder including the final projection layer to the CLIP shared embedding space. The model was trained on LAION improved aesthetics 6plus.
95
+
96
+ - **Hardware:** 8 x A100-40GB GPUs (provided by [Lambda GPU Cloud](https://lambdalabs.com/service/gpu-cloud))
97
+ - **Optimizer:** AdamW
98
+
99
+ - **Stage 1** - Fine tune only CrossAttention layer weights from Stable Diffusion v1.4 model
100
+ - **Steps**: 46,000
101
+ - **Batch:** batch size=4, GPUs=8, Gradient Accumulations=4. Total batch size=128
102
+ - **Learning rate:** warmup to 1e-5 for 10,000 steps and then kept constant
103
+
104
+ - **Stage 2** - Resume from Stage 1 training the whole unet
105
+ - **Steps**: 50,000
106
+ - **Batch:** batch size=4, GPUs=8, Gradient Accumulations=5. Total batch size=160
107
+ - **Learning rate:** warmup to 1e-5 for 5,000 steps and then kept constant
108
+
109
+
110
+ Training was done using a [modified version of the original Stable Diffusion training code](https://github.com/justinpinkney/stable-diffusion).
111
+
112
+
113
+ # Uses
114
+ _The following section is adapted from the [Stable Diffusion model card](https://huggingface.co/CompVis/stable-diffusion-v1-4)_
115
+
116
+ ## Direct Use
117
+ The model is intended for research purposes only. Possible research areas and
118
+ tasks include
119
+
120
+ - Safe deployment of models which have the potential to generate harmful content.
121
+ - Probing and understanding the limitations and biases of generative models.
122
+ - Generation of artworks and use in design and other artistic processes.
123
+ - Applications in educational or creative tools.
124
+ - Research on generative models.
125
+
126
+ Excluded uses are described below.
127
+
128
+ ### Misuse, Malicious Use, and Out-of-Scope Use
129
+
130
+ The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
131
+
132
+ #### Out-of-Scope Use
133
+ The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
134
+
135
+ #### Misuse and Malicious Use
136
+ Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
137
+
138
+ - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
139
+ - Intentionally promoting or propagating discriminatory content or harmful stereotypes.
140
+ - Impersonating individuals without their consent.
141
+ - Sexual content without consent of the people who might see it.
142
+ - Mis- and disinformation
143
+ - Representations of egregious violence and gore
144
+ - Sharing of copyrighted or licensed material in violation of its terms of use.
145
+ - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
146
+
147
+ ## Limitations and Bias
148
+
149
+ ### Limitations
150
+
151
+ - The model does not achieve perfect photorealism
152
+ - The model cannot render legible text
153
+ - The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
154
+ - Faces and people in general may not be generated properly.
155
+ - The model was trained mainly with English captions and will not work as well in other languages.
156
+ - The autoencoding part of the model is lossy
157
+ - The model was trained on a large-scale dataset
158
+ [LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
159
+ and is not fit for product use without additional safety mechanisms and
160
+ considerations.
161
+ - No additional measures were used to deduplicate the dataset. As a result, we observe some degree of memorization for images that are duplicated in the training data.
162
+ The training data can be searched at [https://rom1504.github.io/clip-retrieval/](https://rom1504.github.io/clip-retrieval/) to possibly assist in the detection of memorized images.
163
+
164
+ ### Bias
165
+
166
+ While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
167
+ Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
168
+ which consists of images that are primarily limited to English descriptions.
169
+ Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
170
+ This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
171
+ ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
172
+
173
+ ### Safety Module
174
+
175
+ The intended use of this model is with the [Safety Checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) in Diffusers.
176
+ This checker works by checking model outputs against known hard-coded NSFW concepts.
177
+ The concepts are intentionally hidden to reduce the likelihood of reverse-engineering this filter.
178
+ Specifically, the checker compares the class probability of harmful concepts in the embedding space of the `CLIPModel` *after generation* of the images.
179
+ The concepts are passed into the model with the generated image and compared to a hand-engineered weight for each NSFW concept.
180
+
181
+
182
+ ## Old instructions
183
+
184
+ If you are using a diffusers version <0.8.0 there is no `StableDiffusionImageVariationPipeline`,
185
+ in this case you need to use an older revision (`2ddbd90b14bc5892c19925b15185e561bc8e5d0a`) in conjunction with the lambda-diffusers repo:
186
+
187
+
188
+ First clone [Lambda Diffusers](https://github.com/LambdaLabsML/lambda-diffusers) and install any requirements (in a virtual environment in the example below):
189
+
190
+ ```bash
191
+ git clone https://github.com/LambdaLabsML/lambda-diffusers.git
192
+ cd lambda-diffusers
193
+ python -m venv .venv
194
+ source .venv/bin/activate
195
+ pip install -r requirements.txt
196
+ ```
197
+
198
+ Then run the following python code:
199
+
200
+ ```python
201
+ from pathlib import Path
202
+ from lambda_diffusers import StableDiffusionImageEmbedPipeline
203
+ from PIL import Image
204
+ import torch
205
+
206
+ device = "cuda" if torch.cuda.is_available() else "cpu"
207
+ pipe = StableDiffusionImageEmbedPipeline.from_pretrained(
208
+ "lambdalabs/sd-image-variations-diffusers",
209
+ revision="2ddbd90b14bc5892c19925b15185e561bc8e5d0a",
210
+ )
211
+ pipe = pipe.to(device)
212
+
213
+ im = Image.open("your/input/image/here.jpg")
214
+ num_samples = 4
215
+ image = pipe(num_samples*[im], guidance_scale=3.0)
216
+ image = image["sample"]
217
+
218
+ base_path = Path("outputs/im2im")
219
+ base_path.mkdir(exist_ok=True, parents=True)
220
+ for idx, im in enumerate(image):
221
+ im.save(base_path/f"{idx:06}.jpg")
222
+ ```
223
+
224
+
225
+
226
+ *This model card was written by: Justin Pinkney and is based on the [Stable Diffusion model card](https://huggingface.co/CompVis/stable-diffusion-v1-4).*
pretrained_weights/sd-image-variations-diffusers/alias-montage.jpg ADDED

Git LFS Details

  • SHA256: 785972e472ca53fdc631cbc5cc6e735448c513638adce9049d1963e401a05c7a
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
pretrained_weights/sd-image-variations-diffusers/default-montage.jpg ADDED

Git LFS Details

  • SHA256: bd42b0ee127f0f4df5912eca1f4d479150c9020b2c6136c19c633fa983294aa7
  • Pointer size: 131 Bytes
  • Size of remote file: 148 kB
pretrained_weights/sd-image-variations-diffusers/earring.jpg ADDED

Git LFS Details

  • SHA256: 87b8a0583e481839a98d27068370979b36a6e2bc95aa79ffeeb89cd324d47bb6
  • Pointer size: 131 Bytes
  • Size of remote file: 113 kB
pretrained_weights/sd-image-variations-diffusers/feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feature_extractor_type": "CLIPFeatureExtractor",
12
+ "image_mean": [
13
+ 0.48145466,
14
+ 0.4578275,
15
+ 0.40821073
16
+ ],
17
+ "image_processor_type": "CLIPImageProcessor",
18
+ "image_std": [
19
+ 0.26862954,
20
+ 0.26130258,
21
+ 0.27577711
22
+ ],
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "shortest_edge": 224
27
+ }
28
+ }
pretrained_weights/sd-image-variations-diffusers/image_encoder/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/home/jpinkney/.cache/huggingface/diffusers/models--lambdalabs--sd-image-variations-diffusers/snapshots/ca6f97f838ae1b5bf764f31363a21f388f4d8f3e/image_encoder",
3
+ "architectures": [
4
+ "CLIPVisionModelWithProjection"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "dropout": 0.0,
8
+ "hidden_act": "quick_gelu",
9
+ "hidden_size": 1024,
10
+ "image_size": 224,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 4096,
14
+ "layer_norm_eps": 1e-05,
15
+ "model_type": "clip_vision_model",
16
+ "num_attention_heads": 16,
17
+ "num_channels": 3,
18
+ "num_hidden_layers": 24,
19
+ "patch_size": 14,
20
+ "projection_dim": 768,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.25.1"
23
+ }
pretrained_weights/sd-image-variations-diffusers/image_encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89d2aa29b5fdf64f3ad4f45fb4227ea98bc45156bbae673b85be1af7783dbabb
3
+ size 1215993967
pretrained_weights/sd-image-variations-diffusers/inputs.jpg ADDED
pretrained_weights/sd-image-variations-diffusers/model_index.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "StableDiffusionImageVariationPipeline",
3
+ "_diffusers_version": "0.9.0",
4
+ "feature_extractor": [
5
+ "transformers",
6
+ "CLIPImageProcessor"
7
+ ],
8
+ "image_encoder": [
9
+ "transformers",
10
+ "CLIPVisionModelWithProjection"
11
+ ],
12
+ "requires_safety_checker": true,
13
+ "safety_checker": [
14
+ "stable_diffusion",
15
+ "StableDiffusionSafetyChecker"
16
+ ],
17
+ "scheduler": [
18
+ "diffusers",
19
+ "PNDMScheduler"
20
+ ],
21
+ "unet": [
22
+ "diffusers",
23
+ "UNet2DConditionModel"
24
+ ],
25
+ "vae": [
26
+ "diffusers",
27
+ "AutoencoderKL"
28
+ ]
29
+ }
pretrained_weights/sd-image-variations-diffusers/safety_checker/config.json ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": "ca6f97f838ae1b5bf764f31363a21f388f4d8f3e",
3
+ "_name_or_path": "/home/jpinkney/.cache/huggingface/diffusers/models--lambdalabs--sd-image-variations-diffusers/snapshots/ca6f97f838ae1b5bf764f31363a21f388f4d8f3e/safety_checker",
4
+ "architectures": [
5
+ "StableDiffusionSafetyChecker"
6
+ ],
7
+ "initializer_factor": 1.0,
8
+ "logit_scale_init_value": 2.6592,
9
+ "model_type": "clip",
10
+ "projection_dim": 768,
11
+ "text_config": {
12
+ "_name_or_path": "",
13
+ "add_cross_attention": false,
14
+ "architectures": null,
15
+ "attention_dropout": 0.0,
16
+ "bad_words_ids": null,
17
+ "begin_suppress_tokens": null,
18
+ "bos_token_id": 0,
19
+ "chunk_size_feed_forward": 0,
20
+ "cross_attention_hidden_size": null,
21
+ "decoder_start_token_id": null,
22
+ "diversity_penalty": 0.0,
23
+ "do_sample": false,
24
+ "dropout": 0.0,
25
+ "early_stopping": false,
26
+ "encoder_no_repeat_ngram_size": 0,
27
+ "eos_token_id": 2,
28
+ "exponential_decay_length_penalty": null,
29
+ "finetuning_task": null,
30
+ "forced_bos_token_id": null,
31
+ "forced_eos_token_id": null,
32
+ "hidden_act": "quick_gelu",
33
+ "hidden_size": 768,
34
+ "id2label": {
35
+ "0": "LABEL_0",
36
+ "1": "LABEL_1"
37
+ },
38
+ "initializer_factor": 1.0,
39
+ "initializer_range": 0.02,
40
+ "intermediate_size": 3072,
41
+ "is_decoder": false,
42
+ "is_encoder_decoder": false,
43
+ "label2id": {
44
+ "LABEL_0": 0,
45
+ "LABEL_1": 1
46
+ },
47
+ "layer_norm_eps": 1e-05,
48
+ "length_penalty": 1.0,
49
+ "max_length": 20,
50
+ "max_position_embeddings": 77,
51
+ "min_length": 0,
52
+ "model_type": "clip_text_model",
53
+ "no_repeat_ngram_size": 0,
54
+ "num_attention_heads": 12,
55
+ "num_beam_groups": 1,
56
+ "num_beams": 1,
57
+ "num_hidden_layers": 12,
58
+ "num_return_sequences": 1,
59
+ "output_attentions": false,
60
+ "output_hidden_states": false,
61
+ "output_scores": false,
62
+ "pad_token_id": 1,
63
+ "prefix": null,
64
+ "problem_type": null,
65
+ "projection_dim": 512,
66
+ "pruned_heads": {},
67
+ "remove_invalid_values": false,
68
+ "repetition_penalty": 1.0,
69
+ "return_dict": true,
70
+ "return_dict_in_generate": false,
71
+ "sep_token_id": null,
72
+ "suppress_tokens": null,
73
+ "task_specific_params": null,
74
+ "temperature": 1.0,
75
+ "tf_legacy_loss": false,
76
+ "tie_encoder_decoder": false,
77
+ "tie_word_embeddings": true,
78
+ "tokenizer_class": null,
79
+ "top_k": 50,
80
+ "top_p": 1.0,
81
+ "torch_dtype": null,
82
+ "torchscript": false,
83
+ "transformers_version": "4.25.1",
84
+ "typical_p": 1.0,
85
+ "use_bfloat16": false,
86
+ "vocab_size": 49408
87
+ },
88
+ "text_config_dict": {
89
+ "hidden_size": 768,
90
+ "intermediate_size": 3072,
91
+ "num_attention_heads": 12,
92
+ "num_hidden_layers": 12
93
+ },
94
+ "torch_dtype": "float32",
95
+ "transformers_version": null,
96
+ "vision_config": {
97
+ "_name_or_path": "",
98
+ "add_cross_attention": false,
99
+ "architectures": null,
100
+ "attention_dropout": 0.0,
101
+ "bad_words_ids": null,
102
+ "begin_suppress_tokens": null,
103
+ "bos_token_id": null,
104
+ "chunk_size_feed_forward": 0,
105
+ "cross_attention_hidden_size": null,
106
+ "decoder_start_token_id": null,
107
+ "diversity_penalty": 0.0,
108
+ "do_sample": false,
109
+ "dropout": 0.0,
110
+ "early_stopping": false,
111
+ "encoder_no_repeat_ngram_size": 0,
112
+ "eos_token_id": null,
113
+ "exponential_decay_length_penalty": null,
114
+ "finetuning_task": null,
115
+ "forced_bos_token_id": null,
116
+ "forced_eos_token_id": null,
117
+ "hidden_act": "quick_gelu",
118
+ "hidden_size": 1024,
119
+ "id2label": {
120
+ "0": "LABEL_0",
121
+ "1": "LABEL_1"
122
+ },
123
+ "image_size": 224,
124
+ "initializer_factor": 1.0,
125
+ "initializer_range": 0.02,
126
+ "intermediate_size": 4096,
127
+ "is_decoder": false,
128
+ "is_encoder_decoder": false,
129
+ "label2id": {
130
+ "LABEL_0": 0,
131
+ "LABEL_1": 1
132
+ },
133
+ "layer_norm_eps": 1e-05,
134
+ "length_penalty": 1.0,
135
+ "max_length": 20,
136
+ "min_length": 0,
137
+ "model_type": "clip_vision_model",
138
+ "no_repeat_ngram_size": 0,
139
+ "num_attention_heads": 16,
140
+ "num_beam_groups": 1,
141
+ "num_beams": 1,
142
+ "num_channels": 3,
143
+ "num_hidden_layers": 24,
144
+ "num_return_sequences": 1,
145
+ "output_attentions": false,
146
+ "output_hidden_states": false,
147
+ "output_scores": false,
148
+ "pad_token_id": null,
149
+ "patch_size": 14,
150
+ "prefix": null,
151
+ "problem_type": null,
152
+ "projection_dim": 512,
153
+ "pruned_heads": {},
154
+ "remove_invalid_values": false,
155
+ "repetition_penalty": 1.0,
156
+ "return_dict": true,
157
+ "return_dict_in_generate": false,
158
+ "sep_token_id": null,
159
+ "suppress_tokens": null,
160
+ "task_specific_params": null,
161
+ "temperature": 1.0,
162
+ "tf_legacy_loss": false,
163
+ "tie_encoder_decoder": false,
164
+ "tie_word_embeddings": true,
165
+ "tokenizer_class": null,
166
+ "top_k": 50,
167
+ "top_p": 1.0,
168
+ "torch_dtype": null,
169
+ "torchscript": false,
170
+ "transformers_version": "4.25.1",
171
+ "typical_p": 1.0,
172
+ "use_bfloat16": false
173
+ },
174
+ "vision_config_dict": {
175
+ "hidden_size": 1024,
176
+ "intermediate_size": 4096,
177
+ "num_attention_heads": 16,
178
+ "num_hidden_layers": 24,
179
+ "patch_size": 14
180
+ }
181
+ }
pretrained_weights/sd-image-variations-diffusers/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PNDMScheduler",
3
+ "_diffusers_version": "0.9.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "num_train_timesteps": 1000,
9
+ "set_alpha_to_one": false,
10
+ "skip_prk_steps": true,
11
+ "steps_offset": 1,
12
+ "trained_betas": null
13
+ }
pretrained_weights/sd-image-variations-diffusers/unet/config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.9.0",
4
+ "act_fn": "silu",
5
+ "attention_head_dim": 8,
6
+ "block_out_channels": [
7
+ 320,
8
+ 640,
9
+ 1280,
10
+ 1280
11
+ ],
12
+ "center_input_sample": false,
13
+ "cross_attention_dim": 768,
14
+ "down_block_types": [
15
+ "CrossAttnDownBlock2D",
16
+ "CrossAttnDownBlock2D",
17
+ "CrossAttnDownBlock2D",
18
+ "DownBlock2D"
19
+ ],
20
+ "downsample_padding": 1,
21
+ "dual_cross_attention": false,
22
+ "flip_sin_to_cos": true,
23
+ "freq_shift": 0,
24
+ "in_channels": 4,
25
+ "layers_per_block": 2,
26
+ "mid_block_scale_factor": 1,
27
+ "norm_eps": 1e-05,
28
+ "norm_num_groups": 32,
29
+ "num_class_embeds": null,
30
+ "only_cross_attention": false,
31
+ "out_channels": 4,
32
+ "sample_size": 64,
33
+ "up_block_types": [
34
+ "UpBlock2D",
35
+ "CrossAttnUpBlock2D",
36
+ "CrossAttnUpBlock2D",
37
+ "CrossAttnUpBlock2D"
38
+ ],
39
+ "use_linear_projection": false
40
+ }
pretrained_weights/sd-image-variations-diffusers/unet/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee23e3368e4e7c0e4ef636ed61923609c97fcaa583f8bb416e3e0986d4a0cfc6
3
+ size 3438354725
pretrained_weights/sd-image-variations-diffusers/v1-montage.jpg ADDED

Git LFS Details

  • SHA256: 607396d9a79e0649a898e46f5edfc2a37775eb02ee2d7977c9435db3f9f9db2f
  • Pointer size: 131 Bytes
  • Size of remote file: 613 kB
pretrained_weights/sd-image-variations-diffusers/v2-montage.jpg ADDED

Git LFS Details

  • SHA256: cb189ff32754768525afdc8c7f3fc99b9ab6747a8edd330297148b0332f48e71
  • Pointer size: 131 Bytes
  • Size of remote file: 570 kB
pretrained_weights/sd-image-variations-diffusers/vae/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.9.0",
4
+ "_name_or_path": "stabilityai/sd-vae-ft-mse",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 3,
23
+ "sample_size": 256,
24
+ "up_block_types": [
25
+ "UpDecoderBlock2D",
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D"
29
+ ]
30
+ }
pretrained_weights/sd-image-variations-diffusers/vae/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b4889b6b1d4ce7ae320a02dedaeff1780ad77d415ea0d744b476155c6377ddc
3
+ size 334707217
pretrained_weights/stable-video-diffusion-img2vid-xt/.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ output_tile.gif filter=lfs diff=lfs merge=lfs -text
pretrained_weights/stable-video-diffusion-img2vid-xt/LICENSE.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STABILITY AI COMMUNITY LICENSE AGREEMENT
2
+
3
+ Last Updated: July 5, 2024
4
+
5
+ 1. INTRODUCTION
6
+
7
+ This Agreement applies to any individual person or entity (“You”, “Your” or “Licensee”) that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
8
+
9
+ This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
10
+
11
+ By clicking “I Accept” or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then “You” includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity’s behalf.
12
+
13
+ 2. RESEARCH & NON-COMMERCIAL USE LICENSE
14
+
15
+ Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. “Research Purpose” means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. “Non-Commercial Purpose” means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
16
+
17
+ 3. COMMERCIAL USE LICENSE
18
+
19
+ Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. “Commercial Purpose” means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business’s or organization’s internal operations.
20
+ If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
21
+
22
+ 4. GENERAL TERMS
23
+
24
+ Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms.
25
+ a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved”, and (iii) prominently display “Powered by Stability AI” on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the “Notice” text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the “Notice” text file that You changed the Stability AI Materials and how it was modified.
26
+ b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI’s AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
27
+ c. Intellectual Property.
28
+ (i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein.
29
+ (ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI’s ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
30
+ (iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
31
+ (iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement.
32
+ (v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI’s existing or prospective technology, products or services (collectively, “Feedback”). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided “AS IS” and You make no warranties whatsoever about any Feedback.
33
+ d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
34
+ e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
35
+ f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement.
36
+ g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
37
+
38
+ 5. DEFINITIONS
39
+
40
+ “Affiliate(s)” means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, “control” means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
41
+
42
+ "Agreement" means this Stability AI Community License Agreement.
43
+
44
+ “AUP” means the Stability AI Acceptable Use Policy available at (https://stability.ai/use-policy), as may be updated from time to time.
45
+
46
+ "Derivative Work(s)” means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output, including “fine tune” and “low-rank adaptation” models derived from a Model or a Model’s output, but do not include the output of any Model.
47
+
48
+ “Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
49
+
50
+ “Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability’s Core Models Webpage available at (https://stability.ai/core-models), as may be updated from time to time.
51
+
52
+ "Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
53
+
54
+ "Software" means Stability AI’s proprietary software made available under this Agreement now or in the future.
55
+
56
+ “Stability AI Materials” means, collectively, Stability’s proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
57
+
58
+ “Trade Control Laws” means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
pretrained_weights/stable-video-diffusion-img2vid-xt/README.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pipeline_tag: image-to-video
3
+ license: other
4
+ license_name: stable-video-diffusion-community
5
+ license_link: LICENSE.md
6
+ ---
7
+
8
+ # Stable Video Diffusion Image-to-Video Model Card
9
+
10
+ <!-- Provide a quick summary of what the model is/does. -->
11
+ ![row01](output_tile.gif)
12
+ Stable Video Diffusion (SVD) Image-to-Video is a diffusion model that takes in a still image as a conditioning frame, and generates a video from it.
13
+
14
+ Please note: For commercial use, please refer to https://stability.ai/license.
15
+
16
+ ## Model Details
17
+
18
+ ### Model Description
19
+
20
+ (SVD) Image-to-Video is a latent diffusion model trained to generate short video clips from an image conditioning.
21
+ This model was trained to generate 25 frames at resolution 576x1024 given a context frame of the same size, finetuned from [SVD Image-to-Video [14 frames]](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid).
22
+ We also finetune the widely used [f8-decoder](https://huggingface.co/docs/diffusers/api/models/autoencoderkl#loading-from-the-original-format) for temporal consistency.
23
+ For convenience, we additionally provide the model with the
24
+ standard frame-wise decoder [here](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/blob/main/svd_xt_image_decoder.safetensors).
25
+
26
+
27
+ - **Developed by:** Stability AI
28
+ - **Funded by:** Stability AI
29
+ - **Model type:** Generative image-to-video model
30
+ - **Finetuned from model:** SVD Image-to-Video [14 frames]
31
+
32
+ ### Model Sources
33
+
34
+ For research purposes, we recommend our `generative-models` Github repository (https://github.com/Stability-AI/generative-models),
35
+ which implements the most popular diffusion frameworks (both training and inference).
36
+
37
+ - **Repository:** https://github.com/Stability-AI/generative-models
38
+ - **Paper:** https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets
39
+
40
+
41
+ ## Evaluation
42
+ ![comparison](comparison.png)
43
+ The chart above evaluates user preference for SVD-Image-to-Video over [GEN-2](https://research.runwayml.com/gen2) and [PikaLabs](https://www.pika.art/).
44
+ SVD-Image-to-Video is preferred by human voters in terms of video quality. For details on the user study, we refer to the [research paper](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets)
45
+
46
+ ## Uses
47
+
48
+ ### Direct Use
49
+
50
+ The model is intended for both non-commercial and commercial usage. You can use this model for non-commercial or research purposes under this [license](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/blob/main/LICENSE.md). Possible research areas and tasks include
51
+
52
+ - Research on generative models.
53
+ - Safe deployment of models which have the potential to generate harmful content.
54
+ - Probing and understanding the limitations and biases of generative models.
55
+ - Generation of artworks and use in design and other artistic processes.
56
+ - Applications in educational or creative tools.
57
+
58
+ For commercial use, please refer to https://stability.ai/license.
59
+
60
+ Excluded uses are described below.
61
+
62
+ ### Out-of-Scope Use
63
+
64
+ The model was not trained to be factual or true representations of people or events,
65
+ and therefore using the model to generate such content is out-of-scope for the abilities of this model.
66
+ The model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy).
67
+
68
+ ## Limitations and Bias
69
+
70
+ ### Limitations
71
+ - The generated videos are rather short (<= 4sec), and the model does not achieve perfect photorealism.
72
+ - The model may generate videos without motion, or very slow camera pans.
73
+ - The model cannot be controlled through text.
74
+ - The model cannot render legible text.
75
+ - Faces and people in general may not be generated properly.
76
+ - The autoencoding part of the model is lossy.
77
+
78
+
79
+ ### Recommendations
80
+
81
+ The model is intended for both non-commercial and commercial usage.
82
+
83
+ ## How to Get Started with the Model
84
+
85
+ Check out https://github.com/Stability-AI/generative-models
86
+
87
+ # Appendix:
88
+
89
+ All considered potential data sources were included for final training, with none held out as the proposed data filtering methods described in the SVD paper handle the quality control/filtering of the dataset. With regards to safety/NSFW filtering, sources considered were either deemed safe or filtered with the in-house NSFW filters.
90
+ No explicit human labor is involved in training data preparation. However, human evaluation for model outputs and quality was extensively used to evaluate model quality and performance. The evaluations were performed with third-party contractor platforms (Amazon Sagemaker, Amazon Mechanical Turk, Prolific) with fluent English-speaking contractors from various countries, primarily from the USA, UK, and Canada. Each worker was paid $12/hr for the time invested in the evaluation.
91
+ No other third party was involved in the development of this model; the model was fully developed in-house at Stability AI.
92
+ Training the SVD checkpoints required a total of approximately 200,000 A100 80GB hours. The majority of the training occurred on 48 * 8 A100s, while some stages took more/less than that. The resulting CO2 emission is ~19,000kg CO2 eq., and energy consumed is ~64000 kWh.
93
+ The released checkpoints (SVD/SVD-XT) are image-to-video models that generate short videos/animations closely following the given input image. Since the model relies on an existing supplied image, the potential risks of disclosing specific material or novel unsafe content are minimal. This was also evaluated by third-party independent red-teaming services, which agree with our conclusion to a high degree of confidence (>90% in various areas of safety red-teaming). The external evaluations were also performed for trustworthiness, leading to >95% confidence in real, trustworthy videos.
94
+ With the default settings at the time of release, SVD takes ~100s for generation, and SVD-XT takes ~180s on an A100 80GB card. Several optimizations to trade off quality / memory / speed can be done to perform faster inference or inference on lower VRAM cards.
95
+ The information related to the model and its development process and usage protocols can be found in the GitHub repo, associated research paper, and HuggingFace model page/cards.
96
+ The released model inference & demo code has image-level watermarking enabled by default, which can be used to detect the outputs. This is done via the imWatermark Python library.
97
+ The model can be used to generate videos from static initial images. However, we prohibit unlawful, obscene, or misleading uses of the model consistent with the terms of our license and Acceptable Use Policy. For the open-weights release, our training data filtering mitigations alleviate this risk to some extent. These restrictions are explicitly enforced on user-facing interfaces at stablevideo.com, where a warning is issued. We do not take any responsibility for third-party interfaces. Submitting initial images that bypass input filters to tease out offensive or inappropriate content listed above is also prohibited. Safety filtering checks at stablevideo.com run on model inputs and outputs independently. More details on our user-facing interfaces can be found here: https://www.stablevideo.com/faq. Beyond the Acceptable Use Policy and other mitigations and conditions described here, the model is not subject to additional model behavior interventions of the type described in the Foundation Model Transparency Index.
98
+ For stablevideo.com, we store preference data in the form of upvotes/downvotes on user-generated videos, and we have a pairwise ranker that runs while a user generates videos. This usage data is solely used for improving Stability AI’s future image/video models and services. No other third-party entities are given access to the usage data beyond Stability AI and maintainers of stablevideo.com.
99
+ For usage statistics of SVD, we refer interested users to HuggingFace model download/usage statistics as a primary indicator. Third-party applications also have reported model usage statistics. We might also consider releasing aggregate usage statistics of stablevideo.com on reaching some milestones.
pretrained_weights/stable-video-diffusion-img2vid-xt/comparison.png ADDED

Git LFS Details

  • SHA256: 517263334c2011dd28f819b831ccc32a8dd676895429693477b936dc88600d15
  • Pointer size: 131 Bytes
  • Size of remote file: 147 kB
pretrained_weights/stable-video-diffusion-img2vid-xt/feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feature_extractor_type": "CLIPFeatureExtractor",
12
+ "image_mean": [
13
+ 0.48145466,
14
+ 0.4578275,
15
+ 0.40821073
16
+ ],
17
+ "image_processor_type": "CLIPImageProcessor",
18
+ "image_std": [
19
+ 0.26862954,
20
+ 0.26130258,
21
+ 0.27577711
22
+ ],
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "shortest_edge": 224
27
+ }
28
+ }
pretrained_weights/stable-video-diffusion-img2vid-xt/image_encoder/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/home/suraj_huggingface_co/.cache/huggingface/hub/models--diffusers--svd-xt/snapshots/9703ded20c957c340781ee710b75660826deb487/image_encoder",
3
+ "architectures": [
4
+ "CLIPVisionModelWithProjection"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "dropout": 0.0,
8
+ "hidden_act": "gelu",
9
+ "hidden_size": 1280,
10
+ "image_size": 224,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "model_type": "clip_vision_model",
16
+ "num_attention_heads": 16,
17
+ "num_channels": 3,
18
+ "num_hidden_layers": 32,
19
+ "patch_size": 14,
20
+ "projection_dim": 1024,
21
+ "torch_dtype": "float16",
22
+ "transformers_version": "4.34.0.dev0"
23
+ }
pretrained_weights/stable-video-diffusion-img2vid-xt/model_index.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "StableVideoDiffusionPipeline",
3
+ "_diffusers_version": "0.24.0.dev0",
4
+ "_name_or_path": "diffusers/svd-xt",
5
+ "feature_extractor": [
6
+ "transformers",
7
+ "CLIPImageProcessor"
8
+ ],
9
+ "image_encoder": [
10
+ "transformers",
11
+ "CLIPVisionModelWithProjection"
12
+ ],
13
+ "scheduler": [
14
+ "diffusers",
15
+ "EulerDiscreteScheduler"
16
+ ],
17
+ "unet": [
18
+ "diffusers",
19
+ "UNetSpatioTemporalConditionModel"
20
+ ],
21
+ "vae": [
22
+ "diffusers",
23
+ "AutoencoderKLTemporalDecoder"
24
+ ]
25
+ }
pretrained_weights/stable-video-diffusion-img2vid-xt/output_tile.gif ADDED

Git LFS Details

  • SHA256: 2340a9809e36fa9634633c7cc5fd256737c620ba47151726c85173512dc5c8ff
  • Pointer size: 133 Bytes
  • Size of remote file: 18.6 MB
pretrained_weights/stable-video-diffusion-img2vid-xt/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EulerDiscreteScheduler",
3
+ "_diffusers_version": "0.24.0.dev0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "interpolation_type": "linear",
9
+ "num_train_timesteps": 1000,
10
+ "prediction_type": "v_prediction",
11
+ "set_alpha_to_one": false,
12
+ "sigma_max": 700.0,
13
+ "sigma_min": 0.002,
14
+ "skip_prk_steps": true,
15
+ "steps_offset": 1,
16
+ "timestep_spacing": "leading",
17
+ "timestep_type": "continuous",
18
+ "trained_betas": null,
19
+ "use_karras_sigmas": true
20
+ }
pretrained_weights/stable-video-diffusion-img2vid-xt/svd_xt.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2652c23d64a1da5f14d55011b9b6dce55f2e72e395719f1cd1f8a079b00a451
3
+ size 9559625980
pretrained_weights/stable-video-diffusion-img2vid-xt/svd_xt_image_decoder.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99aa889bf6d1ca28e026755b83ba37e3072ad79b45dd4c94fae14bee7482263b
3
+ size 9503252964
pretrained_weights/stable-video-diffusion-img2vid-xt/unet/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNetSpatioTemporalConditionModel",
3
+ "_diffusers_version": "0.24.0.dev0",
4
+ "_name_or_path": "/home/suraj_huggingface_co/.cache/huggingface/hub/models--diffusers--svd-xt/snapshots/9703ded20c957c340781ee710b75660826deb487/unet",
5
+ "addition_time_embed_dim": 256,
6
+ "block_out_channels": [
7
+ 320,
8
+ 640,
9
+ 1280,
10
+ 1280
11
+ ],
12
+ "cross_attention_dim": 1024,
13
+ "down_block_types": [
14
+ "CrossAttnDownBlockSpatioTemporal",
15
+ "CrossAttnDownBlockSpatioTemporal",
16
+ "CrossAttnDownBlockSpatioTemporal",
17
+ "DownBlockSpatioTemporal"
18
+ ],
19
+ "in_channels": 8,
20
+ "layers_per_block": 2,
21
+ "num_attention_heads": [
22
+ 5,
23
+ 10,
24
+ 20,
25
+ 20
26
+ ],
27
+ "num_frames": 25,
28
+ "out_channels": 4,
29
+ "projection_class_embeddings_input_dim": 768,
30
+ "sample_size": 96,
31
+ "transformer_layers_per_block": 1,
32
+ "up_block_types": [
33
+ "UpBlockSpatioTemporal",
34
+ "CrossAttnUpBlockSpatioTemporal",
35
+ "CrossAttnUpBlockSpatioTemporal",
36
+ "CrossAttnUpBlockSpatioTemporal"
37
+ ]
38
+ }
pretrained_weights/stable-video-diffusion-img2vid-xt/unet/diffusion_pytorch_model.fp16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fbc02e90f37d422f5e3a4aeaee95f6629dc8c45ca211b951626e930daf2bddf
3
+ size 3049435868
pretrained_weights/stable-video-diffusion-img2vid-xt/unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7783d82729af04f26ded4641a5952617fe331fc46add332fb9e47674fecc6ad7
3
+ size 6098682464
pretrained_weights/stable-video-diffusion-img2vid-xt/vae/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKLTemporalDecoder",
3
+ "_diffusers_version": "0.24.0.dev0",
4
+ "_name_or_path": "/home/suraj_huggingface_co/.cache/huggingface/hub/models--diffusers--svd-xt/snapshots/9703ded20c957c340781ee710b75660826deb487/vae",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "out_channels": 3,
22
+ "sample_size": 768,
23
+ "scaling_factor": 0.18215
24
+ }
pretrained_weights/stable-video-diffusion-img2vid-xt/vae/diffusion_pytorch_model.fp16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af602cd0eb4ad6086ec94fbf1438dfb1be5ec9ac03fd0215640854e90d6463a3
3
+ size 195531910
pretrained_weights/stable-video-diffusion-img2vid-xt/vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d92aa595a53d9da9faf594f09910ee869d5d567c8bb0362d5095673c69997d6
3
+ size 391017740
pretrained_weights/xnemo_denoising_unet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ff582dff6e19b08278378cfc244cf7203c6f70e3dcaba492ec39f9abb9be3d2
3
+ size 4927016814
pretrained_weights/xnemo_motion_encoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0230c49cebff21fd81c14fc61fc509ab1120b61415d40571e7dc1b9df1fc6b6f
3
+ size 246869630