graphist / models /edcoder.py
ogutsevda's picture
Upload 3 files
8cd7b86 verified
"""
(c) Adaptation of the code from https://github.com/THUDM/GraphMAE
"""
from typing import Optional
from itertools import chain
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import dropout_edge
from torch_geometric.utils import add_self_loops
from .acm_gin import ACM_GIN_model
def sce_loss(x, y, alpha=3):
x = F.normalize(x, p=2, dim=-1)
y = F.normalize(y, p=2, dim=-1)
loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
loss = loss.mean()
return loss
def setup_module(
m_type,
in_dim,
out_dim,
num_hidden,
num_layers,
activation,
batchnorm,
) -> nn.Module:
if m_type == "acm_gin":
mod = ACM_GIN_model(
int(in_dim),
int(out_dim),
num_layers,
int(num_hidden),
batchnorm,
activation=activation,
)
else:
raise NotImplementedError
return mod
class PreModel(nn.Module):
def __init__(
self,
in_dim: int,
edge_in_dim: int,
num_hidden: int,
num_layers: int,
nhead: int,
nhead_out: int,
activation: str,
feat_drop: float,
attn_drop: float,
negative_slope: float,
residual: bool,
norm: Optional[str],
mask_rate: float = 0.3,
encoder_type: str = "gat",
decoder_type: str = "gat",
loss_fn: str = "sce",
drop_edge_rate: float = 0.0,
replace_rate: float = 0.1,
alpha_l: float = 2,
concat_hidden: bool = False,
batchnorm=False,
):
super(PreModel, self).__init__()
self._mask_rate = mask_rate
self._encoder_type = encoder_type
self._decoder_type = decoder_type
self._drop_edge_rate = drop_edge_rate
self._output_hidden_size = num_hidden
self._concat_hidden = concat_hidden
self._replace_rate = replace_rate
self._mask_token_rate = 1 - self._replace_rate
assert num_hidden % nhead == 0
assert num_hidden % nhead_out == 0
enc_num_hidden = num_hidden
enc_nhead = 1
dec_in_dim = num_hidden
dec_num_hidden = num_hidden
# Build encoder
self.encoder = setup_module(
m_type=encoder_type,
in_dim=in_dim,
out_dim=enc_num_hidden,
num_hidden=enc_num_hidden,
num_layers=num_layers,
activation=activation,
batchnorm=batchnorm,
)
# Build decoder for attribute prediction
self.decoder = setup_module(
m_type=decoder_type,
in_dim=dec_in_dim,
out_dim=in_dim,
num_hidden=dec_num_hidden,
num_layers=1,
activation=activation,
batchnorm=batchnorm,
)
self.enc_mask_token = nn.Parameter(torch.zeros(1, in_dim))
if concat_hidden:
self.encoder_to_decoder = nn.Linear(
dec_in_dim * num_layers, dec_in_dim, bias=False
)
else:
self.encoder_to_decoder = nn.Linear(dec_in_dim, dec_in_dim, bias=False)
# Setup loss function
self.criterion = self.setup_loss_fn(loss_fn, alpha_l)
@property
def output_hidden_dim(self):
return self._output_hidden_size
def setup_loss_fn(self, loss_fn, alpha_l):
if loss_fn == "mse":
criterion = nn.MSELoss()
elif loss_fn == "sce":
criterion = partial(sce_loss, alpha=alpha_l)
else:
raise NotImplementedError
return criterion
def encoding_mask_noise(self, x, mask_rate=0.3, virtual_node_index=None):
num_nodes = x.shape[0]
all_indices = torch.arange(num_nodes, device=x.device)
# Remove virtual node index from masking candidates
if virtual_node_index is not None:
all_indices = all_indices[~torch.isin(all_indices, virtual_node_index)]
perm = all_indices[torch.randperm(len(all_indices), device=x.device)]
# random masking
num_mask_nodes = int(mask_rate * len(perm))
mask_nodes = perm[:num_mask_nodes]
keep_nodes = perm[num_mask_nodes:]
out_x = x.clone()
if self._replace_rate > 0:
num_noise_nodes = int(self._replace_rate * num_mask_nodes)
perm_mask = torch.randperm(num_mask_nodes, device=x.device)
token_nodes = mask_nodes[
perm_mask[: int(self._mask_token_rate * num_mask_nodes)]
]
noise_nodes = mask_nodes[
perm_mask[-int(self._replace_rate * num_mask_nodes) :]
]
noise_to_be_chosen = torch.randperm(len(perm), device=x.device)[
:num_noise_nodes
]
noise_to_be_chosen = all_indices[noise_to_be_chosen]
out_x[token_nodes] = 0.0
out_x[noise_nodes] = x[noise_to_be_chosen]
else:
token_nodes = mask_nodes
out_x[mask_nodes] = 0.0
out_x[token_nodes] += self.enc_mask_token
return out_x, (mask_nodes, keep_nodes)
def forward(self, batch):
# ---- attribute reconstruction ----
x, edge_index, edge_attr, virtual_node_index, batch = (
batch.x,
batch.edge_index,
batch.edge_attr,
getattr(batch, "virtual_node_index", None),
batch.batch,
)
loss = self.mask_attr_prediction(
x, edge_index, edge_attr, batch, virtual_node_index
)
return loss
def mask_attr_prediction(self, x, edge_index, edge_attr, batch, virtual_node_index):
use_x, (mask_nodes, keep_nodes) = self.encoding_mask_noise(
x,
self._mask_rate,
virtual_node_index,
)
if self._drop_edge_rate > 0:
use_edge_index, masked_edges = dropout_edge(
edge_index, self._drop_edge_rate
)
use_edge_attr = edge_attr[masked_edges]
use_edge_index, use_edge_attr = add_self_loops(
use_edge_index, use_edge_attr, fill_value="min"
)
else:
use_edge_index = edge_index
use_edge_attr = edge_attr
enc_rep, all_hidden = self.encoder(
use_x, use_edge_index, use_edge_attr, return_hidden=True
)
if self._concat_hidden:
enc_rep = torch.cat(all_hidden, dim=1)
# ---- attribute reconstruction ----
rep = self.encoder_to_decoder(enc_rep)
if self._decoder_type not in ("mlp", "linear"):
# * remask, re-mask
rep[mask_nodes] = 0
if self._decoder_type in ("mlp", "linear"):
recon = self.decoder(rep)
else:
recon = self.decoder(rep, use_edge_index, use_edge_attr)
x_init = x[mask_nodes]
x_rec = recon[mask_nodes]
loss = self.criterion(x_rec, x_init)
return loss
def embed(self, x, edge_index, edge_attr, batch):
if self._concat_hidden:
enc_rep, all_hidden = self.encoder(
x, edge_index, edge_attr, return_hidden=True
)
enc_rep = torch.cat(all_hidden, dim=1)
else:
enc_rep = self.encoder(x, edge_index, edge_attr)
rep = self.encoder_to_decoder(enc_rep)
return rep
@property
def enc_params(self):
return self.encoder.parameters()
@property
def dec_params(self):
return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()])