File size: 537 Bytes
5fae7ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from torch import nn
import torch_geometric
from torch_geometric.nn.models import DimeNetPlusPlus

class ModelDimeNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.net = DimeNetPlusPlus(hidden_channels=256, out_channels=256, num_blocks=4, num_spherical=8, num_radial=8, int_emb_size=64, basis_emb_size=64, out_emb_channels=64)
        self.head = nn.Linear(256, 1)

    def forward(self, atoms, coords, batch):
        emb = self.net(atoms, coords, batch)
        return self.head(emb)