File size: 6,840 Bytes
d24fe95 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models import create_model
class MultiBand_Gate(nn.Module):
def __init__(self):
"""
初始化门控类。
不需要指定输入大小,PSF 和 input_data 的特征在内部单独提取。
"""
super(MultiBand_Gate, self).__init__()
# 提取 input_data 特征的卷积网络
self.input_encoder = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=3, padding=1), # 输入: (B, 1, 224, 224) -> (B, 8, 224, 224)
nn.InstanceNorm2d(8),
nn.LeakyReLU(),
nn.AdaptiveAvgPool2d((14, 14)) # 下采样为相同尺寸
)
# 合并后生成门控权重的网络
self.gate_network = nn.Sequential(
nn.Conv2d(8, 4, kernel_size=3, padding=1), # 输入: (B, 16, 14, 14)
nn.InstanceNorm2d(4),
nn.LeakyReLU(),
nn.AdaptiveAvgPool2d((1, 1)), # 全局平均池化 -> (B, 8, 1, 1)
nn.Flatten(), # -> (B, 8)
nn.Linear(4, 1), # -> (B, 1)
)
def forward(self, input_data):
"""
前向传播函数。
参数:
- psf: 点扩散函数,形状为 (B, 1, 51, 51)
- input_data: 输入图像数据,形状为 (B, 1, 224, 224)
返回:
- gate_weight: 门控权重,形状为 (B, 1, 224, 224)
"""
input_feat = self.input_encoder(input_data) # -> (B, 8, 14, 14)
combined_feat = input_feat
gate_weight = self.gate_network(combined_feat) # -> (B, 1, 224, 224)
return gate_weight
class MultiTask_Gate(nn.Module):
def __init__(self, input_dim, num_experts, dropout_rate):
"""
初始化多任务门控类。
参数:
- input_dim: 输入数据的维度。
- num_experts: 专家数量。
- dropout_rate: Dropout比率。
"""
super(MultiTask_Gate, self).__init__()
# TODO:这里的输入可以插入qita特征
self.gate = nn.Sequential(
nn.Linear(input_dim, 128),
nn.LayerNorm(128),
nn.LeakyReLU(),
# nn.Dropout(dropout_rate), # TODO:这里的dropout会导致模型不收敛, 目前不使用dropout
nn.Linear(128, num_experts),
nn.Softmax(dim=1)
)
def forward(self, x):
x = F.adaptive_avg_pool2d(x, (1, 1))[:,:,0,0] # 全局平均池化
return self.gate(x)
class MultiBand_MoE(nn.Module):
def __init__(self, model_name='vit_small_patch16_224', pretrained=False, multi_band_experts=12):
"""
初始化多波段输入的Mixture of Experts (MoE)模型。
参数:
- model_name: 使用的基础模型名称。
- pretrained: 是否使用预训练权重。
- input_size: 输入数据的形状 (h, w),即高度和宽度。
- multi_band_experts: 专家数量。
"""
super(MultiBand_MoE, self).__init__()
self.backbone = create_model(model_name=model_name, pretrained=pretrained, num_classes=0)
self.backbone.patch_embed.proj = nn.Conv2d(multi_band_experts, 384, kernel_size=(16, 16), stride=(16, 16), bias=False)
self.multi_band_experts = nn.ModuleList([MultiBand_Gate() for _ in range(multi_band_experts)])
def forward(self, input_data):
"""
前向传播函数。
参数:
- input_data: 输入数据,形状为 (batch_size, num_bands, h, w)。
- psf: 点扩散函数,形状为 (batch_size, num_bands, h, w)。
返回:
- output: 特征图像,形状为 (batch_size, num_bands, h, w)。
- multi_band_weights: 多波段权重,形状为 (batch_size, num_bands, 1, 1)。
"""
multi_band_weights = torch.zeros((input_data.size(0), len(self.multi_band_experts), 1, 1)).to(input_data.device)
################################################################################################################
for i in range(len(self.multi_band_experts)):
multi_band_weight = self.multi_band_experts[i](input_data[:, i:i+1, :, :])
# print("multi_band_weight.shape:", multi_band_weight.shape)
multi_band_weight = multi_band_weight.unsqueeze(1)
multi_band_weights[:, i, :, :] = multi_band_weight
multi_band_weights = F.softmax(multi_band_weights, dim=1)
output_data = input_data * multi_band_weights
##################################################################################################################
# # 不增加 multi_band_weights
# output_data = input_data
feature = self.backbone.forward_features(output_data)
# TODO:这里的特征图像需要进行 reshape,不知道是否正确,先这样
feature = feature[:, 1:, :]
B, N, C = feature.shape
H = W = int(N ** 0.5)
feature = feature.reshape(B, H, W, C)
feature = feature.permute(0, 3, 1, 2)
return feature, multi_band_weights[:,:,0,0]
class ReconTask_MoE(nn.Module):
def __init__(self, input_dim, num_experts, dropout_rate):
super(ReconTask_MoE, self).__init__()
self.gating = MultiTask_Gate(input_dim, num_experts, dropout_rate)
self.num_experts = num_experts
self.embeding_dim = int(384 / num_experts) # 384是backbone的输出通道数
self.decoder = nn.Sequential(
nn.ConvTranspose2d(self.embeding_dim, 64, kernel_size=4, stride=2, padding=1), # -> 56x56
nn.InstanceNorm2d(64),
nn.LeakyReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # -> 112x112
nn.InstanceNorm2d(32),
nn.LeakyReLU(),
nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), # -> 112x112
nn.InstanceNorm2d(16),
nn.LeakyReLU(),
nn.ConvTranspose2d(16, 12, kernel_size=4, stride=2, padding=1), # -> 224x224
)
def forward(self, x):
gate_weights = self.gating(x)
expert_outputs = [x[:, i * self.embeding_dim:(i + 1) * self.embeding_dim,:,:] for i in range(self.num_experts)]
b, c, h, w = expert_outputs[0].shape
output = sum(gate_weights[:, i].view(b, 1, 1, 1) * expert_outputs[i] for i in range(self.num_experts))
output = self.decoder(output)
return output, gate_weights
|