Spaces:
Sleeping
Sleeping
Update networks/decoder.py
Browse files- networks/decoder.py +24 -1
networks/decoder.py
CHANGED
|
@@ -5,9 +5,11 @@ import torch.nn as nn
|
|
| 5 |
from torch import Tensor
|
| 6 |
import torch.nn.functional as F
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
class PPM(nn.ModuleList):
|
| 9 |
"""Pooling Pyramid Module used in PSPNet.
|
| 10 |
-
|
| 11 |
Args:
|
| 12 |
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
| 13 |
Module.
|
|
@@ -155,3 +157,24 @@ class Decoder(nn.Module):
|
|
| 155 |
return torch.cat(psp_outs, dim=1)
|
| 156 |
else:
|
| 157 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from torch import Tensor
|
| 6 |
import torch.nn.functional as F
|
| 7 |
|
| 8 |
+
from mmcv.cnn import ConvModule
|
| 9 |
+
from .height_head import resize
|
| 10 |
+
|
| 11 |
class PPM(nn.ModuleList):
|
| 12 |
"""Pooling Pyramid Module used in PSPNet.
|
|
|
|
| 13 |
Args:
|
| 14 |
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
| 15 |
Module.
|
|
|
|
| 157 |
return torch.cat(psp_outs, dim=1)
|
| 158 |
else:
|
| 159 |
return x
|
| 160 |
+
|
| 161 |
+
if __name__ == '__main__':
|
| 162 |
+
# model = Decoder(in_channel=320)
|
| 163 |
+
# input_data = torch.randn(1, 320, 32, 32)
|
| 164 |
+
# res = [torch.randn(1, 512, 32, 32), torch.randn(1, 256, 64, 64), torch.randn(1, 128, 128, 128)]
|
| 165 |
+
# output = model(input_data, res)
|
| 166 |
+
# print(output.shape)
|
| 167 |
+
|
| 168 |
+
model = Decoder(in_channel=320, short_cut_channels=None)
|
| 169 |
+
# input_data = torch.randn(1, 320, 32, 32)
|
| 170 |
+
# output = model(input_data, None)
|
| 171 |
+
# flops, params = get_model_complexity_info(model, (320, 32, 32))
|
| 172 |
+
# print(f"参数量: {params}")
|
| 173 |
+
# print(f"计算量: {flops}")
|
| 174 |
+
# print("-" * 30)
|
| 175 |
+
# print(output.shape)
|
| 176 |
+
|
| 177 |
+
# model = Decoder(in_channel=320, short_cut_channels=None, psp_channel=-1)
|
| 178 |
+
# input_data = torch.randn(2, 320, 32, 32)
|
| 179 |
+
# output = model(input_data, None)
|
| 180 |
+
# print(output.shape)
|