| | import argparse |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from Bio.PDB import PDBParser |
| | from rdkit import Chem |
| | import pandas as pd |
| | import random |
| | from torch_scatter import scatter_mean |
| | from openbabel import openbabel |
| | openbabel.obErrorLog.StopLogging() |
| |
|
| | import utils |
| | from lightning_modules import LigandPocketDDPM |
| | from constants import FLOAT_TYPE, INT_TYPE |
| | from analysis.molecule_builder import build_molecule, process_molecule |
| | from analysis.metrics import MoleculeProperties |
| |
|
| |
|
| | def prepare_from_sdf_files(sdf_files, atom_encoder): |
| |
|
| | ligand_coords = [] |
| | atom_one_hot = [] |
| | for file in sdf_files: |
| | rdmol = Chem.SDMolSupplier(str(file), sanitize=False)[0] |
| | ligand_coords.append( |
| | torch.from_numpy(rdmol.GetConformer().GetPositions()).float() |
| | ) |
| | types = torch.tensor([atom_encoder[a.GetSymbol()] for a in rdmol.GetAtoms()]) |
| | atom_one_hot.append( |
| | F.one_hot(types, num_classes=len(atom_encoder)) |
| | ) |
| |
|
| | return torch.cat(ligand_coords, dim=0), torch.cat(atom_one_hot, dim=0) |
| |
|
| |
|
| | def prepare_ligands_from_mols(mols, atom_encoder, device='cpu'): |
| |
|
| | ligand_coords = [] |
| | atom_one_hots = [] |
| | masks = [] |
| | sizes = [] |
| | for i, mol in enumerate(mols): |
| | coord = torch.tensor(mol.GetConformer().GetPositions(), dtype=FLOAT_TYPE) |
| | types = torch.tensor([atom_encoder[a.GetSymbol()] for a in mol.GetAtoms()], dtype=INT_TYPE) |
| | one_hot = F.one_hot(types, num_classes=len(atom_encoder)) |
| | mask = torch.ones(len(types), dtype=INT_TYPE) * i |
| | ligand_coords.append(coord) |
| | atom_one_hots.append(one_hot) |
| | masks.append(mask) |
| | sizes.append(len(types)) |
| |
|
| | ligand = { |
| | 'x': torch.cat(ligand_coords, dim=0).to(device), |
| | 'one_hot': torch.cat(atom_one_hots, dim=0).to(device), |
| | 'size': torch.tensor(sizes, dtype=INT_TYPE).to(device), |
| | 'mask': torch.cat(masks, dim=0).to(device), |
| | } |
| |
|
| | return ligand |
| |
|
| |
|
| | def prepare_ligand_from_pdb(biopython_atoms, atom_encoder): |
| |
|
| | coord = torch.tensor(np.array([a.get_coord() |
| | for a in biopython_atoms]), dtype=FLOAT_TYPE) |
| | types = torch.tensor([atom_encoder[a.element.capitalize()] |
| | for a in biopython_atoms]) |
| | one_hot = F.one_hot(types, num_classes=len(atom_encoder)) |
| |
|
| | return coord, one_hot |
| |
|
| |
|
| | def prepare_substructure(ref_ligand, fix_atoms, pdb_model): |
| |
|
| | if fix_atoms[0].endswith(".sdf"): |
| | |
| | coord, one_hot = prepare_from_sdf_files(fix_atoms, model.lig_type_encoder) |
| |
|
| | else: |
| | |
| | chain, resi = ref_ligand.split(':') |
| | ligand = utils.get_residue_with_resi(pdb_model[chain], int(resi)) |
| | fixed_atoms = [a for a in ligand.get_atoms() if a.get_name() in set(fix_atoms)] |
| | coord, one_hot = prepare_ligand_from_pdb(fixed_atoms, model.lig_type_encoder) |
| |
|
| | return coord, one_hot |
| |
|
| |
|
| | def diversify_ligands(model, pocket, mols, timesteps, |
| | sanitize=False, |
| | largest_frag=False, |
| | relax_iter=0): |
| | """ |
| | Diversify ligands for a specified pocket. |
| | |
| | Parameters: |
| | model: The model instance used for diversification. |
| | pocket: The pocket information including coordinates and types. |
| | mols: List of RDKit molecule objects to be diversified. |
| | timesteps: Number of denoising steps to apply during diversification. |
| | sanitize: If True, performs molecule sanitization post-generation (default: False). |
| | largest_frag: If True, only the largest fragment of the generated molecule is returned (default: False). |
| | relax_iter: Number of iterations for force field relaxation of the generated molecules (default: 0). |
| | |
| | Returns: |
| | A list of diversified RDKit molecule objects. |
| | """ |
| |
|
| | ligand = prepare_ligands_from_mols(mols, model.lig_type_encoder, device=model.device) |
| |
|
| | pocket_mask = pocket['mask'] |
| | lig_mask = ligand['mask'] |
| |
|
| | |
| | pocket_com_before = scatter_mean(pocket['x'], pocket['mask'], dim=0) |
| |
|
| | out_lig, out_pocket, _, _ = model.ddpm.diversify(ligand, pocket, noising_steps=timesteps) |
| |
|
| | |
| | pocket_com_after = scatter_mean(out_pocket[:, :model.x_dims], pocket_mask, dim=0) |
| |
|
| | out_pocket[:, :model.x_dims] += \ |
| | (pocket_com_before - pocket_com_after)[pocket_mask] |
| | out_lig[:, :model.x_dims] += \ |
| | (pocket_com_before - pocket_com_after)[lig_mask] |
| |
|
| | |
| | x = out_lig[:, :model.x_dims].detach().cpu() |
| | atom_type = out_lig[:, model.x_dims:].argmax(1).detach().cpu() |
| |
|
| | molecules = [] |
| | for mol_pc in zip(utils.batch_to_list(x, lig_mask), |
| | utils.batch_to_list(atom_type, lig_mask)): |
| |
|
| | mol = build_molecule(*mol_pc, model.dataset_info, add_coords=True) |
| | mol = process_molecule(mol, |
| | add_hydrogens=False, |
| | sanitize=sanitize, |
| | relax_iter=relax_iter, |
| | largest_frag=largest_frag) |
| | if mol is not None: |
| | molecules.append(mol) |
| |
|
| | return molecules |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--checkpoint', type=Path, default='checkpoints/crossdocked_fullatom_cond.ckpt') |
| | parser.add_argument('--pdbfile', type=str, default='example/5ndu.pdb') |
| | parser.add_argument('--ref_ligand', type=str, default='example/5ndu_linked_mols.sdf') |
| | parser.add_argument('--objective', type=str, default='sa', choices={'qed', 'sa'}) |
| | parser.add_argument('--timesteps', type=int, default=100) |
| | parser.add_argument('--population_size', type=int, default=100) |
| | parser.add_argument('--evolution_steps', type=int, default=10) |
| | parser.add_argument('--top_k', type=int, default=7) |
| | parser.add_argument('--outfile', type=Path, default='output.sdf') |
| | parser.add_argument('--relax', action='store_true') |
| |
|
| |
|
| | args = parser.parse_args() |
| |
|
| | pdb_id = Path(args.pdbfile).stem |
| |
|
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | population_size = args.population_size |
| | evolution_steps = args.evolution_steps |
| | top_k = args.top_k |
| |
|
| | |
| | model = LigandPocketDDPM.load_from_checkpoint( |
| | args.checkpoint, map_location=device) |
| | model = model.to(device) |
| |
|
| | |
| | |
| | pdb_model = PDBParser(QUIET=True).get_structure('', args.pdbfile)[0] |
| | |
| | residues = utils.get_pocket_from_ligand(pdb_model, args.ref_ligand) |
| | pocket = model.prepare_pocket(residues, repeats=population_size) |
| |
|
| |
|
| | if args.objective == 'qed': |
| | objective_function = MoleculeProperties().calculate_qed |
| | elif args.objective == 'sa': |
| | objective_function = MoleculeProperties().calculate_sa |
| | else: |
| | |
| | |
| | raise ValueError(f"Objective function {args.objective} not recognized.") |
| |
|
| | ref_mol = Chem.SDMolSupplier(args.ref_ligand)[0] |
| |
|
| | |
| | buffer = pd.DataFrame(columns=['generation', 'score', 'fate' 'mol', 'smiles']) |
| |
|
| | |
| | buffer = buffer.append({'generation': 0, |
| | 'score': objective_function(ref_mol), |
| | 'fate': 'initial', 'mol': ref_mol, |
| | 'smiles': Chem.MolToSmiles(ref_mol)}, ignore_index=True) |
| |
|
| | for generation_idx in range(evolution_steps): |
| |
|
| | if generation_idx == 0: |
| | molecules = buffer['mol'].tolist() * population_size |
| | else: |
| | |
| | previous_gen = buffer[buffer['generation'] == generation_idx] |
| | top_k_molecules = previous_gen.nlargest(top_k, 'score')['mol'].tolist() |
| | molecules = top_k_molecules * (population_size // top_k) |
| |
|
| | |
| | buffer.loc[buffer['generation'] == generation_idx, 'fate'] = 'survived' |
| |
|
| | |
| | if len(molecules) < population_size: |
| | molecules += [random.choice(molecules) for _ in range(population_size - len(molecules))] |
| |
|
| |
|
| | |
| | assert len(molecules) == population_size, f"Wrong number of molecules: {len(molecules)} when it should be {population_size}" |
| | print(f"Generation {generation_idx}, mean score: {np.mean([objective_function(mol) for mol in molecules])}") |
| | molecules = diversify_ligands(model, |
| | pocket, |
| | molecules, |
| | timesteps=args.timesteps, |
| | sanitize=True, |
| | relax_iter=(200 if args.relax else 0)) |
| | |
| | |
| | |
| | for mol in molecules: |
| | buffer = buffer.append({'generation': generation_idx + 1, |
| | 'score': objective_function(mol), |
| | 'fate': 'purged', |
| | 'mol': mol, |
| | 'smiles': Chem.MolToSmiles(mol)}, ignore_index=True) |
| |
|
| |
|
| | |
| | utils.write_sdf_file(args.outfile, molecules) |
| | |
| | buffer.drop(columns=['mol']) |
| | buffer.to_csv(args.outfile.with_suffix('.csv')) |
| |
|