johnson115 commited on
Commit
ab9f2cc
·
verified ·
1 Parent(s): 898cced

Upload 2 files

Browse files
Files changed (2) hide show
  1. down_unet.py +401 -0
  2. up_unet.py +239 -0
down_unet.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ import sympy as sp
6
+ import wandb
7
+ from PIL import Image
8
+ from datasets import load_dataset
9
+ from torchvision import transforms
10
+
11
+
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ print(device)
15
+
16
+ # 初始化项目
17
+ wandb.init(
18
+ # set the wandb project where this run will be logged
19
+ project="unet-try",
20
+ )
21
+
22
+ '''
23
+
24
+ conv_block = resnetblock--attentionblock--convblock. input:[B,C,H,W],output:[B,channel_dim,H(+/-)2,W(+/-)2]
25
+
26
+ down block = 2blocks|-->for_skip_connection
27
+ |
28
+ down_sample-->result_after_pool. input:[B,C,H,W],output:[B,channel_dim,(H-4)//2,(W-4)//2]
29
+
30
+ up block = -->concat-->2blocks input:[B,C,H,W],input_skip:[B,C/2,2H,2W],output:[B,C/2,2H+4,2W+4]
31
+ |
32
+ --up_sample
33
+
34
+ LR-----------------------------MSE LOSS--------------------------LR
35
+ |--down block -------------skip connection-----------up block--|
36
+ |--down block up block--|
37
+ |---------------|
38
+ '''
39
+
40
+ # ----------------------------------------------------------------------------------------------------
41
+ class conv_block(nn.Module): #一个下采样模块包含两个卷积层,深度channel从1-64-128-256这样[B,C,H,W]-->[B,C_DIM,H-2,W-2]
42
+ def __init__(self,in_channel,num_heads,channel_dim,use ="down"):
43
+ super(conv_block,self).__init__() #in_channel输入通道数,channle_dim输出通道数,一个块减少2
44
+
45
+
46
+ self.in_channel = in_channel
47
+ self.num_heads = num_heads
48
+ self.channel_dim = channel_dim
49
+ self.use = use
50
+
51
+ self.GN = nn.GroupNorm(num_groups=4, num_channels=in_channel) #这个channel指的是输入通道数
52
+ # num_groups 是组数(2,4,8)输入特征的通道分成多少组进行归一化,num_channels 是输入的通道数
53
+ self.conv = nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=3,
54
+ stride=1, padding=1, bias=False)
55
+ self.silu = nn.SiLU()
56
+ self.attention = nn.MultiheadAttention(embed_dim=self.in_channel, num_heads=self.num_heads)
57
+
58
+ if self.use == "down":
59
+ self.conv1 = nn.Conv2d(in_channels=self.in_channel, out_channels=self.channel_dim, kernel_size=3,
60
+ stride=1, padding=0, bias=False)
61
+ elif self.use =="up":
62
+ self.conv1 = nn.Conv2d(in_channels=self.in_channel, out_channels=self.channel_dim, kernel_size=3,
63
+ stride=1, padding=2, bias=False)
64
+
65
+
66
+
67
+ def resnet_block(self,X): #隐藏层使用和输入一样的大小
68
+
69
+ out = self.GN(X)
70
+ out = self.conv(out)
71
+ out = self.silu(out)
72
+
73
+ out = self.GN(out)
74
+ out = self.conv(out)
75
+ out = self.silu(out)
76
+
77
+ return out + X
78
+
79
+ def attention_block(self,X):
80
+
81
+ B,C,H,W = X.size()
82
+
83
+ out = self.GN(X)
84
+ out = self.conv(out)
85
+
86
+ out = out.view(B, self.in_channel, H * W).transpose(1, 2) # 将输入重构为 [B, L, C],其中 L = H * W
87
+ out, weights = self.attention(out, out, out)
88
+ out = out.transpose(1, 2).view(B, self.in_channel, H, W)
89
+
90
+ out = self.conv(out)
91
+
92
+ return out+X
93
+
94
+ def forward(self,X):
95
+
96
+ out = self.resnet_block(X)
97
+ out = self.attention_block(out)
98
+ out = self.conv1(out)
99
+
100
+ return out
101
+
102
+
103
+ '''
104
+ model = conv_block(in_channel=4,num_heads=4,channel_dim=64,use="down")
105
+ in_put = torch.randn(1,4,256,256) #注意,在SR3代码中隐藏层是不变的和输入一致
106
+ ouput = model(in_put)
107
+ print(ouput.shape)
108
+ '''
109
+ # -------------------------------------------------------------------------------------------------
110
+ class SpatialAttention(nn.Module):
111
+ def __init__(self, in_channels):
112
+ super(SpatialAttention, self).__init__()
113
+ self.conv = nn.Conv2d(in_channels, 1, kernel_size=1)
114
+
115
+ def forward(self, x):
116
+ # Apply convolution to generate attention map
117
+ attention_map = self.conv(x)
118
+ # Generate attention scores
119
+ attention_scores = torch.softmax(attention_map, dim=1)
120
+ # Apply attention scores
121
+ out = x * attention_scores
122
+ return out
123
+
124
+ class ChannelAttention(nn.Module):
125
+ def __init__(self, in_channels, reduction_ratio=16):
126
+ super(ChannelAttention, self).__init__()
127
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
128
+ self.fc = nn.Sequential(
129
+ nn.Linear(in_channels, in_channels // reduction_ratio, bias=False),
130
+ nn.ReLU(),
131
+ nn.Linear(in_channels // reduction_ratio, in_channels, bias=False),
132
+ nn.ReLU()
133
+ )
134
+
135
+ def forward(self, x):
136
+ # Average pooling to generate a channel descriptor
137
+ avg_out = self.avg_pool(x).view(x.size(0), -1)
138
+ # Apply fully connected layers to generate channel attention
139
+ attn = self.fc(avg_out)
140
+ # Reshape attention to match the input
141
+ attn = attn.view(x.size(0), -1, 1, 1)
142
+ return x * attn
143
+
144
+
145
+ def calculate_attention(X, num_heads, use):
146
+ X = X.to(device)
147
+ B, C, H, W = X.size()
148
+
149
+ if use == "down":
150
+ # Apply channel attention
151
+ channel_attention = ChannelAttention(C).to(device)
152
+ out = channel_attention(X)
153
+ elif use == "up":
154
+ # Reshape and transpose for multi-head attention
155
+ up = X.view(B, C, H * W).transpose(1, 2)
156
+ spatial_attention = nn.MultiheadAttention(embed_dim=C, num_heads=num_heads).to(device)
157
+ out, weights = spatial_attention(up, up, up)
158
+ # Apply spatial attention on upsampled output
159
+ out = out.transpose(1, 2).view(B, C, H,W)
160
+ spatial_attention_module = SpatialAttention(in_channels=C).to(device)
161
+ out = spatial_attention_module(out)
162
+ # Reshape output to match the original input dimensions
163
+
164
+
165
+ return out
166
+
167
+ '''
168
+ # Example usage
169
+ X = torch.randn(1,4,572,572) # Example input tensor
170
+ num_heads = 4
171
+ attention_out = calculate_attention(X, num_heads,use="up")
172
+ print("attention out",attention_out.shape)
173
+ '''
174
+ '''
175
+ X = torch.randn(1, 64, 254, 254)
176
+ output = calculate_attention(X,num_heads=8)
177
+ print("attention", output.shape) # 应该输出 torch.Size([1, 64, 254, 254])
178
+ '''
179
+ # -----------------------------------------------------------------------------------
180
+
181
+ def generate_positional_encoding(X):
182
+ X = X.to(device)
183
+ B,C,H,W = X.size()
184
+ # 初始化位置编码矩阵
185
+ pos_encoding = torch.zeros(B, C, H, W)
186
+
187
+ # 计算位置索引
188
+ y_positions = torch.arange(0, H, dtype=torch.float32).unsqueeze(1).repeat(1, W) #[H,W]
189
+ x_positions = torch.arange(0, W, dtype=torch.float32).unsqueeze(0).repeat(H, 1)
190
+
191
+ # 将位置索引除以尺度以进行缩放
192
+ y_positions = y_positions / (H ** 0.5)
193
+ x_positions = x_positions / (W ** 0.5)
194
+
195
+ # 计算位置编码的正弦和余弦值
196
+ for i in range(0, C, 2):
197
+ pos_encoding[:, i, :, :] = torch.sin(x_positions)
198
+ pos_encoding[:, i + 1, :, :] = torch.cos(y_positions)
199
+
200
+ return pos_encoding
201
+
202
+ '''
203
+ X = torch.randn(1,128, 512, 512)
204
+ # 计算位置编码
205
+ pos_encoding = generate_positional_encoding(X)
206
+ print("Positional Encoding shape:", pos_encoding.shape) # 应该输出 torch.Size([1, 64, 254, 254])
207
+ '''
208
+
209
+ class down_block(nn.Module): #宽高减4,然后除以2
210
+ def __init__(self,in_channel,channel_dim):
211
+ super(down_block,self).__init__()
212
+
213
+ self.channel_dim = channel_dim
214
+
215
+ self.block1 = conv_block(in_channel=in_channel,num_heads=4,
216
+ channel_dim=self.channel_dim,use="down")
217
+ self.block2 = conv_block(in_channel=self.channel_dim, num_heads=4,
218
+ channel_dim=self.channel_dim, use="down")
219
+
220
+ self.down_pool = nn.Conv2d(in_channels=self.channel_dim, out_channels=self.channel_dim, kernel_size=2,
221
+ stride=2, padding=0, bias=False)
222
+
223
+
224
+ def forward(self,X): #输入[1,4,128,128],输出[1.64,124,124]-->[1,64,62,62]
225
+
226
+ out = self.block1(X)
227
+ for_skip_connection = self.block2(out) #这个out用于跳跃连接的
228
+ result_after_pool = self.down_pool(for_skip_connection)
229
+
230
+ return result_after_pool,for_skip_connection
231
+
232
+ '''
233
+ model1 = down_block(in_channel=64,channel_dim=128)
234
+ input = torch.randn(1,64,284,284)
235
+ res,out = model1(input)
236
+ print(res.shape,out.shape)
237
+ '''
238
+ # --------------------------------------------------------------------------------------------------
239
+ class up_block(nn.Module):
240
+ def __init__(self,in_channel):
241
+ super(up_block,self).__init__()
242
+ self.in_channel = in_channel
243
+
244
+
245
+ self.block1 = conv_block(in_channel=in_channel*2, num_heads=4,
246
+ channel_dim=in_channel,use="up")
247
+ self.block2 = conv_block(in_channel=in_channel, num_heads=4,
248
+ channel_dim=in_channel,use="up")
249
+ self.up_pool = nn.ConvTranspose2d(self.in_channel*2, self.in_channel,
250
+ kernel_size=2, stride=2)
251
+
252
+ def forward(self,input,input_skip): #先对输入进行上采样,然后和跳跃的拼接,之后经过两个block
253
+
254
+
255
+ after_transposed = self.up_pool(input) #上采样得到的大小
256
+ after_cat = torch.cat((after_transposed, input_skip), dim=1) # 拼接张量
257
+ out = self.block1(after_cat)
258
+ out = self.block2(out)
259
+
260
+ return out,after_transposed
261
+
262
+ '''
263
+ model2 = up_block(in_channel=128)
264
+ input = torch.randn(1,256,140,140)
265
+ input_skip = torch.randn(1,128,280,280)
266
+ out,after = model2(input,input_skip)
267
+ print("up block",out.shape) #torch.Size([1, 128, 284, 284])
268
+ '''
269
+
270
+
271
+ class down_model(nn.Module):
272
+ def __init__(self):
273
+ super(down_model,self).__init__()
274
+
275
+ self.start_conv = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=1, stride=1)
276
+
277
+ self.down_block1 = down_block(4,64)
278
+ self.down_block2 = down_block(64,128)
279
+ self.down_block3 = down_block(128,256)
280
+ self.down_block4 = down_block(256,512)
281
+
282
+ self.bottle_conv = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1, stride=1)
283
+
284
+ self.up_block4 = up_block(512)
285
+ self.up_block3 = up_block(256)
286
+ self.up_block2 = up_block(128)
287
+ self.up_block1 = up_block(64)
288
+
289
+ self.final_conv = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1)
290
+
291
+ def forward(self,input): #这个地方的输入一定要除的尽
292
+
293
+ input = self.start_conv(input)
294
+
295
+ result_after_pool1, for_skip_connection1 = self.down_block1(input)
296
+ attention_out1 = calculate_attention(for_skip_connection1, num_heads=4, use="down")
297
+ pos_encoding1 = generate_positional_encoding(for_skip_connection1)
298
+ # print("1",result_after_pool1.shape,for_skip_connection1.shape)
299
+
300
+
301
+ result_after_pool2, for_skip_connection2 = self.down_block2(result_after_pool1)
302
+ attention_out2 = calculate_attention(for_skip_connection2, num_heads=4, use="down")
303
+ pos_encoding2 = generate_positional_encoding(for_skip_connection2)
304
+ # print("2",result_after_pool2.shape, for_skip_connection2.shape)
305
+
306
+ result_after_pool3, for_skip_connection3 = self.down_block3(result_after_pool2)
307
+ attention_out3 = calculate_attention(for_skip_connection3, num_heads=4, use="down")
308
+ pos_encoding3 = generate_positional_encoding(for_skip_connection3)
309
+ # print("3",result_after_pool3.shape, for_skip_connection3.shape)
310
+
311
+ result_after_pool4, for_skip_connection4 = self.down_block4(result_after_pool3)
312
+ attention_out4 = calculate_attention(for_skip_connection4, num_heads=4, use="down")
313
+ pos_encoding4 = generate_positional_encoding(for_skip_connection4)
314
+ # print("4",result_after_pool4.shape, for_skip_connection4.shape)
315
+
316
+
317
+ result_after_pool4 = self.bottle_conv(result_after_pool4)
318
+ # print("bottle",result_after_pool4.shape)
319
+
320
+
321
+ out, after_transposed1 = self.up_block4(result_after_pool4, for_skip_connection4)
322
+ attention_out5 = calculate_attention(after_transposed1, num_heads=4, use="up")
323
+ pos_encoding5 = generate_positional_encoding(after_transposed1)
324
+ # print("5",out.shape,after_transposed1.shape)
325
+
326
+
327
+ out, after_transposed2 = self.up_block3(out, for_skip_connection3)
328
+ attention_out6 = calculate_attention(after_transposed2, num_heads=4, use="up").to(device)
329
+ pos_encoding6 = generate_positional_encoding(after_transposed2).to(device)
330
+ # print("6",out.shape, after_transposed2.shape)
331
+
332
+
333
+ out, after_transposed3 = self.up_block2(out, for_skip_connection2)
334
+ attention_out7 = calculate_attention(after_transposed3, num_heads=4, use="up").to(device)
335
+ pos_encoding7 = generate_positional_encoding(after_transposed3).to(device)
336
+ # print("7",out.shape, after_transposed3.shape)
337
+
338
+
339
+ out, after_transposed4 = self.up_block1(out, for_skip_connection1)
340
+ attention_out8 = calculate_attention(after_transposed4, num_heads=4, use="up").to(device)
341
+ pos_encoding8 = generate_positional_encoding(after_transposed4).to(device)
342
+ # print("8",out.shape, after_transposed4.shape)
343
+
344
+
345
+ out = self.final_conv(out)
346
+
347
+ return out,attention_out1,attention_out2,attention_out3,attention_out4,attention_out5,attention_out6,attention_out7,attention_out8,pos_encoding1,pos_encoding2,pos_encoding3,pos_encoding4,pos_encoding5,pos_encoding6,pos_encoding7,pos_encoding8
348
+
349
+
350
+
351
+
352
+ '''
353
+ all_model = model()
354
+ input = torch.randn(1,4,1024,1024)
355
+ output = all_model(input)
356
+ print(output.shape)
357
+ '''
358
+
359
+
360
+ all_model = down_model().to(device)
361
+ loss_function = nn.MSELoss().to(device) #2.定义loss
362
+ optimizer = torch.optim.Adam(all_model.parameters(),lr=1e-6) #3.定义优化器
363
+
364
+ epoch = 3
365
+ batch_size = 10
366
+ image_size = 268 #【10,3,268,268】
367
+
368
+
369
+ ds = load_dataset("bitmind/ffhq-256",split="train")
370
+ preprocess = transforms.Compose(
371
+ [
372
+ transforms.Resize((image_size, image_size)), # Resize
373
+ transforms.RandomHorizontalFlip(), # Randomly flip (data augmentation)
374
+ transforms.ToTensor(), # Convert to tensor (0, 1)
375
+ transforms.Normalize([0.5], [0.5]), # Map to (-1, 1)
376
+ ]
377
+ )
378
+ def transform(examples):
379
+ images = [preprocess(image.convert("RGB")) for image in examples["image"]]
380
+ return {"images": images}
381
+
382
+ ds.set_transform(transform)
383
+ dataloader = torch.utils.data.DataLoader(ds,batch_size=batch_size,shuffle=True)
384
+
385
+
386
+ for i in range(epoch):
387
+ for idx, batch_x in enumerate(dataloader):
388
+ images = batch_x["images"].to(device)
389
+ # print(images.shape) #(4,3,572,572)
390
+ output = all_model(images).to(device)
391
+ loss = loss_function(output, images)
392
+ optimizer.zero_grad()
393
+ loss.backward()
394
+ torch.nn.utils.clip_grad_norm_(all_model.parameters(), 1.)
395
+ optimizer.step()
396
+ print("epoch:", i, "loss:", loss.item())
397
+ wandb.log({'epoch': i,"batch:": idx,'loss':loss})
398
+
399
+ #torch.save(model.state_dict(), 'model_weights.pth')
400
+
401
+
up_unet.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ import sympy as sp
6
+ import wandb
7
+ from PIL import Image
8
+ from datasets import load_dataset
9
+ from torchvision import transforms
10
+
11
+ from down_unet import down_model
12
+ '''上面的网络需要接受三个信息,上下采样模块需要重写,两次宽高减2后接受三个信息,renet块加入时间信息,'''
13
+
14
+ class conv_block(nn.Module): #一个下采样模块包含两个卷积层,深度channel从1-64-128-256这样[B,C,H,W]-->[B,C_DIM,H-2,W-2]
15
+ def __init__(self,in_channel,num_heads,channel_dim,use ="down"):
16
+ super(conv_block,self).__init__() #in_channel输入通道数,channle_dim输出通道数,一个块减少2
17
+
18
+
19
+ self.in_channel = in_channel
20
+ self.num_heads = num_heads
21
+ self.channel_dim = channel_dim
22
+ self.use = use
23
+
24
+ self.GN = nn.GroupNorm(num_groups=4, num_channels=in_channel) #这个channel指的是输入通道数
25
+ # num_groups 是组数(2,4,8)输入特征的通道分成多少组进行归一化,num_channels 是输入的通道数
26
+ self.conv = nn.Conv2d(in_channels=in_channel, out_channels=in_channel, kernel_size=3,
27
+ stride=1, padding=1, bias=False)
28
+ self.silu = nn.SiLU()
29
+ self.attention = nn.MultiheadAttention(embed_dim=self.in_channel, num_heads=self.num_heads)
30
+
31
+ if self.use == "down":
32
+ self.conv1 = nn.Conv2d(in_channels=self.in_channel, out_channels=self.channel_dim, kernel_size=3,
33
+ stride=1, padding=0, bias=False)
34
+ elif self.use =="up":
35
+ self.conv1 = nn.Conv2d(in_channels=self.in_channel, out_channels=self.channel_dim, kernel_size=3,
36
+ stride=1, padding=2, bias=False)
37
+
38
+ def resnet_block(self,X): #隐藏层使用和输入一样的大小
39
+
40
+ out = self.GN(X)
41
+ out = self.conv(out)
42
+ out = self.silu(out) #这里要加入时间信息
43
+
44
+ out = self.GN(out)
45
+ out = self.conv(out)
46
+ out = self.silu(out)
47
+
48
+ return out + X
49
+
50
+ def attention_block(self,X):
51
+
52
+ B,C,H,W = X.size()
53
+
54
+ out = self.GN(X)
55
+ out = self.conv(out)
56
+
57
+ out = out.view(B, self.in_channel, H * W).transpose(1, 2) # 将输入重构为 [B, L, C],其中 L = H * W
58
+ out, weights = self.attention(out, out, out)
59
+ out = out.transpose(1, 2).view(B, self.in_channel, H, W)
60
+
61
+ out = self.conv(out)
62
+
63
+ return out+X
64
+
65
+ def forward(self,X):
66
+
67
+ out = self.resnet_block(X)
68
+ out = self.attention_block(out)
69
+ out = self.conv1(out)
70
+
71
+ return out
72
+
73
+
74
+
75
+ class down_block(nn.Module): #宽高减4,加入两个信息,然后然后除以2
76
+ def __init__(self,in_channel,channel_dim): #in_channel4-->channel_dim64
77
+ super(down_block,self).__init__()
78
+
79
+ self.channel_dim = channel_dim
80
+
81
+ self.in_channel = in_channel
82
+
83
+ self.block1 = conv_block(in_channel=self.in_channel,num_heads=4,
84
+ channel_dim=self.channel_dim,use="down")
85
+ self.block2 = conv_block(in_channel=self.channel_dim, num_heads=4,
86
+ channel_dim=self.channel_dim, use="down")
87
+
88
+ self.return_conv = nn.Conv2d(in_channels=self.channel_dim*2,out_channels=self.channel_dim,kernel_size=1,
89
+ stride=1,padding=0,bias=False)
90
+
91
+ self.attention = nn.MultiheadAttention(embed_dim=self.channel_dim, num_heads=4)
92
+
93
+ self.down_pool = nn.Conv2d(in_channels=self.channel_dim, out_channels=self.channel_dim, kernel_size=2,
94
+ stride=2, padding=0, bias=False)
95
+
96
+ def caculate_attention(self,X_q,Y_kv):
97
+
98
+ B,C,H,W = X_q.size()
99
+
100
+ X_q = X_q.view(B, self.channel_dim, H * W).transpose(1, 2) # 将输入重构为 [B, L, C],其中 L = H * W
101
+ Y_kv = Y_kv.view(B, self.channel_dim, H * W).transpose(1, 2)
102
+
103
+ out, weights = self.attention(X_q, Y_kv, Y_kv)
104
+ out = out.transpose(1, 2).view(B, self.channel_dim, H, W)
105
+
106
+ return out
107
+
108
+ def forward(self,X,attention_out,pos_encoding): #输入[1,4,128,128],输出[1.64,124,124]-->[1,64,62,62]
109
+
110
+ out = self.block1(X)
111
+ for_skip_connection = self.block2(out)
112
+
113
+ out = torch.cat((for_skip_connection,pos_encoding),dim=1)
114
+ out = self.return_conv(out)
115
+
116
+ out = self.caculate_attention(X_q=attention_out,Y_kv=out)
117
+
118
+ out = self.down_pool(out)
119
+
120
+ return out,for_skip_connection
121
+
122
+ '''
123
+ X = torch.randn(1,4,128,128)
124
+ attention_out = torch.randn(1,64,124,124)
125
+ pos_encoding = torch.randn(1,64,124,124)
126
+ model = down_block(4,64,4)
127
+ out = model(X,attention_out,pos_encoding)
128
+ print(out.shape)
129
+ '''
130
+
131
+ class up_block(nn.Module):
132
+ def __init__(self,in_channel): #这里的in_channel指的是cat之后的通道数
133
+ super(up_block,self).__init__()
134
+ self.in_channel = in_channel
135
+
136
+
137
+
138
+ self.block1 = conv_block(in_channel=in_channel*2, num_heads=4,
139
+ channel_dim=in_channel,use="up")
140
+ self.block2 = conv_block(in_channel=in_channel, num_heads=4,
141
+ channel_dim=in_channel,use="up")
142
+ self.up_pool = nn.ConvTranspose2d(self.in_channel*2, self.in_channel,
143
+ kernel_size=2, stride=2)
144
+
145
+ self.return_conv = nn.Conv2d(in_channels=self.in_channel * 2, out_channels=self.in_channel, kernel_size=1,
146
+ stride=1, padding=0, bias=False)
147
+
148
+ self.attention = nn.MultiheadAttention(embed_dim=self.in_channel, num_heads=4)
149
+
150
+ def caculate_attention(self,X_q,Y_kv):
151
+
152
+ B,C,H,W = X_q.size()
153
+
154
+ X_q = X_q.view(B, self.in_channel, H * W).transpose(1, 2) # 将输入重构为 [B, L, C],其中 L = H * W
155
+ Y_kv = Y_kv.view(B, self.in_channel, H * W).transpose(1, 2)
156
+
157
+ out, weights = self.attention(X_q, Y_kv, Y_kv)
158
+ out = out.transpose(1, 2).view(B, self.in_channel, H, W)
159
+
160
+ return out
161
+
162
+ def forward(self,input,input_skip,attention_out,pos_encoding): #先对输入进行上采样,然后和跳跃的拼接,之后经过两个block
163
+
164
+
165
+ after_transposed = self.up_pool(input) #上采样得到的大小
166
+
167
+ after_cat = torch.cat((after_transposed, input_skip), dim=1) # 拼接张量
168
+ after_cat = self.return_conv(after_cat)
169
+ after_cat = torch.cat((after_cat, pos_encoding), dim=1)
170
+ after_cat = self.return_conv(after_cat)
171
+
172
+ out = self.caculate_attention(X_q=attention_out, Y_kv=after_cat)
173
+
174
+ out = self.block2(out) #通道数不用再降低了
175
+ out = self.block2(out)
176
+
177
+ return out
178
+
179
+ '''
180
+ X = torch.randn(1,128,62,62)
181
+ input_skip = torch.randn(1,64,124,124)
182
+ attention_out = torch.randn(1,64,124,124)
183
+ pos_encoding = torch.randn(1,64,124,124)
184
+ model = up_block(in_channel=64,num_head=4)
185
+ out = model(X,input_skip,attention_out,pos_encoding)
186
+ print(out.shape) # torch.Size([1, 64, 128, 128])
187
+ '''
188
+
189
+ class up_model(nn.Module):
190
+ def __init__(self):
191
+ super(up_model,self).__init__()
192
+
193
+ self.down_model = down_model()
194
+
195
+ self.start_conv = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=1, stride=1)
196
+
197
+ self.down_block1 = down_block(4,64)
198
+ self.down_block2 = down_block(64,128)
199
+ self.down_block3 = down_block(128,256)
200
+ self.down_block4 = down_block(256,512)
201
+
202
+ self.bottle_conv = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1, stride=1)
203
+
204
+ self.up_block4 = up_block(512)
205
+ self.up_block3 = up_block(256)
206
+ self.up_block2 = up_block(128)
207
+ self.up_block1 = up_block(64)
208
+
209
+ self.final_conv = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1)
210
+
211
+ def forward(self,input): #这个地方的输入一定要除的尽
212
+
213
+ X, attention_out1, attention_out2, attention_out3, attention_out4, attention_out5, attention_out6, attention_out7, attention_out8, pos_encoding1, pos_encoding2, pos_encoding3, pos_encoding4, pos_encoding5, pos_encoding6, pos_encoding7, pos_encoding8 =self.down_model(input)
214
+
215
+ input = self.start_conv(input)
216
+
217
+ out,for_skip1= self.down_block1(input,attention_out8,pos_encoding8)
218
+
219
+ out,for_skip2 = self.down_block1(out, attention_out7, pos_encoding7)
220
+
221
+ out,for_skip3 = self.down_block1(out, attention_out6, pos_encoding6)
222
+
223
+ out,for_skip4 = self.down_block1(out, attention_out5, pos_encoding5)
224
+
225
+ out = self.bottle_conv(out)
226
+ # print("bottle",out.shape)
227
+
228
+ out = self.up_block4(out, for_skip4, attention_out4,pos_encoding4)
229
+
230
+ out = self.up_block4(out, for_skip3, attention_out3, pos_encoding3)
231
+
232
+ out = self.up_block4(out, for_skip2, attention_out2, pos_encoding2)
233
+
234
+ out = self.up_block4(out, for_skip1, attention_out1, pos_encoding1)
235
+
236
+ out = self.final_conv(out)
237
+
238
+ return out
239
+