File size: 25,360 Bytes
6e7d4ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 |
"""
Geometric Vector Perceptron implementation taken from:
https://github.com/drorlab/gvp-pytorch/blob/main/gvp/__init__.py
"""
import copy
import warnings
import torch, functools
from torch import nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter_add, scatter_mean
def tuple_sum(*args):
'''
Sums any number of tuples (s, V) elementwise.
'''
return tuple(map(sum, zip(*args)))
def tuple_cat(*args, dim=-1):
'''
Concatenates any number of tuples (s, V) elementwise.
:param dim: dimension along which to concatenate when viewed
as the `dim` index for the scalar-channel tensors.
This means that `dim=-1` will be applied as
`dim=-2` for the vector-channel tensors.
'''
dim %= len(args[0][0].shape)
s_args, v_args = list(zip(*args))
return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim)
def tuple_index(x, idx):
'''
Indexes into a tuple (s, V) along the first dimension.
:param idx: any object which can be used to index into a `torch.Tensor`
'''
return x[0][idx], x[1][idx]
def randn(n, dims, device="cpu"):
'''
Returns random tuples (s, V) drawn elementwise from a normal distribution.
:param n: number of data points
:param dims: tuple of dimensions (n_scalar, n_vector)
:return: (s, V) with s.shape = (n, n_scalar) and
V.shape = (n, n_vector, 3)
'''
return torch.randn(n, dims[0], device=device), \
torch.randn(n, dims[1], 3, device=device)
def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True):
'''
L2 norm of tensor clamped above a minimum value `eps`.
:param sqrt: if `False`, returns the square of the L2 norm
'''
out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
return torch.sqrt(out) if sqrt else out
def _split(x, nv):
'''
Splits a merged representation of (s, V) back into a tuple.
Should be used only with `_merge(s, V)` and only if the tuple
representation cannot be used.
:param x: the `torch.Tensor` returned from `_merge`
:param nv: the number of vector channels in the input to `_merge`
'''
v = torch.reshape(x[..., -3 * nv:], x.shape[:-1] + (nv, 3))
s = x[..., :-3 * nv]
return s, v
def _merge(s, v):
'''
Merges a tuple (s, V) into a single `torch.Tensor`, where the
vector channels are flattened and appended to the scalar channels.
Should be used only if the tuple representation cannot be used.
Use `_split(x, nv)` to reverse.
'''
v = torch.reshape(v, v.shape[:-2] + (3 * v.shape[-2],))
return torch.cat([s, v], -1)
class GVP(nn.Module):
'''
Geometric Vector Perceptron. See manuscript and README.md
for more details.
:param in_dims: tuple (n_scalar, n_vector)
:param out_dims: tuple (n_scalar, n_vector)
:param h_dim: intermediate number of vector channels, optional
:param activations: tuple of functions (scalar_act, vector_act)
:param vector_gate: whether to use vector gating.
(vector_act will be used as sigma^+ in vector gating if `True`)
'''
def __init__(self, in_dims, out_dims, h_dim=None,
activations=(F.relu, torch.sigmoid), vector_gate=False):
super(GVP, self).__init__()
self.si, self.vi = in_dims
self.so, self.vo = out_dims
self.vector_gate = vector_gate
if self.vi:
self.h_dim = h_dim or max(self.vi, self.vo)
self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
self.ws = nn.Linear(self.h_dim + self.si, self.so)
if self.vo:
self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo)
else:
self.ws = nn.Linear(self.si, self.so)
self.scalar_act, self.vector_act = activations
self.dummy_param = nn.Parameter(torch.empty(0))
def forward(self, x):
'''
:param x: tuple (s, V) of `torch.Tensor`,
or (if vectors_in is 0), a single `torch.Tensor`
:return: tuple (s, V) of `torch.Tensor`,
or (if vectors_out is 0), a single `torch.Tensor`
'''
if self.vi:
s, v = x
v = torch.transpose(v, -1, -2)
vh = self.wh(v)
vn = _norm_no_nan(vh, axis=-2)
s = self.ws(torch.cat([s, vn], -1))
if self.vo:
v = self.wv(vh)
v = torch.transpose(v, -1, -2)
if self.vector_gate:
if self.vector_act:
gate = self.wsv(self.vector_act(s))
else:
gate = self.wsv(s)
v = v * torch.sigmoid(gate).unsqueeze(-1)
elif self.vector_act:
v = v * self.vector_act(
_norm_no_nan(v, axis=-1, keepdims=True))
else:
s = self.ws(x)
if self.vo:
v = torch.zeros(s.shape[0], self.vo, 3,
device=self.dummy_param.device)
if self.scalar_act:
s = self.scalar_act(s)
return (s, v) if self.vo else s
class _VDropout(nn.Module):
'''
Vector channel dropout where the elements of each
vector channel are dropped together.
'''
def __init__(self, drop_rate):
super(_VDropout, self).__init__()
self.drop_rate = drop_rate
self.dummy_param = nn.Parameter(torch.empty(0))
def forward(self, x):
'''
:param x: `torch.Tensor` corresponding to vector channels
'''
device = self.dummy_param.device
if not self.training:
return x
mask = torch.bernoulli(
(1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device)
).unsqueeze(-1)
x = mask * x / (1 - self.drop_rate)
return x
class Dropout(nn.Module):
'''
Combined dropout for tuples (s, V).
Takes tuples (s, V) as input and as output.
'''
def __init__(self, drop_rate):
super(Dropout, self).__init__()
self.sdropout = nn.Dropout(drop_rate)
self.vdropout = _VDropout(drop_rate)
def forward(self, x):
'''
:param x: tuple (s, V) of `torch.Tensor`,
or single `torch.Tensor`
(will be assumed to be scalar channels)
'''
if type(x) is torch.Tensor:
return self.sdropout(x)
s, v = x
return self.sdropout(s), self.vdropout(v)
class LayerNorm(nn.Module):
'''
Combined LayerNorm for tuples (s, V).
Takes tuples (s, V) as input and as output.
'''
def __init__(self, dims, learnable_vector_weight=False):
super(LayerNorm, self).__init__()
self.s, self.v = dims
self.scalar_norm = nn.LayerNorm(self.s)
self.vector_norm = VectorLayerNorm(self.v, learnable_vector_weight) \
if self.v > 0 else None
def forward(self, x):
'''
:param x: tuple (s, V) of `torch.Tensor`,
or single `torch.Tensor`
(will be assumed to be scalar channels)
'''
if not self.v:
return self.scalar_norm(x)
s, v = x
# vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False)
# vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True))
# return self.scalar_norm(s), v / vn
return self.scalar_norm(s), self.vector_norm(v)
class VectorLayerNorm(nn.Module):
"""
Equivariant normalization of vector-valued features inspired by:
Liao, Yi-Lun, and Tess Smidt.
"Equiformer: Equivariant graph attention transformer for 3d atomistic graphs."
arXiv preprint arXiv:2206.11990 (2022).
Section 4.1, "Layer Normalization"
"""
def __init__(self, n_channels, learnable_weight=True):
super(VectorLayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(1, n_channels, 1)) \
if learnable_weight else None # (1, c, 1)
def forward(self, x):
"""
Computes LN(x) = ( x / RMS( L2-norm(x) ) ) * gamma
:param x: input tensor (n, c, 3)
:return: layer normalized vector feature
"""
norm2 = _norm_no_nan(x, axis=-1, keepdims=True, sqrt=False) # (n, c, 1)
rms = torch.sqrt(torch.mean(norm2, dim=-2, keepdim=True)) # (n, 1, 1)
x = x / rms # (n, c, 3)
if self.gamma is not None:
x = x * self.gamma
return x
class GVPConv(MessagePassing):
'''
Graph convolution / message passing with Geometric Vector Perceptrons.
Takes in a graph with node and edge embeddings,
and returns new node embeddings.
This does NOT do residual updates and pointwise feedforward layers
---see `GVPConvLayer`.
:param in_dims: input node embedding dimensions (n_scalar, n_vector)
:param out_dims: output node embedding dimensions (n_scalar, n_vector)
:param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
:param n_layers: number of GVPs in the message function
:param module_list: preconstructed message function, overrides n_layers
:param aggr: should be "add" if some incoming edges are masked, as in
a masked autoregressive decoder architecture, otherwise "mean"
:param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
:param vector_gate: whether to use vector gating.
(vector_act will be used as sigma^+ in vector gating if `True`)
:param update_edge_attr: whether to compute an updated edge representation
'''
def __init__(self, in_dims, out_dims, edge_dims,
n_layers=3, module_list=None, aggr="mean",
activations=(F.relu, torch.sigmoid), vector_gate=False,
update_edge_attr=False):
super(GVPConv, self).__init__(aggr=aggr)
self.si, self.vi = in_dims
self.so, self.vo = out_dims
self.se, self.ve = edge_dims
self.update_edge_attr = update_edge_attr
GVP_ = functools.partial(GVP,
activations=activations,
vector_gate=vector_gate)
module_list = module_list or []
if not module_list:
if n_layers == 1:
module_list.append(
GVP_((2 * self.si + self.se, 2 * self.vi + self.ve),
(self.so, self.vo), activations=(None, None)))
else:
module_list.append(
GVP_((2 * self.si + self.se, 2 * self.vi + self.ve),
out_dims)
)
for i in range(n_layers - 2):
module_list.append(GVP_(out_dims, out_dims))
module_list.append(GVP_(out_dims, out_dims,
activations=(None, None)))
self.message_func = nn.Sequential(*module_list)
self.edge_func = copy.deepcopy(self.message_func) \
if self.update_edge_attr else None
def forward(self, x, edge_index, edge_attr):
'''
:param x: tuple (s, V) of `torch.Tensor`
:param edge_index: array of shape [2, n_edges]
:param edge_attr: tuple (s, V) of `torch.Tensor`
'''
x_s, x_v = x
message = self.propagate(edge_index,
s=x_s,
v=x_v.reshape(x_v.shape[0], 3 * x_v.shape[1]),
edge_attr=edge_attr)
if self.update_edge_attr:
s_i, s_j = x_s[edge_index[0]], x_s[edge_index[1]]
x_v = x_v.reshape(x_v.shape[0], 3 * x_v.shape[1])
v_i, v_j = x_v[edge_index[0]], x_v[edge_index[1]]
edge_out = self.edge_attr(s_i, v_i, s_j, v_j, edge_attr)
return _split(message, self.vo), edge_out
else:
return _split(message, self.vo)
def message(self, s_i, v_i, s_j, v_j, edge_attr):
v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3)
v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3)
message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
message = self.message_func(message)
return _merge(*message)
def edge_attr(self, s_i, v_i, s_j, v_j, edge_attr):
v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3)
v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3)
message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
return self.edge_func(message)
class GVPConvLayer(nn.Module):
'''
Full graph convolution / message passing layer with
Geometric Vector Perceptrons. Residually updates node embeddings with
aggregated incoming messages, applies a pointwise feedforward
network to node embeddings, and returns updated node embeddings.
To only compute the aggregated messages, see `GVPConv`.
:param node_dims: node embedding dimensions (n_scalar, n_vector)
:param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
:param n_message: number of GVPs to use in message function
:param n_feedforward: number of GVPs to use in feedforward function
:param drop_rate: drop probability in all dropout layers
:param autoregressive: if `True`, this `GVPConvLayer` will be used
with a different set of input node embeddings for messages
where src >= dst
:param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
:param vector_gate: whether to use vector gating.
(vector_act will be used as sigma^+ in vector gating if `True`)
:param update_edge_attr: whether to compute an updated edge representation
:param ln_vector_weight: whether to include a learnable weight in the vector
layer norm
'''
def __init__(self, node_dims, edge_dims,
n_message=3, n_feedforward=2, drop_rate=.1,
autoregressive=False,
activations=(F.relu, torch.sigmoid), vector_gate=False,
update_edge_attr=False, ln_vector_weight=False):
super(GVPConvLayer, self).__init__()
assert not (update_edge_attr and autoregressive), "Not implemented"
self.update_edge_attr = update_edge_attr
self.conv = GVPConv(node_dims, node_dims, edge_dims, n_message,
aggr="add" if autoregressive else "mean",
activations=activations, vector_gate=vector_gate,
update_edge_attr=update_edge_attr)
GVP_ = functools.partial(GVP,
activations=activations,
vector_gate=vector_gate)
self.norm = nn.ModuleList([LayerNorm(node_dims, ln_vector_weight)
for _ in range(2)])
self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])
def get_feedforward(n_dims):
ff_func = []
if n_feedforward == 1:
ff_func.append(GVP_(n_dims, n_dims, activations=(None, None)))
else:
hid_dims = 4 * n_dims[0], 2 * n_dims[1]
ff_func.append(GVP_(n_dims, hid_dims))
for i in range(n_feedforward - 2):
ff_func.append(GVP_(hid_dims, hid_dims))
ff_func.append(GVP_(hid_dims, n_dims, activations=(None, None)))
return nn.Sequential(*ff_func)
self.ff_func = get_feedforward(node_dims)
if self.update_edge_attr:
self.edge_norm = nn.ModuleList([LayerNorm(edge_dims, ln_vector_weight)
for _ in range(2)])
self.edge_dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])
self.edge_ff = get_feedforward(edge_dims)
def forward(self, x, edge_index, edge_attr,
autoregressive_x=None, node_mask=None):
'''
:param x: tuple (s, V) of `torch.Tensor`
:param edge_index: array of shape [2, n_edges]
:param edge_attr: tuple (s, V) of `torch.Tensor`
:param autoregressive_x: tuple (s, V) of `torch.Tensor`.
If not `None`, will be used as src node embeddings
for forming messages where src >= dst. The corrent node
embeddings `x` will still be the base of the update and the
pointwise feedforward.
:param node_mask: array of type `bool` to index into the first
dim of node embeddings (s, V). If not `None`, only
these nodes will be updated.
'''
if autoregressive_x is not None:
src, dst = edge_index
mask = src < dst
edge_index_forward = edge_index[:, mask]
edge_index_backward = edge_index[:, ~mask]
edge_attr_forward = tuple_index(edge_attr, mask)
edge_attr_backward = tuple_index(edge_attr, ~mask)
dh = tuple_sum(
self.conv(x, edge_index_forward, edge_attr_forward),
self.conv(autoregressive_x, edge_index_backward,
edge_attr_backward)
)
count = scatter_add(torch.ones_like(dst), dst,
dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(
-1)
dh = dh[0] / count, dh[1] / count.unsqueeze(-1)
else:
dh = self.conv(x, edge_index, edge_attr)
if self.update_edge_attr:
dh, de = dh
edge_attr = self.edge_norm[0](tuple_sum(edge_attr, self.dropout[0](de)))
de = self.edge_ff(edge_attr)
edge_attr = self.edge_norm[1](tuple_sum(edge_attr, self.dropout[1](de)))
if node_mask is not None:
x_ = x
x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)
x = self.norm[0](tuple_sum(x, self.dropout[0](dh)))
dh = self.ff_func(x)
x = self.norm[1](tuple_sum(x, self.dropout[1](dh)))
if node_mask is not None:
x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
x = x_
return (x, edge_attr) if self.update_edge_attr else x
################################################################################
def _normalize(tensor, dim=-1, eps=1e-8):
'''
Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
'''
return torch.nan_to_num(
torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True) + eps))
def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'):
'''
From https://github.com/jingraham/neurips19-graph-protein-design
Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1.
That is, if `D` has shape [...dims], then the returned tensor will have
shape [...dims, D_count].
'''
D_mu = torch.linspace(D_min, D_max, D_count, device=device)
D_mu = D_mu.view([1, -1])
D_sigma = (D_max - D_min) / D_count
D_expand = torch.unsqueeze(D, -1)
RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2)
return RBF
class GVPModel(torch.nn.Module):
"""
GVP-GNN model
inspired by: https://github.com/drorlab/gvp-pytorch/blob/main/gvp/models.py
and: https://github.com/drorlab/gvp-pytorch/blob/82af6b22eaf8311c15733117b0071408d24ed877/gvp/atom3d.py#L115
:param node_in_dim: node dimension in input graph, scalars or tuple (scalars, vectors)
:param node_h_dim: node dimensions to use in GVP-GNN layers, tuple (s, V)
:param node_out_nf: node dimensions in output graph, tuple (s, V)
:param edge_in_nf: edge dimension in input graph (scalars)
:param edge_h_dim: edge dimensions to embed to before use in GVP-GNN layers,
tuple (s, V)
:param edge_out_nf: edge dimensions in output graph, tuple (s, V)
:param num_layers: number of GVP-GNN layers
:param drop_rate: rate to use in all dropout layers
:param vector_gate: use vector gates in all GVPs
:param reflection_equiv: bool, use reflection-sensitive feature based on the
cross product if False
:param d_max:
:param num_rbf:
:param update_edge_attr: bool, update edge attributes at each layer in a
learnable way
"""
def __init__(self, node_in_dim, node_h_dim, node_out_nf,
edge_in_nf, edge_h_dim, edge_out_nf,
num_layers=3, drop_rate=0.1, vector_gate=False,
reflection_equiv=True, d_max=20.0, num_rbf=16,
update_edge_attr=False):
super(GVPModel, self).__init__()
self.reflection_equiv = reflection_equiv
self.update_edge_attr = update_edge_attr
self.d_max = d_max
self.num_rbf = num_rbf
# node_in_dim = (node_in_dim, 1)
if not isinstance(node_in_dim, tuple):
node_in_dim = (node_in_dim, 0)
edge_in_dim = (edge_in_nf + 2 * node_in_dim[0] + self.num_rbf, 1)
if not self.reflection_equiv:
edge_in_dim = (edge_in_dim[0], edge_in_dim[1] + 1)
# self.W_v = nn.Sequential(
# GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=True),
# LayerNorm(node_h_dim)
# )
self.W_v = nn.Sequential(
LayerNorm(node_in_dim, learnable_vector_weight=True),
GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=vector_gate),
)
# self.W_e = nn.Sequential(
# GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=True),
# LayerNorm(edge_h_dim)
# )
self.W_e = nn.Sequential(
LayerNorm(edge_in_dim, learnable_vector_weight=True),
GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=vector_gate),
)
self.layers = nn.ModuleList(
GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate,
update_edge_attr=self.update_edge_attr,
activations=(F.relu, None), vector_gate=vector_gate,
ln_vector_weight=True)
# activations=(F.relu, torch.sigmoid))
# GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate,
# update_edge_attr=self.update_edge_attr,
# activations=(nn.SiLU(), nn.SiLU()))
for _ in range(num_layers))
# self.W_v_out = GVP(node_h_dim, (node_out_nf, 1),
# activations=(None, None), vector_gate=True)
self.W_v_out = nn.Sequential(
LayerNorm(node_h_dim, learnable_vector_weight=True),
GVP(node_h_dim, (node_out_nf, 1), activations=(None, None), vector_gate=vector_gate),
)
# self.W_e_out = GVP(edge_h_dim, (edge_out_nf, 0),
# activations=(None, None), vector_gate=True) \
# if self.update_edge_attr else None
self.W_e_out = nn.Sequential(
LayerNorm(edge_h_dim, learnable_vector_weight=True),
GVP(edge_h_dim, (edge_out_nf, 0), activations=(None, None), vector_gate=vector_gate)
) if self.update_edge_attr else None
def edge_features(self, h, x, edge_index, batch_mask=None, edge_attr=None):
"""
:param h:
:param x:
:param edge_index:
:param batch_mask:
:param edge_attr:
:return: scalar and vector-valued edge features
"""
row, col = edge_index
coord_diff = x[row] - x[col]
dist = coord_diff.norm(dim=-1)
rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf,
device=x.device)
edge_s = torch.cat([h[row], h[col], rbf], dim=1)
edge_v = _normalize(coord_diff).unsqueeze(-2)
if edge_attr is not None:
edge_s = torch.cat([edge_s, edge_attr], dim=1)
if not self.reflection_equiv:
mean = scatter_mean(x, batch_mask, dim=0,
dim_size=batch_mask.max() + 1)
row, col = edge_index
cross = torch.cross(x[row] - mean[batch_mask[row]],
x[col] - mean[batch_mask[col]], dim=1)
cross = _normalize(cross).unsqueeze(-2)
edge_v = torch.cat([edge_v, cross], dim=-2)
return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v)
def forward(self, h, x, edge_index, v=None, batch_mask=None, edge_attr=None):
# h_v = (h, x.unsqueeze(-2))
h_v = h if v is None else (h, v)
h_e = self.edge_features(h, x, edge_index, batch_mask, edge_attr)
h_v = self.W_v(h_v)
h_e = self.W_e(h_e)
for layer in self.layers:
h_v = layer(h_v, edge_index, edge_attr=h_e)
if self.update_edge_attr:
h_v, h_e = h_v
# h, x = self.W_v_out(h_v)
# x = x.squeeze(-2)
h, vel = self.W_v_out(h_v)
# x = x + vel.squeeze(-2)
if self.update_edge_attr:
edge_attr = self.W_e_out(h_e)
# return h, x, edge_attr
return h, vel.squeeze(-2), edge_attr
|