File size: 9,338 Bytes
1146a67 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 | # 细粒度显存管理方案
本文档介绍如何为模型编写合理的细粒度显存管理方案,以及如何将 `DiffSynth-Studio` 中的显存管理功能用于外部的其他代码库,在阅读本文档前,请先阅读文档[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)。
## 20B 模型需要多少显存?
以 Qwen-Image 的 DiT 模型为例,这一模型的参数量达到了 20B,以下代码会加载这一模型并进行推理,需要约 40G 显存,这个模型在显存较小的消费级 GPU 上显然是无法运行的。
```python
from diffsynth.core import load_model
from diffsynth.models.qwen_image_dit import QwenImageDiT
from modelscope import snapshot_download
import torch
snapshot_download(
model_id="Qwen/Qwen-Image",
local_dir="models/Qwen/Qwen-Image",
allow_file_pattern="transformer/*"
)
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
inputs = {
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
"height": 1024,
"width": 1024,
}
model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cuda")
with torch.no_grad():
output = model(**inputs)
```
## 编写细粒度显存管理方案
为了编写细粒度的显存管理方案,我们需用 `print(model)` 观察和分析模型结构:
```
QwenImageDiT(
(pos_embed): QwenEmbedRope()
(time_text_embed): TimestepEmbeddings(
(time_proj): TemporalTimesteps()
(timestep_embedder): DiffusersCompatibleTimestepProj(
(linear_1): Linear(in_features=256, out_features=3072, bias=True)
(act): SiLU()
(linear_2): Linear(in_features=3072, out_features=3072, bias=True)
)
)
(txt_norm): RMSNorm()
(img_in): Linear(in_features=64, out_features=3072, bias=True)
(txt_in): Linear(in_features=3584, out_features=3072, bias=True)
(transformer_blocks): ModuleList(
(0-59): 60 x QwenImageTransformerBlock(
(img_mod): Sequential(
(0): SiLU()
(1): Linear(in_features=3072, out_features=18432, bias=True)
)
(img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(attn): QwenDoubleStreamAttention(
(to_q): Linear(in_features=3072, out_features=3072, bias=True)
(to_k): Linear(in_features=3072, out_features=3072, bias=True)
(to_v): Linear(in_features=3072, out_features=3072, bias=True)
(norm_q): RMSNorm()
(norm_k): RMSNorm()
(add_q_proj): Linear(in_features=3072, out_features=3072, bias=True)
(add_k_proj): Linear(in_features=3072, out_features=3072, bias=True)
(add_v_proj): Linear(in_features=3072, out_features=3072, bias=True)
(norm_added_q): RMSNorm()
(norm_added_k): RMSNorm()
(to_out): Sequential(
(0): Linear(in_features=3072, out_features=3072, bias=True)
)
(to_add_out): Linear(in_features=3072, out_features=3072, bias=True)
)
(img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(img_mlp): QwenFeedForward(
(net): ModuleList(
(0): ApproximateGELU(
(proj): Linear(in_features=3072, out_features=12288, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=12288, out_features=3072, bias=True)
)
)
(txt_mod): Sequential(
(0): SiLU()
(1): Linear(in_features=3072, out_features=18432, bias=True)
)
(txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(txt_mlp): QwenFeedForward(
(net): ModuleList(
(0): ApproximateGELU(
(proj): Linear(in_features=3072, out_features=12288, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=12288, out_features=3072, bias=True)
)
)
)
)
(norm_out): AdaLayerNorm(
(linear): Linear(in_features=3072, out_features=6144, bias=True)
(norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
)
(proj_out): Linear(in_features=3072, out_features=64, bias=True)
)
```
在显存管理中,我们只关心包含参数的 Layer。在这个模型结构中,`QwenEmbedRope`、`TemporalTimesteps`、`SiLU` 等 Layer 都是不包含参数的,`LayerNorm` 也因为设置了 `elementwise_affine=False` 不包含参数。包含参数的 Layer 只有 `Linear` 和 `RMSNorm`。
`diffsynth.core.vram` 中提供了两个用于替换的模块用于显存管理:
* `AutoWrappedLinear`: 用于替换 `Linear` 层
* `AutoWrappedModule`: 用于替换其他任意层
编写一个 `module_map`,将模型中的 `Linear` 和 `RMSNorm` 映射到对应的模块上:
```python
module_map={
torch.nn.Linear: AutoWrappedLinear,
RMSNorm: AutoWrappedModule,
}
```
此外,还需要提供 `vram_config` 与 `vram_limit`,这两个参数在[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md#更多使用方式)中已有介绍。
调用 `enable_vram_management` 即可启用显存管理,注意此时模型加载时的 `device` 为 `cpu`,与 `offload_device` 一致:
```python
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
import torch
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
inputs = {
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
"height": 1024,
"width": 1024,
}
model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cpu")
enable_vram_management(
model,
module_map={
torch.nn.Linear: AutoWrappedLinear,
RMSNorm: AutoWrappedModule,
},
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
},
vram_limit=0,
)
with torch.no_grad():
output = model(**inputs)
```
以上代码只需要 2G 显存就可以运行 20B 模型的 `forward`。
## Disk Offload
[Disk Offload](/docs/zh/Pipeline_Usage/VRAM_management.md#disk-offload) 是特殊的显存管理方案,需在模型加载过程中启用,而非模型加载完毕后。通常,在以上代码能够顺利运行的前提下,Disk Offload 可以直接启用:
```python
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
import torch
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
inputs = {
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
"height": 1024,
"width": 1024,
}
model = load_model(
QwenImageDiT,
model_path,
module_map={
torch.nn.Linear: AutoWrappedLinear,
RMSNorm: AutoWrappedModule,
},
vram_config={
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
},
vram_limit=0,
)
with torch.no_grad():
output = model(**inputs)
```
Disk Offload 是极为特殊的显存管理方案,只支持 `.safetensors` 格式文件,不支持 `.bin`、`.pth`、`.ckpt` 等二进制文件,不支持带 Tensor reshape 的 [state dict converter](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换)。
如果出现非 Disk Offload 能正常运行但 Disk Offload 不能正常运行的情况,请在 GitHub 上给我们提 issue。
## 写入默认配置
为了让用户能够更方便地使用显存管理功能,我们将细粒度显存管理的配置写在 `diffsynth/configs/vram_management_module_maps.py` 中,上述模型的配置信息为:
```python
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
}
```
|