DAminoMuta / network.py
auralray's picture
Upload folder using huggingface_hub
acbef3a verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from mamba_ssm import Mamba
from utils import FDS
from torchvision.models import resnet18
class MambaModel(nn.Module):
def __init__(self, d_model, max_length=30):
super(MambaModel, self).__init__()
self.linear = nn.Linear(in_features=21, out_features=d_model)
self.pos_encoder = PositionalEncoding(d_model, max_length)
self.mamba = Mamba(d_model=d_model, d_state=32, expand=4)
self.global_pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x: torch.Tensor):
x = self.pos_encoder(self.linear(x))
y = self.mamba(x)
y_flip = self.mamba(x.flip([-2])).flip([-2])
y = torch.cat((y, y_flip), dim=-1)
y = self.global_pool(y.permute(0, 2, 1)).squeeze(-1)
return y
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout_rate=0.1):
super(MLP, self).__init__()
if isinstance(hidden_dim, int):
hidden_dim = [hidden_dim] * num_layers
layers = []
layers.append(nn.Linear(input_dim, hidden_dim[0]))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout_rate))
for i in range(len(hidden_dim) - 1):
layers.append(nn.Linear(hidden_dim[i], hidden_dim[i + 1]))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout_rate))
layers.append(nn.Linear(hidden_dim[-1], output_dim))
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=50):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model) # (max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (max_len, 1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-torch.log(torch.FloatTensor([10000.0])) / d_model)) # (d_model/2,)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
"""
x: (B, N, d_model)
"""
x = x + self.pe[:, :x.size(1), :]
return x
class MHAModel(nn.Module):
def __init__(self, d_model, max_length=50):
super(MHAModel, self).__init__()
self.linear = nn.Linear(in_features=21, out_features=d_model)
self.pos_encoder = PositionalEncoding(d_model, max_length)
self.self_attn = nn.MultiheadAttention(d_model, num_heads=8, batch_first=True)
self.global_pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x: torch.Tensor):
# 线性变换 + 位置编码
x = self.pos_encoder(self.linear(x)) # [batch, seq_len, d_model]
# 正向自注意力
y, _ = self.self_attn(x, x, x) # [batch, seq_len, d_model]
# 反向自注意力
x_flip = x.flip([-2]) # 沿序列维度翻转
y_flip, _ = self.self_attn(x_flip, x_flip, x_flip)
y_flip = y_flip.flip([-2]) # 翻转回原顺序
# 拼接正反向结果
y = torch.cat((y, y_flip), dim=-1) # [batch, seq_len, 2*d_model]
# 全局池化
y = self.global_pool(y.permute(0, 2, 1)) # [batch, 2*d_model, 1]
return y.squeeze(-1) # [batch, 2*d_model]
class MLAModel(nn.Module):
def __init__(self, d_model, max_length=50):
super(MLAModel, self).__init__()
self.linear = nn.Linear(in_features=21, out_features=d_model)
self.pos_encoder = PositionalEncoding(d_model, max_length)
self.MLA = MLA(d_model, n_heads=8, max_len=max_length)
self.global_pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x: torch.Tensor):
x = self.pos_encoder(self.linear(x))
y = self.MLA(x)
y_flip = self.MLA(x.flip([-2])).flip([-2])
y = torch.cat((y, y_flip), dim=-1)
y = self.global_pool(y.permute(0, 2, 1)).squeeze(-1)
return y
class MLA(nn.Module):
def __init__(self, d_model, n_heads, max_len=50, rope_theta=10000.0):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.dh = d_model // n_heads
self.q_proj_dim = d_model // 2
self.kv_proj_dim = (2*d_model) // 3
self.qk_nope_dim = self.dh // 2
self.qk_rope_dim = self.dh // 2
## Q projections
# Lora
self.W_dq = nn.Parameter(0.01*torch.randn((d_model, self.q_proj_dim)))
self.W_uq = nn.Parameter(0.01*torch.randn((self.q_proj_dim, self.d_model)))
self.q_layernorm = nn.LayerNorm(self.q_proj_dim)
## KV projections
# Lora
self.W_dkv = nn.Parameter(0.01*torch.randn((d_model, self.kv_proj_dim + self.qk_rope_dim)))
self.W_ukv = nn.Parameter(0.01*torch.randn((self.kv_proj_dim,
self.d_model + (self.n_heads * self.qk_nope_dim))))
self.kv_layernorm = nn.LayerNorm(self.kv_proj_dim)
# output projection
self.W_o = nn.Parameter(0.01*torch.randn((d_model, d_model)))
# RoPE
self.max_seq_len = max_len
self.rope_theta = rope_theta
# https://github.com/lucidrains/rotary-embedding-torch/tree/main
# visualize emb later to make sure it looks ok
# we do self.dh here instead of self.qk_rope_dim because its better
freqs = 1.0 / (rope_theta ** (torch.arange(0, self.dh, 2).float() / self.dh))
emb = torch.outer(torch.arange(self.max_seq_len).float(), freqs)
cos_cached = emb.cos()[None, None, :, :]
sin_cached = emb.sin()[None, None, :, :]
# https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer
# This is like a parameter but its a constant so we can use register_buffer
self.register_buffer("cos_cached", cos_cached)
self.register_buffer("sin_cached", sin_cached)
def apply_rope_x(self, x, cos, sin):
return (x * cos) + (self.rotate_half(x) * sin)
@staticmethod
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x, kv_cache=None, past_length=0):
B, S, D = x.size()
# Q Projections
compressed_q = x @ self.W_dq
compressed_q = self.q_layernorm(compressed_q)
Q = compressed_q @ self.W_uq
Q = Q.view(B, -1, self.n_heads, self.dh).transpose(1,2)
Q, Q_for_rope = torch.split(Q, [self.qk_nope_dim, self.qk_rope_dim], dim=-1)
# Q Decoupled RoPE
cos_q = self.cos_cached[:, :, past_length:past_length+S, :self.qk_rope_dim//2].repeat(1, 1, 1, 2)
sin_q = self.sin_cached[:, :, past_length:past_length+S, :self.qk_rope_dim//2].repeat(1, 1, 1, 2)
Q_for_rope = self.apply_rope_x(Q_for_rope, cos_q, sin_q)
# KV Projections
if kv_cache is None:
compressed_kv = x @ self.W_dkv
KV_for_lora, K_for_rope = torch.split(compressed_kv,
[self.kv_proj_dim, self.qk_rope_dim],
dim=-1)
KV_for_lora = self.kv_layernorm(KV_for_lora)
else:
new_kv = x @ self.W_dkv
compressed_kv = torch.cat([kv_cache, new_kv], dim=1)
new_kv, new_K_for_rope = torch.split(new_kv,
[self.kv_proj_dim, self.qk_rope_dim],
dim=-1)
old_kv, old_K_for_rope = torch.split(kv_cache,
[self.kv_proj_dim, self.qk_rope_dim],
dim=-1)
new_kv = self.kv_layernorm(new_kv)
old_kv = self.kv_layernorm(old_kv)
KV_for_lora = torch.cat([old_kv, new_kv], dim=1)
K_for_rope = torch.cat([old_K_for_rope, new_K_for_rope], dim=1)
KV = KV_for_lora @ self.W_ukv
KV = KV.view(B, -1, self.n_heads, self.dh+self.qk_nope_dim).transpose(1,2)
K, V = torch.split(KV, [self.qk_nope_dim, self.dh], dim=-1)
S_full = K.size(2)
# K Rope
K_for_rope = K_for_rope.view(B, -1, 1, self.qk_rope_dim).transpose(1,2)
cos_k = self.cos_cached[:, :, :S_full, :self.qk_rope_dim//2].repeat(1, 1, 1, 2)
sin_k = self.sin_cached[:, :, :S_full, :self.qk_rope_dim//2].repeat(1, 1, 1, 2)
K_for_rope = self.apply_rope_x(K_for_rope, cos_k, sin_k)
# apply position encoding to each head
K_for_rope = K_for_rope.repeat(1, self.n_heads, 1, 1)
# split into multiple heads
q_heads = torch.cat([Q, Q_for_rope], dim=-1)
k_heads = torch.cat([K, K_for_rope], dim=-1)
v_heads = V # already reshaped before the split
# make attention mask
mask = torch.ones((S,S_full), device=x.device)
mask = torch.tril(mask, diagonal=past_length)
mask = mask[None, None, :, :]
sq_mask = mask == 1
# attention
x = nn.functional.scaled_dot_product_attention(
q_heads, k_heads, v_heads,
attn_mask=sq_mask
)
x = x.transpose(1, 2).reshape(B, S, D)
# apply projection
x = x @ self.W_o.T
return x
class DMutaPeptide(nn.Module):
def __init__(self, q_encoder='lstm', classes=1, channels=128, dir=False, gf=False, fusion='mlp', non_siamese=False):
"""
参数:
q_encoder: 使用的编码器类型,支持 'lstm', 'mamba', 'mla', 'mha'
classes: 输出类别数
channels: 通道数量,影响隐藏状态维度
dir: 是否使用 DIR 模块
fusion: 融合方法,可选 'mlp'(默认,直接拼接)或 'att'(使用 attention 融合)
"""
super().__init__()
self.classes = classes
self.DIR = dir
self.gf = gf
self.fusion_method = fusion # 融合方式
self.non_siamese = non_siamese
# 拼接后维度设定为 channels * 4
final_dim = channels * 4
# 初始化编码器
if q_encoder == 'lstm':
self.q_encoder = nn.LSTM(
input_size=21,
hidden_size=channels,
num_layers=2,
batch_first=True, # 输入和输出均以 (batch, time_step, input_size) 表示
dropout=0.1,
bidirectional=True
)
elif q_encoder == 'gru':
self.q_encoder = nn.GRU(
input_size=21,
hidden_size=channels,
num_layers=2,
batch_first=True, # 输入和输出均以 (batch, time_step, input_size) 表示
dropout=0.1,
bidirectional=True
)
elif q_encoder == 'mamba':
self.q_encoder = MambaModel(channels, 30)
elif q_encoder == 'mla':
self.q_encoder = MLAModel(channels, 30)
elif q_encoder == 'mha':
self.q_encoder = MHAModel(channels, 30)
else:
raise NotImplementedError
if non_siamese:
self.q_encoder_2 = deepcopy(self.q_encoder)
else:
self.q_encoder_2 = self.q_encoder
if self.fusion_method == 'diff':
final_dim //= 2
if gf:
self.g_encoder = MLP(1024, [512, 256, 128], channels * 2, dropout_rate=0.3)
final_dim += channels * 2
# 如果 fusion 模式为 'att' ,则使用 MultiheadAttention 对两个向量进行融合
if self.fusion_method == 'att':
# 假设每个编码器输出的向量维度为 final_dim // 2
embed_dim = channels * 2
self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=4 if gf else 2, batch_first=True)
if self.DIR:
self.FDS = FDS(final_dim)
self.fc = nn.Sequential(
nn.Linear(final_dim, 128),
nn.Mish(),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.Mish(),
nn.Dropout(0.3),
nn.Linear(64, self.classes)
)
def norm(self, x, dim=-1, p=2):
return F.normalize(x, p=p, dim=dim)
def forward(self, x, labels=None, epoch=0):
if self.gf:
seq1, seq2, gf = x
else:
seq1, seq2 = x
fusion = []
# 获取两个序列的编码结果
if self.q_encoder.__class__.__name__ in ['LSTM', 'GRU']:
# 对于 LSTM, 取序列最后时刻的输出,其维度应为 channels*2 (bidirectional)
fusion.append(self.norm(self.q_encoder(seq1)[0][:, -1, :]))
fusion.append(self.norm(self.q_encoder_2(seq2)[0][:, -1, :]))
# elif self.q_encoder.__class__.__name__ in ['MambaModel', 'MLAModel', 'MHAModel']:
else:
fusion.append(self.norm(self.q_encoder(seq1)))
fusion.append(self.norm(self.q_encoder_2(seq2)))
if self.gf:
fusion.append(self.g_encoder(gf))
# 根据 fusion_method 决定融合方式
if self.fusion_method == 'mlp':
# 维持原有行为:拼接两个向量
fusion = torch.cat(fusion, dim=-1)
elif self.fusion_method == 'diff':
fusion = torch.cat([fusion[1] - fusion[0]] + fusion[2:], dim=-1)
elif self.fusion_method == 'att':
# 使用 attention 融合:
# 先将两个向量堆叠成“tokens”,形状:(batch, 2, embed_dim)
tokens = torch.stack(fusion, dim=1) # embed_dim 应该为 final_dim//2
# 利用 MultiheadAttention 进行自注意力计算
# 注意:因为采用 batch_first=True,所以输入形状为 (batch, seq_len, embed_dim)
attn_output, _ = self.attn(tokens, tokens, tokens)
# 将 attention 输出展平,得到形状 (batch, 2 * embed_dim),即 (batch, final_dim)
fusion = attn_output.reshape(attn_output.size(0), -1)
else:
raise ValueError("Invalid fusion method: choose either 'mse' or 'att'.")
# 如果启用 DIR 模块,保留传入 FDS 前的特征表示
if self.DIR:
features = fusion
fusion = self.FDS.smooth(fusion, labels, epoch)
pred = self.fc(fusion).squeeze(-1)
if self.DIR:
return pred, features
else:
return pred
class CNNEncoder(nn.Module):
def __init__(self, feature_dim=256, base_channels=16, in_dim=3):
"""
feature_dim: 输出的一维特征向量维度
base_channels: 基础卷积模块的通道数
"""
super(CNNEncoder, self).__init__()
# 卷积层
self.conv = nn.Sequential(
nn.Conv2d(in_dim, base_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(base_channels),
# nn.ReLU(inplace=True),
nn.Mish(inplace=True),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(base_channels, base_channels * 2, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(base_channels * 2),
# nn.ReLU(inplace=True),
nn.Mish(inplace=True),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(base_channels * 2, base_channels * 4, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(base_channels * 4),
# nn.ReLU(inplace=True),
nn.Mish(inplace=True),
nn.MaxPool2d(kernel_size=2)
)
# 自适应池化,得到固定尺寸(1x1)的特征图
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
# 全连接层将卷积特征转换为一维特征向量
self.fc = nn.Linear(base_channels * 4, feature_dim)
def forward(self, img):
"""
img: [B, 3, 1024, 1024] 输入的 RGB 图像张量
"""
# 融合后进一步进行卷积、池化处理
fused_conv = self.conv(img)
pooled = self.adaptive_pool(fused_conv) # [B, base_channels*4, 1, 1]
# 展平并经过全连接层输出特征向量
flattened = pooled.view(pooled.size(0), -1) # [B, base_channels*4]
feature_vector = self.fc(flattened) # [B, feature_dim]
return feature_vector
class DMutaPeptideCNN(nn.Module):
def __init__(self, q_encoder='cnn', classes=1, channels=16, dir=False, gf=False, side_enc=None, fusion='mlp', non_siamese=False):
"""
参数:
q_encoder: 使用的编码器类型,支持 'lstm', 'mamba', 'mla', 'mha'
classes: 输出类别数
channels: 通道数量,影响隐藏状态维度
dir: 是否使用 DIR 模块
fusion: 融合方法,可选 'mlp'(默认,直接拼接)或 'att'(使用 attention 融合)
"""
super().__init__()
self.classes = classes
self.DIR = dir
self.gf = gf
self.fusion_method = fusion # 融合方式
self.non_siamese = non_siamese
# 拼接后维度设定为 channels * 4
vector_dim = 512
final_dim = vector_dim * 2
# 初始化编码器
if q_encoder == 'cnn':
self.q_encoder = CNNEncoder(feature_dim=vector_dim, base_channels=channels)
elif q_encoder == 'rn18':
self.q_encoder = resnet18_backbone(pretrained=True)
if non_siamese:
self.q_encoder_2 = deepcopy(self.q_encoder)
else:
self.q_encoder_2 = self.q_encoder
if side_enc:
self.side_enc = True
if side_enc == 'lstm':
self.side_encoder = nn.LSTM(
input_size=21,
hidden_size=256,
num_layers=2,
batch_first=True, # 输入和输出均以 (batch, time_step, input_size) 表示
dropout=0.1,
bidirectional=True
)
elif side_enc == 'mamba':
self.side_encoder = MambaModel(256, 30)
else:
raise NotImplementedError
final_dim += vector_dim * 2
if non_siamese:
self.side_encoder_2 = deepcopy(self.side_encoder)
else:
self.side_encoder_2 = self.side_encoder
else:
self.side_enc = False
if self.fusion_method == 'diff':
final_dim //= 2
if gf:
self.g_encoder = MLP(1024, [512, 256, 128], vector_dim, dropout_rate=0.3)
final_dim += vector_dim
# 如果 fusion 模式为 'att' ,则使用 MultiheadAttention 对两个向量进行融合
if self.fusion_method == 'att':
# 假设每个编码器输出的向量维度为 final_dim // 2
embed_dim = vector_dim
self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=4 if gf else 2, batch_first=True)
if self.DIR:
self.FDS = FDS(final_dim)
self.fc = nn.Sequential(
nn.Linear(final_dim, 128),
nn.Mish(),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.Mish(),
nn.Dropout(0.3),
nn.Linear(64, self.classes)
)
def norm(self, x, dim=-1, p=2):
return F.normalize(x, p=p, dim=dim)
def forward(self, x, labels=None, epoch=0):
if self.gf:
seq1, seq2, gf = x
else:
seq1, seq2 = x
if self.side_enc:
seq1_seq = seq1[1]
seq1 = seq1[0]
seq2_seq = seq2[1]
seq2 = seq2[0]
fusion = []
# 获取两个序列的编码结果
fusion.append(self.norm(self.q_encoder(seq1)))
fusion.append(self.norm(self.q_encoder_2(seq2)))
if self.side_enc:
if self.side_encoder.__class__.__name__ == 'MambaModel':
fusion.append(self.norm(self.side_encoder(seq1_seq)))
fusion.append(self.norm(self.side_encoder_2(seq2_seq)))
# elif self.side_encoder.__class__.__name__ == 'LSTM':
else:
fusion.append(self.norm(self.side_encoder(seq1_seq)[0][:, -1, :]))
fusion.append(self.norm(self.side_encoder_2(seq2_seq)[0][:, -1, :]))
if self.gf:
fusion.append(self.g_encoder(gf))
# 根据 fusion_method 决定融合方式
if self.fusion_method == 'mlp':
# 维持原有行为:拼接两个向量
fusion = torch.cat(fusion, dim=-1)
elif self.fusion_method == 'diff':
if not self.side_enc:
fusion = torch.cat([fusion[1] - fusion[0]] + fusion[2:], dim=-1)
else:
fusion = torch.cat([fusion[1] - fusion[0], fusion[3] - fusion[2]] + fusion[4:], dim=-1)
elif self.fusion_method == 'att':
# 使用 attention 融合:
# 先将两个向量堆叠成“tokens”,形状:(batch, 2, embed_dim)
tokens = torch.stack(fusion, dim=1) # embed_dim 应该为 final_dim//2
# 利用 MultiheadAttention 进行自注意力计算
# 注意:因为采用 batch_first=True,所以输入形状为 (batch, seq_len, embed_dim)
attn_output, _ = self.attn(tokens, tokens, tokens)
# 将 attention 输出展平,得到形状 (batch, 2 * embed_dim),即 (batch, final_dim)
fusion = attn_output.reshape(attn_output.size(0), -1)
else:
raise ValueError("Invalid fusion method: choose either 'mse' or 'att'.")
# 如果启用 DIR 模块,保留传入 FDS 前的特征表示
if self.DIR:
features = fusion
fusion = self.FDS.smooth(fusion, labels, epoch)
pred = self.fc(fusion).squeeze(-1)
if self.DIR:
return pred, features
else:
return pred
def resnet18_backbone(pretrained=False):
weights = None
if pretrained:
weights = 'IMAGENET1K_V1'
model = resnet18(weights=weights, progress=False)
return torch.nn.Sequential(*list(model.children())[:-1], nn.Flatten())
if __name__ == "__main__":
model = resnet18_backbone(pretrained=True)
print(model)
pass