File size: 9,338 Bytes
1146a67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# 细粒度显存管理方案

本文档介绍如何为模型编写合理的细粒度显存管理方案,以及如何将 `DiffSynth-Studio` 中的显存管理功能用于外部的其他代码库,在阅读本文档前,请先阅读文档[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)。

## 20B 模型需要多少显存?

以 Qwen-Image 的 DiT 模型为例,这一模型的参数量达到了 20B,以下代码会加载这一模型并进行推理,需要约 40G 显存,这个模型在显存较小的消费级 GPU 上显然是无法运行的。

```python
from diffsynth.core import load_model
from diffsynth.models.qwen_image_dit import QwenImageDiT
from modelscope import snapshot_download
import torch

snapshot_download(
    model_id="Qwen/Qwen-Image",
    local_dir="models/Qwen/Qwen-Image",
    allow_file_pattern="transformer/*"
)
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
inputs = {
    "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
    "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
    "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
    "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
    "height": 1024,
    "width": 1024,
}

model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cuda")
with torch.no_grad():
    output = model(**inputs)
```

## 编写细粒度显存管理方案

为了编写细粒度的显存管理方案,我们需用 `print(model)` 观察和分析模型结构:

```
QwenImageDiT(
  (pos_embed): QwenEmbedRope()
  (time_text_embed): TimestepEmbeddings(
    (time_proj): TemporalTimesteps()
    (timestep_embedder): DiffusersCompatibleTimestepProj(
      (linear_1): Linear(in_features=256, out_features=3072, bias=True)
      (act): SiLU()
      (linear_2): Linear(in_features=3072, out_features=3072, bias=True)
    )
  )
  (txt_norm): RMSNorm()
  (img_in): Linear(in_features=64, out_features=3072, bias=True)
  (txt_in): Linear(in_features=3584, out_features=3072, bias=True)
  (transformer_blocks): ModuleList(
    (0-59): 60 x QwenImageTransformerBlock(
      (img_mod): Sequential(
        (0): SiLU()
        (1): Linear(in_features=3072, out_features=18432, bias=True)
      )
      (img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
      (attn): QwenDoubleStreamAttention(
        (to_q): Linear(in_features=3072, out_features=3072, bias=True)
        (to_k): Linear(in_features=3072, out_features=3072, bias=True)
        (to_v): Linear(in_features=3072, out_features=3072, bias=True)
        (norm_q): RMSNorm()
        (norm_k): RMSNorm()
        (add_q_proj): Linear(in_features=3072, out_features=3072, bias=True)
        (add_k_proj): Linear(in_features=3072, out_features=3072, bias=True)
        (add_v_proj): Linear(in_features=3072, out_features=3072, bias=True)
        (norm_added_q): RMSNorm()
        (norm_added_k): RMSNorm()
        (to_out): Sequential(
          (0): Linear(in_features=3072, out_features=3072, bias=True)
        )
        (to_add_out): Linear(in_features=3072, out_features=3072, bias=True)
      )
      (img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
      (img_mlp): QwenFeedForward(
        (net): ModuleList(
          (0): ApproximateGELU(
            (proj): Linear(in_features=3072, out_features=12288, bias=True)
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=12288, out_features=3072, bias=True)
        )
      )
      (txt_mod): Sequential(
        (0): SiLU()
        (1): Linear(in_features=3072, out_features=18432, bias=True)
      )
      (txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
      (txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
      (txt_mlp): QwenFeedForward(
        (net): ModuleList(
          (0): ApproximateGELU(
            (proj): Linear(in_features=3072, out_features=12288, bias=True)
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=12288, out_features=3072, bias=True)
        )
      )
    )
  )
  (norm_out): AdaLayerNorm(
    (linear): Linear(in_features=3072, out_features=6144, bias=True)
    (norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
  )
  (proj_out): Linear(in_features=3072, out_features=64, bias=True)
)
```

在显存管理中,我们只关心包含参数的 Layer。在这个模型结构中,`QwenEmbedRope``TemporalTimesteps``SiLU` 等 Layer 都是不包含参数的,`LayerNorm` 也因为设置了 `elementwise_affine=False` 不包含参数。包含参数的 Layer 只有 `Linear``RMSNorm``diffsynth.core.vram` 中提供了两个用于替换的模块用于显存管理:
* `AutoWrappedLinear`: 用于替换 `Linear`* `AutoWrappedModule`: 用于替换其他任意层

编写一个 `module_map`,将模型中的 `Linear``RMSNorm` 映射到对应的模块上:

```python
module_map={
    torch.nn.Linear: AutoWrappedLinear,
    RMSNorm: AutoWrappedModule,
}
```

此外,还需要提供 `vram_config``vram_limit`,这两个参数在[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md#更多使用方式)中已有介绍。

调用 `enable_vram_management` 即可启用显存管理,注意此时模型加载时的 `device``cpu`,与 `offload_device` 一致:

```python
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
import torch

prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
inputs = {
    "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
    "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
    "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
    "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
    "height": 1024,
    "width": 1024,
}

model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cpu")
enable_vram_management(
    model,
    module_map={
        torch.nn.Linear: AutoWrappedLinear,
        RMSNorm: AutoWrappedModule,
    },
    vram_config = {
        "offload_dtype": torch.bfloat16,
        "offload_device": "cpu",
        "onload_dtype": torch.bfloat16,
        "onload_device": "cpu",
        "preparing_dtype": torch.bfloat16,
        "preparing_device": "cuda",
        "computation_dtype": torch.bfloat16,
        "computation_device": "cuda",
    },
    vram_limit=0,
)
with torch.no_grad():
    output = model(**inputs)
```

以上代码只需要 2G 显存就可以运行 20B 模型的 `forward`## Disk Offload

[Disk Offload](/docs/zh/Pipeline_Usage/VRAM_management.md#disk-offload) 是特殊的显存管理方案,需在模型加载过程中启用,而非模型加载完毕后。通常,在以上代码能够顺利运行的前提下,Disk Offload 可以直接启用:

```python
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
import torch

prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
inputs = {
    "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
    "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
    "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
    "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
    "height": 1024,
    "width": 1024,
}

model = load_model(
    QwenImageDiT,
    model_path,
    module_map={
        torch.nn.Linear: AutoWrappedLinear,
        RMSNorm: AutoWrappedModule,
    },
    vram_config={
        "offload_dtype": "disk",
        "offload_device": "disk",
        "onload_dtype": "disk",
        "onload_device": "disk",
        "preparing_dtype": torch.bfloat16,
        "preparing_device": "cuda",
        "computation_dtype": torch.bfloat16,
        "computation_device": "cuda",
    },
    vram_limit=0,
)
with torch.no_grad():
    output = model(**inputs)
```

Disk Offload 是极为特殊的显存管理方案,只支持 `.safetensors` 格式文件,不支持 `.bin``.pth``.ckpt` 等二进制文件,不支持带 Tensor reshape 的 [state dict converter](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换)。

如果出现非 Disk Offload 能正常运行但 Disk Offload 不能正常运行的情况,请在 GitHub 上给我们提 issue。

## 写入默认配置

为了让用户能够更方便地使用显存管理功能,我们将细粒度显存管理的配置写在 `diffsynth/configs/vram_management_module_maps.py` 中,上述模型的配置信息为:

```python
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
    "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
    "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
}
```