File size: 843 Bytes
d9df210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from torch import nn
import torch.nn.functional as F

def pad_graph_nodes(mol_enc, g_n_nodes):
    """
    Args:
        mol_enc: 2D tensor of shape (sum_nodes, D)
                 Node embeddings for each molecule.
        g_n_nodes: list[int]  Number of nodes per graph (len = B)

    Returns:
        padded: (B, max_nodes, D) tensor
        mask:   (B, max_nodes) bool tensor, True for valid nodes
    """

    # Already concatenated: shape (sum_nodes, D)
    B = len(g_n_nodes)
    D = mol_enc.shape[1]
    max_nodes = max(g_n_nodes)
    padded = mol_enc.new_zeros((B, max_nodes, D))
    mask = torch.zeros((B, max_nodes), dtype=torch.bool, device=mol_enc.device)

    idx = 0
    for i, n in enumerate(g_n_nodes):
        padded[i, :n] = mol_enc[idx:idx+n]
        mask[i, :n] = True
        idx += n
    return padded, mask