Create collating_graphormer.pyx
Browse files- collating_graphormer.pyx +134 -0
collating_graphormer.pyx
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation and HuggingFace
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
from typing import Any, Dict, List, Mapping
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from ...utils import is_cython_available, requires_backends
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
if is_cython_available():
|
| 13 |
+
import pyximport
|
| 14 |
+
|
| 15 |
+
pyximport.install(setup_args={"include_dirs": np.get_include()})
|
| 16 |
+
from . import algos_graphormer # noqa E402
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def convert_to_single_emb(x, offset: int = 512):
|
| 20 |
+
feature_num = x.shape[1] if len(x.shape) > 1 else 1
|
| 21 |
+
feature_offset = 1 + np.arange(0, feature_num * offset, offset, dtype=np.int64)
|
| 22 |
+
x = x + feature_offset
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def preprocess_item(item, keep_features=True):
|
| 27 |
+
requires_backends(preprocess_item, ["cython"])
|
| 28 |
+
|
| 29 |
+
if keep_features and "edge_attr" in item.keys(): # edge_attr
|
| 30 |
+
edge_attr = np.asarray(item["edge_attr"], dtype=np.int64)
|
| 31 |
+
else:
|
| 32 |
+
edge_attr = np.ones((len(item["edge_index"][0]), 1), dtype=np.int64) # same embedding for all
|
| 33 |
+
|
| 34 |
+
if keep_features and "node_feat" in item.keys(): # input_nodes
|
| 35 |
+
node_feature = np.asarray(item["node_feat"], dtype=np.int64)
|
| 36 |
+
else:
|
| 37 |
+
node_feature = np.ones((item["num_nodes"], 1), dtype=np.int64) # same embedding for all
|
| 38 |
+
|
| 39 |
+
edge_index = np.asarray(item["edge_index"], dtype=np.int64)
|
| 40 |
+
|
| 41 |
+
input_nodes = convert_to_single_emb(node_feature) + 1
|
| 42 |
+
num_nodes = item["num_nodes"]
|
| 43 |
+
|
| 44 |
+
if len(edge_attr.shape) == 1:
|
| 45 |
+
edge_attr = edge_attr[:, None]
|
| 46 |
+
attn_edge_type = np.zeros([num_nodes, num_nodes, edge_attr.shape[-1]], dtype=np.int64)
|
| 47 |
+
attn_edge_type[edge_index[0], edge_index[1]] = convert_to_single_emb(edge_attr) + 1
|
| 48 |
+
|
| 49 |
+
# node adj matrix [num_nodes, num_nodes] bool
|
| 50 |
+
adj = np.zeros([num_nodes, num_nodes], dtype=bool)
|
| 51 |
+
adj[edge_index[0], edge_index[1]] = True
|
| 52 |
+
|
| 53 |
+
shortest_path_result, path = algos_graphormer.floyd_warshall(adj)
|
| 54 |
+
max_dist = np.amax(shortest_path_result)
|
| 55 |
+
|
| 56 |
+
input_edges = algos_graphormer.gen_edge_input(max_dist, path, attn_edge_type)
|
| 57 |
+
attn_bias = np.zeros([num_nodes + 1, num_nodes + 1], dtype=np.single) # with graph token
|
| 58 |
+
|
| 59 |
+
# combine
|
| 60 |
+
item["input_nodes"] = input_nodes + 1 # we shift all indices by one for padding
|
| 61 |
+
item["attn_bias"] = attn_bias
|
| 62 |
+
item["attn_edge_type"] = attn_edge_type
|
| 63 |
+
item["spatial_pos"] = shortest_path_result.astype(np.int64) + 1 # we shift all indices by one for padding
|
| 64 |
+
item["in_degree"] = np.sum(adj, axis=1).reshape(-1) + 1 # we shift all indices by one for padding
|
| 65 |
+
item["out_degree"] = item["in_degree"] # for undirected graph
|
| 66 |
+
item["input_edges"] = input_edges + 1 # we shift all indices by one for padding
|
| 67 |
+
if "labels" not in item:
|
| 68 |
+
item["labels"] = item["y"]
|
| 69 |
+
|
| 70 |
+
return item
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class GraphormerDataCollator:
|
| 74 |
+
def __init__(self, spatial_pos_max=20, on_the_fly_processing=False):
|
| 75 |
+
if not is_cython_available():
|
| 76 |
+
raise ImportError("Graphormer preprocessing needs Cython (pyximport)")
|
| 77 |
+
|
| 78 |
+
self.spatial_pos_max = spatial_pos_max
|
| 79 |
+
self.on_the_fly_processing = on_the_fly_processing
|
| 80 |
+
|
| 81 |
+
def __call__(self, features: List[dict]) -> Dict[str, Any]:
|
| 82 |
+
if self.on_the_fly_processing:
|
| 83 |
+
features = [preprocess_item(i) for i in features]
|
| 84 |
+
|
| 85 |
+
if not isinstance(features[0], Mapping):
|
| 86 |
+
features = [vars(f) for f in features]
|
| 87 |
+
batch = {}
|
| 88 |
+
|
| 89 |
+
max_node_num = max(len(i["input_nodes"]) for i in features)
|
| 90 |
+
node_feat_size = len(features[0]["input_nodes"][0])
|
| 91 |
+
edge_feat_size = len(features[0]["attn_edge_type"][0][0])
|
| 92 |
+
max_dist = max(len(i["input_edges"][0][0]) for i in features)
|
| 93 |
+
edge_input_size = len(features[0]["input_edges"][0][0][0])
|
| 94 |
+
batch_size = len(features)
|
| 95 |
+
|
| 96 |
+
batch["attn_bias"] = torch.zeros(batch_size, max_node_num + 1, max_node_num + 1, dtype=torch.float)
|
| 97 |
+
batch["attn_edge_type"] = torch.zeros(batch_size, max_node_num, max_node_num, edge_feat_size, dtype=torch.long)
|
| 98 |
+
batch["spatial_pos"] = torch.zeros(batch_size, max_node_num, max_node_num, dtype=torch.long)
|
| 99 |
+
batch["in_degree"] = torch.zeros(batch_size, max_node_num, dtype=torch.long)
|
| 100 |
+
batch["input_nodes"] = torch.zeros(batch_size, max_node_num, node_feat_size, dtype=torch.long)
|
| 101 |
+
batch["input_edges"] = torch.zeros(
|
| 102 |
+
batch_size, max_node_num, max_node_num, max_dist, edge_input_size, dtype=torch.long
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
for ix, f in enumerate(features):
|
| 106 |
+
for k in ["attn_bias", "attn_edge_type", "spatial_pos", "in_degree", "input_nodes", "input_edges"]:
|
| 107 |
+
f[k] = torch.tensor(f[k])
|
| 108 |
+
|
| 109 |
+
if len(f["attn_bias"][1:, 1:][f["spatial_pos"] >= self.spatial_pos_max]) > 0:
|
| 110 |
+
f["attn_bias"][1:, 1:][f["spatial_pos"] >= self.spatial_pos_max] = float("-inf")
|
| 111 |
+
|
| 112 |
+
batch["attn_bias"][ix, : f["attn_bias"].shape[0], : f["attn_bias"].shape[1]] = f["attn_bias"]
|
| 113 |
+
batch["attn_edge_type"][ix, : f["attn_edge_type"].shape[0], : f["attn_edge_type"].shape[1], :] = f[
|
| 114 |
+
"attn_edge_type"
|
| 115 |
+
]
|
| 116 |
+
batch["spatial_pos"][ix, : f["spatial_pos"].shape[0], : f["spatial_pos"].shape[1]] = f["spatial_pos"]
|
| 117 |
+
batch["in_degree"][ix, : f["in_degree"].shape[0]] = f["in_degree"]
|
| 118 |
+
batch["input_nodes"][ix, : f["input_nodes"].shape[0], :] = f["input_nodes"]
|
| 119 |
+
batch["input_edges"][
|
| 120 |
+
ix, : f["input_edges"].shape[0], : f["input_edges"].shape[1], : f["input_edges"].shape[2], :
|
| 121 |
+
] = f["input_edges"]
|
| 122 |
+
|
| 123 |
+
batch["out_degree"] = batch["in_degree"]
|
| 124 |
+
|
| 125 |
+
sample = features[0]["labels"]
|
| 126 |
+
if len(sample) == 1: # one task
|
| 127 |
+
if isinstance(sample[0], float): # regression
|
| 128 |
+
batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
|
| 129 |
+
else: # binary classification
|
| 130 |
+
batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
|
| 131 |
+
else: # multi task classification, left to float to keep the NaNs
|
| 132 |
+
batch["labels"] = torch.from_numpy(np.stack([i["labels"] for i in features], axis=0))
|
| 133 |
+
|
| 134 |
+
return batch
|