| | """ |
| | (c) Adaptation of the code from https://github.com/THUDM/GraphMAE |
| | """ |
| |
|
| | from typing import Optional |
| | from itertools import chain |
| | from functools import partial |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch_geometric.utils import dropout_edge |
| | from torch_geometric.utils import add_self_loops |
| |
|
| | from .acm_gin import ACM_GIN_model |
| |
|
| |
|
| | def sce_loss(x, y, alpha=3): |
| | x = F.normalize(x, p=2, dim=-1) |
| | y = F.normalize(y, p=2, dim=-1) |
| |
|
| | loss = (1 - (x * y).sum(dim=-1)).pow_(alpha) |
| | loss = loss.mean() |
| |
|
| | return loss |
| |
|
| |
|
| | def setup_module( |
| | m_type, |
| | in_dim, |
| | out_dim, |
| | num_hidden, |
| | num_layers, |
| | activation, |
| | batchnorm, |
| | ) -> nn.Module: |
| |
|
| | if m_type == "acm_gin": |
| | mod = ACM_GIN_model( |
| | int(in_dim), |
| | int(out_dim), |
| | num_layers, |
| | int(num_hidden), |
| | batchnorm, |
| | activation=activation, |
| | ) |
| | else: |
| | raise NotImplementedError |
| |
|
| | return mod |
| |
|
| |
|
| | class PreModel(nn.Module): |
| | def __init__( |
| | self, |
| | in_dim: int, |
| | edge_in_dim: int, |
| | num_hidden: int, |
| | num_layers: int, |
| | nhead: int, |
| | nhead_out: int, |
| | activation: str, |
| | feat_drop: float, |
| | attn_drop: float, |
| | negative_slope: float, |
| | residual: bool, |
| | norm: Optional[str], |
| | mask_rate: float = 0.3, |
| | encoder_type: str = "gat", |
| | decoder_type: str = "gat", |
| | loss_fn: str = "sce", |
| | drop_edge_rate: float = 0.0, |
| | replace_rate: float = 0.1, |
| | alpha_l: float = 2, |
| | concat_hidden: bool = False, |
| | batchnorm=False, |
| | ): |
| | super(PreModel, self).__init__() |
| | self._mask_rate = mask_rate |
| | self._encoder_type = encoder_type |
| | self._decoder_type = decoder_type |
| | self._drop_edge_rate = drop_edge_rate |
| | self._output_hidden_size = num_hidden |
| | self._concat_hidden = concat_hidden |
| |
|
| | self._replace_rate = replace_rate |
| | self._mask_token_rate = 1 - self._replace_rate |
| |
|
| | assert num_hidden % nhead == 0 |
| | assert num_hidden % nhead_out == 0 |
| |
|
| | enc_num_hidden = num_hidden |
| | enc_nhead = 1 |
| |
|
| | dec_in_dim = num_hidden |
| | dec_num_hidden = num_hidden |
| |
|
| | |
| | self.encoder = setup_module( |
| | m_type=encoder_type, |
| | in_dim=in_dim, |
| | out_dim=enc_num_hidden, |
| | num_hidden=enc_num_hidden, |
| | num_layers=num_layers, |
| | activation=activation, |
| | batchnorm=batchnorm, |
| | ) |
| |
|
| | |
| | self.decoder = setup_module( |
| | m_type=decoder_type, |
| | in_dim=dec_in_dim, |
| | out_dim=in_dim, |
| | num_hidden=dec_num_hidden, |
| | num_layers=1, |
| | activation=activation, |
| | batchnorm=batchnorm, |
| | ) |
| |
|
| | self.enc_mask_token = nn.Parameter(torch.zeros(1, in_dim)) |
| | if concat_hidden: |
| | self.encoder_to_decoder = nn.Linear( |
| | dec_in_dim * num_layers, dec_in_dim, bias=False |
| | ) |
| | else: |
| | self.encoder_to_decoder = nn.Linear(dec_in_dim, dec_in_dim, bias=False) |
| |
|
| | |
| | self.criterion = self.setup_loss_fn(loss_fn, alpha_l) |
| |
|
| | @property |
| | def output_hidden_dim(self): |
| | return self._output_hidden_size |
| |
|
| | def setup_loss_fn(self, loss_fn, alpha_l): |
| | if loss_fn == "mse": |
| | criterion = nn.MSELoss() |
| | elif loss_fn == "sce": |
| | criterion = partial(sce_loss, alpha=alpha_l) |
| | else: |
| | raise NotImplementedError |
| | return criterion |
| |
|
| | def encoding_mask_noise(self, x, mask_rate=0.3, virtual_node_index=None): |
| | num_nodes = x.shape[0] |
| | all_indices = torch.arange(num_nodes, device=x.device) |
| |
|
| | |
| | if virtual_node_index is not None: |
| | all_indices = all_indices[~torch.isin(all_indices, virtual_node_index)] |
| |
|
| | perm = all_indices[torch.randperm(len(all_indices), device=x.device)] |
| |
|
| | |
| | num_mask_nodes = int(mask_rate * len(perm)) |
| | mask_nodes = perm[:num_mask_nodes] |
| | keep_nodes = perm[num_mask_nodes:] |
| |
|
| | out_x = x.clone() |
| |
|
| | if self._replace_rate > 0: |
| | num_noise_nodes = int(self._replace_rate * num_mask_nodes) |
| | perm_mask = torch.randperm(num_mask_nodes, device=x.device) |
| | token_nodes = mask_nodes[ |
| | perm_mask[: int(self._mask_token_rate * num_mask_nodes)] |
| | ] |
| | noise_nodes = mask_nodes[ |
| | perm_mask[-int(self._replace_rate * num_mask_nodes) :] |
| | ] |
| | noise_to_be_chosen = torch.randperm(len(perm), device=x.device)[ |
| | :num_noise_nodes |
| | ] |
| | noise_to_be_chosen = all_indices[noise_to_be_chosen] |
| |
|
| | out_x[token_nodes] = 0.0 |
| | out_x[noise_nodes] = x[noise_to_be_chosen] |
| | else: |
| | token_nodes = mask_nodes |
| | out_x[mask_nodes] = 0.0 |
| |
|
| | out_x[token_nodes] += self.enc_mask_token |
| |
|
| | return out_x, (mask_nodes, keep_nodes) |
| |
|
| | def forward(self, batch): |
| | |
| | x, edge_index, edge_attr, virtual_node_index, batch = ( |
| | batch.x, |
| | batch.edge_index, |
| | batch.edge_attr, |
| | getattr(batch, "virtual_node_index", None), |
| | batch.batch, |
| | ) |
| | loss = self.mask_attr_prediction( |
| | x, edge_index, edge_attr, batch, virtual_node_index |
| | ) |
| | return loss |
| |
|
| | def mask_attr_prediction(self, x, edge_index, edge_attr, batch, virtual_node_index): |
| |
|
| | use_x, (mask_nodes, keep_nodes) = self.encoding_mask_noise( |
| | x, |
| | self._mask_rate, |
| | virtual_node_index, |
| | ) |
| |
|
| | if self._drop_edge_rate > 0: |
| | use_edge_index, masked_edges = dropout_edge( |
| | edge_index, self._drop_edge_rate |
| | ) |
| | use_edge_attr = edge_attr[masked_edges] |
| | use_edge_index, use_edge_attr = add_self_loops( |
| | use_edge_index, use_edge_attr, fill_value="min" |
| | ) |
| | else: |
| | use_edge_index = edge_index |
| | use_edge_attr = edge_attr |
| |
|
| | enc_rep, all_hidden = self.encoder( |
| | use_x, use_edge_index, use_edge_attr, return_hidden=True |
| | ) |
| | if self._concat_hidden: |
| | enc_rep = torch.cat(all_hidden, dim=1) |
| |
|
| | |
| | rep = self.encoder_to_decoder(enc_rep) |
| |
|
| | if self._decoder_type not in ("mlp", "linear"): |
| | |
| | rep[mask_nodes] = 0 |
| |
|
| | if self._decoder_type in ("mlp", "linear"): |
| | recon = self.decoder(rep) |
| | else: |
| | recon = self.decoder(rep, use_edge_index, use_edge_attr) |
| |
|
| | x_init = x[mask_nodes] |
| | x_rec = recon[mask_nodes] |
| |
|
| | loss = self.criterion(x_rec, x_init) |
| |
|
| | return loss |
| |
|
| | def embed(self, x, edge_index, edge_attr, batch): |
| | if self._concat_hidden: |
| | enc_rep, all_hidden = self.encoder( |
| | x, edge_index, edge_attr, return_hidden=True |
| | ) |
| | enc_rep = torch.cat(all_hidden, dim=1) |
| | else: |
| | enc_rep = self.encoder(x, edge_index, edge_attr) |
| | rep = self.encoder_to_decoder(enc_rep) |
| | return rep |
| |
|
| | @property |
| | def enc_params(self): |
| | return self.encoder.parameters() |
| |
|
| | @property |
| | def dec_params(self): |
| | return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()]) |
| |
|