Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional | |
| from torch import Tensor | |
| from omegaconf import OmegaConf | |
| class NonLinearHead(nn.Module): | |
| """Head for simple classification tasks.""" | |
| def __init__( | |
| self, | |
| input_dim, | |
| out_dim, | |
| hidden=None, | |
| ): | |
| super().__init__() | |
| hidden = input_dim if not hidden else hidden | |
| self.linear1 = nn.Linear(input_dim, hidden) | |
| self.linear2 = nn.Linear(hidden, out_dim) | |
| self.activation_fn = F.gelu | |
| def forward(self, x): | |
| x = self.linear1(x) | |
| x = self.activation_fn(x) | |
| x = self.linear2(x) | |
| return x | |
| class GaussianLayer(nn.Module): | |
| def __init__(self, num_distance=25, K=16, edge_dim=1024): | |
| super().__init__() | |
| self.K = K | |
| self.means = nn.Embedding(1, num_distance*self.K) # 16 * 25 = 400, it's the total number of kernels | |
| self.stds = nn.Embedding(1, num_distance*self.K) | |
| self.mul = nn.Linear(edge_dim, num_distance) | |
| self.bias = nn.Linear(edge_dim, num_distance) | |
| nn.init.uniform_(self.means.weight, 0, 3) | |
| nn.init.uniform_(self.stds.weight, 0, 3) | |
| nn.init.constant_(self.bias.weight, 0) | |
| nn.init.constant_(self.mul.weight, 1) | |
| def forward(self, x, edge_feat): | |
| mul = self.mul(edge_feat).type_as(x) | |
| bias = self.bias(edge_feat).type_as(x) | |
| # x = mul * x.unsqueeze(-1) + bias # [B, N, N, 25, 1] | |
| x = mul * x + bias # [B, N, N, 25] | |
| x = x.unsqueeze(-1) # [B, N, N, 25, 1] | |
| x = x.expand(-1, -1, -1, -1, self.K) # [B, N, N, 25, K] | |
| x = x.reshape(x.shape[0], x.shape[1], x.shape[2], -1) # [B, N, N, 25*K] | |
| mean = self.means.weight.float().view(-1) | |
| std = self.stds.weight.float().view(-1).abs() + 1e-5 | |
| return gaussian(x.float(), mean, std).type_as(self.means.weight) | |
| class GaussianEncoder(nn.Module): | |
| def __init__(self, kernel_num, feat_dim, num_head, use_dist=1, use_product=1): | |
| super().__init__() | |
| self.num_distance = 0 | |
| self.use_dist = use_dist | |
| self.use_product = use_product | |
| if use_dist: | |
| self.num_distance += 1 | |
| if use_product: | |
| self.num_distance += 1 | |
| self.gbf = GaussianLayer(self.num_distance, kernel_num, feat_dim) | |
| self.node_gate = nn.Linear(feat_dim, 1) | |
| self.gbf_proj = NonLinearHead( | |
| input_dim=kernel_num*self.num_distance, | |
| out_dim=num_head, | |
| hidden=128, | |
| ) | |
| self.centrality_proj = NonLinearHead( | |
| input_dim=kernel_num*self.num_distance, | |
| out_dim=feat_dim, | |
| hidden=1024, | |
| ) | |
| def get_encoding_features(self, dist, et, pair_mask=None, get_bias=True): | |
| n_node = dist.size(-2) | |
| gbf_feature = self.gbf(dist, et) | |
| if pair_mask is not None: | |
| centrality_encoding = gbf_feature * pair_mask.unsqueeze(-1) | |
| else: | |
| centrality_encoding = gbf_feature # [B, N, N, 25*K] | |
| centrality_encoding = self.centrality_proj(centrality_encoding.sum(dim=-2)) # [B, N, encoder_embed_dim] | |
| graph_attn_bias = self.gbf_proj(gbf_feature) # [B, N, N, num_head] | |
| if get_bias: | |
| graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous() # [B, num_head, N, N] | |
| graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node) # [B*num_head, N, N] | |
| return graph_attn_bias, centrality_encoding | |
| def build_pairwise_product_dist(self, coords, node_feat): | |
| dist = _get_dist(coords,coords) | |
| coords = coords * self.node_gate(node_feat) | |
| pretext = coords[:,:,None]+coords[:,None,:] | |
| A = torch.einsum('bijd,bjd->bij', pretext, coords) | |
| B = torch.einsum('bid,bjd->bij', coords, coords) | |
| product = A*B | |
| product, dist = product[...,None], dist[...,None] | |
| geo_feat = torch.empty_like(product)[...,0:0] | |
| if self.use_dist: | |
| geo_feat = torch.cat([geo_feat, dist], dim=-1) | |
| if self.use_product: | |
| geo_feat = torch.cat([geo_feat, product], dim=-1) | |
| return geo_feat | |
| def forward(self, coords, node_feat, pair_mask=None, get_bias=True): | |
| geo_feat = self.build_pairwise_product_dist(coords, node_feat) | |
| edge_feat = node_feat[:,:,None,:]-node_feat[:,None,:,:] | |
| graph_attn_bias, centrality_encoding = self.get_encoding_features(geo_feat, edge_feat, pair_mask=pair_mask, get_bias=get_bias) | |
| x = centrality_encoding | |
| return x, graph_attn_bias | |
| class SelfMultiheadAttention(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim, | |
| num_heads, | |
| dropout=0.1, | |
| bias=True, | |
| scaling_factor=1, | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.dropout = dropout | |
| self.head_dim = embed_dim // num_heads | |
| assert ( | |
| self.head_dim * num_heads == self.embed_dim | |
| ), "embed_dim must be divisible by num_heads" | |
| self.scaling = (self.head_dim * scaling_factor) ** -0.5 | |
| self.in_proj = nn.Linear(embed_dim, embed_dim * 3, bias=bias) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | |
| def forward( | |
| self, | |
| query, | |
| key_padding_mask: Optional[Tensor] = None, | |
| attn_bias: Optional[Tensor] = None, | |
| return_attn: bool = False, | |
| ) -> Tensor: | |
| bsz, tgt_len, embed_dim = query.size() | |
| assert embed_dim == self.embed_dim | |
| q, k, v = self.in_proj(query).chunk(3, dim=-1) | |
| q = ( | |
| q.view(bsz, tgt_len, self.num_heads, self.head_dim) | |
| .transpose(1, 2) | |
| .contiguous() | |
| .view(bsz * self.num_heads, -1, self.head_dim) | |
| * self.scaling | |
| ) | |
| if k is not None: | |
| k = ( | |
| k.view(bsz, -1, self.num_heads, self.head_dim) | |
| .transpose(1, 2) | |
| .contiguous() | |
| .view(bsz * self.num_heads, -1, self.head_dim) | |
| ) | |
| if v is not None: | |
| v = ( | |
| v.view(bsz, -1, self.num_heads, self.head_dim) | |
| .transpose(1, 2) | |
| .contiguous() | |
| .view(bsz * self.num_heads, -1, self.head_dim) | |
| ) | |
| assert k is not None | |
| src_len = k.size(1) | |
| # This is part of a workaround to get around fork/join parallelism | |
| # not supporting Optional types. | |
| if key_padding_mask is not None and key_padding_mask.dim() == 0: | |
| key_padding_mask = None | |
| if key_padding_mask is not None: | |
| assert key_padding_mask.size(0) == bsz | |
| assert key_padding_mask.size(1) == src_len | |
| attn_weights = torch.bmm(q, k.transpose(1, 2)) | |
| assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] | |
| if key_padding_mask is not None: | |
| # don't attend to padding symbols | |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | |
| attn_weights.masked_fill_( | |
| key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") | |
| ) | |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| if not return_attn: | |
| attn = F.dropout(F.softmax(attn_weights, dim=-1), p=self.dropout, training=self.training) | |
| else: | |
| attn_weights += attn_bias | |
| attn = F.dropout(F.softmax(attn_weights, dim=-1), p=self.dropout, training=self.training) | |
| # pdb.set_trace() | |
| o = torch.bmm(attn, v) | |
| assert list(o.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] | |
| o = ( | |
| o.view(bsz, self.num_heads, tgt_len, self.head_dim) | |
| .transpose(1, 2) | |
| .contiguous() | |
| .view(bsz, tgt_len, embed_dim) | |
| ) | |
| o = self.out_proj(o) | |
| if not return_attn: | |
| return o | |
| else: | |
| return o, attn_weights, attn | |
| class TransformerEncoderLayer(nn.Module): | |
| """ | |
| Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained | |
| models. | |
| """ | |
| def __init__( | |
| self, | |
| embed_dim: int = 768, | |
| ffn_embed_dim: int = 3072, | |
| attention_heads: int = 8, | |
| dropout: float = 0.1, | |
| attention_dropout: float = 0.1, | |
| activation_dropout: float = 0.0, | |
| activation_fn: str = "gelu", | |
| post_ln = False, | |
| # edge_attn_hidden_dim = 8, | |
| # edge_attn_heads = 4, | |
| ) -> None: | |
| super().__init__() | |
| # Initialize parameters | |
| self.embed_dim = embed_dim | |
| self.attention_heads = attention_heads | |
| self.attention_dropout = attention_dropout | |
| self.dropout = dropout | |
| self.activation_dropout = activation_dropout | |
| self.activation_fn = F.gelu | |
| # self.edge_attn_hidden_dim = edge_attn_hidden_dim | |
| # self.edge_attn_heads = edge_attn_heads | |
| self.self_attn = SelfMultiheadAttention( | |
| self.embed_dim, | |
| self.attention_heads, | |
| dropout=attention_dropout, | |
| ) | |
| # layer norm associated with the self attention layer | |
| self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) | |
| self.fc1 = nn.Linear(self.embed_dim, ffn_embed_dim) | |
| self.fc2 = nn.Linear(ffn_embed_dim, self.embed_dim) | |
| self.final_layer_norm = nn.LayerNorm(self.embed_dim) | |
| self.post_ln = post_ln | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| attn_bias: Optional[torch.Tensor] = None, | |
| padding_mask: Optional[torch.Tensor] = None, | |
| pair_mask: Optional[torch.Tensor] = None, | |
| return_attn: bool=False, | |
| ) -> torch.Tensor: | |
| """ | |
| LayerNorm is applied either before or after the self-attention/ffn | |
| modules similar to the original Transformer implementation. | |
| """ | |
| residual = x | |
| if not self.post_ln: | |
| x = self.self_attn_layer_norm(x) | |
| # new added | |
| x = self.self_attn( | |
| query=x, | |
| key_padding_mask=padding_mask, | |
| attn_bias=attn_bias, | |
| return_attn=return_attn, | |
| ) | |
| if return_attn: | |
| x, attn_weights, attn_probs = x | |
| # edge_repr = attn_weights | |
| # edge_repr[edge_repr == float("-inf")] = 0 | |
| # edge_repr = edge_repr.view(x.shape[0], -1, x.shape[1], x.shape[1]).permute(0, 2, 3, 1).contiguous() | |
| # edge_repr_update = self.edge_attn(edge_repr, pair_mask) | |
| # edge_repr_update = edge_repr_update.permute(0, 3, 1, 2).contiguous() | |
| # edge_repr_update = edge_repr_update.view(-1, x.shape[1], x.shape[1]) # [bsz*num_heads, tgt_len, src_len] | |
| # attn_weights = attn_weights + edge_repr_update # residual connection and keep padding mask | |
| x = F.dropout(x, p=self.dropout, training=self.training) | |
| x = residual + x | |
| if self.post_ln: | |
| x = self.self_attn_layer_norm(x) | |
| residual = x | |
| if not self.post_ln: | |
| x = self.final_layer_norm(x) | |
| x = self.fc1(x) | |
| x = self.activation_fn(x) | |
| x = F.dropout(x, p=self.activation_dropout, training=self.training) | |
| x = self.fc2(x) | |
| x = F.dropout(x, p=self.dropout, training=self.training) | |
| x = residual + x | |
| if self.post_ln: | |
| x = self.final_layer_norm(x) | |
| if not return_attn: | |
| return x | |
| else: | |
| return x, attn_weights, attn_probs | |
| class TransformerEncoderWithPair(nn.Module): | |
| def __init__( | |
| self, | |
| encoder_layers: int = 6, | |
| embed_dim: int = 768, | |
| ffn_embed_dim: int = 3072, | |
| attention_heads: int = 8, | |
| emb_dropout: float = 0.1, | |
| dropout: float = 0.1, | |
| attention_dropout: float = 0.1, | |
| activation_dropout: float = 0.0, | |
| max_seq_len: int = 256, | |
| activation_fn: str = "gelu", | |
| post_ln: bool = False, | |
| no_final_head_layer_norm: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.emb_dropout = emb_dropout | |
| self.max_seq_len = max_seq_len | |
| self.embed_dim = embed_dim | |
| self.attention_heads = attention_heads | |
| self.emb_layer_norm = nn.LayerNorm(self.embed_dim) | |
| if not post_ln: | |
| self.final_layer_norm = nn.LayerNorm(self.embed_dim) | |
| else: | |
| self.final_layer_norm = None | |
| if not no_final_head_layer_norm: | |
| self.final_head_layer_norm = nn.LayerNorm(attention_heads) | |
| else: | |
| self.final_head_layer_norm = None | |
| self.layers = nn.ModuleList( | |
| [ | |
| TransformerEncoderLayer( | |
| embed_dim=self.embed_dim, | |
| ffn_embed_dim=ffn_embed_dim, | |
| attention_heads=attention_heads, | |
| dropout=dropout, | |
| attention_dropout=attention_dropout, | |
| activation_dropout=activation_dropout, | |
| activation_fn=activation_fn, | |
| post_ln=post_ln, | |
| ) | |
| for _ in range(encoder_layers) | |
| ] | |
| ) | |
| def forward( | |
| self, | |
| emb: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| padding_mask: Optional[torch.Tensor] = None, | |
| pair_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| bsz = emb.size(0) | |
| seq_len = emb.size(1) | |
| x = self.emb_layer_norm(emb) | |
| x = F.dropout(x, p=self.emb_dropout, training=self.training) | |
| # account for padding while computing the representation | |
| if padding_mask is not None: | |
| x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) | |
| input_attn_mask = attn_mask | |
| input_padding_mask = padding_mask | |
| def fill_attn_mask(attn_mask, padding_mask, fill_val=float("-inf")): | |
| if attn_mask is not None and padding_mask is not None: | |
| # merge key_padding_mask and attn_mask | |
| attn_mask = attn_mask.view(x.size(0), -1, seq_len, seq_len) | |
| attn_mask.masked_fill_( | |
| padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), | |
| fill_val, | |
| ) | |
| attn_mask = attn_mask.view(-1, seq_len, seq_len) | |
| padding_mask = None | |
| return attn_mask, padding_mask | |
| assert attn_mask is not None | |
| attn_mask, padding_mask = fill_attn_mask(attn_mask, padding_mask) | |
| for i in range(len(self.layers)): | |
| x, attn_mask, _ = self.layers[i]( | |
| x, padding_mask=padding_mask, attn_bias=attn_mask, return_attn=True | |
| ) | |
| def norm_loss(x, eps=1e-10, tolerance=1.0): | |
| x = x.float() | |
| max_norm = x.shape[-1] ** 0.5 | |
| norm = torch.sqrt(torch.sum(x**2, dim=-1) + eps) | |
| error = torch.nn.functional.relu((norm - max_norm).abs() - tolerance) | |
| return error | |
| def masked_mean(mask, value, dim=-1, eps=1e-10): | |
| return ( | |
| torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) | |
| ).mean() | |
| x_norm = norm_loss(x) | |
| if input_padding_mask is not None: | |
| token_mask = 1.0 - input_padding_mask.float() | |
| else: | |
| token_mask = torch.ones_like(x_norm, device=x_norm.device) | |
| x_norm = masked_mean(token_mask, x_norm) | |
| if self.final_layer_norm is not None: | |
| x = self.final_layer_norm(x) | |
| delta_pair_repr = attn_mask - input_attn_mask | |
| delta_pair_repr, _ = fill_attn_mask(delta_pair_repr, input_padding_mask, 0) | |
| attn_mask = ( | |
| attn_mask.view(bsz, -1, seq_len, seq_len).permute(0, 2, 3, 1).contiguous() | |
| ) | |
| delta_pair_repr = ( | |
| delta_pair_repr.view(bsz, -1, seq_len, seq_len) | |
| .permute(0, 2, 3, 1) | |
| .contiguous() | |
| ) | |
| pair_mask = token_mask[..., None] * token_mask[..., None, :] | |
| delta_pair_repr_norm = norm_loss(delta_pair_repr) | |
| delta_pair_repr_norm = masked_mean( | |
| pair_mask, delta_pair_repr_norm, dim=(-1, -2) | |
| ) | |
| if self.final_head_layer_norm is not None: | |
| delta_pair_repr = self.final_head_layer_norm(delta_pair_repr) | |
| return x, attn_mask, delta_pair_repr, x_norm, delta_pair_repr_norm | |
| def gaussian(x, mean, std): | |
| pi = 3.14159 | |
| a = (2 * pi) ** 0.5 | |
| return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std) | |
| def _get_dist(A, B): | |
| D_A_B = torch.sqrt(torch.sum((A[..., None,:] - B[...,None,:,:])**2,-1) + 1e-6) #[B, L, L] | |
| return D_A_B | |