|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class MutationPredictorCNN(nn.Module): |
|
|
""" |
|
|
Mutation Pathogenicity Predictor CNN |
|
|
Architecture matches the trained checkpoint weights |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3) |
|
|
self.bn1 = nn.BatchNorm1d(64) |
|
|
|
|
|
self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2) |
|
|
self.bn2 = nn.BatchNorm1d(128) |
|
|
|
|
|
self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1) |
|
|
self.bn3 = nn.BatchNorm1d(256) |
|
|
|
|
|
|
|
|
self.pool = nn.MaxPool1d(kernel_size=2, stride=2) |
|
|
|
|
|
|
|
|
self.mut_fc = nn.Linear(12, 32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.adaptive_pool = nn.AdaptiveAvgPool1d(1) |
|
|
|
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(288, 128) |
|
|
self.fc2 = nn.Linear(128, 64) |
|
|
self.fc3 = nn.Linear(64, 1) |
|
|
|
|
|
|
|
|
|
|
|
self.importance_head = nn.Linear(256, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Forward pass |
|
|
|
|
|
Args: |
|
|
x: Input tensor (batch, 1101) |
|
|
[0:990] - sequence features (99*10) |
|
|
[990:1089] - difference mask (99) |
|
|
[1089:1101] - mutation type (12) |
|
|
|
|
|
Returns: |
|
|
cls: Classification output (batch, 1) |
|
|
importance: Importance score (batch, 1) |
|
|
""" |
|
|
batch_size = x.size(0) |
|
|
|
|
|
|
|
|
mut_type = x[:, 1089:1101] |
|
|
|
|
|
|
|
|
|
|
|
x_seq = x[:, :1089].view(batch_size, 11, 99) |
|
|
|
|
|
|
|
|
x_conv = F.relu(self.bn1(self.conv1(x_seq))) |
|
|
x_conv = self.pool(x_conv) |
|
|
|
|
|
x_conv = F.relu(self.bn2(self.conv2(x_conv))) |
|
|
x_conv = self.pool(x_conv) |
|
|
|
|
|
x_conv = F.relu(self.bn3(self.conv3(x_conv))) |
|
|
x_conv = self.pool(x_conv) |
|
|
|
|
|
|
|
|
x_conv = self.adaptive_pool(x_conv) |
|
|
conv_features = x_conv.view(batch_size, 256) |
|
|
|
|
|
|
|
|
mut_features = F.relu(self.mut_fc(mut_type)) |
|
|
|
|
|
|
|
|
combined = torch.cat([conv_features, mut_features], dim=1) |
|
|
|
|
|
|
|
|
x = F.relu(self.fc1(combined)) |
|
|
x = F.relu(self.fc2(x)) |
|
|
cls = torch.sigmoid(self.fc3(x)) |
|
|
|
|
|
|
|
|
importance = torch.sigmoid(self.importance_head(conv_features)) |
|
|
|
|
|
return cls, importance |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("Testing MutationPredictorCNN...") |
|
|
|
|
|
model = MutationPredictorCNN() |
|
|
|
|
|
|
|
|
test_input = torch.randn(2, 1101) |
|
|
|
|
|
cls, importance = model(test_input) |
|
|
|
|
|
print(f"Input shape: {test_input.shape}") |
|
|
print(f"Classification output shape: {cls.shape}") |
|
|
print(f"Importance output shape: {importance.shape}") |
|
|
|
|
|
print("\nModel parameter shapes (should match checkpoint):") |
|
|
for name, param in model.named_parameters(): |
|
|
print(f"{name:30s}: {str(param.shape):20s}") |
|
|
|
|
|
print("\nExpected parameter shapes from checkpoint:") |
|
|
print("conv1.weight : torch.Size([64, 11, 7])") |
|
|
print("conv3.weight : torch.Size([256, 128, 3])") |
|
|
print("mut_fc.weight : torch.Size([32, 12])") |
|
|
print("fc1.weight : torch.Size([128, 288])") |
|
|
print("importance_head.weight : torch.Size([1, 256])") |
|
|
|