| """ PyTorch EGT model.""" |
|
|
| from typing import Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dgl.nn import EGTLayer |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithNoAttention, |
| SequenceClassifierOutput, |
| ) |
| from transformers.modeling_utils import PreTrainedModel |
| from .configuration_egt import EGTConfig |
|
|
|
|
| NODE_FEATURES_OFFSET = 128 |
| NUM_NODE_FEATURES = 9 |
| EDGE_FEATURES_OFFSET = 8 |
| NUM_EDGE_FEATURES = 3 |
|
|
|
|
| class VirtualNodes(nn.Module): |
| """ |
| Generate node and edge features for virtual nodes in the graph |
| and pad the corresponding matrices. |
| """ |
|
|
| def __init__(self, feat_size, edge_feat_size, num_virtual_nodes=1): |
| super().__init__() |
| self.feat_size = feat_size |
| self.edge_feat_size = edge_feat_size |
| self.num_virtual_nodes = num_virtual_nodes |
|
|
| self.vn_node_embeddings = nn.Parameter(torch.empty(num_virtual_nodes, self.feat_size)) |
| self.vn_edge_embeddings = nn.Parameter(torch.empty(num_virtual_nodes, self.edge_feat_size)) |
| nn.init.normal_(self.vn_node_embeddings) |
| nn.init.normal_(self.vn_edge_embeddings) |
|
|
| def forward(self, h, e, mask): |
| node_emb = self.vn_node_embeddings.unsqueeze(0).expand(h.shape[0], -1, -1) |
| h = torch.cat([node_emb, h], dim=1) |
|
|
| e_shape = e.shape |
| edge_emb_row = self.vn_edge_embeddings.unsqueeze(1) |
| edge_emb_col = self.vn_edge_embeddings.unsqueeze(0) |
| edge_emb_box = 0.5 * (edge_emb_row + edge_emb_col) |
|
|
| edge_emb_row = edge_emb_row.unsqueeze(0).expand(e_shape[0], -1, e_shape[2], -1) |
| edge_emb_col = edge_emb_col.unsqueeze(0).expand(e_shape[0], e_shape[1], -1, -1) |
| edge_emb_box = edge_emb_box.unsqueeze(0).expand(e_shape[0], -1, -1, -1) |
|
|
| e = torch.cat([edge_emb_row, e], dim=1) |
| e_col_box = torch.cat([edge_emb_box, edge_emb_col], dim=1) |
| e = torch.cat([e_col_box, e], dim=2) |
|
|
| if mask is not None: |
| mask = F.pad(mask, (self.num_virtual_nodes, 0, self.num_virtual_nodes, 0), mode="constant", value=0) |
| return h, e, mask |
|
|
|
|
| class EGTPreTrainedModel(PreTrainedModel): |
| """ |
| A simple interface for downloading and loading pretrained models. |
| """ |
|
|
| config_class = EGTConfig |
| base_model_prefix = "egt" |
| supports_gradient_checkpointing = True |
| main_input_name_nodes = "node_feat" |
| main_input_name_edges = "featm" |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| if isinstance(module, EGTModel): |
| module.gradient_checkpointing = value |
|
|
|
|
| class EGTModel(EGTPreTrainedModel): |
| """The EGT model is a graph-encoder model. |
| |
| It goes from a graph to its representation. If you want to use the model for a downstream classification task, use |
| EGTForGraphClassification instead. For any other downstream task, feel free to add a new class, or combine |
| this model with a downstream model of your choice, following the example in EGTForGraphClassification. |
| """ |
|
|
| def __init__(self, config: EGTConfig): |
| super().__init__(config) |
|
|
| self.activation = getattr(nn, config.activation)() |
|
|
| self.layer_common_kwargs = { |
| "feat_size": config.feat_size, |
| "edge_feat_size": config.edge_feat_size, |
| "num_heads": config.num_heads, |
| "num_virtual_nodes": config.num_virtual_nodes, |
| "dropout": config.dropout, |
| "attn_dropout": config.attn_dropout, |
| "activation": self.activation, |
| } |
| self.edge_update = not config.egt_simple |
|
|
| self.EGT_layers = nn.ModuleList( |
| [EGTLayer(**self.layer_common_kwargs, edge_update=self.edge_update) for _ in range(config.num_layers - 1)] |
| ) |
|
|
| self.EGT_layers.append(EGTLayer(**self.layer_common_kwargs, edge_update=False)) |
|
|
| self.upto_hop = config.upto_hop |
| self.num_virtual_nodes = config.num_virtual_nodes |
| self.svd_pe_size = config.svd_pe_size |
|
|
| self.nodef_embed = nn.Embedding(NUM_NODE_FEATURES * NODE_FEATURES_OFFSET + 1, config.feat_size, padding_idx=0) |
| if self.svd_pe_size: |
| self.svd_embed = nn.Linear(self.svd_pe_size * 2, config.feat_size) |
|
|
| self.dist_embed = nn.Embedding(self.upto_hop + 2, config.edge_feat_size) |
| self.featm_embed = nn.Embedding( |
| NUM_EDGE_FEATURES * EDGE_FEATURES_OFFSET + 1, config.edge_feat_size, padding_idx=0 |
| ) |
|
|
| if self.num_virtual_nodes > 0: |
| self.vn_layer = VirtualNodes(config.feat_size, config.edge_feat_size, self.num_virtual_nodes) |
|
|
| self.final_ln_h = nn.LayerNorm(config.feat_size) |
| mlp_dims = ( |
| [config.feat_size * max(self.num_virtual_nodes, 1)] |
| + [round(config.feat_size * r) for r in config.mlp_ratios] |
| + [config.num_classes] |
| ) |
| self.mlp_layers = nn.ModuleList([nn.Linear(mlp_dims[i], mlp_dims[i + 1]) for i in range(len(mlp_dims) - 1)]) |
| self.mlp_fn = self.activation |
|
|
| self._backward_compatibility_gradient_checkpointing() |
|
|
| def input_block(self, nodef, featm, dm, nodem, svd_pe): |
| dm = dm.long().clamp(min=0, max=self.upto_hop + 1) |
|
|
| h = self.nodef_embed(nodef).sum(dim=2) |
|
|
| if self.svd_pe_size: |
| h = h + self.svd_embed(svd_pe) |
|
|
| e = self.dist_embed(dm) + self.featm_embed(featm).sum(dim=3) |
|
|
| mask = (nodem[:, :, None] * nodem[:, None, :] - 1) * 1e9 |
|
|
| if self.num_virtual_nodes > 0: |
| h, e, mask = self.vn_layer(h, e, mask) |
| return h, e, mask |
|
|
| def final_embedding(self, h, attn_mask): |
| h = self.final_ln_h(h) |
| if self.num_virtual_nodes > 0: |
| h = h[:, : self.num_virtual_nodes].reshape(h.shape[0], -1) |
| else: |
| nodem = attn_mask.float().unsqueeze(dim=-1) |
| h = (h * nodem).sum(dim=1) / (nodem.sum(dim=1) + 1e-9) |
| return h |
|
|
| def output_block(self, h): |
| h = self.mlp_layers[0](h) |
| for layer in self.mlp_layers[1:]: |
| h = layer(self.mlp_fn(h)) |
| return h |
|
|
| def forward( |
| self, |
| node_feat: torch.LongTensor, |
| featm: torch.LongTensor, |
| dm: torch.LongTensor, |
| attn_mask: torch.LongTensor, |
| svd_pe: torch.Tensor, |
| return_dict: Optional[bool] = None, |
| **unused, |
| ) -> torch.Tensor: |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| h, e, mask = self.input_block(node_feat, featm, dm, attn_mask, svd_pe) |
|
|
| for layer in self.EGT_layers[:-1]: |
| if self.edge_update: |
| h, e = layer(h, e, mask) |
| else: |
| h = layer(h, e, mask) |
|
|
| h = self.EGT_layers[-1](h, e, mask) |
|
|
| h = self.final_embedding(h, attn_mask) |
|
|
| outputs = self.output_block(h) |
|
|
| if not return_dict: |
| return tuple(x for x in [outputs] if x is not None) |
| return BaseModelOutputWithNoAttention(last_hidden_state=outputs) |
|
|
|
|
| class EGTForGraphClassification(EGTPreTrainedModel): |
| """ |
| This model can be used for graph-level classification or regression tasks. |
| |
| It can be trained on |
| - regression (by setting config.num_classes to 1); there should be one float-type label per graph |
| - one task classification (by setting config.num_classes to the number of classes); there should be one integer |
| label per graph |
| - binary multi-task classification (by setting config.num_classes to the number of labels); there should be a list |
| of integer labels for each graph. |
| """ |
|
|
| def __init__(self, config: EGTConfig): |
| super().__init__(config) |
| self.model = EGTModel(config) |
| self.num_classes = config.num_classes |
|
|
| self._backward_compatibility_gradient_checkpointing() |
|
|
| def forward( |
| self, |
| node_feat: torch.LongTensor, |
| featm: torch.LongTensor, |
| dm: torch.LongTensor, |
| attn_mask: torch.LongTensor, |
| svd_pe: torch.Tensor, |
| labels: Optional[torch.LongTensor] = None, |
| return_dict: Optional[bool] = None, |
| **unused, |
| ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| logits = self.model( |
| node_feat, |
| featm, |
| dm, |
| attn_mask, |
| svd_pe, |
| return_dict=True, |
| )["last_hidden_state"] |
|
|
| loss = None |
| if labels is not None: |
| mask = ~torch.isnan(labels) |
|
|
| if self.num_classes == 1: |
| loss_fct = MSELoss() |
| loss = loss_fct(logits[mask].squeeze(), labels[mask].squeeze().float()) |
| elif self.num_classes > 1 and len(labels.shape) == 1: |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits[mask].view(-1, self.num_classes), labels[mask].view(-1)) |
| else: |
| loss_fct = BCEWithLogitsLoss(reduction="sum") |
| loss = loss_fct(logits[mask], labels[mask]) |
|
|
| if not return_dict: |
| return tuple(x for x in [loss, logits] if x is not None) |
| return SequenceClassifierOutput(loss=loss, logits=logits, attentions=None) |
|
|