nileshhanotia commited on
Commit
d989594
·
verified ·
1 Parent(s): 766ebcf

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +118 -64
model.py CHANGED
@@ -1,74 +1,128 @@
1
  import torch
2
  import torch.nn as nn
 
3
 
4
  class MutationPredictorCNN(nn.Module):
 
 
 
 
5
 
6
  def __init__(self):
7
  super().__init__()
8
 
9
- # convolution backbone (MUST match training)
10
- self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3)
11
- self.bn1 = nn.BatchNorm1d(64)
12
-
13
- self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
14
- self.bn2 = nn.BatchNorm1d(128)
15
-
16
- self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
17
- self.bn3 = nn.BatchNorm1d(256)
18
-
19
- self.global_pool = nn.AdaptiveAvgPool1d(1)
20
-
21
- # mutation branch
22
- self.mut_fc = nn.Linear(12, 32)
23
-
24
- # explainability head
25
- self.importance_head = nn.Linear(256, 1)
26
-
27
- # classifier head
28
- self.fc1 = nn.Linear(288, 128)
29
- self.fc2 = nn.Linear(128, 64)
30
- self.fc3 = nn.Linear(64, 1)
31
-
32
- self.relu = nn.ReLU()
33
- self.dropout = nn.Dropout(0.4)
 
 
 
 
 
 
 
 
 
34
 
35
  def forward(self, x):
36
-
37
- seq_flat = x[:, :1089]
38
- mut_one = x[:, 1089:1101]
39
-
40
- h = self.relu(self.bn1(self.conv1(seq_flat.view(-1, 11, 99))))
41
- h = self.relu(self.bn2(self.conv2(h)))
42
-
43
- conv_out = self.relu(self.bn3(self.conv3(h)))
44
-
45
- # full explainability map
46
- importance_map = torch.sigmoid(
47
- self.importance_head(conv_out.permute(0,2,1))
48
- ).squeeze(-1)
49
-
50
- # mutation position
51
- mut_pos = x[:, 990:1089].argmax(dim=1).clamp(0,98)
52
-
53
- mut_feat = conv_out[
54
- torch.arange(conv_out.size(0)),
55
- :,
56
- mut_pos
57
- ]
58
-
59
- importance_score = torch.sigmoid(
60
- self.importance_head(mut_feat)
61
- )
62
-
63
- pooled = self.global_pool(conv_out).squeeze(-1)
64
-
65
- mut_vec = self.relu(self.mut_fc(mut_one))
66
-
67
- combined = torch.cat([pooled, mut_vec], dim=1)
68
-
69
- out = self.dropout(self.relu(self.fc1(combined)))
70
- out = self.dropout(self.relu(self.fc2(out)))
71
-
72
- logit = self.fc3(out)
73
-
74
- return logit, importance_score, importance_map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
 
5
  class MutationPredictorCNN(nn.Module):
6
+ """
7
+ Mutation Pathogenicity Predictor CNN
8
+ Architecture matches the trained checkpoint weights
9
+ """
10
 
11
  def __init__(self):
12
  super().__init__()
13
 
14
+ # Convolutional layers - CORRECTED kernel sizes to match checkpoint
15
+ self.conv1 = nn.Conv1d(11, 64, kernel_size=7, padding=3) # Changed from 5 to 7
16
+ self.bn1 = nn.BatchNorm1d(64)
17
+
18
+ self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2) # Correct
19
+ self.bn2 = nn.BatchNorm1d(128)
20
+
21
+ self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1) # Changed from 5 to 3
22
+ self.bn3 = nn.BatchNorm1d(256)
23
+
24
+ # Pooling layers to reduce dimensions
25
+ self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
26
+
27
+ # Mutation type processing - CORRECTED
28
+ self.mut_fc = nn.Linear(12, 32) # Changed from (256, 1) to (12, 32)
29
+
30
+ # After 3 pooling layers: 99 -> 49 -> 24 -> 12
31
+ # Conv output: 256 channels * 12 positions = 3072
32
+ # Mutation features: 32
33
+ # Total: 3072 + 32 = 3104
34
+ # BUT checkpoint shows fc1 is (128, 288)
35
+ # So we need adaptive pooling to get 256 features from conv
36
+
37
+ self.adaptive_pool = nn.AdaptiveAvgPool1d(1) # Pool to single value per channel
38
+
39
+ # Fully connected layers - CORRECTED to match checkpoint
40
+ # Input: 256 (conv) + 32 (mutation) = 288
41
+ self.fc1 = nn.Linear(288, 128) # Changed from (256*99, 256) to (288, 128)
42
+ self.fc2 = nn.Linear(128, 64) # Changed from (256, 64) to (128, 64)
43
+ self.fc3 = nn.Linear(64, 1) # Correct
44
+
45
+ # Importance head - CORRECTED
46
+ # Takes conv features (256) and outputs single importance score
47
+ self.importance_head = nn.Linear(256, 1) # Changed from (256*99, 99) to (256, 1)
48
 
49
  def forward(self, x):
50
+ """
51
+ Forward pass
52
+
53
+ Args:
54
+ x: Input tensor (batch, 1101)
55
+ [0:990] - sequence features (99*10)
56
+ [990:1089] - difference mask (99)
57
+ [1089:1101] - mutation type (12)
58
+
59
+ Returns:
60
+ cls: Classification output (batch, 1)
61
+ importance: Importance score (batch, 1)
62
+ """
63
+ batch_size = x.size(0)
64
+
65
+ # Extract mutation type for separate processing
66
+ mut_type = x[:, 1089:1101] # Last 12 dimensions
67
+
68
+ # Reshape remaining features for CNN
69
+ # First 1089 features -> reshape to (batch, 11, 99)
70
+ x_seq = x[:, :1089].view(batch_size, 11, 99)
71
+
72
+ # Convolutional layers with pooling
73
+ x_conv = F.relu(self.bn1(self.conv1(x_seq)))
74
+ x_conv = self.pool(x_conv) # 99 -> 49
75
+
76
+ x_conv = F.relu(self.bn2(self.conv2(x_conv)))
77
+ x_conv = self.pool(x_conv) # 49 -> 24
78
+
79
+ x_conv = F.relu(self.bn3(self.conv3(x_conv)))
80
+ x_conv = self.pool(x_conv) # 24 -> 12
81
+
82
+ # Adaptive pooling to get fixed 256 features
83
+ x_conv = self.adaptive_pool(x_conv) # (batch, 256, 1)
84
+ conv_features = x_conv.view(batch_size, 256) # (batch, 256)
85
+
86
+ # Process mutation type
87
+ mut_features = F.relu(self.mut_fc(mut_type)) # (batch, 32)
88
+
89
+ # Concatenate features
90
+ combined = torch.cat([conv_features, mut_features], dim=1) # (batch, 288)
91
+
92
+ # Classification branch
93
+ x = F.relu(self.fc1(combined))
94
+ x = F.relu(self.fc2(x))
95
+ cls = torch.sigmoid(self.fc3(x)) # (batch, 1)
96
+
97
+ # Importance branch (uses conv features)
98
+ importance = torch.sigmoid(self.importance_head(conv_features)) # (batch, 1)
99
+
100
+ return cls, importance
101
+
102
+
103
+ if __name__ == "__main__":
104
+ # Test the model
105
+ print("Testing MutationPredictorCNN...")
106
+
107
+ model = MutationPredictorCNN()
108
+
109
+ # Test input (batch_size=2, features=1101)
110
+ test_input = torch.randn(2, 1101)
111
+
112
+ cls, importance = model(test_input)
113
+
114
+ print(f"Input shape: {test_input.shape}")
115
+ print(f"Classification output shape: {cls.shape}")
116
+ print(f"Importance output shape: {importance.shape}")
117
+
118
+ print("\nModel parameter shapes (should match checkpoint):")
119
+ for name, param in model.named_parameters():
120
+ print(f"{name:30s}: {str(param.shape):20s}")
121
+
122
+ print("\nExpected parameter shapes from checkpoint:")
123
+ print("conv1.weight : torch.Size([64, 11, 7])")
124
+ print("conv3.weight : torch.Size([256, 128, 3])")
125
+ print("mut_fc.weight : torch.Size([32, 12])")
126
+ print("fc1.weight : torch.Size([128, 288])")
127
+ print("importance_head.weight : torch.Size([1, 256])")
128
+