caixiaoshun commited on
Commit
020b1da
·
verified ·
1 Parent(s): 7ea70d9

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +274 -0
model.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class Attention(nn.Module):
7
+ def __init__(self, n_head, dim):
8
+ super().__init__()
9
+ assert dim % n_head == 0
10
+ self.qkv_proj = nn.Linear(dim, dim * 3)
11
+ self.out_proj = nn.Linear(dim, dim)
12
+ self.n_head = n_head
13
+ self.head_dim = dim // self.n_head
14
+
15
+ def forward(self, x: torch.Tensor):
16
+ batch_size, channel, height, width = x.shape
17
+ x = x.reshape(batch_size, channel, height * width).transpose(-1, -2)
18
+ q, k, v = torch.chunk(self.qkv_proj(x), chunks=3, dim=-1)
19
+ q_state = q.reshape(
20
+ batch_size, height * width, self.n_head, self.head_dim
21
+ ).transpose(1, 2)
22
+ k_state = k.reshape(
23
+ batch_size, height * width, self.n_head, self.head_dim
24
+ ).transpose(1, 2)
25
+ v_state = v.reshape(
26
+ batch_size, height * width, self.n_head, self.head_dim
27
+ ).transpose(1, 2)
28
+
29
+ out = F.scaled_dot_product_attention(q_state, k_state, v_state)
30
+ out = out.transpose(1, 2).reshape(batch_size, height * width, channel)
31
+ out = self.out_proj(out)
32
+ out = out.transpose(-1, -2).reshape(batch_size, channel, height, width)
33
+ return out
34
+
35
+
36
+ class TimePositionEmbedding(nn.Module):
37
+ def __init__(self, seq_len=1000, dim=320):
38
+ super().__init__()
39
+ base = 10000
40
+ inv_freq = 1 / base ** (torch.arange(0, dim, step=2).float() / dim)
41
+ inv_freq = inv_freq.unsqueeze(0)
42
+ position = torch.arange(0, seq_len, step=1).unsqueeze(1)
43
+ position = position * inv_freq
44
+ pe = torch.zeros(size=(seq_len, dim))
45
+ pe[:, 0::2] = position.sin()
46
+ pe[:, 1::2] = position.cos()
47
+ self.register_buffer("pe", pe, persistent=False)
48
+
49
+ def forward(self, time):
50
+ time = time.reshape(-1)
51
+ return self.pe[time]
52
+
53
+
54
+ class TimeEmbedding(nn.Module):
55
+ def __init__(self, dim):
56
+ super().__init__()
57
+ self.mlp = nn.Sequential(
58
+ nn.Linear(dim, dim * 4), nn.SiLU(), nn.Linear(dim * 4, dim * 4)
59
+ )
60
+
61
+ def forward(self, x):
62
+ return self.mlp(x)
63
+
64
+
65
+ class ResidualBlock(nn.Module):
66
+ def __init__(self, in_channel, out_channel, time_dim):
67
+ super().__init__()
68
+ self.norm1 = nn.GroupNorm(32, in_channel)
69
+ self.norm2 = nn.GroupNorm(32, out_channel)
70
+ self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1)
71
+ self.time_proj = nn.Linear(time_dim, out_channel)
72
+ self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1)
73
+ self.residual_conv = nn.Identity()
74
+ if in_channel != out_channel:
75
+ self.residual_conv = nn.Conv2d(in_channel, out_channel, kernel_size=1)
76
+
77
+ def forward(self, x, time):
78
+ residual = x
79
+ x = F.silu(self.conv1(self.norm1(x)))
80
+ time = self.time_proj(time)[:, :, None, None]
81
+ x += time
82
+ x = self.norm2(x)
83
+ x = F.silu(self.conv2(x))
84
+ return self.residual_conv(residual) + x
85
+
86
+
87
+ class DownSampler(nn.Module):
88
+ def __init__(self, in_channel):
89
+ super().__init__()
90
+ self.conv = nn.Conv2d(
91
+ in_channel, in_channel, stride=2, padding=1, kernel_size=3
92
+ )
93
+
94
+ def forward(self, x):
95
+ return self.conv(x)
96
+
97
+
98
+ class UpSampler(nn.Module):
99
+ def __init__(self, in_channel):
100
+ super().__init__()
101
+ self.conv = nn.Conv2d(
102
+ in_channel, in_channel, stride=1, padding=1, kernel_size=3
103
+ )
104
+ self.up = nn.Upsample(scale_factor=2)
105
+
106
+ def forward(self, x):
107
+ x = self.up(x)
108
+ return self.conv(x)
109
+
110
+
111
+ class SwitchSequential(nn.Sequential):
112
+ def forward(self, x, time):
113
+ for module in self:
114
+ if isinstance(module, ResidualBlock):
115
+ x = module(x, time)
116
+ else:
117
+ x = module(x)
118
+ return x
119
+
120
+
121
+ class Unet(nn.Module):
122
+ def __init__(self, time_dim=320, n_head=8):
123
+ super().__init__()
124
+ # 时间嵌入
125
+ self.time_position_embedding = TimePositionEmbedding()
126
+ self.time_proj = TimeEmbedding(dim=320)
127
+ time_dim = time_dim * 4
128
+
129
+ # ---------------- Encoder:保存“下采样前”的特征做 skip ----------------
130
+ self.down_blocks = nn.ModuleList(
131
+ [
132
+ # 输出:128 通道,分辨率 H
133
+ SwitchSequential(
134
+ nn.Conv2d(3, 64, kernel_size=3, padding=1, stride=1),
135
+ ResidualBlock(64, 128, time_dim=time_dim),
136
+ ResidualBlock(128, 128, time_dim=time_dim),
137
+ ),
138
+ # 输出:256 通道,分辨率 H/2
139
+ SwitchSequential(
140
+ ResidualBlock(128, 256, time_dim=time_dim),
141
+ ResidualBlock(256, 256, time_dim=time_dim),
142
+ ),
143
+ # 输出:512 通道,分辨率 H/4
144
+ SwitchSequential(
145
+ ResidualBlock(256, 512, time_dim=time_dim),
146
+ ResidualBlock(512, 512, time_dim=time_dim),
147
+ Attention(n_head, 512),
148
+ ResidualBlock(512, 512, time_dim=time_dim),
149
+ ),
150
+ # 底部:512 通道,分辨率 H/8(无下采样)
151
+ SwitchSequential(
152
+ ResidualBlock(512, 512, time_dim=time_dim),
153
+ Attention(n_head, 512),
154
+ ResidualBlock(512, 512, time_dim=time_dim),
155
+ ),
156
+ ]
157
+ )
158
+ self.down_samplers = nn.ModuleList(
159
+ [
160
+ DownSampler(128), # H -> H/2
161
+ DownSampler(256), # H/2 -> H/4
162
+ DownSampler(512), # H/4 -> H/8
163
+ ]
164
+ )
165
+
166
+ # ---------------- Bottleneck ----------------
167
+ self.mid_blocks = nn.ModuleList(
168
+ [
169
+ SwitchSequential(
170
+ ResidualBlock(512, 512, time_dim=time_dim),
171
+ Attention(n_head, 512),
172
+ ResidualBlock(512, 512, time_dim=time_dim),
173
+ ),
174
+ SwitchSequential(
175
+ ResidualBlock(512, 512, time_dim=time_dim),
176
+ Attention(n_head, 512),
177
+ ResidualBlock(512, 512, time_dim=time_dim),
178
+ ),
179
+ SwitchSequential(
180
+ ResidualBlock(512, 512, time_dim=time_dim),
181
+ Attention(n_head, 512),
182
+ ResidualBlock(512, 512, time_dim=time_dim),
183
+ ),
184
+ ]
185
+ )
186
+
187
+ # ---------------- Decoder:先上采样,再与对应 skip 拼接 ----------------
188
+ # up_blocks[0]:在最底层先做一轮处理(不拼接)
189
+ # up_blocks[1]:分辨率 H/4,拼接 skip@H/4(512 通道),输出保持 512
190
+ # up_blocks[2]:分辨率 H/2,拼接 skip@H/2(256 通道),输出 256
191
+ # up_blocks[3]:分辨率 H,拼接 skip@H(128 通道),输出 64
192
+ self.up_blocks = nn.ModuleList(
193
+ [
194
+ SwitchSequential( # H/8,512 -> 512(不拼接)
195
+ ResidualBlock(512, 512, time_dim=time_dim),
196
+ Attention(n_head, 512),
197
+ ResidualBlock(512, 512, time_dim=time_dim),
198
+ ),
199
+ SwitchSequential( # H/4,(512 + 512) -> 512
200
+ ResidualBlock(512 + 512, 512, time_dim=time_dim),
201
+ Attention(n_head, 512),
202
+ ResidualBlock(512, 512, time_dim=time_dim),
203
+ ),
204
+ SwitchSequential( # H/2,(512 + 256) -> 256
205
+ ResidualBlock(512 + 256, 256, time_dim=time_dim),
206
+ ResidualBlock(256, 256, time_dim=time_dim),
207
+ Attention(n_head, 256),
208
+ ResidualBlock(256, 256, time_dim=time_dim),
209
+ ),
210
+ SwitchSequential( # H,(256 + 128) -> 64
211
+ ResidualBlock(256 + 128, 64, time_dim=time_dim),
212
+ ResidualBlock(64, 64, time_dim=time_dim),
213
+ ),
214
+ ]
215
+ )
216
+ # 与各阶段输出通道匹配的上采样器:
217
+ # 先把 512@H/8 上采样到 512@H/4,再 512@H/2,最后 256@H
218
+ self.up_samplers = nn.ModuleList(
219
+ [
220
+ UpSampler(512), # H/8 -> H/4
221
+ UpSampler(512), # H/4 -> H/2
222
+ UpSampler(256), # H/2 -> H
223
+ ]
224
+ )
225
+
226
+ self.head = nn.Conv2d(64, 3, kernel_size=3, padding=1, stride=1)
227
+
228
+ def forward(self, x, time):
229
+ # 时间嵌入
230
+ t = self.time_proj(self.time_position_embedding(time))
231
+
232
+ # -------- Encoder:每个 down_block 输出作为 pre-down skip,然后再下采样 --------
233
+ skips = []
234
+ for i, block in enumerate(self.down_blocks):
235
+ x = block(x, t) # 处理当前分辨率
236
+ skips.append(x) # 保存“下采样前”的特征
237
+ if i < len(self.down_samplers):
238
+ x = self.down_samplers[i](x) # 下采样到更小分辨率
239
+
240
+ # -------- Bottleneck --------
241
+ for block in self.mid_blocks:
242
+ x = block(x, t)
243
+
244
+ # -------- Decoder --------
245
+ # 底部先做一轮处理(不拼接)
246
+ x = self.up_blocks[0](x, t) # 仍在 H/8,通道 512
247
+
248
+ # Stage 1:H/8 -> H/4,拼接 skip@H/4(skips[2])
249
+ x = self.up_samplers[0](x) # 512@H/4
250
+ x = torch.cat([x, skips[2]], dim=1) # (512 + 512)@H/4
251
+ x = self.up_blocks[1](x, t) # 512@H/4
252
+
253
+ # Stage 2:H/4 -> H/2,拼接 skip@H/2(skips[1])
254
+ x = self.up_samplers[1](x) # 512@H/2
255
+ x = torch.cat([x, skips[1]], dim=1) # (512 + 256)@H/2
256
+ x = self.up_blocks[2](x, t) # 256@H/2
257
+
258
+ # Stage 3:H/2 -> H,拼接 skip@H(skips[0])
259
+ x = self.up_samplers[2](x) # 256@H
260
+ x = torch.cat([x, skips[0]], dim=1) # (256 + 128)@H
261
+ x = self.up_blocks[3](x, t) # 64@H
262
+
263
+ # 头部
264
+ x = self.head(x) # -> 3@H
265
+ return x
266
+
267
+
268
+ if __name__ == "__main__":
269
+ model = Unet()
270
+ x = torch.randn(2, 3, 64, 64)
271
+ t = torch.randint(0, 1000, (2,))
272
+ out = model(x, t)
273
+ print(out.shape)
274
+ # torch.save({"model": model.state_dict()}, "unet.pt")