File size: 6,840 Bytes
d24fe95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models import create_model


class MultiBand_Gate(nn.Module):
    
    def __init__(self):
        """
        初始化门控类。
        不需要指定输入大小,PSF 和 input_data 的特征在内部单独提取。
        """
        super(MultiBand_Gate, self).__init__()

        # 提取 input_data 特征的卷积网络
        self.input_encoder = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, padding=1),  # 输入: (B, 1, 224, 224) -> (B, 8, 224, 224)
            nn.InstanceNorm2d(8),
            nn.LeakyReLU(),
            nn.AdaptiveAvgPool2d((14, 14))             # 下采样为相同尺寸
        )

        # 合并后生成门控权重的网络
        self.gate_network = nn.Sequential(
            nn.Conv2d(8, 4, kernel_size=3, padding=1),  # 输入: (B, 16, 14, 14)
            nn.InstanceNorm2d(4),
            nn.LeakyReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),               # 全局平均池化 -> (B, 8, 1, 1)
            nn.Flatten(),                                # -> (B, 8)
            nn.Linear(4, 1),                             # -> (B, 1)
        )

    def forward(self, input_data):
        """
        前向传播函数。
        参数:
        - psf: 点扩散函数,形状为 (B, 1, 51, 51)
        - input_data: 输入图像数据,形状为 (B, 1, 224, 224)
        返回:
        - gate_weight: 门控权重,形状为 (B, 1, 224, 224)
        """
                 
        input_feat = self.input_encoder(input_data)       # -> (B, 8, 14, 14)

        combined_feat = input_feat
        gate_weight = self.gate_network(combined_feat)            # -> (B, 1, 224, 224)

        return gate_weight

    
class MultiTask_Gate(nn.Module):
    def __init__(self, input_dim, num_experts, dropout_rate):
        """
        初始化多任务门控类。
        
        参数:
        - input_dim: 输入数据的维度。
        - num_experts: 专家数量。
        - dropout_rate: Dropout比率。
        """
        super(MultiTask_Gate, self).__init__()
        
        # TODO:这里的输入可以插入qita特征
        self.gate = nn.Sequential(
            nn.Linear(input_dim, 128), 
            nn.LayerNorm(128),   
            nn.LeakyReLU(),
            # nn.Dropout(dropout_rate),                                  # TODO:这里的dropout会导致模型不收敛, 目前不使用dropout
            nn.Linear(128, num_experts),
            nn.Softmax(dim=1)         
        )

    def forward(self, x):
        x = F.adaptive_avg_pool2d(x, (1, 1))[:,:,0,0]  # 全局平均池化
        return self.gate(x)
      
      
class MultiBand_MoE(nn.Module):
    
    def __init__(self, model_name='vit_small_patch16_224', pretrained=False, multi_band_experts=12):
        """
        初始化多波段输入的Mixture of Experts (MoE)模型。
        
        参数:
        - model_name: 使用的基础模型名称。
        - pretrained: 是否使用预训练权重。
        - input_size: 输入数据的形状 (h, w),即高度和宽度。
        - multi_band_experts: 专家数量。
        """
        super(MultiBand_MoE, self).__init__()
        
        self.backbone = create_model(model_name=model_name, pretrained=pretrained, num_classes=0)
        self.backbone.patch_embed.proj = nn.Conv2d(multi_band_experts, 384, kernel_size=(16, 16), stride=(16, 16), bias=False)
        
        self.multi_band_experts = nn.ModuleList([MultiBand_Gate() for _ in range(multi_band_experts)])

    def forward(self, input_data):
        """
        前向传播函数。
        
        参数:
        - input_data: 输入数据,形状为 (batch_size, num_bands, h, w)。
        - psf: 点扩散函数,形状为 (batch_size, num_bands, h, w)。
        
        返回:
        - output: 特征图像,形状为 (batch_size, num_bands, h, w)。
        - multi_band_weights: 多波段权重,形状为 (batch_size, num_bands, 1, 1)。
        """
        multi_band_weights = torch.zeros((input_data.size(0), len(self.multi_band_experts), 1, 1)).to(input_data.device)
        
        ################################################################################################################
        for i in range(len(self.multi_band_experts)):
            multi_band_weight = self.multi_band_experts[i](input_data[:, i:i+1, :, :])
            # print("multi_band_weight.shape:", multi_band_weight.shape)
            multi_band_weight = multi_band_weight.unsqueeze(1) 
            multi_band_weights[:, i, :, :] = multi_band_weight
            
        multi_band_weights = F.softmax(multi_band_weights, dim=1)
        output_data = input_data * multi_band_weights
        ##################################################################################################################

        # # 不增加 multi_band_weights
        # output_data = input_data

        feature = self.backbone.forward_features(output_data)
   
        # TODO:这里的特征图像需要进行 reshape,不知道是否正确,先这样
        feature = feature[:, 1:, :] 
        B, N, C = feature.shape 
        H = W = int(N ** 0.5)  
        feature = feature.reshape(B, H, W, C)    
        feature = feature.permute(0, 3, 1, 2)   

        return feature, multi_band_weights[:,:,0,0]
        
class ReconTask_MoE(nn.Module):

    def __init__(self, input_dim, num_experts, dropout_rate):
        super(ReconTask_MoE, self).__init__()
        
        self.gating = MultiTask_Gate(input_dim, num_experts, dropout_rate)
        
        self.num_experts = num_experts

        self.embeding_dim = int(384 / num_experts)  # 384是backbone的输出通道数
                
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(self.embeding_dim, 64, kernel_size=4, stride=2, padding=1),   # -> 56x56
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # -> 112x112
            nn.InstanceNorm2d(32),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),    # -> 112x112
            nn.InstanceNorm2d(16),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(16, 12, kernel_size=4, stride=2, padding=1),  # -> 224x224
        )
        
    def forward(self, x):
        
        gate_weights = self.gating(x)  
               
        expert_outputs = [x[:, i * self.embeding_dim:(i + 1) * self.embeding_dim,:,:] for i in range(self.num_experts)]   

        b, c, h, w = expert_outputs[0].shape
        output = sum(gate_weights[:, i].view(b, 1, 1, 1) * expert_outputs[i] for i in range(self.num_experts))
        output = self.decoder(output)
        
        return output, gate_weights