Spaces:
Running
Running
Commit
·
9dd777e
1
Parent(s):
7ca0099
Setup the spaces app
Browse files- README.md +8 -5
- protac_splitter/__init__.py +11 -0
- protac_splitter/chemoinformatics.py +487 -0
- protac_splitter/data/__init__.py +0 -0
- protac_splitter/data/curation/__init__.py +11 -0
- protac_splitter/data/curation/bond_adjustments.py +407 -0
- protac_splitter/data/curation/curation.py +894 -0
- protac_splitter/data/curation/mapping_utils.py +77 -0
- protac_splitter/data/curation/substructure_extraction.py +586 -0
- protac_splitter/data/generation/__init__.py +11 -0
- protac_splitter/data/generation/functional_groups.py +400 -0
- protac_splitter/data/generation/generation.py +277 -0
- protac_splitter/display_utils.py +199 -0
- protac_splitter/drawing_utils.py +177 -0
- protac_splitter/evaluation.py +495 -0
- protac_splitter/fixing_functions.py +355 -0
- protac_splitter/graphs/README.md +114 -0
- protac_splitter/graphs/__init__.py +0 -0
- protac_splitter/graphs/e3_clustering.py +321 -0
- protac_splitter/graphs/edge_classifier.py +582 -0
- protac_splitter/graphs/edge_features.py +293 -0
- protac_splitter/graphs/splitting_algorithms.py +512 -0
- protac_splitter/graphs/utils.py +67 -0
- protac_splitter/graphs_utils.py +190 -0
- protac_splitter/llms/__init__.py +0 -0
- protac_splitter/llms/data_utils.py +296 -0
- protac_splitter/llms/evaluation.py +169 -0
- protac_splitter/llms/hf_utils.py +36 -0
- protac_splitter/llms/model_utils.py +256 -0
- protac_splitter/llms/training.py +869 -0
- protac_splitter/llms/training_causal_model.py +87 -0
- protac_splitter/llms/training_mlm_model.py +287 -0
- protac_splitter/llms/training_rl_models.py +406 -0
- protac_splitter/protac_cheminformatics.py +120 -0
- protac_splitter/protac_splitter.py +370 -0
- protac_splitter_app.py +351 -0
- requirements.txt +138 -0
README.md
CHANGED
|
@@ -1,14 +1,17 @@
|
|
| 1 |
---
|
| 2 |
-
title: PROTAC
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.35.0
|
| 8 |
-
|
|
|
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
short_description: App to split given PROTACs into their substructures.
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: PROTAC-Splitter
|
| 3 |
+
emoji: ✂️
|
| 4 |
+
colorFrom: green
|
| 5 |
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.35.0
|
| 8 |
+
python_version: 3.10
|
| 9 |
+
app_file: protac_splitter_app.py
|
| 10 |
pinned: false
|
| 11 |
license: mit
|
| 12 |
short_description: App to split given PROTACs into their substructures.
|
| 13 |
---
|
| 14 |
|
| 15 |
+
# PROTAC-Splitter
|
| 16 |
+
|
| 17 |
+
This repository contains a program to split PROTAC molecules into their substructures.
|
protac_splitter/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" PROTAC Splitter package for splitting PROTAC SMILES into substructures."""
|
| 2 |
+
from protac_splitter.protac_splitter import split_protac
|
| 3 |
+
from protac_splitter.fixing_functions import fix_prediction
|
| 4 |
+
from protac_splitter.graphs.splitting_algorithms import split_protac_graph_based
|
| 5 |
+
from protac_splitter.evaluation import (
|
| 6 |
+
check_reassembly,
|
| 7 |
+
split_prediction,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
__version__ = "1.0.0"
|
| 11 |
+
__author__ = "Stefano Ribes and Anders Källberg"
|
protac_splitter/chemoinformatics.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Chemoinformatics utilities for PROTAC Splitter. """
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List, Union, Optional, Literal
|
| 4 |
+
from multiprocessing import Process, Queue
|
| 5 |
+
from hashlib import sha256
|
| 6 |
+
|
| 7 |
+
from rdkit import Chem
|
| 8 |
+
from rdkit.Chem import rdFingerprintGenerator
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def GetSubstructMatchesWorker(q, mol, substruct, useChirality, maxMatches):
|
| 12 |
+
""" Worker function to get substructure matches in a separate process. """
|
| 13 |
+
q.put(list(mol.GetSubstructMatches(
|
| 14 |
+
substruct,
|
| 15 |
+
useChirality=useChirality,
|
| 16 |
+
maxMatches=maxMatches,
|
| 17 |
+
)))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def GetSubstructMatchesWithTimeout(
|
| 21 |
+
mol: Chem.Mol,
|
| 22 |
+
substruct: Chem.Mol,
|
| 23 |
+
useChirality: bool = True,
|
| 24 |
+
maxMatches: int = 50,
|
| 25 |
+
timeout: Union[int, float] = 10,
|
| 26 |
+
) -> Optional[List[List[int]]]:
|
| 27 |
+
""" Get substructure matches with a timeout.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
mol (Chem.Mol): The molecule to search for substructure matches.
|
| 31 |
+
substruct (Chem.Mol): The substructure to search for in the molecule.
|
| 32 |
+
useChirality (bool, optional): Whether to use chirality in the substructure search. Defaults to True.
|
| 33 |
+
maxMatches (int, optional): The maximum number of matches to return. Defaults to 50.
|
| 34 |
+
timeout (int | float, optional): The timeout in seconds. Defaults to 10.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Optional[List[List[int]]]: A list of lists containing the atom indices of the substructure matches. Returns None if the search times out or failed.
|
| 38 |
+
"""
|
| 39 |
+
q = Queue()
|
| 40 |
+
p = Process(
|
| 41 |
+
target=GetSubstructMatchesWorker,
|
| 42 |
+
args=(q, mol, substruct, useChirality, maxMatches),
|
| 43 |
+
)
|
| 44 |
+
p.start()
|
| 45 |
+
p.join(timeout)
|
| 46 |
+
|
| 47 |
+
if p.is_alive():
|
| 48 |
+
p.terminate()
|
| 49 |
+
p.join()
|
| 50 |
+
return None
|
| 51 |
+
return q.get()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def remove_stereo(smiles: str) -> str:
|
| 55 |
+
"""
|
| 56 |
+
Remove stereochemistry from a SMILES string.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
smiles (str): The input SMILES string.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
str: The SMILES string with stereochemistry removed.
|
| 63 |
+
"""
|
| 64 |
+
try:
|
| 65 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 66 |
+
Chem.rdmolops.RemoveStereochemistry(mol)
|
| 67 |
+
return Chem.MolToSmiles(mol)
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logging.warning(f"Error removing stereochemistry: {e}")
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_mol(smiles: str, remove_stereo: bool = False) -> Chem.Mol:
|
| 74 |
+
"""
|
| 75 |
+
Get a molecule object from a SMILES string.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
smiles (str): The SMILES string representing the molecule.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Chem.Mol: The molecule object.
|
| 82 |
+
"""
|
| 83 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 84 |
+
if mol is not None and remove_stereo:
|
| 85 |
+
Chem.rdmolops.RemoveStereochemistry(mol)
|
| 86 |
+
return mol
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def canonize_smarts(smarts: str) -> str:
|
| 90 |
+
"""
|
| 91 |
+
Cleans a SMARTS string by converting it to canonical SMARTS representation.
|
| 92 |
+
|
| 93 |
+
NOTE: It might not work for complex patterns: https://github.com/rdkit/rdkit/discussions/6929
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
smarts (str): The input SMARTS string.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
str: The cleaned SMARTS string.
|
| 100 |
+
"""
|
| 101 |
+
mol = Chem.MolFromSmarts(smarts)
|
| 102 |
+
|
| 103 |
+
if mol is None:
|
| 104 |
+
return None
|
| 105 |
+
canonical_smarts = Chem.MolToSmarts(Chem.MolFromSmiles(Chem.MolToSmiles(mol), sanitize=False))
|
| 106 |
+
return canonical_smarts
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def smiles2mol(smiles: str) -> Chem.Mol:
|
| 110 |
+
"""Converts a SMILES string to an RDKit molecule object.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
smiles (str): The input SMILES string.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Chem.Mol: The RDKit molecule object.
|
| 117 |
+
"""
|
| 118 |
+
return Chem.MolFromSmiles(smiles)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def mol2smiles(mol: Chem.Mol) -> str:
|
| 122 |
+
"""Converts an RDKit molecule object to a SMILES string.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
mol (Chem.Mol): The RDKit molecule object.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
str: The SMILES string.
|
| 129 |
+
"""
|
| 130 |
+
return Chem.MolToSmiles(mol)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def canonize_smiles(smiles: str) -> str:
|
| 134 |
+
""" Canonizes a SMILES string by converting it to canonical SMILES representation.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
smiles (str): The input SMILES string.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
str: The canonized SMILES string.
|
| 141 |
+
"""
|
| 142 |
+
if smiles is None:
|
| 143 |
+
return None
|
| 144 |
+
try:
|
| 145 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"Error: {e}")
|
| 148 |
+
return None
|
| 149 |
+
if mol is None:
|
| 150 |
+
return None
|
| 151 |
+
try:
|
| 152 |
+
return Chem.MolToSmiles(mol, canonical=True)
|
| 153 |
+
except:
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def canonize(x: Union[str, Chem.Mol]) -> Union[str, Chem.Mol]:
|
| 158 |
+
""" Canonizes a SMILES string or RDKit molecule object.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
x: The input SMILES string or RDKit molecule object.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
str | Chem.Mol: The canonized SMILES string or RDKit molecule object, according to the input type.
|
| 165 |
+
"""
|
| 166 |
+
if x is None:
|
| 167 |
+
return None
|
| 168 |
+
if isinstance(x, str):
|
| 169 |
+
return canonize_smiles(x)
|
| 170 |
+
return Chem.MolFromSmiles(Chem.MolToSmiles(x, canonical=True))
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def compute_RDKitFP(
|
| 174 |
+
smiles: Union[str, List[str], List[Chem.Mol]],
|
| 175 |
+
maxPath: int = 7,
|
| 176 |
+
fpSize: int = 2048,
|
| 177 |
+
) -> List[Chem.RDKFingerprint]:
|
| 178 |
+
"""
|
| 179 |
+
Compute RDKit fingerprints for a given list of SMILES strings or RDKit molecules.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
smiles (Union[str, List[str], List[Chem.Mol]]): A single SMILES string or a list of SMILES strings
|
| 183 |
+
or a list of RDKit molecules.
|
| 184 |
+
maxPath (int, optional): The maximum path length for the fingerprints. Defaults to 7.
|
| 185 |
+
fpSize (int, optional): The size of the fingerprint vector. Defaults to 2048.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
List[Chem.RDKFingerprint]: A list of RDKit fingerprints computed from the input SMILES strings or molecules.
|
| 189 |
+
"""
|
| 190 |
+
if isinstance(smiles[0], str):
|
| 191 |
+
mols = [get_mol(smi) for smi in smiles]
|
| 192 |
+
else:
|
| 193 |
+
mols = smiles # assume mols were fed instead
|
| 194 |
+
rdgen = rdFingerprintGenerator.GetRDKitFPGenerator(
|
| 195 |
+
maxPath=maxPath, fpSize=fpSize)
|
| 196 |
+
fps = [rdgen.GetCountFingerprint(mol) for mol in mols]
|
| 197 |
+
return fps
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def remove_dummy_atoms(mol: Union[str, Chem.Mol], canonical=True) -> Union[str, Chem.Mol]:
|
| 201 |
+
"""
|
| 202 |
+
Removes all dummy atoms (attachment points) from a molecule.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
mol: RDKit Mol object with dummy atoms.
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
A new RDKit Mol object without dummy atoms.
|
| 209 |
+
"""
|
| 210 |
+
return_smiles = False
|
| 211 |
+
if isinstance(mol, str):
|
| 212 |
+
return_smiles = True
|
| 213 |
+
mol = Chem.MolFromSmiles(mol)
|
| 214 |
+
|
| 215 |
+
if mol is None:
|
| 216 |
+
return None
|
| 217 |
+
|
| 218 |
+
# Remove all dummy atoms with a query
|
| 219 |
+
mol_no_dummy = Chem.DeleteSubstructs(mol, Chem.MolFromSmarts('[#0]'))
|
| 220 |
+
|
| 221 |
+
if mol_no_dummy is None:
|
| 222 |
+
# --------------------------------------------------------------------------
|
| 223 |
+
# Other approach: editing molecule and removing dummy atoms
|
| 224 |
+
# --------------------------------------------------------------------------
|
| 225 |
+
# Create an editable molecule to remove atoms
|
| 226 |
+
editable_mol = Chem.EditableMol(mol)
|
| 227 |
+
|
| 228 |
+
# List of atoms to remove (dummy atoms have atomic number 0)
|
| 229 |
+
dummy_atoms = [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetAtomicNum() == 0]
|
| 230 |
+
|
| 231 |
+
# Remove dummy atoms
|
| 232 |
+
for atom_idx in sorted(dummy_atoms, reverse=True): # Remove from the highest index to avoid index shifts
|
| 233 |
+
editable_mol.RemoveAtom(atom_idx)
|
| 234 |
+
|
| 235 |
+
if editable_mol is None:
|
| 236 |
+
return None
|
| 237 |
+
|
| 238 |
+
# Return the modified molecule
|
| 239 |
+
if return_smiles:
|
| 240 |
+
return Chem.MolToSmiles(editable_mol.GetMol())
|
| 241 |
+
editable_mol = editable_mol.GetMol()
|
| 242 |
+
editable_mol.UpdatePropertyCache()
|
| 243 |
+
return editable_mol
|
| 244 |
+
# --------------------------------------------------------------------------
|
| 245 |
+
|
| 246 |
+
# Return the modified molecule
|
| 247 |
+
if return_smiles:
|
| 248 |
+
return Chem.MolToSmiles(mol_no_dummy, canonical=canonical)
|
| 249 |
+
return mol_no_dummy
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def dummy2query(mol: Chem.Mol) -> Chem.Mol:
|
| 253 |
+
""" Converts dummy atoms to query atoms, so that a molecule with attachment points can be used in HasSubstructMatch.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
mol: The molecule to convert.
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
The molecule with dummy atoms converted to query atoms
|
| 260 |
+
"""
|
| 261 |
+
if mol is None:
|
| 262 |
+
return None
|
| 263 |
+
p = Chem.AdjustQueryParameters.NoAdjustments()
|
| 264 |
+
p.makeDummiesQueries = True
|
| 265 |
+
return Chem.AdjustQueryProperties(mol, p)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def get_substr_match(
|
| 269 |
+
protac_mol: Chem.Mol,
|
| 270 |
+
substr: Chem.Mol,
|
| 271 |
+
max_allowed_fragments: int = 1,
|
| 272 |
+
replace: Literal['core', 'sidechains'] = 'core',
|
| 273 |
+
useChirality: bool = True,
|
| 274 |
+
) -> bool:
|
| 275 |
+
""" Check if a molecule contains a substructure match with a given molecule.
|
| 276 |
+
Compared to RDKit HasSubstructMatch, this function also checks the number of fragments when replacing the substr in the PROTAC.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
protac_mol (Chem.Mol): The PROTAC molecule.
|
| 280 |
+
substr (Chem.Mol): The substructure molecule.
|
| 281 |
+
max_allowed_fragments (int, optional): The maximum number of fragments allowed when replacing the substr in the PROTAC. Defaults to 1. Example when equal to 1: if removing the warhead, a single fragment should remain.
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
bool: True if the PROTAC contains a substructure match with the given molecule and the fragments count is equal, False otherwise.
|
| 285 |
+
"""
|
| 286 |
+
# Count the number of fragments when replacing the substr in the PROTAC
|
| 287 |
+
if replace == 'core':
|
| 288 |
+
fragments = Chem.ReplaceCore(protac_mol, dummy2query(substr), useChirality=useChirality)
|
| 289 |
+
elif replace == 'sidechains':
|
| 290 |
+
fragments = Chem.ReplaceSidechains(protac_mol, dummy2query(substr), useChirality=useChirality)
|
| 291 |
+
else:
|
| 292 |
+
raise ValueError(f"replace argument should be either 'core' or 'sidechains', provided: {replace}")
|
| 293 |
+
# Check if the number of fragments is equal to the max allowed fragments
|
| 294 |
+
if fragments is None:
|
| 295 |
+
return False
|
| 296 |
+
try:
|
| 297 |
+
fragments = Chem.GetMolFrags(fragments, sanitizeFrags=False)
|
| 298 |
+
except Exception as e:
|
| 299 |
+
print(e)
|
| 300 |
+
return False
|
| 301 |
+
return len(fragments) == max_allowed_fragments
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def remove_attach_atom(mol: Chem.Mol, attach_id: int, sanitize: bool = False) -> Chem.Mol:
|
| 305 |
+
""" Removes the atom with the specified attachment id from the molecule.
|
| 306 |
+
|
| 307 |
+
Example:
|
| 308 |
+
|
| 309 |
+
>>> remove_attach_atom(Chem.MolFromSmiles('CC[*:1]'), 1)
|
| 310 |
+
CC
|
| 311 |
+
|
| 312 |
+
There are no checks on the molecule, so it is assumed it is not None.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
mol (Chem.Mol): The molecule.
|
| 316 |
+
attach_id (int): The attachment id of the atom to remove.
|
| 317 |
+
sanitize (bool, optional): Whether to sanitize the molecule after removing the atom. When used in `fix_prediction` function, it is used to "remove" substructures, so there is no need to have them sanitized. Default: False.
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
(Chem.Mol) The molecule with the atom removed.
|
| 321 |
+
"""
|
| 322 |
+
atoms_to_remove = []
|
| 323 |
+
for atom in mol.GetAtoms():
|
| 324 |
+
if atom.GetAtomicNum() == 0: # Dummy atom
|
| 325 |
+
map_num = atom.GetAtomMapNum()
|
| 326 |
+
if map_num == attach_id: # Targeting only [*:attach_id]
|
| 327 |
+
atoms_to_remove.append(atom.GetIdx())
|
| 328 |
+
|
| 329 |
+
# Remove atoms using an EditableMol
|
| 330 |
+
editable_mol = Chem.EditableMol(mol)
|
| 331 |
+
for idx in sorted(atoms_to_remove, reverse=True): # Remove from highest index to avoid shifting
|
| 332 |
+
editable_mol.RemoveAtom(idx)
|
| 333 |
+
|
| 334 |
+
# Convert back to a molecule
|
| 335 |
+
new_mol = editable_mol.GetMol()
|
| 336 |
+
if sanitize:
|
| 337 |
+
Chem.SanitizeMol(new_mol)
|
| 338 |
+
return new_mol
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def get_bond_idx(smi: str, bonds_start_end_atoms: List[List[int]]) -> List[int]:
|
| 342 |
+
"""
|
| 343 |
+
Get the indices of bonds in a molecule that match the given start and end atom indices.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
smi (str): The SMILES representation of the molecule.
|
| 347 |
+
bonds_start_end_atoms (List[List[int]]): A list of lists containing the start and end atom indices of the bonds to search for.
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
List[int]: A list of bond indices that match the given start and end atom indices.
|
| 351 |
+
"""
|
| 352 |
+
mol = Chem.MolFromSmiles(smi)
|
| 353 |
+
|
| 354 |
+
bond_indices = []
|
| 355 |
+
|
| 356 |
+
for bond in mol.GetBonds():
|
| 357 |
+
begin_idx = bond.GetBeginAtomIdx()
|
| 358 |
+
end_idx = bond.GetEndAtomIdx()
|
| 359 |
+
|
| 360 |
+
if [begin_idx, end_idx] in bonds_start_end_atoms or [end_idx, begin_idx] in bonds_start_end_atoms:
|
| 361 |
+
bond_indices.append(bond.GetIdx())
|
| 362 |
+
elif (begin_idx, end_idx) in bonds_start_end_atoms or (end_idx, begin_idx) in bonds_start_end_atoms:
|
| 363 |
+
bond_indices.append(bond.GetIdx())
|
| 364 |
+
|
| 365 |
+
return bond_indices
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def get_mol_id(smiles: str) -> str | None:
|
| 369 |
+
""" Get the Hash of a given SMILES string.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
smiles (str): The SMILES string to hash.
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
str | None: The Hash of the SMILES string. None if the function failed.
|
| 376 |
+
"""
|
| 377 |
+
try:
|
| 378 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 379 |
+
if mol is None:
|
| 380 |
+
return None
|
| 381 |
+
Chem.RemoveStereochemistry(mol)
|
| 382 |
+
except Exception as e:
|
| 383 |
+
logging.warning(f"Error while removing stereochemistry: {e}")
|
| 384 |
+
logging.warning(f"SMILES: {smiles}")
|
| 385 |
+
return None
|
| 386 |
+
|
| 387 |
+
# Get the InChIKey for the molecule
|
| 388 |
+
inchi_key = Chem.MolToInchiKey(mol)
|
| 389 |
+
smiles = Chem.MolToSmiles(mol, canonical=True)
|
| 390 |
+
|
| 391 |
+
# Encode the InChIKey and SMILES to create a unique identifier
|
| 392 |
+
return sha256((inchi_key + smiles).encode()).hexdigest()
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def get_atom_idx_at_attachment(
|
| 396 |
+
protac: Chem.Mol,
|
| 397 |
+
substruct: Chem.Mol,
|
| 398 |
+
linker: Optional[Chem.Mol] = None,
|
| 399 |
+
timeout: Optional[Union[int, float]] = None,
|
| 400 |
+
return_dict: bool = False,
|
| 401 |
+
verbose: int = 0,
|
| 402 |
+
) -> List[int]:
|
| 403 |
+
""" Get the atom index of the attachment point of a substructure in the PROTAC molecule.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
protac: The PROTAC molecule.
|
| 407 |
+
substruct: The substructure of the PROTAC that contains the attachment point, e.g., the POI or E3 ligase.
|
| 408 |
+
linker: The linker molecule.
|
| 409 |
+
verbose: Verbosity level.
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
List[int]: The two atom indices at the attachment point.
|
| 413 |
+
"""
|
| 414 |
+
if linker is None:
|
| 415 |
+
# Get the "other" substructure, i.e., replace side chain of PROTAC using the substruct
|
| 416 |
+
linker = Chem.DeleteSubstructs(protac, remove_dummy_atoms(substruct), useChirality=True)
|
| 417 |
+
if timeout is None:
|
| 418 |
+
timeout = 60
|
| 419 |
+
logging.warning(f'No timeout set when linker is not provided, using default value of {timeout} seconds.')
|
| 420 |
+
|
| 421 |
+
substruct_match = set(protac.GetSubstructMatch(dummy2query(substruct), useChirality=True))
|
| 422 |
+
if verbose:
|
| 423 |
+
print(f'Substruct match: {substruct_match}')
|
| 424 |
+
|
| 425 |
+
linker_no_dummy = remove_dummy_atoms(linker)
|
| 426 |
+
if verbose:
|
| 427 |
+
print(f'Linker without dummy atoms found.')
|
| 428 |
+
|
| 429 |
+
max_matches = 2
|
| 430 |
+
linker_match = set()
|
| 431 |
+
shared_atoms = set()
|
| 432 |
+
|
| 433 |
+
# NOTE: The following is a hacky way to speed up the search for linker
|
| 434 |
+
# matches. In fact, the linker can be quite short, so it might match in
|
| 435 |
+
# multiple places of the PROTAC molecule.
|
| 436 |
+
# If the number of max matches in GetSubstructMatches is low, then the
|
| 437 |
+
# search tends to be faster, but imprecise. However, we are interested in
|
| 438 |
+
# the interesection of the matches, so we can progressively increase the
|
| 439 |
+
# number of max matches until we find a single atom in common.
|
| 440 |
+
while len(shared_atoms) != 1 and max_matches <= 50:
|
| 441 |
+
if timeout is None:
|
| 442 |
+
linker_matches = list(protac.GetSubstructMatches(linker_no_dummy, useChirality=True, maxMatches=max_matches))
|
| 443 |
+
else:
|
| 444 |
+
linker_matches = GetSubstructMatchesWithTimeout(protac, linker_no_dummy, useChirality=True, maxMatches=max_matches, timeout=timeout)
|
| 445 |
+
if verbose:
|
| 446 |
+
print(f'Linker matches: {linker_matches}')
|
| 447 |
+
|
| 448 |
+
if not linker_matches:
|
| 449 |
+
# return None
|
| 450 |
+
linker_match = set()
|
| 451 |
+
shared_atoms = set()
|
| 452 |
+
max_matches += 1
|
| 453 |
+
continue
|
| 454 |
+
|
| 455 |
+
for match in linker_matches:
|
| 456 |
+
shared_atoms = set(match) & set(substruct_match)
|
| 457 |
+
linker_match = match
|
| 458 |
+
if len(shared_atoms) == 1:
|
| 459 |
+
if verbose:
|
| 460 |
+
print(f'Shared atoms: {list(shared_atoms)}')
|
| 461 |
+
break
|
| 462 |
+
|
| 463 |
+
if len(shared_atoms) != 1:
|
| 464 |
+
linker_match = set()
|
| 465 |
+
shared_atoms = set()
|
| 466 |
+
max_matches += 1
|
| 467 |
+
|
| 468 |
+
if not shared_atoms:
|
| 469 |
+
if verbose:
|
| 470 |
+
print('No shared atoms found.')
|
| 471 |
+
return None
|
| 472 |
+
|
| 473 |
+
attachment_idx = list(shared_atoms)
|
| 474 |
+
attachments = {'substruct': attachment_idx[0]}
|
| 475 |
+
|
| 476 |
+
# Get the other atom at the attachment point that is NOT in the linker
|
| 477 |
+
for neighbor in protac.GetAtomWithIdx(attachment_idx[0]).GetNeighbors():
|
| 478 |
+
if neighbor.GetIdx() not in linker_match:
|
| 479 |
+
attachment_idx.append(neighbor.GetIdx())
|
| 480 |
+
attachments['linker'] = neighbor.GetIdx()
|
| 481 |
+
break
|
| 482 |
+
|
| 483 |
+
if return_dict:
|
| 484 |
+
return attachments
|
| 485 |
+
return attachment_idx
|
| 486 |
+
|
| 487 |
+
|
protac_splitter/data/__init__.py
ADDED
|
File without changes
|
protac_splitter/data/curation/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mapping_utils import update_dictionary
|
| 2 |
+
from .curation import (
|
| 3 |
+
split_protacs,
|
| 4 |
+
iterative_protac_splitting,
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
'update_dictionary',
|
| 9 |
+
'split_protacs',
|
| 10 |
+
'iterative_protac_splitting',
|
| 11 |
+
]
|
protac_splitter/data/curation/bond_adjustments.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Adjusts amide and ester bonds in PROTAC substructures. """
|
| 2 |
+
from typing import Tuple, Dict
|
| 3 |
+
|
| 4 |
+
from rdkit import Chem
|
| 5 |
+
|
| 6 |
+
from protac_splitter.chemoinformatics import (
|
| 7 |
+
dummy2query,
|
| 8 |
+
canonize,
|
| 9 |
+
)
|
| 10 |
+
from protac_splitter.display_utils import display_mol
|
| 11 |
+
from protac_splitter.evaluation import check_reassembly
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def adjust_amide_bond(
|
| 15 |
+
substruct: Chem.Mol,
|
| 16 |
+
linker: Chem.Mol,
|
| 17 |
+
substruct_attachment_id: int,
|
| 18 |
+
verbose: int = 0,
|
| 19 |
+
) -> Tuple[Chem.Mol, Chem.Mol]:
|
| 20 |
+
"""
|
| 21 |
+
Adjust the amide bond between the substruct and linker substructure.
|
| 22 |
+
Handles the case when neighboring atoms of the amide bond are dummy atoms, which represent attachment points.
|
| 23 |
+
The linker will be modified with the required additional atoms.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
substruct: The substructure of the substruct (protein of interest) that contains the amide bond.
|
| 27 |
+
linker: The linker molecule that connects substruct to the E3 ligase.
|
| 28 |
+
substruct_attachment_id: The attachment point ID in the substruct substructure. E.g., 1 for the POI, as in "[*:1]".
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Tuple[Chem.Mol, Chem.Mol]: The adjusted substruct and linker molecules, in that order.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
# Pseudo-code of the algorithm:
|
| 35 |
+
"""
|
| 36 |
+
```python
|
| 37 |
+
# Check if the amide bond (N-C=O) is in the substructure
|
| 38 |
+
if "N-C(=O)" in substruct:
|
| 39 |
+
if neighbor("N-C(=O)") == "[*:substruct]":
|
| 40 |
+
# If the neighboring atom of the amide bond is a dummy atom, i.e., attachment point
|
| 41 |
+
mark_protac_as_wrong("[PROTAC]")
|
| 42 |
+
|
| 43 |
+
# Identify the bond to split, i.e., the nitrogen-carbon bond, and split
|
| 44 |
+
"[*:substruct]-[<optional neighboring atom>]-N-[*:tmp]", "[*:tmp]-C(=O)-[rest of the PROTAC]" = split_PROTAC_at("N-C")
|
| 45 |
+
|
| 46 |
+
"[Linker]-N-[*:tmp]" = join("[Linker]-[*:substruct]", "[*:substruct]-N-[*:tmp]")
|
| 47 |
+
|
| 48 |
+
rename_attachment_point("[*:tmp]-C(=O)-[rest of the PROTAC]")
|
| 49 |
+
rename_attachment_point("[Linker]-N-[*:tmp]")
|
| 50 |
+
|
| 51 |
+
elif neighbor(neighbor("N-C(=O)")) == "[*:substruct]":
|
| 52 |
+
# If the second neighbor of athe amide bond is a dummy atom, i.e., attachment point
|
| 53 |
+
mark_protac_as_wrong("[PROTAC]")
|
| 54 |
+
|
| 55 |
+
# Do as above
|
| 56 |
+
# Identify the bond to split, i.e., the nitrogen-carbon bond, and split
|
| 57 |
+
"[*:substruct]-N-[*:tmp]", "[*:tmp]-C(=O)-[rest of the PROTAC]" = split_PROTAC_at("N-C")
|
| 58 |
+
|
| 59 |
+
"[Linker]-N-[*:tmp]" = join("[Linker]-[*:substruct]", "[*:substruct]-N-[*:tmp]")
|
| 60 |
+
|
| 61 |
+
rename_attachment_point("[*:tmp]-C(=O)-[rest of the PROTAC]")
|
| 62 |
+
rename_attachment_point("[Linker]-N-[*:tmp]")
|
| 63 |
+
```
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
# Convert dummy atoms in substruct to query atoms for substructure search
|
| 67 |
+
query_substruct = dummy2query(substruct)
|
| 68 |
+
|
| 69 |
+
# Identify amide bond (N-C=O) in substruct substructure
|
| 70 |
+
amide_pattern = Chem.MolFromSmarts("[NX3][CX3](=[OX1])")
|
| 71 |
+
amide_matches = query_substruct.GetSubstructMatches(amide_pattern, useChirality=True)
|
| 72 |
+
|
| 73 |
+
if not amide_matches:
|
| 74 |
+
return substruct, linker # No amide bond found, return the original substruct
|
| 75 |
+
|
| 76 |
+
side_atom = None
|
| 77 |
+
nitrogen_idx_found, carbonyl_idx_found = None, None
|
| 78 |
+
for match in amide_matches:
|
| 79 |
+
nitrogen_idx, carbonyl_idx = match[0], match[1]
|
| 80 |
+
nitrogen_atom = query_substruct.GetAtomWithIdx(nitrogen_idx)
|
| 81 |
+
carbonyl_atom = query_substruct.GetAtomWithIdx(carbonyl_idx)
|
| 82 |
+
|
| 83 |
+
for amide_atom in [nitrogen_atom, carbonyl_atom]:
|
| 84 |
+
# Check neighboring atoms for attachment points
|
| 85 |
+
# NOTE: The dummy atom representing an attachment point have atomic number 0
|
| 86 |
+
for neighbor in amide_atom.GetNeighbors():
|
| 87 |
+
if neighbor.GetAtomicNum() == 0:
|
| 88 |
+
nitrogen_idx_found = nitrogen_idx
|
| 89 |
+
carbonyl_idx_found = carbonyl_idx
|
| 90 |
+
side_atom = "N" if amide_atom == nitrogen_atom else "C"
|
| 91 |
+
break
|
| 92 |
+
|
| 93 |
+
# If previous search failed, check the neighbors of the neighboring
|
| 94 |
+
# atoms (second-order neighbors)
|
| 95 |
+
if nitrogen_idx_found is None or carbonyl_idx_found is None:
|
| 96 |
+
for neighbor in amide_atom.GetNeighbors():
|
| 97 |
+
for second_neighbor in neighbor.GetNeighbors():
|
| 98 |
+
if second_neighbor.GetIdx() == carbonyl_idx or second_neighbor.GetIdx() == nitrogen_idx:
|
| 99 |
+
continue # Skip the opposite atom from the amide bond
|
| 100 |
+
|
| 101 |
+
if second_neighbor.GetAtomicNum() == 0:
|
| 102 |
+
nitrogen_idx_found = nitrogen_idx
|
| 103 |
+
carbonyl_idx_found = carbonyl_idx
|
| 104 |
+
side_atom = "N" if amide_atom == nitrogen_atom else "C"
|
| 105 |
+
break
|
| 106 |
+
else:
|
| 107 |
+
break
|
| 108 |
+
|
| 109 |
+
if nitrogen_idx_found is None or carbonyl_idx_found is None or side_atom is None:
|
| 110 |
+
return substruct, linker
|
| 111 |
+
|
| 112 |
+
# Split the amide bond and adjust
|
| 113 |
+
dummy_label = 3
|
| 114 |
+
dummy_labels = [(dummy_label, dummy_label)] # The E3 and substruct will have 1 and 2, so we need a third one
|
| 115 |
+
amid_bond_idx = query_substruct.GetBondBetweenAtoms(nitrogen_idx_found, carbonyl_idx_found).GetIdx()
|
| 116 |
+
fragments = Chem.FragmentOnBonds(query_substruct, [amid_bond_idx], addDummies=True, dummyLabels=dummy_labels)
|
| 117 |
+
|
| 118 |
+
# Get the fragments resulting from bond breaking
|
| 119 |
+
try:
|
| 120 |
+
mol_frags = Chem.GetMolFrags(fragments, asMols=True, sanitizeFrags=True)
|
| 121 |
+
except Exception as e:
|
| 122 |
+
print(e)
|
| 123 |
+
return substruct, linker
|
| 124 |
+
|
| 125 |
+
# Identify the "[*:substruct][<optional neighboring atom>]N[3*]" fragment, the other one will be the "truncated" substruct
|
| 126 |
+
amide_fragment_pattern = Chem.MolFromSmarts(f"[*:{substruct_attachment_id}][{side_atom}][{dummy_label}*]")
|
| 127 |
+
amide_fragment = None
|
| 128 |
+
substruct_fixed = None
|
| 129 |
+
|
| 130 |
+
if verbose:
|
| 131 |
+
print(f'Attachment point: *:{substruct_attachment_id}')
|
| 132 |
+
print('Substruct:')
|
| 133 |
+
display_mol(substruct)
|
| 134 |
+
print('Linker:')
|
| 135 |
+
display_mol(linker)
|
| 136 |
+
|
| 137 |
+
for frag in mol_frags:
|
| 138 |
+
if frag.HasSubstructMatch(dummy2query(amide_fragment_pattern)):
|
| 139 |
+
amide_fragment = frag
|
| 140 |
+
if verbose:
|
| 141 |
+
print('Amide fragment:')
|
| 142 |
+
display_mol(frag)
|
| 143 |
+
else:
|
| 144 |
+
if verbose:
|
| 145 |
+
print('Substruct fragment:')
|
| 146 |
+
display_mol(frag)
|
| 147 |
+
substruct_fixed = frag
|
| 148 |
+
|
| 149 |
+
if amide_fragment is None or substruct_fixed is None:
|
| 150 |
+
return substruct, linker
|
| 151 |
+
|
| 152 |
+
# In order for the function to be used "on linkers", we need to make sure
|
| 153 |
+
# that the amide fragment contains the attachment point of the substruct.
|
| 154 |
+
# If not, there's nothing to do.
|
| 155 |
+
if f'[*:{substruct_attachment_id}]' not in Chem.MolToSmiles(amide_fragment, canonical=True):
|
| 156 |
+
return substruct, linker
|
| 157 |
+
|
| 158 |
+
# Rename the "[3*]" attachment point on the amide fragment to "[*:3]"
|
| 159 |
+
amide_fragment_smiles = Chem.MolToSmiles(amide_fragment, canonical=True)
|
| 160 |
+
amide_fragment_smiles = amide_fragment_smiles.replace(f'[{dummy_label}*]', f'[*:{dummy_label}]')
|
| 161 |
+
amide_fragment_smiles = canonize(amide_fragment_smiles)
|
| 162 |
+
amide_fragment = Chem.MolFromSmiles(amide_fragment_smiles)
|
| 163 |
+
|
| 164 |
+
# Use molzip to join the linker and the fragment at the original attachment point
|
| 165 |
+
linker_fixed = Chem.molzip(linker, amide_fragment)
|
| 166 |
+
|
| 167 |
+
# Rename the "[*:3]" attachment point back to the original attachment point on the linker
|
| 168 |
+
linker_fixed_smiles = Chem.MolToSmiles(linker_fixed, canonical=True)
|
| 169 |
+
linker_fixed_smiles = linker_fixed_smiles.replace(f'[*:{dummy_label}]', f'[*:{substruct_attachment_id}]')
|
| 170 |
+
linker_fixed_smiles = canonize(linker_fixed_smiles)
|
| 171 |
+
linker_fixed = Chem.MolFromSmiles(linker_fixed_smiles)
|
| 172 |
+
|
| 173 |
+
# Rename the "[3*]" attachment point back to the original attachment point on the substruct
|
| 174 |
+
substruct_fixed_smiles = Chem.MolToSmiles(substruct_fixed, canonical=True)
|
| 175 |
+
substruct_fixed_smiles = substruct_fixed_smiles.replace(f'[{dummy_label}*]', f'[*:{substruct_attachment_id}]')
|
| 176 |
+
substruct_fixed_smiles = canonize(substruct_fixed_smiles)
|
| 177 |
+
substruct_fixed = Chem.MolFromSmiles(substruct_fixed_smiles)
|
| 178 |
+
|
| 179 |
+
return substruct_fixed, linker_fixed
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def adjust_amide_bonds_in_substructs(
|
| 183 |
+
substructs: Dict[str, str],
|
| 184 |
+
protac_smiles: str,
|
| 185 |
+
poi_attachment_id: int = 1,
|
| 186 |
+
e3_attachment_id: int = 2,
|
| 187 |
+
) -> Dict[str, str]:
|
| 188 |
+
""" Adjusts the amide bonds in the substructures of a PROTAC. Just a wrapper function to apply it to multiple substructures.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
substructs: The substructures of the PROTAC. A dictionary of SMILES with keys 'poi', 'linker', and 'e3'.
|
| 192 |
+
protac_smiles: The SMILES of the PROTAC for checking reassembly.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
The updated substructures dictionary.
|
| 196 |
+
"""
|
| 197 |
+
poi_mol = Chem.MolFromSmiles(substructs['poi'])
|
| 198 |
+
e3_mol = Chem.MolFromSmiles(substructs['e3'])
|
| 199 |
+
linker_mol = Chem.MolFromSmiles(substructs['linker'])
|
| 200 |
+
|
| 201 |
+
# Fix the amide group on the POI ligand
|
| 202 |
+
poi_mol, linker_mol = adjust_amide_bond(poi_mol, linker_mol, poi_attachment_id)
|
| 203 |
+
poi_smiles = Chem.MolToSmiles(poi_mol, canonical=True)
|
| 204 |
+
linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
|
| 205 |
+
e3_smiles = substructs['e3']
|
| 206 |
+
if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
|
| 207 |
+
return substructs
|
| 208 |
+
|
| 209 |
+
# Fix the amide group on the E3 binder
|
| 210 |
+
e3_mol, linker_mol = adjust_amide_bond(e3_mol, linker_mol, e3_attachment_id)
|
| 211 |
+
e3_smiles = Chem.MolToSmiles(e3_mol, canonical=True)
|
| 212 |
+
linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
|
| 213 |
+
if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
|
| 214 |
+
return substructs
|
| 215 |
+
|
| 216 |
+
# Fix the amide group on the linker, E3 side
|
| 217 |
+
linker_mol, e3_mol = adjust_amide_bond(linker_mol, e3_mol, e3_attachment_id)
|
| 218 |
+
e3_smiles = Chem.MolToSmiles(e3_mol, canonical=True)
|
| 219 |
+
linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
|
| 220 |
+
if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
|
| 221 |
+
return substructs
|
| 222 |
+
|
| 223 |
+
# Fix the amide group on the linker, POI side
|
| 224 |
+
linker_mol, poi_mol = adjust_amide_bond(linker_mol, poi_mol, poi_attachment_id)
|
| 225 |
+
poi_smiles = Chem.MolToSmiles(poi_mol, canonical=True)
|
| 226 |
+
linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
|
| 227 |
+
if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
|
| 228 |
+
return substructs
|
| 229 |
+
|
| 230 |
+
substructs['poi'] = poi_smiles
|
| 231 |
+
substructs['e3'] = e3_smiles
|
| 232 |
+
substructs['linker'] = linker_smiles
|
| 233 |
+
return substructs
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def adjust_ester_bond(
|
| 237 |
+
substruct: Chem.Mol,
|
| 238 |
+
linker: Chem.Mol,
|
| 239 |
+
substruct_attachment_id: int,
|
| 240 |
+
verbose: int = 0,
|
| 241 |
+
) -> Tuple[Chem.Mol, Chem.Mol]:
|
| 242 |
+
"""
|
| 243 |
+
Adjust the amide bond between the substruct and linker substructure.
|
| 244 |
+
Handles the case when neighboring atoms of the amide bond are dummy atoms, which represent attachment points.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
substruct: The substructure of the substruct (protein of interest) that contains the amide bond.
|
| 248 |
+
linker: The linker molecule that connects substruct to the E3 ligase.
|
| 249 |
+
substruct_attachment_id: The attachment point ID in the substruct substructure. E.g., 1 for the POI, as in "[*:1]".
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
Tuple[Chem.Mol, Chem.Mol]: The adjusted substruct and linker molecules, in that order.
|
| 253 |
+
"""
|
| 254 |
+
# Convert dummy atoms in substruct to query atoms for substructure search
|
| 255 |
+
query_substruct = dummy2query(substruct)
|
| 256 |
+
|
| 257 |
+
# Identify ester group (COOR) in substruct substructure
|
| 258 |
+
ester_pattern = Chem.MolFromSmarts("[OX2][CX3](=[OX1])")
|
| 259 |
+
|
| 260 |
+
ester_matches = query_substruct.GetSubstructMatches(ester_pattern)
|
| 261 |
+
|
| 262 |
+
if not ester_matches:
|
| 263 |
+
return substruct, linker # No amide bond found, return the original substruct
|
| 264 |
+
|
| 265 |
+
side_atom = None
|
| 266 |
+
oxygen_idx_found, carbonyl_idx_found = None, None
|
| 267 |
+
for match in ester_matches:
|
| 268 |
+
oxygen_idx, carbonyl_idx = match[0], match[1]
|
| 269 |
+
oxygen_atom = query_substruct.GetAtomWithIdx(oxygen_idx)
|
| 270 |
+
carbonyl_atom = query_substruct.GetAtomWithIdx(carbonyl_idx)
|
| 271 |
+
|
| 272 |
+
for ester_atom in [oxygen_atom, carbonyl_atom]:
|
| 273 |
+
# Check neighboring atoms for attachment points
|
| 274 |
+
# NOTE: The dummy atom representing an attachment point have atomic number 0
|
| 275 |
+
for neighbor in ester_atom.GetNeighbors():
|
| 276 |
+
if neighbor.GetAtomicNum() == 0:
|
| 277 |
+
oxygen_idx_found = oxygen_idx
|
| 278 |
+
carbonyl_idx_found = carbonyl_idx
|
| 279 |
+
side_atom = "O" if ester_atom == oxygen_atom else "C"
|
| 280 |
+
break
|
| 281 |
+
|
| 282 |
+
# If previous search failed, check the neighbors of the neighboring
|
| 283 |
+
# atoms (second-order neighbors)
|
| 284 |
+
if oxygen_idx_found is None or carbonyl_idx_found is None:
|
| 285 |
+
for neighbor in ester_atom.GetNeighbors():
|
| 286 |
+
for second_neighbor in neighbor.GetNeighbors():
|
| 287 |
+
if second_neighbor.GetIdx() == carbonyl_idx or second_neighbor.GetIdx() == oxygen_idx:
|
| 288 |
+
continue # Skip the opposite atom from the amide bond
|
| 289 |
+
|
| 290 |
+
if second_neighbor.GetAtomicNum() == 0:
|
| 291 |
+
oxygen_idx_found = oxygen_idx
|
| 292 |
+
carbonyl_idx_found = carbonyl_idx
|
| 293 |
+
side_atom = "O" if ester_atom == oxygen_atom else "C"
|
| 294 |
+
break
|
| 295 |
+
else:
|
| 296 |
+
break
|
| 297 |
+
|
| 298 |
+
if oxygen_idx_found is None or carbonyl_idx_found is None or side_atom is None:
|
| 299 |
+
return substruct, linker
|
| 300 |
+
|
| 301 |
+
# Split the amide bond and adjust
|
| 302 |
+
dummy_label = 3
|
| 303 |
+
dummy_labels = [(dummy_label, dummy_label)] # The E3 and substruct will have 1 and 2, so we need a third one
|
| 304 |
+
amid_bond_idx = query_substruct.GetBondBetweenAtoms(oxygen_idx_found, carbonyl_idx_found).GetIdx()
|
| 305 |
+
fragments = Chem.FragmentOnBonds(query_substruct, [amid_bond_idx], addDummies=True, dummyLabels=dummy_labels)
|
| 306 |
+
|
| 307 |
+
# Get the fragments resulting from bond breaking
|
| 308 |
+
try:
|
| 309 |
+
mol_frags = Chem.GetMolFrags(fragments, asMols=True, sanitizeFrags=True)
|
| 310 |
+
except Exception as e:
|
| 311 |
+
if verbose:
|
| 312 |
+
print(e)
|
| 313 |
+
return substruct, linker
|
| 314 |
+
|
| 315 |
+
# Identify the "[*:substruct][<optional neighboring atom>]N[3*]" fragment, the other one will be the "truncated" substruct
|
| 316 |
+
ester_fragment_pattern = Chem.MolFromSmarts(f"[*:{substruct_attachment_id}][{side_atom}][{dummy_label}*]")
|
| 317 |
+
ester_fragment = None
|
| 318 |
+
substruct_fixed = None
|
| 319 |
+
|
| 320 |
+
for frag in mol_frags:
|
| 321 |
+
if frag.HasSubstructMatch(dummy2query(ester_fragment_pattern)):
|
| 322 |
+
ester_fragment = frag
|
| 323 |
+
else:
|
| 324 |
+
substruct_fixed = frag
|
| 325 |
+
|
| 326 |
+
if ester_fragment is None or substruct_fixed is None:
|
| 327 |
+
return substruct, linker
|
| 328 |
+
|
| 329 |
+
# In order for the function to be used "on linkers", we need to make sure
|
| 330 |
+
# that the ester fragment contains the attachment point of the substruct.
|
| 331 |
+
# If not, there's nothing to do.
|
| 332 |
+
if f'[*:{substruct_attachment_id}]' not in Chem.MolToSmiles(ester_fragment, canonical=True):
|
| 333 |
+
return substruct, linker
|
| 334 |
+
|
| 335 |
+
# Rename the "[3*]" attachment point on the amide fragment to "[*:3]"
|
| 336 |
+
ester_fragment_smiles = Chem.MolToSmiles(ester_fragment, canonical=True)
|
| 337 |
+
ester_fragment_smiles = ester_fragment_smiles.replace(f'[{dummy_label}*]', f'[*:{dummy_label}]')
|
| 338 |
+
ester_fragment = Chem.MolFromSmiles(ester_fragment_smiles)
|
| 339 |
+
|
| 340 |
+
# Use molzip to join the linker and the fragment at the original attachment point
|
| 341 |
+
linker_fixed = Chem.molzip(linker, ester_fragment)
|
| 342 |
+
|
| 343 |
+
# Rename the "[*:3]" attachment point back to the original attachment point on the linker
|
| 344 |
+
linker_fixed_smiles = Chem.MolToSmiles(linker_fixed, canonical=True)
|
| 345 |
+
linker_fixed_smiles = linker_fixed_smiles.replace(f'[*:{dummy_label}]', f'[*:{substruct_attachment_id}]')
|
| 346 |
+
linker_fixed = Chem.MolFromSmiles(linker_fixed_smiles)
|
| 347 |
+
|
| 348 |
+
# Rename the "[3*]" attachment point back to the original attachment point on the substruct
|
| 349 |
+
substruct_fixed_smiles = Chem.MolToSmiles(substruct_fixed, canonical=True)
|
| 350 |
+
substruct_fixed_smiles = substruct_fixed_smiles.replace(f'[{dummy_label}*]', f'[*:{substruct_attachment_id}]')
|
| 351 |
+
substruct_fixed = Chem.MolFromSmiles(substruct_fixed_smiles)
|
| 352 |
+
|
| 353 |
+
return substruct_fixed, linker_fixed
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def adjust_ester_bonds_in_substructs(
|
| 357 |
+
substructs: Dict[str, str],
|
| 358 |
+
protac_smiles: str,
|
| 359 |
+
poi_attachment_id: int = 1,
|
| 360 |
+
e3_attachment_id: int = 2,
|
| 361 |
+
) -> Dict[str, str]:
|
| 362 |
+
""" Adjusts the ester bonds in the substructures of a PROTAC. Just a wrapper function to apply it to multiple substructures.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
substructs: The substructures of the PROTAC. A dictionary of SMILES with keys 'poi', 'linker', and 'e3'.
|
| 366 |
+
protac_smiles: The SMILES of the PROTAC for checking reassembly.
|
| 367 |
+
|
| 368 |
+
Returns:
|
| 369 |
+
The updated substructures dictionary.
|
| 370 |
+
"""
|
| 371 |
+
poi_mol = Chem.MolFromSmiles(substructs['poi'])
|
| 372 |
+
e3_mol = Chem.MolFromSmiles(substructs['e3'])
|
| 373 |
+
linker_mol = Chem.MolFromSmiles(substructs['linker'])
|
| 374 |
+
|
| 375 |
+
# Fix the amide group on the POI ligand
|
| 376 |
+
poi_mol, linker_mol = adjust_ester_bond(poi_mol, linker_mol, poi_attachment_id)
|
| 377 |
+
poi_smiles = Chem.MolToSmiles(poi_mol, canonical=True)
|
| 378 |
+
linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
|
| 379 |
+
e3_smiles = substructs['e3']
|
| 380 |
+
if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
|
| 381 |
+
return substructs
|
| 382 |
+
|
| 383 |
+
# Fix the amide group on the E3 binder
|
| 384 |
+
e3_mol, linker_mol = adjust_ester_bond(e3_mol, linker_mol, e3_attachment_id)
|
| 385 |
+
e3_smiles = Chem.MolToSmiles(e3_mol, canonical=True)
|
| 386 |
+
linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
|
| 387 |
+
if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
|
| 388 |
+
return substructs
|
| 389 |
+
|
| 390 |
+
# Fix the amide group on the linker, E3 side
|
| 391 |
+
linker_mol, e3_mol = adjust_ester_bond(linker_mol, e3_mol, e3_attachment_id)
|
| 392 |
+
e3_smiles = Chem.MolToSmiles(e3_mol, canonical=True)
|
| 393 |
+
linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
|
| 394 |
+
if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
|
| 395 |
+
return substructs
|
| 396 |
+
|
| 397 |
+
# Fix the amide group on the linker, POI side
|
| 398 |
+
linker_mol, poi_mol = adjust_ester_bond(linker_mol, poi_mol, poi_attachment_id)
|
| 399 |
+
poi_smiles = Chem.MolToSmiles(poi_mol, canonical=True)
|
| 400 |
+
linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
|
| 401 |
+
if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
|
| 402 |
+
return substructs
|
| 403 |
+
|
| 404 |
+
substructs['poi'] = poi_smiles
|
| 405 |
+
substructs['e3'] = e3_smiles
|
| 406 |
+
substructs['linker'] = linker_smiles
|
| 407 |
+
return substructs
|
protac_splitter/data/curation/curation.py
ADDED
|
@@ -0,0 +1,894 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Curation utilities for PROTAC Splitter. """
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
from typing import Any, Dict, Optional, Union, Callable
|
| 5 |
+
from joblib import Parallel, delayed
|
| 6 |
+
|
| 7 |
+
from rdkit import Chem
|
| 8 |
+
from rdkit.Chem import DataStructs
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from protac_splitter.chemoinformatics import (
|
| 14 |
+
canonize,
|
| 15 |
+
remove_dummy_atoms,
|
| 16 |
+
canonize_smiles,
|
| 17 |
+
get_mol_id,
|
| 18 |
+
get_substr_match,
|
| 19 |
+
)
|
| 20 |
+
from protac_splitter.evaluation import check_reassembly
|
| 21 |
+
from protac_splitter.data.curation.substructure_extraction import (
|
| 22 |
+
get_substructure_from_non_perfect_match,
|
| 23 |
+
get_substructs_from_unmapped_e3_poi,
|
| 24 |
+
get_substructs_from_substr_and_linker,
|
| 25 |
+
get_substructs_from_mapped_linker,
|
| 26 |
+
swap_attachment_points,
|
| 27 |
+
)
|
| 28 |
+
from protac_splitter.data.curation.bond_adjustments import (
|
| 29 |
+
adjust_amide_bonds_in_substructs,
|
| 30 |
+
adjust_ester_bonds_in_substructs,
|
| 31 |
+
)
|
| 32 |
+
from protac_splitter.data.curation.mapping_utils import update_dictionary
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def check_substructs_size(
|
| 36 |
+
protac_mol: Chem.Mol,
|
| 37 |
+
substructs: Dict[str, str],
|
| 38 |
+
size_perc_threshold: float = 0.8,
|
| 39 |
+
) -> bool:
|
| 40 |
+
""" Check the size of the substructures in the PROTAC. If any of them is too big, return False.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
protac_mol: The PROTAC molecule.
|
| 44 |
+
substructs: The substructures to check against.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
False if any of the substructures is too big. True otherwise.
|
| 48 |
+
"""
|
| 49 |
+
num_protac_atoms = protac_mol.GetNumAtoms()
|
| 50 |
+
for key, smiles in substructs.items():
|
| 51 |
+
substruct = Chem.MolFromSmiles(smiles)
|
| 52 |
+
num_substruct_atoms = substruct.GetNumAtoms()
|
| 53 |
+
if num_substruct_atoms / num_protac_atoms > size_perc_threshold:
|
| 54 |
+
# print(f'Error: {key.upper()} is too big in the PROTAC ({num_substruct_atoms} / {num_protac_atoms} = {num_substruct_atoms / num_protac_atoms:.2%} > {size_perc_threshold:.2%})')
|
| 55 |
+
# display_mol(substruct)
|
| 56 |
+
# display_mol(protac_mol)
|
| 57 |
+
return False
|
| 58 |
+
return True
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def check_linker_similarity(
|
| 62 |
+
linker_smiles: str,
|
| 63 |
+
pois: Union[pd.DataFrame, str],
|
| 64 |
+
e3s: Union[pd.DataFrame, str],
|
| 65 |
+
linkers: Optional[Union[pd.DataFrame, str]] = None,
|
| 66 |
+
pois_similarity_threshold: float = 0.7,
|
| 67 |
+
e3s_similarity_threshold: float = 0.7,
|
| 68 |
+
linkers_similarity_threshold: float = 0.6,
|
| 69 |
+
morgan_fp_generator: Optional[Callable] = None,
|
| 70 |
+
) -> bool:
|
| 71 |
+
""" Check the similarity of the linker with all the matching POIs and E3s. If too similar to any of them, return False.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
linker_smiles: The linker SMILES.
|
| 75 |
+
pois: The POI ligands. Must have a 'FP' column with the Morgan fingerprints.
|
| 76 |
+
e3s: The E3 binders. Must have a 'FP' column with the Morgan fingerprints.
|
| 77 |
+
pois_similarity_threshold: The similarity threshold for the POIs.
|
| 78 |
+
e3s_similarity_threshold: The similarity threshold for the E3s.
|
| 79 |
+
morgan_fp_generator: The Morgan fingerprint generator.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
False if the linker is too similar to any of the POIs or E3s. True otherwise.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
# Get the linker fingerprint
|
| 86 |
+
if morgan_fp_generator is None:
|
| 87 |
+
morgan_fp_generator = Chem.rdFingerprintGenerator.GetMorganGenerator(
|
| 88 |
+
radius=2,
|
| 89 |
+
fpSize=2048,
|
| 90 |
+
useBondTypes=True,
|
| 91 |
+
includeChirality=True,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
linker = Chem.MolFromSmiles(linker_smiles)
|
| 95 |
+
linker_fp = morgan_fp_generator.GetFingerprint(linker)
|
| 96 |
+
|
| 97 |
+
# Check the similarity of the linker with the POIs and E3s (use BulkTanimotoSimilarity)
|
| 98 |
+
if isinstance(e3s, str):
|
| 99 |
+
# Create a one-element list with the E3 fingerprint
|
| 100 |
+
e3s_fps = [morgan_fp_generator.GetFingerprint(Chem.MolFromSmiles(e3s))]
|
| 101 |
+
else:
|
| 102 |
+
e3s_fps = e3s['FP'].to_list()
|
| 103 |
+
e3s_similarities = DataStructs.BulkTanimotoSimilarity(linker_fp, e3s_fps)
|
| 104 |
+
if (np.array(e3s_similarities) > e3s_similarity_threshold).any():
|
| 105 |
+
print(f'WARNING: Linker {linker_smiles} is too similar to an E3 binder')
|
| 106 |
+
# display_mol(linker)
|
| 107 |
+
# display_mol(Chem.MolFromSmiles(e3s[e3s_similarities.argmax()]))
|
| 108 |
+
return False
|
| 109 |
+
|
| 110 |
+
# Check if the linker is similar to any of the POIs or E3s
|
| 111 |
+
if isinstance(pois, str):
|
| 112 |
+
# Create a one-element list with the POI fingerprint
|
| 113 |
+
pois_fps = [morgan_fp_generator.GetFingerprint(Chem.MolFromSmiles(pois))]
|
| 114 |
+
else:
|
| 115 |
+
pois_fps = pois['FP'].to_list()
|
| 116 |
+
pois_similarities = DataStructs.BulkTanimotoSimilarity(linker_fp, pois_fps)
|
| 117 |
+
if (np.array(pois_similarities) > pois_similarity_threshold).any():
|
| 118 |
+
# print(f'Error: Linker {linker_smiles} is too similar to a POI ligand')
|
| 119 |
+
# display_mol(linker)
|
| 120 |
+
# display_mol(Chem.MolFromSmiles(pois[pois_similarities.argmax()]))
|
| 121 |
+
return False
|
| 122 |
+
|
| 123 |
+
# Check if the linker is NOT similar to any of the linkers
|
| 124 |
+
if linkers is not None:
|
| 125 |
+
if isinstance(linkers, str):
|
| 126 |
+
# Create a one-element list with the linker fingerprint
|
| 127 |
+
linkers_fps = [morgan_fp_generator.GetFingerprint(Chem.MolFromSmiles(linkers))]
|
| 128 |
+
else:
|
| 129 |
+
linkers_fps = linkers['FP'].to_list()
|
| 130 |
+
linkers_similarities = DataStructs.BulkTanimotoSimilarity(linker_fp, linkers_fps)
|
| 131 |
+
if not (np.array(linkers_similarities) > linkers_similarity_threshold).all():
|
| 132 |
+
print(f'WARNING: Linker {linker_smiles} is too similar to a linker')
|
| 133 |
+
# display_mol(linker)
|
| 134 |
+
# display_mol(Chem.MolFromSmiles(linkers[linkers_similarities.argmax()]))
|
| 135 |
+
return False
|
| 136 |
+
|
| 137 |
+
return True
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def check_substructs_similarity(
|
| 141 |
+
protac: Union[np.ndarray, str, Chem.Mol],
|
| 142 |
+
substructs: Dict[str, str],
|
| 143 |
+
similarity_threshold: float = 0.7,
|
| 144 |
+
similarity_thresholds : Dict[str, float] = None,
|
| 145 |
+
morgan_fp_generator: Optional[Callable] = None,
|
| 146 |
+
) -> bool:
|
| 147 |
+
""" Check the similarity of the PROTAC with the substructures. If too similar to any of them, return False.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
protac: The PROTAC molecule or its SMILES.
|
| 151 |
+
substructs: The substructures to check against.
|
| 152 |
+
similarity_threshold: The similarity threshold.
|
| 153 |
+
similarity_thresholds: The similarity thresholds for the substructures.
|
| 154 |
+
morgan_fp_generator: The Morgan fingerprint generator.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
False if the PROTAC is too similar to any of the substructures. True otherwise.
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
if morgan_fp_generator is None:
|
| 161 |
+
morgan_fp_generator = Chem.rdFingerprintGenerator.GetMorganGenerator(
|
| 162 |
+
radius=2,
|
| 163 |
+
fpSize=2048,
|
| 164 |
+
useBondTypes=True,
|
| 165 |
+
includeChirality=True,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if isinstance(protac, str):
|
| 169 |
+
protac = Chem.MolFromSmiles(protac)
|
| 170 |
+
protac_fp = morgan_fp_generator.GetFingerprint(protac)
|
| 171 |
+
elif isinstance(protac, Chem.Mol):
|
| 172 |
+
protac_fp = morgan_fp_generator.GetFingerprint(protac)
|
| 173 |
+
else:
|
| 174 |
+
protac_fp = protac
|
| 175 |
+
|
| 176 |
+
for key, smiles in substructs.items():
|
| 177 |
+
substr_fp = morgan_fp_generator.GetFingerprint(Chem.MolFromSmiles(smiles))
|
| 178 |
+
threshold = similarity_thresholds[key] if similarity_thresholds is not None else similarity_threshold
|
| 179 |
+
if DataStructs.TanimotoSimilarity(protac_fp, substr_fp) > threshold:
|
| 180 |
+
print(f'WARNING: {key.upper()} is too similar to the PROTAC, similarity: {DataStructs.TanimotoSimilarity(protac_fp, substr_fp):.4f} > {threshold}')
|
| 181 |
+
# display_mol(Chem.MolFromSmiles(smiles))
|
| 182 |
+
return False
|
| 183 |
+
|
| 184 |
+
return True
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def get_split_row(
|
| 188 |
+
row: pd.Series,
|
| 189 |
+
substructs: Dict[str, str],
|
| 190 |
+
poi_smiles_no_dummy: Optional[str] = None,
|
| 191 |
+
e3_smiles_no_dummy: Optional[str] = None,
|
| 192 |
+
) -> Dict[str, Any]:
|
| 193 |
+
""" Update the fields of a row with the substructures and their IDs.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
row: The input row.
|
| 197 |
+
dictionaries: The dictionaries containing the substructures.
|
| 198 |
+
substructs: The substructures found in the PROTAC.
|
| 199 |
+
poi_smiles_no_dummy: The POI ligand SMILES without the dummy atoms.
|
| 200 |
+
e3_smiles_no_dummy: The E3 binder SMILES without the dummy atoms.
|
| 201 |
+
update_dict_if_ids_not_found: Whether to update the dictionary if the substructure IDs are not found.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
The updated row.
|
| 205 |
+
"""
|
| 206 |
+
mapped_row = {}
|
| 207 |
+
mapped_row['PROTAC SMILES'] = canonize_smiles(row['SMILES'])
|
| 208 |
+
mapped_row['POI Ligand SMILES with direction'] = substructs['poi']
|
| 209 |
+
mapped_row['E3 Binder SMILES with direction'] = substructs['e3']
|
| 210 |
+
mapped_row['Linker SMILES with direction'] = substructs['linker']
|
| 211 |
+
mapped_row['POI Ligand SMILES'] = remove_dummy_atoms(substructs['poi']) if poi_smiles_no_dummy is None else poi_smiles_no_dummy
|
| 212 |
+
mapped_row['E3 Binder SMILES'] = remove_dummy_atoms(substructs['e3']) if e3_smiles_no_dummy is None else e3_smiles_no_dummy
|
| 213 |
+
mapped_row['Linker SMILES'] = remove_dummy_atoms(substructs['linker'])
|
| 214 |
+
|
| 215 |
+
# Get the IDs and update the dictionaries with new substructures
|
| 216 |
+
mapped_row['PROTAC ID'] = get_mol_id(mapped_row['PROTAC SMILES'])
|
| 217 |
+
mapped_row['POI Ligand ID'] = get_mol_id(mapped_row['POI Ligand SMILES with direction'])
|
| 218 |
+
mapped_row['E3 Binder ID'] = get_mol_id(mapped_row['E3 Binder SMILES with direction'])
|
| 219 |
+
mapped_row['Linker ID'] = get_mol_id(mapped_row['Linker SMILES with direction'])
|
| 220 |
+
|
| 221 |
+
return mapped_row
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def split_single_protac(
|
| 225 |
+
row: pd.Series,
|
| 226 |
+
dictionaries: Dict[str, pd.DataFrame],
|
| 227 |
+
biggest_matches_first: bool = True,
|
| 228 |
+
max_iter_on_linkers: int = 0,
|
| 229 |
+
split_with_substr_and_linker_matching: bool = False,
|
| 230 |
+
similarity_threshold: float = 0.65,
|
| 231 |
+
morgan_radius: Optional[int] = None,
|
| 232 |
+
morgan_fp_size: Optional[int] = None,
|
| 233 |
+
morgan_fp_generator: Optional[Callable] = None,
|
| 234 |
+
poi_attachment_id: int = 1,
|
| 235 |
+
e3_attachment_id: int = 2,
|
| 236 |
+
) -> Dict[str, Any]:
|
| 237 |
+
""" Map a PROTAC row to the substructures in the dictionaries.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
row: The input row, containing the PROTAC SMILES, ID, and molecule.
|
| 241 |
+
dictionaries: The dictionaries containing the substructures.
|
| 242 |
+
biggest_matches_first: Whether to sort the matches by the number of atoms in the molecule.
|
| 243 |
+
max_iter_on_linkers: The maximum number of iterations to perform on the linkers.
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
The mapped row. None if the mapping was not successful.
|
| 247 |
+
"""
|
| 248 |
+
# # Disable the RDKit warnings that pop up when RDKit fails to create molecules
|
| 249 |
+
# # NOTE: The following is done to avoid warning messages during multiprocessing
|
| 250 |
+
# RDLogger.DisableLog("rdApp.*")
|
| 251 |
+
# blocker = rdBase.BlockLogs()
|
| 252 |
+
|
| 253 |
+
protac_smiles = row['SMILES']
|
| 254 |
+
protac_mol = row['Molecule']
|
| 255 |
+
|
| 256 |
+
if morgan_fp_generator is None:
|
| 257 |
+
morgan_radius = 2 if morgan_radius is None else morgan_radius
|
| 258 |
+
morgan_fp_size = 2048 if morgan_fp_size is None else morgan_fp_size
|
| 259 |
+
morgan_fp_generator = Chem.rdFingerprintGenerator.GetMorganGenerator(
|
| 260 |
+
radius=morgan_radius,
|
| 261 |
+
fpSize=morgan_fp_size,
|
| 262 |
+
useBondTypes=True,
|
| 263 |
+
includeChirality=True,
|
| 264 |
+
)
|
| 265 |
+
else:
|
| 266 |
+
morgan_radius = 'None'
|
| 267 |
+
morgan_fp_size = 'None'
|
| 268 |
+
protac_fp = morgan_fp_generator.GetFingerprint(protac_mol)
|
| 269 |
+
|
| 270 |
+
notes = f'({max_iter_on_linkers=})({split_with_substr_and_linker_matching=})({morgan_radius=})({morgan_fp_size=})'
|
| 271 |
+
|
| 272 |
+
# Get all substructure matches in the POI dictionary
|
| 273 |
+
# poi_matches = dictionaries['POI Ligand']['Molecule'].apply(lambda x: get_substr_match(protac_mol, x, num_allowed_fragments=1))
|
| 274 |
+
poi_matches = dictionaries['POI Ligand']['Molecule'].apply(lambda x: protac_mol.HasSubstructMatch(x))
|
| 275 |
+
pois = dictionaries['POI Ligand'][poi_matches].drop_duplicates(subset=['SMILES'])
|
| 276 |
+
|
| 277 |
+
# Get all substructure matches in the E3 dictionary
|
| 278 |
+
# e3_matches = dictionaries['E3 Binder']['Molecule'].apply(lambda x: get_substr_match(protac_mol, x, num_allowed_fragments=1))
|
| 279 |
+
e3_matches = dictionaries['E3 Binder']['Molecule'].apply(lambda x: protac_mol.HasSubstructMatch(x))
|
| 280 |
+
e3s = dictionaries['E3 Binder'][e3_matches].drop_duplicates(subset=['SMILES'])
|
| 281 |
+
|
| 282 |
+
# # Sort the matches by the number of atoms in the molecule
|
| 283 |
+
# ascending = False if biggest_matches_first else True
|
| 284 |
+
# pois = pois.sort_values(by='Molecule', key=lambda s: s.apply(lambda m: m.GetNumAtoms()), ascending=True)
|
| 285 |
+
# e3s = e3s.sort_values(by='Molecule', key=lambda s: s.apply(lambda m: m.GetNumAtoms()), ascending=True)
|
| 286 |
+
|
| 287 |
+
# Get the POI median, then re-arrenge the pois dataframe so that the median is the first element
|
| 288 |
+
poi_median = pois['Molecule'].apply(lambda x: x.GetNumAtoms()).median()
|
| 289 |
+
pois = pois.sort_values(by='Molecule', key=lambda s: s.apply(lambda m: m.GetNumAtoms()), ascending=True)
|
| 290 |
+
pois = pois.iloc[np.abs(pois['Molecule'].apply(lambda x: x.GetNumAtoms()) - poi_median).argsort()]
|
| 291 |
+
|
| 292 |
+
# Get the E3 median, then re-arrenge the e3s dataframe so that the median is the first element
|
| 293 |
+
e3_median = e3s['Molecule'].apply(lambda x: x.GetNumAtoms()).median()
|
| 294 |
+
e3s = e3s.sort_values(by='Molecule', key=lambda s: s.apply(lambda m: m.GetNumAtoms()), ascending=True)
|
| 295 |
+
e3s = e3s.iloc[np.abs(e3s['Molecule'].apply(lambda x: x.GetNumAtoms()) - e3_median).argsort()]
|
| 296 |
+
|
| 297 |
+
# If any of the substructures is not found, get the matching linkers to be
|
| 298 |
+
# used later (do it only once).
|
| 299 |
+
linkers = None
|
| 300 |
+
if len(pois) == 0 or len(e3s) == 0 or split_with_substr_and_linker_matching:
|
| 301 |
+
matches = dictionaries['Linker with direction']['Molecule'].apply(lambda x: get_substr_match(protac_mol, x, num_allowed_fragments=2))
|
| 302 |
+
linkers = dictionaries['Linker with direction'][matches]
|
| 303 |
+
linkers = linkers.sort_values(by='Molecule', key=lambda s: s.apply(lambda m: m.GetNumAtoms()), ascending=False)
|
| 304 |
+
|
| 305 |
+
# dummy_attachment_id = 1
|
| 306 |
+
# mapping_found = False
|
| 307 |
+
# for _, linker in linkers.iterrows():
|
| 308 |
+
# if mapping_found:
|
| 309 |
+
# break
|
| 310 |
+
# for _, poi in pois.iterrows():
|
| 311 |
+
# if mapping_found:
|
| 312 |
+
# break
|
| 313 |
+
# for _, e3 in e3s.iterrows():
|
| 314 |
+
# if mapping_found:
|
| 315 |
+
# break
|
| 316 |
+
# # Get the replace side chain
|
| 317 |
+
# e3_mapped = Chem.ReplaceSidechains(protac_mol, e3['Molecule'], useChirality=True)
|
| 318 |
+
# e3_mapped = rename_attachment_id(e3_mapped, dummy_attachment_id, e3_attachment_id)
|
| 319 |
+
# if e3_mapped is None:
|
| 320 |
+
# continue
|
| 321 |
+
|
| 322 |
+
# poi_mapped = Chem.ReplaceSidechains(protac_mol, poi['Molecule'], useChirality=True)
|
| 323 |
+
# poi_mapped = rename_attachment_id(poi_mapped, dummy_attachment_id, poi_attachment_id)
|
| 324 |
+
# if poi_mapped is None:
|
| 325 |
+
# continue
|
| 326 |
+
|
| 327 |
+
# # Join the substructures as fragments
|
| 328 |
+
# protac_candidate = canonize('.'.join([linker['SMILES'], e3_mapped, poi_mapped]))
|
| 329 |
+
# protac_candidate = Chem.MolFromSmiles(protac_candidate)
|
| 330 |
+
# protac_candidate = canonize(Chem.molzip(protac_candidate))
|
| 331 |
+
# if check_reassembly(protac_mol, protac_candidate):
|
| 332 |
+
# print('Found a match!')
|
| 333 |
+
# mapping_found = True
|
| 334 |
+
|
| 335 |
+
# # substructs = {
|
| 336 |
+
# # 'linker': linker['Molecule'],
|
| 337 |
+
# # 'e3': e3['Molecule'],
|
| 338 |
+
# # 'poi': poi['Molecule'],
|
| 339 |
+
# # }
|
| 340 |
+
# # mapped_row = get_split_row(row, dictionaries, substructs, poi['SMILES'], e3['SMILES'])
|
| 341 |
+
# # mapped_row['Notes'] = 'Obtained from matching E3, POI, and Linker found in dictionaries.'
|
| 342 |
+
# # return mapped_row
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# TODO: Add a variable to get mapped ligands even if the checks failed... add a note when it happens
|
| 346 |
+
best_substructs_candidate = None
|
| 347 |
+
|
| 348 |
+
# There were matching E3s and matching POIs: try to recover the linker from
|
| 349 |
+
# an unmapped E3 and an unmapped POI.
|
| 350 |
+
if len(e3s) > 0 and len(pois) > 0:
|
| 351 |
+
for _, poi in pois.iterrows():
|
| 352 |
+
for _, e3 in e3s.iterrows():
|
| 353 |
+
additional_notes = '(matching_poi=True)(matching_e3=True)(matching_linker=None)'
|
| 354 |
+
substructs = get_substructs_from_unmapped_e3_poi(protac_smiles, protac_mol, poi['Molecule'], e3['Molecule'])
|
| 355 |
+
|
| 356 |
+
# If the substructure is not found, try to get it from a non-perfect match
|
| 357 |
+
if substructs is None:
|
| 358 |
+
fixed_poi = get_substructure_from_non_perfect_match(protac_mol, poi['Molecule'], poi_attachment_id)
|
| 359 |
+
fixed_e3 = get_substructure_from_non_perfect_match(protac_mol, e3['Molecule'], e3_attachment_id)
|
| 360 |
+
fixed_poi = poi['Molecule'] if fixed_poi is None else fixed_poi
|
| 361 |
+
fixed_e3 = e3['Molecule'] if fixed_e3 is None else fixed_e3
|
| 362 |
+
if fixed_poi is not None and fixed_e3 is not None:
|
| 363 |
+
substructs = get_substructs_from_unmapped_e3_poi(protac_smiles, protac_mol, fixed_poi, fixed_e3)
|
| 364 |
+
if Chem.MolToSmiles(fixed_e3) != e3['SMILES']:
|
| 365 |
+
additional_notes += '(non_perfect_e3_match=True)'
|
| 366 |
+
else:
|
| 367 |
+
additional_notes += '(non_perfect_e3_match=False)'
|
| 368 |
+
|
| 369 |
+
if Chem.MolToSmiles(fixed_poi) != poi['SMILES']:
|
| 370 |
+
additional_notes += '(non_perfect_poi_match=True)'
|
| 371 |
+
else:
|
| 372 |
+
additional_notes += '(non_perfect_poi_match=False)'
|
| 373 |
+
|
| 374 |
+
if substructs is not None:
|
| 375 |
+
size_check = check_substructs_size(protac_mol, substructs, size_perc_threshold=0.7)
|
| 376 |
+
|
| 377 |
+
# Check if the linker is too similar to any of the matching POIs or E3s (use the bulk Tanimoto similarity)
|
| 378 |
+
if not check_linker_similarity(substructs['linker'], pois, e3s, morgan_fp_generator=morgan_fp_generator, e3s_similarity_threshold=similarity_threshold, pois_similarity_threshold=similarity_threshold):
|
| 379 |
+
best_substructs_candidate = substructs
|
| 380 |
+
continue
|
| 381 |
+
|
| 382 |
+
if not size_check and not check_substructs_similarity(protac_fp, substructs, similarity_threshold=similarity_threshold, morgan_fp_generator=morgan_fp_generator):
|
| 383 |
+
best_substructs_candidate = substructs
|
| 384 |
+
# display_mol(protac_mol)
|
| 385 |
+
continue
|
| 386 |
+
|
| 387 |
+
# Fix the bonds close to amide and ester groups, if necessary
|
| 388 |
+
substructs_copy = substructs.copy()
|
| 389 |
+
substructs = adjust_amide_bonds_in_substructs(substructs, protac_smiles)
|
| 390 |
+
# Check and report if any SMILES was changed
|
| 391 |
+
if substructs['linker'] != substructs_copy['linker']:
|
| 392 |
+
additional_notes += '(amide_bonds_fixed=True)'
|
| 393 |
+
else:
|
| 394 |
+
additional_notes += '(amide_bonds_fixed=False)'
|
| 395 |
+
|
| 396 |
+
substructs_copy = substructs.copy()
|
| 397 |
+
substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
|
| 398 |
+
# Check and report if any SMILES was changed
|
| 399 |
+
if substructs['linker'] != substructs_copy['linker']:
|
| 400 |
+
additional_notes += '(ester_bonds_fixed=True)'
|
| 401 |
+
else:
|
| 402 |
+
additional_notes += '(ester_bonds_fixed=False)'
|
| 403 |
+
|
| 404 |
+
# Add the mapped PROTAC to the final list
|
| 405 |
+
mapped_row = get_split_row(row, substructs)
|
| 406 |
+
mapped_row['Notes'] = notes + additional_notes
|
| 407 |
+
return mapped_row
|
| 408 |
+
|
| 409 |
+
# There were no matching POIs, but some E3s and linkers matched: try to
|
| 410 |
+
# recover the E3 from an unmapped POI and a mapped Linker
|
| 411 |
+
if len(e3s) > 0 and split_with_substr_and_linker_matching: # len(pois) == 0 and
|
| 412 |
+
# NOTE: Only take the largest linker(s) into account
|
| 413 |
+
if max_iter_on_linkers:
|
| 414 |
+
selected_linkers = linkers.iloc[:max_iter_on_linkers, :]
|
| 415 |
+
else:
|
| 416 |
+
selected_linkers = linkers.iloc[:1, :]
|
| 417 |
+
|
| 418 |
+
for _, e3 in e3s.iterrows():
|
| 419 |
+
# Adjust the E3 molecule if it is not a perfect match
|
| 420 |
+
e3_mol_fixed = get_substructure_from_non_perfect_match(protac_mol, e3['Molecule'], e3_attachment_id)
|
| 421 |
+
e3_mol = e3['Molecule'] if e3_mol_fixed is None else e3_mol_fixed
|
| 422 |
+
e3_mol = remove_dummy_atoms(e3_mol)
|
| 423 |
+
if Chem.MolToSmiles(e3_mol) != e3['SMILES']:
|
| 424 |
+
non_perfect_e3_match = True
|
| 425 |
+
else:
|
| 426 |
+
non_perfect_e3_match = False
|
| 427 |
+
|
| 428 |
+
for _, linker in selected_linkers.iterrows():
|
| 429 |
+
additional_notes = f'(matching_poi=False)(matching_e3=True)(matching_linker=True)({non_perfect_e3_match=})'
|
| 430 |
+
|
| 431 |
+
substructs = get_substructs_from_substr_and_linker(
|
| 432 |
+
protac_smiles=protac_smiles,
|
| 433 |
+
protac=protac_mol,
|
| 434 |
+
substr=e3_mol,
|
| 435 |
+
linker=linker['Molecule'],
|
| 436 |
+
attachment_id=e3_attachment_id,
|
| 437 |
+
)
|
| 438 |
+
if substructs is not None:
|
| 439 |
+
size_check = check_substructs_size(protac_mol, substructs, size_perc_threshold=0.7)
|
| 440 |
+
|
| 441 |
+
if not check_linker_similarity(substructs['linker'], substructs['poi'], e3s, morgan_fp_generator=morgan_fp_generator, e3s_similarity_threshold=similarity_threshold, pois_similarity_threshold=similarity_threshold):
|
| 442 |
+
best_substructs_candidate = substructs
|
| 443 |
+
continue
|
| 444 |
+
|
| 445 |
+
if not size_check and not check_substructs_similarity(protac_fp, substructs, similarity_threshold=similarity_threshold, morgan_fp_generator=morgan_fp_generator):
|
| 446 |
+
best_substructs_candidate = substructs
|
| 447 |
+
# display_mol(protac_mol)
|
| 448 |
+
continue
|
| 449 |
+
|
| 450 |
+
# Fix the bonds close to amide and ester groups, if necessary
|
| 451 |
+
substructs_copy = substructs.copy()
|
| 452 |
+
substructs = adjust_amide_bonds_in_substructs(substructs, protac_smiles)
|
| 453 |
+
if substructs['linker'] != substructs_copy['linker']:
|
| 454 |
+
additional_notes += '(amide_bonds_fixed=True)'
|
| 455 |
+
else:
|
| 456 |
+
additional_notes += '(amide_bonds_fixed=False)'
|
| 457 |
+
substructs_copy = substructs.copy()
|
| 458 |
+
substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
|
| 459 |
+
if substructs['linker'] != substructs_copy['linker']:
|
| 460 |
+
additional_notes += '(ester_bonds_fixed=True)'
|
| 461 |
+
else:
|
| 462 |
+
additional_notes += '(ester_bonds_fixed=False)'
|
| 463 |
+
|
| 464 |
+
mapped_row = get_split_row(row, substructs)
|
| 465 |
+
mapped_row['Notes'] = notes + additional_notes
|
| 466 |
+
return mapped_row
|
| 467 |
+
|
| 468 |
+
# Swap the attachment points on the linker and try again
|
| 469 |
+
linker_swapped = swap_attachment_points(linker['SMILES'])
|
| 470 |
+
substructs = get_substructs_from_substr_and_linker(
|
| 471 |
+
protac_smiles=protac_smiles,
|
| 472 |
+
protac=protac_mol,
|
| 473 |
+
substr=e3_mol,
|
| 474 |
+
linker=Chem.MolFromSmiles(linker_swapped),
|
| 475 |
+
attachment_id=e3_attachment_id,
|
| 476 |
+
)
|
| 477 |
+
additional_notes += '(attachment_points_swapped_in_linker=True)'
|
| 478 |
+
if substructs is not None:
|
| 479 |
+
|
| 480 |
+
size_check = check_substructs_size(protac_mol, substructs, size_perc_threshold=0.7)
|
| 481 |
+
|
| 482 |
+
if not check_linker_similarity(substructs['linker'], substructs['poi'], e3s, morgan_fp_generator=morgan_fp_generator, e3s_similarity_threshold=similarity_threshold, pois_similarity_threshold=similarity_threshold):
|
| 483 |
+
continue
|
| 484 |
+
|
| 485 |
+
if not size_check and not check_substructs_similarity(protac_fp, substructs, similarity_threshold=similarity_threshold, morgan_fp_generator=morgan_fp_generator):
|
| 486 |
+
# display_mol(protac_mol)
|
| 487 |
+
continue
|
| 488 |
+
|
| 489 |
+
# Fix the bonds close to amide and ester groups, if necessary
|
| 490 |
+
substructs_copy = substructs.copy()
|
| 491 |
+
substructs = adjust_amide_bonds_in_substructs(substructs, protac_smiles)
|
| 492 |
+
if substructs['linker'] != substructs_copy['linker']:
|
| 493 |
+
additional_notes += '(amide_bonds_fixed=True)'
|
| 494 |
+
else:
|
| 495 |
+
additional_notes += '(amide_bonds_fixed=False)'
|
| 496 |
+
substructs_copy = substructs.copy()
|
| 497 |
+
substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
|
| 498 |
+
if substructs['linker'] != substructs_copy['linker']:
|
| 499 |
+
additional_notes += '(ester_bonds_fixed=True)'
|
| 500 |
+
|
| 501 |
+
mapped_row = get_split_row(row, substructs)
|
| 502 |
+
mapped_row['Notes'] = notes + additional_notes
|
| 503 |
+
return mapped_row
|
| 504 |
+
|
| 505 |
+
# There were no matching E3s, but some POIs and linkers matched: try to
|
| 506 |
+
# recover the POI from an unmapped E3 and a mapped Linker
|
| 507 |
+
if len(pois) > 0 and split_with_substr_and_linker_matching: # and len(e3s) == 0
|
| 508 |
+
# NOTE: Only take the largest linker(s) into account
|
| 509 |
+
if max_iter_on_linkers:
|
| 510 |
+
selected_linkers = linkers.iloc[:max_iter_on_linkers, :]
|
| 511 |
+
else:
|
| 512 |
+
selected_linkers = linkers.iloc[:1, :]
|
| 513 |
+
|
| 514 |
+
for _, poi in pois.iterrows():
|
| 515 |
+
poi_mol = get_substructure_from_non_perfect_match(protac_mol, poi['Molecule'], poi_attachment_id)
|
| 516 |
+
poi_mol = poi['Molecule'] if poi_mol is None else poi_mol
|
| 517 |
+
poi_mol = remove_dummy_atoms(poi_mol)
|
| 518 |
+
if Chem.MolToSmiles(poi_mol) != poi['SMILES']:
|
| 519 |
+
non_perfect_poi_match = True
|
| 520 |
+
else:
|
| 521 |
+
non_perfect_poi_match = False
|
| 522 |
+
|
| 523 |
+
for _, linker in selected_linkers.iterrows():
|
| 524 |
+
additional_notes = f'(matching_poi=True)(matching_e3=False)(matching_linker=True)({non_perfect_poi_match=})'
|
| 525 |
+
|
| 526 |
+
substructs = get_substructs_from_substr_and_linker(
|
| 527 |
+
protac_smiles=protac_smiles,
|
| 528 |
+
protac=protac_mol,
|
| 529 |
+
substr=poi_mol,
|
| 530 |
+
linker=linker['Molecule'],
|
| 531 |
+
attachment_id=poi_attachment_id,
|
| 532 |
+
)
|
| 533 |
+
if substructs is not None:
|
| 534 |
+
size_check = check_substructs_size(protac_mol, substructs, size_perc_threshold=0.7)
|
| 535 |
+
|
| 536 |
+
if not check_linker_similarity(substructs['linker'], pois, substructs['e3'], morgan_fp_generator=morgan_fp_generator, e3s_similarity_threshold=similarity_threshold, pois_similarity_threshold=similarity_threshold):
|
| 537 |
+
best_substructs_candidate = substructs
|
| 538 |
+
continue
|
| 539 |
+
|
| 540 |
+
if not size_check and not check_substructs_similarity(protac_fp, substructs, similarity_threshold=similarity_threshold, morgan_fp_generator=morgan_fp_generator):
|
| 541 |
+
best_substructs_candidate = substructs
|
| 542 |
+
# display_mol(protac_mol)
|
| 543 |
+
continue
|
| 544 |
+
|
| 545 |
+
# Fix the bonds close to amide and ester groups, if necessary
|
| 546 |
+
substructs_copy = substructs.copy()
|
| 547 |
+
substructs = adjust_amide_bonds_in_substructs(substructs, protac_smiles)
|
| 548 |
+
if substructs['linker'] != substructs_copy['linker']:
|
| 549 |
+
additional_notes += '(amide_bonds_fixed=True)'
|
| 550 |
+
substructs_copy = substructs.copy()
|
| 551 |
+
substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
|
| 552 |
+
if substructs['linker'] != substructs_copy['linker']:
|
| 553 |
+
additional_notes += '(ester_bonds_fixed=True)'
|
| 554 |
+
|
| 555 |
+
mapped_row = get_split_row(row, substructs)
|
| 556 |
+
mapped_row['Notes'] = notes + additional_notes
|
| 557 |
+
return mapped_row
|
| 558 |
+
|
| 559 |
+
# Swap the attachment points on the linker and try again
|
| 560 |
+
linker_swapped = swap_attachment_points(linker['SMILES'])
|
| 561 |
+
substructs = get_substructs_from_substr_and_linker(
|
| 562 |
+
protac_smiles=protac_smiles,
|
| 563 |
+
protac=protac_mol,
|
| 564 |
+
substr=poi_mol,
|
| 565 |
+
linker=Chem.MolFromSmiles(linker_swapped),
|
| 566 |
+
attachment_id=poi_attachment_id,
|
| 567 |
+
)
|
| 568 |
+
additional_notes += '(attachment_points_swapped_in_linker=True)'
|
| 569 |
+
if substructs is not None:
|
| 570 |
+
|
| 571 |
+
size_check = check_substructs_size(protac_mol, substructs, size_perc_threshold=0.7)
|
| 572 |
+
|
| 573 |
+
if not check_linker_similarity(substructs['linker'], substructs['poi'], e3s, morgan_fp_generator=morgan_fp_generator, e3s_similarity_threshold=similarity_threshold, pois_similarity_threshold=similarity_threshold):
|
| 574 |
+
best_substructs_candidate = substructs
|
| 575 |
+
continue
|
| 576 |
+
|
| 577 |
+
if not size_check and not check_substructs_similarity(protac_fp, substructs, similarity_threshold=similarity_threshold, morgan_fp_generator=morgan_fp_generator):
|
| 578 |
+
best_substructs_candidate = substructs
|
| 579 |
+
# display_mol(protac_mol)
|
| 580 |
+
continue
|
| 581 |
+
|
| 582 |
+
# Fix the bonds close to amide and ester groups, if necessary
|
| 583 |
+
substructs_copy = substructs.copy()
|
| 584 |
+
substructs = adjust_amide_bonds_in_substructs(substructs, protac_smiles)
|
| 585 |
+
if substructs['linker'] != substructs_copy['linker']:
|
| 586 |
+
additional_notes += '(amide_bonds_fixed=True)'
|
| 587 |
+
substructs_copy = substructs.copy()
|
| 588 |
+
substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
|
| 589 |
+
if substructs['linker'] != substructs_copy['linker']:
|
| 590 |
+
additional_notes += '(ester_bonds_fixed=True)'
|
| 591 |
+
|
| 592 |
+
mapped_row = get_split_row(row, substructs)
|
| 593 |
+
mapped_row['Notes'] = notes + additional_notes
|
| 594 |
+
return mapped_row
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
# Get all substructure matches in the Linker with direction dictionary
|
| 598 |
+
# NOTE: This code is repeated here for performance reasons, to avoid
|
| 599 |
+
# calculating the matches if not needed.
|
| 600 |
+
if linkers is None and max_iter_on_linkers:
|
| 601 |
+
matches = dictionaries['Linker with direction']['Molecule'].apply(lambda x: get_substr_match(protac_mol, x, num_allowed_fragments=2))
|
| 602 |
+
linkers = dictionaries['Linker with direction'][matches]
|
| 603 |
+
# Sort all the matches by the number of atoms in the linker, the biggest first
|
| 604 |
+
linkers = linkers.sort_values(by='Molecule', key=lambda s: s.apply(lambda m: m.GetNumAtoms()), ascending=False)
|
| 605 |
+
|
| 606 |
+
# for j, (_, linker) in enumerate(linkers.iterrows()):
|
| 607 |
+
# additional_notes = '(matching_poi=False)(matching_e3=False)(matching_linker=True)'
|
| 608 |
+
# if j >= max_iter_on_linkers or max_iter_on_linkers == 0:
|
| 609 |
+
# return None
|
| 610 |
+
|
| 611 |
+
for j in range(max_iter_on_linkers):
|
| 612 |
+
additional_notes = '(matching_poi=False)(matching_e3=False)(matching_linker=True)'
|
| 613 |
+
linker = linkers.iloc[j, :]
|
| 614 |
+
substructs = get_substructs_from_mapped_linker(protac_smiles, linker['SMILES'])
|
| 615 |
+
|
| 616 |
+
if substructs is not None:
|
| 617 |
+
if not check_linker_similarity(substructs['linker'], substructs['poi'], substructs['e3'], morgan_fp_generator=morgan_fp_generator, e3s_similarity_threshold=similarity_threshold, pois_similarity_threshold=similarity_threshold):
|
| 618 |
+
best_substructs_candidate = substructs
|
| 619 |
+
continue
|
| 620 |
+
|
| 621 |
+
size_check = check_substructs_size(protac_mol, substructs, size_perc_threshold=0.7)
|
| 622 |
+
if not size_check and not check_substructs_similarity(protac_fp, substructs, similarity_threshold=similarity_threshold, morgan_fp_generator=morgan_fp_generator):
|
| 623 |
+
best_substructs_candidate = substructs
|
| 624 |
+
# display_mol(protac_mol)
|
| 625 |
+
continue
|
| 626 |
+
|
| 627 |
+
# Fix the bonds close to amide and ester groups, if necessary
|
| 628 |
+
substructs_copy = substructs.copy()
|
| 629 |
+
substructs = adjust_amide_bonds_in_substructs(substructs, protac_smiles)
|
| 630 |
+
if substructs['linker'] != substructs_copy['linker']:
|
| 631 |
+
additional_notes += '(amide_bonds_fixed=True)'
|
| 632 |
+
substructs_copy = substructs.copy()
|
| 633 |
+
substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
|
| 634 |
+
if substructs['linker'] != substructs_copy['linker']:
|
| 635 |
+
additional_notes += '(ester_bonds_fixed=True)'
|
| 636 |
+
|
| 637 |
+
if not check_substructs_size(protac_mol, substructs, size_perc_threshold=0.95):
|
| 638 |
+
best_substructs_candidate = substructs
|
| 639 |
+
continue
|
| 640 |
+
|
| 641 |
+
mapped_row = get_split_row(row, substructs)
|
| 642 |
+
mapped_row['Notes'] = notes + additional_notes
|
| 643 |
+
return mapped_row
|
| 644 |
+
|
| 645 |
+
# If we are here, it means that the substructures found in the above loops
|
| 646 |
+
# failed the similarity checks. We add a note and return the best
|
| 647 |
+
# substructure candidate found.
|
| 648 |
+
if best_substructs_candidate is not None:
|
| 649 |
+
substructs_copy = substructs.copy()
|
| 650 |
+
substructs = adjust_amide_bonds_in_substructs(best_substructs_candidate, protac_smiles)
|
| 651 |
+
if substructs['linker'] != best_substructs_candidate['linker']:
|
| 652 |
+
notes += '(amide_bonds_fixed=True)'
|
| 653 |
+
substructs_copy = substructs.copy()
|
| 654 |
+
substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
|
| 655 |
+
if substructs['linker'] != substructs_copy['linker']:
|
| 656 |
+
notes += '(ester_bonds_fixed=True)'
|
| 657 |
+
mapped_row = get_split_row(row, substructs)
|
| 658 |
+
mapped_row['Notes'] = notes + '(similarity_checks_failed=True)'
|
| 659 |
+
return mapped_row
|
| 660 |
+
|
| 661 |
+
return None
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def split_protacs(
|
| 665 |
+
protac_df: pd.DataFrame,
|
| 666 |
+
dictionaries: Dict[str, pd.DataFrame],
|
| 667 |
+
max_iter_on_linkers: int = 0,
|
| 668 |
+
split_with_substr_and_linker_matching: bool = False,
|
| 669 |
+
biggest_matches_first: bool = True,
|
| 670 |
+
update_dict_if_ids_not_found: bool = False,
|
| 671 |
+
use_multiprocessing: bool = False,
|
| 672 |
+
) -> pd.DataFrame:
|
| 673 |
+
""" Maps PROTACs to their substructures.
|
| 674 |
+
|
| 675 |
+
Args:
|
| 676 |
+
protac_df: The input PROTAC dataframe.
|
| 677 |
+
dictionaries: The input dictionaries.
|
| 678 |
+
max_iter_on_linkers: The maximum number of matching linkers to iterate over. If zero, there will be no attempt to match linkers in the dictionary. If negative, iterate over all matched linkers. Default is 0.
|
| 679 |
+
biggest_matches_first: Whether to sort the matches by the number of atoms in the molecule. Default is True.
|
| 680 |
+
update_dict_if_ids_not_found: DEPRECATED. Whether to update the dictionary if the substructure IDs are not found. Default is False.
|
| 681 |
+
use_multiprocessing: Whether to use multiprocessing. Default is False.
|
| 682 |
+
|
| 683 |
+
Returns:
|
| 684 |
+
The mapped PROTAC dataframe.
|
| 685 |
+
"""
|
| 686 |
+
# if use_multiprocessing:
|
| 687 |
+
# global split_single_protac
|
| 688 |
+
|
| 689 |
+
# with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
|
| 690 |
+
# results = pool.map(partial(split_single_protac, dictionaries=dictionaries, biggest_matches_first=biggest_matches_first, max_iter_on_linkers=max_iter_on_linkers), protac_df.copy().to_dict(orient='records'))
|
| 691 |
+
|
| 692 |
+
# mapped_protacs = pd.DataFrame(results)
|
| 693 |
+
# mapped_protacs = mapped_protacs.dropna(subset=['POI Ligand SMILES with direction', 'E3 Binder SMILES with direction', 'Linker SMILES with direction'])
|
| 694 |
+
# return mapped_protacs
|
| 695 |
+
|
| 696 |
+
if use_multiprocessing:
|
| 697 |
+
# TODO: The following does run in parallel, but it gives wrong results. I don't know why. I will have to investigate further.
|
| 698 |
+
results = Parallel(n_jobs=-1)(delayed(split_single_protac)(row, dictionaries=dictionaries, biggest_matches_first=biggest_matches_first, max_iter_on_linkers=max_iter_on_linkers) for _, row in protac_df.iterrows())
|
| 699 |
+
mapped_protacs = pd.DataFrame([r for r in results if r is not None])
|
| 700 |
+
return mapped_protacs
|
| 701 |
+
|
| 702 |
+
mapped_protacs = []
|
| 703 |
+
for i, row in (pbar := tqdm(protac_df.iterrows(), total=len(protac_df))):
|
| 704 |
+
pbar.set_description(f'PROTAC n.{i:4d}')
|
| 705 |
+
|
| 706 |
+
r = split_single_protac(
|
| 707 |
+
row,
|
| 708 |
+
dictionaries,
|
| 709 |
+
biggest_matches_first=biggest_matches_first,
|
| 710 |
+
max_iter_on_linkers=max_iter_on_linkers,
|
| 711 |
+
split_with_substr_and_linker_matching=split_with_substr_and_linker_matching,
|
| 712 |
+
)
|
| 713 |
+
if r is not None:
|
| 714 |
+
mapped_protacs.append(r)
|
| 715 |
+
tmp = pd.DataFrame(mapped_protacs)
|
| 716 |
+
pbar.set_postfix({'len_mapped': len(tmp), 'perc_mapped': f'{len(tmp) / len(protac_df):.1%}'})
|
| 717 |
+
|
| 718 |
+
mapped_protacs = pd.DataFrame(mapped_protacs)
|
| 719 |
+
return mapped_protacs
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
def parse_notes(notes: str) -> Dict[str, Any]:
|
| 723 |
+
# Define the regex pattern to match key-value pairs within parentheses
|
| 724 |
+
pattern = r'\(([^=]+)=([^\)]+)\)'
|
| 725 |
+
|
| 726 |
+
# Find all matches in the string
|
| 727 |
+
matches = re.findall(pattern, notes)
|
| 728 |
+
|
| 729 |
+
# Initialize an empty dictionary to store the parsed key-value pairs
|
| 730 |
+
parsed_dict = {}
|
| 731 |
+
|
| 732 |
+
# Iterate over the matches and add them to the dictionary
|
| 733 |
+
for key, value in matches:
|
| 734 |
+
# Convert the value to the appropriate type (int, bool, None, or str)
|
| 735 |
+
if value.isdigit():
|
| 736 |
+
parsed_dict[key] = int(value)
|
| 737 |
+
elif value.lower() == 'true':
|
| 738 |
+
parsed_dict[key] = True
|
| 739 |
+
elif value.lower() == 'false':
|
| 740 |
+
parsed_dict[key] = False
|
| 741 |
+
elif value.lower() == 'none':
|
| 742 |
+
parsed_dict[key] = None
|
| 743 |
+
else:
|
| 744 |
+
parsed_dict[key] = value
|
| 745 |
+
|
| 746 |
+
return parsed_dict
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
def iterative_protac_splitting(
|
| 750 |
+
dictionaries: Dict[str, pd.DataFrame],
|
| 751 |
+
data_dir: str,
|
| 752 |
+
) -> Dict[str, pd.DataFrame]:
|
| 753 |
+
""" Map PROTACs to their substructures in an iterative way.
|
| 754 |
+
|
| 755 |
+
Args:
|
| 756 |
+
dictionaries: The input dictionaries. The same format as the output of the `update_dictionary` function.
|
| 757 |
+
data_dir: The directory where the output data is stored.
|
| 758 |
+
|
| 759 |
+
Returns:
|
| 760 |
+
The final mapped PROTAC dataframe.
|
| 761 |
+
"""
|
| 762 |
+
|
| 763 |
+
final_df = None
|
| 764 |
+
non_mapped_protacs = dictionaries['PROTAC'].copy()
|
| 765 |
+
|
| 766 |
+
start_from_beginning = True # Re-map all PROTACs ignoring loading previous results
|
| 767 |
+
step = -1
|
| 768 |
+
max_iter_on_linkers = 0
|
| 769 |
+
split_with_substr_and_linker_matching = False
|
| 770 |
+
|
| 771 |
+
while True:
|
| 772 |
+
if max_iter_on_linkers == -1 or non_mapped_protacs.empty or step >= 50:
|
| 773 |
+
break
|
| 774 |
+
|
| 775 |
+
if max_iter_on_linkers == 5:
|
| 776 |
+
max_iter_on_linkers = -1 # Iterate over all linkers
|
| 777 |
+
|
| 778 |
+
step += 1
|
| 779 |
+
print('-' * 100)
|
| 780 |
+
print(f'Step n.{step}')
|
| 781 |
+
print(f'Max iterations on linkers: {max_iter_on_linkers}')
|
| 782 |
+
print(f'Map with substr and linker matching: {split_with_substr_and_linker_matching}')
|
| 783 |
+
print('-' * 50)
|
| 784 |
+
|
| 785 |
+
step_filename = os.path.join(data_dir, f'mapped_protacs_{step=}.csv')
|
| 786 |
+
final_filename = os.path.join(data_dir, 'mapped_protacs.csv')
|
| 787 |
+
non_mapped_filename = os.path.join(data_dir, 'non_mapped_protacs.csv')
|
| 788 |
+
|
| 789 |
+
if os.path.exists(step_filename) and not start_from_beginning:
|
| 790 |
+
# Check if all lines of the file are empty
|
| 791 |
+
with open(step_filename, 'r') as f:
|
| 792 |
+
lines = f.readlines()
|
| 793 |
+
if all([len(line.strip()) == 0 for line in lines]):
|
| 794 |
+
mapped_protacs = pd.DataFrame()
|
| 795 |
+
else:
|
| 796 |
+
mapped_protacs = pd.read_csv(step_filename)
|
| 797 |
+
else:
|
| 798 |
+
mapped_protacs = split_protacs(
|
| 799 |
+
non_mapped_protacs,
|
| 800 |
+
dictionaries=dictionaries,
|
| 801 |
+
split_with_substr_and_linker_matching=split_with_substr_and_linker_matching,
|
| 802 |
+
max_iter_on_linkers=max_iter_on_linkers,
|
| 803 |
+
biggest_matches_first=False,
|
| 804 |
+
use_multiprocessing=False,
|
| 805 |
+
)
|
| 806 |
+
# Add a string at the end of the strings in the 'Notes' column
|
| 807 |
+
if not mapped_protacs.empty:
|
| 808 |
+
mapped_protacs['Notes'] = mapped_protacs['Notes'].apply(lambda x: f'{x}({step=})')
|
| 809 |
+
mapped_protacs.to_csv(step_filename, index=False)
|
| 810 |
+
|
| 811 |
+
# Update the final dataframe and save it to file
|
| 812 |
+
if final_df is None:
|
| 813 |
+
final_df = mapped_protacs
|
| 814 |
+
else:
|
| 815 |
+
final_df = pd.concat([final_df, mapped_protacs], axis=0).drop_duplicates(subset=['PROTAC SMILES'])
|
| 816 |
+
final_df.to_csv(final_filename, index=False)
|
| 817 |
+
print(f'All mapped PROTACs saved to: {final_filename}')
|
| 818 |
+
|
| 819 |
+
# Reporting information
|
| 820 |
+
mapped_perc = len(mapped_protacs) / len(non_mapped_protacs)
|
| 821 |
+
total_mapped_perc = len(final_df) / len(dictionaries['PROTAC'])
|
| 822 |
+
print(f'Number of mapped PROTACs: {len(mapped_protacs)} ({mapped_perc:.2%})')
|
| 823 |
+
print(f'Total num. of mapped PROTACs: {len(final_df)} ({total_mapped_perc:.2%})')
|
| 824 |
+
print('-' * 50)
|
| 825 |
+
print(final_df['Notes'].value_counts())
|
| 826 |
+
print('-' * 50)
|
| 827 |
+
|
| 828 |
+
# Get the non-mapped PROTACs yet and save them to file
|
| 829 |
+
non_mapped_protacs = dictionaries['PROTAC'][~dictionaries['PROTAC']['SMILES'].isin(final_df['PROTAC SMILES'])].copy()
|
| 830 |
+
non_mapped_protacs[['SMILES', 'ID']].to_csv(non_mapped_filename, index=False)
|
| 831 |
+
print(f'Non-mapped PROTACs saved to: {non_mapped_filename}')
|
| 832 |
+
|
| 833 |
+
# Control logic for breaking the loop
|
| 834 |
+
if mapped_protacs.empty:
|
| 835 |
+
if max_iter_on_linkers == 0 and not split_with_substr_and_linker_matching:
|
| 836 |
+
split_with_substr_and_linker_matching = True
|
| 837 |
+
continue
|
| 838 |
+
else:
|
| 839 |
+
max_iter_on_linkers += 1
|
| 840 |
+
continue
|
| 841 |
+
else:
|
| 842 |
+
# Using only the linker to map the PROTACs can be unreliable, so if we
|
| 843 |
+
# found new PROTACs, we should the max_iter_on_linkers to zero and try
|
| 844 |
+
# to map the PROTACs again with the newly found substructures.
|
| 845 |
+
max_iter_on_linkers = 0
|
| 846 |
+
split_with_substr_and_linker_matching = False
|
| 847 |
+
|
| 848 |
+
# Update all dictionaries with the substructures of the mapped PROTACs
|
| 849 |
+
smiles_list = mapped_protacs['Linker SMILES with direction'].unique()
|
| 850 |
+
smiles_list = [canonize(smiles) for smiles in smiles_list]
|
| 851 |
+
dictionaries['Linker with direction'] = update_dictionary(dictionaries['Linker with direction'], smiles_list)
|
| 852 |
+
|
| 853 |
+
# Avoid adding POIs that are in the E3 dictionary!
|
| 854 |
+
smiles_list = mapped_protacs['POI Ligand SMILES'].unique()
|
| 855 |
+
smiles_list = [canonize(smiles) for smiles in smiles_list]
|
| 856 |
+
smiles_list = [s for s in smiles_list if s not in dictionaries['E3 Binder']['SMILES'].values]
|
| 857 |
+
|
| 858 |
+
smiles_list = [remove_dummy_atoms(s) for s in smiles_list if s is not None]
|
| 859 |
+
|
| 860 |
+
# Use Tanimoto similarity to prevent adding POIs too similar to E3s
|
| 861 |
+
similarity_threshold = 0.5
|
| 862 |
+
radius = 2
|
| 863 |
+
nbits = 2048
|
| 864 |
+
morgan_fp_generator = Chem.rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=nbits, useBondTypes=True, includeChirality=True)
|
| 865 |
+
|
| 866 |
+
pois_to_add = []
|
| 867 |
+
for poi_smiles in smiles_list:
|
| 868 |
+
poi_mol = Chem.MolFromSmiles(poi_smiles)
|
| 869 |
+
poi_fp = morgan_fp_generator.GetFingerprint(poi_mol)
|
| 870 |
+
similarities = DataStructs.BulkTanimotoSimilarity(poi_fp, dictionaries['E3 Binder']['FP'].to_list())
|
| 871 |
+
skip_poi = False
|
| 872 |
+
for sim in similarities:
|
| 873 |
+
if sim >= similarity_threshold:
|
| 874 |
+
skip_poi = True
|
| 875 |
+
break
|
| 876 |
+
if not skip_poi:
|
| 877 |
+
pois_to_add.append(poi_smiles)
|
| 878 |
+
|
| 879 |
+
dictionaries['POI Ligand'] = update_dictionary(dictionaries['POI Ligand'], smiles_list)
|
| 880 |
+
|
| 881 |
+
# Avoid adding E3s that are in the POI dictionary!
|
| 882 |
+
smiles_list = mapped_protacs['E3 Binder SMILES'].unique()
|
| 883 |
+
smiles_list = [canonize(smiles) for smiles in smiles_list]
|
| 884 |
+
smiles_list = [s for s in smiles_list if s not in dictionaries['POI Ligand']['SMILES'].values]
|
| 885 |
+
smiles_list = [remove_dummy_atoms(s) for s in smiles_list if s is not None]
|
| 886 |
+
dictionaries['E3 Binder'] = update_dictionary(dictionaries['E3 Binder'], smiles_list)
|
| 887 |
+
|
| 888 |
+
# Save all dictionaries to file
|
| 889 |
+
for key, dictionary in dictionaries.items():
|
| 890 |
+
filename = os.path.join(data_dir, f'dictionary_{key.lower().replace(" ", "_")}.csv')
|
| 891 |
+
dictionary[['ID', 'SMILES']].to_csv(filename, index=False)
|
| 892 |
+
print(f'Dictionary saved to: {filename}')
|
| 893 |
+
|
| 894 |
+
return dictionaries
|
protac_splitter/data/curation/mapping_utils.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rdkit import Chem
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
from protac_splitter.chemoinformatics import (
|
| 5 |
+
canonize_smiles,
|
| 6 |
+
remove_stereo,
|
| 7 |
+
get_mol_id,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
def update_dictionary(
|
| 11 |
+
dictionary: pd.DataFrame,
|
| 12 |
+
substr_to_add: list,
|
| 13 |
+
morgan_fp_generator = None,
|
| 14 |
+
verbose: int = 0,
|
| 15 |
+
) -> pd.DataFrame:
|
| 16 |
+
""" Updates a dictionary with a list of additional substructures.
|
| 17 |
+
|
| 18 |
+
The dictionary is a dataframe with columns 'SMILES', 'Molecule', 'ID', and 'FP'.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
dictionary: The input dictionary dataframe.
|
| 22 |
+
substr_to_add: The list of additional substructures.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
The updated dictionary dataframe.
|
| 26 |
+
"""
|
| 27 |
+
# Canonize the SMILES strings
|
| 28 |
+
substr_to_add = [canonize_smiles(smiles) for smiles in substr_to_add if smiles is not None]
|
| 29 |
+
substr_to_add = list(set(substr_to_add))
|
| 30 |
+
|
| 31 |
+
# Remove entries already in the dictionary
|
| 32 |
+
for smiles in substr_to_add:
|
| 33 |
+
if not dictionary.empty and smiles in dictionary[f'SMILES'].unique().tolist():
|
| 34 |
+
if verbose > 1:
|
| 35 |
+
print(f'\tWARNING. SMILES already in the dictionary: {smiles}')
|
| 36 |
+
# Remove it from the list
|
| 37 |
+
substr_to_add.remove(smiles)
|
| 38 |
+
|
| 39 |
+
new_entries = []
|
| 40 |
+
for smiles in substr_to_add:
|
| 41 |
+
try:
|
| 42 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 43 |
+
except Exception as e:
|
| 44 |
+
if verbose:
|
| 45 |
+
print(e)
|
| 46 |
+
mol = None
|
| 47 |
+
# Remove entries that result in invalid molecules
|
| 48 |
+
if mol is None:
|
| 49 |
+
continue
|
| 50 |
+
new_entries.append({
|
| 51 |
+
'SMILES': smiles,
|
| 52 |
+
'Molecule': mol,
|
| 53 |
+
'ID': get_mol_id(smiles),
|
| 54 |
+
})
|
| 55 |
+
# Try adding its no-stereochemistry version as well
|
| 56 |
+
smiles_nostereo = remove_stereo(smiles)
|
| 57 |
+
if smiles_nostereo is not None and smiles_nostereo != smiles:
|
| 58 |
+
mol_nostereo = Chem.MolFromSmiles(smiles_nostereo)
|
| 59 |
+
if mol_nostereo is not None:
|
| 60 |
+
new_entries.append({
|
| 61 |
+
'SMILES': canonize_smiles(smiles_nostereo),
|
| 62 |
+
'Molecule': mol_nostereo,
|
| 63 |
+
'ID': get_mol_id(smiles_nostereo),
|
| 64 |
+
})
|
| 65 |
+
new_entries = pd.DataFrame(new_entries).drop_duplicates()
|
| 66 |
+
|
| 67 |
+
if len(new_entries) > 0:
|
| 68 |
+
# Add fingerprints to the new entries
|
| 69 |
+
if morgan_fp_generator is None:
|
| 70 |
+
morgan_fp_generator = Chem.rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048, useBondTypes=True, includeChirality=True)
|
| 71 |
+
|
| 72 |
+
new_entries['FP'] = new_entries['Molecule'].apply(lambda x: morgan_fp_generator.GetFingerprint(x) if x is not None else None)
|
| 73 |
+
if verbose:
|
| 74 |
+
print(f'Number of substructures added to the dictionary: {len(new_entries)}')
|
| 75 |
+
|
| 76 |
+
# Return the updated dictionary
|
| 77 |
+
return pd.concat([dictionary, pd.DataFrame(new_entries)], axis=0).drop_duplicates(subset='SMILES').reset_index(drop=True)
|
protac_splitter/data/curation/substructure_extraction.py
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Any, Dict, List, Optional, Union
|
| 3 |
+
from collections import Counter
|
| 4 |
+
|
| 5 |
+
from rdkit import Chem
|
| 6 |
+
from rdkit.Chem import Draw
|
| 7 |
+
|
| 8 |
+
from protac_splitter.chemoinformatics import (
|
| 9 |
+
dummy2query,
|
| 10 |
+
remove_dummy_atoms,
|
| 11 |
+
canonize,
|
| 12 |
+
canonize_smiles,
|
| 13 |
+
GetSubstructMatchesWithTimeout,
|
| 14 |
+
)
|
| 15 |
+
from protac_splitter.display_utils import (
|
| 16 |
+
safe_display,
|
| 17 |
+
display_mol,
|
| 18 |
+
)
|
| 19 |
+
from protac_splitter.evaluation import check_reassembly
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_substructs_from_mapped_linker(
|
| 23 |
+
protac_smiles: str,
|
| 24 |
+
linker_smiles: str,
|
| 25 |
+
e3_attachment_id: int = 2,
|
| 26 |
+
poi_attachment_id: int = 1,
|
| 27 |
+
verbose: int = 0,
|
| 28 |
+
) -> Dict[str, str]:
|
| 29 |
+
""" Get the substructures of a PROTAC molecule from a mapped linker SMILES.
|
| 30 |
+
|
| 31 |
+
This function will return the substructures given a linker with
|
| 32 |
+
directionality, _i.e._, with the two attachment points mapped.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
protac_smiles: The SMILES of the PROTAC molecule.
|
| 36 |
+
linker_smiles: The SMILES of the linker molecule. Must have attachment points.
|
| 37 |
+
verbose: Verbosity level.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
A dictionary with the substructure names as keys ('e3', 'linker', and 'poi') and their SMILES as values. None if the matching fails.
|
| 41 |
+
"""
|
| 42 |
+
protac_smiles = canonize_smiles(protac_smiles)
|
| 43 |
+
linker_smiles = canonize_smiles(linker_smiles)
|
| 44 |
+
|
| 45 |
+
protac_mol = Chem.MolFromSmiles(protac_smiles)
|
| 46 |
+
linker_mol = Chem.MolFromSmiles(linker_smiles)
|
| 47 |
+
|
| 48 |
+
# Check if the linker is a substructure of the PROTAC
|
| 49 |
+
if not protac_mol.HasSubstructMatch(dummy2query(linker_mol), useChirality=True):
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
# Split the big molecule into the two fragments
|
| 53 |
+
frags = Chem.ReplaceCore(protac_mol, dummy2query(linker_mol), labelByIndex=True, replaceDummies=False)
|
| 54 |
+
if frags is None:
|
| 55 |
+
return None
|
| 56 |
+
try:
|
| 57 |
+
frags = Chem.GetMolFrags(frags, asMols=True, sanitizeFrags=True)
|
| 58 |
+
except Exception as e:
|
| 59 |
+
# print(e)
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
if verbose:
|
| 63 |
+
safe_display(protac_mol)
|
| 64 |
+
safe_display(linker_mol)
|
| 65 |
+
|
| 66 |
+
# The linker has a map number at its attachment points: the following is a
|
| 67 |
+
# dictionary that maps the atom index of the attachment points to their
|
| 68 |
+
# respective map numbers, i.e., the attachment IDs.
|
| 69 |
+
linker_idx2map = {}
|
| 70 |
+
for atom in linker_mol.GetAtoms():
|
| 71 |
+
if atom.GetAtomicNum() == 0:
|
| 72 |
+
linker_idx2map[atom.GetIdx()] = atom.GetAtomMapNum()
|
| 73 |
+
if verbose:
|
| 74 |
+
print(f'linker indexes: {linker_idx2map}')
|
| 75 |
+
print('-' * 80)
|
| 76 |
+
|
| 77 |
+
substructs = {'linker': linker_smiles}
|
| 78 |
+
|
| 79 |
+
# After splitting the PROTAC with ReplaceCore, the fragments will have as
|
| 80 |
+
# attachment points the same atom indexes as the linker. We can then use the
|
| 81 |
+
# map numbers from the linker to identify the attachment points in the
|
| 82 |
+
# PROTAC fragments and assign the correct map number to them, i.e., the
|
| 83 |
+
# attachment ID.
|
| 84 |
+
for i, side_mol in enumerate(frags):
|
| 85 |
+
|
| 86 |
+
side_smiles = Chem.MolToSmiles(side_mol, canonical=True)
|
| 87 |
+
|
| 88 |
+
# Use a regex to get the number in the pattern, e.g., [9*], in the SMILES
|
| 89 |
+
attachment_point = re.findall(r'\[(\d+)\*\]', side_smiles)
|
| 90 |
+
if attachment_point:
|
| 91 |
+
attachment_point = int(attachment_point[0])
|
| 92 |
+
else:
|
| 93 |
+
attachment_point = None
|
| 94 |
+
|
| 95 |
+
if verbose:
|
| 96 |
+
print(f'Side {i + 1} SMILES: {side_smiles}')
|
| 97 |
+
print(f'Attachment point: {attachment_point}')
|
| 98 |
+
safe_display(side_mol)
|
| 99 |
+
|
| 100 |
+
# Get the map from the linker
|
| 101 |
+
linker_attachment_point = linker_idx2map.get(attachment_point, None)
|
| 102 |
+
|
| 103 |
+
# Modify the SMILES to include the map number
|
| 104 |
+
if linker_attachment_point is not None:
|
| 105 |
+
side_smiles = re.sub(r'\[(\d+)\*\]', f'[*:{linker_attachment_point}]', side_smiles)
|
| 106 |
+
if f'[*:{e3_attachment_id}]' in side_smiles:
|
| 107 |
+
substructs['e3'] = canonize_smiles(side_smiles)
|
| 108 |
+
elif f'[*:{poi_attachment_id}]' in side_smiles:
|
| 109 |
+
substructs['poi'] = canonize_smiles(side_smiles)
|
| 110 |
+
|
| 111 |
+
if verbose:
|
| 112 |
+
print(f'Modified SMILES: {side_smiles}')
|
| 113 |
+
safe_display(Chem.MolFromSmiles(side_smiles))
|
| 114 |
+
|
| 115 |
+
# Canonize the substructures SMILES
|
| 116 |
+
substructs = {k: canonize_smiles(v) for k, v in substructs.items()}
|
| 117 |
+
|
| 118 |
+
# Check that the reassembled PROTAC matches the original PROTAC
|
| 119 |
+
if not check_reassembly(protac_smiles, '.'.join(substructs.values())):
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
return substructs
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_attachment_bonds(mol: Chem.Mol, match_atoms: List[int]) -> List[int]:
|
| 126 |
+
""" Get the bonds to break to separate the substructure from the PROTAC or R-groups molecule.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
mol: The molecule to break, i.e., the PROTAC.
|
| 130 |
+
match_atoms: The atoms matched in the PROTAC molecule, from the GetSubstructMatch function.
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
List[int]: The bond indices to break.
|
| 134 |
+
"""
|
| 135 |
+
bonds_to_break = []
|
| 136 |
+
for idx in match_atoms:
|
| 137 |
+
atom = mol.GetAtomWithIdx(idx)
|
| 138 |
+
# Skip non-heavy atoms
|
| 139 |
+
if atom.GetAtomicNum() == 1:
|
| 140 |
+
continue
|
| 141 |
+
for bond in atom.GetBonds():
|
| 142 |
+
neighbor_idx = bond.GetOtherAtomIdx(idx)
|
| 143 |
+
# Skip if the neighbor atom if non-heavy
|
| 144 |
+
if mol.GetAtomWithIdx(neighbor_idx).GetAtomicNum() == 1:
|
| 145 |
+
continue
|
| 146 |
+
if neighbor_idx not in match_atoms:
|
| 147 |
+
bonds_to_break.append(bond.GetIdx())
|
| 148 |
+
# If more than one bond is found, e.g., if the substructure is
|
| 149 |
+
# connected to the PROTAC/R-groups in multiple places like in a
|
| 150 |
+
# ring, reset list of bonds and go to the next atom.
|
| 151 |
+
if len(bonds_to_break) > 1:
|
| 152 |
+
bonds_to_break = []
|
| 153 |
+
break
|
| 154 |
+
return bonds_to_break
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_substructs_from_unmapped_e3_poi(
|
| 158 |
+
protac_smiles: str,
|
| 159 |
+
mol_protac: Chem.Mol,
|
| 160 |
+
mol_poi: Chem.Mol,
|
| 161 |
+
mol_e3: Chem.Mol,
|
| 162 |
+
poi_attachment_id: int = 1,
|
| 163 |
+
e3_attachment_id: int = 2,
|
| 164 |
+
verbose: int = 0,
|
| 165 |
+
stats: Counter = None,
|
| 166 |
+
) -> Optional[Dict[str, str]]:
|
| 167 |
+
""" Get the matches of the POI, E3, and linker in the PROTAC molecule.
|
| 168 |
+
|
| 169 |
+
This function will return the substructures given a PROTAC and its unmapped
|
| 170 |
+
POI and E3 ligand substructures, _i.e._, they do not need to have the
|
| 171 |
+
attachment points in their SMILES strings.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
mol_protac: The PROTAC molecule.
|
| 175 |
+
mol_poi: The POI ligand molecule. Must NOT contain the attachment point.
|
| 176 |
+
mol_e3: The E3 binder molecule. Must NOT contain the attachment point.
|
| 177 |
+
verbose: The verbosity level.
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Dict: The matches of the POI, E3, and linker in the PROTAC molecule. None if no match is found.
|
| 181 |
+
"""
|
| 182 |
+
if verbose:
|
| 183 |
+
safe_display(mol_protac)
|
| 184 |
+
|
| 185 |
+
poi_match = mol_protac.GetSubstructMatch(mol_poi, useChirality=True)
|
| 186 |
+
|
| 187 |
+
# Get bonds to break to separate the POI ligand
|
| 188 |
+
bonds_to_break_poi = get_attachment_bonds(mol_protac, poi_match)
|
| 189 |
+
|
| 190 |
+
# Return if no bonds are found
|
| 191 |
+
if len(bonds_to_break_poi) != 1:
|
| 192 |
+
if stats is not None:
|
| 193 |
+
stats['multiple POI attachment bonds'] += 1
|
| 194 |
+
if verbose:
|
| 195 |
+
print('ERROR: Multiple POI attachment bonds')
|
| 196 |
+
return None
|
| 197 |
+
|
| 198 |
+
# Break the bonds to isolate the POI ligand
|
| 199 |
+
frag_mol_poi = Chem.FragmentOnBonds(mol_protac, bonds_to_break_poi, addDummies=True, dummyLabels=[(poi_attachment_id, poi_attachment_id)])
|
| 200 |
+
|
| 201 |
+
# Get the fragments resulting from bond breaking
|
| 202 |
+
try:
|
| 203 |
+
frags = Chem.GetMolFrags(frag_mol_poi, asMols=True, sanitizeFrags=True)
|
| 204 |
+
except Exception as e:
|
| 205 |
+
print(e)
|
| 206 |
+
return None
|
| 207 |
+
|
| 208 |
+
# Identify the POI ligand fragment
|
| 209 |
+
poi_fragment = None
|
| 210 |
+
for frag in frags:
|
| 211 |
+
if frag.HasSubstructMatch(mol_poi):
|
| 212 |
+
poi_fragment = frag
|
| 213 |
+
break
|
| 214 |
+
if poi_fragment is None:
|
| 215 |
+
if stats is not None:
|
| 216 |
+
stats['POI fragment not found'] += 1
|
| 217 |
+
if verbose:
|
| 218 |
+
print('ERROR: POI fragment not found')
|
| 219 |
+
return None
|
| 220 |
+
|
| 221 |
+
# Combine the remaining fragments to get the R-groups
|
| 222 |
+
# TODO: Check that the length of frags is 1, otherwise, there are multiple fragments
|
| 223 |
+
r_group_mol = [frag for frag in frags if frag != poi_fragment]
|
| 224 |
+
if len(r_group_mol) != 1:
|
| 225 |
+
if stats is not None:
|
| 226 |
+
stats['multiple POI fragments'] += 1
|
| 227 |
+
if verbose:
|
| 228 |
+
for frag in frags:
|
| 229 |
+
safe_display(frag)
|
| 230 |
+
print('ERROR: Multiple POI fragments')
|
| 231 |
+
return None
|
| 232 |
+
r_group_mol = r_group_mol[0]
|
| 233 |
+
|
| 234 |
+
if verbose:
|
| 235 |
+
print('POI:', Chem.MolToSmiles(poi_fragment, canonical=True))
|
| 236 |
+
safe_display(poi_fragment)
|
| 237 |
+
|
| 238 |
+
e3_match = r_group_mol.GetSubstructMatch(mol_e3, useChirality=True)
|
| 239 |
+
|
| 240 |
+
# Get bonds to break to isolate the E3 binder
|
| 241 |
+
bonds_to_break_e3 = get_attachment_bonds(r_group_mol, e3_match)
|
| 242 |
+
|
| 243 |
+
# Return if no bonds are found
|
| 244 |
+
if len(bonds_to_break_e3) != 1:
|
| 245 |
+
if stats is not None:
|
| 246 |
+
stats['multiple E3 attachment bonds'] += 1
|
| 247 |
+
if verbose:
|
| 248 |
+
safe_display(r_group_mol)
|
| 249 |
+
print('ERROR: Multiple E3 attachment bonds')
|
| 250 |
+
return None
|
| 251 |
+
|
| 252 |
+
# Break the bonds to isolate the E3 binder
|
| 253 |
+
frag_mol_e3 = Chem.FragmentOnBonds(r_group_mol, bonds_to_break_e3, addDummies=True, dummyLabels=[(e3_attachment_id, e3_attachment_id)])
|
| 254 |
+
|
| 255 |
+
# Get fragments after breaking bonds in R-groups
|
| 256 |
+
try:
|
| 257 |
+
frags = Chem.GetMolFrags(frag_mol_e3, asMols=True, sanitizeFrags=True)
|
| 258 |
+
except Exception as e:
|
| 259 |
+
print(e)
|
| 260 |
+
return None
|
| 261 |
+
|
| 262 |
+
# Identify the E3 binder fragment
|
| 263 |
+
e3_fragment = None
|
| 264 |
+
for frag in frags:
|
| 265 |
+
if frag.HasSubstructMatch(mol_e3):
|
| 266 |
+
e3_fragment = frag
|
| 267 |
+
break
|
| 268 |
+
if e3_fragment is None:
|
| 269 |
+
if stats is not None:
|
| 270 |
+
stats['E3 fragment not found'] += 1
|
| 271 |
+
if verbose:
|
| 272 |
+
print('ERROR: E3 fragment not found')
|
| 273 |
+
return None
|
| 274 |
+
|
| 275 |
+
if verbose:
|
| 276 |
+
print('E3:', Chem.MolToSmiles(e3_fragment, canonical=True))
|
| 277 |
+
safe_display(e3_fragment)
|
| 278 |
+
|
| 279 |
+
# The remaining fragment is the linker
|
| 280 |
+
# TODO: Check that the length of frags is 1, otherwise, there are multiple fragments
|
| 281 |
+
linker_mol = [frag for frag in frags if frag != e3_fragment]
|
| 282 |
+
if len(linker_mol) != 1:
|
| 283 |
+
if stats is not None:
|
| 284 |
+
stats['multiple E3 fragments'] += 1
|
| 285 |
+
if verbose:
|
| 286 |
+
for frag in frags:
|
| 287 |
+
safe_display(frag)
|
| 288 |
+
print('ERROR: Multiple E3 fragments')
|
| 289 |
+
return None
|
| 290 |
+
linker_mol = linker_mol[0]
|
| 291 |
+
|
| 292 |
+
poi_smiles = Chem.MolToSmiles(poi_fragment, canonical=True).replace(f'[{poi_attachment_id}*]', f'[*:{poi_attachment_id}]')
|
| 293 |
+
e3_smiles = Chem.MolToSmiles(e3_fragment, canonical=True).replace(f'[{e3_attachment_id}*]', f'[*:{e3_attachment_id}]')
|
| 294 |
+
linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True).replace(f'[{poi_attachment_id}*]', f'[*:{poi_attachment_id}]').replace(f'[{e3_attachment_id}*]', f'[*:{e3_attachment_id}]')
|
| 295 |
+
|
| 296 |
+
# Get the substructure names and canonize their SMILES
|
| 297 |
+
substructs = {'poi': poi_smiles, 'e3': e3_smiles, 'linker': linker_smiles}
|
| 298 |
+
substructs = {k: canonize_smiles(v) for k, v in substructs.items()}
|
| 299 |
+
|
| 300 |
+
if verbose:
|
| 301 |
+
print('Linker:', Chem.MolToSmiles(linker_mol, canonical=True))
|
| 302 |
+
safe_display(linker_mol)
|
| 303 |
+
|
| 304 |
+
# Check that the reassembled PROTAC matches the original PROTAC
|
| 305 |
+
if check_reassembly(protac_smiles, '.'.join(substructs.values())):
|
| 306 |
+
return substructs
|
| 307 |
+
|
| 308 |
+
if stats is not None:
|
| 309 |
+
stats['reassembling failed'] += 1
|
| 310 |
+
if verbose:
|
| 311 |
+
print('ERROR: Reassembling failed')
|
| 312 |
+
return None
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def get_substructure_from_non_perfect_match(
|
| 316 |
+
protac_mol: Chem.Mol,
|
| 317 |
+
substruct_mol: Chem.Mol,
|
| 318 |
+
attachment_id: int,
|
| 319 |
+
verbose: int = 0,
|
| 320 |
+
) -> Chem.Mol:
|
| 321 |
+
""" Extract the correct substructure from a PROTAC molecule, given the
|
| 322 |
+
SMILES of a wrong substructure resulting in many fragments and matches.
|
| 323 |
+
|
| 324 |
+
Sometimes the substructure we have is not a _perfect_ substructure of the
|
| 325 |
+
PROTAC, _i.e._, it will generate more than two fragments when trying to
|
| 326 |
+
replace the PROTAC core with it. In this case, this function will perform
|
| 327 |
+
the following steps:
|
| 328 |
+
|
| 329 |
+
1. Get the largest fragment by trying to replace the PROTAC core with the
|
| 330 |
+
substructure. This largest fragment will be the other substructure plus
|
| 331 |
+
the linker.
|
| 332 |
+
2. We can now remove the largest fragment from the PROTAC to get the
|
| 333 |
+
"original" substructure without the smaller dangling fragments.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
protac_mol (Chem.Mol): The PROTAC molecule.
|
| 337 |
+
substruct_smiles (Chem.Mol): The molecule of the wrong substructure, either the POI ligand or the E3 binder.
|
| 338 |
+
attachment_id (int): The attachment ID.
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
Chem.Mol: The extracted substructure molecule. If failing, it will return None.
|
| 342 |
+
"""
|
| 343 |
+
# Remove the substructure, even if there are "dangling" fragments, to obtain: PROTAC - substruct = (POI + Linker) + remainders
|
| 344 |
+
linker_and_other_mol = Chem.DeleteSubstructs(protac_mol, substruct_mol, useChirality=True)
|
| 345 |
+
|
| 346 |
+
# Get the largest fragment, i.e., the PROTAC - substruct = POI + Linker
|
| 347 |
+
try:
|
| 348 |
+
fragments = Chem.GetMolFrags(linker_and_other_mol, asMols=True)
|
| 349 |
+
except Exception as e:
|
| 350 |
+
if verbose:
|
| 351 |
+
print(e)
|
| 352 |
+
return None
|
| 353 |
+
|
| 354 |
+
if len(fragments) == 1:
|
| 355 |
+
if verbose:
|
| 356 |
+
print("WARNING. There are no small fragments, there's only one fragment.")
|
| 357 |
+
|
| 358 |
+
if not fragments:
|
| 359 |
+
if verbose:
|
| 360 |
+
print('ERROR. No fragments found.')
|
| 361 |
+
return None
|
| 362 |
+
largest_fragment = max(fragments, key=lambda x: x.GetNumAtoms())
|
| 363 |
+
|
| 364 |
+
# Get the match of the largest fragment in the PROTAC molecule
|
| 365 |
+
largest_match = protac_mol.GetSubstructMatch(largest_fragment, useChirality=True)
|
| 366 |
+
|
| 367 |
+
# Get bonds to break to isolate the substructure, i.e., the opposite of the POI + Linker
|
| 368 |
+
bonds_to_break = get_attachment_bonds(protac_mol, largest_match)
|
| 369 |
+
|
| 370 |
+
if len(bonds_to_break) != 1:
|
| 371 |
+
if verbose:
|
| 372 |
+
print(f'ERROR. The bond to break is not a single one: {bonds_to_break}')
|
| 373 |
+
return None
|
| 374 |
+
|
| 375 |
+
# Break the bonds to isolate the substructure
|
| 376 |
+
frag_mol_substruct = Chem.FragmentOnBonds(protac_mol, bonds_to_break, addDummies=True, dummyLabels=[(attachment_id, attachment_id)])
|
| 377 |
+
|
| 378 |
+
# Get fragments after breaking bonds, i.e., the POI + Linker and the substructure without "remainders"
|
| 379 |
+
try:
|
| 380 |
+
frags = Chem.GetMolFrags(frag_mol_substruct, asMols=True, sanitizeFrags=True)
|
| 381 |
+
except Exception as e:
|
| 382 |
+
if verbose:
|
| 383 |
+
print(e)
|
| 384 |
+
return None
|
| 385 |
+
|
| 386 |
+
# Get the smallest between the substructure and the POI+Linker fragments
|
| 387 |
+
substruct_mol = min(frags, key=lambda x: x.GetNumAtoms())
|
| 388 |
+
substruct_smiles = Chem.MolToSmiles(substruct_mol, canonical=True).replace(f'[{attachment_id}*]', f'[*:{attachment_id}]')
|
| 389 |
+
substruct_mol = Chem.MolFromSmiles(canonize(substruct_smiles))
|
| 390 |
+
|
| 391 |
+
# Check that the substructure matches in the PROTAC molecule
|
| 392 |
+
if not protac_mol.HasSubstructMatch(dummy2query(substruct_mol), useChirality=True):
|
| 393 |
+
if verbose:
|
| 394 |
+
print('ERROR. Substructure does not match in PROTAC molecule:')
|
| 395 |
+
print('PROTAC molecule:')
|
| 396 |
+
safe_display(protac_mol)
|
| 397 |
+
print('Substructure molecule:')
|
| 398 |
+
safe_display(substruct_mol)
|
| 399 |
+
return None
|
| 400 |
+
|
| 401 |
+
return substruct_mol
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def get_mapped_substr_from_protac(
|
| 405 |
+
protac: Chem.Mol,
|
| 406 |
+
substr: Chem.Mol,
|
| 407 |
+
attachment_id: int = 1,
|
| 408 |
+
) -> Optional[Chem.Mol]:
|
| 409 |
+
""" Get the mapped substructure from a PROTAC molecule and an unmapped substructure.
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
protac: The PROTAC molecule.
|
| 413 |
+
substr: The unmapped substructure.
|
| 414 |
+
attachment_id: The attachment point ID to be assigned to the substructure.
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
The mapped substructure molecule. None if the function fails to find the substructure.
|
| 418 |
+
"""
|
| 419 |
+
num_matches = len(protac.GetSubstructMatches(substr, useChirality=True))
|
| 420 |
+
if num_matches != 1:
|
| 421 |
+
return None
|
| 422 |
+
other_substr = Chem.ReplaceCore(protac, substr, labelByIndex=False, replaceDummies=False)
|
| 423 |
+
if other_substr is None:
|
| 424 |
+
return None
|
| 425 |
+
mapped_substr = Chem.ReplaceCore(protac, remove_dummy_atoms(other_substr), labelByIndex=False, replaceDummies=False)
|
| 426 |
+
if mapped_substr is None:
|
| 427 |
+
return None
|
| 428 |
+
mapped_smiles = Chem.MolToSmiles(mapped_substr, canonical=True)
|
| 429 |
+
# Replace "[1*]" or "[2*]" with the correct attachment point with a regex
|
| 430 |
+
mapped_smiles = re.sub(r'\[(\d+)\*\]', f'[*:{attachment_id}]', mapped_smiles)
|
| 431 |
+
mapped_smiles = canonize(mapped_smiles)
|
| 432 |
+
if mapped_smiles is None:
|
| 433 |
+
return None
|
| 434 |
+
return Chem.MolFromSmiles(mapped_smiles)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def get_substructs_from_substr_and_linker(
|
| 438 |
+
protac_smiles: str,
|
| 439 |
+
protac: Chem.Mol,
|
| 440 |
+
substr: Chem.Mol,
|
| 441 |
+
linker: Chem.Mol,
|
| 442 |
+
attachment_id: int = 1,
|
| 443 |
+
poi_attachment_id: int = 1,
|
| 444 |
+
e3_attachment_id: int = 2,
|
| 445 |
+
verbose: int = 0,
|
| 446 |
+
stats: Counter = None,
|
| 447 |
+
) -> Optional[Dict[str, str]]:
|
| 448 |
+
""" Get the substructures of a PROTAC molecule from an unmapped substructure and linker.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
protac_smiles: The SMILES of the PROTAC molecule.
|
| 452 |
+
protac: The RDKit molecule object of the PROTAC.
|
| 453 |
+
substr: The RDKit molecule object of the currently matching substructure. Should be UNMAPPED.
|
| 454 |
+
linker: The RDKit molecule object of the linker.
|
| 455 |
+
attachment_id: The attachment point ID of the currently matching substructure.
|
| 456 |
+
verbose: The verbosity level.
|
| 457 |
+
|
| 458 |
+
Returns:
|
| 459 |
+
Dict: The substructures of the PROTAC molecule. None if the function fails to find the substructures.
|
| 460 |
+
"""
|
| 461 |
+
if attachment_id not in [poi_attachment_id, e3_attachment_id]:
|
| 462 |
+
raise ValueError('Attachment ID must be either 1 or 2')
|
| 463 |
+
|
| 464 |
+
if substr is None:
|
| 465 |
+
return None
|
| 466 |
+
|
| 467 |
+
subr_matches = list(protac.GetSubstructMatches(substr, useChirality=True))
|
| 468 |
+
if len(subr_matches) != 1:
|
| 469 |
+
if stats is not None:
|
| 470 |
+
stats['multiple substructure matches'] += 1
|
| 471 |
+
if verbose:
|
| 472 |
+
print('ERROR: Multiple substructure matches')
|
| 473 |
+
return None
|
| 474 |
+
subr_match = subr_matches[0]
|
| 475 |
+
|
| 476 |
+
mapped_substr = get_mapped_substr_from_protac(protac, substr, attachment_id)
|
| 477 |
+
if mapped_substr is None:
|
| 478 |
+
if stats is not None:
|
| 479 |
+
stats['mapped substructure not found'] += 1
|
| 480 |
+
if verbose:
|
| 481 |
+
print('ERROR: Mapped substructure not found')
|
| 482 |
+
return None
|
| 483 |
+
|
| 484 |
+
linker_matches = protac.GetSubstructMatches(remove_dummy_atoms(linker), useChirality=True)
|
| 485 |
+
for linker_match in linker_matches:
|
| 486 |
+
# Check that the intersection between the substructure and the linker
|
| 487 |
+
# matches is only one atom, i.e., the attachment point
|
| 488 |
+
if len(set(subr_match).intersection(linker_match)) == 1:
|
| 489 |
+
linker_match = linker_match
|
| 490 |
+
break
|
| 491 |
+
|
| 492 |
+
# Based on the linker match found, remove it from the PROTAC
|
| 493 |
+
emol = Chem.EditableMol(protac)
|
| 494 |
+
|
| 495 |
+
# Remove atoms in descending order of their indices
|
| 496 |
+
for idx in sorted(linker_match, reverse=True):
|
| 497 |
+
emol.RemoveAtom(idx)
|
| 498 |
+
# Get the modified molecule
|
| 499 |
+
try:
|
| 500 |
+
protac_fragments = emol.GetMol()
|
| 501 |
+
except Exception as e:
|
| 502 |
+
if verbose:
|
| 503 |
+
print(e)
|
| 504 |
+
return None
|
| 505 |
+
try:
|
| 506 |
+
Chem.SanitizeMol(protac_fragments)
|
| 507 |
+
except Exception as e:
|
| 508 |
+
if verbose:
|
| 509 |
+
print(e)
|
| 510 |
+
return None
|
| 511 |
+
if verbose:
|
| 512 |
+
img = Draw.MolToImage(protac_fragments, highlightAtoms=linker_match, size=(800, 300))
|
| 513 |
+
safe_display(img)
|
| 514 |
+
|
| 515 |
+
# Get the fragments after removing the linker
|
| 516 |
+
try:
|
| 517 |
+
fragments = Chem.GetMolFrags(protac_fragments, asMols=True, sanitizeFrags=True)
|
| 518 |
+
except Exception as e:
|
| 519 |
+
if verbose:
|
| 520 |
+
print(e)
|
| 521 |
+
return None
|
| 522 |
+
|
| 523 |
+
if len(fragments) != 2:
|
| 524 |
+
if stats is not None:
|
| 525 |
+
stats['multiple fragments after removing the linker'] += 1
|
| 526 |
+
if verbose:
|
| 527 |
+
for frag in fragments:
|
| 528 |
+
safe_display(frag)
|
| 529 |
+
print('ERROR: Multiple fragments after removing the linker')
|
| 530 |
+
return None
|
| 531 |
+
|
| 532 |
+
substructs = {}
|
| 533 |
+
substructs['linker'] = Chem.MolToSmiles(linker, canonical=True)
|
| 534 |
+
for frag in fragments:
|
| 535 |
+
if frag.HasSubstructMatch(substr, useChirality=True):
|
| 536 |
+
label = 'e3' if attachment_id == e3_attachment_id else 'poi'
|
| 537 |
+
substructs[label] = Chem.MolToSmiles(mapped_substr, canonical=True)
|
| 538 |
+
# Replace "[1*]" or "[2*]" with the correct attachment point with a regex
|
| 539 |
+
substructs[label] = re.sub(r'\[(\d+)\*\]', f'[*:{attachment_id}]', substructs[label])
|
| 540 |
+
if verbose:
|
| 541 |
+
print(f'Found {label.capitalize()} fragment.')
|
| 542 |
+
img = Draw.MolToImage(Chem.MolFromSmiles(substructs[label]), size=(800, 300))
|
| 543 |
+
safe_display(img)
|
| 544 |
+
else:
|
| 545 |
+
label = 'e3' if attachment_id == poi_attachment_id else 'poi'
|
| 546 |
+
other_attachment_id = e3_attachment_id if label == 'e3' else poi_attachment_id
|
| 547 |
+
|
| 548 |
+
other_substr = get_mapped_substr_from_protac(protac, frag, other_attachment_id)
|
| 549 |
+
if other_substr is None:
|
| 550 |
+
return None
|
| 551 |
+
substructs[label] = Chem.MolToSmiles(other_substr, canonical=True)
|
| 552 |
+
|
| 553 |
+
if verbose:
|
| 554 |
+
print(f'Found {label.capitalize()} fragment.')
|
| 555 |
+
img = Draw.MolToImage(Chem.MolFromSmiles(substructs[label]), size=(800, 300))
|
| 556 |
+
safe_display(img)
|
| 557 |
+
# Canonicalize the SMILES strings
|
| 558 |
+
substructs = {k: canonize(v) for k, v in substructs.items()}
|
| 559 |
+
|
| 560 |
+
# Check that the reassembled PROTAC matches the original PROTAC
|
| 561 |
+
if not check_reassembly(protac_smiles, '.'.join(substructs.values()), stats, verbose):
|
| 562 |
+
return None
|
| 563 |
+
|
| 564 |
+
return substructs
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def swap_attachment_points(
|
| 568 |
+
s: str,
|
| 569 |
+
poi_attachment_id: int = 1,
|
| 570 |
+
e3_attachment_id: int = 2,
|
| 571 |
+
) -> str:
|
| 572 |
+
""" Swaps the attachment points in a SMARTS string.
|
| 573 |
+
|
| 574 |
+
Args:
|
| 575 |
+
s: The input SMARTS string.
|
| 576 |
+
|
| 577 |
+
Returns:
|
| 578 |
+
The SMARTS string with the attachment points swapped.
|
| 579 |
+
"""
|
| 580 |
+
tmp_e3_id = '^^^^E3^^^^'
|
| 581 |
+
tmp_poi_id = '^^^^POI^^^^'
|
| 582 |
+
s = s.replace(f'[*:{poi_attachment_id}]', f'[*:{tmp_poi_id}]')
|
| 583 |
+
s = s.replace(f'[*:{e3_attachment_id}]', f'[*:{tmp_e3_id}]')
|
| 584 |
+
s = s.replace(f'[*:{tmp_poi_id}]', f'[*:{e3_attachment_id}]')
|
| 585 |
+
s = s.replace(f'[*:{tmp_e3_id}]', f'[*:{poi_attachment_id}]')
|
| 586 |
+
return canonize(s)
|
protac_splitter/data/generation/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .generation import generate_protacs
|
| 2 |
+
from .functional_groups import (
|
| 3 |
+
get_functional_group_at_attachment,
|
| 4 |
+
get_functional_groups_distributions,
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
'generate_protacs',
|
| 9 |
+
'get_functional_group_at_attachment',
|
| 10 |
+
'get_functional_groups_distributions',
|
| 11 |
+
]
|
protac_splitter/data/generation/functional_groups.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional, Union
|
| 2 |
+
from collections import defaultdict, Counter
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from rdkit import Chem
|
| 7 |
+
from rdkit.Chem import Draw
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from protac_splitter.chemoinformatics import (
|
| 11 |
+
get_atom_idx_at_attachment,
|
| 12 |
+
canonize_smarts,
|
| 13 |
+
)
|
| 14 |
+
from protac_splitter.display_utils import (
|
| 15 |
+
safe_display,
|
| 16 |
+
display_mol,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_functional_group_at_attachment(
|
| 21 |
+
protac: Chem.Mol,
|
| 22 |
+
substruct: Chem.Mol,
|
| 23 |
+
linker: Chem.Mol,
|
| 24 |
+
n_hops: int = 1,
|
| 25 |
+
timeout: Optional[Union[int, float]] = None,
|
| 26 |
+
return_dict: bool = False,
|
| 27 |
+
verbose: int = 0,
|
| 28 |
+
) -> Union[str, Dict[str, str]]:
|
| 29 |
+
""" Get the functional group at the attachment point of a substructure in the PROTAC molecule.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
protac: The PROTAC molecule.
|
| 33 |
+
substruct: The substructure of the PROTAC that contains the attachment point, e.g., the POI or E3 ligase.
|
| 34 |
+
linker: The linker molecule.
|
| 35 |
+
n_hops: The number of hops to consider for the neighborhood.
|
| 36 |
+
timeout: The timeout for the substructure search.
|
| 37 |
+
return_dict: Whether to return the functional groups as a dictionary.
|
| 38 |
+
verbose: Verbosity level.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
str | Dict[str, str]: The SMARTS of the functional group at the attachment point. If return_dict is True, a dictionary with the SMARTS of the functional groups at the attachment point and at the "two sides" of the attachment point (keys: 'attachment', 'substruct', 'linker').
|
| 42 |
+
"""
|
| 43 |
+
protac = Chem.AddHs(protac)
|
| 44 |
+
substruct = Chem.AddHs(substruct)
|
| 45 |
+
|
| 46 |
+
if linker is not None:
|
| 47 |
+
linker = Chem.AddHs(linker)
|
| 48 |
+
|
| 49 |
+
attachment_idxs = get_atom_idx_at_attachment(
|
| 50 |
+
protac=protac,
|
| 51 |
+
substruct=substruct,
|
| 52 |
+
linker=linker,
|
| 53 |
+
timeout=timeout,
|
| 54 |
+
return_dict=True,
|
| 55 |
+
verbose=0,
|
| 56 |
+
)
|
| 57 |
+
# Get all neighboring atoms that are n_hops away from the attachment point
|
| 58 |
+
if attachment_idxs is None:
|
| 59 |
+
return None
|
| 60 |
+
if len(attachment_idxs) != 2:
|
| 61 |
+
return None
|
| 62 |
+
if verbose:
|
| 63 |
+
print(f'Attachment points: {attachment_idxs}')
|
| 64 |
+
img = Draw.MolToImage(protac, highlightAtoms=attachment_idxs.values(), size=(800, 500))
|
| 65 |
+
safe_display(img)
|
| 66 |
+
print('Neighbors:')
|
| 67 |
+
|
| 68 |
+
# Recursively find neighbors at n_hops distance
|
| 69 |
+
neighborhood = set([protac.GetAtomWithIdx(idx) for idx in attachment_idxs.values()])
|
| 70 |
+
def find_neighbors(atom, hops, excluded_atom_idx=None):
|
| 71 |
+
if hops <= 0:
|
| 72 |
+
return
|
| 73 |
+
for neighbor in atom.GetNeighbors():
|
| 74 |
+
if excluded_atom_idx is not None and neighbor.GetIdx() == excluded_atom_idx:
|
| 75 |
+
neighborhood.add(neighbor)
|
| 76 |
+
continue
|
| 77 |
+
neighborhood.add(neighbor)
|
| 78 |
+
find_neighbors(neighbor, hops - 1)
|
| 79 |
+
|
| 80 |
+
for idx in attachment_idxs.values():
|
| 81 |
+
find_neighbors(protac.GetAtomWithIdx(idx), n_hops)
|
| 82 |
+
|
| 83 |
+
# Display the neighborhood
|
| 84 |
+
if verbose:
|
| 85 |
+
print(f'Neighbors at {n_hops} hops:')
|
| 86 |
+
# Get options to display all hydrogen atoms
|
| 87 |
+
options = Draw.DrawingOptions()
|
| 88 |
+
# Add a legend to the image
|
| 89 |
+
options.legend = 'Neighbors at attachment points'
|
| 90 |
+
img = Draw.MolToImage(protac, highlightAtoms=[a.GetIdx() for a in neighborhood], size=(800, 500), options=options)
|
| 91 |
+
safe_display(img)
|
| 92 |
+
|
| 93 |
+
# # NOTE: The following is an overkill, there is an RDKit function to extract a substructure
|
| 94 |
+
# neighborhood_mol = extract_atoms_as_molecule(protac, [a.GetIdx() for a in neighborhood])
|
| 95 |
+
# neighborhood_smarts = canonize_smarts(Chem.MolToSmarts(neighborhood_mol))
|
| 96 |
+
|
| 97 |
+
# Extract the SMARTS given the atom indices of the neighborhood
|
| 98 |
+
neighborhood_idxs = [a.GetIdx() for a in neighborhood]
|
| 99 |
+
neighborhood_smarts = Chem.MolFragmentToSmarts(protac, neighborhood_idxs)
|
| 100 |
+
neighborhood_smarts = canonize_smarts(neighborhood_smarts)
|
| 101 |
+
|
| 102 |
+
if verbose:
|
| 103 |
+
print(neighborhood_smarts)
|
| 104 |
+
display_mol(Chem.MolFromSmarts(neighborhood_smarts), display_svg=False)
|
| 105 |
+
|
| 106 |
+
if return_dict:
|
| 107 |
+
smarts = {}
|
| 108 |
+
smarts['attachment'] = neighborhood_smarts
|
| 109 |
+
# Get the SMARTS at the attachment point and at its "two sides"
|
| 110 |
+
for side, idx in attachment_idxs.items():
|
| 111 |
+
# NOTE: We know that attachment_idxs is a dictionary with two keys,
|
| 112 |
+
# 'susbtruct' and 'linker', so we can directly use the other key
|
| 113 |
+
other_side = 'linker' if side == 'substruct' else 'substruct'
|
| 114 |
+
excluded_atom_idx = attachment_idxs[other_side]
|
| 115 |
+
neighborhood = {protac.GetAtomWithIdx(idx)}
|
| 116 |
+
find_neighbors(protac.GetAtomWithIdx(idx), n_hops, excluded_atom_idx=excluded_atom_idx)
|
| 117 |
+
|
| 118 |
+
# Get the atom indices of the neighborhood
|
| 119 |
+
neighborhood_idxs = [a.GetIdx() for a in neighborhood]
|
| 120 |
+
|
| 121 |
+
# Copy the PROTAC molecule and set the excluded_atom_idx to a dummy
|
| 122 |
+
p = Chem.Mol(protac)
|
| 123 |
+
p.GetAtomWithIdx(excluded_atom_idx).SetAtomicNum(0)
|
| 124 |
+
|
| 125 |
+
# Extract the SMARTS from the copied PROTAC given the indeces
|
| 126 |
+
s = Chem.MolFragmentToSmarts(p, neighborhood_idxs)
|
| 127 |
+
smarts[other_side] = canonize_smarts(s)
|
| 128 |
+
|
| 129 |
+
return smarts
|
| 130 |
+
|
| 131 |
+
return neighborhood_smarts
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_functional_group_at_attachment_side(
|
| 135 |
+
substruct: Chem.Mol,
|
| 136 |
+
attachment_id: Optional[int] = None,
|
| 137 |
+
n_hops: int = 2,
|
| 138 |
+
add_Hs: bool = True,
|
| 139 |
+
) -> Optional[str]:
|
| 140 |
+
""" Get the functional group at the attachment point of a substructure in the PROTAC molecule.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
substruct: The substructure of the PROTAC that contains the attachment point, e.g., the POI or E3 ligase.
|
| 144 |
+
attachment_id: The attachment point ID in the substructure. E.g., 1 for the POI, as in "[*:1]".
|
| 145 |
+
n_hops: The number of hops to consider for the neighborhood. Default is 2.
|
| 146 |
+
add_Hs: Whether to add hydrogens to the substructure.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
str: The SMARTS of the functional group at the attachment point. None if failed.
|
| 150 |
+
"""
|
| 151 |
+
if add_Hs:
|
| 152 |
+
substruct = Chem.AddHs(substruct)
|
| 153 |
+
|
| 154 |
+
# Get the atom index of the attachment point, i.e., a dummy atom
|
| 155 |
+
attachment_idx2map = {}
|
| 156 |
+
for atom in substruct.GetAtoms():
|
| 157 |
+
if atom.GetAtomicNum() == 0:
|
| 158 |
+
# Get the mapped atom index
|
| 159 |
+
attachment_idx2map[atom.GetIdx()] = atom.GetAtomMapNum()
|
| 160 |
+
|
| 161 |
+
if not attachment_idx2map:
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
# If we are dealing with a linker, get the specific attachment point
|
| 165 |
+
if attachment_id is not None:
|
| 166 |
+
attachment_idx = [k for k, v in attachment_idx2map.items() if v == attachment_id]
|
| 167 |
+
if not attachment_idx:
|
| 168 |
+
return None
|
| 169 |
+
attachment_idx = attachment_idx[0]
|
| 170 |
+
else:
|
| 171 |
+
attachment_idx = list(attachment_idx2map.keys())[0]
|
| 172 |
+
|
| 173 |
+
neighborhood = {substruct.GetAtomWithIdx(attachment_idx)}
|
| 174 |
+
def find_neighbors(atom, hops):
|
| 175 |
+
if hops <= 0:
|
| 176 |
+
return
|
| 177 |
+
for neighbor in atom.GetNeighbors():
|
| 178 |
+
neighborhood.add(neighbor)
|
| 179 |
+
find_neighbors(neighbor, hops - 1)
|
| 180 |
+
|
| 181 |
+
find_neighbors(substruct.GetAtomWithIdx(attachment_idx), n_hops)
|
| 182 |
+
neighborhood_idxs = [a.GetIdx() for a in neighborhood]
|
| 183 |
+
|
| 184 |
+
neighborhood_smarts = Chem.MolFragmentToSmarts(substruct, neighborhood_idxs)
|
| 185 |
+
if neighborhood_smarts:
|
| 186 |
+
return canonize_smarts(neighborhood_smarts)
|
| 187 |
+
|
| 188 |
+
return None
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def get_functional_groups_distributions(
|
| 192 |
+
df: pd.DataFrame,
|
| 193 |
+
get_side_chain_info: bool = False,
|
| 194 |
+
timeout: Optional[Union[int, float]] = None,
|
| 195 |
+
filename_distributions: Optional[str] = None,
|
| 196 |
+
filename_mappings: Optional[str] = None,
|
| 197 |
+
filename_df_with_functional_groups: Optional[str] = None,
|
| 198 |
+
load_from_file: bool = True,
|
| 199 |
+
verbose: int = 0,
|
| 200 |
+
) -> Dict[str, Dict[str, set]]:
|
| 201 |
+
""" Get the distributions of functional groups at attachment points in a dataframe of PROTACs.
|
| 202 |
+
|
| 203 |
+
The input dataframe should contain the following columns:
|
| 204 |
+
- 'PROTAC SMILES': The SMILES of the PROTAC.
|
| 205 |
+
- 'POI Ligand SMILES with direction': The SMILES of the POI ligand.
|
| 206 |
+
- 'Linker SMILES with direction': The SMILES of the linker.
|
| 207 |
+
- 'E3 Binder SMILES with direction': The SMILES of the E3 binder.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
df: The DataFrame containing the PROTACs.
|
| 211 |
+
get_side_chain_info: Whether to get the side chain information along with the functional groups at the attachment points.
|
| 212 |
+
timeout: The timeout for the substructure search. Default is None.
|
| 213 |
+
verbose: Verbosity level.
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
Dict[str, Dict[str, set]]: The distributions of functional groups at attachment points in PROTACs.
|
| 217 |
+
"""
|
| 218 |
+
smarts_counter = Counter()
|
| 219 |
+
e3_smarts_counter = Counter()
|
| 220 |
+
poi_smarts_counter = Counter()
|
| 221 |
+
substr_smarts_counter = {
|
| 222 |
+
'poi2linker': defaultdict(Counter),
|
| 223 |
+
'linker2poi': defaultdict(Counter),
|
| 224 |
+
'e32linker': defaultdict(Counter),
|
| 225 |
+
'linker2e3': defaultdict(Counter),
|
| 226 |
+
}
|
| 227 |
+
# Assign to each functional group the list of substructures that appear in the df
|
| 228 |
+
poi_substr2fg = defaultdict(set)
|
| 229 |
+
e3_substr2fg = defaultdict(set)
|
| 230 |
+
# Assign to each substructure the list of functional groups that appear in the df
|
| 231 |
+
poi_fg_2_substr = defaultdict(set)
|
| 232 |
+
e3_fg_2_substr = defaultdict(set)
|
| 233 |
+
substr_fg_2_linker = defaultdict(set)
|
| 234 |
+
|
| 235 |
+
linker2fg = defaultdict(dict)
|
| 236 |
+
|
| 237 |
+
if load_from_file:
|
| 238 |
+
if filename_distributions is not None and filename_mappings is not None:
|
| 239 |
+
with open(filename_distributions, 'r') as f:
|
| 240 |
+
fg_distr = json.load(f)
|
| 241 |
+
with open(filename_mappings, 'r') as f:
|
| 242 |
+
fg_mappings = json.load(f)
|
| 243 |
+
ret = {}
|
| 244 |
+
ret.update(fg_distr)
|
| 245 |
+
ret.update(fg_mappings)
|
| 246 |
+
return ret
|
| 247 |
+
else:
|
| 248 |
+
print(f'WARNING: No filename provided to load the mappings from. The functional groups will be recomputed.')
|
| 249 |
+
|
| 250 |
+
df_with_functional_groups = []
|
| 251 |
+
|
| 252 |
+
for i, row in tqdm(df.iterrows(), total=len(df)):
|
| 253 |
+
protac_smiles = row['PROTAC SMILES']
|
| 254 |
+
poi_smiles = row['POI Ligand SMILES with direction']
|
| 255 |
+
linker_smiles = row['Linker SMILES with direction']
|
| 256 |
+
e3_smiles = row['E3 Binder SMILES with direction']
|
| 257 |
+
|
| 258 |
+
protac = Chem.MolFromSmiles(protac_smiles)
|
| 259 |
+
poi = Chem.MolFromSmiles(poi_smiles)
|
| 260 |
+
e3 = Chem.MolFromSmiles(e3_smiles)
|
| 261 |
+
linker = Chem.MolFromSmiles(linker_smiles)
|
| 262 |
+
|
| 263 |
+
if None in [protac, poi, e3, linker]:
|
| 264 |
+
print(f'WARNING: Could not parse the following SMILES:')
|
| 265 |
+
print(f'PROTAC: {protac_smiles}')
|
| 266 |
+
print(f'POI: {poi_smiles}')
|
| 267 |
+
print(f'Linker: {linker_smiles}')
|
| 268 |
+
print(f'E3: {e3_smiles}')
|
| 269 |
+
print('-' * 80)
|
| 270 |
+
|
| 271 |
+
# We have a bit of care with the linker, as it can be empty
|
| 272 |
+
try:
|
| 273 |
+
_ = Chem.molzip(Chem.MolFromSmiles('.'.join([poi_smiles, linker_smiles, e3_smiles])))
|
| 274 |
+
except:
|
| 275 |
+
print(f'WARNING: The linker might be empty: {linker_smiles}')
|
| 276 |
+
linker = None
|
| 277 |
+
|
| 278 |
+
if linker is not None:
|
| 279 |
+
fg_poi = get_functional_group_at_attachment(protac, poi, linker, timeout=timeout, return_dict=get_side_chain_info)
|
| 280 |
+
fg_e3 = get_functional_group_at_attachment(protac, e3, linker, timeout=timeout, return_dict=get_side_chain_info)
|
| 281 |
+
else:
|
| 282 |
+
# If the linker is empty, then we use the other side as the linker
|
| 283 |
+
fg_poi = get_functional_group_at_attachment(protac, poi, e3, return_dict=get_side_chain_info)
|
| 284 |
+
fg_e3 = get_functional_group_at_attachment(protac, e3, poi, return_dict=get_side_chain_info)
|
| 285 |
+
|
| 286 |
+
if get_side_chain_info:
|
| 287 |
+
if fg_poi is not None:
|
| 288 |
+
smarts_counter.update([fg_poi['attachment']])
|
| 289 |
+
poi_smarts_counter.update([fg_poi['substruct']])
|
| 290 |
+
substr_smarts_counter['poi2linker'][fg_poi['substruct']].update([fg_poi['linker']])
|
| 291 |
+
substr_smarts_counter['linker2poi'][fg_poi['linker']].update([fg_poi['substruct']])
|
| 292 |
+
linker2fg[linker_smiles]['poi'] = fg_poi['attachment']
|
| 293 |
+
|
| 294 |
+
poi_substr2fg[poi_smiles].append(fg_poi['attachment'])
|
| 295 |
+
poi_fg_2_substr[fg_poi['attachment']].update([poi_smiles])
|
| 296 |
+
|
| 297 |
+
if fg_e3 is not None:
|
| 298 |
+
smarts_counter.update([fg_e3['attachment']])
|
| 299 |
+
e3_smarts_counter.update([fg_e3['substruct']])
|
| 300 |
+
substr_smarts_counter['e32linker'][fg_e3['substruct']].update([fg_e3['linker']])
|
| 301 |
+
substr_smarts_counter['linker2e3'][fg_e3['linker']].update([fg_e3['substruct']])
|
| 302 |
+
linker2fg[linker_smiles]['e3'] = fg_e3['attachment']
|
| 303 |
+
|
| 304 |
+
e3_substr2fg[e3_smiles].update(fg_e3['attachment'])
|
| 305 |
+
e3_fg_2_substr[fg_e3['attachment']].update([e3_smiles])
|
| 306 |
+
else:
|
| 307 |
+
if fg_poi is not None:
|
| 308 |
+
smarts_counter.update([fg_poi])
|
| 309 |
+
poi_smarts_counter.update([fg_poi])
|
| 310 |
+
poi_substr2fg[poi_smiles].update([fg_poi])
|
| 311 |
+
poi_fg_2_substr[fg_poi].update([poi_smiles])
|
| 312 |
+
substr_fg_2_linker[fg_poi].update([linker_smiles])
|
| 313 |
+
if fg_e3 is not None:
|
| 314 |
+
smarts_counter.update([fg_e3])
|
| 315 |
+
e3_smarts_counter.update([fg_e3])
|
| 316 |
+
e3_substr2fg[e3_smiles].update([fg_e3])
|
| 317 |
+
e3_fg_2_substr[fg_e3].update([e3_smiles])
|
| 318 |
+
substr_fg_2_linker[fg_e3].update([linker_smiles])
|
| 319 |
+
|
| 320 |
+
# Update the DataFrame with the functional groups
|
| 321 |
+
if fg_poi is not None:
|
| 322 |
+
row['POI Ligand Functional Group'] = fg_poi
|
| 323 |
+
if fg_e3 is not None:
|
| 324 |
+
row['E3 Binder Functional Group'] = fg_e3
|
| 325 |
+
df_with_functional_groups.append(row)
|
| 326 |
+
|
| 327 |
+
# Normalize all the counts to probability distributions
|
| 328 |
+
fg_distr = {k: v / smarts_counter.total() for k, v in smarts_counter.items()}
|
| 329 |
+
e3_fg_distr = {k: v / e3_smarts_counter.total() for k, v in e3_smarts_counter.items()}
|
| 330 |
+
poi_fg_distr = {k: v / poi_smarts_counter.total() for k, v in poi_smarts_counter.items()}
|
| 331 |
+
|
| 332 |
+
# Sort the probability distributions
|
| 333 |
+
fg_distr = dict(sorted(fg_distr.items(), key=lambda x: x[1], reverse=True))
|
| 334 |
+
e3_fg_distr = dict(sorted(e3_fg_distr.items(), key=lambda x: x[1], reverse=True))
|
| 335 |
+
poi_fg_distr = dict(sorted(poi_fg_distr.items(), key=lambda x: x[1], reverse=True))
|
| 336 |
+
|
| 337 |
+
if not get_side_chain_info:
|
| 338 |
+
ret = {
|
| 339 |
+
'fg_distr': fg_distr,
|
| 340 |
+
'e3_fg_distr': e3_fg_distr,
|
| 341 |
+
'poi_fg_distr': poi_fg_distr,
|
| 342 |
+
'poi_fg_2_substr': poi_fg_2_substr,
|
| 343 |
+
'e3_fg_2_substr': e3_fg_2_substr,
|
| 344 |
+
'substr_fg_2_linker': substr_fg_2_linker,
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
# Normalize the linker-to-substructure to probability distributions
|
| 348 |
+
if get_side_chain_info:
|
| 349 |
+
side_fg_distr = defaultdict(dict)
|
| 350 |
+
for direction, smarts2counter in substr_smarts_counter.items():
|
| 351 |
+
for smarts, counter in smarts2counter.items():
|
| 352 |
+
side_fg_distr[direction][smarts] = {k: v / counter.total() for k, v in counter.items()}
|
| 353 |
+
side_fg_distr[direction][smarts] = dict(sorted(side_fg_distr[direction][smarts].items(), key=lambda x: x[1], reverse=True))
|
| 354 |
+
|
| 355 |
+
if verbose:
|
| 356 |
+
# Display the top 5 functional groups
|
| 357 |
+
print('-' * 80)
|
| 358 |
+
print(f'{"-".join(direction.upper().split("2"))}:')
|
| 359 |
+
print('-' * len(direction) + '-' * 2)
|
| 360 |
+
for i, (smarts, probs) in enumerate(side_fg_distr[direction].items()):
|
| 361 |
+
if i >= 5:
|
| 362 |
+
break
|
| 363 |
+
print(f'{smarts}:')
|
| 364 |
+
for j, (sma, prob) in enumerate(probs.items()):
|
| 365 |
+
if j >= 5:
|
| 366 |
+
break
|
| 367 |
+
print(f'\t{prob:.2%} -> {sma}')
|
| 368 |
+
ret = {
|
| 369 |
+
'fg_distr': fg_distr,
|
| 370 |
+
'e3_fg_distr': e3_fg_distr,
|
| 371 |
+
'poi_fg_distr': poi_fg_distr,
|
| 372 |
+
'poi_fg_2_substr': poi_fg_2_substr,
|
| 373 |
+
'e3_fg_2_substr': e3_fg_2_substr,
|
| 374 |
+
'substr_fg_2_linker': substr_fg_2_linker,
|
| 375 |
+
'side_fg_distr': side_fg_distr,
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
if filename_distributions is not None:
|
| 379 |
+
# Save to JSON file
|
| 380 |
+
distributions = {k: v for k, v in ret.items() if 'distr' in k}
|
| 381 |
+
with open(filename_distributions, 'w') as f:
|
| 382 |
+
json.dump(distributions, f, indent=4)
|
| 383 |
+
print(f'Functional group distributions saved to: {filename_distributions}')
|
| 384 |
+
|
| 385 |
+
if filename_mappings is not None:
|
| 386 |
+
# Convert sets to lists to make the data serializable
|
| 387 |
+
fg_mappings = {k: {sk: list(s) for sk, s in v.items()} for k, v in ret.items() if 'distr' not in k}
|
| 388 |
+
|
| 389 |
+
with open(filename_mappings, 'w') as f:
|
| 390 |
+
json.dump(fg_mappings, f, indent=4)
|
| 391 |
+
print(f'Functional group mappings saved to: {filename_mappings}')
|
| 392 |
+
|
| 393 |
+
df_with_functional_groups = pd.DataFrame(df_with_functional_groups)
|
| 394 |
+
ret['dataframe'] = df_with_functional_groups
|
| 395 |
+
|
| 396 |
+
if filename_df_with_functional_groups is not None:
|
| 397 |
+
df_with_functional_groups.to_csv(filename_df_with_functional_groups, index=False)
|
| 398 |
+
print(f'DataFrame with functional groups saved to: {filename_df_with_functional_groups}')
|
| 399 |
+
|
| 400 |
+
return ret
|
protac_splitter/data/generation/generation.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 3 |
+
from typing import Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from rdkit import Chem
|
| 9 |
+
|
| 10 |
+
from protac_splitter.evaluation import check_reassembly
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def generate_protacs(
|
| 14 |
+
poi_fg_distr: Dict[str, float],
|
| 15 |
+
e3_fg_distr: Dict[str, float],
|
| 16 |
+
substr_fg_2_linker: Dict[str, List[str]],
|
| 17 |
+
poi_fg_2_substr: Dict[str, List[str]],
|
| 18 |
+
e3_fg_2_substr: Dict[str, List[str]],
|
| 19 |
+
num_samples: int,
|
| 20 |
+
random_state: int = 42,
|
| 21 |
+
batch_size: int = 1000,
|
| 22 |
+
max_workers: int = 4,
|
| 23 |
+
original_df: Optional[pd.DataFrame] = None,
|
| 24 |
+
filename_generated_df: Optional[str] = None,
|
| 25 |
+
base_data_dir: Optional[str] = None,
|
| 26 |
+
cover_all_smiles: bool = False,
|
| 27 |
+
) -> pd.DataFrame:
|
| 28 |
+
""" Generate PROTACs given the distributions of functional groups at attachment points.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
poi_fg_distr: The distribution of functional groups at the POI attachment point.
|
| 32 |
+
e3_fg_distr: The distribution of functional groups at the E3 attachment point.
|
| 33 |
+
substr_fg_2_linker: The mapping of functional groups to linkers.
|
| 34 |
+
poi_fg_2_substr: The mapping of functional groups to POI substrates.
|
| 35 |
+
e3_fg_2_substr: The mapping of functional groups to E3 substrates.
|
| 36 |
+
num_samples: The number of PROTACs to generate.
|
| 37 |
+
random_state: The random state for reproducibility.
|
| 38 |
+
batch_size: The batch size for generating PROTACs.
|
| 39 |
+
max_workers: The maximum number of workers for the ThreadPoolExecutor.
|
| 40 |
+
original_df: The original DataFrame containing the PROTACs. Must have a
|
| 41 |
+
column named 'PROTAC SMILES' containing the strings to
|
| 42 |
+
avoid generating. The check is done on strings, so make
|
| 43 |
+
sure to canonize/standardize the SMILES strings.
|
| 44 |
+
filename_generated_df: The filename to save the generated PROTACs.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
pd.DataFrame: The DataFrame containing the generated PROTACs.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
np.random.seed(random_state)
|
| 51 |
+
final_df = pd.DataFrame()
|
| 52 |
+
total_batches = int(np.ceil(num_samples / batch_size))
|
| 53 |
+
|
| 54 |
+
def generate_protac_batch(batch_size: int, random_state: int) -> List[dict]:
|
| 55 |
+
np.random.seed(random_state)
|
| 56 |
+
|
| 57 |
+
# Sample functional groups for POI and E3
|
| 58 |
+
poi_fgs = np.random.choice(list(poi_fg_distr.keys()), size=batch_size, p=list(poi_fg_distr.values()))
|
| 59 |
+
e3_fgs = np.random.choice(list(e3_fg_distr.keys()), size=batch_size, p=list(e3_fg_distr.values()))
|
| 60 |
+
|
| 61 |
+
# Map functional groups to corresponding substrates
|
| 62 |
+
# NOTE: When size argument is specified, the output is a numpy array.
|
| 63 |
+
# NOTE: If the functional group is not in the dictionary, the output is an empty numpy array.
|
| 64 |
+
poi_samples = [
|
| 65 |
+
np.random.choice(poi_fg_2_substr.get(fg, []), size=1 if fg in poi_fg_2_substr and poi_fg_2_substr[fg] else 0)
|
| 66 |
+
for fg in poi_fgs
|
| 67 |
+
]
|
| 68 |
+
e3_samples = [
|
| 69 |
+
np.random.choice(e3_fg_2_substr.get(fg, []), size=1 if fg in e3_fg_2_substr and e3_fg_2_substr[fg] else 0)
|
| 70 |
+
for fg in e3_fgs
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
generated_protacs = []
|
| 74 |
+
|
| 75 |
+
for poi_smiles, poi_fg, e3_smiles, e3_fg in zip(poi_samples, poi_fgs, e3_samples, e3_fgs):
|
| 76 |
+
# Check if poi_smiles and e3_smiles are not an empty numpy array
|
| 77 |
+
if poi_smiles.size == 0 or e3_smiles.size == 0:
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
# Convert the numpy arrays to strings
|
| 81 |
+
poi_smiles, e3_smiles = poi_smiles[0], e3_smiles[0]
|
| 82 |
+
|
| 83 |
+
linkers = set(substr_fg_2_linker.get(poi_fg, [])) & set(substr_fg_2_linker.get(e3_fg, []))
|
| 84 |
+
if not linkers:
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
linker_smiles = np.random.choice(list(linkers))
|
| 88 |
+
|
| 89 |
+
# Get the PROTAC by combining the POI, linker, and E3
|
| 90 |
+
ligands_smiles = '.'.join([poi_smiles, linker_smiles, e3_smiles])
|
| 91 |
+
protac = Chem.MolFromSmiles(ligands_smiles)
|
| 92 |
+
|
| 93 |
+
if protac is None:
|
| 94 |
+
continue
|
| 95 |
+
try:
|
| 96 |
+
protac = Chem.molzip(protac)
|
| 97 |
+
except:
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
# Sanitize molecule
|
| 101 |
+
try:
|
| 102 |
+
zero_on_success = Chem.SanitizeMol(protac, catchErrors=True)
|
| 103 |
+
if zero_on_success != 0:
|
| 104 |
+
continue
|
| 105 |
+
protac_smiles = Chem.MolToSmiles(protac, canonical=True)
|
| 106 |
+
except:
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
if original_df is not None and protac_smiles in original_df['PROTAC SMILES'].values:
|
| 110 |
+
continue
|
| 111 |
+
|
| 112 |
+
# Check if PROTAC can be reassembled
|
| 113 |
+
if not check_reassembly(protac_smiles, ligands_smiles):
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
generated_protacs.append({
|
| 117 |
+
'PROTAC SMILES': protac_smiles,
|
| 118 |
+
'POI Ligand SMILES with direction': poi_smiles,
|
| 119 |
+
'Linker SMILES with direction': linker_smiles,
|
| 120 |
+
'E3 Binder SMILES with direction': e3_smiles,
|
| 121 |
+
'POI Ligand Functional Group': poi_fg,
|
| 122 |
+
'E3 Binder Functional Group': e3_fg,
|
| 123 |
+
})
|
| 124 |
+
|
| 125 |
+
return generated_protacs
|
| 126 |
+
|
| 127 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 128 |
+
futures = []
|
| 129 |
+
for i in tqdm(range(total_batches), desc="Generating Batches"):
|
| 130 |
+
futures.append(executor.submit(generate_protac_batch, batch_size, random_state + i))
|
| 131 |
+
|
| 132 |
+
for i, future in tqdm(enumerate(futures), desc="Processing Results", total=total_batches):
|
| 133 |
+
generated_batch = future.result()
|
| 134 |
+
if generated_batch:
|
| 135 |
+
batch_df = pd.DataFrame(generated_batch)
|
| 136 |
+
final_df = pd.concat([final_df, batch_df]).drop_duplicates()
|
| 137 |
+
if i % 100 == 0:
|
| 138 |
+
if base_data_dir:
|
| 139 |
+
batch_df.to_csv(os.path.join(base_data_dir, f'generated_protacs_batch={i}.csv'), index=False)
|
| 140 |
+
else:
|
| 141 |
+
batch_df.to_csv(f'generated_protacs_batch={i}.csv', index=False)
|
| 142 |
+
if filename_generated_df:
|
| 143 |
+
final_df.to_csv(filename_generated_df, index=False)
|
| 144 |
+
if len(final_df) >= num_samples:
|
| 145 |
+
break
|
| 146 |
+
|
| 147 |
+
if not final_df.empty:
|
| 148 |
+
generated_pois = set(final_df['POI Ligand SMILES with direction'].unique())
|
| 149 |
+
generated_e3s = set(final_df['E3 Binder SMILES with direction'].unique())
|
| 150 |
+
generated_linkers = set(final_df['Linker SMILES with direction'].unique())
|
| 151 |
+
else:
|
| 152 |
+
generated_pois = set()
|
| 153 |
+
generated_e3s = set()
|
| 154 |
+
generated_linkers = set()
|
| 155 |
+
|
| 156 |
+
# Check how we covered the available substructures
|
| 157 |
+
avail_pois = set()
|
| 158 |
+
avail_e3s = set()
|
| 159 |
+
avail_linkers = set()
|
| 160 |
+
for fg in poi_fg_2_substr:
|
| 161 |
+
avail_pois.update(set(poi_fg_2_substr[fg]))
|
| 162 |
+
for fg in e3_fg_2_substr:
|
| 163 |
+
avail_e3s.update(set(e3_fg_2_substr[fg]))
|
| 164 |
+
for fg in substr_fg_2_linker:
|
| 165 |
+
avail_linkers.update(set(substr_fg_2_linker[fg]))
|
| 166 |
+
|
| 167 |
+
e3_coverage = len(generated_e3s) / len(avail_e3s)
|
| 168 |
+
poi_coverage = len(generated_pois) / len(avail_pois)
|
| 169 |
+
linker_coverage = len(generated_linkers) / len(avail_linkers)
|
| 170 |
+
|
| 171 |
+
print(f"POI coverage: {poi_coverage:.3%}")
|
| 172 |
+
print(f"E3 coverage: {e3_coverage:.3%}")
|
| 173 |
+
print(f"Linker coverage: {linker_coverage:.3%}")
|
| 174 |
+
|
| 175 |
+
# Get the "leftover" ligands
|
| 176 |
+
leftover_pois = avail_pois - generated_pois
|
| 177 |
+
leftover_e3s = avail_e3s - generated_e3s
|
| 178 |
+
leftover_linkers = avail_linkers - generated_linkers
|
| 179 |
+
|
| 180 |
+
covering_df = []
|
| 181 |
+
|
| 182 |
+
with tqdm(total=len(leftover_pois) + len(leftover_e3s) + len(leftover_linkers), desc="Covering Leftover Ligands") as pbar:
|
| 183 |
+
while True:
|
| 184 |
+
if not cover_all_smiles:
|
| 185 |
+
break
|
| 186 |
+
|
| 187 |
+
# Randomly select a POI, E3, and linker
|
| 188 |
+
if not leftover_pois:
|
| 189 |
+
pois_to_sample = avail_pois
|
| 190 |
+
else:
|
| 191 |
+
pois_to_sample = leftover_pois
|
| 192 |
+
if not leftover_e3s:
|
| 193 |
+
e3s_to_sample = avail_e3s
|
| 194 |
+
else:
|
| 195 |
+
e3s_to_sample = leftover_e3s
|
| 196 |
+
if not leftover_linkers:
|
| 197 |
+
linkers_to_sample = avail_linkers
|
| 198 |
+
else:
|
| 199 |
+
linkers_to_sample = leftover_linkers
|
| 200 |
+
|
| 201 |
+
poi_smiles = np.random.choice(list(pois_to_sample))
|
| 202 |
+
e3_smiles = np.random.choice(list(e3s_to_sample))
|
| 203 |
+
linker_smiles = np.random.choice(list(linkers_to_sample))
|
| 204 |
+
|
| 205 |
+
# Get the PROTAC by combining the POI, linker, and E3
|
| 206 |
+
ligands_smiles = '.'.join([poi_smiles, linker_smiles, e3_smiles])
|
| 207 |
+
protac = Chem.MolFromSmiles(ligands_smiles)
|
| 208 |
+
if protac is None:
|
| 209 |
+
continue
|
| 210 |
+
try:
|
| 211 |
+
protac = Chem.molzip(protac)
|
| 212 |
+
except:
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
# Sanitize molecule
|
| 216 |
+
try:
|
| 217 |
+
zero_on_success = Chem.SanitizeMol(protac, catchErrors=True)
|
| 218 |
+
if zero_on_success != 0:
|
| 219 |
+
continue
|
| 220 |
+
protac_smiles = Chem.MolToSmiles(protac, canonical=True)
|
| 221 |
+
except:
|
| 222 |
+
continue
|
| 223 |
+
|
| 224 |
+
if original_df is not None and protac_smiles in original_df['PROTAC SMILES'].values:
|
| 225 |
+
continue
|
| 226 |
+
|
| 227 |
+
# Check if PROTAC can be reassembled
|
| 228 |
+
if not check_reassembly(protac_smiles, ligands_smiles):
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
covering_df.append({
|
| 232 |
+
'PROTAC SMILES': protac_smiles,
|
| 233 |
+
'POI Ligand SMILES with direction': poi_smiles,
|
| 234 |
+
'Linker SMILES with direction': linker_smiles,
|
| 235 |
+
'E3 Binder SMILES with direction': e3_smiles,
|
| 236 |
+
'POI Ligand Functional Group': None,
|
| 237 |
+
'E3 Binder Functional Group': None,
|
| 238 |
+
})
|
| 239 |
+
|
| 240 |
+
generated_pois.add(poi_smiles)
|
| 241 |
+
generated_e3s.add(e3_smiles)
|
| 242 |
+
generated_linkers.add(linker_smiles)
|
| 243 |
+
|
| 244 |
+
ligands_added = 0
|
| 245 |
+
if poi_smiles in leftover_pois:
|
| 246 |
+
leftover_pois.remove(poi_smiles)
|
| 247 |
+
ligands_added += 1
|
| 248 |
+
if e3_smiles in leftover_e3s:
|
| 249 |
+
leftover_e3s.remove(e3_smiles)
|
| 250 |
+
ligands_added += 1
|
| 251 |
+
if linker_smiles in leftover_linkers:
|
| 252 |
+
leftover_linkers.remove(linker_smiles)
|
| 253 |
+
ligands_added += 1
|
| 254 |
+
|
| 255 |
+
e3_coverage = len(generated_e3s) / len(avail_e3s)
|
| 256 |
+
poi_coverage = len(generated_pois) / len(avail_pois)
|
| 257 |
+
linker_coverage = len(generated_linkers) / len(avail_linkers)
|
| 258 |
+
|
| 259 |
+
# Update the pbar and write the coverage
|
| 260 |
+
pbar.update(ligands_added)
|
| 261 |
+
pbar.set_postfix({
|
| 262 |
+
'POI': f"{poi_coverage:.2%}",
|
| 263 |
+
'E3': f"{e3_coverage:.2%}",
|
| 264 |
+
'Linker': f"{linker_coverage:.2%}",
|
| 265 |
+
})
|
| 266 |
+
|
| 267 |
+
if not leftover_pois and not leftover_e3s and not leftover_linkers:
|
| 268 |
+
break
|
| 269 |
+
|
| 270 |
+
final_df = pd.concat([final_df, pd.DataFrame(covering_df)]).drop_duplicates()
|
| 271 |
+
|
| 272 |
+
# Save to file if specified
|
| 273 |
+
if filename_generated_df:
|
| 274 |
+
final_df.to_csv(filename_generated_df, index=False)
|
| 275 |
+
print(f"Generated PROTACs saved to: {filename_generated_df}")
|
| 276 |
+
|
| 277 |
+
return final_df
|
protac_splitter/display_utils.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
from rdkit import Chem
|
| 6 |
+
from rdkit.Chem import Draw
|
| 7 |
+
|
| 8 |
+
if 'ipykernel' in sys.modules:
|
| 9 |
+
from IPython.display import SVG
|
| 10 |
+
|
| 11 |
+
from .chemoinformatics import get_atom_idx_at_attachment, canonize
|
| 12 |
+
|
| 13 |
+
def safe_display(*args):
|
| 14 |
+
"""Displays content only if running in a Jupyter notebook."""
|
| 15 |
+
if 'ipykernel' in sys.modules:
|
| 16 |
+
display(*args)
|
| 17 |
+
else:
|
| 18 |
+
print(*args)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def display_mol(
|
| 22 |
+
mol: Chem.Mol,
|
| 23 |
+
w: int = 800,
|
| 24 |
+
h: int = 300,
|
| 25 |
+
legend: Optional[str] = None,
|
| 26 |
+
use_smiles_as_legend: bool = True,
|
| 27 |
+
display_svg: bool = True,
|
| 28 |
+
):
|
| 29 |
+
""" Display a molecule in a Jupyter notebook. Useful for having """
|
| 30 |
+
if mol is None:
|
| 31 |
+
print('Molecule is None')
|
| 32 |
+
return None
|
| 33 |
+
if use_smiles_as_legend and legend is None:
|
| 34 |
+
legend = Chem.MolToSmiles(mol)
|
| 35 |
+
if display_svg:
|
| 36 |
+
mol.SetProp("_Name", Chem.MolToSmiles(mol, canonical=True))
|
| 37 |
+
d = Draw.rdMolDraw2D.MolDraw2DSVG(w, h, noFreetype=True)
|
| 38 |
+
font_path = '/System/Library/Fonts/Supplemental/Arial.ttf'
|
| 39 |
+
if os.path.exists(font_path):
|
| 40 |
+
d.fontFile = font_path
|
| 41 |
+
d.DrawMolecule(mol, legend=legend)
|
| 42 |
+
d.FinishDrawing()
|
| 43 |
+
svg = d.GetDrawingText()
|
| 44 |
+
# Check if in Jupyter notebook
|
| 45 |
+
if sys.modules.get('ipykernel', None):
|
| 46 |
+
from IPython.display import SVG
|
| 47 |
+
safe_display(SVG(svg))
|
| 48 |
+
else:
|
| 49 |
+
img = Draw.MolToImage(mol, size=(w, h))
|
| 50 |
+
safe_display(img)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_mapped_protac_img(
|
| 54 |
+
protac_smiles: str,
|
| 55 |
+
poi_smiles: str,
|
| 56 |
+
linker_smiles: str,
|
| 57 |
+
e3_smiles: str,
|
| 58 |
+
w: int = 1000,
|
| 59 |
+
h: int = 1000,
|
| 60 |
+
useSVG: bool = False,
|
| 61 |
+
display_image: bool = False,
|
| 62 |
+
legend: Optional[str] = None,
|
| 63 |
+
show_bond_indices: bool = False,
|
| 64 |
+
):
|
| 65 |
+
""" Display a PROTAC molecule with the POI, linker, and E3 ligase highlighted.
|
| 66 |
+
|
| 67 |
+
If `useSVG` is True, then the POI-Linker bond is highlighted in purple, whereas the E3-Linker bond is highlighted in green.
|
| 68 |
+
If `useSVG` is False, then both splitting points are highlighted in purple.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
protac_smiles: The SMILES string of the PROTAC.
|
| 72 |
+
poi_smiles: The SMILES string of the POI.
|
| 73 |
+
linker_smiles: The SMILES string of the linker.
|
| 74 |
+
e3_smiles: The SMILES string of the E3 ligase.
|
| 75 |
+
w: The width of the image.
|
| 76 |
+
h: The height of the image.
|
| 77 |
+
useSVG: Whether to use SVG format.
|
| 78 |
+
display_image: Whether to display the image.
|
| 79 |
+
legend: The legend to display.
|
| 80 |
+
show_bond_indices: Whether to show bond indices in the image.
|
| 81 |
+
"""
|
| 82 |
+
protac_smiles = canonize(protac_smiles)
|
| 83 |
+
e3_smiles = canonize(e3_smiles)
|
| 84 |
+
poi_smiles = canonize(poi_smiles)
|
| 85 |
+
linker_smiles = canonize(linker_smiles)
|
| 86 |
+
|
| 87 |
+
# Check if any of the canonicalized SMILES is None
|
| 88 |
+
if None in [protac_smiles, e3_smiles, poi_smiles, linker_smiles]:
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
protac_mol = Chem.MolFromSmiles(protac_smiles)
|
| 92 |
+
e3_mol = Chem.MolFromSmiles(e3_smiles)
|
| 93 |
+
poi_mol = Chem.MolFromSmiles(poi_smiles)
|
| 94 |
+
linker_mol = Chem.MolFromSmiles(linker_smiles)
|
| 95 |
+
|
| 96 |
+
if None in [protac_mol, e3_mol, poi_mol, linker_mol]:
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
if linker_smiles in ['[*:1][*:2]', '[*:2][*:1]']:
|
| 100 |
+
print('WARNING. Linker is empty.')
|
| 101 |
+
poi_attachment_idx = get_atom_idx_at_attachment(protac_mol, poi_mol, e3_mol)
|
| 102 |
+
e3_attachment_idx = get_atom_idx_at_attachment(protac_mol, e3_mol, poi_mol)
|
| 103 |
+
else:
|
| 104 |
+
poi_attachment_idx = get_atom_idx_at_attachment(protac_mol, poi_mol, linker_mol)
|
| 105 |
+
e3_attachment_idx = get_atom_idx_at_attachment(protac_mol, e3_mol, linker_mol)
|
| 106 |
+
|
| 107 |
+
cyan = (0, 1, 1, 0.5)
|
| 108 |
+
red = (1, 0, 0, 0.5)
|
| 109 |
+
green = (0, 1, 0, 0.5)
|
| 110 |
+
blue = (0, 0, 1, 0.5)
|
| 111 |
+
purple = (1, 0, 1, 0.3)
|
| 112 |
+
|
| 113 |
+
highlight_atoms = []
|
| 114 |
+
highlight_bonds = []
|
| 115 |
+
atom_colors = {}
|
| 116 |
+
bond_colors = {}
|
| 117 |
+
|
| 118 |
+
if poi_attachment_idx is not None:
|
| 119 |
+
if len(poi_attachment_idx) != 2:
|
| 120 |
+
if linker_smiles in ['[*:1][*:2]', '[*:2][*:1]']:
|
| 121 |
+
print(f'WARNING. Linker is empty, no highlighting will be showed for the POI.')
|
| 122 |
+
else:
|
| 123 |
+
print(f'WARNING. POI attachment points must be only two, got instead: {poi_attachment_idx}')
|
| 124 |
+
else:
|
| 125 |
+
poi_bond_idx = protac_mol.GetBondBetweenAtoms(*poi_attachment_idx).GetIdx()
|
| 126 |
+
highlight_atoms += poi_attachment_idx
|
| 127 |
+
highlight_bonds.append(poi_bond_idx)
|
| 128 |
+
atom_colors[poi_attachment_idx[0]] = purple
|
| 129 |
+
atom_colors[poi_attachment_idx[1]] = purple
|
| 130 |
+
bond_colors[poi_bond_idx] = purple
|
| 131 |
+
|
| 132 |
+
if e3_attachment_idx is not None:
|
| 133 |
+
if len(e3_attachment_idx) != 2:
|
| 134 |
+
if linker_smiles in ['[*:1][*:2]', '[*:2][*:1]']:
|
| 135 |
+
print(f'WARNING. Linker is empty, no highlighting will be showed for the E3.')
|
| 136 |
+
else:
|
| 137 |
+
print(f'WARNING. E3 attachment points must be only two, got instead: {e3_attachment_idx}')
|
| 138 |
+
else:
|
| 139 |
+
e3_bond_idx = protac_mol.GetBondBetweenAtoms(*e3_attachment_idx).GetIdx()
|
| 140 |
+
highlight_atoms += e3_attachment_idx
|
| 141 |
+
highlight_bonds.append(e3_bond_idx)
|
| 142 |
+
atom_colors[e3_attachment_idx[0]] = green
|
| 143 |
+
atom_colors[e3_attachment_idx[1]] = green
|
| 144 |
+
bond_colors[e3_bond_idx] = green
|
| 145 |
+
|
| 146 |
+
if useSVG:
|
| 147 |
+
drawer = Draw.rdMolDraw2D.MolDraw2DSVG(w, h, noFreetype=True)
|
| 148 |
+
options = drawer.drawOptions()
|
| 149 |
+
options.fontFile = '/System/Library/Fonts/Supplemental/Arial.ttf'
|
| 150 |
+
|
| 151 |
+
if legend is None:
|
| 152 |
+
# legend = '.'.join([e3_smiles, linker_smiles, poi_smiles])
|
| 153 |
+
legend = ""
|
| 154 |
+
|
| 155 |
+
drawer.DrawMolecule(
|
| 156 |
+
protac_mol,
|
| 157 |
+
legend=legend,
|
| 158 |
+
highlightAtoms=highlight_atoms,
|
| 159 |
+
highlightBonds=highlight_bonds,
|
| 160 |
+
highlightAtomColors=atom_colors,
|
| 161 |
+
highlightBondColors=bond_colors,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Add bond indices as text in the center of each bond
|
| 165 |
+
if show_bond_indices:
|
| 166 |
+
# Needs coordinates; ensure 2D coords present
|
| 167 |
+
Chem.rdDepictor.Compute2DCoords(protac_mol)
|
| 168 |
+
for bond in protac_mol.GetBonds():
|
| 169 |
+
idx = bond.GetIdx()
|
| 170 |
+
begin = bond.GetBeginAtomIdx()
|
| 171 |
+
end = bond.GetEndAtomIdx()
|
| 172 |
+
begin_pos = drawer.GetDrawCoords(begin)
|
| 173 |
+
end_pos = drawer.GetDrawCoords(end)
|
| 174 |
+
mid_y = (begin_pos.y + end_pos.y) / 2
|
| 175 |
+
mid_x = (begin_pos.x + end_pos.x) / 2
|
| 176 |
+
drawer.DrawString(f"{idx}", Chem.rdGeometry.Point2D(mid_x, mid_y), rawCoords=True)
|
| 177 |
+
|
| 178 |
+
drawer.FinishDrawing()
|
| 179 |
+
svg_text = drawer.GetDrawingText()
|
| 180 |
+
|
| 181 |
+
if display_image:
|
| 182 |
+
safe_display(SVG(svg_text))
|
| 183 |
+
|
| 184 |
+
return svg_text
|
| 185 |
+
else:
|
| 186 |
+
img = Draw.MolToImage(
|
| 187 |
+
protac_mol,
|
| 188 |
+
size=(w, h),
|
| 189 |
+
highlightColor=purple,
|
| 190 |
+
highlightAtoms=highlight_atoms,
|
| 191 |
+
highlightBonds=highlight_bonds,
|
| 192 |
+
highlightAtomColors=atom_colors,
|
| 193 |
+
highlightBondColors=bond_colors,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if display_image:
|
| 197 |
+
safe_display(img)
|
| 198 |
+
|
| 199 |
+
return img
|
protac_splitter/drawing_utils.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from rdkit import Chem, DataStructs
|
| 5 |
+
from rdkit.Chem import (
|
| 6 |
+
AllChem,
|
| 7 |
+
Draw,
|
| 8 |
+
rdFMCS,
|
| 9 |
+
rdMolAlign,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def save_as_svg(svg_content, filename, num_mols):
|
| 14 |
+
"""Save SVG content to a file."""
|
| 15 |
+
with open(filename, 'w') as file:
|
| 16 |
+
data = str(svg_content.data)
|
| 17 |
+
data = data.replace('1500', str(500*num_mols))
|
| 18 |
+
file.write(data)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def align_molecules_2D(ref_mol, to_align_mol):
|
| 22 |
+
AllChem.Compute2DCoords(ref_mol)
|
| 23 |
+
AllChem.Compute2DCoords(to_align_mol)
|
| 24 |
+
# Find the maximum common substructure and use it to align molecules
|
| 25 |
+
mcs = rdFMCS.FindMCS([ref_mol, to_align_mol])
|
| 26 |
+
mcs_mol = Chem.MolFromSmarts(mcs.smartsString)
|
| 27 |
+
ref_match = ref_mol.GetSubstructMatch(mcs_mol)
|
| 28 |
+
align_match = to_align_mol.GetSubstructMatch(mcs_mol)
|
| 29 |
+
atom_map = list(zip(align_match, ref_match))
|
| 30 |
+
rdMolAlign.AlignMol(to_align_mol, ref_mol, atomMap=atom_map)
|
| 31 |
+
return to_align_mol
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def align_molecules_by_coordinates(ref_mol, to_align_mol):
|
| 35 |
+
# Find the maximum common substructure
|
| 36 |
+
AllChem.Compute2DCoords(to_align_mol)
|
| 37 |
+
mcs = rdFMCS.FindMCS([ref_mol, to_align_mol])
|
| 38 |
+
mcs_mol = Chem.MolFromSmarts(mcs.smartsString)
|
| 39 |
+
ref_match = ref_mol.GetSubstructMatch(mcs_mol)
|
| 40 |
+
align_match = to_align_mol.GetSubstructMatch(mcs_mol)
|
| 41 |
+
|
| 42 |
+
# Copy the coordinates from the reference molecule to the molecule to be aligned
|
| 43 |
+
ref_conf = ref_mol.GetConformer()
|
| 44 |
+
align_conf = to_align_mol.GetConformer()
|
| 45 |
+
for ref_idx, align_idx in zip(ref_match, align_match):
|
| 46 |
+
ref_pos = ref_conf.GetAtomPosition(ref_idx)
|
| 47 |
+
align_conf.SetAtomPosition(align_idx, ref_pos)
|
| 48 |
+
|
| 49 |
+
return to_align_mol
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def draw_molecule_to_svg(mol, size=(500, 500), scale=1.0):
|
| 53 |
+
drawer = Draw.rdMolDraw2D.MolDraw2DSVG(size[0], size[1])
|
| 54 |
+
drawer.drawOptions().fixedBondLength = scale
|
| 55 |
+
drawer.DrawMolecule(mol)
|
| 56 |
+
drawer.FinishDrawing()
|
| 57 |
+
svg = drawer.GetDrawingText()
|
| 58 |
+
svg = re.sub(r'\<\?xml.*?\?\>', '', svg) # Remove XML declaration
|
| 59 |
+
svg = svg.replace('<svg', '<g').replace(
|
| 60 |
+
'</svg>', '</g>') # Replace svg tags with g tags
|
| 61 |
+
return svg
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def combine_svgs(svgs, output_filename, dimensions=None, size=(500, 500), xy_shifts=None):
|
| 65 |
+
if dimensions is None:
|
| 66 |
+
dimensions = (len(svgs), 1)
|
| 67 |
+
if xy_shifts is None:
|
| 68 |
+
xy_shifts = [(0, 0) for i in range(dimensions[0]*dimensions[1])]
|
| 69 |
+
|
| 70 |
+
width, height = size
|
| 71 |
+
grid_width, grid_height = dimensions
|
| 72 |
+
# Include only one XML declaration and the opening <svg> tag
|
| 73 |
+
combined_svg = f'<?xml version="1.0" standalone="no"?>\n'
|
| 74 |
+
combined_svg += f'<svg width="{grid_width * width}px" height="{grid_height * height}px" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">\n'
|
| 75 |
+
|
| 76 |
+
# Arrange SVGs in a grid
|
| 77 |
+
for i, (svg, xy_shift) in enumerate(zip(svgs, xy_shifts)):
|
| 78 |
+
x = (i % grid_width) * width
|
| 79 |
+
y = (i // grid_width) * height
|
| 80 |
+
combined_svg += f'<g transform="translate({x+xy_shift[0]},{y-xy_shift[1]})">{svg}</g>\n'
|
| 81 |
+
|
| 82 |
+
combined_svg += '</svg>'
|
| 83 |
+
with open(output_filename, 'w') as file:
|
| 84 |
+
file.write(combined_svg)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def draw_molecule_with_highlighted_bonds(mol, bonds_to_highlight):
|
| 88 |
+
"""
|
| 89 |
+
Draws a molecule with specified atoms and bonds highlighted.
|
| 90 |
+
|
| 91 |
+
Parameters:
|
| 92 |
+
- smiles (str): SMILES string for the molecule.
|
| 93 |
+
- atoms_to_highlight (set): Set of atom indices to highlight.
|
| 94 |
+
- bonds_to_highlight (list): List of bond indices to highlight.
|
| 95 |
+
- highlight_bond_colors (dict): Dictionary mapping bond indices to colors.
|
| 96 |
+
"""
|
| 97 |
+
# Create molecule from SMILES
|
| 98 |
+
|
| 99 |
+
# Initialize drawer
|
| 100 |
+
d2d = Draw.rdMolDraw2D.MolDraw2DSVG(350*2, 300*2)
|
| 101 |
+
|
| 102 |
+
# Set drawing options
|
| 103 |
+
d2d.drawOptions().useBWAtomPalette()
|
| 104 |
+
d2d.drawOptions().continuousHighlight = False
|
| 105 |
+
d2d.drawOptions().highlightBondWidthMultiplier = 24
|
| 106 |
+
d2d.drawOptions().setHighlightColour((0, 0, 1))
|
| 107 |
+
d2d.drawOptions().fillHighlights = False
|
| 108 |
+
|
| 109 |
+
# Draw the molecule with highlights
|
| 110 |
+
d2d.DrawMolecule(mol,
|
| 111 |
+
highlightAtoms=[],
|
| 112 |
+
highlightBonds=bonds_to_highlight)
|
| 113 |
+
d2d.FinishDrawing()
|
| 114 |
+
|
| 115 |
+
# Convert drawing to image and display
|
| 116 |
+
svg = d2d.GetDrawingText()
|
| 117 |
+
svg = svg.replace('svg:', '')
|
| 118 |
+
|
| 119 |
+
return svg
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def align_mol_2D_ver2(template, query):
|
| 123 |
+
mcs = rdFMCS.FindMCS([template, query])
|
| 124 |
+
patt = Chem.MolFromSmarts(mcs.smartsString)
|
| 125 |
+
|
| 126 |
+
query_match = query.GetSubstructMatch(patt)
|
| 127 |
+
template_match = template.GetSubstructMatch(patt)
|
| 128 |
+
|
| 129 |
+
rms = AllChem.AlignMol(query, template, atomMap=list(
|
| 130 |
+
zip(query_match, template_match)))
|
| 131 |
+
return template, query
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def transform_molecule(mol, degrees, translate_x=0, translate_y=0, flip_x_axis=False):
|
| 135 |
+
"""Apply rotation, translation, and optionally flip the molecule."""
|
| 136 |
+
radians = np.deg2rad(degrees)
|
| 137 |
+
rotation_matrix = np.array([
|
| 138 |
+
[np.cos(radians), -np.sin(radians), 0],
|
| 139 |
+
[np.sin(radians), np.cos(radians), 0],
|
| 140 |
+
[0, 0, 1]
|
| 141 |
+
])
|
| 142 |
+
AllChem.Compute2DCoords(mol)
|
| 143 |
+
|
| 144 |
+
conf = mol.GetConformer()
|
| 145 |
+
for i in range(conf.GetNumAtoms()):
|
| 146 |
+
pos = np.array(conf.GetAtomPosition(i))
|
| 147 |
+
new_pos = np.dot(rotation_matrix, pos)
|
| 148 |
+
new_pos[0] += translate_x # Translate along the x-axis
|
| 149 |
+
new_pos[1] += translate_y # Translate along the y-axis
|
| 150 |
+
if flip_x_axis:
|
| 151 |
+
new_pos[1] = -new_pos[1] # Flip along the x-axis
|
| 152 |
+
conf.SetAtomPosition(i, new_pos)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def tailored_framework_example(mol_ms):
|
| 156 |
+
# remove lone atoms
|
| 157 |
+
# define all atoms to be atom number 1
|
| 158 |
+
# define all bonds to be single bonds
|
| 159 |
+
|
| 160 |
+
mol_ms_w = Chem.RWMol(mol_ms)
|
| 161 |
+
atom_idx_to_remove = []
|
| 162 |
+
for atom in mol_ms_w.GetAtoms():
|
| 163 |
+
# lone atom. Need to remove it to create the generic framework.
|
| 164 |
+
if atom.GetDegree() == 1:
|
| 165 |
+
atom_idx_to_remove.append(atom.GetIdx())
|
| 166 |
+
continue
|
| 167 |
+
atom.SetAtomicNum(0)
|
| 168 |
+
|
| 169 |
+
for bond in mol_ms_w.GetBonds():
|
| 170 |
+
bond.SetBondType(Chem.rdchem.BondType.SINGLE)
|
| 171 |
+
|
| 172 |
+
atom_idx_to_remove.sort(reverse=True)
|
| 173 |
+
for atom_idx in atom_idx_to_remove:
|
| 174 |
+
mol_ms_w.RemoveAtom(atom_idx)
|
| 175 |
+
|
| 176 |
+
mol_ms_new = mol_ms_w.GetMol()
|
| 177 |
+
return mol_ms_new
|
protac_splitter/evaluation.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Evaluation functions for the protac_splitter package. They need to be generic to accomodate predictions coming from different models. """
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import re
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Tuple, Any, Dict, Optional, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from rdkit import Chem, RDLogger
|
| 10 |
+
from rdkit.Chem import DataStructs
|
| 11 |
+
|
| 12 |
+
# Disable RDKit logging: when checking SMILES validity, we suppress warnings
|
| 13 |
+
RDLogger.DisableLog("rdApp.*")
|
| 14 |
+
|
| 15 |
+
from .chemoinformatics import (
|
| 16 |
+
canonize,
|
| 17 |
+
canonize_smiles,
|
| 18 |
+
remove_stereo,
|
| 19 |
+
get_substr_match,
|
| 20 |
+
)
|
| 21 |
+
from .protac_cheminformatics import reassemble_protac
|
| 22 |
+
from .graphs_utils import (
|
| 23 |
+
get_smiles2graph_edit_distance,
|
| 24 |
+
get_smiles2graph_edit_distance_norm,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def is_valid_smiles(
|
| 29 |
+
smiles: Optional[str],
|
| 30 |
+
return_mol: bool = False,
|
| 31 |
+
) -> Union[bool, Tuple[bool, Chem.Mol]]:
|
| 32 |
+
""" Check if a SMILES is valid, i.e., it can be parsed by RDKit.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
smiles (Optional[str]): The SMILES to check.
|
| 36 |
+
return_mol (bool): If True, return the RDKit molecule object, i.e., `(is_valid, mol)`.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
bool | Tuple[bool, Chem.Mol]: True if the SMILES is valid, False otherwise. If return_mol is True, also return the RDKit molecule object.
|
| 40 |
+
"""
|
| 41 |
+
if smiles is None:
|
| 42 |
+
return False
|
| 43 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 44 |
+
if return_mol:
|
| 45 |
+
return mol is not None, mol
|
| 46 |
+
return mol is not None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def has_three_substructures(smiles: Optional[str]) -> bool:
|
| 50 |
+
""" Check if a PROTAC SMILES has three substructures. """
|
| 51 |
+
if smiles is None:
|
| 52 |
+
return False
|
| 53 |
+
return smiles.count(".") == 2
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def has_all_attachment_points(smiles: Optional[str]) -> bool:
|
| 57 |
+
""" Check if a PROTAC SMILES has all attachment points, i.e., [*:1] and [*:2], two each. """
|
| 58 |
+
if smiles is None:
|
| 59 |
+
return False
|
| 60 |
+
return smiles.count("[*:1]") == 2 and smiles.count("[*:2]") == 2
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def split_prediction(
|
| 64 |
+
pred: str,
|
| 65 |
+
poi_attachment_id: int = 1,
|
| 66 |
+
e3_attachment_id: int = 2,
|
| 67 |
+
) -> Optional[dict[str, str]]:
|
| 68 |
+
""" Split a PROTAC SMILES prediction into its three substructures.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
pred (str): The SMILES of the PROTAC molecule.
|
| 72 |
+
poi_attachment_id (int): The attachment point ID for the POI substructure.
|
| 73 |
+
e3_attachment_id (int): The attachment point ID for the E3 substructure.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
dict[str, str] | None: A dictionary (with keys: 'e3', 'linker', 'poi') containing the SMILES notations for the POI, linker, and E3 substructures, or None if the prediction is invalid
|
| 77 |
+
"""
|
| 78 |
+
ret = {k: None for k in ['poi', 'linker', 'e3']}
|
| 79 |
+
if pred is None:
|
| 80 |
+
return ret
|
| 81 |
+
|
| 82 |
+
ligands = pred.split('.')
|
| 83 |
+
if len(ligands) != 3:
|
| 84 |
+
return ret
|
| 85 |
+
|
| 86 |
+
for ligand in ligands:
|
| 87 |
+
if f'[*:{poi_attachment_id}]' in ligand and f'[*:{e3_attachment_id}]' not in ligand:
|
| 88 |
+
ret['poi'] = ligand
|
| 89 |
+
elif f'[*:{e3_attachment_id}]' in ligand and f'[*:{poi_attachment_id}]' not in ligand:
|
| 90 |
+
ret['e3'] = ligand
|
| 91 |
+
elif f'[*:{poi_attachment_id}]' in ligand and f'[*:{e3_attachment_id}]' in ligand:
|
| 92 |
+
ret['linker'] = ligand
|
| 93 |
+
return ret
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def rename_attachment_id(mol: Union[str, Chem.Mol], old_id: int, new_id: int) -> Union[str, Chem.Mol]:
|
| 97 |
+
""" Rename an attachment point ID in a molecule.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
mol: The input molecule.
|
| 101 |
+
old_id: The old attachment point ID.
|
| 102 |
+
new_id: The new attachment point ID.
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
The renamed molecule.
|
| 106 |
+
"""
|
| 107 |
+
return_str = False
|
| 108 |
+
if isinstance(mol, Chem.Mol):
|
| 109 |
+
mol = Chem.MolToSmiles(mol, canonical=True)
|
| 110 |
+
return_str = True
|
| 111 |
+
# Regex-replace the patterns "[*:old_id]" or "[old_id*]" with "[*:new_id]"
|
| 112 |
+
mol = re.sub(rf'\[\*:{old_id}\]', f'[*:{new_id}]', mol)
|
| 113 |
+
mol = re.sub(rf'\[{old_id}\*\]', f'[*:{new_id}]', mol)
|
| 114 |
+
mol = canonize_smiles(mol)
|
| 115 |
+
if mol is None:
|
| 116 |
+
return None
|
| 117 |
+
mol = Chem.MolFromSmiles(mol)
|
| 118 |
+
if return_str:
|
| 119 |
+
return Chem.MolToSmiles(mol, canonical=True)
|
| 120 |
+
return mol
|
| 121 |
+
|
| 122 |
+
def at_least_two_ligands_correct(
|
| 123 |
+
protac_smiles: str,
|
| 124 |
+
ligands_smiles: str,
|
| 125 |
+
) -> bool:
|
| 126 |
+
""" Check if at least two ligands are correct. """
|
| 127 |
+
# Check if there is at least one "." in the ligands SMILES
|
| 128 |
+
if "." not in ligands_smiles:
|
| 129 |
+
return False
|
| 130 |
+
ligands = ligands_smiles.split(".")
|
| 131 |
+
return True
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def check_reassembly(
|
| 135 |
+
protac_smiles: str,
|
| 136 |
+
ligands_smiles: str,
|
| 137 |
+
stats: Optional[Dict[str, int]] = None,
|
| 138 |
+
linker_can_be_null: bool = False,
|
| 139 |
+
poi_attachment_id: int = 1,
|
| 140 |
+
e3_attachment_id: int = 2,
|
| 141 |
+
verbose: int = 0,
|
| 142 |
+
return_reassembled_smiles: bool = False,
|
| 143 |
+
) -> bool:
|
| 144 |
+
"""Check if the reassembled PROTAC matches the original PROTAC SMILES.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
protac_smiles (str): The original PROTAC SMILES.
|
| 148 |
+
ligands_smiles (str): The SMILES of the joined PROTAC ligands, separated by a "." (dot).
|
| 149 |
+
stats (Optional[Dict[str, int]]): A dictionary to store statistics about the reassembly process.
|
| 150 |
+
linker_can_be_null (bool): If False, the linker cannot be empty, and if so, a None will be returned. If True, a special check is performed to rename the E3 and WH attchament points to assemble them together.
|
| 151 |
+
poi_attachment_id (int): The label of the attachment point for the POI ligand, i.e., "[*:{poi_attachment_id}]". Default is 1.
|
| 152 |
+
e3_attachment_id (int): The label of the attachment point for the E3 binder, i.e., "[*:{e3_attachment_id}]". Default is 2.
|
| 153 |
+
verbose (int): The verbosity
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
bool: True if the reassembled PROTAC matches the original PROTAC SMILES, False otherwise. None if it failed.
|
| 157 |
+
"""
|
| 158 |
+
ligands_smiles = canonize_smiles(ligands_smiles)
|
| 159 |
+
if ligands_smiles is None:
|
| 160 |
+
if verbose:
|
| 161 |
+
logging.error('Ligand could be canonicalized.')
|
| 162 |
+
return (False, None) if return_reassembled_smiles else False
|
| 163 |
+
|
| 164 |
+
null_linker_e3 = f'[*:{e3_attachment_id}][*:{poi_attachment_id}]'
|
| 165 |
+
null_linker_poi = f'[*:{poi_attachment_id}][*:{e3_attachment_id}]'
|
| 166 |
+
linker_is_null = False
|
| 167 |
+
if null_linker_e3 in ligands_smiles or null_linker_poi in ligands_smiles:
|
| 168 |
+
# If the linker is empty, remove the linker atoms
|
| 169 |
+
ligands_smiles = ligands_smiles.replace(null_linker_poi, '')
|
| 170 |
+
ligands_smiles = ligands_smiles.replace(null_linker_e3, '')
|
| 171 |
+
ligands_smiles = ligands_smiles.replace('..', '.')
|
| 172 |
+
ligands_smiles = ligands_smiles.rstrip('.')
|
| 173 |
+
ligands_smiles = ligands_smiles.lstrip('.')
|
| 174 |
+
ligands_smiles = canonize_smiles(ligands_smiles)
|
| 175 |
+
linker_is_null = True
|
| 176 |
+
|
| 177 |
+
if linker_can_be_null or linker_is_null:
|
| 178 |
+
if len(ligands_smiles.split('.')) == 2:
|
| 179 |
+
# Replace the attachment points with a third one (they will be joined later)
|
| 180 |
+
ligands_smiles = rename_attachment_id(ligands_smiles, e3_attachment_id, max([poi_attachment_id, e3_attachment_id]) + 1)
|
| 181 |
+
ligands_smiles = rename_attachment_id(ligands_smiles, poi_attachment_id, max([poi_attachment_id, e3_attachment_id]) + 1)
|
| 182 |
+
|
| 183 |
+
ligands_mol = Chem.MolFromSmiles(ligands_smiles)
|
| 184 |
+
if ligands_mol is None:
|
| 185 |
+
if verbose:
|
| 186 |
+
logging.error('ligands_mol is None')
|
| 187 |
+
return (False, None) if return_reassembled_smiles else False
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
reassembled_mol = Chem.molzip(ligands_mol)
|
| 191 |
+
if reassembled_mol is None:
|
| 192 |
+
if stats is not None:
|
| 193 |
+
stats['molzip failed'] += 1
|
| 194 |
+
if verbose:
|
| 195 |
+
logging.error(f'molzip failed')
|
| 196 |
+
return (False, None) if return_reassembled_smiles else False
|
| 197 |
+
except:
|
| 198 |
+
if stats is not None:
|
| 199 |
+
stats['molzip failed (exception)'] += 1
|
| 200 |
+
if verbose:
|
| 201 |
+
logging.error(f'molzip failed (exception)')
|
| 202 |
+
return (False, None) if return_reassembled_smiles else False
|
| 203 |
+
|
| 204 |
+
try:
|
| 205 |
+
reassembled_smiles = canonize(Chem.MolToSmiles(reassembled_mol))
|
| 206 |
+
if reassembled_smiles is None:
|
| 207 |
+
if stats is not None:
|
| 208 |
+
stats['MolToSmiles of reassembled failed'] += 1
|
| 209 |
+
if verbose:
|
| 210 |
+
logging.error('MolToSmiles of reassembled failed')
|
| 211 |
+
return (False, None) if return_reassembled_smiles else False
|
| 212 |
+
except:
|
| 213 |
+
if stats is not None:
|
| 214 |
+
stats['MolToSmiles of reassembled failed'] += 1
|
| 215 |
+
if verbose:
|
| 216 |
+
logging.error('MolToSmiles of reassembled failed')
|
| 217 |
+
return (False, None) if return_reassembled_smiles else False
|
| 218 |
+
|
| 219 |
+
is_equal = canonize(protac_smiles) == reassembled_smiles
|
| 220 |
+
|
| 221 |
+
return (is_equal, reassembled_smiles) if return_reassembled_smiles else is_equal
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def check_substructs(
|
| 225 |
+
protac_smiles: str,
|
| 226 |
+
poi_smiles: str = None,
|
| 227 |
+
linker_smiles: str = None,
|
| 228 |
+
e3_smiles: str = None,
|
| 229 |
+
return_bond_types: bool = False,
|
| 230 |
+
poi_attachment_id: int = 1,
|
| 231 |
+
e3_attachment_id: int = 2,
|
| 232 |
+
pred: str = None,
|
| 233 |
+
) -> Union[bool, Tuple[bool, dict[str, str]]]:
|
| 234 |
+
""" DEPRECATED.
|
| 235 |
+
|
| 236 |
+
Check if the reassembled PROTAC is correct.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
protac_smiles (str): The SMILES of the PROTAC molecule.
|
| 240 |
+
poi_smiles (str): The SMILES of the POI ligand.
|
| 241 |
+
linker_smiles (str): The SMILES of the linker.
|
| 242 |
+
e3_smiles (str): The SMILES of the E3 binder.
|
| 243 |
+
return_bond_types (bool): If True, return the bond types used for the reassembly.
|
| 244 |
+
poi_attachment_id (int): The label of the attachment point for the POI ligand, i.e., "[*:{poi_attachment_id}]".
|
| 245 |
+
e3_attachment_id (int): The label of the attachment point for the E3 binder, i.e., "[*:{e3_attachment_id}]".
|
| 246 |
+
pred (str): The SMILES of the predicted PROTAC molecule.
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
bool | Tuple[bool, dict[str, str]]: True if the reassembled PROTAC is correct, False otherwise. If return_bond_types is True, also return the bond types used for the reassembly.
|
| 250 |
+
"""
|
| 251 |
+
def get_failed_return():
|
| 252 |
+
if return_bond_types:
|
| 253 |
+
return False, {}
|
| 254 |
+
return False
|
| 255 |
+
|
| 256 |
+
# Make some checks before starting and fail if necessary
|
| 257 |
+
all_subs_none = all(v is None for v in [poi_smiles, linker_smiles, e3_smiles])
|
| 258 |
+
any_subs_none = any(v is None for v in [poi_smiles, linker_smiles, e3_smiles])
|
| 259 |
+
|
| 260 |
+
if pred is not None and all_subs_none:
|
| 261 |
+
# Split the prediction into the substructures
|
| 262 |
+
pred_substructs = split_prediction(pred, poi_attachment_id, e3_attachment_id)
|
| 263 |
+
if any(v is None for v in pred_substructs.values()):
|
| 264 |
+
return get_failed_return()
|
| 265 |
+
poi_smiles = pred_substructs['poi']
|
| 266 |
+
linker_smiles = pred_substructs['linker']
|
| 267 |
+
e3_smiles = pred_substructs['e3']
|
| 268 |
+
elif pred is None and any_subs_none:
|
| 269 |
+
return get_failed_return()
|
| 270 |
+
elif pred is None and all_subs_none:
|
| 271 |
+
logging.warning("Arguments 'pred' and 'poi_smiles', 'linker_smiles', 'e3_smiles' cannot be all None.")
|
| 272 |
+
return get_failed_return()
|
| 273 |
+
|
| 274 |
+
if f"[*:{poi_attachment_id}]" in e3_smiles:
|
| 275 |
+
return get_failed_return()
|
| 276 |
+
if f"[*:{e3_attachment_id}]" in poi_smiles:
|
| 277 |
+
return get_failed_return()
|
| 278 |
+
if f"[*:{poi_attachment_id}]" not in linker_smiles:
|
| 279 |
+
return get_failed_return()
|
| 280 |
+
if f"[*:{e3_attachment_id}]" not in linker_smiles:
|
| 281 |
+
return get_failed_return()
|
| 282 |
+
|
| 283 |
+
correct_substructs = False
|
| 284 |
+
protac_mol = Chem.MolFromSmiles(protac_smiles)
|
| 285 |
+
protac_inchi = Chem.MolToInchi(protac_mol)
|
| 286 |
+
protac_smiles_canon = canonize_smiles(protac_smiles)
|
| 287 |
+
bond_types = {}
|
| 288 |
+
bonds = ['single', 'double', 'triple']
|
| 289 |
+
# for e3_bond_type, poi_bond_type in itertools.product([bonds, bonds]):
|
| 290 |
+
for e3_bond_type in bonds:
|
| 291 |
+
for poi_bond_type in bonds:
|
| 292 |
+
try:
|
| 293 |
+
assmbl_smiles, assmbl_mol = reassemble_protac(
|
| 294 |
+
poi_smiles,
|
| 295 |
+
linker_smiles,
|
| 296 |
+
e3_smiles,
|
| 297 |
+
e3_bond_type,
|
| 298 |
+
poi_bond_type,
|
| 299 |
+
poi_attachment_id,
|
| 300 |
+
e3_attachment_id,
|
| 301 |
+
)
|
| 302 |
+
if assmbl_mol is not None:
|
| 303 |
+
# If either the InChI or SMILES of the reassembled PROTAC is
|
| 304 |
+
# the same as the original PROTAC, then the reassembly is
|
| 305 |
+
# correct.
|
| 306 |
+
if protac_inchi == Chem.MolToInchi(assmbl_mol):
|
| 307 |
+
correct_substructs = True
|
| 308 |
+
bond_types['e3_bond_type'] = e3_bond_type
|
| 309 |
+
bond_types['poi_bond_type'] = poi_bond_type
|
| 310 |
+
break
|
| 311 |
+
if protac_smiles_canon == canonize_smiles(assmbl_smiles):
|
| 312 |
+
correct_substructs = True
|
| 313 |
+
bond_types['e3_bond_type'] = e3_bond_type
|
| 314 |
+
bond_types['poi_bond_type'] = poi_bond_type
|
| 315 |
+
break
|
| 316 |
+
except:
|
| 317 |
+
continue
|
| 318 |
+
if return_bond_types:
|
| 319 |
+
return correct_substructs, bond_types
|
| 320 |
+
return correct_substructs
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def score_prediction(
|
| 324 |
+
protac_smiles: str,
|
| 325 |
+
label_smiles: str,
|
| 326 |
+
pred_smiles: str,
|
| 327 |
+
rouge = None,
|
| 328 |
+
poi_attachment_id: int = 1,
|
| 329 |
+
e3_attachment_id: int = 2,
|
| 330 |
+
fpgen = Chem.rdFingerprintGenerator.GetMorganGenerator(radius=11, fpSize=2048),
|
| 331 |
+
compute_rdkit_metrics: bool = False,
|
| 332 |
+
compute_graph_metrics: bool = False,
|
| 333 |
+
graph_edit_kwargs: Dict[str, Any] = {},
|
| 334 |
+
) -> dict[str, float]:
|
| 335 |
+
""" Score a PROTAC SMILES prediction.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
protac_smiles (str): The SMILES of the PROTAC molecule.
|
| 339 |
+
label_smiles (str): The SMILES of the ground truth PROTAC molecule.
|
| 340 |
+
pred_smiles (str): The SMILES of the predicted PROTAC molecule.
|
| 341 |
+
rouge (Rouge | None): The Rouge object to use for scoring. If None, do not compute Rouge scores. Example: `rouge = evaluate.load("rouge")`
|
| 342 |
+
poi_attachment_id (int): The attachment point ID for the POI substructure.
|
| 343 |
+
e3_attachment_id (int): The attachment point ID for the E3 substructure.
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
dict[str, float]: A dictionary containing the scores for the prediction
|
| 347 |
+
"""
|
| 348 |
+
protac_mol = Chem.MolFromSmiles(protac_smiles)
|
| 349 |
+
protac_num_atoms = protac_mol.GetNumHeavyAtoms()
|
| 350 |
+
|
| 351 |
+
scores = {
|
| 352 |
+
'has_three_substructures': has_three_substructures(pred_smiles),
|
| 353 |
+
'has_all_attachment_points': has_all_attachment_points(pred_smiles),
|
| 354 |
+
'num_fragments': 0 if pred_smiles is None else pred_smiles.count('.') + 1,
|
| 355 |
+
'tanimoto_similarity': 0.0, # Default value
|
| 356 |
+
'valid': False,
|
| 357 |
+
'reassembly': False,
|
| 358 |
+
'reassembly_nostereo': False,
|
| 359 |
+
'heavy_atoms_difference': protac_num_atoms,
|
| 360 |
+
'heavy_atoms_difference_norm': 1.0,
|
| 361 |
+
'all_ligands_equal': False,
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
pred_substructs = split_prediction(pred_smiles, poi_attachment_id, e3_attachment_id)
|
| 365 |
+
|
| 366 |
+
# Compute metrics for the "entire" predicted PROTAC molecule
|
| 367 |
+
if None not in list(pred_substructs.values()):
|
| 368 |
+
e3_nostereo = remove_stereo(pred_substructs['e3'])
|
| 369 |
+
linker_nostereo = remove_stereo(pred_substructs['linker'])
|
| 370 |
+
poi_nostereo = remove_stereo(pred_substructs['poi'])
|
| 371 |
+
if None not in [e3_nostereo, linker_nostereo, poi_nostereo]:
|
| 372 |
+
pred_nostereo = f"{e3_nostereo}.{linker_nostereo}.{poi_nostereo}"
|
| 373 |
+
scores['reassembly_nostereo'] = check_reassembly(remove_stereo(protac_smiles), pred_nostereo)
|
| 374 |
+
|
| 375 |
+
scores['valid'] = is_valid_smiles(pred_smiles)
|
| 376 |
+
is_equal, reassembled_smiles = check_reassembly(protac_smiles, pred_smiles, return_reassembled_smiles=True)
|
| 377 |
+
scores['reassembly'] = is_equal
|
| 378 |
+
|
| 379 |
+
# Get the number of heavy atoms difference between the reassembled PROTAC and the ground truth PROTAC
|
| 380 |
+
if reassembled_smiles is not None:
|
| 381 |
+
reassembled_mol = Chem.MolFromSmiles(reassembled_smiles)
|
| 382 |
+
if reassembled_mol is not None:
|
| 383 |
+
scores['heavy_atoms_difference'] -= reassembled_mol.GetNumHeavyAtoms()
|
| 384 |
+
scores['heavy_atoms_difference_norm'] = scores['heavy_atoms_difference'] / protac_num_atoms
|
| 385 |
+
|
| 386 |
+
if scores['valid'] and compute_rdkit_metrics and fpgen is not None:
|
| 387 |
+
# Get Tanimoto similarity between the predicted PROTAC and the ground truth PROTAC
|
| 388 |
+
pred_mol = Chem.MolFromSmiles(pred_smiles)
|
| 389 |
+
label_mol = Chem.MolFromSmiles(label_smiles)
|
| 390 |
+
pred_fp = fpgen.GetFingerprint(pred_mol)
|
| 391 |
+
label_fp = fpgen.GetFingerprint(label_mol)
|
| 392 |
+
scores['tanimoto_similarity'] = DataStructs.TanimotoSimilarity(pred_fp, label_fp)
|
| 393 |
+
|
| 394 |
+
if rouge is not None:
|
| 395 |
+
rouge_output = rouge.compute(predictions=[pred_smiles], references=[label_smiles])
|
| 396 |
+
scores.update({k: v for k, v in rouge_output.items()})
|
| 397 |
+
|
| 398 |
+
# Compute metrics for each substructure
|
| 399 |
+
label_substructs = split_prediction(label_smiles, poi_attachment_id, e3_attachment_id)
|
| 400 |
+
|
| 401 |
+
# Set default values
|
| 402 |
+
for sub in ['e3', 'poi', 'linker']:
|
| 403 |
+
scores[f'{sub}_valid'] = False
|
| 404 |
+
scores[f'{sub}_equal'] = False
|
| 405 |
+
scores[f'{sub}_has_attachment_point(s)'] = False
|
| 406 |
+
scores[f'{sub}_tanimoto_similarity'] = 0.0
|
| 407 |
+
|
| 408 |
+
# NOTE: The graph edit distance can be very high and dependant on the
|
| 409 |
+
# graphs, but when the molecule is not valid, then we cannot compute it.
|
| 410 |
+
# Because of that, we instead set it to something very large, in case we
|
| 411 |
+
# need to sum the eval metrics.
|
| 412 |
+
scores[f'{sub}_graph_edit_distance'] = 1e64
|
| 413 |
+
scores[f'{sub}_graph_edit_distance_norm'] = 1.0
|
| 414 |
+
scores[f'{sub}_heavy_atoms_difference'] = 0
|
| 415 |
+
try:
|
| 416 |
+
scores[f'{sub}_heavy_atoms_difference'] = Chem.MolFromSmiles(label_substructs[sub]).GetNumHeavyAtoms()
|
| 417 |
+
except:
|
| 418 |
+
logging.warning(f"WARNING: {sub} substructure is None in the label: '{label_smiles}' - PROTAC: '{protac_smiles}'")
|
| 419 |
+
scores[f'{sub}_heavy_atoms_difference_norm'] = 1.0
|
| 420 |
+
|
| 421 |
+
# Calculate metrics for each substructure
|
| 422 |
+
for sub in ['e3', 'poi', 'linker']:
|
| 423 |
+
# Skip if the predicted substructure is None from `split_prediction`
|
| 424 |
+
pred_sub = pred_substructs[sub]
|
| 425 |
+
label_sub = label_substructs[sub]
|
| 426 |
+
if pred_sub is None:
|
| 427 |
+
continue
|
| 428 |
+
if label_sub is None:
|
| 429 |
+
logging.warning(f"WARNING: {sub} substructure is None in the label: '{label_smiles}' - PROTAC: '{protac_smiles}'")
|
| 430 |
+
continue
|
| 431 |
+
|
| 432 |
+
# Check if the predicted substructure is a valid RDKit molecule
|
| 433 |
+
sub_valid, sub_mol = is_valid_smiles(pred_sub, return_mol=True)
|
| 434 |
+
scores[f'{sub}_valid'] = sub_valid
|
| 435 |
+
|
| 436 |
+
if sub_mol is None:
|
| 437 |
+
continue
|
| 438 |
+
|
| 439 |
+
# Check if the predicted substructure has the correct attachment point(s)
|
| 440 |
+
if sub == 'e3':
|
| 441 |
+
if f'[*:{e3_attachment_id}]' in pred_sub and f'[*:{poi_attachment_id}]' not in pred_sub:
|
| 442 |
+
scores[f'{sub}_has_attachment_point(s)'] = True
|
| 443 |
+
elif sub == 'poi':
|
| 444 |
+
if f'[*:{poi_attachment_id}]' in pred_sub and f'[*:{e3_attachment_id}]' not in pred_sub:
|
| 445 |
+
scores[f'{sub}_has_attachment_point(s)'] = True
|
| 446 |
+
elif sub == 'linker':
|
| 447 |
+
if f'[*:{poi_attachment_id}]' in pred_sub and f'[*:{e3_attachment_id}]' in pred_sub:
|
| 448 |
+
scores[f'{sub}_has_attachment_point(s)'] = True
|
| 449 |
+
|
| 450 |
+
# Check if the predicted substructure InChI is the same as the ground truth substructure InChI
|
| 451 |
+
if scores[f'{sub}_valid']:
|
| 452 |
+
# scores[f'{sub}_equal'] = Chem.MolToInchi(sub_mol) == Chem.MolToInchi(Chem.MolFromSmiles(label_sub))
|
| 453 |
+
canon_pred = canonize_smiles(pred_sub)
|
| 454 |
+
canon_label = canonize_smiles(label_sub)
|
| 455 |
+
scores[f'{sub}_equal'] = canon_pred == canon_label
|
| 456 |
+
|
| 457 |
+
# Compute graph-related metrics
|
| 458 |
+
if scores[f'{sub}_valid'] and compute_graph_metrics:
|
| 459 |
+
scores[f'{sub}_graph_edit_distance'] = get_smiles2graph_edit_distance(pred_sub, label_sub, **graph_edit_kwargs)
|
| 460 |
+
scores[f'{sub}_graph_edit_distance_norm'] = get_smiles2graph_edit_distance_norm(
|
| 461 |
+
smi1=pred_sub,
|
| 462 |
+
smi2=label_sub,
|
| 463 |
+
ged_G1_G2=scores[f'{sub}_graph_edit_distance'],
|
| 464 |
+
**graph_edit_kwargs,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# Get the number of heavy atoms difference between the predicted substructure and the ground truth substructure
|
| 468 |
+
if scores[f'{sub}_valid']:
|
| 469 |
+
pred_mol = Chem.MolFromSmiles(pred_sub)
|
| 470 |
+
label_mol = Chem.MolFromSmiles(label_sub)
|
| 471 |
+
if label_mol is None:
|
| 472 |
+
logging.warning(f"WARNING: {sub} substructure is None in the label: '{label_smiles}' - PROTAC: '{protac_smiles}'")
|
| 473 |
+
continue
|
| 474 |
+
scores[f'{sub}_heavy_atoms_difference'] -= pred_mol.GetNumHeavyAtoms()
|
| 475 |
+
scores[f'{sub}_heavy_atoms_difference_norm'] = scores[f'{sub}_heavy_atoms_difference'] / label_mol.GetNumHeavyAtoms()
|
| 476 |
+
|
| 477 |
+
# Get Tanimoto similarity b/w the predicted substructure and the ground truth
|
| 478 |
+
if scores[f'{sub}_valid'] and compute_rdkit_metrics:
|
| 479 |
+
pred_mol = Chem.MolFromSmiles(pred_sub)
|
| 480 |
+
label_mol = Chem.MolFromSmiles(label_sub)
|
| 481 |
+
if label_mol is None:
|
| 482 |
+
logging.warning(f"WARNING: {sub} substructure is None in the label: '{label_smiles}' - PROTAC: '{protac_smiles}'")
|
| 483 |
+
continue
|
| 484 |
+
pred_fp = fpgen.GetFingerprint(pred_mol)
|
| 485 |
+
label_fp = fpgen.GetFingerprint(label_mol)
|
| 486 |
+
scores[f'{sub}_tanimoto_similarity'] = DataStructs.TanimotoSimilarity(pred_fp, label_fp)
|
| 487 |
+
|
| 488 |
+
# Compute Rouge scores
|
| 489 |
+
if rouge is not None:
|
| 490 |
+
rouge_output = rouge.compute(predictions=[pred_sub], references=[label_sub])
|
| 491 |
+
scores.update({f'{sub}_{k}': v for k, v in rouge_output.items()})
|
| 492 |
+
|
| 493 |
+
scores['all_ligands_equal'] = all([scores[f'{sub}_equal'] for sub in ['e3', 'poi', 'linker']])
|
| 494 |
+
|
| 495 |
+
return scores
|
protac_splitter/fixing_functions.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
from rdkit import Chem
|
| 5 |
+
|
| 6 |
+
from protac_splitter.chemoinformatics import (
|
| 7 |
+
canonize,
|
| 8 |
+
dummy2query,
|
| 9 |
+
remove_attach_atom,
|
| 10 |
+
remove_dummy_atoms,
|
| 11 |
+
)
|
| 12 |
+
from protac_splitter.evaluation import (
|
| 13 |
+
split_prediction,
|
| 14 |
+
check_reassembly,
|
| 15 |
+
)
|
| 16 |
+
from protac_splitter.data.curation.substructure_extraction import get_attachment_bonds
|
| 17 |
+
|
| 18 |
+
def fix_tetrahedral_centers_ligand(
|
| 19 |
+
protac_mol: Chem.Mol,
|
| 20 |
+
ligand_smiles: str,
|
| 21 |
+
attachment_id: int = 1,
|
| 22 |
+
) -> Optional[str]:
|
| 23 |
+
""" Fixes the tetrahedral centers of a ligand in a PROTAC molecule.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
protac_mol (Chem.Mol): The RDKit molecule object of the PROTAC.
|
| 27 |
+
ligand_smiles (str): The SMILES of the ligand to fix.
|
| 28 |
+
attachment_id (int): The attachment point id of the ligand. Default is 1.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
A string containing the fixed ligand SMILES, or None if the fixing process failed.
|
| 32 |
+
"""
|
| 33 |
+
ligand_mol = Chem.MolFromSmiles(ligand_smiles)
|
| 34 |
+
if ligand_mol is None:
|
| 35 |
+
logging.error(f"Invalid ligand SMILES: {ligand_smiles}")
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
ligand_mol = remove_dummy_atoms(ligand_mol)
|
| 39 |
+
ligand_match = protac_mol.GetSubstructMatch(ligand_mol, useChirality=False) # useChirality=True
|
| 40 |
+
|
| 41 |
+
# Get bonds to break to separate the ligand
|
| 42 |
+
bonds_to_break = get_attachment_bonds(protac_mol, ligand_match)
|
| 43 |
+
|
| 44 |
+
# Return if no bonds are found
|
| 45 |
+
if len(bonds_to_break) != 1:
|
| 46 |
+
logging.error('ERROR: Multiple attachment bonds')
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
# Break the bonds to isolate the ligand
|
| 50 |
+
frag_ligand_mol = Chem.FragmentOnBonds(protac_mol, bonds_to_break, addDummies=True, dummyLabels=[(attachment_id, attachment_id)])
|
| 51 |
+
|
| 52 |
+
# Get the fragments resulting from bond breaking
|
| 53 |
+
try:
|
| 54 |
+
frags = Chem.GetMolFrags(frag_ligand_mol, asMols=True, sanitizeFrags=True)
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logging.error(e)
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
# Identify the ligand fragment
|
| 60 |
+
ligand_fragment = None
|
| 61 |
+
for frag in frags:
|
| 62 |
+
if frag.HasSubstructMatch(ligand_mol):
|
| 63 |
+
ligand_fragment = frag
|
| 64 |
+
break
|
| 65 |
+
if ligand_fragment is None:
|
| 66 |
+
logging.error('ERROR: POI fragment not found')
|
| 67 |
+
|
| 68 |
+
ligand_fixed = Chem.MolToSmiles(ligand_fragment)
|
| 69 |
+
ligand_fixed = canonize(ligand_fixed.replace(f'[{attachment_id}*]', f'[*:{attachment_id}]'))
|
| 70 |
+
return ligand_fixed
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def fix_prediction(
|
| 74 |
+
protac_smiles: str,
|
| 75 |
+
pred_smiles: str,
|
| 76 |
+
poi_attachment_id: int = 1,
|
| 77 |
+
e3_attachment_id: int = 2,
|
| 78 |
+
remove_stereochemistry: bool = False,
|
| 79 |
+
verbose: int = 0,
|
| 80 |
+
) -> Optional[str]:
|
| 81 |
+
""" Fixes a prediction by replacing the substructure that does not match the PROTAC with the rest of the PROTAC.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
protac_smiles (str): The SMILES of the PROTAC.
|
| 85 |
+
pred_smiles (str): The SMILES of the prediction.
|
| 86 |
+
poi_attachment_id (int): The attachment point id of the POI. Default is 1.
|
| 87 |
+
e3_attachment_id (int): The attachment point id of the E3 ligase. Default is 2.
|
| 88 |
+
verbose (int): The verbosity level. Default is 0.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
A string containing the fixed predictions, or None if the fixing process failed.
|
| 92 |
+
"""
|
| 93 |
+
protac_mol = Chem.MolFromSmiles(protac_smiles)
|
| 94 |
+
if protac_mol is None:
|
| 95 |
+
logging.warning(f"Invalid PROTAC SMILES: {protac_smiles}")
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
substructs = split_prediction(pred_smiles)
|
| 99 |
+
|
| 100 |
+
# If there are at least two None values, there's nothing we can do to fix it
|
| 101 |
+
if sum(v is None for v in substructs.values()) >= 2:
|
| 102 |
+
logging.warning(f'Unable to continue, more than two substructures are not valid for given input: "{pred_smiles}"')
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
# Get molecules of PROTAC and substructures
|
| 106 |
+
substructs = {k: {'smiles': v, 'mol': Chem.MolFromSmiles(v) if v is not None else v} for k, v in substructs.items()}
|
| 107 |
+
|
| 108 |
+
# Check if renaming the attachment points might already fix the prediction
|
| 109 |
+
for sub in ['poi', 'e3', 'both']:
|
| 110 |
+
if sub == 'e3':
|
| 111 |
+
if substructs['e3']['smiles'] is None:
|
| 112 |
+
continue
|
| 113 |
+
e3_attempt = substructs['e3']['smiles'].replace(f'[*:{poi_attachment_id}]', f'[*:{e3_attachment_id}]')
|
| 114 |
+
poi_attempt = substructs['poi']['smiles']
|
| 115 |
+
if sub == 'poi':
|
| 116 |
+
if substructs['poi']['smiles'] is None:
|
| 117 |
+
continue
|
| 118 |
+
e3_attempt = substructs['e3']['smiles']
|
| 119 |
+
poi_attempt = substructs['poi']['smiles'].replace(f'[*:{e3_attachment_id}]', f'[*:{poi_attachment_id}]')
|
| 120 |
+
else:
|
| 121 |
+
if substructs['e3']['smiles'] is None or substructs['poi']['smiles'] is None:
|
| 122 |
+
continue
|
| 123 |
+
e3_attempt = substructs['e3']['smiles'].replace(f'[*:{e3_attachment_id}]', f'[*:{poi_attachment_id}]')
|
| 124 |
+
poi_attempt = substructs['poi']['smiles'].replace(f'[*:{poi_attachment_id}]', f'[*:{e3_attachment_id}]')
|
| 125 |
+
|
| 126 |
+
protac_attempt = f"{e3_attempt}.{substructs['linker']['smiles']}.{poi_attempt}"
|
| 127 |
+
if check_reassembly(protac_smiles, protac_attempt):
|
| 128 |
+
logging.info(f'Input works when renaming attachment points in {sub.title()} substruct. SMILES: "{protac_attempt}"')
|
| 129 |
+
return protac_attempt
|
| 130 |
+
|
| 131 |
+
# Check if swapping the POI and E3 attachments in the linker might already fix the prediction
|
| 132 |
+
if substructs['linker']['smiles'] is None:
|
| 133 |
+
continue
|
| 134 |
+
linker_attempt = substructs['linker']['smiles']
|
| 135 |
+
linker_attempt = linker_attempt.replace(f'[*:{poi_attachment_id}]', f'[*:DUMMY]')
|
| 136 |
+
linker_attempt = linker_attempt.replace(f'[*:{e3_attachment_id}]', f'[*:{poi_attachment_id}]')
|
| 137 |
+
linker_attempt = linker_attempt.replace(f'[*:DUMMY]', f'[*:{e3_attachment_id}]')
|
| 138 |
+
|
| 139 |
+
# Try with the original POI and E3 substructures
|
| 140 |
+
protac_attempt = f"{substructs['e3']['smiles']}.{linker_attempt}.{substructs['poi']['smiles']}"
|
| 141 |
+
if check_reassembly(protac_smiles, protac_attempt):
|
| 142 |
+
logging.info(f'Input works when swapping POI and E3 attachment points in the linker. Fixed SMILES: "{protac_attempt}"')
|
| 143 |
+
return protac_attempt
|
| 144 |
+
|
| 145 |
+
# Try with the swapped POI and E3 substructures
|
| 146 |
+
protac_attempt = f"{e3_attempt}.{linker_attempt}.{poi_attempt}"
|
| 147 |
+
if check_reassembly(protac_smiles, protac_attempt):
|
| 148 |
+
logging.info(f'Input works when swapping POI and E3 attachment points in the linker and in {sub.title()} substruct. Fixed SMILES: "{protac_attempt}"')
|
| 149 |
+
return protac_attempt
|
| 150 |
+
|
| 151 |
+
# Check if removing stereochemistry results in a valid prediction
|
| 152 |
+
if remove_stereochemistry:
|
| 153 |
+
Chem.RemoveStereochemistry(protac_mol)
|
| 154 |
+
protac_smiles = Chem.MolToSmiles(protac_mol, canonical=True)
|
| 155 |
+
for k, v in substructs.items():
|
| 156 |
+
if v['mol'] is not None:
|
| 157 |
+
Chem.RemoveStereochemistry(v['mol'])
|
| 158 |
+
substructs[k]['smiles'] = Chem.MolToSmiles(v['mol'], canonical=True)
|
| 159 |
+
|
| 160 |
+
if all(v['mol'] is not None for v in substructs.values()):
|
| 161 |
+
if check_reassembly(
|
| 162 |
+
protac_smiles,
|
| 163 |
+
'.'.join([v['smiles'] for v in substructs.values()]),
|
| 164 |
+
):
|
| 165 |
+
logging.info(f'Input works when removing stereochemistry. SMILES: "{pred_smiles}"')
|
| 166 |
+
return f"{substructs['e3']['smiles']}.{substructs['linker']['smiles']}.{substructs['poi']['smiles']}"
|
| 167 |
+
|
| 168 |
+
# Check if any of the substructures is NOT a substructure of the PROTAC, if
|
| 169 |
+
# so, we mark it as the wrong substructure to fix.
|
| 170 |
+
num_matches = 0
|
| 171 |
+
wrong_substruct = None
|
| 172 |
+
for sub in ['poi', 'linker', 'e3']:
|
| 173 |
+
if substructs[sub]['mol'] is None:
|
| 174 |
+
substructs[sub]['match'] = False
|
| 175 |
+
wrong_substruct = sub
|
| 176 |
+
elif protac_mol.HasSubstructMatch(dummy2query(substructs[sub]['mol'])):
|
| 177 |
+
substructs[sub]['match'] = True
|
| 178 |
+
num_matches += 1
|
| 179 |
+
else:
|
| 180 |
+
substructs[sub]['match'] = False
|
| 181 |
+
wrong_substruct = sub
|
| 182 |
+
|
| 183 |
+
if num_matches < 2:
|
| 184 |
+
logging.warning(f'Prediction does not contain at least two matching substructures of the PROTAC. Num matches: {num_matches}. Prediction SMILES: "{pred_smiles}"')
|
| 185 |
+
return None
|
| 186 |
+
|
| 187 |
+
# If the wrong substructure is still matching in the PROTAC, we need to a
|
| 188 |
+
# more complex approach to fix the prediction (see below).
|
| 189 |
+
def remove_substructure(mol, substructure, attachment_id, replaceDummies=False):
|
| 190 |
+
if mol is None or substructure is None:
|
| 191 |
+
return None
|
| 192 |
+
smaller_mol = Chem.ReplaceCore(
|
| 193 |
+
mol,
|
| 194 |
+
substructure,
|
| 195 |
+
labelByIndex=False,
|
| 196 |
+
replaceDummies=replaceDummies,
|
| 197 |
+
)
|
| 198 |
+
if smaller_mol is None:
|
| 199 |
+
logging.warning(f'Failed to remove substructure from prediction SMILES: "{pred_smiles}"')
|
| 200 |
+
return None
|
| 201 |
+
smaller_smiles = Chem.MolToSmiles(smaller_mol, canonical=True)
|
| 202 |
+
smaller_smiles = smaller_smiles.replace('[1*]', f'[*:{attachment_id}]')
|
| 203 |
+
smaller_smiles = smaller_smiles.replace('[2*]', f'[*:{attachment_id}]')
|
| 204 |
+
smaller_mol = canonize(Chem.MolFromSmiles(smaller_smiles))
|
| 205 |
+
return smaller_mol
|
| 206 |
+
|
| 207 |
+
# If we still have 3 matches: for each substructure, we progressively remove
|
| 208 |
+
# the other substructures, then we check if the resulting molecule is valid
|
| 209 |
+
# and has only one fragment.
|
| 210 |
+
if num_matches == 3:
|
| 211 |
+
wrong_substruct = None
|
| 212 |
+
for sub in ['poi', 'linker', 'e3']:
|
| 213 |
+
removed_mol = Chem.MolFromSmiles(protac_smiles)
|
| 214 |
+
|
| 215 |
+
# Put the current substructure at the end of the list [poi, e3, linker]
|
| 216 |
+
sub_names = ['poi', 'e3', 'linker']
|
| 217 |
+
sub_names.remove(sub)
|
| 218 |
+
sub_names.append(sub)
|
| 219 |
+
# The linker often matches in many parts of the PROTAC, so we remove
|
| 220 |
+
# it when checking the E3 and POI substructures.
|
| 221 |
+
if sub != 'linker':
|
| 222 |
+
sub_names.remove('linker')
|
| 223 |
+
|
| 224 |
+
for s in sub_names:
|
| 225 |
+
attachment_id = poi_attachment_id if s == 'poi' else e3_attachment_id
|
| 226 |
+
removed_mol = remove_substructure(
|
| 227 |
+
removed_mol,
|
| 228 |
+
dummy2query(substructs[s]['mol']),
|
| 229 |
+
attachment_id=attachment_id,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Check if resulting molecule is None, if so, it is the wrong one
|
| 233 |
+
if removed_mol is None:
|
| 234 |
+
substructs[sub]['match'] = False
|
| 235 |
+
wrong_substruct = sub
|
| 236 |
+
num_matches -= 1
|
| 237 |
+
break
|
| 238 |
+
|
| 239 |
+
# Count the number of fragments in the removed molecule
|
| 240 |
+
num_fragments = Chem.GetMolFrags(removed_mol, asMols=True, sanitizeFrags=False)
|
| 241 |
+
if len(num_fragments) > 1:
|
| 242 |
+
substructs[sub]['match'] = False
|
| 243 |
+
wrong_substruct = sub
|
| 244 |
+
num_matches -= 1
|
| 245 |
+
break
|
| 246 |
+
|
| 247 |
+
if num_matches == 3:
|
| 248 |
+
logging.warning(f'Prediction already contains all matching substructures of the PROTAC. Prediction SMILES: "{pred_smiles}"')
|
| 249 |
+
return None
|
| 250 |
+
|
| 251 |
+
# Get the order in which to remove the substructures and get the final one
|
| 252 |
+
# as the fixed molecule.
|
| 253 |
+
if wrong_substruct == 'linker':
|
| 254 |
+
poi_atoms = substructs['poi']['mol'].GetNumAtoms()
|
| 255 |
+
e3_atoms = substructs['e3']['mol'].GetNumAtoms()
|
| 256 |
+
order = ['poi', 'e3'] if poi_atoms > e3_atoms else ['e3', 'poi']
|
| 257 |
+
else:
|
| 258 |
+
if wrong_substruct == 'poi':
|
| 259 |
+
order = ['e3', 'linker']
|
| 260 |
+
else:
|
| 261 |
+
order = ['poi', 'linker']
|
| 262 |
+
|
| 263 |
+
logging.debug(f'Wrong substructure: {wrong_substruct.upper()}. Order: {order}')
|
| 264 |
+
|
| 265 |
+
fixed_mol = protac_mol
|
| 266 |
+
for sub in order:
|
| 267 |
+
logging.debug(f'Removing substructure {sub.upper()} from PROTAC.')
|
| 268 |
+
|
| 269 |
+
if 'linker' not in order:
|
| 270 |
+
fixed_attach_id = poi_attachment_id if sub == 'poi' else e3_attachment_id
|
| 271 |
+
else:
|
| 272 |
+
fixed_attach_id = poi_attachment_id if 'e3' in order else e3_attachment_id
|
| 273 |
+
|
| 274 |
+
if sub == 'linker':
|
| 275 |
+
attach_id = poi_attachment_id if wrong_substruct == 'poi' else e3_attachment_id
|
| 276 |
+
fixed_attach_id = poi_attachment_id if wrong_substruct == 'poi' else e3_attachment_id
|
| 277 |
+
query_mol = remove_attach_atom(substructs[sub]['mol'], attach_id)
|
| 278 |
+
replaceDummies = True
|
| 279 |
+
else:
|
| 280 |
+
query_mol = dummy2query(substructs[sub]['mol'])
|
| 281 |
+
replaceDummies = False
|
| 282 |
+
|
| 283 |
+
if verbose:
|
| 284 |
+
# display(Draw.MolToImage(fixed_mol, legend=f"Starting molecule", size=(800, 300)))
|
| 285 |
+
# display(Draw.MolToImage(query_mol, legend=f"Molecule {sub.upper()} to remove", size=(800, 300)))
|
| 286 |
+
pass
|
| 287 |
+
|
| 288 |
+
fixed_mol_tmp = remove_substructure(
|
| 289 |
+
fixed_mol,
|
| 290 |
+
query_mol,
|
| 291 |
+
attachment_id=fixed_attach_id,
|
| 292 |
+
replaceDummies=replaceDummies,
|
| 293 |
+
)
|
| 294 |
+
if fixed_mol_tmp is None:
|
| 295 |
+
logging.debug(f'Failed to replace substructure "{sub}" in prediction SMILES: "{pred_smiles}"')
|
| 296 |
+
continue
|
| 297 |
+
|
| 298 |
+
fixed_mol = fixed_mol_tmp
|
| 299 |
+
|
| 300 |
+
# If there are multiple fragments, keep the biggest one
|
| 301 |
+
fragments = Chem.GetMolFrags(fixed_mol, asMols=True)
|
| 302 |
+
if len(fragments) > 1:
|
| 303 |
+
logging.debug(f'Fixed molecule contains more than one fragment. Keeping the biggest one.')
|
| 304 |
+
max_frag = max(fragments, key=lambda x: x.GetNumAtoms())
|
| 305 |
+
fixed_mol = max_frag
|
| 306 |
+
|
| 307 |
+
# Get the SMILES of the fixed molecule
|
| 308 |
+
fixed_smiles = Chem.MolToSmiles(canonize(fixed_mol), canonical=True)
|
| 309 |
+
substructs[wrong_substruct]['smiles'] = fixed_smiles
|
| 310 |
+
|
| 311 |
+
if verbose:
|
| 312 |
+
# display(Draw.MolToImage(fixed_mol, legend=f"{wrong_substruct.upper()} fixed molecule: {fixed_smiles}", size=(800, 300)))
|
| 313 |
+
pass
|
| 314 |
+
|
| 315 |
+
# Concatenate the substructures check if the re-assembly is correct
|
| 316 |
+
fixed_pred_smiles = f"{substructs['e3']['smiles']}.{substructs['linker']['smiles']}.{substructs['poi']['smiles']}"
|
| 317 |
+
|
| 318 |
+
if not check_reassembly(
|
| 319 |
+
protac_smiles,
|
| 320 |
+
fixed_pred_smiles,
|
| 321 |
+
):
|
| 322 |
+
# logging.warning(f"Failed to fix prediction, re-assembly check failed. Generated fixed prediction (failing): {fixed_pred_smiles}")
|
| 323 |
+
# return None
|
| 324 |
+
|
| 325 |
+
# Check if by flipping the tetrahedral centers of the ligands we can
|
| 326 |
+
# still fix the prediction.
|
| 327 |
+
protac_mol = canonize(Chem.MolFromSmiles(protac_smiles))
|
| 328 |
+
chiral_centers = Chem.FindMolChiralCenters(
|
| 329 |
+
protac_mol,
|
| 330 |
+
includeUnassigned=True,
|
| 331 |
+
useLegacyImplementation=False,
|
| 332 |
+
)
|
| 333 |
+
if not chiral_centers:
|
| 334 |
+
logging.warning(f"Failed to fix prediction, re-assembly check failed. Generated fixed prediction (failing): {fixed_pred_smiles}")
|
| 335 |
+
return None
|
| 336 |
+
|
| 337 |
+
# Attempt to fix the tetrahedral centers of the ligands
|
| 338 |
+
e3_fixed = fix_tetrahedral_centers_ligand(protac_mol, substructs['e3']['smiles'], attachment_id=e3_attachment_id)
|
| 339 |
+
poi_fixed = fix_tetrahedral_centers_ligand(protac_mol, substructs['poi']['smiles'], attachment_id=poi_attachment_id)
|
| 340 |
+
if e3_fixed is None or poi_fixed is None:
|
| 341 |
+
logging.warning(f"Failed to fix prediction, re-assembly check failed. Generated fixed prediction (failing): {fixed_pred_smiles}")
|
| 342 |
+
return None
|
| 343 |
+
|
| 344 |
+
# Update the substructures with the fixed ligands and check re-assembly
|
| 345 |
+
substructs['e3']['smiles'] = e3_fixed
|
| 346 |
+
substructs['poi']['smiles'] = poi_fixed
|
| 347 |
+
fixed_pred_smiles = f"{substructs['e3']['smiles']}.{substructs['linker']['smiles']}.{substructs['poi']['smiles']}"
|
| 348 |
+
if not check_reassembly(
|
| 349 |
+
protac_smiles,
|
| 350 |
+
fixed_pred_smiles,
|
| 351 |
+
):
|
| 352 |
+
logging.warning(f"Failed to fix prediction, re-assembly check failed. Generated fixed prediction (failing): {fixed_pred_smiles}")
|
| 353 |
+
return None
|
| 354 |
+
|
| 355 |
+
return fixed_pred_smiles
|
protac_splitter/graphs/README.md
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Graph-Based PROTAC-Splitter
|
| 2 |
+
|
| 3 |
+
## Heuristic Betweenness Centrality
|
| 4 |
+
|
| 5 |
+
```python
|
| 6 |
+
idx = 3765
|
| 7 |
+
for i in range(10):
|
| 8 |
+
# sample = held_out_df.sample(n=1, random_state=42 + i).iloc[0]
|
| 9 |
+
sample = held_out_df.iloc[i]
|
| 10 |
+
# sample = held_out_df.iloc[i]
|
| 11 |
+
protac_smiles = sample['PROTAC SMILES']
|
| 12 |
+
wh_smiles = sample['POI Ligand SMILES with direction']
|
| 13 |
+
lk_smiles = sample['Linker SMILES with direction']
|
| 14 |
+
e3_smiles = sample['E3 Binder SMILES with direction']
|
| 15 |
+
|
| 16 |
+
protac = Chem.MolFromSmiles(protac_smiles)
|
| 17 |
+
wh = Chem.MolFromSmiles(wh_smiles)
|
| 18 |
+
lk = Chem.MolFromSmiles(lk_smiles)
|
| 19 |
+
e3 = Chem.MolFromSmiles(e3_smiles)
|
| 20 |
+
|
| 21 |
+
# display_mol(Chem.MolFromSmiles(protac_smiles), w=1500, h=600)
|
| 22 |
+
get_mapped_protac_img(protac_smiles, wh_smiles, lk_smiles, e3_smiles, w=1500, h=600, display_image=True, useSVG=False)
|
| 23 |
+
# wh_edge = get_atom_idx_at_attachment(protac, wh, lk)
|
| 24 |
+
# e3_edge = get_atom_idx_at_attachment(protac, e3, lk)
|
| 25 |
+
|
| 26 |
+
ret = nx_split(protac_smiles, representative_e3s_fp, morgan_fp_generator, use_capacity_weight=False, betweenness_threshold=0.4)
|
| 27 |
+
e3_smiles = ret['e3']
|
| 28 |
+
wh_smiles = ret['poi']
|
| 29 |
+
linker_smiles = ret['linker']
|
| 30 |
+
top_nodes = ret['top_nodes']
|
| 31 |
+
centrality = ret['centrality']
|
| 32 |
+
|
| 33 |
+
# display_mol(Chem.MolFromSmiles(e3_smiles), w=800, h=400, legend="E3")
|
| 34 |
+
# display_mol(Chem.MolFromSmiles(linker_smiles), w=800, h=400, legend="Linker")
|
| 35 |
+
# display_mol(Chem.MolFromSmiles(wh_smiles), w=800, h=400, legend="WH")
|
| 36 |
+
|
| 37 |
+
display_mol(Chem.MolFromSmiles('.'.join([wh_smiles, linker_smiles, e3_smiles])), w=800, h=400, legend="Graph-based split")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
display(Draw.MolToImage(
|
| 41 |
+
protac,
|
| 42 |
+
size=(1500, 400),
|
| 43 |
+
highlightColor=(1, 0, 1, 0.3), # Light purple
|
| 44 |
+
highlightAtoms=top_nodes, # Highlight the top nodes
|
| 45 |
+
legend=f"Graph nodes: {top_nodes} (Betweenness centrality: {centrality[top_nodes[0]]:.3f})",
|
| 46 |
+
))
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
## Graph Edge Classifier Example
|
| 51 |
+
|
| 52 |
+
Example of how to use the GraphEdgeClassifier to train a model on a dataset of PROTACs and their ligands, and then predict edges in new PROTACs.
|
| 53 |
+
|
| 54 |
+
```python
|
| 55 |
+
label_cols = [c for c in train_set.columns if c.startswith("label_")]
|
| 56 |
+
train_set = sets["train"].dropna(subset=label_cols)
|
| 57 |
+
train_set = train_set[(train_set["label_e3_split"] + train_set["label_wh_split"]) <= 1]
|
| 58 |
+
X_train = train_set.drop(columns=label_cols)
|
| 59 |
+
|
| 60 |
+
graph_features = [c for c in X_train.columns if c.startswith("graph_")]
|
| 61 |
+
# graph_features = [
|
| 62 |
+
# "graph_betweenness",
|
| 63 |
+
# "graph_degree",
|
| 64 |
+
# "graph_degree_r2",
|
| 65 |
+
# "graph_degree_r3",
|
| 66 |
+
# ]
|
| 67 |
+
categorical_features = ["chem_bond_type", "chem_atom_u", "chem_atom_v"]
|
| 68 |
+
fingerprint_features = [c for c in X_train.columns if c.startswith("chem_mol_fp_")]
|
| 69 |
+
|
| 70 |
+
# Instantiate and train
|
| 71 |
+
clf = GraphEdgeClassifier(
|
| 72 |
+
graph_features=graph_features,
|
| 73 |
+
categorical_features=categorical_features,
|
| 74 |
+
fingerprint_features=fingerprint_features,
|
| 75 |
+
use_descriptors=False,
|
| 76 |
+
use_fingerprints=False,
|
| 77 |
+
binary=True,
|
| 78 |
+
)
|
| 79 |
+
y_train = train_set["label_is_split"].astype("int32") if clf.binary else GraphEdgeClassifier.build_multiclass_target(train_set)
|
| 80 |
+
|
| 81 |
+
clf.fit(X_train, y_train)
|
| 82 |
+
clf.save("../models/edge_classifier_bin.joblib")
|
| 83 |
+
print(f"Model saved to ../models/edge_classifier_bin.joblib")
|
| 84 |
+
|
| 85 |
+
label_cols = [c for c in train_set.columns if c.startswith("label_")]
|
| 86 |
+
train_set = sets["train"].dropna(subset=label_cols)
|
| 87 |
+
train_set = train_set[(train_set["label_e3_split"] + train_set["label_wh_split"]) <= 1]
|
| 88 |
+
X_train = train_set.drop(columns=label_cols)
|
| 89 |
+
|
| 90 |
+
graph_features = [c for c in X_train.columns if c.startswith("graph_")]
|
| 91 |
+
# graph_features = [
|
| 92 |
+
# "graph_betweenness",
|
| 93 |
+
# "graph_degree",
|
| 94 |
+
# "graph_degree_r2",
|
| 95 |
+
# "graph_degree_r3",
|
| 96 |
+
# ]
|
| 97 |
+
categorical_features = ["chem_bond_type", "chem_atom_u", "chem_atom_v"]
|
| 98 |
+
fingerprint_features = [c for c in X_train.columns if c.startswith("chem_mol_fp_")]
|
| 99 |
+
|
| 100 |
+
# Instantiate and train
|
| 101 |
+
clf = GraphEdgeClassifier(
|
| 102 |
+
graph_features=graph_features,
|
| 103 |
+
categorical_features=categorical_features,
|
| 104 |
+
fingerprint_features=fingerprint_features,
|
| 105 |
+
use_descriptors=False,
|
| 106 |
+
use_fingerprints=False,
|
| 107 |
+
binary=False,
|
| 108 |
+
)
|
| 109 |
+
y_train = train_set["label_is_split"].astype("int32") if clf.binary else GraphEdgeClassifier.build_multiclass_target(train_set)
|
| 110 |
+
|
| 111 |
+
clf.fit(X_train, y_train)
|
| 112 |
+
clf.save("../models/edge_classifier.joblib")
|
| 113 |
+
print(f"Model saved to ../models/edge_classifier.joblib")
|
| 114 |
+
```
|
protac_splitter/graphs/__init__.py
ADDED
|
File without changes
|
protac_splitter/graphs/e3_clustering.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple, Any, Dict
|
| 2 |
+
import functools
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from sklearn.cluster import AgglomerativeClustering, KMeans
|
| 8 |
+
from scipy.stats import skew
|
| 9 |
+
from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score
|
| 10 |
+
|
| 11 |
+
from rdkit import Chem, DataStructs
|
| 12 |
+
from rdkit.Chem import rdFingerprintGenerator
|
| 13 |
+
|
| 14 |
+
from protac_splitter.graphs.utils import get_fp, numpy_to_rdkit_fp
|
| 15 |
+
from protac_splitter.chemoinformatics import remove_dummy_atoms
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_umap_clusters_fp(fp_list: List[str], n_clusters: int = 7) -> np.ndarray:
|
| 19 |
+
"""
|
| 20 |
+
Cluster a list of SMILES strings using the umap clustering algorithm.
|
| 21 |
+
From Scaffold Splits Overestimate Virtual Screening Performance
|
| 22 |
+
https://arxiv.org/abs/2406.00873
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
fp_list (List[str]): List of SMILES strings.
|
| 26 |
+
n_clusters (int): The number of clusters to use for clustering.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
np.ndarray: Array of cluster labels corresponding to each SMILES string in the input list.
|
| 30 |
+
"""
|
| 31 |
+
ac = AgglomerativeClustering(n_clusters=n_clusters)
|
| 32 |
+
ac.fit_predict(np.stack(fp_list))
|
| 33 |
+
return ac.labels_
|
| 34 |
+
|
| 35 |
+
def get_kmeans_clusters_fp(fp_list: List[str], n_clusters: int = 10, return_centroids: bool = False) -> np.ndarray:
|
| 36 |
+
"""
|
| 37 |
+
Cluster a list of SMILES strings using the KMeans clustering algorithm.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
fp_list (List[str]): List of SMILES strings.
|
| 41 |
+
n_clusters (int): The number of clusters to use for clustering.
|
| 42 |
+
return_centroids (bool): If True, return the cluster centroids as well.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
np.ndarray: Array of cluster labels corresponding to each SMILES string in the input list.
|
| 46 |
+
"""
|
| 47 |
+
km = KMeans(n_clusters=n_clusters, n_init='auto', random_state=42, max_iter=1000)
|
| 48 |
+
if return_centroids:
|
| 49 |
+
km.fit(np.stack(fp_list))
|
| 50 |
+
return km.labels_, km.cluster_centers_
|
| 51 |
+
return km.fit_predict(np.stack(fp_list))
|
| 52 |
+
|
| 53 |
+
def evaluate_clusters(X: np.array, clusters: np.ndarray) -> Dict[str, float]:
|
| 54 |
+
""" Compute clustering metrics and assess cluster size distribution.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
X (np.array): The input data used for clustering.
|
| 58 |
+
clusters (np.ndarray): The cluster labels for each data point in X.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Dict[str, float]: A dictionary containing various clustering metrics:
|
| 62 |
+
- silhouette: Silhouette score of the clustering.
|
| 63 |
+
- davies_bouldin: Davies-Bouldin index of the clustering.
|
| 64 |
+
- calinski_harabasz: Calinski-Harabasz index of the clustering.
|
| 65 |
+
- avg_cluster_size: Average size of clusters.
|
| 66 |
+
- avg_cluster_data_ratio: Ratio of average cluster size to total data size.
|
| 67 |
+
- std_cluster_size: Standard deviation of cluster sizes.
|
| 68 |
+
- min_cluster_size: Minimum size of clusters.
|
| 69 |
+
- median_cluster_size: Median size of clusters.
|
| 70 |
+
- max_cluster_size: Maximum size of clusters.
|
| 71 |
+
- cluster_size_skewness: Skewness of cluster sizes indicating imbalance.
|
| 72 |
+
- num_clusters: Number of unique clusters found.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
unique_clusters = list(set(clusters))
|
| 76 |
+
|
| 77 |
+
if len(unique_clusters) < 2: # Avoid single-cluster issues
|
| 78 |
+
return {
|
| 79 |
+
"silhouette": -1,
|
| 80 |
+
"davies_bouldin": float("inf"),
|
| 81 |
+
"calinski_harabasz": -1,
|
| 82 |
+
"avg_cluster_size": len(X),
|
| 83 |
+
"avg_cluster_data_ratio": 1,
|
| 84 |
+
"std_cluster_size": 0,
|
| 85 |
+
"min_cluster_size": len(X),
|
| 86 |
+
"median_cluster_size": len(X),
|
| 87 |
+
"max_cluster_size": len(X),
|
| 88 |
+
"cluster_size_skewness": 0,
|
| 89 |
+
"num_clusters": 1,
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
# Compute standard clustering metrics
|
| 93 |
+
silhouette = silhouette_score(X, clusters)
|
| 94 |
+
davies_bouldin = davies_bouldin_score(X, clusters)
|
| 95 |
+
calinski_harabasz = calinski_harabasz_score(X, clusters)
|
| 96 |
+
|
| 97 |
+
# Compute cluster size statistics
|
| 98 |
+
cluster_sizes = [len(np.where(clusters == i)[0]) for i in np.unique(clusters)]
|
| 99 |
+
avg_cluster_size = np.mean(cluster_sizes)
|
| 100 |
+
avg_cluster_data_ratio = avg_cluster_size / len(X)
|
| 101 |
+
std_cluster_size = np.std(cluster_sizes)
|
| 102 |
+
median_cluster_size = np.median(cluster_sizes)
|
| 103 |
+
min_cluster_size = np.min(cluster_sizes)
|
| 104 |
+
max_cluster_size = np.max(cluster_sizes)
|
| 105 |
+
cluster_size_skewness = skew(cluster_sizes, nan_policy="omit") # Indicates imbalance in cluster sizes
|
| 106 |
+
|
| 107 |
+
return {
|
| 108 |
+
"silhouette": silhouette,
|
| 109 |
+
"davies_bouldin": davies_bouldin,
|
| 110 |
+
"calinski_harabasz": calinski_harabasz,
|
| 111 |
+
"avg_cluster_size": avg_cluster_size,
|
| 112 |
+
"avg_cluster_data_ratio": avg_cluster_data_ratio,
|
| 113 |
+
"std_cluster_size": std_cluster_size,
|
| 114 |
+
"min_cluster_size": min_cluster_size,
|
| 115 |
+
"median_cluster_size": median_cluster_size,
|
| 116 |
+
"max_cluster_size": max_cluster_size,
|
| 117 |
+
"cluster_size_skewness": cluster_size_skewness,
|
| 118 |
+
"num_clusters": len(unique_clusters),
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
def get_representative_e3s(
|
| 122 |
+
train_df: pd.DataFrame,
|
| 123 |
+
fp_generator: Optional[Any] = None,
|
| 124 |
+
n_clusters_candidates: List[int] = [10, 25, 50, 100, 150],
|
| 125 |
+
e3_column: str = 'E3 Binder SMILES with direction',
|
| 126 |
+
) -> Tuple[List[str], List[Any], int, pd.DataFrame]:
|
| 127 |
+
"""
|
| 128 |
+
Get representative E3 ligands from a DataFrame of training data by clustering their fingerprints.
|
| 129 |
+
This function computes Morgan fingerprints for unique E3 ligands, clusters them using KMeans and UMAP,
|
| 130 |
+
evaluates the clusters using silhouette, Davies-Bouldin, and Calinski-Harabasz scores, and identifies
|
| 131 |
+
the optimal number of clusters based on these metrics.
|
| 132 |
+
It returns the representative E3 ligands, their fingerprints, the best number of clusters, and a DataFrame
|
| 133 |
+
containing the clustering metrics.
|
| 134 |
+
|
| 135 |
+
Parameters:
|
| 136 |
+
train_df (pd.DataFrame): DataFrame containing training data with E3 ligands.
|
| 137 |
+
fp_generator (Optional[Any]): RDKit fingerprint generator. If None, a default Morgan fingerprint generator with 1024 bits and radius 6 is used.
|
| 138 |
+
n_clusters_candidates (List[int]): List of candidate numbers of clusters to evaluate.
|
| 139 |
+
e3_column (str): The column name in the DataFrame that contains the E3 ligand SMILES strings.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Tuple[List[str], List[Any], int, pd.DataFrame]: A tuple containing:
|
| 143 |
+
- List of representative E3 ligand SMILES strings.
|
| 144 |
+
- List of RDKit fingerprints corresponding to the representative E3 ligands.
|
| 145 |
+
- The best number of clusters determined from the clustering metrics.
|
| 146 |
+
- DataFrame containing clustering metrics for each candidate number of clusters.
|
| 147 |
+
"""
|
| 148 |
+
if e3_column not in train_df.columns:
|
| 149 |
+
raise ValueError(f"Column '{e3_column}' not found in the DataFrame.")
|
| 150 |
+
|
| 151 |
+
if fp_generator is None:
|
| 152 |
+
fp_generator = rdFingerprintGenerator.GetMorganGenerator(
|
| 153 |
+
radius=16,
|
| 154 |
+
fpSize=1024,
|
| 155 |
+
useBondTypes=True,
|
| 156 |
+
includeChirality=True,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
fp_dict = {}
|
| 160 |
+
for smi in tqdm(train_df[e3_column].unique()):
|
| 161 |
+
fp = get_fp(remove_dummy_atoms(smi), fp_generator)
|
| 162 |
+
if fp is not None:
|
| 163 |
+
fp_dict[smi] = fp
|
| 164 |
+
|
| 165 |
+
fp_list = list(fp_dict.values())
|
| 166 |
+
fp2smiles = {fp.tobytes(): smi for smi, fp in fp_dict.items() if fp is not None}
|
| 167 |
+
|
| 168 |
+
centroids_dict = {}
|
| 169 |
+
clusters_dict = {}
|
| 170 |
+
metrics_df = []
|
| 171 |
+
for n_clusters in tqdm(n_clusters_candidates, desc="Clustering and evaluating"):
|
| 172 |
+
clusters, centroids = get_kmeans_clusters_fp(fp_list, n_clusters=n_clusters, return_centroids=True)
|
| 173 |
+
metrics = evaluate_clusters(fp_list, clusters)
|
| 174 |
+
clusters_dict[f'kmeans_n{n_clusters}'] = clusters.copy()
|
| 175 |
+
centroids_dict[n_clusters] = centroids.copy()
|
| 176 |
+
|
| 177 |
+
metrics['num_clusters'] = n_clusters
|
| 178 |
+
metrics['cluster_algorithm'] = 'kmeans'
|
| 179 |
+
metrics_df.append(metrics.copy())
|
| 180 |
+
|
| 181 |
+
clusters = get_umap_clusters_fp(fp_list, n_clusters=n_clusters)
|
| 182 |
+
metrics = evaluate_clusters(fp_list, clusters)
|
| 183 |
+
clusters_dict[f'umap_n{n_clusters}'] = clusters.copy()
|
| 184 |
+
|
| 185 |
+
metrics['num_clusters'] = n_clusters
|
| 186 |
+
metrics['cluster_algorithm'] = 'umap'
|
| 187 |
+
metrics_df.append(metrics.copy())
|
| 188 |
+
|
| 189 |
+
metrics_df = pd.DataFrame(metrics_df)
|
| 190 |
+
|
| 191 |
+
# Get the sweet spot for the number of clusters
|
| 192 |
+
# Flip davies_bouldin so that all metrics are to be maximized
|
| 193 |
+
metrics_df['-davies_bouldin'] = -metrics_df['davies_bouldin']
|
| 194 |
+
|
| 195 |
+
# Normalize all three metrics (by group if you want per algorithm)
|
| 196 |
+
metrics = ['silhouette', '-davies_bouldin', 'calinski_harabasz']
|
| 197 |
+
df_norm = metrics_df.copy()
|
| 198 |
+
df_norm[metrics] = df_norm.groupby('cluster_algorithm')[metrics].transform(
|
| 199 |
+
lambda x: (x - x.min()) / (x.max() - x.min())
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Measure divergence: standard deviation of normalized metrics per row
|
| 203 |
+
df_norm['metric_divergence'] = df_norm[metrics].std(axis=1)
|
| 204 |
+
|
| 205 |
+
# Pick the point with lowest divergence, possibly applying constraints (e.g. not too many clusters)
|
| 206 |
+
sweet_spots = df_norm.loc[df_norm.groupby('cluster_algorithm')['metric_divergence'].idxmin()]
|
| 207 |
+
|
| 208 |
+
best_n_clusters = sweet_spots[['num_clusters']]['num_clusters'].unique()[0]
|
| 209 |
+
|
| 210 |
+
# Get the centroids of the clusters
|
| 211 |
+
centroids = centroids_dict[best_n_clusters]
|
| 212 |
+
|
| 213 |
+
# Get the cluster labels for the centroids
|
| 214 |
+
clusters = np.array(clusters_dict[f'kmeans_n{n_clusters}'])
|
| 215 |
+
representative_e3s = []
|
| 216 |
+
representative_e3s_fp = []
|
| 217 |
+
for label, centroid in enumerate(centroids):
|
| 218 |
+
# Isolate the FP with the same label as the centroid
|
| 219 |
+
fp_cluster = np.array(fp_list)[clusters == label]
|
| 220 |
+
# Get the closest FP for the centroid, use euclidean distance
|
| 221 |
+
distances = np.linalg.norm(fp_cluster - centroid, axis=1)
|
| 222 |
+
closest_fp = np.argmin(distances)
|
| 223 |
+
# To get the SMILES from the FP, use the fp2smiles dictionary
|
| 224 |
+
closest_smiles = fp2smiles[fp_cluster[closest_fp].tobytes()]
|
| 225 |
+
# Append the closest SMILES to the representative_e3s list
|
| 226 |
+
representative_e3s.append(closest_smiles)
|
| 227 |
+
representative_e3s_fp.append(fp_cluster[closest_fp])
|
| 228 |
+
|
| 229 |
+
# Convert the representative E3s to RDKit fingerprints
|
| 230 |
+
representative_e3s_fp = [numpy_to_rdkit_fp(fp) for fp in representative_e3s_fp]
|
| 231 |
+
|
| 232 |
+
return representative_e3s, representative_e3s_fp, best_n_clusters, metrics_df
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
DEFAULT_REPRESENTATIVE_E3S = [
|
| 236 |
+
'Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)CN[*:2])cc1',
|
| 237 |
+
'O=C1CCC(N2Cc3c(N=[*:2])cccc3C2=O)C(=O)N1',
|
| 238 |
+
'CC(=O)NC(C(=O)N1CC(O)CC1C(=O)[*:2])C(C)(C)C',
|
| 239 |
+
'CN[C@@H](C)C(=O)N[C@H](C(=O)N1C[C@@H](Oc2ccccc2[*:2])C[C@H]1C(=O)N[C@@H]1CCCc2ccccc21)C1CCCCC1',
|
| 240 |
+
'Cc1ncsc1-c1ccc(CNC(=O)C2CC(O)CN2C(=O)C(NC(=O)CCO[*:2])C(C)(C)C)cc1',
|
| 241 |
+
'O=C1CCC(N2Cc3ccc([*:2])cc3C2=O)C(=O)N1',
|
| 242 |
+
'COc1ccc(C2=N[C@@H](c3ccc(Cl)cc3)[C@@H](c3ccc(Cl)cc3)N2C(=O)N2CCN(CC(=O)[*:2])C(=O)C2)c(OC(C)C)c1',
|
| 243 |
+
'CC(NC(=O)C1CC(O)CN1C(=O)C(N[*:2])C(C)(C)C)c1ccc(C2CC2)cc1',
|
| 244 |
+
'CCOc1cc(C(C)(C)C)ccc1C1=NC(c2ccc(Cl)cc2)C(c2ccc(Cl)cc2)N1C(=O)N1CCN(CCCC[*:2])CC1',
|
| 245 |
+
'CNC(C)C(=O)NC(C(=O)N1CCCC1c1cncc(C(=O)c2cccc([*:2])c2)c1)C1CCCCC1',
|
| 246 |
+
'CN[C@@H](C)C(=O)N[C@H](C(=O)N1CCC[C@H]1c1nc(C(=O)c2ccc([*:2])cc2)cs1)C1CCCCC1',
|
| 247 |
+
'O=C1CCC(N2C(=O)c3cccc(OC[*:2])c3C2=O)C(=O)N1',
|
| 248 |
+
'CCOc1cc(C(C)(C)C)ccc1C1=NC(c2ccc(Cl)cc2)C(c2ccc(Cl)cc2)N1C(=O)N1CCN([*:2])CC1',
|
| 249 |
+
'Cc1ncsc1-c1ccc(CNC(=O)[C@H]2C[C@H](O)CN2C(=O)C(N[*:2])C(C)(C)C)cc1',
|
| 250 |
+
'Cc1ncsc1-c1ccc([C@H](C)NC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](N[*:2])C(C)(C)C)cc1',
|
| 251 |
+
'CN[C@@H](C)C(=O)N[C@H](C(=O)N1CCC[C@H]1c1cncc(C(=O)c2cccc([*:2])c2)c1)C1CCCCC1',
|
| 252 |
+
'Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](N[*:2])C(C)(C)C)c(OC2CCNCC2)c1',
|
| 253 |
+
'CNC(C)C(=O)NC(C(=O)N1CC(Oc2ccc([*:2])cc2)CC1C(=O)NC1CCCc2ccccc21)C1CCCCC1',
|
| 254 |
+
'C[C@H](NC(=O)[C@@H]1C[C@@H](O)CN1C(=O)[C@@H](N[*:2])C(C)(C)C)c1ccc(C(C)(C)C)cc1',
|
| 255 |
+
'CNC(C)C(=O)NC(C(=O)N1CCCC1c1nc(C(=O)c2ccc([*:2])cc2)cs1)C1CCCCC1',
|
| 256 |
+
'CC(=O)NC(C(=O)N1CC(O)CC1C(=O)NCc1ccc(-c2scnc2C)cc1[*:2])C(C)(C)C',
|
| 257 |
+
'Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](NC(=O)C2(F)CC2)C(C)(C)C)c([*:2])c1',
|
| 258 |
+
'CCOc1cc(C(C)(C)C)ccc1C1=NC(C)(c2ccc(Cl)cc2)C(C)(c2ccc(Cl)cc2)N1C(=O)N1CCN(CC(=O)[*:2])CC1',
|
| 259 |
+
'COc1ccc(C(=O)[*:2])cc1N1CCC(=O)NC1=O',
|
| 260 |
+
'CN[C@@H](C)C(=O)N[C@H](C(=O)N[C@H]1C[C@H]2CC[C@@H]1N(CCc1ccc([*:2])cc1)C2)C1CCCCC1',
|
| 261 |
+
'CNC(C)C(=O)NC(C(=O)N1CC(N[*:2])CC1C(=O)NC1CCCc2ccccc21)C1CCCCC1',
|
| 262 |
+
'CN[C@@H](C)C(=O)N[C@@H](CCCCN[*:2])C(=O)N1CCC[C@H]1C(=O)Nc1snnc1-c1ccccc1',
|
| 263 |
+
'CNC(C)C(=O)NC(C(=O)NC1CC2CCC1N(CCc1cccc([*:2])c1)C2)C1CCCCC1',
|
| 264 |
+
'O=C1CCC(N2C(=O)c3ccc(N[*:2])cc3C2=O)C(=O)N1',
|
| 265 |
+
'CNC(C)C(=O)NC(C(=O)N1CC(NC(=O)CC[*:2])CC1C(=O)Nc1c(F)cccc1F)C(C)(C)C',
|
| 266 |
+
'Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@H](N[*:2])C(C)(C)C)cc1',
|
| 267 |
+
'Cc1nc[nH]c1-c1ccc(CNC(=O)C2CC(O)CN2C(=O)C(N[*:2])C(C)(C)C)cc1',
|
| 268 |
+
'Cc1ncsc1-c1ccc(C(C)NC(=O)C2CC(O)CN2C(=O)C(N[*:2])C(C)(C)C)cc1',
|
| 269 |
+
'Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](N[*:2])C(C)(C)C)cc1',
|
| 270 |
+
'O=C1CCC(c2cccc([*:2])c2)C(=O)N1',
|
| 271 |
+
'CC(=O)N[C@H](C(=O)N1C[C@@H](O)C[C@@H]1C(=O)N[C@@H](CC(=O)N1CCC([*:2])CC1)c1ccccc1)C(C)C',
|
| 272 |
+
'O=C(CCl)[*:2]',
|
| 273 |
+
'CC[C@@H](NC(=O)[C@@H]1C[C@H](N[*:2])CN1C(=O)[C@@H](NC(=O)[C@H](C)NC)C(C)(C)C)c1ccccc1',
|
| 274 |
+
'CN[C@H](C)C(=O)N[C@@H]1CCO[C@@H]2CC(C)(C)[C@H](C(=O)N[C@@H]3CCCc4cc([*:2])ccc43)N2C1=O',
|
| 275 |
+
'CN[C@@H](C)C(=O)N[C@H](C(=O)N1CCC[C@H]1c1nc(C(=O)c2ccc(F)cc2)cs1)C1CCN(C[*:2])CC1',
|
| 276 |
+
'Cc1ncsc1-c1ccc(CNC(=O)C2CC(O)CN2C(=O)C(N[*:2])C(C)(C)C)cc1',
|
| 277 |
+
'CNC(C)C(=O)NC(CCCCN[*:2])C(=O)N1CCCC1C(=O)Nc1snnc1-c1ccccc1',
|
| 278 |
+
'O=C1CCC(N2C(=O)c3cccc([*:2])c3C2=O)C(=O)O1',
|
| 279 |
+
'COc1ccc(C2=N[C@@H](c3ccc(Cl)cc3)[C@@H](c3ccc(Cl)cc3)N2C(=O)N2CCN(CC(=O)[*:2])C(=O)C2)cc1OC(C)C',
|
| 280 |
+
'Cc1ncsc1-c1ccc(CNC(=O)C2CC(O)CN2C(=O)C(N[*:2])C(C)(C)C)c(OC2CCNCC2)c1',
|
| 281 |
+
'CNC(C)C(=O)NC(C(=O)N1CCCC1c1cncc(-n2ccc3c(C(=O)[*:2])cccc32)c1)C(C)C',
|
| 282 |
+
'CCN1CCN(Cc2ccc(NC(=O)c3cccc(-c4ccc5nc(N[*:2])sc5n4)c3)cc2C(F)(F)F)CC1',
|
| 283 |
+
'CN[C@@H](C)C(=O)N[C@H](C(=O)N1C[C@@H](NC(=O)CC[*:2])C[C@H]1C(=O)Nc1c(F)cccc1F)C(C)(C)C',
|
| 284 |
+
'CNC(C)C(=O)NC(C(=O)N1CCCC1C(=O)NC(C(=O)[*:2])C(c1ccccc1)c1ccccc1)C1CCCCC1',
|
| 285 |
+
'CC(=O)NCC(C(=O)N1CC(O)CC1C(=O)NC(CC(=O)N1CCC(N2CCC([*:2])CC2)CC1)c1ccccc1)C(C)C',
|
| 286 |
+
]
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@functools.lru_cache(maxsize=1, typed=False)
|
| 290 |
+
def get_representative_e3s_fp(
|
| 291 |
+
e3_list: Optional[List[str]] = None,
|
| 292 |
+
fp_generator: Optional[Any] = None,
|
| 293 |
+
verbose: int = 0,
|
| 294 |
+
) -> List[DataStructs.ExplicitBitVect]:
|
| 295 |
+
"""
|
| 296 |
+
Generate Morgan fingerprints for a list of E3 ligands. If no list is provided,
|
| 297 |
+
it uses a default list of representative E3 ligands.
|
| 298 |
+
|
| 299 |
+
Parameters:
|
| 300 |
+
e3_list (Optional[List[str]]): List of SMILES strings for E3 ligands. If None, uses a default list.
|
| 301 |
+
fp_generator (Optional[Any]): RDKit fingerprint generator. If None, a default Morgan fingerprint generator is used.
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
List[DataStructs.ExplicitBitVect]: List of RDKit Morgan fingerprints for the E3 ligands.
|
| 305 |
+
"""
|
| 306 |
+
representative_e3s_fp = []
|
| 307 |
+
if verbose > 0:
|
| 308 |
+
iterable = tqdm(e3_list or DEFAULT_REPRESENTATIVE_E3S, desc="Generating fingerprints for E3 ligands")
|
| 309 |
+
else:
|
| 310 |
+
iterable = e3_list or DEFAULT_REPRESENTATIVE_E3S
|
| 311 |
+
for smi in iterable:
|
| 312 |
+
# Get the Morgan fingerprint for the SMILES string
|
| 313 |
+
fp = get_fp(remove_dummy_atoms(smi), fp_generator, return_np=False)
|
| 314 |
+
if fp is not None:
|
| 315 |
+
representative_e3s_fp.append(fp)
|
| 316 |
+
else:
|
| 317 |
+
print(f"Warning: Invalid SMILES string '{smi}' encountered, skipping.")
|
| 318 |
+
if not representative_e3s_fp:
|
| 319 |
+
raise ValueError("No valid E3 ligands found in the provided list.")
|
| 320 |
+
return representative_e3s_fp
|
| 321 |
+
|
protac_splitter/graphs/edge_classifier.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import joblib
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional, List, Dict, Union, Any, Literal
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
from sklearn.base import BaseEstimator, ClassifierMixin
|
| 8 |
+
from sklearn.compose import ColumnTransformer
|
| 9 |
+
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
| 10 |
+
from sklearn.decomposition import TruncatedSVD
|
| 11 |
+
from imblearn.over_sampling import SMOTE
|
| 12 |
+
from imblearn.pipeline import Pipeline as ImbPipeline
|
| 13 |
+
from sklearn.pipeline import Pipeline
|
| 14 |
+
from sklearn.metrics import classification_report
|
| 15 |
+
from sklearn.metrics import confusion_matrix
|
| 16 |
+
from xgboost import XGBClassifier
|
| 17 |
+
import optuna
|
| 18 |
+
from optuna.samplers import QMCSampler
|
| 19 |
+
from sklearn.metrics import accuracy_score, f1_score
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
import seaborn as sns
|
| 23 |
+
import matplotlib.pyplot as plt
|
| 24 |
+
HAS_VISUALIZATION = True
|
| 25 |
+
except ImportError:
|
| 26 |
+
HAS_VISUALIZATION = False
|
| 27 |
+
|
| 28 |
+
from .edge_features import extract_edge_features, get_edge_features
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class GraphEdgeClassifier(BaseEstimator, ClassifierMixin):
|
| 32 |
+
"""
|
| 33 |
+
Edge-level graph classifier for PROTACs with integrated pipeline building.
|
| 34 |
+
"""
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
graph_features: List[str],
|
| 38 |
+
categorical_features: Optional[List[str]] = None,
|
| 39 |
+
descriptor_features: Optional[List[str]] = None,
|
| 40 |
+
fingerprint_features: Optional[List[str]] = None,
|
| 41 |
+
use_descriptors: bool = True,
|
| 42 |
+
use_fingerprints: bool = True,
|
| 43 |
+
scaler_graph: Literal["passthrough", "standard"] = "passthrough",
|
| 44 |
+
scaler_desc: Literal["passthrough", "standard"] = "passthrough",
|
| 45 |
+
use_svd_fp: bool = True,
|
| 46 |
+
n_svd_components: int = 100,
|
| 47 |
+
binary: bool = False,
|
| 48 |
+
smote_k_neighbors: Optional[int] = 5,
|
| 49 |
+
xgb_params: Optional[dict] = None,
|
| 50 |
+
n_bits: int = 512,
|
| 51 |
+
radius: int = 6,
|
| 52 |
+
descriptor_names: Optional[List[str]] = None
|
| 53 |
+
):
|
| 54 |
+
self.graph_features = graph_features
|
| 55 |
+
self.categorical_features = categorical_features
|
| 56 |
+
self.descriptor_features = descriptor_features
|
| 57 |
+
self.fingerprint_features = fingerprint_features
|
| 58 |
+
self.use_descriptors = use_descriptors
|
| 59 |
+
self.use_fingerprints = use_fingerprints
|
| 60 |
+
self.scaler_graph = scaler_graph
|
| 61 |
+
self.scaler_desc = scaler_desc
|
| 62 |
+
self.use_svd_fp = use_svd_fp
|
| 63 |
+
self.n_svd_components = n_svd_components
|
| 64 |
+
self.binary = binary
|
| 65 |
+
self.smote_k_neighbors = smote_k_neighbors
|
| 66 |
+
self.xgb_params = xgb_params or {}
|
| 67 |
+
self.n_bits = n_bits
|
| 68 |
+
self.radius = radius
|
| 69 |
+
self.descriptor_names = descriptor_names or [
|
| 70 |
+
"MolWt", "HeavyAtomCount", "NumHAcceptors", "NumHDonors",
|
| 71 |
+
"TPSA", "NumRotatableBonds", "RingCount", "MolLogP"
|
| 72 |
+
]
|
| 73 |
+
self.pipeline = self._build_pipeline()
|
| 74 |
+
|
| 75 |
+
def _build_pipeline(self):
|
| 76 |
+
transformers = []
|
| 77 |
+
if self.categorical_features:
|
| 78 |
+
transformers.append(("cat", OneHotEncoder(handle_unknown="ignore"), self.categorical_features))
|
| 79 |
+
if self.scaler_graph == "standard":
|
| 80 |
+
transformers.append(("num", StandardScaler(), self.graph_features))
|
| 81 |
+
else:
|
| 82 |
+
transformers.append(("num", "passthrough", self.graph_features))
|
| 83 |
+
|
| 84 |
+
if self.use_descriptors and self.descriptor_features:
|
| 85 |
+
desc_block = (
|
| 86 |
+
("desc", StandardScaler(), self.descriptor_features)
|
| 87 |
+
if self.scaler_desc == "standard"
|
| 88 |
+
else ("desc", "passthrough", self.descriptor_features)
|
| 89 |
+
)
|
| 90 |
+
transformers.append(desc_block)
|
| 91 |
+
|
| 92 |
+
if self.use_fingerprints and self.fingerprint_features:
|
| 93 |
+
if self.use_svd_fp:
|
| 94 |
+
fp_block = ("fp",
|
| 95 |
+
ImbPipeline([
|
| 96 |
+
("svd", TruncatedSVD(n_components=self.n_svd_components, random_state=42))
|
| 97 |
+
]),
|
| 98 |
+
self.fingerprint_features)
|
| 99 |
+
else:
|
| 100 |
+
fp_block = ("fp", "passthrough", self.fingerprint_features)
|
| 101 |
+
transformers.append(fp_block)
|
| 102 |
+
|
| 103 |
+
preprocessor = ColumnTransformer(transformers)
|
| 104 |
+
|
| 105 |
+
# Define the classifier
|
| 106 |
+
classifier = XGBClassifier(
|
| 107 |
+
random_state=42,
|
| 108 |
+
eval_metric="logloss" if self.binary else "mlogloss",
|
| 109 |
+
objective="binary:logistic" if self.binary else "multi:softprob",
|
| 110 |
+
**self.xgb_params
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if self.smote_k_neighbors is not None:
|
| 114 |
+
return ImbPipeline([
|
| 115 |
+
("preprocess", preprocessor),
|
| 116 |
+
("smote", SMOTE(random_state=42, k_neighbors=self.smote_k_neighbors)),
|
| 117 |
+
("clf", classifier)
|
| 118 |
+
])
|
| 119 |
+
else:
|
| 120 |
+
return Pipeline([
|
| 121 |
+
("preprocess", preprocessor),
|
| 122 |
+
("clf", classifier)
|
| 123 |
+
])
|
| 124 |
+
|
| 125 |
+
def fit(self, X: pd.DataFrame, y: pd.Series):
|
| 126 |
+
self.pipeline.fit(X, y)
|
| 127 |
+
return self
|
| 128 |
+
|
| 129 |
+
def predict(self, X: Union[pd.DataFrame, List[Dict], List[str]]) -> Any:
|
| 130 |
+
X_proc = self._ensure_features(X)
|
| 131 |
+
return self.pipeline.predict(X_proc)
|
| 132 |
+
|
| 133 |
+
def predict_proba(self, X: Union[pd.DataFrame, List[Dict], List[str]]) -> Any:
|
| 134 |
+
X_proc = self._ensure_features(X)
|
| 135 |
+
return self.pipeline.predict_proba(X_proc)
|
| 136 |
+
|
| 137 |
+
def save(self, path: Union[str, Path]):
|
| 138 |
+
joblib.dump(self, str(path))
|
| 139 |
+
|
| 140 |
+
@classmethod
|
| 141 |
+
def load(cls, path: Union[str, Path]) -> "GraphEdgeClassifier":
|
| 142 |
+
return joblib.load(str(path))
|
| 143 |
+
|
| 144 |
+
@staticmethod
|
| 145 |
+
def extract_graph_features(
|
| 146 |
+
protac_smiles: Union[str, List[str]],
|
| 147 |
+
wh_smiles: Optional[Union[str, List[str]]] = None,
|
| 148 |
+
lk_smiles: Optional[Union[str, List[str]]] = None,
|
| 149 |
+
e3_smiles: Optional[Union[str, List[str]]] = None,
|
| 150 |
+
n_bits: int = 512,
|
| 151 |
+
radius: int = 6,
|
| 152 |
+
descriptor_names: Optional[List[str]] = None,
|
| 153 |
+
verbose: int = 0
|
| 154 |
+
) -> pd.DataFrame:
|
| 155 |
+
if any(x is None for x in [wh_smiles, lk_smiles, e3_smiles]):
|
| 156 |
+
# Get features from PROTAC only, for inference
|
| 157 |
+
return extract_edge_features(
|
| 158 |
+
protac_smiles=protac_smiles,
|
| 159 |
+
n_bits=n_bits,
|
| 160 |
+
radius=radius,
|
| 161 |
+
descriptor_names=descriptor_names,
|
| 162 |
+
)
|
| 163 |
+
else:
|
| 164 |
+
# Get features and labels from all components, for training
|
| 165 |
+
return get_edge_features(
|
| 166 |
+
protac_smiles=protac_smiles,
|
| 167 |
+
wh_smiles=wh_smiles,
|
| 168 |
+
lk_smiles=lk_smiles,
|
| 169 |
+
e3_smiles=e3_smiles,
|
| 170 |
+
n_bits=n_bits,
|
| 171 |
+
radius=radius,
|
| 172 |
+
descriptor_names=descriptor_names,
|
| 173 |
+
verbose=verbose
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
@staticmethod
|
| 177 |
+
def build_multiclass_target(
|
| 178 |
+
df: pd.DataFrame,
|
| 179 |
+
poi_attachment_id: int = 1,
|
| 180 |
+
e3_attachment_id: int = 2,
|
| 181 |
+
) -> pd.Series:
|
| 182 |
+
"""
|
| 183 |
+
Returns multiclass target: 0 = no split, 1 = E3 split, 2 = WH split
|
| 184 |
+
"""
|
| 185 |
+
assert ((df["label_e3_split"] + df["label_wh_split"]) <= 1).all()
|
| 186 |
+
y = (
|
| 187 |
+
df["label_wh_split"] * poi_attachment_id +
|
| 188 |
+
df["label_e3_split"] * e3_attachment_id
|
| 189 |
+
)
|
| 190 |
+
return y.astype("int32")
|
| 191 |
+
|
| 192 |
+
def _ensure_features(self, X: Union[pd.DataFrame, List[Dict], List[str]]) -> pd.DataFrame:
|
| 193 |
+
""" Filter out features/columns that are are not used in the pipeline. """
|
| 194 |
+
required_columns = (
|
| 195 |
+
(self.graph_features or []) +
|
| 196 |
+
(self.categorical_features or []) +
|
| 197 |
+
(self.descriptor_features or []) +
|
| 198 |
+
(self.fingerprint_features or [])
|
| 199 |
+
)
|
| 200 |
+
# If input is a DataFrame with SMILES, assume already featurized
|
| 201 |
+
if isinstance(X, pd.DataFrame):
|
| 202 |
+
Xf = X
|
| 203 |
+
elif isinstance(X, list) and isinstance(X[0], dict):
|
| 204 |
+
Xf = pd.DataFrame(X)
|
| 205 |
+
else:
|
| 206 |
+
raise ValueError("Provide either a DataFrame or list of feature dicts. Use extract_graph_features for SMILES.")
|
| 207 |
+
missing = set(required_columns) - set(Xf.columns)
|
| 208 |
+
if missing:
|
| 209 |
+
raise ValueError(f"Input data missing required columns: {missing}")
|
| 210 |
+
return Xf[required_columns].copy()
|
| 211 |
+
|
| 212 |
+
def predict_proba_from_smiles(
|
| 213 |
+
self,
|
| 214 |
+
protac_smiles: Union[str, List[str]],
|
| 215 |
+
wh_smiles: Union[str, List[str]],
|
| 216 |
+
lk_smiles: Union[str, List[str]],
|
| 217 |
+
e3_smiles: Union[str, List[str]],
|
| 218 |
+
verbose: int = 0,
|
| 219 |
+
):
|
| 220 |
+
features = self.extract_graph_features(
|
| 221 |
+
protac_smiles, wh_smiles, lk_smiles, e3_smiles,
|
| 222 |
+
n_bits=self.n_bits,
|
| 223 |
+
radius=self.radius,
|
| 224 |
+
descriptor_names=self.descriptor_names,
|
| 225 |
+
verbose=verbose
|
| 226 |
+
)
|
| 227 |
+
Xf = self._ensure_features(features)
|
| 228 |
+
return self.pipeline.predict_proba(Xf)
|
| 229 |
+
|
| 230 |
+
def predict_from_smiles(
|
| 231 |
+
self,
|
| 232 |
+
protac_smiles: Union[str, List[str]],
|
| 233 |
+
wh_smiles: Union[str, List[str]],
|
| 234 |
+
lk_smiles: Union[str, List[str]],
|
| 235 |
+
e3_smiles: Union[str, List[str]],
|
| 236 |
+
top_n: int = 1,
|
| 237 |
+
return_array: bool = True,
|
| 238 |
+
verbose: int = 0,
|
| 239 |
+
) -> Union[pd.DataFrame, np.ndarray]:
|
| 240 |
+
"""
|
| 241 |
+
For binary classification:
|
| 242 |
+
For each SMILES, return the top_n edge chem_bond_idx indices among those predicted as class 1,
|
| 243 |
+
sorted by predicted probability. If not enough edges are class 1, pad with -1.
|
| 244 |
+
For multiclass:
|
| 245 |
+
For each SMILES, return the chem_bond_idx with highest probability for class 1 (E3 split)
|
| 246 |
+
and for class 2 (WH split). Shape: (num_smiles, 2).
|
| 247 |
+
If no edge is predicted as that class, value is -1.
|
| 248 |
+
"""
|
| 249 |
+
features = self.extract_graph_features(
|
| 250 |
+
protac_smiles, wh_smiles, lk_smiles, e3_smiles,
|
| 251 |
+
n_bits=self.n_bits,
|
| 252 |
+
radius=self.radius,
|
| 253 |
+
descriptor_names=self.descriptor_names,
|
| 254 |
+
verbose=verbose
|
| 255 |
+
)
|
| 256 |
+
Xf = self._ensure_features(features)
|
| 257 |
+
pred_proba = self.pipeline.predict_proba(Xf)
|
| 258 |
+
pred_label = self.pipeline.predict(Xf)
|
| 259 |
+
features = features.copy()
|
| 260 |
+
features["pred_label"] = pred_label
|
| 261 |
+
features["pred_proba"] = pred_proba[:, 1] if pred_proba.shape[1] > 1 else pred_proba[:, 0]
|
| 262 |
+
|
| 263 |
+
unique_smiles = pd.Series(features["chem_mol_smiles"]).drop_duplicates().tolist()
|
| 264 |
+
groupby = features.groupby("chem_mol_smiles")
|
| 265 |
+
|
| 266 |
+
results = []
|
| 267 |
+
|
| 268 |
+
if return_array:
|
| 269 |
+
if pred_proba.shape[1] == 2: # Binary case
|
| 270 |
+
for mol_smiles in unique_smiles:
|
| 271 |
+
group = groupby.get_group(mol_smiles)
|
| 272 |
+
# Only consider edges predicted as label 1
|
| 273 |
+
edges_class1 = group[group["pred_label"] == 1]
|
| 274 |
+
# If none, pad with -1
|
| 275 |
+
if len(edges_class1) == 0:
|
| 276 |
+
results.append(np.full(top_n, -1))
|
| 277 |
+
continue
|
| 278 |
+
# Sort by proba, take top_n
|
| 279 |
+
top_edges = edges_class1.nlargest(top_n, "pred_proba")
|
| 280 |
+
idxs = top_edges["chem_bond_idx"].to_numpy()
|
| 281 |
+
if len(idxs) < top_n:
|
| 282 |
+
idxs = np.pad(idxs, (0, top_n - len(idxs)), constant_values=-1)
|
| 283 |
+
results.append(idxs[:top_n])
|
| 284 |
+
return np.vstack(results)
|
| 285 |
+
else: # Multiclass case
|
| 286 |
+
for mol_smiles in unique_smiles:
|
| 287 |
+
group = groupby.get_group(mol_smiles)
|
| 288 |
+
# For class 1
|
| 289 |
+
class1_idx = -1
|
| 290 |
+
if (group["pred_label"] == 1).any():
|
| 291 |
+
# Take the edge with highest class-1 probability
|
| 292 |
+
mask = group["pred_label"] == 1
|
| 293 |
+
idx1 = group.loc[mask, "pred_proba"].idxmax()
|
| 294 |
+
class1_idx = group.loc[idx1, "chem_bond_idx"]
|
| 295 |
+
# For class 2
|
| 296 |
+
class2_idx = -1
|
| 297 |
+
if (group["pred_label"] == 2).any():
|
| 298 |
+
mask = group["pred_label"] == 2
|
| 299 |
+
idx2 = group.loc[mask, "pred_proba"].idxmax()
|
| 300 |
+
class2_idx = group.loc[idx2, "chem_bond_idx"]
|
| 301 |
+
results.append([class1_idx, class2_idx])
|
| 302 |
+
return np.array(results, dtype=int)
|
| 303 |
+
else:
|
| 304 |
+
return features
|
| 305 |
+
|
| 306 |
+
def get_classification_report(y_true, y_pred, labels):
|
| 307 |
+
report = classification_report(y_true, y_pred, target_names=labels, output_dict=True)
|
| 308 |
+
df_report = pd.DataFrame(report).transpose().round(2)
|
| 309 |
+
print(df_report)
|
| 310 |
+
return df_report
|
| 311 |
+
|
| 312 |
+
def plot_confusion_matrix(y_true, y_pred, labels):
|
| 313 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 314 |
+
if HAS_VISUALIZATION:
|
| 315 |
+
plt.figure(figsize=(8, 6))
|
| 316 |
+
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
|
| 317 |
+
plt.xlabel("Predicted")
|
| 318 |
+
plt.ylabel("True")
|
| 319 |
+
plt.title("Confusion Matrix")
|
| 320 |
+
plt.show()
|
| 321 |
+
else:
|
| 322 |
+
print("Visualization libraries not available. Skipping confusion matrix plot.")
|
| 323 |
+
print("Confusion Matrix:")
|
| 324 |
+
print(cm)
|
| 325 |
+
|
| 326 |
+
def get_classification_report_and_plot(y_true, y_pred, labels):
|
| 327 |
+
report = get_classification_report(y_true, y_pred, labels)
|
| 328 |
+
plot_confusion_matrix(y_true, y_pred, labels)
|
| 329 |
+
return report
|
| 330 |
+
|
| 331 |
+
def train_edge_classifier(
|
| 332 |
+
train_df: pd.DataFrame,
|
| 333 |
+
val_df: Optional[pd.DataFrame] = None,
|
| 334 |
+
test_df: Optional[pd.DataFrame] = None,
|
| 335 |
+
model_filename: Optional[Union[str, Path]] = None,
|
| 336 |
+
edge_classifier_kwargs: Optional[Dict[str, Any]] = None,
|
| 337 |
+
cache_dir: Optional[Union[str, Path]] = None,
|
| 338 |
+
return_reports: bool = True,
|
| 339 |
+
plot_confusion_matrix: bool = False,
|
| 340 |
+
) -> GraphEdgeClassifier:
|
| 341 |
+
"""
|
| 342 |
+
Train an edge-level graph classifier for PROTACs.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
train_df (pd.DataFrame): Training data with columns:
|
| 346 |
+
- 'PROTAC SMILES'
|
| 347 |
+
- 'POI Ligand SMILES with direction'
|
| 348 |
+
- 'Linker SMILES with direction'
|
| 349 |
+
- 'E3 Binder SMILES with direction'
|
| 350 |
+
val_df (Optional[pd.DataFrame]): Validation data, same format as train_df.
|
| 351 |
+
test_df (Optional[pd.DataFrame]): Test data, same format as train_df.
|
| 352 |
+
model_filename (Optional[Union[str, Path]]): Path to save the trained model.
|
| 353 |
+
edge_classifier_kwargs (Optional[Dict[str, Any]]): Additional parameters for GraphEdgeClassifier.
|
| 354 |
+
return_reports (bool): Whether to return classification reports for validation and test sets.
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
GraphEdgeClassifier: Trained edge classifier instance.
|
| 358 |
+
"""
|
| 359 |
+
sets = {}
|
| 360 |
+
for set_name, df in [
|
| 361 |
+
("train", train_df),
|
| 362 |
+
("val", val_df),
|
| 363 |
+
("test", test_df),
|
| 364 |
+
]:
|
| 365 |
+
if cache_dir is not None:
|
| 366 |
+
cache_path = Path(cache_dir) / f"{set_name}.csv"
|
| 367 |
+
if cache_path.exists():
|
| 368 |
+
print(f"Loading cached features for {set_name} from {cache_path}")
|
| 369 |
+
sets[set_name] = pd.read_csv(cache_path)
|
| 370 |
+
continue
|
| 371 |
+
else:
|
| 372 |
+
print(f"Cache not found for {set_name}, extracting features...")
|
| 373 |
+
|
| 374 |
+
if df is None or df.empty:
|
| 375 |
+
continue
|
| 376 |
+
|
| 377 |
+
print(f"Set: {set_name}, size: {len(df):,}")
|
| 378 |
+
if 'PROTAC SMILES' not in df.columns or \
|
| 379 |
+
'POI Ligand SMILES with direction' not in df.columns or \
|
| 380 |
+
'Linker SMILES with direction' not in df.columns or \
|
| 381 |
+
'E3 Binder SMILES with direction' not in df.columns:
|
| 382 |
+
raise ValueError(f"DataFrame for {set_name} is missing required columns: 'PROTAC SMILES', 'POI Ligand SMILES with direction', 'Linker SMILES with direction', 'E3 Binder SMILES with direction'.")
|
| 383 |
+
|
| 384 |
+
sets[set_name] = GraphEdgeClassifier.extract_graph_features(
|
| 385 |
+
df['PROTAC SMILES'].tolist(),
|
| 386 |
+
df['POI Ligand SMILES with direction'].tolist(),
|
| 387 |
+
df['Linker SMILES with direction'].tolist(),
|
| 388 |
+
df['E3 Binder SMILES with direction'].tolist(),
|
| 389 |
+
verbose=1,
|
| 390 |
+
)
|
| 391 |
+
# Drop rows with label_e3_split + label_wh_split > 1
|
| 392 |
+
sets[set_name] = sets[set_name][(sets[set_name]["label_e3_split"] + sets[set_name]["label_wh_split"]) <= 1]
|
| 393 |
+
print(f"Set: {set_name}, size: {len(sets[set_name]):,}")
|
| 394 |
+
if cache_dir is not None:
|
| 395 |
+
cache_path = Path(cache_dir) / f"{set_name}.csv"
|
| 396 |
+
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
| 397 |
+
sets[set_name].to_csv(cache_path, index=False)
|
| 398 |
+
print(f"Saved {set_name} features to {cache_path}")
|
| 399 |
+
|
| 400 |
+
train_set = sets["train"]
|
| 401 |
+
label_cols = [c for c in train_set.columns if c.startswith("label_")]
|
| 402 |
+
train_set = train_set.dropna(subset=label_cols)
|
| 403 |
+
train_set = train_set[(train_set["label_e3_split"] + train_set["label_wh_split"]) <= 1]
|
| 404 |
+
X_train = train_set.drop(columns=label_cols)
|
| 405 |
+
|
| 406 |
+
# Instantiate and train
|
| 407 |
+
clf = GraphEdgeClassifier(**edge_classifier_kwargs or {
|
| 408 |
+
"graph_features": [c for c in X_train.columns if c.startswith("graph_")],
|
| 409 |
+
"categorical_features": ["chem_bond_type", "chem_atom_u", "chem_atom_v"],
|
| 410 |
+
"fingerprint_features": [c for c in X_train.columns if c.startswith("chem_mol_fp_")],
|
| 411 |
+
"use_descriptors": False,
|
| 412 |
+
"use_fingerprints": True,
|
| 413 |
+
"n_svd_components": 50,
|
| 414 |
+
"binary": True,
|
| 415 |
+
"smote_k_neighbors": 10,
|
| 416 |
+
"xgb_params": {
|
| 417 |
+
"max_depth": 6,
|
| 418 |
+
"learning_rate": 0.3,
|
| 419 |
+
"alpha": 0.1, # Default: 0
|
| 420 |
+
"lambda": 0.5, # Default: 1
|
| 421 |
+
"gamma": 0.1, # Default: 0
|
| 422 |
+
},
|
| 423 |
+
})
|
| 424 |
+
|
| 425 |
+
# Prepare target variable according to classification type
|
| 426 |
+
if clf.binary:
|
| 427 |
+
y_train = train_set["label_is_split"].astype("int32")
|
| 428 |
+
else:
|
| 429 |
+
y_train = GraphEdgeClassifier.build_multiclass_target(train_set)
|
| 430 |
+
|
| 431 |
+
print(f"Training set size: {len(X_train):,}, labels: {y_train.unique()}")
|
| 432 |
+
clf.fit(X_train, y_train)
|
| 433 |
+
print("Training complete.")
|
| 434 |
+
|
| 435 |
+
if model_filename is not None:
|
| 436 |
+
clf.save(model_filename)
|
| 437 |
+
print(f"Model saved to {model_filename}")
|
| 438 |
+
|
| 439 |
+
target_labels = ["No Split", "Split"] if clf.binary else ["No Split", "WH-Linker", "E3-Linker"]
|
| 440 |
+
|
| 441 |
+
report = None
|
| 442 |
+
if "val" in sets:
|
| 443 |
+
# Get validation data
|
| 444 |
+
val_set = sets["val"].dropna(subset=label_cols)
|
| 445 |
+
val_set = val_set[(val_set["label_e3_split"] + val_set["label_wh_split"]) <= 1]
|
| 446 |
+
X_val = val_set.drop(columns=label_cols)
|
| 447 |
+
y_val = val_set["label_is_split"].astype("int32") if clf.binary else GraphEdgeClassifier.build_multiclass_target(val_set)
|
| 448 |
+
y_pred = clf.predict(X_val)
|
| 449 |
+
if plot_confusion_matrix:
|
| 450 |
+
report = get_classification_report_and_plot(y_val, y_pred, target_labels)
|
| 451 |
+
else:
|
| 452 |
+
report = get_classification_report(y_val, y_pred, target_labels)
|
| 453 |
+
print(f"Validation set classification report:\n{report.to_markdown(index=False)}")
|
| 454 |
+
|
| 455 |
+
if "test" in sets:
|
| 456 |
+
# Get test data
|
| 457 |
+
test_set = sets["test"].dropna(subset=label_cols)
|
| 458 |
+
test_set = test_set[(test_set["label_e3_split"] + test_set["label_wh_split"]) <= 1]
|
| 459 |
+
X_test = test_set.drop(columns=label_cols)
|
| 460 |
+
y_test = test_set["label_is_split"].astype("int32") if clf.binary else GraphEdgeClassifier.build_multiclass_target(test_set)
|
| 461 |
+
y_pred = clf.predict(X_test)
|
| 462 |
+
if plot_confusion_matrix:
|
| 463 |
+
report = get_classification_report_and_plot(y_test, y_pred, target_labels)
|
| 464 |
+
else:
|
| 465 |
+
report = get_classification_report(y_test, y_pred, target_labels)
|
| 466 |
+
print(f"Test set classification report:\n{report.to_markdown(index=False)}")
|
| 467 |
+
|
| 468 |
+
if return_reports:
|
| 469 |
+
return clf, report
|
| 470 |
+
else:
|
| 471 |
+
return clf
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def objective(trial, train_df, val_df):
|
| 475 |
+
# HP space
|
| 476 |
+
max_depth = trial.suggest_int("max_depth", 3, 10)
|
| 477 |
+
learning_rate = trial.suggest_float("learning_rate", 0.01, 0.3, log=True)
|
| 478 |
+
alpha = trial.suggest_float("alpha", 0.0, 2.0)
|
| 479 |
+
reg_lambda = trial.suggest_float("lambda", 0.0, 2.0)
|
| 480 |
+
gamma = trial.suggest_float("gamma", 0.0, 1.0)
|
| 481 |
+
n_svd_components = trial.suggest_int("n_svd_components", 16, 128)
|
| 482 |
+
smote_k_neighbors = trial.suggest_int("smote_k_neighbors", 3, 15)
|
| 483 |
+
use_descriptors = trial.suggest_categorical("use_descriptors", [False, True])
|
| 484 |
+
use_fingerprints = trial.suggest_categorical("use_fingerprints", [True, False])
|
| 485 |
+
|
| 486 |
+
edge_classifier_kwargs = {
|
| 487 |
+
"graph_features": None, # Will be set in train_edge_classifier
|
| 488 |
+
"categorical_features": None,
|
| 489 |
+
"fingerprint_features": None,
|
| 490 |
+
"use_descriptors": use_descriptors,
|
| 491 |
+
"use_fingerprints": use_fingerprints,
|
| 492 |
+
"n_svd_components": n_svd_components,
|
| 493 |
+
"binary": True,
|
| 494 |
+
"smote_k_neighbors": smote_k_neighbors,
|
| 495 |
+
"xgb_params": {
|
| 496 |
+
"max_depth": max_depth,
|
| 497 |
+
"learning_rate": learning_rate,
|
| 498 |
+
"alpha": alpha,
|
| 499 |
+
"lambda": reg_lambda,
|
| 500 |
+
"gamma": gamma,
|
| 501 |
+
},
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
_, val_report = train_edge_classifier(
|
| 505 |
+
train_df=train_df,
|
| 506 |
+
val_df=val_df,
|
| 507 |
+
edge_classifier_kwargs=edge_classifier_kwargs,
|
| 508 |
+
return_reports=True,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# Evaluate metrics on validation set
|
| 512 |
+
# Assume val_report has columns: ['Label', 'precision', 'recall', 'f1-score', 'support']
|
| 513 |
+
# and that the binary positive class is "Split" or "1"
|
| 514 |
+
try:
|
| 515 |
+
f1_1 = float(val_report[val_report["Label"].isin(["Split", 1, "1"])]["f1-score"])
|
| 516 |
+
except Exception:
|
| 517 |
+
f1_1 = 0.0
|
| 518 |
+
try:
|
| 519 |
+
acc = float(val_report[val_report["Label"] == "accuracy"]["f1-score"])
|
| 520 |
+
except Exception:
|
| 521 |
+
acc = 0.0
|
| 522 |
+
|
| 523 |
+
# Multi-objective: prioritize F1 for minority class, but keep accuracy
|
| 524 |
+
# Adjust weight depending on task (here equal)
|
| 525 |
+
score = 0.5 * acc + 0.5 * f1_1
|
| 526 |
+
return score
|
| 527 |
+
|
| 528 |
+
def run_optuna_search(
|
| 529 |
+
train_df: pd.DataFrame,
|
| 530 |
+
val_df: pd.DataFrame,
|
| 531 |
+
n_trials: int = 50,
|
| 532 |
+
study_name: str = "edge_classifier_hp_search",
|
| 533 |
+
study_dir: str = "./optuna_studies",
|
| 534 |
+
seed: int = 42,
|
| 535 |
+
) -> Any:
|
| 536 |
+
import os
|
| 537 |
+
os.makedirs(study_dir, exist_ok=True)
|
| 538 |
+
study_path = f"sqlite:///{os.path.join(study_dir, study_name)}.db"
|
| 539 |
+
|
| 540 |
+
study = optuna.create_study(
|
| 541 |
+
study_name=study_name,
|
| 542 |
+
direction="maximize",
|
| 543 |
+
sampler=QMCSampler(seed=seed, qmc_type="sobol"),
|
| 544 |
+
storage=study_path,
|
| 545 |
+
load_if_exists=True,
|
| 546 |
+
)
|
| 547 |
+
func = lambda trial: objective(trial, train_df, val_df)
|
| 548 |
+
study.optimize(func, n_trials=n_trials, show_progress_bar=True)
|
| 549 |
+
|
| 550 |
+
print("Best trial:")
|
| 551 |
+
print(study.best_trial)
|
| 552 |
+
|
| 553 |
+
# Train classifier with best HP and return it
|
| 554 |
+
best_params = study.best_trial.params
|
| 555 |
+
edge_classifier_kwargs = {
|
| 556 |
+
"graph_features": None,
|
| 557 |
+
"categorical_features": None,
|
| 558 |
+
"fingerprint_features": None,
|
| 559 |
+
"use_descriptors": best_params["use_descriptors"],
|
| 560 |
+
"use_fingerprints": best_params["use_fingerprints"],
|
| 561 |
+
"n_svd_components": best_params["n_svd_components"],
|
| 562 |
+
"binary": True,
|
| 563 |
+
"smote_k_neighbors": best_params["smote_k_neighbors"],
|
| 564 |
+
"xgb_params": {
|
| 565 |
+
"max_depth": best_params["max_depth"],
|
| 566 |
+
"learning_rate": best_params["learning_rate"],
|
| 567 |
+
"alpha": best_params["alpha"],
|
| 568 |
+
"lambda": best_params["lambda"],
|
| 569 |
+
"gamma": best_params["gamma"],
|
| 570 |
+
},
|
| 571 |
+
}
|
| 572 |
+
clf, _ = train_edge_classifier(
|
| 573 |
+
train_df=train_df,
|
| 574 |
+
val_df=val_df,
|
| 575 |
+
edge_classifier_kwargs=edge_classifier_kwargs,
|
| 576 |
+
return_reports=True,
|
| 577 |
+
)
|
| 578 |
+
study_file = os.path.join(study_dir, f"{study_name}_study.pkl")
|
| 579 |
+
import joblib
|
| 580 |
+
joblib.dump(study, study_file)
|
| 581 |
+
print(f"Optuna study saved to {study_file}")
|
| 582 |
+
return clf, study
|
protac_splitter/graphs/edge_features.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List
|
| 2 |
+
|
| 3 |
+
from rdkit import Chem
|
| 4 |
+
from rdkit.Chem import AllChem, Descriptors, Draw
|
| 5 |
+
import networkx as nx
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from protac_splitter.chemoinformatics import get_atom_idx_at_attachment
|
| 11 |
+
from protac_splitter.display_utils import safe_display, get_mapped_protac_img
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def bond_capacity(bond: Chem.Bond) -> int:
|
| 15 |
+
""" Calculate the capacity of a bond based on its type and properties.
|
| 16 |
+
Parameters:
|
| 17 |
+
bond (Chem.Bond): The bond object from RDKit.
|
| 18 |
+
Returns:
|
| 19 |
+
int: The capacity of the bond, where higher values indicate less preference for cutting.
|
| 20 |
+
"""
|
| 21 |
+
# High capacity for aromatic and ring bonds to avoid cutting them
|
| 22 |
+
if bond.GetIsAromatic() or bond.IsInRing():
|
| 23 |
+
return 1000 # very high capacity: avoid cutting aromatic bonds
|
| 24 |
+
elif bond.GetBondType() == Chem.BondType.SINGLE:
|
| 25 |
+
return 1 # low capacity: prefer to cut here
|
| 26 |
+
elif bond.GetBondType() == Chem.BondType.DOUBLE:
|
| 27 |
+
return 10 # medium penalty
|
| 28 |
+
elif bond.GetBondType() == Chem.BondType.TRIPLE:
|
| 29 |
+
return 20 # stronger penalty
|
| 30 |
+
else:
|
| 31 |
+
return 50 # fallback for unknown/rare types
|
| 32 |
+
|
| 33 |
+
def smiles_to_nx(
|
| 34 |
+
smiles: str,
|
| 35 |
+
use_capacity: bool = False,
|
| 36 |
+
) -> nx.Graph:
|
| 37 |
+
""" Convert a SMILES string to a NetworkX graph.
|
| 38 |
+
Parameters:
|
| 39 |
+
smiles (str): The SMILES string to convert.
|
| 40 |
+
use_capacity (bool): Whether to use bond capacity as edge weights.
|
| 41 |
+
Returns:
|
| 42 |
+
nx.Graph: The NetworkX graph representation of the molecule.
|
| 43 |
+
"""
|
| 44 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 45 |
+
if mol is None:
|
| 46 |
+
raise ValueError(f"Input SMILES could not be parsed: {smiles}")
|
| 47 |
+
# Canonicalize the SMILES
|
| 48 |
+
mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol, canonical=True))
|
| 49 |
+
if mol is None:
|
| 50 |
+
raise ValueError(f"Input SMILES could not be canonicalized: {smiles}")
|
| 51 |
+
# Convert SMILES to NetworkX graph
|
| 52 |
+
G = nx.Graph()
|
| 53 |
+
if use_capacity:
|
| 54 |
+
for bond in mol.GetBonds():
|
| 55 |
+
capacity = bond_capacity(bond)
|
| 56 |
+
G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), capacity=capacity)
|
| 57 |
+
else:
|
| 58 |
+
for bond in mol.GetBonds():
|
| 59 |
+
G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())
|
| 60 |
+
return G
|
| 61 |
+
|
| 62 |
+
def extract_edge_features(
|
| 63 |
+
protac_smiles: str,
|
| 64 |
+
e3_split_pair: Tuple[int, int] = None,
|
| 65 |
+
wh_split_pair: Tuple[int, int] = None,
|
| 66 |
+
n_bits: int = 512,
|
| 67 |
+
radius: int = 6,
|
| 68 |
+
descriptor_names: List[str] = None,
|
| 69 |
+
fp_as_string: bool = False,
|
| 70 |
+
) -> pd.DataFrame:
|
| 71 |
+
"""Extract features from the edges of a PROTAC molecule represented as a SMILES string.
|
| 72 |
+
|
| 73 |
+
Parameters:
|
| 74 |
+
protac_smiles (str): SMILES representation of the PROTAC molecule.
|
| 75 |
+
e3_split_pair (Tuple[int, int]): Indices of the E3 split pair.
|
| 76 |
+
wh_split_pair (Tuple[int, int]): Indices of the warhead split pair.
|
| 77 |
+
n_bits (int): Number of bits for Morgan fingerprints.
|
| 78 |
+
radius (int): Radius for Morgan fingerprints.
|
| 79 |
+
descriptor_names (List[str]): List of RDKit descriptor names to compute.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
pd.DataFrame: DataFrame containing edge features.
|
| 83 |
+
"""
|
| 84 |
+
mol = Chem.MolFromSmiles(protac_smiles)
|
| 85 |
+
if mol is None:
|
| 86 |
+
raise ValueError(f"Input SMILES could not be parsed: {protac_smiles}")
|
| 87 |
+
# Canonicalize the SMILES
|
| 88 |
+
mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol, canonical=True))
|
| 89 |
+
if mol is None:
|
| 90 |
+
raise ValueError(f"Input SMILES could not be canonicalized: {protac_smiles}")
|
| 91 |
+
|
| 92 |
+
# Step 1: Convert SMILES to NetworkX
|
| 93 |
+
G = smiles_to_nx(protac_smiles, use_capacity=False)
|
| 94 |
+
|
| 95 |
+
num_nodes = G.number_of_nodes()
|
| 96 |
+
num_edges = G.number_of_edges()
|
| 97 |
+
|
| 98 |
+
# Step 2: Create line graph and compute betweenness + degree
|
| 99 |
+
LG = nx.line_graph(G)
|
| 100 |
+
line_betweenness = nx.betweenness_centrality(LG, endpoints=True)
|
| 101 |
+
betweenness = nx.betweenness_centrality(G, endpoints=True)
|
| 102 |
+
|
| 103 |
+
# Compute k-hop degrees (number of nodes within 2, 3 hops)
|
| 104 |
+
# TODO: Shall I get the degree of the node in the line graph or the original graph?
|
| 105 |
+
line_degree = dict(LG.degree())
|
| 106 |
+
line_degree_r2 = {}
|
| 107 |
+
line_degree_r3 = {}
|
| 108 |
+
for node in LG.nodes():
|
| 109 |
+
# Nodes within radius 2 and 3 (excluding the center node)
|
| 110 |
+
neighbors_r2 = nx.single_source_shortest_path_length(LG, node, cutoff=2)
|
| 111 |
+
neighbors_r3 = nx.single_source_shortest_path_length(LG, node, cutoff=3)
|
| 112 |
+
line_degree_r2[node] = len([n for n, d in neighbors_r2.items() if d == 2])
|
| 113 |
+
line_degree_r3[node] = len([n for n, d in neighbors_r3.items() if d == 3])
|
| 114 |
+
|
| 115 |
+
degree = dict(G.degree())
|
| 116 |
+
degree_r2 = {}
|
| 117 |
+
degree_r3 = {}
|
| 118 |
+
for node in G.nodes():
|
| 119 |
+
# Nodes within radius 2 and 3 (excluding the center node)
|
| 120 |
+
neighbors_r2 = nx.single_source_shortest_path_length(G, node, cutoff=2)
|
| 121 |
+
neighbors_r3 = nx.single_source_shortest_path_length(G, node, cutoff=3)
|
| 122 |
+
degree_r2[node] = len([n for n, d in neighbors_r2.items() if d == 2])
|
| 123 |
+
degree_r3[node] = len([n for n, d in neighbors_r3.items() if d == 3])
|
| 124 |
+
|
| 125 |
+
if e3_split_pair is not None and wh_split_pair is not None:
|
| 126 |
+
true_split_edges = {frozenset(e3_split_pair), frozenset(wh_split_pair)}
|
| 127 |
+
|
| 128 |
+
# Get molecular characteristics, i.e., Morgan fingerprints and descriptors
|
| 129 |
+
# Generate Morgan fingerprint
|
| 130 |
+
fp_bitvec = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
|
| 131 |
+
fp = np.zeros((n_bits,), dtype=np.float32)
|
| 132 |
+
AllChem.DataStructs.ConvertToNumpyArray(fp_bitvec, fp)
|
| 133 |
+
if fp_as_string:
|
| 134 |
+
fp = {"chem_mol_fp": "".join([str(int(bit)) for bit in fp])}
|
| 135 |
+
else:
|
| 136 |
+
fp = {f"chem_mol_fp_{i}": bool(fp[i]) for i in range(n_bits)}
|
| 137 |
+
# Generate RDKit descriptors
|
| 138 |
+
descriptor_func_names = descriptor_names or [
|
| 139 |
+
"MolWt", "HeavyAtomCount", "NumHAcceptors", "NumHDonors",
|
| 140 |
+
"TPSA", "NumRotatableBonds", "RingCount", "MolLogP"
|
| 141 |
+
]
|
| 142 |
+
functions = [getattr(Descriptors, name) for name in descriptor_func_names]
|
| 143 |
+
descriptors = {f"chem_mol_desc_{name}": func(mol) for name, func in zip(descriptor_func_names, functions)}
|
| 144 |
+
|
| 145 |
+
# Step 3: Gather edge features
|
| 146 |
+
# NOTE: Only consider bridge nodes
|
| 147 |
+
edge_features = []
|
| 148 |
+
for (u, v) in nx.bridges(G):
|
| 149 |
+
bond = mol.GetBondBetweenAtoms(u, v)
|
| 150 |
+
|
| 151 |
+
# Avoid reporting the same edge twice (i.e., swap u and v if needed) and
|
| 152 |
+
# ensure to find the node pair in the line graph
|
| 153 |
+
node = (u, v) if (u, v) in LG else (v, u)
|
| 154 |
+
node_key = node if node in line_betweenness else (v, u)
|
| 155 |
+
|
| 156 |
+
features = {
|
| 157 |
+
"graph_num_nodes": num_nodes,
|
| 158 |
+
"graph_num_edges": num_edges,
|
| 159 |
+
"graph_betweenness": line_betweenness.get(node_key, 0.0),
|
| 160 |
+
"graph_degree": line_degree.get(node_key, 0),
|
| 161 |
+
"graph_degree_r2": line_degree_r2.get(node_key, 0),
|
| 162 |
+
"graph_degree_r3": line_degree_r3.get(node_key, 0),
|
| 163 |
+
"graph_node_u_degree": degree.get(u, 0),
|
| 164 |
+
"graph_node_u_degree_r2": degree_r2.get(u, 0),
|
| 165 |
+
"graph_node_u_degree_r3": degree_r3.get(u, 0),
|
| 166 |
+
"graph_node_v_degree": degree.get(v, 0),
|
| 167 |
+
"graph_node_v_degree_r2": degree_r2.get(v, 0),
|
| 168 |
+
"graph_node_v_degree_r3": degree_r3.get(v, 0),
|
| 169 |
+
"graph_node_u_betweenness": betweenness.get(u, 0.0),
|
| 170 |
+
"graph_node_v_betweenness": betweenness.get(v, 0.0),
|
| 171 |
+
"chem_bond_idx": bond.GetIdx(),
|
| 172 |
+
"chem_bond_type": str(bond.GetBondType()),
|
| 173 |
+
"chem_atom_u": mol.GetAtomWithIdx(u).GetSymbol(),
|
| 174 |
+
"chem_atom_v": mol.GetAtomWithIdx(v).GetSymbol(),
|
| 175 |
+
"chem_is_aromatic": bond.GetIsAromatic(),
|
| 176 |
+
"chem_is_in_ring": bond.IsInRing(),
|
| 177 |
+
"chem_mol_smiles": protac_smiles,
|
| 178 |
+
"chem_mol_n_bits": n_bits,
|
| 179 |
+
"chem_mol_radius": radius,
|
| 180 |
+
}
|
| 181 |
+
# Add RDKit descriptors and Morgan fingerprint
|
| 182 |
+
features.update(fp)
|
| 183 |
+
features.update(descriptors)
|
| 184 |
+
|
| 185 |
+
# Add E3 and warhead split labels
|
| 186 |
+
if e3_split_pair is not None and wh_split_pair is not None:
|
| 187 |
+
features.update({
|
| 188 |
+
"label_is_split": frozenset([u, v]) in true_split_edges,
|
| 189 |
+
"label_e3_split": frozenset([u, v]) == frozenset(e3_split_pair),
|
| 190 |
+
"label_wh_split": frozenset([u, v]) == frozenset(wh_split_pair),
|
| 191 |
+
})
|
| 192 |
+
|
| 193 |
+
# Append the features to the list of edge features
|
| 194 |
+
edge_features.append(features)
|
| 195 |
+
|
| 196 |
+
df = pd.DataFrame(edge_features)
|
| 197 |
+
|
| 198 |
+
# Identify columns with int64 dtype
|
| 199 |
+
int64_cols = df.select_dtypes(include=['int64']).columns
|
| 200 |
+
|
| 201 |
+
# Create a dictionary mapping these columns to int32
|
| 202 |
+
dtype_mapping = {col: np.int32 for col in int64_cols}
|
| 203 |
+
|
| 204 |
+
# Apply the type conversion
|
| 205 |
+
df = df.astype(dtype_mapping)
|
| 206 |
+
|
| 207 |
+
return df
|
| 208 |
+
|
| 209 |
+
def get_edge_features(
|
| 210 |
+
protac_smiles: str | List[str],
|
| 211 |
+
wh_smiles: str | List[str],
|
| 212 |
+
lk_smiles: str | List[str],
|
| 213 |
+
e3_smiles: str | List[str],
|
| 214 |
+
n_bits: int = 512,
|
| 215 |
+
radius: int = 6,
|
| 216 |
+
descriptor_names: List[str] = None,
|
| 217 |
+
fp_as_string: bool = False,
|
| 218 |
+
verbose: int = 0,
|
| 219 |
+
) -> pd.DataFrame:
|
| 220 |
+
"""Get edge features for a given PROTAC molecule and its components.
|
| 221 |
+
|
| 222 |
+
Parameters:
|
| 223 |
+
protac_smiles (str | List[str]): SMILES representation of the PROTAC molecule.
|
| 224 |
+
wh_smiles (str | List[str]): SMILES representation of the warhead.
|
| 225 |
+
lk_smiles (str | List[str]): SMILES representation of the linker.
|
| 226 |
+
e3_smiles (str | List[str]): SMILES representation of the E3 binder.
|
| 227 |
+
n_bits (int): Number of bits for Morgan fingerprints.
|
| 228 |
+
radius (int): Radius for Morgan fingerprints.
|
| 229 |
+
descriptor_names (List[str]): List of RDKit descriptor names to compute.
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
pd.DataFrame: DataFrame containing edge features.
|
| 233 |
+
"""
|
| 234 |
+
if isinstance(protac_smiles, str):
|
| 235 |
+
protac_smiles = [protac_smiles]
|
| 236 |
+
if isinstance(wh_smiles, str):
|
| 237 |
+
wh_smiles = [wh_smiles]
|
| 238 |
+
if isinstance(lk_smiles, str):
|
| 239 |
+
lk_smiles = [lk_smiles]
|
| 240 |
+
if isinstance(e3_smiles, str):
|
| 241 |
+
e3_smiles = [e3_smiles]
|
| 242 |
+
|
| 243 |
+
iterables = zip(protac_smiles, wh_smiles, lk_smiles, e3_smiles)
|
| 244 |
+
iterables = tqdm(iterables, desc="Extracting edge features", total=len(protac_smiles), disable=verbose == 0)
|
| 245 |
+
features_list = []
|
| 246 |
+
for protac_smi, wh_smi, lk_smi, e3_smi in iterables:
|
| 247 |
+
if verbose > 1:
|
| 248 |
+
get_mapped_protac_img(protac_smi, wh_smi, lk_smi, e3_smi, w=1500, h=600, display_image=True, useSVG=True)
|
| 249 |
+
|
| 250 |
+
# Convert SMILES to RDKit molecules
|
| 251 |
+
protac = Chem.MolFromSmiles(protac_smi)
|
| 252 |
+
wh = Chem.MolFromSmiles(wh_smi)
|
| 253 |
+
lk = Chem.MolFromSmiles(lk_smi)
|
| 254 |
+
e3 = Chem.MolFromSmiles(e3_smi)
|
| 255 |
+
if protac is None or wh is None or lk is None or e3 is None:
|
| 256 |
+
raise ValueError(f"Invalid SMILES string: {protac}, {wh}, {lk}, {e3}")
|
| 257 |
+
|
| 258 |
+
# Get the attachment points
|
| 259 |
+
wh_edge = get_atom_idx_at_attachment(protac, wh, lk)
|
| 260 |
+
e3_edge = get_atom_idx_at_attachment(protac, e3, lk)
|
| 261 |
+
|
| 262 |
+
# Extract features
|
| 263 |
+
features = extract_edge_features(
|
| 264 |
+
protac_smi,
|
| 265 |
+
e3_split_pair=e3_edge,
|
| 266 |
+
wh_split_pair=wh_edge,
|
| 267 |
+
n_bits=n_bits,
|
| 268 |
+
radius=radius,
|
| 269 |
+
descriptor_names=descriptor_names,
|
| 270 |
+
fp_as_string=fp_as_string,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
if verbose > 1:
|
| 274 |
+
# Randomly sample and display 5 edges
|
| 275 |
+
sample_edges = features.sample(n=5, random_state=42)
|
| 276 |
+
# Display the sampled edges
|
| 277 |
+
for _, row in sample_edges.iterrows():
|
| 278 |
+
bond = protac.GetBondWithIdx(row['chem_bond_idx'])
|
| 279 |
+
u, v = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
| 280 |
+
safe_display(Draw.MolToImage(
|
| 281 |
+
protac,
|
| 282 |
+
size=(1500, 400),
|
| 283 |
+
highlightColor=(1, 0, 1, 0.3), # Light purple
|
| 284 |
+
highlightAtoms=[u, v], # Highlight the two atoms
|
| 285 |
+
legend=f"Graph nodes: {u}, {v} (Betweenness centrality: {row['graph_betweenness']:.3f})",
|
| 286 |
+
))
|
| 287 |
+
# print(row[[c for c in features.columns if c.startswith('graph_')] + ['chem_atom_u', 'chem_atom_v', 'chem_is_in_ring']])
|
| 288 |
+
print(row)
|
| 289 |
+
|
| 290 |
+
# Append the features to the list
|
| 291 |
+
features_list.append(features)
|
| 292 |
+
|
| 293 |
+
return pd.concat(features_list, ignore_index=True)
|
protac_splitter/graphs/splitting_algorithms.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Dict, Any, Optional, List, Union
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from joblib import Parallel, delayed
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import networkx as nx
|
| 8 |
+
from rdkit import Chem, DataStructs
|
| 9 |
+
from rdkit.Chem import rdFingerprintGenerator
|
| 10 |
+
|
| 11 |
+
from .edge_classifier import GraphEdgeClassifier
|
| 12 |
+
from .e3_clustering import get_representative_e3s_fp
|
| 13 |
+
from .utils import average_tanimoto_distance
|
| 14 |
+
from protac_splitter.data.curation.bond_adjustments import (
|
| 15 |
+
adjust_amide_bonds_in_substructs,
|
| 16 |
+
adjust_ester_bonds_in_substructs
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
def bond_capacity(bond: Chem.Bond) -> int:
|
| 20 |
+
if bond.GetIsAromatic() or bond.IsInRing():
|
| 21 |
+
return 1000 # very high capacity: avoid cutting aromatic bonds
|
| 22 |
+
elif bond.GetBondType() == Chem.BondType.SINGLE:
|
| 23 |
+
return 1 # low capacity: prefer to cut here
|
| 24 |
+
elif bond.GetBondType() == Chem.BondType.DOUBLE:
|
| 25 |
+
return 10 # medium penalty
|
| 26 |
+
elif bond.GetBondType() == Chem.BondType.TRIPLE:
|
| 27 |
+
return 20 # stronger penalty
|
| 28 |
+
else:
|
| 29 |
+
return 50 # fallback for unknown/rare types
|
| 30 |
+
|
| 31 |
+
def smiles_to_nx(smiles: str) -> nx.Graph:
|
| 32 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 33 |
+
G = nx.Graph()
|
| 34 |
+
for bond in mol.GetBonds():
|
| 35 |
+
capacity = bond_capacity(bond)
|
| 36 |
+
G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), capacity=capacity)
|
| 37 |
+
return G
|
| 38 |
+
|
| 39 |
+
def extract_attachment_point(smiles):
|
| 40 |
+
"""
|
| 41 |
+
Extracts the number X from the pattern [X*] in a SMILES string.
|
| 42 |
+
|
| 43 |
+
Parameters:
|
| 44 |
+
smiles (str): The SMILES string containing the attachment point.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
str or None: The extracted number as a string, or None if not found.
|
| 48 |
+
"""
|
| 49 |
+
match = re.search(r'\[(\d+)\*\]', smiles)
|
| 50 |
+
return match.group(1) if match else None
|
| 51 |
+
|
| 52 |
+
def split_protac_with_betweenness_centrality(
|
| 53 |
+
protac_smiles: str,
|
| 54 |
+
representative_e3s_fp: List[DataStructs.ExplicitBitVect] = None,
|
| 55 |
+
morgan_fp_generator: Optional[Any] = None,
|
| 56 |
+
use_capacity_weight: bool = False,
|
| 57 |
+
betweenness_threshold: float = 0.4,
|
| 58 |
+
) -> Dict[str, str]:
|
| 59 |
+
"""
|
| 60 |
+
Split the PROTAC molecule into two parts using the NetworkX library.
|
| 61 |
+
|
| 62 |
+
Parameters:
|
| 63 |
+
protac_smiles (str): The SMILES string of the PROTAC molecule.
|
| 64 |
+
representative_e3s_fp (list): List of representative E3 ligands fingerprints.
|
| 65 |
+
morgan_fp_generator: RDKit Morgan fingerprint generator (should be the same as the one that generated the E3 fingerprints).
|
| 66 |
+
use_capacity_weight (bool): Whether to use bond capacity as weight for the graph.
|
| 67 |
+
betweenness_threshold (float): Threshold for betweenness centrality to consider a node as a candidate for splitting.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
dict: A dictionary containing the E3 ligand, warhead, linker, top nodes, and max centrality score.
|
| 71 |
+
"""
|
| 72 |
+
if morgan_fp_generator is None:
|
| 73 |
+
# Create a default Morgan fingerprint generator
|
| 74 |
+
morgan_fp_generator = rdFingerprintGenerator.GetMorganGenerator(
|
| 75 |
+
radius=16,
|
| 76 |
+
fpSize=1024,
|
| 77 |
+
useBondTypes=True,
|
| 78 |
+
includeChirality=True,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
if representative_e3s_fp is None:
|
| 82 |
+
# Get the representative E3 ligands fingerprints
|
| 83 |
+
representative_e3s_fp = get_representative_e3s_fp(fp_generator=morgan_fp_generator)
|
| 84 |
+
|
| 85 |
+
# -----------------------------------
|
| 86 |
+
# Deterministic graph-based algorithm
|
| 87 |
+
# -----------------------------------
|
| 88 |
+
protac = Chem.MolFromSmiles(protac_smiles)
|
| 89 |
+
if protac is None:
|
| 90 |
+
raise ValueError(f"Invalid SMILES string: {protac_smiles}")
|
| 91 |
+
|
| 92 |
+
G = smiles_to_nx(protac_smiles)
|
| 93 |
+
|
| 94 |
+
# Compute betweenness centrality
|
| 95 |
+
weight = 'capacity' if use_capacity_weight else None
|
| 96 |
+
centrality = nx.betweenness_centrality(G, normalized=True, endpoints=True, weight=weight)
|
| 97 |
+
|
| 98 |
+
# Get the two nodes with the highest betweenness centrality
|
| 99 |
+
sorted_nodes = sorted(centrality.items(), key=lambda x: x[1], reverse=True)
|
| 100 |
+
|
| 101 |
+
# Get the list of bridges in the graph
|
| 102 |
+
bridges = list(nx.bridges(G))
|
| 103 |
+
|
| 104 |
+
# Get the top two nodes
|
| 105 |
+
top_nodes = [n for n, _ in sorted_nodes if n in bridges][:2]
|
| 106 |
+
|
| 107 |
+
# Get the top nodes with the highest betweenness centrality that are not in
|
| 108 |
+
# a ring, but are adjacent to the top nodes or have a high betweenness
|
| 109 |
+
for node, score in sorted_nodes:
|
| 110 |
+
# Check if the node is in a ring in the protac molecule
|
| 111 |
+
atom = protac.GetAtomWithIdx(node)
|
| 112 |
+
if not atom.IsInRing():
|
| 113 |
+
# Check if the atom is adjacent to any of the top nodes, if so, add it to the list
|
| 114 |
+
for neighbor in G.neighbors(node):
|
| 115 |
+
if neighbor in top_nodes:
|
| 116 |
+
top_nodes.append(node)
|
| 117 |
+
break
|
| 118 |
+
if score > betweenness_threshold:
|
| 119 |
+
top_nodes.append(node)
|
| 120 |
+
|
| 121 |
+
# If a node as only top nodes as neighbors, add it to the list
|
| 122 |
+
for node in G.nodes():
|
| 123 |
+
if node not in top_nodes:
|
| 124 |
+
neighbors = list(G.neighbors(node))
|
| 125 |
+
if all(neighbor in top_nodes for neighbor in neighbors):
|
| 126 |
+
top_nodes.append(node)
|
| 127 |
+
|
| 128 |
+
# Get all paths between the top nodes, e.g., rings
|
| 129 |
+
for i in range(len(top_nodes)):
|
| 130 |
+
for j in range(i + 1, len(top_nodes)):
|
| 131 |
+
node1 = top_nodes[i]
|
| 132 |
+
node2 = top_nodes[j]
|
| 133 |
+
|
| 134 |
+
for path in nx.all_simple_paths(G, node1, node2):
|
| 135 |
+
for node in path:
|
| 136 |
+
if node not in top_nodes:
|
| 137 |
+
top_nodes.append(node)
|
| 138 |
+
|
| 139 |
+
# Remove duplicates
|
| 140 |
+
top_nodes = list(set(top_nodes))
|
| 141 |
+
|
| 142 |
+
# Loop over the top nodes and find the nodes that have a neighbor outside
|
| 143 |
+
# the top nodes
|
| 144 |
+
edge_nodes = set()
|
| 145 |
+
for top_node in top_nodes:
|
| 146 |
+
for neighbor in G.neighbors(top_node):
|
| 147 |
+
if neighbor not in top_nodes:
|
| 148 |
+
edge_nodes.update([(top_node, neighbor)])
|
| 149 |
+
break
|
| 150 |
+
|
| 151 |
+
# Get molecule fragment from the top nodes
|
| 152 |
+
bonds = [protac.GetBondBetweenAtoms(i, j) for (i, j) in edge_nodes]
|
| 153 |
+
bonds_idx = [bond.GetIdx() for bond in bonds if bond is not None]
|
| 154 |
+
|
| 155 |
+
# Try any pair of indexes, if the number of resulting fragments is not 3,
|
| 156 |
+
# then do not consider them as candidates for splitting
|
| 157 |
+
candidate_bonds = []
|
| 158 |
+
for i in range(len(bonds_idx)):
|
| 159 |
+
for j in range(i + 1, len(bonds_idx)):
|
| 160 |
+
bond1 = bonds_idx[i]
|
| 161 |
+
bond2 = bonds_idx[j]
|
| 162 |
+
|
| 163 |
+
# Get the fragments
|
| 164 |
+
fragments = Chem.FragmentOnBonds(protac, [bond1, bond2])
|
| 165 |
+
|
| 166 |
+
# Check if there are 3 fragments
|
| 167 |
+
if Chem.MolToSmiles(fragments).count(".") == 2:
|
| 168 |
+
frag_lens = []
|
| 169 |
+
avg_len = 0
|
| 170 |
+
for frag in Chem.GetMolFrags(fragments, asMols=True):
|
| 171 |
+
frag_len = frag.GetNumAtoms()
|
| 172 |
+
frag_lens.append(frag_len)
|
| 173 |
+
avg_len += frag_len
|
| 174 |
+
avg_len /= 3
|
| 175 |
+
|
| 176 |
+
# Calculate the standard deviation of the fragment lengths
|
| 177 |
+
len_std = 0
|
| 178 |
+
for frag_len in frag_lens:
|
| 179 |
+
len_std += (frag_len - avg_len) ** 2
|
| 180 |
+
len_std = (len_std / 3) ** 0.5
|
| 181 |
+
candidate_bonds.append(((bond1, bond2), len_std))
|
| 182 |
+
|
| 183 |
+
# Sort the candidate bonds by distance to average (smallest first)
|
| 184 |
+
candidate_bonds = sorted(candidate_bonds, key=lambda x: x[1])
|
| 185 |
+
|
| 186 |
+
ligands = None
|
| 187 |
+
while ligands is None and len(candidate_bonds) > 0:
|
| 188 |
+
bonds_idx = candidate_bonds[0][0]
|
| 189 |
+
try:
|
| 190 |
+
ligands = Chem.FragmentOnBonds(protac, bonds_idx, addDummies=True, dummyLabels=[(1, 1), (2, 2)])
|
| 191 |
+
except Exception as e:
|
| 192 |
+
print(f"Error fragmenting the molecule: {e}")
|
| 193 |
+
candidate_bonds.pop(0)
|
| 194 |
+
|
| 195 |
+
# If no candidate bonds were found, return None
|
| 196 |
+
if ligands is None:
|
| 197 |
+
print(f"No candidate bonds found for splitting PROTAC: {protac_smiles}")
|
| 198 |
+
return {'e3': None, 'poi': None, 'linker': None, 'top_nodes': None, 'centrality': None}
|
| 199 |
+
|
| 200 |
+
# Get the linker
|
| 201 |
+
substructures = []
|
| 202 |
+
for ligand in Chem.GetMolFrags(ligands, asMols=True):
|
| 203 |
+
ligand_smiles = Chem.MolToSmiles(ligand, canonical=True)
|
| 204 |
+
if ligand_smiles.count("*") == 2:
|
| 205 |
+
linker_smiles = ligand_smiles
|
| 206 |
+
else:
|
| 207 |
+
substructures.append(ligand_smiles)
|
| 208 |
+
|
| 209 |
+
sub1_dist = average_tanimoto_distance(substructures[0], representative_e3s_fp, morgan_fp_generator)
|
| 210 |
+
sub2_dist = average_tanimoto_distance(substructures[1], representative_e3s_fp, morgan_fp_generator)
|
| 211 |
+
if sub1_dist < sub2_dist:
|
| 212 |
+
e3_smiles = substructures[0]
|
| 213 |
+
wh_smiles = substructures[1]
|
| 214 |
+
else:
|
| 215 |
+
e3_smiles = substructures[1]
|
| 216 |
+
wh_smiles = substructures[0]
|
| 217 |
+
|
| 218 |
+
# Get the attachment point using a regex, e.g., should return 1 if [1*] is in the SMILES
|
| 219 |
+
e3_attach_point = extract_attachment_point(e3_smiles)
|
| 220 |
+
e3_smiles = e3_smiles.replace(f"[{e3_attach_point}*]", "[*:2]")
|
| 221 |
+
linker_smiles = linker_smiles.replace(f"[{e3_attach_point}*]", "[*:2]")
|
| 222 |
+
|
| 223 |
+
wh_attach_point = extract_attachment_point(wh_smiles)
|
| 224 |
+
wh_smiles = wh_smiles.replace(f"[{wh_attach_point}*]", "[*:1]")
|
| 225 |
+
linker_smiles = linker_smiles.replace(f"[{wh_attach_point}*]", "[*:1]")
|
| 226 |
+
return {'e3': e3_smiles, 'poi': wh_smiles, 'linker': linker_smiles, 'top_nodes': top_nodes, 'centrality': centrality}
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def split_protac_with_edge_classifier(
|
| 230 |
+
protac_smiles: str,
|
| 231 |
+
pipeline: Union[str, Path],
|
| 232 |
+
representative_e3s_fp: Optional[List[np.array]] = None,
|
| 233 |
+
morgan_fp_generator: Optional[Any] = None,
|
| 234 |
+
) -> Dict[str, str]:
|
| 235 |
+
""" Split the PROTAC molecule into two parts using the pretrained edge classifier.
|
| 236 |
+
|
| 237 |
+
Parameters:
|
| 238 |
+
protac_smiles (str): The SMILES string of the PROTAC molecule.
|
| 239 |
+
pipeline (Union[str, Path]): Path to the trained GraphEdgeClassifier model.
|
| 240 |
+
representative_e3s_fp (Optional[List[np.array]]): Precomputed fingerprints of representative E3 ligands.
|
| 241 |
+
morgan_fp_generator (Optional[Any]): RDKit Morgan fingerprint generator (should be the same as the one that generated the E3 fingerprints).
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
dict: A dictionary containing the E3 ligand, warhead, linker, and bonds_idx
|
| 245 |
+
"""
|
| 246 |
+
if morgan_fp_generator is None:
|
| 247 |
+
# Create a default Morgan fingerprint generator
|
| 248 |
+
morgan_fp_generator = rdFingerprintGenerator.GetMorganGenerator(
|
| 249 |
+
radius=16,
|
| 250 |
+
fpSize=1024,
|
| 251 |
+
useBondTypes=True,
|
| 252 |
+
includeChirality=True,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
if representative_e3s_fp is None:
|
| 256 |
+
# Get the representative E3 ligands fingerprints
|
| 257 |
+
representative_e3s_fp = get_representative_e3s_fp(fp_generator=morgan_fp_generator)
|
| 258 |
+
|
| 259 |
+
protac = Chem.MolFromSmiles(protac_smiles)
|
| 260 |
+
if protac is None:
|
| 261 |
+
raise ValueError(f"Invalid SMILES string: {protac_smiles}")
|
| 262 |
+
|
| 263 |
+
if isinstance(pipeline, str):
|
| 264 |
+
pipeline = GraphEdgeClassifier.load(pipeline)
|
| 265 |
+
|
| 266 |
+
# TODO: Get the top-n bonds, if splitting results in more than 3 ligands,
|
| 267 |
+
# test other pairs of bonds, then repeat until we get 3 ligands exactly.
|
| 268 |
+
bonds_idx = pipeline.predict_from_smiles(
|
| 269 |
+
protac_smiles,
|
| 270 |
+
wh_smiles=None,
|
| 271 |
+
lk_smiles=None,
|
| 272 |
+
e3_smiles=None,
|
| 273 |
+
top_n=2,
|
| 274 |
+
return_array=True,
|
| 275 |
+
).flatten().tolist()
|
| 276 |
+
# print(f"Predicted bonds: {bonds_idx}")
|
| 277 |
+
|
| 278 |
+
if -1 in bonds_idx:
|
| 279 |
+
bonds_idx = [bond for bond in bonds_idx if bond != -1]
|
| 280 |
+
# Randomly select a bond index from the PROTAC molecule
|
| 281 |
+
# that is not in the predicted bonds
|
| 282 |
+
for _ in range(2 - len(bonds_idx)):
|
| 283 |
+
bond = np.random.choice([bond.GetIdx() for bond in protac.GetBonds() if bond.GetIdx() not in bonds_idx and not bond.IsInRing()])
|
| 284 |
+
bonds_idx.append(int(bond))
|
| 285 |
+
|
| 286 |
+
ligands = Chem.FragmentOnBonds(protac, bonds_idx, addDummies=True, dummyLabels=[(1, 1), (2, 2)])
|
| 287 |
+
|
| 288 |
+
# Get the linker
|
| 289 |
+
substructures = []
|
| 290 |
+
for ligand in Chem.GetMolFrags(ligands, asMols=True):
|
| 291 |
+
ligand_smiles = Chem.MolToSmiles(ligand, canonical=True)
|
| 292 |
+
if ligand_smiles.count("*") == 2:
|
| 293 |
+
linker_smiles = ligand_smiles
|
| 294 |
+
else:
|
| 295 |
+
substructures.append(ligand_smiles)
|
| 296 |
+
|
| 297 |
+
if not pipeline.binary:
|
| 298 |
+
e3_smiles = substructures[0]
|
| 299 |
+
wh_smiles = substructures[1]
|
| 300 |
+
# NOTE: The classifier was trained on the following labels assignment:
|
| 301 |
+
e3_attach_point = 1
|
| 302 |
+
wh_attach_point = 2
|
| 303 |
+
else:
|
| 304 |
+
if representative_e3s_fp is None or morgan_fp_generator is None:
|
| 305 |
+
raise ValueError("For pipeline trained on binary classification, representative_e3s_fp and morgan_fp_generator must be provided.")
|
| 306 |
+
sub1_dist = average_tanimoto_distance(substructures[0], representative_e3s_fp, morgan_fp_generator)
|
| 307 |
+
sub2_dist = average_tanimoto_distance(substructures[1], representative_e3s_fp, morgan_fp_generator)
|
| 308 |
+
if sub1_dist < sub2_dist:
|
| 309 |
+
e3_smiles = substructures[0]
|
| 310 |
+
wh_smiles = substructures[1]
|
| 311 |
+
else:
|
| 312 |
+
e3_smiles = substructures[1]
|
| 313 |
+
wh_smiles = substructures[0]
|
| 314 |
+
# Get the attachment point using a regex, e.g., should return 1 if [1*] is in the SMILES
|
| 315 |
+
e3_attach_point = extract_attachment_point(e3_smiles)
|
| 316 |
+
wh_attach_point = extract_attachment_point(wh_smiles)
|
| 317 |
+
|
| 318 |
+
e3_smiles = e3_smiles.replace(f"[{e3_attach_point}*]", "[*:2]")
|
| 319 |
+
linker_smiles = linker_smiles.replace(f"[{e3_attach_point}*]", "[*:2]")
|
| 320 |
+
|
| 321 |
+
wh_smiles = wh_smiles.replace(f"[{wh_attach_point}*]", "[*:1]")
|
| 322 |
+
linker_smiles = linker_smiles.replace(f"[{wh_attach_point}*]", "[*:1]")
|
| 323 |
+
return {'e3': e3_smiles, 'poi': wh_smiles, 'linker': linker_smiles, "bonds_idx": bonds_idx}
|
| 324 |
+
|
| 325 |
+
def split_protac_graph_based(
|
| 326 |
+
protac_smiles: str,
|
| 327 |
+
use_classifier: bool = False,
|
| 328 |
+
classifier: Optional['GraphEdgeClassifier'] = None,
|
| 329 |
+
representative_e3s_fp: Optional[List[Any]] = None,
|
| 330 |
+
morgan_fp_generator: Optional[Any] = None,
|
| 331 |
+
use_capacity_weight: bool = False,
|
| 332 |
+
betweenness_threshold: float = 0.4,
|
| 333 |
+
) -> Dict[str, str]:
|
| 334 |
+
"""
|
| 335 |
+
Splits a PROTAC molecule using either ML classifier or deterministic betweenness centrality.
|
| 336 |
+
Returns a dictionary with e3, poi, linker, bonds_idx.
|
| 337 |
+
"""
|
| 338 |
+
|
| 339 |
+
if representative_e3s_fp is None:
|
| 340 |
+
if morgan_fp_generator is None:
|
| 341 |
+
# Create a default Morgan fingerprint generator
|
| 342 |
+
morgan_fp_generator = rdFingerprintGenerator.GetMorganGenerator(
|
| 343 |
+
radius=16,
|
| 344 |
+
fpSize=1024,
|
| 345 |
+
useBondTypes=True,
|
| 346 |
+
includeChirality=True,
|
| 347 |
+
)
|
| 348 |
+
# Get the representative E3 ligands fingerprints
|
| 349 |
+
representative_e3s_fp = get_representative_e3s_fp(fp_generator=morgan_fp_generator)
|
| 350 |
+
|
| 351 |
+
if use_classifier:
|
| 352 |
+
ret = split_protac_with_edge_classifier(
|
| 353 |
+
protac_smiles=protac_smiles,
|
| 354 |
+
pipeline=classifier,
|
| 355 |
+
representative_e3s_fp=representative_e3s_fp,
|
| 356 |
+
morgan_fp_generator=morgan_fp_generator,
|
| 357 |
+
)
|
| 358 |
+
else:
|
| 359 |
+
ret = split_protac_with_betweenness_centrality(
|
| 360 |
+
protac_smiles=protac_smiles,
|
| 361 |
+
representative_e3s_fp=representative_e3s_fp,
|
| 362 |
+
morgan_fp_generator=morgan_fp_generator,
|
| 363 |
+
use_capacity_weight=use_capacity_weight,
|
| 364 |
+
betweenness_threshold=betweenness_threshold,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
substructs = {
|
| 368 |
+
"e3": ret["e3"],
|
| 369 |
+
"poi": ret["poi"],
|
| 370 |
+
"linker": ret["linker"],
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
# If all of the substructures are not None, fix the amide and ester bonds
|
| 374 |
+
if all(x is not None for x in substructs.values()):
|
| 375 |
+
substructs = adjust_amide_bonds_in_substructs(substructs, protac_smiles)
|
| 376 |
+
substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
|
| 377 |
+
ret["e3"] = substructs["e3"]
|
| 378 |
+
ret["poi"] = substructs["poi"]
|
| 379 |
+
ret["linker"] = substructs["linker"]
|
| 380 |
+
|
| 381 |
+
return ret
|
| 382 |
+
|
| 383 |
+
def split_protac_with_graphs_wrapper(
|
| 384 |
+
protac_smiles: List[str],
|
| 385 |
+
use_classifier: bool = False,
|
| 386 |
+
classifier: Optional['GraphEdgeClassifier'] = None,
|
| 387 |
+
representative_e3s: Optional[List[Any]] = None,
|
| 388 |
+
representative_e3s_fp: Optional[List[Any]] = None,
|
| 389 |
+
morgan_fp_generator: Optional[Any] = None,
|
| 390 |
+
use_capacity_weight: bool = False,
|
| 391 |
+
betweenness_threshold: float = 0.4,
|
| 392 |
+
) -> List[Dict[str, str]]:
|
| 393 |
+
""" Wrapper function to apply split_protac_graph_based over a list of PROTAC SMILES.
|
| 394 |
+
|
| 395 |
+
Parameters:
|
| 396 |
+
protac_smiles (List[str]): List of SMILES strings of PROTAC molecules.
|
| 397 |
+
use_classifier (bool): Whether to use a classifier for splitting.
|
| 398 |
+
classifier (Optional[GraphEdgeClassifier]): Classifier to use if use_classifier is True.
|
| 399 |
+
representative_e3s_fp (Optional[List[Any]]): Precomputed fingerprints of representative E3 ligands.
|
| 400 |
+
morgan_fp_generator (Optional[Any]): RDKit Morgan fingerprint generator.
|
| 401 |
+
use_capacity_weight (bool): Whether to use bond capacity as weight for the graph.
|
| 402 |
+
betweenness_threshold (float): Threshold for betweenness centrality to consider a node as a candidate for splitting.
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
List[Dict[str, str]]: List of dictionaries containing the split results for each PROTAC molecule.
|
| 406 |
+
"""
|
| 407 |
+
if morgan_fp_generator is None and (representative_e3s is None or representative_e3s_fp is None):
|
| 408 |
+
# Create a default Morgan fingerprint generator
|
| 409 |
+
morgan_fp_generator = rdFingerprintGenerator.GetMorganGenerator(
|
| 410 |
+
radius=16,
|
| 411 |
+
fpSize=1024,
|
| 412 |
+
useBondTypes=True,
|
| 413 |
+
includeChirality=True,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
if representative_e3s is None and representative_e3s_fp is None:
|
| 417 |
+
# Get the representative E3 ligands fingerprints
|
| 418 |
+
representative_e3s_fp = get_representative_e3s_fp(fp_generator=morgan_fp_generator)
|
| 419 |
+
elif representative_e3s is not None and representative_e3s_fp is None:
|
| 420 |
+
# Convert representative E3 ligands to fingerprints
|
| 421 |
+
representative_e3s_fp = get_representative_e3s_fp(e3_list=representative_e3s, fp_generator=morgan_fp_generator)
|
| 422 |
+
|
| 423 |
+
# Load the classifier if it is a string or Path
|
| 424 |
+
if use_classifier and classifier is not None and isinstance(classifier, (str, Path)):
|
| 425 |
+
classifier = GraphEdgeClassifier.load(classifier)
|
| 426 |
+
|
| 427 |
+
return [
|
| 428 |
+
split_protac_graph_based(
|
| 429 |
+
protac_smiles=smi,
|
| 430 |
+
use_classifier=use_classifier,
|
| 431 |
+
classifier=classifier,
|
| 432 |
+
representative_e3s_fp=representative_e3s_fp,
|
| 433 |
+
morgan_fp_generator=morgan_fp_generator,
|
| 434 |
+
use_capacity_weight=use_capacity_weight,
|
| 435 |
+
betweenness_threshold=betweenness_threshold,
|
| 436 |
+
) for smi in protac_smiles
|
| 437 |
+
]
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def split_protac_with_graphs_parallel(
|
| 441 |
+
protac_smiles: List[str],
|
| 442 |
+
use_classifier: bool = False,
|
| 443 |
+
classifier: Optional['GraphEdgeClassifier'] = None,
|
| 444 |
+
representative_e3s: Optional[List[Any]] = None,
|
| 445 |
+
representative_e3s_fp: Optional[List[Any]] = None,
|
| 446 |
+
morgan_fp_generator: Optional[Any] = None,
|
| 447 |
+
use_capacity_weight: bool = False,
|
| 448 |
+
betweenness_threshold: float = 0.4,
|
| 449 |
+
n_jobs: int = 1,
|
| 450 |
+
batch_size: int = 1,
|
| 451 |
+
) -> List[Dict[str, str]]:
|
| 452 |
+
""" Splits a list of PROTAC molecules using either ML classifier or deterministic betweenness centrality.
|
| 453 |
+
|
| 454 |
+
Parameters:
|
| 455 |
+
protac_smiles (List[str]): List of SMILES strings of PROTAC molecules.
|
| 456 |
+
use_classifier (bool): Whether to use a classifier for splitting.
|
| 457 |
+
classifier (Optional[GraphEdgeClassifier]): Classifier to use if use_classifier is True.
|
| 458 |
+
representative_e3s (Optional[List[Any]]): List of representative E3 ligands. If None, uses precomputed fingerprints.
|
| 459 |
+
representative_e3s_fp (Optional[List[Any]]): Precomputed fingerprints of representative E3 ligands.
|
| 460 |
+
morgan_fp_generator (Optional[Any]): RDKit Morgan fingerprint generator.
|
| 461 |
+
use_capacity_weight (bool): Whether to use bond capacity as weight for the graph.
|
| 462 |
+
betweenness_threshold (float): Threshold for betweenness centrality to consider a node as a candidate for splitting.
|
| 463 |
+
n_jobs (int): Number of parallel jobs to run. If 1, runs sequentially.
|
| 464 |
+
batch_size (int): Size of each batch for parallel processing.
|
| 465 |
+
"""
|
| 466 |
+
# Load the classifier if it is a string or Path
|
| 467 |
+
if use_classifier and classifier is not None and isinstance(classifier, (str, Path)):
|
| 468 |
+
classifier = GraphEdgeClassifier.load(classifier)
|
| 469 |
+
|
| 470 |
+
if n_jobs < 1:
|
| 471 |
+
raise ValueError("n_jobs must be a positive integer.")
|
| 472 |
+
if n_jobs == 1:
|
| 473 |
+
# If n_jobs is 1, run the function sequentially
|
| 474 |
+
return split_protac_with_graphs_wrapper(
|
| 475 |
+
protac_smiles=protac_smiles,
|
| 476 |
+
use_classifier=use_classifier,
|
| 477 |
+
classifier=classifier,
|
| 478 |
+
representative_e3s=representative_e3s,
|
| 479 |
+
representative_e3s_fp=representative_e3s_fp,
|
| 480 |
+
morgan_fp_generator=morgan_fp_generator,
|
| 481 |
+
use_capacity_weight=use_capacity_weight,
|
| 482 |
+
betweenness_threshold=betweenness_threshold,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# Raise a warning if the n_jobs > 1 and the fingerprint generator is provided
|
| 486 |
+
if morgan_fp_generator is not None:
|
| 487 |
+
print("Warning: Using a custom Morgan fingerprint generator with n_jobs > 1 may be un-pickleable.")
|
| 488 |
+
|
| 489 |
+
# Split the SMILES list into batches
|
| 490 |
+
smiles_batches = [protac_smiles[i:i + batch_size] for i in range(0, len(protac_smiles), batch_size)]
|
| 491 |
+
|
| 492 |
+
# Ensure all SMILES are processed, even if the last batch is smaller than batch_size
|
| 493 |
+
smiles_batches = [protac_smiles[i:i + batch_size] for i in range(0, len(protac_smiles), batch_size)]
|
| 494 |
+
# Remove any empty batches (shouldn't happen, but for safety)
|
| 495 |
+
smiles_batches = [batch for batch in smiles_batches if batch]
|
| 496 |
+
|
| 497 |
+
# Run each batch in parallel
|
| 498 |
+
results = Parallel(n_jobs=n_jobs)(
|
| 499 |
+
delayed(split_protac_with_graphs_wrapper)(
|
| 500 |
+
protac_smiles=batch,
|
| 501 |
+
use_classifier=use_classifier,
|
| 502 |
+
classifier=classifier,
|
| 503 |
+
representative_e3s=representative_e3s,
|
| 504 |
+
representative_e3s_fp=representative_e3s_fp,
|
| 505 |
+
morgan_fp_generator=morgan_fp_generator,
|
| 506 |
+
use_capacity_weight=use_capacity_weight,
|
| 507 |
+
betweenness_threshold=betweenness_threshold,
|
| 508 |
+
) for batch in smiles_batches
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# Flatten the list of lists into a single list
|
| 512 |
+
return [item for batch_result in results for item in batch_result]
|
protac_splitter/graphs/utils.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Optional, List
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from rdkit import Chem, DataStructs
|
| 5 |
+
from rdkit.Chem import rdFingerprintGenerator
|
| 6 |
+
|
| 7 |
+
def get_fp(
|
| 8 |
+
smiles: str,
|
| 9 |
+
fp_generator: Optional[Any] = None,
|
| 10 |
+
return_np: bool = True,
|
| 11 |
+
) -> Optional[np.ndarray]:
|
| 12 |
+
"""
|
| 13 |
+
Get the Morgan fingerprint of a molecule from its SMILES representation.
|
| 14 |
+
|
| 15 |
+
Parameters:
|
| 16 |
+
smiles (str): The SMILES string of the molecule.
|
| 17 |
+
fp_generator (Any, optional): The fingerprint generator to use. If None, a default generator is used.
|
| 18 |
+
return_np (bool): Whether to return the fingerprint as a NumPy array. Defaults to True.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
Optional[np.ndarray]: The Morgan fingerprint of the molecule as a NumPy array, or None if the SMILES is invalid.
|
| 22 |
+
"""
|
| 23 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 24 |
+
if mol is None:
|
| 25 |
+
return None
|
| 26 |
+
|
| 27 |
+
if fp_generator is None:
|
| 28 |
+
fp_generator = rdFingerprintGenerator.GetMorganGenerator(
|
| 29 |
+
radius=16,
|
| 30 |
+
fpSize=1024,
|
| 31 |
+
useBondTypes=True,
|
| 32 |
+
includeChirality=True,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
if return_np:
|
| 36 |
+
return fp_generator.GetFingerprintAsNumPy(mol)
|
| 37 |
+
else:
|
| 38 |
+
return fp_generator.GetFingerprint(mol)
|
| 39 |
+
|
| 40 |
+
def average_tanimoto_distance(
|
| 41 |
+
smiles: str,
|
| 42 |
+
fingerprints: List[DataStructs.ExplicitBitVect],
|
| 43 |
+
morgan_fp_generator: Optional[Any] = None,
|
| 44 |
+
) -> float:
|
| 45 |
+
"""
|
| 46 |
+
Compute the average Tanimoto distance between a query SMILES and a list of RDKit fingerprints.
|
| 47 |
+
|
| 48 |
+
Parameters:
|
| 49 |
+
smiles (str): SMILES string of the query molecule.
|
| 50 |
+
fingerprints (list): List of RDKit fingerprint objects (e.g., ExplicitBitVect).
|
| 51 |
+
morgan_fp_generator: RDKit Morgan fingerprint generator.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
float: Average Tanimoto distance (1 - similarity) between the query and the fingerprints.
|
| 55 |
+
"""
|
| 56 |
+
query_fp = get_fp(smiles, morgan_fp_generator, return_np=False)
|
| 57 |
+
if query_fp is None:
|
| 58 |
+
raise ValueError(f"Invalid SMILES string: {smiles}")
|
| 59 |
+
distances = DataStructs.BulkTanimotoSimilarity(query_fp, fingerprints, returnDistance=True)
|
| 60 |
+
|
| 61 |
+
return np.array(distances).mean()
|
| 62 |
+
|
| 63 |
+
def numpy_to_rdkit_fp(arr: np.ndarray) -> DataStructs.ExplicitBitVect:
|
| 64 |
+
"""
|
| 65 |
+
Convert a NumPy array to an RDKit ExplicitBitVect.
|
| 66 |
+
"""
|
| 67 |
+
return DataStructs.CreateFromBitString(''.join(arr.astype(str)))
|
protac_splitter/graphs_utils.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from numba import njit
|
| 2 |
+
import numpy as np
|
| 3 |
+
import networkx as nx
|
| 4 |
+
from rdkit import Chem
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def mol2graph(mol: Chem.Mol) -> nx.Graph:
|
| 8 |
+
""" Convert an RDKit molecule to a NetworkX graph.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
mol (Chem.Mol): The RDKit molecule to convert.
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
nx.Graph: The NetworkX graph representation of the molecule.
|
| 15 |
+
"""
|
| 16 |
+
# NOTE: https://github.com/maxhodak/keras-molecules/pull/32/files
|
| 17 |
+
# TODO: Double check this implementation too: https://gist.github.com/jhjensen2/6450138cda3ab796a30850610843cfff
|
| 18 |
+
if mol is None:
|
| 19 |
+
return nx.empty_graph()
|
| 20 |
+
G = nx.Graph()
|
| 21 |
+
for atom in mol.GetAtoms():
|
| 22 |
+
# Skip non-heavy atoms
|
| 23 |
+
if atom.GetAtomicNum() != 0:
|
| 24 |
+
G.add_node(atom.GetIdx(), label=atom.GetSymbol())
|
| 25 |
+
for bond in mol.GetBonds():
|
| 26 |
+
# Skip bonds to non-heavy atoms
|
| 27 |
+
if bond.GetBeginAtom().GetAtomicNum() == 0 or bond.GetEndAtom().GetAtomicNum() == 0:
|
| 28 |
+
continue
|
| 29 |
+
G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), label=bond.GetBondType())
|
| 30 |
+
return G
|
| 31 |
+
|
| 32 |
+
def smiles2graph(smiles: str) -> nx.Graph:
|
| 33 |
+
""" Convert a SMILES string to a NetworkX graph.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
smiles (str): The SMILES string to convert.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
nx.Graph: The NetworkX graph representation of the molecule.
|
| 40 |
+
"""
|
| 41 |
+
return mol2graph(Chem.MolFromSmiles(smiles))
|
| 42 |
+
|
| 43 |
+
def get_smiles2graph_edit_distance(smi1: str, smi2: str, **kwargs) -> float:
|
| 44 |
+
""" Compute the graph edit distance between two SMILES strings.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
smi1 (str): The first SMILES string.
|
| 48 |
+
smi2 (str): The second SMILES string.
|
| 49 |
+
**kwargs: Additional keyword arguments for `nx.graph_edit_distance`.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
float: The graph edit distance between the two SMILES strings.
|
| 53 |
+
"""
|
| 54 |
+
ged = nx.graph_edit_distance(smiles2graph(smi1), smiles2graph(smi2), **kwargs)
|
| 55 |
+
return ged if ged is not None else np.inf
|
| 56 |
+
|
| 57 |
+
def get_mol2graph_edit_distance(mol1: str, mol2: str, **kwargs) -> float:
|
| 58 |
+
""" Compute the graph edit distance between two RDKit molecules.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
mol1 (Chem.Mol): The first RDKit molecule.
|
| 62 |
+
mol2 (Chem.Mol): The second RDKit molecule.
|
| 63 |
+
**kwargs: Additional keyword arguments for `nx.graph_edit_distance`.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
float: The graph edit distance between the two RDKit molecules.
|
| 67 |
+
"""
|
| 68 |
+
ged = nx.graph_edit_distance(mol2graph(mol1), mol2graph(mol2), **kwargs)
|
| 69 |
+
return ged if ged is not None else np.inf
|
| 70 |
+
|
| 71 |
+
def get_smiles2graph_edit_distance_norm(
|
| 72 |
+
smi1: str,
|
| 73 |
+
smi2: str,
|
| 74 |
+
ged_G1_G2: None,
|
| 75 |
+
eps: float = 1e-9,
|
| 76 |
+
**kwargs,
|
| 77 |
+
) -> float:
|
| 78 |
+
""" Compute the normalized graph edit distance between two SMILES strings.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
smi1 (str): The first SMILES string.
|
| 82 |
+
smi2 (str): The second SMILES string.
|
| 83 |
+
ged_G1_G2 (float): The graph edit distance between the two graphs. If None, it will be computed using `nx.graph_edit_distance`.
|
| 84 |
+
eps (float): A small value to avoid division by zero.
|
| 85 |
+
**kwargs: Additional keyword arguments for `nx.graph_edit_distance`.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
float: The normalized graph edit distance between the two SMILES strings.
|
| 89 |
+
"""
|
| 90 |
+
G1 = smiles2graph(smi1)
|
| 91 |
+
G2 = smiles2graph(smi2)
|
| 92 |
+
G0 = nx.empty_graph()
|
| 93 |
+
ged_G1_G2 = ged_G1_G2 if ged_G1_G2 is not None else nx.graph_edit_distance(G1, G2, **kwargs)
|
| 94 |
+
ged_G1_G0 = nx.graph_edit_distance(G1, G0, **kwargs)
|
| 95 |
+
ged_G2_G0 = nx.graph_edit_distance(G2, G0, **kwargs)
|
| 96 |
+
if None in [ged_G1_G2, ged_G1_G0, ged_G2_G0]:
|
| 97 |
+
return np.inf
|
| 98 |
+
return ged_G1_G2 / (ged_G1_G0 + ged_G2_G0 + eps)
|
| 99 |
+
|
| 100 |
+
def smiles2adjacency_matrix(smiles: str) -> np.ndarray:
|
| 101 |
+
return nx.adjacency_matrix(smiles2graph(smiles)).todense()
|
| 102 |
+
|
| 103 |
+
def build_label_mapping(G1, G2):
|
| 104 |
+
labels = set()
|
| 105 |
+
for G in [G1, G2]:
|
| 106 |
+
for node in G.nodes():
|
| 107 |
+
labels.add(G.nodes[node]['label'])
|
| 108 |
+
label_to_int = {label: idx for idx, label in enumerate(sorted(labels))}
|
| 109 |
+
return label_to_int
|
| 110 |
+
|
| 111 |
+
def preprocess_graph(G, label_to_int):
|
| 112 |
+
n = G.number_of_nodes()
|
| 113 |
+
adj = np.zeros((n, n), dtype=np.int32)
|
| 114 |
+
labels = np.zeros(n, dtype=np.int32)
|
| 115 |
+
node_id_to_idx = {}
|
| 116 |
+
for idx, node in enumerate(G.nodes()):
|
| 117 |
+
node_id_to_idx[node] = idx
|
| 118 |
+
label = G.nodes[node]['label']
|
| 119 |
+
labels[idx] = label_to_int[label]
|
| 120 |
+
for u, v in G.edges():
|
| 121 |
+
idx_u = node_id_to_idx[u]
|
| 122 |
+
idx_v = node_id_to_idx[v]
|
| 123 |
+
adj[idx_u, idx_v] = 1
|
| 124 |
+
adj[idx_v, idx_u] = 1 # Assuming undirected graph
|
| 125 |
+
return adj, labels
|
| 126 |
+
|
| 127 |
+
@njit
|
| 128 |
+
def compute_cost_matrix(labels1, labels2, degrees1, degrees2):
|
| 129 |
+
n1 = labels1.shape[0]
|
| 130 |
+
n2 = labels2.shape[0]
|
| 131 |
+
C = np.zeros((n1, n2), dtype=np.float64)
|
| 132 |
+
for i in range(n1):
|
| 133 |
+
for j in range(n2):
|
| 134 |
+
label_cost = 0.0 if labels1[i] == labels2[j] else 1.0
|
| 135 |
+
neighborhood_cost = abs(degrees1[i] - degrees2[j])
|
| 136 |
+
C[i, j] = label_cost + neighborhood_cost
|
| 137 |
+
return C
|
| 138 |
+
|
| 139 |
+
@njit
|
| 140 |
+
def greedy_assignment(C):
|
| 141 |
+
n1, n2 = C.shape
|
| 142 |
+
assigned_cols = np.full(n2, False)
|
| 143 |
+
row_ind = np.full(n1, -1, dtype=np.int32)
|
| 144 |
+
for i in range(n1):
|
| 145 |
+
min_cost = np.inf
|
| 146 |
+
min_j = -1
|
| 147 |
+
for j in range(n2):
|
| 148 |
+
if not assigned_cols[j] and C[i, j] < min_cost:
|
| 149 |
+
min_cost = C[i, j]
|
| 150 |
+
min_j = j
|
| 151 |
+
if min_j != -1:
|
| 152 |
+
row_ind[i] = min_j
|
| 153 |
+
assigned_cols[min_j] = True
|
| 154 |
+
return row_ind
|
| 155 |
+
|
| 156 |
+
@njit
|
| 157 |
+
def compute_total_cost(C, row_ind, n1, n2, c_node_del, c_node_ins):
|
| 158 |
+
total_cost = 0.0
|
| 159 |
+
assigned_cols = np.full(n2, False)
|
| 160 |
+
for i in range(n1):
|
| 161 |
+
j = row_ind[i]
|
| 162 |
+
if j != -1:
|
| 163 |
+
total_cost += C[i, j]
|
| 164 |
+
assigned_cols[j] = True
|
| 165 |
+
else:
|
| 166 |
+
total_cost += c_node_del
|
| 167 |
+
for j in range(n2):
|
| 168 |
+
if not assigned_cols[j]:
|
| 169 |
+
total_cost += c_node_ins
|
| 170 |
+
return total_cost
|
| 171 |
+
|
| 172 |
+
def approximate_graph_edit_distance(adj1, labels1, adj2, labels2, c_node_del=1.0, c_node_ins=1.0):
|
| 173 |
+
degrees1 = adj1.sum(axis=1)
|
| 174 |
+
degrees2 = adj2.sum(axis=1)
|
| 175 |
+
C = compute_cost_matrix(labels1, labels2, degrees1, degrees2)
|
| 176 |
+
row_ind = greedy_assignment(C)
|
| 177 |
+
total_cost = compute_total_cost(C, row_ind, labels1.shape[0], labels2.shape[0], c_node_del, c_node_ins)
|
| 178 |
+
return total_cost
|
| 179 |
+
|
| 180 |
+
def get_approximate_ged(G1, G2):
|
| 181 |
+
label_to_int = build_label_mapping(G1, G2)
|
| 182 |
+
adj1, labels1 = preprocess_graph(G1, label_to_int)
|
| 183 |
+
adj2, labels2 = preprocess_graph(G2, label_to_int)
|
| 184 |
+
cost = approximate_graph_edit_distance(adj1, labels1, adj2, labels2)
|
| 185 |
+
return cost
|
| 186 |
+
|
| 187 |
+
def get_smiles2graph_edit_distance_approx(smi1: str, smi2: str) -> float:
|
| 188 |
+
G1 = smiles2graph(smi1)
|
| 189 |
+
G2 = smiles2graph(smi2)
|
| 190 |
+
return get_approximate_ged(G1, G2)
|
protac_splitter/llms/__init__.py
ADDED
|
File without changes
|
protac_splitter/llms/data_utils.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Optional, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from datasets import load_dataset, concatenate_datasets, Dataset
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
from rdkit import Chem
|
| 10 |
+
|
| 11 |
+
from protac_splitter.evaluation import split_prediction
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def randomize_smiles_dataset(
|
| 15 |
+
batch: dict,
|
| 16 |
+
repeat: int = 1,
|
| 17 |
+
prob: float = 0.5,
|
| 18 |
+
apply_to_text: bool = True,
|
| 19 |
+
apply_to_labels: bool = False,
|
| 20 |
+
) -> dict:
|
| 21 |
+
""" Randomize SMILES in a batch of data.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
batch (dict): Batch of data with "text" and "labels" keys.
|
| 25 |
+
repeat (int, optional): Number of times to repeat the randomization. Defaults to 1.
|
| 26 |
+
prob (float, optional): Probability of randomizing SMILES. Defaults to 0.5.
|
| 27 |
+
apply_to_text (bool, optional): Whether to apply randomization to text. Defaults to True.
|
| 28 |
+
apply_to_labels (bool, optional): Whether to apply randomization to labels. Defaults to False.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
dict: Randomized batch of data.
|
| 32 |
+
"""
|
| 33 |
+
new_texts, new_labels = [], []
|
| 34 |
+
for text, label in zip(batch["text"], batch["labels"]):
|
| 35 |
+
try:
|
| 36 |
+
mol_text = Chem.MolFromSmiles(text)
|
| 37 |
+
mol_label = Chem.MolFromSmiles(label)
|
| 38 |
+
except Exception:
|
| 39 |
+
logging.error("Failed to convert SMILES to Mol!")
|
| 40 |
+
new_texts.append(text)
|
| 41 |
+
new_labels.append(label)
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
if random.random() < prob:
|
| 45 |
+
if apply_to_text:
|
| 46 |
+
rand_texts = [Chem.MolToSmiles(mol_text, canonical=False, doRandom=True) for _ in range(repeat)]
|
| 47 |
+
else:
|
| 48 |
+
rand_texts = [text] * repeat
|
| 49 |
+
|
| 50 |
+
if apply_to_labels:
|
| 51 |
+
rand_labels = [Chem.MolToSmiles(mol_label, canonical=False, doRandom=True) for _ in range(repeat)]
|
| 52 |
+
else:
|
| 53 |
+
rand_labels = [label] * repeat
|
| 54 |
+
|
| 55 |
+
new_texts.extend(rand_texts)
|
| 56 |
+
new_labels.extend(rand_labels)
|
| 57 |
+
else:
|
| 58 |
+
new_texts.append(text)
|
| 59 |
+
new_labels.append(label)
|
| 60 |
+
|
| 61 |
+
return {"text": new_texts, "labels": new_labels}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def process_data_to_model_inputs(
|
| 65 |
+
batch,
|
| 66 |
+
tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
|
| 67 |
+
encoder_max_length: int = 512,
|
| 68 |
+
decoder_max_length: int = 512,
|
| 69 |
+
):
|
| 70 |
+
if isinstance(tokenizer, str):
|
| 71 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
| 72 |
+
# tokenize the inputs and labels
|
| 73 |
+
inputs = tokenizer(batch["text"], truncation=True, max_length=encoder_max_length)
|
| 74 |
+
outputs = tokenizer(batch["labels"], truncation=True, max_length=decoder_max_length)
|
| 75 |
+
batch["input_ids"] = inputs.input_ids
|
| 76 |
+
batch["attention_mask"] = inputs.attention_mask
|
| 77 |
+
batch["labels"] = outputs.input_ids.copy()
|
| 78 |
+
|
| 79 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 80 |
+
# batch["input_ids"] = batch["input_ids"].to(device)
|
| 81 |
+
# batch["attention_mask"] = batch["attention_mask"].to(device)
|
| 82 |
+
# batch["labels"] = batch["labels"].to(device)
|
| 83 |
+
|
| 84 |
+
# Because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`.
|
| 85 |
+
# We have to make sure that the PAD token is ignored when calculating the loss.
|
| 86 |
+
# NOTE: Check the `ignore_index` argument in nn.CrossEntropyLoss.
|
| 87 |
+
# NOTE: The following is already done in the DataCollatorForSeq2Seq
|
| 88 |
+
# batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]
|
| 89 |
+
return batch
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_fragments_in_labels(labels: str, linkers_only_as_labels: bool = True) -> list[str]:
|
| 93 |
+
""" Get the fragments in the labels.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
labels (str): The labels.
|
| 97 |
+
linkers_only_as_labels (bool, optional): Whether to get only the linkers in the labels. Defaults to True.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
list[str]: The fragments in the labels.
|
| 101 |
+
"""
|
| 102 |
+
ligands = split_prediction(labels)
|
| 103 |
+
if linkers_only_as_labels:
|
| 104 |
+
return ligands.get("linker", None)
|
| 105 |
+
if None in ligands.values():
|
| 106 |
+
return None
|
| 107 |
+
return f"{ligands['e3']}.{ligands['poi']}"
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def load_tokenized_dataset(
|
| 111 |
+
dataset_dir: str,
|
| 112 |
+
dataset_config: str = 'default',
|
| 113 |
+
tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
|
| 114 |
+
batch_size: int = 512,
|
| 115 |
+
encoder_max_length: int = 512,
|
| 116 |
+
decoder_max_length: int = 512,
|
| 117 |
+
token: Optional[str] = None,
|
| 118 |
+
num_proc_map: int = 1,
|
| 119 |
+
randomize_smiles: bool = False,
|
| 120 |
+
randomize_smiles_prob: float = 0.5,
|
| 121 |
+
randomize_smiles_repeat: int = 1,
|
| 122 |
+
randomize_text: bool = True,
|
| 123 |
+
randomize_labels: bool = False,
|
| 124 |
+
cache_dir: Optional[str] = None,
|
| 125 |
+
all_fragments_as_labels: bool = True,
|
| 126 |
+
linkers_only_as_labels: bool = False,
|
| 127 |
+
causal_language_modeling: bool = False,
|
| 128 |
+
train_size_ratio: float = 1.0,
|
| 129 |
+
) -> Dataset:
|
| 130 |
+
""" Load dataset and tokenize it.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
dataset_dir (str): The directory of the dataset or the name of the data on the Hugging Face Hub.
|
| 134 |
+
dataset_config (str, optional): The configuration of the dataset. Defaults to 'default'.
|
| 135 |
+
tokenizer (AutoTokenizer | str, optional): The tokenizer to use for tokenization. If a string, the tokenizer will be loaded using `AutoTokenizer.from_pretrained(tokenizer)`. Defaults to "seyonec/ChemBERTa-zinc-base-v1".
|
| 136 |
+
batch_size (int, optional): The batch size for tokenization. Defaults to 512.
|
| 137 |
+
encoder_max_length (int, optional): The maximum length of the encoder input sequence. Defaults to 512.
|
| 138 |
+
decoder_max_length (int, optional): The maximum length of the decoder input sequence. Defaults to 512.
|
| 139 |
+
token (Optional[str], optional): The Hugging Face API token. Defaults to None.
|
| 140 |
+
num_proc_map (int, optional): The number of processes to use for mapping. Defaults to 1.
|
| 141 |
+
randomize_smiles (bool, optional): Whether to randomize SMILES. Defaults to False.
|
| 142 |
+
randomize_smiles_prob (float, optional): The probability of randomizing SMILES. Defaults to 0.5.
|
| 143 |
+
randomize_smiles_repeat (int, optional): The number of times to repeat the randomization. Defaults to 1.
|
| 144 |
+
randomize_text (bool, optional): Whether to randomize text. Defaults to True.
|
| 145 |
+
randomize_labels (bool, optional): Whether to randomize labels. Defaults to False.
|
| 146 |
+
cache_dir (Optional[str], optional): The directory to cache the dataset. Defaults to None.
|
| 147 |
+
all_fragments_as_labels (bool, optional): Whether to get all fragments in the labels. Defaults to True.
|
| 148 |
+
linkers_only_as_labels (bool, optional): Whether to get only the linkers in the labels. Defaults to False.
|
| 149 |
+
causal_language_modeling (bool, optional): Whether to use causal language modeling. Defaults to False.
|
| 150 |
+
train_size_ratio (float, optional): The ratio of the training dataset to use. Defaults to 1.0.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Dataset: The tokenized dataset.
|
| 154 |
+
"""
|
| 155 |
+
if isinstance(tokenizer, str):
|
| 156 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
| 157 |
+
if os.path.exists(dataset_dir):
|
| 158 |
+
# NOTE: We need a different argument to load a dataset from disk:
|
| 159 |
+
dataset = load_dataset(
|
| 160 |
+
dataset_dir,
|
| 161 |
+
data_dir=dataset_config,
|
| 162 |
+
)
|
| 163 |
+
print(f"Dataset loaded from disk at: \"{dataset_dir}\". Length: {dataset.num_rows}")
|
| 164 |
+
else:
|
| 165 |
+
dataset = load_dataset(
|
| 166 |
+
dataset_dir,
|
| 167 |
+
dataset_config,
|
| 168 |
+
token=token,
|
| 169 |
+
cache_dir=cache_dir,
|
| 170 |
+
)
|
| 171 |
+
print(f"Dataset loaded from hub. Length: {dataset.num_rows}")
|
| 172 |
+
|
| 173 |
+
if train_size_ratio < 1.0 and train_size_ratio > 0:
|
| 174 |
+
# Reduce the size of the training dataset but just selecting a fraction of the samples
|
| 175 |
+
dataset["train"] = dataset["train"].select(range(int(train_size_ratio * dataset["train"].num_rows)))
|
| 176 |
+
print(f"Reduced training dataset size to {train_size_ratio}. Length: {dataset.num_rows}")
|
| 177 |
+
elif train_size_ratio > 1.0 or train_size_ratio < 0:
|
| 178 |
+
raise ValueError("train_size_ratio must be between 0 and 1.")
|
| 179 |
+
|
| 180 |
+
if not all_fragments_as_labels:
|
| 181 |
+
dataset = dataset.map(
|
| 182 |
+
lambda x: {
|
| 183 |
+
"text": x["text"],
|
| 184 |
+
"labels": get_fragments_in_labels(x["labels"], linkers_only_as_labels),
|
| 185 |
+
},
|
| 186 |
+
batched=False,
|
| 187 |
+
num_proc=num_proc_map,
|
| 188 |
+
load_from_cache_file=True,
|
| 189 |
+
desc="Getting fragments in labels",
|
| 190 |
+
)
|
| 191 |
+
# Filter out the samples with None labels
|
| 192 |
+
dataset = dataset.filter(lambda x: x["labels"] is not None)
|
| 193 |
+
|
| 194 |
+
if linkers_only_as_labels:
|
| 195 |
+
print(f"Set labels to linkers only. Length: {dataset.num_rows}")
|
| 196 |
+
else:
|
| 197 |
+
print(f"Set labels to E3 and WH only. Length: {dataset.num_rows}")
|
| 198 |
+
|
| 199 |
+
if randomize_smiles:
|
| 200 |
+
dataset["train"] = dataset["train"].map(
|
| 201 |
+
randomize_smiles_dataset,
|
| 202 |
+
batched=True,
|
| 203 |
+
batch_size=batch_size,
|
| 204 |
+
fn_kwargs={
|
| 205 |
+
"repeat": randomize_smiles_repeat,
|
| 206 |
+
"prob": randomize_smiles_prob,
|
| 207 |
+
"apply_to_text": randomize_text,
|
| 208 |
+
"apply_to_labels": randomize_labels,
|
| 209 |
+
},
|
| 210 |
+
num_proc=num_proc_map,
|
| 211 |
+
load_from_cache_file=True,
|
| 212 |
+
desc="Randomizing SMILES",
|
| 213 |
+
)
|
| 214 |
+
print(f"Randomized SMILES in dataset. Length: {dataset.num_rows}")
|
| 215 |
+
|
| 216 |
+
if causal_language_modeling:
|
| 217 |
+
dataset = dataset.map(
|
| 218 |
+
lambda x: {
|
| 219 |
+
"text": x["text"] + "." + x["labels"],
|
| 220 |
+
"labels": x["labels"],
|
| 221 |
+
},
|
| 222 |
+
batched=False,
|
| 223 |
+
num_proc=num_proc_map,
|
| 224 |
+
load_from_cache_file=True,
|
| 225 |
+
desc="Setting labels to text",
|
| 226 |
+
)
|
| 227 |
+
print(f"Appended labels to text. Length: {dataset.num_rows}")
|
| 228 |
+
|
| 229 |
+
# NOTE: Remove the "labels" column if causal language modeling, since the
|
| 230 |
+
# DataCollatorForLM will automatically set the labels to the input_ids.
|
| 231 |
+
dataset = dataset.map(
|
| 232 |
+
process_data_to_model_inputs,
|
| 233 |
+
batched=True,
|
| 234 |
+
batch_size=batch_size,
|
| 235 |
+
remove_columns=["text", "labels"] if causal_language_modeling else ["text"],
|
| 236 |
+
fn_kwargs={
|
| 237 |
+
"tokenizer": tokenizer,
|
| 238 |
+
"encoder_max_length": encoder_max_length,
|
| 239 |
+
"decoder_max_length": decoder_max_length,
|
| 240 |
+
},
|
| 241 |
+
num_proc=num_proc_map,
|
| 242 |
+
load_from_cache_file=True,
|
| 243 |
+
desc="Tokenizing dataset",
|
| 244 |
+
)
|
| 245 |
+
print(f"Tokenized dataset. Length: {dataset.num_rows}")
|
| 246 |
+
|
| 247 |
+
return dataset
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def load_trl_dataset(
|
| 251 |
+
tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
|
| 252 |
+
token: Optional[str] = None,
|
| 253 |
+
max_length: int = 512,
|
| 254 |
+
dataset_name: str = "ailab-bio/PROTAC-Splitter-Dataset",
|
| 255 |
+
ds_config: str = "standard",
|
| 256 |
+
ds_unalabeled: Optional[str] = None,
|
| 257 |
+
) -> Dataset:
|
| 258 |
+
if isinstance(tokenizer, str):
|
| 259 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
| 260 |
+
# Load training data
|
| 261 |
+
train_dataset = load_dataset(
|
| 262 |
+
dataset_name,
|
| 263 |
+
ds_config,
|
| 264 |
+
split="train",
|
| 265 |
+
token=token,
|
| 266 |
+
)
|
| 267 |
+
train_dataset = train_dataset.rename_column("text", "query")
|
| 268 |
+
train_dataset = train_dataset.remove_columns(["labels"])
|
| 269 |
+
|
| 270 |
+
if ds_unalabeled is not None:
|
| 271 |
+
# Load un-labelled data
|
| 272 |
+
unlabeled_dataset = load_dataset(
|
| 273 |
+
dataset_name,
|
| 274 |
+
ds_unalabeled,
|
| 275 |
+
split="train",
|
| 276 |
+
token=token,
|
| 277 |
+
)
|
| 278 |
+
unlabeled_dataset = unlabeled_dataset.rename_column("text", "query")
|
| 279 |
+
unlabeled_dataset = unlabeled_dataset.remove_columns(["labels"])
|
| 280 |
+
# Concatenate datasets row-wise
|
| 281 |
+
dataset = concatenate_datasets([train_dataset, unlabeled_dataset])
|
| 282 |
+
else:
|
| 283 |
+
dataset = train_dataset
|
| 284 |
+
|
| 285 |
+
def tokenize(sample, tokenizer, max_length=512):
|
| 286 |
+
input_ids = tokenizer.encode(sample["query"], padding="max_length", max_length=max_length)
|
| 287 |
+
return {"input_ids": input_ids, "query": sample["query"]}
|
| 288 |
+
|
| 289 |
+
return dataset.map(lambda x: tokenize(x, tokenizer, max_length), batched=False)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def data_collator_for_trl(batch):
|
| 293 |
+
return {
|
| 294 |
+
"input_ids": [torch.tensor(x["input_ids"]) for x in batch],
|
| 295 |
+
"query": [x["query"] for x in batch],
|
| 296 |
+
}
|
protac_splitter/llms/evaluation.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
|
| 3 |
+
from transformers import AutoTokenizer, EvalPrediction
|
| 4 |
+
import numpy as np
|
| 5 |
+
from rdkit import Chem, DataStructs
|
| 6 |
+
import evaluate
|
| 7 |
+
import multiprocessing as mp
|
| 8 |
+
import datetime
|
| 9 |
+
|
| 10 |
+
from protac_splitter.evaluation import (
|
| 11 |
+
# is_valid_smiles,
|
| 12 |
+
# has_three_substructures,
|
| 13 |
+
# has_all_attachment_points,
|
| 14 |
+
# check_substructs,
|
| 15 |
+
score_prediction,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def process_predictions(args) -> list:
|
| 19 |
+
""" Process one iteration of the prediction scoring.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
args (tuple): Tuple of arguments for the scoring function.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
dict: The scores for the prediction.
|
| 26 |
+
"""
|
| 27 |
+
pred_smiles, protac_smiles, label_smiles, fpgen, compute_rdkit_metrics, compute_graph_metrics = args
|
| 28 |
+
scores = []
|
| 29 |
+
for protac, pred, label in zip(protac_smiles, pred_smiles, label_smiles):
|
| 30 |
+
scores.append(score_prediction(
|
| 31 |
+
protac_smiles=protac,
|
| 32 |
+
label_smiles=label,
|
| 33 |
+
pred_smiles=pred,
|
| 34 |
+
fpgen=fpgen,
|
| 35 |
+
compute_rdkit_metrics=compute_rdkit_metrics,
|
| 36 |
+
compute_graph_metrics=compute_graph_metrics,
|
| 37 |
+
graph_edit_kwargs={"timeout": 0.05},
|
| 38 |
+
))
|
| 39 |
+
return scores
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def decode_and_get_metrics(
|
| 43 |
+
pred: EvalPrediction,
|
| 44 |
+
tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
|
| 45 |
+
rouge = None, # Optional[evaluate.metrics.rouge.Rouge] = None,
|
| 46 |
+
fpgen = None, # Optional[Chem.rdFingerprintGenerator] = None,
|
| 47 |
+
compute_rdkit_metrics: bool = False,
|
| 48 |
+
compute_graph_metrics: bool = True,
|
| 49 |
+
num_proc: int = 1,
|
| 50 |
+
batch_size: int = 128,
|
| 51 |
+
use_nan_for_missing: bool = True,
|
| 52 |
+
causal_language_modeling: bool = False,
|
| 53 |
+
) -> dict[str, float]:
|
| 54 |
+
""" Compute metrics for tokenized PROTAC predictions.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
pred (transformers.EvalPrediction): The predictions from the model.
|
| 58 |
+
rouge (Rouge): The Rouge object to use for scoring. Example: `rouge = evaluate.load("rouge")`
|
| 59 |
+
tokenizer (AutoTokenizer | str): The tokenizer to use for decoding the predictions. If a string, the tokenizer will be loaded using `AutoTokenizer.from_pretrained(tokenizer)`. Default: "seyonec/ChemBERTa-zinc-base-v1"
|
| 60 |
+
fpgen (Chem.rdFingerprintGenerator): The fingerprint generator to use for computing the Tanimoto similarity. Default: `Chem.rdFingerprintGenerator.GetMorganGenerator(radius=8, fpSize=2048)`
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
dict[str, float]: A dictionary containing the scores for the predictions
|
| 64 |
+
"""
|
| 65 |
+
print(f"[{datetime.datetime.now()}] Starting decode_and_get_metrics (protac_splitter/llms/evaluation.py)")
|
| 66 |
+
|
| 67 |
+
if causal_language_modeling:
|
| 68 |
+
# NOTE: For causal language models, we only care about perplexity, so we
|
| 69 |
+
# only need the eval_loss, which is automatically added.
|
| 70 |
+
return {}
|
| 71 |
+
|
| 72 |
+
if isinstance(tokenizer, str):
|
| 73 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
| 74 |
+
|
| 75 |
+
labels_ids = pred.label_ids
|
| 76 |
+
pred_ids = pred.predictions
|
| 77 |
+
input_ids = pred.inputs
|
| 78 |
+
|
| 79 |
+
if causal_language_modeling:
|
| 80 |
+
# The prediction logits will be of shape: (batch_size, sequence_length, vocabulary_size)
|
| 81 |
+
# So we need to get the argmax of the last dimension to get the
|
| 82 |
+
# predicted token IDs.
|
| 83 |
+
# NOTE: Not exactly the same as what would happen during generation, but
|
| 84 |
+
# hopefully it's close enough to assess model performance during
|
| 85 |
+
# training.
|
| 86 |
+
pred_ids = np.argmax(pred_ids, axis=-1)
|
| 87 |
+
|
| 88 |
+
# Replace -100 in the IDs with the tokenizer pad token id
|
| 89 |
+
# NOTE: Check the `ignore_index` argument in nn.CrossEntropyLoss.
|
| 90 |
+
# TODO: Understand why this needs to be done to the inputs as well
|
| 91 |
+
ignore_index = -100
|
| 92 |
+
labels_ids[labels_ids == ignore_index] = tokenizer.pad_token_id
|
| 93 |
+
pred_ids[pred_ids == ignore_index] = tokenizer.pad_token_id
|
| 94 |
+
|
| 95 |
+
# Get strings from IDs
|
| 96 |
+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
| 97 |
+
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
| 98 |
+
|
| 99 |
+
if not causal_language_modeling:
|
| 100 |
+
input_ids[input_ids == ignore_index] = tokenizer.pad_token_id
|
| 101 |
+
input_str = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
| 102 |
+
else:
|
| 103 |
+
# NOTE: For causal language models, i.e., decoder only, the input PROTAC
|
| 104 |
+
# is in the label. Therefore, we need to decode the label to get the
|
| 105 |
+
# input. The label looks something like "PROTAC.E3.Linker.WH", so we
|
| 106 |
+
# need to split it and get the last (three) parts.
|
| 107 |
+
input_str = [str(s.split('.')[0]) for s in label_str]
|
| 108 |
+
label_str = ['.'.join(s.split('.')[1:]) for s in label_str]
|
| 109 |
+
pred_str = ['.'.join(s.split('.')[1:]) if '.' in s else s for s in pred_str]
|
| 110 |
+
|
| 111 |
+
# Get scores
|
| 112 |
+
if num_proc == 1:
|
| 113 |
+
scores = process_predictions((
|
| 114 |
+
pred_str, input_str, label_str, fpgen, compute_rdkit_metrics, compute_graph_metrics
|
| 115 |
+
))
|
| 116 |
+
else:
|
| 117 |
+
# Use pools to process batches of predictions
|
| 118 |
+
with mp.Pool(processes=num_proc) as pool:
|
| 119 |
+
scores = []
|
| 120 |
+
for i in range(0, len(pred_str), batch_size):
|
| 121 |
+
scores += pool.map(process_predictions, [
|
| 122 |
+
(pred_str[i:i+batch_size], input_str[i:i+batch_size], label_str[i:i+batch_size], fpgen, compute_rdkit_metrics, compute_graph_metrics)
|
| 123 |
+
])
|
| 124 |
+
# Flatten the list of scores
|
| 125 |
+
scores = [s for ls in scores for s in ls]
|
| 126 |
+
|
| 127 |
+
# Aggregate scores
|
| 128 |
+
scores_labels = set()
|
| 129 |
+
for s in scores:
|
| 130 |
+
scores_labels.update(s.keys())
|
| 131 |
+
|
| 132 |
+
aggregated_scores = {}
|
| 133 |
+
for k in scores_labels:
|
| 134 |
+
values = np.array([s.get(k, np.nan) for s in scores], dtype=float)
|
| 135 |
+
|
| 136 |
+
# If values is all NaN, set the aggregated score to NaN and continue
|
| 137 |
+
if np.all(np.isnan(values)):
|
| 138 |
+
aggregated_scores[k] = None
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
# Compute average, excluding `NaN` values if necessary
|
| 142 |
+
if use_nan_for_missing:
|
| 143 |
+
aggregated_scores[k] = np.nanmean(values)
|
| 144 |
+
else:
|
| 145 |
+
valid_values = values[~np.isnan(values)]
|
| 146 |
+
aggregated_scores[k] = np.mean(valid_values) if valid_values.size > 0 else float('nan')
|
| 147 |
+
|
| 148 |
+
# Get Rouge score
|
| 149 |
+
if rouge is not None:
|
| 150 |
+
rouge_output = rouge.compute(predictions=pred_str, references=label_str)
|
| 151 |
+
aggregated_scores.update({k: v for k, v in rouge_output.items()})
|
| 152 |
+
|
| 153 |
+
# TODO
|
| 154 |
+
# # Get tanimoto score
|
| 155 |
+
# pred_str = np.array(pred_str)[valid_smiles == 1]
|
| 156 |
+
# label_str = np.array(label_str)[valid_smiles == 1]
|
| 157 |
+
# if len(pred_str) == 0:
|
| 158 |
+
# scores['tanimoto'] = 0.0
|
| 159 |
+
# return scores
|
| 160 |
+
# pred_mols = [Chem.MolFromSmiles(s) for s in pred_str]
|
| 161 |
+
# label_mols = [Chem.MolFromSmiles(s) for s in label_str]
|
| 162 |
+
# pred_fps = [fpgen.GetFingerprint(m) for m in pred_mols]
|
| 163 |
+
# label_fps = [fpgen.GetFingerprint(m) for m in label_mols]
|
| 164 |
+
# tanimoto = [DataStructs.TanimotoSimilarity(l, p) for l, p in zip(label_fps, pred_fps)]
|
| 165 |
+
# scores['tanimoto'] = np.array(tanimoto).mean()
|
| 166 |
+
|
| 167 |
+
print(f"[{datetime.datetime.now()}] Done with decode_and_get_metrics (protac_splitter/llms/evaluation.py)")
|
| 168 |
+
|
| 169 |
+
return aggregated_scores
|
protac_splitter/llms/hf_utils.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Hugging Face Hub utilities for repository management and file uploads. """
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import huggingface_hub as hf
|
| 5 |
+
from huggingface_hub import repo_info
|
| 6 |
+
from huggingface_hub.utils import RepositoryNotFoundError
|
| 7 |
+
|
| 8 |
+
def repo_exists(repo_id: str, token: Optional[str] = None) -> bool:
|
| 9 |
+
""" Checks if a Hugging Face repository exists. """
|
| 10 |
+
try:
|
| 11 |
+
print(repo_info(repo_id, token=token))
|
| 12 |
+
return True
|
| 13 |
+
except RepositoryNotFoundError:
|
| 14 |
+
return False
|
| 15 |
+
|
| 16 |
+
def create_hf_repository(**kwargs):
|
| 17 |
+
"""Creates a new Hugging Face repository."""
|
| 18 |
+
api = hf.HfApi()
|
| 19 |
+
return api.create_repo(**kwargs)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def delete_hf_repository(**kwargs):
|
| 23 |
+
"""Creates a new Hugging Face repository."""
|
| 24 |
+
print(f'Deleting repository {kwargs["repo_id"]}.')
|
| 25 |
+
api = hf.HfApi()
|
| 26 |
+
return api.delete_repo(**kwargs)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def upload_single_file(**kwargs):
|
| 30 |
+
"""Uploads a single file to a Hugging Face repository."""
|
| 31 |
+
try:
|
| 32 |
+
api = hf.HfApi()
|
| 33 |
+
api.upload_file(**kwargs)
|
| 34 |
+
except Exception as e:
|
| 35 |
+
print(e)
|
| 36 |
+
print("WARNING. Best parameters NOT pushed to the hub.")
|
protac_splitter/llms/model_utils.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Hugging Face utilities for model loading and pipeline creation. """
|
| 2 |
+
from typing import Optional, List, Dict, Union
|
| 3 |
+
from datasets import Dataset
|
| 4 |
+
from transformers import (
|
| 5 |
+
AutoTokenizer,
|
| 6 |
+
EncoderDecoderModel,
|
| 7 |
+
AutoModelForCausalLM,
|
| 8 |
+
pipeline,
|
| 9 |
+
GenerationConfig,
|
| 10 |
+
)
|
| 11 |
+
from transformers.pipelines.pt_utils import KeyDataset
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_encoder_decoder_model(
|
| 17 |
+
pretrained_encoder: str = "seyonec/ChemBERTa-zinc-base-v1",
|
| 18 |
+
pretrained_decoder: str = "seyonec/ChemBERTa-zinc-base-v1",
|
| 19 |
+
max_length: Optional[int] = 512,
|
| 20 |
+
tie_encoder_decoder: bool = False,
|
| 21 |
+
) -> EncoderDecoderModel:
|
| 22 |
+
""" Get the EncoderDecoderModel model for the PROTAC splitter.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
pretrained_encoder (str): The pretrained model to use for the encoder. Default: "seyonec/ChemBERTa-zinc-base-v1"
|
| 26 |
+
pretrained_decoder (str): The pretrained model to use for the decoder. Default: "seyonec/ChemBERTa-zinc-base-v1"
|
| 27 |
+
max_length (int): The maximum length of the input sequence. Default: 512
|
| 28 |
+
tie_encoder_decoder (bool): Whether to tie the encoder and decoder weights. Default: False
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
EncoderDecoderModel: The EncoderDecoderModel model for the PROTAC splitter
|
| 32 |
+
"""
|
| 33 |
+
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained(
|
| 34 |
+
pretrained_encoder,
|
| 35 |
+
pretrained_decoder,
|
| 36 |
+
tie_encoder_decoder=tie_encoder_decoder,
|
| 37 |
+
)
|
| 38 |
+
print(f"Number of parameters: {bert2bert.num_parameters():,}")
|
| 39 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_encoder)
|
| 40 |
+
# Tokenizer-related configs
|
| 41 |
+
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
|
| 42 |
+
bert2bert.config.eos_token_id = tokenizer.sep_token_id
|
| 43 |
+
bert2bert.config.pad_token_id = tokenizer.pad_token_id
|
| 44 |
+
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
|
| 45 |
+
# Generation configs
|
| 46 |
+
# NOTE: See full list of configurations can be found here: https://huggingface.co/docs/transformers/v4.33.3/en/main_classes/text_generation#transformers.GenerationConfig
|
| 47 |
+
bert2bert.encoder.config.max_length = max_length
|
| 48 |
+
bert2bert.decoder.config.max_length = max_length
|
| 49 |
+
|
| 50 |
+
def setup_gen(config):
|
| 51 |
+
config.do_sample = True
|
| 52 |
+
config.num_beams = 5
|
| 53 |
+
config.top_k = 20
|
| 54 |
+
config.max_length = 512
|
| 55 |
+
# config.max_new_tokens = 512
|
| 56 |
+
return config
|
| 57 |
+
|
| 58 |
+
bert2bert.config = setup_gen(bert2bert.config)
|
| 59 |
+
bert2bert.encoder.config = setup_gen(bert2bert.encoder.config)
|
| 60 |
+
bert2bert.decoder.config = setup_gen(bert2bert.decoder.config)
|
| 61 |
+
bert2bert.decoder.config.is_decoder = True
|
| 62 |
+
bert2bert.generation_config = setup_gen(bert2bert.generation_config)
|
| 63 |
+
|
| 64 |
+
# bert2bert.config.do_sample = True
|
| 65 |
+
# bert2bert.config.num_beams = 5
|
| 66 |
+
# bert2bert.config.top_k = 20
|
| 67 |
+
# bert2bert.config.max_length=512
|
| 68 |
+
# bert2bert.config.max_new_tokens=512
|
| 69 |
+
|
| 70 |
+
# bert2bert.generation_config.max_new_tokens = 512
|
| 71 |
+
# bert2bert.generation_config.min_new_tokens = 512
|
| 72 |
+
|
| 73 |
+
# bert2bert.config.max_new_tokens = 514
|
| 74 |
+
# bert2bert.config.early_stopping = True
|
| 75 |
+
# bert2bert.config.length_penalty = 2.0
|
| 76 |
+
# # bert2bert.config.no_repeat_ngram_size = 3 # Default: 0
|
| 77 |
+
|
| 78 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 79 |
+
bert2bert.to(device)
|
| 80 |
+
|
| 81 |
+
return bert2bert
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_causal_model(
|
| 85 |
+
pretrained_model: str = "seyonec/ChemBERTa-zinc-base-v1",
|
| 86 |
+
max_length: Optional[int] = 512,
|
| 87 |
+
) -> AutoModelForCausalLM:
|
| 88 |
+
""" Get the causal language model for the PROTAC splitter.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
pretrained_model (str): The pretrained model to use for the causal language model. Default: "seyonec/ChemBERTa-zinc-base-v1"
|
| 92 |
+
max_length (int): The maximum length of the input sequence. Default: 512
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
AutoModelForCausalLM: The causal language model for the PROTAC splitter
|
| 96 |
+
"""
|
| 97 |
+
model = AutoModelForCausalLM.from_pretrained(pretrained_model, is_decoder=True)
|
| 98 |
+
# model.is_decoder = True # It might not be necessary, but it's good to be explicit
|
| 99 |
+
|
| 100 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 101 |
+
model.to(device)
|
| 102 |
+
|
| 103 |
+
return model
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# REF: https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/generation/configuration_utils.py#L71
|
| 107 |
+
GENERATION_STRATEGY_PARAMS = {
|
| 108 |
+
"greedy": {"num_beams": 1, "do_sample": False},
|
| 109 |
+
"contrastive_search": {"penalty_alpha": 0.1, "top_k": 10},
|
| 110 |
+
"multinomial_sampling": {"num_beams": 1, "do_sample": True},
|
| 111 |
+
"beam_search_decoding": {"num_beams": 5, "do_sample": False, "num_return_sequences": 5},
|
| 112 |
+
"beam_search_multinomial_sampling": {"num_beams": 5, "do_sample": True, "num_return_sequences": 5},
|
| 113 |
+
"diverse_beam_search_decoding": {"num_beams": 5, "num_beam_groups": 5, "diversity_penalty": 1.0, "num_return_sequences": 5},
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
def avail_generation_strategies() -> List[str]:
|
| 117 |
+
""" Get the available generation strategies. """
|
| 118 |
+
return list(GENERATION_STRATEGY_PARAMS.keys())
|
| 119 |
+
|
| 120 |
+
def get_generation_config(generation_strategy: str) -> GenerationConfig:
|
| 121 |
+
""" Get the generation config for the given generation strategy. """
|
| 122 |
+
return GenerationConfig(
|
| 123 |
+
max_length=512,
|
| 124 |
+
max_new_tokens=512,
|
| 125 |
+
**GENERATION_STRATEGY_PARAMS[generation_strategy],
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def get_pipeline(
|
| 129 |
+
model_name: str,
|
| 130 |
+
token: str,
|
| 131 |
+
is_causal_language_model: bool,
|
| 132 |
+
generation_strategy: Optional[str] = None,
|
| 133 |
+
num_return_sequences: int = 1,
|
| 134 |
+
device: Optional[Union[int, str]] = None,
|
| 135 |
+
) -> pipeline:
|
| 136 |
+
""" Get the pipeline for the given model name and generation strategy.
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
"""
|
| 141 |
+
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
| 142 |
+
if is_causal_language_model and generation_strategy is None:
|
| 143 |
+
print('Loading pipeline for causal language models...')
|
| 144 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, padding_side='left')
|
| 145 |
+
return pipeline(
|
| 146 |
+
"text-generation",
|
| 147 |
+
model=model_name,
|
| 148 |
+
tokenizer=tokenizer,
|
| 149 |
+
token=token,
|
| 150 |
+
device=device,
|
| 151 |
+
num_return_sequences=num_return_sequences,
|
| 152 |
+
)
|
| 153 |
+
if is_causal_language_model and generation_strategy is not None:
|
| 154 |
+
print('Loading pipeline for causal language models...')
|
| 155 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, padding_side='left')
|
| 156 |
+
return pipeline(
|
| 157 |
+
"text-generation",
|
| 158 |
+
model=model_name,
|
| 159 |
+
tokenizer=tokenizer,
|
| 160 |
+
token=token,
|
| 161 |
+
device=device,
|
| 162 |
+
generation_config=get_generation_config(generation_strategy),
|
| 163 |
+
)
|
| 164 |
+
if not is_causal_language_model and generation_strategy is None:
|
| 165 |
+
print('Loading pipeline for sequence-to-sequence models...')
|
| 166 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
|
| 167 |
+
return pipeline(
|
| 168 |
+
"text2text-generation",
|
| 169 |
+
model=model_name,
|
| 170 |
+
tokenizer=tokenizer,
|
| 171 |
+
token=token,
|
| 172 |
+
device=device,
|
| 173 |
+
)
|
| 174 |
+
if not is_causal_language_model and generation_strategy is not None:
|
| 175 |
+
print('Loading pipeline for sequence-to-sequence models...')
|
| 176 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
|
| 177 |
+
return pipeline(
|
| 178 |
+
"text2text-generation",
|
| 179 |
+
model=model_name,
|
| 180 |
+
tokenizer=tokenizer,
|
| 181 |
+
token=token,
|
| 182 |
+
device=device,
|
| 183 |
+
generation_config=get_generation_config(generation_strategy),
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def run_causal_pipeline(
|
| 187 |
+
pipe: pipeline,
|
| 188 |
+
test_ds: Dataset,
|
| 189 |
+
batch_size: int,
|
| 190 |
+
smiles_column: str = 'prompt',
|
| 191 |
+
) -> List[Dict[str, str]]:
|
| 192 |
+
""" Run the pipeline for causal language models and return the predictions.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
pipe (pipeline): The pipeline object to use for generating predictions.
|
| 196 |
+
test_ds (Dataset): The test dataset to generate predictions for.
|
| 197 |
+
batch_size (int): The batch size to use for generating predictions.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
List[Dict[str, str]]: A list of dictionaries containing the predictions.
|
| 201 |
+
"""
|
| 202 |
+
preds = []
|
| 203 |
+
for pred in tqdm(pipe(KeyDataset(test_ds, smiles_column), batch_size=batch_size, max_length=512), total=len(test_ds) // batch_size):
|
| 204 |
+
generated_text = [p['generated_text'] for p in pred]
|
| 205 |
+
# Remove the prompt from the generated text
|
| 206 |
+
generated_text = ['.'.join(t.split('.')[1:]) for t in generated_text]
|
| 207 |
+
# Add the predictions to the list
|
| 208 |
+
p = {f'pred_n{i}': t for i, t in enumerate(generated_text)}
|
| 209 |
+
preds.append(p)
|
| 210 |
+
return preds
|
| 211 |
+
|
| 212 |
+
def run_seq2seq_pipeline(
|
| 213 |
+
pipe: pipeline,
|
| 214 |
+
test_ds: Dataset,
|
| 215 |
+
batch_size: int,
|
| 216 |
+
smiles_column: str = 'text',
|
| 217 |
+
) -> List[Dict[str, str]]:
|
| 218 |
+
""" Run the pipeline for sequence-to-sequence models and return the predictions.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
pipe (pipeline): The pipeline object to use for generating predictions.
|
| 222 |
+
test_ds (Dataset): The test dataset to generate predictions for.
|
| 223 |
+
batch_size (int): The batch size to use for generating predictions.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
List[Dict[str, str]]: A list of dictionaries containing the predictions.
|
| 227 |
+
"""
|
| 228 |
+
preds = []
|
| 229 |
+
for pred in tqdm(pipe(KeyDataset(test_ds, smiles_column), batch_size=batch_size, max_length=512), total=len(test_ds) // batch_size):
|
| 230 |
+
p = {f'pred_n{i}': p['generated_text'] for i, p in enumerate(pred)}
|
| 231 |
+
preds.append(p)
|
| 232 |
+
return preds
|
| 233 |
+
|
| 234 |
+
def run_pipeline(
|
| 235 |
+
pipe: pipeline,
|
| 236 |
+
test_ds: Dataset,
|
| 237 |
+
batch_size: int,
|
| 238 |
+
is_causal_language_model: bool,
|
| 239 |
+
smiles_column: str = 'text',
|
| 240 |
+
) -> List[Dict[str, str]]:
|
| 241 |
+
""" Run the pipeline and return the predictions.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
pipe (pipeline): The pipeline object to use for generating predictions.
|
| 245 |
+
test_ds (Dataset): The test dataset to generate predictions for.
|
| 246 |
+
batch_size (int): The batch size to use for generating predictions.
|
| 247 |
+
is_causal_language_model (bool): Whether the model is a causal language model or not.
|
| 248 |
+
smiles_column (str): The column name in the dataset that contains the SMILES strings. Default: 'text'
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
List[Dict[str, str]]: A list of dictionaries containing the beam-size predictions in the format: [{'pred_n0': 'prediction_0', 'pred_n1': 'prediction_1', ...}, ...]
|
| 252 |
+
"""
|
| 253 |
+
if is_causal_language_model:
|
| 254 |
+
return run_causal_pipeline(pipe, test_ds, batch_size, smiles_column)
|
| 255 |
+
else:
|
| 256 |
+
return run_seq2seq_pipeline(pipe, test_ds, batch_size, smiles_column)
|
protac_splitter/llms/training.py
ADDED
|
@@ -0,0 +1,869 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional, Dict, Any, Callable, Tuple, Union
|
| 3 |
+
from functools import partial
|
| 4 |
+
import subprocess
|
| 5 |
+
import copy
|
| 6 |
+
import datetime
|
| 7 |
+
import logging
|
| 8 |
+
import math
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import numpy as np
|
| 13 |
+
import huggingface_hub as hf
|
| 14 |
+
from transformers import (
|
| 15 |
+
Trainer,
|
| 16 |
+
TrainingArguments,
|
| 17 |
+
Seq2SeqTrainer,
|
| 18 |
+
Seq2SeqTrainingArguments,
|
| 19 |
+
DataCollatorForSeq2Seq,
|
| 20 |
+
DataCollatorForLanguageModeling,
|
| 21 |
+
AutoTokenizer,
|
| 22 |
+
GenerationConfig,
|
| 23 |
+
TrainerCallback,
|
| 24 |
+
set_seed,
|
| 25 |
+
)
|
| 26 |
+
from accelerate.utils import write_basic_config
|
| 27 |
+
from accelerate import Accelerator
|
| 28 |
+
|
| 29 |
+
import optuna
|
| 30 |
+
from optuna.samplers import QMCSampler
|
| 31 |
+
from optuna.pruners import (
|
| 32 |
+
BasePruner,
|
| 33 |
+
HyperbandPruner,
|
| 34 |
+
ThresholdPruner,
|
| 35 |
+
PatientPruner,
|
| 36 |
+
MedianPruner,
|
| 37 |
+
)
|
| 38 |
+
from optuna.study._study_direction import StudyDirection
|
| 39 |
+
|
| 40 |
+
from .data_utils import load_tokenized_dataset
|
| 41 |
+
from .evaluation import decode_and_get_metrics
|
| 42 |
+
from .hf_utils import (
|
| 43 |
+
create_hf_repository,
|
| 44 |
+
delete_hf_repository,
|
| 45 |
+
repo_exists,
|
| 46 |
+
upload_single_file,
|
| 47 |
+
)
|
| 48 |
+
from .model_utils import get_encoder_decoder_model, get_causal_model
|
| 49 |
+
|
| 50 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use GPU with index 0
|
| 51 |
+
# logging.basicConfig(level=logging.DEBUG)
|
| 52 |
+
|
| 53 |
+
class PrintStepCallback(TrainerCallback):
|
| 54 |
+
|
| 55 |
+
def on_init_end(self, args, state, control, **kwargs):
|
| 56 |
+
print(f"[{datetime.datetime.now()}] Initialization complete. Training is starting.")
|
| 57 |
+
|
| 58 |
+
def on_step_begin(self, args, state, control, **kwargs):
|
| 59 |
+
if state.global_step % args.logging_steps == 0:
|
| 60 |
+
print(f"[{datetime.datetime.now()}] Global step: {state.global_step:,}")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class ScoreMetric:
|
| 64 |
+
|
| 65 |
+
def __init__(self):
|
| 66 |
+
self.batch_scores = []
|
| 67 |
+
|
| 68 |
+
def update(self, scores):
|
| 69 |
+
self.batch_scores.append(scores)
|
| 70 |
+
|
| 71 |
+
def compute(self):
|
| 72 |
+
all_labels = set()
|
| 73 |
+
for scores in self.batch_scores:
|
| 74 |
+
all_labels.update(scores.keys())
|
| 75 |
+
|
| 76 |
+
aggregate_scores = {}
|
| 77 |
+
for k in all_labels:
|
| 78 |
+
scores = [s.get(k, np.nan) for s in self.batch_scores]
|
| 79 |
+
print(f"{k}: {np.nanmean(scores):.4f}")
|
| 80 |
+
aggregate_scores[k] = np.nanmean(scores)
|
| 81 |
+
|
| 82 |
+
self.batch_scores = []
|
| 83 |
+
return aggregate_scores
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
score_metric = ScoreMetric()
|
| 87 |
+
hp_score_metric = ScoreMetric()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class WrappedEarlyStoppingPruner(BasePruner):
|
| 91 |
+
"""
|
| 92 |
+
Pruner that wraps another pruner and checks if the trial should be pruned.
|
| 93 |
+
It first evaluates the wrapped pruner and, if the wrapped pruner suggests
|
| 94 |
+
pruning, prune. Otherwise, evaluates based on a patience threshold with a
|
| 95 |
+
tolerance (min_delta) and eventually prunes.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
wrapped_pruner:
|
| 99 |
+
Wrapped pruner to check first. Pruning is only applied if this pruner recommends it.
|
| 100 |
+
patience:
|
| 101 |
+
Number of steps to wait for an improvement before pruning.
|
| 102 |
+
min_delta:
|
| 103 |
+
Minimum improvement required to reset patience.
|
| 104 |
+
n_warmup_steps:
|
| 105 |
+
Number of initial steps to skip the patience check.
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
wrapped_pruner: BasePruner,
|
| 111 |
+
patience: int,
|
| 112 |
+
min_delta: float = 0.0,
|
| 113 |
+
n_warmup_steps: int = 0,
|
| 114 |
+
) -> None:
|
| 115 |
+
if wrapped_pruner is None or not isinstance(wrapped_pruner, BasePruner):
|
| 116 |
+
raise ValueError(f"wrapped_pruner must be an instance of BasePruner but got {wrapped_pruner}.")
|
| 117 |
+
if patience < 0:
|
| 118 |
+
raise ValueError(f"patience cannot be negative but got {patience}.")
|
| 119 |
+
if min_delta < 0:
|
| 120 |
+
raise ValueError(f"min_delta cannot be negative but got {min_delta}.")
|
| 121 |
+
if n_warmup_steps < 0:
|
| 122 |
+
raise ValueError(f"n_warmup_steps cannot be negative but got {n_warmup_steps}.")
|
| 123 |
+
|
| 124 |
+
self._wrapped_pruner = wrapped_pruner
|
| 125 |
+
self._patience = patience
|
| 126 |
+
self._min_delta = min_delta
|
| 127 |
+
self._n_warmup_steps = n_warmup_steps
|
| 128 |
+
|
| 129 |
+
def prune(self, study: "optuna.study.Study", trial: "optuna.trial.FrozenTrial") -> bool:
|
| 130 |
+
step = trial.last_step
|
| 131 |
+
if step is None:
|
| 132 |
+
return False
|
| 133 |
+
|
| 134 |
+
intermediate_values = trial.intermediate_values
|
| 135 |
+
steps = np.asarray(list(intermediate_values.keys()))
|
| 136 |
+
|
| 137 |
+
# If there are insufficient steps or we are still in the warmup phase, do not prune.
|
| 138 |
+
if steps.size <= self._patience + 1 or step < self._n_warmup_steps:
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
# First, check the wrapped pruner. If it suggests pruning, prune.
|
| 142 |
+
if self._wrapped_pruner.prune(study, trial):
|
| 143 |
+
return True
|
| 144 |
+
|
| 145 |
+
steps.sort()
|
| 146 |
+
|
| 147 |
+
# This is the score patience steps ago
|
| 148 |
+
steps_before_patience = steps[: -self._patience - 1]
|
| 149 |
+
scores_before_patience = np.asarray(
|
| 150 |
+
list(intermediate_values[step] for step in steps_before_patience)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# And these are the scores after that
|
| 154 |
+
steps_after_patience = steps[-self._patience - 1 :]
|
| 155 |
+
scores_after_patience = np.asarray(
|
| 156 |
+
list(intermediate_values[step] for step in steps_after_patience)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
direction = study.direction
|
| 160 |
+
if direction == StudyDirection.MINIMIZE:
|
| 161 |
+
should_prune = np.nanmin(scores_before_patience) + self._min_delta < np.nanmin(
|
| 162 |
+
scores_after_patience
|
| 163 |
+
)
|
| 164 |
+
else:
|
| 165 |
+
should_prune = np.nanmax(scores_before_patience) - self._min_delta > np.nanmax(
|
| 166 |
+
scores_after_patience
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
return should_prune
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_lr_scheduler_kwargs(lr_scheduler_type: str) -> Dict[str, Any]:
|
| 173 |
+
""" Returns the default learning rate scheduler kwargs for a given type.
|
| 174 |
+
|
| 175 |
+
Reference: https://huggingface.co/docs/timm/en/reference/schedulers
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
lr_scheduler_type (str): The type of the learning rate scheduler.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
Dict[str, Any]: The default learning rate scheduler kwargs.
|
| 182 |
+
"""
|
| 183 |
+
if lr_scheduler_type == "cosine":
|
| 184 |
+
return {}
|
| 185 |
+
elif lr_scheduler_type == "cosine_with_restarts":
|
| 186 |
+
return {"num_cycles": 3}
|
| 187 |
+
elif lr_scheduler_type == "cosine_with_min_lr":
|
| 188 |
+
return {}
|
| 189 |
+
elif lr_scheduler_type == "polynomial":
|
| 190 |
+
return {"power": 1.0}
|
| 191 |
+
elif lr_scheduler_type == "reduce_lr_on_plateau":
|
| 192 |
+
return {"min_lr": 1e-6}
|
| 193 |
+
else:
|
| 194 |
+
raise ValueError(f"Unknown learning rate scheduler type: '{lr_scheduler_type}'")
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_best_hyperparameters(
|
| 198 |
+
model_init: Callable,
|
| 199 |
+
tokenizer: AutoTokenizer,
|
| 200 |
+
data_collator: Union[DataCollatorForSeq2Seq, DataCollatorForLanguageModeling],
|
| 201 |
+
compute_metrics: Callable,
|
| 202 |
+
dataset_tokenized: Dict[str, Any],
|
| 203 |
+
training_args: Dict[str, Any],
|
| 204 |
+
num_optuna_trials: int,
|
| 205 |
+
lr_scheduler_type: Optional[str] = None,
|
| 206 |
+
causal_language_modeling: bool = False,
|
| 207 |
+
all_fragments_as_labels: bool = True,
|
| 208 |
+
linkers_only_as_labels: bool = False,
|
| 209 |
+
) -> Tuple[float, Dict[str, Any], Dict[str, Any]]:
|
| 210 |
+
"""Runs an Optuna hyperparameter search to find the best hyperparameters.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
model_init (Callable): The model initialization function.
|
| 214 |
+
tokenizer (AutoTokenizer): The tokenizer.
|
| 215 |
+
data_collator (DataCollatorForSeq2Seq): The data collator.
|
| 216 |
+
compute_metrics (Callable): The compute metrics function.
|
| 217 |
+
dataset_tokenized (Dict[str, Any]): The tokenized dataset.
|
| 218 |
+
training_args (Dict[str, Any]): The training arguments.
|
| 219 |
+
num_optuna_trials (int): The number of Optuna trials.
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
Tuple[float, Dict[str, Any], Dict[str, Any]]: The best objective, the best hyperparameters, and the best training arguments.
|
| 223 |
+
"""
|
| 224 |
+
def optuna_hp_space(trial):
|
| 225 |
+
# NOTE: Tuning generation config is not implemented yet, please refer to this issue: https://github.com/huggingface/transformers/issues/33755
|
| 226 |
+
# Suggest hparams "shared" across all scheduler types
|
| 227 |
+
# learning_rate = trial.suggest_float("learning_rate", 1e-6, 1e-3, log=True)
|
| 228 |
+
# warmup_ratio = trial.suggest_float("warmup_ratio", 0.01, 0.1, step=0.01)
|
| 229 |
+
|
| 230 |
+
# Restrict learning rate closer to best-performing values
|
| 231 |
+
learning_rate = trial.suggest_float("learning_rate", 5e-6, 2e-4, log=True) # Previously 1e-6 to 1e-3
|
| 232 |
+
|
| 233 |
+
# Slightly adjust warmup ratio to avoid extreme values
|
| 234 |
+
warmup_ratio = trial.suggest_float("warmup_ratio", 0.02, 0.06, step=0.01) # Previously 0.01 to 0.1
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# NOTE: We might want to use QMCSampler instead of TPESampler, which
|
| 238 |
+
# doesn't support categorical parameters. Categories can be encoded as
|
| 239 |
+
# integers and then decoded back to the original categories.
|
| 240 |
+
|
| 241 |
+
# NOTE: According to the GitHub code, the number of training and warmup
|
| 242 |
+
# steps for the scheduler types are automatically set, we don't need to
|
| 243 |
+
# pass them in the lr_scheduler_kwargs.
|
| 244 |
+
|
| 245 |
+
if lr_scheduler_type is None:
|
| 246 |
+
lr_scheduler_types = ["cosine", "cosine_with_restarts", "reduce_lr_on_plateau"] # "cosine_with_min_lr", "polynomial"
|
| 247 |
+
suggested_lr_sched = trial.suggest_int("lr_scheduler_type", 0, len(lr_scheduler_types) - 1)
|
| 248 |
+
suggested_lr_sched = lr_scheduler_types[suggested_lr_sched]
|
| 249 |
+
lr_scheduler_kwargs = get_lr_scheduler_kwargs(lr_scheduler_type)
|
| 250 |
+
elif lr_scheduler_type == "cosine":
|
| 251 |
+
lr_scheduler_kwargs = {
|
| 252 |
+
"num_cycles": trial.suggest_float("num_cycles", 0.5, 10, step=0.5),
|
| 253 |
+
}
|
| 254 |
+
elif lr_scheduler_type == "cosine_with_restarts":
|
| 255 |
+
lr_scheduler_kwargs = {
|
| 256 |
+
"num_cycles": trial.suggest_int("num_cycles", 1, 10, step=1),
|
| 257 |
+
}
|
| 258 |
+
elif lr_scheduler_type == "reduce_lr_on_plateau":
|
| 259 |
+
lr_scheduler_kwargs = {
|
| 260 |
+
"min_lr": trial.suggest_float("min_lr", 1e-10, 1e-8, log=True), # Previously 1e-12 to 1e-9
|
| 261 |
+
"factor": trial.suggest_float("factor", 0.8, 0.98, step=0.01), # Previously 0.1 to 0.99
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
return {
|
| 265 |
+
"lr_scheduler_kwargs": lr_scheduler_kwargs,
|
| 266 |
+
"lr_scheduler_type": lr_scheduler_type if lr_scheduler_type is not None else suggested_lr_sched,
|
| 267 |
+
"learning_rate": learning_rate,
|
| 268 |
+
"warmup_ratio": warmup_ratio,
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
if causal_language_modeling:
|
| 272 |
+
def compute_objective(metrics: Dict[str, float]):
|
| 273 |
+
# NOTE: We want to minimize the model perplexity, which is the
|
| 274 |
+
# exponential of the negative log-likelihood loss. Optuna is setup
|
| 275 |
+
# to maximize the objective, so we return the negative perplexity.
|
| 276 |
+
return -math.exp(metrics["eval_loss"])
|
| 277 |
+
else:
|
| 278 |
+
if all_fragments_as_labels:
|
| 279 |
+
def compute_objective(metrics: Dict[str, float]):
|
| 280 |
+
# NOTE: Having a higher eval_reassembly score should also correspond
|
| 281 |
+
# to a low eval loss, so we just focus on the reassembly score.
|
| 282 |
+
return metrics["eval_all_ligands_equal"]
|
| 283 |
+
else:
|
| 284 |
+
if linkers_only_as_labels:
|
| 285 |
+
def compute_objective(metrics: Dict[str, float]):
|
| 286 |
+
return metrics["eval_linker_equal"]
|
| 287 |
+
else:
|
| 288 |
+
def compute_objective(metrics: Dict[str, float]):
|
| 289 |
+
return metrics["eval_e3_equal"] + metrics["eval_poi_equal"]
|
| 290 |
+
|
| 291 |
+
def hp_name(trial: Any) -> str:
|
| 292 |
+
trial_name = f"trial-number={trial.number}"
|
| 293 |
+
for hparam, value in trial.params.items():
|
| 294 |
+
# Check if the value is a float and round it to 3 decimals
|
| 295 |
+
if hparam == "learning_rate":
|
| 296 |
+
value = f"{value:.1e}"
|
| 297 |
+
elif isinstance(value, float):
|
| 298 |
+
value = f"{value:.3f}"
|
| 299 |
+
trial_name += f"-{hparam}={value}"
|
| 300 |
+
return trial_name
|
| 301 |
+
|
| 302 |
+
# Override the training steps
|
| 303 |
+
hp_training_args = copy.deepcopy(training_args)
|
| 304 |
+
hp_training_args["num_train_epochs"] = -1
|
| 305 |
+
hp_training_args["max_steps"] = 10_000
|
| 306 |
+
hp_training_args["eval_steps"] = 2500
|
| 307 |
+
hp_training_args["eval_delay"] = 5000 # TODO: Double check if this is needed
|
| 308 |
+
hp_training_args["logging_steps"] = 500
|
| 309 |
+
hp_training_args["save_steps"] = 5000
|
| 310 |
+
if not causal_language_modeling:
|
| 311 |
+
# Use greedy decoding for the evaluation during HP search
|
| 312 |
+
hp_training_args["generation_config"] = GenerationConfig(
|
| 313 |
+
max_length=512,
|
| 314 |
+
max_new_tokens=512,
|
| 315 |
+
do_sample=False,
|
| 316 |
+
num_beams=1,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
print("Hyperparameter search training arguments:")
|
| 320 |
+
for k, v in hp_training_args.items():
|
| 321 |
+
if 'token' in k:
|
| 322 |
+
continue
|
| 323 |
+
print(f" - {k}: {v}")
|
| 324 |
+
|
| 325 |
+
if causal_language_modeling:
|
| 326 |
+
TrainerClass = Trainer
|
| 327 |
+
TrainingArgumentsClass = TrainingArguments
|
| 328 |
+
else:
|
| 329 |
+
TrainerClass = Seq2SeqTrainer
|
| 330 |
+
TrainingArgumentsClass = Seq2SeqTrainingArguments
|
| 331 |
+
|
| 332 |
+
# Setup a "fake" Trainer for the hyperparameter search
|
| 333 |
+
trainer = TrainerClass(
|
| 334 |
+
model_init=model_init,
|
| 335 |
+
tokenizer=tokenizer,
|
| 336 |
+
data_collator=data_collator,
|
| 337 |
+
args=TrainingArgumentsClass(**hp_training_args),
|
| 338 |
+
compute_metrics=compute_metrics,
|
| 339 |
+
train_dataset=dataset_tokenized["train"],
|
| 340 |
+
eval_dataset=dataset_tokenized["validation"],
|
| 341 |
+
callbacks=[PrintStepCallback],
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Setup the Optuna pruner and sampler
|
| 345 |
+
max_warmup_ratio = 0.1
|
| 346 |
+
pruner = WrappedEarlyStoppingPruner(
|
| 347 |
+
MedianPruner(
|
| 348 |
+
n_startup_trials=0,
|
| 349 |
+
interval_steps=1,
|
| 350 |
+
n_warmup_steps=int(max_warmup_ratio * hp_training_args["max_steps"]),
|
| 351 |
+
),
|
| 352 |
+
patience=5, # Check every 5000 training steps
|
| 353 |
+
min_delta=0.01,
|
| 354 |
+
n_warmup_steps=int(max_warmup_ratio * hp_training_args["max_steps"]),
|
| 355 |
+
)
|
| 356 |
+
sampler = QMCSampler(scramble=True, seed=42)
|
| 357 |
+
|
| 358 |
+
# NOTE: The Trainer will return a BestRun object, not the Optuna trial
|
| 359 |
+
best_run = trainer.hyperparameter_search(
|
| 360 |
+
direction="maximize",
|
| 361 |
+
backend="optuna",
|
| 362 |
+
hp_space=optuna_hp_space,
|
| 363 |
+
hp_name=hp_name,
|
| 364 |
+
n_trials=num_optuna_trials,
|
| 365 |
+
compute_objective=compute_objective, # Default: Will sum over all metrics but loss
|
| 366 |
+
sampler=sampler,
|
| 367 |
+
pruner=pruner,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# Set the best hyperparameters in the original Trainer arguments
|
| 371 |
+
try:
|
| 372 |
+
print("-" * 80)
|
| 373 |
+
print(f"Best trial objective: {best_run.objective:.4f}. Summary: {best_run.run_summary}")
|
| 374 |
+
except Exception as e:
|
| 375 |
+
print(e)
|
| 376 |
+
print("WARNING. Best trial objective could not be printed.")
|
| 377 |
+
|
| 378 |
+
return best_run, hp_training_args
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def train_model(
|
| 382 |
+
model_id: str,
|
| 383 |
+
ds_name: str,
|
| 384 |
+
ds_config: str = 'default',
|
| 385 |
+
learning_rate: float = 5e-5,
|
| 386 |
+
max_steps: int = -1,
|
| 387 |
+
num_train_epochs: int = 40,
|
| 388 |
+
batch_size: int = 128,
|
| 389 |
+
batch_size_tokenizer: int = 512,
|
| 390 |
+
gradient_accumulation_steps: int = 4,
|
| 391 |
+
hub_token: Optional[str] = None,
|
| 392 |
+
organization: Optional[str] = None,
|
| 393 |
+
output_dir: str = "./models/",
|
| 394 |
+
tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
|
| 395 |
+
pretrained_encoder: str = "seyonec/ChemBERTa-zinc-base-v1",
|
| 396 |
+
pretrained_decoder: str = "seyonec/ChemBERTa-zinc-base-v1",
|
| 397 |
+
encoder_max_length: int = 512,
|
| 398 |
+
decoder_max_length: int = 512,
|
| 399 |
+
tie_encoder_decoder: bool = False,
|
| 400 |
+
delete_repo_if_exists: bool = False,
|
| 401 |
+
delete_local_repo_if_exists: bool = False,
|
| 402 |
+
training_args: Optional[Dict[str, Any]] = None,
|
| 403 |
+
resume_from_checkpoint: Optional[str] = None,
|
| 404 |
+
num_optuna_trials: int = 0,
|
| 405 |
+
num_proc_map: int = 1,
|
| 406 |
+
per_device_train_batch_size: Optional[int] = None,
|
| 407 |
+
per_device_eval_batch_size: Optional[int] = None,
|
| 408 |
+
lr_scheduler_type: Optional[str] = None,
|
| 409 |
+
cache_dir: Optional[str] = None,
|
| 410 |
+
randomize_smiles: bool = False,
|
| 411 |
+
randomize_smiles_prob: float = 0.0,
|
| 412 |
+
all_fragments_as_labels: bool = True,
|
| 413 |
+
linkers_only_as_labels: bool = False,
|
| 414 |
+
warmup_ratio: Optional[float] = None,
|
| 415 |
+
num_cycles: Optional[int] = None,
|
| 416 |
+
warmup_steps: Optional[int] = None,
|
| 417 |
+
causal_language_modeling: bool = False,
|
| 418 |
+
train_size_ratio: float = 1.0,
|
| 419 |
+
training_args_bin: Optional[str] = None,
|
| 420 |
+
):
|
| 421 |
+
"""Trains a model on a given dataset.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
model_id (str): The name of the model to be trained.
|
| 425 |
+
ds_name (str): The name of the dataset to be used for training.
|
| 426 |
+
ds_config (str, optional): The name of the dataset configuration to be used for training. Defaults to 'default'.
|
| 427 |
+
learning_rate (float, optional): The learning rate. Defaults to 5e-5.
|
| 428 |
+
max_steps (int, optional): The maximum number of training steps. Defaults to -1.
|
| 429 |
+
num_train_epochs (int, optional): The number of training epochs. Defaults to 40.
|
| 430 |
+
batch_size (int, optional): The batch size. Defaults to 128.
|
| 431 |
+
batch_size_tokenizer (int, optional): The batch size for the tokenizer. Defaults to 512.
|
| 432 |
+
gradient_accumulation_steps (int, optional): The number of gradient accumulation steps. Defaults to 4.
|
| 433 |
+
hub_token (Optional[str], optional): The Hugging Face token. Defaults to None.
|
| 434 |
+
organization (Optional[str], optional): The Hugging Face organization. Defaults to None.
|
| 435 |
+
output_dir (str, optional): The output directory. Defaults to "./models/".
|
| 436 |
+
tokenizer (AutoTokenizer | str, optional): The tokenizer. Defaults to "seyonec/ChemBERTa-zinc-base-v1".
|
| 437 |
+
pretrained_encoder (str, optional): The name of the pretrained encoder. Defaults to "seyonec/ChemBERTa-zinc-base-v1".
|
| 438 |
+
pretrained_decoder (str, optional): The name of the pretrained decoder. Defaults to "seyonec/ChemBERTa-zinc-base-v1".
|
| 439 |
+
encoder_max_length (int, optional): The maximum length of the encoder. Defaults to 256.
|
| 440 |
+
decoder_max_length (int, optional): The maximum length of the decoder. Defaults to 256.
|
| 441 |
+
delete_repo_if_exists (bool, optional): Whether to delete the repository first. Defaults to False.
|
| 442 |
+
training_args (Optional[Seq2SeqTrainingArguments], optional): The training arguments. Defaults to None.
|
| 443 |
+
resume_from_checkpoint (Optional[str], optional): The checkpoint to resume training from. Defaults to None.
|
| 444 |
+
num_optuna_trials (int, optional): The number of Optuna trials. Defaults to 0, i.e., no Optuna hyperparameter search.
|
| 445 |
+
"""
|
| 446 |
+
set_seed(42)
|
| 447 |
+
|
| 448 |
+
# if torch.cuda.is_available():
|
| 449 |
+
# write_basic_config(mixed_precision='fp16')
|
| 450 |
+
accelerator = Accelerator()
|
| 451 |
+
accelerator.print(f"Accelerator state from the current environment:\n{accelerator.state}")
|
| 452 |
+
|
| 453 |
+
# Check if resume_from_checkpoint exists and it's a file
|
| 454 |
+
if resume_from_checkpoint is not None:
|
| 455 |
+
# Check if the checkpoint exists: it can be either a file or a directory
|
| 456 |
+
if not os.path.exists(resume_from_checkpoint):
|
| 457 |
+
raise ValueError(f"Checkpoint file '{resume_from_checkpoint}' does not exist.")
|
| 458 |
+
|
| 459 |
+
if hub_token is not None:
|
| 460 |
+
hf.login(token=hub_token)
|
| 461 |
+
|
| 462 |
+
# Setup output directory and Hugging Face repository
|
| 463 |
+
output_dir += f"/{model_id}"
|
| 464 |
+
if organization is not None:
|
| 465 |
+
hub_model_id = f"{organization}/{model_id}"
|
| 466 |
+
if delete_local_repo_if_exists and os.path.exists(output_dir):
|
| 467 |
+
subprocess.run(["rm", "-rf", output_dir])
|
| 468 |
+
if not os.path.exists(output_dir):
|
| 469 |
+
print(f"Local repository '{output_dir}' deleted.")
|
| 470 |
+
else:
|
| 471 |
+
print(f"Local repository '{output_dir}' could not be deleted.")
|
| 472 |
+
return
|
| 473 |
+
if delete_repo_if_exists and repo_exists(hub_model_id, token=hub_token):
|
| 474 |
+
delete_hf_repository(repo_id=hub_model_id, token=hub_token, missing_ok=True)
|
| 475 |
+
print(f"Repository '{hub_model_id}' deleted.")
|
| 476 |
+
|
| 477 |
+
repo_url = create_hf_repository(
|
| 478 |
+
repo_id=hub_model_id,
|
| 479 |
+
repo_type="model",
|
| 480 |
+
exist_ok=True,
|
| 481 |
+
private=True,
|
| 482 |
+
token=hub_token,
|
| 483 |
+
)
|
| 484 |
+
print(f"Repository '{hub_model_id}' created at URL: {repo_url}")
|
| 485 |
+
else:
|
| 486 |
+
hub_model_id = None
|
| 487 |
+
print(f"Hub model ID: {hub_model_id}")
|
| 488 |
+
|
| 489 |
+
if isinstance(tokenizer, str):
|
| 490 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
| 491 |
+
elif tokenizer is None:
|
| 492 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_encoder)
|
| 493 |
+
|
| 494 |
+
# Load the tokenized dataset
|
| 495 |
+
print("Loading tokenized dataset.")
|
| 496 |
+
dataset_tokenized = load_tokenized_dataset(
|
| 497 |
+
ds_name,
|
| 498 |
+
ds_config,
|
| 499 |
+
tokenizer,
|
| 500 |
+
batch_size_tokenizer,
|
| 501 |
+
encoder_max_length,
|
| 502 |
+
decoder_max_length,
|
| 503 |
+
token=hub_token,
|
| 504 |
+
num_proc_map=num_proc_map,
|
| 505 |
+
cache_dir=cache_dir,
|
| 506 |
+
randomize_smiles=randomize_smiles,
|
| 507 |
+
randomize_smiles_prob=randomize_smiles_prob,
|
| 508 |
+
all_fragments_as_labels=all_fragments_as_labels,
|
| 509 |
+
linkers_only_as_labels=linkers_only_as_labels,
|
| 510 |
+
causal_language_modeling=causal_language_modeling,
|
| 511 |
+
train_size_ratio=train_size_ratio,
|
| 512 |
+
)
|
| 513 |
+
print("Dataset loaded.")
|
| 514 |
+
|
| 515 |
+
if causal_language_modeling:
|
| 516 |
+
# Setup the model for `model_init` in the Trainer
|
| 517 |
+
model_lambda = lambda: get_causal_model(
|
| 518 |
+
pretrained_model=pretrained_decoder,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# Setup the data collator, which will efficiently pad the inputs and targets
|
| 522 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 523 |
+
tokenizer,
|
| 524 |
+
mlm=False,
|
| 525 |
+
pad_to_multiple_of=8, # Default: None, Original: 8
|
| 526 |
+
)
|
| 527 |
+
else:
|
| 528 |
+
# Precompute a "length" column for the dataset using the map function
|
| 529 |
+
def add_length(x):
|
| 530 |
+
x["length"] = len(x["input_ids"])
|
| 531 |
+
return x
|
| 532 |
+
dataset_tokenized = dataset_tokenized.map(
|
| 533 |
+
add_length,
|
| 534 |
+
num_proc=num_proc_map,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
# Setup the model for `model_init` in the Trainer
|
| 538 |
+
model_lambda = lambda: get_encoder_decoder_model(
|
| 539 |
+
pretrained_encoder=pretrained_encoder,
|
| 540 |
+
pretrained_decoder=pretrained_decoder,
|
| 541 |
+
max_length=encoder_max_length,
|
| 542 |
+
tie_encoder_decoder=tie_encoder_decoder,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
# Setup the data collator, which will efficiently pad the inputs and targets
|
| 546 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 547 |
+
tokenizer,
|
| 548 |
+
model=model_lambda(),
|
| 549 |
+
pad_to_multiple_of=32, # Default: None, Original: 8
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
# Setup the training arguments
|
| 553 |
+
if per_device_train_batch_size is None:
|
| 554 |
+
per_device_train_batch_size = batch_size // gradient_accumulation_steps
|
| 555 |
+
if per_device_eval_batch_size is None:
|
| 556 |
+
per_device_eval_batch_size = batch_size // gradient_accumulation_steps
|
| 557 |
+
if training_args is None:
|
| 558 |
+
training_args = {
|
| 559 |
+
"output_dir": output_dir,
|
| 560 |
+
# Optimizer-related configs
|
| 561 |
+
"learning_rate": learning_rate,
|
| 562 |
+
"optim": "adamw_torch",
|
| 563 |
+
"lr_scheduler_type": "cosine" if lr_scheduler_type is None else lr_scheduler_type,
|
| 564 |
+
"lr_scheduler_kwargs": get_lr_scheduler_kwargs(lr_scheduler_type),
|
| 565 |
+
# "warmup_steps": int(0.08 * 10_000), # NOTE: ChemFormer: 8000
|
| 566 |
+
# "warmup_ratio": warmup_ratio,
|
| 567 |
+
"adam_beta1": 0.9, # NOTE: ChemFormer: 0.9
|
| 568 |
+
"adam_beta2": 0.999, # NOTE: ChemFormer: 0.999
|
| 569 |
+
"adam_epsilon": 1e-8, # Default: 1e-8
|
| 570 |
+
# Batch size, device, and performance optimizations configs
|
| 571 |
+
"batch_eval_metrics": False, # Default: False
|
| 572 |
+
"group_by_length": True,
|
| 573 |
+
"per_device_train_batch_size": per_device_train_batch_size,
|
| 574 |
+
"per_device_eval_batch_size": per_device_eval_batch_size,
|
| 575 |
+
"gradient_accumulation_steps": gradient_accumulation_steps,
|
| 576 |
+
"auto_find_batch_size": True,
|
| 577 |
+
"fp16": True if torch.cuda.is_available() else False,
|
| 578 |
+
"fp16_full_eval" : True, # Enable full BF16 evaluation for efficiency
|
| 579 |
+
"half_precision_backend" : "auto", # Let Hugging Face decide the best backend. Default: "auto"
|
| 580 |
+
"use_cpu": False, # Default: False
|
| 581 |
+
"dataloader_num_workers": 8, # Default: 0 (main process only)
|
| 582 |
+
"dataloader_prefetch_factor": None, # Default: None
|
| 583 |
+
# Evaluation and checkpointing configs
|
| 584 |
+
"max_steps": max_steps,
|
| 585 |
+
"num_train_epochs": num_train_epochs,
|
| 586 |
+
"save_steps": 20_000, # NOTE: 200
|
| 587 |
+
"save_strategy": "steps",
|
| 588 |
+
"eval_steps": 20_000, # NOTE: 500
|
| 589 |
+
"eval_delay": max(int(max(max_steps, num_train_epochs) * 0.7), 0), # Default: None
|
| 590 |
+
"eval_strategy": "steps", # NOTE: "evaluation_strategy" is deprecated.
|
| 591 |
+
"save_total_limit": 2, # This will save both the best and the last trainer checkpoint
|
| 592 |
+
"load_best_model_at_end": True,
|
| 593 |
+
"metric_for_best_model": "all_ligands_equal",
|
| 594 |
+
"include_inputs_for_metrics": True,
|
| 595 |
+
"eval_on_start": False, # Default: False
|
| 596 |
+
# Logging configs
|
| 597 |
+
"log_level": "debug",
|
| 598 |
+
"logging_steps": 5000,
|
| 599 |
+
"disable_tqdm": True,
|
| 600 |
+
"report_to": ["tensorboard"],
|
| 601 |
+
"save_only_model": False, # Default: False
|
| 602 |
+
# Hub information configs
|
| 603 |
+
"push_to_hub": hub_model_id is not None, # NOTE: Also manually done further down
|
| 604 |
+
"push_to_hub_model_id": model_id,
|
| 605 |
+
"push_to_hub_organization": organization,
|
| 606 |
+
"hub_model_id": hub_model_id,
|
| 607 |
+
"hub_token": hub_token,
|
| 608 |
+
"hub_strategy": "checkpoint", # NOTE: Allows to resume training from last checkpoint
|
| 609 |
+
"hub_private_repo": True,
|
| 610 |
+
# Other configs
|
| 611 |
+
"seed": 42,
|
| 612 |
+
"data_seed": 42,
|
| 613 |
+
}
|
| 614 |
+
if 'num_cycles' in training_args["lr_scheduler_kwargs"] and num_cycles is not None:
|
| 615 |
+
training_args["lr_scheduler_kwargs"]["num_cycles"] = num_cycles
|
| 616 |
+
if warmup_ratio is not None:
|
| 617 |
+
training_args["warmup_ratio"] = warmup_ratio
|
| 618 |
+
if warmup_steps is not None:
|
| 619 |
+
training_args["warmup_steps"] = warmup_steps
|
| 620 |
+
|
| 621 |
+
# Add Generation configs
|
| 622 |
+
if causal_language_modeling:
|
| 623 |
+
training_args["metric_for_best_model"] = "eval_loss"
|
| 624 |
+
else:
|
| 625 |
+
generation_config = GenerationConfig(
|
| 626 |
+
max_length=512,
|
| 627 |
+
max_new_tokens=512,
|
| 628 |
+
do_sample=True,
|
| 629 |
+
num_beams=5,
|
| 630 |
+
temperature=1.0,
|
| 631 |
+
)
|
| 632 |
+
training_args["generation_config"] = generation_config
|
| 633 |
+
training_args["predict_with_generate"] = True
|
| 634 |
+
training_args["generation_config"] = generation_config
|
| 635 |
+
training_args["generation_max_length"] = 512
|
| 636 |
+
|
| 637 |
+
print("Training arguments:")
|
| 638 |
+
for k, v in training_args.items():
|
| 639 |
+
if 'token' in k:
|
| 640 |
+
continue
|
| 641 |
+
print(f" - {k}: {v}")
|
| 642 |
+
|
| 643 |
+
# Modify the training arguments with Optuna hyperparameter search
|
| 644 |
+
if num_optuna_trials > 0:
|
| 645 |
+
# Setup the compute_metrics function for the hyperparameter search
|
| 646 |
+
hp_compute_metrics = partial(
|
| 647 |
+
decode_and_get_metrics,
|
| 648 |
+
tokenizer=tokenizer,
|
| 649 |
+
compute_rdkit_metrics=False,
|
| 650 |
+
compute_graph_metrics=False,
|
| 651 |
+
num_proc=num_proc_map,
|
| 652 |
+
causal_language_modeling=causal_language_modeling,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
# Run the HP search (and update the training_args accordingly)
|
| 656 |
+
best_run, hp_training_args = get_best_hyperparameters(
|
| 657 |
+
model_init=model_lambda,
|
| 658 |
+
tokenizer=tokenizer,
|
| 659 |
+
data_collator=data_collator,
|
| 660 |
+
compute_metrics=hp_compute_metrics,
|
| 661 |
+
dataset_tokenized=dataset_tokenized,
|
| 662 |
+
training_args=copy.deepcopy(training_args),
|
| 663 |
+
lr_scheduler_type=lr_scheduler_type,
|
| 664 |
+
num_optuna_trials=num_optuna_trials,
|
| 665 |
+
causal_language_modeling=causal_language_modeling,
|
| 666 |
+
all_fragments_as_labels=all_fragments_as_labels,
|
| 667 |
+
linkers_only_as_labels=linkers_only_as_labels,
|
| 668 |
+
)
|
| 669 |
+
best_objective = best_run.objective
|
| 670 |
+
best_trial_number = best_run.run_id
|
| 671 |
+
best_hparams = best_run.hyperparameters
|
| 672 |
+
|
| 673 |
+
# Save to output directory the best hyperparameters
|
| 674 |
+
with open(f"{output_dir}/best_hyperparameters.md", "w") as f:
|
| 675 |
+
f.write(f"Number of Optuna trials: {num_optuna_trials}\n\n")
|
| 676 |
+
f.write(f"Best trial objective: {best_objective:.4f} (best trial number: {best_trial_number})\n\n")
|
| 677 |
+
|
| 678 |
+
f.write("Best hyperparameters:\n")
|
| 679 |
+
for hparam, value in best_hparams.items():
|
| 680 |
+
f.write(f"- {hparam}: {value}\n")
|
| 681 |
+
f.write("\n")
|
| 682 |
+
|
| 683 |
+
f.write("Training arguments:\n")
|
| 684 |
+
for hparam, value in hp_training_args.items():
|
| 685 |
+
if "token" in hparam:
|
| 686 |
+
continue
|
| 687 |
+
elif isinstance(value, str):
|
| 688 |
+
if 'hf_' in value:
|
| 689 |
+
continue
|
| 690 |
+
f.write(f"- {hparam}: {value}\n")
|
| 691 |
+
|
| 692 |
+
# Open the file and remove any line that might contain the token
|
| 693 |
+
with open(f"{output_dir}/best_hyperparameters.md", "r") as f:
|
| 694 |
+
lines = f.readlines()
|
| 695 |
+
with open(f"{output_dir}/best_hyperparameters.md", "w") as f:
|
| 696 |
+
for line in lines:
|
| 697 |
+
if "hf_" in line:
|
| 698 |
+
continue
|
| 699 |
+
f.write(line)
|
| 700 |
+
print(f"Best hyperparameters saved to '{output_dir}/best_hyperparameters.md'.")
|
| 701 |
+
|
| 702 |
+
if hub_model_id is not None:
|
| 703 |
+
upload_single_file(
|
| 704 |
+
path_or_fileobj=f"{output_dir}/best_hyperparameters.md",
|
| 705 |
+
path_in_repo="best_hyperparameters.md",
|
| 706 |
+
repo_id=hub_model_id,
|
| 707 |
+
token=hub_token,
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
# Save the best_hparams to a JSON file
|
| 711 |
+
with open(f"{output_dir}/best_hyperparameters.json", "w") as f:
|
| 712 |
+
json.dump(best_hparams, f, indent=4)
|
| 713 |
+
print(f"Best hyperparameters saved to '{output_dir}/best_hyperparameters.json'.")
|
| 714 |
+
|
| 715 |
+
if hub_model_id is not None:
|
| 716 |
+
upload_single_file(
|
| 717 |
+
path_or_fileobj=f"{output_dir}/best_hyperparameters.json",
|
| 718 |
+
path_in_repo="best_hyperparameters.json",
|
| 719 |
+
repo_id=hub_model_id,
|
| 720 |
+
token=hub_token,
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
# Update the training arguments with the best hyperparameters
|
| 724 |
+
hp_specific_args = [
|
| 725 |
+
"num_train_epochs",
|
| 726 |
+
"max_steps",
|
| 727 |
+
"eval_steps",
|
| 728 |
+
"eval_delay",
|
| 729 |
+
"logging_steps",
|
| 730 |
+
"save_steps",
|
| 731 |
+
"generation_config",
|
| 732 |
+
]
|
| 733 |
+
for k, v in hp_training_args.items():
|
| 734 |
+
# Skip the specific arguments set/modifed by the HP search
|
| 735 |
+
if k in hp_specific_args:
|
| 736 |
+
continue
|
| 737 |
+
training_args[k] = v
|
| 738 |
+
|
| 739 |
+
# Update the num_cycles according to the original max_steps
|
| 740 |
+
lr_scheduler_kwargs = hp_training_args["lr_scheduler_kwargs"]
|
| 741 |
+
|
| 742 |
+
if "num_cycles" in lr_scheduler_kwargs:
|
| 743 |
+
hp_num_cycles = lr_scheduler_kwargs["num_cycles"]
|
| 744 |
+
hp_max_steps = hp_training_args["max_steps"]
|
| 745 |
+
|
| 746 |
+
# Adjust/scale the max_cycles according to the number of steps
|
| 747 |
+
if hp_max_steps > 0:
|
| 748 |
+
hp_cycle_ratio = hp_num_cycles / hp_max_steps
|
| 749 |
+
num_cycles = int(hp_cycle_ratio * max_steps)
|
| 750 |
+
training_args["lr_scheduler_kwargs"]["num_cycles"] = num_cycles
|
| 751 |
+
print(f"Adjusted number of cycles: {num_cycles}")
|
| 752 |
+
|
| 753 |
+
# Adjust the warmup steps according to the original max_steps
|
| 754 |
+
if "warmup_ratio" in hp_training_args:
|
| 755 |
+
hp_warmup_ratio = hp_training_args["warmup_ratio"]
|
| 756 |
+
hp_max_steps = hp_training_args["max_steps"]
|
| 757 |
+
warmup_steps = int(hp_warmup_ratio * hp_max_steps)
|
| 758 |
+
warmup_ratio = warmup_steps / max_steps
|
| 759 |
+
training_args["warmup_steps"] = warmup_steps
|
| 760 |
+
training_args["warmup_ratio"] = warmup_ratio
|
| 761 |
+
|
| 762 |
+
print("Training arguments updated with the best hyperparameters:")
|
| 763 |
+
for k, v in training_args.items():
|
| 764 |
+
if 'token' in k:
|
| 765 |
+
continue
|
| 766 |
+
print(f" - {k}: {v}")
|
| 767 |
+
print("-" * 80)
|
| 768 |
+
print("Starting training with the best hyperparameters.")
|
| 769 |
+
print("-" * 80)
|
| 770 |
+
|
| 771 |
+
# rouge = evaluate.load("rouge") # , cache_dir="/mimer/NOBACKUP/groups/naiss2023-6-290/stefano/.cache/huggingface/evaluate/")
|
| 772 |
+
# fpgen = Chem.rdFingerprintGenerator.GetMorganGenerator(
|
| 773 |
+
# radius=11,
|
| 774 |
+
# fpSize=1024,
|
| 775 |
+
# )
|
| 776 |
+
rouge = None
|
| 777 |
+
fpgen = None
|
| 778 |
+
compute_metrics = partial(
|
| 779 |
+
decode_and_get_metrics,
|
| 780 |
+
tokenizer=tokenizer,
|
| 781 |
+
rouge=rouge,
|
| 782 |
+
fpgen=fpgen,
|
| 783 |
+
compute_rdkit_metrics=False,
|
| 784 |
+
compute_graph_metrics=True,
|
| 785 |
+
num_proc=max(1, num_proc_map - 2), # NOTE: Use 2 less process for the metrics, since there will be a timeout logic
|
| 786 |
+
causal_language_modeling=causal_language_modeling,
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
if training_args_bin is not None:
|
| 790 |
+
print(f"Loading training arguments from: {training_args_bin}.")
|
| 791 |
+
# Load training arguments from a binary file and update model-specific arguments
|
| 792 |
+
args = torch.load(training_args_bin)
|
| 793 |
+
args.output_dir = output_dir
|
| 794 |
+
args.overwrite_output_dir = True if delete_local_repo_if_exists else False
|
| 795 |
+
args.push_to_hub_model_id = model_id
|
| 796 |
+
args.push_to_hub_organization = organization
|
| 797 |
+
args.hub_model_id = hub_model_id
|
| 798 |
+
args.hub_token = hub_token
|
| 799 |
+
# Print all the training arguments
|
| 800 |
+
print("Training arguments loaded:")
|
| 801 |
+
for k, v in args.__dict__.items():
|
| 802 |
+
if 'token' in k:
|
| 803 |
+
continue
|
| 804 |
+
print(f" - {k}: {v}")
|
| 805 |
+
else:
|
| 806 |
+
if causal_language_modeling:
|
| 807 |
+
args = TrainingArguments(**training_args)
|
| 808 |
+
else:
|
| 809 |
+
args = Seq2SeqTrainingArguments(**training_args)
|
| 810 |
+
|
| 811 |
+
if causal_language_modeling:
|
| 812 |
+
TrainerClass = Trainer
|
| 813 |
+
else:
|
| 814 |
+
TrainerClass = Seq2SeqTrainer
|
| 815 |
+
|
| 816 |
+
# Setup the Trainer and start training (no Optuna hyperparameter search)
|
| 817 |
+
trainer = TrainerClass(
|
| 818 |
+
model_init=model_lambda,
|
| 819 |
+
tokenizer=tokenizer,
|
| 820 |
+
data_collator=data_collator,
|
| 821 |
+
args=args,
|
| 822 |
+
compute_metrics=compute_metrics,
|
| 823 |
+
train_dataset=dataset_tokenized["train"],
|
| 824 |
+
eval_dataset=dataset_tokenized["test"],
|
| 825 |
+
)
|
| 826 |
+
if resume_from_checkpoint is not None:
|
| 827 |
+
trainer.train(
|
| 828 |
+
resume_from_checkpoint=resume_from_checkpoint,
|
| 829 |
+
)
|
| 830 |
+
else:
|
| 831 |
+
trainer.train()
|
| 832 |
+
print("-" * 80)
|
| 833 |
+
print("Training completed.")
|
| 834 |
+
print("-" * 80)
|
| 835 |
+
|
| 836 |
+
if causal_language_modeling:
|
| 837 |
+
tasks = ["Text Generation"]
|
| 838 |
+
else:
|
| 839 |
+
tasks = ["Text2Text Generation", "question-answering"]
|
| 840 |
+
|
| 841 |
+
tokenizer.save_pretrained(output_dir)
|
| 842 |
+
|
| 843 |
+
if hub_model_id is not None:
|
| 844 |
+
print("Pushing model to Hugging Face Hub.")
|
| 845 |
+
print("-" * 80)
|
| 846 |
+
trainer.push_to_hub(
|
| 847 |
+
commit_message="Initial version",
|
| 848 |
+
model_name=hub_model_id,
|
| 849 |
+
license="mit",
|
| 850 |
+
finetuned_from=f"{pretrained_encoder}",
|
| 851 |
+
tasks=tasks,
|
| 852 |
+
tags=["PROTAC", "cheminformatics"],
|
| 853 |
+
dataset=[ds_name],
|
| 854 |
+
dataset_args=[ds_config],
|
| 855 |
+
)
|
| 856 |
+
tokenizer.push_to_hub(
|
| 857 |
+
repo_id=hub_model_id,
|
| 858 |
+
commit_message="Upload tokenizer",
|
| 859 |
+
private=True,
|
| 860 |
+
token=hub_token,
|
| 861 |
+
tags=["PROTAC", "cheminformatics"],
|
| 862 |
+
)
|
| 863 |
+
else:
|
| 864 |
+
print("Pushing model to local directory.")
|
| 865 |
+
print("-" * 80)
|
| 866 |
+
trainer.save_model(output_dir)
|
| 867 |
+
tokenizer.save_pretrained(output_dir)
|
| 868 |
+
print(f"Model saved to '{output_dir}'.")
|
| 869 |
+
print("All done.")
|
protac_splitter/llms/training_causal_model.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
from typing import Dict, Any
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import TrainerCallback
|
| 6 |
+
from trl import SFTTrainer
|
| 7 |
+
from rdkit import Chem
|
| 8 |
+
|
| 9 |
+
from protac_splitter.llms.data_utils import load_tokenized_dataset
|
| 10 |
+
from protac_splitter.llms.model_utils import get_model
|
| 11 |
+
|
| 12 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use GPU if available
|
| 13 |
+
|
| 14 |
+
# Placeholder for a scoring function that evaluates the generated SMILES
|
| 15 |
+
def score_function(smiles1, predicted_smiles):
|
| 16 |
+
""" Evaluates the generated SMILES sequence based on validity. """
|
| 17 |
+
mol = Chem.MolFromSmiles(predicted_smiles)
|
| 18 |
+
return 1 if mol else 0 # Returns 1 if valid, 0 if invalid
|
| 19 |
+
|
| 20 |
+
# Custom Trainer subclass to integrate SMILES evaluation
|
| 21 |
+
class CustomSFTTrainer(SFTTrainer):
|
| 22 |
+
def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"):
|
| 23 |
+
if eval_dataset is None:
|
| 24 |
+
eval_dataset = self.eval_dataset
|
| 25 |
+
|
| 26 |
+
# Generate predictions
|
| 27 |
+
predictions = self.predict(eval_dataset)
|
| 28 |
+
generated_texts = self.tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True)
|
| 29 |
+
|
| 30 |
+
total_score = 0
|
| 31 |
+
total_samples = len(generated_texts)
|
| 32 |
+
|
| 33 |
+
for i, example in enumerate(eval_dataset):
|
| 34 |
+
input_text = example["text"] # Full input: "Smiles1 Smiles2.Smiles3.Smiles4"
|
| 35 |
+
smiles1 = input_text.split(" ")[0] # Extract Smiles1 (the prompt)
|
| 36 |
+
|
| 37 |
+
# Remove the prompt from the generated text to get the predicted completion
|
| 38 |
+
predicted_completion = generated_texts[i].removeprefix(smiles1).strip()
|
| 39 |
+
|
| 40 |
+
# Compute custom score
|
| 41 |
+
score = score_function(smiles1, predicted_completion)
|
| 42 |
+
total_score += score
|
| 43 |
+
|
| 44 |
+
# Compute average score
|
| 45 |
+
average_score = total_score / total_samples if total_samples > 0 else 0
|
| 46 |
+
|
| 47 |
+
# Log metrics
|
| 48 |
+
metrics = {f"{metric_key_prefix}_average_score": average_score}
|
| 49 |
+
self.log(metrics)
|
| 50 |
+
|
| 51 |
+
return metrics
|
| 52 |
+
|
| 53 |
+
def train():
|
| 54 |
+
""" Main training function """
|
| 55 |
+
model = get_model() # Load the model
|
| 56 |
+
tokenizer = model.tokenizer # Get tokenizer from model
|
| 57 |
+
|
| 58 |
+
# Load dataset
|
| 59 |
+
dataset = load_tokenized_dataset()
|
| 60 |
+
|
| 61 |
+
# Training arguments
|
| 62 |
+
training_args = {
|
| 63 |
+
"output_dir": "./trained_model",
|
| 64 |
+
"evaluation_strategy": "steps",
|
| 65 |
+
"save_strategy": "steps",
|
| 66 |
+
"logging_steps": 100,
|
| 67 |
+
"save_steps": 500,
|
| 68 |
+
"num_train_epochs": 3,
|
| 69 |
+
"per_device_train_batch_size": 8,
|
| 70 |
+
"per_device_eval_batch_size": 8,
|
| 71 |
+
"learning_rate": 5e-5,
|
| 72 |
+
"save_total_limit": 2,
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
# Initialize custom trainer
|
| 76 |
+
trainer = CustomSFTTrainer(
|
| 77 |
+
model=model,
|
| 78 |
+
args=training_args,
|
| 79 |
+
train_dataset=dataset["train"],
|
| 80 |
+
eval_dataset=dataset["validation"],
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Train model
|
| 84 |
+
trainer.train()
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
train()
|
protac_splitter/llms/training_mlm_model.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Train a masked language model (MLM) using an encoder-decoder architecture. """
|
| 2 |
+
import os
|
| 3 |
+
from typing import Optional, Dict, Any, Union
|
| 4 |
+
import subprocess
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import huggingface_hub as hf
|
| 8 |
+
from transformers import (
|
| 9 |
+
Trainer,
|
| 10 |
+
TrainingArguments,
|
| 11 |
+
DataCollatorForLanguageModeling,
|
| 12 |
+
AutoTokenizer,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from protac_splitter.llms.data_utils import load_tokenized_dataset
|
| 16 |
+
from protac_splitter.llms.hf_utils import (
|
| 17 |
+
create_hf_repository,
|
| 18 |
+
delete_hf_repository,
|
| 19 |
+
repo_exists,
|
| 20 |
+
)
|
| 21 |
+
from protac_splitter.llms.model_utils import get_encoder_decoder_model
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def compute_metrics_for_mlm(pred) -> Dict[str, float]:
|
| 25 |
+
"""Compute metrics for MLM predictions, i.e., perplexity."""
|
| 26 |
+
logits = pred.predictions[0] if isinstance(pred.predictions, tuple) else pred.predictions
|
| 27 |
+
labels = pred.label_ids
|
| 28 |
+
|
| 29 |
+
# Convert to torch tensors
|
| 30 |
+
logits = torch.tensor(logits)
|
| 31 |
+
labels = torch.tensor(labels)
|
| 32 |
+
|
| 33 |
+
# Compute masked loss
|
| 34 |
+
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
|
| 35 |
+
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
|
| 36 |
+
|
| 37 |
+
return {
|
| 38 |
+
"perplexity": torch.exp(loss).item(),
|
| 39 |
+
"loss": loss.item()
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def train_mlm_model(
|
| 44 |
+
model_id: str,
|
| 45 |
+
ds_name: str,
|
| 46 |
+
ds_config: str = 'default',
|
| 47 |
+
learning_rate: float = 5e-5,
|
| 48 |
+
max_steps: int = -1,
|
| 49 |
+
num_train_epochs: int = 40,
|
| 50 |
+
batch_size: int = 128,
|
| 51 |
+
batch_size_tokenizer: int = 512,
|
| 52 |
+
gradient_accumulation_steps: int = 4,
|
| 53 |
+
hub_token: Optional[str] = None,
|
| 54 |
+
organization: Optional[str] = None,
|
| 55 |
+
output_dir: str = "./models/",
|
| 56 |
+
tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
|
| 57 |
+
pretrained_encoder: str = "seyonec/ChemBERTa-zinc-base-v1",
|
| 58 |
+
pretrained_decoder: str = "seyonec/ChemBERTa-zinc-base-v1",
|
| 59 |
+
encoder_max_length: int = 512,
|
| 60 |
+
decoder_max_length: int = 512,
|
| 61 |
+
tie_encoder_decoder: bool = False,
|
| 62 |
+
delete_repo_if_exists: bool = False,
|
| 63 |
+
delete_local_repo_if_exists: bool = False,
|
| 64 |
+
training_args: Optional[Dict[str, Any]] = None,
|
| 65 |
+
resume_from_checkpoint: Optional[str] = None,
|
| 66 |
+
num_proc_map: int = 1,
|
| 67 |
+
per_device_batch_size: Optional[int] = None,
|
| 68 |
+
lr_scheduler_type: Optional[str] = None,
|
| 69 |
+
mlm_probability: float = 0.15,
|
| 70 |
+
randomize_smiles: bool = False,
|
| 71 |
+
randomize_smiles_prob: float = 0.5,
|
| 72 |
+
randomize_smiles_repeat: int = 1,
|
| 73 |
+
):
|
| 74 |
+
"""
|
| 75 |
+
Trains a masked language model (MLM) using an encoder-decoder architecture.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
model_id (str): The name of the model to be trained.
|
| 79 |
+
ds_name (str): The name of the dataset to use for training.
|
| 80 |
+
ds_config (str): The configuration of the dataset to use. Default: 'default'.
|
| 81 |
+
learning_rate (float): The learning rate for training. Default: 5e-5.
|
| 82 |
+
max_steps (int): The maximum number of training steps. Default: -1.
|
| 83 |
+
num_train_epochs (int): The number of training epochs. Default: 40.
|
| 84 |
+
batch_size (int): The total batch size. Default: 128.
|
| 85 |
+
batch_size_tokenizer (int): The batch size for the tokenizer. Default: 512.
|
| 86 |
+
gradient_accumulation_steps (int): The number of gradient accumulation steps. Default: 4.
|
| 87 |
+
hub_token (str): The Hugging Face token for authentication. Default: None.
|
| 88 |
+
organization (str): The organization to push the model to. Default: None.
|
| 89 |
+
output_dir (str): The output directory for the model. Default: "./models/".
|
| 90 |
+
tokenizer (AutoTokenizer | str): The tokenizer to use for training. Default: "seyonec/ChemBERTa-zinc-base-v1".
|
| 91 |
+
pretrained_encoder (str): The pretrained encoder model to use. Default: "seyonec/ChemBERTa-zinc-base-v1".
|
| 92 |
+
pretrained_decoder (str): The pretrained decoder model to use. Default: "seyonec/ChemBERTa-zinc-base-v1".
|
| 93 |
+
encoder_max_length (int): The maximum length of the encoder input. Default: 512.
|
| 94 |
+
decoder_max_length (int): The maximum length of the decoder input. Default: 512.
|
| 95 |
+
tie_encoder_decoder (bool): Whether to tie the encoder and decoder weights. Default: False.
|
| 96 |
+
delete_repo_if_exists (bool): Whether to delete the repository if it already exists. Default: False.
|
| 97 |
+
delete_local_repo_if_exists (bool): Whether to delete the local repository if it already exists. Default: False.
|
| 98 |
+
training_args (Dict[str, Any]): The training arguments for the Trainer. Default: None.
|
| 99 |
+
resume_from_checkpoint (str): The checkpoint to resume training from. Default: None.
|
| 100 |
+
num_optuna_trials (int): The number of Optuna hyperparameter search trials. Default: 0.
|
| 101 |
+
num_proc_map (int): The number of processes to use for mapping. Default: 1.
|
| 102 |
+
per_device_batch_size (int): The batch size per device. If defined, it will overwrite batch_size. Default: None.
|
| 103 |
+
lr_scheduler_type (str): The learning rate scheduler type. Default: None.
|
| 104 |
+
mlm_probability (float): The probability of masking tokens in the input. Default: 0.15.
|
| 105 |
+
randomize_smiles (bool): Whether to randomize SMILES strings. Default: False.
|
| 106 |
+
randomize_smiles_prob (float): The probability of randomizing SMILES strings. Default: 0.5.
|
| 107 |
+
randomize_smiles_repeat (int): The number of times to repeat randomizing SMILES strings. Default: 1.
|
| 108 |
+
"""
|
| 109 |
+
# Check if resume_from_checkpoint exists and it's a file
|
| 110 |
+
if resume_from_checkpoint is not None:
|
| 111 |
+
# Check if the checkpoint exists: it can be either a file or a directory
|
| 112 |
+
if not os.path.exists(resume_from_checkpoint):
|
| 113 |
+
raise ValueError(f"Checkpoint file '{resume_from_checkpoint}' does not exist.")
|
| 114 |
+
|
| 115 |
+
if hub_token is not None:
|
| 116 |
+
hf.login(token=hub_token)
|
| 117 |
+
|
| 118 |
+
# Setup output directory and Hugging Face repository
|
| 119 |
+
output_dir += f"/{model_id}"
|
| 120 |
+
if organization is not None:
|
| 121 |
+
hub_model_id = f"{organization}/{model_id}"
|
| 122 |
+
if delete_repo_if_exists and repo_exists(hub_model_id, token=hub_token):
|
| 123 |
+
delete_hf_repository(repo_id=hub_model_id, token=hub_token)
|
| 124 |
+
if not repo_exists(hub_model_id, token=hub_token):
|
| 125 |
+
print(f"Repository '{hub_model_id}' deleted.")
|
| 126 |
+
else:
|
| 127 |
+
print(f"Repository '{hub_model_id}' could not be deleted.")
|
| 128 |
+
return
|
| 129 |
+
if delete_local_repo_if_exists and os.path.exists(output_dir):
|
| 130 |
+
subprocess.run(["rm", "-rf", output_dir])
|
| 131 |
+
if not os.path.exists(output_dir):
|
| 132 |
+
print(f"Local repository '{output_dir}' deleted.")
|
| 133 |
+
else:
|
| 134 |
+
print(f"Local repository '{output_dir}' could not be deleted.")
|
| 135 |
+
return
|
| 136 |
+
repo_url = create_hf_repository(
|
| 137 |
+
repo_id=hub_model_id,
|
| 138 |
+
repo_type="model",
|
| 139 |
+
exist_ok=True,
|
| 140 |
+
private=True,
|
| 141 |
+
token=hub_token,
|
| 142 |
+
)
|
| 143 |
+
print(f"Repository '{hub_model_id}' created at URL: {repo_url}")
|
| 144 |
+
else:
|
| 145 |
+
hub_model_id = None
|
| 146 |
+
print(f"Hub model ID: {hub_model_id}")
|
| 147 |
+
|
| 148 |
+
if isinstance(tokenizer, str):
|
| 149 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
| 150 |
+
elif tokenizer is None:
|
| 151 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_encoder)
|
| 152 |
+
|
| 153 |
+
# Set the pad token to the end of the sequence, required for MLM training
|
| 154 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 155 |
+
|
| 156 |
+
# Load the tokenized dataset
|
| 157 |
+
print("Loading tokenized dataset.")
|
| 158 |
+
dataset_tokenized = load_tokenized_dataset(
|
| 159 |
+
ds_name,
|
| 160 |
+
ds_config,
|
| 161 |
+
tokenizer,
|
| 162 |
+
batch_size_tokenizer,
|
| 163 |
+
encoder_max_length,
|
| 164 |
+
decoder_max_length,
|
| 165 |
+
token=hub_token,
|
| 166 |
+
num_proc_map=num_proc_map,
|
| 167 |
+
randomize_smiles=randomize_smiles,
|
| 168 |
+
randomize_smiles_prob=randomize_smiles_prob,
|
| 169 |
+
randomize_smiles_repeat=randomize_smiles_repeat,
|
| 170 |
+
randomize_text=True,
|
| 171 |
+
randomize_labels=False,
|
| 172 |
+
)
|
| 173 |
+
# Remove "labels" column from the dataset
|
| 174 |
+
dataset_tokenized = dataset_tokenized.remove_columns(["labels"])
|
| 175 |
+
print("Dataset loaded.")
|
| 176 |
+
|
| 177 |
+
# Setup the model for `model_init` in the Trainer
|
| 178 |
+
bert2bert = lambda: get_encoder_decoder_model(
|
| 179 |
+
pretrained_encoder=pretrained_encoder,
|
| 180 |
+
pretrained_decoder=pretrained_decoder,
|
| 181 |
+
max_length=encoder_max_length,
|
| 182 |
+
tie_encoder_decoder=tie_encoder_decoder,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Setup the data collator
|
| 186 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 187 |
+
tokenizer,
|
| 188 |
+
mlm=True,
|
| 189 |
+
mlm_probability=mlm_probability,
|
| 190 |
+
pad_to_multiple_of=8,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# Setup the training arguments
|
| 194 |
+
if per_device_batch_size is None:
|
| 195 |
+
per_device_batch_size = batch_size // gradient_accumulation_steps
|
| 196 |
+
if training_args is None:
|
| 197 |
+
training_args = {
|
| 198 |
+
"output_dir": output_dir,
|
| 199 |
+
# Optimizer-related configs
|
| 200 |
+
"learning_rate": learning_rate,
|
| 201 |
+
"optim": "adamw_torch",
|
| 202 |
+
"lr_scheduler_type": "cosine" if lr_scheduler_type is None else lr_scheduler_type,
|
| 203 |
+
"warmup_steps": 8000, # NOTE: ChemFormer: 8000
|
| 204 |
+
# "warmup_ratio": 0,
|
| 205 |
+
"adam_beta1": 0.9, # NOTE: ChemFormer: 0.9
|
| 206 |
+
"adam_beta2": 0.999, # NOTE: ChemFormer: 0.999
|
| 207 |
+
"adam_epsilon": 1e-8, # Default: 1e-8
|
| 208 |
+
# Batch size, device, and performance optimizations configs
|
| 209 |
+
# "torch_compile": True,
|
| 210 |
+
"group_by_length": True,
|
| 211 |
+
"per_device_train_batch_size": per_device_batch_size,
|
| 212 |
+
"per_device_eval_batch_size": per_device_batch_size,
|
| 213 |
+
"gradient_accumulation_steps": gradient_accumulation_steps,
|
| 214 |
+
"auto_find_batch_size": True,
|
| 215 |
+
"fp16": True if torch.cuda.is_available() else False,
|
| 216 |
+
# Evaluation and checkpointing configs
|
| 217 |
+
"max_steps": max_steps,
|
| 218 |
+
"num_train_epochs": num_train_epochs,
|
| 219 |
+
"save_steps": 1000, # NOTE: 200
|
| 220 |
+
"save_strategy": "steps",
|
| 221 |
+
"eval_steps": 1000, # NOTE: 500
|
| 222 |
+
"evaluation_strategy": "steps",
|
| 223 |
+
"save_total_limit": 1,
|
| 224 |
+
"load_best_model_at_end": True,
|
| 225 |
+
"metric_for_best_model": "perplexity",
|
| 226 |
+
"include_inputs_for_metrics": True,
|
| 227 |
+
# Logging configs
|
| 228 |
+
"log_level": "warning",
|
| 229 |
+
"logging_steps": 500,
|
| 230 |
+
"disable_tqdm": True,
|
| 231 |
+
"report_to": ["tensorboard"],
|
| 232 |
+
"save_only_model": False, # Default: False
|
| 233 |
+
# Hub information configs
|
| 234 |
+
"push_to_hub": True, # NOTE: Also manually done further down
|
| 235 |
+
"push_to_hub_model_id": model_id,
|
| 236 |
+
"push_to_hub_organization": organization,
|
| 237 |
+
"hub_model_id": hub_model_id,
|
| 238 |
+
"hub_token": hub_token,
|
| 239 |
+
"hub_strategy": "checkpoint", # NOTE: Allows to resume training from last checkpoint
|
| 240 |
+
"hub_private_repo": True,
|
| 241 |
+
# Other configs
|
| 242 |
+
"seed": 42,
|
| 243 |
+
"data_seed": 42,
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
# Setup the Trainer and start training (no Optuna hyperparameter search)
|
| 247 |
+
trainer = Trainer(
|
| 248 |
+
model_init=bert2bert,
|
| 249 |
+
tokenizer=tokenizer,
|
| 250 |
+
data_collator=data_collator,
|
| 251 |
+
args=TrainingArguments(**training_args),
|
| 252 |
+
compute_metrics=compute_metrics_for_mlm,
|
| 253 |
+
train_dataset=dataset_tokenized["train"],
|
| 254 |
+
eval_dataset=dataset_tokenized["validation"],
|
| 255 |
+
)
|
| 256 |
+
if resume_from_checkpoint is not None:
|
| 257 |
+
trainer.train(
|
| 258 |
+
resume_from_checkpoint=resume_from_checkpoint,
|
| 259 |
+
)
|
| 260 |
+
else:
|
| 261 |
+
trainer.train()
|
| 262 |
+
print("-" * 80)
|
| 263 |
+
print("Training completed.")
|
| 264 |
+
print("-" * 80)
|
| 265 |
+
|
| 266 |
+
if hub_model_id is not None:
|
| 267 |
+
print("Pushing model to Hugging Face Hub.")
|
| 268 |
+
print("-" * 80)
|
| 269 |
+
tokenizer.save_pretrained(output_dir)
|
| 270 |
+
trainer.push_to_hub(
|
| 271 |
+
commit_message="Initial version",
|
| 272 |
+
model_name=hub_model_id,
|
| 273 |
+
license="mit",
|
| 274 |
+
finetuned_from=f"{pretrained_encoder}",
|
| 275 |
+
tasks=["Text2Text Generation", "question-answering"],
|
| 276 |
+
tags=["PROTAC", "cheminformatics"],
|
| 277 |
+
dataset=[ds_name],
|
| 278 |
+
dataset_args=[ds_config],
|
| 279 |
+
)
|
| 280 |
+
tokenizer.push_to_hub(
|
| 281 |
+
repo_id=hub_model_id,
|
| 282 |
+
commit_message="Upload tokenizer",
|
| 283 |
+
private=True,
|
| 284 |
+
token=hub_token,
|
| 285 |
+
tags=["PROTAC", "cheminformatics"],
|
| 286 |
+
)
|
| 287 |
+
print("All done.")
|
protac_splitter/llms/training_rl_models.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Train a PPO and DPO model for PROTAC-Splitter using Hugging Face
|
| 2 |
+
Transformers and TRL. This is a work in progress code, so it's not tested nor
|
| 3 |
+
used in the package.
|
| 4 |
+
"""
|
| 5 |
+
from typing import Optional, Literal
|
| 6 |
+
from functools import partial
|
| 7 |
+
import os
|
| 8 |
+
import subprocess
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import evaluate
|
| 12 |
+
import huggingface_hub as hf
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from datasets import load_dataset
|
| 15 |
+
from rdkit import Chem
|
| 16 |
+
from transformers import (
|
| 17 |
+
AutoTokenizer,
|
| 18 |
+
TrainingArguments,
|
| 19 |
+
EncoderDecoderModel,
|
| 20 |
+
AutoConfig,
|
| 21 |
+
)
|
| 22 |
+
from trl import (
|
| 23 |
+
AutoModelForSeq2SeqLMWithValueHead,
|
| 24 |
+
PPOConfig,
|
| 25 |
+
PPOTrainer,
|
| 26 |
+
DPOTrainer,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
from protac_splitter.llms.data_utils import (
|
| 30 |
+
load_trl_dataset,
|
| 31 |
+
data_collator_for_trl,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
from protac_splitter.llms.hf_utils import (
|
| 35 |
+
create_hf_repository,
|
| 36 |
+
delete_hf_repository,
|
| 37 |
+
repo_exists,
|
| 38 |
+
)
|
| 39 |
+
from protac_splitter.llms.evaluation import decode_and_get_metrics
|
| 40 |
+
from protac_splitter.evaluation import check_substructs, split_prediction
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def clean_text(text: str) -> str:
|
| 44 |
+
""" Cleans the text by removing special tokens. """
|
| 45 |
+
return text.replace("<s>", "").replace("</s>", "")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def reward_function(
|
| 49 |
+
query: str,
|
| 50 |
+
response: str,
|
| 51 |
+
) -> float:
|
| 52 |
+
""" Reward function for the RL-based models.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
query (str): The query SMILES string.
|
| 56 |
+
response (str): The response SMILES string.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
float: The reward value.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
substructs = split_prediction(response)
|
| 63 |
+
if substructs is None:
|
| 64 |
+
return torch.Tensor(-1.)
|
| 65 |
+
|
| 66 |
+
if not check_substructs(
|
| 67 |
+
protac_smiles=query,
|
| 68 |
+
poi_smiles=substructs['poi'],
|
| 69 |
+
linker_smiles=substructs['linker'],
|
| 70 |
+
e3_smiles=substructs['e3'],
|
| 71 |
+
return_bond_types=False,
|
| 72 |
+
poi_attachment_id=1,
|
| 73 |
+
e3_attachment_id=2,
|
| 74 |
+
):
|
| 75 |
+
return torch.Tensor(0.)
|
| 76 |
+
|
| 77 |
+
return torch.Tensor(1.)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def train_ppo_model(
|
| 81 |
+
model_id: str = "PROTAC-Splitter-PPO-standard_rand_recombined-ChemBERTa-zinc-base",
|
| 82 |
+
organization: str = 'ailab-bio',
|
| 83 |
+
output_dir: str = "./models/",
|
| 84 |
+
max_steps: int = 2000,
|
| 85 |
+
ppo_epochs: int = 5,
|
| 86 |
+
batch_size: int = 128,
|
| 87 |
+
hub_token: Optional[str] = None,
|
| 88 |
+
pretrained_model_name: str = "ailab-bio/PROTAC-Splitter-standard_rand_recombined-ChemBERTa-zinc-base",
|
| 89 |
+
max_length: int = 512,
|
| 90 |
+
delete_repo_if_exists: bool = False,
|
| 91 |
+
delete_local_repo_if_exists: bool = False,
|
| 92 |
+
ds_name: str = "ailab-bio/PROTAC-Splitter-Dataset",
|
| 93 |
+
ds_config: str = "standard",
|
| 94 |
+
):
|
| 95 |
+
""" Trains a PPO model on a given dataset.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
model_id (str, optional): The name of the model to be trained. Defaults to "PROTAC-Splitter-PPO-standard_rand_recombined-ChemBERTa-zinc-base".
|
| 99 |
+
organization (str, optional): The organization name. Defaults to 'ailab-bio'.
|
| 100 |
+
output_dir (str, optional): The output directory. Defaults to "./models/".
|
| 101 |
+
max_steps (int, optional): The maximum number of training steps. Defaults to 2000.
|
| 102 |
+
ppo_epochs (int, optional): The number of PPO epochs. Defaults to 4.
|
| 103 |
+
batch_size (int, optional): The batch size. Defaults to 128.
|
| 104 |
+
hub_token (Optional[str], optional): The Hugging Face token. Defaults to None.
|
| 105 |
+
pretrained_model_name (str, optional): The name of the pretrained model. Defaults to "ailab-bio/PROTAC-Splitter-standard_rand_recombined-ChemBERTa-zinc-base".
|
| 106 |
+
max_length (int, optional): The maximum length of the input sequence. Defaults to 512.
|
| 107 |
+
delete_repo_first (bool, optional): Whether to delete the repository first. Defaults to False.
|
| 108 |
+
"""
|
| 109 |
+
if ppo_epochs < 1:
|
| 110 |
+
raise ValueError(f"ppo_epochs must be >= 1, got {ppo_epochs}.")
|
| 111 |
+
if hub_token is not None:
|
| 112 |
+
hf.login(token=hub_token)
|
| 113 |
+
|
| 114 |
+
# Setup output directory and Hugging Face repository
|
| 115 |
+
output_dir += f"/{model_id}"
|
| 116 |
+
if organization is not None:
|
| 117 |
+
hub_model_id = f"{organization}/{model_id}"
|
| 118 |
+
if delete_repo_if_exists and repo_exists(hub_model_id, token=hub_token):
|
| 119 |
+
delete_hf_repository(repo_id=hub_model_id, token=hub_token)
|
| 120 |
+
if not repo_exists(hub_model_id, token=hub_token):
|
| 121 |
+
print(f"Repository '{hub_model_id}' deleted.")
|
| 122 |
+
else:
|
| 123 |
+
print(f"Repository '{hub_model_id}' could not be deleted.")
|
| 124 |
+
return
|
| 125 |
+
if delete_local_repo_if_exists and os.path.exists(output_dir):
|
| 126 |
+
subprocess.run(["rm", "-rf", output_dir])
|
| 127 |
+
if not os.path.exists(output_dir):
|
| 128 |
+
print(f"Local repository '{output_dir}' deleted.")
|
| 129 |
+
else:
|
| 130 |
+
print(f"Local repository '{output_dir}' could not be deleted.")
|
| 131 |
+
return
|
| 132 |
+
repo_url = create_hf_repository(
|
| 133 |
+
repo_id=hub_model_id,
|
| 134 |
+
repo_type="model",
|
| 135 |
+
exist_ok=True,
|
| 136 |
+
private=True,
|
| 137 |
+
token=hub_token,
|
| 138 |
+
)
|
| 139 |
+
print(f"Repository '{hub_model_id}' created at URL: {repo_url}")
|
| 140 |
+
else:
|
| 141 |
+
hub_model_id = None
|
| 142 |
+
print(f"Hub model ID: {hub_model_id}")
|
| 143 |
+
|
| 144 |
+
# Load pretrained model
|
| 145 |
+
model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
|
| 146 |
+
pretrained_model_name,
|
| 147 |
+
max_length=max_length,
|
| 148 |
+
)
|
| 149 |
+
ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
|
| 150 |
+
pretrained_model_name,
|
| 151 |
+
max_length=max_length,
|
| 152 |
+
)
|
| 153 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
|
| 154 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 155 |
+
|
| 156 |
+
# Get dataset
|
| 157 |
+
train_dataset = load_trl_dataset(
|
| 158 |
+
tokenizer=tokenizer,
|
| 159 |
+
token=hub_token,
|
| 160 |
+
max_length=max_length,
|
| 161 |
+
dataset_name=ds_name,
|
| 162 |
+
ds_config=ds_config,
|
| 163 |
+
).shuffle(seed=42).flatten_indices()
|
| 164 |
+
|
| 165 |
+
# Setup PPO trainer
|
| 166 |
+
hub_configs = {
|
| 167 |
+
"repo_id": hub_model_id,
|
| 168 |
+
"commit_message": "Initial version",
|
| 169 |
+
"private": True,
|
| 170 |
+
}
|
| 171 |
+
ppo_config = PPOConfig(
|
| 172 |
+
# Learning parameters
|
| 173 |
+
learning_rate=1e-5,
|
| 174 |
+
steps=max_steps, # Default: 20_000
|
| 175 |
+
ppo_epochs=ppo_epochs, # Default: 4
|
| 176 |
+
batch_size=batch_size, # Default: 256
|
| 177 |
+
gradient_accumulation_steps=1, # Default: 1
|
| 178 |
+
optimize_device_cache=True,
|
| 179 |
+
# PPO parameters
|
| 180 |
+
init_kl_coef=1.0,
|
| 181 |
+
adap_kl_ctrl=True,
|
| 182 |
+
target=0.5,
|
| 183 |
+
horizon=1000,
|
| 184 |
+
cliprange=0.1,
|
| 185 |
+
early_stopping=True,
|
| 186 |
+
target_kl=0.5,
|
| 187 |
+
max_grad_norm=1.0,
|
| 188 |
+
use_score_scaling=True,
|
| 189 |
+
use_score_norm=True,
|
| 190 |
+
whiten_rewards=True,
|
| 191 |
+
# Logging parameters
|
| 192 |
+
# NOTE: Check this guide for more information about the logged metrics:
|
| 193 |
+
# https://huggingface.co/docs/trl/v0.10.1/logging
|
| 194 |
+
model_name=hub_model_id,
|
| 195 |
+
push_to_hub_if_best_kwargs=hub_configs,
|
| 196 |
+
log_with="tensorboard", # ["wandb", LoggerType.TENSORBOARD],
|
| 197 |
+
project_kwargs={"logging_dir": output_dir},
|
| 198 |
+
seed=42,
|
| 199 |
+
)
|
| 200 |
+
ppo_trainer = PPOTrainer(
|
| 201 |
+
model=model,
|
| 202 |
+
ref_model=ref_model,
|
| 203 |
+
num_shared_layers=0,
|
| 204 |
+
config=ppo_config,
|
| 205 |
+
tokenizer=tokenizer,
|
| 206 |
+
dataset=train_dataset,
|
| 207 |
+
data_collator=data_collator_for_trl,
|
| 208 |
+
# lr_scheduler=torch.optim.lr_scheduler.LRScheduler, # NOTE: It must be that, CosineAnnealingLR is not supported
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Training Loop
|
| 212 |
+
generation_kwargs = {
|
| 213 |
+
"do_sample": True,
|
| 214 |
+
"num_beams": 5,
|
| 215 |
+
"top_k": 20,
|
| 216 |
+
"max_length": 512,
|
| 217 |
+
"pad_token_id": tokenizer.eos_token_id,
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader), total=len(ppo_trainer.dataloader)):
|
| 221 |
+
query_tensors = batch["input_ids"]
|
| 222 |
+
|
| 223 |
+
# Get response from SFTModel
|
| 224 |
+
response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
|
| 225 |
+
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
|
| 226 |
+
|
| 227 |
+
# Compute reward score
|
| 228 |
+
rewards = [reward_function(clean_text(q), clean_text(r)) for q, r in zip(batch["query"], batch["response"])]
|
| 229 |
+
rewards = [torch.tensor(r) for r in rewards]
|
| 230 |
+
|
| 231 |
+
# Run PPO step
|
| 232 |
+
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
|
| 233 |
+
ppo_trainer.log_stats(stats, batch, rewards)
|
| 234 |
+
|
| 235 |
+
# Save model and tokenizer
|
| 236 |
+
ppo_trainer.push_to_hub(**hub_configs)
|
| 237 |
+
tokenizer.push_to_hub(**hub_configs)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def train_dpo_model(
|
| 241 |
+
model_name: str = "ailab-bio/PROTAC-Splitter-DPO",
|
| 242 |
+
output_dir: str = "./models/",
|
| 243 |
+
beta: float = 0.1,
|
| 244 |
+
loss_type: Literal["sigmoid", "hinge"] = "sigmoid",
|
| 245 |
+
learning_rate: float = 5e-5,
|
| 246 |
+
max_steps: int = 2000,
|
| 247 |
+
num_train_epochs: int = -1,
|
| 248 |
+
batch_size: int = 128,
|
| 249 |
+
gradient_accumulation_steps: int = 4,
|
| 250 |
+
resume_from_checkpoint: bool = False,
|
| 251 |
+
hub_token: Optional[str] = None,
|
| 252 |
+
pretrained_model_name: str = "ailab-bio/PROTAC-Splitter_untied_80-20-split",
|
| 253 |
+
pretrained_ref_model_name: str = "ailab-bio/PROTAC-Splitter_untied_80-20-split",
|
| 254 |
+
max_length: int = None,
|
| 255 |
+
delete_repo_first: bool = False,
|
| 256 |
+
optuna_search: bool = False,
|
| 257 |
+
):
|
| 258 |
+
""" Trains a DPO model on a given dataset.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
model_name (str, optional): The name of the model to be trained. Defaults to "ailab-bio/PROTAC-Splitter-DPO".
|
| 262 |
+
max_steps (int, optional): The maximum number of training steps. Defaults to 2000.
|
| 263 |
+
"""
|
| 264 |
+
if hub_token is not None:
|
| 265 |
+
hf.login(token=hub_token)
|
| 266 |
+
if delete_repo_first and not resume_from_checkpoint:
|
| 267 |
+
delete_hf_repository(repo_id=model_name, token=hub_token)
|
| 268 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 269 |
+
pretrained_model_name,
|
| 270 |
+
token=hub_token,
|
| 271 |
+
)
|
| 272 |
+
if tokenizer.pad_token is None:
|
| 273 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 274 |
+
# Get train and eval datasets
|
| 275 |
+
dataset = load_dataset(
|
| 276 |
+
"ailab-bio/PROTAC-Substructures-DPO",
|
| 277 |
+
token=hub_token,
|
| 278 |
+
)
|
| 279 |
+
# Setup models
|
| 280 |
+
def model_init():
|
| 281 |
+
return EncoderDecoderModel.from_pretrained(
|
| 282 |
+
pretrained_model_name,
|
| 283 |
+
token=hub_token,
|
| 284 |
+
)
|
| 285 |
+
model_ref = EncoderDecoderModel.from_pretrained(
|
| 286 |
+
pretrained_ref_model_name,
|
| 287 |
+
token=hub_token,
|
| 288 |
+
)
|
| 289 |
+
# Setup training arguments
|
| 290 |
+
per_device_batch_size = batch_size // gradient_accumulation_steps
|
| 291 |
+
training_args = TrainingArguments(
|
| 292 |
+
output_dir=output_dir,
|
| 293 |
+
# Optimizer-related configs
|
| 294 |
+
learning_rate=learning_rate,
|
| 295 |
+
optim="adamw_torch",
|
| 296 |
+
lr_scheduler_type="cosine", # Default: "linear"
|
| 297 |
+
# Batch size and device configs
|
| 298 |
+
per_device_train_batch_size=per_device_batch_size,
|
| 299 |
+
per_device_eval_batch_size=per_device_batch_size,
|
| 300 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 301 |
+
auto_find_batch_size=True,
|
| 302 |
+
# torch_compile=True,
|
| 303 |
+
fp16=True,
|
| 304 |
+
# Evaluation and checkpointing configs
|
| 305 |
+
evaluation_strategy="steps", # TODO: Why is it not working? "steps",
|
| 306 |
+
max_steps=max_steps,
|
| 307 |
+
num_train_epochs=num_train_epochs,
|
| 308 |
+
eval_steps=100,
|
| 309 |
+
save_steps=200,
|
| 310 |
+
# eval_steps=7500,
|
| 311 |
+
# warmup_steps=2000,
|
| 312 |
+
save_strategy="steps",
|
| 313 |
+
save_total_limit=1,
|
| 314 |
+
load_best_model_at_end=True,
|
| 315 |
+
# metric_for_best_model="valid_smiles",
|
| 316 |
+
# Logging configs
|
| 317 |
+
log_level="info",
|
| 318 |
+
logging_steps=50,
|
| 319 |
+
disable_tqdm=True,
|
| 320 |
+
# Hub information configs
|
| 321 |
+
push_to_hub=True, # NOTE: Done manually further down
|
| 322 |
+
hub_token=hub_token,
|
| 323 |
+
hub_model_id=model_name,
|
| 324 |
+
hub_strategy="checkpoint", # NOTE: Allows to resume training from last checkpoint
|
| 325 |
+
hub_private_repo=True,
|
| 326 |
+
# Other configs
|
| 327 |
+
remove_unused_columns=False,
|
| 328 |
+
seed=42,
|
| 329 |
+
data_seed=42,
|
| 330 |
+
)
|
| 331 |
+
# Setup Matrics
|
| 332 |
+
# TODO: The metric is not working because the predictions include rewards,
|
| 333 |
+
# or something like that, i.e., real values, which cannot be decoded by the
|
| 334 |
+
# tokenizer. Skipping for now and using the default one.
|
| 335 |
+
rouge = evaluate.load("rouge")
|
| 336 |
+
fpgen = Chem.rdFingerprintGenerator.GetMorganGenerator(
|
| 337 |
+
radius=8,
|
| 338 |
+
fpSize=2048,
|
| 339 |
+
)
|
| 340 |
+
metric = partial(
|
| 341 |
+
decode_and_get_metrics,
|
| 342 |
+
rouge=rouge,
|
| 343 |
+
tokenizer=tokenizer,
|
| 344 |
+
fpgen=fpgen,
|
| 345 |
+
)
|
| 346 |
+
# Setup trainer and start training
|
| 347 |
+
if max_length is None:
|
| 348 |
+
max_length = AutoConfig.from_pretrained(
|
| 349 |
+
pretrained_model_name,
|
| 350 |
+
token=hub_token,
|
| 351 |
+
).max_length
|
| 352 |
+
# max_length = model.config.max_length
|
| 353 |
+
dpo_trainer = DPOTrainer(
|
| 354 |
+
model=model_init(),
|
| 355 |
+
ref_model=model_ref,
|
| 356 |
+
beta=beta,
|
| 357 |
+
loss_type=loss_type,
|
| 358 |
+
train_dataset=dataset["train"],
|
| 359 |
+
eval_dataset=dataset["test"],
|
| 360 |
+
tokenizer=tokenizer,
|
| 361 |
+
model_init=model_init if optuna_search else None,
|
| 362 |
+
# compute_metrics=metric,
|
| 363 |
+
max_length=max_length,
|
| 364 |
+
max_prompt_length=max_length,
|
| 365 |
+
max_target_length=max_length,
|
| 366 |
+
is_encoder_decoder=True,
|
| 367 |
+
padding_value=tokenizer.pad_token_id,
|
| 368 |
+
truncation_mode="keep_start",
|
| 369 |
+
args=training_args,
|
| 370 |
+
)
|
| 371 |
+
if optuna_search and False:
|
| 372 |
+
# TODO: This is not working because the training arguments do NOT
|
| 373 |
+
# include the beta parameter...
|
| 374 |
+
def optuna_hp_space(trial):
|
| 375 |
+
return {
|
| 376 |
+
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
|
| 377 |
+
"per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [16, 32, 64, 128]),
|
| 378 |
+
"beta": trial.suggest_float("beta", 0.1, 0.5),
|
| 379 |
+
}
|
| 380 |
+
best_trials = dpo_trainer.hyperparameter_search(
|
| 381 |
+
direction=["minimize"],
|
| 382 |
+
backend="optuna",
|
| 383 |
+
hp_space=optuna_hp_space,
|
| 384 |
+
n_trials=20,
|
| 385 |
+
# compute_objective=compute_objective,
|
| 386 |
+
)
|
| 387 |
+
print("-" * 80)
|
| 388 |
+
print(f"Best trials:\n{best_trials}")
|
| 389 |
+
print("-" * 80)
|
| 390 |
+
else:
|
| 391 |
+
if resume_from_checkpoint:
|
| 392 |
+
resume_from_checkpoint = "last-checkpoint"
|
| 393 |
+
else:
|
| 394 |
+
resume_from_checkpoint = None
|
| 395 |
+
dpo_trainer.train(
|
| 396 |
+
resume_from_checkpoint=resume_from_checkpoint,
|
| 397 |
+
)
|
| 398 |
+
dpo_trainer.push_to_hub(
|
| 399 |
+
commit_message="Initial version",
|
| 400 |
+
model_name=model_name,
|
| 401 |
+
license="mit",
|
| 402 |
+
finetuned_from=pretrained_model_name,
|
| 403 |
+
tasks=["Text2Text Generation"],
|
| 404 |
+
tags=["PROTAC", "cheminformatics"],
|
| 405 |
+
dataset="ailab-bio/PROTAC-Substructures-DPO",
|
| 406 |
+
)
|
protac_splitter/protac_cheminformatics.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import random
|
| 3 |
+
from typing import List, Tuple, Callable, Any, Union, Dict, Optional, Literal
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
from rdkit import Chem
|
| 7 |
+
from rdkit.Chem import AllChem
|
| 8 |
+
from rdkit.Chem import rdchem
|
| 9 |
+
from rdkit import RDLogger
|
| 10 |
+
from rdkit.Chem import CanonSmiles
|
| 11 |
+
|
| 12 |
+
from .chemoinformatics import (
|
| 13 |
+
canonize,
|
| 14 |
+
smiles2mol,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
RDLogger.DisableLog("rdApp.*")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@lru_cache(maxsize=None)
|
| 21 |
+
def get_mol(smiles: str) -> rdchem.Mol:
|
| 22 |
+
return Chem.MolFromSmiles(smiles)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def find_atom_idx_of_map_atoms(
|
| 26 |
+
mol: rdchem.Mol,
|
| 27 |
+
find_poi: True,
|
| 28 |
+
find_e3: True,
|
| 29 |
+
poi_attachment_id: int = 1,
|
| 30 |
+
e3_attachment_id: int = 2,
|
| 31 |
+
) -> Union[int, Tuple[int, int]]:
|
| 32 |
+
""" Find the indices of the attachment points in the given molecule.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
mol (rdkit.Chem.rdchem.Mol): The molecule.
|
| 36 |
+
find_poi (bool): Whether to find the POI attachment point.
|
| 37 |
+
find_e3 (bool): Whether to find the E3 attachment point.
|
| 38 |
+
poi_attachment_id (int): The label of the attachment point for the POI ligand, i.e., "[*:{poi_attachment_id}]".
|
| 39 |
+
e3_attachment_id (int): The label of the attachment point for the E3 binder, i.e., "[*:{e3_attachment_id}]".
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
int | Tuple[int, int]: The index of the attachment point for the POI ligand if find_poi is True, the index of the attachment point for the E3 binder if find_e3 is True, or a tuple containing POI and E3 indices (in this order) if both find_poi and find_e3 are True.
|
| 43 |
+
"""
|
| 44 |
+
if find_poi and find_e3:
|
| 45 |
+
poi_idx = None
|
| 46 |
+
e3_idx = None
|
| 47 |
+
for atom in mol.GetAtoms():
|
| 48 |
+
if atom.GetAtomMapNum() == poi_attachment_id:
|
| 49 |
+
poi_idx = atom.GetIdx()
|
| 50 |
+
elif atom.GetAtomMapNum() == e3_attachment_id:
|
| 51 |
+
e3_idx = atom.GetIdx()
|
| 52 |
+
if poi_idx is not None and e3_idx is not None:
|
| 53 |
+
break
|
| 54 |
+
return poi_idx, e3_idx
|
| 55 |
+
elif find_poi:
|
| 56 |
+
for atom in mol.GetAtoms():
|
| 57 |
+
if atom.GetAtomMapNum() == poi_attachment_id:
|
| 58 |
+
return atom.GetIdx()
|
| 59 |
+
elif find_e3:
|
| 60 |
+
for atom in mol.GetAtoms():
|
| 61 |
+
if atom.GetAtomMapNum() == e3_attachment_id:
|
| 62 |
+
return atom.GetIdx()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def reassemble_protac(
|
| 66 |
+
ligands_smiles: Optional[str] = None,
|
| 67 |
+
poi_smiles: Optional[str] = None,
|
| 68 |
+
linker_smiles: Optional[str] = None,
|
| 69 |
+
e3_smiles: Optional[str] = None,
|
| 70 |
+
e3_bond_type: Literal['single', 'double', 'triple', 'rand_uniform'] = 'single',
|
| 71 |
+
poi_bond_type: Literal['single', 'double', 'triple', 'rand_uniform'] = 'single',
|
| 72 |
+
poi_attachment_id: int = 1,
|
| 73 |
+
e3_attachment_id: int = 2,
|
| 74 |
+
rand_generator = None,
|
| 75 |
+
) -> Tuple[str, Chem.rdchem.Mol]:
|
| 76 |
+
""" Reassemble a PROTAC molecule from its substructures. The SMILES must contain attachment points.
|
| 77 |
+
|
| 78 |
+
In case the bond type cannot be formed an error will be raised.
|
| 79 |
+
|
| 80 |
+
Example of usage:
|
| 81 |
+
|
| 82 |
+
```python
|
| 83 |
+
e3_smiles = '[*:2]NC(C(=O)N1CC(O)CC1C(=O)NCc1ccc(-c2scnc2C)cc1)C(C)(C)C'
|
| 84 |
+
linker_smiles = '[*:2]C(=O)CCCCCCCCCC[*:1]'
|
| 85 |
+
poi_smiles = '[*:1]CN1CCN(c2ccc(Nc3ncc4c(C)cc(=O)n(-c5cccc(NC(=O)C=C)c5)c4n3)c(OC)c2)CC1'
|
| 86 |
+
|
| 87 |
+
merged_smiles, _ = reassemble_protac(poi_smiles, linker_smiles, e3_smiles, 'single', 'single')
|
| 88 |
+
print(merged_smiles)
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
poi_smiles (str): The SMILES notation for the POI ligand.
|
| 93 |
+
linker_smiles (str): The SMILES notation for the linker.
|
| 94 |
+
e3_smiles (str): The SMILES notation for the E3 binder.
|
| 95 |
+
e3_bond_type (str): The type of bond to be added between the E3 binder and the linker. Can be 'single', 'double', 'triple', or 'rand_uniform'.
|
| 96 |
+
poi_bond_type (str): The type of bond to be added between the POI ligand and the linker. Can be 'single', 'double', 'triple', or 'rand_uniform'.
|
| 97 |
+
poi_attachment_id (int): The label of the attachment point for the POI ligand, i.e., "[*:{poi_attachment_id}]".
|
| 98 |
+
e3_attachment_id (int): The label of the attachment point for the E3 binder, i.e., "[*:{e3_attachment_id}]".
|
| 99 |
+
rand_generator: A random number generator for 'rand_uniform' bond types. Defaults to None, i.e., standard library random.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Tuple[str, Chem.rdchem.Mol]: The SMILES notation and RDKit molecule object for the reassembled PROTAC molecule.
|
| 103 |
+
"""
|
| 104 |
+
if ligands_smiles is None:
|
| 105 |
+
if None in [poi_smiles, linker_smiles, e3_smiles]:
|
| 106 |
+
raise ValueError("Missing substructures SMILES: either provide ligands_smiles or all of poi_smiles, linker_smiles, and e3_smiles")
|
| 107 |
+
ligands_smiles = f'{e3_smiles}.{linker_smiles}.{poi_smiles}'
|
| 108 |
+
if None in [poi_smiles, linker_smiles, e3_smiles]:
|
| 109 |
+
if ligands_smiles is None:
|
| 110 |
+
raise ValueError("Missing substructures SMILES: either provide ligands_smiles or all of poi_smiles, linker_smiles, and e3_smiles")
|
| 111 |
+
|
| 112 |
+
ligands_mol = canonize(smiles2mol(ligands_smiles))
|
| 113 |
+
if ligands_mol is None:
|
| 114 |
+
return None, None
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
protac_mol = Chem.molzip(ligands_mol)
|
| 118 |
+
except ValueError as e:
|
| 119 |
+
logging.error(f"Failed to reassemble PROTAC: {e}")
|
| 120 |
+
return None, None
|
protac_splitter/protac_splitter.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
from typing import Union, Optional, Dict, List
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
from datasets import Dataset
|
| 8 |
+
import pandas as pd
|
| 9 |
+
|
| 10 |
+
from protac_splitter.chemoinformatics import canonize
|
| 11 |
+
from protac_splitter.fixing_functions import fix_prediction
|
| 12 |
+
from protac_splitter.llms.model_utils import get_pipeline, run_pipeline
|
| 13 |
+
from protac_splitter.graphs.e3_clustering import get_representative_e3s_fp
|
| 14 |
+
from protac_splitter.graphs.edge_classifier import GraphEdgeClassifier
|
| 15 |
+
from protac_splitter.graphs.splitting_algorithms import split_protac_graph_based
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_graph_edge_classifier_from_cache(
|
| 19 |
+
cache_dir: Union[str, Path] = "~/.cache/protac_splitter",
|
| 20 |
+
model_filename: str = "PROTAC-Splitter-XGBoost.joblib",
|
| 21 |
+
download_url: str = "https://docs.google.com/uc?export=download&id=1bb9i5_L_-re3QYPc7tSiCtVNEEbNIzAC",
|
| 22 |
+
) -> GraphEdgeClassifier:
|
| 23 |
+
"""
|
| 24 |
+
Loads the GraphEdgeClassifier model from a local cache directory.
|
| 25 |
+
If the model file is not found, downloads it from the specified URL.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
cache_dir (str or Path): Directory to cache the model file.
|
| 29 |
+
model_filename (str): Name of the model file.
|
| 30 |
+
download_url (str): URL to download the model if not present.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
GraphEdgeClassifier: Loaded classifier.
|
| 34 |
+
"""
|
| 35 |
+
cache_dir = Path(os.path.expanduser(cache_dir))
|
| 36 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
model_path = cache_dir / model_filename
|
| 38 |
+
|
| 39 |
+
if not model_path.exists():
|
| 40 |
+
response = requests.get(download_url, stream=True)
|
| 41 |
+
response.raise_for_status()
|
| 42 |
+
expected_size = int(response.headers.get("Content-Length", -1))
|
| 43 |
+
|
| 44 |
+
with open(model_path, "wb") as f:
|
| 45 |
+
for chunk in response.iter_content(chunk_size=1024*1024):
|
| 46 |
+
if chunk:
|
| 47 |
+
f.write(chunk)
|
| 48 |
+
|
| 49 |
+
if expected_size != -1:
|
| 50 |
+
actual = model_path.stat().st_size
|
| 51 |
+
if actual != expected_size:
|
| 52 |
+
raise RuntimeError(f"Download incomplete: got {actual}, expected {expected_size}")
|
| 53 |
+
|
| 54 |
+
# Optional checksum:
|
| 55 |
+
# NOTE: Uncomment the following for debugging
|
| 56 |
+
import hashlib
|
| 57 |
+
h = hashlib.sha256(model_path.read_bytes()).hexdigest()
|
| 58 |
+
h_orig = "513621f4dc2ff7ec819a222bc7311afb8b6e6e89d6d694dd2906e695a50086dd"
|
| 59 |
+
if h != h_orig:
|
| 60 |
+
raise RuntimeError(
|
| 61 |
+
f"Downloaded model checksum mismatch: got {h}, expected {h_orig}. "
|
| 62 |
+
"Please delete the model file and try again."
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
return GraphEdgeClassifier.load(model_path)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def split_protac(
|
| 69 |
+
protac_smiles: Union[str, List, pd.DataFrame],
|
| 70 |
+
use_transformer: bool = False,
|
| 71 |
+
use_xgboost: bool = True,
|
| 72 |
+
fix_predictions: bool = True,
|
| 73 |
+
protac_smiles_col: str = "text",
|
| 74 |
+
batch_size: int = 1,
|
| 75 |
+
beam_size: int = 5,
|
| 76 |
+
device: Optional[Union[int, str]] = None,
|
| 77 |
+
num_proc: int = 1,
|
| 78 |
+
verbose: int = 0,
|
| 79 |
+
) -> Union[Dict[str, str], List[Dict[str, str]]]:
|
| 80 |
+
""" Split a PROTAC SMILES into the two ligands and the linker.
|
| 81 |
+
|
| 82 |
+
If `use_transformer` and `use_xgboost` are both True, the Transformer model
|
| 83 |
+
will run first, and XGBost will be used as a fallback for predictions that
|
| 84 |
+
fail re-assembly and fixing. If both `use_transformer` and `use_xgboost`
|
| 85 |
+
are False, a fully heuristic-based algorithm will be used for splitting.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
protac_smiles (str, list, or pd.DataFrame): The PROTAC SMILES to split.
|
| 89 |
+
If a DataFrame is provided, it must contain a column named `protac_smiles_col`.
|
| 90 |
+
use_transformer (bool): Whether to use the transformer model for splitting.
|
| 91 |
+
use_xgboost (bool): Whether to use the XGBoost model for splitting.
|
| 92 |
+
fix_predictions (bool): Whether to fix the predictions using deterministic cheminformatics rules. Only used if `use_transformer` is True.
|
| 93 |
+
protac_smiles_col (str): The name of the column containing the PROTAC SMILES in the DataFrame.
|
| 94 |
+
batch_size (int): Batch size for processing. Only used if `use_transformer` is True.
|
| 95 |
+
beam_size (int): Number of beam search predictions to generate. Only used if `use_transformer` is True. Higher values may yield better results but increase computation time.
|
| 96 |
+
device (int or str, optional): Device to run the Transformer model on. Defaults to None will attempt to run on GPU if available, otherwise CPU.
|
| 97 |
+
num_proc (int): Number of processes to use for parallel processing. Useful for large datasets of PROTACs to split.
|
| 98 |
+
verbose (int): Verbosity level.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Union[Dict[str, str], List[Dict[str, str]]]: Depending on the input type, returns:
|
| 102 |
+
- If a single string is provided, returns a dictionary with format: `{protac_smiles_col: protac_smiles, "default_pred_n0": e3l.linker.warhead, "model_name": Transformer|XGBoost|Heuristic}`.
|
| 103 |
+
- If a list of strings is provided, returns a list of dictionaries with the same format as above.
|
| 104 |
+
- If a DataFrame is provided, returns a DataFrame with columns: `protac_smiles_col`, `default_pred_n0`, and `model_name`. The `default_pred_n0` column contains the predicted split strings in the format `e3.linker.warhead`.
|
| 105 |
+
"""
|
| 106 |
+
if use_xgboost:
|
| 107 |
+
representative_e3s_fp = get_representative_e3s_fp()
|
| 108 |
+
xgboost_model = load_graph_edge_classifier_from_cache()
|
| 109 |
+
|
| 110 |
+
# Generate a Dataset from the input PROTAC SMILES
|
| 111 |
+
if isinstance(protac_smiles, str):
|
| 112 |
+
protac_smiles_canon = canonize(protac_smiles)
|
| 113 |
+
if protac_smiles_canon is None:
|
| 114 |
+
raise ValueError(f"Invalid PROTAC SMILES: {protac_smiles}")
|
| 115 |
+
ds = Dataset.from_dict({protac_smiles_col: [protac_smiles_canon]})
|
| 116 |
+
elif isinstance(protac_smiles, list):
|
| 117 |
+
# Canonize and check if all PROTAC SMILES are valid
|
| 118 |
+
protac_smiles_canon = [canonize(protac) for protac in protac_smiles]
|
| 119 |
+
if None in protac_smiles_canon:
|
| 120 |
+
wrong_protacs = [protac for protac, canon in zip(protac_smiles, protac_smiles_canon) if canon is None]
|
| 121 |
+
raise ValueError(f"Invalid PROTAC SMILES in list: {wrong_protacs}")
|
| 122 |
+
ds = Dataset.from_dict({protac_smiles_col: protac_smiles_canon})
|
| 123 |
+
elif isinstance(protac_smiles, pd.DataFrame):
|
| 124 |
+
# Check if the DataFrame contains a columns named `protac_smiles_col`
|
| 125 |
+
if protac_smiles_col not in protac_smiles.columns:
|
| 126 |
+
raise ValueError(f"DataFrame must contain a column named \"{protac_smiles_col}\".")
|
| 127 |
+
# Canonize and check if all PROTAC SMILES are valid
|
| 128 |
+
protac_smiles_canon = protac_smiles[protac_smiles_col].apply(canonize)
|
| 129 |
+
if protac_smiles_canon.isnull().any():
|
| 130 |
+
wrong_protacs = protac_smiles[protac_smiles_canon.isnull()]
|
| 131 |
+
raise ValueError(f"Invalid PROTAC SMILES in DataFrame: {wrong_protacs}")
|
| 132 |
+
ds = Dataset.from_pandas(protac_smiles_canon.to_frame(name=protac_smiles_col))
|
| 133 |
+
|
| 134 |
+
if use_transformer:
|
| 135 |
+
pipe = get_pipeline(
|
| 136 |
+
model_name="ailab-bio/PROTAC-Splitter-EncoderDecoder-lr_reduce-rand-smiles",
|
| 137 |
+
token=os.environ.get("HF_TOKEN", None),
|
| 138 |
+
is_causal_language_model=False,
|
| 139 |
+
num_return_sequences=beam_size,
|
| 140 |
+
device=device,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# preds will be a list of dictionaries, each containing the
|
| 144 |
+
# beam-size predictions for each input PROTAC SMILES. Format: [{'pred_n0': 'prediction_0', 'pred_n1': 'prediction_1', ...}, ...]
|
| 145 |
+
preds = run_pipeline(
|
| 146 |
+
pipe,
|
| 147 |
+
ds,
|
| 148 |
+
batch_size,
|
| 149 |
+
is_causal_language_model=False,
|
| 150 |
+
smiles_column=protac_smiles_col,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Turn the predictions into a DataFrame and then into a Dataset
|
| 154 |
+
preds_df = pd.DataFrame(preds)
|
| 155 |
+
preds_df[protac_smiles_col] = ds[protac_smiles_col]
|
| 156 |
+
preds_ds = Dataset.from_pandas(preds_df)
|
| 157 |
+
|
| 158 |
+
def mapping_func(row: Dict[str, str]) -> Dict[str, str]:
|
| 159 |
+
"""Fix the predictions for each row."""
|
| 160 |
+
protac = row[protac_smiles_col]
|
| 161 |
+
if fix_predictions:
|
| 162 |
+
preds = {k: fix_prediction(protac, v, verbose=verbose) for k, v in row.items() if k.startswith("pred_")}
|
| 163 |
+
else:
|
| 164 |
+
preds = {k: v for k, v in row.items() if k.startswith("pred_")}
|
| 165 |
+
|
| 166 |
+
# If all preds are None, we attempt to use the XGBoost model
|
| 167 |
+
if all(v is None for v in preds.values()):
|
| 168 |
+
if use_xgboost:
|
| 169 |
+
pred = split_protac_graph_based(
|
| 170 |
+
protac_smiles=protac,
|
| 171 |
+
use_classifier=True,
|
| 172 |
+
classifier=xgboost_model,
|
| 173 |
+
representative_e3s_fp=representative_e3s_fp,
|
| 174 |
+
)
|
| 175 |
+
return {
|
| 176 |
+
protac_smiles_col: protac,
|
| 177 |
+
"default_pred_n0": f"{pred['e3']}.{pred['linker']}.{pred['poi']}",
|
| 178 |
+
"model_name": "XGBoost",
|
| 179 |
+
}
|
| 180 |
+
else:
|
| 181 |
+
# If no predictions are valid, we return None for the default prediction
|
| 182 |
+
return {
|
| 183 |
+
protac_smiles_col: protac,
|
| 184 |
+
"default_pred_n0": None,
|
| 185 |
+
"model_name": "Transformer",
|
| 186 |
+
}
|
| 187 |
+
else:
|
| 188 |
+
# Select the non-None prediction with the lowest beam index
|
| 189 |
+
# NOTE: The HF predictions comes in lists, with the first
|
| 190 |
+
# element being the one with the highest likelihood.
|
| 191 |
+
for i in range(beam_size):
|
| 192 |
+
key = f"pred_n{i}"
|
| 193 |
+
if preds[key] is not None:
|
| 194 |
+
return {
|
| 195 |
+
protac_smiles_col: protac,
|
| 196 |
+
"default_pred_n0": preds[key],
|
| 197 |
+
"model_name": "Transformer",
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
# Map the function over the Dataset to fix the predictions and/or
|
| 201 |
+
# replace them with the XGBoost fallback predictions if they fail.
|
| 202 |
+
if fix_predictions or use_xgboost:
|
| 203 |
+
preds_ds = preds_ds.map(
|
| 204 |
+
mapping_func,
|
| 205 |
+
num_proc=1 if use_xgboost else num_proc, # Using XGBoost IN a map function might not be thread-safe
|
| 206 |
+
desc=f"{'Fixing predictions' if fix_predictions else ''}{' and ' if fix_predictions and use_xgboost else ''}{'Replacing predictions with XGBoost fallback' if use_xgboost else ''}",
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
elif use_xgboost:
|
| 210 |
+
# Use the XGBoost model only
|
| 211 |
+
def mapping_func(row: Dict[str, str]) -> Dict[str, str]:
|
| 212 |
+
"""Split the PROTAC SMILES using the XGBoost model."""
|
| 213 |
+
protac = row[protac_smiles_col]
|
| 214 |
+
pred = split_protac_graph_based(
|
| 215 |
+
protac_smiles=protac,
|
| 216 |
+
use_classifier=True,
|
| 217 |
+
classifier=xgboost_model,
|
| 218 |
+
representative_e3s_fp=representative_e3s_fp,
|
| 219 |
+
)
|
| 220 |
+
if all(v is None for v in pred.values()):
|
| 221 |
+
split = None
|
| 222 |
+
else:
|
| 223 |
+
split = f"{pred['e3']}.{pred['linker']}.{pred['poi']}"
|
| 224 |
+
return {
|
| 225 |
+
protac_smiles_col: protac,
|
| 226 |
+
"default_pred_n0": split,
|
| 227 |
+
"model_name": "XGBoost",
|
| 228 |
+
}
|
| 229 |
+
preds_ds = ds.map(
|
| 230 |
+
mapping_func,
|
| 231 |
+
num_proc=1,
|
| 232 |
+
desc="Splitting PROTAC SMILES using XGBoost model",
|
| 233 |
+
)
|
| 234 |
+
else:
|
| 235 |
+
# If neither transformer nor XGBoost is used, we use the heuristic-based
|
| 236 |
+
# algorithm, that does not require any model.
|
| 237 |
+
def mapping_func(row: Dict[str, str]) -> Dict[str, str]:
|
| 238 |
+
"""Split the PROTAC SMILES using the heuristic-based algorithm."""
|
| 239 |
+
protac = row[protac_smiles_col]
|
| 240 |
+
pred = split_protac_graph_based(
|
| 241 |
+
protac_smiles=protac,
|
| 242 |
+
use_classifier=False,
|
| 243 |
+
)
|
| 244 |
+
if all(v is None for v in pred.values()):
|
| 245 |
+
split = None
|
| 246 |
+
else:
|
| 247 |
+
split = f"{pred['e3']}.{pred['linker']}.{pred['poi']}"
|
| 248 |
+
return {
|
| 249 |
+
protac_smiles_col: protac,
|
| 250 |
+
"default_pred_n0": split,
|
| 251 |
+
"model_name": "Heuristic",
|
| 252 |
+
}
|
| 253 |
+
preds_ds = ds.map(
|
| 254 |
+
mapping_func,
|
| 255 |
+
num_proc=num_proc,
|
| 256 |
+
desc="Splitting PROTAC SMILES using heuristic-based algorithm",
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
if isinstance(protac_smiles, str):
|
| 260 |
+
# If the input was a single string, we return the first prediction
|
| 261 |
+
return preds_ds[0]
|
| 262 |
+
elif isinstance(protac_smiles, pd.DataFrame):
|
| 263 |
+
# If the input was a DataFrame, we return a dataframe with the predictions
|
| 264 |
+
return preds_ds.to_pandas()
|
| 265 |
+
elif isinstance(protac_smiles, list):
|
| 266 |
+
# Convert the Dataset to a list of dictionaries
|
| 267 |
+
return [row for row in preds_ds]
|
| 268 |
+
|
| 269 |
+
# if tokenizer is None:
|
| 270 |
+
# if verbose:
|
| 271 |
+
# print(f"Loading tokenizer...")
|
| 272 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
|
| 273 |
+
|
| 274 |
+
# if pipe is None:
|
| 275 |
+
# if verbose:
|
| 276 |
+
# print("Loading pipeline for \"default\" predictions...")
|
| 277 |
+
# pipe = pipeline(
|
| 278 |
+
# "text2text-generation",
|
| 279 |
+
# model=model_name,
|
| 280 |
+
# tokenizer=tokenizer,
|
| 281 |
+
# device="cuda" if torch.cuda.is_available() else "cpu",
|
| 282 |
+
# token=hf_token,
|
| 283 |
+
# num_return_sequences=beam_size,
|
| 284 |
+
# )
|
| 285 |
+
|
| 286 |
+
# if isinstance(protac_smiles, str):
|
| 287 |
+
# protac_smiles_canon = canonize(protac_smiles)
|
| 288 |
+
# if protac_smiles_canon is None:
|
| 289 |
+
# raise ValueError(f"Invalid PROTAC SMILES: {protac_smiles}")
|
| 290 |
+
# pred = pipe(protac_smiles_canon)
|
| 291 |
+
# pred = {f"default_pred_n{i}": pred[i]["generated_text"] for i in range(len(pred))}
|
| 292 |
+
# if fix_predictions:
|
| 293 |
+
# p_fixed = {k: fix_prediction(protac_smiles_canon, v, verbose=verbose) for k, v in pred.items()}
|
| 294 |
+
# # For each prediction, if the fixed prediction is not None, we
|
| 295 |
+
# # replace the original prediction with the fixed one.
|
| 296 |
+
# for k, v in p_fixed.items():
|
| 297 |
+
# if v is not None:
|
| 298 |
+
# pred[k] = v
|
| 299 |
+
# preds = [pred]
|
| 300 |
+
|
| 301 |
+
# if isinstance(protac_smiles, list):
|
| 302 |
+
# # Canonize and check if all PROTAC SMILES are valid
|
| 303 |
+
# protac_smiles_canon = [canonize(protac) for protac in protac_smiles]
|
| 304 |
+
# if None in protac_smiles_canon:
|
| 305 |
+
# wrong_protacs = [protac for protac, canon in zip(protac_smiles, protac_smiles_canon) if canon is None]
|
| 306 |
+
# raise ValueError(f"Invalid PROTAC SMILES in list: {wrong_protacs}")
|
| 307 |
+
|
| 308 |
+
# # Get the predictions for all PROTAC SMILES
|
| 309 |
+
# preds = pipe(protac_smiles_canon, batch_size=batch_size)
|
| 310 |
+
# preds = [{f"default_pred_n{i}": p["generated_text"] for i, p in enumerate(pred)} for pred in preds]
|
| 311 |
+
|
| 312 |
+
# if fix_predictions:
|
| 313 |
+
# for i, (protac, pred) in enumerate(zip(protac_smiles_canon, preds)):
|
| 314 |
+
# p_fixed = {k: fix_prediction(protac, v, verbose=verbose) for k, v in pred.items()}
|
| 315 |
+
# # For each prediction, if the fixed prediction is not None, we
|
| 316 |
+
# # replace the original prediction with the fixed one.
|
| 317 |
+
# for k, v in p_fixed.items():
|
| 318 |
+
# if v is not None:
|
| 319 |
+
# preds[i][k] = v
|
| 320 |
+
|
| 321 |
+
# if isinstance(protac_smiles, pd.DataFrame):
|
| 322 |
+
# # Check if the DataFrame contains a columns named `protac_smiles_col`
|
| 323 |
+
# if protac_smiles_col not in protac_smiles.columns:
|
| 324 |
+
# raise ValueError(f"DataFrame must contain a column named \"{protac_smiles_col}\".")
|
| 325 |
+
|
| 326 |
+
# # Canonize and check if all PROTAC SMILES are valid
|
| 327 |
+
# protac_smiles_canon = protac_smiles.apply(lambda x: canonize(x[protac_smiles_col]), axis=1)
|
| 328 |
+
|
| 329 |
+
# # Check if there are invalid PROTAC SMILES
|
| 330 |
+
# if protac_smiles_canon.isnull().any():
|
| 331 |
+
# wrong_protacs = protac_smiles[protac_smiles_canon.isnull()]
|
| 332 |
+
# raise ValueError(f"Invalid PROTAC SMILES in DataFrame: {wrong_protacs}")
|
| 333 |
+
|
| 334 |
+
# # Convert the Series to a DataFrame
|
| 335 |
+
# protac_smiles_canon = pd.DataFrame(protac_smiles_canon, columns=[protac_smiles_col])
|
| 336 |
+
|
| 337 |
+
# # Convert the DataFrame to a Dataset
|
| 338 |
+
# dataset = Dataset.from_pandas(protac_smiles_canon)
|
| 339 |
+
# preds = []
|
| 340 |
+
# for pred in tqdm(pipe(KeyDataset(dataset, protac_smiles_col), batch_size=batch_size), total=len(dataset) // batch_size, desc="Generating predictions"):
|
| 341 |
+
# p = {f"default_pred_n{i}": pred[i]["generated_text"] for i in range(len(pred))}
|
| 342 |
+
# preds.append(p)
|
| 343 |
+
|
| 344 |
+
# if fix_predictions:
|
| 345 |
+
# for i, (protac, pred) in tqdm(enumerate(zip(protac_smiles_canon, preds)), desc="Fixing predictions", total=len(preds)):
|
| 346 |
+
# p_fixed = {k: fix_prediction(protac, v, verbose=verbose) for k, v in pred.items()}
|
| 347 |
+
# # For each prediction, if the fixed prediction is not None, we
|
| 348 |
+
# # replace the original prediction with the fixed one.
|
| 349 |
+
# for k, v in p_fixed.items():
|
| 350 |
+
# if v is not None:
|
| 351 |
+
# pred[k] = v
|
| 352 |
+
|
| 353 |
+
# if return_check_reassembly:
|
| 354 |
+
# if isinstance(protac_smiles_canon, str):
|
| 355 |
+
# protac_smiles_list = [protac_smiles_canon]
|
| 356 |
+
# elif isinstance(protac_smiles_canon, list):
|
| 357 |
+
# protac_smiles_list = protac_smiles_canon
|
| 358 |
+
# elif isinstance(protac_smiles_canon, pd.DataFrame):
|
| 359 |
+
# protac_smiles_list = protac_smiles_canon[protac_smiles_col].tolist()
|
| 360 |
+
|
| 361 |
+
# print("Checking re-assembly...")
|
| 362 |
+
# for protac, pred in zip(protac_smiles_list, preds):
|
| 363 |
+
# for i in range(beam_size):
|
| 364 |
+
# pred[f"reassembly_correct_n{i}"] = check_reassembly(protac, pred[f"default_pred_n{i}"])
|
| 365 |
+
|
| 366 |
+
# # Just take the first prediction if the input was a string
|
| 367 |
+
# if isinstance(protac_smiles, str):
|
| 368 |
+
# preds = preds[0]
|
| 369 |
+
|
| 370 |
+
# return preds
|
protac_splitter_app.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PROTAC Splitter Web Application
|
| 3 |
+
|
| 4 |
+
This script provides a web interface for splitting PROTAC molecules into their
|
| 5 |
+
constituent parts: E3 ligase binder, linker, and protein-of-interest (POI)
|
| 6 |
+
ligand (warhead).
|
| 7 |
+
|
| 8 |
+
The app uses the protac_splitter library to perform the splitting and offers
|
| 9 |
+
two main modes of operation:
|
| 10 |
+
1. Single SMILES processing
|
| 11 |
+
2. Batch processing via CSV file upload
|
| 12 |
+
|
| 13 |
+
Users can select which models to use:
|
| 14 |
+
- XGBoost model (default): Fast graph-based edge classification model
|
| 15 |
+
- Transformer model: More accurate but slower deep learning model
|
| 16 |
+
- If neither is selected, a rule-based splitting algorithm is used
|
| 17 |
+
|
| 18 |
+
Author: Stefano Ribes
|
| 19 |
+
Date: 2025-06
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import logging
|
| 23 |
+
import tempfile
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Union
|
| 26 |
+
|
| 27 |
+
from PIL import Image
|
| 28 |
+
import gradio as gr
|
| 29 |
+
import pandas as pd
|
| 30 |
+
from rdkit import Chem
|
| 31 |
+
from rdkit.Chem import Draw
|
| 32 |
+
|
| 33 |
+
from protac_splitter import split_protac
|
| 34 |
+
from protac_splitter.display_utils import get_mapped_protac_img
|
| 35 |
+
|
| 36 |
+
def save_svg_to_tempfile(svg_string: str, suffix: str = ".svg") -> Union[str, Path]:
|
| 37 |
+
"""
|
| 38 |
+
Write an SVG string to a temporary file and return its filesystem path.
|
| 39 |
+
"""
|
| 40 |
+
# Create a named temporary file that persists after closing
|
| 41 |
+
tmp_file = tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False, encoding="utf-8")
|
| 42 |
+
logging.debug(f"Saving SVG to temporary file: {tmp_file.name}")
|
| 43 |
+
try:
|
| 44 |
+
tmp_file.write(svg_string)
|
| 45 |
+
tmp_file.flush()
|
| 46 |
+
return Path(tmp_file.name)
|
| 47 |
+
finally:
|
| 48 |
+
tmp_file.close()
|
| 49 |
+
|
| 50 |
+
def process_single_smiles(protac_smiles: str, use_transformer: bool = False, use_xgboost: bool = True, beam_size: int = 5) -> tuple:
|
| 51 |
+
"""
|
| 52 |
+
Process a single SMILES string and generate PROTAC fragment predictions
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
protac_smiles: The SMILES string of the PROTAC molecule
|
| 56 |
+
use_transformer: Whether to use the transformer model for prediction
|
| 57 |
+
use_xgboost: Whether to use the XGBoost model for prediction
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Tuple containing input image, output images, SMILES texts and status message
|
| 61 |
+
"""
|
| 62 |
+
if not protac_smiles:
|
| 63 |
+
raise gr.Error("Please provide a valid PROTAC SMILES string.", duration=5)
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
results = split_protac(
|
| 67 |
+
protac_smiles,
|
| 68 |
+
use_transformer=use_transformer,
|
| 69 |
+
use_xgboost=use_xgboost,
|
| 70 |
+
fix_predictions=True, # Always apply fixes to predictions
|
| 71 |
+
beam_size=beam_size, # Use beam search width for Transformer model
|
| 72 |
+
verbose=1
|
| 73 |
+
)
|
| 74 |
+
except Exception as e:
|
| 75 |
+
exception_message = str(e)
|
| 76 |
+
if exception_message.startswith("Invalid PROTAC SMILES"):
|
| 77 |
+
raise gr.Error("The input SMILES string is not valid (couldn't be parsed by RDKit).", duration=5)
|
| 78 |
+
else:
|
| 79 |
+
raise gr.Error(f"An error occurred while processing the input SMILES: {exception_message}", duration=10)
|
| 80 |
+
|
| 81 |
+
valid_molecules = []
|
| 82 |
+
pred_key = f'default_pred_n0'
|
| 83 |
+
valid_molecules.append(results[pred_key])
|
| 84 |
+
|
| 85 |
+
# Generate images and corresponding SMILES text
|
| 86 |
+
images = []
|
| 87 |
+
smiles_texts = []
|
| 88 |
+
input_mol = Chem.MolFromSmiles(protac_smiles)
|
| 89 |
+
|
| 90 |
+
if input_mol is not None:
|
| 91 |
+
input_img = Draw.MolToImage(input_mol, legend="", size=(1000, 200))
|
| 92 |
+
else:
|
| 93 |
+
input_img = Image.new('RGB', (1000, 1000))
|
| 94 |
+
|
| 95 |
+
splits = {}
|
| 96 |
+
for smiles in results[pred_key].split("."):
|
| 97 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 98 |
+
if mol:
|
| 99 |
+
if "[*:1]" in smiles and "[*:2]" in smiles:
|
| 100 |
+
legend = "Linker"
|
| 101 |
+
splits['linker'] = smiles
|
| 102 |
+
elif "[*:1]" in smiles:
|
| 103 |
+
legend = "Warhead"
|
| 104 |
+
splits['poi'] = smiles
|
| 105 |
+
elif "[*:2]" in smiles:
|
| 106 |
+
legend = "E3 Ligase Ligand"
|
| 107 |
+
splits['e3'] = smiles
|
| 108 |
+
|
| 109 |
+
img = Draw.MolToImage(mol, legend="", size=(1000, 1000))
|
| 110 |
+
images.append(img)
|
| 111 |
+
smiles_texts.append(f"{legend}: {smiles}")
|
| 112 |
+
smiles_texts = "\n".join(smiles_texts)
|
| 113 |
+
|
| 114 |
+
use_svg = False
|
| 115 |
+
input_img = get_mapped_protac_img(
|
| 116 |
+
protac_smiles=protac_smiles,
|
| 117 |
+
poi_smiles=splits.get('poi', ''),
|
| 118 |
+
linker_smiles=splits.get('linker', ''),
|
| 119 |
+
e3_smiles=splits.get('e3', ''),
|
| 120 |
+
w=1000,
|
| 121 |
+
h=500,
|
| 122 |
+
legend=None,
|
| 123 |
+
useSVG=use_svg,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
if use_svg:
|
| 127 |
+
input_img = save_svg_to_tempfile(input_img)
|
| 128 |
+
logging.debug(f"Returning processed image path: {input_img}")
|
| 129 |
+
|
| 130 |
+
return input_img, list(images), smiles_texts
|
| 131 |
+
|
| 132 |
+
def process_csv(
|
| 133 |
+
file: gr.File,
|
| 134 |
+
smiles_col: str,
|
| 135 |
+
use_transformer: bool = False,
|
| 136 |
+
use_xgboost: bool = True,
|
| 137 |
+
beam_size: int = 5,
|
| 138 |
+
batch_size: int = 4,
|
| 139 |
+
num_proc: int = 2,
|
| 140 |
+
# NOTE: `pr` is a progress tracker, it is used to track the progress but
|
| 141 |
+
# it is not used in this function. Do not remove it.
|
| 142 |
+
pr: gr.Progress = gr.Progress(track_tqdm=True),
|
| 143 |
+
) -> Path:
|
| 144 |
+
"""
|
| 145 |
+
Process a CSV file containing PROTAC SMILES
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
file: Uploaded CSV file
|
| 149 |
+
smiles_col: Name of the column containing SMILES strings
|
| 150 |
+
use_transformer: Whether to use the transformer model for prediction
|
| 151 |
+
use_xgboost: Whether to use the XGBoost model for prediction
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Path to output CSV file with predictions
|
| 155 |
+
"""
|
| 156 |
+
df = pd.read_csv(file.name)
|
| 157 |
+
if smiles_col not in df.columns:
|
| 158 |
+
# Use Gradio's error message instead of raising an exception
|
| 159 |
+
raise gr.Error(f"Column \"{smiles_col}\" is not in the provided CSV file.", duration=5)
|
| 160 |
+
|
| 161 |
+
try:
|
| 162 |
+
results = split_protac(
|
| 163 |
+
df,
|
| 164 |
+
use_transformer=use_transformer,
|
| 165 |
+
use_xgboost=use_xgboost,
|
| 166 |
+
protac_smiles_col=smiles_col,
|
| 167 |
+
fix_predictions=True,
|
| 168 |
+
batch_size=batch_size,
|
| 169 |
+
num_proc=num_proc,
|
| 170 |
+
beam_size=beam_size, # Use beam search width for Transformer model
|
| 171 |
+
verbose=1
|
| 172 |
+
)
|
| 173 |
+
except Exception as e:
|
| 174 |
+
exception_message = str(e)
|
| 175 |
+
if exception_message.startswith("Invalid PROTAC SMILES"):
|
| 176 |
+
raise gr.Error("One or more of the input SMILES are not valid (couldn't be parsed by RDKit).", duration=5)
|
| 177 |
+
else:
|
| 178 |
+
raise gr.Error(f"An error occurred while processing: {exception_message}", duration=10)
|
| 179 |
+
|
| 180 |
+
output_df = pd.DataFrame(results)
|
| 181 |
+
|
| 182 |
+
# Create a temporary output file
|
| 183 |
+
output_file = str(Path(tempfile.gettempdir()) / "split_preds.csv")
|
| 184 |
+
logging.debug(f"Saving predictions to temporary file: {output_file}")
|
| 185 |
+
output_df.to_csv(output_file, index=False)
|
| 186 |
+
logging.debug(f"Output DataFrame saved to: {output_file}")
|
| 187 |
+
|
| 188 |
+
return output_file
|
| 189 |
+
|
| 190 |
+
def create_interface():
|
| 191 |
+
"""
|
| 192 |
+
Create and return the Gradio interface for the PROTAC Splitter app
|
| 193 |
+
|
| 194 |
+
The interface includes two tabs:
|
| 195 |
+
1. Single SMILES Input - For processing individual PROTAC SMILES
|
| 196 |
+
2. CSV Upload - For batch processing of multiple PROTAC SMILES
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
gr.Blocks: The Gradio interface
|
| 200 |
+
"""
|
| 201 |
+
with gr.Blocks() as demo:
|
| 202 |
+
header = """# PROTAC-Splitter Web Application
|
| 203 |
+
|
| 204 |
+
Upload a CSV file or enter a single SMILES string to predict PROTAC substructures.
|
| 205 |
+
|
| 206 |
+
Warheads and E3 ligase ligands connections to the linker are marked with dummy atoms, _i.e._, attachment points, as follows:
|
| 207 |
+
|
| 208 |
+
- Warhead: `[*:1]`
|
| 209 |
+
- E3 Ligase ligand: `[*:2]`
|
| 210 |
+
|
| 211 |
+
"""
|
| 212 |
+
gr.Markdown(header)
|
| 213 |
+
|
| 214 |
+
# Model selection section - common to both tabs
|
| 215 |
+
model_selection = """## Model Selection
|
| 216 |
+
|
| 217 |
+
You can choose which model to use for splitting PROTAC molecules:
|
| 218 |
+
|
| 219 |
+
- **XGBoost model** (default): Fast graph-based edge classification model
|
| 220 |
+
- **Transformer model**: More accurate but slower deep learning model
|
| 221 |
+
- If both are selected, the Transformer model will be used first, then if it fails, the XGBoost model will be used.
|
| 222 |
+
- If no model is selected, splitting will be done using graph-based heuristics, with no AI model involved.
|
| 223 |
+
|
| 224 |
+
For fast splitting, we reccommend using the XGBoost model only, which is fast and efficient for most cases. The Transformer model might be more accurate but it is slower, especially for processing large CSV files.
|
| 225 |
+
"""
|
| 226 |
+
gr.Markdown(model_selection)
|
| 227 |
+
with gr.Row():
|
| 228 |
+
with gr.Column(scale=2):
|
| 229 |
+
with gr.Row():
|
| 230 |
+
use_xgboost = gr.Checkbox(label="Use XGBoost model", value=True)
|
| 231 |
+
use_transformer = gr.Checkbox(label="Use Transformer model", value=False)
|
| 232 |
+
|
| 233 |
+
# Performance configuration section
|
| 234 |
+
performance_configs = """### Performance Configurations
|
| 235 |
+
|
| 236 |
+
Change the following parameters to optimize performance based on your machine's capabilities. Particularly useful when processing large CSV files or when using the Transformer model.
|
| 237 |
+
For single SMILES processing, the default values should work well in most cases.
|
| 238 |
+
"""
|
| 239 |
+
gr.Markdown(performance_configs)
|
| 240 |
+
with gr.Column(scale=1):
|
| 241 |
+
# Add a num_proc input
|
| 242 |
+
with gr.Row():
|
| 243 |
+
num_proc = gr.Number(
|
| 244 |
+
label="Number of Processes",
|
| 245 |
+
value=2,
|
| 246 |
+
minimum=1,
|
| 247 |
+
maximum=8,
|
| 248 |
+
step=1,
|
| 249 |
+
info="Number of processes to use for parallel processing. Higher values may improve performance but require more memory."
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# Add a number input for beam_size if Transformer model is selected
|
| 253 |
+
with gr.Row():
|
| 254 |
+
# Only show beam size input if Transformer model is selected
|
| 255 |
+
beam_size = gr.Number(
|
| 256 |
+
label="Beam Search Width",
|
| 257 |
+
value=5,
|
| 258 |
+
minimum=1,
|
| 259 |
+
maximum=10,
|
| 260 |
+
step=1,
|
| 261 |
+
info="Width of the beam search for the Transformer model. Higher values may improve accuracy but increase processing time.",
|
| 262 |
+
visible=use_transformer.value # Initially hidden, will be shown if Transformer is selected
|
| 263 |
+
)
|
| 264 |
+
# Add a dynamic visibility condition to show/hide beam_size based on Transformer model selection
|
| 265 |
+
use_transformer.change(
|
| 266 |
+
lambda x: gr.update(visible=x),
|
| 267 |
+
inputs=[use_transformer],
|
| 268 |
+
outputs=[beam_size]
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# Add a batch size input for Transformer model if selected
|
| 272 |
+
with gr.Row():
|
| 273 |
+
batch_size = gr.Number(
|
| 274 |
+
label="Batch Size",
|
| 275 |
+
value=4,
|
| 276 |
+
minimum=1,
|
| 277 |
+
maximum=64,
|
| 278 |
+
step=1,
|
| 279 |
+
info="Batch size for processing. Higher values may improve performance, especially on GPU machines, but require more memory.",
|
| 280 |
+
visible=use_transformer.value # Initially hidden, will be shown if Transformer is selected
|
| 281 |
+
)
|
| 282 |
+
use_transformer.change(
|
| 283 |
+
lambda x: gr.update(visible=x),
|
| 284 |
+
inputs=[use_transformer],
|
| 285 |
+
outputs=[batch_size]
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Single SMILES Input tab
|
| 289 |
+
gr.Markdown("## Specify Inputs")
|
| 290 |
+
with gr.Tab("Single SMILES Input"):
|
| 291 |
+
# Input area
|
| 292 |
+
smiles_input = gr.Textbox(
|
| 293 |
+
label="Enter SMILES String",
|
| 294 |
+
placeholder="E.g., CC(C)(C)S(=O)(=O)c1cc2c(Nc3ccc4scnc4c3)ccnc2cc1OCCOCCOCCOCCOCC(=O)Nc1cccc2c1CN(C1CCC(=O)NC1=O)C2=O",
|
| 295 |
+
# value="CC(C)(C)S(=O)(=O)c1cc2c(Nc3ccc4scnc4c3)ccnc2cc1OCCOCCOCCOCCOCC(=O)Nc1cccc2c1CN(C1CCC(=O)NC1=O)C2=O",
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
submit_smiles = gr.Button("Process SMILES")
|
| 299 |
+
|
| 300 |
+
# Output area
|
| 301 |
+
smiles_input_image = gr.Image(label="Input PROTAC", type="filepath") # Use None to allow SVG input
|
| 302 |
+
smiles_output_images = gr.Gallery(label="Valid Splits", columns=3)
|
| 303 |
+
smiles_output_texts = gr.Textbox(label="SMILES of the Splits", interactive=False, lines=3)
|
| 304 |
+
|
| 305 |
+
# Connect the button click event to the processing function
|
| 306 |
+
submit_smiles.click(
|
| 307 |
+
process_single_smiles,
|
| 308 |
+
inputs=[smiles_input, use_transformer, use_xgboost, beam_size],
|
| 309 |
+
outputs=[smiles_input_image, smiles_output_images, smiles_output_texts]
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# CSV file processing tab
|
| 313 |
+
with gr.Tab("Upload CSV"):
|
| 314 |
+
# File upload area
|
| 315 |
+
file_input = gr.File(label="Upload CSV File")
|
| 316 |
+
smiles_column = gr.Textbox(
|
| 317 |
+
label="Column Name for PROTAC SMILES",
|
| 318 |
+
placeholder="E.g., \"PROTAC SMILES\"",
|
| 319 |
+
# value="PROTAC SMILES",
|
| 320 |
+
)
|
| 321 |
+
submit_csv = gr.Button("Process CSV")
|
| 322 |
+
|
| 323 |
+
# Output file download area
|
| 324 |
+
download_output = gr.File(label="Download Predictions")
|
| 325 |
+
|
| 326 |
+
# Connect the button click event to the processing function
|
| 327 |
+
submit_csv.click(
|
| 328 |
+
process_csv,
|
| 329 |
+
inputs=[file_input, smiles_column, use_transformer, use_xgboost, beam_size, batch_size, num_proc],
|
| 330 |
+
outputs=[download_output]
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
csv_notes = f"""**Note:** The output CSV will contain the following columns:
|
| 334 |
+
|
| 335 |
+
- `{smiles_column}`: The original PROTAC SMILES string
|
| 336 |
+
- `default_pred_n0`: The predicted SMILES strings for the splits
|
| 337 |
+
- `model_name`: The model used for the prediction
|
| 338 |
+
"""
|
| 339 |
+
gr.Markdown(csv_notes)
|
| 340 |
+
|
| 341 |
+
return demo
|
| 342 |
+
|
| 343 |
+
# Create the Gradio interface
|
| 344 |
+
# NOTE: `demo` must be a global variable, so to make the Gradio’s hot-reload system work.
|
| 345 |
+
# NOTE: Launch the app with `gradio scripts/protac_splitter_app.py` to develop it.
|
| 346 |
+
demo = create_interface()
|
| 347 |
+
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
# Set logging level to DEBUG for detailed output
|
| 350 |
+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 351 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.3.0
|
| 2 |
+
aiofiles==24.1.0
|
| 3 |
+
aiohappyeyeballs==2.6.1
|
| 4 |
+
aiohttp==3.12.13
|
| 5 |
+
aiosignal==1.3.2
|
| 6 |
+
alembic==1.16.2
|
| 7 |
+
annotated-types==0.7.0
|
| 8 |
+
anyio==4.9.0
|
| 9 |
+
asttokens==3.0.0
|
| 10 |
+
attrs==25.3.0
|
| 11 |
+
cairocffi==1.7.1
|
| 12 |
+
CairoSVG==2.8.2
|
| 13 |
+
certifi==2025.6.15
|
| 14 |
+
cffi==1.17.1
|
| 15 |
+
charset-normalizer==3.4.2
|
| 16 |
+
click==8.2.1
|
| 17 |
+
colorlog==6.9.0
|
| 18 |
+
contourpy==1.3.2
|
| 19 |
+
cssselect2==0.8.0
|
| 20 |
+
cycler==0.12.1
|
| 21 |
+
datasets==3.0.0
|
| 22 |
+
decorator==5.2.1
|
| 23 |
+
defusedxml==0.7.1
|
| 24 |
+
dill==0.3.8
|
| 25 |
+
docstring_parser==0.16
|
| 26 |
+
evaluate==0.4.3
|
| 27 |
+
executing==2.2.0
|
| 28 |
+
fastapi==0.115.14
|
| 29 |
+
ffmpy==0.6.0
|
| 30 |
+
filelock==3.18.0
|
| 31 |
+
fonttools==4.58.4
|
| 32 |
+
frozenlist==1.7.0
|
| 33 |
+
fsspec==2024.6.1
|
| 34 |
+
gradio==5.35.0
|
| 35 |
+
gradio_client==1.10.4
|
| 36 |
+
groovy==0.1.2
|
| 37 |
+
h11==0.16.0
|
| 38 |
+
hf-xet==1.1.5
|
| 39 |
+
httpcore==1.0.9
|
| 40 |
+
httpx==0.28.1
|
| 41 |
+
huggingface-hub==0.33.1
|
| 42 |
+
idna==3.10
|
| 43 |
+
imbalanced-learn==0.13.0
|
| 44 |
+
imblearn==0.0
|
| 45 |
+
iniconfig==2.1.0
|
| 46 |
+
ipython==9.4.0
|
| 47 |
+
ipython_pygments_lexers==1.1.1
|
| 48 |
+
jedi==0.19.2
|
| 49 |
+
Jinja2==3.1.6
|
| 50 |
+
joblib==1.5.1
|
| 51 |
+
jsonargparse==4.40.0
|
| 52 |
+
kiwisolver==1.4.8
|
| 53 |
+
lightning-utilities==0.14.3
|
| 54 |
+
llvmlite==0.44.0
|
| 55 |
+
Mako==1.3.10
|
| 56 |
+
markdown-it-py==3.0.0
|
| 57 |
+
MarkupSafe==3.0.2
|
| 58 |
+
matplotlib==3.10.3
|
| 59 |
+
matplotlib-inline==0.1.7
|
| 60 |
+
mdurl==0.1.2
|
| 61 |
+
mpmath==1.3.0
|
| 62 |
+
multidict==6.6.3
|
| 63 |
+
multiprocess==0.70.16
|
| 64 |
+
networkx==3.1
|
| 65 |
+
numba==0.61.0
|
| 66 |
+
numpy==1.26.4
|
| 67 |
+
optuna==4.2.0
|
| 68 |
+
ordered-set==4.1.0
|
| 69 |
+
orjson==3.10.18
|
| 70 |
+
packaging==25.0
|
| 71 |
+
pandas==2.2.2
|
| 72 |
+
parso==0.8.4
|
| 73 |
+
pexpect==4.9.0
|
| 74 |
+
pillow==11.3.0
|
| 75 |
+
pluggy==1.6.0
|
| 76 |
+
prompt_toolkit==3.0.51
|
| 77 |
+
propcache==0.3.2
|
| 78 |
+
psutil==7.0.0
|
| 79 |
+
ptyprocess==0.7.0
|
| 80 |
+
pure_eval==0.2.3
|
| 81 |
+
pyarrow==20.0.0
|
| 82 |
+
pycparser==2.22
|
| 83 |
+
pydantic==2.11.7
|
| 84 |
+
pydantic_core==2.33.2
|
| 85 |
+
pydub==0.25.1
|
| 86 |
+
Pygments==2.19.2
|
| 87 |
+
PyLaTeX==1.4.2
|
| 88 |
+
pyparsing==3.2.3
|
| 89 |
+
pytest==8.4.1
|
| 90 |
+
python-dateutil==2.9.0.post0
|
| 91 |
+
python-multipart==0.0.20
|
| 92 |
+
pytz==2025.2
|
| 93 |
+
PyYAML==6.0.2
|
| 94 |
+
rdkit==2024.9.4
|
| 95 |
+
regex==2024.11.6
|
| 96 |
+
requests==2.32.4
|
| 97 |
+
rich==14.0.0
|
| 98 |
+
ruff==0.12.1
|
| 99 |
+
safehttpx==0.1.6
|
| 100 |
+
safetensors==0.5.3
|
| 101 |
+
scikit-learn==1.6.1
|
| 102 |
+
scipy==1.14.1
|
| 103 |
+
seaborn==0.13.2
|
| 104 |
+
semantic-version==2.10.0
|
| 105 |
+
setuptools==80.9.0
|
| 106 |
+
shellingham==1.5.4
|
| 107 |
+
shtab==1.7.2
|
| 108 |
+
six==1.17.0
|
| 109 |
+
sklearn-compat==0.1.3
|
| 110 |
+
sniffio==1.3.1
|
| 111 |
+
SQLAlchemy==2.0.41
|
| 112 |
+
stack-data==0.6.3
|
| 113 |
+
starlette==0.46.2
|
| 114 |
+
sympy==1.13.1
|
| 115 |
+
threadpoolctl==3.6.0
|
| 116 |
+
tinycss2==1.4.0
|
| 117 |
+
tokenizers==0.19.1
|
| 118 |
+
tomlkit==0.13.3
|
| 119 |
+
torch==2.6.0
|
| 120 |
+
torchmetrics==1.7.3
|
| 121 |
+
tqdm==4.67.1
|
| 122 |
+
traitlets==5.14.3
|
| 123 |
+
transformers==4.44.2
|
| 124 |
+
trl==0.10.1
|
| 125 |
+
typeguard==4.4.4
|
| 126 |
+
typer==0.16.0
|
| 127 |
+
typing-inspection==0.4.1
|
| 128 |
+
typing_extensions==4.14.0
|
| 129 |
+
tyro==0.9.25
|
| 130 |
+
tzdata==2025.2
|
| 131 |
+
urllib3==2.5.0
|
| 132 |
+
uvicorn==0.35.0
|
| 133 |
+
wcwidth==0.2.13
|
| 134 |
+
webencodings==0.5.1
|
| 135 |
+
websockets==15.0.1
|
| 136 |
+
xgboost==3.0.1
|
| 137 |
+
xxhash==3.5.0
|
| 138 |
+
yarl==1.20.1
|