| from transformers import PreTrainedModel |
| |
| |
| import torch |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn import Parameter, Sequential, ModuleList, Linear |
|
|
| from rdkit import Chem |
| from rdkit.Chem import AllChem |
|
|
| from transformers import PretrainedConfig |
| from transformers import PreTrainedModel |
| from transformers import AutoModel |
|
|
| from torch_geometric.data import Data |
| from torch_geometric.loader import DataLoader |
| from torch_geometric.utils import remove_self_loops, add_self_loops, sort_edge_index |
| from torch_scatter import scatter |
| from torch_geometric.nn import global_add_pool, radius |
| from torch_sparse import SparseTensor |
|
|
| from transmxm_model.configuration_transmxm import TransmxmConfig |
|
|
| from tqdm import tqdm |
| import numpy as np |
| import pandas as pd |
| from typing import List |
| import math |
| import inspect |
| from operator import itemgetter |
| from collections import OrderedDict |
| from math import sqrt, pi as PI |
| from scipy.optimize import brentq |
| from scipy import special as sp |
|
|
| try: |
| import sympy as sym |
| except ImportError: |
| sym = None |
|
|
|
|
|
|
| class SmilesDataset(torch.utils.data.Dataset): |
| def __init__(self, smiles): |
| self.smiles_list = smiles |
| self.data_list = [] |
|
|
|
|
| def __len__(self): |
| return len(self.data_list) |
|
|
| def __getitem__(self, idx): |
| return self.data_list[idx] |
|
|
| def get_data(self, smiles): |
| self.smiles_list = smiles |
| |
| |
| types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'S': 4} |
|
|
| for i in range(len(self.smiles_list)): |
| |
| |
| mol = Chem.MolFromSmiles(self.smiles_list[i]) |
| if mol is None: |
| print("无法创建Mol对象", self.smiles_list[i]) |
| else: |
|
|
| mol3d = Chem.AddHs( |
| mol) |
| if mol3d is None: |
| print("无法创建mol3d对象", self.smiles_list[i]) |
| else: |
| AllChem.EmbedMolecule(mol3d, randomSeed=1) |
|
|
| N = mol3d.GetNumAtoms() |
| |
| if mol3d.GetNumConformers() > 0: |
| conformer = mol3d.GetConformer() |
| pos = conformer.GetPositions() |
| pos = torch.tensor(pos, dtype=torch.float) |
|
|
| type_idx = [] |
| |
| |
| |
| |
| |
| for atom in mol3d.GetAtoms(): |
| type_idx.append(types[atom.GetSymbol()]) |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| row, col, edge_type = [], [], [] |
| for bond in mol3d.GetBonds(): |
| start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() |
| row += [start, end] |
| col += [end, start] |
| |
|
|
| edge_index = torch.tensor([row, col], dtype=torch.long) |
| |
| |
|
|
| perm = (edge_index[0] * N + edge_index[1]).argsort() |
| edge_index = edge_index[:, perm] |
| |
| |
| |
| |
| |
|
|
| x = torch.tensor(type_idx).to(torch.float) |
|
|
| |
|
|
| data = Data(x=x, pos=pos, edge_index=edge_index, smiles=self.smiles_list[i]) |
|
|
| self.data_list.append(data) |
| else: |
| print("无法创建comfor", self.smiles_list[i]) |
| return self.data_list |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| import math |
| import logging |
| from typing import List, Optional, Tuple |
|
|
| import numpy as np |
| from torch.nn import LayerNorm |
| import copy |
| from typing import Optional |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn, Tensor |
|
|
|
|
| class PositionEmbeddingSine(nn.Module): |
| """ |
| This is a more standard version of the position embedding, very similar to the one |
| used by the Attention is all you need paper, generalized to work on images. (To 1D sequences) |
| """ |
| def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): |
| super().__init__() |
| self.num_pos_feats = num_pos_feats |
| self.temperature = temperature |
| self.normalize = normalize |
| if scale is not None and normalize is False: |
| raise ValueError("normalize should be True if scale is passed") |
| if scale is None: |
| scale = 2 * math.pi |
| self.scale = scale |
|
|
| def forward(self, x, mask): |
| """ |
| Args: |
| x: torch.tensor, (batch_size, L, d) |
| mask: torch.tensor, (batch_size, L), with 1 as valid |
| |
| Returns: |
| |
| """ |
| assert mask is not None |
| x_embed = mask.cumsum(1, dtype=torch.float32) |
| if self.normalize: |
| eps = 1e-6 |
| x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale |
|
|
| dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) |
| |
| dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / self.num_pos_feats) |
| pos_x = x_embed[:, :, None] / dim_t |
| pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) |
| |
| return pos_x |
|
|
| def build_position_encoding(x): |
| N_steps = x |
| pos_embed = PositionEmbeddingSine(N_steps, normalize=True) |
|
|
| return pos_embed |
|
|
|
|
| class Transformer(nn.Module): |
|
|
| def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, |
| num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, |
| activation="relu", normalize_before=False): |
| super().__init__() |
|
|
| |
| encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, |
| dropout, activation, normalize_before) |
| encoder_norm = nn.LayerNorm(d_model) if normalize_before else None |
| self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) |
|
|
| self._reset_parameters() |
|
|
| self.d_model = d_model |
| self.nhead = nhead |
|
|
| def _reset_parameters(self): |
| for p in self.parameters(): |
| if p.dim() > 1: |
| nn.init.xavier_uniform_(p) |
|
|
| def forward(self, src, mask, att_mask, pos_embed): |
| """ |
| Args: |
| src: (batch_size, L, d) |
| mask: (batch_size, L) |
| query_embed: (#queries, d) |
| pos_embed: (batch_size, L, d) the same as src |
| |
| Returns: |
| |
| """ |
| src = src.permute(1, 0, 2) |
| pos_embed = pos_embed.permute(1, 0, 2) |
|
|
| memory = self.encoder( |
| src, |
| mask=att_mask, |
| src_key_padding_mask=mask, |
| pos=pos_embed |
| ) |
|
|
| memory = memory.transpose(0, 1) |
| return memory |
|
|
|
|
| class TransformerEncoder(nn.Module): |
|
|
| def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False): |
| super().__init__() |
| self.layers = _get_clones(encoder_layer, num_layers) |
| self.num_layers = num_layers |
| self.norm = norm |
| self.return_intermediate = return_intermediate |
|
|
| def forward(self, src, |
| mask: Optional[Tensor] = None, |
| src_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None): |
| output = src |
|
|
| intermediate = [] |
|
|
| for layer in self.layers: |
| output = layer(output, src_mask=mask, |
| src_key_padding_mask=src_key_padding_mask, pos=pos) |
| if self.return_intermediate: |
| intermediate.append(output) |
|
|
| if self.norm is not None: |
| output = self.norm(output) |
|
|
| if self.return_intermediate: |
| return torch.stack(intermediate) |
|
|
| return output |
|
|
|
|
| class TransformerEncoderLayer(nn.Module): |
|
|
| def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
| activation="relu", normalize_before=False): |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| |
| self.linear1 = nn.Linear(d_model, dim_feedforward) |
| self.dropout = nn.Dropout(dropout) |
| self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.dropout1 = nn.Dropout(dropout) |
| self.dropout2 = nn.Dropout(dropout) |
|
|
| self.activation = _get_activation_fn(activation) |
| self.normalize_before = normalize_before |
|
|
| def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
| return tensor if pos is None else tensor + pos |
|
|
| def forward_post(self, |
| src, |
| src_mask: Optional[Tensor] = None, |
| src_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None): |
| q = k = self.with_pos_embed(src, pos) |
| src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, |
| key_padding_mask=src_key_padding_mask)[0] |
| src = src + self.dropout1(src2) |
| src = self.norm1(src) |
| src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) |
| src = src + self.dropout2(src2) |
| src = self.norm2(src) |
| return src |
|
|
| def forward_pre(self, src, |
| src_mask: Optional[Tensor] = None, |
| src_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None): |
| src2 = self.norm1(src) |
| q = k = self.with_pos_embed(src2, pos) |
| src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, |
| key_padding_mask=src_key_padding_mask)[0] |
| src = src + self.dropout1(src2) |
| src2 = self.norm2(src) |
| src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) |
| src = src + self.dropout2(src2) |
| return src |
|
|
| def forward(self, src, |
| src_mask: Optional[Tensor] = None, |
| src_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None): |
| if self.normalize_before: |
| return self.forward_pre(src, src_mask, src_key_padding_mask, pos) |
| return self.forward_post(src, src_mask, src_key_padding_mask, pos) |
|
|
|
|
| def _get_clones(module, N): |
| return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
|
| def build_transformer(x): |
| return Transformer( |
| d_model=x, |
| dropout=0.5, |
| nhead=8, |
| dim_feedforward=1024, |
| num_encoder_layers=2, |
| normalize_before=True, |
| ) |
|
|
|
|
| def _get_activation_fn(activation): |
| """Return an activation function given a string""" |
| if activation == "relu": |
| return F.relu |
| if activation == "gelu": |
| return F.gelu |
| if activation == "glu": |
| return F.glu |
| raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |
|
|
|
|
|
|
| class EMA: |
| def __init__(self, model, decay): |
| self.decay = decay |
| self.shadow = {} |
| self.original = {} |
|
|
| |
| for name, param in model.named_parameters(): |
| if param.requires_grad: |
| self.shadow[name] = param.data.clone() |
|
|
| def __call__(self, model, num_updates=99999): |
| decay = min(self.decay, (1.0 + num_updates) / (10.0 + num_updates)) |
| for name, param in model.named_parameters(): |
| if param.requires_grad: |
| assert name in self.shadow |
| new_average = \ |
| (1.0 - decay) * param.data + decay * self.shadow[name] |
| self.shadow[name] = new_average.clone() |
|
|
| def assign(self, model): |
| for name, param in model.named_parameters(): |
| if param.requires_grad: |
| assert name in self.shadow |
| self.original[name] = param.data.clone() |
| param.data = self.shadow[name] |
|
|
| def resume(self, model): |
| for name, param in model.named_parameters(): |
| if param.requires_grad: |
| assert name in self.shadow |
| param.data = self.original[name] |
|
|
|
|
| def MLP(channels): |
| return Sequential(*[ |
| Sequential(Linear(channels[i - 1], channels[i]), SiLU()) |
| for i in range(1, len(channels))]) |
|
|
|
|
| class Res(nn.Module): |
| def __init__(self, dim): |
| super(Res, self).__init__() |
|
|
| self.mlp = MLP([dim, dim, dim]) |
|
|
| def forward(self, m): |
| m1 = self.mlp(m) |
| m_out = m1 + m |
| return m_out |
|
|
|
|
| def compute_idx(pos, edge_index): |
|
|
| pos_i = pos[edge_index[0]] |
| pos_j = pos[edge_index[1]] |
|
|
| d_ij = torch.norm(abs(pos_j - pos_i), dim=-1, keepdim=False).unsqueeze(-1) + 1e-5 |
| v_ji = (pos_i - pos_j) / d_ij |
|
|
| unique, counts = torch.unique(edge_index[0], sorted=True, return_counts=True) |
| full_index = torch.arange(0, edge_index[0].size()[0]).cuda().int() |
| |
|
|
| |
| repeat = torch.repeat_interleave(counts, counts) |
| counts_repeat1 = torch.repeat_interleave(full_index, repeat) |
|
|
| |
| split = torch.split(full_index, counts.tolist()) |
| index2 = list(edge_index[0].data.cpu().numpy()) |
| counts_repeat2 = torch.cat(itemgetter(*index2)(split), dim=0) |
|
|
| |
| v1 = v_ji[counts_repeat1.long()] |
| v2 = v_ji[counts_repeat2.long()] |
|
|
| angle = (v1*v2).sum(-1).unsqueeze(-1) |
| angle = torch.clamp(angle, min=-1.0, max=1.0) + 1e-6 + 1.0 |
|
|
| return counts_repeat1.long(), counts_repeat2.long(), angle |
|
|
|
|
| def Jn(r, n): |
| return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r) |
|
|
|
|
| def Jn_zeros(n, k): |
| zerosj = np.zeros((n, k), dtype='float32') |
| zerosj[0] = np.arange(1, k + 1) * np.pi |
| points = np.arange(1, k + n) * np.pi |
| racines = np.zeros(k + n - 1, dtype='float32') |
| for i in range(1, n): |
| for j in range(k + n - 1 - i): |
| foo = brentq(Jn, points[j], points[j + 1], (i, )) |
| racines[j] = foo |
| points = racines |
| zerosj[i][:k] = racines[:k] |
|
|
| return zerosj |
|
|
|
|
| def spherical_bessel_formulas(n): |
| x = sym.symbols('x') |
|
|
| f = [sym.sin(x) / x] |
| a = sym.sin(x) / x |
| for i in range(1, n): |
| b = sym.diff(a, x) / x |
| f += [sym.simplify(b * (-x)**i)] |
| a = sym.simplify(b) |
| return f |
|
|
|
|
| def bessel_basis(n, k): |
| zeros = Jn_zeros(n, k) |
| normalizer = [] |
| for order in range(n): |
| normalizer_tmp = [] |
| for i in range(k): |
| normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1)**2] |
| normalizer_tmp = 1 / np.array(normalizer_tmp)**0.5 |
| normalizer += [normalizer_tmp] |
|
|
| f = spherical_bessel_formulas(n) |
| x = sym.symbols('x') |
| bess_basis = [] |
| for order in range(n): |
| bess_basis_tmp = [] |
| for i in range(k): |
| bess_basis_tmp += [ |
| sym.simplify(normalizer[order][i] * |
| f[order].subs(x, zeros[order, i] * x)) |
| ] |
| bess_basis += [bess_basis_tmp] |
| return bess_basis |
|
|
|
|
| def sph_harm_prefactor(k, m): |
| return ((2 * k + 1) * np.math.factorial(k - abs(m)) / |
| (4 * np.pi * np.math.factorial(k + abs(m))))**0.5 |
|
|
|
|
| def associated_legendre_polynomials(k, zero_m_only=True): |
| z = sym.symbols('z') |
| P_l_m = [[0] * (j + 1) for j in range(k)] |
|
|
| P_l_m[0][0] = 1 |
| if k > 0: |
| P_l_m[1][0] = z |
|
|
| for j in range(2, k): |
| P_l_m[j][0] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][0] - |
| (j - 1) * P_l_m[j - 2][0]) / j) |
| if not zero_m_only: |
| for i in range(1, k): |
| P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1]) |
| if i + 1 < k: |
| P_l_m[i + 1][i] = sym.simplify( |
| (2 * i + 1) * z * P_l_m[i][i]) |
| for j in range(i + 2, k): |
| P_l_m[j][i] = sym.simplify( |
| ((2 * j - 1) * z * P_l_m[j - 1][i] - |
| (i + j - 1) * P_l_m[j - 2][i]) / (j - i)) |
|
|
| return P_l_m |
|
|
|
|
| def real_sph_harm(k, zero_m_only=True, spherical_coordinates=True): |
| if not zero_m_only: |
| S_m = [0] |
| C_m = [1] |
| for i in range(1, k): |
| x = sym.symbols('x') |
| y = sym.symbols('y') |
| S_m += [x * S_m[i - 1] + y * C_m[i - 1]] |
| C_m += [x * C_m[i - 1] - y * S_m[i - 1]] |
|
|
| P_l_m = associated_legendre_polynomials(k, zero_m_only) |
| if spherical_coordinates: |
| theta = sym.symbols('theta') |
| z = sym.symbols('z') |
| for i in range(len(P_l_m)): |
| for j in range(len(P_l_m[i])): |
| if type(P_l_m[i][j]) != int: |
| P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta)) |
| if not zero_m_only: |
| phi = sym.symbols('phi') |
| for i in range(len(S_m)): |
| S_m[i] = S_m[i].subs(x, |
| sym.sin(theta) * sym.cos(phi)).subs( |
| y, |
| sym.sin(theta) * sym.sin(phi)) |
| for i in range(len(C_m)): |
| C_m[i] = C_m[i].subs(x, |
| sym.sin(theta) * sym.cos(phi)).subs( |
| y, |
| sym.sin(theta) * sym.sin(phi)) |
|
|
| Y_func_l_m = [['0'] * (2 * j + 1) for j in range(k)] |
| for i in range(k): |
| Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0]) |
|
|
| if not zero_m_only: |
| for i in range(1, k): |
| for j in range(1, i + 1): |
| Y_func_l_m[i][j] = sym.simplify( |
| 2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]) |
| for i in range(1, k): |
| for j in range(1, i + 1): |
| Y_func_l_m[i][-j] = sym.simplify( |
| 2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]) |
|
|
| return Y_func_l_m |
|
|
|
|
| class BesselBasisLayer(torch.nn.Module): |
| def __init__(self, num_radial, cutoff, envelope_exponent=6): |
| super(BesselBasisLayer, self).__init__() |
| self.cutoff = cutoff |
| self.envelope = Envelope(envelope_exponent) |
|
|
| self.freq = torch.nn.Parameter(torch.Tensor(num_radial)) |
|
|
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| |
| |
| |
|
|
| |
| tmp_tensor = torch.arange(1, self.freq.numel() + 1, dtype=self.freq.dtype, device=self.freq.device) |
|
|
| |
| self.freq.data = torch.mul(tmp_tensor, PI) |
|
|
| def forward(self, dist): |
| dist = dist.unsqueeze(-1) / self.cutoff |
| return self.envelope(dist) * (self.freq * dist).sin() |
|
|
|
|
| class SiLU(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, input): |
| return silu(input) |
|
|
|
|
| def silu(input): |
| return input * torch.sigmoid(input) |
|
|
|
|
| class Envelope(torch.nn.Module): |
| def __init__(self, exponent): |
| super(Envelope, self).__init__() |
| self.p = exponent |
| self.a = -(self.p + 1) * (self.p + 2) / 2 |
| self.b = self.p * (self.p + 2) |
| self.c = -self.p * (self.p + 1) / 2 |
|
|
| def forward(self, x): |
| p, a, b, c = self.p, self.a, self.b, self.c |
| x_pow_p0 = x.pow(p) |
| x_pow_p1 = x_pow_p0 * x |
| env_val = 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p1 * x |
|
|
| zero = torch.zeros_like(x) |
| return torch.where(x < 1, env_val, zero) |
|
|
|
|
| class SphericalBasisLayer(torch.nn.Module): |
| def __init__(self, num_spherical, num_radial, cutoff=5.0, |
| envelope_exponent=5): |
| super(SphericalBasisLayer, self).__init__() |
| assert num_radial <= 64 |
| self.num_spherical = num_spherical |
| self.num_radial = num_radial |
| self.cutoff = cutoff |
| self.envelope = Envelope(envelope_exponent) |
|
|
| bessel_forms = bessel_basis(num_spherical, num_radial) |
| sph_harm_forms = real_sph_harm(num_spherical) |
| self.sph_funcs = [] |
| self.bessel_funcs = [] |
|
|
| x, theta = sym.symbols('x theta') |
| modules = {'sin': torch.sin, 'cos': torch.cos} |
| for i in range(num_spherical): |
| if i == 0: |
| sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0) |
| self.sph_funcs.append(lambda x: torch.zeros_like(x) + sph1) |
| else: |
| sph = sym.lambdify([theta], sph_harm_forms[i][0], modules) |
| self.sph_funcs.append(sph) |
| for j in range(num_radial): |
| bessel = sym.lambdify([x], bessel_forms[i][j], modules) |
| self.bessel_funcs.append(bessel) |
|
|
| def forward(self, dist, angle, idx_kj): |
| dist = dist / self.cutoff |
| rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1) |
| rbf = self.envelope(dist).unsqueeze(-1) * rbf |
|
|
| cbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1) |
|
|
| n, k = self.num_spherical, self.num_radial |
| out = (rbf[idx_kj].view(-1, n, k) * cbf.view(-1, n, 1)).view(-1, n * k) |
| return out |
|
|
|
|
|
|
| msg_special_args = set([ |
| 'edge_index', |
| 'edge_index_i', |
| 'edge_index_j', |
| 'size', |
| 'size_i', |
| 'size_j', |
| ]) |
|
|
| aggr_special_args = set([ |
| 'index', |
| 'dim_size', |
| ]) |
|
|
| update_special_args = set([]) |
|
|
|
|
| class MessagePassing(torch.nn.Module): |
| r"""Base class for creating message passing layers |
| |
| .. math:: |
| \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, |
| \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} |
| \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right), |
| |
| where :math:`\square` denotes a differentiable, permutation invariant |
| function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}` |
| and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as |
| MLPs. |
| See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/ |
| create_gnn.html>`__ for the accompanying tutorial. |
| |
| Args: |
| aggr (string, optional): The aggregation scheme to use |
| (:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`). |
| (default: :obj:`"add"`) |
| flow (string, optional): The flow direction of message passing |
| (:obj:`"source_to_target"` or :obj:`"target_to_source"`). |
| (default: :obj:`"source_to_target"`) |
| node_dim (int, optional): The axis along which to propagate. |
| (default: :obj:`0`) |
| """ |
| def __init__(self, aggr='add', flow='target_to_source', node_dim=0): |
| super(MessagePassing, self).__init__() |
|
|
| self.aggr = aggr |
| assert self.aggr in ['add', 'mean', 'max'] |
|
|
| self.flow = flow |
| assert self.flow in ['source_to_target', 'target_to_source'] |
|
|
| self.node_dim = node_dim |
| assert self.node_dim >= 0 |
|
|
| self.__msg_params__ = inspect.signature(self.message).parameters |
| self.__msg_params__ = OrderedDict(self.__msg_params__) |
|
|
| self.__aggr_params__ = inspect.signature(self.aggregate).parameters |
| self.__aggr_params__ = OrderedDict(self.__aggr_params__) |
| self.__aggr_params__.popitem(last=False) |
|
|
| self.__update_params__ = inspect.signature(self.update).parameters |
| self.__update_params__ = OrderedDict(self.__update_params__) |
| self.__update_params__.popitem(last=False) |
|
|
| msg_args = set(self.__msg_params__.keys()) - msg_special_args |
| aggr_args = set(self.__aggr_params__.keys()) - aggr_special_args |
| update_args = set(self.__update_params__.keys()) - update_special_args |
|
|
| self.__args__ = set().union(msg_args, aggr_args, update_args) |
|
|
| def __set_size__(self, size, index, tensor): |
| if not torch.is_tensor(tensor): |
| pass |
| elif size[index] is None: |
| size[index] = tensor.size(self.node_dim) |
| elif size[index] != tensor.size(self.node_dim): |
| raise ValueError( |
| (f'Encountered node tensor with size ' |
| f'{tensor.size(self.node_dim)} in dimension {self.node_dim}, ' |
| f'but expected size {size[index]}.')) |
|
|
| def __collect__(self, edge_index, size, kwargs): |
| i, j = (0, 1) if self.flow == "target_to_source" else (1, 0) |
| ij = {"_i": i, "_j": j} |
|
|
| out = {} |
| for arg in self.__args__: |
| if arg[-2:] not in ij.keys(): |
| out[arg] = kwargs.get(arg, inspect.Parameter.empty) |
| else: |
| idx = ij[arg[-2:]] |
| data = kwargs.get(arg[:-2], inspect.Parameter.empty) |
|
|
| if data is inspect.Parameter.empty: |
| out[arg] = data |
| continue |
|
|
| if isinstance(data, tuple) or isinstance(data, list): |
| assert len(data) == 2 |
| self.__set_size__(size, 1 - idx, data[1 - idx]) |
| data = data[idx] |
|
|
| if not torch.is_tensor(data): |
| out[arg] = data |
| continue |
|
|
| self.__set_size__(size, idx, data) |
| out[arg] = data.index_select(self.node_dim, edge_index[idx]) |
|
|
| size[0] = size[1] if size[0] is None else size[0] |
| size[1] = size[0] if size[1] is None else size[1] |
|
|
| |
| out['edge_index'] = edge_index |
| out['edge_index_i'] = edge_index[i] |
| out['edge_index_j'] = edge_index[j] |
| out['size'] = size |
| out['size_i'] = size[i] |
| out['size_j'] = size[j] |
|
|
| |
| out['index'] = out['edge_index_i'] |
| out['dim_size'] = out['size_i'] |
|
|
| return out |
|
|
| def __distribute__(self, params, kwargs): |
| out = {} |
| for key, param in params.items(): |
| data = kwargs[key] |
| if data is inspect.Parameter.empty: |
| if param.default is inspect.Parameter.empty: |
| raise TypeError(f'Required parameter {key} is empty.') |
| data = param.default |
| out[key] = data |
| return out |
|
|
| def propagate(self, edge_index, size=None, **kwargs): |
| r"""The initial call to start propagating messages. |
| |
| Args: |
| edge_index (Tensor): The indices of a general (sparse) assignment |
| matrix with shape :obj:`[N, M]` (can be directed or |
| undirected). |
| size (list or tuple, optional): The size :obj:`[N, M]` of the |
| assignment matrix. If set to :obj:`None`, the size will be |
| automatically inferred and assumed to be quadratic. |
| (default: :obj:`None`) |
| **kwargs: Any additional data which is needed to construct and |
| aggregate messages, and to update node embeddings. |
| """ |
|
|
| size = [None, None] if size is None else size |
| size = [size, size] if isinstance(size, int) else size |
| size = size.tolist() if torch.is_tensor(size) else size |
| size = list(size) if isinstance(size, tuple) else size |
| assert isinstance(size, list) |
| assert len(size) == 2 |
|
|
| kwargs = self.__collect__(edge_index, size, kwargs) |
|
|
| msg_kwargs = self.__distribute__(self.__msg_params__, kwargs) |
|
|
| m = self.message(**msg_kwargs) |
| aggr_kwargs = self.__distribute__(self.__aggr_params__, kwargs) |
| m = self.aggregate(m, **aggr_kwargs) |
|
|
| update_kwargs = self.__distribute__(self.__update_params__, kwargs) |
| m = self.update(m, **update_kwargs) |
|
|
| return m |
|
|
| def message(self, x_j): |
| r"""Constructs messages to node :math:`i` in analogy to |
| :math:`\phi_{\mathbf{\Theta}}` for each edge in |
| :math:`(j,i) \in \mathcal{E}` if :obj:`flow="source_to_target"` and |
| :math:`(i,j) \in \mathcal{E}` if :obj:`flow="target_to_source"`. |
| Can take any argument which was initially passed to :meth:`propagate`. |
| In addition, tensors passed to :meth:`propagate` can be mapped to the |
| respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or |
| :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`. |
| """ |
|
|
| return x_j |
|
|
| def aggregate(self, inputs, index, dim_size): |
| r"""Aggregates messages from neighbors as |
| :math:`\square_{j \in \mathcal{N}(i)}`. |
| |
| By default, delegates call to scatter functions that support |
| "add", "mean" and "max" operations specified in :meth:`__init__` by |
| the :obj:`aggr` argument. |
| """ |
|
|
| return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr) |
|
|
| def update(self, inputs): |
| r"""Updates node embeddings in analogy to |
| :math:`\gamma_{\mathbf{\Theta}}` for each node |
| :math:`i \in \mathcal{V}`. |
| Takes in the output of aggregation as first argument and any argument |
| which was initially passed to :meth:`propagate`. |
| """ |
|
|
| return inputs |
|
|
| class TransMXMNet(nn.Module): |
| def __init__(self, dim=128, n_layer=6, cutoff=5.0, num_spherical=7, num_radial=6, envelope_exponent=5): |
| super(TransMXMNet, self).__init__() |
|
|
| self.dim = dim |
| self.n_layer = n_layer |
| self.cutoff = cutoff |
|
|
| self.embeddings = nn.Parameter(torch.ones((5, self.dim))) |
|
|
| self.rbf_l = BesselBasisLayer(16, 5, envelope_exponent) |
| self.rbf_g = BesselBasisLayer(16, self.cutoff, envelope_exponent) |
| self.sbf = SphericalBasisLayer(num_spherical, num_radial, 5, envelope_exponent) |
|
|
| self.rbf_g_mlp = MLP([16, self.dim]) |
| self.rbf_l_mlp = MLP([16, self.dim]) |
|
|
| self.sbf_1_mlp = MLP([num_spherical * num_radial, self.dim]) |
| self.sbf_2_mlp = MLP([num_spherical * num_radial, self.dim]) |
|
|
| self.global_layers = torch.nn.ModuleList() |
| for layer in range(self.n_layer): |
| self.global_layers.append(Global_MP(self.dim)) |
|
|
| self.local_layers = torch.nn.ModuleList() |
| for layer in range(self.n_layer): |
| self.local_layers.append(Local_MP(self.dim)) |
|
|
| self.pos_embed = build_position_encoding(self.dim) |
| self.transformer = build_transformer(self.dim) |
|
|
| self.init() |
|
|
| def init(self): |
| stdv = math.sqrt(3) |
| self.embeddings.data.uniform_(-stdv, stdv) |
|
|
| def indices(self, edge_index, num_nodes): |
| row, col = edge_index |
|
|
| value = torch.arange(row.size(0), device=row.device) |
| adj_t = SparseTensor(row=col, col=row, value=value, |
| sparse_sizes=(num_nodes, num_nodes)) |
|
|
| |
| adj_t_row = adj_t[row] |
| num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) |
|
|
| idx_i = col.repeat_interleave(num_triplets) |
| idx_j = row.repeat_interleave(num_triplets) |
| idx_k = adj_t_row.storage.col() |
| mask = idx_i != idx_k |
| idx_i_1, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask] |
|
|
| idx_kj = adj_t_row.storage.value()[mask] |
| idx_ji_1 = adj_t_row.storage.row()[mask] |
|
|
| |
| adj_t_col = adj_t[col] |
|
|
| num_pairs = adj_t_col.set_value(None).sum(dim=1).to(torch.long) |
| idx_i_2 = row.repeat_interleave(num_pairs) |
| idx_j1 = col.repeat_interleave(num_pairs) |
| idx_j2 = adj_t_col.storage.col() |
|
|
| idx_ji_2 = adj_t_col.storage.row() |
| idx_jj = adj_t_col.storage.value() |
|
|
| return idx_i_1, idx_j, idx_k, idx_kj, idx_ji_1, idx_i_2, idx_j1, idx_j2, idx_jj, idx_ji_2 |
|
|
|
|
| def forward_features(self, data): |
| x = data.x |
| edge_index = data.edge_index |
| pos = data.pos |
| batch = data.batch |
| |
| h = torch.index_select(self.embeddings, 0, x.long()).unsqueeze(0) |
| data_len = torch.bincount(batch) |
| |
| diff_tensor = torch.diff(data_len) |
| indices = torch.nonzero(diff_tensor) + 1 |
| indices[0] = 0 |
|
|
| att_mask = torch.zeros(len(batch), len(batch)).cuda() |
|
|
| att_mask[indices[0]:, indices[0]:] = 1 |
| i = 0 |
| for i in range(0, h.size(0) - 1): |
| att_mask[indices[i]:indices[i + 1], indices[i]:indices[i + 1]] = 1 |
| att_mask[indices[i]:indices[-1], indices[i]:indices[-1]] = 1 |
|
|
| mask = torch.ones(1, len(batch)).bool().cuda() |
|
|
| pos_h = self.pos_embed(h, mask).cuda() |
| memory = self.transformer(h, ~mask, att_mask, pos_h) |
| h = memory.squeeze(0) |
|
|
| '''局部层-------------------------------------------------------------------------- |
| ''' |
| |
| edge_index_l, _ = remove_self_loops(edge_index) |
| j_l, i_l = edge_index_l |
| dist_l = (pos[i_l] - pos[j_l]).pow(2).sum(dim=-1).sqrt() |
|
|
| '''全局层-------------------------------------------------------------------------- |
| ''' |
| |
| |
| row, col = radius(pos, pos, self.cutoff, batch, batch, max_num_neighbors=500) |
| edge_index_g = torch.stack([row, col], dim=0) |
| edge_index_g, _ = remove_self_loops(edge_index_g) |
| j_g, i_g = edge_index_g |
| dist_g = (pos[i_g] - pos[j_g]).pow(2).sum(dim=-1).sqrt() |
|
|
| |
| idx_i_1, idx_j, idx_k, idx_kj, idx_ji, idx_i_2, idx_j1, idx_j2, idx_jj, idx_ji_2 = self.indices(edge_index_l, num_nodes=h.size(0)) |
|
|
| |
| pos_ji_1, pos_kj = pos[idx_j] - pos[idx_i_1], pos[idx_k] - pos[idx_j] |
| a = (pos_ji_1 * pos_kj).sum(dim=-1) |
| b = torch.cross(pos_ji_1, pos_kj).norm(dim=-1) |
| angle_1 = torch.atan2(b, a) |
|
|
| |
| pos_ji_2, pos_jj = pos[idx_j1] - pos[idx_i_2], pos[idx_j2] - pos[idx_j1] |
| a = (pos_ji_2 * pos_jj).sum(dim=-1) |
| b = torch.cross(pos_ji_2, pos_jj).norm(dim=-1) |
| angle_2 = torch.atan2(b, a) |
|
|
| |
| rbf_g = self.rbf_g(dist_g) |
| rbf_l = self.rbf_l(dist_l) |
| sbf_1 = self.sbf(dist_l, angle_1, idx_kj) |
| sbf_2 = self.sbf(dist_l, angle_2, idx_jj) |
|
|
| rbf_g = self.rbf_g_mlp(rbf_g) |
| rbf_l = self.rbf_l_mlp(rbf_l) |
| sbf_1 = self.sbf_1_mlp(sbf_1) |
| sbf_2 = self.sbf_2_mlp(sbf_2) |
|
|
| |
| node_sum = 0 |
|
|
| for layer in range(self.n_layer): |
| h = self.global_layers[layer](h, rbf_g, edge_index_g) |
| h, t = self.local_layers[layer](h, rbf_l, sbf_1, sbf_2, idx_kj, idx_ji, idx_jj, idx_ji_2, edge_index_l) |
| node_sum += t |
|
|
| |
| output = global_add_pool(node_sum, batch) |
| return output.view(-1) |
|
|
| def loss(self, pred, label): |
| pred, label = pred.reshape(-1), label.reshape(-1) |
| return F.mse_loss(pred, label) |
|
|
|
|
| class Global_MP(MessagePassing): |
|
|
| def __init__(self, dim): |
| super(Global_MP, self).__init__() |
| self.dim = dim |
|
|
| self.h_mlp = MLP([self.dim, self.dim]) |
|
|
| self.res1 = Res(self.dim) |
| self.res2 = Res(self.dim) |
| self.res3 = Res(self.dim) |
| self.mlp = MLP([self.dim, self.dim]) |
|
|
| self.x_edge_mlp = MLP([self.dim * 3, self.dim]) |
| self.linear = nn.Linear(self.dim, self.dim, bias=False) |
|
|
| def forward(self, h, edge_attr, edge_index): |
| edge_index, _ = add_self_loops(edge_index, num_nodes=h.size(0)) |
|
|
| res_h = h |
|
|
| |
| h = self.h_mlp(h) |
|
|
| |
| h = self.propagate(edge_index, x=h, num_nodes=h.size(0), edge_attr=edge_attr) |
|
|
| |
| h = self.res1(h) |
| h = self.mlp(h) + res_h |
| h = self.res2(h) |
| h = self.res3(h) |
|
|
| |
| h = self.propagate(edge_index, x=h, num_nodes=h.size(0), edge_attr=edge_attr) |
|
|
| return h |
|
|
| def message(self, x_i, x_j, edge_attr, edge_index, num_nodes): |
| num_edge = edge_attr.size()[0] |
|
|
| x_edge = torch.cat((x_i[:num_edge], x_j[:num_edge], edge_attr), -1) |
| x_edge = self.x_edge_mlp(x_edge) |
|
|
| x_j = torch.cat((self.linear(edge_attr) * x_edge, x_j[num_edge:]), dim=0) |
|
|
| return x_j |
|
|
| def update(self, aggr_out): |
| return aggr_out |
|
|
|
|
| class Local_MP(torch.nn.Module): |
| def __init__(self, dim): |
| super(Local_MP, self).__init__() |
| self.dim = dim |
|
|
| self.h_mlp = MLP([self.dim, self.dim]) |
|
|
| self.mlp_kj = MLP([3 * self.dim, self.dim]) |
| self.mlp_ji_1 = MLP([3 * self.dim, self.dim]) |
| self.mlp_ji_2 = MLP([self.dim, self.dim]) |
| self.mlp_jj = MLP([self.dim, self.dim]) |
|
|
| self.mlp_sbf1 = MLP([self.dim, self.dim, self.dim]) |
| self.mlp_sbf2 = MLP([self.dim, self.dim, self.dim]) |
| self.lin_rbf1 = nn.Linear(self.dim, self.dim, bias=False) |
| self.lin_rbf2 = nn.Linear(self.dim, self.dim, bias=False) |
|
|
| self.res1 = Res(self.dim) |
| self.res2 = Res(self.dim) |
| self.res3 = Res(self.dim) |
|
|
| self.lin_rbf_out = nn.Linear(self.dim, self.dim, bias=False) |
|
|
| self.h_mlp = MLP([self.dim, self.dim]) |
|
|
| self.y_mlp = MLP([self.dim, self.dim, self.dim, self.dim]) |
| self.y_W = nn.Linear(self.dim, 1) |
|
|
| def forward(self, h, rbf, sbf1, sbf2, idx_kj, idx_ji_1, idx_jj, idx_ji_2, edge_index, num_nodes=None): |
| res_h = h |
|
|
| |
| h = self.h_mlp(h) |
|
|
| |
| j, i = edge_index |
| m = torch.cat([h[i], h[j], rbf], dim=-1) |
|
|
| m_kj = self.mlp_kj(m) |
| m_kj = m_kj * self.lin_rbf1(rbf) |
| m_kj = m_kj[idx_kj] * self.mlp_sbf1(sbf1) |
| m_kj = scatter(m_kj, idx_ji_1, dim=0, dim_size=m.size(0), reduce='add') |
|
|
| m_ji_1 = self.mlp_ji_1(m) |
|
|
| m = m_ji_1 + m_kj |
|
|
| |
| m_jj = self.mlp_jj(m) |
| m_jj = m_jj * self.lin_rbf2(rbf) |
| m_jj = m_jj[idx_jj] * self.mlp_sbf2(sbf2) |
| m_jj = scatter(m_jj, idx_ji_2, dim=0, dim_size=m.size(0), reduce='add') |
|
|
| m_ji_2 = self.mlp_ji_2(m) |
|
|
| m = m_ji_2 + m_jj |
|
|
| |
| m = self.lin_rbf_out(rbf) * m |
| h = scatter(m, i, dim=0, dim_size=h.size(0), reduce='add') |
|
|
| |
| h = self.res1(h) |
| h = self.h_mlp(h) + res_h |
| h = self.res2(h) |
| h = self.res3(h) |
|
|
| |
| y = self.y_mlp(h) |
| y = self.y_W(y) |
|
|
| return h, y |
|
|
|
|
| class TransmxmConfig(PretrainedConfig): |
| model_type = "transmxm" |
|
|
| def __init__( |
| self, |
| dim: int=128, |
| n_layer: int=6, |
| cutoff: float=5.0, |
| num_spherical: int=7, |
| num_radial: int=6, |
| envelope_exponent: int=5, |
| |
| smiles: List[str] = None, |
| processor_class: str = "SmilesProcessor", |
| **kwargs, |
| ): |
|
|
| self.dim = dim |
| self.n_layer = n_layer |
| self.cutoff = cutoff |
| self.num_spherical = num_spherical |
| self.num_radial = num_radial |
| self.envelope_exponent = envelope_exponent |
|
|
| self.smiles = smiles |
| self.processor_class = processor_class |
|
|
|
|
| super().__init__(**kwargs) |
|
|
|
|
|
|
| class TransmxmModel(PreTrainedModel): |
| config_class = TransmxmConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
| self.backbone = TransMXMNet( |
| dim=config.dim, |
| n_layer=config.n_layer, |
| cutoff=config.cutoff, |
| num_spherical=config.num_spherical, |
| num_radial=config.num_radial, |
| envelope_exponent=config.envelope_exponent, |
| ) |
| self.process = SmilesDataset( |
| smiles=config.smiles, |
| ) |
|
|
| self.model = None |
| self.dataset = None |
| self.output = None |
| self.data_loader = None |
| self.pred_data = None |
|
|
| def forward(self, tensor): |
| return self.backbone.forward_features(tensor) |
|
|
| def SmilesProcessor(self, smiles): |
| return self.process.get_data(smiles) |
|
|
|
|
| def predict_smiles(self, smiles, device: str='cpu', result_dir: str='./', **kwargs): |
|
|
|
|
| batch_size = kwargs.pop('batch_size', 1) |
| shuffle = kwargs.pop('shuffle', False) |
| drop_last = kwargs.pop('drop_last', False) |
| num_workers = kwargs.pop('num_workers', 0) |
|
|
| self.model = AutoModel.from_pretrained("Huhujingjing/custom-transmxm", trust_remote_code=True).to(device) |
| self.model.eval() |
|
|
| self.dataset = self.process.get_data(smiles) |
| self.output = "" |
| self.output += ("predicted samples num: {}\n".format(len(self.dataset))) |
| self.output +=("predicted samples:{}\n".format(self.dataset[0])) |
| self.data_loader = DataLoader(self.dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| drop_last=drop_last, |
| num_workers=num_workers |
| ) |
| self.pred_data = { |
| 'smiles': [], |
| 'pred': [] |
| } |
|
|
| for batch in tqdm(self.data_loader): |
| batch = batch.to(device) |
| with torch.no_grad(): |
| self.pred_data['smiles'] += batch['smiles'] |
| self.pred_data['pred'] += self.model(batch).cpu().tolist() |
|
|
| pred = torch.tensor(self.pred_data['pred']).reshape(-1) |
| if device == 'cuda': |
| pred = pred.cpu().tolist() |
| self.pred_data['pred'] = pred |
| pred_df = pd.DataFrame(self.pred_data) |
| pred_df['pred'] = pred_df['pred'].apply(lambda x: round(x, 2)) |
| self.output +=('-' * 40 + '\n'+'predicted result: \n'+'{}\n'.format(pred_df)) |
| self.output +=('-' * 40) |
|
|
| pred_df.to_csv(os.path.join(result_dir, 'prediction.csv'), index=False) |
| self.output +=('\nsave predicted result to {}\n'.format(os.path.join(result_dir, 'prediction.csv'))) |
|
|
| return self.output |
|
|
|
|
| if __name__ == "__main__": |
|
|
| transmxm_config = TransmxmConfig.from_pretrained("custom-transmxm") |
|
|
| transmxmd = TransmxmModel(transmxm_config) |
| transmxmd.model.load_state_dict(torch.load(r'G:\Trans_MXM\runs\model.pt')) |
| transmxmd.save_pretrained("custom-transmxm") |
|
|
|
|