ogutsevda commited on
Commit
8536369
·
verified ·
1 Parent(s): 2eabbb9

Delete models

Browse files
Files changed (4) hide show
  1. models/__init__.py +0 -55
  2. models/acm_gin.py +0 -207
  3. models/edcoder.py +0 -260
  4. models/utils.py +0 -75
models/__init__.py DELETED
@@ -1,55 +0,0 @@
1
- """
2
- (c) Adaptation of the code from https://github.com/THUDM/GraphMAE
3
- """
4
-
5
- from .edcoder import PreModel
6
-
7
-
8
- def build_model(args):
9
- num_heads = args.num_heads
10
- num_out_heads = args.num_out_heads
11
- num_hidden = args.num_hidden
12
- num_layers = args.num_layers
13
- residual = args.residual
14
- attn_drop = args.attn_drop
15
- in_drop = args.in_drop
16
- norm = args.norm
17
- negative_slope = args.negative_slope
18
- encoder_type = args.encoder
19
- decoder_type = args.decoder
20
- mask_rate = args.mask_rate
21
- drop_edge_rate = args.drop_edge_rate
22
- replace_rate = args.replace_rate
23
- batchnorm = args.batchnorm
24
-
25
- activation = args.activation
26
- loss_fn = args.loss_fn
27
- alpha_l = args.alpha_l
28
- concat_hidden = args.concat_hidden
29
- num_features = args.num_features
30
- num_edge_features = args.num_edge_features
31
-
32
- model = PreModel(
33
- in_dim=int(num_features),
34
- edge_in_dim=int(num_edge_features),
35
- num_hidden=int(num_hidden),
36
- num_layers=num_layers,
37
- nhead=num_heads,
38
- nhead_out=num_out_heads,
39
- activation=activation,
40
- feat_drop=in_drop,
41
- attn_drop=attn_drop,
42
- negative_slope=negative_slope,
43
- residual=residual,
44
- encoder_type=encoder_type,
45
- decoder_type=decoder_type,
46
- mask_rate=mask_rate,
47
- norm=norm,
48
- loss_fn=loss_fn,
49
- drop_edge_rate=drop_edge_rate,
50
- replace_rate=replace_rate,
51
- alpha_l=alpha_l,
52
- concat_hidden=concat_hidden,
53
- batchnorm=batchnorm,
54
- )
55
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/acm_gin.py DELETED
@@ -1,207 +0,0 @@
1
- """
2
- (c) Adaptation of the code from https://github.com/SitaoLuan/ACM-GNN
3
- """
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from torch import Tensor
9
- from typing import Union
10
- from torch_geometric.nn.conv import MessagePassing
11
- from torch_geometric.nn.inits import reset
12
- from torch_geometric.typing import OptPairTensor, OptTensor, Size
13
- from torch_geometric.utils import scatter
14
-
15
- from models.utils import create_activation
16
-
17
-
18
- class ACM_GIN(MessagePassing):
19
- def __init__(
20
- self,
21
- nn_lowpass: torch.nn.Module,
22
- nn_highpass: torch.nn.Module,
23
- nn_fullpass: torch.nn.Module,
24
- nn_lowpass_proj: torch.nn.Module,
25
- nn_highpass_proj: torch.nn.Module,
26
- nn_fullpass_proj: torch.nn.Module,
27
- nn_mix: torch.nn.Module,
28
- T: float = 3.0,
29
- **kwargs,
30
- ):
31
- kwargs.setdefault("aggr", "add")
32
- super().__init__(**kwargs)
33
- self.nn_lowpass = nn_lowpass
34
- self.nn_highpass = nn_highpass
35
- self.nn_fullpass = nn_fullpass
36
- self.nn_lowpass_proj = nn_lowpass_proj
37
- self.nn_highpass_proj = nn_highpass_proj
38
- self.nn_fullpass_proj = nn_fullpass_proj
39
- self.nn_mix = nn_mix
40
- self.sigmoid = torch.nn.Sigmoid()
41
- self.softmax = torch.nn.Softmax(dim=1)
42
- self.T = T
43
- self.reset_parameters()
44
-
45
- def reset_parameters(self):
46
- reset(self.nn_lowpass)
47
- reset(self.nn_highpass)
48
- reset(self.nn_fullpass)
49
- reset(self.nn_lowpass_proj)
50
- reset(self.nn_highpass_proj)
51
- reset(self.nn_fullpass_proj)
52
- reset(self.nn_mix)
53
-
54
- def forward(
55
- self,
56
- x: Union[Tensor, OptPairTensor],
57
- edge_index: Tensor,
58
- edge_weight: OptTensor = None,
59
- size: Size = None,
60
- ) -> Tensor:
61
-
62
- if isinstance(x, Tensor):
63
- x: OptPairTensor = (x, x)
64
-
65
- # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
66
- out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=size)
67
-
68
- deg = scatter(edge_weight, edge_index[1], 0, out.size(0), reduce="sum")
69
- deg_inv = 1.0 / deg
70
- deg_inv.masked_fill_(deg_inv == float("inf"), 0)
71
- out = deg_inv.view(-1, 1) * out
72
-
73
- x_r = x[1]
74
- if x_r is not None:
75
- out_lowpass = (x_r + out) / 2.0
76
- out_highpass = (x_r - out) / 2.0
77
-
78
- # compute embeddings for each filter
79
- out_lowpass = self.nn_lowpass(out_lowpass)
80
- out_highpass = self.nn_highpass(out_highpass)
81
- out_fullpass = self.nn_fullpass(x_r)
82
- # compute importance weights per filter
83
- alpha_lowpass = self.sigmoid(self.nn_lowpass_proj(out_lowpass))
84
- alpha_highpass = self.sigmoid(self.nn_highpass_proj(out_highpass))
85
- alpha_fullpass = self.sigmoid(self.nn_fullpass_proj(out_fullpass))
86
- alpha_cat = torch.concat([alpha_lowpass, alpha_highpass, alpha_fullpass], dim=1)
87
- alpha_cat = self.softmax(self.nn_mix(alpha_cat / self.T))
88
-
89
- out = alpha_cat[:, 0].view(-1, 1) * out_lowpass
90
- out = out + alpha_cat[:, 1].view(-1, 1) * out_highpass
91
- out = out + alpha_cat[:, 2].view(-1, 1) * out_fullpass
92
-
93
- return out
94
-
95
- def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
96
- return edge_weight.view(-1, 1) * x_j
97
-
98
- def __repr__(self) -> str:
99
- return f"{self.__class__.__name__}(nn={self.nn})"
100
-
101
-
102
- class ACM_GIN_model(nn.Module):
103
- """ """
104
-
105
- def __init__(
106
- self, in_dim, out_dim, num_layers, hidden_dim, batchnorm, activation="relu"
107
- ):
108
- super(ACM_GIN_model, self).__init__()
109
- self.num_layers = num_layers
110
- self.hidden_dim = hidden_dim
111
- self.gnn_batchnorm = batchnorm
112
- self.out_dim = out_dim
113
-
114
- self.ACM_convs = nn.ModuleList()
115
- self.nns_lowpass = nn.ModuleList()
116
- self.nns_highpass = nn.ModuleList()
117
- self.nns_fullpass = nn.ModuleList()
118
- self.nns_lowpass_proj = nn.ModuleList()
119
- self.nns_highpass_proj = nn.ModuleList()
120
- self.nns_fullpass_proj = nn.ModuleList()
121
- self.nns_mix = nn.ModuleList()
122
-
123
- self.activation = create_activation(activation)
124
-
125
- for i in range(self.num_layers):
126
- # projection modules to compute importance weights
127
- for channel_proj_module in [
128
- self.nns_lowpass_proj,
129
- self.nns_highpass_proj,
130
- self.nns_fullpass_proj,
131
- ]:
132
- if i == self.num_layers - 1:
133
- channel_proj_module.append(nn.Linear(self.out_dim, 1))
134
- else:
135
- channel_proj_module.append(nn.Linear(self.hidden_dim, 1))
136
- # weights mixing module as attention mechanism
137
- self.nns_mix.append(nn.Linear(3, 3))
138
-
139
- # GIN embedding scheme per channel
140
- if i == 0:
141
- local_input_dim = in_dim
142
- else:
143
- local_input_dim = self.hidden_dim
144
-
145
- if i == self.num_layers - 1:
146
- local_out_dim = self.out_dim
147
- else:
148
- local_out_dim = self.hidden_dim
149
-
150
- for channel_module in [
151
- self.nns_lowpass,
152
- self.nns_highpass,
153
- self.nns_fullpass,
154
- ]:
155
- if self.gnn_batchnorm:
156
- sequential = nn.Sequential(
157
- nn.Linear(local_input_dim, self.hidden_dim),
158
- nn.BatchNorm1d(self.hidden_dim),
159
- self.activation,
160
- nn.Linear(self.hidden_dim, local_out_dim),
161
- nn.BatchNorm1d(local_out_dim),
162
- self.activation,
163
- )
164
- else:
165
- sequential = nn.Sequential(
166
- nn.Linear(local_input_dim, self.hidden_dim),
167
- self.activation,
168
- nn.Linear(self.hidden_dim, local_out_dim),
169
- self.activation,
170
- )
171
-
172
- channel_module.append(sequential)
173
-
174
- self.ACM_convs.append(
175
- ACM_GIN(
176
- nn_lowpass=self.nns_lowpass[i],
177
- nn_highpass=self.nns_highpass[i],
178
- nn_fullpass=self.nns_fullpass[i],
179
- nn_lowpass_proj=self.nns_lowpass_proj[i],
180
- nn_highpass_proj=self.nns_highpass_proj[i],
181
- nn_fullpass_proj=self.nns_fullpass_proj[i],
182
- nn_mix=self.nns_mix[i],
183
- )
184
- )
185
-
186
- def reset_parameters(self):
187
- for m in self.modules():
188
- if isinstance(m, nn.Linear):
189
- m.reset_parameters()
190
- elif isinstance(m, nn.BatchNorm1d):
191
- m.reset_parameters()
192
-
193
- def forward(self, x, edge_index, edge_attr, return_hidden=False):
194
- outs = []
195
- for i in range(self.num_layers):
196
- x = self.ACM_convs[i](x=x, edge_index=edge_index, edge_weight=edge_attr)
197
- outs.append(x)
198
- if return_hidden:
199
- return x, outs
200
- else:
201
- return x
202
-
203
-
204
- if __name__ == "__main__":
205
- acm_gin = ACM_GIN_model(46, 46, 2, 256, True)
206
- print(sum(p.numel() for p in acm_gin.parameters() if p.requires_grad))
207
- print("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/edcoder.py DELETED
@@ -1,260 +0,0 @@
1
- """
2
- (c) Adaptation of the code from https://github.com/THUDM/GraphMAE
3
- """
4
-
5
- from typing import Optional
6
- from itertools import chain
7
- from functools import partial
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- from .acm_gin import ACM_GIN_model
13
- from torch_geometric.utils import dropout_edge
14
- from torch_geometric.utils import add_self_loops
15
-
16
-
17
- def sce_loss(x, y, alpha=3):
18
- x = F.normalize(x, p=2, dim=-1)
19
- y = F.normalize(y, p=2, dim=-1)
20
-
21
- loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
22
- loss = loss.mean()
23
-
24
- return loss
25
-
26
-
27
- def setup_module(
28
- m_type,
29
- in_dim,
30
- out_dim,
31
- num_hidden,
32
- num_layers,
33
- activation,
34
- batchnorm,
35
- ) -> nn.Module:
36
-
37
- if m_type == "acm_gin":
38
- mod = ACM_GIN_model(
39
- int(in_dim),
40
- int(out_dim),
41
- num_layers,
42
- int(num_hidden),
43
- batchnorm,
44
- activation=activation,
45
- )
46
- else:
47
- raise NotImplementedError
48
-
49
- return mod
50
-
51
-
52
- class PreModel(nn.Module):
53
- def __init__(
54
- self,
55
- in_dim: int,
56
- edge_in_dim: int,
57
- num_hidden: int,
58
- num_layers: int,
59
- nhead: int,
60
- nhead_out: int,
61
- activation: str,
62
- feat_drop: float,
63
- attn_drop: float,
64
- negative_slope: float,
65
- residual: bool,
66
- norm: Optional[str],
67
- mask_rate: float = 0.3,
68
- encoder_type: str = "gat",
69
- decoder_type: str = "gat",
70
- loss_fn: str = "sce",
71
- drop_edge_rate: float = 0.0,
72
- replace_rate: float = 0.1,
73
- alpha_l: float = 2,
74
- concat_hidden: bool = False,
75
- batchnorm=False,
76
- ):
77
- super(PreModel, self).__init__()
78
- self._mask_rate = mask_rate
79
- self._encoder_type = encoder_type
80
- self._decoder_type = decoder_type
81
- self._drop_edge_rate = drop_edge_rate
82
- self._output_hidden_size = num_hidden
83
- self._concat_hidden = concat_hidden
84
-
85
- self._replace_rate = replace_rate
86
- self._mask_token_rate = 1 - self._replace_rate
87
-
88
- assert num_hidden % nhead == 0
89
- assert num_hidden % nhead_out == 0
90
-
91
- enc_num_hidden = num_hidden
92
- enc_nhead = 1
93
-
94
- dec_in_dim = num_hidden
95
- dec_num_hidden = num_hidden
96
-
97
- # Build encoder
98
- self.encoder = setup_module(
99
- m_type=encoder_type,
100
- in_dim=in_dim,
101
- out_dim=enc_num_hidden,
102
- num_hidden=enc_num_hidden,
103
- num_layers=num_layers,
104
- activation=activation,
105
- batchnorm=batchnorm,
106
- )
107
-
108
- # Build decoder for attribute prediction
109
- self.decoder = setup_module(
110
- m_type=decoder_type,
111
- in_dim=dec_in_dim,
112
- out_dim=in_dim,
113
- num_hidden=dec_num_hidden,
114
- num_layers=1,
115
- activation=activation,
116
- batchnorm=batchnorm,
117
- )
118
-
119
- self.enc_mask_token = nn.Parameter(torch.zeros(1, in_dim))
120
- if concat_hidden:
121
- self.encoder_to_decoder = nn.Linear(
122
- dec_in_dim * num_layers, dec_in_dim, bias=False
123
- )
124
- else:
125
- self.encoder_to_decoder = nn.Linear(dec_in_dim, dec_in_dim, bias=False)
126
-
127
- # Setup loss function
128
- self.criterion = self.setup_loss_fn(loss_fn, alpha_l)
129
-
130
- @property
131
- def output_hidden_dim(self):
132
- return self._output_hidden_size
133
-
134
- def setup_loss_fn(self, loss_fn, alpha_l):
135
- if loss_fn == "mse":
136
- criterion = nn.MSELoss()
137
- elif loss_fn == "sce":
138
- criterion = partial(sce_loss, alpha=alpha_l)
139
- else:
140
- raise NotImplementedError
141
- return criterion
142
-
143
- def encoding_mask_noise(self, x, mask_rate=0.3, virtual_node_index=None):
144
- num_nodes = x.shape[0]
145
- all_indices = torch.arange(num_nodes, device=x.device)
146
-
147
- # Remove virtual node index from masking candidates
148
- if virtual_node_index is not None:
149
- all_indices = all_indices[~torch.isin(all_indices, virtual_node_index)]
150
-
151
- perm = all_indices[torch.randperm(len(all_indices), device=x.device)]
152
-
153
- # random masking
154
- num_mask_nodes = int(mask_rate * len(perm))
155
- mask_nodes = perm[:num_mask_nodes]
156
- keep_nodes = perm[num_mask_nodes:]
157
-
158
- out_x = x.clone()
159
-
160
- if self._replace_rate > 0:
161
- num_noise_nodes = int(self._replace_rate * num_mask_nodes)
162
- perm_mask = torch.randperm(num_mask_nodes, device=x.device)
163
- token_nodes = mask_nodes[
164
- perm_mask[: int(self._mask_token_rate * num_mask_nodes)]
165
- ]
166
- noise_nodes = mask_nodes[
167
- perm_mask[-int(self._replace_rate * num_mask_nodes) :]
168
- ]
169
- noise_to_be_chosen = torch.randperm(len(perm), device=x.device)[
170
- :num_noise_nodes
171
- ]
172
- noise_to_be_chosen = all_indices[noise_to_be_chosen]
173
-
174
- out_x[token_nodes] = 0.0
175
- out_x[noise_nodes] = x[noise_to_be_chosen]
176
- else:
177
- token_nodes = mask_nodes
178
- out_x[mask_nodes] = 0.0
179
-
180
- out_x[token_nodes] += self.enc_mask_token
181
-
182
- return out_x, (mask_nodes, keep_nodes)
183
-
184
- def forward(self, batch):
185
- # ---- attribute reconstruction ----
186
- x, edge_index, edge_attr, virtual_node_index, batch = (
187
- batch.x,
188
- batch.edge_index,
189
- batch.edge_attr,
190
- getattr(batch, "virtual_node_index", None),
191
- batch.batch,
192
- )
193
- loss = self.mask_attr_prediction(
194
- x, edge_index, edge_attr, batch, virtual_node_index
195
- )
196
- return loss
197
-
198
- def mask_attr_prediction(self, x, edge_index, edge_attr, batch, virtual_node_index):
199
-
200
- use_x, (mask_nodes, keep_nodes) = self.encoding_mask_noise(
201
- x,
202
- self._mask_rate,
203
- virtual_node_index,
204
- )
205
-
206
- if self._drop_edge_rate > 0:
207
- use_edge_index, masked_edges = dropout_edge(
208
- edge_index, self._drop_edge_rate
209
- )
210
- use_edge_attr = edge_attr[masked_edges]
211
- use_edge_index, use_edge_attr = add_self_loops(
212
- use_edge_index, use_edge_attr, fill_value="min"
213
- )
214
- else:
215
- use_edge_index = edge_index
216
- use_edge_attr = edge_attr
217
-
218
- enc_rep, all_hidden = self.encoder(
219
- use_x, use_edge_index, use_edge_attr, return_hidden=True
220
- )
221
- if self._concat_hidden:
222
- enc_rep = torch.cat(all_hidden, dim=1)
223
-
224
- # ---- attribute reconstruction ----
225
- rep = self.encoder_to_decoder(enc_rep)
226
-
227
- if self._decoder_type not in ("mlp", "linear"):
228
- # * remask, re-mask
229
- rep[mask_nodes] = 0
230
-
231
- if self._decoder_type in ("mlp", "linear"):
232
- recon = self.decoder(rep)
233
- else:
234
- recon = self.decoder(rep, use_edge_index, use_edge_attr)
235
-
236
- x_init = x[mask_nodes]
237
- x_rec = recon[mask_nodes]
238
-
239
- loss = self.criterion(x_rec, x_init)
240
-
241
- return loss
242
-
243
- def embed(self, x, edge_index, edge_attr, batch):
244
- if self._concat_hidden:
245
- enc_rep, all_hidden = self.encoder(
246
- x, edge_index, edge_attr, return_hidden=True
247
- )
248
- enc_rep = torch.cat(all_hidden, dim=1)
249
- else:
250
- enc_rep = self.encoder(x, edge_index, edge_attr)
251
- rep = self.encoder_to_decoder(enc_rep)
252
- return rep
253
-
254
- @property
255
- def enc_params(self):
256
- return self.encoder.parameters()
257
-
258
- @property
259
- def dec_params(self):
260
- return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/utils.py DELETED
@@ -1,75 +0,0 @@
1
- import torch.nn as nn
2
- import torch
3
- from functools import partial
4
-
5
-
6
- def create_activation(name):
7
- if name == "relu":
8
- return nn.ReLU()
9
- elif name == "gelu":
10
- return nn.GELU()
11
- elif name == "prelu":
12
- return nn.PReLU()
13
- elif name is None:
14
- return nn.Identity()
15
- elif name == "elu":
16
- return nn.ELU()
17
- else:
18
- raise NotImplementedError(f"{name} is not implemented.")
19
-
20
-
21
- def create_norm(name):
22
- if name == "layernorm":
23
- return nn.LayerNorm
24
- elif name == "batchnorm":
25
- return nn.BatchNorm1d
26
- elif name == "graphnorm":
27
- return partial(NormLayer, norm_type="groupnorm")
28
- else:
29
- return nn.Identity
30
-
31
-
32
- class NormLayer(nn.Module):
33
- def __init__(self, hidden_dim, norm_type):
34
- super().__init__()
35
- if norm_type == "batchnorm":
36
- self.norm = nn.BatchNorm1d(hidden_dim)
37
- elif norm_type == "layernorm":
38
- self.norm = nn.LayerNorm(hidden_dim)
39
- elif norm_type == "graphnorm":
40
- self.norm = norm_type
41
- self.weight = nn.Parameter(torch.ones(hidden_dim))
42
- self.bias = nn.Parameter(torch.zeros(hidden_dim))
43
-
44
- self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
45
- else:
46
- raise NotImplementedError
47
-
48
- def forward(self, graph, x):
49
- tensor = x
50
- if self.norm is not None and type(self.norm) != str:
51
- return self.norm(tensor)
52
- elif self.norm is None:
53
- return tensor
54
-
55
- batch_list = graph.batch_num_nodes
56
- batch_size = len(batch_list)
57
- batch_list = torch.Tensor(batch_list).long().to(tensor.device)
58
- batch_index = (
59
- torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
60
- )
61
- batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(
62
- tensor
63
- )
64
- mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
65
- mean = mean.scatter_add_(0, batch_index, tensor)
66
- mean = (mean.T / batch_list).T
67
- mean = mean.repeat_interleave(batch_list, dim=0)
68
-
69
- sub = tensor - mean * self.mean_scale
70
-
71
- std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
72
- std = std.scatter_add_(0, batch_index, sub.pow(2))
73
- std = ((std.T / batch_list).T + 1e-6).sqrt()
74
- std = std.repeat_interleave(batch_list, dim=0)
75
- return self.weight * sub / std + self.bias