Spaces:
Runtime error
Runtime error
| import torch | |
| from data.features import get_atom_feature_dims, get_bond_feature_dims | |
| full_atom_feature_dims = get_atom_feature_dims() | |
| full_bond_feature_dims = get_bond_feature_dims() | |
| class AtomEncoder(torch.nn.Module): | |
| def __init__(self, emb_dim): | |
| super(AtomEncoder, self).__init__() | |
| self.atom_embedding_list = torch.nn.ModuleList() | |
| for i, dim in enumerate(full_atom_feature_dims): | |
| emb = torch.nn.Embedding(dim, emb_dim) | |
| torch.nn.init.xavier_uniform_(emb.weight.data) | |
| self.atom_embedding_list.append(emb) | |
| def forward(self, x): | |
| x_embedding = 0 | |
| for i in range(x.shape[1]): | |
| x_embedding += self.atom_embedding_list[i](x[:,i]) | |
| return x_embedding | |
| class BondEncoder(torch.nn.Module): | |
| def __init__(self, emb_dim): | |
| super(BondEncoder, self).__init__() | |
| self.bond_embedding_list = torch.nn.ModuleList() | |
| for i, dim in enumerate(full_bond_feature_dims): | |
| emb = torch.nn.Embedding(dim, emb_dim) | |
| torch.nn.init.xavier_uniform_(emb.weight.data) | |
| self.bond_embedding_list.append(emb) | |
| def forward(self, edge_attr): | |
| bond_embedding = 0 | |
| for i in range(edge_attr.shape[1]): | |
| bond_embedding += self.bond_embedding_list[i](edge_attr[:,i]) | |
| return bond_embedding | |
| if __name__ == '__main__': | |
| from loader import GraphClassificationPygDataset | |
| dataset = GraphClassificationPygDataset(name = 'tox21') | |
| atom_enc = AtomEncoder(100) | |
| bond_enc = BondEncoder(100) | |
| print(atom_enc(dataset[0].x)) | |
| print(bond_enc(dataset[0].edge_attr)) |