| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from timm.models import create_model |
|
|
|
|
| class MultiBand_Gate(nn.Module): |
| |
| def __init__(self): |
| """ |
| 初始化门控类。 |
| 不需要指定输入大小,PSF 和 input_data 的特征在内部单独提取。 |
| """ |
| super(MultiBand_Gate, self).__init__() |
|
|
| |
| self.input_encoder = nn.Sequential( |
| nn.Conv2d(1, 8, kernel_size=3, padding=1), |
| nn.InstanceNorm2d(8), |
| nn.LeakyReLU(), |
| nn.AdaptiveAvgPool2d((14, 14)) |
| ) |
|
|
| |
| self.gate_network = nn.Sequential( |
| nn.Conv2d(8, 4, kernel_size=3, padding=1), |
| nn.InstanceNorm2d(4), |
| nn.LeakyReLU(), |
| nn.AdaptiveAvgPool2d((1, 1)), |
| nn.Flatten(), |
| nn.Linear(4, 1), |
| ) |
|
|
| def forward(self, input_data): |
| """ |
| 前向传播函数。 |
| 参数: |
| - psf: 点扩散函数,形状为 (B, 1, 51, 51) |
| - input_data: 输入图像数据,形状为 (B, 1, 224, 224) |
| 返回: |
| - gate_weight: 门控权重,形状为 (B, 1, 224, 224) |
| """ |
| |
| input_feat = self.input_encoder(input_data) |
|
|
| combined_feat = input_feat |
| gate_weight = self.gate_network(combined_feat) |
|
|
| return gate_weight |
|
|
| |
| class MultiTask_Gate(nn.Module): |
| def __init__(self, input_dim, num_experts, dropout_rate): |
| """ |
| 初始化多任务门控类。 |
| |
| 参数: |
| - input_dim: 输入数据的维度。 |
| - num_experts: 专家数量。 |
| - dropout_rate: Dropout比率。 |
| """ |
| super(MultiTask_Gate, self).__init__() |
| |
| |
| self.gate = nn.Sequential( |
| nn.Linear(input_dim, 128), |
| nn.LayerNorm(128), |
| nn.LeakyReLU(), |
| |
| nn.Linear(128, num_experts), |
| nn.Softmax(dim=1) |
| ) |
|
|
| def forward(self, x): |
| x = F.adaptive_avg_pool2d(x, (1, 1))[:,:,0,0] |
| return self.gate(x) |
| |
| |
| class MultiBand_MoE(nn.Module): |
| |
| def __init__(self, model_name='vit_small_patch16_224', pretrained=False, multi_band_experts=12): |
| """ |
| 初始化多波段输入的Mixture of Experts (MoE)模型。 |
| |
| 参数: |
| - model_name: 使用的基础模型名称。 |
| - pretrained: 是否使用预训练权重。 |
| - input_size: 输入数据的形状 (h, w),即高度和宽度。 |
| - multi_band_experts: 专家数量。 |
| """ |
| super(MultiBand_MoE, self).__init__() |
| |
| self.backbone = create_model(model_name=model_name, pretrained=pretrained, num_classes=0) |
| self.backbone.patch_embed.proj = nn.Conv2d(multi_band_experts, 384, kernel_size=(16, 16), stride=(16, 16), bias=False) |
| |
| self.multi_band_experts = nn.ModuleList([MultiBand_Gate() for _ in range(multi_band_experts)]) |
|
|
| def forward(self, input_data): |
| """ |
| 前向传播函数。 |
| |
| 参数: |
| - input_data: 输入数据,形状为 (batch_size, num_bands, h, w)。 |
| - psf: 点扩散函数,形状为 (batch_size, num_bands, h, w)。 |
| |
| 返回: |
| - output: 特征图像,形状为 (batch_size, num_bands, h, w)。 |
| - multi_band_weights: 多波段权重,形状为 (batch_size, num_bands, 1, 1)。 |
| """ |
| multi_band_weights = torch.zeros((input_data.size(0), len(self.multi_band_experts), 1, 1)).to(input_data.device) |
| |
| |
| for i in range(len(self.multi_band_experts)): |
| multi_band_weight = self.multi_band_experts[i](input_data[:, i:i+1, :, :]) |
| |
| multi_band_weight = multi_band_weight.unsqueeze(1) |
| multi_band_weights[:, i, :, :] = multi_band_weight |
| |
| multi_band_weights = F.softmax(multi_band_weights, dim=1) |
| output_data = input_data * multi_band_weights |
| |
|
|
| |
| |
|
|
| feature = self.backbone.forward_features(output_data) |
| |
| |
| feature = feature[:, 1:, :] |
| B, N, C = feature.shape |
| H = W = int(N ** 0.5) |
| feature = feature.reshape(B, H, W, C) |
| feature = feature.permute(0, 3, 1, 2) |
|
|
| return feature, multi_band_weights[:,:,0,0] |
| |
| class ReconTask_MoE(nn.Module): |
|
|
| def __init__(self, input_dim, num_experts, dropout_rate): |
| super(ReconTask_MoE, self).__init__() |
| |
| self.gating = MultiTask_Gate(input_dim, num_experts, dropout_rate) |
| |
| self.num_experts = num_experts |
|
|
| self.embeding_dim = int(384 / num_experts) |
| |
| self.decoder = nn.Sequential( |
| nn.ConvTranspose2d(self.embeding_dim, 64, kernel_size=4, stride=2, padding=1), |
| nn.InstanceNorm2d(64), |
| nn.LeakyReLU(), |
| nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), |
| nn.InstanceNorm2d(32), |
| nn.LeakyReLU(), |
| nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), |
| nn.InstanceNorm2d(16), |
| nn.LeakyReLU(), |
| nn.ConvTranspose2d(16, 12, kernel_size=4, stride=2, padding=1), |
| ) |
| |
| def forward(self, x): |
| |
| gate_weights = self.gating(x) |
| |
| expert_outputs = [x[:, i * self.embeding_dim:(i + 1) * self.embeding_dim,:,:] for i in range(self.num_experts)] |
|
|
| b, c, h, w = expert_outputs[0].shape |
| output = sum(gate_weights[:, i].view(b, 1, 1, 1) * expert_outputs[i] for i in range(self.num_experts)) |
| output = self.decoder(output) |
| |
| return output, gate_weights |
|
|