| import torch | |
| from torch import nn | |
| import torch_geometric | |
| from torch_geometric.nn.models import DimeNetPlusPlus | |
| class ModelDimeNet(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.net = DimeNetPlusPlus(hidden_channels=256, out_channels=256, num_blocks=4, num_spherical=8, num_radial=8, int_emb_size=64, basis_emb_size=64, out_emb_channels=64) | |
| self.head = nn.Linear(256, 1) | |
| def forward(self, atoms, coords, batch): | |
| emb = self.net(atoms, coords, batch) | |
| return self.head(emb) |