AstroM3 / ExampleCode /example1 /model /JplusEncoder.py
lvjiameng's picture
Upload 21 files
d24fe95 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models import create_model
class MultiBand_Gate(nn.Module):
def __init__(self):
"""
初始化门控类。
不需要指定输入大小,PSF 和 input_data 的特征在内部单独提取。
"""
super(MultiBand_Gate, self).__init__()
# 提取 input_data 特征的卷积网络
self.input_encoder = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=3, padding=1), # 输入: (B, 1, 224, 224) -> (B, 8, 224, 224)
nn.InstanceNorm2d(8),
nn.LeakyReLU(),
nn.AdaptiveAvgPool2d((14, 14)) # 下采样为相同尺寸
)
# 合并后生成门控权重的网络
self.gate_network = nn.Sequential(
nn.Conv2d(8, 4, kernel_size=3, padding=1), # 输入: (B, 16, 14, 14)
nn.InstanceNorm2d(4),
nn.LeakyReLU(),
nn.AdaptiveAvgPool2d((1, 1)), # 全局平均池化 -> (B, 8, 1, 1)
nn.Flatten(), # -> (B, 8)
nn.Linear(4, 1), # -> (B, 1)
)
def forward(self, input_data):
"""
前向传播函数。
参数:
- psf: 点扩散函数,形状为 (B, 1, 51, 51)
- input_data: 输入图像数据,形状为 (B, 1, 224, 224)
返回:
- gate_weight: 门控权重,形状为 (B, 1, 224, 224)
"""
input_feat = self.input_encoder(input_data) # -> (B, 8, 14, 14)
combined_feat = input_feat
gate_weight = self.gate_network(combined_feat) # -> (B, 1, 224, 224)
return gate_weight
class MultiTask_Gate(nn.Module):
def __init__(self, input_dim, num_experts, dropout_rate):
"""
初始化多任务门控类。
参数:
- input_dim: 输入数据的维度。
- num_experts: 专家数量。
- dropout_rate: Dropout比率。
"""
super(MultiTask_Gate, self).__init__()
# TODO:这里的输入可以插入qita特征
self.gate = nn.Sequential(
nn.Linear(input_dim, 128),
nn.LayerNorm(128),
nn.LeakyReLU(),
# nn.Dropout(dropout_rate), # TODO:这里的dropout会导致模型不收敛, 目前不使用dropout
nn.Linear(128, num_experts),
nn.Softmax(dim=1)
)
def forward(self, x):
x = F.adaptive_avg_pool2d(x, (1, 1))[:,:,0,0] # 全局平均池化
return self.gate(x)
class MultiBand_MoE(nn.Module):
def __init__(self, model_name='vit_small_patch16_224', pretrained=False, multi_band_experts=12):
"""
初始化多波段输入的Mixture of Experts (MoE)模型。
参数:
- model_name: 使用的基础模型名称。
- pretrained: 是否使用预训练权重。
- input_size: 输入数据的形状 (h, w),即高度和宽度。
- multi_band_experts: 专家数量。
"""
super(MultiBand_MoE, self).__init__()
self.backbone = create_model(model_name=model_name, pretrained=pretrained, num_classes=0)
self.backbone.patch_embed.proj = nn.Conv2d(multi_band_experts, 384, kernel_size=(16, 16), stride=(16, 16), bias=False)
self.multi_band_experts = nn.ModuleList([MultiBand_Gate() for _ in range(multi_band_experts)])
def forward(self, input_data):
"""
前向传播函数。
参数:
- input_data: 输入数据,形状为 (batch_size, num_bands, h, w)。
- psf: 点扩散函数,形状为 (batch_size, num_bands, h, w)。
返回:
- output: 特征图像,形状为 (batch_size, num_bands, h, w)。
- multi_band_weights: 多波段权重,形状为 (batch_size, num_bands, 1, 1)。
"""
multi_band_weights = torch.zeros((input_data.size(0), len(self.multi_band_experts), 1, 1)).to(input_data.device)
################################################################################################################
for i in range(len(self.multi_band_experts)):
multi_band_weight = self.multi_band_experts[i](input_data[:, i:i+1, :, :])
# print("multi_band_weight.shape:", multi_band_weight.shape)
multi_band_weight = multi_band_weight.unsqueeze(1)
multi_band_weights[:, i, :, :] = multi_band_weight
multi_band_weights = F.softmax(multi_band_weights, dim=1)
output_data = input_data * multi_band_weights
##################################################################################################################
# # 不增加 multi_band_weights
# output_data = input_data
feature = self.backbone.forward_features(output_data)
# TODO:这里的特征图像需要进行 reshape,不知道是否正确,先这样
feature = feature[:, 1:, :]
B, N, C = feature.shape
H = W = int(N ** 0.5)
feature = feature.reshape(B, H, W, C)
feature = feature.permute(0, 3, 1, 2)
return feature, multi_band_weights[:,:,0,0]
class ReconTask_MoE(nn.Module):
def __init__(self, input_dim, num_experts, dropout_rate):
super(ReconTask_MoE, self).__init__()
self.gating = MultiTask_Gate(input_dim, num_experts, dropout_rate)
self.num_experts = num_experts
self.embeding_dim = int(384 / num_experts) # 384是backbone的输出通道数
self.decoder = nn.Sequential(
nn.ConvTranspose2d(self.embeding_dim, 64, kernel_size=4, stride=2, padding=1), # -> 56x56
nn.InstanceNorm2d(64),
nn.LeakyReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # -> 112x112
nn.InstanceNorm2d(32),
nn.LeakyReLU(),
nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), # -> 112x112
nn.InstanceNorm2d(16),
nn.LeakyReLU(),
nn.ConvTranspose2d(16, 12, kernel_size=4, stride=2, padding=1), # -> 224x224
)
def forward(self, x):
gate_weights = self.gating(x)
expert_outputs = [x[:, i * self.embeding_dim:(i + 1) * self.embeding_dim,:,:] for i in range(self.num_experts)]
b, c, h, w = expert_outputs[0].shape
output = sum(gate_weights[:, i].view(b, 1, 1, 1) * expert_outputs[i] for i in range(self.num_experts))
output = self.decoder(output)
return output, gate_weights