Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| # TODO this is an easy model; refactor it to be customized by config file only | |
| class DeepDTA(nn.Module): | |
| """ | |
| From DeepDTA | |
| """ | |
| def __init__( | |
| self, | |
| drug_cnn: nn.Module, | |
| protein_cnn: nn.Module, | |
| num_features_drug: int, | |
| num_features_protein: int, | |
| embed_dim: int, | |
| ): | |
| super().__init__() | |
| self.drug_cnn = drug_cnn | |
| self.protein_cnn = protein_cnn | |
| self.fc = nn.Sequential(nn.LazyLinear(1024), nn.ReLU(), nn.Dropout(0.1), | |
| nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.1)) | |
| # protein sequence encoder (1d conv) | |
| self.drug_embedding = nn.Embedding(num_features_drug, embed_dim) | |
| self.protein_embedding = nn.Embedding(num_features_protein, embed_dim) | |
| def forward(self, v_d, v_p): | |
| v_d = self.drug_embedding(v_d.long()) | |
| v_d = self.drug_cnn(v_d) | |
| v_p = self.protein_embedding(v_p.long()) | |
| v_p = self.protein_cnn(v_p) | |
| v_f = torch.cat([v_d, v_p], 1) | |
| v_f = self.fc(v_f) | |
| return v_f | |