File size: 4,779 Bytes
f986766
 
d989594
f986766
 
d989594
 
 
 
f986766
 
 
 
d989594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f986766
 
d989594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn as nn
import torch.nn.functional as F

class MutationPredictorCNN(nn.Module):
    """
    Mutation Pathogenicity Predictor CNN
    Architecture matches the trained checkpoint weights
    """

    def __init__(self):
        super().__init__()

        # Convolutional layers - CORRECTED kernel sizes to match checkpoint
        self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3)  # Changed from 5 to 7
        self.bn1 = nn.BatchNorm1d(64)

        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)  # Correct
        self.bn2 = nn.BatchNorm1d(128)

        self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)  # Changed from 5 to 3
        self.bn3 = nn.BatchNorm1d(256)

        # Pooling layers to reduce dimensions
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)

        # Mutation type processing - CORRECTED
        self.mut_fc = nn.Linear(12, 32)  # Changed from (256, 1) to (12, 32)

        # After 3 pooling layers: 99 -> 49 -> 24 -> 12
        # Conv output: 256 channels * 12 positions = 3072
        # Mutation features: 32
        # Total: 3072 + 32 = 3104
        # BUT checkpoint shows fc1 is (128, 288)
        # So we need adaptive pooling to get 256 features from conv
        
        self.adaptive_pool = nn.AdaptiveAvgPool1d(1)  # Pool to single value per channel
        
        # Fully connected layers - CORRECTED to match checkpoint
        # Input: 256 (conv) + 32 (mutation) = 288
        self.fc1 = nn.Linear(288, 128)  # Changed from (256*99, 256) to (288, 128)
        self.fc2 = nn.Linear(128, 64)   # Changed from (256, 64) to (128, 64)
        self.fc3 = nn.Linear(64, 1)     # Correct

        # Importance head - CORRECTED
        # Takes conv features (256) and outputs single importance score
        self.importance_head = nn.Linear(256, 1)  # Changed from (256*99, 99) to (256, 1)

    def forward(self, x):
        """
        Forward pass
        
        Args:
            x: Input tensor (batch, 1101)
               [0:990] - sequence features (99*10)
               [990:1089] - difference mask (99)
               [1089:1101] - mutation type (12)
        
        Returns:
            cls: Classification output (batch, 1)
            importance: Importance score (batch, 1)
        """
        batch_size = x.size(0)
        
        # Extract mutation type for separate processing
        mut_type = x[:, 1089:1101]  # Last 12 dimensions
        
        # Reshape remaining features for CNN
        # First 1089 features -> reshape to (batch, 11, 99)
        x_seq = x[:, :1089].view(batch_size, 11, 99)
        
        # Convolutional layers with pooling
        x_conv = F.relu(self.bn1(self.conv1(x_seq)))
        x_conv = self.pool(x_conv)  # 99 -> 49
        
        x_conv = F.relu(self.bn2(self.conv2(x_conv)))
        x_conv = self.pool(x_conv)  # 49 -> 24
        
        x_conv = F.relu(self.bn3(self.conv3(x_conv)))
        x_conv = self.pool(x_conv)  # 24 -> 12
        
        # Adaptive pooling to get fixed 256 features
        x_conv = self.adaptive_pool(x_conv)  # (batch, 256, 1)
        conv_features = x_conv.view(batch_size, 256)  # (batch, 256)
        
        # Process mutation type
        mut_features = F.relu(self.mut_fc(mut_type))  # (batch, 32)
        
        # Concatenate features
        combined = torch.cat([conv_features, mut_features], dim=1)  # (batch, 288)
        
        # Classification branch
        x = F.relu(self.fc1(combined))
        x = F.relu(self.fc2(x))
        cls = torch.sigmoid(self.fc3(x))  # (batch, 1)
        
        # Importance branch (uses conv features)
        importance = torch.sigmoid(self.importance_head(conv_features))  # (batch, 1)
        
        return cls, importance


if __name__ == "__main__":
    # Test the model
    print("Testing MutationPredictorCNN...")
    
    model = MutationPredictorCNN()
    
    # Test input (batch_size=2, features=1101)
    test_input = torch.randn(2, 1101)
    
    cls, importance = model(test_input)
    
    print(f"Input shape: {test_input.shape}")
    print(f"Classification output shape: {cls.shape}")
    print(f"Importance output shape: {importance.shape}")
    
    print("\nModel parameter shapes (should match checkpoint):")
    for name, param in model.named_parameters():
        print(f"{name:30s}: {str(param.shape):20s}")
    
    print("\nExpected parameter shapes from checkpoint:")
    print("conv1.weight                  : torch.Size([64, 11, 7])")
    print("conv3.weight                  : torch.Size([256, 128, 3])")
    print("mut_fc.weight                 : torch.Size([32, 12])")
    print("fc1.weight                    : torch.Size([128, 288])")
    print("importance_head.weight        : torch.Size([1, 256])")