| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class MutationPredictorCNN(nn.Module): |
| """ |
| Mutation Pathogenicity Predictor CNN |
| Architecture matches the trained checkpoint weights |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| |
| self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3) |
| self.bn1 = nn.BatchNorm1d(64) |
|
|
| self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2) |
| self.bn2 = nn.BatchNorm1d(128) |
|
|
| self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1) |
| self.bn3 = nn.BatchNorm1d(256) |
|
|
| |
| self.pool = nn.MaxPool1d(kernel_size=2, stride=2) |
|
|
| |
| self.mut_fc = nn.Linear(12, 32) |
|
|
| |
| |
| |
| |
| |
| |
| |
| self.adaptive_pool = nn.AdaptiveAvgPool1d(1) |
| |
| |
| |
| self.fc1 = nn.Linear(288, 128) |
| self.fc2 = nn.Linear(128, 64) |
| self.fc3 = nn.Linear(64, 1) |
|
|
| |
| |
| self.importance_head = nn.Linear(256, 1) |
|
|
| def forward(self, x): |
| """ |
| Forward pass |
| |
| Args: |
| x: Input tensor (batch, 1101) |
| [0:990] - sequence features (99*10) |
| [990:1089] - difference mask (99) |
| [1089:1101] - mutation type (12) |
| |
| Returns: |
| cls: Classification output (batch, 1) |
| importance: Importance score (batch, 1) |
| """ |
| batch_size = x.size(0) |
| |
| |
| mut_type = x[:, 1089:1101] |
| |
| |
| |
| x_seq = x[:, :1089].view(batch_size, 11, 99) |
| |
| |
| x_conv = F.relu(self.bn1(self.conv1(x_seq))) |
| x_conv = self.pool(x_conv) |
| |
| x_conv = F.relu(self.bn2(self.conv2(x_conv))) |
| x_conv = self.pool(x_conv) |
| |
| x_conv = F.relu(self.bn3(self.conv3(x_conv))) |
| x_conv = self.pool(x_conv) |
| |
| |
| x_conv = self.adaptive_pool(x_conv) |
| conv_features = x_conv.view(batch_size, 256) |
| |
| |
| mut_features = F.relu(self.mut_fc(mut_type)) |
| |
| |
| combined = torch.cat([conv_features, mut_features], dim=1) |
| |
| |
| x = F.relu(self.fc1(combined)) |
| x = F.relu(self.fc2(x)) |
| cls = torch.sigmoid(self.fc3(x)) |
| |
| |
| importance = torch.sigmoid(self.importance_head(conv_features)) |
| |
| return cls, importance |
|
|
|
|
| if __name__ == "__main__": |
| |
| print("Testing MutationPredictorCNN...") |
| |
| model = MutationPredictorCNN() |
| |
| |
| test_input = torch.randn(2, 1101) |
| |
| cls, importance = model(test_input) |
| |
| print(f"Input shape: {test_input.shape}") |
| print(f"Classification output shape: {cls.shape}") |
| print(f"Importance output shape: {importance.shape}") |
| |
| print("\nModel parameter shapes (should match checkpoint):") |
| for name, param in model.named_parameters(): |
| print(f"{name:30s}: {str(param.shape):20s}") |
| |
| print("\nExpected parameter shapes from checkpoint:") |
| print("conv1.weight : torch.Size([64, 11, 7])") |
| print("conv3.weight : torch.Size([256, 128, 3])") |
| print("mut_fc.weight : torch.Size([32, 12])") |
| print("fc1.weight : torch.Size([128, 288])") |
| print("importance_head.weight : torch.Size([1, 256])") |
| |