| | |
| | |
| |
|
| | from copy import deepcopy |
| | from torch.nn.init import xavier_uniform_ |
| | import torch.nn.functional as F |
| | from torch.nn import Parameter |
| | from torch.nn.init import normal_ |
| | import torch.utils.checkpoint |
| | from torch import Tensor, device |
| | from .G2PTL_utils import * |
| | from transformers.modeling_utils import ModuleUtilsMixin |
| | from fairseq import utils |
| | from fairseq.models import ( |
| | FairseqEncoder, |
| | register_model, |
| | register_model_architecture, |
| | ) |
| | from fairseq.modules import ( |
| | LayerNorm, |
| | ) |
| |
|
| | def init_params(module, n_layers): |
| | if isinstance(module, nn.Linear): |
| | module.weight.data.normal_(mean=0.0, std=0.02 / math.sqrt(n_layers)) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | if isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=0.02) |
| |
|
| |
|
| | @torch.jit.script |
| | def softmax_dropout(input, dropout_prob: float, is_training: bool): |
| | return F.dropout(F.softmax(input, -1), dropout_prob, is_training) |
| |
|
| |
|
| | class SelfMultiheadAttention(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dim, |
| | num_heads, |
| | dropout=0.0, |
| | bias=True, |
| | scaling_factor=1, |
| | ): |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| |
|
| | self.num_heads = num_heads |
| | self.dropout = dropout |
| |
|
| | self.head_dim = embed_dim // num_heads |
| | assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads" |
| | self.scaling = (self.head_dim * scaling_factor) ** -0.5 |
| |
|
| | self.linear_q = nn.Linear(self.embed_dim, self.num_heads * self.head_dim) |
| | self.linear_k = nn.Linear(self.embed_dim, self.num_heads * self.head_dim) |
| | self.linear_v = nn.Linear(self.embed_dim, self.num_heads * self.head_dim) |
| | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias) |
| |
|
| | def forward( |
| | self, |
| | query: Tensor, |
| | attn_bias: Tensor = None, |
| | ) -> Tensor: |
| | n_graph, n_node, embed_dim = query.size() |
| | |
| |
|
| | _shape = (-1, n_graph * self.num_heads, self.head_dim) |
| | q = self.linear_q(query).contiguous().view(n_graph, -1, self.num_heads, self.head_dim).transpose(1, 2) * self.scaling |
| | k = self.linear_k(query).contiguous().view(n_graph, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| | v = self.linear_v(query).contiguous().view(n_graph, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| |
|
| | attn_weights = torch.matmul(q, k.transpose(2, 3)) |
| | attn_weights = attn_weights + attn_bias |
| | attn_probs = softmax_dropout(attn_weights, self.dropout, self.training) |
| |
|
| | attn = torch.matmul(attn_probs, v) |
| | attn = attn.transpose(1, 2).contiguous().view(n_graph, -1, embed_dim) |
| | attn = self.out_proj(attn) |
| | return attn |
| |
|
| |
|
| | class Graphormer3DEncoderLayer(nn.Module): |
| | """ |
| | Implements a Graphormer-3D Encoder Layer. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | embedding_dim: int = 768, |
| | ffn_embedding_dim: int = 3072, |
| | num_attention_heads: int = 8, |
| | dropout: float = 0.1, |
| | attention_dropout: float = 0.1, |
| | activation_dropout: float = 0.1, |
| | ) -> None: |
| | super().__init__() |
| |
|
| | |
| | self.embedding_dim = embedding_dim |
| | self.num_attention_heads = num_attention_heads |
| | self.attention_dropout = attention_dropout |
| |
|
| | self.dropout = dropout |
| | self.activation_dropout = activation_dropout |
| |
|
| | self.self_attn = SelfMultiheadAttention(self.embedding_dim, num_attention_heads, dropout=attention_dropout) |
| | |
| | self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim) |
| | self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) |
| | self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) |
| | self.final_layer_norm = nn.LayerNorm(self.embedding_dim) |
| |
|
| | def forward(self, x: Tensor, attn_bias: Tensor = None): |
| | residual = x |
| | x = self.self_attn_layer_norm(x) |
| | x = self.self_attn(query=x, attn_bias=attn_bias) |
| | x = F.dropout(x, p=self.dropout, training=self.training) |
| | x = residual + x |
| |
|
| | residual = x |
| | x = self.final_layer_norm(x) |
| | x = F.gelu(self.fc1(x)) |
| | x = F.dropout(x, p=self.activation_dropout, training=self.training) |
| | x = self.fc2(x) |
| | x = F.dropout(x, p=self.dropout, training=self.training) |
| | x = residual + x |
| | return x |
| |
|
| |
|
| | from fairseq.models import ( |
| | BaseFairseqModel, |
| | register_model, |
| | register_model_architecture, |
| | ) |
| |
|
| |
|
| | class Graphormer3D(BaseFairseqModel): |
| | def __init__(self): |
| | super().__init__() |
| | self.atom_types = 64 |
| | self.edge_types = 64 * 64 |
| | self.embed_dim = 768 |
| | self.layer_nums = 12 |
| | self.ffn_embed_dim = 768 |
| | self.blocks = 4 |
| | self.attention_heads = 48 |
| | self.input_dropout = 0.0 |
| | self.dropout = 0.1 |
| | self.attention_dropout = 0.1 |
| | self.activation_dropout = 0.0 |
| | self.node_loss_weight = 15 |
| | self.min_node_loss_weight = 1 |
| | self.eng_loss_weight = 1 |
| | self.num_kernel = 128 |
| | self.atom_encoder = nn.Embedding(self.atom_types, self.embed_dim, padding_idx=0) |
| | self.edge_embedding = nn.Embedding(32, self.attention_heads, padding_idx=0) |
| | self.input_dropout = nn.Dropout(0.1) |
| | self.layers = nn.ModuleList( |
| | [ |
| | Graphormer3DEncoderLayer( |
| | self.embed_dim, |
| | self.ffn_embed_dim, |
| | num_attention_heads=self.attention_heads, |
| | dropout=self.dropout, |
| | attention_dropout=self.attention_dropout, |
| | activation_dropout=self.activation_dropout, |
| | ) |
| | for _ in range(self.layer_nums) |
| | ] |
| | ) |
| | self.atom_encoder = nn.Embedding(512 * 9 + 1, self.embed_dim, padding_idx=0) |
| | self.edge_encoder = nn.Embedding(512 * 3 + 1, self.attention_heads, padding_idx=0) |
| | self.edge_type = 'multi_hop' |
| | if self.edge_type == 'multi_hop': |
| | self.edge_dis_encoder = nn.Embedding(16 * self.attention_heads * self.attention_heads, 1) |
| | self.spatial_pos_encoder = nn.Embedding(512, self.attention_heads, padding_idx=0) |
| | self.in_degree_encoder = nn.Embedding(512, self.embed_dim, padding_idx=0) |
| | self.out_degree_encoder = nn.Embedding(512, self.embed_dim, padding_idx=0) |
| | self.node_position_ids_encoder = nn.Embedding(10, self.embed_dim, padding_idx=0) |
| |
|
| | self.final_ln: Callable[[Tensor], Tensor] = nn.LayerNorm(self.embed_dim) |
| |
|
| | self.engergy_proj: Callable[[Tensor], Tensor] = NonLinear(self.embed_dim, 1) |
| | self.energe_agg_factor: Callable[[Tensor], Tensor] = nn.Embedding(3, 1) |
| | nn.init.normal_(self.energe_agg_factor.weight, 0, 0.01) |
| |
|
| | self.graph_token = nn.Embedding(1, 768) |
| | self.graph_token_virtual_distance = nn.Embedding(1, self.attention_heads) |
| |
|
| | K = self.num_kernel |
| |
|
| | self.gbf: Callable[[Tensor, Tensor], Tensor] = GaussianLayer(K, self.edge_types) |
| | self.bias_proj: Callable[[Tensor], Tensor] = NonLinear(K, self.attention_heads) |
| | self.edge_proj: Callable[[Tensor], Tensor] = nn.Linear(K, self.embed_dim) |
| | self.node_proc: Callable[[Tensor, Tensor, Tensor], Tensor] = NodeTaskHead(self.embed_dim, self.attention_heads) |
| |
|
| | def forward(self, node_feature, spatial_pos, in_degree, out_degree, edge_type_matrix, edge_input, node_position_ids): |
| | """ |
| | node_feature: text embedding |
| | spatial_pos: The shortest path length between nodes in the graph, shape: (n_graph, n_node, n_node) |
| | in_degree: The in-degree of nodes in the graph, shape: (n_graph, n_node) |
| | out_degree: The out-degree of nodes in the graph, shape: (n_graph, n_node) |
| | edge_type_matrix: The edge type of edges in the graph |
| | edge_input: The shortest path route between nodes in the graph, shape: (n_graph, n_node, n_node, multi_hop_max_dist, n_edge_features) |
| | node_position_ids: node poistion ids |
| | """ |
| | attn_edge_type = self.edge_embedding(edge_type_matrix) |
| | edge_input = self.edge_embedding(edge_input) |
| | n_graph, n_node = node_feature.size()[:2] |
| | spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2) |
| |
|
| | if self.edge_type == 'multi_hop': |
| | spatial_pos_ = spatial_pos.clone() |
| | spatial_pos_[spatial_pos_ == 0] = 1 |
| | spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_) |
| | max_dist = edge_input.size(-2) |
| | edge_input_flat = edge_input.permute(3, 0, 1, 2, 4).reshape(max_dist, -1, self.attention_heads) |
| | edge_input_flat = torch.bmm(edge_input_flat, self.edge_dis_encoder.weight.reshape(-1, self.attention_heads, self.attention_heads)[:max_dist, :, :]) |
| | edge_input = edge_input_flat.reshape(max_dist, n_graph, n_node, n_node, self.attention_heads).permute(1, 2, 3, 0, 4) |
| | edge_input = (edge_input.sum(-2) / (spatial_pos_.float().unsqueeze(-1))).permute(0, 3, 1, 2) |
| | else: |
| | |
| | edge_input = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2) |
| |
|
| | graph_attn_bias = spatial_pos_bias + edge_input |
| | node_position_embedding = self.node_position_ids_encoder(node_position_ids) |
| | node_position_embedding = node_position_embedding.contiguous().view(n_graph, n_node, self.embed_dim) |
| | node_feature = node_feature + self.in_degree_encoder(in_degree) + \ |
| | self.out_degree_encoder(out_degree) + node_position_embedding |
| |
|
| | |
| | output = self.input_dropout(node_feature) |
| | for enc_layer in self.layers: |
| | output = enc_layer(output, graph_attn_bias) |
| | output = self.final_ln(output) |
| |
|
| | return output |
| |
|
| |
|
| | @torch.jit.script |
| | def gaussian(x, mean, std): |
| | pi = 3.14159 |
| | a = (2 * pi) ** 0.5 |
| | return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std) |
| |
|
| |
|
| | class GaussianLayer(nn.Module): |
| | def __init__(self, K=128, edge_types=1024): |
| | super().__init__() |
| | self.K = K |
| | self.means = nn.Embedding(1, K) |
| | self.stds = nn.Embedding(1, K) |
| | self.mul = nn.Embedding(edge_types, 1) |
| | self.bias = nn.Embedding(edge_types, 1) |
| | nn.init.uniform_(self.means.weight, 0, 3) |
| | nn.init.uniform_(self.stds.weight, 0, 3) |
| | nn.init.constant_(self.bias.weight, 0) |
| | nn.init.constant_(self.mul.weight, 1) |
| |
|
| | def forward(self, x, edge_types): |
| | mul = self.mul(edge_types) |
| | bias = self.bias(edge_types) |
| | x = mul * x.unsqueeze(-1) + bias |
| | x = x.expand(-1, -1, -1, self.K) |
| | mean = self.means.weight.float().view(-1) |
| | std = self.stds.weight.float().view(-1).abs() + 1e-5 |
| | return gaussian(x.float(), mean, std).type_as(self.means.weight) |
| |
|
| |
|
| | class RBF(nn.Module): |
| | def __init__(self, K, edge_types): |
| | super().__init__() |
| | self.K = K |
| | self.means = nn.parameter.Parameter(torch.empty(K)) |
| | self.temps = nn.parameter.Parameter(torch.empty(K)) |
| | self.mul: Callable[..., Tensor] = nn.Embedding(edge_types, 1) |
| | self.bias: Callable[..., Tensor] = nn.Embedding(edge_types, 1) |
| | nn.init.uniform_(self.means, 0, 3) |
| | nn.init.uniform_(self.temps, 0.1, 10) |
| | nn.init.constant_(self.bias.weight, 0) |
| | nn.init.constant_(self.mul.weight, 1) |
| |
|
| | def forward(self, x: Tensor, edge_types): |
| | mul = self.mul(edge_types) |
| | bias = self.bias(edge_types) |
| | x = mul * x.unsqueeze(-1) + bias |
| | mean = self.means.float() |
| | temp = self.temps.float().abs() |
| | return ((x - mean).square() * (-temp)).exp().type_as(self.means) |
| |
|
| |
|
| | class NonLinear(nn.Module): |
| | def __init__(self, input, output_size, hidden=None): |
| | super(NonLinear, self).__init__() |
| | if hidden is None: |
| | hidden = input |
| | self.layer1 = nn.Linear(input, hidden) |
| | self.layer2 = nn.Linear(hidden, output_size) |
| |
|
| | def forward(self, x): |
| | x = F.gelu(self.layer1(x)) |
| | x = self.layer2(x) |
| | return x |
| |
|
| |
|
| | class NodeTaskHead(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dim: int, |
| | num_heads: int, |
| | ): |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.q_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim) |
| | self.k_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim) |
| | self.v_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim) |
| | self.num_heads = num_heads |
| | self.scaling = (embed_dim // num_heads) ** -0.5 |
| | self.force_proj1: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1) |
| | self.force_proj2: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1) |
| | self.force_proj3: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1) |
| |
|
| | def forward( |
| | self, |
| | query: Tensor, |
| | attn_bias: Tensor, |
| | delta_pos: Tensor, |
| | ) -> Tensor: |
| | bsz, n_node, _ = query.size() |
| | q = (self.q_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2) * self.scaling) |
| | k = self.k_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2) |
| | v = self.v_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2) |
| | attn = q @ k.transpose(-1, -2) |
| | attn_probs = softmax_dropout(attn.view(-1, n_node, n_node) + attn_bias, 0.1, self.training).view(bsz, self.num_heads, n_node, n_node) |
| | rot_attn_probs = attn_probs.unsqueeze(-1) * delta_pos.unsqueeze(1).type_as(attn_probs) |
| | rot_attn_probs = rot_attn_probs.permute(0, 1, 4, 2, 3) |
| | x = rot_attn_probs @ v.unsqueeze(2) |
| | x = x.permute(0, 3, 2, 1, 4).contiguous().view(bsz, n_node, 3, -1) |
| | f1 = self.force_proj1(x[:, :, 0, :]).view(bsz, n_node, 1) |
| | f2 = self.force_proj2(x[:, :, 1, :]).view(bsz, n_node, 1) |
| | f3 = self.force_proj3(x[:, :, 2, :]).view(bsz, n_node, 1) |
| | cur_force = torch.cat([f1, f2, f3], dim=-1).float() |
| | return cur_force |
| |
|
| |
|