import torch from torch import nn import torch_geometric from torch_geometric.nn.models import DimeNetPlusPlus class ModelDimeNet(nn.Module): def __init__(self): super().__init__() self.net = DimeNetPlusPlus(hidden_channels=256, out_channels=256, num_blocks=4, num_spherical=8, num_radial=8, int_emb_size=64, basis_emb_size=64, out_emb_channels=64) self.head = nn.Linear(256, 1) def forward(self, atoms, coords, batch): emb = self.net(atoms, coords, batch) return self.head(emb)