mohamed12ahmed commited on
Commit
03d31ed
·
verified ·
1 Parent(s): bac2bc4

Update models/restormer_arch.py

Browse files
Files changed (1) hide show
  1. models/restormer_arch.py +36 -18
models/restormer_arch.py CHANGED
@@ -107,12 +107,11 @@ class Attention(nn.Module):
107
  self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
108
 
109
 
110
-
111
  def forward(self, x):
112
  b,c,h,w = x.shape
113
 
114
  qkv = self.qkv_dwconv(self.qkv(x))
115
- q,k,v = qkv.chunk(3, dim=1)
116
 
117
  q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
118
  k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
@@ -200,8 +199,8 @@ class Restormer(nn.Module):
200
  heads = [1,2,4,8],
201
  ffn_expansion_factor = 2.66,
202
  bias = False,
203
- LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
204
- dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
205
  ):
206
 
207
  super(Restormer, self).__init__()
@@ -257,7 +256,7 @@ class Restormer(nn.Module):
257
  inp_enc_level4 = self.down3_4(out_enc_level3)
258
  latent = self.latent(inp_enc_level4)
259
 
260
-
261
  inp_dec_level3 = self.up4_3(latent)
262
  inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
263
  inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
@@ -274,7 +273,8 @@ class Restormer(nn.Module):
274
 
275
  out_dec_level1 = self.refinement(out_dec_level1)
276
 
277
- out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
 
278
  out_dec_level1 = self.output(out_dec_level1)
279
 
280
  return out_dec_level1
@@ -283,6 +283,9 @@ class Restormer(nn.Module):
283
 
284
  if __name__ == '__main__':
285
  from torchtoolbox.tools import summary
 
 
 
286
  model = Restormer(
287
  inp_channels=6,
288
  out_channels=3,
@@ -293,16 +296,31 @@ if __name__ == '__main__':
293
  heads = [1,2,4,8],
294
  ffn_expansion_factor = 2.66,
295
  bias = False,
296
- LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
297
- dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
298
  )
299
- # model = Restormer(num_blocks=[4, 6, 6, 8], num_heads=[1, 2, 4, 8], channels=[48, 96, 192, 384], num_refinement=4, expansion_factor=2.66)
300
- print(summary(model,torch.rand((1, 6, 256, 256))))
301
-
302
- from thop import profile
303
- input = torch.rand((1, 6, 256, 256))
304
- gflops,params = profile(model,inputs=(input,))
305
- gflops = gflops*2 / 10**9
306
- params = params / 10**6
307
- print(gflops,'==============')
308
- print(params,'==============')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
108
 
109
 
 
110
  def forward(self, x):
111
  b,c,h,w = x.shape
112
 
113
  qkv = self.qkv_dwconv(self.qkv(x))
114
+ q,k,v = qkv.chunk(3, dim=1)
115
 
116
  q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
117
  k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
 
199
  heads = [1,2,4,8],
200
  ffn_expansion_factor = 2.66,
201
  bias = False,
202
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
203
+ dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
204
  ):
205
 
206
  super(Restormer, self).__init__()
 
256
  inp_enc_level4 = self.down3_4(out_enc_level3)
257
  latent = self.latent(inp_enc_level4)
258
 
259
+
260
  inp_dec_level3 = self.up4_3(latent)
261
  inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
262
  inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
 
273
 
274
  out_dec_level1 = self.refinement(out_dec_level1)
275
 
276
+ if self.dual_pixel_task:
277
+ out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
278
  out_dec_level1 = self.output(out_dec_level1)
279
 
280
  return out_dec_level1
 
283
 
284
  if __name__ == '__main__':
285
  from torchtoolbox.tools import summary
286
+ # NOTE: The thop and torchtoolbox imports might require installation (pip install thop torchtoolbox)
287
+ # The summary function from torchtoolbox is used here.
288
+
289
  model = Restormer(
290
  inp_channels=6,
291
  out_channels=3,
 
296
  heads = [1,2,4,8],
297
  ffn_expansion_factor = 2.66,
298
  bias = False,
299
+ LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
300
+ dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
301
  )
302
+
303
+ print("--- Model Summary (Requires torchtoolbox) ---")
304
+ # print(summary(model,torch.rand((1, 6, 256, 256)))) # Uncomment if torchtoolbox is installed
305
+
306
+ print("\n--- Model Profiling (Requires thop) ---")
307
+ try:
308
+ from thop import profile
309
+ input = torch.rand((1, 6, 256, 256))
310
+ gflops,params = profile(model,inputs=(input,), verbose=False)
311
+ gflops = gflops*2 / 10**9
312
+ params = params / 10**6
313
+ print(f"GFLOPS: {gflops:.4f}")
314
+ print(f"Params (M): {params:.4f}")
315
+ except ImportError:
316
+ print("Note: 'thop' library not found. Skipping GFLOPS/Params calculation.")
317
+ except Exception as e:
318
+ print(f"An error occurred during profiling: {e}")
319
+
320
+ # Example of a simple forward pass test
321
+ try:
322
+ input_tensor = torch.rand((1, 6, 256, 256))
323
+ output_tensor = model(input_tensor)
324
+ print(f"\nForward Pass Test: Input Shape {input_tensor.shape}, Output Shape {output_tensor.shape}")
325
+ except Exception as e:
326
+ print(f"\nForward Pass failed: {e}")