PubAccount commited on
Commit
fbefebb
·
verified ·
1 Parent(s): ead0a80

Update networks/decoder.py

Browse files
Files changed (1) hide show
  1. networks/decoder.py +24 -1
networks/decoder.py CHANGED
@@ -5,9 +5,11 @@ import torch.nn as nn
5
  from torch import Tensor
6
  import torch.nn.functional as F
7
 
 
 
 
8
  class PPM(nn.ModuleList):
9
  """Pooling Pyramid Module used in PSPNet.
10
-
11
  Args:
12
  pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
13
  Module.
@@ -155,3 +157,24 @@ class Decoder(nn.Module):
155
  return torch.cat(psp_outs, dim=1)
156
  else:
157
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from torch import Tensor
6
  import torch.nn.functional as F
7
 
8
+ from mmcv.cnn import ConvModule
9
+ from .height_head import resize
10
+
11
  class PPM(nn.ModuleList):
12
  """Pooling Pyramid Module used in PSPNet.
 
13
  Args:
14
  pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
15
  Module.
 
157
  return torch.cat(psp_outs, dim=1)
158
  else:
159
  return x
160
+
161
+ if __name__ == '__main__':
162
+ # model = Decoder(in_channel=320)
163
+ # input_data = torch.randn(1, 320, 32, 32)
164
+ # res = [torch.randn(1, 512, 32, 32), torch.randn(1, 256, 64, 64), torch.randn(1, 128, 128, 128)]
165
+ # output = model(input_data, res)
166
+ # print(output.shape)
167
+
168
+ model = Decoder(in_channel=320, short_cut_channels=None)
169
+ # input_data = torch.randn(1, 320, 32, 32)
170
+ # output = model(input_data, None)
171
+ # flops, params = get_model_complexity_info(model, (320, 32, 32))
172
+ # print(f"参数量: {params}")
173
+ # print(f"计算量: {flops}")
174
+ # print("-" * 30)
175
+ # print(output.shape)
176
+
177
+ # model = Decoder(in_channel=320, short_cut_channels=None, psp_channel=-1)
178
+ # input_data = torch.randn(2, 320, 32, 32)
179
+ # output = model(input_data, None)
180
+ # print(output.shape)