add egt model
Browse files- egt_model/__init__.py +57 -0
- egt_model/collating_egt.py +103 -0
- egt_model/configuration_egt.py +115 -0
- egt_model/modeling_egt.py +256 -0
- share_model.py +15 -0
egt_model/__init__.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
_import_structure = {
|
| 20 |
+
"configuration_egt": ["EGT_PRETRAINED_CONFIG_ARCHIVE_MAP", "EGTConfig"],
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
if not is_torch_available():
|
| 25 |
+
raise OptionalDependencyNotAvailable()
|
| 26 |
+
except OptionalDependencyNotAvailable:
|
| 27 |
+
pass
|
| 28 |
+
else:
|
| 29 |
+
_import_structure["modeling_egt"] = [
|
| 30 |
+
"EGT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
| 31 |
+
"EGTForGraphClassification",
|
| 32 |
+
"EGTModel",
|
| 33 |
+
"EGTPreTrainedModel",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if TYPE_CHECKING:
|
| 38 |
+
from .configuration_egt import EGT_PRETRAINED_CONFIG_ARCHIVE_MAP, EGTConfig
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
if not is_torch_available():
|
| 42 |
+
raise OptionalDependencyNotAvailable()
|
| 43 |
+
except OptionalDependencyNotAvailable:
|
| 44 |
+
pass
|
| 45 |
+
else:
|
| 46 |
+
from .modeling_egt import (
|
| 47 |
+
EGT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
| 48 |
+
EGTForGraphClassification,
|
| 49 |
+
EGTModel,
|
| 50 |
+
EGTPreTrainedModel,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
else:
|
| 55 |
+
import sys
|
| 56 |
+
|
| 57 |
+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
egt_model/collating_egt.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Mapping
|
| 2 |
+
|
| 3 |
+
import dgl
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def convert_to_single_node_emb(x, offset: int = 128):
|
| 9 |
+
feature_num = x.shape[1] if len(x.shape) > 1 else 1
|
| 10 |
+
feature_offset = 1 + np.arange(0, feature_num * offset, offset, dtype=np.int64)
|
| 11 |
+
x = x + feature_offset
|
| 12 |
+
return x
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def convert_to_single_edge_emb(x, offset: int = 8):
|
| 16 |
+
feature_num = x.shape[1] if len(x.shape) > 1 else 1
|
| 17 |
+
feature_offset = 1 + np.arange(0, feature_num * offset, offset, dtype=np.int64)
|
| 18 |
+
x = x + feature_offset
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def preprocess_item(item, keep_features=True):
|
| 23 |
+
if keep_features and "edge_attr" in item.keys(): # edge_attr
|
| 24 |
+
edge_attr = np.asarray(item["edge_attr"], dtype=np.int64)
|
| 25 |
+
else:
|
| 26 |
+
edge_attr = np.ones((len(item["edge_index"][0]), 1), dtype=np.int64) # same embedding for all
|
| 27 |
+
|
| 28 |
+
if keep_features and "node_feat" in item.keys(): # input_nodes
|
| 29 |
+
node_feature = np.asarray(item["node_feat"], dtype=np.int64)
|
| 30 |
+
else:
|
| 31 |
+
node_feature = np.ones((item["num_nodes"], 1), dtype=np.int64) # same embedding for all
|
| 32 |
+
|
| 33 |
+
edge_index = np.asarray(item["edge_index"], dtype=np.int64)
|
| 34 |
+
|
| 35 |
+
input_nodes = convert_to_single_node_emb(node_feature)
|
| 36 |
+
num_nodes = item["num_nodes"]
|
| 37 |
+
|
| 38 |
+
if len(edge_attr.shape) == 1:
|
| 39 |
+
edge_attr = edge_attr[:, None]
|
| 40 |
+
attn_edge_type = np.zeros([num_nodes, num_nodes, edge_attr.shape[-1]], dtype=np.int64)
|
| 41 |
+
attn_edge_type[edge_index[0], edge_index[1]] = convert_to_single_edge_emb(edge_attr)
|
| 42 |
+
|
| 43 |
+
# convert to dgl graph for computing shortest path distance and svd encodings
|
| 44 |
+
g = dgl.graph((edge_index[0], edge_index[1]))
|
| 45 |
+
shortest_path_result = dgl.shortest_dist(g)
|
| 46 |
+
shortest_path_result = torch.where(shortest_path_result == -1, 510, shortest_path_result)
|
| 47 |
+
svd_pe = dgl.svd_pe(g, k=8, padding=True, random_flip=True)
|
| 48 |
+
|
| 49 |
+
# combine
|
| 50 |
+
item["input_nodes"] = input_nodes
|
| 51 |
+
item["attn_edge_type"] = attn_edge_type
|
| 52 |
+
item["spatial_pos"] = shortest_path_result
|
| 53 |
+
item["svd_pe"] = svd_pe
|
| 54 |
+
if "labels" not in item:
|
| 55 |
+
item["labels"] = item["y"]
|
| 56 |
+
|
| 57 |
+
return item
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class EGTDataCollator:
|
| 61 |
+
def __init__(self, on_the_fly_processing=False):
|
| 62 |
+
self.on_the_fly_processing = on_the_fly_processing
|
| 63 |
+
|
| 64 |
+
def __call__(self, features: List[dict]) -> Dict[str, Any]:
|
| 65 |
+
if self.on_the_fly_processing:
|
| 66 |
+
features = [preprocess_item(i) for i in features]
|
| 67 |
+
|
| 68 |
+
if not isinstance(features[0], Mapping):
|
| 69 |
+
features = [vars(f) for f in features]
|
| 70 |
+
batch = {}
|
| 71 |
+
|
| 72 |
+
max_node_num = max(len(i["input_nodes"]) for i in features)
|
| 73 |
+
node_feat_size = len(features[0]["input_nodes"][0])
|
| 74 |
+
edge_feat_size = len(features[0]["attn_edge_type"][0][0])
|
| 75 |
+
svd_pe_size = len(features[0]["svd_pe"][0]) // 2
|
| 76 |
+
batch_size = len(features)
|
| 77 |
+
|
| 78 |
+
batch["featm"] = torch.zeros(batch_size, max_node_num, max_node_num, edge_feat_size, dtype=torch.long)
|
| 79 |
+
batch["dm"] = torch.zeros(batch_size, max_node_num, max_node_num, dtype=torch.long)
|
| 80 |
+
batch["node_feat"] = torch.zeros(batch_size, max_node_num, node_feat_size, dtype=torch.long)
|
| 81 |
+
batch["svd_pe"] = torch.zeros(batch_size, max_node_num, svd_pe_size * 2, dtype=torch.float)
|
| 82 |
+
batch["attn_mask"] = torch.zeros(batch_size, max_node_num, dtype=torch.long)
|
| 83 |
+
|
| 84 |
+
for ix, f in enumerate(features):
|
| 85 |
+
for k in ["attn_edge_type", "spatial_pos", "input_nodes", "svd_pe"]:
|
| 86 |
+
f[k] = torch.tensor(f[k])
|
| 87 |
+
|
| 88 |
+
batch["featm"][ix, : f["attn_edge_type"].shape[0], : f["attn_edge_type"].shape[1], :] = f["attn_edge_type"]
|
| 89 |
+
batch["dm"][ix, : f["spatial_pos"].shape[0], : f["spatial_pos"].shape[1]] = f["spatial_pos"]
|
| 90 |
+
batch["node_feat"][ix, : f["input_nodes"].shape[0], :] = f["input_nodes"]
|
| 91 |
+
batch["svd_pe"][ix, : f["svd_pe"].shape[0], :] = f["svd_pe"]
|
| 92 |
+
batch["attn_mask"][ix, : f["svd_pe"].shape[0]] = 1
|
| 93 |
+
|
| 94 |
+
sample = features[0]["labels"]
|
| 95 |
+
if len(sample) == 1: # one task
|
| 96 |
+
if isinstance(sample[0], float): # regression
|
| 97 |
+
batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
|
| 98 |
+
else: # binary classification
|
| 99 |
+
batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
|
| 100 |
+
else: # multi task classification, left to float to keep the NaNs
|
| 101 |
+
batch["labels"] = torch.from_numpy(np.stack([i["labels"] for i in features], dim=0))
|
| 102 |
+
|
| 103 |
+
return batch
|
egt_model/configuration_egt.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
+
from transformers.utils import logging
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
logger = logging.get_logger(__name__)
|
| 8 |
+
|
| 9 |
+
EGT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
| 10 |
+
# pcqm4mv1 now deprecated
|
| 11 |
+
"graphormer-base": "https://huggingface.co/clefourrier/graphormer-base-pcqm4mv2/resolve/main/config.json",
|
| 12 |
+
# See all Graphormer models at https://huggingface.co/models?filter=graphormer
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class EGTConfig(PretrainedConfig):
|
| 17 |
+
r"""
|
| 18 |
+
This is the configuration class to store the configuration of a [`~EGTModel`]. It is used to instantiate an
|
| 19 |
+
EGT model according to the specified arguments, defining the model architecture. Instantiating a
|
| 20 |
+
configuration with the defaults will yield a similar configuration to that of the EGT
|
| 21 |
+
[graphormer-base-pcqm4mv1](https://huggingface.co/graphormer-base-pcqm4mv1) architecture.
|
| 22 |
+
|
| 23 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 24 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
feat_size (`int`, *optional*, defaults to 768):
|
| 29 |
+
Node feature size.
|
| 30 |
+
edge_feat_size (`int`, *optional*, defaults to 64):
|
| 31 |
+
Edge feature size.
|
| 32 |
+
num_heads (`int`, *optional*, defaults to 32):
|
| 33 |
+
Number of attention heads, by which :attr: `feat_size` is divisible.
|
| 34 |
+
num_layers (`int`, *optional*, defaults to 30):
|
| 35 |
+
Number of layers.
|
| 36 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 37 |
+
Dropout probability.
|
| 38 |
+
attn_dropout (`float`, *optional*, defaults to 0.3):
|
| 39 |
+
Attention dropout probability.
|
| 40 |
+
activation (`str`, *optional*, defaults to 'ELU'):
|
| 41 |
+
Activation function.
|
| 42 |
+
egt_simple (`bool`, *optional*, defaults to False):
|
| 43 |
+
If `False`, update the edge embedding.
|
| 44 |
+
upto_hop (`int`, *optional*, defaults to 16):
|
| 45 |
+
Maximum distance between nodes in the distance matrices.
|
| 46 |
+
mlp_ratios (`List[float]`, *optional*, defaults to [1., 1.]):
|
| 47 |
+
Ratios of inner dimensions with respect to the input dimension in MLP output block.
|
| 48 |
+
num_virtual_nodes (`int`, *optional*, defaults to 4):
|
| 49 |
+
Number of virtual nodes in EGT model, aggregated to graph embedding in the readout function.
|
| 50 |
+
svd_pe_size (`int`, *optional*, defaults to 8):
|
| 51 |
+
SVD positional encoding size.
|
| 52 |
+
num_classes (`int`, *optional*, defaults to 1):
|
| 53 |
+
Number of target classes or labels, set to n for binary classification of n tasks.
|
| 54 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 55 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
| 56 |
+
traceable (`bool`, *optional*, defaults to `False`):
|
| 57 |
+
Changes return value of the encoder's inner_state to stacked tensors.
|
| 58 |
+
|
| 59 |
+
Example:
|
| 60 |
+
```python
|
| 61 |
+
>>> from transformers import EGTForGraphClassification, EGTConfig
|
| 62 |
+
|
| 63 |
+
>>> # Initializing a EGT graphormer-base-pcqm4mv2 style configuration
|
| 64 |
+
>>> configuration = EGTConfig()
|
| 65 |
+
|
| 66 |
+
>>> # Initializing a model from the graphormer-base-pcqm4mv1 style configuration
|
| 67 |
+
>>> model = EGTForGraphClassification(configuration)
|
| 68 |
+
|
| 69 |
+
>>> # Accessing the model configuration
|
| 70 |
+
>>> configuration = model.config
|
| 71 |
+
```
|
| 72 |
+
"""
|
| 73 |
+
model_type = "egt"
|
| 74 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
feat_size: int = 768,
|
| 79 |
+
edge_feat_size: int = 64,
|
| 80 |
+
num_heads: int = 32,
|
| 81 |
+
num_layers: int = 30,
|
| 82 |
+
dropout: float = 0.0,
|
| 83 |
+
attn_dropout: float = 0.3,
|
| 84 |
+
activation: str = "ELU",
|
| 85 |
+
egt_simple: bool = False,
|
| 86 |
+
upto_hop: int = 16,
|
| 87 |
+
mlp_ratios: List[float] = [1.0, 1.0],
|
| 88 |
+
num_virtual_nodes: int = 4,
|
| 89 |
+
svd_pe_size: int = 8,
|
| 90 |
+
num_classes: int = 1,
|
| 91 |
+
pad_token_id=0,
|
| 92 |
+
bos_token_id=1,
|
| 93 |
+
eos_token_id=2,
|
| 94 |
+
**kwargs,
|
| 95 |
+
):
|
| 96 |
+
self.feat_size = feat_size
|
| 97 |
+
self.edge_feat_size = edge_feat_size
|
| 98 |
+
self.num_heads = num_heads
|
| 99 |
+
self.num_layers = num_layers
|
| 100 |
+
self.dropout = dropout
|
| 101 |
+
self.attn_dropout = attn_dropout
|
| 102 |
+
self.activation = activation
|
| 103 |
+
self.egt_simple = egt_simple
|
| 104 |
+
self.upto_hop = upto_hop
|
| 105 |
+
self.mlp_ratios = mlp_ratios
|
| 106 |
+
self.num_virtual_nodes = num_virtual_nodes
|
| 107 |
+
self.svd_pe_size = svd_pe_size
|
| 108 |
+
self.num_classes = num_classes
|
| 109 |
+
|
| 110 |
+
super().__init__(
|
| 111 |
+
pad_token_id=pad_token_id,
|
| 112 |
+
bos_token_id=bos_token_id,
|
| 113 |
+
eos_token_id=eos_token_id,
|
| 114 |
+
**kwargs,
|
| 115 |
+
)
|
egt_model/modeling_egt.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" PyTorch EGT model."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from dgl.nn import EGTLayer
|
| 9 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 10 |
+
|
| 11 |
+
from transformers.modeling_outputs import (
|
| 12 |
+
BaseModelOutputWithNoAttention,
|
| 13 |
+
SequenceClassifierOutput,
|
| 14 |
+
)
|
| 15 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 16 |
+
from .configuration_egt import EGTConfig
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
NODE_FEATURES_OFFSET = 128
|
| 20 |
+
NUM_NODE_FEATURES = 9
|
| 21 |
+
EDGE_FEATURES_OFFSET = 8
|
| 22 |
+
NUM_EDGE_FEATURES = 3
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class VirtualNodes(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
Generate node and edge features for virtual nodes in the graph
|
| 28 |
+
and pad the corresponding matrices.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, feat_size, edge_feat_size, num_virtual_nodes=1):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.feat_size = feat_size
|
| 34 |
+
self.edge_feat_size = edge_feat_size
|
| 35 |
+
self.num_virtual_nodes = num_virtual_nodes
|
| 36 |
+
|
| 37 |
+
self.vn_node_embeddings = nn.Parameter(torch.empty(num_virtual_nodes, self.feat_size))
|
| 38 |
+
self.vn_edge_embeddings = nn.Parameter(torch.empty(num_virtual_nodes, self.edge_feat_size))
|
| 39 |
+
nn.init.normal_(self.vn_node_embeddings)
|
| 40 |
+
nn.init.normal_(self.vn_edge_embeddings)
|
| 41 |
+
|
| 42 |
+
def forward(self, h, e, mask):
|
| 43 |
+
node_emb = self.vn_node_embeddings.unsqueeze(0).expand(h.shape[0], -1, -1)
|
| 44 |
+
h = torch.cat([node_emb, h], dim=1)
|
| 45 |
+
|
| 46 |
+
e_shape = e.shape
|
| 47 |
+
edge_emb_row = self.vn_edge_embeddings.unsqueeze(1)
|
| 48 |
+
edge_emb_col = self.vn_edge_embeddings.unsqueeze(0)
|
| 49 |
+
edge_emb_box = 0.5 * (edge_emb_row + edge_emb_col)
|
| 50 |
+
|
| 51 |
+
edge_emb_row = edge_emb_row.unsqueeze(0).expand(e_shape[0], -1, e_shape[2], -1)
|
| 52 |
+
edge_emb_col = edge_emb_col.unsqueeze(0).expand(e_shape[0], e_shape[1], -1, -1)
|
| 53 |
+
edge_emb_box = edge_emb_box.unsqueeze(0).expand(e_shape[0], -1, -1, -1)
|
| 54 |
+
|
| 55 |
+
e = torch.cat([edge_emb_row, e], dim=1)
|
| 56 |
+
e_col_box = torch.cat([edge_emb_box, edge_emb_col], dim=1)
|
| 57 |
+
e = torch.cat([e_col_box, e], dim=2)
|
| 58 |
+
|
| 59 |
+
if mask is not None:
|
| 60 |
+
mask = F.pad(mask, (self.num_virtual_nodes, 0, self.num_virtual_nodes, 0), mode="constant", value=0)
|
| 61 |
+
return h, e, mask
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class EGTPreTrainedModel(PreTrainedModel):
|
| 65 |
+
"""
|
| 66 |
+
A simple interface for downloading and loading pretrained models.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
config_class = EGTConfig
|
| 70 |
+
base_model_prefix = "egt"
|
| 71 |
+
supports_gradient_checkpointing = True
|
| 72 |
+
main_input_name_nodes = "node_feat"
|
| 73 |
+
main_input_name_edges = "featm"
|
| 74 |
+
|
| 75 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 76 |
+
if isinstance(module, EGTModel):
|
| 77 |
+
module.gradient_checkpointing = value
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class EGTModel(EGTPreTrainedModel):
|
| 81 |
+
"""The EGT model is a graph-encoder model.
|
| 82 |
+
|
| 83 |
+
It goes from a graph to its representation. If you want to use the model for a downstream classification task, use
|
| 84 |
+
EGTForGraphClassification instead. For any other downstream task, feel free to add a new class, or combine
|
| 85 |
+
this model with a downstream model of your choice, following the example in EGTForGraphClassification.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self, config: EGTConfig):
|
| 89 |
+
super().__init__(config)
|
| 90 |
+
|
| 91 |
+
self.activation = getattr(nn, config.activation)()
|
| 92 |
+
|
| 93 |
+
self.layer_common_kwargs = {
|
| 94 |
+
"feat_size": config.feat_size,
|
| 95 |
+
"edge_feat_size": config.edge_feat_size,
|
| 96 |
+
"num_heads": config.num_heads,
|
| 97 |
+
"num_virtual_nodes": config.num_virtual_nodes,
|
| 98 |
+
"dropout": config.dropout,
|
| 99 |
+
"attn_dropout": config.attn_dropout,
|
| 100 |
+
"activation": self.activation,
|
| 101 |
+
}
|
| 102 |
+
self.edge_update = not config.egt_simple
|
| 103 |
+
|
| 104 |
+
self.EGT_layers = nn.ModuleList(
|
| 105 |
+
[EGTLayer(**self.layer_common_kwargs, edge_update=self.edge_update) for _ in range(config.num_layers - 1)]
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self.EGT_layers.append(EGTLayer(**self.layer_common_kwargs, edge_update=False))
|
| 109 |
+
|
| 110 |
+
self.upto_hop = config.upto_hop
|
| 111 |
+
self.num_virtual_nodes = config.num_virtual_nodes
|
| 112 |
+
self.svd_pe_size = config.svd_pe_size
|
| 113 |
+
|
| 114 |
+
self.nodef_embed = nn.Embedding(NUM_NODE_FEATURES * NODE_FEATURES_OFFSET + 1, config.feat_size, padding_idx=0)
|
| 115 |
+
if self.svd_pe_size:
|
| 116 |
+
self.svd_embed = nn.Linear(self.svd_pe_size * 2, config.feat_size)
|
| 117 |
+
|
| 118 |
+
self.dist_embed = nn.Embedding(self.upto_hop + 2, config.edge_feat_size)
|
| 119 |
+
self.featm_embed = nn.Embedding(
|
| 120 |
+
NUM_EDGE_FEATURES * EDGE_FEATURES_OFFSET + 1, config.edge_feat_size, padding_idx=0
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if self.num_virtual_nodes > 0:
|
| 124 |
+
self.vn_layer = VirtualNodes(config.feat_size, config.edge_feat_size, self.num_virtual_nodes)
|
| 125 |
+
|
| 126 |
+
self.final_ln_h = nn.LayerNorm(config.feat_size)
|
| 127 |
+
mlp_dims = (
|
| 128 |
+
[config.feat_size * max(self.num_virtual_nodes, 1)]
|
| 129 |
+
+ [round(config.feat_size * r) for r in config.mlp_ratios]
|
| 130 |
+
+ [config.num_classes]
|
| 131 |
+
)
|
| 132 |
+
self.mlp_layers = nn.ModuleList([nn.Linear(mlp_dims[i], mlp_dims[i + 1]) for i in range(len(mlp_dims) - 1)])
|
| 133 |
+
self.mlp_fn = self.activation
|
| 134 |
+
|
| 135 |
+
self._backward_compatibility_gradient_checkpointing()
|
| 136 |
+
|
| 137 |
+
def input_block(self, nodef, featm, dm, nodem, svd_pe):
|
| 138 |
+
dm = dm.long().clamp(min=0, max=self.upto_hop + 1) # (b,i,j)
|
| 139 |
+
|
| 140 |
+
h = self.nodef_embed(nodef).sum(dim=2) # (b,i,w,h) -> (b,i,h)
|
| 141 |
+
|
| 142 |
+
if self.svd_pe_size:
|
| 143 |
+
h = h + self.svd_embed(svd_pe)
|
| 144 |
+
|
| 145 |
+
e = self.dist_embed(dm) + self.featm_embed(featm).sum(dim=3) # (b,i,j,f,e) -> (b,i,j,e)
|
| 146 |
+
|
| 147 |
+
mask = (nodem[:, :, None] * nodem[:, None, :] - 1) * 1e9
|
| 148 |
+
|
| 149 |
+
if self.num_virtual_nodes > 0:
|
| 150 |
+
h, e, mask = self.vn_layer(h, e, mask)
|
| 151 |
+
return h, e, mask
|
| 152 |
+
|
| 153 |
+
def final_embedding(self, h, attn_mask):
|
| 154 |
+
h = self.final_ln_h(h)
|
| 155 |
+
if self.num_virtual_nodes > 0:
|
| 156 |
+
h = h[:, : self.num_virtual_nodes].reshape(h.shape[0], -1)
|
| 157 |
+
else:
|
| 158 |
+
nodem = attn_mask.float().unsqueeze(dim=-1)
|
| 159 |
+
h = (h * nodem).sum(dim=1) / (nodem.sum(dim=1) + 1e-9)
|
| 160 |
+
return h
|
| 161 |
+
|
| 162 |
+
def output_block(self, h):
|
| 163 |
+
h = self.mlp_layers[0](h)
|
| 164 |
+
for layer in self.mlp_layers[1:]:
|
| 165 |
+
h = layer(self.mlp_fn(h))
|
| 166 |
+
return h
|
| 167 |
+
|
| 168 |
+
def forward(
|
| 169 |
+
self,
|
| 170 |
+
node_feat: torch.LongTensor,
|
| 171 |
+
featm: torch.LongTensor,
|
| 172 |
+
dm: torch.LongTensor,
|
| 173 |
+
attn_mask: torch.LongTensor,
|
| 174 |
+
svd_pe: torch.Tensor,
|
| 175 |
+
return_dict: Optional[bool] = None,
|
| 176 |
+
**unused,
|
| 177 |
+
) -> torch.Tensor:
|
| 178 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 179 |
+
|
| 180 |
+
h, e, mask = self.input_block(node_feat, featm, dm, attn_mask, svd_pe)
|
| 181 |
+
|
| 182 |
+
for layer in self.EGT_layers[:-1]:
|
| 183 |
+
if self.edge_update:
|
| 184 |
+
h, e = layer(h, e, mask)
|
| 185 |
+
else:
|
| 186 |
+
h = layer(h, e, mask)
|
| 187 |
+
|
| 188 |
+
h = self.EGT_layers[-1](h, e, mask)
|
| 189 |
+
|
| 190 |
+
h = self.final_embedding(h, attn_mask)
|
| 191 |
+
|
| 192 |
+
outputs = self.output_block(h)
|
| 193 |
+
|
| 194 |
+
if not return_dict:
|
| 195 |
+
return tuple(x for x in [outputs] if x is not None)
|
| 196 |
+
return BaseModelOutputWithNoAttention(last_hidden_state=outputs)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class EGTForGraphClassification(EGTPreTrainedModel):
|
| 200 |
+
"""
|
| 201 |
+
This model can be used for graph-level classification or regression tasks.
|
| 202 |
+
|
| 203 |
+
It can be trained on
|
| 204 |
+
- regression (by setting config.num_classes to 1); there should be one float-type label per graph
|
| 205 |
+
- one task classification (by setting config.num_classes to the number of classes); there should be one integer
|
| 206 |
+
label per graph
|
| 207 |
+
- binary multi-task classification (by setting config.num_classes to the number of labels); there should be a list
|
| 208 |
+
of integer labels for each graph.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
def __init__(self, config: EGTConfig):
|
| 212 |
+
super().__init__(config)
|
| 213 |
+
self.model = EGTModel(config)
|
| 214 |
+
self.num_classes = config.num_classes
|
| 215 |
+
|
| 216 |
+
self._backward_compatibility_gradient_checkpointing()
|
| 217 |
+
|
| 218 |
+
def forward(
|
| 219 |
+
self,
|
| 220 |
+
node_feat: torch.LongTensor,
|
| 221 |
+
featm: torch.LongTensor,
|
| 222 |
+
dm: torch.LongTensor,
|
| 223 |
+
attn_mask: torch.LongTensor,
|
| 224 |
+
svd_pe: torch.Tensor,
|
| 225 |
+
labels: Optional[torch.LongTensor] = None,
|
| 226 |
+
return_dict: Optional[bool] = None,
|
| 227 |
+
**unused,
|
| 228 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 229 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 230 |
+
|
| 231 |
+
logits = self.model(
|
| 232 |
+
node_feat,
|
| 233 |
+
featm,
|
| 234 |
+
dm,
|
| 235 |
+
attn_mask,
|
| 236 |
+
svd_pe,
|
| 237 |
+
return_dict=True,
|
| 238 |
+
)["last_hidden_state"]
|
| 239 |
+
|
| 240 |
+
loss = None
|
| 241 |
+
if labels is not None:
|
| 242 |
+
mask = ~torch.isnan(labels)
|
| 243 |
+
|
| 244 |
+
if self.num_classes == 1: # regression
|
| 245 |
+
loss_fct = MSELoss()
|
| 246 |
+
loss = loss_fct(logits[mask].squeeze(), labels[mask].squeeze().float())
|
| 247 |
+
elif self.num_classes > 1 and len(labels.shape) == 1: # One task classification
|
| 248 |
+
loss_fct = CrossEntropyLoss()
|
| 249 |
+
loss = loss_fct(logits[mask].view(-1, self.num_classes), labels[mask].view(-1))
|
| 250 |
+
else: # Binary multi-task classification
|
| 251 |
+
loss_fct = BCEWithLogitsLoss(reduction="sum")
|
| 252 |
+
loss = loss_fct(logits[mask], labels[mask])
|
| 253 |
+
|
| 254 |
+
if not return_dict:
|
| 255 |
+
return tuple(x for x in [loss, logits] if x is not None)
|
| 256 |
+
return SequenceClassifierOutput(loss=loss, logits=logits, attentions=None)
|
share_model.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from egt_model.configuration_egt import EGTConfig
|
| 3 |
+
from egt_model.modeling_egt import EGTModel, EGTForGraphClassification
|
| 4 |
+
|
| 5 |
+
EGTConfig.register_for_auto_class()
|
| 6 |
+
EGTModel.register_for_auto_class("AutoModel")
|
| 7 |
+
EGTForGraphClassification.register_for_auto_class("AutoModelForGraphClassification")
|
| 8 |
+
|
| 9 |
+
egt_config = EGTConfig()
|
| 10 |
+
egt = EGTForGraphClassification(egt_config)
|
| 11 |
+
|
| 12 |
+
pretrained_model = torch.load("/home/ubuntu/transformers/egt_model_state")
|
| 13 |
+
egt.model.load_state_dict(pretrained_model.state_dict())
|
| 14 |
+
|
| 15 |
+
# egt.push_to_hub("Zhiteng/egt")
|