jiachen commited on
Commit
0e9d4e8
·
1 Parent(s): 42558c6

promptxrestormer

Browse files
__pycache__/app.cpython-38.pyc ADDED
Binary file (361 Bytes). View file
 
app.py CHANGED
@@ -1,7 +1,116 @@
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
  import gradio as gr
3
 
4
+ import numpy as np
 
5
 
6
+ import torch
7
+
8
+
9
+ from PIL import Image
10
+ from torchvision.transforms import ToTensor
11
+
12
+ from net.prompt_xrestormer import PromptXRestormer
13
+ import lightning.pytorch as pl
14
+
15
+
16
+
17
+ # crop an image to the multiple of base
18
+ def crop_img(image, base=64):
19
+ h = image.shape[0]
20
+ w = image.shape[1]
21
+ crop_h = h % base
22
+ crop_w = w % base
23
+ return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]
24
+
25
+ class PromptXRestormerIRModel(pl.LightningModule):
26
+ def __init__(self):
27
+ super().__init__()
28
+ self.net = PromptXRestormer(
29
+ inp_channels=3,
30
+ out_channels=3,
31
+ dim = 48,
32
+ num_blocks = [2,4,4,4],
33
+ num_refinement_blocks = 4,
34
+ channel_heads= [1,1,1,1],
35
+ spatial_heads= [1,2,4,8],
36
+ overlap_ratio= [0.5, 0.5, 0.5, 0.5],
37
+ ffn_expansion_factor = 2.66,
38
+ bias = False,
39
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
40
+ dual_pixel_task = False, ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
41
+ scale = 1,prompt = True
42
+ )
43
+
44
+ def forward(self,x):
45
+ return self.net(x)
46
+
47
+ def np_to_pil(img_np):
48
+ """
49
+ Converts image in np.array format to PIL image.
50
+
51
+ From C x W x H [0..1] to W x H x C [0...255]
52
+ :param img_np:
53
+ :return:
54
+ """
55
+ ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
56
+
57
+ if img_np.shape[0] == 1:
58
+ ar = ar[0]
59
+ else:
60
+ assert img_np.shape[0] == 3, img_np.shape
61
+ ar = ar.transpose(1, 2, 0)
62
+
63
+ return Image.fromarray(ar)
64
+
65
+ def torch_to_np(img_var):
66
+ """
67
+ Converts an image in torch.Tensor format to np.array.
68
+
69
+ From 1 x C x W x H [0..1] to C x W x H [0..1]
70
+ :param img_var:
71
+ :return:
72
+ """
73
+ return img_var.detach().cpu().numpy()[0]
74
+
75
+
76
+
77
+
78
+ def restore_image(input_img):
79
+ np.random.seed(0)
80
+ torch.manual_seed(0)
81
+ torch.cuda.set_device(0)
82
+
83
+ ckpt_path = "/home/jiachen/MyGradio/ckpt/promptxrestormer_epoch=64-step=578630.ckpt"
84
+ print("CKPT name : {}".format(ckpt_path))
85
+
86
+ net = PromptXRestormerIRModel().load_from_checkpoint(ckpt_path).cuda()
87
+ net.eval()
88
+
89
+ #degraded_path = "/home/jiachen/MyGradio/test_images/rain-070.png"
90
+
91
+ degraded_img = crop_img(input_img.convert('RGB'), base=16)
92
+ toTensor = ToTensor()
93
+ degraded_img = toTensor(degraded_img)
94
+ print(degraded_img.shape)
95
+
96
+ with torch.no_grad():
97
+ degraded_img = degraded_img.unsqueeze(0).cuda()
98
+
99
+ _, _, H_old, W_old = degraded_img.shape
100
+
101
+
102
+ h_pad = (H_old // 64 + 1) * 64 - H_old
103
+ w_pad = (W_old // 64 + 1) * 64 - W_old
104
+ degrad_img = torch.cat([degraded_img, torch.flip(degraded_img, [2])], 2)[:,:,:H_old+h_pad,:]
105
+ degrad_img = torch.cat([degraded_img, torch.flip(degraded_img, [3])], 3)[:,:,:,:W_old+w_pad]
106
+
107
+ print(degrad_img.shape)
108
+ restored = net(degrad_img)
109
+ restored = restored[:,:,:H_old:,:W_old]
110
+
111
+ restored_image = torch_to_np(restored)
112
+
113
+ return restored_image
114
+
115
+ demo = gr.Interface(restore_image, gr.Image(), "image")
116
+ demo.launch()
ckpt/promptxrestormer_epoch=64-step=578630.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31eeeab21dace516dec55e5d51e97f4ce30c0fcce86ce36b729ca480175e23c7
3
+ size 424348801
flagged/log.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name,output,flag,username,timestamp
2
+ jiachen fdsf,Hello jiachen fdsf!!,,,2024-08-06 11:38:35.977480
3
+ jiachen fdsf,Hello jiachen fdsf!!,,,2024-08-06 11:38:38.296024
net/__pycache__/prompt_xrestormer.cpython-38.pyc ADDED
Binary file (19.1 kB). View file
 
net/prompt_xrestormer.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import einsum
4
+ import torch.nn.functional as F
5
+ from pdb import set_trace as stx
6
+ import numbers
7
+ from einops import rearrange
8
+ import math
9
+
10
+ def to(x):
11
+ return {'device': x.device, 'dtype': x.dtype}
12
+
13
+ def pair(x):
14
+ return (x, x) if not isinstance(x, tuple) else x
15
+
16
+ def expand_dim(t, dim, k):
17
+ t = t.unsqueeze(dim = dim)
18
+ expand_shape = [-1] * len(t.shape)
19
+ expand_shape[dim] = k
20
+ return t.expand(*expand_shape)
21
+
22
+ def rel_to_abs(x):
23
+ b, l, m = x.shape
24
+ r = (m + 1) // 2
25
+
26
+ col_pad = torch.zeros((b, l, 1), **to(x))
27
+ x = torch.cat((x, col_pad), dim = 2)
28
+ flat_x = rearrange(x, 'b l c -> b (l c)')
29
+ flat_pad = torch.zeros((b, m - l), **to(x))
30
+ flat_x_padded = torch.cat((flat_x, flat_pad), dim = 1)
31
+ final_x = flat_x_padded.reshape(b, l + 1, m)
32
+ final_x = final_x[:, :l, -r:]
33
+ return final_x
34
+
35
+ def relative_logits_1d(q, rel_k):
36
+ b, h, w, _ = q.shape
37
+ r = (rel_k.shape[0] + 1) // 2
38
+
39
+ logits = einsum('b x y d, r d -> b x y r', q, rel_k)
40
+ logits = rearrange(logits, 'b x y r -> (b x) y r')
41
+ logits = rel_to_abs(logits)
42
+
43
+ logits = logits.reshape(b, h, w, r)
44
+ logits = expand_dim(logits, dim = 2, k = r)
45
+ return logits
46
+
47
+ class RelPosEmb(nn.Module):
48
+ def __init__(
49
+ self,
50
+ block_size,
51
+ rel_size,
52
+ dim_head
53
+ ):
54
+ super().__init__()
55
+ height = width = rel_size
56
+ scale = dim_head ** -0.5
57
+
58
+ self.block_size = block_size
59
+ self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)
60
+ self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)
61
+
62
+ def forward(self, q):
63
+ block = self.block_size
64
+
65
+ q = rearrange(q, 'b (x y) c -> b x y c', x = block)
66
+ rel_logits_w = relative_logits_1d(q, self.rel_width)
67
+ rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')
68
+
69
+ q = rearrange(q, 'b x y d -> b y x d')
70
+ rel_logits_h = relative_logits_1d(q, self.rel_height)
71
+ rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')
72
+ return rel_logits_w + rel_logits_h
73
+
74
+ ##########################################################################
75
+ ## Layer Norm
76
+
77
+ def to_3d(x):
78
+ return rearrange(x, 'b c h w -> b (h w) c')
79
+
80
+ def to_4d(x,h,w):
81
+ return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
82
+
83
+ class BiasFree_LayerNorm(nn.Module):
84
+ def __init__(self, normalized_shape):
85
+ super(BiasFree_LayerNorm, self).__init__()
86
+ if isinstance(normalized_shape, numbers.Integral):
87
+ normalized_shape = (normalized_shape,)
88
+ normalized_shape = torch.Size(normalized_shape)
89
+
90
+ assert len(normalized_shape) == 1
91
+
92
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
93
+ self.normalized_shape = normalized_shape
94
+
95
+ def forward(self, x):
96
+ sigma = x.var(-1, keepdim=True, unbiased=False)
97
+ return x / torch.sqrt(sigma+1e-5) * self.weight
98
+
99
+ class WithBias_LayerNorm(nn.Module):
100
+ def __init__(self, normalized_shape):
101
+ super(WithBias_LayerNorm, self).__init__()
102
+ if isinstance(normalized_shape, numbers.Integral):
103
+ normalized_shape = (normalized_shape,)
104
+ normalized_shape = torch.Size(normalized_shape)
105
+
106
+ assert len(normalized_shape) == 1
107
+
108
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
109
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
110
+ self.normalized_shape = normalized_shape
111
+
112
+ def forward(self, x):
113
+ mu = x.mean(-1, keepdim=True)
114
+ sigma = x.var(-1, keepdim=True, unbiased=False)
115
+ return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
116
+
117
+ class LayerNorm(nn.Module):
118
+ def __init__(self, dim, LayerNorm_type):
119
+ super(LayerNorm, self).__init__()
120
+ if LayerNorm_type =='BiasFree':
121
+ self.body = BiasFree_LayerNorm(dim)
122
+ else:
123
+ self.body = WithBias_LayerNorm(dim)
124
+
125
+ def forward(self, x):
126
+ h, w = x.shape[-2:]
127
+ return to_4d(self.body(to_3d(x)), h, w)
128
+
129
+ ##########################################################################
130
+ ## Gated-Dconv Feed-Forward Network (GDFN)
131
+ class FeedForward(nn.Module):
132
+ def __init__(self, dim, ffn_expansion_factor, bias):
133
+ super(FeedForward, self).__init__()
134
+
135
+ hidden_features = int(dim*ffn_expansion_factor)
136
+
137
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
138
+
139
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
140
+
141
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
142
+
143
+ def forward(self, x):
144
+ x = self.project_in(x)
145
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
146
+ x = F.gelu(x1) * x2
147
+ x = self.project_out(x)
148
+ return x
149
+
150
+
151
+ ##########################################################################
152
+ ## Multi-DConv Head Transposed Self-Attention (MDTA)
153
+ class ChannelAttention(nn.Module):
154
+ def __init__(self, dim, num_heads, bias):
155
+ super(ChannelAttention, self).__init__()
156
+ self.num_heads = num_heads
157
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
158
+
159
+ self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
160
+ self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
161
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
162
+
163
+ def forward(self, x):
164
+ b,c,h,w = x.shape
165
+
166
+ qkv = self.qkv_dwconv(self.qkv(x))
167
+ q,k,v = qkv.chunk(3, dim=1)
168
+
169
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
170
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
171
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
172
+
173
+ q = torch.nn.functional.normalize(q, dim=-1)
174
+ k = torch.nn.functional.normalize(k, dim=-1)
175
+
176
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
177
+ attn = attn.softmax(dim=-1)
178
+
179
+ out = (attn @ v)
180
+
181
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
182
+
183
+ out = self.project_out(out)
184
+ return out
185
+
186
+ ##########################################################################
187
+ ## Overlapping Cross-Attention (OCA)
188
+ class OCAB(nn.Module):
189
+ def __init__(self, dim, window_size, overlap_ratio, num_heads, dim_head, bias):
190
+ super(OCAB, self).__init__()
191
+ self.num_spatial_heads = num_heads
192
+ self.dim = dim
193
+ self.window_size = window_size
194
+ self.overlap_win_size = int(window_size * overlap_ratio) + window_size
195
+ self.dim_head = dim_head
196
+ self.inner_dim = self.dim_head * self.num_spatial_heads
197
+ self.scale = self.dim_head**-0.5
198
+
199
+ self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size, padding=(self.overlap_win_size-window_size)//2)
200
+ self.qkv = nn.Conv2d(self.dim, self.inner_dim*3, kernel_size=1, bias=bias)
201
+ self.project_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, bias=bias)
202
+ self.rel_pos_emb = RelPosEmb(
203
+ block_size = window_size,
204
+ rel_size = window_size + (self.overlap_win_size - window_size),
205
+ dim_head = self.dim_head
206
+ )
207
+ def forward(self, x):
208
+ b, c, h, w = x.shape
209
+ qkv = self.qkv(x)
210
+ qs, ks, vs = qkv.chunk(3, dim=1)
211
+
212
+ # spatial attention
213
+ qs = rearrange(qs, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1 = self.window_size, p2 = self.window_size)
214
+ ks, vs = map(lambda t: self.unfold(t), (ks, vs))
215
+ ks, vs = map(lambda t: rearrange(t, 'b (c j) i -> (b i) j c', c = self.inner_dim), (ks, vs))
216
+
217
+ # print(f'qs.shape:{qs.shape}, ks.shape:{ks.shape}, vs.shape:{vs.shape}')
218
+ #split heads
219
+ qs, ks, vs = map(lambda t: rearrange(t, 'b n (head c) -> (b head) n c', head = self.num_spatial_heads), (qs, ks, vs))
220
+
221
+ # attention
222
+ #print(f'qs.shape:{qs.shape}, ks.shape:{ks.shape}, vs.shape:{vs.shape}')
223
+ qs = qs * self.scale
224
+ spatial_attn = (qs @ ks.transpose(-2, -1))
225
+ spatial_attn += self.rel_pos_emb(qs)
226
+ spatial_attn = spatial_attn.softmax(dim=-1)
227
+
228
+ out = (spatial_attn @ vs)
229
+
230
+ out = rearrange(out, '(b h w head) (p1 p2) c -> b (head c) (h p1) (w p2)', head = self.num_spatial_heads, h = h // self.window_size, w = w // self.window_size, p1 = self.window_size, p2 = self.window_size)
231
+
232
+ # merge spatial and channel
233
+ out = self.project_out(out)
234
+
235
+ return out
236
+
237
+ ##########################################################################
238
+ class TransformerBlock(nn.Module):
239
+ def __init__(self, dim, window_size, overlap_ratio, num_channel_heads, num_spatial_heads, spatial_dim_head, ffn_expansion_factor, bias, LayerNorm_type):
240
+ super(TransformerBlock, self).__init__()
241
+
242
+
243
+ self.spatial_attn = OCAB(dim, window_size, overlap_ratio, num_spatial_heads, spatial_dim_head, bias)
244
+ self.channel_attn = ChannelAttention(dim, num_channel_heads, bias)
245
+
246
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
247
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
248
+ self.norm3 = LayerNorm(dim, LayerNorm_type)
249
+ self.norm4 = LayerNorm(dim, LayerNorm_type)
250
+
251
+ self.channel_ffn = FeedForward(dim, ffn_expansion_factor, bias)
252
+ self.spatial_ffn = FeedForward(dim, ffn_expansion_factor, bias)
253
+
254
+
255
+ def forward(self, x):
256
+ x = x + self.channel_attn(self.norm1(x))
257
+ x = x + self.channel_ffn(self.norm2(x))
258
+ x = x + self.spatial_attn(self.norm3(x))
259
+ x = x + self.spatial_ffn(self.norm4(x))
260
+ return x
261
+
262
+
263
+ ##########################################################################
264
+ class ChannelTransformerBlock(nn.Module):
265
+ def __init__(self, dim, num_channel_heads, ffn_expansion_factor, bias, LayerNorm_type):
266
+ super(ChannelTransformerBlock, self).__init__()
267
+
268
+ self.channel_attn = ChannelAttention(dim, num_channel_heads, bias)
269
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
270
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
271
+
272
+ self.channel_ffn = FeedForward(dim, ffn_expansion_factor, bias)
273
+
274
+ def forward(self, x):
275
+ x = x + self.channel_attn(self.norm1(x))
276
+ x = x + self.channel_ffn(self.norm2(x))
277
+ return x
278
+
279
+
280
+
281
+ ##########################################################################
282
+ ## Overlapped image patch embedding with 3x3 Conv
283
+ class OverlapPatchEmbed(nn.Module):
284
+ def __init__(self, in_c=3, embed_dim=48, bias=False):
285
+ super(OverlapPatchEmbed, self).__init__()
286
+
287
+ self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
288
+
289
+ def forward(self, x):
290
+ x = self.proj(x)
291
+
292
+ return x
293
+
294
+ ##########################################################################
295
+ ## Resizing modules
296
+ class Downsample(nn.Module):
297
+ def __init__(self, n_feat):
298
+ super(Downsample, self).__init__()
299
+
300
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
301
+ nn.PixelUnshuffle(2))
302
+
303
+ def forward(self, x):
304
+ return self.body(x)
305
+
306
+
307
+ class Upsample(nn.Module):
308
+ def __init__(self, n_feat):
309
+ super(Upsample, self).__init__()
310
+
311
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
312
+ nn.PixelShuffle(2))
313
+
314
+ def forward(self, x):
315
+ return self.body(x)
316
+
317
+
318
+ class SR_Upsample(nn.Sequential):
319
+ """SR_Upsample module.
320
+ Args:
321
+ scale (int): Scale factor. Supported scales: 2^n and 3.
322
+ num_feat (int): Channel number of features.
323
+ """
324
+
325
+ def __init__(self, scale, num_feat):
326
+ m = []
327
+
328
+ if (scale & (scale - 1)) == 0: # scale = 2^n
329
+ for _ in range(int(math.log(scale, 2))):
330
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, kernel_size = 3, stride = 1, padding = 1))
331
+ m.append(nn.PixelShuffle(2))
332
+ elif scale == 3:
333
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
334
+ m.append(nn.PixelShuffle(3))
335
+ else:
336
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
337
+ super(SR_Upsample, self).__init__(*m)
338
+
339
+
340
+ ##---------- Prompt Module -----------------------
341
+ class PromptBlock(nn.Module):
342
+ def __init__(self, window_size, overlap_ratio, num_channel_heads, num_spatial_heads,
343
+ spatial_dim_head, ffn_expansion_factor, bias, LayerNorm_type,
344
+ prompt_dim=128,prompt_len=5,prompt_size = 96,lin_dim = 192,
345
+ ):
346
+ super(PromptBlock,self).__init__()
347
+
348
+ # prompt generation
349
+ self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size))
350
+ self.linear_layer = nn.Linear(lin_dim,prompt_len)
351
+ self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False)
352
+
353
+ # prompt interaction
354
+ self.attn = ChannelTransformerBlock(dim=lin_dim + prompt_dim, window_size = window_size,
355
+ overlap_ratio=overlap_ratio, num_channel_heads=num_channel_heads,
356
+ num_spatial_heads=num_spatial_heads, spatial_dim_head = spatial_dim_head,
357
+ ffn_expansion_factor=ffn_expansion_factor, bias=bias,
358
+ LayerNorm_type=LayerNorm_type)
359
+ self.conv = nn.Conv2d(prompt_dim+lin_dim,lin_dim,kernel_size=3,stride=1,padding=1,bias=False)
360
+
361
+
362
+ def forward(self,x):
363
+ # input x shape is [B, HW, C]
364
+ B, C, H, W = x.shape
365
+ # prompt generation
366
+ emb = x.mean(dim=(-2,-1))
367
+ prompt_weights = F.softmax(self.linear_layer(emb),dim=1)
368
+ prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)
369
+ prompt = torch.sum(prompt,dim=1)
370
+ prompt = F.interpolate(prompt,(H,W),mode="bilinear", align_corners=True)
371
+ prompt = self.conv3x3(prompt)
372
+
373
+ # x shape [B, C + C_p, H, W]
374
+ x = torch.cat([x, prompt], 1)
375
+ x = self.attn(x)
376
+ x = self.conv(x)
377
+
378
+ return x
379
+
380
+ ##---------- Prompt Gen Module -----------------------
381
+ class PromptGenBlock(nn.Module):
382
+ def __init__(self,prompt_dim=128,prompt_len=5,prompt_size = 96,lin_dim = 192):
383
+ super(PromptGenBlock,self).__init__()
384
+ self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size))
385
+ self.linear_layer = nn.Linear(lin_dim,prompt_len)
386
+ self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False)
387
+
388
+
389
+ def forward(self,x):
390
+ B,C,H,W = x.shape
391
+ emb = x.mean(dim=(-2,-1))
392
+ prompt_weights = F.softmax(self.linear_layer(emb),dim=1)
393
+ prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)
394
+ prompt = torch.sum(prompt,dim=1)
395
+ prompt = F.interpolate(prompt,(H,W),mode="bilinear")
396
+ prompt = self.conv3x3(prompt)
397
+
398
+ return prompt
399
+
400
+
401
+ ##########################################################################
402
+
403
+
404
+ class PromptXRestormer(nn.Module):
405
+ def __init__(self,
406
+ inp_channels=3,
407
+ out_channels=3,
408
+ dim = 48,
409
+ num_blocks = [4,6,6,8],
410
+ num_refinement_blocks = 4,
411
+ channel_heads = [1,2,4,8],
412
+ spatial_heads = [2,2,3,4],
413
+ overlap_ratio=[0.5, 0.5, 0.5, 0.5],
414
+ window_size = 8,
415
+ spatial_dim_head = 16,
416
+ bias = False,
417
+ ffn_expansion_factor = 2.66,
418
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
419
+ dual_pixel_task = False, ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
420
+ scale = 1,
421
+ prompt = True
422
+ ):
423
+
424
+ super(PromptXRestormer, self).__init__()
425
+ print("Initializing XRestormer")
426
+ self.scale = scale
427
+
428
+ self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
429
+ self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, window_size = window_size, overlap_ratio=overlap_ratio[0], num_channel_heads=channel_heads[0], num_spatial_heads=spatial_heads[0], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
430
+
431
+ self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
432
+ self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[1], num_channel_heads=channel_heads[1], num_spatial_heads=spatial_heads[1], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
433
+
434
+ self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
435
+ self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), window_size = window_size, overlap_ratio=overlap_ratio[2], num_channel_heads=channel_heads[2], num_spatial_heads=spatial_heads[2], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
436
+
437
+ self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
438
+ self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), window_size = window_size, overlap_ratio=overlap_ratio[3], num_channel_heads=channel_heads[3], num_spatial_heads=spatial_heads[3], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
439
+
440
+ self.up4_3 = Upsample(int(dim*2**2)) ## From Level 4 to Level 3
441
+ self.reduce_chan_level3 = nn.Conv2d(int(dim*2**1) + 192, int(dim*2**2), kernel_size=1, bias=bias)
442
+ self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), window_size = window_size, overlap_ratio=overlap_ratio[2], num_channel_heads=channel_heads[2], num_spatial_heads=spatial_heads[2], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
443
+
444
+
445
+ self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
446
+ self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
447
+ self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[1], num_channel_heads=channel_heads[1], num_spatial_heads=spatial_heads[1], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
448
+
449
+ self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels)
450
+
451
+ self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[0], num_channel_heads=channel_heads[0], num_spatial_heads=spatial_heads[0], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
452
+
453
+ self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), window_size = window_size, overlap_ratio=overlap_ratio[0], num_channel_heads=channel_heads[0], num_spatial_heads=spatial_heads[0], spatial_dim_head = spatial_dim_head, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
454
+
455
+ self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
456
+
457
+ self.prompt = prompt
458
+ if prompt:
459
+ self.prompt1 = PromptGenBlock(prompt_dim=64,prompt_len=5,prompt_size = 64,lin_dim = 96)
460
+ self.prompt2 = PromptGenBlock(prompt_dim=128,prompt_len=5,prompt_size = 32,lin_dim = 192)
461
+ self.prompt3 = PromptGenBlock(prompt_dim=320,prompt_len=5,prompt_size = 16,lin_dim = 384)
462
+
463
+ self.noise_level1 = ChannelTransformerBlock(dim=int(dim*2**1)+64, num_channel_heads = 1, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)
464
+ self.reduce_noise_level1 = nn.Conv2d(int(dim*2**1)+64,int(dim*2**1),kernel_size=1,bias=bias)
465
+
466
+ self.noise_level2 = ChannelTransformerBlock(dim=int(dim*2**1) + 224, num_channel_heads = 1, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)
467
+ self.reduce_noise_level2 = nn.Conv2d(int(dim*2**1)+224,int(dim*2**2),kernel_size=1,bias=bias)
468
+
469
+ self.noise_level3 = ChannelTransformerBlock(dim=int(dim*2**2) + 512, num_channel_heads = 1, ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)
470
+ self.reduce_noise_level3 = nn.Conv2d(int(dim*2**2)+512,int(dim*2**2),kernel_size=1,bias=bias)
471
+
472
+
473
+ def forward(self, inp_img):
474
+
475
+ if self.scale > 1:
476
+ inp_img = F.interpolate(inp_img, scale_factor=self.scale, mode='bilinear', align_corners=False)
477
+
478
+ inp_enc_level1 = self.patch_embed(inp_img)
479
+ out_enc_level1 = self.encoder_level1(inp_enc_level1)
480
+
481
+ inp_enc_level2 = self.down1_2(out_enc_level1)
482
+ out_enc_level2 = self.encoder_level2(inp_enc_level2)
483
+
484
+ inp_enc_level3 = self.down2_3(out_enc_level2)
485
+ out_enc_level3 = self.encoder_level3(inp_enc_level3)
486
+
487
+ inp_enc_level4 = self.down3_4(out_enc_level3)
488
+ latent = self.latent(inp_enc_level4)
489
+ #print(latent.shape)
490
+ if self.prompt:
491
+ dec3_param = self.prompt3(latent)
492
+ latent = torch.cat([latent, dec3_param], 1)
493
+ latent = self.noise_level3(latent)
494
+ latent = self.reduce_noise_level3(latent)
495
+
496
+ #print(latent.shape)
497
+
498
+ inp_dec_level3 = self.up4_3(latent)
499
+ inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
500
+ inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
501
+ out_dec_level3 = self.decoder_level3(inp_dec_level3)
502
+
503
+ if self.prompt:
504
+ dec2_param = self.prompt2(out_dec_level3)
505
+ out_dec_level3 = torch.cat([out_dec_level3, dec2_param], 1)
506
+ out_dec_level3 = self.noise_level2(out_dec_level3)
507
+ out_dec_level3 = self.reduce_noise_level2(out_dec_level3)
508
+
509
+
510
+ inp_dec_level2 = self.up3_2(out_dec_level3)
511
+ inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
512
+ inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
513
+ out_dec_level2 = self.decoder_level2(inp_dec_level2)
514
+
515
+ if self.prompt:
516
+ dec1_param = self.prompt1(out_dec_level2)
517
+ out_dec_level2 = torch.cat([out_dec_level2, dec1_param], 1)
518
+ out_dec_level2 = self.noise_level1(out_dec_level2)
519
+ out_dec_level2 = self.reduce_noise_level1(out_dec_level2)
520
+
521
+
522
+ inp_dec_level1 = self.up2_1(out_dec_level2)
523
+ inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
524
+ out_dec_level1 = self.decoder_level1(inp_dec_level1)
525
+
526
+ out_dec_level1 = self.refinement(out_dec_level1)
527
+ out_dec_level1 = self.output(out_dec_level1) + inp_img
528
+
529
+ return out_dec_level1
530
+
531
+ if __name__ == "__main__":
532
+ model = PromptXRestormer(
533
+ inp_channels=3,
534
+ out_channels=3,
535
+ dim = 48,
536
+ num_blocks = [2,4,4,4],
537
+ num_refinement_blocks = 4,
538
+ channel_heads= [1,1,1,1],
539
+ spatial_heads= [1,2,4,8],
540
+ overlap_ratio= [0.5, 0.5, 0.5, 0.5],
541
+ ffn_expansion_factor = 2.66,
542
+ bias = False,
543
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
544
+ dual_pixel_task = False, ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
545
+ scale = 1,prompt = True
546
+ )
547
+
548
+ # torchstat
549
+ x = torch.randn(1, 3, 320, 512)
550
+ y = model(x)
551
+ print(y.shape)
552
+ # print('# model_restoration parameters: %.2f M'%(sum(param.numel() for param in model.parameters())/ 1e6))
553
+ # # stat(model, (3, 512, 512))
554
+
555
+ # from fvcore.nn import FlopCountAnalysis, flop_count_table
556
+ # input = torch.randn(1,3,64,64)
557
+ # flops = FlopCountAnalysis(model, input)
558
+ # print(flop_count_table(flops))
559
+ # print(flops.total()/1e9)
output.png ADDED
test.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import subprocess
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ import os
9
+ import torch.nn as nn
10
+
11
+ # from utils.dataset_utils import DenoiseTestDataset, DerainDehazeDataset
12
+ # from utils.val_utils import AverageMeter, compute_psnr_ssim
13
+ # from utils.image_io import save_image_tensor
14
+
15
+ from PIL import Image
16
+ from torchvision.transforms import ToTensor
17
+
18
+
19
+ import lightning.pytorch as pl
20
+ import torch.nn.functional as F
21
+
22
+ from net.prompt_xrestormer import PromptXRestormer
23
+ import json
24
+
25
+ # crop an image to the multiple of base
26
+ def crop_img(image, base=64):
27
+ h = image.shape[0]
28
+ w = image.shape[1]
29
+ crop_h = h % base
30
+ crop_w = w % base
31
+ return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]
32
+
33
+ class PromptXRestormerIRModel(pl.LightningModule):
34
+ def __init__(self):
35
+ super().__init__()
36
+ self.net = PromptXRestormer(
37
+ inp_channels=3,
38
+ out_channels=3,
39
+ dim = 48,
40
+ num_blocks = [2,4,4,4],
41
+ num_refinement_blocks = 4,
42
+ channel_heads= [1,1,1,1],
43
+ spatial_heads= [1,2,4,8],
44
+ overlap_ratio= [0.5, 0.5, 0.5, 0.5],
45
+ ffn_expansion_factor = 2.66,
46
+ bias = False,
47
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
48
+ dual_pixel_task = False, ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
49
+ scale = 1,prompt = True
50
+ )
51
+ self.loss_fn = nn.L1Loss()
52
+
53
+ def forward(self,x):
54
+ return self.net(x)
55
+
56
+ def np_to_pil(img_np):
57
+ """
58
+ Converts image in np.array format to PIL image.
59
+
60
+ From C x W x H [0..1] to W x H x C [0...255]
61
+ :param img_np:
62
+ :return:
63
+ """
64
+ ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
65
+
66
+ if img_np.shape[0] == 1:
67
+ ar = ar[0]
68
+ else:
69
+ assert img_np.shape[0] == 3, img_np.shape
70
+ ar = ar.transpose(1, 2, 0)
71
+
72
+ return Image.fromarray(ar)
73
+
74
+ def torch_to_np(img_var):
75
+ """
76
+ Converts an image in torch.Tensor format to np.array.
77
+
78
+ From 1 x C x W x H [0..1] to C x W x H [0..1]
79
+ :param img_var:
80
+ :return:
81
+ """
82
+ return img_var.detach().cpu().numpy()[0]
83
+
84
+ def save_image_tensor(image_tensor, output_path="output/"):
85
+ image_np = torch_to_np(image_tensor)
86
+ # print(image_np.shape)
87
+ p = np_to_pil(image_np)
88
+ p.save(output_path)
89
+
90
+
91
+
92
+ if __name__ == '__main__':
93
+
94
+ np.random.seed(0)
95
+ torch.manual_seed(0)
96
+ torch.cuda.set_device(0)
97
+
98
+ ckpt_path = "/home/jiachen/MyGradio/ckpt/promptxrestormer_epoch=64-step=578630.ckpt"
99
+ print("CKPT name : {}".format(ckpt_path))
100
+
101
+ net = PromptXRestormerIRModel().load_from_checkpoint(ckpt_path).cuda()
102
+ net.eval()
103
+
104
+ degraded_path = "/home/jiachen/MyGradio/test_images/rain-070.png"
105
+
106
+ degraded_img = crop_img(np.array(Image.open(degraded_path).convert('RGB')), base=16)
107
+ toTensor = ToTensor()
108
+ degraded_img = toTensor(degraded_img)
109
+ print(degraded_img.shape)
110
+
111
+ with torch.no_grad():
112
+ degraded_img = degraded_img.unsqueeze(0).cuda()
113
+
114
+ _, _, H_old, W_old = degraded_img.shape
115
+
116
+
117
+ h_pad = (H_old // 64 + 1) * 64 - H_old
118
+ w_pad = (W_old // 64 + 1) * 64 - W_old
119
+ degrad_img = torch.cat([degraded_img, torch.flip(degraded_img, [2])], 2)[:,:,:H_old+h_pad,:]
120
+ degrad_img = torch.cat([degraded_img, torch.flip(degraded_img, [3])], 3)[:,:,:,:W_old+w_pad]
121
+
122
+ print(degrad_img.shape)
123
+ restored = net(degrad_img)
124
+ restored = restored[:,:,:H_old:,:W_old]
125
+
126
+ save_image_tensor(restored, "output.png")
127
+
128
+
129
+
130
+
131
+
132
+
133
+
test_images/rain-070.png ADDED