import torch import torch.nn as nn import torch.nn.functional as F from timm.models import create_model class MultiBand_Gate(nn.Module): def __init__(self): """ 初始化门控类。 不需要指定输入大小,PSF 和 input_data 的特征在内部单独提取。 """ super(MultiBand_Gate, self).__init__() # 提取 input_data 特征的卷积网络 self.input_encoder = nn.Sequential( nn.Conv2d(1, 8, kernel_size=3, padding=1), # 输入: (B, 1, 224, 224) -> (B, 8, 224, 224) nn.InstanceNorm2d(8), nn.LeakyReLU(), nn.AdaptiveAvgPool2d((14, 14)) # 下采样为相同尺寸 ) # 合并后生成门控权重的网络 self.gate_network = nn.Sequential( nn.Conv2d(8, 4, kernel_size=3, padding=1), # 输入: (B, 16, 14, 14) nn.InstanceNorm2d(4), nn.LeakyReLU(), nn.AdaptiveAvgPool2d((1, 1)), # 全局平均池化 -> (B, 8, 1, 1) nn.Flatten(), # -> (B, 8) nn.Linear(4, 1), # -> (B, 1) ) def forward(self, input_data): """ 前向传播函数。 参数: - psf: 点扩散函数,形状为 (B, 1, 51, 51) - input_data: 输入图像数据,形状为 (B, 1, 224, 224) 返回: - gate_weight: 门控权重,形状为 (B, 1, 224, 224) """ input_feat = self.input_encoder(input_data) # -> (B, 8, 14, 14) combined_feat = input_feat gate_weight = self.gate_network(combined_feat) # -> (B, 1, 224, 224) return gate_weight class MultiTask_Gate(nn.Module): def __init__(self, input_dim, num_experts, dropout_rate): """ 初始化多任务门控类。 参数: - input_dim: 输入数据的维度。 - num_experts: 专家数量。 - dropout_rate: Dropout比率。 """ super(MultiTask_Gate, self).__init__() # TODO:这里的输入可以插入qita特征 self.gate = nn.Sequential( nn.Linear(input_dim, 128), nn.LayerNorm(128), nn.LeakyReLU(), # nn.Dropout(dropout_rate), # TODO:这里的dropout会导致模型不收敛, 目前不使用dropout nn.Linear(128, num_experts), nn.Softmax(dim=1) ) def forward(self, x): x = F.adaptive_avg_pool2d(x, (1, 1))[:,:,0,0] # 全局平均池化 return self.gate(x) class MultiBand_MoE(nn.Module): def __init__(self, model_name='vit_small_patch16_224', pretrained=False, multi_band_experts=12): """ 初始化多波段输入的Mixture of Experts (MoE)模型。 参数: - model_name: 使用的基础模型名称。 - pretrained: 是否使用预训练权重。 - input_size: 输入数据的形状 (h, w),即高度和宽度。 - multi_band_experts: 专家数量。 """ super(MultiBand_MoE, self).__init__() self.backbone = create_model(model_name=model_name, pretrained=pretrained, num_classes=0) self.backbone.patch_embed.proj = nn.Conv2d(multi_band_experts, 384, kernel_size=(16, 16), stride=(16, 16), bias=False) self.multi_band_experts = nn.ModuleList([MultiBand_Gate() for _ in range(multi_band_experts)]) def forward(self, input_data): """ 前向传播函数。 参数: - input_data: 输入数据,形状为 (batch_size, num_bands, h, w)。 - psf: 点扩散函数,形状为 (batch_size, num_bands, h, w)。 返回: - output: 特征图像,形状为 (batch_size, num_bands, h, w)。 - multi_band_weights: 多波段权重,形状为 (batch_size, num_bands, 1, 1)。 """ multi_band_weights = torch.zeros((input_data.size(0), len(self.multi_band_experts), 1, 1)).to(input_data.device) ################################################################################################################ for i in range(len(self.multi_band_experts)): multi_band_weight = self.multi_band_experts[i](input_data[:, i:i+1, :, :]) # print("multi_band_weight.shape:", multi_band_weight.shape) multi_band_weight = multi_band_weight.unsqueeze(1) multi_band_weights[:, i, :, :] = multi_band_weight multi_band_weights = F.softmax(multi_band_weights, dim=1) output_data = input_data * multi_band_weights ################################################################################################################## # # 不增加 multi_band_weights # output_data = input_data feature = self.backbone.forward_features(output_data) # TODO:这里的特征图像需要进行 reshape,不知道是否正确,先这样 feature = feature[:, 1:, :] B, N, C = feature.shape H = W = int(N ** 0.5) feature = feature.reshape(B, H, W, C) feature = feature.permute(0, 3, 1, 2) return feature, multi_band_weights[:,:,0,0] class ReconTask_MoE(nn.Module): def __init__(self, input_dim, num_experts, dropout_rate): super(ReconTask_MoE, self).__init__() self.gating = MultiTask_Gate(input_dim, num_experts, dropout_rate) self.num_experts = num_experts self.embeding_dim = int(384 / num_experts) # 384是backbone的输出通道数 self.decoder = nn.Sequential( nn.ConvTranspose2d(self.embeding_dim, 64, kernel_size=4, stride=2, padding=1), # -> 56x56 nn.InstanceNorm2d(64), nn.LeakyReLU(), nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # -> 112x112 nn.InstanceNorm2d(32), nn.LeakyReLU(), nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), # -> 112x112 nn.InstanceNorm2d(16), nn.LeakyReLU(), nn.ConvTranspose2d(16, 12, kernel_size=4, stride=2, padding=1), # -> 224x224 ) def forward(self, x): gate_weights = self.gating(x) expert_outputs = [x[:, i * self.embeding_dim:(i + 1) * self.embeding_dim,:,:] for i in range(self.num_experts)] b, c, h, w = expert_outputs[0].shape output = sum(gate_weights[:, i].view(b, 1, 1, 1) * expert_outputs[i] for i in range(self.num_experts)) output = self.decoder(output) return output, gate_weights