File size: 537 Bytes
5fae7ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import torch
from torch import nn
import torch_geometric
from torch_geometric.nn.models import DimeNetPlusPlus
class ModelDimeNet(nn.Module):
def __init__(self):
super().__init__()
self.net = DimeNetPlusPlus(hidden_channels=256, out_channels=256, num_blocks=4, num_spherical=8, num_radial=8, int_emb_size=64, basis_emb_size=64, out_emb_channels=64)
self.head = nn.Linear(256, 1)
def forward(self, atoms, coords, batch):
emb = self.net(atoms, coords, batch)
return self.head(emb) |