cnbg commited on
Commit
235b048
·
1 Parent(s): c1f699a

add egt model

Browse files
egt_model/__init__.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
17
+
18
+
19
+ _import_structure = {
20
+ "configuration_egt": ["EGT_PRETRAINED_CONFIG_ARCHIVE_MAP", "EGTConfig"],
21
+ }
22
+
23
+ try:
24
+ if not is_torch_available():
25
+ raise OptionalDependencyNotAvailable()
26
+ except OptionalDependencyNotAvailable:
27
+ pass
28
+ else:
29
+ _import_structure["modeling_egt"] = [
30
+ "EGT_PRETRAINED_MODEL_ARCHIVE_LIST",
31
+ "EGTForGraphClassification",
32
+ "EGTModel",
33
+ "EGTPreTrainedModel",
34
+ ]
35
+
36
+
37
+ if TYPE_CHECKING:
38
+ from .configuration_egt import EGT_PRETRAINED_CONFIG_ARCHIVE_MAP, EGTConfig
39
+
40
+ try:
41
+ if not is_torch_available():
42
+ raise OptionalDependencyNotAvailable()
43
+ except OptionalDependencyNotAvailable:
44
+ pass
45
+ else:
46
+ from .modeling_egt import (
47
+ EGT_PRETRAINED_MODEL_ARCHIVE_LIST,
48
+ EGTForGraphClassification,
49
+ EGTModel,
50
+ EGTPreTrainedModel,
51
+ )
52
+
53
+
54
+ else:
55
+ import sys
56
+
57
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
egt_model/collating_egt.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Mapping
2
+
3
+ import dgl
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def convert_to_single_node_emb(x, offset: int = 128):
9
+ feature_num = x.shape[1] if len(x.shape) > 1 else 1
10
+ feature_offset = 1 + np.arange(0, feature_num * offset, offset, dtype=np.int64)
11
+ x = x + feature_offset
12
+ return x
13
+
14
+
15
+ def convert_to_single_edge_emb(x, offset: int = 8):
16
+ feature_num = x.shape[1] if len(x.shape) > 1 else 1
17
+ feature_offset = 1 + np.arange(0, feature_num * offset, offset, dtype=np.int64)
18
+ x = x + feature_offset
19
+ return x
20
+
21
+
22
+ def preprocess_item(item, keep_features=True):
23
+ if keep_features and "edge_attr" in item.keys(): # edge_attr
24
+ edge_attr = np.asarray(item["edge_attr"], dtype=np.int64)
25
+ else:
26
+ edge_attr = np.ones((len(item["edge_index"][0]), 1), dtype=np.int64) # same embedding for all
27
+
28
+ if keep_features and "node_feat" in item.keys(): # input_nodes
29
+ node_feature = np.asarray(item["node_feat"], dtype=np.int64)
30
+ else:
31
+ node_feature = np.ones((item["num_nodes"], 1), dtype=np.int64) # same embedding for all
32
+
33
+ edge_index = np.asarray(item["edge_index"], dtype=np.int64)
34
+
35
+ input_nodes = convert_to_single_node_emb(node_feature)
36
+ num_nodes = item["num_nodes"]
37
+
38
+ if len(edge_attr.shape) == 1:
39
+ edge_attr = edge_attr[:, None]
40
+ attn_edge_type = np.zeros([num_nodes, num_nodes, edge_attr.shape[-1]], dtype=np.int64)
41
+ attn_edge_type[edge_index[0], edge_index[1]] = convert_to_single_edge_emb(edge_attr)
42
+
43
+ # convert to dgl graph for computing shortest path distance and svd encodings
44
+ g = dgl.graph((edge_index[0], edge_index[1]))
45
+ shortest_path_result = dgl.shortest_dist(g)
46
+ shortest_path_result = torch.where(shortest_path_result == -1, 510, shortest_path_result)
47
+ svd_pe = dgl.svd_pe(g, k=8, padding=True, random_flip=True)
48
+
49
+ # combine
50
+ item["input_nodes"] = input_nodes
51
+ item["attn_edge_type"] = attn_edge_type
52
+ item["spatial_pos"] = shortest_path_result
53
+ item["svd_pe"] = svd_pe
54
+ if "labels" not in item:
55
+ item["labels"] = item["y"]
56
+
57
+ return item
58
+
59
+
60
+ class EGTDataCollator:
61
+ def __init__(self, on_the_fly_processing=False):
62
+ self.on_the_fly_processing = on_the_fly_processing
63
+
64
+ def __call__(self, features: List[dict]) -> Dict[str, Any]:
65
+ if self.on_the_fly_processing:
66
+ features = [preprocess_item(i) for i in features]
67
+
68
+ if not isinstance(features[0], Mapping):
69
+ features = [vars(f) for f in features]
70
+ batch = {}
71
+
72
+ max_node_num = max(len(i["input_nodes"]) for i in features)
73
+ node_feat_size = len(features[0]["input_nodes"][0])
74
+ edge_feat_size = len(features[0]["attn_edge_type"][0][0])
75
+ svd_pe_size = len(features[0]["svd_pe"][0]) // 2
76
+ batch_size = len(features)
77
+
78
+ batch["featm"] = torch.zeros(batch_size, max_node_num, max_node_num, edge_feat_size, dtype=torch.long)
79
+ batch["dm"] = torch.zeros(batch_size, max_node_num, max_node_num, dtype=torch.long)
80
+ batch["node_feat"] = torch.zeros(batch_size, max_node_num, node_feat_size, dtype=torch.long)
81
+ batch["svd_pe"] = torch.zeros(batch_size, max_node_num, svd_pe_size * 2, dtype=torch.float)
82
+ batch["attn_mask"] = torch.zeros(batch_size, max_node_num, dtype=torch.long)
83
+
84
+ for ix, f in enumerate(features):
85
+ for k in ["attn_edge_type", "spatial_pos", "input_nodes", "svd_pe"]:
86
+ f[k] = torch.tensor(f[k])
87
+
88
+ batch["featm"][ix, : f["attn_edge_type"].shape[0], : f["attn_edge_type"].shape[1], :] = f["attn_edge_type"]
89
+ batch["dm"][ix, : f["spatial_pos"].shape[0], : f["spatial_pos"].shape[1]] = f["spatial_pos"]
90
+ batch["node_feat"][ix, : f["input_nodes"].shape[0], :] = f["input_nodes"]
91
+ batch["svd_pe"][ix, : f["svd_pe"].shape[0], :] = f["svd_pe"]
92
+ batch["attn_mask"][ix, : f["svd_pe"].shape[0]] = 1
93
+
94
+ sample = features[0]["labels"]
95
+ if len(sample) == 1: # one task
96
+ if isinstance(sample[0], float): # regression
97
+ batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
98
+ else: # binary classification
99
+ batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
100
+ else: # multi task classification, left to float to keep the NaNs
101
+ batch["labels"] = torch.from_numpy(np.stack([i["labels"] for i in features], dim=0))
102
+
103
+ return batch
egt_model/configuration_egt.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+ EGT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
10
+ # pcqm4mv1 now deprecated
11
+ "graphormer-base": "https://huggingface.co/clefourrier/graphormer-base-pcqm4mv2/resolve/main/config.json",
12
+ # See all Graphormer models at https://huggingface.co/models?filter=graphormer
13
+ }
14
+
15
+
16
+ class EGTConfig(PretrainedConfig):
17
+ r"""
18
+ This is the configuration class to store the configuration of a [`~EGTModel`]. It is used to instantiate an
19
+ EGT model according to the specified arguments, defining the model architecture. Instantiating a
20
+ configuration with the defaults will yield a similar configuration to that of the EGT
21
+ [graphormer-base-pcqm4mv1](https://huggingface.co/graphormer-base-pcqm4mv1) architecture.
22
+
23
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
24
+ documentation from [`PretrainedConfig`] for more information.
25
+
26
+
27
+ Args:
28
+ feat_size (`int`, *optional*, defaults to 768):
29
+ Node feature size.
30
+ edge_feat_size (`int`, *optional*, defaults to 64):
31
+ Edge feature size.
32
+ num_heads (`int`, *optional*, defaults to 32):
33
+ Number of attention heads, by which :attr: `feat_size` is divisible.
34
+ num_layers (`int`, *optional*, defaults to 30):
35
+ Number of layers.
36
+ dropout (`float`, *optional*, defaults to 0.0):
37
+ Dropout probability.
38
+ attn_dropout (`float`, *optional*, defaults to 0.3):
39
+ Attention dropout probability.
40
+ activation (`str`, *optional*, defaults to 'ELU'):
41
+ Activation function.
42
+ egt_simple (`bool`, *optional*, defaults to False):
43
+ If `False`, update the edge embedding.
44
+ upto_hop (`int`, *optional*, defaults to 16):
45
+ Maximum distance between nodes in the distance matrices.
46
+ mlp_ratios (`List[float]`, *optional*, defaults to [1., 1.]):
47
+ Ratios of inner dimensions with respect to the input dimension in MLP output block.
48
+ num_virtual_nodes (`int`, *optional*, defaults to 4):
49
+ Number of virtual nodes in EGT model, aggregated to graph embedding in the readout function.
50
+ svd_pe_size (`int`, *optional*, defaults to 8):
51
+ SVD positional encoding size.
52
+ num_classes (`int`, *optional*, defaults to 1):
53
+ Number of target classes or labels, set to n for binary classification of n tasks.
54
+ use_cache (`bool`, *optional*, defaults to `True`):
55
+ Whether or not the model should return the last key/values attentions (not used by all models).
56
+ traceable (`bool`, *optional*, defaults to `False`):
57
+ Changes return value of the encoder's inner_state to stacked tensors.
58
+
59
+ Example:
60
+ ```python
61
+ >>> from transformers import EGTForGraphClassification, EGTConfig
62
+
63
+ >>> # Initializing a EGT graphormer-base-pcqm4mv2 style configuration
64
+ >>> configuration = EGTConfig()
65
+
66
+ >>> # Initializing a model from the graphormer-base-pcqm4mv1 style configuration
67
+ >>> model = EGTForGraphClassification(configuration)
68
+
69
+ >>> # Accessing the model configuration
70
+ >>> configuration = model.config
71
+ ```
72
+ """
73
+ model_type = "egt"
74
+ keys_to_ignore_at_inference = ["past_key_values"]
75
+
76
+ def __init__(
77
+ self,
78
+ feat_size: int = 768,
79
+ edge_feat_size: int = 64,
80
+ num_heads: int = 32,
81
+ num_layers: int = 30,
82
+ dropout: float = 0.0,
83
+ attn_dropout: float = 0.3,
84
+ activation: str = "ELU",
85
+ egt_simple: bool = False,
86
+ upto_hop: int = 16,
87
+ mlp_ratios: List[float] = [1.0, 1.0],
88
+ num_virtual_nodes: int = 4,
89
+ svd_pe_size: int = 8,
90
+ num_classes: int = 1,
91
+ pad_token_id=0,
92
+ bos_token_id=1,
93
+ eos_token_id=2,
94
+ **kwargs,
95
+ ):
96
+ self.feat_size = feat_size
97
+ self.edge_feat_size = edge_feat_size
98
+ self.num_heads = num_heads
99
+ self.num_layers = num_layers
100
+ self.dropout = dropout
101
+ self.attn_dropout = attn_dropout
102
+ self.activation = activation
103
+ self.egt_simple = egt_simple
104
+ self.upto_hop = upto_hop
105
+ self.mlp_ratios = mlp_ratios
106
+ self.num_virtual_nodes = num_virtual_nodes
107
+ self.svd_pe_size = svd_pe_size
108
+ self.num_classes = num_classes
109
+
110
+ super().__init__(
111
+ pad_token_id=pad_token_id,
112
+ bos_token_id=bos_token_id,
113
+ eos_token_id=eos_token_id,
114
+ **kwargs,
115
+ )
egt_model/modeling_egt.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch EGT model."""
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from dgl.nn import EGTLayer
9
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
+
11
+ from transformers.modeling_outputs import (
12
+ BaseModelOutputWithNoAttention,
13
+ SequenceClassifierOutput,
14
+ )
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from .configuration_egt import EGTConfig
17
+
18
+
19
+ NODE_FEATURES_OFFSET = 128
20
+ NUM_NODE_FEATURES = 9
21
+ EDGE_FEATURES_OFFSET = 8
22
+ NUM_EDGE_FEATURES = 3
23
+
24
+
25
+ class VirtualNodes(nn.Module):
26
+ """
27
+ Generate node and edge features for virtual nodes in the graph
28
+ and pad the corresponding matrices.
29
+ """
30
+
31
+ def __init__(self, feat_size, edge_feat_size, num_virtual_nodes=1):
32
+ super().__init__()
33
+ self.feat_size = feat_size
34
+ self.edge_feat_size = edge_feat_size
35
+ self.num_virtual_nodes = num_virtual_nodes
36
+
37
+ self.vn_node_embeddings = nn.Parameter(torch.empty(num_virtual_nodes, self.feat_size))
38
+ self.vn_edge_embeddings = nn.Parameter(torch.empty(num_virtual_nodes, self.edge_feat_size))
39
+ nn.init.normal_(self.vn_node_embeddings)
40
+ nn.init.normal_(self.vn_edge_embeddings)
41
+
42
+ def forward(self, h, e, mask):
43
+ node_emb = self.vn_node_embeddings.unsqueeze(0).expand(h.shape[0], -1, -1)
44
+ h = torch.cat([node_emb, h], dim=1)
45
+
46
+ e_shape = e.shape
47
+ edge_emb_row = self.vn_edge_embeddings.unsqueeze(1)
48
+ edge_emb_col = self.vn_edge_embeddings.unsqueeze(0)
49
+ edge_emb_box = 0.5 * (edge_emb_row + edge_emb_col)
50
+
51
+ edge_emb_row = edge_emb_row.unsqueeze(0).expand(e_shape[0], -1, e_shape[2], -1)
52
+ edge_emb_col = edge_emb_col.unsqueeze(0).expand(e_shape[0], e_shape[1], -1, -1)
53
+ edge_emb_box = edge_emb_box.unsqueeze(0).expand(e_shape[0], -1, -1, -1)
54
+
55
+ e = torch.cat([edge_emb_row, e], dim=1)
56
+ e_col_box = torch.cat([edge_emb_box, edge_emb_col], dim=1)
57
+ e = torch.cat([e_col_box, e], dim=2)
58
+
59
+ if mask is not None:
60
+ mask = F.pad(mask, (self.num_virtual_nodes, 0, self.num_virtual_nodes, 0), mode="constant", value=0)
61
+ return h, e, mask
62
+
63
+
64
+ class EGTPreTrainedModel(PreTrainedModel):
65
+ """
66
+ A simple interface for downloading and loading pretrained models.
67
+ """
68
+
69
+ config_class = EGTConfig
70
+ base_model_prefix = "egt"
71
+ supports_gradient_checkpointing = True
72
+ main_input_name_nodes = "node_feat"
73
+ main_input_name_edges = "featm"
74
+
75
+ def _set_gradient_checkpointing(self, module, value=False):
76
+ if isinstance(module, EGTModel):
77
+ module.gradient_checkpointing = value
78
+
79
+
80
+ class EGTModel(EGTPreTrainedModel):
81
+ """The EGT model is a graph-encoder model.
82
+
83
+ It goes from a graph to its representation. If you want to use the model for a downstream classification task, use
84
+ EGTForGraphClassification instead. For any other downstream task, feel free to add a new class, or combine
85
+ this model with a downstream model of your choice, following the example in EGTForGraphClassification.
86
+ """
87
+
88
+ def __init__(self, config: EGTConfig):
89
+ super().__init__(config)
90
+
91
+ self.activation = getattr(nn, config.activation)()
92
+
93
+ self.layer_common_kwargs = {
94
+ "feat_size": config.feat_size,
95
+ "edge_feat_size": config.edge_feat_size,
96
+ "num_heads": config.num_heads,
97
+ "num_virtual_nodes": config.num_virtual_nodes,
98
+ "dropout": config.dropout,
99
+ "attn_dropout": config.attn_dropout,
100
+ "activation": self.activation,
101
+ }
102
+ self.edge_update = not config.egt_simple
103
+
104
+ self.EGT_layers = nn.ModuleList(
105
+ [EGTLayer(**self.layer_common_kwargs, edge_update=self.edge_update) for _ in range(config.num_layers - 1)]
106
+ )
107
+
108
+ self.EGT_layers.append(EGTLayer(**self.layer_common_kwargs, edge_update=False))
109
+
110
+ self.upto_hop = config.upto_hop
111
+ self.num_virtual_nodes = config.num_virtual_nodes
112
+ self.svd_pe_size = config.svd_pe_size
113
+
114
+ self.nodef_embed = nn.Embedding(NUM_NODE_FEATURES * NODE_FEATURES_OFFSET + 1, config.feat_size, padding_idx=0)
115
+ if self.svd_pe_size:
116
+ self.svd_embed = nn.Linear(self.svd_pe_size * 2, config.feat_size)
117
+
118
+ self.dist_embed = nn.Embedding(self.upto_hop + 2, config.edge_feat_size)
119
+ self.featm_embed = nn.Embedding(
120
+ NUM_EDGE_FEATURES * EDGE_FEATURES_OFFSET + 1, config.edge_feat_size, padding_idx=0
121
+ )
122
+
123
+ if self.num_virtual_nodes > 0:
124
+ self.vn_layer = VirtualNodes(config.feat_size, config.edge_feat_size, self.num_virtual_nodes)
125
+
126
+ self.final_ln_h = nn.LayerNorm(config.feat_size)
127
+ mlp_dims = (
128
+ [config.feat_size * max(self.num_virtual_nodes, 1)]
129
+ + [round(config.feat_size * r) for r in config.mlp_ratios]
130
+ + [config.num_classes]
131
+ )
132
+ self.mlp_layers = nn.ModuleList([nn.Linear(mlp_dims[i], mlp_dims[i + 1]) for i in range(len(mlp_dims) - 1)])
133
+ self.mlp_fn = self.activation
134
+
135
+ self._backward_compatibility_gradient_checkpointing()
136
+
137
+ def input_block(self, nodef, featm, dm, nodem, svd_pe):
138
+ dm = dm.long().clamp(min=0, max=self.upto_hop + 1) # (b,i,j)
139
+
140
+ h = self.nodef_embed(nodef).sum(dim=2) # (b,i,w,h) -> (b,i,h)
141
+
142
+ if self.svd_pe_size:
143
+ h = h + self.svd_embed(svd_pe)
144
+
145
+ e = self.dist_embed(dm) + self.featm_embed(featm).sum(dim=3) # (b,i,j,f,e) -> (b,i,j,e)
146
+
147
+ mask = (nodem[:, :, None] * nodem[:, None, :] - 1) * 1e9
148
+
149
+ if self.num_virtual_nodes > 0:
150
+ h, e, mask = self.vn_layer(h, e, mask)
151
+ return h, e, mask
152
+
153
+ def final_embedding(self, h, attn_mask):
154
+ h = self.final_ln_h(h)
155
+ if self.num_virtual_nodes > 0:
156
+ h = h[:, : self.num_virtual_nodes].reshape(h.shape[0], -1)
157
+ else:
158
+ nodem = attn_mask.float().unsqueeze(dim=-1)
159
+ h = (h * nodem).sum(dim=1) / (nodem.sum(dim=1) + 1e-9)
160
+ return h
161
+
162
+ def output_block(self, h):
163
+ h = self.mlp_layers[0](h)
164
+ for layer in self.mlp_layers[1:]:
165
+ h = layer(self.mlp_fn(h))
166
+ return h
167
+
168
+ def forward(
169
+ self,
170
+ node_feat: torch.LongTensor,
171
+ featm: torch.LongTensor,
172
+ dm: torch.LongTensor,
173
+ attn_mask: torch.LongTensor,
174
+ svd_pe: torch.Tensor,
175
+ return_dict: Optional[bool] = None,
176
+ **unused,
177
+ ) -> torch.Tensor:
178
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
179
+
180
+ h, e, mask = self.input_block(node_feat, featm, dm, attn_mask, svd_pe)
181
+
182
+ for layer in self.EGT_layers[:-1]:
183
+ if self.edge_update:
184
+ h, e = layer(h, e, mask)
185
+ else:
186
+ h = layer(h, e, mask)
187
+
188
+ h = self.EGT_layers[-1](h, e, mask)
189
+
190
+ h = self.final_embedding(h, attn_mask)
191
+
192
+ outputs = self.output_block(h)
193
+
194
+ if not return_dict:
195
+ return tuple(x for x in [outputs] if x is not None)
196
+ return BaseModelOutputWithNoAttention(last_hidden_state=outputs)
197
+
198
+
199
+ class EGTForGraphClassification(EGTPreTrainedModel):
200
+ """
201
+ This model can be used for graph-level classification or regression tasks.
202
+
203
+ It can be trained on
204
+ - regression (by setting config.num_classes to 1); there should be one float-type label per graph
205
+ - one task classification (by setting config.num_classes to the number of classes); there should be one integer
206
+ label per graph
207
+ - binary multi-task classification (by setting config.num_classes to the number of labels); there should be a list
208
+ of integer labels for each graph.
209
+ """
210
+
211
+ def __init__(self, config: EGTConfig):
212
+ super().__init__(config)
213
+ self.model = EGTModel(config)
214
+ self.num_classes = config.num_classes
215
+
216
+ self._backward_compatibility_gradient_checkpointing()
217
+
218
+ def forward(
219
+ self,
220
+ node_feat: torch.LongTensor,
221
+ featm: torch.LongTensor,
222
+ dm: torch.LongTensor,
223
+ attn_mask: torch.LongTensor,
224
+ svd_pe: torch.Tensor,
225
+ labels: Optional[torch.LongTensor] = None,
226
+ return_dict: Optional[bool] = None,
227
+ **unused,
228
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
229
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
230
+
231
+ logits = self.model(
232
+ node_feat,
233
+ featm,
234
+ dm,
235
+ attn_mask,
236
+ svd_pe,
237
+ return_dict=True,
238
+ )["last_hidden_state"]
239
+
240
+ loss = None
241
+ if labels is not None:
242
+ mask = ~torch.isnan(labels)
243
+
244
+ if self.num_classes == 1: # regression
245
+ loss_fct = MSELoss()
246
+ loss = loss_fct(logits[mask].squeeze(), labels[mask].squeeze().float())
247
+ elif self.num_classes > 1 and len(labels.shape) == 1: # One task classification
248
+ loss_fct = CrossEntropyLoss()
249
+ loss = loss_fct(logits[mask].view(-1, self.num_classes), labels[mask].view(-1))
250
+ else: # Binary multi-task classification
251
+ loss_fct = BCEWithLogitsLoss(reduction="sum")
252
+ loss = loss_fct(logits[mask], labels[mask])
253
+
254
+ if not return_dict:
255
+ return tuple(x for x in [loss, logits] if x is not None)
256
+ return SequenceClassifierOutput(loss=loss, logits=logits, attentions=None)
share_model.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from egt_model.configuration_egt import EGTConfig
3
+ from egt_model.modeling_egt import EGTModel, EGTForGraphClassification
4
+
5
+ EGTConfig.register_for_auto_class()
6
+ EGTModel.register_for_auto_class("AutoModel")
7
+ EGTForGraphClassification.register_for_auto_class("AutoModelForGraphClassification")
8
+
9
+ egt_config = EGTConfig()
10
+ egt = EGTForGraphClassification(egt_config)
11
+
12
+ pretrained_model = torch.load("/home/ubuntu/transformers/egt_model_state")
13
+ egt.model.load_state_dict(pretrained_model.state_dict())
14
+
15
+ # egt.push_to_hub("Zhiteng/egt")