Update models/monet.py
Browse files- models/monet.py +17 -6
models/monet.py
CHANGED
|
@@ -69,7 +69,7 @@ class MAL(nn.Module):
|
|
| 69 |
Multi-view Attention Learning (MAL) module
|
| 70 |
"""
|
| 71 |
|
| 72 |
-
def __init__(self, in_dim=768, feature_num=4, feature_size=28):
|
| 73 |
super().__init__()
|
| 74 |
|
| 75 |
self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
|
|
@@ -82,9 +82,14 @@ class MAL(nn.Module):
|
|
| 82 |
|
| 83 |
self.feature_num = feature_num
|
| 84 |
self.in_dim = in_dim
|
|
|
|
| 85 |
|
| 86 |
def forward(self, features):
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
for index, _ in enumerate(features):
|
| 89 |
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
|
| 90 |
features = feature
|
|
@@ -118,7 +123,7 @@ class SaveOutput:
|
|
| 118 |
|
| 119 |
|
| 120 |
class MoNet(nn.Module):
|
| 121 |
-
def __init__(self, config, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
|
| 122 |
super().__init__()
|
| 123 |
self.img_size = img_size
|
| 124 |
self.input_size = img_size // patch_size
|
|
@@ -136,10 +141,10 @@ class MoNet(nn.Module):
|
|
| 136 |
|
| 137 |
self.MALs = nn.ModuleList()
|
| 138 |
for _ in range(config.mal_num):
|
| 139 |
-
self.MALs.append(MAL())
|
| 140 |
|
| 141 |
# Image Quality Score Regression
|
| 142 |
-
self.fusion_wam = MAL(feature_num=config.mal_num)
|
| 143 |
self.block = Block(dim_mlp, 12)
|
| 144 |
self.cnn = nn.Sequential(
|
| 145 |
nn.Conv2d(dim_mlp, 256, 5),
|
|
@@ -163,6 +168,8 @@ class MoNet(nn.Module):
|
|
| 163 |
nn.Sigmoid()
|
| 164 |
)
|
| 165 |
|
|
|
|
|
|
|
| 166 |
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
|
| 167 |
x1 = save_output.outputs[block_index[0]][:, 1:]
|
| 168 |
x2 = save_output.outputs[block_index[1]][:, 1:]
|
|
@@ -182,7 +189,11 @@ class MoNet(nn.Module):
|
|
| 182 |
x = x.permute(1, 0, 2, 3, 4) # bs, 4, 768, 28 * 28
|
| 183 |
|
| 184 |
# Different Opinion Features (DOF)
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
for index, _ in enumerate(self.MALs):
|
| 187 |
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 188 |
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # 3, bs, 768, 28, 28
|
|
|
|
| 69 |
Multi-view Attention Learning (MAL) module
|
| 70 |
"""
|
| 71 |
|
| 72 |
+
def __init__(self, in_dim=768, feature_num=4, feature_size=28, is_gpu=True):
|
| 73 |
super().__init__()
|
| 74 |
|
| 75 |
self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
|
|
|
|
| 82 |
|
| 83 |
self.feature_num = feature_num
|
| 84 |
self.in_dim = in_dim
|
| 85 |
+
self.is_gpu = is_gpu
|
| 86 |
|
| 87 |
def forward(self, features):
|
| 88 |
+
if self.is_gpu:
|
| 89 |
+
feature = torch.tensor([]).cuda()
|
| 90 |
+
else:
|
| 91 |
+
feature = torch.tensor([])
|
| 92 |
+
|
| 93 |
for index, _ in enumerate(features):
|
| 94 |
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
|
| 95 |
features = feature
|
|
|
|
| 123 |
|
| 124 |
|
| 125 |
class MoNet(nn.Module):
|
| 126 |
+
def __init__(self, config, patch_size=8, drop=0.1, dim_mlp=768, img_size=224, is_gpu=True):
|
| 127 |
super().__init__()
|
| 128 |
self.img_size = img_size
|
| 129 |
self.input_size = img_size // patch_size
|
|
|
|
| 141 |
|
| 142 |
self.MALs = nn.ModuleList()
|
| 143 |
for _ in range(config.mal_num):
|
| 144 |
+
self.MALs.append(MAL(is_gpu=is_gpu))
|
| 145 |
|
| 146 |
# Image Quality Score Regression
|
| 147 |
+
self.fusion_wam = MAL(feature_num=config.mal_num, is_gpu=is_gpu)
|
| 148 |
self.block = Block(dim_mlp, 12)
|
| 149 |
self.cnn = nn.Sequential(
|
| 150 |
nn.Conv2d(dim_mlp, 256, 5),
|
|
|
|
| 168 |
nn.Sigmoid()
|
| 169 |
)
|
| 170 |
|
| 171 |
+
self.is_gpu = is_gpu
|
| 172 |
+
|
| 173 |
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
|
| 174 |
x1 = save_output.outputs[block_index[0]][:, 1:]
|
| 175 |
x2 = save_output.outputs[block_index[1]][:, 1:]
|
|
|
|
| 189 |
x = x.permute(1, 0, 2, 3, 4) # bs, 4, 768, 28 * 28
|
| 190 |
|
| 191 |
# Different Opinion Features (DOF)
|
| 192 |
+
if self.is_gpu:
|
| 193 |
+
DOF = torch.tensor([]).cuda()
|
| 194 |
+
else:
|
| 195 |
+
DOF = torch.tensor([])
|
| 196 |
+
|
| 197 |
for index, _ in enumerate(self.MALs):
|
| 198 |
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
| 199 |
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # 3, bs, 768, 28, 28
|