File size: 36,756 Bytes
f9073ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 |
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import softmax
import math
from typing import Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.typing import Adj, OptTensor
import numpy as np
def coord2dist(x, edge_index, sqrt=False, pos_unmask=None):
if x.dim() == 3 and pos_unmask is not None:
x = x * pos_unmask.unsqueeze(-1) # shape = [B, 3, 3]
x = x.sum(dim=1) / pos_unmask.sum(dim=1, keepdim=True).clamp(min=1)
elif x.shape[1] == 9 and x.dim() == 2:
# coordinates to distance
x = x.view(-1, 3, 3).mean(dim=1)
# coordinates to distance
row, col = edge_index
coord_diff = x[row] - x[col]
radial = torch.sum(coord_diff ** 2, 1).unsqueeze(1)
if sqrt:
radial = radial.sqrt()
return radial.detach()
def modulate(x, shift, scale):
# return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return x * (1 + scale) + shift
class TransLayer(MessagePassing):
"""The version for involving the edge feature. Multiply Msg. Without FFN and norm."""
_alpha: OptTensor
def __init__(self, x_channels: int, out_channels: int,
heads: int = 1, dropout: float = 0., edge_dim: Optional[int] = None,
bias: bool = True, **kwargs):
kwargs.setdefault('aggr', 'add')
super(TransLayer, self).__init__(node_dim=0, **kwargs)
self.x_channels = x_channels
self.in_channels = in_channels = x_channels
self.out_channels = out_channels
self.heads = heads
self.dropout = dropout
self.edge_dim = edge_dim
self.lin_key = nn.Linear(in_channels, heads * out_channels, bias=bias)
self.lin_query = nn.Linear(in_channels, heads * out_channels, bias=bias)
self.lin_value = nn.Linear(in_channels, heads * out_channels, bias=bias)
self.lin_edge0 = nn.Linear(edge_dim, heads * out_channels, bias=False)
self.lin_edge1 = nn.Linear(edge_dim, heads * out_channels, bias=False)
self.proj = nn.Linear(heads * out_channels, heads * out_channels, bias=bias)
self.reset_parameters()
def reset_parameters(self):
self.lin_key.reset_parameters()
self.lin_query.reset_parameters()
self.lin_value.reset_parameters()
self.lin_edge0.reset_parameters()
self.lin_edge1.reset_parameters()
self.proj.reset_parameters()
def forward(self, x: OptTensor,
edge_index: Adj,
edge_attr: OptTensor = None
) -> Tensor:
""""""
H, C = self.heads, self.out_channels
x_feat = x
query = self.lin_query(x_feat).view(-1, H, C)
key = self.lin_key(x_feat).view(-1, H, C)
value = self.lin_value(x_feat).view(-1, H, C)
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
out_x = self.propagate(edge_index, query=query, key=key, value=value, edge_attr=edge_attr, size=None)
out_x = out_x.view(-1, self.heads * self.out_channels)
out_x = self.proj(out_x)
return out_x
def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
edge_attr: OptTensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tuple[Tensor, Tensor]:
edge_attn = self.lin_edge0(edge_attr).view(-1, self.heads, self.out_channels)
edge_attn = torch.tanh(edge_attn)
alpha = (query_i * key_j * edge_attn).sum(dim=-1) / math.sqrt(self.out_channels)
alpha = softmax(alpha, index, ptr, size_i)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
# node feature message
msg = value_j
msg = msg * torch.tanh(self.lin_edge1(edge_attr).view(-1, self.heads, self.out_channels))
msg = msg * alpha.view(-1, self.heads, 1)
return msg
def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.heads)
class TransLayerOptim(MessagePassing):
"""The version for involving the edge feature. Multiply Msg. Without FFN and norm."""
_alpha: OptTensor
def __init__(self, x_channels: int, out_channels: int,
heads: int = 1, dropout: float = 0., edge_dim: Optional[int] = None,
bias: bool = True, **kwargs):
kwargs.setdefault('aggr', 'add')
super(TransLayerOptim, self).__init__(node_dim=0, **kwargs)
self.x_channels = x_channels
self.in_channels = in_channels = x_channels
self.out_channels = out_channels
self.heads = heads
self.dropout = dropout
self.edge_dim = edge_dim
self.lin_qkv = nn.Linear(in_channels, heads * out_channels * 3, bias=bias)
self.lin_edge = nn.Linear(edge_dim, heads * out_channels * 2, bias=False)
self.proj = nn.Linear(heads * out_channels, heads * out_channels, bias=bias)
self.reset_parameters()
def reset_parameters(self):
self.lin_qkv.reset_parameters()
self.lin_edge.reset_parameters()
self.proj.reset_parameters()
def forward(self, x: OptTensor,
edge_index: Adj,
edge_attr: OptTensor = None
) -> Tensor:
""""""
H, C = self.heads, self.out_channels
x_feat = x
qkv = self.lin_qkv(x_feat).view(-1, H, 3, C)
query, key, value = qkv.unbind(dim=2)
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
out_x = self.propagate(edge_index, query=query, key=key, value=value, edge_attr=edge_attr, size=None)
out_x = out_x.view(-1, self.heads * self.out_channels)
out_x = self.proj(out_x)
return out_x
def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
edge_attr: OptTensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tuple[Tensor, Tensor]:
edge_key, edge_value = torch.tanh(self.lin_edge(edge_attr)).view(-1, self.heads, 2, self.out_channels).unbind(dim=2)
alpha = (query_i * key_j * edge_key).sum(dim=-1) / math.sqrt(self.out_channels)
alpha = softmax(alpha, index, ptr, size_i)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
# node feature message
msg = value_j * edge_value * alpha.view(-1, self.heads, 1)
return msg
def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.heads)
class TransLayerOptimV2(MessagePassing):
"""The version for involving the edge feature. Multiply Msg. Without FFN and norm."""
_alpha: OptTensor
def __init__(self, x_channels: int, out_channels: int,
heads: int = 1, dropout: float = 0., edge_dim: Optional[int] = None,
bias: bool = True, **kwargs):
kwargs.setdefault('aggr', 'add')
super(TransLayerOptimV2, self).__init__(node_dim=0, **kwargs)
self.x_channels = x_channels
self.in_channels = in_channels = x_channels
self.out_channels = out_channels
self.heads = heads
self.dropout = dropout
self.edge_dim = edge_dim
self.lin_q = nn.Linear(in_channels, heads * out_channels, bias=bias)
self.edge_mlp = nn.Sequential(
nn.Linear(in_channels + edge_dim, in_channels, bias=bias),
nn.GELU(),
)
self.lin_kv = nn.Linear(in_channels, heads * out_channels * 2, bias=bias)
self.proj = nn.Linear(heads * out_channels, heads * out_channels, bias=bias)
self.reset_parameters()
def reset_parameters(self):
self.lin_q.reset_parameters()
self.lin_kv.reset_parameters()
# self.edge_mlp.reset_parameters()
self.proj.reset_parameters()
def forward(self, x: OptTensor,
edge_index: Adj,
edge_attr: OptTensor = None
) -> Tensor:
""""""
H, C = self.heads, self.out_channels
x_feat = x
query = self.lin_q(x_feat).view(-1, H, C)
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
out_x = self.propagate(edge_index, query=query, x_feat=x_feat, edge_attr=edge_attr)
out_x = out_x.view(-1, self.heads * self.out_channels)
out_x = self.proj(out_x)
return out_x
def message(self, query_i: Tensor, x_feat_j: Tensor,
edge_attr: OptTensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tuple[Tensor, Tensor]:
edge_feat_ij = self.edge_mlp(torch.cat([x_feat_j, edge_attr], dim=-1)) # shape [N * N, in_channels]
edge_key_ij, edge_value_ij = self.lin_kv(edge_feat_ij).view(-1, self.heads, 2, self.out_channels).unbind(dim=2) # shape [N * N, heads, out_channels]
alpha_ij = (query_i * edge_key_ij).sum(dim=-1) / math.sqrt(self.out_channels) # shape [N * N, heads]
alpha_ij = softmax(alpha_ij, index, ptr, size_i)
alpha_ij = F.dropout(alpha_ij, p=self.dropout, training=self.training)
# node feature message
msg = edge_value_ij * alpha_ij.view(-1, self.heads, 1) # shape [N * N, heads, out_channels]
return msg
def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.heads)
class TransLayerOptimV3(MessagePassing):
"""The version for involving the edge feature. Multiply Msg. Without FFN and norm."""
_alpha: OptTensor
def __init__(self, x_channels: int, out_channels: int,
heads: int = 1, dropout: float = 0., edge_dim: Optional[int] = None,
bias: bool = True, **kwargs):
kwargs.setdefault('aggr', 'add')
super(TransLayerOptimV3, self).__init__(node_dim=0, **kwargs)
self.x_channels = x_channels
self.in_channels = in_channels = x_channels
self.out_channels = out_channels
self.heads = heads
self.dropout = dropout
self.edge_dim = edge_dim
self.lin_q = nn.Linear(in_channels + edge_dim, heads * out_channels, bias=bias)
self.lin_kv = nn.Linear(in_channels + edge_dim, heads * out_channels * 2, bias=bias)
self.proj = nn.Linear(heads * out_channels, heads * out_channels, bias=bias)
self.reset_parameters()
def reset_parameters(self):
self.lin_q.reset_parameters()
self.lin_kv.reset_parameters()
self.proj.reset_parameters()
def forward(self, x: OptTensor,
edge_index: Adj,
edge_attr: OptTensor = None,
edge_mask: OptTensor = None
) -> Tensor:
""""""
x_feat = x
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
out_x = self.propagate(edge_index, x_feat=x_feat, edge_attr=edge_attr)
out_x = out_x.view(-1, self.heads * self.out_channels)
out_x = self.proj(out_x)
return out_x
def message(self, x_feat_i: Tensor, x_feat_j: Tensor,
edge_attr: OptTensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tuple[Tensor, Tensor]:
query_ij = self.lin_q(torch.cat([x_feat_i, edge_attr], dim=-1)).view(-1, self.heads, self.out_channels)
edge_key_ij, edge_value_ij = self.lin_kv(torch.cat([x_feat_j, edge_attr], dim=-1)).view(-1, self.heads, 2, self.out_channels).unbind(dim=2) # shape [N * N, heads, out_channels]
alpha_ij = (query_ij * edge_key_ij).sum(dim=-1) / math.sqrt(self.out_channels) # shape [N * N, heads]
alpha_ij = softmax(alpha_ij, index, ptr, size_i)
alpha_ij = F.dropout(alpha_ij, p=self.dropout, training=self.training)
# node feature message
msg = edge_value_ij * alpha_ij.view(-1, self.heads, 1) # shape [N * N, heads, out_channels]
return msg
def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.heads)
class TransLayerOptimV3Mask(MessagePassing):
"""The version for involving the edge feature. Multiply Msg. Without FFN and norm."""
_alpha: OptTensor
def __init__(self, x_channels: int, out_channels: int,
heads: int = 1, dropout: float = 0., edge_dim: Optional[int] = None,
bias: bool = True, **kwargs):
kwargs.setdefault('aggr', 'add')
super(TransLayerOptimV3Mask, self).__init__(node_dim=0, **kwargs)
self.x_channels = x_channels
self.in_channels = in_channels = x_channels
self.out_channels = out_channels
self.heads = heads
self.dropout = dropout
self.edge_dim = edge_dim
self.lin_q = nn.Linear(in_channels + edge_dim, heads * out_channels, bias=bias)
self.lin_kv = nn.Linear(in_channels + edge_dim, heads * out_channels * 2, bias=bias)
self.proj = nn.Linear(heads * out_channels, heads * out_channels, bias=bias)
self.reset_parameters()
def reset_parameters(self):
self.lin_q.reset_parameters()
self.lin_kv.reset_parameters()
self.proj.reset_parameters()
def forward(self, x: OptTensor,
edge_index: Adj,
edge_attr: OptTensor = None,
edge_mask: OptTensor = None
) -> Tensor:
""""""
x_feat = x
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
out_x = self.propagate(edge_index, x_feat=x_feat, edge_attr=edge_attr, edge_mask=edge_mask)
out_x = out_x.view(-1, self.heads * self.out_channels)
out_x = self.proj(out_x)
return out_x
def message(self, x_feat_i: Tensor, x_feat_j: Tensor,
edge_attr: OptTensor, edge_mask: OptTensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tuple[Tensor, Tensor]:
query_ij = self.lin_q(torch.cat([x_feat_i, edge_attr], dim=-1)).view(-1, self.heads, self.out_channels)
edge_key_ij, edge_value_ij = self.lin_kv(torch.cat([x_feat_j, edge_attr], dim=-1)).view(-1, self.heads, 2, self.out_channels).unbind(dim=2) # shape [N * N, heads, out_channels]
alpha_ij = (query_ij * edge_key_ij).sum(dim=-1) / math.sqrt(self.out_channels) # shape [N * N, heads]
min_dtype = torch.finfo(alpha_ij.dtype).min
alpha_ij = alpha_ij + min_dtype * edge_mask.view(-1, 1)
alpha_ij = softmax(alpha_ij, index, ptr, size_i)
alpha_ij = F.dropout(alpha_ij, p=self.dropout, training=self.training)
# node feature message
msg = edge_value_ij * alpha_ij.view(-1, self.heads, 1) # shape [N * N, heads, out_channels]
return msg
def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.heads)
class TransLayerOptimV4(MessagePassing):
"""The version for involving the edge feature. Multiply Msg. Without FFN and norm."""
_alpha: OptTensor
def __init__(self, x_channels: int, out_channels: int,
heads: int = 1, dropout: float = 0., edge_dim: Optional[int] = None,
bias: bool = True, **kwargs):
kwargs.setdefault('aggr', 'add')
super(TransLayerOptimV4, self).__init__(node_dim=0, **kwargs)
self.x_channels = x_channels
self.in_channels = in_channels = x_channels
self.out_channels = out_channels
self.heads = heads
self.dropout = dropout
self.edge_dim = edge_dim
self.lin_qkv = nn.Linear(in_channels, heads * out_channels * 3, bias=bias)
self.lin_qkv_e = nn.Linear(edge_dim, heads * out_channels * 3, bias=False)
self.proj = nn.Linear(heads * out_channels, heads * out_channels, bias=bias)
self.reset_parameters()
def reset_parameters(self):
self.lin_qkv.reset_parameters()
self.lin_qkv_e.reset_parameters()
self.proj.reset_parameters()
def forward(self, x: OptTensor,
edge_index: Adj,
edge_attr: OptTensor = None
) -> Tensor:
""""""
x_feat = x
query, key, value = self.lin_qkv(x_feat).view(-1, self.heads, 3, self.out_channels).unbind(dim=2)
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
out_x = self.propagate(edge_index, query=query, key=key, value=value, edge_attr=edge_attr)
out_x = out_x.view(-1, self.heads * self.out_channels)
out_x = self.proj(out_x)
return out_x
def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
edge_attr: OptTensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tuple[Tensor, Tensor]:
edge_query_ij, edge_key_ij, edge_value_ij = self.lin_qkv_e(edge_attr).view(-1, self.heads, 3, self.out_channels).unbind(dim=2)
query_ij = query_i + edge_query_ij
key_ij = key_j + edge_key_ij
value_ij = value_j + edge_value_ij
alpha_ij = (query_ij * key_ij).sum(dim=-1) / math.sqrt(self.out_channels) # shape [N * N, heads]
alpha_ij = softmax(alpha_ij, index, ptr, size_i)
alpha_ij = F.dropout(alpha_ij, p=self.dropout, training=self.training)
# node feature message
msg = value_ij * alpha_ij.view(-1, self.heads, 1) # shape [N * N, heads, out_channels]
return msg
def __repr__(self):
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
self.in_channels,
self.out_channels, self.heads)
@torch.jit.script
def gaussian(x, mean, std):
pi = 3.14159
a = (2 * pi) ** 0.5
return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
class GaussianLayer(nn.Module):
"""Gaussian basis function layer for 3D distance features"""
def __init__(self, K, dist_mask_type=False, *args, **kwargs):
super().__init__()
self.K = K - 1
self.means = nn.Embedding(1, self.K)
self.stds = nn.Embedding(1, self.K)
nn.init.uniform_(self.means.weight, 0, 3)
nn.init.uniform_(self.stds.weight, 0, 3)
self.dist_mask_type = dist_mask_type
if self.dist_mask_type == 'replace':
self.mask_token = nn.Parameter(torch.zeros(1, K))
# self.init_mask_token()
elif self.dist_mask_type == 'add':
self.mask_token = nn.Parameter(torch.zeros(2, K))
nn.init.xavier_normal_(self.mask_token)
elif self.dist_mask_type == 'none':
pass
else:
raise ValueError(f'Unknown mask_token {dist_mask_type}')
def forward(self, x, x_mask=None, *args, **kwargs):
mean = self.means.weight.float().view(-1)
std = self.stds.weight.float().view(-1).abs() + 1e-5
out = torch.cat([x, gaussian(x, mean, std).type_as(self.means.weight)], dim=-1)
if self.dist_mask_type == 'replace':
out[x_mask] = self.mask_token
elif self.dist_mask_type == 'add':
out = out + self.mask_token[x_mask.long()]
elif self.dist_mask_type == 'none':
pass
else:
assert False
return out
class DMTBlock(nn.Module):
"""Equivariant block based on graph relational transformer layer, without extra heads."""
def __init__(self, node_dim, edge_dim, num_heads,
mlp_ratio=4, act=nn.GELU, dropout=0.0, pair_update=True, trans_ver='v3'):
super().__init__()
self.dropout = dropout
self.act = act()
self.pair_update = pair_update
if not self.pair_update:
self.edge_emb = nn.Sequential(
nn.Linear(edge_dim, edge_dim * 2),
nn.GELU(),
nn.Linear(edge_dim * 2, edge_dim),
nn.LayerNorm(edge_dim),
)
if trans_ver == 'v2':
# message passing layer
self.attn_mpnn = TransLayerOptimV2(node_dim, node_dim // num_heads, num_heads, edge_dim=edge_dim, dropout=dropout)
elif trans_ver == 'v3':
# message passing layer
self.attn_mpnn = TransLayerOptimV3(node_dim, node_dim // num_heads, num_heads, edge_dim=edge_dim, dropout=dropout)
elif trans_ver == 'v4':
# message passing layer
self.attn_mpnn = TransLayerOptimV4(node_dim, node_dim // num_heads, num_heads, edge_dim=edge_dim, dropout=dropout)
else:
# message passing layer
self.attn_mpnn = TransLayerOptim(node_dim, node_dim // num_heads, num_heads, edge_dim=edge_dim, dropout=dropout)
# Feed forward block -> node.
self.ff_linear1 = nn.Linear(node_dim, node_dim * mlp_ratio)
self.ff_linear2 = nn.Linear(node_dim * mlp_ratio, node_dim)
if pair_update:
self.node2edge_lin = nn.Linear(node_dim * 2 + edge_dim, edge_dim)
# Feed forward block -> edge.
self.ff_linear3 = nn.Linear(edge_dim, edge_dim * mlp_ratio)
self.ff_linear4 = nn.Linear(edge_dim * mlp_ratio, edge_dim)
# equivariant edge update layer
self.norm1_node = nn.LayerNorm(node_dim, elementwise_affine=True, eps=1e-6)
self.norm2_node = nn.LayerNorm(node_dim, elementwise_affine=True, eps=1e-6)
if self.pair_update:
self.norm1_edge = nn.LayerNorm(edge_dim, elementwise_affine=True, eps=1e-6)
self.norm2_edge = nn.LayerNorm(edge_dim, elementwise_affine=True, eps=1e-6)
def _ff_block_node(self, x):
x = F.dropout(self.act(self.ff_linear1(x)), p=self.dropout, training=self.training)
return F.dropout(self.ff_linear2(x), p=self.dropout, training=self.training)
def _ff_block_edge(self, x):
x = F.dropout(self.act(self.ff_linear3(x)), p=self.dropout, training=self.training)
return F.dropout(self.ff_linear4(x), p=self.dropout, training=self.training)
def forward(self, h, edge_attr, edge_index):
"""
A more optimized version of forward_old using torch.compile
Params:
h: [B*N, hid_dim]
edge_attr: [N_edge, edge_hid_dim]
edge_index: [2, N_edge]
"""
h_in_node = h
h_in_edge = edge_attr
## prepare node features
h = self.norm1_node(h)
## prepare edge features
if self.pair_update:
edge_attr = self.norm1_edge(edge_attr)
else:
edge_attr = self.edge_emb(edge_attr)
# apply transformer-based message passing, update node features and edge features (FFN + norm)
h_node = self.attn_mpnn(h, edge_index, edge_attr)
## update node features
h_out = self.node_update(h_in_node, h_node)
## update edge features
if self.pair_update:
# h_edge = h_node[edge_index[0]] + h_node[edge_index[1]]
h_edge = h_node[edge_index.transpose(0, 1)].flatten(1, 2) # shape [N_edge, 2 * edge_hid_dim]
h_edge = torch.cat([h_edge, h_in_edge], dim=-1)
h_edge_out = self.edge_update(h_in_edge, h_edge)
else:
h_edge_out = h_in_edge
return h_out, h_edge_out
# @torch.compile(dynamic=True, disable=disable_compile)
def node_update(self, h_in_node, h_node):
h_node = h_in_node + h_node
_h_node = self.norm2_node(h_node)
h_out = h_node + self._ff_block_node(_h_node)
return h_out
# @torch.compile(dynamic=True, disable=disable_compile)
def edge_update(self, h_in_edge, h_edge):
h_edge = self.node2edge_lin(h_edge)
h_edge = h_in_edge + h_edge
_h_edge = self.norm2_edge(h_edge)
h_edge_out = h_edge + self._ff_block_edge(_h_edge)
return h_edge_out
class PositionalEncoding(nn.Module):
def __init__(self, d_hid, n_position=3000):
super(PositionalEncoding, self).__init__()
# Not a parameter
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
def _get_sinusoid_encoding_table(self, n_position, d_hid):
''' Sinusoid position encoding table '''
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) # shape = [n_position, d_hid]
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table) # shape = [n_position, d_hid]
def forward(self, seq_pos):
'''
seq_pos: [\sum_i N_i, ]
'''
return self.pos_table[seq_pos].clone().detach() # shape = [\sum_i N_i, d_hid]
class NodeEmbed(nn.Module):
def __init__(self, in_node_features, hidden_size, pos_dim=72, mlp_ratio=4, pos_mask_type='none', llm_embed=False, use_protenix_emb=True, protenix_hidden_dim=384):
super().__init__()
self.x_linear = nn.Linear(in_node_features, hidden_size * mlp_ratio, bias=False)
self.pos_linear = nn.Linear(pos_dim, hidden_size * mlp_ratio, bias=True)
self.seq_pos_emb = PositionalEncoding(hidden_size)
self.seq_pos_linear = nn.Linear(hidden_size, hidden_size * mlp_ratio, bias=False)
self.llm_embed = llm_embed
self.use_llm_mlp = False
self.use_protenix_emb = use_protenix_emb
if llm_embed:
if self.use_llm_mlp:
self.llm_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size * mlp_ratio),
nn.GELU(),
nn.Linear(hidden_size * mlp_ratio, hidden_size)
)
else:
self.llm_mlp = nn.Linear(hidden_size, hidden_size * mlp_ratio, bias=False)
if use_protenix_emb:
self.protenix_mlp = nn.Linear(protenix_hidden_dim, hidden_size * mlp_ratio, bias=False)
self.mlp = nn.Sequential(
nn.GELU(),
nn.Linear(hidden_size * mlp_ratio, hidden_size)
)
self.pos_mask_type = pos_mask_type
if pos_mask_type == 'replace':
self.mask_token = nn.Parameter(torch.zeros(1, hidden_size * mlp_ratio))
nn.init.normal_(self.mask_token, std=.02)
elif pos_mask_type == 'add':
self.mask_token = nn.Parameter(torch.zeros(2, hidden_size * mlp_ratio))
nn.init.xavier_normal_(self.mask_token)
elif pos_mask_type == 'none':
pass
else:
raise ValueError(f'Unknown pos_mask_type {pos_mask_type}')
def forward(self, x, pos, seq_pos, pos_mask=None, llm_embed=None, protenix_emb=None):
if pos.dim() == 3:
pos = pos.flatten(1,2)
x = self.x_linear(x)
pos = self.pos_linear(pos)
seq_pos = self.seq_pos_linear(self.seq_pos_emb(seq_pos))
if self.pos_mask_type == 'replace':
pos[pos_mask] = self.mask_token.to(pos.dtype)
elif self.pos_mask_type == 'add':
pos = pos + self.mask_token[pos_mask.long()]
elif self.pos_mask_type == 'none':
pass
else:
assert False
if self.llm_embed:
if self.use_llm_mlp:
return self.mlp(x + pos + seq_pos) + self.llm_mlp(llm_embed)
else:
return self.mlp(x + pos + seq_pos + self.llm_mlp(llm_embed))
if self.use_protenix_emb:
return self.mlp(x + pos + seq_pos + self.protenix_mlp(protenix_emb))
return self.mlp(x + pos + seq_pos)
class NodeEmbed_with_struc(nn.Module):
def __init__(self, in_node_features, hidden_size, pos_dim=3, mlp_ratio=4, pos_mask_type='none', llm_embed=False, struc_emb_dim=20):
super().__init__()
self.x_linear = nn.Linear(in_node_features, hidden_size * mlp_ratio, bias=False)
# self.pos_linear = nn.Linear(pos_dim+20, hidden_size * mlp_ratio, bias=True)
self.pos_linear = nn.Linear(pos_dim, hidden_size * mlp_ratio, bias=True)# add struc_emb [0414 by TIANRUI]
self.struc_linear = nn.Linear(struc_emb_dim*24, hidden_size * mlp_ratio, bias=False)# add struc_emb [0414 by TIANRUI]
self.seq_pos_emb = PositionalEncoding(hidden_size)
self.seq_pos_linear = nn.Linear(hidden_size, hidden_size * mlp_ratio, bias=False)
self.llm_embed = llm_embed
self.use_llm_mlp = False
if llm_embed:
if self.use_llm_mlp:
self.llm_mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size * mlp_ratio),
nn.GELU(),
nn.Linear(hidden_size * mlp_ratio, hidden_size)
)
else:
self.llm_mlp = nn.Linear(hidden_size, hidden_size * mlp_ratio, bias=False)
self.mlp = nn.Sequential(
nn.GELU(),
nn.Linear(hidden_size * mlp_ratio, hidden_size)
)
self.pos_mask_type = pos_mask_type
if pos_mask_type == 'replace':
self.mask_token = nn.Parameter(torch.zeros(1, hidden_size * mlp_ratio))
nn.init.normal_(self.mask_token, std=.02)
elif pos_mask_type == 'add':
self.mask_token = nn.Parameter(torch.zeros(2, hidden_size * mlp_ratio))
nn.init.xavier_normal_(self.mask_token)
elif pos_mask_type == 'none':
pass
else:
raise ValueError(f'Unknown pos_mask_type {pos_mask_type}')
def forward(self, x, struc_emb, pos, seq_pos, pos_mask=None, llm_embed=None):
if pos.dim() == 3:
pos = pos.flatten(1,2)
struc_emb = struc_emb.flatten(1,2)
x = self.x_linear(x)
pos = self.pos_linear(pos)
# pos = self.pos_linear(pos).sum(dim=1)
struc_emb = self.struc_linear(struc_emb)
pos = pos + struc_emb
seq_pos = self.seq_pos_linear(self.seq_pos_emb(seq_pos))
if self.pos_mask_type == 'replace':
pos[pos_mask] = self.mask_token.to(pos.dtype)
elif self.pos_mask_type == 'add':
pos = pos + self.mask_token[pos_mask.long()]
elif self.pos_mask_type == 'none':
pass
else:
assert False
if self.llm_embed:
if self.use_llm_mlp:
return self.mlp(x + pos + seq_pos) + self.llm_mlp(llm_embed)
else:
return self.mlp(x + pos + seq_pos + self.llm_mlp(llm_embed))
return self.mlp(x + pos + seq_pos)
class DMT(nn.Module):
def __init__(self, configs):
super().__init__()
self.use_struc_emb = configs.use_struc_emb
self.disable_dist = configs.disable_dist
self.new_aa = configs.new_aa
self.sqrt_dis = configs.sqrt_dis
edge_dim = configs.hidden_dim // configs.e2n_ratio
if configs.use_struc_emb:
self.node_emb = NodeEmbed_with_struc(configs.in_res_node_features, configs.hidden_dim, configs.pos_dim, configs.mlp_ratio, configs.pos_mask_type, configs.enable_llm)
else:
self.node_emb = NodeEmbed(configs.in_res_node_features, configs.hidden_dim, configs.pos_dim, configs.mlp_ratio, configs.pos_mask_type, configs.enable_llm, configs.use_protenix_emb)
if not configs.disable_dist:
self.dist_mask_type = configs.dist_mask_type
# distance GBF embedding
self.dist_gbf = GaussianLayer(edge_dim, configs.dist_mask_type)
in_edge_dim = configs.in_res_edge_features + edge_dim
else:
in_edge_dim = configs.in_res_edge_features
self.edge_emb = nn.Sequential(
nn.Linear(in_edge_dim, 2 * edge_dim),
nn.GELU(),
nn.Linear(2 * edge_dim, edge_dim),
)
self.blocks = nn.ModuleList()
for _ in range(configs.n_blocks):
self.blocks.append(DMTBlock(configs.hidden_dim, edge_dim,
configs.n_heads, mlp_ratio=configs.mlp_ratio, act=nn.GELU, dropout=configs.dropout, pair_update=not configs.not_pair_update, trans_ver=configs.trans_ver))
self.pooling_mlp = nn.Sequential(
nn.Linear(configs.hidden_dim, configs.hidden_dim * configs.mlp_ratio),
nn.GELU(),
nn.Linear(configs.hidden_dim * configs.mlp_ratio, configs.hidden_dim)
)
self.pred_layer = nn.Sequential(
nn.Linear(configs.hidden_dim, configs.hidden_dim * configs.mlp_ratio),
nn.Tanh(),
nn.Linear(configs.hidden_dim * configs.mlp_ratio, 72)
)
def forward(self, data):
assert hasattr(data, 'seq_pos')
seq_pos = data.seq_pos
# obtain node and edge feature
llm_embed = data.get('llm_embed', None)
if self.use_struc_emb:
# add struc_emb [0403 by TIANRUI]
# struc_input = torch.cat([data.pos, data.struc_emb], dim=-1)
node_h = self.node_emb(data.x, data.struc_emb, data.gt_pos, seq_pos, data['pos_mask'], llm_embed=llm_embed)
else:
# node_h = self.node_emb(data.x, data.gt_pos, seq_pos, data['pos_mask'], llm_embed=llm_embed, protenix_emb=data.get('protenix_emb', None))
node_h = self.node_emb(data.x, data.pos, seq_pos, data['pos_mask'], llm_embed=llm_embed, protenix_emb=data.get('protenix_emb', None))
# add distance to edge feature
if not self.disable_dist:
if self.new_aa:
# distance = coord2dist(data.gt_pos, data.edge_index, self.sqrt_dis, ~data.pos_mask)
distance = coord2dist(data.pos, data.edge_index, self.sqrt_dis, ~data.pos_mask)
else:
# distance = coord2dist(data.gt_pos, data.edge_index, self.sqrt_dis)
distance = coord2dist(data.pos, data.edge_index, self.sqrt_dis)
edge_mask = None
if self.dist_mask_type != 'none':
edge_mask = data.pos_mask[data.edge_index].any(dim=0)
dist_emb = self.dist_gbf(distance, edge_mask)
edge_h = self.edge_emb(torch.cat([data.edge_attr, dist_emb], dim=-1))
else:
edge_h = self.edge_emb(data.edge_attr)
# run the DMT blocks
for layer in self.blocks:
node_h, edge_h = layer(node_h, edge_h, data.edge_index)
pred_noise = self.pred_layer(node_h).reshape(data.pos.shape)
denoising_loss = ((pred_noise[~data['pos_mask']] - data['noise'][~data['pos_mask']]) ** 2).mean()
graph_h = global_mean_pool(node_h, data.batch) # [B, hidden_dim]
graph_h = self.pooling_mlp(graph_h)
return graph_h, denoising_loss
# class InfoNCELoss(nn.Module):
# def __init__(self, temperature=0.05):
# super().__init__()
# self.temperature = temperature
# def forward(self, z1, z2):
# """
# z1, z2: (B, D) 两个视图经过 projection head 的表示
# """
# B = z1.shape[0]
# z = torch.cat([z1, z2], dim=0) # (2B, D)
# sim = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=-1) # (2B, 2B)
# # Positive indices: i-th with (i + B)%2B
# labels = torch.arange(B, device=z.device)
# labels = torch.cat([labels + B, labels], dim=0)
# # Mask: remove self-similarity
# mask = ~torch.eye(2 * B, dtype=torch.bool, device=z.device)
# sim = sim / self.temperature
# sim_exp = torch.exp(sim) * mask # (2B, 2B), exp(sim) and remove diagonal
# # Denominator
# denom = sim_exp.sum(dim=1) # (2B,)
# # Numerator: select positive pairs
# numerator = torch.exp(sim[torch.arange(2 * B), labels])
# loss = -torch.log(numerator / denom)
# return loss.mean() |