ogutsevda commited on
Commit
8cd7b86
·
verified ·
1 Parent(s): 6177cee

Upload 3 files

Browse files
Files changed (3) hide show
  1. models/acm_gin.py +207 -0
  2. models/edcoder.py +261 -0
  3. models/utils.py +75 -0
models/acm_gin.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ (c) Adaptation of the code from https://github.com/SitaoLuan/ACM-GNN
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from torch import Tensor
9
+ from typing import Union
10
+ from torch_geometric.nn.conv import MessagePassing
11
+ from torch_geometric.nn.inits import reset
12
+ from torch_geometric.typing import OptPairTensor, OptTensor, Size
13
+ from torch_geometric.utils import scatter
14
+
15
+ from .utils import create_activation
16
+
17
+
18
+ class ACM_GIN(MessagePassing):
19
+ def __init__(
20
+ self,
21
+ nn_lowpass: torch.nn.Module,
22
+ nn_highpass: torch.nn.Module,
23
+ nn_fullpass: torch.nn.Module,
24
+ nn_lowpass_proj: torch.nn.Module,
25
+ nn_highpass_proj: torch.nn.Module,
26
+ nn_fullpass_proj: torch.nn.Module,
27
+ nn_mix: torch.nn.Module,
28
+ T: float = 3.0,
29
+ **kwargs,
30
+ ):
31
+ kwargs.setdefault("aggr", "add")
32
+ super().__init__(**kwargs)
33
+ self.nn_lowpass = nn_lowpass
34
+ self.nn_highpass = nn_highpass
35
+ self.nn_fullpass = nn_fullpass
36
+ self.nn_lowpass_proj = nn_lowpass_proj
37
+ self.nn_highpass_proj = nn_highpass_proj
38
+ self.nn_fullpass_proj = nn_fullpass_proj
39
+ self.nn_mix = nn_mix
40
+ self.sigmoid = torch.nn.Sigmoid()
41
+ self.softmax = torch.nn.Softmax(dim=1)
42
+ self.T = T
43
+ self.reset_parameters()
44
+
45
+ def reset_parameters(self):
46
+ reset(self.nn_lowpass)
47
+ reset(self.nn_highpass)
48
+ reset(self.nn_fullpass)
49
+ reset(self.nn_lowpass_proj)
50
+ reset(self.nn_highpass_proj)
51
+ reset(self.nn_fullpass_proj)
52
+ reset(self.nn_mix)
53
+
54
+ def forward(
55
+ self,
56
+ x: Union[Tensor, OptPairTensor],
57
+ edge_index: Tensor,
58
+ edge_weight: OptTensor = None,
59
+ size: Size = None,
60
+ ) -> Tensor:
61
+
62
+ if isinstance(x, Tensor):
63
+ x: OptPairTensor = (x, x)
64
+
65
+ # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
66
+ out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size)
67
+
68
+ deg = scatter(edge_weight, edge_index[1], 0, out.size(0), reduce="sum")
69
+ deg_inv = 1.0 / deg
70
+ deg_inv.masked_fill_(deg_inv == float("inf"), 0)
71
+ out = deg_inv.view(-1, 1) * out
72
+
73
+ x_r = x[1]
74
+ if x_r is not None:
75
+ out_lowpass = (x_r + out) / 2.0
76
+ out_highpass = (x_r - out) / 2.0
77
+
78
+ # compute embeddings for each filter
79
+ out_lowpass = self.nn_lowpass(out_lowpass)
80
+ out_highpass = self.nn_highpass(out_highpass)
81
+ out_fullpass = self.nn_fullpass(x_r)
82
+ # compute importance weights per filter
83
+ alpha_lowpass = self.sigmoid(self.nn_lowpass_proj(out_lowpass))
84
+ alpha_highpass = self.sigmoid(self.nn_highpass_proj(out_highpass))
85
+ alpha_fullpass = self.sigmoid(self.nn_fullpass_proj(out_fullpass))
86
+ alpha_cat = torch.concat([alpha_lowpass, alpha_highpass, alpha_fullpass], dim=1)
87
+ alpha_cat = self.softmax(self.nn_mix(alpha_cat / self.T))
88
+
89
+ out = alpha_cat[:, 0].view(-1, 1) * out_lowpass
90
+ out = out + alpha_cat[:, 1].view(-1, 1) * out_highpass
91
+ out = out + alpha_cat[:, 2].view(-1, 1) * out_fullpass
92
+
93
+ return out
94
+
95
+ def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
96
+ return edge_weight.view(-1, 1) * x_j
97
+
98
+ def __repr__(self) -> str:
99
+ return f"{self.__class__.__name__}(nn={self.nn})"
100
+
101
+
102
+ class ACM_GIN_model(nn.Module):
103
+ """ """
104
+
105
+ def __init__(
106
+ self, in_dim, out_dim, num_layers, hidden_dim, batchnorm, activation="relu"
107
+ ):
108
+ super(ACM_GIN_model, self).__init__()
109
+ self.num_layers = num_layers
110
+ self.hidden_dim = hidden_dim
111
+ self.gnn_batchnorm = batchnorm
112
+ self.out_dim = out_dim
113
+
114
+ self.ACM_convs = nn.ModuleList()
115
+ self.nns_lowpass = nn.ModuleList()
116
+ self.nns_highpass = nn.ModuleList()
117
+ self.nns_fullpass = nn.ModuleList()
118
+ self.nns_lowpass_proj = nn.ModuleList()
119
+ self.nns_highpass_proj = nn.ModuleList()
120
+ self.nns_fullpass_proj = nn.ModuleList()
121
+ self.nns_mix = nn.ModuleList()
122
+
123
+ self.activation = create_activation(activation)
124
+
125
+ for i in range(self.num_layers):
126
+ # projection modules to compute importance weights
127
+ for channel_proj_module in [
128
+ self.nns_lowpass_proj,
129
+ self.nns_highpass_proj,
130
+ self.nns_fullpass_proj,
131
+ ]:
132
+ if i == self.num_layers - 1:
133
+ channel_proj_module.append(nn.Linear(self.out_dim, 1))
134
+ else:
135
+ channel_proj_module.append(nn.Linear(self.hidden_dim, 1))
136
+ # weights mixing module as attention mechanism
137
+ self.nns_mix.append(nn.Linear(3, 3))
138
+
139
+ # GIN embedding scheme per channel
140
+ if i == 0:
141
+ local_input_dim = in_dim
142
+ else:
143
+ local_input_dim = self.hidden_dim
144
+
145
+ if i == self.num_layers - 1:
146
+ local_out_dim = self.out_dim
147
+ else:
148
+ local_out_dim = self.hidden_dim
149
+
150
+ for channel_module in [
151
+ self.nns_lowpass,
152
+ self.nns_highpass,
153
+ self.nns_fullpass,
154
+ ]:
155
+ if self.gnn_batchnorm:
156
+ sequential = nn.Sequential(
157
+ nn.Linear(local_input_dim, self.hidden_dim),
158
+ nn.BatchNorm1d(self.hidden_dim),
159
+ self.activation,
160
+ nn.Linear(self.hidden_dim, local_out_dim),
161
+ nn.BatchNorm1d(local_out_dim),
162
+ self.activation,
163
+ )
164
+ else:
165
+ sequential = nn.Sequential(
166
+ nn.Linear(local_input_dim, self.hidden_dim),
167
+ self.activation,
168
+ nn.Linear(self.hidden_dim, local_out_dim),
169
+ self.activation,
170
+ )
171
+
172
+ channel_module.append(sequential)
173
+
174
+ self.ACM_convs.append(
175
+ ACM_GIN(
176
+ nn_lowpass=self.nns_lowpass[i],
177
+ nn_highpass=self.nns_highpass[i],
178
+ nn_fullpass=self.nns_fullpass[i],
179
+ nn_lowpass_proj=self.nns_lowpass_proj[i],
180
+ nn_highpass_proj=self.nns_highpass_proj[i],
181
+ nn_fullpass_proj=self.nns_fullpass_proj[i],
182
+ nn_mix=self.nns_mix[i],
183
+ )
184
+ )
185
+
186
+ def reset_parameters(self):
187
+ for m in self.modules():
188
+ if isinstance(m, nn.Linear):
189
+ m.reset_parameters()
190
+ elif isinstance(m, nn.BatchNorm1d):
191
+ m.reset_parameters()
192
+
193
+ def forward(self, x, edge_index, edge_attr, return_hidden=False):
194
+ outs = []
195
+ for i in range(self.num_layers):
196
+ x = self.ACM_convs[i](x=x, edge_index=edge_index, edge_weight=edge_attr)
197
+ outs.append(x)
198
+ if return_hidden:
199
+ return x, outs
200
+ else:
201
+ return x
202
+
203
+
204
+ if __name__ == "__main__":
205
+ acm_gin = ACM_GIN_model(46, 46, 2, 256, True)
206
+ print(sum(p.numel() for p in acm_gin.parameters() if p.requires_grad))
207
+ print("")
models/edcoder.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ (c) Adaptation of the code from https://github.com/THUDM/GraphMAE
3
+ """
4
+
5
+ from typing import Optional
6
+ from itertools import chain
7
+ from functools import partial
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch_geometric.utils import dropout_edge
13
+ from torch_geometric.utils import add_self_loops
14
+
15
+ from .acm_gin import ACM_GIN_model
16
+
17
+
18
+ def sce_loss(x, y, alpha=3):
19
+ x = F.normalize(x, p=2, dim=-1)
20
+ y = F.normalize(y, p=2, dim=-1)
21
+
22
+ loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
23
+ loss = loss.mean()
24
+
25
+ return loss
26
+
27
+
28
+ def setup_module(
29
+ m_type,
30
+ in_dim,
31
+ out_dim,
32
+ num_hidden,
33
+ num_layers,
34
+ activation,
35
+ batchnorm,
36
+ ) -> nn.Module:
37
+
38
+ if m_type == "acm_gin":
39
+ mod = ACM_GIN_model(
40
+ int(in_dim),
41
+ int(out_dim),
42
+ num_layers,
43
+ int(num_hidden),
44
+ batchnorm,
45
+ activation=activation,
46
+ )
47
+ else:
48
+ raise NotImplementedError
49
+
50
+ return mod
51
+
52
+
53
+ class PreModel(nn.Module):
54
+ def __init__(
55
+ self,
56
+ in_dim: int,
57
+ edge_in_dim: int,
58
+ num_hidden: int,
59
+ num_layers: int,
60
+ nhead: int,
61
+ nhead_out: int,
62
+ activation: str,
63
+ feat_drop: float,
64
+ attn_drop: float,
65
+ negative_slope: float,
66
+ residual: bool,
67
+ norm: Optional[str],
68
+ mask_rate: float = 0.3,
69
+ encoder_type: str = "gat",
70
+ decoder_type: str = "gat",
71
+ loss_fn: str = "sce",
72
+ drop_edge_rate: float = 0.0,
73
+ replace_rate: float = 0.1,
74
+ alpha_l: float = 2,
75
+ concat_hidden: bool = False,
76
+ batchnorm=False,
77
+ ):
78
+ super(PreModel, self).__init__()
79
+ self._mask_rate = mask_rate
80
+ self._encoder_type = encoder_type
81
+ self._decoder_type = decoder_type
82
+ self._drop_edge_rate = drop_edge_rate
83
+ self._output_hidden_size = num_hidden
84
+ self._concat_hidden = concat_hidden
85
+
86
+ self._replace_rate = replace_rate
87
+ self._mask_token_rate = 1 - self._replace_rate
88
+
89
+ assert num_hidden % nhead == 0
90
+ assert num_hidden % nhead_out == 0
91
+
92
+ enc_num_hidden = num_hidden
93
+ enc_nhead = 1
94
+
95
+ dec_in_dim = num_hidden
96
+ dec_num_hidden = num_hidden
97
+
98
+ # Build encoder
99
+ self.encoder = setup_module(
100
+ m_type=encoder_type,
101
+ in_dim=in_dim,
102
+ out_dim=enc_num_hidden,
103
+ num_hidden=enc_num_hidden,
104
+ num_layers=num_layers,
105
+ activation=activation,
106
+ batchnorm=batchnorm,
107
+ )
108
+
109
+ # Build decoder for attribute prediction
110
+ self.decoder = setup_module(
111
+ m_type=decoder_type,
112
+ in_dim=dec_in_dim,
113
+ out_dim=in_dim,
114
+ num_hidden=dec_num_hidden,
115
+ num_layers=1,
116
+ activation=activation,
117
+ batchnorm=batchnorm,
118
+ )
119
+
120
+ self.enc_mask_token = nn.Parameter(torch.zeros(1, in_dim))
121
+ if concat_hidden:
122
+ self.encoder_to_decoder = nn.Linear(
123
+ dec_in_dim * num_layers, dec_in_dim, bias=False
124
+ )
125
+ else:
126
+ self.encoder_to_decoder = nn.Linear(dec_in_dim, dec_in_dim, bias=False)
127
+
128
+ # Setup loss function
129
+ self.criterion = self.setup_loss_fn(loss_fn, alpha_l)
130
+
131
+ @property
132
+ def output_hidden_dim(self):
133
+ return self._output_hidden_size
134
+
135
+ def setup_loss_fn(self, loss_fn, alpha_l):
136
+ if loss_fn == "mse":
137
+ criterion = nn.MSELoss()
138
+ elif loss_fn == "sce":
139
+ criterion = partial(sce_loss, alpha=alpha_l)
140
+ else:
141
+ raise NotImplementedError
142
+ return criterion
143
+
144
+ def encoding_mask_noise(self, x, mask_rate=0.3, virtual_node_index=None):
145
+ num_nodes = x.shape[0]
146
+ all_indices = torch.arange(num_nodes, device=x.device)
147
+
148
+ # Remove virtual node index from masking candidates
149
+ if virtual_node_index is not None:
150
+ all_indices = all_indices[~torch.isin(all_indices, virtual_node_index)]
151
+
152
+ perm = all_indices[torch.randperm(len(all_indices), device=x.device)]
153
+
154
+ # random masking
155
+ num_mask_nodes = int(mask_rate * len(perm))
156
+ mask_nodes = perm[:num_mask_nodes]
157
+ keep_nodes = perm[num_mask_nodes:]
158
+
159
+ out_x = x.clone()
160
+
161
+ if self._replace_rate > 0:
162
+ num_noise_nodes = int(self._replace_rate * num_mask_nodes)
163
+ perm_mask = torch.randperm(num_mask_nodes, device=x.device)
164
+ token_nodes = mask_nodes[
165
+ perm_mask[: int(self._mask_token_rate * num_mask_nodes)]
166
+ ]
167
+ noise_nodes = mask_nodes[
168
+ perm_mask[-int(self._replace_rate * num_mask_nodes) :]
169
+ ]
170
+ noise_to_be_chosen = torch.randperm(len(perm), device=x.device)[
171
+ :num_noise_nodes
172
+ ]
173
+ noise_to_be_chosen = all_indices[noise_to_be_chosen]
174
+
175
+ out_x[token_nodes] = 0.0
176
+ out_x[noise_nodes] = x[noise_to_be_chosen]
177
+ else:
178
+ token_nodes = mask_nodes
179
+ out_x[mask_nodes] = 0.0
180
+
181
+ out_x[token_nodes] += self.enc_mask_token
182
+
183
+ return out_x, (mask_nodes, keep_nodes)
184
+
185
+ def forward(self, batch):
186
+ # ---- attribute reconstruction ----
187
+ x, edge_index, edge_attr, virtual_node_index, batch = (
188
+ batch.x,
189
+ batch.edge_index,
190
+ batch.edge_attr,
191
+ getattr(batch, "virtual_node_index", None),
192
+ batch.batch,
193
+ )
194
+ loss = self.mask_attr_prediction(
195
+ x, edge_index, edge_attr, batch, virtual_node_index
196
+ )
197
+ return loss
198
+
199
+ def mask_attr_prediction(self, x, edge_index, edge_attr, batch, virtual_node_index):
200
+
201
+ use_x, (mask_nodes, keep_nodes) = self.encoding_mask_noise(
202
+ x,
203
+ self._mask_rate,
204
+ virtual_node_index,
205
+ )
206
+
207
+ if self._drop_edge_rate > 0:
208
+ use_edge_index, masked_edges = dropout_edge(
209
+ edge_index, self._drop_edge_rate
210
+ )
211
+ use_edge_attr = edge_attr[masked_edges]
212
+ use_edge_index, use_edge_attr = add_self_loops(
213
+ use_edge_index, use_edge_attr, fill_value="min"
214
+ )
215
+ else:
216
+ use_edge_index = edge_index
217
+ use_edge_attr = edge_attr
218
+
219
+ enc_rep, all_hidden = self.encoder(
220
+ use_x, use_edge_index, use_edge_attr, return_hidden=True
221
+ )
222
+ if self._concat_hidden:
223
+ enc_rep = torch.cat(all_hidden, dim=1)
224
+
225
+ # ---- attribute reconstruction ----
226
+ rep = self.encoder_to_decoder(enc_rep)
227
+
228
+ if self._decoder_type not in ("mlp", "linear"):
229
+ # * remask, re-mask
230
+ rep[mask_nodes] = 0
231
+
232
+ if self._decoder_type in ("mlp", "linear"):
233
+ recon = self.decoder(rep)
234
+ else:
235
+ recon = self.decoder(rep, use_edge_index, use_edge_attr)
236
+
237
+ x_init = x[mask_nodes]
238
+ x_rec = recon[mask_nodes]
239
+
240
+ loss = self.criterion(x_rec, x_init)
241
+
242
+ return loss
243
+
244
+ def embed(self, x, edge_index, edge_attr, batch):
245
+ if self._concat_hidden:
246
+ enc_rep, all_hidden = self.encoder(
247
+ x, edge_index, edge_attr, return_hidden=True
248
+ )
249
+ enc_rep = torch.cat(all_hidden, dim=1)
250
+ else:
251
+ enc_rep = self.encoder(x, edge_index, edge_attr)
252
+ rep = self.encoder_to_decoder(enc_rep)
253
+ return rep
254
+
255
+ @property
256
+ def enc_params(self):
257
+ return self.encoder.parameters()
258
+
259
+ @property
260
+ def dec_params(self):
261
+ return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()])
models/utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from functools import partial
4
+
5
+
6
+ def create_activation(name):
7
+ if name == "relu":
8
+ return nn.ReLU()
9
+ elif name == "gelu":
10
+ return nn.GELU()
11
+ elif name == "prelu":
12
+ return nn.PReLU()
13
+ elif name is None:
14
+ return nn.Identity()
15
+ elif name == "elu":
16
+ return nn.ELU()
17
+ else:
18
+ raise NotImplementedError(f"{name} is not implemented.")
19
+
20
+
21
+ def create_norm(name):
22
+ if name == "layernorm":
23
+ return nn.LayerNorm
24
+ elif name == "batchnorm":
25
+ return nn.BatchNorm1d
26
+ elif name == "graphnorm":
27
+ return partial(NormLayer, norm_type="groupnorm")
28
+ else:
29
+ return nn.Identity
30
+
31
+
32
+ class NormLayer(nn.Module):
33
+ def __init__(self, hidden_dim, norm_type):
34
+ super().__init__()
35
+ if norm_type == "batchnorm":
36
+ self.norm = nn.BatchNorm1d(hidden_dim)
37
+ elif norm_type == "layernorm":
38
+ self.norm = nn.LayerNorm(hidden_dim)
39
+ elif norm_type == "graphnorm":
40
+ self.norm = norm_type
41
+ self.weight = nn.Parameter(torch.ones(hidden_dim))
42
+ self.bias = nn.Parameter(torch.zeros(hidden_dim))
43
+
44
+ self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
45
+ else:
46
+ raise NotImplementedError
47
+
48
+ def forward(self, graph, x):
49
+ tensor = x
50
+ if self.norm is not None and type(self.norm) != str:
51
+ return self.norm(tensor)
52
+ elif self.norm is None:
53
+ return tensor
54
+
55
+ batch_list = graph.batch_num_nodes
56
+ batch_size = len(batch_list)
57
+ batch_list = torch.Tensor(batch_list).long().to(tensor.device)
58
+ batch_index = (
59
+ torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
60
+ )
61
+ batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(
62
+ tensor
63
+ )
64
+ mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
65
+ mean = mean.scatter_add_(0, batch_index, tensor)
66
+ mean = (mean.T / batch_list).T
67
+ mean = mean.repeat_interleave(batch_list, dim=0)
68
+
69
+ sub = tensor - mean * self.mean_scale
70
+
71
+ std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
72
+ std = std.scatter_add_(0, batch_index, sub.pow(2))
73
+ std = ((std.T / batch_list).T + 1e-6).sqrt()
74
+ std = std.repeat_interleave(batch_list, dim=0)
75
+ return self.weight * sub / std + self.bias