Update models/restormer_arch.py
Browse files- models/restormer_arch.py +36 -18
models/restormer_arch.py
CHANGED
|
@@ -107,12 +107,11 @@ class Attention(nn.Module):
|
|
| 107 |
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
| 108 |
|
| 109 |
|
| 110 |
-
|
| 111 |
def forward(self, x):
|
| 112 |
b,c,h,w = x.shape
|
| 113 |
|
| 114 |
qkv = self.qkv_dwconv(self.qkv(x))
|
| 115 |
-
q,k,v = qkv.chunk(3, dim=1)
|
| 116 |
|
| 117 |
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 118 |
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
|
@@ -200,8 +199,8 @@ class Restormer(nn.Module):
|
|
| 200 |
heads = [1,2,4,8],
|
| 201 |
ffn_expansion_factor = 2.66,
|
| 202 |
bias = False,
|
| 203 |
-
LayerNorm_type = 'WithBias',
|
| 204 |
-
dual_pixel_task = True
|
| 205 |
):
|
| 206 |
|
| 207 |
super(Restormer, self).__init__()
|
|
@@ -257,7 +256,7 @@ class Restormer(nn.Module):
|
|
| 257 |
inp_enc_level4 = self.down3_4(out_enc_level3)
|
| 258 |
latent = self.latent(inp_enc_level4)
|
| 259 |
|
| 260 |
-
|
| 261 |
inp_dec_level3 = self.up4_3(latent)
|
| 262 |
inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
|
| 263 |
inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
|
|
@@ -274,7 +273,8 @@ class Restormer(nn.Module):
|
|
| 274 |
|
| 275 |
out_dec_level1 = self.refinement(out_dec_level1)
|
| 276 |
|
| 277 |
-
|
|
|
|
| 278 |
out_dec_level1 = self.output(out_dec_level1)
|
| 279 |
|
| 280 |
return out_dec_level1
|
|
@@ -283,6 +283,9 @@ class Restormer(nn.Module):
|
|
| 283 |
|
| 284 |
if __name__ == '__main__':
|
| 285 |
from torchtoolbox.tools import summary
|
|
|
|
|
|
|
|
|
|
| 286 |
model = Restormer(
|
| 287 |
inp_channels=6,
|
| 288 |
out_channels=3,
|
|
@@ -293,16 +296,31 @@ if __name__ == '__main__':
|
|
| 293 |
heads = [1,2,4,8],
|
| 294 |
ffn_expansion_factor = 2.66,
|
| 295 |
bias = False,
|
| 296 |
-
LayerNorm_type = 'WithBias',
|
| 297 |
-
dual_pixel_task = True
|
| 298 |
)
|
| 299 |
-
|
| 300 |
-
print(
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
| 108 |
|
| 109 |
|
|
|
|
| 110 |
def forward(self, x):
|
| 111 |
b,c,h,w = x.shape
|
| 112 |
|
| 113 |
qkv = self.qkv_dwconv(self.qkv(x))
|
| 114 |
+
q,k,v = qkv.chunk(3, dim=1)
|
| 115 |
|
| 116 |
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 117 |
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
|
|
|
| 199 |
heads = [1,2,4,8],
|
| 200 |
ffn_expansion_factor = 2.66,
|
| 201 |
bias = False,
|
| 202 |
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
| 203 |
+
dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
|
| 204 |
):
|
| 205 |
|
| 206 |
super(Restormer, self).__init__()
|
|
|
|
| 256 |
inp_enc_level4 = self.down3_4(out_enc_level3)
|
| 257 |
latent = self.latent(inp_enc_level4)
|
| 258 |
|
| 259 |
+
|
| 260 |
inp_dec_level3 = self.up4_3(latent)
|
| 261 |
inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
|
| 262 |
inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
|
|
|
|
| 273 |
|
| 274 |
out_dec_level1 = self.refinement(out_dec_level1)
|
| 275 |
|
| 276 |
+
if self.dual_pixel_task:
|
| 277 |
+
out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
|
| 278 |
out_dec_level1 = self.output(out_dec_level1)
|
| 279 |
|
| 280 |
return out_dec_level1
|
|
|
|
| 283 |
|
| 284 |
if __name__ == '__main__':
|
| 285 |
from torchtoolbox.tools import summary
|
| 286 |
+
# NOTE: The thop and torchtoolbox imports might require installation (pip install thop torchtoolbox)
|
| 287 |
+
# The summary function from torchtoolbox is used here.
|
| 288 |
+
|
| 289 |
model = Restormer(
|
| 290 |
inp_channels=6,
|
| 291 |
out_channels=3,
|
|
|
|
| 296 |
heads = [1,2,4,8],
|
| 297 |
ffn_expansion_factor = 2.66,
|
| 298 |
bias = False,
|
| 299 |
+
LayerNorm_type = 'WithBias', ## Other option 'BiasFree'
|
| 300 |
+
dual_pixel_task = True ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
|
| 301 |
)
|
| 302 |
+
|
| 303 |
+
print("--- Model Summary (Requires torchtoolbox) ---")
|
| 304 |
+
# print(summary(model,torch.rand((1, 6, 256, 256)))) # Uncomment if torchtoolbox is installed
|
| 305 |
+
|
| 306 |
+
print("\n--- Model Profiling (Requires thop) ---")
|
| 307 |
+
try:
|
| 308 |
+
from thop import profile
|
| 309 |
+
input = torch.rand((1, 6, 256, 256))
|
| 310 |
+
gflops,params = profile(model,inputs=(input,), verbose=False)
|
| 311 |
+
gflops = gflops*2 / 10**9
|
| 312 |
+
params = params / 10**6
|
| 313 |
+
print(f"GFLOPS: {gflops:.4f}")
|
| 314 |
+
print(f"Params (M): {params:.4f}")
|
| 315 |
+
except ImportError:
|
| 316 |
+
print("Note: 'thop' library not found. Skipping GFLOPS/Params calculation.")
|
| 317 |
+
except Exception as e:
|
| 318 |
+
print(f"An error occurred during profiling: {e}")
|
| 319 |
+
|
| 320 |
+
# Example of a simple forward pass test
|
| 321 |
+
try:
|
| 322 |
+
input_tensor = torch.rand((1, 6, 256, 256))
|
| 323 |
+
output_tensor = model(input_tensor)
|
| 324 |
+
print(f"\nForward Pass Test: Input Shape {input_tensor.shape}, Output Shape {output_tensor.shape}")
|
| 325 |
+
except Exception as e:
|
| 326 |
+
print(f"\nForward Pass failed: {e}")
|