import torch import torch.nn as nn import torch.nn.functional as F from copy import deepcopy from mamba_ssm import Mamba from utils import FDS from torchvision.models import resnet18 class MambaModel(nn.Module): def __init__(self, d_model, max_length=30): super(MambaModel, self).__init__() self.linear = nn.Linear(in_features=21, out_features=d_model) self.pos_encoder = PositionalEncoding(d_model, max_length) self.mamba = Mamba(d_model=d_model, d_state=32, expand=4) self.global_pool = nn.AdaptiveAvgPool1d(1) def forward(self, x: torch.Tensor): x = self.pos_encoder(self.linear(x)) y = self.mamba(x) y_flip = self.mamba(x.flip([-2])).flip([-2]) y = torch.cat((y, y_flip), dim=-1) y = self.global_pool(y.permute(0, 2, 1)).squeeze(-1) return y class MLP(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout_rate=0.1): super(MLP, self).__init__() if isinstance(hidden_dim, int): hidden_dim = [hidden_dim] * num_layers layers = [] layers.append(nn.Linear(input_dim, hidden_dim[0])) layers.append(nn.ReLU()) layers.append(nn.Dropout(dropout_rate)) for i in range(len(hidden_dim) - 1): layers.append(nn.Linear(hidden_dim[i], hidden_dim[i + 1])) layers.append(nn.ReLU()) layers.append(nn.Dropout(dropout_rate)) layers.append(nn.Linear(hidden_dim[-1], output_dim)) self.network = nn.Sequential(*layers) def forward(self, x): return self.network(x) class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=50): super(PositionalEncoding, self).__init__() pe = torch.zeros(max_len, d_model) # (max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (max_len, 1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.FloatTensor([10000.0])) / d_model)) # (d_model/2,) pe[:, 0::2] = torch.sin(position * div_term) # 偶数维 pe[:, 1::2] = torch.cos(position * div_term) # 奇数维 pe = pe.unsqueeze(0) # (1, max_len, d_model) self.register_buffer('pe', pe) def forward(self, x): """ x: (B, N, d_model) """ x = x + self.pe[:, :x.size(1), :] return x class MHAModel(nn.Module): def __init__(self, d_model, max_length=50): super(MHAModel, self).__init__() self.linear = nn.Linear(in_features=21, out_features=d_model) self.pos_encoder = PositionalEncoding(d_model, max_length) self.self_attn = nn.MultiheadAttention(d_model, num_heads=8, batch_first=True) self.global_pool = nn.AdaptiveAvgPool1d(1) def forward(self, x: torch.Tensor): # 线性变换 + 位置编码 x = self.pos_encoder(self.linear(x)) # [batch, seq_len, d_model] # 正向自注意力 y, _ = self.self_attn(x, x, x) # [batch, seq_len, d_model] # 反向自注意力 x_flip = x.flip([-2]) # 沿序列维度翻转 y_flip, _ = self.self_attn(x_flip, x_flip, x_flip) y_flip = y_flip.flip([-2]) # 翻转回原顺序 # 拼接正反向结果 y = torch.cat((y, y_flip), dim=-1) # [batch, seq_len, 2*d_model] # 全局池化 y = self.global_pool(y.permute(0, 2, 1)) # [batch, 2*d_model, 1] return y.squeeze(-1) # [batch, 2*d_model] class MLAModel(nn.Module): def __init__(self, d_model, max_length=50): super(MLAModel, self).__init__() self.linear = nn.Linear(in_features=21, out_features=d_model) self.pos_encoder = PositionalEncoding(d_model, max_length) self.MLA = MLA(d_model, n_heads=8, max_len=max_length) self.global_pool = nn.AdaptiveAvgPool1d(1) def forward(self, x: torch.Tensor): x = self.pos_encoder(self.linear(x)) y = self.MLA(x) y_flip = self.MLA(x.flip([-2])).flip([-2]) y = torch.cat((y, y_flip), dim=-1) y = self.global_pool(y.permute(0, 2, 1)).squeeze(-1) return y class MLA(nn.Module): def __init__(self, d_model, n_heads, max_len=50, rope_theta=10000.0): super().__init__() self.d_model = d_model self.n_heads = n_heads self.dh = d_model // n_heads self.q_proj_dim = d_model // 2 self.kv_proj_dim = (2*d_model) // 3 self.qk_nope_dim = self.dh // 2 self.qk_rope_dim = self.dh // 2 ## Q projections # Lora self.W_dq = nn.Parameter(0.01*torch.randn((d_model, self.q_proj_dim))) self.W_uq = nn.Parameter(0.01*torch.randn((self.q_proj_dim, self.d_model))) self.q_layernorm = nn.LayerNorm(self.q_proj_dim) ## KV projections # Lora self.W_dkv = nn.Parameter(0.01*torch.randn((d_model, self.kv_proj_dim + self.qk_rope_dim))) self.W_ukv = nn.Parameter(0.01*torch.randn((self.kv_proj_dim, self.d_model + (self.n_heads * self.qk_nope_dim)))) self.kv_layernorm = nn.LayerNorm(self.kv_proj_dim) # output projection self.W_o = nn.Parameter(0.01*torch.randn((d_model, d_model))) # RoPE self.max_seq_len = max_len self.rope_theta = rope_theta # https://github.com/lucidrains/rotary-embedding-torch/tree/main # visualize emb later to make sure it looks ok # we do self.dh here instead of self.qk_rope_dim because its better freqs = 1.0 / (rope_theta ** (torch.arange(0, self.dh, 2).float() / self.dh)) emb = torch.outer(torch.arange(self.max_seq_len).float(), freqs) cos_cached = emb.cos()[None, None, :, :] sin_cached = emb.sin()[None, None, :, :] # https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer # This is like a parameter but its a constant so we can use register_buffer self.register_buffer("cos_cached", cos_cached) self.register_buffer("sin_cached", sin_cached) def apply_rope_x(self, x, cos, sin): return (x * cos) + (self.rotate_half(x) * sin) @staticmethod def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def forward(self, x, kv_cache=None, past_length=0): B, S, D = x.size() # Q Projections compressed_q = x @ self.W_dq compressed_q = self.q_layernorm(compressed_q) Q = compressed_q @ self.W_uq Q = Q.view(B, -1, self.n_heads, self.dh).transpose(1,2) Q, Q_for_rope = torch.split(Q, [self.qk_nope_dim, self.qk_rope_dim], dim=-1) # Q Decoupled RoPE cos_q = self.cos_cached[:, :, past_length:past_length+S, :self.qk_rope_dim//2].repeat(1, 1, 1, 2) sin_q = self.sin_cached[:, :, past_length:past_length+S, :self.qk_rope_dim//2].repeat(1, 1, 1, 2) Q_for_rope = self.apply_rope_x(Q_for_rope, cos_q, sin_q) # KV Projections if kv_cache is None: compressed_kv = x @ self.W_dkv KV_for_lora, K_for_rope = torch.split(compressed_kv, [self.kv_proj_dim, self.qk_rope_dim], dim=-1) KV_for_lora = self.kv_layernorm(KV_for_lora) else: new_kv = x @ self.W_dkv compressed_kv = torch.cat([kv_cache, new_kv], dim=1) new_kv, new_K_for_rope = torch.split(new_kv, [self.kv_proj_dim, self.qk_rope_dim], dim=-1) old_kv, old_K_for_rope = torch.split(kv_cache, [self.kv_proj_dim, self.qk_rope_dim], dim=-1) new_kv = self.kv_layernorm(new_kv) old_kv = self.kv_layernorm(old_kv) KV_for_lora = torch.cat([old_kv, new_kv], dim=1) K_for_rope = torch.cat([old_K_for_rope, new_K_for_rope], dim=1) KV = KV_for_lora @ self.W_ukv KV = KV.view(B, -1, self.n_heads, self.dh+self.qk_nope_dim).transpose(1,2) K, V = torch.split(KV, [self.qk_nope_dim, self.dh], dim=-1) S_full = K.size(2) # K Rope K_for_rope = K_for_rope.view(B, -1, 1, self.qk_rope_dim).transpose(1,2) cos_k = self.cos_cached[:, :, :S_full, :self.qk_rope_dim//2].repeat(1, 1, 1, 2) sin_k = self.sin_cached[:, :, :S_full, :self.qk_rope_dim//2].repeat(1, 1, 1, 2) K_for_rope = self.apply_rope_x(K_for_rope, cos_k, sin_k) # apply position encoding to each head K_for_rope = K_for_rope.repeat(1, self.n_heads, 1, 1) # split into multiple heads q_heads = torch.cat([Q, Q_for_rope], dim=-1) k_heads = torch.cat([K, K_for_rope], dim=-1) v_heads = V # already reshaped before the split # make attention mask mask = torch.ones((S,S_full), device=x.device) mask = torch.tril(mask, diagonal=past_length) mask = mask[None, None, :, :] sq_mask = mask == 1 # attention x = nn.functional.scaled_dot_product_attention( q_heads, k_heads, v_heads, attn_mask=sq_mask ) x = x.transpose(1, 2).reshape(B, S, D) # apply projection x = x @ self.W_o.T return x class DMutaPeptide(nn.Module): def __init__(self, q_encoder='lstm', classes=1, channels=128, dir=False, gf=False, fusion='mlp', non_siamese=False): """ 参数: q_encoder: 使用的编码器类型,支持 'lstm', 'mamba', 'mla', 'mha' classes: 输出类别数 channels: 通道数量,影响隐藏状态维度 dir: 是否使用 DIR 模块 fusion: 融合方法,可选 'mlp'(默认,直接拼接)或 'att'(使用 attention 融合) """ super().__init__() self.classes = classes self.DIR = dir self.gf = gf self.fusion_method = fusion # 融合方式 self.non_siamese = non_siamese # 拼接后维度设定为 channels * 4 final_dim = channels * 4 # 初始化编码器 if q_encoder == 'lstm': self.q_encoder = nn.LSTM( input_size=21, hidden_size=channels, num_layers=2, batch_first=True, # 输入和输出均以 (batch, time_step, input_size) 表示 dropout=0.1, bidirectional=True ) elif q_encoder == 'gru': self.q_encoder = nn.GRU( input_size=21, hidden_size=channels, num_layers=2, batch_first=True, # 输入和输出均以 (batch, time_step, input_size) 表示 dropout=0.1, bidirectional=True ) elif q_encoder == 'mamba': self.q_encoder = MambaModel(channels, 30) elif q_encoder == 'mla': self.q_encoder = MLAModel(channels, 30) elif q_encoder == 'mha': self.q_encoder = MHAModel(channels, 30) else: raise NotImplementedError if non_siamese: self.q_encoder_2 = deepcopy(self.q_encoder) else: self.q_encoder_2 = self.q_encoder if self.fusion_method == 'diff': final_dim //= 2 if gf: self.g_encoder = MLP(1024, [512, 256, 128], channels * 2, dropout_rate=0.3) final_dim += channels * 2 # 如果 fusion 模式为 'att' ,则使用 MultiheadAttention 对两个向量进行融合 if self.fusion_method == 'att': # 假设每个编码器输出的向量维度为 final_dim // 2 embed_dim = channels * 2 self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=4 if gf else 2, batch_first=True) if self.DIR: self.FDS = FDS(final_dim) self.fc = nn.Sequential( nn.Linear(final_dim, 128), nn.Mish(), nn.Dropout(0.3), nn.Linear(128, 64), nn.Mish(), nn.Dropout(0.3), nn.Linear(64, self.classes) ) def norm(self, x, dim=-1, p=2): return F.normalize(x, p=p, dim=dim) def forward(self, x, labels=None, epoch=0): if self.gf: seq1, seq2, gf = x else: seq1, seq2 = x fusion = [] # 获取两个序列的编码结果 if self.q_encoder.__class__.__name__ in ['LSTM', 'GRU']: # 对于 LSTM, 取序列最后时刻的输出,其维度应为 channels*2 (bidirectional) fusion.append(self.norm(self.q_encoder(seq1)[0][:, -1, :])) fusion.append(self.norm(self.q_encoder_2(seq2)[0][:, -1, :])) # elif self.q_encoder.__class__.__name__ in ['MambaModel', 'MLAModel', 'MHAModel']: else: fusion.append(self.norm(self.q_encoder(seq1))) fusion.append(self.norm(self.q_encoder_2(seq2))) if self.gf: fusion.append(self.g_encoder(gf)) # 根据 fusion_method 决定融合方式 if self.fusion_method == 'mlp': # 维持原有行为:拼接两个向量 fusion = torch.cat(fusion, dim=-1) elif self.fusion_method == 'diff': fusion = torch.cat([fusion[1] - fusion[0]] + fusion[2:], dim=-1) elif self.fusion_method == 'att': # 使用 attention 融合: # 先将两个向量堆叠成“tokens”,形状:(batch, 2, embed_dim) tokens = torch.stack(fusion, dim=1) # embed_dim 应该为 final_dim//2 # 利用 MultiheadAttention 进行自注意力计算 # 注意:因为采用 batch_first=True,所以输入形状为 (batch, seq_len, embed_dim) attn_output, _ = self.attn(tokens, tokens, tokens) # 将 attention 输出展平,得到形状 (batch, 2 * embed_dim),即 (batch, final_dim) fusion = attn_output.reshape(attn_output.size(0), -1) else: raise ValueError("Invalid fusion method: choose either 'mse' or 'att'.") # 如果启用 DIR 模块,保留传入 FDS 前的特征表示 if self.DIR: features = fusion fusion = self.FDS.smooth(fusion, labels, epoch) pred = self.fc(fusion).squeeze(-1) if self.DIR: return pred, features else: return pred class CNNEncoder(nn.Module): def __init__(self, feature_dim=256, base_channels=16, in_dim=3): """ feature_dim: 输出的一维特征向量维度 base_channels: 基础卷积模块的通道数 """ super(CNNEncoder, self).__init__() # 卷积层 self.conv = nn.Sequential( nn.Conv2d(in_dim, base_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(base_channels), # nn.ReLU(inplace=True), nn.Mish(inplace=True), nn.MaxPool2d(kernel_size=2), nn.Conv2d(base_channels, base_channels * 2, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(base_channels * 2), # nn.ReLU(inplace=True), nn.Mish(inplace=True), nn.MaxPool2d(kernel_size=2), nn.Conv2d(base_channels * 2, base_channels * 4, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(base_channels * 4), # nn.ReLU(inplace=True), nn.Mish(inplace=True), nn.MaxPool2d(kernel_size=2) ) # 自适应池化,得到固定尺寸(1x1)的特征图 self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1)) # 全连接层将卷积特征转换为一维特征向量 self.fc = nn.Linear(base_channels * 4, feature_dim) def forward(self, img): """ img: [B, 3, 1024, 1024] 输入的 RGB 图像张量 """ # 融合后进一步进行卷积、池化处理 fused_conv = self.conv(img) pooled = self.adaptive_pool(fused_conv) # [B, base_channels*4, 1, 1] # 展平并经过全连接层输出特征向量 flattened = pooled.view(pooled.size(0), -1) # [B, base_channels*4] feature_vector = self.fc(flattened) # [B, feature_dim] return feature_vector class DMutaPeptideCNN(nn.Module): def __init__(self, q_encoder='cnn', classes=1, channels=16, dir=False, gf=False, side_enc=None, fusion='mlp', non_siamese=False): """ 参数: q_encoder: 使用的编码器类型,支持 'lstm', 'mamba', 'mla', 'mha' classes: 输出类别数 channels: 通道数量,影响隐藏状态维度 dir: 是否使用 DIR 模块 fusion: 融合方法,可选 'mlp'(默认,直接拼接)或 'att'(使用 attention 融合) """ super().__init__() self.classes = classes self.DIR = dir self.gf = gf self.fusion_method = fusion # 融合方式 self.non_siamese = non_siamese # 拼接后维度设定为 channels * 4 vector_dim = 512 final_dim = vector_dim * 2 # 初始化编码器 if q_encoder == 'cnn': self.q_encoder = CNNEncoder(feature_dim=vector_dim, base_channels=channels) elif q_encoder == 'rn18': self.q_encoder = resnet18_backbone(pretrained=True) if non_siamese: self.q_encoder_2 = deepcopy(self.q_encoder) else: self.q_encoder_2 = self.q_encoder if side_enc: self.side_enc = True if side_enc == 'lstm': self.side_encoder = nn.LSTM( input_size=21, hidden_size=256, num_layers=2, batch_first=True, # 输入和输出均以 (batch, time_step, input_size) 表示 dropout=0.1, bidirectional=True ) elif side_enc == 'mamba': self.side_encoder = MambaModel(256, 30) else: raise NotImplementedError final_dim += vector_dim * 2 if non_siamese: self.side_encoder_2 = deepcopy(self.side_encoder) else: self.side_encoder_2 = self.side_encoder else: self.side_enc = False if self.fusion_method == 'diff': final_dim //= 2 if gf: self.g_encoder = MLP(1024, [512, 256, 128], vector_dim, dropout_rate=0.3) final_dim += vector_dim # 如果 fusion 模式为 'att' ,则使用 MultiheadAttention 对两个向量进行融合 if self.fusion_method == 'att': # 假设每个编码器输出的向量维度为 final_dim // 2 embed_dim = vector_dim self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=4 if gf else 2, batch_first=True) if self.DIR: self.FDS = FDS(final_dim) self.fc = nn.Sequential( nn.Linear(final_dim, 128), nn.Mish(), nn.Dropout(0.3), nn.Linear(128, 64), nn.Mish(), nn.Dropout(0.3), nn.Linear(64, self.classes) ) def norm(self, x, dim=-1, p=2): return F.normalize(x, p=p, dim=dim) def forward(self, x, labels=None, epoch=0): if self.gf: seq1, seq2, gf = x else: seq1, seq2 = x if self.side_enc: seq1_seq = seq1[1] seq1 = seq1[0] seq2_seq = seq2[1] seq2 = seq2[0] fusion = [] # 获取两个序列的编码结果 fusion.append(self.norm(self.q_encoder(seq1))) fusion.append(self.norm(self.q_encoder_2(seq2))) if self.side_enc: if self.side_encoder.__class__.__name__ == 'MambaModel': fusion.append(self.norm(self.side_encoder(seq1_seq))) fusion.append(self.norm(self.side_encoder_2(seq2_seq))) # elif self.side_encoder.__class__.__name__ == 'LSTM': else: fusion.append(self.norm(self.side_encoder(seq1_seq)[0][:, -1, :])) fusion.append(self.norm(self.side_encoder_2(seq2_seq)[0][:, -1, :])) if self.gf: fusion.append(self.g_encoder(gf)) # 根据 fusion_method 决定融合方式 if self.fusion_method == 'mlp': # 维持原有行为:拼接两个向量 fusion = torch.cat(fusion, dim=-1) elif self.fusion_method == 'diff': if not self.side_enc: fusion = torch.cat([fusion[1] - fusion[0]] + fusion[2:], dim=-1) else: fusion = torch.cat([fusion[1] - fusion[0], fusion[3] - fusion[2]] + fusion[4:], dim=-1) elif self.fusion_method == 'att': # 使用 attention 融合: # 先将两个向量堆叠成“tokens”,形状:(batch, 2, embed_dim) tokens = torch.stack(fusion, dim=1) # embed_dim 应该为 final_dim//2 # 利用 MultiheadAttention 进行自注意力计算 # 注意:因为采用 batch_first=True,所以输入形状为 (batch, seq_len, embed_dim) attn_output, _ = self.attn(tokens, tokens, tokens) # 将 attention 输出展平,得到形状 (batch, 2 * embed_dim),即 (batch, final_dim) fusion = attn_output.reshape(attn_output.size(0), -1) else: raise ValueError("Invalid fusion method: choose either 'mse' or 'att'.") # 如果启用 DIR 模块,保留传入 FDS 前的特征表示 if self.DIR: features = fusion fusion = self.FDS.smooth(fusion, labels, epoch) pred = self.fc(fusion).squeeze(-1) if self.DIR: return pred, features else: return pred def resnet18_backbone(pretrained=False): weights = None if pretrained: weights = 'IMAGENET1K_V1' model = resnet18(weights=weights, progress=False) return torch.nn.Sequential(*list(model.children())[:-1], nn.Flatten()) if __name__ == "__main__": model = resnet18_backbone(pretrained=True) print(model) pass