vscf_mlff / utils_model.py
timcryt's picture
Initial commit
5fae7ca verified
import torch
from torch import nn
import torch_geometric
from torch_geometric.nn.models import DimeNetPlusPlus
class ModelDimeNet(nn.Module):
def __init__(self):
super().__init__()
self.net = DimeNetPlusPlus(hidden_channels=256, out_channels=256, num_blocks=4, num_spherical=8, num_radial=8, int_emb_size=64, basis_emb_size=64, out_emb_channels=64)
self.head = nn.Linear(256, 1)
def forward(self, atoms, coords, batch):
emb = self.net(atoms, coords, batch)
return self.head(emb)