ribesstefano commited on
Commit
9dd777e
·
1 Parent(s): 7ca0099

Setup the spaces app

Browse files
Files changed (37) hide show
  1. README.md +8 -5
  2. protac_splitter/__init__.py +11 -0
  3. protac_splitter/chemoinformatics.py +487 -0
  4. protac_splitter/data/__init__.py +0 -0
  5. protac_splitter/data/curation/__init__.py +11 -0
  6. protac_splitter/data/curation/bond_adjustments.py +407 -0
  7. protac_splitter/data/curation/curation.py +894 -0
  8. protac_splitter/data/curation/mapping_utils.py +77 -0
  9. protac_splitter/data/curation/substructure_extraction.py +586 -0
  10. protac_splitter/data/generation/__init__.py +11 -0
  11. protac_splitter/data/generation/functional_groups.py +400 -0
  12. protac_splitter/data/generation/generation.py +277 -0
  13. protac_splitter/display_utils.py +199 -0
  14. protac_splitter/drawing_utils.py +177 -0
  15. protac_splitter/evaluation.py +495 -0
  16. protac_splitter/fixing_functions.py +355 -0
  17. protac_splitter/graphs/README.md +114 -0
  18. protac_splitter/graphs/__init__.py +0 -0
  19. protac_splitter/graphs/e3_clustering.py +321 -0
  20. protac_splitter/graphs/edge_classifier.py +582 -0
  21. protac_splitter/graphs/edge_features.py +293 -0
  22. protac_splitter/graphs/splitting_algorithms.py +512 -0
  23. protac_splitter/graphs/utils.py +67 -0
  24. protac_splitter/graphs_utils.py +190 -0
  25. protac_splitter/llms/__init__.py +0 -0
  26. protac_splitter/llms/data_utils.py +296 -0
  27. protac_splitter/llms/evaluation.py +169 -0
  28. protac_splitter/llms/hf_utils.py +36 -0
  29. protac_splitter/llms/model_utils.py +256 -0
  30. protac_splitter/llms/training.py +869 -0
  31. protac_splitter/llms/training_causal_model.py +87 -0
  32. protac_splitter/llms/training_mlm_model.py +287 -0
  33. protac_splitter/llms/training_rl_models.py +406 -0
  34. protac_splitter/protac_cheminformatics.py +120 -0
  35. protac_splitter/protac_splitter.py +370 -0
  36. protac_splitter_app.py +351 -0
  37. requirements.txt +138 -0
README.md CHANGED
@@ -1,14 +1,17 @@
1
  ---
2
- title: PROTAC Splitter
3
- emoji: 👁
4
- colorFrom: gray
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.35.0
8
- app_file: app.py
 
9
  pinned: false
10
  license: mit
11
  short_description: App to split given PROTACs into their substructures.
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
  ---
2
+ title: PROTAC-Splitter
3
+ emoji: ✂️
4
+ colorFrom: green
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.35.0
8
+ python_version: 3.10
9
+ app_file: protac_splitter_app.py
10
  pinned: false
11
  license: mit
12
  short_description: App to split given PROTACs into their substructures.
13
  ---
14
 
15
+ # PROTAC-Splitter
16
+
17
+ This repository contains a program to split PROTAC molecules into their substructures.
protac_splitter/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PROTAC Splitter package for splitting PROTAC SMILES into substructures."""
2
+ from protac_splitter.protac_splitter import split_protac
3
+ from protac_splitter.fixing_functions import fix_prediction
4
+ from protac_splitter.graphs.splitting_algorithms import split_protac_graph_based
5
+ from protac_splitter.evaluation import (
6
+ check_reassembly,
7
+ split_prediction,
8
+ )
9
+
10
+ __version__ = "1.0.0"
11
+ __author__ = "Stefano Ribes and Anders Källberg"
protac_splitter/chemoinformatics.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Chemoinformatics utilities for PROTAC Splitter. """
2
+ import logging
3
+ from typing import List, Union, Optional, Literal
4
+ from multiprocessing import Process, Queue
5
+ from hashlib import sha256
6
+
7
+ from rdkit import Chem
8
+ from rdkit.Chem import rdFingerprintGenerator
9
+
10
+
11
+ def GetSubstructMatchesWorker(q, mol, substruct, useChirality, maxMatches):
12
+ """ Worker function to get substructure matches in a separate process. """
13
+ q.put(list(mol.GetSubstructMatches(
14
+ substruct,
15
+ useChirality=useChirality,
16
+ maxMatches=maxMatches,
17
+ )))
18
+
19
+
20
+ def GetSubstructMatchesWithTimeout(
21
+ mol: Chem.Mol,
22
+ substruct: Chem.Mol,
23
+ useChirality: bool = True,
24
+ maxMatches: int = 50,
25
+ timeout: Union[int, float] = 10,
26
+ ) -> Optional[List[List[int]]]:
27
+ """ Get substructure matches with a timeout.
28
+
29
+ Args:
30
+ mol (Chem.Mol): The molecule to search for substructure matches.
31
+ substruct (Chem.Mol): The substructure to search for in the molecule.
32
+ useChirality (bool, optional): Whether to use chirality in the substructure search. Defaults to True.
33
+ maxMatches (int, optional): The maximum number of matches to return. Defaults to 50.
34
+ timeout (int | float, optional): The timeout in seconds. Defaults to 10.
35
+
36
+ Returns:
37
+ Optional[List[List[int]]]: A list of lists containing the atom indices of the substructure matches. Returns None if the search times out or failed.
38
+ """
39
+ q = Queue()
40
+ p = Process(
41
+ target=GetSubstructMatchesWorker,
42
+ args=(q, mol, substruct, useChirality, maxMatches),
43
+ )
44
+ p.start()
45
+ p.join(timeout)
46
+
47
+ if p.is_alive():
48
+ p.terminate()
49
+ p.join()
50
+ return None
51
+ return q.get()
52
+
53
+
54
+ def remove_stereo(smiles: str) -> str:
55
+ """
56
+ Remove stereochemistry from a SMILES string.
57
+
58
+ Args:
59
+ smiles (str): The input SMILES string.
60
+
61
+ Returns:
62
+ str: The SMILES string with stereochemistry removed.
63
+ """
64
+ try:
65
+ mol = Chem.MolFromSmiles(smiles)
66
+ Chem.rdmolops.RemoveStereochemistry(mol)
67
+ return Chem.MolToSmiles(mol)
68
+ except Exception as e:
69
+ logging.warning(f"Error removing stereochemistry: {e}")
70
+ return None
71
+
72
+
73
+ def get_mol(smiles: str, remove_stereo: bool = False) -> Chem.Mol:
74
+ """
75
+ Get a molecule object from a SMILES string.
76
+
77
+ Args:
78
+ smiles (str): The SMILES string representing the molecule.
79
+
80
+ Returns:
81
+ Chem.Mol: The molecule object.
82
+ """
83
+ mol = Chem.MolFromSmiles(smiles)
84
+ if mol is not None and remove_stereo:
85
+ Chem.rdmolops.RemoveStereochemistry(mol)
86
+ return mol
87
+
88
+
89
+ def canonize_smarts(smarts: str) -> str:
90
+ """
91
+ Cleans a SMARTS string by converting it to canonical SMARTS representation.
92
+
93
+ NOTE: It might not work for complex patterns: https://github.com/rdkit/rdkit/discussions/6929
94
+
95
+ Args:
96
+ smarts (str): The input SMARTS string.
97
+
98
+ Returns:
99
+ str: The cleaned SMARTS string.
100
+ """
101
+ mol = Chem.MolFromSmarts(smarts)
102
+
103
+ if mol is None:
104
+ return None
105
+ canonical_smarts = Chem.MolToSmarts(Chem.MolFromSmiles(Chem.MolToSmiles(mol), sanitize=False))
106
+ return canonical_smarts
107
+
108
+
109
+ def smiles2mol(smiles: str) -> Chem.Mol:
110
+ """Converts a SMILES string to an RDKit molecule object.
111
+
112
+ Args:
113
+ smiles (str): The input SMILES string.
114
+
115
+ Returns:
116
+ Chem.Mol: The RDKit molecule object.
117
+ """
118
+ return Chem.MolFromSmiles(smiles)
119
+
120
+
121
+ def mol2smiles(mol: Chem.Mol) -> str:
122
+ """Converts an RDKit molecule object to a SMILES string.
123
+
124
+ Args:
125
+ mol (Chem.Mol): The RDKit molecule object.
126
+
127
+ Returns:
128
+ str: The SMILES string.
129
+ """
130
+ return Chem.MolToSmiles(mol)
131
+
132
+
133
+ def canonize_smiles(smiles: str) -> str:
134
+ """ Canonizes a SMILES string by converting it to canonical SMILES representation.
135
+
136
+ Args:
137
+ smiles (str): The input SMILES string.
138
+
139
+ Returns:
140
+ str: The canonized SMILES string.
141
+ """
142
+ if smiles is None:
143
+ return None
144
+ try:
145
+ mol = Chem.MolFromSmiles(smiles)
146
+ except Exception as e:
147
+ print(f"Error: {e}")
148
+ return None
149
+ if mol is None:
150
+ return None
151
+ try:
152
+ return Chem.MolToSmiles(mol, canonical=True)
153
+ except:
154
+ return None
155
+
156
+
157
+ def canonize(x: Union[str, Chem.Mol]) -> Union[str, Chem.Mol]:
158
+ """ Canonizes a SMILES string or RDKit molecule object.
159
+
160
+ Args:
161
+ x: The input SMILES string or RDKit molecule object.
162
+
163
+ Returns:
164
+ str | Chem.Mol: The canonized SMILES string or RDKit molecule object, according to the input type.
165
+ """
166
+ if x is None:
167
+ return None
168
+ if isinstance(x, str):
169
+ return canonize_smiles(x)
170
+ return Chem.MolFromSmiles(Chem.MolToSmiles(x, canonical=True))
171
+
172
+
173
+ def compute_RDKitFP(
174
+ smiles: Union[str, List[str], List[Chem.Mol]],
175
+ maxPath: int = 7,
176
+ fpSize: int = 2048,
177
+ ) -> List[Chem.RDKFingerprint]:
178
+ """
179
+ Compute RDKit fingerprints for a given list of SMILES strings or RDKit molecules.
180
+
181
+ Args:
182
+ smiles (Union[str, List[str], List[Chem.Mol]]): A single SMILES string or a list of SMILES strings
183
+ or a list of RDKit molecules.
184
+ maxPath (int, optional): The maximum path length for the fingerprints. Defaults to 7.
185
+ fpSize (int, optional): The size of the fingerprint vector. Defaults to 2048.
186
+
187
+ Returns:
188
+ List[Chem.RDKFingerprint]: A list of RDKit fingerprints computed from the input SMILES strings or molecules.
189
+ """
190
+ if isinstance(smiles[0], str):
191
+ mols = [get_mol(smi) for smi in smiles]
192
+ else:
193
+ mols = smiles # assume mols were fed instead
194
+ rdgen = rdFingerprintGenerator.GetRDKitFPGenerator(
195
+ maxPath=maxPath, fpSize=fpSize)
196
+ fps = [rdgen.GetCountFingerprint(mol) for mol in mols]
197
+ return fps
198
+
199
+
200
+ def remove_dummy_atoms(mol: Union[str, Chem.Mol], canonical=True) -> Union[str, Chem.Mol]:
201
+ """
202
+ Removes all dummy atoms (attachment points) from a molecule.
203
+
204
+ Args:
205
+ mol: RDKit Mol object with dummy atoms.
206
+
207
+ Returns:
208
+ A new RDKit Mol object without dummy atoms.
209
+ """
210
+ return_smiles = False
211
+ if isinstance(mol, str):
212
+ return_smiles = True
213
+ mol = Chem.MolFromSmiles(mol)
214
+
215
+ if mol is None:
216
+ return None
217
+
218
+ # Remove all dummy atoms with a query
219
+ mol_no_dummy = Chem.DeleteSubstructs(mol, Chem.MolFromSmarts('[#0]'))
220
+
221
+ if mol_no_dummy is None:
222
+ # --------------------------------------------------------------------------
223
+ # Other approach: editing molecule and removing dummy atoms
224
+ # --------------------------------------------------------------------------
225
+ # Create an editable molecule to remove atoms
226
+ editable_mol = Chem.EditableMol(mol)
227
+
228
+ # List of atoms to remove (dummy atoms have atomic number 0)
229
+ dummy_atoms = [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetAtomicNum() == 0]
230
+
231
+ # Remove dummy atoms
232
+ for atom_idx in sorted(dummy_atoms, reverse=True): # Remove from the highest index to avoid index shifts
233
+ editable_mol.RemoveAtom(atom_idx)
234
+
235
+ if editable_mol is None:
236
+ return None
237
+
238
+ # Return the modified molecule
239
+ if return_smiles:
240
+ return Chem.MolToSmiles(editable_mol.GetMol())
241
+ editable_mol = editable_mol.GetMol()
242
+ editable_mol.UpdatePropertyCache()
243
+ return editable_mol
244
+ # --------------------------------------------------------------------------
245
+
246
+ # Return the modified molecule
247
+ if return_smiles:
248
+ return Chem.MolToSmiles(mol_no_dummy, canonical=canonical)
249
+ return mol_no_dummy
250
+
251
+
252
+ def dummy2query(mol: Chem.Mol) -> Chem.Mol:
253
+ """ Converts dummy atoms to query atoms, so that a molecule with attachment points can be used in HasSubstructMatch.
254
+
255
+ Args:
256
+ mol: The molecule to convert.
257
+
258
+ Returns:
259
+ The molecule with dummy atoms converted to query atoms
260
+ """
261
+ if mol is None:
262
+ return None
263
+ p = Chem.AdjustQueryParameters.NoAdjustments()
264
+ p.makeDummiesQueries = True
265
+ return Chem.AdjustQueryProperties(mol, p)
266
+
267
+
268
+ def get_substr_match(
269
+ protac_mol: Chem.Mol,
270
+ substr: Chem.Mol,
271
+ max_allowed_fragments: int = 1,
272
+ replace: Literal['core', 'sidechains'] = 'core',
273
+ useChirality: bool = True,
274
+ ) -> bool:
275
+ """ Check if a molecule contains a substructure match with a given molecule.
276
+ Compared to RDKit HasSubstructMatch, this function also checks the number of fragments when replacing the substr in the PROTAC.
277
+
278
+ Args:
279
+ protac_mol (Chem.Mol): The PROTAC molecule.
280
+ substr (Chem.Mol): The substructure molecule.
281
+ max_allowed_fragments (int, optional): The maximum number of fragments allowed when replacing the substr in the PROTAC. Defaults to 1. Example when equal to 1: if removing the warhead, a single fragment should remain.
282
+
283
+ Returns:
284
+ bool: True if the PROTAC contains a substructure match with the given molecule and the fragments count is equal, False otherwise.
285
+ """
286
+ # Count the number of fragments when replacing the substr in the PROTAC
287
+ if replace == 'core':
288
+ fragments = Chem.ReplaceCore(protac_mol, dummy2query(substr), useChirality=useChirality)
289
+ elif replace == 'sidechains':
290
+ fragments = Chem.ReplaceSidechains(protac_mol, dummy2query(substr), useChirality=useChirality)
291
+ else:
292
+ raise ValueError(f"replace argument should be either 'core' or 'sidechains', provided: {replace}")
293
+ # Check if the number of fragments is equal to the max allowed fragments
294
+ if fragments is None:
295
+ return False
296
+ try:
297
+ fragments = Chem.GetMolFrags(fragments, sanitizeFrags=False)
298
+ except Exception as e:
299
+ print(e)
300
+ return False
301
+ return len(fragments) == max_allowed_fragments
302
+
303
+
304
+ def remove_attach_atom(mol: Chem.Mol, attach_id: int, sanitize: bool = False) -> Chem.Mol:
305
+ """ Removes the atom with the specified attachment id from the molecule.
306
+
307
+ Example:
308
+
309
+ >>> remove_attach_atom(Chem.MolFromSmiles('CC[*:1]'), 1)
310
+ CC
311
+
312
+ There are no checks on the molecule, so it is assumed it is not None.
313
+
314
+ Args:
315
+ mol (Chem.Mol): The molecule.
316
+ attach_id (int): The attachment id of the atom to remove.
317
+ sanitize (bool, optional): Whether to sanitize the molecule after removing the atom. When used in `fix_prediction` function, it is used to "remove" substructures, so there is no need to have them sanitized. Default: False.
318
+
319
+ Returns:
320
+ (Chem.Mol) The molecule with the atom removed.
321
+ """
322
+ atoms_to_remove = []
323
+ for atom in mol.GetAtoms():
324
+ if atom.GetAtomicNum() == 0: # Dummy atom
325
+ map_num = atom.GetAtomMapNum()
326
+ if map_num == attach_id: # Targeting only [*:attach_id]
327
+ atoms_to_remove.append(atom.GetIdx())
328
+
329
+ # Remove atoms using an EditableMol
330
+ editable_mol = Chem.EditableMol(mol)
331
+ for idx in sorted(atoms_to_remove, reverse=True): # Remove from highest index to avoid shifting
332
+ editable_mol.RemoveAtom(idx)
333
+
334
+ # Convert back to a molecule
335
+ new_mol = editable_mol.GetMol()
336
+ if sanitize:
337
+ Chem.SanitizeMol(new_mol)
338
+ return new_mol
339
+
340
+
341
+ def get_bond_idx(smi: str, bonds_start_end_atoms: List[List[int]]) -> List[int]:
342
+ """
343
+ Get the indices of bonds in a molecule that match the given start and end atom indices.
344
+
345
+ Args:
346
+ smi (str): The SMILES representation of the molecule.
347
+ bonds_start_end_atoms (List[List[int]]): A list of lists containing the start and end atom indices of the bonds to search for.
348
+
349
+ Returns:
350
+ List[int]: A list of bond indices that match the given start and end atom indices.
351
+ """
352
+ mol = Chem.MolFromSmiles(smi)
353
+
354
+ bond_indices = []
355
+
356
+ for bond in mol.GetBonds():
357
+ begin_idx = bond.GetBeginAtomIdx()
358
+ end_idx = bond.GetEndAtomIdx()
359
+
360
+ if [begin_idx, end_idx] in bonds_start_end_atoms or [end_idx, begin_idx] in bonds_start_end_atoms:
361
+ bond_indices.append(bond.GetIdx())
362
+ elif (begin_idx, end_idx) in bonds_start_end_atoms or (end_idx, begin_idx) in bonds_start_end_atoms:
363
+ bond_indices.append(bond.GetIdx())
364
+
365
+ return bond_indices
366
+
367
+
368
+ def get_mol_id(smiles: str) -> str | None:
369
+ """ Get the Hash of a given SMILES string.
370
+
371
+ Args:
372
+ smiles (str): The SMILES string to hash.
373
+
374
+ Returns:
375
+ str | None: The Hash of the SMILES string. None if the function failed.
376
+ """
377
+ try:
378
+ mol = Chem.MolFromSmiles(smiles)
379
+ if mol is None:
380
+ return None
381
+ Chem.RemoveStereochemistry(mol)
382
+ except Exception as e:
383
+ logging.warning(f"Error while removing stereochemistry: {e}")
384
+ logging.warning(f"SMILES: {smiles}")
385
+ return None
386
+
387
+ # Get the InChIKey for the molecule
388
+ inchi_key = Chem.MolToInchiKey(mol)
389
+ smiles = Chem.MolToSmiles(mol, canonical=True)
390
+
391
+ # Encode the InChIKey and SMILES to create a unique identifier
392
+ return sha256((inchi_key + smiles).encode()).hexdigest()
393
+
394
+
395
+ def get_atom_idx_at_attachment(
396
+ protac: Chem.Mol,
397
+ substruct: Chem.Mol,
398
+ linker: Optional[Chem.Mol] = None,
399
+ timeout: Optional[Union[int, float]] = None,
400
+ return_dict: bool = False,
401
+ verbose: int = 0,
402
+ ) -> List[int]:
403
+ """ Get the atom index of the attachment point of a substructure in the PROTAC molecule.
404
+
405
+ Args:
406
+ protac: The PROTAC molecule.
407
+ substruct: The substructure of the PROTAC that contains the attachment point, e.g., the POI or E3 ligase.
408
+ linker: The linker molecule.
409
+ verbose: Verbosity level.
410
+
411
+ Returns:
412
+ List[int]: The two atom indices at the attachment point.
413
+ """
414
+ if linker is None:
415
+ # Get the "other" substructure, i.e., replace side chain of PROTAC using the substruct
416
+ linker = Chem.DeleteSubstructs(protac, remove_dummy_atoms(substruct), useChirality=True)
417
+ if timeout is None:
418
+ timeout = 60
419
+ logging.warning(f'No timeout set when linker is not provided, using default value of {timeout} seconds.')
420
+
421
+ substruct_match = set(protac.GetSubstructMatch(dummy2query(substruct), useChirality=True))
422
+ if verbose:
423
+ print(f'Substruct match: {substruct_match}')
424
+
425
+ linker_no_dummy = remove_dummy_atoms(linker)
426
+ if verbose:
427
+ print(f'Linker without dummy atoms found.')
428
+
429
+ max_matches = 2
430
+ linker_match = set()
431
+ shared_atoms = set()
432
+
433
+ # NOTE: The following is a hacky way to speed up the search for linker
434
+ # matches. In fact, the linker can be quite short, so it might match in
435
+ # multiple places of the PROTAC molecule.
436
+ # If the number of max matches in GetSubstructMatches is low, then the
437
+ # search tends to be faster, but imprecise. However, we are interested in
438
+ # the interesection of the matches, so we can progressively increase the
439
+ # number of max matches until we find a single atom in common.
440
+ while len(shared_atoms) != 1 and max_matches <= 50:
441
+ if timeout is None:
442
+ linker_matches = list(protac.GetSubstructMatches(linker_no_dummy, useChirality=True, maxMatches=max_matches))
443
+ else:
444
+ linker_matches = GetSubstructMatchesWithTimeout(protac, linker_no_dummy, useChirality=True, maxMatches=max_matches, timeout=timeout)
445
+ if verbose:
446
+ print(f'Linker matches: {linker_matches}')
447
+
448
+ if not linker_matches:
449
+ # return None
450
+ linker_match = set()
451
+ shared_atoms = set()
452
+ max_matches += 1
453
+ continue
454
+
455
+ for match in linker_matches:
456
+ shared_atoms = set(match) & set(substruct_match)
457
+ linker_match = match
458
+ if len(shared_atoms) == 1:
459
+ if verbose:
460
+ print(f'Shared atoms: {list(shared_atoms)}')
461
+ break
462
+
463
+ if len(shared_atoms) != 1:
464
+ linker_match = set()
465
+ shared_atoms = set()
466
+ max_matches += 1
467
+
468
+ if not shared_atoms:
469
+ if verbose:
470
+ print('No shared atoms found.')
471
+ return None
472
+
473
+ attachment_idx = list(shared_atoms)
474
+ attachments = {'substruct': attachment_idx[0]}
475
+
476
+ # Get the other atom at the attachment point that is NOT in the linker
477
+ for neighbor in protac.GetAtomWithIdx(attachment_idx[0]).GetNeighbors():
478
+ if neighbor.GetIdx() not in linker_match:
479
+ attachment_idx.append(neighbor.GetIdx())
480
+ attachments['linker'] = neighbor.GetIdx()
481
+ break
482
+
483
+ if return_dict:
484
+ return attachments
485
+ return attachment_idx
486
+
487
+
protac_splitter/data/__init__.py ADDED
File without changes
protac_splitter/data/curation/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mapping_utils import update_dictionary
2
+ from .curation import (
3
+ split_protacs,
4
+ iterative_protac_splitting,
5
+ )
6
+
7
+ __all__ = [
8
+ 'update_dictionary',
9
+ 'split_protacs',
10
+ 'iterative_protac_splitting',
11
+ ]
protac_splitter/data/curation/bond_adjustments.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Adjusts amide and ester bonds in PROTAC substructures. """
2
+ from typing import Tuple, Dict
3
+
4
+ from rdkit import Chem
5
+
6
+ from protac_splitter.chemoinformatics import (
7
+ dummy2query,
8
+ canonize,
9
+ )
10
+ from protac_splitter.display_utils import display_mol
11
+ from protac_splitter.evaluation import check_reassembly
12
+
13
+
14
+ def adjust_amide_bond(
15
+ substruct: Chem.Mol,
16
+ linker: Chem.Mol,
17
+ substruct_attachment_id: int,
18
+ verbose: int = 0,
19
+ ) -> Tuple[Chem.Mol, Chem.Mol]:
20
+ """
21
+ Adjust the amide bond between the substruct and linker substructure.
22
+ Handles the case when neighboring atoms of the amide bond are dummy atoms, which represent attachment points.
23
+ The linker will be modified with the required additional atoms.
24
+
25
+ Args:
26
+ substruct: The substructure of the substruct (protein of interest) that contains the amide bond.
27
+ linker: The linker molecule that connects substruct to the E3 ligase.
28
+ substruct_attachment_id: The attachment point ID in the substruct substructure. E.g., 1 for the POI, as in "[*:1]".
29
+
30
+ Returns:
31
+ Tuple[Chem.Mol, Chem.Mol]: The adjusted substruct and linker molecules, in that order.
32
+ """
33
+
34
+ # Pseudo-code of the algorithm:
35
+ """
36
+ ```python
37
+ # Check if the amide bond (N-C=O) is in the substructure
38
+ if "N-C(=O)" in substruct:
39
+ if neighbor("N-C(=O)") == "[*:substruct]":
40
+ # If the neighboring atom of the amide bond is a dummy atom, i.e., attachment point
41
+ mark_protac_as_wrong("[PROTAC]")
42
+
43
+ # Identify the bond to split, i.e., the nitrogen-carbon bond, and split
44
+ "[*:substruct]-[<optional neighboring atom>]-N-[*:tmp]", "[*:tmp]-C(=O)-[rest of the PROTAC]" = split_PROTAC_at("N-C")
45
+
46
+ "[Linker]-N-[*:tmp]" = join("[Linker]-[*:substruct]", "[*:substruct]-N-[*:tmp]")
47
+
48
+ rename_attachment_point("[*:tmp]-C(=O)-[rest of the PROTAC]")
49
+ rename_attachment_point("[Linker]-N-[*:tmp]")
50
+
51
+ elif neighbor(neighbor("N-C(=O)")) == "[*:substruct]":
52
+ # If the second neighbor of athe amide bond is a dummy atom, i.e., attachment point
53
+ mark_protac_as_wrong("[PROTAC]")
54
+
55
+ # Do as above
56
+ # Identify the bond to split, i.e., the nitrogen-carbon bond, and split
57
+ "[*:substruct]-N-[*:tmp]", "[*:tmp]-C(=O)-[rest of the PROTAC]" = split_PROTAC_at("N-C")
58
+
59
+ "[Linker]-N-[*:tmp]" = join("[Linker]-[*:substruct]", "[*:substruct]-N-[*:tmp]")
60
+
61
+ rename_attachment_point("[*:tmp]-C(=O)-[rest of the PROTAC]")
62
+ rename_attachment_point("[Linker]-N-[*:tmp]")
63
+ ```
64
+ """
65
+
66
+ # Convert dummy atoms in substruct to query atoms for substructure search
67
+ query_substruct = dummy2query(substruct)
68
+
69
+ # Identify amide bond (N-C=O) in substruct substructure
70
+ amide_pattern = Chem.MolFromSmarts("[NX3][CX3](=[OX1])")
71
+ amide_matches = query_substruct.GetSubstructMatches(amide_pattern, useChirality=True)
72
+
73
+ if not amide_matches:
74
+ return substruct, linker # No amide bond found, return the original substruct
75
+
76
+ side_atom = None
77
+ nitrogen_idx_found, carbonyl_idx_found = None, None
78
+ for match in amide_matches:
79
+ nitrogen_idx, carbonyl_idx = match[0], match[1]
80
+ nitrogen_atom = query_substruct.GetAtomWithIdx(nitrogen_idx)
81
+ carbonyl_atom = query_substruct.GetAtomWithIdx(carbonyl_idx)
82
+
83
+ for amide_atom in [nitrogen_atom, carbonyl_atom]:
84
+ # Check neighboring atoms for attachment points
85
+ # NOTE: The dummy atom representing an attachment point have atomic number 0
86
+ for neighbor in amide_atom.GetNeighbors():
87
+ if neighbor.GetAtomicNum() == 0:
88
+ nitrogen_idx_found = nitrogen_idx
89
+ carbonyl_idx_found = carbonyl_idx
90
+ side_atom = "N" if amide_atom == nitrogen_atom else "C"
91
+ break
92
+
93
+ # If previous search failed, check the neighbors of the neighboring
94
+ # atoms (second-order neighbors)
95
+ if nitrogen_idx_found is None or carbonyl_idx_found is None:
96
+ for neighbor in amide_atom.GetNeighbors():
97
+ for second_neighbor in neighbor.GetNeighbors():
98
+ if second_neighbor.GetIdx() == carbonyl_idx or second_neighbor.GetIdx() == nitrogen_idx:
99
+ continue # Skip the opposite atom from the amide bond
100
+
101
+ if second_neighbor.GetAtomicNum() == 0:
102
+ nitrogen_idx_found = nitrogen_idx
103
+ carbonyl_idx_found = carbonyl_idx
104
+ side_atom = "N" if amide_atom == nitrogen_atom else "C"
105
+ break
106
+ else:
107
+ break
108
+
109
+ if nitrogen_idx_found is None or carbonyl_idx_found is None or side_atom is None:
110
+ return substruct, linker
111
+
112
+ # Split the amide bond and adjust
113
+ dummy_label = 3
114
+ dummy_labels = [(dummy_label, dummy_label)] # The E3 and substruct will have 1 and 2, so we need a third one
115
+ amid_bond_idx = query_substruct.GetBondBetweenAtoms(nitrogen_idx_found, carbonyl_idx_found).GetIdx()
116
+ fragments = Chem.FragmentOnBonds(query_substruct, [amid_bond_idx], addDummies=True, dummyLabels=dummy_labels)
117
+
118
+ # Get the fragments resulting from bond breaking
119
+ try:
120
+ mol_frags = Chem.GetMolFrags(fragments, asMols=True, sanitizeFrags=True)
121
+ except Exception as e:
122
+ print(e)
123
+ return substruct, linker
124
+
125
+ # Identify the "[*:substruct][<optional neighboring atom>]N[3*]" fragment, the other one will be the "truncated" substruct
126
+ amide_fragment_pattern = Chem.MolFromSmarts(f"[*:{substruct_attachment_id}][{side_atom}][{dummy_label}*]")
127
+ amide_fragment = None
128
+ substruct_fixed = None
129
+
130
+ if verbose:
131
+ print(f'Attachment point: *:{substruct_attachment_id}')
132
+ print('Substruct:')
133
+ display_mol(substruct)
134
+ print('Linker:')
135
+ display_mol(linker)
136
+
137
+ for frag in mol_frags:
138
+ if frag.HasSubstructMatch(dummy2query(amide_fragment_pattern)):
139
+ amide_fragment = frag
140
+ if verbose:
141
+ print('Amide fragment:')
142
+ display_mol(frag)
143
+ else:
144
+ if verbose:
145
+ print('Substruct fragment:')
146
+ display_mol(frag)
147
+ substruct_fixed = frag
148
+
149
+ if amide_fragment is None or substruct_fixed is None:
150
+ return substruct, linker
151
+
152
+ # In order for the function to be used "on linkers", we need to make sure
153
+ # that the amide fragment contains the attachment point of the substruct.
154
+ # If not, there's nothing to do.
155
+ if f'[*:{substruct_attachment_id}]' not in Chem.MolToSmiles(amide_fragment, canonical=True):
156
+ return substruct, linker
157
+
158
+ # Rename the "[3*]" attachment point on the amide fragment to "[*:3]"
159
+ amide_fragment_smiles = Chem.MolToSmiles(amide_fragment, canonical=True)
160
+ amide_fragment_smiles = amide_fragment_smiles.replace(f'[{dummy_label}*]', f'[*:{dummy_label}]')
161
+ amide_fragment_smiles = canonize(amide_fragment_smiles)
162
+ amide_fragment = Chem.MolFromSmiles(amide_fragment_smiles)
163
+
164
+ # Use molzip to join the linker and the fragment at the original attachment point
165
+ linker_fixed = Chem.molzip(linker, amide_fragment)
166
+
167
+ # Rename the "[*:3]" attachment point back to the original attachment point on the linker
168
+ linker_fixed_smiles = Chem.MolToSmiles(linker_fixed, canonical=True)
169
+ linker_fixed_smiles = linker_fixed_smiles.replace(f'[*:{dummy_label}]', f'[*:{substruct_attachment_id}]')
170
+ linker_fixed_smiles = canonize(linker_fixed_smiles)
171
+ linker_fixed = Chem.MolFromSmiles(linker_fixed_smiles)
172
+
173
+ # Rename the "[3*]" attachment point back to the original attachment point on the substruct
174
+ substruct_fixed_smiles = Chem.MolToSmiles(substruct_fixed, canonical=True)
175
+ substruct_fixed_smiles = substruct_fixed_smiles.replace(f'[{dummy_label}*]', f'[*:{substruct_attachment_id}]')
176
+ substruct_fixed_smiles = canonize(substruct_fixed_smiles)
177
+ substruct_fixed = Chem.MolFromSmiles(substruct_fixed_smiles)
178
+
179
+ return substruct_fixed, linker_fixed
180
+
181
+
182
+ def adjust_amide_bonds_in_substructs(
183
+ substructs: Dict[str, str],
184
+ protac_smiles: str,
185
+ poi_attachment_id: int = 1,
186
+ e3_attachment_id: int = 2,
187
+ ) -> Dict[str, str]:
188
+ """ Adjusts the amide bonds in the substructures of a PROTAC. Just a wrapper function to apply it to multiple substructures.
189
+
190
+ Args:
191
+ substructs: The substructures of the PROTAC. A dictionary of SMILES with keys 'poi', 'linker', and 'e3'.
192
+ protac_smiles: The SMILES of the PROTAC for checking reassembly.
193
+
194
+ Returns:
195
+ The updated substructures dictionary.
196
+ """
197
+ poi_mol = Chem.MolFromSmiles(substructs['poi'])
198
+ e3_mol = Chem.MolFromSmiles(substructs['e3'])
199
+ linker_mol = Chem.MolFromSmiles(substructs['linker'])
200
+
201
+ # Fix the amide group on the POI ligand
202
+ poi_mol, linker_mol = adjust_amide_bond(poi_mol, linker_mol, poi_attachment_id)
203
+ poi_smiles = Chem.MolToSmiles(poi_mol, canonical=True)
204
+ linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
205
+ e3_smiles = substructs['e3']
206
+ if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
207
+ return substructs
208
+
209
+ # Fix the amide group on the E3 binder
210
+ e3_mol, linker_mol = adjust_amide_bond(e3_mol, linker_mol, e3_attachment_id)
211
+ e3_smiles = Chem.MolToSmiles(e3_mol, canonical=True)
212
+ linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
213
+ if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
214
+ return substructs
215
+
216
+ # Fix the amide group on the linker, E3 side
217
+ linker_mol, e3_mol = adjust_amide_bond(linker_mol, e3_mol, e3_attachment_id)
218
+ e3_smiles = Chem.MolToSmiles(e3_mol, canonical=True)
219
+ linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
220
+ if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
221
+ return substructs
222
+
223
+ # Fix the amide group on the linker, POI side
224
+ linker_mol, poi_mol = adjust_amide_bond(linker_mol, poi_mol, poi_attachment_id)
225
+ poi_smiles = Chem.MolToSmiles(poi_mol, canonical=True)
226
+ linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
227
+ if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
228
+ return substructs
229
+
230
+ substructs['poi'] = poi_smiles
231
+ substructs['e3'] = e3_smiles
232
+ substructs['linker'] = linker_smiles
233
+ return substructs
234
+
235
+
236
+ def adjust_ester_bond(
237
+ substruct: Chem.Mol,
238
+ linker: Chem.Mol,
239
+ substruct_attachment_id: int,
240
+ verbose: int = 0,
241
+ ) -> Tuple[Chem.Mol, Chem.Mol]:
242
+ """
243
+ Adjust the amide bond between the substruct and linker substructure.
244
+ Handles the case when neighboring atoms of the amide bond are dummy atoms, which represent attachment points.
245
+
246
+ Args:
247
+ substruct: The substructure of the substruct (protein of interest) that contains the amide bond.
248
+ linker: The linker molecule that connects substruct to the E3 ligase.
249
+ substruct_attachment_id: The attachment point ID in the substruct substructure. E.g., 1 for the POI, as in "[*:1]".
250
+
251
+ Returns:
252
+ Tuple[Chem.Mol, Chem.Mol]: The adjusted substruct and linker molecules, in that order.
253
+ """
254
+ # Convert dummy atoms in substruct to query atoms for substructure search
255
+ query_substruct = dummy2query(substruct)
256
+
257
+ # Identify ester group (COOR) in substruct substructure
258
+ ester_pattern = Chem.MolFromSmarts("[OX2][CX3](=[OX1])")
259
+
260
+ ester_matches = query_substruct.GetSubstructMatches(ester_pattern)
261
+
262
+ if not ester_matches:
263
+ return substruct, linker # No amide bond found, return the original substruct
264
+
265
+ side_atom = None
266
+ oxygen_idx_found, carbonyl_idx_found = None, None
267
+ for match in ester_matches:
268
+ oxygen_idx, carbonyl_idx = match[0], match[1]
269
+ oxygen_atom = query_substruct.GetAtomWithIdx(oxygen_idx)
270
+ carbonyl_atom = query_substruct.GetAtomWithIdx(carbonyl_idx)
271
+
272
+ for ester_atom in [oxygen_atom, carbonyl_atom]:
273
+ # Check neighboring atoms for attachment points
274
+ # NOTE: The dummy atom representing an attachment point have atomic number 0
275
+ for neighbor in ester_atom.GetNeighbors():
276
+ if neighbor.GetAtomicNum() == 0:
277
+ oxygen_idx_found = oxygen_idx
278
+ carbonyl_idx_found = carbonyl_idx
279
+ side_atom = "O" if ester_atom == oxygen_atom else "C"
280
+ break
281
+
282
+ # If previous search failed, check the neighbors of the neighboring
283
+ # atoms (second-order neighbors)
284
+ if oxygen_idx_found is None or carbonyl_idx_found is None:
285
+ for neighbor in ester_atom.GetNeighbors():
286
+ for second_neighbor in neighbor.GetNeighbors():
287
+ if second_neighbor.GetIdx() == carbonyl_idx or second_neighbor.GetIdx() == oxygen_idx:
288
+ continue # Skip the opposite atom from the amide bond
289
+
290
+ if second_neighbor.GetAtomicNum() == 0:
291
+ oxygen_idx_found = oxygen_idx
292
+ carbonyl_idx_found = carbonyl_idx
293
+ side_atom = "O" if ester_atom == oxygen_atom else "C"
294
+ break
295
+ else:
296
+ break
297
+
298
+ if oxygen_idx_found is None or carbonyl_idx_found is None or side_atom is None:
299
+ return substruct, linker
300
+
301
+ # Split the amide bond and adjust
302
+ dummy_label = 3
303
+ dummy_labels = [(dummy_label, dummy_label)] # The E3 and substruct will have 1 and 2, so we need a third one
304
+ amid_bond_idx = query_substruct.GetBondBetweenAtoms(oxygen_idx_found, carbonyl_idx_found).GetIdx()
305
+ fragments = Chem.FragmentOnBonds(query_substruct, [amid_bond_idx], addDummies=True, dummyLabels=dummy_labels)
306
+
307
+ # Get the fragments resulting from bond breaking
308
+ try:
309
+ mol_frags = Chem.GetMolFrags(fragments, asMols=True, sanitizeFrags=True)
310
+ except Exception as e:
311
+ if verbose:
312
+ print(e)
313
+ return substruct, linker
314
+
315
+ # Identify the "[*:substruct][<optional neighboring atom>]N[3*]" fragment, the other one will be the "truncated" substruct
316
+ ester_fragment_pattern = Chem.MolFromSmarts(f"[*:{substruct_attachment_id}][{side_atom}][{dummy_label}*]")
317
+ ester_fragment = None
318
+ substruct_fixed = None
319
+
320
+ for frag in mol_frags:
321
+ if frag.HasSubstructMatch(dummy2query(ester_fragment_pattern)):
322
+ ester_fragment = frag
323
+ else:
324
+ substruct_fixed = frag
325
+
326
+ if ester_fragment is None or substruct_fixed is None:
327
+ return substruct, linker
328
+
329
+ # In order for the function to be used "on linkers", we need to make sure
330
+ # that the ester fragment contains the attachment point of the substruct.
331
+ # If not, there's nothing to do.
332
+ if f'[*:{substruct_attachment_id}]' not in Chem.MolToSmiles(ester_fragment, canonical=True):
333
+ return substruct, linker
334
+
335
+ # Rename the "[3*]" attachment point on the amide fragment to "[*:3]"
336
+ ester_fragment_smiles = Chem.MolToSmiles(ester_fragment, canonical=True)
337
+ ester_fragment_smiles = ester_fragment_smiles.replace(f'[{dummy_label}*]', f'[*:{dummy_label}]')
338
+ ester_fragment = Chem.MolFromSmiles(ester_fragment_smiles)
339
+
340
+ # Use molzip to join the linker and the fragment at the original attachment point
341
+ linker_fixed = Chem.molzip(linker, ester_fragment)
342
+
343
+ # Rename the "[*:3]" attachment point back to the original attachment point on the linker
344
+ linker_fixed_smiles = Chem.MolToSmiles(linker_fixed, canonical=True)
345
+ linker_fixed_smiles = linker_fixed_smiles.replace(f'[*:{dummy_label}]', f'[*:{substruct_attachment_id}]')
346
+ linker_fixed = Chem.MolFromSmiles(linker_fixed_smiles)
347
+
348
+ # Rename the "[3*]" attachment point back to the original attachment point on the substruct
349
+ substruct_fixed_smiles = Chem.MolToSmiles(substruct_fixed, canonical=True)
350
+ substruct_fixed_smiles = substruct_fixed_smiles.replace(f'[{dummy_label}*]', f'[*:{substruct_attachment_id}]')
351
+ substruct_fixed = Chem.MolFromSmiles(substruct_fixed_smiles)
352
+
353
+ return substruct_fixed, linker_fixed
354
+
355
+
356
+ def adjust_ester_bonds_in_substructs(
357
+ substructs: Dict[str, str],
358
+ protac_smiles: str,
359
+ poi_attachment_id: int = 1,
360
+ e3_attachment_id: int = 2,
361
+ ) -> Dict[str, str]:
362
+ """ Adjusts the ester bonds in the substructures of a PROTAC. Just a wrapper function to apply it to multiple substructures.
363
+
364
+ Args:
365
+ substructs: The substructures of the PROTAC. A dictionary of SMILES with keys 'poi', 'linker', and 'e3'.
366
+ protac_smiles: The SMILES of the PROTAC for checking reassembly.
367
+
368
+ Returns:
369
+ The updated substructures dictionary.
370
+ """
371
+ poi_mol = Chem.MolFromSmiles(substructs['poi'])
372
+ e3_mol = Chem.MolFromSmiles(substructs['e3'])
373
+ linker_mol = Chem.MolFromSmiles(substructs['linker'])
374
+
375
+ # Fix the amide group on the POI ligand
376
+ poi_mol, linker_mol = adjust_ester_bond(poi_mol, linker_mol, poi_attachment_id)
377
+ poi_smiles = Chem.MolToSmiles(poi_mol, canonical=True)
378
+ linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
379
+ e3_smiles = substructs['e3']
380
+ if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
381
+ return substructs
382
+
383
+ # Fix the amide group on the E3 binder
384
+ e3_mol, linker_mol = adjust_ester_bond(e3_mol, linker_mol, e3_attachment_id)
385
+ e3_smiles = Chem.MolToSmiles(e3_mol, canonical=True)
386
+ linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
387
+ if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
388
+ return substructs
389
+
390
+ # Fix the amide group on the linker, E3 side
391
+ linker_mol, e3_mol = adjust_ester_bond(linker_mol, e3_mol, e3_attachment_id)
392
+ e3_smiles = Chem.MolToSmiles(e3_mol, canonical=True)
393
+ linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
394
+ if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
395
+ return substructs
396
+
397
+ # Fix the amide group on the linker, POI side
398
+ linker_mol, poi_mol = adjust_ester_bond(linker_mol, poi_mol, poi_attachment_id)
399
+ poi_smiles = Chem.MolToSmiles(poi_mol, canonical=True)
400
+ linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True)
401
+ if not check_reassembly(protac_smiles, '.'.join([poi_smiles, linker_smiles, e3_smiles])):
402
+ return substructs
403
+
404
+ substructs['poi'] = poi_smiles
405
+ substructs['e3'] = e3_smiles
406
+ substructs['linker'] = linker_smiles
407
+ return substructs
protac_splitter/data/curation/curation.py ADDED
@@ -0,0 +1,894 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Curation utilities for PROTAC Splitter. """
2
+ import os
3
+ import re
4
+ from typing import Any, Dict, Optional, Union, Callable
5
+ from joblib import Parallel, delayed
6
+
7
+ from rdkit import Chem
8
+ from rdkit.Chem import DataStructs
9
+ import pandas as pd
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+
13
+ from protac_splitter.chemoinformatics import (
14
+ canonize,
15
+ remove_dummy_atoms,
16
+ canonize_smiles,
17
+ get_mol_id,
18
+ get_substr_match,
19
+ )
20
+ from protac_splitter.evaluation import check_reassembly
21
+ from protac_splitter.data.curation.substructure_extraction import (
22
+ get_substructure_from_non_perfect_match,
23
+ get_substructs_from_unmapped_e3_poi,
24
+ get_substructs_from_substr_and_linker,
25
+ get_substructs_from_mapped_linker,
26
+ swap_attachment_points,
27
+ )
28
+ from protac_splitter.data.curation.bond_adjustments import (
29
+ adjust_amide_bonds_in_substructs,
30
+ adjust_ester_bonds_in_substructs,
31
+ )
32
+ from protac_splitter.data.curation.mapping_utils import update_dictionary
33
+
34
+
35
+ def check_substructs_size(
36
+ protac_mol: Chem.Mol,
37
+ substructs: Dict[str, str],
38
+ size_perc_threshold: float = 0.8,
39
+ ) -> bool:
40
+ """ Check the size of the substructures in the PROTAC. If any of them is too big, return False.
41
+
42
+ Args:
43
+ protac_mol: The PROTAC molecule.
44
+ substructs: The substructures to check against.
45
+
46
+ Returns:
47
+ False if any of the substructures is too big. True otherwise.
48
+ """
49
+ num_protac_atoms = protac_mol.GetNumAtoms()
50
+ for key, smiles in substructs.items():
51
+ substruct = Chem.MolFromSmiles(smiles)
52
+ num_substruct_atoms = substruct.GetNumAtoms()
53
+ if num_substruct_atoms / num_protac_atoms > size_perc_threshold:
54
+ # print(f'Error: {key.upper()} is too big in the PROTAC ({num_substruct_atoms} / {num_protac_atoms} = {num_substruct_atoms / num_protac_atoms:.2%} > {size_perc_threshold:.2%})')
55
+ # display_mol(substruct)
56
+ # display_mol(protac_mol)
57
+ return False
58
+ return True
59
+
60
+
61
+ def check_linker_similarity(
62
+ linker_smiles: str,
63
+ pois: Union[pd.DataFrame, str],
64
+ e3s: Union[pd.DataFrame, str],
65
+ linkers: Optional[Union[pd.DataFrame, str]] = None,
66
+ pois_similarity_threshold: float = 0.7,
67
+ e3s_similarity_threshold: float = 0.7,
68
+ linkers_similarity_threshold: float = 0.6,
69
+ morgan_fp_generator: Optional[Callable] = None,
70
+ ) -> bool:
71
+ """ Check the similarity of the linker with all the matching POIs and E3s. If too similar to any of them, return False.
72
+
73
+ Args:
74
+ linker_smiles: The linker SMILES.
75
+ pois: The POI ligands. Must have a 'FP' column with the Morgan fingerprints.
76
+ e3s: The E3 binders. Must have a 'FP' column with the Morgan fingerprints.
77
+ pois_similarity_threshold: The similarity threshold for the POIs.
78
+ e3s_similarity_threshold: The similarity threshold for the E3s.
79
+ morgan_fp_generator: The Morgan fingerprint generator.
80
+
81
+ Returns:
82
+ False if the linker is too similar to any of the POIs or E3s. True otherwise.
83
+ """
84
+
85
+ # Get the linker fingerprint
86
+ if morgan_fp_generator is None:
87
+ morgan_fp_generator = Chem.rdFingerprintGenerator.GetMorganGenerator(
88
+ radius=2,
89
+ fpSize=2048,
90
+ useBondTypes=True,
91
+ includeChirality=True,
92
+ )
93
+
94
+ linker = Chem.MolFromSmiles(linker_smiles)
95
+ linker_fp = morgan_fp_generator.GetFingerprint(linker)
96
+
97
+ # Check the similarity of the linker with the POIs and E3s (use BulkTanimotoSimilarity)
98
+ if isinstance(e3s, str):
99
+ # Create a one-element list with the E3 fingerprint
100
+ e3s_fps = [morgan_fp_generator.GetFingerprint(Chem.MolFromSmiles(e3s))]
101
+ else:
102
+ e3s_fps = e3s['FP'].to_list()
103
+ e3s_similarities = DataStructs.BulkTanimotoSimilarity(linker_fp, e3s_fps)
104
+ if (np.array(e3s_similarities) > e3s_similarity_threshold).any():
105
+ print(f'WARNING: Linker {linker_smiles} is too similar to an E3 binder')
106
+ # display_mol(linker)
107
+ # display_mol(Chem.MolFromSmiles(e3s[e3s_similarities.argmax()]))
108
+ return False
109
+
110
+ # Check if the linker is similar to any of the POIs or E3s
111
+ if isinstance(pois, str):
112
+ # Create a one-element list with the POI fingerprint
113
+ pois_fps = [morgan_fp_generator.GetFingerprint(Chem.MolFromSmiles(pois))]
114
+ else:
115
+ pois_fps = pois['FP'].to_list()
116
+ pois_similarities = DataStructs.BulkTanimotoSimilarity(linker_fp, pois_fps)
117
+ if (np.array(pois_similarities) > pois_similarity_threshold).any():
118
+ # print(f'Error: Linker {linker_smiles} is too similar to a POI ligand')
119
+ # display_mol(linker)
120
+ # display_mol(Chem.MolFromSmiles(pois[pois_similarities.argmax()]))
121
+ return False
122
+
123
+ # Check if the linker is NOT similar to any of the linkers
124
+ if linkers is not None:
125
+ if isinstance(linkers, str):
126
+ # Create a one-element list with the linker fingerprint
127
+ linkers_fps = [morgan_fp_generator.GetFingerprint(Chem.MolFromSmiles(linkers))]
128
+ else:
129
+ linkers_fps = linkers['FP'].to_list()
130
+ linkers_similarities = DataStructs.BulkTanimotoSimilarity(linker_fp, linkers_fps)
131
+ if not (np.array(linkers_similarities) > linkers_similarity_threshold).all():
132
+ print(f'WARNING: Linker {linker_smiles} is too similar to a linker')
133
+ # display_mol(linker)
134
+ # display_mol(Chem.MolFromSmiles(linkers[linkers_similarities.argmax()]))
135
+ return False
136
+
137
+ return True
138
+
139
+
140
+ def check_substructs_similarity(
141
+ protac: Union[np.ndarray, str, Chem.Mol],
142
+ substructs: Dict[str, str],
143
+ similarity_threshold: float = 0.7,
144
+ similarity_thresholds : Dict[str, float] = None,
145
+ morgan_fp_generator: Optional[Callable] = None,
146
+ ) -> bool:
147
+ """ Check the similarity of the PROTAC with the substructures. If too similar to any of them, return False.
148
+
149
+ Args:
150
+ protac: The PROTAC molecule or its SMILES.
151
+ substructs: The substructures to check against.
152
+ similarity_threshold: The similarity threshold.
153
+ similarity_thresholds: The similarity thresholds for the substructures.
154
+ morgan_fp_generator: The Morgan fingerprint generator.
155
+
156
+ Returns:
157
+ False if the PROTAC is too similar to any of the substructures. True otherwise.
158
+ """
159
+
160
+ if morgan_fp_generator is None:
161
+ morgan_fp_generator = Chem.rdFingerprintGenerator.GetMorganGenerator(
162
+ radius=2,
163
+ fpSize=2048,
164
+ useBondTypes=True,
165
+ includeChirality=True,
166
+ )
167
+
168
+ if isinstance(protac, str):
169
+ protac = Chem.MolFromSmiles(protac)
170
+ protac_fp = morgan_fp_generator.GetFingerprint(protac)
171
+ elif isinstance(protac, Chem.Mol):
172
+ protac_fp = morgan_fp_generator.GetFingerprint(protac)
173
+ else:
174
+ protac_fp = protac
175
+
176
+ for key, smiles in substructs.items():
177
+ substr_fp = morgan_fp_generator.GetFingerprint(Chem.MolFromSmiles(smiles))
178
+ threshold = similarity_thresholds[key] if similarity_thresholds is not None else similarity_threshold
179
+ if DataStructs.TanimotoSimilarity(protac_fp, substr_fp) > threshold:
180
+ print(f'WARNING: {key.upper()} is too similar to the PROTAC, similarity: {DataStructs.TanimotoSimilarity(protac_fp, substr_fp):.4f} > {threshold}')
181
+ # display_mol(Chem.MolFromSmiles(smiles))
182
+ return False
183
+
184
+ return True
185
+
186
+
187
+ def get_split_row(
188
+ row: pd.Series,
189
+ substructs: Dict[str, str],
190
+ poi_smiles_no_dummy: Optional[str] = None,
191
+ e3_smiles_no_dummy: Optional[str] = None,
192
+ ) -> Dict[str, Any]:
193
+ """ Update the fields of a row with the substructures and their IDs.
194
+
195
+ Args:
196
+ row: The input row.
197
+ dictionaries: The dictionaries containing the substructures.
198
+ substructs: The substructures found in the PROTAC.
199
+ poi_smiles_no_dummy: The POI ligand SMILES without the dummy atoms.
200
+ e3_smiles_no_dummy: The E3 binder SMILES without the dummy atoms.
201
+ update_dict_if_ids_not_found: Whether to update the dictionary if the substructure IDs are not found.
202
+
203
+ Returns:
204
+ The updated row.
205
+ """
206
+ mapped_row = {}
207
+ mapped_row['PROTAC SMILES'] = canonize_smiles(row['SMILES'])
208
+ mapped_row['POI Ligand SMILES with direction'] = substructs['poi']
209
+ mapped_row['E3 Binder SMILES with direction'] = substructs['e3']
210
+ mapped_row['Linker SMILES with direction'] = substructs['linker']
211
+ mapped_row['POI Ligand SMILES'] = remove_dummy_atoms(substructs['poi']) if poi_smiles_no_dummy is None else poi_smiles_no_dummy
212
+ mapped_row['E3 Binder SMILES'] = remove_dummy_atoms(substructs['e3']) if e3_smiles_no_dummy is None else e3_smiles_no_dummy
213
+ mapped_row['Linker SMILES'] = remove_dummy_atoms(substructs['linker'])
214
+
215
+ # Get the IDs and update the dictionaries with new substructures
216
+ mapped_row['PROTAC ID'] = get_mol_id(mapped_row['PROTAC SMILES'])
217
+ mapped_row['POI Ligand ID'] = get_mol_id(mapped_row['POI Ligand SMILES with direction'])
218
+ mapped_row['E3 Binder ID'] = get_mol_id(mapped_row['E3 Binder SMILES with direction'])
219
+ mapped_row['Linker ID'] = get_mol_id(mapped_row['Linker SMILES with direction'])
220
+
221
+ return mapped_row
222
+
223
+
224
+ def split_single_protac(
225
+ row: pd.Series,
226
+ dictionaries: Dict[str, pd.DataFrame],
227
+ biggest_matches_first: bool = True,
228
+ max_iter_on_linkers: int = 0,
229
+ split_with_substr_and_linker_matching: bool = False,
230
+ similarity_threshold: float = 0.65,
231
+ morgan_radius: Optional[int] = None,
232
+ morgan_fp_size: Optional[int] = None,
233
+ morgan_fp_generator: Optional[Callable] = None,
234
+ poi_attachment_id: int = 1,
235
+ e3_attachment_id: int = 2,
236
+ ) -> Dict[str, Any]:
237
+ """ Map a PROTAC row to the substructures in the dictionaries.
238
+
239
+ Args:
240
+ row: The input row, containing the PROTAC SMILES, ID, and molecule.
241
+ dictionaries: The dictionaries containing the substructures.
242
+ biggest_matches_first: Whether to sort the matches by the number of atoms in the molecule.
243
+ max_iter_on_linkers: The maximum number of iterations to perform on the linkers.
244
+
245
+ Returns:
246
+ The mapped row. None if the mapping was not successful.
247
+ """
248
+ # # Disable the RDKit warnings that pop up when RDKit fails to create molecules
249
+ # # NOTE: The following is done to avoid warning messages during multiprocessing
250
+ # RDLogger.DisableLog("rdApp.*")
251
+ # blocker = rdBase.BlockLogs()
252
+
253
+ protac_smiles = row['SMILES']
254
+ protac_mol = row['Molecule']
255
+
256
+ if morgan_fp_generator is None:
257
+ morgan_radius = 2 if morgan_radius is None else morgan_radius
258
+ morgan_fp_size = 2048 if morgan_fp_size is None else morgan_fp_size
259
+ morgan_fp_generator = Chem.rdFingerprintGenerator.GetMorganGenerator(
260
+ radius=morgan_radius,
261
+ fpSize=morgan_fp_size,
262
+ useBondTypes=True,
263
+ includeChirality=True,
264
+ )
265
+ else:
266
+ morgan_radius = 'None'
267
+ morgan_fp_size = 'None'
268
+ protac_fp = morgan_fp_generator.GetFingerprint(protac_mol)
269
+
270
+ notes = f'({max_iter_on_linkers=})({split_with_substr_and_linker_matching=})({morgan_radius=})({morgan_fp_size=})'
271
+
272
+ # Get all substructure matches in the POI dictionary
273
+ # poi_matches = dictionaries['POI Ligand']['Molecule'].apply(lambda x: get_substr_match(protac_mol, x, num_allowed_fragments=1))
274
+ poi_matches = dictionaries['POI Ligand']['Molecule'].apply(lambda x: protac_mol.HasSubstructMatch(x))
275
+ pois = dictionaries['POI Ligand'][poi_matches].drop_duplicates(subset=['SMILES'])
276
+
277
+ # Get all substructure matches in the E3 dictionary
278
+ # e3_matches = dictionaries['E3 Binder']['Molecule'].apply(lambda x: get_substr_match(protac_mol, x, num_allowed_fragments=1))
279
+ e3_matches = dictionaries['E3 Binder']['Molecule'].apply(lambda x: protac_mol.HasSubstructMatch(x))
280
+ e3s = dictionaries['E3 Binder'][e3_matches].drop_duplicates(subset=['SMILES'])
281
+
282
+ # # Sort the matches by the number of atoms in the molecule
283
+ # ascending = False if biggest_matches_first else True
284
+ # pois = pois.sort_values(by='Molecule', key=lambda s: s.apply(lambda m: m.GetNumAtoms()), ascending=True)
285
+ # e3s = e3s.sort_values(by='Molecule', key=lambda s: s.apply(lambda m: m.GetNumAtoms()), ascending=True)
286
+
287
+ # Get the POI median, then re-arrenge the pois dataframe so that the median is the first element
288
+ poi_median = pois['Molecule'].apply(lambda x: x.GetNumAtoms()).median()
289
+ pois = pois.sort_values(by='Molecule', key=lambda s: s.apply(lambda m: m.GetNumAtoms()), ascending=True)
290
+ pois = pois.iloc[np.abs(pois['Molecule'].apply(lambda x: x.GetNumAtoms()) - poi_median).argsort()]
291
+
292
+ # Get the E3 median, then re-arrenge the e3s dataframe so that the median is the first element
293
+ e3_median = e3s['Molecule'].apply(lambda x: x.GetNumAtoms()).median()
294
+ e3s = e3s.sort_values(by='Molecule', key=lambda s: s.apply(lambda m: m.GetNumAtoms()), ascending=True)
295
+ e3s = e3s.iloc[np.abs(e3s['Molecule'].apply(lambda x: x.GetNumAtoms()) - e3_median).argsort()]
296
+
297
+ # If any of the substructures is not found, get the matching linkers to be
298
+ # used later (do it only once).
299
+ linkers = None
300
+ if len(pois) == 0 or len(e3s) == 0 or split_with_substr_and_linker_matching:
301
+ matches = dictionaries['Linker with direction']['Molecule'].apply(lambda x: get_substr_match(protac_mol, x, num_allowed_fragments=2))
302
+ linkers = dictionaries['Linker with direction'][matches]
303
+ linkers = linkers.sort_values(by='Molecule', key=lambda s: s.apply(lambda m: m.GetNumAtoms()), ascending=False)
304
+
305
+ # dummy_attachment_id = 1
306
+ # mapping_found = False
307
+ # for _, linker in linkers.iterrows():
308
+ # if mapping_found:
309
+ # break
310
+ # for _, poi in pois.iterrows():
311
+ # if mapping_found:
312
+ # break
313
+ # for _, e3 in e3s.iterrows():
314
+ # if mapping_found:
315
+ # break
316
+ # # Get the replace side chain
317
+ # e3_mapped = Chem.ReplaceSidechains(protac_mol, e3['Molecule'], useChirality=True)
318
+ # e3_mapped = rename_attachment_id(e3_mapped, dummy_attachment_id, e3_attachment_id)
319
+ # if e3_mapped is None:
320
+ # continue
321
+
322
+ # poi_mapped = Chem.ReplaceSidechains(protac_mol, poi['Molecule'], useChirality=True)
323
+ # poi_mapped = rename_attachment_id(poi_mapped, dummy_attachment_id, poi_attachment_id)
324
+ # if poi_mapped is None:
325
+ # continue
326
+
327
+ # # Join the substructures as fragments
328
+ # protac_candidate = canonize('.'.join([linker['SMILES'], e3_mapped, poi_mapped]))
329
+ # protac_candidate = Chem.MolFromSmiles(protac_candidate)
330
+ # protac_candidate = canonize(Chem.molzip(protac_candidate))
331
+ # if check_reassembly(protac_mol, protac_candidate):
332
+ # print('Found a match!')
333
+ # mapping_found = True
334
+
335
+ # # substructs = {
336
+ # # 'linker': linker['Molecule'],
337
+ # # 'e3': e3['Molecule'],
338
+ # # 'poi': poi['Molecule'],
339
+ # # }
340
+ # # mapped_row = get_split_row(row, dictionaries, substructs, poi['SMILES'], e3['SMILES'])
341
+ # # mapped_row['Notes'] = 'Obtained from matching E3, POI, and Linker found in dictionaries.'
342
+ # # return mapped_row
343
+
344
+
345
+ # TODO: Add a variable to get mapped ligands even if the checks failed... add a note when it happens
346
+ best_substructs_candidate = None
347
+
348
+ # There were matching E3s and matching POIs: try to recover the linker from
349
+ # an unmapped E3 and an unmapped POI.
350
+ if len(e3s) > 0 and len(pois) > 0:
351
+ for _, poi in pois.iterrows():
352
+ for _, e3 in e3s.iterrows():
353
+ additional_notes = '(matching_poi=True)(matching_e3=True)(matching_linker=None)'
354
+ substructs = get_substructs_from_unmapped_e3_poi(protac_smiles, protac_mol, poi['Molecule'], e3['Molecule'])
355
+
356
+ # If the substructure is not found, try to get it from a non-perfect match
357
+ if substructs is None:
358
+ fixed_poi = get_substructure_from_non_perfect_match(protac_mol, poi['Molecule'], poi_attachment_id)
359
+ fixed_e3 = get_substructure_from_non_perfect_match(protac_mol, e3['Molecule'], e3_attachment_id)
360
+ fixed_poi = poi['Molecule'] if fixed_poi is None else fixed_poi
361
+ fixed_e3 = e3['Molecule'] if fixed_e3 is None else fixed_e3
362
+ if fixed_poi is not None and fixed_e3 is not None:
363
+ substructs = get_substructs_from_unmapped_e3_poi(protac_smiles, protac_mol, fixed_poi, fixed_e3)
364
+ if Chem.MolToSmiles(fixed_e3) != e3['SMILES']:
365
+ additional_notes += '(non_perfect_e3_match=True)'
366
+ else:
367
+ additional_notes += '(non_perfect_e3_match=False)'
368
+
369
+ if Chem.MolToSmiles(fixed_poi) != poi['SMILES']:
370
+ additional_notes += '(non_perfect_poi_match=True)'
371
+ else:
372
+ additional_notes += '(non_perfect_poi_match=False)'
373
+
374
+ if substructs is not None:
375
+ size_check = check_substructs_size(protac_mol, substructs, size_perc_threshold=0.7)
376
+
377
+ # Check if the linker is too similar to any of the matching POIs or E3s (use the bulk Tanimoto similarity)
378
+ if not check_linker_similarity(substructs['linker'], pois, e3s, morgan_fp_generator=morgan_fp_generator, e3s_similarity_threshold=similarity_threshold, pois_similarity_threshold=similarity_threshold):
379
+ best_substructs_candidate = substructs
380
+ continue
381
+
382
+ if not size_check and not check_substructs_similarity(protac_fp, substructs, similarity_threshold=similarity_threshold, morgan_fp_generator=morgan_fp_generator):
383
+ best_substructs_candidate = substructs
384
+ # display_mol(protac_mol)
385
+ continue
386
+
387
+ # Fix the bonds close to amide and ester groups, if necessary
388
+ substructs_copy = substructs.copy()
389
+ substructs = adjust_amide_bonds_in_substructs(substructs, protac_smiles)
390
+ # Check and report if any SMILES was changed
391
+ if substructs['linker'] != substructs_copy['linker']:
392
+ additional_notes += '(amide_bonds_fixed=True)'
393
+ else:
394
+ additional_notes += '(amide_bonds_fixed=False)'
395
+
396
+ substructs_copy = substructs.copy()
397
+ substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
398
+ # Check and report if any SMILES was changed
399
+ if substructs['linker'] != substructs_copy['linker']:
400
+ additional_notes += '(ester_bonds_fixed=True)'
401
+ else:
402
+ additional_notes += '(ester_bonds_fixed=False)'
403
+
404
+ # Add the mapped PROTAC to the final list
405
+ mapped_row = get_split_row(row, substructs)
406
+ mapped_row['Notes'] = notes + additional_notes
407
+ return mapped_row
408
+
409
+ # There were no matching POIs, but some E3s and linkers matched: try to
410
+ # recover the E3 from an unmapped POI and a mapped Linker
411
+ if len(e3s) > 0 and split_with_substr_and_linker_matching: # len(pois) == 0 and
412
+ # NOTE: Only take the largest linker(s) into account
413
+ if max_iter_on_linkers:
414
+ selected_linkers = linkers.iloc[:max_iter_on_linkers, :]
415
+ else:
416
+ selected_linkers = linkers.iloc[:1, :]
417
+
418
+ for _, e3 in e3s.iterrows():
419
+ # Adjust the E3 molecule if it is not a perfect match
420
+ e3_mol_fixed = get_substructure_from_non_perfect_match(protac_mol, e3['Molecule'], e3_attachment_id)
421
+ e3_mol = e3['Molecule'] if e3_mol_fixed is None else e3_mol_fixed
422
+ e3_mol = remove_dummy_atoms(e3_mol)
423
+ if Chem.MolToSmiles(e3_mol) != e3['SMILES']:
424
+ non_perfect_e3_match = True
425
+ else:
426
+ non_perfect_e3_match = False
427
+
428
+ for _, linker in selected_linkers.iterrows():
429
+ additional_notes = f'(matching_poi=False)(matching_e3=True)(matching_linker=True)({non_perfect_e3_match=})'
430
+
431
+ substructs = get_substructs_from_substr_and_linker(
432
+ protac_smiles=protac_smiles,
433
+ protac=protac_mol,
434
+ substr=e3_mol,
435
+ linker=linker['Molecule'],
436
+ attachment_id=e3_attachment_id,
437
+ )
438
+ if substructs is not None:
439
+ size_check = check_substructs_size(protac_mol, substructs, size_perc_threshold=0.7)
440
+
441
+ if not check_linker_similarity(substructs['linker'], substructs['poi'], e3s, morgan_fp_generator=morgan_fp_generator, e3s_similarity_threshold=similarity_threshold, pois_similarity_threshold=similarity_threshold):
442
+ best_substructs_candidate = substructs
443
+ continue
444
+
445
+ if not size_check and not check_substructs_similarity(protac_fp, substructs, similarity_threshold=similarity_threshold, morgan_fp_generator=morgan_fp_generator):
446
+ best_substructs_candidate = substructs
447
+ # display_mol(protac_mol)
448
+ continue
449
+
450
+ # Fix the bonds close to amide and ester groups, if necessary
451
+ substructs_copy = substructs.copy()
452
+ substructs = adjust_amide_bonds_in_substructs(substructs, protac_smiles)
453
+ if substructs['linker'] != substructs_copy['linker']:
454
+ additional_notes += '(amide_bonds_fixed=True)'
455
+ else:
456
+ additional_notes += '(amide_bonds_fixed=False)'
457
+ substructs_copy = substructs.copy()
458
+ substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
459
+ if substructs['linker'] != substructs_copy['linker']:
460
+ additional_notes += '(ester_bonds_fixed=True)'
461
+ else:
462
+ additional_notes += '(ester_bonds_fixed=False)'
463
+
464
+ mapped_row = get_split_row(row, substructs)
465
+ mapped_row['Notes'] = notes + additional_notes
466
+ return mapped_row
467
+
468
+ # Swap the attachment points on the linker and try again
469
+ linker_swapped = swap_attachment_points(linker['SMILES'])
470
+ substructs = get_substructs_from_substr_and_linker(
471
+ protac_smiles=protac_smiles,
472
+ protac=protac_mol,
473
+ substr=e3_mol,
474
+ linker=Chem.MolFromSmiles(linker_swapped),
475
+ attachment_id=e3_attachment_id,
476
+ )
477
+ additional_notes += '(attachment_points_swapped_in_linker=True)'
478
+ if substructs is not None:
479
+
480
+ size_check = check_substructs_size(protac_mol, substructs, size_perc_threshold=0.7)
481
+
482
+ if not check_linker_similarity(substructs['linker'], substructs['poi'], e3s, morgan_fp_generator=morgan_fp_generator, e3s_similarity_threshold=similarity_threshold, pois_similarity_threshold=similarity_threshold):
483
+ continue
484
+
485
+ if not size_check and not check_substructs_similarity(protac_fp, substructs, similarity_threshold=similarity_threshold, morgan_fp_generator=morgan_fp_generator):
486
+ # display_mol(protac_mol)
487
+ continue
488
+
489
+ # Fix the bonds close to amide and ester groups, if necessary
490
+ substructs_copy = substructs.copy()
491
+ substructs = adjust_amide_bonds_in_substructs(substructs, protac_smiles)
492
+ if substructs['linker'] != substructs_copy['linker']:
493
+ additional_notes += '(amide_bonds_fixed=True)'
494
+ else:
495
+ additional_notes += '(amide_bonds_fixed=False)'
496
+ substructs_copy = substructs.copy()
497
+ substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
498
+ if substructs['linker'] != substructs_copy['linker']:
499
+ additional_notes += '(ester_bonds_fixed=True)'
500
+
501
+ mapped_row = get_split_row(row, substructs)
502
+ mapped_row['Notes'] = notes + additional_notes
503
+ return mapped_row
504
+
505
+ # There were no matching E3s, but some POIs and linkers matched: try to
506
+ # recover the POI from an unmapped E3 and a mapped Linker
507
+ if len(pois) > 0 and split_with_substr_and_linker_matching: # and len(e3s) == 0
508
+ # NOTE: Only take the largest linker(s) into account
509
+ if max_iter_on_linkers:
510
+ selected_linkers = linkers.iloc[:max_iter_on_linkers, :]
511
+ else:
512
+ selected_linkers = linkers.iloc[:1, :]
513
+
514
+ for _, poi in pois.iterrows():
515
+ poi_mol = get_substructure_from_non_perfect_match(protac_mol, poi['Molecule'], poi_attachment_id)
516
+ poi_mol = poi['Molecule'] if poi_mol is None else poi_mol
517
+ poi_mol = remove_dummy_atoms(poi_mol)
518
+ if Chem.MolToSmiles(poi_mol) != poi['SMILES']:
519
+ non_perfect_poi_match = True
520
+ else:
521
+ non_perfect_poi_match = False
522
+
523
+ for _, linker in selected_linkers.iterrows():
524
+ additional_notes = f'(matching_poi=True)(matching_e3=False)(matching_linker=True)({non_perfect_poi_match=})'
525
+
526
+ substructs = get_substructs_from_substr_and_linker(
527
+ protac_smiles=protac_smiles,
528
+ protac=protac_mol,
529
+ substr=poi_mol,
530
+ linker=linker['Molecule'],
531
+ attachment_id=poi_attachment_id,
532
+ )
533
+ if substructs is not None:
534
+ size_check = check_substructs_size(protac_mol, substructs, size_perc_threshold=0.7)
535
+
536
+ if not check_linker_similarity(substructs['linker'], pois, substructs['e3'], morgan_fp_generator=morgan_fp_generator, e3s_similarity_threshold=similarity_threshold, pois_similarity_threshold=similarity_threshold):
537
+ best_substructs_candidate = substructs
538
+ continue
539
+
540
+ if not size_check and not check_substructs_similarity(protac_fp, substructs, similarity_threshold=similarity_threshold, morgan_fp_generator=morgan_fp_generator):
541
+ best_substructs_candidate = substructs
542
+ # display_mol(protac_mol)
543
+ continue
544
+
545
+ # Fix the bonds close to amide and ester groups, if necessary
546
+ substructs_copy = substructs.copy()
547
+ substructs = adjust_amide_bonds_in_substructs(substructs, protac_smiles)
548
+ if substructs['linker'] != substructs_copy['linker']:
549
+ additional_notes += '(amide_bonds_fixed=True)'
550
+ substructs_copy = substructs.copy()
551
+ substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
552
+ if substructs['linker'] != substructs_copy['linker']:
553
+ additional_notes += '(ester_bonds_fixed=True)'
554
+
555
+ mapped_row = get_split_row(row, substructs)
556
+ mapped_row['Notes'] = notes + additional_notes
557
+ return mapped_row
558
+
559
+ # Swap the attachment points on the linker and try again
560
+ linker_swapped = swap_attachment_points(linker['SMILES'])
561
+ substructs = get_substructs_from_substr_and_linker(
562
+ protac_smiles=protac_smiles,
563
+ protac=protac_mol,
564
+ substr=poi_mol,
565
+ linker=Chem.MolFromSmiles(linker_swapped),
566
+ attachment_id=poi_attachment_id,
567
+ )
568
+ additional_notes += '(attachment_points_swapped_in_linker=True)'
569
+ if substructs is not None:
570
+
571
+ size_check = check_substructs_size(protac_mol, substructs, size_perc_threshold=0.7)
572
+
573
+ if not check_linker_similarity(substructs['linker'], substructs['poi'], e3s, morgan_fp_generator=morgan_fp_generator, e3s_similarity_threshold=similarity_threshold, pois_similarity_threshold=similarity_threshold):
574
+ best_substructs_candidate = substructs
575
+ continue
576
+
577
+ if not size_check and not check_substructs_similarity(protac_fp, substructs, similarity_threshold=similarity_threshold, morgan_fp_generator=morgan_fp_generator):
578
+ best_substructs_candidate = substructs
579
+ # display_mol(protac_mol)
580
+ continue
581
+
582
+ # Fix the bonds close to amide and ester groups, if necessary
583
+ substructs_copy = substructs.copy()
584
+ substructs = adjust_amide_bonds_in_substructs(substructs, protac_smiles)
585
+ if substructs['linker'] != substructs_copy['linker']:
586
+ additional_notes += '(amide_bonds_fixed=True)'
587
+ substructs_copy = substructs.copy()
588
+ substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
589
+ if substructs['linker'] != substructs_copy['linker']:
590
+ additional_notes += '(ester_bonds_fixed=True)'
591
+
592
+ mapped_row = get_split_row(row, substructs)
593
+ mapped_row['Notes'] = notes + additional_notes
594
+ return mapped_row
595
+
596
+
597
+ # Get all substructure matches in the Linker with direction dictionary
598
+ # NOTE: This code is repeated here for performance reasons, to avoid
599
+ # calculating the matches if not needed.
600
+ if linkers is None and max_iter_on_linkers:
601
+ matches = dictionaries['Linker with direction']['Molecule'].apply(lambda x: get_substr_match(protac_mol, x, num_allowed_fragments=2))
602
+ linkers = dictionaries['Linker with direction'][matches]
603
+ # Sort all the matches by the number of atoms in the linker, the biggest first
604
+ linkers = linkers.sort_values(by='Molecule', key=lambda s: s.apply(lambda m: m.GetNumAtoms()), ascending=False)
605
+
606
+ # for j, (_, linker) in enumerate(linkers.iterrows()):
607
+ # additional_notes = '(matching_poi=False)(matching_e3=False)(matching_linker=True)'
608
+ # if j >= max_iter_on_linkers or max_iter_on_linkers == 0:
609
+ # return None
610
+
611
+ for j in range(max_iter_on_linkers):
612
+ additional_notes = '(matching_poi=False)(matching_e3=False)(matching_linker=True)'
613
+ linker = linkers.iloc[j, :]
614
+ substructs = get_substructs_from_mapped_linker(protac_smiles, linker['SMILES'])
615
+
616
+ if substructs is not None:
617
+ if not check_linker_similarity(substructs['linker'], substructs['poi'], substructs['e3'], morgan_fp_generator=morgan_fp_generator, e3s_similarity_threshold=similarity_threshold, pois_similarity_threshold=similarity_threshold):
618
+ best_substructs_candidate = substructs
619
+ continue
620
+
621
+ size_check = check_substructs_size(protac_mol, substructs, size_perc_threshold=0.7)
622
+ if not size_check and not check_substructs_similarity(protac_fp, substructs, similarity_threshold=similarity_threshold, morgan_fp_generator=morgan_fp_generator):
623
+ best_substructs_candidate = substructs
624
+ # display_mol(protac_mol)
625
+ continue
626
+
627
+ # Fix the bonds close to amide and ester groups, if necessary
628
+ substructs_copy = substructs.copy()
629
+ substructs = adjust_amide_bonds_in_substructs(substructs, protac_smiles)
630
+ if substructs['linker'] != substructs_copy['linker']:
631
+ additional_notes += '(amide_bonds_fixed=True)'
632
+ substructs_copy = substructs.copy()
633
+ substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
634
+ if substructs['linker'] != substructs_copy['linker']:
635
+ additional_notes += '(ester_bonds_fixed=True)'
636
+
637
+ if not check_substructs_size(protac_mol, substructs, size_perc_threshold=0.95):
638
+ best_substructs_candidate = substructs
639
+ continue
640
+
641
+ mapped_row = get_split_row(row, substructs)
642
+ mapped_row['Notes'] = notes + additional_notes
643
+ return mapped_row
644
+
645
+ # If we are here, it means that the substructures found in the above loops
646
+ # failed the similarity checks. We add a note and return the best
647
+ # substructure candidate found.
648
+ if best_substructs_candidate is not None:
649
+ substructs_copy = substructs.copy()
650
+ substructs = adjust_amide_bonds_in_substructs(best_substructs_candidate, protac_smiles)
651
+ if substructs['linker'] != best_substructs_candidate['linker']:
652
+ notes += '(amide_bonds_fixed=True)'
653
+ substructs_copy = substructs.copy()
654
+ substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
655
+ if substructs['linker'] != substructs_copy['linker']:
656
+ notes += '(ester_bonds_fixed=True)'
657
+ mapped_row = get_split_row(row, substructs)
658
+ mapped_row['Notes'] = notes + '(similarity_checks_failed=True)'
659
+ return mapped_row
660
+
661
+ return None
662
+
663
+
664
+ def split_protacs(
665
+ protac_df: pd.DataFrame,
666
+ dictionaries: Dict[str, pd.DataFrame],
667
+ max_iter_on_linkers: int = 0,
668
+ split_with_substr_and_linker_matching: bool = False,
669
+ biggest_matches_first: bool = True,
670
+ update_dict_if_ids_not_found: bool = False,
671
+ use_multiprocessing: bool = False,
672
+ ) -> pd.DataFrame:
673
+ """ Maps PROTACs to their substructures.
674
+
675
+ Args:
676
+ protac_df: The input PROTAC dataframe.
677
+ dictionaries: The input dictionaries.
678
+ max_iter_on_linkers: The maximum number of matching linkers to iterate over. If zero, there will be no attempt to match linkers in the dictionary. If negative, iterate over all matched linkers. Default is 0.
679
+ biggest_matches_first: Whether to sort the matches by the number of atoms in the molecule. Default is True.
680
+ update_dict_if_ids_not_found: DEPRECATED. Whether to update the dictionary if the substructure IDs are not found. Default is False.
681
+ use_multiprocessing: Whether to use multiprocessing. Default is False.
682
+
683
+ Returns:
684
+ The mapped PROTAC dataframe.
685
+ """
686
+ # if use_multiprocessing:
687
+ # global split_single_protac
688
+
689
+ # with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
690
+ # results = pool.map(partial(split_single_protac, dictionaries=dictionaries, biggest_matches_first=biggest_matches_first, max_iter_on_linkers=max_iter_on_linkers), protac_df.copy().to_dict(orient='records'))
691
+
692
+ # mapped_protacs = pd.DataFrame(results)
693
+ # mapped_protacs = mapped_protacs.dropna(subset=['POI Ligand SMILES with direction', 'E3 Binder SMILES with direction', 'Linker SMILES with direction'])
694
+ # return mapped_protacs
695
+
696
+ if use_multiprocessing:
697
+ # TODO: The following does run in parallel, but it gives wrong results. I don't know why. I will have to investigate further.
698
+ results = Parallel(n_jobs=-1)(delayed(split_single_protac)(row, dictionaries=dictionaries, biggest_matches_first=biggest_matches_first, max_iter_on_linkers=max_iter_on_linkers) for _, row in protac_df.iterrows())
699
+ mapped_protacs = pd.DataFrame([r for r in results if r is not None])
700
+ return mapped_protacs
701
+
702
+ mapped_protacs = []
703
+ for i, row in (pbar := tqdm(protac_df.iterrows(), total=len(protac_df))):
704
+ pbar.set_description(f'PROTAC n.{i:4d}')
705
+
706
+ r = split_single_protac(
707
+ row,
708
+ dictionaries,
709
+ biggest_matches_first=biggest_matches_first,
710
+ max_iter_on_linkers=max_iter_on_linkers,
711
+ split_with_substr_and_linker_matching=split_with_substr_and_linker_matching,
712
+ )
713
+ if r is not None:
714
+ mapped_protacs.append(r)
715
+ tmp = pd.DataFrame(mapped_protacs)
716
+ pbar.set_postfix({'len_mapped': len(tmp), 'perc_mapped': f'{len(tmp) / len(protac_df):.1%}'})
717
+
718
+ mapped_protacs = pd.DataFrame(mapped_protacs)
719
+ return mapped_protacs
720
+
721
+
722
+ def parse_notes(notes: str) -> Dict[str, Any]:
723
+ # Define the regex pattern to match key-value pairs within parentheses
724
+ pattern = r'\(([^=]+)=([^\)]+)\)'
725
+
726
+ # Find all matches in the string
727
+ matches = re.findall(pattern, notes)
728
+
729
+ # Initialize an empty dictionary to store the parsed key-value pairs
730
+ parsed_dict = {}
731
+
732
+ # Iterate over the matches and add them to the dictionary
733
+ for key, value in matches:
734
+ # Convert the value to the appropriate type (int, bool, None, or str)
735
+ if value.isdigit():
736
+ parsed_dict[key] = int(value)
737
+ elif value.lower() == 'true':
738
+ parsed_dict[key] = True
739
+ elif value.lower() == 'false':
740
+ parsed_dict[key] = False
741
+ elif value.lower() == 'none':
742
+ parsed_dict[key] = None
743
+ else:
744
+ parsed_dict[key] = value
745
+
746
+ return parsed_dict
747
+
748
+
749
+ def iterative_protac_splitting(
750
+ dictionaries: Dict[str, pd.DataFrame],
751
+ data_dir: str,
752
+ ) -> Dict[str, pd.DataFrame]:
753
+ """ Map PROTACs to their substructures in an iterative way.
754
+
755
+ Args:
756
+ dictionaries: The input dictionaries. The same format as the output of the `update_dictionary` function.
757
+ data_dir: The directory where the output data is stored.
758
+
759
+ Returns:
760
+ The final mapped PROTAC dataframe.
761
+ """
762
+
763
+ final_df = None
764
+ non_mapped_protacs = dictionaries['PROTAC'].copy()
765
+
766
+ start_from_beginning = True # Re-map all PROTACs ignoring loading previous results
767
+ step = -1
768
+ max_iter_on_linkers = 0
769
+ split_with_substr_and_linker_matching = False
770
+
771
+ while True:
772
+ if max_iter_on_linkers == -1 or non_mapped_protacs.empty or step >= 50:
773
+ break
774
+
775
+ if max_iter_on_linkers == 5:
776
+ max_iter_on_linkers = -1 # Iterate over all linkers
777
+
778
+ step += 1
779
+ print('-' * 100)
780
+ print(f'Step n.{step}')
781
+ print(f'Max iterations on linkers: {max_iter_on_linkers}')
782
+ print(f'Map with substr and linker matching: {split_with_substr_and_linker_matching}')
783
+ print('-' * 50)
784
+
785
+ step_filename = os.path.join(data_dir, f'mapped_protacs_{step=}.csv')
786
+ final_filename = os.path.join(data_dir, 'mapped_protacs.csv')
787
+ non_mapped_filename = os.path.join(data_dir, 'non_mapped_protacs.csv')
788
+
789
+ if os.path.exists(step_filename) and not start_from_beginning:
790
+ # Check if all lines of the file are empty
791
+ with open(step_filename, 'r') as f:
792
+ lines = f.readlines()
793
+ if all([len(line.strip()) == 0 for line in lines]):
794
+ mapped_protacs = pd.DataFrame()
795
+ else:
796
+ mapped_protacs = pd.read_csv(step_filename)
797
+ else:
798
+ mapped_protacs = split_protacs(
799
+ non_mapped_protacs,
800
+ dictionaries=dictionaries,
801
+ split_with_substr_and_linker_matching=split_with_substr_and_linker_matching,
802
+ max_iter_on_linkers=max_iter_on_linkers,
803
+ biggest_matches_first=False,
804
+ use_multiprocessing=False,
805
+ )
806
+ # Add a string at the end of the strings in the 'Notes' column
807
+ if not mapped_protacs.empty:
808
+ mapped_protacs['Notes'] = mapped_protacs['Notes'].apply(lambda x: f'{x}({step=})')
809
+ mapped_protacs.to_csv(step_filename, index=False)
810
+
811
+ # Update the final dataframe and save it to file
812
+ if final_df is None:
813
+ final_df = mapped_protacs
814
+ else:
815
+ final_df = pd.concat([final_df, mapped_protacs], axis=0).drop_duplicates(subset=['PROTAC SMILES'])
816
+ final_df.to_csv(final_filename, index=False)
817
+ print(f'All mapped PROTACs saved to: {final_filename}')
818
+
819
+ # Reporting information
820
+ mapped_perc = len(mapped_protacs) / len(non_mapped_protacs)
821
+ total_mapped_perc = len(final_df) / len(dictionaries['PROTAC'])
822
+ print(f'Number of mapped PROTACs: {len(mapped_protacs)} ({mapped_perc:.2%})')
823
+ print(f'Total num. of mapped PROTACs: {len(final_df)} ({total_mapped_perc:.2%})')
824
+ print('-' * 50)
825
+ print(final_df['Notes'].value_counts())
826
+ print('-' * 50)
827
+
828
+ # Get the non-mapped PROTACs yet and save them to file
829
+ non_mapped_protacs = dictionaries['PROTAC'][~dictionaries['PROTAC']['SMILES'].isin(final_df['PROTAC SMILES'])].copy()
830
+ non_mapped_protacs[['SMILES', 'ID']].to_csv(non_mapped_filename, index=False)
831
+ print(f'Non-mapped PROTACs saved to: {non_mapped_filename}')
832
+
833
+ # Control logic for breaking the loop
834
+ if mapped_protacs.empty:
835
+ if max_iter_on_linkers == 0 and not split_with_substr_and_linker_matching:
836
+ split_with_substr_and_linker_matching = True
837
+ continue
838
+ else:
839
+ max_iter_on_linkers += 1
840
+ continue
841
+ else:
842
+ # Using only the linker to map the PROTACs can be unreliable, so if we
843
+ # found new PROTACs, we should the max_iter_on_linkers to zero and try
844
+ # to map the PROTACs again with the newly found substructures.
845
+ max_iter_on_linkers = 0
846
+ split_with_substr_and_linker_matching = False
847
+
848
+ # Update all dictionaries with the substructures of the mapped PROTACs
849
+ smiles_list = mapped_protacs['Linker SMILES with direction'].unique()
850
+ smiles_list = [canonize(smiles) for smiles in smiles_list]
851
+ dictionaries['Linker with direction'] = update_dictionary(dictionaries['Linker with direction'], smiles_list)
852
+
853
+ # Avoid adding POIs that are in the E3 dictionary!
854
+ smiles_list = mapped_protacs['POI Ligand SMILES'].unique()
855
+ smiles_list = [canonize(smiles) for smiles in smiles_list]
856
+ smiles_list = [s for s in smiles_list if s not in dictionaries['E3 Binder']['SMILES'].values]
857
+
858
+ smiles_list = [remove_dummy_atoms(s) for s in smiles_list if s is not None]
859
+
860
+ # Use Tanimoto similarity to prevent adding POIs too similar to E3s
861
+ similarity_threshold = 0.5
862
+ radius = 2
863
+ nbits = 2048
864
+ morgan_fp_generator = Chem.rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=nbits, useBondTypes=True, includeChirality=True)
865
+
866
+ pois_to_add = []
867
+ for poi_smiles in smiles_list:
868
+ poi_mol = Chem.MolFromSmiles(poi_smiles)
869
+ poi_fp = morgan_fp_generator.GetFingerprint(poi_mol)
870
+ similarities = DataStructs.BulkTanimotoSimilarity(poi_fp, dictionaries['E3 Binder']['FP'].to_list())
871
+ skip_poi = False
872
+ for sim in similarities:
873
+ if sim >= similarity_threshold:
874
+ skip_poi = True
875
+ break
876
+ if not skip_poi:
877
+ pois_to_add.append(poi_smiles)
878
+
879
+ dictionaries['POI Ligand'] = update_dictionary(dictionaries['POI Ligand'], smiles_list)
880
+
881
+ # Avoid adding E3s that are in the POI dictionary!
882
+ smiles_list = mapped_protacs['E3 Binder SMILES'].unique()
883
+ smiles_list = [canonize(smiles) for smiles in smiles_list]
884
+ smiles_list = [s for s in smiles_list if s not in dictionaries['POI Ligand']['SMILES'].values]
885
+ smiles_list = [remove_dummy_atoms(s) for s in smiles_list if s is not None]
886
+ dictionaries['E3 Binder'] = update_dictionary(dictionaries['E3 Binder'], smiles_list)
887
+
888
+ # Save all dictionaries to file
889
+ for key, dictionary in dictionaries.items():
890
+ filename = os.path.join(data_dir, f'dictionary_{key.lower().replace(" ", "_")}.csv')
891
+ dictionary[['ID', 'SMILES']].to_csv(filename, index=False)
892
+ print(f'Dictionary saved to: {filename}')
893
+
894
+ return dictionaries
protac_splitter/data/curation/mapping_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit import Chem
2
+ import pandas as pd
3
+
4
+ from protac_splitter.chemoinformatics import (
5
+ canonize_smiles,
6
+ remove_stereo,
7
+ get_mol_id,
8
+ )
9
+
10
+ def update_dictionary(
11
+ dictionary: pd.DataFrame,
12
+ substr_to_add: list,
13
+ morgan_fp_generator = None,
14
+ verbose: int = 0,
15
+ ) -> pd.DataFrame:
16
+ """ Updates a dictionary with a list of additional substructures.
17
+
18
+ The dictionary is a dataframe with columns 'SMILES', 'Molecule', 'ID', and 'FP'.
19
+
20
+ Args:
21
+ dictionary: The input dictionary dataframe.
22
+ substr_to_add: The list of additional substructures.
23
+
24
+ Returns:
25
+ The updated dictionary dataframe.
26
+ """
27
+ # Canonize the SMILES strings
28
+ substr_to_add = [canonize_smiles(smiles) for smiles in substr_to_add if smiles is not None]
29
+ substr_to_add = list(set(substr_to_add))
30
+
31
+ # Remove entries already in the dictionary
32
+ for smiles in substr_to_add:
33
+ if not dictionary.empty and smiles in dictionary[f'SMILES'].unique().tolist():
34
+ if verbose > 1:
35
+ print(f'\tWARNING. SMILES already in the dictionary: {smiles}')
36
+ # Remove it from the list
37
+ substr_to_add.remove(smiles)
38
+
39
+ new_entries = []
40
+ for smiles in substr_to_add:
41
+ try:
42
+ mol = Chem.MolFromSmiles(smiles)
43
+ except Exception as e:
44
+ if verbose:
45
+ print(e)
46
+ mol = None
47
+ # Remove entries that result in invalid molecules
48
+ if mol is None:
49
+ continue
50
+ new_entries.append({
51
+ 'SMILES': smiles,
52
+ 'Molecule': mol,
53
+ 'ID': get_mol_id(smiles),
54
+ })
55
+ # Try adding its no-stereochemistry version as well
56
+ smiles_nostereo = remove_stereo(smiles)
57
+ if smiles_nostereo is not None and smiles_nostereo != smiles:
58
+ mol_nostereo = Chem.MolFromSmiles(smiles_nostereo)
59
+ if mol_nostereo is not None:
60
+ new_entries.append({
61
+ 'SMILES': canonize_smiles(smiles_nostereo),
62
+ 'Molecule': mol_nostereo,
63
+ 'ID': get_mol_id(smiles_nostereo),
64
+ })
65
+ new_entries = pd.DataFrame(new_entries).drop_duplicates()
66
+
67
+ if len(new_entries) > 0:
68
+ # Add fingerprints to the new entries
69
+ if morgan_fp_generator is None:
70
+ morgan_fp_generator = Chem.rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048, useBondTypes=True, includeChirality=True)
71
+
72
+ new_entries['FP'] = new_entries['Molecule'].apply(lambda x: morgan_fp_generator.GetFingerprint(x) if x is not None else None)
73
+ if verbose:
74
+ print(f'Number of substructures added to the dictionary: {len(new_entries)}')
75
+
76
+ # Return the updated dictionary
77
+ return pd.concat([dictionary, pd.DataFrame(new_entries)], axis=0).drop_duplicates(subset='SMILES').reset_index(drop=True)
protac_splitter/data/curation/substructure_extraction.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Any, Dict, List, Optional, Union
3
+ from collections import Counter
4
+
5
+ from rdkit import Chem
6
+ from rdkit.Chem import Draw
7
+
8
+ from protac_splitter.chemoinformatics import (
9
+ dummy2query,
10
+ remove_dummy_atoms,
11
+ canonize,
12
+ canonize_smiles,
13
+ GetSubstructMatchesWithTimeout,
14
+ )
15
+ from protac_splitter.display_utils import (
16
+ safe_display,
17
+ display_mol,
18
+ )
19
+ from protac_splitter.evaluation import check_reassembly
20
+
21
+
22
+ def get_substructs_from_mapped_linker(
23
+ protac_smiles: str,
24
+ linker_smiles: str,
25
+ e3_attachment_id: int = 2,
26
+ poi_attachment_id: int = 1,
27
+ verbose: int = 0,
28
+ ) -> Dict[str, str]:
29
+ """ Get the substructures of a PROTAC molecule from a mapped linker SMILES.
30
+
31
+ This function will return the substructures given a linker with
32
+ directionality, _i.e._, with the two attachment points mapped.
33
+
34
+ Args:
35
+ protac_smiles: The SMILES of the PROTAC molecule.
36
+ linker_smiles: The SMILES of the linker molecule. Must have attachment points.
37
+ verbose: Verbosity level.
38
+
39
+ Returns:
40
+ A dictionary with the substructure names as keys ('e3', 'linker', and 'poi') and their SMILES as values. None if the matching fails.
41
+ """
42
+ protac_smiles = canonize_smiles(protac_smiles)
43
+ linker_smiles = canonize_smiles(linker_smiles)
44
+
45
+ protac_mol = Chem.MolFromSmiles(protac_smiles)
46
+ linker_mol = Chem.MolFromSmiles(linker_smiles)
47
+
48
+ # Check if the linker is a substructure of the PROTAC
49
+ if not protac_mol.HasSubstructMatch(dummy2query(linker_mol), useChirality=True):
50
+ return None
51
+
52
+ # Split the big molecule into the two fragments
53
+ frags = Chem.ReplaceCore(protac_mol, dummy2query(linker_mol), labelByIndex=True, replaceDummies=False)
54
+ if frags is None:
55
+ return None
56
+ try:
57
+ frags = Chem.GetMolFrags(frags, asMols=True, sanitizeFrags=True)
58
+ except Exception as e:
59
+ # print(e)
60
+ return None
61
+
62
+ if verbose:
63
+ safe_display(protac_mol)
64
+ safe_display(linker_mol)
65
+
66
+ # The linker has a map number at its attachment points: the following is a
67
+ # dictionary that maps the atom index of the attachment points to their
68
+ # respective map numbers, i.e., the attachment IDs.
69
+ linker_idx2map = {}
70
+ for atom in linker_mol.GetAtoms():
71
+ if atom.GetAtomicNum() == 0:
72
+ linker_idx2map[atom.GetIdx()] = atom.GetAtomMapNum()
73
+ if verbose:
74
+ print(f'linker indexes: {linker_idx2map}')
75
+ print('-' * 80)
76
+
77
+ substructs = {'linker': linker_smiles}
78
+
79
+ # After splitting the PROTAC with ReplaceCore, the fragments will have as
80
+ # attachment points the same atom indexes as the linker. We can then use the
81
+ # map numbers from the linker to identify the attachment points in the
82
+ # PROTAC fragments and assign the correct map number to them, i.e., the
83
+ # attachment ID.
84
+ for i, side_mol in enumerate(frags):
85
+
86
+ side_smiles = Chem.MolToSmiles(side_mol, canonical=True)
87
+
88
+ # Use a regex to get the number in the pattern, e.g., [9*], in the SMILES
89
+ attachment_point = re.findall(r'\[(\d+)\*\]', side_smiles)
90
+ if attachment_point:
91
+ attachment_point = int(attachment_point[0])
92
+ else:
93
+ attachment_point = None
94
+
95
+ if verbose:
96
+ print(f'Side {i + 1} SMILES: {side_smiles}')
97
+ print(f'Attachment point: {attachment_point}')
98
+ safe_display(side_mol)
99
+
100
+ # Get the map from the linker
101
+ linker_attachment_point = linker_idx2map.get(attachment_point, None)
102
+
103
+ # Modify the SMILES to include the map number
104
+ if linker_attachment_point is not None:
105
+ side_smiles = re.sub(r'\[(\d+)\*\]', f'[*:{linker_attachment_point}]', side_smiles)
106
+ if f'[*:{e3_attachment_id}]' in side_smiles:
107
+ substructs['e3'] = canonize_smiles(side_smiles)
108
+ elif f'[*:{poi_attachment_id}]' in side_smiles:
109
+ substructs['poi'] = canonize_smiles(side_smiles)
110
+
111
+ if verbose:
112
+ print(f'Modified SMILES: {side_smiles}')
113
+ safe_display(Chem.MolFromSmiles(side_smiles))
114
+
115
+ # Canonize the substructures SMILES
116
+ substructs = {k: canonize_smiles(v) for k, v in substructs.items()}
117
+
118
+ # Check that the reassembled PROTAC matches the original PROTAC
119
+ if not check_reassembly(protac_smiles, '.'.join(substructs.values())):
120
+ return None
121
+
122
+ return substructs
123
+
124
+
125
+ def get_attachment_bonds(mol: Chem.Mol, match_atoms: List[int]) -> List[int]:
126
+ """ Get the bonds to break to separate the substructure from the PROTAC or R-groups molecule.
127
+
128
+ Args:
129
+ mol: The molecule to break, i.e., the PROTAC.
130
+ match_atoms: The atoms matched in the PROTAC molecule, from the GetSubstructMatch function.
131
+
132
+ Returns:
133
+ List[int]: The bond indices to break.
134
+ """
135
+ bonds_to_break = []
136
+ for idx in match_atoms:
137
+ atom = mol.GetAtomWithIdx(idx)
138
+ # Skip non-heavy atoms
139
+ if atom.GetAtomicNum() == 1:
140
+ continue
141
+ for bond in atom.GetBonds():
142
+ neighbor_idx = bond.GetOtherAtomIdx(idx)
143
+ # Skip if the neighbor atom if non-heavy
144
+ if mol.GetAtomWithIdx(neighbor_idx).GetAtomicNum() == 1:
145
+ continue
146
+ if neighbor_idx not in match_atoms:
147
+ bonds_to_break.append(bond.GetIdx())
148
+ # If more than one bond is found, e.g., if the substructure is
149
+ # connected to the PROTAC/R-groups in multiple places like in a
150
+ # ring, reset list of bonds and go to the next atom.
151
+ if len(bonds_to_break) > 1:
152
+ bonds_to_break = []
153
+ break
154
+ return bonds_to_break
155
+
156
+
157
+ def get_substructs_from_unmapped_e3_poi(
158
+ protac_smiles: str,
159
+ mol_protac: Chem.Mol,
160
+ mol_poi: Chem.Mol,
161
+ mol_e3: Chem.Mol,
162
+ poi_attachment_id: int = 1,
163
+ e3_attachment_id: int = 2,
164
+ verbose: int = 0,
165
+ stats: Counter = None,
166
+ ) -> Optional[Dict[str, str]]:
167
+ """ Get the matches of the POI, E3, and linker in the PROTAC molecule.
168
+
169
+ This function will return the substructures given a PROTAC and its unmapped
170
+ POI and E3 ligand substructures, _i.e._, they do not need to have the
171
+ attachment points in their SMILES strings.
172
+
173
+ Args:
174
+ mol_protac: The PROTAC molecule.
175
+ mol_poi: The POI ligand molecule. Must NOT contain the attachment point.
176
+ mol_e3: The E3 binder molecule. Must NOT contain the attachment point.
177
+ verbose: The verbosity level.
178
+
179
+ Returns:
180
+ Dict: The matches of the POI, E3, and linker in the PROTAC molecule. None if no match is found.
181
+ """
182
+ if verbose:
183
+ safe_display(mol_protac)
184
+
185
+ poi_match = mol_protac.GetSubstructMatch(mol_poi, useChirality=True)
186
+
187
+ # Get bonds to break to separate the POI ligand
188
+ bonds_to_break_poi = get_attachment_bonds(mol_protac, poi_match)
189
+
190
+ # Return if no bonds are found
191
+ if len(bonds_to_break_poi) != 1:
192
+ if stats is not None:
193
+ stats['multiple POI attachment bonds'] += 1
194
+ if verbose:
195
+ print('ERROR: Multiple POI attachment bonds')
196
+ return None
197
+
198
+ # Break the bonds to isolate the POI ligand
199
+ frag_mol_poi = Chem.FragmentOnBonds(mol_protac, bonds_to_break_poi, addDummies=True, dummyLabels=[(poi_attachment_id, poi_attachment_id)])
200
+
201
+ # Get the fragments resulting from bond breaking
202
+ try:
203
+ frags = Chem.GetMolFrags(frag_mol_poi, asMols=True, sanitizeFrags=True)
204
+ except Exception as e:
205
+ print(e)
206
+ return None
207
+
208
+ # Identify the POI ligand fragment
209
+ poi_fragment = None
210
+ for frag in frags:
211
+ if frag.HasSubstructMatch(mol_poi):
212
+ poi_fragment = frag
213
+ break
214
+ if poi_fragment is None:
215
+ if stats is not None:
216
+ stats['POI fragment not found'] += 1
217
+ if verbose:
218
+ print('ERROR: POI fragment not found')
219
+ return None
220
+
221
+ # Combine the remaining fragments to get the R-groups
222
+ # TODO: Check that the length of frags is 1, otherwise, there are multiple fragments
223
+ r_group_mol = [frag for frag in frags if frag != poi_fragment]
224
+ if len(r_group_mol) != 1:
225
+ if stats is not None:
226
+ stats['multiple POI fragments'] += 1
227
+ if verbose:
228
+ for frag in frags:
229
+ safe_display(frag)
230
+ print('ERROR: Multiple POI fragments')
231
+ return None
232
+ r_group_mol = r_group_mol[0]
233
+
234
+ if verbose:
235
+ print('POI:', Chem.MolToSmiles(poi_fragment, canonical=True))
236
+ safe_display(poi_fragment)
237
+
238
+ e3_match = r_group_mol.GetSubstructMatch(mol_e3, useChirality=True)
239
+
240
+ # Get bonds to break to isolate the E3 binder
241
+ bonds_to_break_e3 = get_attachment_bonds(r_group_mol, e3_match)
242
+
243
+ # Return if no bonds are found
244
+ if len(bonds_to_break_e3) != 1:
245
+ if stats is not None:
246
+ stats['multiple E3 attachment bonds'] += 1
247
+ if verbose:
248
+ safe_display(r_group_mol)
249
+ print('ERROR: Multiple E3 attachment bonds')
250
+ return None
251
+
252
+ # Break the bonds to isolate the E3 binder
253
+ frag_mol_e3 = Chem.FragmentOnBonds(r_group_mol, bonds_to_break_e3, addDummies=True, dummyLabels=[(e3_attachment_id, e3_attachment_id)])
254
+
255
+ # Get fragments after breaking bonds in R-groups
256
+ try:
257
+ frags = Chem.GetMolFrags(frag_mol_e3, asMols=True, sanitizeFrags=True)
258
+ except Exception as e:
259
+ print(e)
260
+ return None
261
+
262
+ # Identify the E3 binder fragment
263
+ e3_fragment = None
264
+ for frag in frags:
265
+ if frag.HasSubstructMatch(mol_e3):
266
+ e3_fragment = frag
267
+ break
268
+ if e3_fragment is None:
269
+ if stats is not None:
270
+ stats['E3 fragment not found'] += 1
271
+ if verbose:
272
+ print('ERROR: E3 fragment not found')
273
+ return None
274
+
275
+ if verbose:
276
+ print('E3:', Chem.MolToSmiles(e3_fragment, canonical=True))
277
+ safe_display(e3_fragment)
278
+
279
+ # The remaining fragment is the linker
280
+ # TODO: Check that the length of frags is 1, otherwise, there are multiple fragments
281
+ linker_mol = [frag for frag in frags if frag != e3_fragment]
282
+ if len(linker_mol) != 1:
283
+ if stats is not None:
284
+ stats['multiple E3 fragments'] += 1
285
+ if verbose:
286
+ for frag in frags:
287
+ safe_display(frag)
288
+ print('ERROR: Multiple E3 fragments')
289
+ return None
290
+ linker_mol = linker_mol[0]
291
+
292
+ poi_smiles = Chem.MolToSmiles(poi_fragment, canonical=True).replace(f'[{poi_attachment_id}*]', f'[*:{poi_attachment_id}]')
293
+ e3_smiles = Chem.MolToSmiles(e3_fragment, canonical=True).replace(f'[{e3_attachment_id}*]', f'[*:{e3_attachment_id}]')
294
+ linker_smiles = Chem.MolToSmiles(linker_mol, canonical=True).replace(f'[{poi_attachment_id}*]', f'[*:{poi_attachment_id}]').replace(f'[{e3_attachment_id}*]', f'[*:{e3_attachment_id}]')
295
+
296
+ # Get the substructure names and canonize their SMILES
297
+ substructs = {'poi': poi_smiles, 'e3': e3_smiles, 'linker': linker_smiles}
298
+ substructs = {k: canonize_smiles(v) for k, v in substructs.items()}
299
+
300
+ if verbose:
301
+ print('Linker:', Chem.MolToSmiles(linker_mol, canonical=True))
302
+ safe_display(linker_mol)
303
+
304
+ # Check that the reassembled PROTAC matches the original PROTAC
305
+ if check_reassembly(protac_smiles, '.'.join(substructs.values())):
306
+ return substructs
307
+
308
+ if stats is not None:
309
+ stats['reassembling failed'] += 1
310
+ if verbose:
311
+ print('ERROR: Reassembling failed')
312
+ return None
313
+
314
+
315
+ def get_substructure_from_non_perfect_match(
316
+ protac_mol: Chem.Mol,
317
+ substruct_mol: Chem.Mol,
318
+ attachment_id: int,
319
+ verbose: int = 0,
320
+ ) -> Chem.Mol:
321
+ """ Extract the correct substructure from a PROTAC molecule, given the
322
+ SMILES of a wrong substructure resulting in many fragments and matches.
323
+
324
+ Sometimes the substructure we have is not a _perfect_ substructure of the
325
+ PROTAC, _i.e._, it will generate more than two fragments when trying to
326
+ replace the PROTAC core with it. In this case, this function will perform
327
+ the following steps:
328
+
329
+ 1. Get the largest fragment by trying to replace the PROTAC core with the
330
+ substructure. This largest fragment will be the other substructure plus
331
+ the linker.
332
+ 2. We can now remove the largest fragment from the PROTAC to get the
333
+ "original" substructure without the smaller dangling fragments.
334
+
335
+ Args:
336
+ protac_mol (Chem.Mol): The PROTAC molecule.
337
+ substruct_smiles (Chem.Mol): The molecule of the wrong substructure, either the POI ligand or the E3 binder.
338
+ attachment_id (int): The attachment ID.
339
+
340
+ Returns:
341
+ Chem.Mol: The extracted substructure molecule. If failing, it will return None.
342
+ """
343
+ # Remove the substructure, even if there are "dangling" fragments, to obtain: PROTAC - substruct = (POI + Linker) + remainders
344
+ linker_and_other_mol = Chem.DeleteSubstructs(protac_mol, substruct_mol, useChirality=True)
345
+
346
+ # Get the largest fragment, i.e., the PROTAC - substruct = POI + Linker
347
+ try:
348
+ fragments = Chem.GetMolFrags(linker_and_other_mol, asMols=True)
349
+ except Exception as e:
350
+ if verbose:
351
+ print(e)
352
+ return None
353
+
354
+ if len(fragments) == 1:
355
+ if verbose:
356
+ print("WARNING. There are no small fragments, there's only one fragment.")
357
+
358
+ if not fragments:
359
+ if verbose:
360
+ print('ERROR. No fragments found.')
361
+ return None
362
+ largest_fragment = max(fragments, key=lambda x: x.GetNumAtoms())
363
+
364
+ # Get the match of the largest fragment in the PROTAC molecule
365
+ largest_match = protac_mol.GetSubstructMatch(largest_fragment, useChirality=True)
366
+
367
+ # Get bonds to break to isolate the substructure, i.e., the opposite of the POI + Linker
368
+ bonds_to_break = get_attachment_bonds(protac_mol, largest_match)
369
+
370
+ if len(bonds_to_break) != 1:
371
+ if verbose:
372
+ print(f'ERROR. The bond to break is not a single one: {bonds_to_break}')
373
+ return None
374
+
375
+ # Break the bonds to isolate the substructure
376
+ frag_mol_substruct = Chem.FragmentOnBonds(protac_mol, bonds_to_break, addDummies=True, dummyLabels=[(attachment_id, attachment_id)])
377
+
378
+ # Get fragments after breaking bonds, i.e., the POI + Linker and the substructure without "remainders"
379
+ try:
380
+ frags = Chem.GetMolFrags(frag_mol_substruct, asMols=True, sanitizeFrags=True)
381
+ except Exception as e:
382
+ if verbose:
383
+ print(e)
384
+ return None
385
+
386
+ # Get the smallest between the substructure and the POI+Linker fragments
387
+ substruct_mol = min(frags, key=lambda x: x.GetNumAtoms())
388
+ substruct_smiles = Chem.MolToSmiles(substruct_mol, canonical=True).replace(f'[{attachment_id}*]', f'[*:{attachment_id}]')
389
+ substruct_mol = Chem.MolFromSmiles(canonize(substruct_smiles))
390
+
391
+ # Check that the substructure matches in the PROTAC molecule
392
+ if not protac_mol.HasSubstructMatch(dummy2query(substruct_mol), useChirality=True):
393
+ if verbose:
394
+ print('ERROR. Substructure does not match in PROTAC molecule:')
395
+ print('PROTAC molecule:')
396
+ safe_display(protac_mol)
397
+ print('Substructure molecule:')
398
+ safe_display(substruct_mol)
399
+ return None
400
+
401
+ return substruct_mol
402
+
403
+
404
+ def get_mapped_substr_from_protac(
405
+ protac: Chem.Mol,
406
+ substr: Chem.Mol,
407
+ attachment_id: int = 1,
408
+ ) -> Optional[Chem.Mol]:
409
+ """ Get the mapped substructure from a PROTAC molecule and an unmapped substructure.
410
+
411
+ Args:
412
+ protac: The PROTAC molecule.
413
+ substr: The unmapped substructure.
414
+ attachment_id: The attachment point ID to be assigned to the substructure.
415
+
416
+ Returns:
417
+ The mapped substructure molecule. None if the function fails to find the substructure.
418
+ """
419
+ num_matches = len(protac.GetSubstructMatches(substr, useChirality=True))
420
+ if num_matches != 1:
421
+ return None
422
+ other_substr = Chem.ReplaceCore(protac, substr, labelByIndex=False, replaceDummies=False)
423
+ if other_substr is None:
424
+ return None
425
+ mapped_substr = Chem.ReplaceCore(protac, remove_dummy_atoms(other_substr), labelByIndex=False, replaceDummies=False)
426
+ if mapped_substr is None:
427
+ return None
428
+ mapped_smiles = Chem.MolToSmiles(mapped_substr, canonical=True)
429
+ # Replace "[1*]" or "[2*]" with the correct attachment point with a regex
430
+ mapped_smiles = re.sub(r'\[(\d+)\*\]', f'[*:{attachment_id}]', mapped_smiles)
431
+ mapped_smiles = canonize(mapped_smiles)
432
+ if mapped_smiles is None:
433
+ return None
434
+ return Chem.MolFromSmiles(mapped_smiles)
435
+
436
+
437
+ def get_substructs_from_substr_and_linker(
438
+ protac_smiles: str,
439
+ protac: Chem.Mol,
440
+ substr: Chem.Mol,
441
+ linker: Chem.Mol,
442
+ attachment_id: int = 1,
443
+ poi_attachment_id: int = 1,
444
+ e3_attachment_id: int = 2,
445
+ verbose: int = 0,
446
+ stats: Counter = None,
447
+ ) -> Optional[Dict[str, str]]:
448
+ """ Get the substructures of a PROTAC molecule from an unmapped substructure and linker.
449
+
450
+ Args:
451
+ protac_smiles: The SMILES of the PROTAC molecule.
452
+ protac: The RDKit molecule object of the PROTAC.
453
+ substr: The RDKit molecule object of the currently matching substructure. Should be UNMAPPED.
454
+ linker: The RDKit molecule object of the linker.
455
+ attachment_id: The attachment point ID of the currently matching substructure.
456
+ verbose: The verbosity level.
457
+
458
+ Returns:
459
+ Dict: The substructures of the PROTAC molecule. None if the function fails to find the substructures.
460
+ """
461
+ if attachment_id not in [poi_attachment_id, e3_attachment_id]:
462
+ raise ValueError('Attachment ID must be either 1 or 2')
463
+
464
+ if substr is None:
465
+ return None
466
+
467
+ subr_matches = list(protac.GetSubstructMatches(substr, useChirality=True))
468
+ if len(subr_matches) != 1:
469
+ if stats is not None:
470
+ stats['multiple substructure matches'] += 1
471
+ if verbose:
472
+ print('ERROR: Multiple substructure matches')
473
+ return None
474
+ subr_match = subr_matches[0]
475
+
476
+ mapped_substr = get_mapped_substr_from_protac(protac, substr, attachment_id)
477
+ if mapped_substr is None:
478
+ if stats is not None:
479
+ stats['mapped substructure not found'] += 1
480
+ if verbose:
481
+ print('ERROR: Mapped substructure not found')
482
+ return None
483
+
484
+ linker_matches = protac.GetSubstructMatches(remove_dummy_atoms(linker), useChirality=True)
485
+ for linker_match in linker_matches:
486
+ # Check that the intersection between the substructure and the linker
487
+ # matches is only one atom, i.e., the attachment point
488
+ if len(set(subr_match).intersection(linker_match)) == 1:
489
+ linker_match = linker_match
490
+ break
491
+
492
+ # Based on the linker match found, remove it from the PROTAC
493
+ emol = Chem.EditableMol(protac)
494
+
495
+ # Remove atoms in descending order of their indices
496
+ for idx in sorted(linker_match, reverse=True):
497
+ emol.RemoveAtom(idx)
498
+ # Get the modified molecule
499
+ try:
500
+ protac_fragments = emol.GetMol()
501
+ except Exception as e:
502
+ if verbose:
503
+ print(e)
504
+ return None
505
+ try:
506
+ Chem.SanitizeMol(protac_fragments)
507
+ except Exception as e:
508
+ if verbose:
509
+ print(e)
510
+ return None
511
+ if verbose:
512
+ img = Draw.MolToImage(protac_fragments, highlightAtoms=linker_match, size=(800, 300))
513
+ safe_display(img)
514
+
515
+ # Get the fragments after removing the linker
516
+ try:
517
+ fragments = Chem.GetMolFrags(protac_fragments, asMols=True, sanitizeFrags=True)
518
+ except Exception as e:
519
+ if verbose:
520
+ print(e)
521
+ return None
522
+
523
+ if len(fragments) != 2:
524
+ if stats is not None:
525
+ stats['multiple fragments after removing the linker'] += 1
526
+ if verbose:
527
+ for frag in fragments:
528
+ safe_display(frag)
529
+ print('ERROR: Multiple fragments after removing the linker')
530
+ return None
531
+
532
+ substructs = {}
533
+ substructs['linker'] = Chem.MolToSmiles(linker, canonical=True)
534
+ for frag in fragments:
535
+ if frag.HasSubstructMatch(substr, useChirality=True):
536
+ label = 'e3' if attachment_id == e3_attachment_id else 'poi'
537
+ substructs[label] = Chem.MolToSmiles(mapped_substr, canonical=True)
538
+ # Replace "[1*]" or "[2*]" with the correct attachment point with a regex
539
+ substructs[label] = re.sub(r'\[(\d+)\*\]', f'[*:{attachment_id}]', substructs[label])
540
+ if verbose:
541
+ print(f'Found {label.capitalize()} fragment.')
542
+ img = Draw.MolToImage(Chem.MolFromSmiles(substructs[label]), size=(800, 300))
543
+ safe_display(img)
544
+ else:
545
+ label = 'e3' if attachment_id == poi_attachment_id else 'poi'
546
+ other_attachment_id = e3_attachment_id if label == 'e3' else poi_attachment_id
547
+
548
+ other_substr = get_mapped_substr_from_protac(protac, frag, other_attachment_id)
549
+ if other_substr is None:
550
+ return None
551
+ substructs[label] = Chem.MolToSmiles(other_substr, canonical=True)
552
+
553
+ if verbose:
554
+ print(f'Found {label.capitalize()} fragment.')
555
+ img = Draw.MolToImage(Chem.MolFromSmiles(substructs[label]), size=(800, 300))
556
+ safe_display(img)
557
+ # Canonicalize the SMILES strings
558
+ substructs = {k: canonize(v) for k, v in substructs.items()}
559
+
560
+ # Check that the reassembled PROTAC matches the original PROTAC
561
+ if not check_reassembly(protac_smiles, '.'.join(substructs.values()), stats, verbose):
562
+ return None
563
+
564
+ return substructs
565
+
566
+
567
+ def swap_attachment_points(
568
+ s: str,
569
+ poi_attachment_id: int = 1,
570
+ e3_attachment_id: int = 2,
571
+ ) -> str:
572
+ """ Swaps the attachment points in a SMARTS string.
573
+
574
+ Args:
575
+ s: The input SMARTS string.
576
+
577
+ Returns:
578
+ The SMARTS string with the attachment points swapped.
579
+ """
580
+ tmp_e3_id = '^^^^E3^^^^'
581
+ tmp_poi_id = '^^^^POI^^^^'
582
+ s = s.replace(f'[*:{poi_attachment_id}]', f'[*:{tmp_poi_id}]')
583
+ s = s.replace(f'[*:{e3_attachment_id}]', f'[*:{tmp_e3_id}]')
584
+ s = s.replace(f'[*:{tmp_poi_id}]', f'[*:{e3_attachment_id}]')
585
+ s = s.replace(f'[*:{tmp_e3_id}]', f'[*:{poi_attachment_id}]')
586
+ return canonize(s)
protac_splitter/data/generation/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .generation import generate_protacs
2
+ from .functional_groups import (
3
+ get_functional_group_at_attachment,
4
+ get_functional_groups_distributions,
5
+ )
6
+
7
+ __all__ = [
8
+ 'generate_protacs',
9
+ 'get_functional_group_at_attachment',
10
+ 'get_functional_groups_distributions',
11
+ ]
protac_splitter/data/generation/functional_groups.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Union
2
+ from collections import defaultdict, Counter
3
+ import json
4
+
5
+ import pandas as pd
6
+ from rdkit import Chem
7
+ from rdkit.Chem import Draw
8
+ from tqdm import tqdm
9
+
10
+ from protac_splitter.chemoinformatics import (
11
+ get_atom_idx_at_attachment,
12
+ canonize_smarts,
13
+ )
14
+ from protac_splitter.display_utils import (
15
+ safe_display,
16
+ display_mol,
17
+ )
18
+
19
+
20
+ def get_functional_group_at_attachment(
21
+ protac: Chem.Mol,
22
+ substruct: Chem.Mol,
23
+ linker: Chem.Mol,
24
+ n_hops: int = 1,
25
+ timeout: Optional[Union[int, float]] = None,
26
+ return_dict: bool = False,
27
+ verbose: int = 0,
28
+ ) -> Union[str, Dict[str, str]]:
29
+ """ Get the functional group at the attachment point of a substructure in the PROTAC molecule.
30
+
31
+ Args:
32
+ protac: The PROTAC molecule.
33
+ substruct: The substructure of the PROTAC that contains the attachment point, e.g., the POI or E3 ligase.
34
+ linker: The linker molecule.
35
+ n_hops: The number of hops to consider for the neighborhood.
36
+ timeout: The timeout for the substructure search.
37
+ return_dict: Whether to return the functional groups as a dictionary.
38
+ verbose: Verbosity level.
39
+
40
+ Returns:
41
+ str | Dict[str, str]: The SMARTS of the functional group at the attachment point. If return_dict is True, a dictionary with the SMARTS of the functional groups at the attachment point and at the "two sides" of the attachment point (keys: 'attachment', 'substruct', 'linker').
42
+ """
43
+ protac = Chem.AddHs(protac)
44
+ substruct = Chem.AddHs(substruct)
45
+
46
+ if linker is not None:
47
+ linker = Chem.AddHs(linker)
48
+
49
+ attachment_idxs = get_atom_idx_at_attachment(
50
+ protac=protac,
51
+ substruct=substruct,
52
+ linker=linker,
53
+ timeout=timeout,
54
+ return_dict=True,
55
+ verbose=0,
56
+ )
57
+ # Get all neighboring atoms that are n_hops away from the attachment point
58
+ if attachment_idxs is None:
59
+ return None
60
+ if len(attachment_idxs) != 2:
61
+ return None
62
+ if verbose:
63
+ print(f'Attachment points: {attachment_idxs}')
64
+ img = Draw.MolToImage(protac, highlightAtoms=attachment_idxs.values(), size=(800, 500))
65
+ safe_display(img)
66
+ print('Neighbors:')
67
+
68
+ # Recursively find neighbors at n_hops distance
69
+ neighborhood = set([protac.GetAtomWithIdx(idx) for idx in attachment_idxs.values()])
70
+ def find_neighbors(atom, hops, excluded_atom_idx=None):
71
+ if hops <= 0:
72
+ return
73
+ for neighbor in atom.GetNeighbors():
74
+ if excluded_atom_idx is not None and neighbor.GetIdx() == excluded_atom_idx:
75
+ neighborhood.add(neighbor)
76
+ continue
77
+ neighborhood.add(neighbor)
78
+ find_neighbors(neighbor, hops - 1)
79
+
80
+ for idx in attachment_idxs.values():
81
+ find_neighbors(protac.GetAtomWithIdx(idx), n_hops)
82
+
83
+ # Display the neighborhood
84
+ if verbose:
85
+ print(f'Neighbors at {n_hops} hops:')
86
+ # Get options to display all hydrogen atoms
87
+ options = Draw.DrawingOptions()
88
+ # Add a legend to the image
89
+ options.legend = 'Neighbors at attachment points'
90
+ img = Draw.MolToImage(protac, highlightAtoms=[a.GetIdx() for a in neighborhood], size=(800, 500), options=options)
91
+ safe_display(img)
92
+
93
+ # # NOTE: The following is an overkill, there is an RDKit function to extract a substructure
94
+ # neighborhood_mol = extract_atoms_as_molecule(protac, [a.GetIdx() for a in neighborhood])
95
+ # neighborhood_smarts = canonize_smarts(Chem.MolToSmarts(neighborhood_mol))
96
+
97
+ # Extract the SMARTS given the atom indices of the neighborhood
98
+ neighborhood_idxs = [a.GetIdx() for a in neighborhood]
99
+ neighborhood_smarts = Chem.MolFragmentToSmarts(protac, neighborhood_idxs)
100
+ neighborhood_smarts = canonize_smarts(neighborhood_smarts)
101
+
102
+ if verbose:
103
+ print(neighborhood_smarts)
104
+ display_mol(Chem.MolFromSmarts(neighborhood_smarts), display_svg=False)
105
+
106
+ if return_dict:
107
+ smarts = {}
108
+ smarts['attachment'] = neighborhood_smarts
109
+ # Get the SMARTS at the attachment point and at its "two sides"
110
+ for side, idx in attachment_idxs.items():
111
+ # NOTE: We know that attachment_idxs is a dictionary with two keys,
112
+ # 'susbtruct' and 'linker', so we can directly use the other key
113
+ other_side = 'linker' if side == 'substruct' else 'substruct'
114
+ excluded_atom_idx = attachment_idxs[other_side]
115
+ neighborhood = {protac.GetAtomWithIdx(idx)}
116
+ find_neighbors(protac.GetAtomWithIdx(idx), n_hops, excluded_atom_idx=excluded_atom_idx)
117
+
118
+ # Get the atom indices of the neighborhood
119
+ neighborhood_idxs = [a.GetIdx() for a in neighborhood]
120
+
121
+ # Copy the PROTAC molecule and set the excluded_atom_idx to a dummy
122
+ p = Chem.Mol(protac)
123
+ p.GetAtomWithIdx(excluded_atom_idx).SetAtomicNum(0)
124
+
125
+ # Extract the SMARTS from the copied PROTAC given the indeces
126
+ s = Chem.MolFragmentToSmarts(p, neighborhood_idxs)
127
+ smarts[other_side] = canonize_smarts(s)
128
+
129
+ return smarts
130
+
131
+ return neighborhood_smarts
132
+
133
+
134
+ def get_functional_group_at_attachment_side(
135
+ substruct: Chem.Mol,
136
+ attachment_id: Optional[int] = None,
137
+ n_hops: int = 2,
138
+ add_Hs: bool = True,
139
+ ) -> Optional[str]:
140
+ """ Get the functional group at the attachment point of a substructure in the PROTAC molecule.
141
+
142
+ Args:
143
+ substruct: The substructure of the PROTAC that contains the attachment point, e.g., the POI or E3 ligase.
144
+ attachment_id: The attachment point ID in the substructure. E.g., 1 for the POI, as in "[*:1]".
145
+ n_hops: The number of hops to consider for the neighborhood. Default is 2.
146
+ add_Hs: Whether to add hydrogens to the substructure.
147
+
148
+ Returns:
149
+ str: The SMARTS of the functional group at the attachment point. None if failed.
150
+ """
151
+ if add_Hs:
152
+ substruct = Chem.AddHs(substruct)
153
+
154
+ # Get the atom index of the attachment point, i.e., a dummy atom
155
+ attachment_idx2map = {}
156
+ for atom in substruct.GetAtoms():
157
+ if atom.GetAtomicNum() == 0:
158
+ # Get the mapped atom index
159
+ attachment_idx2map[atom.GetIdx()] = atom.GetAtomMapNum()
160
+
161
+ if not attachment_idx2map:
162
+ return None
163
+
164
+ # If we are dealing with a linker, get the specific attachment point
165
+ if attachment_id is not None:
166
+ attachment_idx = [k for k, v in attachment_idx2map.items() if v == attachment_id]
167
+ if not attachment_idx:
168
+ return None
169
+ attachment_idx = attachment_idx[0]
170
+ else:
171
+ attachment_idx = list(attachment_idx2map.keys())[0]
172
+
173
+ neighborhood = {substruct.GetAtomWithIdx(attachment_idx)}
174
+ def find_neighbors(atom, hops):
175
+ if hops <= 0:
176
+ return
177
+ for neighbor in atom.GetNeighbors():
178
+ neighborhood.add(neighbor)
179
+ find_neighbors(neighbor, hops - 1)
180
+
181
+ find_neighbors(substruct.GetAtomWithIdx(attachment_idx), n_hops)
182
+ neighborhood_idxs = [a.GetIdx() for a in neighborhood]
183
+
184
+ neighborhood_smarts = Chem.MolFragmentToSmarts(substruct, neighborhood_idxs)
185
+ if neighborhood_smarts:
186
+ return canonize_smarts(neighborhood_smarts)
187
+
188
+ return None
189
+
190
+
191
+ def get_functional_groups_distributions(
192
+ df: pd.DataFrame,
193
+ get_side_chain_info: bool = False,
194
+ timeout: Optional[Union[int, float]] = None,
195
+ filename_distributions: Optional[str] = None,
196
+ filename_mappings: Optional[str] = None,
197
+ filename_df_with_functional_groups: Optional[str] = None,
198
+ load_from_file: bool = True,
199
+ verbose: int = 0,
200
+ ) -> Dict[str, Dict[str, set]]:
201
+ """ Get the distributions of functional groups at attachment points in a dataframe of PROTACs.
202
+
203
+ The input dataframe should contain the following columns:
204
+ - 'PROTAC SMILES': The SMILES of the PROTAC.
205
+ - 'POI Ligand SMILES with direction': The SMILES of the POI ligand.
206
+ - 'Linker SMILES with direction': The SMILES of the linker.
207
+ - 'E3 Binder SMILES with direction': The SMILES of the E3 binder.
208
+
209
+ Args:
210
+ df: The DataFrame containing the PROTACs.
211
+ get_side_chain_info: Whether to get the side chain information along with the functional groups at the attachment points.
212
+ timeout: The timeout for the substructure search. Default is None.
213
+ verbose: Verbosity level.
214
+
215
+ Returns:
216
+ Dict[str, Dict[str, set]]: The distributions of functional groups at attachment points in PROTACs.
217
+ """
218
+ smarts_counter = Counter()
219
+ e3_smarts_counter = Counter()
220
+ poi_smarts_counter = Counter()
221
+ substr_smarts_counter = {
222
+ 'poi2linker': defaultdict(Counter),
223
+ 'linker2poi': defaultdict(Counter),
224
+ 'e32linker': defaultdict(Counter),
225
+ 'linker2e3': defaultdict(Counter),
226
+ }
227
+ # Assign to each functional group the list of substructures that appear in the df
228
+ poi_substr2fg = defaultdict(set)
229
+ e3_substr2fg = defaultdict(set)
230
+ # Assign to each substructure the list of functional groups that appear in the df
231
+ poi_fg_2_substr = defaultdict(set)
232
+ e3_fg_2_substr = defaultdict(set)
233
+ substr_fg_2_linker = defaultdict(set)
234
+
235
+ linker2fg = defaultdict(dict)
236
+
237
+ if load_from_file:
238
+ if filename_distributions is not None and filename_mappings is not None:
239
+ with open(filename_distributions, 'r') as f:
240
+ fg_distr = json.load(f)
241
+ with open(filename_mappings, 'r') as f:
242
+ fg_mappings = json.load(f)
243
+ ret = {}
244
+ ret.update(fg_distr)
245
+ ret.update(fg_mappings)
246
+ return ret
247
+ else:
248
+ print(f'WARNING: No filename provided to load the mappings from. The functional groups will be recomputed.')
249
+
250
+ df_with_functional_groups = []
251
+
252
+ for i, row in tqdm(df.iterrows(), total=len(df)):
253
+ protac_smiles = row['PROTAC SMILES']
254
+ poi_smiles = row['POI Ligand SMILES with direction']
255
+ linker_smiles = row['Linker SMILES with direction']
256
+ e3_smiles = row['E3 Binder SMILES with direction']
257
+
258
+ protac = Chem.MolFromSmiles(protac_smiles)
259
+ poi = Chem.MolFromSmiles(poi_smiles)
260
+ e3 = Chem.MolFromSmiles(e3_smiles)
261
+ linker = Chem.MolFromSmiles(linker_smiles)
262
+
263
+ if None in [protac, poi, e3, linker]:
264
+ print(f'WARNING: Could not parse the following SMILES:')
265
+ print(f'PROTAC: {protac_smiles}')
266
+ print(f'POI: {poi_smiles}')
267
+ print(f'Linker: {linker_smiles}')
268
+ print(f'E3: {e3_smiles}')
269
+ print('-' * 80)
270
+
271
+ # We have a bit of care with the linker, as it can be empty
272
+ try:
273
+ _ = Chem.molzip(Chem.MolFromSmiles('.'.join([poi_smiles, linker_smiles, e3_smiles])))
274
+ except:
275
+ print(f'WARNING: The linker might be empty: {linker_smiles}')
276
+ linker = None
277
+
278
+ if linker is not None:
279
+ fg_poi = get_functional_group_at_attachment(protac, poi, linker, timeout=timeout, return_dict=get_side_chain_info)
280
+ fg_e3 = get_functional_group_at_attachment(protac, e3, linker, timeout=timeout, return_dict=get_side_chain_info)
281
+ else:
282
+ # If the linker is empty, then we use the other side as the linker
283
+ fg_poi = get_functional_group_at_attachment(protac, poi, e3, return_dict=get_side_chain_info)
284
+ fg_e3 = get_functional_group_at_attachment(protac, e3, poi, return_dict=get_side_chain_info)
285
+
286
+ if get_side_chain_info:
287
+ if fg_poi is not None:
288
+ smarts_counter.update([fg_poi['attachment']])
289
+ poi_smarts_counter.update([fg_poi['substruct']])
290
+ substr_smarts_counter['poi2linker'][fg_poi['substruct']].update([fg_poi['linker']])
291
+ substr_smarts_counter['linker2poi'][fg_poi['linker']].update([fg_poi['substruct']])
292
+ linker2fg[linker_smiles]['poi'] = fg_poi['attachment']
293
+
294
+ poi_substr2fg[poi_smiles].append(fg_poi['attachment'])
295
+ poi_fg_2_substr[fg_poi['attachment']].update([poi_smiles])
296
+
297
+ if fg_e3 is not None:
298
+ smarts_counter.update([fg_e3['attachment']])
299
+ e3_smarts_counter.update([fg_e3['substruct']])
300
+ substr_smarts_counter['e32linker'][fg_e3['substruct']].update([fg_e3['linker']])
301
+ substr_smarts_counter['linker2e3'][fg_e3['linker']].update([fg_e3['substruct']])
302
+ linker2fg[linker_smiles]['e3'] = fg_e3['attachment']
303
+
304
+ e3_substr2fg[e3_smiles].update(fg_e3['attachment'])
305
+ e3_fg_2_substr[fg_e3['attachment']].update([e3_smiles])
306
+ else:
307
+ if fg_poi is not None:
308
+ smarts_counter.update([fg_poi])
309
+ poi_smarts_counter.update([fg_poi])
310
+ poi_substr2fg[poi_smiles].update([fg_poi])
311
+ poi_fg_2_substr[fg_poi].update([poi_smiles])
312
+ substr_fg_2_linker[fg_poi].update([linker_smiles])
313
+ if fg_e3 is not None:
314
+ smarts_counter.update([fg_e3])
315
+ e3_smarts_counter.update([fg_e3])
316
+ e3_substr2fg[e3_smiles].update([fg_e3])
317
+ e3_fg_2_substr[fg_e3].update([e3_smiles])
318
+ substr_fg_2_linker[fg_e3].update([linker_smiles])
319
+
320
+ # Update the DataFrame with the functional groups
321
+ if fg_poi is not None:
322
+ row['POI Ligand Functional Group'] = fg_poi
323
+ if fg_e3 is not None:
324
+ row['E3 Binder Functional Group'] = fg_e3
325
+ df_with_functional_groups.append(row)
326
+
327
+ # Normalize all the counts to probability distributions
328
+ fg_distr = {k: v / smarts_counter.total() for k, v in smarts_counter.items()}
329
+ e3_fg_distr = {k: v / e3_smarts_counter.total() for k, v in e3_smarts_counter.items()}
330
+ poi_fg_distr = {k: v / poi_smarts_counter.total() for k, v in poi_smarts_counter.items()}
331
+
332
+ # Sort the probability distributions
333
+ fg_distr = dict(sorted(fg_distr.items(), key=lambda x: x[1], reverse=True))
334
+ e3_fg_distr = dict(sorted(e3_fg_distr.items(), key=lambda x: x[1], reverse=True))
335
+ poi_fg_distr = dict(sorted(poi_fg_distr.items(), key=lambda x: x[1], reverse=True))
336
+
337
+ if not get_side_chain_info:
338
+ ret = {
339
+ 'fg_distr': fg_distr,
340
+ 'e3_fg_distr': e3_fg_distr,
341
+ 'poi_fg_distr': poi_fg_distr,
342
+ 'poi_fg_2_substr': poi_fg_2_substr,
343
+ 'e3_fg_2_substr': e3_fg_2_substr,
344
+ 'substr_fg_2_linker': substr_fg_2_linker,
345
+ }
346
+
347
+ # Normalize the linker-to-substructure to probability distributions
348
+ if get_side_chain_info:
349
+ side_fg_distr = defaultdict(dict)
350
+ for direction, smarts2counter in substr_smarts_counter.items():
351
+ for smarts, counter in smarts2counter.items():
352
+ side_fg_distr[direction][smarts] = {k: v / counter.total() for k, v in counter.items()}
353
+ side_fg_distr[direction][smarts] = dict(sorted(side_fg_distr[direction][smarts].items(), key=lambda x: x[1], reverse=True))
354
+
355
+ if verbose:
356
+ # Display the top 5 functional groups
357
+ print('-' * 80)
358
+ print(f'{"-".join(direction.upper().split("2"))}:')
359
+ print('-' * len(direction) + '-' * 2)
360
+ for i, (smarts, probs) in enumerate(side_fg_distr[direction].items()):
361
+ if i >= 5:
362
+ break
363
+ print(f'{smarts}:')
364
+ for j, (sma, prob) in enumerate(probs.items()):
365
+ if j >= 5:
366
+ break
367
+ print(f'\t{prob:.2%} -> {sma}')
368
+ ret = {
369
+ 'fg_distr': fg_distr,
370
+ 'e3_fg_distr': e3_fg_distr,
371
+ 'poi_fg_distr': poi_fg_distr,
372
+ 'poi_fg_2_substr': poi_fg_2_substr,
373
+ 'e3_fg_2_substr': e3_fg_2_substr,
374
+ 'substr_fg_2_linker': substr_fg_2_linker,
375
+ 'side_fg_distr': side_fg_distr,
376
+ }
377
+
378
+ if filename_distributions is not None:
379
+ # Save to JSON file
380
+ distributions = {k: v for k, v in ret.items() if 'distr' in k}
381
+ with open(filename_distributions, 'w') as f:
382
+ json.dump(distributions, f, indent=4)
383
+ print(f'Functional group distributions saved to: {filename_distributions}')
384
+
385
+ if filename_mappings is not None:
386
+ # Convert sets to lists to make the data serializable
387
+ fg_mappings = {k: {sk: list(s) for sk, s in v.items()} for k, v in ret.items() if 'distr' not in k}
388
+
389
+ with open(filename_mappings, 'w') as f:
390
+ json.dump(fg_mappings, f, indent=4)
391
+ print(f'Functional group mappings saved to: {filename_mappings}')
392
+
393
+ df_with_functional_groups = pd.DataFrame(df_with_functional_groups)
394
+ ret['dataframe'] = df_with_functional_groups
395
+
396
+ if filename_df_with_functional_groups is not None:
397
+ df_with_functional_groups.to_csv(filename_df_with_functional_groups, index=False)
398
+ print(f'DataFrame with functional groups saved to: {filename_df_with_functional_groups}')
399
+
400
+ return ret
protac_splitter/data/generation/generation.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ from typing import Dict, List, Optional
4
+
5
+ import pandas as pd
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from rdkit import Chem
9
+
10
+ from protac_splitter.evaluation import check_reassembly
11
+
12
+
13
+ def generate_protacs(
14
+ poi_fg_distr: Dict[str, float],
15
+ e3_fg_distr: Dict[str, float],
16
+ substr_fg_2_linker: Dict[str, List[str]],
17
+ poi_fg_2_substr: Dict[str, List[str]],
18
+ e3_fg_2_substr: Dict[str, List[str]],
19
+ num_samples: int,
20
+ random_state: int = 42,
21
+ batch_size: int = 1000,
22
+ max_workers: int = 4,
23
+ original_df: Optional[pd.DataFrame] = None,
24
+ filename_generated_df: Optional[str] = None,
25
+ base_data_dir: Optional[str] = None,
26
+ cover_all_smiles: bool = False,
27
+ ) -> pd.DataFrame:
28
+ """ Generate PROTACs given the distributions of functional groups at attachment points.
29
+
30
+ Args:
31
+ poi_fg_distr: The distribution of functional groups at the POI attachment point.
32
+ e3_fg_distr: The distribution of functional groups at the E3 attachment point.
33
+ substr_fg_2_linker: The mapping of functional groups to linkers.
34
+ poi_fg_2_substr: The mapping of functional groups to POI substrates.
35
+ e3_fg_2_substr: The mapping of functional groups to E3 substrates.
36
+ num_samples: The number of PROTACs to generate.
37
+ random_state: The random state for reproducibility.
38
+ batch_size: The batch size for generating PROTACs.
39
+ max_workers: The maximum number of workers for the ThreadPoolExecutor.
40
+ original_df: The original DataFrame containing the PROTACs. Must have a
41
+ column named 'PROTAC SMILES' containing the strings to
42
+ avoid generating. The check is done on strings, so make
43
+ sure to canonize/standardize the SMILES strings.
44
+ filename_generated_df: The filename to save the generated PROTACs.
45
+
46
+ Returns:
47
+ pd.DataFrame: The DataFrame containing the generated PROTACs.
48
+ """
49
+
50
+ np.random.seed(random_state)
51
+ final_df = pd.DataFrame()
52
+ total_batches = int(np.ceil(num_samples / batch_size))
53
+
54
+ def generate_protac_batch(batch_size: int, random_state: int) -> List[dict]:
55
+ np.random.seed(random_state)
56
+
57
+ # Sample functional groups for POI and E3
58
+ poi_fgs = np.random.choice(list(poi_fg_distr.keys()), size=batch_size, p=list(poi_fg_distr.values()))
59
+ e3_fgs = np.random.choice(list(e3_fg_distr.keys()), size=batch_size, p=list(e3_fg_distr.values()))
60
+
61
+ # Map functional groups to corresponding substrates
62
+ # NOTE: When size argument is specified, the output is a numpy array.
63
+ # NOTE: If the functional group is not in the dictionary, the output is an empty numpy array.
64
+ poi_samples = [
65
+ np.random.choice(poi_fg_2_substr.get(fg, []), size=1 if fg in poi_fg_2_substr and poi_fg_2_substr[fg] else 0)
66
+ for fg in poi_fgs
67
+ ]
68
+ e3_samples = [
69
+ np.random.choice(e3_fg_2_substr.get(fg, []), size=1 if fg in e3_fg_2_substr and e3_fg_2_substr[fg] else 0)
70
+ for fg in e3_fgs
71
+ ]
72
+
73
+ generated_protacs = []
74
+
75
+ for poi_smiles, poi_fg, e3_smiles, e3_fg in zip(poi_samples, poi_fgs, e3_samples, e3_fgs):
76
+ # Check if poi_smiles and e3_smiles are not an empty numpy array
77
+ if poi_smiles.size == 0 or e3_smiles.size == 0:
78
+ continue
79
+
80
+ # Convert the numpy arrays to strings
81
+ poi_smiles, e3_smiles = poi_smiles[0], e3_smiles[0]
82
+
83
+ linkers = set(substr_fg_2_linker.get(poi_fg, [])) & set(substr_fg_2_linker.get(e3_fg, []))
84
+ if not linkers:
85
+ continue
86
+
87
+ linker_smiles = np.random.choice(list(linkers))
88
+
89
+ # Get the PROTAC by combining the POI, linker, and E3
90
+ ligands_smiles = '.'.join([poi_smiles, linker_smiles, e3_smiles])
91
+ protac = Chem.MolFromSmiles(ligands_smiles)
92
+
93
+ if protac is None:
94
+ continue
95
+ try:
96
+ protac = Chem.molzip(protac)
97
+ except:
98
+ continue
99
+
100
+ # Sanitize molecule
101
+ try:
102
+ zero_on_success = Chem.SanitizeMol(protac, catchErrors=True)
103
+ if zero_on_success != 0:
104
+ continue
105
+ protac_smiles = Chem.MolToSmiles(protac, canonical=True)
106
+ except:
107
+ continue
108
+
109
+ if original_df is not None and protac_smiles in original_df['PROTAC SMILES'].values:
110
+ continue
111
+
112
+ # Check if PROTAC can be reassembled
113
+ if not check_reassembly(protac_smiles, ligands_smiles):
114
+ continue
115
+
116
+ generated_protacs.append({
117
+ 'PROTAC SMILES': protac_smiles,
118
+ 'POI Ligand SMILES with direction': poi_smiles,
119
+ 'Linker SMILES with direction': linker_smiles,
120
+ 'E3 Binder SMILES with direction': e3_smiles,
121
+ 'POI Ligand Functional Group': poi_fg,
122
+ 'E3 Binder Functional Group': e3_fg,
123
+ })
124
+
125
+ return generated_protacs
126
+
127
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
128
+ futures = []
129
+ for i in tqdm(range(total_batches), desc="Generating Batches"):
130
+ futures.append(executor.submit(generate_protac_batch, batch_size, random_state + i))
131
+
132
+ for i, future in tqdm(enumerate(futures), desc="Processing Results", total=total_batches):
133
+ generated_batch = future.result()
134
+ if generated_batch:
135
+ batch_df = pd.DataFrame(generated_batch)
136
+ final_df = pd.concat([final_df, batch_df]).drop_duplicates()
137
+ if i % 100 == 0:
138
+ if base_data_dir:
139
+ batch_df.to_csv(os.path.join(base_data_dir, f'generated_protacs_batch={i}.csv'), index=False)
140
+ else:
141
+ batch_df.to_csv(f'generated_protacs_batch={i}.csv', index=False)
142
+ if filename_generated_df:
143
+ final_df.to_csv(filename_generated_df, index=False)
144
+ if len(final_df) >= num_samples:
145
+ break
146
+
147
+ if not final_df.empty:
148
+ generated_pois = set(final_df['POI Ligand SMILES with direction'].unique())
149
+ generated_e3s = set(final_df['E3 Binder SMILES with direction'].unique())
150
+ generated_linkers = set(final_df['Linker SMILES with direction'].unique())
151
+ else:
152
+ generated_pois = set()
153
+ generated_e3s = set()
154
+ generated_linkers = set()
155
+
156
+ # Check how we covered the available substructures
157
+ avail_pois = set()
158
+ avail_e3s = set()
159
+ avail_linkers = set()
160
+ for fg in poi_fg_2_substr:
161
+ avail_pois.update(set(poi_fg_2_substr[fg]))
162
+ for fg in e3_fg_2_substr:
163
+ avail_e3s.update(set(e3_fg_2_substr[fg]))
164
+ for fg in substr_fg_2_linker:
165
+ avail_linkers.update(set(substr_fg_2_linker[fg]))
166
+
167
+ e3_coverage = len(generated_e3s) / len(avail_e3s)
168
+ poi_coverage = len(generated_pois) / len(avail_pois)
169
+ linker_coverage = len(generated_linkers) / len(avail_linkers)
170
+
171
+ print(f"POI coverage: {poi_coverage:.3%}")
172
+ print(f"E3 coverage: {e3_coverage:.3%}")
173
+ print(f"Linker coverage: {linker_coverage:.3%}")
174
+
175
+ # Get the "leftover" ligands
176
+ leftover_pois = avail_pois - generated_pois
177
+ leftover_e3s = avail_e3s - generated_e3s
178
+ leftover_linkers = avail_linkers - generated_linkers
179
+
180
+ covering_df = []
181
+
182
+ with tqdm(total=len(leftover_pois) + len(leftover_e3s) + len(leftover_linkers), desc="Covering Leftover Ligands") as pbar:
183
+ while True:
184
+ if not cover_all_smiles:
185
+ break
186
+
187
+ # Randomly select a POI, E3, and linker
188
+ if not leftover_pois:
189
+ pois_to_sample = avail_pois
190
+ else:
191
+ pois_to_sample = leftover_pois
192
+ if not leftover_e3s:
193
+ e3s_to_sample = avail_e3s
194
+ else:
195
+ e3s_to_sample = leftover_e3s
196
+ if not leftover_linkers:
197
+ linkers_to_sample = avail_linkers
198
+ else:
199
+ linkers_to_sample = leftover_linkers
200
+
201
+ poi_smiles = np.random.choice(list(pois_to_sample))
202
+ e3_smiles = np.random.choice(list(e3s_to_sample))
203
+ linker_smiles = np.random.choice(list(linkers_to_sample))
204
+
205
+ # Get the PROTAC by combining the POI, linker, and E3
206
+ ligands_smiles = '.'.join([poi_smiles, linker_smiles, e3_smiles])
207
+ protac = Chem.MolFromSmiles(ligands_smiles)
208
+ if protac is None:
209
+ continue
210
+ try:
211
+ protac = Chem.molzip(protac)
212
+ except:
213
+ continue
214
+
215
+ # Sanitize molecule
216
+ try:
217
+ zero_on_success = Chem.SanitizeMol(protac, catchErrors=True)
218
+ if zero_on_success != 0:
219
+ continue
220
+ protac_smiles = Chem.MolToSmiles(protac, canonical=True)
221
+ except:
222
+ continue
223
+
224
+ if original_df is not None and protac_smiles in original_df['PROTAC SMILES'].values:
225
+ continue
226
+
227
+ # Check if PROTAC can be reassembled
228
+ if not check_reassembly(protac_smiles, ligands_smiles):
229
+ continue
230
+
231
+ covering_df.append({
232
+ 'PROTAC SMILES': protac_smiles,
233
+ 'POI Ligand SMILES with direction': poi_smiles,
234
+ 'Linker SMILES with direction': linker_smiles,
235
+ 'E3 Binder SMILES with direction': e3_smiles,
236
+ 'POI Ligand Functional Group': None,
237
+ 'E3 Binder Functional Group': None,
238
+ })
239
+
240
+ generated_pois.add(poi_smiles)
241
+ generated_e3s.add(e3_smiles)
242
+ generated_linkers.add(linker_smiles)
243
+
244
+ ligands_added = 0
245
+ if poi_smiles in leftover_pois:
246
+ leftover_pois.remove(poi_smiles)
247
+ ligands_added += 1
248
+ if e3_smiles in leftover_e3s:
249
+ leftover_e3s.remove(e3_smiles)
250
+ ligands_added += 1
251
+ if linker_smiles in leftover_linkers:
252
+ leftover_linkers.remove(linker_smiles)
253
+ ligands_added += 1
254
+
255
+ e3_coverage = len(generated_e3s) / len(avail_e3s)
256
+ poi_coverage = len(generated_pois) / len(avail_pois)
257
+ linker_coverage = len(generated_linkers) / len(avail_linkers)
258
+
259
+ # Update the pbar and write the coverage
260
+ pbar.update(ligands_added)
261
+ pbar.set_postfix({
262
+ 'POI': f"{poi_coverage:.2%}",
263
+ 'E3': f"{e3_coverage:.2%}",
264
+ 'Linker': f"{linker_coverage:.2%}",
265
+ })
266
+
267
+ if not leftover_pois and not leftover_e3s and not leftover_linkers:
268
+ break
269
+
270
+ final_df = pd.concat([final_df, pd.DataFrame(covering_df)]).drop_duplicates()
271
+
272
+ # Save to file if specified
273
+ if filename_generated_df:
274
+ final_df.to_csv(filename_generated_df, index=False)
275
+ print(f"Generated PROTACs saved to: {filename_generated_df}")
276
+
277
+ return final_df
protac_splitter/display_utils.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import Optional
4
+
5
+ from rdkit import Chem
6
+ from rdkit.Chem import Draw
7
+
8
+ if 'ipykernel' in sys.modules:
9
+ from IPython.display import SVG
10
+
11
+ from .chemoinformatics import get_atom_idx_at_attachment, canonize
12
+
13
+ def safe_display(*args):
14
+ """Displays content only if running in a Jupyter notebook."""
15
+ if 'ipykernel' in sys.modules:
16
+ display(*args)
17
+ else:
18
+ print(*args)
19
+
20
+
21
+ def display_mol(
22
+ mol: Chem.Mol,
23
+ w: int = 800,
24
+ h: int = 300,
25
+ legend: Optional[str] = None,
26
+ use_smiles_as_legend: bool = True,
27
+ display_svg: bool = True,
28
+ ):
29
+ """ Display a molecule in a Jupyter notebook. Useful for having """
30
+ if mol is None:
31
+ print('Molecule is None')
32
+ return None
33
+ if use_smiles_as_legend and legend is None:
34
+ legend = Chem.MolToSmiles(mol)
35
+ if display_svg:
36
+ mol.SetProp("_Name", Chem.MolToSmiles(mol, canonical=True))
37
+ d = Draw.rdMolDraw2D.MolDraw2DSVG(w, h, noFreetype=True)
38
+ font_path = '/System/Library/Fonts/Supplemental/Arial.ttf'
39
+ if os.path.exists(font_path):
40
+ d.fontFile = font_path
41
+ d.DrawMolecule(mol, legend=legend)
42
+ d.FinishDrawing()
43
+ svg = d.GetDrawingText()
44
+ # Check if in Jupyter notebook
45
+ if sys.modules.get('ipykernel', None):
46
+ from IPython.display import SVG
47
+ safe_display(SVG(svg))
48
+ else:
49
+ img = Draw.MolToImage(mol, size=(w, h))
50
+ safe_display(img)
51
+
52
+
53
+ def get_mapped_protac_img(
54
+ protac_smiles: str,
55
+ poi_smiles: str,
56
+ linker_smiles: str,
57
+ e3_smiles: str,
58
+ w: int = 1000,
59
+ h: int = 1000,
60
+ useSVG: bool = False,
61
+ display_image: bool = False,
62
+ legend: Optional[str] = None,
63
+ show_bond_indices: bool = False,
64
+ ):
65
+ """ Display a PROTAC molecule with the POI, linker, and E3 ligase highlighted.
66
+
67
+ If `useSVG` is True, then the POI-Linker bond is highlighted in purple, whereas the E3-Linker bond is highlighted in green.
68
+ If `useSVG` is False, then both splitting points are highlighted in purple.
69
+
70
+ Args:
71
+ protac_smiles: The SMILES string of the PROTAC.
72
+ poi_smiles: The SMILES string of the POI.
73
+ linker_smiles: The SMILES string of the linker.
74
+ e3_smiles: The SMILES string of the E3 ligase.
75
+ w: The width of the image.
76
+ h: The height of the image.
77
+ useSVG: Whether to use SVG format.
78
+ display_image: Whether to display the image.
79
+ legend: The legend to display.
80
+ show_bond_indices: Whether to show bond indices in the image.
81
+ """
82
+ protac_smiles = canonize(protac_smiles)
83
+ e3_smiles = canonize(e3_smiles)
84
+ poi_smiles = canonize(poi_smiles)
85
+ linker_smiles = canonize(linker_smiles)
86
+
87
+ # Check if any of the canonicalized SMILES is None
88
+ if None in [protac_smiles, e3_smiles, poi_smiles, linker_smiles]:
89
+ return None
90
+
91
+ protac_mol = Chem.MolFromSmiles(protac_smiles)
92
+ e3_mol = Chem.MolFromSmiles(e3_smiles)
93
+ poi_mol = Chem.MolFromSmiles(poi_smiles)
94
+ linker_mol = Chem.MolFromSmiles(linker_smiles)
95
+
96
+ if None in [protac_mol, e3_mol, poi_mol, linker_mol]:
97
+ return None
98
+
99
+ if linker_smiles in ['[*:1][*:2]', '[*:2][*:1]']:
100
+ print('WARNING. Linker is empty.')
101
+ poi_attachment_idx = get_atom_idx_at_attachment(protac_mol, poi_mol, e3_mol)
102
+ e3_attachment_idx = get_atom_idx_at_attachment(protac_mol, e3_mol, poi_mol)
103
+ else:
104
+ poi_attachment_idx = get_atom_idx_at_attachment(protac_mol, poi_mol, linker_mol)
105
+ e3_attachment_idx = get_atom_idx_at_attachment(protac_mol, e3_mol, linker_mol)
106
+
107
+ cyan = (0, 1, 1, 0.5)
108
+ red = (1, 0, 0, 0.5)
109
+ green = (0, 1, 0, 0.5)
110
+ blue = (0, 0, 1, 0.5)
111
+ purple = (1, 0, 1, 0.3)
112
+
113
+ highlight_atoms = []
114
+ highlight_bonds = []
115
+ atom_colors = {}
116
+ bond_colors = {}
117
+
118
+ if poi_attachment_idx is not None:
119
+ if len(poi_attachment_idx) != 2:
120
+ if linker_smiles in ['[*:1][*:2]', '[*:2][*:1]']:
121
+ print(f'WARNING. Linker is empty, no highlighting will be showed for the POI.')
122
+ else:
123
+ print(f'WARNING. POI attachment points must be only two, got instead: {poi_attachment_idx}')
124
+ else:
125
+ poi_bond_idx = protac_mol.GetBondBetweenAtoms(*poi_attachment_idx).GetIdx()
126
+ highlight_atoms += poi_attachment_idx
127
+ highlight_bonds.append(poi_bond_idx)
128
+ atom_colors[poi_attachment_idx[0]] = purple
129
+ atom_colors[poi_attachment_idx[1]] = purple
130
+ bond_colors[poi_bond_idx] = purple
131
+
132
+ if e3_attachment_idx is not None:
133
+ if len(e3_attachment_idx) != 2:
134
+ if linker_smiles in ['[*:1][*:2]', '[*:2][*:1]']:
135
+ print(f'WARNING. Linker is empty, no highlighting will be showed for the E3.')
136
+ else:
137
+ print(f'WARNING. E3 attachment points must be only two, got instead: {e3_attachment_idx}')
138
+ else:
139
+ e3_bond_idx = protac_mol.GetBondBetweenAtoms(*e3_attachment_idx).GetIdx()
140
+ highlight_atoms += e3_attachment_idx
141
+ highlight_bonds.append(e3_bond_idx)
142
+ atom_colors[e3_attachment_idx[0]] = green
143
+ atom_colors[e3_attachment_idx[1]] = green
144
+ bond_colors[e3_bond_idx] = green
145
+
146
+ if useSVG:
147
+ drawer = Draw.rdMolDraw2D.MolDraw2DSVG(w, h, noFreetype=True)
148
+ options = drawer.drawOptions()
149
+ options.fontFile = '/System/Library/Fonts/Supplemental/Arial.ttf'
150
+
151
+ if legend is None:
152
+ # legend = '.'.join([e3_smiles, linker_smiles, poi_smiles])
153
+ legend = ""
154
+
155
+ drawer.DrawMolecule(
156
+ protac_mol,
157
+ legend=legend,
158
+ highlightAtoms=highlight_atoms,
159
+ highlightBonds=highlight_bonds,
160
+ highlightAtomColors=atom_colors,
161
+ highlightBondColors=bond_colors,
162
+ )
163
+
164
+ # Add bond indices as text in the center of each bond
165
+ if show_bond_indices:
166
+ # Needs coordinates; ensure 2D coords present
167
+ Chem.rdDepictor.Compute2DCoords(protac_mol)
168
+ for bond in protac_mol.GetBonds():
169
+ idx = bond.GetIdx()
170
+ begin = bond.GetBeginAtomIdx()
171
+ end = bond.GetEndAtomIdx()
172
+ begin_pos = drawer.GetDrawCoords(begin)
173
+ end_pos = drawer.GetDrawCoords(end)
174
+ mid_y = (begin_pos.y + end_pos.y) / 2
175
+ mid_x = (begin_pos.x + end_pos.x) / 2
176
+ drawer.DrawString(f"{idx}", Chem.rdGeometry.Point2D(mid_x, mid_y), rawCoords=True)
177
+
178
+ drawer.FinishDrawing()
179
+ svg_text = drawer.GetDrawingText()
180
+
181
+ if display_image:
182
+ safe_display(SVG(svg_text))
183
+
184
+ return svg_text
185
+ else:
186
+ img = Draw.MolToImage(
187
+ protac_mol,
188
+ size=(w, h),
189
+ highlightColor=purple,
190
+ highlightAtoms=highlight_atoms,
191
+ highlightBonds=highlight_bonds,
192
+ highlightAtomColors=atom_colors,
193
+ highlightBondColors=bond_colors,
194
+ )
195
+
196
+ if display_image:
197
+ safe_display(img)
198
+
199
+ return img
protac_splitter/drawing_utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import numpy as np
4
+ from rdkit import Chem, DataStructs
5
+ from rdkit.Chem import (
6
+ AllChem,
7
+ Draw,
8
+ rdFMCS,
9
+ rdMolAlign,
10
+ )
11
+
12
+
13
+ def save_as_svg(svg_content, filename, num_mols):
14
+ """Save SVG content to a file."""
15
+ with open(filename, 'w') as file:
16
+ data = str(svg_content.data)
17
+ data = data.replace('1500', str(500*num_mols))
18
+ file.write(data)
19
+
20
+
21
+ def align_molecules_2D(ref_mol, to_align_mol):
22
+ AllChem.Compute2DCoords(ref_mol)
23
+ AllChem.Compute2DCoords(to_align_mol)
24
+ # Find the maximum common substructure and use it to align molecules
25
+ mcs = rdFMCS.FindMCS([ref_mol, to_align_mol])
26
+ mcs_mol = Chem.MolFromSmarts(mcs.smartsString)
27
+ ref_match = ref_mol.GetSubstructMatch(mcs_mol)
28
+ align_match = to_align_mol.GetSubstructMatch(mcs_mol)
29
+ atom_map = list(zip(align_match, ref_match))
30
+ rdMolAlign.AlignMol(to_align_mol, ref_mol, atomMap=atom_map)
31
+ return to_align_mol
32
+
33
+
34
+ def align_molecules_by_coordinates(ref_mol, to_align_mol):
35
+ # Find the maximum common substructure
36
+ AllChem.Compute2DCoords(to_align_mol)
37
+ mcs = rdFMCS.FindMCS([ref_mol, to_align_mol])
38
+ mcs_mol = Chem.MolFromSmarts(mcs.smartsString)
39
+ ref_match = ref_mol.GetSubstructMatch(mcs_mol)
40
+ align_match = to_align_mol.GetSubstructMatch(mcs_mol)
41
+
42
+ # Copy the coordinates from the reference molecule to the molecule to be aligned
43
+ ref_conf = ref_mol.GetConformer()
44
+ align_conf = to_align_mol.GetConformer()
45
+ for ref_idx, align_idx in zip(ref_match, align_match):
46
+ ref_pos = ref_conf.GetAtomPosition(ref_idx)
47
+ align_conf.SetAtomPosition(align_idx, ref_pos)
48
+
49
+ return to_align_mol
50
+
51
+
52
+ def draw_molecule_to_svg(mol, size=(500, 500), scale=1.0):
53
+ drawer = Draw.rdMolDraw2D.MolDraw2DSVG(size[0], size[1])
54
+ drawer.drawOptions().fixedBondLength = scale
55
+ drawer.DrawMolecule(mol)
56
+ drawer.FinishDrawing()
57
+ svg = drawer.GetDrawingText()
58
+ svg = re.sub(r'\<\?xml.*?\?\>', '', svg) # Remove XML declaration
59
+ svg = svg.replace('<svg', '<g').replace(
60
+ '</svg>', '</g>') # Replace svg tags with g tags
61
+ return svg
62
+
63
+
64
+ def combine_svgs(svgs, output_filename, dimensions=None, size=(500, 500), xy_shifts=None):
65
+ if dimensions is None:
66
+ dimensions = (len(svgs), 1)
67
+ if xy_shifts is None:
68
+ xy_shifts = [(0, 0) for i in range(dimensions[0]*dimensions[1])]
69
+
70
+ width, height = size
71
+ grid_width, grid_height = dimensions
72
+ # Include only one XML declaration and the opening <svg> tag
73
+ combined_svg = f'<?xml version="1.0" standalone="no"?>\n'
74
+ combined_svg += f'<svg width="{grid_width * width}px" height="{grid_height * height}px" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">\n'
75
+
76
+ # Arrange SVGs in a grid
77
+ for i, (svg, xy_shift) in enumerate(zip(svgs, xy_shifts)):
78
+ x = (i % grid_width) * width
79
+ y = (i // grid_width) * height
80
+ combined_svg += f'<g transform="translate({x+xy_shift[0]},{y-xy_shift[1]})">{svg}</g>\n'
81
+
82
+ combined_svg += '</svg>'
83
+ with open(output_filename, 'w') as file:
84
+ file.write(combined_svg)
85
+
86
+
87
+ def draw_molecule_with_highlighted_bonds(mol, bonds_to_highlight):
88
+ """
89
+ Draws a molecule with specified atoms and bonds highlighted.
90
+
91
+ Parameters:
92
+ - smiles (str): SMILES string for the molecule.
93
+ - atoms_to_highlight (set): Set of atom indices to highlight.
94
+ - bonds_to_highlight (list): List of bond indices to highlight.
95
+ - highlight_bond_colors (dict): Dictionary mapping bond indices to colors.
96
+ """
97
+ # Create molecule from SMILES
98
+
99
+ # Initialize drawer
100
+ d2d = Draw.rdMolDraw2D.MolDraw2DSVG(350*2, 300*2)
101
+
102
+ # Set drawing options
103
+ d2d.drawOptions().useBWAtomPalette()
104
+ d2d.drawOptions().continuousHighlight = False
105
+ d2d.drawOptions().highlightBondWidthMultiplier = 24
106
+ d2d.drawOptions().setHighlightColour((0, 0, 1))
107
+ d2d.drawOptions().fillHighlights = False
108
+
109
+ # Draw the molecule with highlights
110
+ d2d.DrawMolecule(mol,
111
+ highlightAtoms=[],
112
+ highlightBonds=bonds_to_highlight)
113
+ d2d.FinishDrawing()
114
+
115
+ # Convert drawing to image and display
116
+ svg = d2d.GetDrawingText()
117
+ svg = svg.replace('svg:', '')
118
+
119
+ return svg
120
+
121
+
122
+ def align_mol_2D_ver2(template, query):
123
+ mcs = rdFMCS.FindMCS([template, query])
124
+ patt = Chem.MolFromSmarts(mcs.smartsString)
125
+
126
+ query_match = query.GetSubstructMatch(patt)
127
+ template_match = template.GetSubstructMatch(patt)
128
+
129
+ rms = AllChem.AlignMol(query, template, atomMap=list(
130
+ zip(query_match, template_match)))
131
+ return template, query
132
+
133
+
134
+ def transform_molecule(mol, degrees, translate_x=0, translate_y=0, flip_x_axis=False):
135
+ """Apply rotation, translation, and optionally flip the molecule."""
136
+ radians = np.deg2rad(degrees)
137
+ rotation_matrix = np.array([
138
+ [np.cos(radians), -np.sin(radians), 0],
139
+ [np.sin(radians), np.cos(radians), 0],
140
+ [0, 0, 1]
141
+ ])
142
+ AllChem.Compute2DCoords(mol)
143
+
144
+ conf = mol.GetConformer()
145
+ for i in range(conf.GetNumAtoms()):
146
+ pos = np.array(conf.GetAtomPosition(i))
147
+ new_pos = np.dot(rotation_matrix, pos)
148
+ new_pos[0] += translate_x # Translate along the x-axis
149
+ new_pos[1] += translate_y # Translate along the y-axis
150
+ if flip_x_axis:
151
+ new_pos[1] = -new_pos[1] # Flip along the x-axis
152
+ conf.SetAtomPosition(i, new_pos)
153
+
154
+
155
+ def tailored_framework_example(mol_ms):
156
+ # remove lone atoms
157
+ # define all atoms to be atom number 1
158
+ # define all bonds to be single bonds
159
+
160
+ mol_ms_w = Chem.RWMol(mol_ms)
161
+ atom_idx_to_remove = []
162
+ for atom in mol_ms_w.GetAtoms():
163
+ # lone atom. Need to remove it to create the generic framework.
164
+ if atom.GetDegree() == 1:
165
+ atom_idx_to_remove.append(atom.GetIdx())
166
+ continue
167
+ atom.SetAtomicNum(0)
168
+
169
+ for bond in mol_ms_w.GetBonds():
170
+ bond.SetBondType(Chem.rdchem.BondType.SINGLE)
171
+
172
+ atom_idx_to_remove.sort(reverse=True)
173
+ for atom_idx in atom_idx_to_remove:
174
+ mol_ms_w.RemoveAtom(atom_idx)
175
+
176
+ mol_ms_new = mol_ms_w.GetMol()
177
+ return mol_ms_new
protac_splitter/evaluation.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Evaluation functions for the protac_splitter package. They need to be generic to accomodate predictions coming from different models. """
2
+
3
+ import math
4
+ import re
5
+ import logging
6
+ from typing import Tuple, Any, Dict, Optional, Union
7
+
8
+ import numpy as np
9
+ from rdkit import Chem, RDLogger
10
+ from rdkit.Chem import DataStructs
11
+
12
+ # Disable RDKit logging: when checking SMILES validity, we suppress warnings
13
+ RDLogger.DisableLog("rdApp.*")
14
+
15
+ from .chemoinformatics import (
16
+ canonize,
17
+ canonize_smiles,
18
+ remove_stereo,
19
+ get_substr_match,
20
+ )
21
+ from .protac_cheminformatics import reassemble_protac
22
+ from .graphs_utils import (
23
+ get_smiles2graph_edit_distance,
24
+ get_smiles2graph_edit_distance_norm,
25
+ )
26
+
27
+
28
+ def is_valid_smiles(
29
+ smiles: Optional[str],
30
+ return_mol: bool = False,
31
+ ) -> Union[bool, Tuple[bool, Chem.Mol]]:
32
+ """ Check if a SMILES is valid, i.e., it can be parsed by RDKit.
33
+
34
+ Args:
35
+ smiles (Optional[str]): The SMILES to check.
36
+ return_mol (bool): If True, return the RDKit molecule object, i.e., `(is_valid, mol)`.
37
+
38
+ Returns:
39
+ bool | Tuple[bool, Chem.Mol]: True if the SMILES is valid, False otherwise. If return_mol is True, also return the RDKit molecule object.
40
+ """
41
+ if smiles is None:
42
+ return False
43
+ mol = Chem.MolFromSmiles(smiles)
44
+ if return_mol:
45
+ return mol is not None, mol
46
+ return mol is not None
47
+
48
+
49
+ def has_three_substructures(smiles: Optional[str]) -> bool:
50
+ """ Check if a PROTAC SMILES has three substructures. """
51
+ if smiles is None:
52
+ return False
53
+ return smiles.count(".") == 2
54
+
55
+
56
+ def has_all_attachment_points(smiles: Optional[str]) -> bool:
57
+ """ Check if a PROTAC SMILES has all attachment points, i.e., [*:1] and [*:2], two each. """
58
+ if smiles is None:
59
+ return False
60
+ return smiles.count("[*:1]") == 2 and smiles.count("[*:2]") == 2
61
+
62
+
63
+ def split_prediction(
64
+ pred: str,
65
+ poi_attachment_id: int = 1,
66
+ e3_attachment_id: int = 2,
67
+ ) -> Optional[dict[str, str]]:
68
+ """ Split a PROTAC SMILES prediction into its three substructures.
69
+
70
+ Args:
71
+ pred (str): The SMILES of the PROTAC molecule.
72
+ poi_attachment_id (int): The attachment point ID for the POI substructure.
73
+ e3_attachment_id (int): The attachment point ID for the E3 substructure.
74
+
75
+ Returns:
76
+ dict[str, str] | None: A dictionary (with keys: 'e3', 'linker', 'poi') containing the SMILES notations for the POI, linker, and E3 substructures, or None if the prediction is invalid
77
+ """
78
+ ret = {k: None for k in ['poi', 'linker', 'e3']}
79
+ if pred is None:
80
+ return ret
81
+
82
+ ligands = pred.split('.')
83
+ if len(ligands) != 3:
84
+ return ret
85
+
86
+ for ligand in ligands:
87
+ if f'[*:{poi_attachment_id}]' in ligand and f'[*:{e3_attachment_id}]' not in ligand:
88
+ ret['poi'] = ligand
89
+ elif f'[*:{e3_attachment_id}]' in ligand and f'[*:{poi_attachment_id}]' not in ligand:
90
+ ret['e3'] = ligand
91
+ elif f'[*:{poi_attachment_id}]' in ligand and f'[*:{e3_attachment_id}]' in ligand:
92
+ ret['linker'] = ligand
93
+ return ret
94
+
95
+
96
+ def rename_attachment_id(mol: Union[str, Chem.Mol], old_id: int, new_id: int) -> Union[str, Chem.Mol]:
97
+ """ Rename an attachment point ID in a molecule.
98
+
99
+ Args:
100
+ mol: The input molecule.
101
+ old_id: The old attachment point ID.
102
+ new_id: The new attachment point ID.
103
+
104
+ Returns:
105
+ The renamed molecule.
106
+ """
107
+ return_str = False
108
+ if isinstance(mol, Chem.Mol):
109
+ mol = Chem.MolToSmiles(mol, canonical=True)
110
+ return_str = True
111
+ # Regex-replace the patterns "[*:old_id]" or "[old_id*]" with "[*:new_id]"
112
+ mol = re.sub(rf'\[\*:{old_id}\]', f'[*:{new_id}]', mol)
113
+ mol = re.sub(rf'\[{old_id}\*\]', f'[*:{new_id}]', mol)
114
+ mol = canonize_smiles(mol)
115
+ if mol is None:
116
+ return None
117
+ mol = Chem.MolFromSmiles(mol)
118
+ if return_str:
119
+ return Chem.MolToSmiles(mol, canonical=True)
120
+ return mol
121
+
122
+ def at_least_two_ligands_correct(
123
+ protac_smiles: str,
124
+ ligands_smiles: str,
125
+ ) -> bool:
126
+ """ Check if at least two ligands are correct. """
127
+ # Check if there is at least one "." in the ligands SMILES
128
+ if "." not in ligands_smiles:
129
+ return False
130
+ ligands = ligands_smiles.split(".")
131
+ return True
132
+
133
+
134
+ def check_reassembly(
135
+ protac_smiles: str,
136
+ ligands_smiles: str,
137
+ stats: Optional[Dict[str, int]] = None,
138
+ linker_can_be_null: bool = False,
139
+ poi_attachment_id: int = 1,
140
+ e3_attachment_id: int = 2,
141
+ verbose: int = 0,
142
+ return_reassembled_smiles: bool = False,
143
+ ) -> bool:
144
+ """Check if the reassembled PROTAC matches the original PROTAC SMILES.
145
+
146
+ Args:
147
+ protac_smiles (str): The original PROTAC SMILES.
148
+ ligands_smiles (str): The SMILES of the joined PROTAC ligands, separated by a "." (dot).
149
+ stats (Optional[Dict[str, int]]): A dictionary to store statistics about the reassembly process.
150
+ linker_can_be_null (bool): If False, the linker cannot be empty, and if so, a None will be returned. If True, a special check is performed to rename the E3 and WH attchament points to assemble them together.
151
+ poi_attachment_id (int): The label of the attachment point for the POI ligand, i.e., "[*:{poi_attachment_id}]". Default is 1.
152
+ e3_attachment_id (int): The label of the attachment point for the E3 binder, i.e., "[*:{e3_attachment_id}]". Default is 2.
153
+ verbose (int): The verbosity
154
+
155
+ Returns:
156
+ bool: True if the reassembled PROTAC matches the original PROTAC SMILES, False otherwise. None if it failed.
157
+ """
158
+ ligands_smiles = canonize_smiles(ligands_smiles)
159
+ if ligands_smiles is None:
160
+ if verbose:
161
+ logging.error('Ligand could be canonicalized.')
162
+ return (False, None) if return_reassembled_smiles else False
163
+
164
+ null_linker_e3 = f'[*:{e3_attachment_id}][*:{poi_attachment_id}]'
165
+ null_linker_poi = f'[*:{poi_attachment_id}][*:{e3_attachment_id}]'
166
+ linker_is_null = False
167
+ if null_linker_e3 in ligands_smiles or null_linker_poi in ligands_smiles:
168
+ # If the linker is empty, remove the linker atoms
169
+ ligands_smiles = ligands_smiles.replace(null_linker_poi, '')
170
+ ligands_smiles = ligands_smiles.replace(null_linker_e3, '')
171
+ ligands_smiles = ligands_smiles.replace('..', '.')
172
+ ligands_smiles = ligands_smiles.rstrip('.')
173
+ ligands_smiles = ligands_smiles.lstrip('.')
174
+ ligands_smiles = canonize_smiles(ligands_smiles)
175
+ linker_is_null = True
176
+
177
+ if linker_can_be_null or linker_is_null:
178
+ if len(ligands_smiles.split('.')) == 2:
179
+ # Replace the attachment points with a third one (they will be joined later)
180
+ ligands_smiles = rename_attachment_id(ligands_smiles, e3_attachment_id, max([poi_attachment_id, e3_attachment_id]) + 1)
181
+ ligands_smiles = rename_attachment_id(ligands_smiles, poi_attachment_id, max([poi_attachment_id, e3_attachment_id]) + 1)
182
+
183
+ ligands_mol = Chem.MolFromSmiles(ligands_smiles)
184
+ if ligands_mol is None:
185
+ if verbose:
186
+ logging.error('ligands_mol is None')
187
+ return (False, None) if return_reassembled_smiles else False
188
+
189
+ try:
190
+ reassembled_mol = Chem.molzip(ligands_mol)
191
+ if reassembled_mol is None:
192
+ if stats is not None:
193
+ stats['molzip failed'] += 1
194
+ if verbose:
195
+ logging.error(f'molzip failed')
196
+ return (False, None) if return_reassembled_smiles else False
197
+ except:
198
+ if stats is not None:
199
+ stats['molzip failed (exception)'] += 1
200
+ if verbose:
201
+ logging.error(f'molzip failed (exception)')
202
+ return (False, None) if return_reassembled_smiles else False
203
+
204
+ try:
205
+ reassembled_smiles = canonize(Chem.MolToSmiles(reassembled_mol))
206
+ if reassembled_smiles is None:
207
+ if stats is not None:
208
+ stats['MolToSmiles of reassembled failed'] += 1
209
+ if verbose:
210
+ logging.error('MolToSmiles of reassembled failed')
211
+ return (False, None) if return_reassembled_smiles else False
212
+ except:
213
+ if stats is not None:
214
+ stats['MolToSmiles of reassembled failed'] += 1
215
+ if verbose:
216
+ logging.error('MolToSmiles of reassembled failed')
217
+ return (False, None) if return_reassembled_smiles else False
218
+
219
+ is_equal = canonize(protac_smiles) == reassembled_smiles
220
+
221
+ return (is_equal, reassembled_smiles) if return_reassembled_smiles else is_equal
222
+
223
+
224
+ def check_substructs(
225
+ protac_smiles: str,
226
+ poi_smiles: str = None,
227
+ linker_smiles: str = None,
228
+ e3_smiles: str = None,
229
+ return_bond_types: bool = False,
230
+ poi_attachment_id: int = 1,
231
+ e3_attachment_id: int = 2,
232
+ pred: str = None,
233
+ ) -> Union[bool, Tuple[bool, dict[str, str]]]:
234
+ """ DEPRECATED.
235
+
236
+ Check if the reassembled PROTAC is correct.
237
+
238
+ Args:
239
+ protac_smiles (str): The SMILES of the PROTAC molecule.
240
+ poi_smiles (str): The SMILES of the POI ligand.
241
+ linker_smiles (str): The SMILES of the linker.
242
+ e3_smiles (str): The SMILES of the E3 binder.
243
+ return_bond_types (bool): If True, return the bond types used for the reassembly.
244
+ poi_attachment_id (int): The label of the attachment point for the POI ligand, i.e., "[*:{poi_attachment_id}]".
245
+ e3_attachment_id (int): The label of the attachment point for the E3 binder, i.e., "[*:{e3_attachment_id}]".
246
+ pred (str): The SMILES of the predicted PROTAC molecule.
247
+
248
+ Returns:
249
+ bool | Tuple[bool, dict[str, str]]: True if the reassembled PROTAC is correct, False otherwise. If return_bond_types is True, also return the bond types used for the reassembly.
250
+ """
251
+ def get_failed_return():
252
+ if return_bond_types:
253
+ return False, {}
254
+ return False
255
+
256
+ # Make some checks before starting and fail if necessary
257
+ all_subs_none = all(v is None for v in [poi_smiles, linker_smiles, e3_smiles])
258
+ any_subs_none = any(v is None for v in [poi_smiles, linker_smiles, e3_smiles])
259
+
260
+ if pred is not None and all_subs_none:
261
+ # Split the prediction into the substructures
262
+ pred_substructs = split_prediction(pred, poi_attachment_id, e3_attachment_id)
263
+ if any(v is None for v in pred_substructs.values()):
264
+ return get_failed_return()
265
+ poi_smiles = pred_substructs['poi']
266
+ linker_smiles = pred_substructs['linker']
267
+ e3_smiles = pred_substructs['e3']
268
+ elif pred is None and any_subs_none:
269
+ return get_failed_return()
270
+ elif pred is None and all_subs_none:
271
+ logging.warning("Arguments 'pred' and 'poi_smiles', 'linker_smiles', 'e3_smiles' cannot be all None.")
272
+ return get_failed_return()
273
+
274
+ if f"[*:{poi_attachment_id}]" in e3_smiles:
275
+ return get_failed_return()
276
+ if f"[*:{e3_attachment_id}]" in poi_smiles:
277
+ return get_failed_return()
278
+ if f"[*:{poi_attachment_id}]" not in linker_smiles:
279
+ return get_failed_return()
280
+ if f"[*:{e3_attachment_id}]" not in linker_smiles:
281
+ return get_failed_return()
282
+
283
+ correct_substructs = False
284
+ protac_mol = Chem.MolFromSmiles(protac_smiles)
285
+ protac_inchi = Chem.MolToInchi(protac_mol)
286
+ protac_smiles_canon = canonize_smiles(protac_smiles)
287
+ bond_types = {}
288
+ bonds = ['single', 'double', 'triple']
289
+ # for e3_bond_type, poi_bond_type in itertools.product([bonds, bonds]):
290
+ for e3_bond_type in bonds:
291
+ for poi_bond_type in bonds:
292
+ try:
293
+ assmbl_smiles, assmbl_mol = reassemble_protac(
294
+ poi_smiles,
295
+ linker_smiles,
296
+ e3_smiles,
297
+ e3_bond_type,
298
+ poi_bond_type,
299
+ poi_attachment_id,
300
+ e3_attachment_id,
301
+ )
302
+ if assmbl_mol is not None:
303
+ # If either the InChI or SMILES of the reassembled PROTAC is
304
+ # the same as the original PROTAC, then the reassembly is
305
+ # correct.
306
+ if protac_inchi == Chem.MolToInchi(assmbl_mol):
307
+ correct_substructs = True
308
+ bond_types['e3_bond_type'] = e3_bond_type
309
+ bond_types['poi_bond_type'] = poi_bond_type
310
+ break
311
+ if protac_smiles_canon == canonize_smiles(assmbl_smiles):
312
+ correct_substructs = True
313
+ bond_types['e3_bond_type'] = e3_bond_type
314
+ bond_types['poi_bond_type'] = poi_bond_type
315
+ break
316
+ except:
317
+ continue
318
+ if return_bond_types:
319
+ return correct_substructs, bond_types
320
+ return correct_substructs
321
+
322
+
323
+ def score_prediction(
324
+ protac_smiles: str,
325
+ label_smiles: str,
326
+ pred_smiles: str,
327
+ rouge = None,
328
+ poi_attachment_id: int = 1,
329
+ e3_attachment_id: int = 2,
330
+ fpgen = Chem.rdFingerprintGenerator.GetMorganGenerator(radius=11, fpSize=2048),
331
+ compute_rdkit_metrics: bool = False,
332
+ compute_graph_metrics: bool = False,
333
+ graph_edit_kwargs: Dict[str, Any] = {},
334
+ ) -> dict[str, float]:
335
+ """ Score a PROTAC SMILES prediction.
336
+
337
+ Args:
338
+ protac_smiles (str): The SMILES of the PROTAC molecule.
339
+ label_smiles (str): The SMILES of the ground truth PROTAC molecule.
340
+ pred_smiles (str): The SMILES of the predicted PROTAC molecule.
341
+ rouge (Rouge | None): The Rouge object to use for scoring. If None, do not compute Rouge scores. Example: `rouge = evaluate.load("rouge")`
342
+ poi_attachment_id (int): The attachment point ID for the POI substructure.
343
+ e3_attachment_id (int): The attachment point ID for the E3 substructure.
344
+
345
+ Returns:
346
+ dict[str, float]: A dictionary containing the scores for the prediction
347
+ """
348
+ protac_mol = Chem.MolFromSmiles(protac_smiles)
349
+ protac_num_atoms = protac_mol.GetNumHeavyAtoms()
350
+
351
+ scores = {
352
+ 'has_three_substructures': has_three_substructures(pred_smiles),
353
+ 'has_all_attachment_points': has_all_attachment_points(pred_smiles),
354
+ 'num_fragments': 0 if pred_smiles is None else pred_smiles.count('.') + 1,
355
+ 'tanimoto_similarity': 0.0, # Default value
356
+ 'valid': False,
357
+ 'reassembly': False,
358
+ 'reassembly_nostereo': False,
359
+ 'heavy_atoms_difference': protac_num_atoms,
360
+ 'heavy_atoms_difference_norm': 1.0,
361
+ 'all_ligands_equal': False,
362
+ }
363
+
364
+ pred_substructs = split_prediction(pred_smiles, poi_attachment_id, e3_attachment_id)
365
+
366
+ # Compute metrics for the "entire" predicted PROTAC molecule
367
+ if None not in list(pred_substructs.values()):
368
+ e3_nostereo = remove_stereo(pred_substructs['e3'])
369
+ linker_nostereo = remove_stereo(pred_substructs['linker'])
370
+ poi_nostereo = remove_stereo(pred_substructs['poi'])
371
+ if None not in [e3_nostereo, linker_nostereo, poi_nostereo]:
372
+ pred_nostereo = f"{e3_nostereo}.{linker_nostereo}.{poi_nostereo}"
373
+ scores['reassembly_nostereo'] = check_reassembly(remove_stereo(protac_smiles), pred_nostereo)
374
+
375
+ scores['valid'] = is_valid_smiles(pred_smiles)
376
+ is_equal, reassembled_smiles = check_reassembly(protac_smiles, pred_smiles, return_reassembled_smiles=True)
377
+ scores['reassembly'] = is_equal
378
+
379
+ # Get the number of heavy atoms difference between the reassembled PROTAC and the ground truth PROTAC
380
+ if reassembled_smiles is not None:
381
+ reassembled_mol = Chem.MolFromSmiles(reassembled_smiles)
382
+ if reassembled_mol is not None:
383
+ scores['heavy_atoms_difference'] -= reassembled_mol.GetNumHeavyAtoms()
384
+ scores['heavy_atoms_difference_norm'] = scores['heavy_atoms_difference'] / protac_num_atoms
385
+
386
+ if scores['valid'] and compute_rdkit_metrics and fpgen is not None:
387
+ # Get Tanimoto similarity between the predicted PROTAC and the ground truth PROTAC
388
+ pred_mol = Chem.MolFromSmiles(pred_smiles)
389
+ label_mol = Chem.MolFromSmiles(label_smiles)
390
+ pred_fp = fpgen.GetFingerprint(pred_mol)
391
+ label_fp = fpgen.GetFingerprint(label_mol)
392
+ scores['tanimoto_similarity'] = DataStructs.TanimotoSimilarity(pred_fp, label_fp)
393
+
394
+ if rouge is not None:
395
+ rouge_output = rouge.compute(predictions=[pred_smiles], references=[label_smiles])
396
+ scores.update({k: v for k, v in rouge_output.items()})
397
+
398
+ # Compute metrics for each substructure
399
+ label_substructs = split_prediction(label_smiles, poi_attachment_id, e3_attachment_id)
400
+
401
+ # Set default values
402
+ for sub in ['e3', 'poi', 'linker']:
403
+ scores[f'{sub}_valid'] = False
404
+ scores[f'{sub}_equal'] = False
405
+ scores[f'{sub}_has_attachment_point(s)'] = False
406
+ scores[f'{sub}_tanimoto_similarity'] = 0.0
407
+
408
+ # NOTE: The graph edit distance can be very high and dependant on the
409
+ # graphs, but when the molecule is not valid, then we cannot compute it.
410
+ # Because of that, we instead set it to something very large, in case we
411
+ # need to sum the eval metrics.
412
+ scores[f'{sub}_graph_edit_distance'] = 1e64
413
+ scores[f'{sub}_graph_edit_distance_norm'] = 1.0
414
+ scores[f'{sub}_heavy_atoms_difference'] = 0
415
+ try:
416
+ scores[f'{sub}_heavy_atoms_difference'] = Chem.MolFromSmiles(label_substructs[sub]).GetNumHeavyAtoms()
417
+ except:
418
+ logging.warning(f"WARNING: {sub} substructure is None in the label: '{label_smiles}' - PROTAC: '{protac_smiles}'")
419
+ scores[f'{sub}_heavy_atoms_difference_norm'] = 1.0
420
+
421
+ # Calculate metrics for each substructure
422
+ for sub in ['e3', 'poi', 'linker']:
423
+ # Skip if the predicted substructure is None from `split_prediction`
424
+ pred_sub = pred_substructs[sub]
425
+ label_sub = label_substructs[sub]
426
+ if pred_sub is None:
427
+ continue
428
+ if label_sub is None:
429
+ logging.warning(f"WARNING: {sub} substructure is None in the label: '{label_smiles}' - PROTAC: '{protac_smiles}'")
430
+ continue
431
+
432
+ # Check if the predicted substructure is a valid RDKit molecule
433
+ sub_valid, sub_mol = is_valid_smiles(pred_sub, return_mol=True)
434
+ scores[f'{sub}_valid'] = sub_valid
435
+
436
+ if sub_mol is None:
437
+ continue
438
+
439
+ # Check if the predicted substructure has the correct attachment point(s)
440
+ if sub == 'e3':
441
+ if f'[*:{e3_attachment_id}]' in pred_sub and f'[*:{poi_attachment_id}]' not in pred_sub:
442
+ scores[f'{sub}_has_attachment_point(s)'] = True
443
+ elif sub == 'poi':
444
+ if f'[*:{poi_attachment_id}]' in pred_sub and f'[*:{e3_attachment_id}]' not in pred_sub:
445
+ scores[f'{sub}_has_attachment_point(s)'] = True
446
+ elif sub == 'linker':
447
+ if f'[*:{poi_attachment_id}]' in pred_sub and f'[*:{e3_attachment_id}]' in pred_sub:
448
+ scores[f'{sub}_has_attachment_point(s)'] = True
449
+
450
+ # Check if the predicted substructure InChI is the same as the ground truth substructure InChI
451
+ if scores[f'{sub}_valid']:
452
+ # scores[f'{sub}_equal'] = Chem.MolToInchi(sub_mol) == Chem.MolToInchi(Chem.MolFromSmiles(label_sub))
453
+ canon_pred = canonize_smiles(pred_sub)
454
+ canon_label = canonize_smiles(label_sub)
455
+ scores[f'{sub}_equal'] = canon_pred == canon_label
456
+
457
+ # Compute graph-related metrics
458
+ if scores[f'{sub}_valid'] and compute_graph_metrics:
459
+ scores[f'{sub}_graph_edit_distance'] = get_smiles2graph_edit_distance(pred_sub, label_sub, **graph_edit_kwargs)
460
+ scores[f'{sub}_graph_edit_distance_norm'] = get_smiles2graph_edit_distance_norm(
461
+ smi1=pred_sub,
462
+ smi2=label_sub,
463
+ ged_G1_G2=scores[f'{sub}_graph_edit_distance'],
464
+ **graph_edit_kwargs,
465
+ )
466
+
467
+ # Get the number of heavy atoms difference between the predicted substructure and the ground truth substructure
468
+ if scores[f'{sub}_valid']:
469
+ pred_mol = Chem.MolFromSmiles(pred_sub)
470
+ label_mol = Chem.MolFromSmiles(label_sub)
471
+ if label_mol is None:
472
+ logging.warning(f"WARNING: {sub} substructure is None in the label: '{label_smiles}' - PROTAC: '{protac_smiles}'")
473
+ continue
474
+ scores[f'{sub}_heavy_atoms_difference'] -= pred_mol.GetNumHeavyAtoms()
475
+ scores[f'{sub}_heavy_atoms_difference_norm'] = scores[f'{sub}_heavy_atoms_difference'] / label_mol.GetNumHeavyAtoms()
476
+
477
+ # Get Tanimoto similarity b/w the predicted substructure and the ground truth
478
+ if scores[f'{sub}_valid'] and compute_rdkit_metrics:
479
+ pred_mol = Chem.MolFromSmiles(pred_sub)
480
+ label_mol = Chem.MolFromSmiles(label_sub)
481
+ if label_mol is None:
482
+ logging.warning(f"WARNING: {sub} substructure is None in the label: '{label_smiles}' - PROTAC: '{protac_smiles}'")
483
+ continue
484
+ pred_fp = fpgen.GetFingerprint(pred_mol)
485
+ label_fp = fpgen.GetFingerprint(label_mol)
486
+ scores[f'{sub}_tanimoto_similarity'] = DataStructs.TanimotoSimilarity(pred_fp, label_fp)
487
+
488
+ # Compute Rouge scores
489
+ if rouge is not None:
490
+ rouge_output = rouge.compute(predictions=[pred_sub], references=[label_sub])
491
+ scores.update({f'{sub}_{k}': v for k, v in rouge_output.items()})
492
+
493
+ scores['all_ligands_equal'] = all([scores[f'{sub}_equal'] for sub in ['e3', 'poi', 'linker']])
494
+
495
+ return scores
protac_splitter/fixing_functions.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ from rdkit import Chem
5
+
6
+ from protac_splitter.chemoinformatics import (
7
+ canonize,
8
+ dummy2query,
9
+ remove_attach_atom,
10
+ remove_dummy_atoms,
11
+ )
12
+ from protac_splitter.evaluation import (
13
+ split_prediction,
14
+ check_reassembly,
15
+ )
16
+ from protac_splitter.data.curation.substructure_extraction import get_attachment_bonds
17
+
18
+ def fix_tetrahedral_centers_ligand(
19
+ protac_mol: Chem.Mol,
20
+ ligand_smiles: str,
21
+ attachment_id: int = 1,
22
+ ) -> Optional[str]:
23
+ """ Fixes the tetrahedral centers of a ligand in a PROTAC molecule.
24
+
25
+ Args:
26
+ protac_mol (Chem.Mol): The RDKit molecule object of the PROTAC.
27
+ ligand_smiles (str): The SMILES of the ligand to fix.
28
+ attachment_id (int): The attachment point id of the ligand. Default is 1.
29
+
30
+ Returns:
31
+ A string containing the fixed ligand SMILES, or None if the fixing process failed.
32
+ """
33
+ ligand_mol = Chem.MolFromSmiles(ligand_smiles)
34
+ if ligand_mol is None:
35
+ logging.error(f"Invalid ligand SMILES: {ligand_smiles}")
36
+ return None
37
+
38
+ ligand_mol = remove_dummy_atoms(ligand_mol)
39
+ ligand_match = protac_mol.GetSubstructMatch(ligand_mol, useChirality=False) # useChirality=True
40
+
41
+ # Get bonds to break to separate the ligand
42
+ bonds_to_break = get_attachment_bonds(protac_mol, ligand_match)
43
+
44
+ # Return if no bonds are found
45
+ if len(bonds_to_break) != 1:
46
+ logging.error('ERROR: Multiple attachment bonds')
47
+ return None
48
+
49
+ # Break the bonds to isolate the ligand
50
+ frag_ligand_mol = Chem.FragmentOnBonds(protac_mol, bonds_to_break, addDummies=True, dummyLabels=[(attachment_id, attachment_id)])
51
+
52
+ # Get the fragments resulting from bond breaking
53
+ try:
54
+ frags = Chem.GetMolFrags(frag_ligand_mol, asMols=True, sanitizeFrags=True)
55
+ except Exception as e:
56
+ logging.error(e)
57
+ return None
58
+
59
+ # Identify the ligand fragment
60
+ ligand_fragment = None
61
+ for frag in frags:
62
+ if frag.HasSubstructMatch(ligand_mol):
63
+ ligand_fragment = frag
64
+ break
65
+ if ligand_fragment is None:
66
+ logging.error('ERROR: POI fragment not found')
67
+
68
+ ligand_fixed = Chem.MolToSmiles(ligand_fragment)
69
+ ligand_fixed = canonize(ligand_fixed.replace(f'[{attachment_id}*]', f'[*:{attachment_id}]'))
70
+ return ligand_fixed
71
+
72
+
73
+ def fix_prediction(
74
+ protac_smiles: str,
75
+ pred_smiles: str,
76
+ poi_attachment_id: int = 1,
77
+ e3_attachment_id: int = 2,
78
+ remove_stereochemistry: bool = False,
79
+ verbose: int = 0,
80
+ ) -> Optional[str]:
81
+ """ Fixes a prediction by replacing the substructure that does not match the PROTAC with the rest of the PROTAC.
82
+
83
+ Args:
84
+ protac_smiles (str): The SMILES of the PROTAC.
85
+ pred_smiles (str): The SMILES of the prediction.
86
+ poi_attachment_id (int): The attachment point id of the POI. Default is 1.
87
+ e3_attachment_id (int): The attachment point id of the E3 ligase. Default is 2.
88
+ verbose (int): The verbosity level. Default is 0.
89
+
90
+ Returns:
91
+ A string containing the fixed predictions, or None if the fixing process failed.
92
+ """
93
+ protac_mol = Chem.MolFromSmiles(protac_smiles)
94
+ if protac_mol is None:
95
+ logging.warning(f"Invalid PROTAC SMILES: {protac_smiles}")
96
+ return None
97
+
98
+ substructs = split_prediction(pred_smiles)
99
+
100
+ # If there are at least two None values, there's nothing we can do to fix it
101
+ if sum(v is None for v in substructs.values()) >= 2:
102
+ logging.warning(f'Unable to continue, more than two substructures are not valid for given input: "{pred_smiles}"')
103
+ return None
104
+
105
+ # Get molecules of PROTAC and substructures
106
+ substructs = {k: {'smiles': v, 'mol': Chem.MolFromSmiles(v) if v is not None else v} for k, v in substructs.items()}
107
+
108
+ # Check if renaming the attachment points might already fix the prediction
109
+ for sub in ['poi', 'e3', 'both']:
110
+ if sub == 'e3':
111
+ if substructs['e3']['smiles'] is None:
112
+ continue
113
+ e3_attempt = substructs['e3']['smiles'].replace(f'[*:{poi_attachment_id}]', f'[*:{e3_attachment_id}]')
114
+ poi_attempt = substructs['poi']['smiles']
115
+ if sub == 'poi':
116
+ if substructs['poi']['smiles'] is None:
117
+ continue
118
+ e3_attempt = substructs['e3']['smiles']
119
+ poi_attempt = substructs['poi']['smiles'].replace(f'[*:{e3_attachment_id}]', f'[*:{poi_attachment_id}]')
120
+ else:
121
+ if substructs['e3']['smiles'] is None or substructs['poi']['smiles'] is None:
122
+ continue
123
+ e3_attempt = substructs['e3']['smiles'].replace(f'[*:{e3_attachment_id}]', f'[*:{poi_attachment_id}]')
124
+ poi_attempt = substructs['poi']['smiles'].replace(f'[*:{poi_attachment_id}]', f'[*:{e3_attachment_id}]')
125
+
126
+ protac_attempt = f"{e3_attempt}.{substructs['linker']['smiles']}.{poi_attempt}"
127
+ if check_reassembly(protac_smiles, protac_attempt):
128
+ logging.info(f'Input works when renaming attachment points in {sub.title()} substruct. SMILES: "{protac_attempt}"')
129
+ return protac_attempt
130
+
131
+ # Check if swapping the POI and E3 attachments in the linker might already fix the prediction
132
+ if substructs['linker']['smiles'] is None:
133
+ continue
134
+ linker_attempt = substructs['linker']['smiles']
135
+ linker_attempt = linker_attempt.replace(f'[*:{poi_attachment_id}]', f'[*:DUMMY]')
136
+ linker_attempt = linker_attempt.replace(f'[*:{e3_attachment_id}]', f'[*:{poi_attachment_id}]')
137
+ linker_attempt = linker_attempt.replace(f'[*:DUMMY]', f'[*:{e3_attachment_id}]')
138
+
139
+ # Try with the original POI and E3 substructures
140
+ protac_attempt = f"{substructs['e3']['smiles']}.{linker_attempt}.{substructs['poi']['smiles']}"
141
+ if check_reassembly(protac_smiles, protac_attempt):
142
+ logging.info(f'Input works when swapping POI and E3 attachment points in the linker. Fixed SMILES: "{protac_attempt}"')
143
+ return protac_attempt
144
+
145
+ # Try with the swapped POI and E3 substructures
146
+ protac_attempt = f"{e3_attempt}.{linker_attempt}.{poi_attempt}"
147
+ if check_reassembly(protac_smiles, protac_attempt):
148
+ logging.info(f'Input works when swapping POI and E3 attachment points in the linker and in {sub.title()} substruct. Fixed SMILES: "{protac_attempt}"')
149
+ return protac_attempt
150
+
151
+ # Check if removing stereochemistry results in a valid prediction
152
+ if remove_stereochemistry:
153
+ Chem.RemoveStereochemistry(protac_mol)
154
+ protac_smiles = Chem.MolToSmiles(protac_mol, canonical=True)
155
+ for k, v in substructs.items():
156
+ if v['mol'] is not None:
157
+ Chem.RemoveStereochemistry(v['mol'])
158
+ substructs[k]['smiles'] = Chem.MolToSmiles(v['mol'], canonical=True)
159
+
160
+ if all(v['mol'] is not None for v in substructs.values()):
161
+ if check_reassembly(
162
+ protac_smiles,
163
+ '.'.join([v['smiles'] for v in substructs.values()]),
164
+ ):
165
+ logging.info(f'Input works when removing stereochemistry. SMILES: "{pred_smiles}"')
166
+ return f"{substructs['e3']['smiles']}.{substructs['linker']['smiles']}.{substructs['poi']['smiles']}"
167
+
168
+ # Check if any of the substructures is NOT a substructure of the PROTAC, if
169
+ # so, we mark it as the wrong substructure to fix.
170
+ num_matches = 0
171
+ wrong_substruct = None
172
+ for sub in ['poi', 'linker', 'e3']:
173
+ if substructs[sub]['mol'] is None:
174
+ substructs[sub]['match'] = False
175
+ wrong_substruct = sub
176
+ elif protac_mol.HasSubstructMatch(dummy2query(substructs[sub]['mol'])):
177
+ substructs[sub]['match'] = True
178
+ num_matches += 1
179
+ else:
180
+ substructs[sub]['match'] = False
181
+ wrong_substruct = sub
182
+
183
+ if num_matches < 2:
184
+ logging.warning(f'Prediction does not contain at least two matching substructures of the PROTAC. Num matches: {num_matches}. Prediction SMILES: "{pred_smiles}"')
185
+ return None
186
+
187
+ # If the wrong substructure is still matching in the PROTAC, we need to a
188
+ # more complex approach to fix the prediction (see below).
189
+ def remove_substructure(mol, substructure, attachment_id, replaceDummies=False):
190
+ if mol is None or substructure is None:
191
+ return None
192
+ smaller_mol = Chem.ReplaceCore(
193
+ mol,
194
+ substructure,
195
+ labelByIndex=False,
196
+ replaceDummies=replaceDummies,
197
+ )
198
+ if smaller_mol is None:
199
+ logging.warning(f'Failed to remove substructure from prediction SMILES: "{pred_smiles}"')
200
+ return None
201
+ smaller_smiles = Chem.MolToSmiles(smaller_mol, canonical=True)
202
+ smaller_smiles = smaller_smiles.replace('[1*]', f'[*:{attachment_id}]')
203
+ smaller_smiles = smaller_smiles.replace('[2*]', f'[*:{attachment_id}]')
204
+ smaller_mol = canonize(Chem.MolFromSmiles(smaller_smiles))
205
+ return smaller_mol
206
+
207
+ # If we still have 3 matches: for each substructure, we progressively remove
208
+ # the other substructures, then we check if the resulting molecule is valid
209
+ # and has only one fragment.
210
+ if num_matches == 3:
211
+ wrong_substruct = None
212
+ for sub in ['poi', 'linker', 'e3']:
213
+ removed_mol = Chem.MolFromSmiles(protac_smiles)
214
+
215
+ # Put the current substructure at the end of the list [poi, e3, linker]
216
+ sub_names = ['poi', 'e3', 'linker']
217
+ sub_names.remove(sub)
218
+ sub_names.append(sub)
219
+ # The linker often matches in many parts of the PROTAC, so we remove
220
+ # it when checking the E3 and POI substructures.
221
+ if sub != 'linker':
222
+ sub_names.remove('linker')
223
+
224
+ for s in sub_names:
225
+ attachment_id = poi_attachment_id if s == 'poi' else e3_attachment_id
226
+ removed_mol = remove_substructure(
227
+ removed_mol,
228
+ dummy2query(substructs[s]['mol']),
229
+ attachment_id=attachment_id,
230
+ )
231
+
232
+ # Check if resulting molecule is None, if so, it is the wrong one
233
+ if removed_mol is None:
234
+ substructs[sub]['match'] = False
235
+ wrong_substruct = sub
236
+ num_matches -= 1
237
+ break
238
+
239
+ # Count the number of fragments in the removed molecule
240
+ num_fragments = Chem.GetMolFrags(removed_mol, asMols=True, sanitizeFrags=False)
241
+ if len(num_fragments) > 1:
242
+ substructs[sub]['match'] = False
243
+ wrong_substruct = sub
244
+ num_matches -= 1
245
+ break
246
+
247
+ if num_matches == 3:
248
+ logging.warning(f'Prediction already contains all matching substructures of the PROTAC. Prediction SMILES: "{pred_smiles}"')
249
+ return None
250
+
251
+ # Get the order in which to remove the substructures and get the final one
252
+ # as the fixed molecule.
253
+ if wrong_substruct == 'linker':
254
+ poi_atoms = substructs['poi']['mol'].GetNumAtoms()
255
+ e3_atoms = substructs['e3']['mol'].GetNumAtoms()
256
+ order = ['poi', 'e3'] if poi_atoms > e3_atoms else ['e3', 'poi']
257
+ else:
258
+ if wrong_substruct == 'poi':
259
+ order = ['e3', 'linker']
260
+ else:
261
+ order = ['poi', 'linker']
262
+
263
+ logging.debug(f'Wrong substructure: {wrong_substruct.upper()}. Order: {order}')
264
+
265
+ fixed_mol = protac_mol
266
+ for sub in order:
267
+ logging.debug(f'Removing substructure {sub.upper()} from PROTAC.')
268
+
269
+ if 'linker' not in order:
270
+ fixed_attach_id = poi_attachment_id if sub == 'poi' else e3_attachment_id
271
+ else:
272
+ fixed_attach_id = poi_attachment_id if 'e3' in order else e3_attachment_id
273
+
274
+ if sub == 'linker':
275
+ attach_id = poi_attachment_id if wrong_substruct == 'poi' else e3_attachment_id
276
+ fixed_attach_id = poi_attachment_id if wrong_substruct == 'poi' else e3_attachment_id
277
+ query_mol = remove_attach_atom(substructs[sub]['mol'], attach_id)
278
+ replaceDummies = True
279
+ else:
280
+ query_mol = dummy2query(substructs[sub]['mol'])
281
+ replaceDummies = False
282
+
283
+ if verbose:
284
+ # display(Draw.MolToImage(fixed_mol, legend=f"Starting molecule", size=(800, 300)))
285
+ # display(Draw.MolToImage(query_mol, legend=f"Molecule {sub.upper()} to remove", size=(800, 300)))
286
+ pass
287
+
288
+ fixed_mol_tmp = remove_substructure(
289
+ fixed_mol,
290
+ query_mol,
291
+ attachment_id=fixed_attach_id,
292
+ replaceDummies=replaceDummies,
293
+ )
294
+ if fixed_mol_tmp is None:
295
+ logging.debug(f'Failed to replace substructure "{sub}" in prediction SMILES: "{pred_smiles}"')
296
+ continue
297
+
298
+ fixed_mol = fixed_mol_tmp
299
+
300
+ # If there are multiple fragments, keep the biggest one
301
+ fragments = Chem.GetMolFrags(fixed_mol, asMols=True)
302
+ if len(fragments) > 1:
303
+ logging.debug(f'Fixed molecule contains more than one fragment. Keeping the biggest one.')
304
+ max_frag = max(fragments, key=lambda x: x.GetNumAtoms())
305
+ fixed_mol = max_frag
306
+
307
+ # Get the SMILES of the fixed molecule
308
+ fixed_smiles = Chem.MolToSmiles(canonize(fixed_mol), canonical=True)
309
+ substructs[wrong_substruct]['smiles'] = fixed_smiles
310
+
311
+ if verbose:
312
+ # display(Draw.MolToImage(fixed_mol, legend=f"{wrong_substruct.upper()} fixed molecule: {fixed_smiles}", size=(800, 300)))
313
+ pass
314
+
315
+ # Concatenate the substructures check if the re-assembly is correct
316
+ fixed_pred_smiles = f"{substructs['e3']['smiles']}.{substructs['linker']['smiles']}.{substructs['poi']['smiles']}"
317
+
318
+ if not check_reassembly(
319
+ protac_smiles,
320
+ fixed_pred_smiles,
321
+ ):
322
+ # logging.warning(f"Failed to fix prediction, re-assembly check failed. Generated fixed prediction (failing): {fixed_pred_smiles}")
323
+ # return None
324
+
325
+ # Check if by flipping the tetrahedral centers of the ligands we can
326
+ # still fix the prediction.
327
+ protac_mol = canonize(Chem.MolFromSmiles(protac_smiles))
328
+ chiral_centers = Chem.FindMolChiralCenters(
329
+ protac_mol,
330
+ includeUnassigned=True,
331
+ useLegacyImplementation=False,
332
+ )
333
+ if not chiral_centers:
334
+ logging.warning(f"Failed to fix prediction, re-assembly check failed. Generated fixed prediction (failing): {fixed_pred_smiles}")
335
+ return None
336
+
337
+ # Attempt to fix the tetrahedral centers of the ligands
338
+ e3_fixed = fix_tetrahedral_centers_ligand(protac_mol, substructs['e3']['smiles'], attachment_id=e3_attachment_id)
339
+ poi_fixed = fix_tetrahedral_centers_ligand(protac_mol, substructs['poi']['smiles'], attachment_id=poi_attachment_id)
340
+ if e3_fixed is None or poi_fixed is None:
341
+ logging.warning(f"Failed to fix prediction, re-assembly check failed. Generated fixed prediction (failing): {fixed_pred_smiles}")
342
+ return None
343
+
344
+ # Update the substructures with the fixed ligands and check re-assembly
345
+ substructs['e3']['smiles'] = e3_fixed
346
+ substructs['poi']['smiles'] = poi_fixed
347
+ fixed_pred_smiles = f"{substructs['e3']['smiles']}.{substructs['linker']['smiles']}.{substructs['poi']['smiles']}"
348
+ if not check_reassembly(
349
+ protac_smiles,
350
+ fixed_pred_smiles,
351
+ ):
352
+ logging.warning(f"Failed to fix prediction, re-assembly check failed. Generated fixed prediction (failing): {fixed_pred_smiles}")
353
+ return None
354
+
355
+ return fixed_pred_smiles
protac_splitter/graphs/README.md ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Graph-Based PROTAC-Splitter
2
+
3
+ ## Heuristic Betweenness Centrality
4
+
5
+ ```python
6
+ idx = 3765
7
+ for i in range(10):
8
+ # sample = held_out_df.sample(n=1, random_state=42 + i).iloc[0]
9
+ sample = held_out_df.iloc[i]
10
+ # sample = held_out_df.iloc[i]
11
+ protac_smiles = sample['PROTAC SMILES']
12
+ wh_smiles = sample['POI Ligand SMILES with direction']
13
+ lk_smiles = sample['Linker SMILES with direction']
14
+ e3_smiles = sample['E3 Binder SMILES with direction']
15
+
16
+ protac = Chem.MolFromSmiles(protac_smiles)
17
+ wh = Chem.MolFromSmiles(wh_smiles)
18
+ lk = Chem.MolFromSmiles(lk_smiles)
19
+ e3 = Chem.MolFromSmiles(e3_smiles)
20
+
21
+ # display_mol(Chem.MolFromSmiles(protac_smiles), w=1500, h=600)
22
+ get_mapped_protac_img(protac_smiles, wh_smiles, lk_smiles, e3_smiles, w=1500, h=600, display_image=True, useSVG=False)
23
+ # wh_edge = get_atom_idx_at_attachment(protac, wh, lk)
24
+ # e3_edge = get_atom_idx_at_attachment(protac, e3, lk)
25
+
26
+ ret = nx_split(protac_smiles, representative_e3s_fp, morgan_fp_generator, use_capacity_weight=False, betweenness_threshold=0.4)
27
+ e3_smiles = ret['e3']
28
+ wh_smiles = ret['poi']
29
+ linker_smiles = ret['linker']
30
+ top_nodes = ret['top_nodes']
31
+ centrality = ret['centrality']
32
+
33
+ # display_mol(Chem.MolFromSmiles(e3_smiles), w=800, h=400, legend="E3")
34
+ # display_mol(Chem.MolFromSmiles(linker_smiles), w=800, h=400, legend="Linker")
35
+ # display_mol(Chem.MolFromSmiles(wh_smiles), w=800, h=400, legend="WH")
36
+
37
+ display_mol(Chem.MolFromSmiles('.'.join([wh_smiles, linker_smiles, e3_smiles])), w=800, h=400, legend="Graph-based split")
38
+
39
+
40
+ display(Draw.MolToImage(
41
+ protac,
42
+ size=(1500, 400),
43
+ highlightColor=(1, 0, 1, 0.3), # Light purple
44
+ highlightAtoms=top_nodes, # Highlight the top nodes
45
+ legend=f"Graph nodes: {top_nodes} (Betweenness centrality: {centrality[top_nodes[0]]:.3f})",
46
+ ))
47
+ ```
48
+
49
+
50
+ ## Graph Edge Classifier Example
51
+
52
+ Example of how to use the GraphEdgeClassifier to train a model on a dataset of PROTACs and their ligands, and then predict edges in new PROTACs.
53
+
54
+ ```python
55
+ label_cols = [c for c in train_set.columns if c.startswith("label_")]
56
+ train_set = sets["train"].dropna(subset=label_cols)
57
+ train_set = train_set[(train_set["label_e3_split"] + train_set["label_wh_split"]) <= 1]
58
+ X_train = train_set.drop(columns=label_cols)
59
+
60
+ graph_features = [c for c in X_train.columns if c.startswith("graph_")]
61
+ # graph_features = [
62
+ # "graph_betweenness",
63
+ # "graph_degree",
64
+ # "graph_degree_r2",
65
+ # "graph_degree_r3",
66
+ # ]
67
+ categorical_features = ["chem_bond_type", "chem_atom_u", "chem_atom_v"]
68
+ fingerprint_features = [c for c in X_train.columns if c.startswith("chem_mol_fp_")]
69
+
70
+ # Instantiate and train
71
+ clf = GraphEdgeClassifier(
72
+ graph_features=graph_features,
73
+ categorical_features=categorical_features,
74
+ fingerprint_features=fingerprint_features,
75
+ use_descriptors=False,
76
+ use_fingerprints=False,
77
+ binary=True,
78
+ )
79
+ y_train = train_set["label_is_split"].astype("int32") if clf.binary else GraphEdgeClassifier.build_multiclass_target(train_set)
80
+
81
+ clf.fit(X_train, y_train)
82
+ clf.save("../models/edge_classifier_bin.joblib")
83
+ print(f"Model saved to ../models/edge_classifier_bin.joblib")
84
+
85
+ label_cols = [c for c in train_set.columns if c.startswith("label_")]
86
+ train_set = sets["train"].dropna(subset=label_cols)
87
+ train_set = train_set[(train_set["label_e3_split"] + train_set["label_wh_split"]) <= 1]
88
+ X_train = train_set.drop(columns=label_cols)
89
+
90
+ graph_features = [c for c in X_train.columns if c.startswith("graph_")]
91
+ # graph_features = [
92
+ # "graph_betweenness",
93
+ # "graph_degree",
94
+ # "graph_degree_r2",
95
+ # "graph_degree_r3",
96
+ # ]
97
+ categorical_features = ["chem_bond_type", "chem_atom_u", "chem_atom_v"]
98
+ fingerprint_features = [c for c in X_train.columns if c.startswith("chem_mol_fp_")]
99
+
100
+ # Instantiate and train
101
+ clf = GraphEdgeClassifier(
102
+ graph_features=graph_features,
103
+ categorical_features=categorical_features,
104
+ fingerprint_features=fingerprint_features,
105
+ use_descriptors=False,
106
+ use_fingerprints=False,
107
+ binary=False,
108
+ )
109
+ y_train = train_set["label_is_split"].astype("int32") if clf.binary else GraphEdgeClassifier.build_multiclass_target(train_set)
110
+
111
+ clf.fit(X_train, y_train)
112
+ clf.save("../models/edge_classifier.joblib")
113
+ print(f"Model saved to ../models/edge_classifier.joblib")
114
+ ```
protac_splitter/graphs/__init__.py ADDED
File without changes
protac_splitter/graphs/e3_clustering.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Any, Dict
2
+ import functools
3
+
4
+ import pandas as pd
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from sklearn.cluster import AgglomerativeClustering, KMeans
8
+ from scipy.stats import skew
9
+ from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score
10
+
11
+ from rdkit import Chem, DataStructs
12
+ from rdkit.Chem import rdFingerprintGenerator
13
+
14
+ from protac_splitter.graphs.utils import get_fp, numpy_to_rdkit_fp
15
+ from protac_splitter.chemoinformatics import remove_dummy_atoms
16
+
17
+
18
+ def get_umap_clusters_fp(fp_list: List[str], n_clusters: int = 7) -> np.ndarray:
19
+ """
20
+ Cluster a list of SMILES strings using the umap clustering algorithm.
21
+ From Scaffold Splits Overestimate Virtual Screening Performance
22
+ https://arxiv.org/abs/2406.00873
23
+
24
+ Args:
25
+ fp_list (List[str]): List of SMILES strings.
26
+ n_clusters (int): The number of clusters to use for clustering.
27
+
28
+ Returns:
29
+ np.ndarray: Array of cluster labels corresponding to each SMILES string in the input list.
30
+ """
31
+ ac = AgglomerativeClustering(n_clusters=n_clusters)
32
+ ac.fit_predict(np.stack(fp_list))
33
+ return ac.labels_
34
+
35
+ def get_kmeans_clusters_fp(fp_list: List[str], n_clusters: int = 10, return_centroids: bool = False) -> np.ndarray:
36
+ """
37
+ Cluster a list of SMILES strings using the KMeans clustering algorithm.
38
+
39
+ Args:
40
+ fp_list (List[str]): List of SMILES strings.
41
+ n_clusters (int): The number of clusters to use for clustering.
42
+ return_centroids (bool): If True, return the cluster centroids as well.
43
+
44
+ Returns:
45
+ np.ndarray: Array of cluster labels corresponding to each SMILES string in the input list.
46
+ """
47
+ km = KMeans(n_clusters=n_clusters, n_init='auto', random_state=42, max_iter=1000)
48
+ if return_centroids:
49
+ km.fit(np.stack(fp_list))
50
+ return km.labels_, km.cluster_centers_
51
+ return km.fit_predict(np.stack(fp_list))
52
+
53
+ def evaluate_clusters(X: np.array, clusters: np.ndarray) -> Dict[str, float]:
54
+ """ Compute clustering metrics and assess cluster size distribution.
55
+
56
+ Args:
57
+ X (np.array): The input data used for clustering.
58
+ clusters (np.ndarray): The cluster labels for each data point in X.
59
+
60
+ Returns:
61
+ Dict[str, float]: A dictionary containing various clustering metrics:
62
+ - silhouette: Silhouette score of the clustering.
63
+ - davies_bouldin: Davies-Bouldin index of the clustering.
64
+ - calinski_harabasz: Calinski-Harabasz index of the clustering.
65
+ - avg_cluster_size: Average size of clusters.
66
+ - avg_cluster_data_ratio: Ratio of average cluster size to total data size.
67
+ - std_cluster_size: Standard deviation of cluster sizes.
68
+ - min_cluster_size: Minimum size of clusters.
69
+ - median_cluster_size: Median size of clusters.
70
+ - max_cluster_size: Maximum size of clusters.
71
+ - cluster_size_skewness: Skewness of cluster sizes indicating imbalance.
72
+ - num_clusters: Number of unique clusters found.
73
+ """
74
+
75
+ unique_clusters = list(set(clusters))
76
+
77
+ if len(unique_clusters) < 2: # Avoid single-cluster issues
78
+ return {
79
+ "silhouette": -1,
80
+ "davies_bouldin": float("inf"),
81
+ "calinski_harabasz": -1,
82
+ "avg_cluster_size": len(X),
83
+ "avg_cluster_data_ratio": 1,
84
+ "std_cluster_size": 0,
85
+ "min_cluster_size": len(X),
86
+ "median_cluster_size": len(X),
87
+ "max_cluster_size": len(X),
88
+ "cluster_size_skewness": 0,
89
+ "num_clusters": 1,
90
+ }
91
+
92
+ # Compute standard clustering metrics
93
+ silhouette = silhouette_score(X, clusters)
94
+ davies_bouldin = davies_bouldin_score(X, clusters)
95
+ calinski_harabasz = calinski_harabasz_score(X, clusters)
96
+
97
+ # Compute cluster size statistics
98
+ cluster_sizes = [len(np.where(clusters == i)[0]) for i in np.unique(clusters)]
99
+ avg_cluster_size = np.mean(cluster_sizes)
100
+ avg_cluster_data_ratio = avg_cluster_size / len(X)
101
+ std_cluster_size = np.std(cluster_sizes)
102
+ median_cluster_size = np.median(cluster_sizes)
103
+ min_cluster_size = np.min(cluster_sizes)
104
+ max_cluster_size = np.max(cluster_sizes)
105
+ cluster_size_skewness = skew(cluster_sizes, nan_policy="omit") # Indicates imbalance in cluster sizes
106
+
107
+ return {
108
+ "silhouette": silhouette,
109
+ "davies_bouldin": davies_bouldin,
110
+ "calinski_harabasz": calinski_harabasz,
111
+ "avg_cluster_size": avg_cluster_size,
112
+ "avg_cluster_data_ratio": avg_cluster_data_ratio,
113
+ "std_cluster_size": std_cluster_size,
114
+ "min_cluster_size": min_cluster_size,
115
+ "median_cluster_size": median_cluster_size,
116
+ "max_cluster_size": max_cluster_size,
117
+ "cluster_size_skewness": cluster_size_skewness,
118
+ "num_clusters": len(unique_clusters),
119
+ }
120
+
121
+ def get_representative_e3s(
122
+ train_df: pd.DataFrame,
123
+ fp_generator: Optional[Any] = None,
124
+ n_clusters_candidates: List[int] = [10, 25, 50, 100, 150],
125
+ e3_column: str = 'E3 Binder SMILES with direction',
126
+ ) -> Tuple[List[str], List[Any], int, pd.DataFrame]:
127
+ """
128
+ Get representative E3 ligands from a DataFrame of training data by clustering their fingerprints.
129
+ This function computes Morgan fingerprints for unique E3 ligands, clusters them using KMeans and UMAP,
130
+ evaluates the clusters using silhouette, Davies-Bouldin, and Calinski-Harabasz scores, and identifies
131
+ the optimal number of clusters based on these metrics.
132
+ It returns the representative E3 ligands, their fingerprints, the best number of clusters, and a DataFrame
133
+ containing the clustering metrics.
134
+
135
+ Parameters:
136
+ train_df (pd.DataFrame): DataFrame containing training data with E3 ligands.
137
+ fp_generator (Optional[Any]): RDKit fingerprint generator. If None, a default Morgan fingerprint generator with 1024 bits and radius 6 is used.
138
+ n_clusters_candidates (List[int]): List of candidate numbers of clusters to evaluate.
139
+ e3_column (str): The column name in the DataFrame that contains the E3 ligand SMILES strings.
140
+
141
+ Returns:
142
+ Tuple[List[str], List[Any], int, pd.DataFrame]: A tuple containing:
143
+ - List of representative E3 ligand SMILES strings.
144
+ - List of RDKit fingerprints corresponding to the representative E3 ligands.
145
+ - The best number of clusters determined from the clustering metrics.
146
+ - DataFrame containing clustering metrics for each candidate number of clusters.
147
+ """
148
+ if e3_column not in train_df.columns:
149
+ raise ValueError(f"Column '{e3_column}' not found in the DataFrame.")
150
+
151
+ if fp_generator is None:
152
+ fp_generator = rdFingerprintGenerator.GetMorganGenerator(
153
+ radius=16,
154
+ fpSize=1024,
155
+ useBondTypes=True,
156
+ includeChirality=True,
157
+ )
158
+
159
+ fp_dict = {}
160
+ for smi in tqdm(train_df[e3_column].unique()):
161
+ fp = get_fp(remove_dummy_atoms(smi), fp_generator)
162
+ if fp is not None:
163
+ fp_dict[smi] = fp
164
+
165
+ fp_list = list(fp_dict.values())
166
+ fp2smiles = {fp.tobytes(): smi for smi, fp in fp_dict.items() if fp is not None}
167
+
168
+ centroids_dict = {}
169
+ clusters_dict = {}
170
+ metrics_df = []
171
+ for n_clusters in tqdm(n_clusters_candidates, desc="Clustering and evaluating"):
172
+ clusters, centroids = get_kmeans_clusters_fp(fp_list, n_clusters=n_clusters, return_centroids=True)
173
+ metrics = evaluate_clusters(fp_list, clusters)
174
+ clusters_dict[f'kmeans_n{n_clusters}'] = clusters.copy()
175
+ centroids_dict[n_clusters] = centroids.copy()
176
+
177
+ metrics['num_clusters'] = n_clusters
178
+ metrics['cluster_algorithm'] = 'kmeans'
179
+ metrics_df.append(metrics.copy())
180
+
181
+ clusters = get_umap_clusters_fp(fp_list, n_clusters=n_clusters)
182
+ metrics = evaluate_clusters(fp_list, clusters)
183
+ clusters_dict[f'umap_n{n_clusters}'] = clusters.copy()
184
+
185
+ metrics['num_clusters'] = n_clusters
186
+ metrics['cluster_algorithm'] = 'umap'
187
+ metrics_df.append(metrics.copy())
188
+
189
+ metrics_df = pd.DataFrame(metrics_df)
190
+
191
+ # Get the sweet spot for the number of clusters
192
+ # Flip davies_bouldin so that all metrics are to be maximized
193
+ metrics_df['-davies_bouldin'] = -metrics_df['davies_bouldin']
194
+
195
+ # Normalize all three metrics (by group if you want per algorithm)
196
+ metrics = ['silhouette', '-davies_bouldin', 'calinski_harabasz']
197
+ df_norm = metrics_df.copy()
198
+ df_norm[metrics] = df_norm.groupby('cluster_algorithm')[metrics].transform(
199
+ lambda x: (x - x.min()) / (x.max() - x.min())
200
+ )
201
+
202
+ # Measure divergence: standard deviation of normalized metrics per row
203
+ df_norm['metric_divergence'] = df_norm[metrics].std(axis=1)
204
+
205
+ # Pick the point with lowest divergence, possibly applying constraints (e.g. not too many clusters)
206
+ sweet_spots = df_norm.loc[df_norm.groupby('cluster_algorithm')['metric_divergence'].idxmin()]
207
+
208
+ best_n_clusters = sweet_spots[['num_clusters']]['num_clusters'].unique()[0]
209
+
210
+ # Get the centroids of the clusters
211
+ centroids = centroids_dict[best_n_clusters]
212
+
213
+ # Get the cluster labels for the centroids
214
+ clusters = np.array(clusters_dict[f'kmeans_n{n_clusters}'])
215
+ representative_e3s = []
216
+ representative_e3s_fp = []
217
+ for label, centroid in enumerate(centroids):
218
+ # Isolate the FP with the same label as the centroid
219
+ fp_cluster = np.array(fp_list)[clusters == label]
220
+ # Get the closest FP for the centroid, use euclidean distance
221
+ distances = np.linalg.norm(fp_cluster - centroid, axis=1)
222
+ closest_fp = np.argmin(distances)
223
+ # To get the SMILES from the FP, use the fp2smiles dictionary
224
+ closest_smiles = fp2smiles[fp_cluster[closest_fp].tobytes()]
225
+ # Append the closest SMILES to the representative_e3s list
226
+ representative_e3s.append(closest_smiles)
227
+ representative_e3s_fp.append(fp_cluster[closest_fp])
228
+
229
+ # Convert the representative E3s to RDKit fingerprints
230
+ representative_e3s_fp = [numpy_to_rdkit_fp(fp) for fp in representative_e3s_fp]
231
+
232
+ return representative_e3s, representative_e3s_fp, best_n_clusters, metrics_df
233
+
234
+
235
+ DEFAULT_REPRESENTATIVE_E3S = [
236
+ 'Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)CN[*:2])cc1',
237
+ 'O=C1CCC(N2Cc3c(N=[*:2])cccc3C2=O)C(=O)N1',
238
+ 'CC(=O)NC(C(=O)N1CC(O)CC1C(=O)[*:2])C(C)(C)C',
239
+ 'CN[C@@H](C)C(=O)N[C@H](C(=O)N1C[C@@H](Oc2ccccc2[*:2])C[C@H]1C(=O)N[C@@H]1CCCc2ccccc21)C1CCCCC1',
240
+ 'Cc1ncsc1-c1ccc(CNC(=O)C2CC(O)CN2C(=O)C(NC(=O)CCO[*:2])C(C)(C)C)cc1',
241
+ 'O=C1CCC(N2Cc3ccc([*:2])cc3C2=O)C(=O)N1',
242
+ 'COc1ccc(C2=N[C@@H](c3ccc(Cl)cc3)[C@@H](c3ccc(Cl)cc3)N2C(=O)N2CCN(CC(=O)[*:2])C(=O)C2)c(OC(C)C)c1',
243
+ 'CC(NC(=O)C1CC(O)CN1C(=O)C(N[*:2])C(C)(C)C)c1ccc(C2CC2)cc1',
244
+ 'CCOc1cc(C(C)(C)C)ccc1C1=NC(c2ccc(Cl)cc2)C(c2ccc(Cl)cc2)N1C(=O)N1CCN(CCCC[*:2])CC1',
245
+ 'CNC(C)C(=O)NC(C(=O)N1CCCC1c1cncc(C(=O)c2cccc([*:2])c2)c1)C1CCCCC1',
246
+ 'CN[C@@H](C)C(=O)N[C@H](C(=O)N1CCC[C@H]1c1nc(C(=O)c2ccc([*:2])cc2)cs1)C1CCCCC1',
247
+ 'O=C1CCC(N2C(=O)c3cccc(OC[*:2])c3C2=O)C(=O)N1',
248
+ 'CCOc1cc(C(C)(C)C)ccc1C1=NC(c2ccc(Cl)cc2)C(c2ccc(Cl)cc2)N1C(=O)N1CCN([*:2])CC1',
249
+ 'Cc1ncsc1-c1ccc(CNC(=O)[C@H]2C[C@H](O)CN2C(=O)C(N[*:2])C(C)(C)C)cc1',
250
+ 'Cc1ncsc1-c1ccc([C@H](C)NC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](N[*:2])C(C)(C)C)cc1',
251
+ 'CN[C@@H](C)C(=O)N[C@H](C(=O)N1CCC[C@H]1c1cncc(C(=O)c2cccc([*:2])c2)c1)C1CCCCC1',
252
+ 'Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](N[*:2])C(C)(C)C)c(OC2CCNCC2)c1',
253
+ 'CNC(C)C(=O)NC(C(=O)N1CC(Oc2ccc([*:2])cc2)CC1C(=O)NC1CCCc2ccccc21)C1CCCCC1',
254
+ 'C[C@H](NC(=O)[C@@H]1C[C@@H](O)CN1C(=O)[C@@H](N[*:2])C(C)(C)C)c1ccc(C(C)(C)C)cc1',
255
+ 'CNC(C)C(=O)NC(C(=O)N1CCCC1c1nc(C(=O)c2ccc([*:2])cc2)cs1)C1CCCCC1',
256
+ 'CC(=O)NC(C(=O)N1CC(O)CC1C(=O)NCc1ccc(-c2scnc2C)cc1[*:2])C(C)(C)C',
257
+ 'Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](NC(=O)C2(F)CC2)C(C)(C)C)c([*:2])c1',
258
+ 'CCOc1cc(C(C)(C)C)ccc1C1=NC(C)(c2ccc(Cl)cc2)C(C)(c2ccc(Cl)cc2)N1C(=O)N1CCN(CC(=O)[*:2])CC1',
259
+ 'COc1ccc(C(=O)[*:2])cc1N1CCC(=O)NC1=O',
260
+ 'CN[C@@H](C)C(=O)N[C@H](C(=O)N[C@H]1C[C@H]2CC[C@@H]1N(CCc1ccc([*:2])cc1)C2)C1CCCCC1',
261
+ 'CNC(C)C(=O)NC(C(=O)N1CC(N[*:2])CC1C(=O)NC1CCCc2ccccc21)C1CCCCC1',
262
+ 'CN[C@@H](C)C(=O)N[C@@H](CCCCN[*:2])C(=O)N1CCC[C@H]1C(=O)Nc1snnc1-c1ccccc1',
263
+ 'CNC(C)C(=O)NC(C(=O)NC1CC2CCC1N(CCc1cccc([*:2])c1)C2)C1CCCCC1',
264
+ 'O=C1CCC(N2C(=O)c3ccc(N[*:2])cc3C2=O)C(=O)N1',
265
+ 'CNC(C)C(=O)NC(C(=O)N1CC(NC(=O)CC[*:2])CC1C(=O)Nc1c(F)cccc1F)C(C)(C)C',
266
+ 'Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@H](N[*:2])C(C)(C)C)cc1',
267
+ 'Cc1nc[nH]c1-c1ccc(CNC(=O)C2CC(O)CN2C(=O)C(N[*:2])C(C)(C)C)cc1',
268
+ 'Cc1ncsc1-c1ccc(C(C)NC(=O)C2CC(O)CN2C(=O)C(N[*:2])C(C)(C)C)cc1',
269
+ 'Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O)[C@@H](N[*:2])C(C)(C)C)cc1',
270
+ 'O=C1CCC(c2cccc([*:2])c2)C(=O)N1',
271
+ 'CC(=O)N[C@H](C(=O)N1C[C@@H](O)C[C@@H]1C(=O)N[C@@H](CC(=O)N1CCC([*:2])CC1)c1ccccc1)C(C)C',
272
+ 'O=C(CCl)[*:2]',
273
+ 'CC[C@@H](NC(=O)[C@@H]1C[C@H](N[*:2])CN1C(=O)[C@@H](NC(=O)[C@H](C)NC)C(C)(C)C)c1ccccc1',
274
+ 'CN[C@H](C)C(=O)N[C@@H]1CCO[C@@H]2CC(C)(C)[C@H](C(=O)N[C@@H]3CCCc4cc([*:2])ccc43)N2C1=O',
275
+ 'CN[C@@H](C)C(=O)N[C@H](C(=O)N1CCC[C@H]1c1nc(C(=O)c2ccc(F)cc2)cs1)C1CCN(C[*:2])CC1',
276
+ 'Cc1ncsc1-c1ccc(CNC(=O)C2CC(O)CN2C(=O)C(N[*:2])C(C)(C)C)cc1',
277
+ 'CNC(C)C(=O)NC(CCCCN[*:2])C(=O)N1CCCC1C(=O)Nc1snnc1-c1ccccc1',
278
+ 'O=C1CCC(N2C(=O)c3cccc([*:2])c3C2=O)C(=O)O1',
279
+ 'COc1ccc(C2=N[C@@H](c3ccc(Cl)cc3)[C@@H](c3ccc(Cl)cc3)N2C(=O)N2CCN(CC(=O)[*:2])C(=O)C2)cc1OC(C)C',
280
+ 'Cc1ncsc1-c1ccc(CNC(=O)C2CC(O)CN2C(=O)C(N[*:2])C(C)(C)C)c(OC2CCNCC2)c1',
281
+ 'CNC(C)C(=O)NC(C(=O)N1CCCC1c1cncc(-n2ccc3c(C(=O)[*:2])cccc32)c1)C(C)C',
282
+ 'CCN1CCN(Cc2ccc(NC(=O)c3cccc(-c4ccc5nc(N[*:2])sc5n4)c3)cc2C(F)(F)F)CC1',
283
+ 'CN[C@@H](C)C(=O)N[C@H](C(=O)N1C[C@@H](NC(=O)CC[*:2])C[C@H]1C(=O)Nc1c(F)cccc1F)C(C)(C)C',
284
+ 'CNC(C)C(=O)NC(C(=O)N1CCCC1C(=O)NC(C(=O)[*:2])C(c1ccccc1)c1ccccc1)C1CCCCC1',
285
+ 'CC(=O)NCC(C(=O)N1CC(O)CC1C(=O)NC(CC(=O)N1CCC(N2CCC([*:2])CC2)CC1)c1ccccc1)C(C)C',
286
+ ]
287
+
288
+
289
+ @functools.lru_cache(maxsize=1, typed=False)
290
+ def get_representative_e3s_fp(
291
+ e3_list: Optional[List[str]] = None,
292
+ fp_generator: Optional[Any] = None,
293
+ verbose: int = 0,
294
+ ) -> List[DataStructs.ExplicitBitVect]:
295
+ """
296
+ Generate Morgan fingerprints for a list of E3 ligands. If no list is provided,
297
+ it uses a default list of representative E3 ligands.
298
+
299
+ Parameters:
300
+ e3_list (Optional[List[str]]): List of SMILES strings for E3 ligands. If None, uses a default list.
301
+ fp_generator (Optional[Any]): RDKit fingerprint generator. If None, a default Morgan fingerprint generator is used.
302
+
303
+ Returns:
304
+ List[DataStructs.ExplicitBitVect]: List of RDKit Morgan fingerprints for the E3 ligands.
305
+ """
306
+ representative_e3s_fp = []
307
+ if verbose > 0:
308
+ iterable = tqdm(e3_list or DEFAULT_REPRESENTATIVE_E3S, desc="Generating fingerprints for E3 ligands")
309
+ else:
310
+ iterable = e3_list or DEFAULT_REPRESENTATIVE_E3S
311
+ for smi in iterable:
312
+ # Get the Morgan fingerprint for the SMILES string
313
+ fp = get_fp(remove_dummy_atoms(smi), fp_generator, return_np=False)
314
+ if fp is not None:
315
+ representative_e3s_fp.append(fp)
316
+ else:
317
+ print(f"Warning: Invalid SMILES string '{smi}' encountered, skipping.")
318
+ if not representative_e3s_fp:
319
+ raise ValueError("No valid E3 ligands found in the provided list.")
320
+ return representative_e3s_fp
321
+
protac_splitter/graphs/edge_classifier.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ from pathlib import Path
3
+ from typing import Optional, List, Dict, Union, Any, Literal
4
+
5
+ import pandas as pd
6
+ import numpy as np
7
+ from sklearn.base import BaseEstimator, ClassifierMixin
8
+ from sklearn.compose import ColumnTransformer
9
+ from sklearn.preprocessing import StandardScaler, OneHotEncoder
10
+ from sklearn.decomposition import TruncatedSVD
11
+ from imblearn.over_sampling import SMOTE
12
+ from imblearn.pipeline import Pipeline as ImbPipeline
13
+ from sklearn.pipeline import Pipeline
14
+ from sklearn.metrics import classification_report
15
+ from sklearn.metrics import confusion_matrix
16
+ from xgboost import XGBClassifier
17
+ import optuna
18
+ from optuna.samplers import QMCSampler
19
+ from sklearn.metrics import accuracy_score, f1_score
20
+
21
+ try:
22
+ import seaborn as sns
23
+ import matplotlib.pyplot as plt
24
+ HAS_VISUALIZATION = True
25
+ except ImportError:
26
+ HAS_VISUALIZATION = False
27
+
28
+ from .edge_features import extract_edge_features, get_edge_features
29
+
30
+
31
+ class GraphEdgeClassifier(BaseEstimator, ClassifierMixin):
32
+ """
33
+ Edge-level graph classifier for PROTACs with integrated pipeline building.
34
+ """
35
+ def __init__(
36
+ self,
37
+ graph_features: List[str],
38
+ categorical_features: Optional[List[str]] = None,
39
+ descriptor_features: Optional[List[str]] = None,
40
+ fingerprint_features: Optional[List[str]] = None,
41
+ use_descriptors: bool = True,
42
+ use_fingerprints: bool = True,
43
+ scaler_graph: Literal["passthrough", "standard"] = "passthrough",
44
+ scaler_desc: Literal["passthrough", "standard"] = "passthrough",
45
+ use_svd_fp: bool = True,
46
+ n_svd_components: int = 100,
47
+ binary: bool = False,
48
+ smote_k_neighbors: Optional[int] = 5,
49
+ xgb_params: Optional[dict] = None,
50
+ n_bits: int = 512,
51
+ radius: int = 6,
52
+ descriptor_names: Optional[List[str]] = None
53
+ ):
54
+ self.graph_features = graph_features
55
+ self.categorical_features = categorical_features
56
+ self.descriptor_features = descriptor_features
57
+ self.fingerprint_features = fingerprint_features
58
+ self.use_descriptors = use_descriptors
59
+ self.use_fingerprints = use_fingerprints
60
+ self.scaler_graph = scaler_graph
61
+ self.scaler_desc = scaler_desc
62
+ self.use_svd_fp = use_svd_fp
63
+ self.n_svd_components = n_svd_components
64
+ self.binary = binary
65
+ self.smote_k_neighbors = smote_k_neighbors
66
+ self.xgb_params = xgb_params or {}
67
+ self.n_bits = n_bits
68
+ self.radius = radius
69
+ self.descriptor_names = descriptor_names or [
70
+ "MolWt", "HeavyAtomCount", "NumHAcceptors", "NumHDonors",
71
+ "TPSA", "NumRotatableBonds", "RingCount", "MolLogP"
72
+ ]
73
+ self.pipeline = self._build_pipeline()
74
+
75
+ def _build_pipeline(self):
76
+ transformers = []
77
+ if self.categorical_features:
78
+ transformers.append(("cat", OneHotEncoder(handle_unknown="ignore"), self.categorical_features))
79
+ if self.scaler_graph == "standard":
80
+ transformers.append(("num", StandardScaler(), self.graph_features))
81
+ else:
82
+ transformers.append(("num", "passthrough", self.graph_features))
83
+
84
+ if self.use_descriptors and self.descriptor_features:
85
+ desc_block = (
86
+ ("desc", StandardScaler(), self.descriptor_features)
87
+ if self.scaler_desc == "standard"
88
+ else ("desc", "passthrough", self.descriptor_features)
89
+ )
90
+ transformers.append(desc_block)
91
+
92
+ if self.use_fingerprints and self.fingerprint_features:
93
+ if self.use_svd_fp:
94
+ fp_block = ("fp",
95
+ ImbPipeline([
96
+ ("svd", TruncatedSVD(n_components=self.n_svd_components, random_state=42))
97
+ ]),
98
+ self.fingerprint_features)
99
+ else:
100
+ fp_block = ("fp", "passthrough", self.fingerprint_features)
101
+ transformers.append(fp_block)
102
+
103
+ preprocessor = ColumnTransformer(transformers)
104
+
105
+ # Define the classifier
106
+ classifier = XGBClassifier(
107
+ random_state=42,
108
+ eval_metric="logloss" if self.binary else "mlogloss",
109
+ objective="binary:logistic" if self.binary else "multi:softprob",
110
+ **self.xgb_params
111
+ )
112
+
113
+ if self.smote_k_neighbors is not None:
114
+ return ImbPipeline([
115
+ ("preprocess", preprocessor),
116
+ ("smote", SMOTE(random_state=42, k_neighbors=self.smote_k_neighbors)),
117
+ ("clf", classifier)
118
+ ])
119
+ else:
120
+ return Pipeline([
121
+ ("preprocess", preprocessor),
122
+ ("clf", classifier)
123
+ ])
124
+
125
+ def fit(self, X: pd.DataFrame, y: pd.Series):
126
+ self.pipeline.fit(X, y)
127
+ return self
128
+
129
+ def predict(self, X: Union[pd.DataFrame, List[Dict], List[str]]) -> Any:
130
+ X_proc = self._ensure_features(X)
131
+ return self.pipeline.predict(X_proc)
132
+
133
+ def predict_proba(self, X: Union[pd.DataFrame, List[Dict], List[str]]) -> Any:
134
+ X_proc = self._ensure_features(X)
135
+ return self.pipeline.predict_proba(X_proc)
136
+
137
+ def save(self, path: Union[str, Path]):
138
+ joblib.dump(self, str(path))
139
+
140
+ @classmethod
141
+ def load(cls, path: Union[str, Path]) -> "GraphEdgeClassifier":
142
+ return joblib.load(str(path))
143
+
144
+ @staticmethod
145
+ def extract_graph_features(
146
+ protac_smiles: Union[str, List[str]],
147
+ wh_smiles: Optional[Union[str, List[str]]] = None,
148
+ lk_smiles: Optional[Union[str, List[str]]] = None,
149
+ e3_smiles: Optional[Union[str, List[str]]] = None,
150
+ n_bits: int = 512,
151
+ radius: int = 6,
152
+ descriptor_names: Optional[List[str]] = None,
153
+ verbose: int = 0
154
+ ) -> pd.DataFrame:
155
+ if any(x is None for x in [wh_smiles, lk_smiles, e3_smiles]):
156
+ # Get features from PROTAC only, for inference
157
+ return extract_edge_features(
158
+ protac_smiles=protac_smiles,
159
+ n_bits=n_bits,
160
+ radius=radius,
161
+ descriptor_names=descriptor_names,
162
+ )
163
+ else:
164
+ # Get features and labels from all components, for training
165
+ return get_edge_features(
166
+ protac_smiles=protac_smiles,
167
+ wh_smiles=wh_smiles,
168
+ lk_smiles=lk_smiles,
169
+ e3_smiles=e3_smiles,
170
+ n_bits=n_bits,
171
+ radius=radius,
172
+ descriptor_names=descriptor_names,
173
+ verbose=verbose
174
+ )
175
+
176
+ @staticmethod
177
+ def build_multiclass_target(
178
+ df: pd.DataFrame,
179
+ poi_attachment_id: int = 1,
180
+ e3_attachment_id: int = 2,
181
+ ) -> pd.Series:
182
+ """
183
+ Returns multiclass target: 0 = no split, 1 = E3 split, 2 = WH split
184
+ """
185
+ assert ((df["label_e3_split"] + df["label_wh_split"]) <= 1).all()
186
+ y = (
187
+ df["label_wh_split"] * poi_attachment_id +
188
+ df["label_e3_split"] * e3_attachment_id
189
+ )
190
+ return y.astype("int32")
191
+
192
+ def _ensure_features(self, X: Union[pd.DataFrame, List[Dict], List[str]]) -> pd.DataFrame:
193
+ """ Filter out features/columns that are are not used in the pipeline. """
194
+ required_columns = (
195
+ (self.graph_features or []) +
196
+ (self.categorical_features or []) +
197
+ (self.descriptor_features or []) +
198
+ (self.fingerprint_features or [])
199
+ )
200
+ # If input is a DataFrame with SMILES, assume already featurized
201
+ if isinstance(X, pd.DataFrame):
202
+ Xf = X
203
+ elif isinstance(X, list) and isinstance(X[0], dict):
204
+ Xf = pd.DataFrame(X)
205
+ else:
206
+ raise ValueError("Provide either a DataFrame or list of feature dicts. Use extract_graph_features for SMILES.")
207
+ missing = set(required_columns) - set(Xf.columns)
208
+ if missing:
209
+ raise ValueError(f"Input data missing required columns: {missing}")
210
+ return Xf[required_columns].copy()
211
+
212
+ def predict_proba_from_smiles(
213
+ self,
214
+ protac_smiles: Union[str, List[str]],
215
+ wh_smiles: Union[str, List[str]],
216
+ lk_smiles: Union[str, List[str]],
217
+ e3_smiles: Union[str, List[str]],
218
+ verbose: int = 0,
219
+ ):
220
+ features = self.extract_graph_features(
221
+ protac_smiles, wh_smiles, lk_smiles, e3_smiles,
222
+ n_bits=self.n_bits,
223
+ radius=self.radius,
224
+ descriptor_names=self.descriptor_names,
225
+ verbose=verbose
226
+ )
227
+ Xf = self._ensure_features(features)
228
+ return self.pipeline.predict_proba(Xf)
229
+
230
+ def predict_from_smiles(
231
+ self,
232
+ protac_smiles: Union[str, List[str]],
233
+ wh_smiles: Union[str, List[str]],
234
+ lk_smiles: Union[str, List[str]],
235
+ e3_smiles: Union[str, List[str]],
236
+ top_n: int = 1,
237
+ return_array: bool = True,
238
+ verbose: int = 0,
239
+ ) -> Union[pd.DataFrame, np.ndarray]:
240
+ """
241
+ For binary classification:
242
+ For each SMILES, return the top_n edge chem_bond_idx indices among those predicted as class 1,
243
+ sorted by predicted probability. If not enough edges are class 1, pad with -1.
244
+ For multiclass:
245
+ For each SMILES, return the chem_bond_idx with highest probability for class 1 (E3 split)
246
+ and for class 2 (WH split). Shape: (num_smiles, 2).
247
+ If no edge is predicted as that class, value is -1.
248
+ """
249
+ features = self.extract_graph_features(
250
+ protac_smiles, wh_smiles, lk_smiles, e3_smiles,
251
+ n_bits=self.n_bits,
252
+ radius=self.radius,
253
+ descriptor_names=self.descriptor_names,
254
+ verbose=verbose
255
+ )
256
+ Xf = self._ensure_features(features)
257
+ pred_proba = self.pipeline.predict_proba(Xf)
258
+ pred_label = self.pipeline.predict(Xf)
259
+ features = features.copy()
260
+ features["pred_label"] = pred_label
261
+ features["pred_proba"] = pred_proba[:, 1] if pred_proba.shape[1] > 1 else pred_proba[:, 0]
262
+
263
+ unique_smiles = pd.Series(features["chem_mol_smiles"]).drop_duplicates().tolist()
264
+ groupby = features.groupby("chem_mol_smiles")
265
+
266
+ results = []
267
+
268
+ if return_array:
269
+ if pred_proba.shape[1] == 2: # Binary case
270
+ for mol_smiles in unique_smiles:
271
+ group = groupby.get_group(mol_smiles)
272
+ # Only consider edges predicted as label 1
273
+ edges_class1 = group[group["pred_label"] == 1]
274
+ # If none, pad with -1
275
+ if len(edges_class1) == 0:
276
+ results.append(np.full(top_n, -1))
277
+ continue
278
+ # Sort by proba, take top_n
279
+ top_edges = edges_class1.nlargest(top_n, "pred_proba")
280
+ idxs = top_edges["chem_bond_idx"].to_numpy()
281
+ if len(idxs) < top_n:
282
+ idxs = np.pad(idxs, (0, top_n - len(idxs)), constant_values=-1)
283
+ results.append(idxs[:top_n])
284
+ return np.vstack(results)
285
+ else: # Multiclass case
286
+ for mol_smiles in unique_smiles:
287
+ group = groupby.get_group(mol_smiles)
288
+ # For class 1
289
+ class1_idx = -1
290
+ if (group["pred_label"] == 1).any():
291
+ # Take the edge with highest class-1 probability
292
+ mask = group["pred_label"] == 1
293
+ idx1 = group.loc[mask, "pred_proba"].idxmax()
294
+ class1_idx = group.loc[idx1, "chem_bond_idx"]
295
+ # For class 2
296
+ class2_idx = -1
297
+ if (group["pred_label"] == 2).any():
298
+ mask = group["pred_label"] == 2
299
+ idx2 = group.loc[mask, "pred_proba"].idxmax()
300
+ class2_idx = group.loc[idx2, "chem_bond_idx"]
301
+ results.append([class1_idx, class2_idx])
302
+ return np.array(results, dtype=int)
303
+ else:
304
+ return features
305
+
306
+ def get_classification_report(y_true, y_pred, labels):
307
+ report = classification_report(y_true, y_pred, target_names=labels, output_dict=True)
308
+ df_report = pd.DataFrame(report).transpose().round(2)
309
+ print(df_report)
310
+ return df_report
311
+
312
+ def plot_confusion_matrix(y_true, y_pred, labels):
313
+ cm = confusion_matrix(y_true, y_pred)
314
+ if HAS_VISUALIZATION:
315
+ plt.figure(figsize=(8, 6))
316
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
317
+ plt.xlabel("Predicted")
318
+ plt.ylabel("True")
319
+ plt.title("Confusion Matrix")
320
+ plt.show()
321
+ else:
322
+ print("Visualization libraries not available. Skipping confusion matrix plot.")
323
+ print("Confusion Matrix:")
324
+ print(cm)
325
+
326
+ def get_classification_report_and_plot(y_true, y_pred, labels):
327
+ report = get_classification_report(y_true, y_pred, labels)
328
+ plot_confusion_matrix(y_true, y_pred, labels)
329
+ return report
330
+
331
+ def train_edge_classifier(
332
+ train_df: pd.DataFrame,
333
+ val_df: Optional[pd.DataFrame] = None,
334
+ test_df: Optional[pd.DataFrame] = None,
335
+ model_filename: Optional[Union[str, Path]] = None,
336
+ edge_classifier_kwargs: Optional[Dict[str, Any]] = None,
337
+ cache_dir: Optional[Union[str, Path]] = None,
338
+ return_reports: bool = True,
339
+ plot_confusion_matrix: bool = False,
340
+ ) -> GraphEdgeClassifier:
341
+ """
342
+ Train an edge-level graph classifier for PROTACs.
343
+
344
+ Args:
345
+ train_df (pd.DataFrame): Training data with columns:
346
+ - 'PROTAC SMILES'
347
+ - 'POI Ligand SMILES with direction'
348
+ - 'Linker SMILES with direction'
349
+ - 'E3 Binder SMILES with direction'
350
+ val_df (Optional[pd.DataFrame]): Validation data, same format as train_df.
351
+ test_df (Optional[pd.DataFrame]): Test data, same format as train_df.
352
+ model_filename (Optional[Union[str, Path]]): Path to save the trained model.
353
+ edge_classifier_kwargs (Optional[Dict[str, Any]]): Additional parameters for GraphEdgeClassifier.
354
+ return_reports (bool): Whether to return classification reports for validation and test sets.
355
+
356
+ Returns:
357
+ GraphEdgeClassifier: Trained edge classifier instance.
358
+ """
359
+ sets = {}
360
+ for set_name, df in [
361
+ ("train", train_df),
362
+ ("val", val_df),
363
+ ("test", test_df),
364
+ ]:
365
+ if cache_dir is not None:
366
+ cache_path = Path(cache_dir) / f"{set_name}.csv"
367
+ if cache_path.exists():
368
+ print(f"Loading cached features for {set_name} from {cache_path}")
369
+ sets[set_name] = pd.read_csv(cache_path)
370
+ continue
371
+ else:
372
+ print(f"Cache not found for {set_name}, extracting features...")
373
+
374
+ if df is None or df.empty:
375
+ continue
376
+
377
+ print(f"Set: {set_name}, size: {len(df):,}")
378
+ if 'PROTAC SMILES' not in df.columns or \
379
+ 'POI Ligand SMILES with direction' not in df.columns or \
380
+ 'Linker SMILES with direction' not in df.columns or \
381
+ 'E3 Binder SMILES with direction' not in df.columns:
382
+ raise ValueError(f"DataFrame for {set_name} is missing required columns: 'PROTAC SMILES', 'POI Ligand SMILES with direction', 'Linker SMILES with direction', 'E3 Binder SMILES with direction'.")
383
+
384
+ sets[set_name] = GraphEdgeClassifier.extract_graph_features(
385
+ df['PROTAC SMILES'].tolist(),
386
+ df['POI Ligand SMILES with direction'].tolist(),
387
+ df['Linker SMILES with direction'].tolist(),
388
+ df['E3 Binder SMILES with direction'].tolist(),
389
+ verbose=1,
390
+ )
391
+ # Drop rows with label_e3_split + label_wh_split > 1
392
+ sets[set_name] = sets[set_name][(sets[set_name]["label_e3_split"] + sets[set_name]["label_wh_split"]) <= 1]
393
+ print(f"Set: {set_name}, size: {len(sets[set_name]):,}")
394
+ if cache_dir is not None:
395
+ cache_path = Path(cache_dir) / f"{set_name}.csv"
396
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
397
+ sets[set_name].to_csv(cache_path, index=False)
398
+ print(f"Saved {set_name} features to {cache_path}")
399
+
400
+ train_set = sets["train"]
401
+ label_cols = [c for c in train_set.columns if c.startswith("label_")]
402
+ train_set = train_set.dropna(subset=label_cols)
403
+ train_set = train_set[(train_set["label_e3_split"] + train_set["label_wh_split"]) <= 1]
404
+ X_train = train_set.drop(columns=label_cols)
405
+
406
+ # Instantiate and train
407
+ clf = GraphEdgeClassifier(**edge_classifier_kwargs or {
408
+ "graph_features": [c for c in X_train.columns if c.startswith("graph_")],
409
+ "categorical_features": ["chem_bond_type", "chem_atom_u", "chem_atom_v"],
410
+ "fingerprint_features": [c for c in X_train.columns if c.startswith("chem_mol_fp_")],
411
+ "use_descriptors": False,
412
+ "use_fingerprints": True,
413
+ "n_svd_components": 50,
414
+ "binary": True,
415
+ "smote_k_neighbors": 10,
416
+ "xgb_params": {
417
+ "max_depth": 6,
418
+ "learning_rate": 0.3,
419
+ "alpha": 0.1, # Default: 0
420
+ "lambda": 0.5, # Default: 1
421
+ "gamma": 0.1, # Default: 0
422
+ },
423
+ })
424
+
425
+ # Prepare target variable according to classification type
426
+ if clf.binary:
427
+ y_train = train_set["label_is_split"].astype("int32")
428
+ else:
429
+ y_train = GraphEdgeClassifier.build_multiclass_target(train_set)
430
+
431
+ print(f"Training set size: {len(X_train):,}, labels: {y_train.unique()}")
432
+ clf.fit(X_train, y_train)
433
+ print("Training complete.")
434
+
435
+ if model_filename is not None:
436
+ clf.save(model_filename)
437
+ print(f"Model saved to {model_filename}")
438
+
439
+ target_labels = ["No Split", "Split"] if clf.binary else ["No Split", "WH-Linker", "E3-Linker"]
440
+
441
+ report = None
442
+ if "val" in sets:
443
+ # Get validation data
444
+ val_set = sets["val"].dropna(subset=label_cols)
445
+ val_set = val_set[(val_set["label_e3_split"] + val_set["label_wh_split"]) <= 1]
446
+ X_val = val_set.drop(columns=label_cols)
447
+ y_val = val_set["label_is_split"].astype("int32") if clf.binary else GraphEdgeClassifier.build_multiclass_target(val_set)
448
+ y_pred = clf.predict(X_val)
449
+ if plot_confusion_matrix:
450
+ report = get_classification_report_and_plot(y_val, y_pred, target_labels)
451
+ else:
452
+ report = get_classification_report(y_val, y_pred, target_labels)
453
+ print(f"Validation set classification report:\n{report.to_markdown(index=False)}")
454
+
455
+ if "test" in sets:
456
+ # Get test data
457
+ test_set = sets["test"].dropna(subset=label_cols)
458
+ test_set = test_set[(test_set["label_e3_split"] + test_set["label_wh_split"]) <= 1]
459
+ X_test = test_set.drop(columns=label_cols)
460
+ y_test = test_set["label_is_split"].astype("int32") if clf.binary else GraphEdgeClassifier.build_multiclass_target(test_set)
461
+ y_pred = clf.predict(X_test)
462
+ if plot_confusion_matrix:
463
+ report = get_classification_report_and_plot(y_test, y_pred, target_labels)
464
+ else:
465
+ report = get_classification_report(y_test, y_pred, target_labels)
466
+ print(f"Test set classification report:\n{report.to_markdown(index=False)}")
467
+
468
+ if return_reports:
469
+ return clf, report
470
+ else:
471
+ return clf
472
+
473
+
474
+ def objective(trial, train_df, val_df):
475
+ # HP space
476
+ max_depth = trial.suggest_int("max_depth", 3, 10)
477
+ learning_rate = trial.suggest_float("learning_rate", 0.01, 0.3, log=True)
478
+ alpha = trial.suggest_float("alpha", 0.0, 2.0)
479
+ reg_lambda = trial.suggest_float("lambda", 0.0, 2.0)
480
+ gamma = trial.suggest_float("gamma", 0.0, 1.0)
481
+ n_svd_components = trial.suggest_int("n_svd_components", 16, 128)
482
+ smote_k_neighbors = trial.suggest_int("smote_k_neighbors", 3, 15)
483
+ use_descriptors = trial.suggest_categorical("use_descriptors", [False, True])
484
+ use_fingerprints = trial.suggest_categorical("use_fingerprints", [True, False])
485
+
486
+ edge_classifier_kwargs = {
487
+ "graph_features": None, # Will be set in train_edge_classifier
488
+ "categorical_features": None,
489
+ "fingerprint_features": None,
490
+ "use_descriptors": use_descriptors,
491
+ "use_fingerprints": use_fingerprints,
492
+ "n_svd_components": n_svd_components,
493
+ "binary": True,
494
+ "smote_k_neighbors": smote_k_neighbors,
495
+ "xgb_params": {
496
+ "max_depth": max_depth,
497
+ "learning_rate": learning_rate,
498
+ "alpha": alpha,
499
+ "lambda": reg_lambda,
500
+ "gamma": gamma,
501
+ },
502
+ }
503
+
504
+ _, val_report = train_edge_classifier(
505
+ train_df=train_df,
506
+ val_df=val_df,
507
+ edge_classifier_kwargs=edge_classifier_kwargs,
508
+ return_reports=True,
509
+ )
510
+
511
+ # Evaluate metrics on validation set
512
+ # Assume val_report has columns: ['Label', 'precision', 'recall', 'f1-score', 'support']
513
+ # and that the binary positive class is "Split" or "1"
514
+ try:
515
+ f1_1 = float(val_report[val_report["Label"].isin(["Split", 1, "1"])]["f1-score"])
516
+ except Exception:
517
+ f1_1 = 0.0
518
+ try:
519
+ acc = float(val_report[val_report["Label"] == "accuracy"]["f1-score"])
520
+ except Exception:
521
+ acc = 0.0
522
+
523
+ # Multi-objective: prioritize F1 for minority class, but keep accuracy
524
+ # Adjust weight depending on task (here equal)
525
+ score = 0.5 * acc + 0.5 * f1_1
526
+ return score
527
+
528
+ def run_optuna_search(
529
+ train_df: pd.DataFrame,
530
+ val_df: pd.DataFrame,
531
+ n_trials: int = 50,
532
+ study_name: str = "edge_classifier_hp_search",
533
+ study_dir: str = "./optuna_studies",
534
+ seed: int = 42,
535
+ ) -> Any:
536
+ import os
537
+ os.makedirs(study_dir, exist_ok=True)
538
+ study_path = f"sqlite:///{os.path.join(study_dir, study_name)}.db"
539
+
540
+ study = optuna.create_study(
541
+ study_name=study_name,
542
+ direction="maximize",
543
+ sampler=QMCSampler(seed=seed, qmc_type="sobol"),
544
+ storage=study_path,
545
+ load_if_exists=True,
546
+ )
547
+ func = lambda trial: objective(trial, train_df, val_df)
548
+ study.optimize(func, n_trials=n_trials, show_progress_bar=True)
549
+
550
+ print("Best trial:")
551
+ print(study.best_trial)
552
+
553
+ # Train classifier with best HP and return it
554
+ best_params = study.best_trial.params
555
+ edge_classifier_kwargs = {
556
+ "graph_features": None,
557
+ "categorical_features": None,
558
+ "fingerprint_features": None,
559
+ "use_descriptors": best_params["use_descriptors"],
560
+ "use_fingerprints": best_params["use_fingerprints"],
561
+ "n_svd_components": best_params["n_svd_components"],
562
+ "binary": True,
563
+ "smote_k_neighbors": best_params["smote_k_neighbors"],
564
+ "xgb_params": {
565
+ "max_depth": best_params["max_depth"],
566
+ "learning_rate": best_params["learning_rate"],
567
+ "alpha": best_params["alpha"],
568
+ "lambda": best_params["lambda"],
569
+ "gamma": best_params["gamma"],
570
+ },
571
+ }
572
+ clf, _ = train_edge_classifier(
573
+ train_df=train_df,
574
+ val_df=val_df,
575
+ edge_classifier_kwargs=edge_classifier_kwargs,
576
+ return_reports=True,
577
+ )
578
+ study_file = os.path.join(study_dir, f"{study_name}_study.pkl")
579
+ import joblib
580
+ joblib.dump(study, study_file)
581
+ print(f"Optuna study saved to {study_file}")
582
+ return clf, study
protac_splitter/graphs/edge_features.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List
2
+
3
+ from rdkit import Chem
4
+ from rdkit.Chem import AllChem, Descriptors, Draw
5
+ import networkx as nx
6
+ import pandas as pd
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+
10
+ from protac_splitter.chemoinformatics import get_atom_idx_at_attachment
11
+ from protac_splitter.display_utils import safe_display, get_mapped_protac_img
12
+
13
+
14
+ def bond_capacity(bond: Chem.Bond) -> int:
15
+ """ Calculate the capacity of a bond based on its type and properties.
16
+ Parameters:
17
+ bond (Chem.Bond): The bond object from RDKit.
18
+ Returns:
19
+ int: The capacity of the bond, where higher values indicate less preference for cutting.
20
+ """
21
+ # High capacity for aromatic and ring bonds to avoid cutting them
22
+ if bond.GetIsAromatic() or bond.IsInRing():
23
+ return 1000 # very high capacity: avoid cutting aromatic bonds
24
+ elif bond.GetBondType() == Chem.BondType.SINGLE:
25
+ return 1 # low capacity: prefer to cut here
26
+ elif bond.GetBondType() == Chem.BondType.DOUBLE:
27
+ return 10 # medium penalty
28
+ elif bond.GetBondType() == Chem.BondType.TRIPLE:
29
+ return 20 # stronger penalty
30
+ else:
31
+ return 50 # fallback for unknown/rare types
32
+
33
+ def smiles_to_nx(
34
+ smiles: str,
35
+ use_capacity: bool = False,
36
+ ) -> nx.Graph:
37
+ """ Convert a SMILES string to a NetworkX graph.
38
+ Parameters:
39
+ smiles (str): The SMILES string to convert.
40
+ use_capacity (bool): Whether to use bond capacity as edge weights.
41
+ Returns:
42
+ nx.Graph: The NetworkX graph representation of the molecule.
43
+ """
44
+ mol = Chem.MolFromSmiles(smiles)
45
+ if mol is None:
46
+ raise ValueError(f"Input SMILES could not be parsed: {smiles}")
47
+ # Canonicalize the SMILES
48
+ mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol, canonical=True))
49
+ if mol is None:
50
+ raise ValueError(f"Input SMILES could not be canonicalized: {smiles}")
51
+ # Convert SMILES to NetworkX graph
52
+ G = nx.Graph()
53
+ if use_capacity:
54
+ for bond in mol.GetBonds():
55
+ capacity = bond_capacity(bond)
56
+ G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), capacity=capacity)
57
+ else:
58
+ for bond in mol.GetBonds():
59
+ G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())
60
+ return G
61
+
62
+ def extract_edge_features(
63
+ protac_smiles: str,
64
+ e3_split_pair: Tuple[int, int] = None,
65
+ wh_split_pair: Tuple[int, int] = None,
66
+ n_bits: int = 512,
67
+ radius: int = 6,
68
+ descriptor_names: List[str] = None,
69
+ fp_as_string: bool = False,
70
+ ) -> pd.DataFrame:
71
+ """Extract features from the edges of a PROTAC molecule represented as a SMILES string.
72
+
73
+ Parameters:
74
+ protac_smiles (str): SMILES representation of the PROTAC molecule.
75
+ e3_split_pair (Tuple[int, int]): Indices of the E3 split pair.
76
+ wh_split_pair (Tuple[int, int]): Indices of the warhead split pair.
77
+ n_bits (int): Number of bits for Morgan fingerprints.
78
+ radius (int): Radius for Morgan fingerprints.
79
+ descriptor_names (List[str]): List of RDKit descriptor names to compute.
80
+
81
+ Returns:
82
+ pd.DataFrame: DataFrame containing edge features.
83
+ """
84
+ mol = Chem.MolFromSmiles(protac_smiles)
85
+ if mol is None:
86
+ raise ValueError(f"Input SMILES could not be parsed: {protac_smiles}")
87
+ # Canonicalize the SMILES
88
+ mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol, canonical=True))
89
+ if mol is None:
90
+ raise ValueError(f"Input SMILES could not be canonicalized: {protac_smiles}")
91
+
92
+ # Step 1: Convert SMILES to NetworkX
93
+ G = smiles_to_nx(protac_smiles, use_capacity=False)
94
+
95
+ num_nodes = G.number_of_nodes()
96
+ num_edges = G.number_of_edges()
97
+
98
+ # Step 2: Create line graph and compute betweenness + degree
99
+ LG = nx.line_graph(G)
100
+ line_betweenness = nx.betweenness_centrality(LG, endpoints=True)
101
+ betweenness = nx.betweenness_centrality(G, endpoints=True)
102
+
103
+ # Compute k-hop degrees (number of nodes within 2, 3 hops)
104
+ # TODO: Shall I get the degree of the node in the line graph or the original graph?
105
+ line_degree = dict(LG.degree())
106
+ line_degree_r2 = {}
107
+ line_degree_r3 = {}
108
+ for node in LG.nodes():
109
+ # Nodes within radius 2 and 3 (excluding the center node)
110
+ neighbors_r2 = nx.single_source_shortest_path_length(LG, node, cutoff=2)
111
+ neighbors_r3 = nx.single_source_shortest_path_length(LG, node, cutoff=3)
112
+ line_degree_r2[node] = len([n for n, d in neighbors_r2.items() if d == 2])
113
+ line_degree_r3[node] = len([n for n, d in neighbors_r3.items() if d == 3])
114
+
115
+ degree = dict(G.degree())
116
+ degree_r2 = {}
117
+ degree_r3 = {}
118
+ for node in G.nodes():
119
+ # Nodes within radius 2 and 3 (excluding the center node)
120
+ neighbors_r2 = nx.single_source_shortest_path_length(G, node, cutoff=2)
121
+ neighbors_r3 = nx.single_source_shortest_path_length(G, node, cutoff=3)
122
+ degree_r2[node] = len([n for n, d in neighbors_r2.items() if d == 2])
123
+ degree_r3[node] = len([n for n, d in neighbors_r3.items() if d == 3])
124
+
125
+ if e3_split_pair is not None and wh_split_pair is not None:
126
+ true_split_edges = {frozenset(e3_split_pair), frozenset(wh_split_pair)}
127
+
128
+ # Get molecular characteristics, i.e., Morgan fingerprints and descriptors
129
+ # Generate Morgan fingerprint
130
+ fp_bitvec = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
131
+ fp = np.zeros((n_bits,), dtype=np.float32)
132
+ AllChem.DataStructs.ConvertToNumpyArray(fp_bitvec, fp)
133
+ if fp_as_string:
134
+ fp = {"chem_mol_fp": "".join([str(int(bit)) for bit in fp])}
135
+ else:
136
+ fp = {f"chem_mol_fp_{i}": bool(fp[i]) for i in range(n_bits)}
137
+ # Generate RDKit descriptors
138
+ descriptor_func_names = descriptor_names or [
139
+ "MolWt", "HeavyAtomCount", "NumHAcceptors", "NumHDonors",
140
+ "TPSA", "NumRotatableBonds", "RingCount", "MolLogP"
141
+ ]
142
+ functions = [getattr(Descriptors, name) for name in descriptor_func_names]
143
+ descriptors = {f"chem_mol_desc_{name}": func(mol) for name, func in zip(descriptor_func_names, functions)}
144
+
145
+ # Step 3: Gather edge features
146
+ # NOTE: Only consider bridge nodes
147
+ edge_features = []
148
+ for (u, v) in nx.bridges(G):
149
+ bond = mol.GetBondBetweenAtoms(u, v)
150
+
151
+ # Avoid reporting the same edge twice (i.e., swap u and v if needed) and
152
+ # ensure to find the node pair in the line graph
153
+ node = (u, v) if (u, v) in LG else (v, u)
154
+ node_key = node if node in line_betweenness else (v, u)
155
+
156
+ features = {
157
+ "graph_num_nodes": num_nodes,
158
+ "graph_num_edges": num_edges,
159
+ "graph_betweenness": line_betweenness.get(node_key, 0.0),
160
+ "graph_degree": line_degree.get(node_key, 0),
161
+ "graph_degree_r2": line_degree_r2.get(node_key, 0),
162
+ "graph_degree_r3": line_degree_r3.get(node_key, 0),
163
+ "graph_node_u_degree": degree.get(u, 0),
164
+ "graph_node_u_degree_r2": degree_r2.get(u, 0),
165
+ "graph_node_u_degree_r3": degree_r3.get(u, 0),
166
+ "graph_node_v_degree": degree.get(v, 0),
167
+ "graph_node_v_degree_r2": degree_r2.get(v, 0),
168
+ "graph_node_v_degree_r3": degree_r3.get(v, 0),
169
+ "graph_node_u_betweenness": betweenness.get(u, 0.0),
170
+ "graph_node_v_betweenness": betweenness.get(v, 0.0),
171
+ "chem_bond_idx": bond.GetIdx(),
172
+ "chem_bond_type": str(bond.GetBondType()),
173
+ "chem_atom_u": mol.GetAtomWithIdx(u).GetSymbol(),
174
+ "chem_atom_v": mol.GetAtomWithIdx(v).GetSymbol(),
175
+ "chem_is_aromatic": bond.GetIsAromatic(),
176
+ "chem_is_in_ring": bond.IsInRing(),
177
+ "chem_mol_smiles": protac_smiles,
178
+ "chem_mol_n_bits": n_bits,
179
+ "chem_mol_radius": radius,
180
+ }
181
+ # Add RDKit descriptors and Morgan fingerprint
182
+ features.update(fp)
183
+ features.update(descriptors)
184
+
185
+ # Add E3 and warhead split labels
186
+ if e3_split_pair is not None and wh_split_pair is not None:
187
+ features.update({
188
+ "label_is_split": frozenset([u, v]) in true_split_edges,
189
+ "label_e3_split": frozenset([u, v]) == frozenset(e3_split_pair),
190
+ "label_wh_split": frozenset([u, v]) == frozenset(wh_split_pair),
191
+ })
192
+
193
+ # Append the features to the list of edge features
194
+ edge_features.append(features)
195
+
196
+ df = pd.DataFrame(edge_features)
197
+
198
+ # Identify columns with int64 dtype
199
+ int64_cols = df.select_dtypes(include=['int64']).columns
200
+
201
+ # Create a dictionary mapping these columns to int32
202
+ dtype_mapping = {col: np.int32 for col in int64_cols}
203
+
204
+ # Apply the type conversion
205
+ df = df.astype(dtype_mapping)
206
+
207
+ return df
208
+
209
+ def get_edge_features(
210
+ protac_smiles: str | List[str],
211
+ wh_smiles: str | List[str],
212
+ lk_smiles: str | List[str],
213
+ e3_smiles: str | List[str],
214
+ n_bits: int = 512,
215
+ radius: int = 6,
216
+ descriptor_names: List[str] = None,
217
+ fp_as_string: bool = False,
218
+ verbose: int = 0,
219
+ ) -> pd.DataFrame:
220
+ """Get edge features for a given PROTAC molecule and its components.
221
+
222
+ Parameters:
223
+ protac_smiles (str | List[str]): SMILES representation of the PROTAC molecule.
224
+ wh_smiles (str | List[str]): SMILES representation of the warhead.
225
+ lk_smiles (str | List[str]): SMILES representation of the linker.
226
+ e3_smiles (str | List[str]): SMILES representation of the E3 binder.
227
+ n_bits (int): Number of bits for Morgan fingerprints.
228
+ radius (int): Radius for Morgan fingerprints.
229
+ descriptor_names (List[str]): List of RDKit descriptor names to compute.
230
+
231
+ Returns:
232
+ pd.DataFrame: DataFrame containing edge features.
233
+ """
234
+ if isinstance(protac_smiles, str):
235
+ protac_smiles = [protac_smiles]
236
+ if isinstance(wh_smiles, str):
237
+ wh_smiles = [wh_smiles]
238
+ if isinstance(lk_smiles, str):
239
+ lk_smiles = [lk_smiles]
240
+ if isinstance(e3_smiles, str):
241
+ e3_smiles = [e3_smiles]
242
+
243
+ iterables = zip(protac_smiles, wh_smiles, lk_smiles, e3_smiles)
244
+ iterables = tqdm(iterables, desc="Extracting edge features", total=len(protac_smiles), disable=verbose == 0)
245
+ features_list = []
246
+ for protac_smi, wh_smi, lk_smi, e3_smi in iterables:
247
+ if verbose > 1:
248
+ get_mapped_protac_img(protac_smi, wh_smi, lk_smi, e3_smi, w=1500, h=600, display_image=True, useSVG=True)
249
+
250
+ # Convert SMILES to RDKit molecules
251
+ protac = Chem.MolFromSmiles(protac_smi)
252
+ wh = Chem.MolFromSmiles(wh_smi)
253
+ lk = Chem.MolFromSmiles(lk_smi)
254
+ e3 = Chem.MolFromSmiles(e3_smi)
255
+ if protac is None or wh is None or lk is None or e3 is None:
256
+ raise ValueError(f"Invalid SMILES string: {protac}, {wh}, {lk}, {e3}")
257
+
258
+ # Get the attachment points
259
+ wh_edge = get_atom_idx_at_attachment(protac, wh, lk)
260
+ e3_edge = get_atom_idx_at_attachment(protac, e3, lk)
261
+
262
+ # Extract features
263
+ features = extract_edge_features(
264
+ protac_smi,
265
+ e3_split_pair=e3_edge,
266
+ wh_split_pair=wh_edge,
267
+ n_bits=n_bits,
268
+ radius=radius,
269
+ descriptor_names=descriptor_names,
270
+ fp_as_string=fp_as_string,
271
+ )
272
+
273
+ if verbose > 1:
274
+ # Randomly sample and display 5 edges
275
+ sample_edges = features.sample(n=5, random_state=42)
276
+ # Display the sampled edges
277
+ for _, row in sample_edges.iterrows():
278
+ bond = protac.GetBondWithIdx(row['chem_bond_idx'])
279
+ u, v = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
280
+ safe_display(Draw.MolToImage(
281
+ protac,
282
+ size=(1500, 400),
283
+ highlightColor=(1, 0, 1, 0.3), # Light purple
284
+ highlightAtoms=[u, v], # Highlight the two atoms
285
+ legend=f"Graph nodes: {u}, {v} (Betweenness centrality: {row['graph_betweenness']:.3f})",
286
+ ))
287
+ # print(row[[c for c in features.columns if c.startswith('graph_')] + ['chem_atom_u', 'chem_atom_v', 'chem_is_in_ring']])
288
+ print(row)
289
+
290
+ # Append the features to the list
291
+ features_list.append(features)
292
+
293
+ return pd.concat(features_list, ignore_index=True)
protac_splitter/graphs/splitting_algorithms.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Dict, Any, Optional, List, Union
3
+ from pathlib import Path
4
+ from joblib import Parallel, delayed
5
+
6
+ import numpy as np
7
+ import networkx as nx
8
+ from rdkit import Chem, DataStructs
9
+ from rdkit.Chem import rdFingerprintGenerator
10
+
11
+ from .edge_classifier import GraphEdgeClassifier
12
+ from .e3_clustering import get_representative_e3s_fp
13
+ from .utils import average_tanimoto_distance
14
+ from protac_splitter.data.curation.bond_adjustments import (
15
+ adjust_amide_bonds_in_substructs,
16
+ adjust_ester_bonds_in_substructs
17
+ )
18
+
19
+ def bond_capacity(bond: Chem.Bond) -> int:
20
+ if bond.GetIsAromatic() or bond.IsInRing():
21
+ return 1000 # very high capacity: avoid cutting aromatic bonds
22
+ elif bond.GetBondType() == Chem.BondType.SINGLE:
23
+ return 1 # low capacity: prefer to cut here
24
+ elif bond.GetBondType() == Chem.BondType.DOUBLE:
25
+ return 10 # medium penalty
26
+ elif bond.GetBondType() == Chem.BondType.TRIPLE:
27
+ return 20 # stronger penalty
28
+ else:
29
+ return 50 # fallback for unknown/rare types
30
+
31
+ def smiles_to_nx(smiles: str) -> nx.Graph:
32
+ mol = Chem.MolFromSmiles(smiles)
33
+ G = nx.Graph()
34
+ for bond in mol.GetBonds():
35
+ capacity = bond_capacity(bond)
36
+ G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), capacity=capacity)
37
+ return G
38
+
39
+ def extract_attachment_point(smiles):
40
+ """
41
+ Extracts the number X from the pattern [X*] in a SMILES string.
42
+
43
+ Parameters:
44
+ smiles (str): The SMILES string containing the attachment point.
45
+
46
+ Returns:
47
+ str or None: The extracted number as a string, or None if not found.
48
+ """
49
+ match = re.search(r'\[(\d+)\*\]', smiles)
50
+ return match.group(1) if match else None
51
+
52
+ def split_protac_with_betweenness_centrality(
53
+ protac_smiles: str,
54
+ representative_e3s_fp: List[DataStructs.ExplicitBitVect] = None,
55
+ morgan_fp_generator: Optional[Any] = None,
56
+ use_capacity_weight: bool = False,
57
+ betweenness_threshold: float = 0.4,
58
+ ) -> Dict[str, str]:
59
+ """
60
+ Split the PROTAC molecule into two parts using the NetworkX library.
61
+
62
+ Parameters:
63
+ protac_smiles (str): The SMILES string of the PROTAC molecule.
64
+ representative_e3s_fp (list): List of representative E3 ligands fingerprints.
65
+ morgan_fp_generator: RDKit Morgan fingerprint generator (should be the same as the one that generated the E3 fingerprints).
66
+ use_capacity_weight (bool): Whether to use bond capacity as weight for the graph.
67
+ betweenness_threshold (float): Threshold for betweenness centrality to consider a node as a candidate for splitting.
68
+
69
+ Returns:
70
+ dict: A dictionary containing the E3 ligand, warhead, linker, top nodes, and max centrality score.
71
+ """
72
+ if morgan_fp_generator is None:
73
+ # Create a default Morgan fingerprint generator
74
+ morgan_fp_generator = rdFingerprintGenerator.GetMorganGenerator(
75
+ radius=16,
76
+ fpSize=1024,
77
+ useBondTypes=True,
78
+ includeChirality=True,
79
+ )
80
+
81
+ if representative_e3s_fp is None:
82
+ # Get the representative E3 ligands fingerprints
83
+ representative_e3s_fp = get_representative_e3s_fp(fp_generator=morgan_fp_generator)
84
+
85
+ # -----------------------------------
86
+ # Deterministic graph-based algorithm
87
+ # -----------------------------------
88
+ protac = Chem.MolFromSmiles(protac_smiles)
89
+ if protac is None:
90
+ raise ValueError(f"Invalid SMILES string: {protac_smiles}")
91
+
92
+ G = smiles_to_nx(protac_smiles)
93
+
94
+ # Compute betweenness centrality
95
+ weight = 'capacity' if use_capacity_weight else None
96
+ centrality = nx.betweenness_centrality(G, normalized=True, endpoints=True, weight=weight)
97
+
98
+ # Get the two nodes with the highest betweenness centrality
99
+ sorted_nodes = sorted(centrality.items(), key=lambda x: x[1], reverse=True)
100
+
101
+ # Get the list of bridges in the graph
102
+ bridges = list(nx.bridges(G))
103
+
104
+ # Get the top two nodes
105
+ top_nodes = [n for n, _ in sorted_nodes if n in bridges][:2]
106
+
107
+ # Get the top nodes with the highest betweenness centrality that are not in
108
+ # a ring, but are adjacent to the top nodes or have a high betweenness
109
+ for node, score in sorted_nodes:
110
+ # Check if the node is in a ring in the protac molecule
111
+ atom = protac.GetAtomWithIdx(node)
112
+ if not atom.IsInRing():
113
+ # Check if the atom is adjacent to any of the top nodes, if so, add it to the list
114
+ for neighbor in G.neighbors(node):
115
+ if neighbor in top_nodes:
116
+ top_nodes.append(node)
117
+ break
118
+ if score > betweenness_threshold:
119
+ top_nodes.append(node)
120
+
121
+ # If a node as only top nodes as neighbors, add it to the list
122
+ for node in G.nodes():
123
+ if node not in top_nodes:
124
+ neighbors = list(G.neighbors(node))
125
+ if all(neighbor in top_nodes for neighbor in neighbors):
126
+ top_nodes.append(node)
127
+
128
+ # Get all paths between the top nodes, e.g., rings
129
+ for i in range(len(top_nodes)):
130
+ for j in range(i + 1, len(top_nodes)):
131
+ node1 = top_nodes[i]
132
+ node2 = top_nodes[j]
133
+
134
+ for path in nx.all_simple_paths(G, node1, node2):
135
+ for node in path:
136
+ if node not in top_nodes:
137
+ top_nodes.append(node)
138
+
139
+ # Remove duplicates
140
+ top_nodes = list(set(top_nodes))
141
+
142
+ # Loop over the top nodes and find the nodes that have a neighbor outside
143
+ # the top nodes
144
+ edge_nodes = set()
145
+ for top_node in top_nodes:
146
+ for neighbor in G.neighbors(top_node):
147
+ if neighbor not in top_nodes:
148
+ edge_nodes.update([(top_node, neighbor)])
149
+ break
150
+
151
+ # Get molecule fragment from the top nodes
152
+ bonds = [protac.GetBondBetweenAtoms(i, j) for (i, j) in edge_nodes]
153
+ bonds_idx = [bond.GetIdx() for bond in bonds if bond is not None]
154
+
155
+ # Try any pair of indexes, if the number of resulting fragments is not 3,
156
+ # then do not consider them as candidates for splitting
157
+ candidate_bonds = []
158
+ for i in range(len(bonds_idx)):
159
+ for j in range(i + 1, len(bonds_idx)):
160
+ bond1 = bonds_idx[i]
161
+ bond2 = bonds_idx[j]
162
+
163
+ # Get the fragments
164
+ fragments = Chem.FragmentOnBonds(protac, [bond1, bond2])
165
+
166
+ # Check if there are 3 fragments
167
+ if Chem.MolToSmiles(fragments).count(".") == 2:
168
+ frag_lens = []
169
+ avg_len = 0
170
+ for frag in Chem.GetMolFrags(fragments, asMols=True):
171
+ frag_len = frag.GetNumAtoms()
172
+ frag_lens.append(frag_len)
173
+ avg_len += frag_len
174
+ avg_len /= 3
175
+
176
+ # Calculate the standard deviation of the fragment lengths
177
+ len_std = 0
178
+ for frag_len in frag_lens:
179
+ len_std += (frag_len - avg_len) ** 2
180
+ len_std = (len_std / 3) ** 0.5
181
+ candidate_bonds.append(((bond1, bond2), len_std))
182
+
183
+ # Sort the candidate bonds by distance to average (smallest first)
184
+ candidate_bonds = sorted(candidate_bonds, key=lambda x: x[1])
185
+
186
+ ligands = None
187
+ while ligands is None and len(candidate_bonds) > 0:
188
+ bonds_idx = candidate_bonds[0][0]
189
+ try:
190
+ ligands = Chem.FragmentOnBonds(protac, bonds_idx, addDummies=True, dummyLabels=[(1, 1), (2, 2)])
191
+ except Exception as e:
192
+ print(f"Error fragmenting the molecule: {e}")
193
+ candidate_bonds.pop(0)
194
+
195
+ # If no candidate bonds were found, return None
196
+ if ligands is None:
197
+ print(f"No candidate bonds found for splitting PROTAC: {protac_smiles}")
198
+ return {'e3': None, 'poi': None, 'linker': None, 'top_nodes': None, 'centrality': None}
199
+
200
+ # Get the linker
201
+ substructures = []
202
+ for ligand in Chem.GetMolFrags(ligands, asMols=True):
203
+ ligand_smiles = Chem.MolToSmiles(ligand, canonical=True)
204
+ if ligand_smiles.count("*") == 2:
205
+ linker_smiles = ligand_smiles
206
+ else:
207
+ substructures.append(ligand_smiles)
208
+
209
+ sub1_dist = average_tanimoto_distance(substructures[0], representative_e3s_fp, morgan_fp_generator)
210
+ sub2_dist = average_tanimoto_distance(substructures[1], representative_e3s_fp, morgan_fp_generator)
211
+ if sub1_dist < sub2_dist:
212
+ e3_smiles = substructures[0]
213
+ wh_smiles = substructures[1]
214
+ else:
215
+ e3_smiles = substructures[1]
216
+ wh_smiles = substructures[0]
217
+
218
+ # Get the attachment point using a regex, e.g., should return 1 if [1*] is in the SMILES
219
+ e3_attach_point = extract_attachment_point(e3_smiles)
220
+ e3_smiles = e3_smiles.replace(f"[{e3_attach_point}*]", "[*:2]")
221
+ linker_smiles = linker_smiles.replace(f"[{e3_attach_point}*]", "[*:2]")
222
+
223
+ wh_attach_point = extract_attachment_point(wh_smiles)
224
+ wh_smiles = wh_smiles.replace(f"[{wh_attach_point}*]", "[*:1]")
225
+ linker_smiles = linker_smiles.replace(f"[{wh_attach_point}*]", "[*:1]")
226
+ return {'e3': e3_smiles, 'poi': wh_smiles, 'linker': linker_smiles, 'top_nodes': top_nodes, 'centrality': centrality}
227
+
228
+
229
+ def split_protac_with_edge_classifier(
230
+ protac_smiles: str,
231
+ pipeline: Union[str, Path],
232
+ representative_e3s_fp: Optional[List[np.array]] = None,
233
+ morgan_fp_generator: Optional[Any] = None,
234
+ ) -> Dict[str, str]:
235
+ """ Split the PROTAC molecule into two parts using the pretrained edge classifier.
236
+
237
+ Parameters:
238
+ protac_smiles (str): The SMILES string of the PROTAC molecule.
239
+ pipeline (Union[str, Path]): Path to the trained GraphEdgeClassifier model.
240
+ representative_e3s_fp (Optional[List[np.array]]): Precomputed fingerprints of representative E3 ligands.
241
+ morgan_fp_generator (Optional[Any]): RDKit Morgan fingerprint generator (should be the same as the one that generated the E3 fingerprints).
242
+
243
+ Returns:
244
+ dict: A dictionary containing the E3 ligand, warhead, linker, and bonds_idx
245
+ """
246
+ if morgan_fp_generator is None:
247
+ # Create a default Morgan fingerprint generator
248
+ morgan_fp_generator = rdFingerprintGenerator.GetMorganGenerator(
249
+ radius=16,
250
+ fpSize=1024,
251
+ useBondTypes=True,
252
+ includeChirality=True,
253
+ )
254
+
255
+ if representative_e3s_fp is None:
256
+ # Get the representative E3 ligands fingerprints
257
+ representative_e3s_fp = get_representative_e3s_fp(fp_generator=morgan_fp_generator)
258
+
259
+ protac = Chem.MolFromSmiles(protac_smiles)
260
+ if protac is None:
261
+ raise ValueError(f"Invalid SMILES string: {protac_smiles}")
262
+
263
+ if isinstance(pipeline, str):
264
+ pipeline = GraphEdgeClassifier.load(pipeline)
265
+
266
+ # TODO: Get the top-n bonds, if splitting results in more than 3 ligands,
267
+ # test other pairs of bonds, then repeat until we get 3 ligands exactly.
268
+ bonds_idx = pipeline.predict_from_smiles(
269
+ protac_smiles,
270
+ wh_smiles=None,
271
+ lk_smiles=None,
272
+ e3_smiles=None,
273
+ top_n=2,
274
+ return_array=True,
275
+ ).flatten().tolist()
276
+ # print(f"Predicted bonds: {bonds_idx}")
277
+
278
+ if -1 in bonds_idx:
279
+ bonds_idx = [bond for bond in bonds_idx if bond != -1]
280
+ # Randomly select a bond index from the PROTAC molecule
281
+ # that is not in the predicted bonds
282
+ for _ in range(2 - len(bonds_idx)):
283
+ bond = np.random.choice([bond.GetIdx() for bond in protac.GetBonds() if bond.GetIdx() not in bonds_idx and not bond.IsInRing()])
284
+ bonds_idx.append(int(bond))
285
+
286
+ ligands = Chem.FragmentOnBonds(protac, bonds_idx, addDummies=True, dummyLabels=[(1, 1), (2, 2)])
287
+
288
+ # Get the linker
289
+ substructures = []
290
+ for ligand in Chem.GetMolFrags(ligands, asMols=True):
291
+ ligand_smiles = Chem.MolToSmiles(ligand, canonical=True)
292
+ if ligand_smiles.count("*") == 2:
293
+ linker_smiles = ligand_smiles
294
+ else:
295
+ substructures.append(ligand_smiles)
296
+
297
+ if not pipeline.binary:
298
+ e3_smiles = substructures[0]
299
+ wh_smiles = substructures[1]
300
+ # NOTE: The classifier was trained on the following labels assignment:
301
+ e3_attach_point = 1
302
+ wh_attach_point = 2
303
+ else:
304
+ if representative_e3s_fp is None or morgan_fp_generator is None:
305
+ raise ValueError("For pipeline trained on binary classification, representative_e3s_fp and morgan_fp_generator must be provided.")
306
+ sub1_dist = average_tanimoto_distance(substructures[0], representative_e3s_fp, morgan_fp_generator)
307
+ sub2_dist = average_tanimoto_distance(substructures[1], representative_e3s_fp, morgan_fp_generator)
308
+ if sub1_dist < sub2_dist:
309
+ e3_smiles = substructures[0]
310
+ wh_smiles = substructures[1]
311
+ else:
312
+ e3_smiles = substructures[1]
313
+ wh_smiles = substructures[0]
314
+ # Get the attachment point using a regex, e.g., should return 1 if [1*] is in the SMILES
315
+ e3_attach_point = extract_attachment_point(e3_smiles)
316
+ wh_attach_point = extract_attachment_point(wh_smiles)
317
+
318
+ e3_smiles = e3_smiles.replace(f"[{e3_attach_point}*]", "[*:2]")
319
+ linker_smiles = linker_smiles.replace(f"[{e3_attach_point}*]", "[*:2]")
320
+
321
+ wh_smiles = wh_smiles.replace(f"[{wh_attach_point}*]", "[*:1]")
322
+ linker_smiles = linker_smiles.replace(f"[{wh_attach_point}*]", "[*:1]")
323
+ return {'e3': e3_smiles, 'poi': wh_smiles, 'linker': linker_smiles, "bonds_idx": bonds_idx}
324
+
325
+ def split_protac_graph_based(
326
+ protac_smiles: str,
327
+ use_classifier: bool = False,
328
+ classifier: Optional['GraphEdgeClassifier'] = None,
329
+ representative_e3s_fp: Optional[List[Any]] = None,
330
+ morgan_fp_generator: Optional[Any] = None,
331
+ use_capacity_weight: bool = False,
332
+ betweenness_threshold: float = 0.4,
333
+ ) -> Dict[str, str]:
334
+ """
335
+ Splits a PROTAC molecule using either ML classifier or deterministic betweenness centrality.
336
+ Returns a dictionary with e3, poi, linker, bonds_idx.
337
+ """
338
+
339
+ if representative_e3s_fp is None:
340
+ if morgan_fp_generator is None:
341
+ # Create a default Morgan fingerprint generator
342
+ morgan_fp_generator = rdFingerprintGenerator.GetMorganGenerator(
343
+ radius=16,
344
+ fpSize=1024,
345
+ useBondTypes=True,
346
+ includeChirality=True,
347
+ )
348
+ # Get the representative E3 ligands fingerprints
349
+ representative_e3s_fp = get_representative_e3s_fp(fp_generator=morgan_fp_generator)
350
+
351
+ if use_classifier:
352
+ ret = split_protac_with_edge_classifier(
353
+ protac_smiles=protac_smiles,
354
+ pipeline=classifier,
355
+ representative_e3s_fp=representative_e3s_fp,
356
+ morgan_fp_generator=morgan_fp_generator,
357
+ )
358
+ else:
359
+ ret = split_protac_with_betweenness_centrality(
360
+ protac_smiles=protac_smiles,
361
+ representative_e3s_fp=representative_e3s_fp,
362
+ morgan_fp_generator=morgan_fp_generator,
363
+ use_capacity_weight=use_capacity_weight,
364
+ betweenness_threshold=betweenness_threshold,
365
+ )
366
+
367
+ substructs = {
368
+ "e3": ret["e3"],
369
+ "poi": ret["poi"],
370
+ "linker": ret["linker"],
371
+ }
372
+
373
+ # If all of the substructures are not None, fix the amide and ester bonds
374
+ if all(x is not None for x in substructs.values()):
375
+ substructs = adjust_amide_bonds_in_substructs(substructs, protac_smiles)
376
+ substructs = adjust_ester_bonds_in_substructs(substructs, protac_smiles)
377
+ ret["e3"] = substructs["e3"]
378
+ ret["poi"] = substructs["poi"]
379
+ ret["linker"] = substructs["linker"]
380
+
381
+ return ret
382
+
383
+ def split_protac_with_graphs_wrapper(
384
+ protac_smiles: List[str],
385
+ use_classifier: bool = False,
386
+ classifier: Optional['GraphEdgeClassifier'] = None,
387
+ representative_e3s: Optional[List[Any]] = None,
388
+ representative_e3s_fp: Optional[List[Any]] = None,
389
+ morgan_fp_generator: Optional[Any] = None,
390
+ use_capacity_weight: bool = False,
391
+ betweenness_threshold: float = 0.4,
392
+ ) -> List[Dict[str, str]]:
393
+ """ Wrapper function to apply split_protac_graph_based over a list of PROTAC SMILES.
394
+
395
+ Parameters:
396
+ protac_smiles (List[str]): List of SMILES strings of PROTAC molecules.
397
+ use_classifier (bool): Whether to use a classifier for splitting.
398
+ classifier (Optional[GraphEdgeClassifier]): Classifier to use if use_classifier is True.
399
+ representative_e3s_fp (Optional[List[Any]]): Precomputed fingerprints of representative E3 ligands.
400
+ morgan_fp_generator (Optional[Any]): RDKit Morgan fingerprint generator.
401
+ use_capacity_weight (bool): Whether to use bond capacity as weight for the graph.
402
+ betweenness_threshold (float): Threshold for betweenness centrality to consider a node as a candidate for splitting.
403
+
404
+ Returns:
405
+ List[Dict[str, str]]: List of dictionaries containing the split results for each PROTAC molecule.
406
+ """
407
+ if morgan_fp_generator is None and (representative_e3s is None or representative_e3s_fp is None):
408
+ # Create a default Morgan fingerprint generator
409
+ morgan_fp_generator = rdFingerprintGenerator.GetMorganGenerator(
410
+ radius=16,
411
+ fpSize=1024,
412
+ useBondTypes=True,
413
+ includeChirality=True,
414
+ )
415
+
416
+ if representative_e3s is None and representative_e3s_fp is None:
417
+ # Get the representative E3 ligands fingerprints
418
+ representative_e3s_fp = get_representative_e3s_fp(fp_generator=morgan_fp_generator)
419
+ elif representative_e3s is not None and representative_e3s_fp is None:
420
+ # Convert representative E3 ligands to fingerprints
421
+ representative_e3s_fp = get_representative_e3s_fp(e3_list=representative_e3s, fp_generator=morgan_fp_generator)
422
+
423
+ # Load the classifier if it is a string or Path
424
+ if use_classifier and classifier is not None and isinstance(classifier, (str, Path)):
425
+ classifier = GraphEdgeClassifier.load(classifier)
426
+
427
+ return [
428
+ split_protac_graph_based(
429
+ protac_smiles=smi,
430
+ use_classifier=use_classifier,
431
+ classifier=classifier,
432
+ representative_e3s_fp=representative_e3s_fp,
433
+ morgan_fp_generator=morgan_fp_generator,
434
+ use_capacity_weight=use_capacity_weight,
435
+ betweenness_threshold=betweenness_threshold,
436
+ ) for smi in protac_smiles
437
+ ]
438
+
439
+
440
+ def split_protac_with_graphs_parallel(
441
+ protac_smiles: List[str],
442
+ use_classifier: bool = False,
443
+ classifier: Optional['GraphEdgeClassifier'] = None,
444
+ representative_e3s: Optional[List[Any]] = None,
445
+ representative_e3s_fp: Optional[List[Any]] = None,
446
+ morgan_fp_generator: Optional[Any] = None,
447
+ use_capacity_weight: bool = False,
448
+ betweenness_threshold: float = 0.4,
449
+ n_jobs: int = 1,
450
+ batch_size: int = 1,
451
+ ) -> List[Dict[str, str]]:
452
+ """ Splits a list of PROTAC molecules using either ML classifier or deterministic betweenness centrality.
453
+
454
+ Parameters:
455
+ protac_smiles (List[str]): List of SMILES strings of PROTAC molecules.
456
+ use_classifier (bool): Whether to use a classifier for splitting.
457
+ classifier (Optional[GraphEdgeClassifier]): Classifier to use if use_classifier is True.
458
+ representative_e3s (Optional[List[Any]]): List of representative E3 ligands. If None, uses precomputed fingerprints.
459
+ representative_e3s_fp (Optional[List[Any]]): Precomputed fingerprints of representative E3 ligands.
460
+ morgan_fp_generator (Optional[Any]): RDKit Morgan fingerprint generator.
461
+ use_capacity_weight (bool): Whether to use bond capacity as weight for the graph.
462
+ betweenness_threshold (float): Threshold for betweenness centrality to consider a node as a candidate for splitting.
463
+ n_jobs (int): Number of parallel jobs to run. If 1, runs sequentially.
464
+ batch_size (int): Size of each batch for parallel processing.
465
+ """
466
+ # Load the classifier if it is a string or Path
467
+ if use_classifier and classifier is not None and isinstance(classifier, (str, Path)):
468
+ classifier = GraphEdgeClassifier.load(classifier)
469
+
470
+ if n_jobs < 1:
471
+ raise ValueError("n_jobs must be a positive integer.")
472
+ if n_jobs == 1:
473
+ # If n_jobs is 1, run the function sequentially
474
+ return split_protac_with_graphs_wrapper(
475
+ protac_smiles=protac_smiles,
476
+ use_classifier=use_classifier,
477
+ classifier=classifier,
478
+ representative_e3s=representative_e3s,
479
+ representative_e3s_fp=representative_e3s_fp,
480
+ morgan_fp_generator=morgan_fp_generator,
481
+ use_capacity_weight=use_capacity_weight,
482
+ betweenness_threshold=betweenness_threshold,
483
+ )
484
+
485
+ # Raise a warning if the n_jobs > 1 and the fingerprint generator is provided
486
+ if morgan_fp_generator is not None:
487
+ print("Warning: Using a custom Morgan fingerprint generator with n_jobs > 1 may be un-pickleable.")
488
+
489
+ # Split the SMILES list into batches
490
+ smiles_batches = [protac_smiles[i:i + batch_size] for i in range(0, len(protac_smiles), batch_size)]
491
+
492
+ # Ensure all SMILES are processed, even if the last batch is smaller than batch_size
493
+ smiles_batches = [protac_smiles[i:i + batch_size] for i in range(0, len(protac_smiles), batch_size)]
494
+ # Remove any empty batches (shouldn't happen, but for safety)
495
+ smiles_batches = [batch for batch in smiles_batches if batch]
496
+
497
+ # Run each batch in parallel
498
+ results = Parallel(n_jobs=n_jobs)(
499
+ delayed(split_protac_with_graphs_wrapper)(
500
+ protac_smiles=batch,
501
+ use_classifier=use_classifier,
502
+ classifier=classifier,
503
+ representative_e3s=representative_e3s,
504
+ representative_e3s_fp=representative_e3s_fp,
505
+ morgan_fp_generator=morgan_fp_generator,
506
+ use_capacity_weight=use_capacity_weight,
507
+ betweenness_threshold=betweenness_threshold,
508
+ ) for batch in smiles_batches
509
+ )
510
+
511
+ # Flatten the list of lists into a single list
512
+ return [item for batch_result in results for item in batch_result]
protac_splitter/graphs/utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, List
2
+
3
+ import numpy as np
4
+ from rdkit import Chem, DataStructs
5
+ from rdkit.Chem import rdFingerprintGenerator
6
+
7
+ def get_fp(
8
+ smiles: str,
9
+ fp_generator: Optional[Any] = None,
10
+ return_np: bool = True,
11
+ ) -> Optional[np.ndarray]:
12
+ """
13
+ Get the Morgan fingerprint of a molecule from its SMILES representation.
14
+
15
+ Parameters:
16
+ smiles (str): The SMILES string of the molecule.
17
+ fp_generator (Any, optional): The fingerprint generator to use. If None, a default generator is used.
18
+ return_np (bool): Whether to return the fingerprint as a NumPy array. Defaults to True.
19
+
20
+ Returns:
21
+ Optional[np.ndarray]: The Morgan fingerprint of the molecule as a NumPy array, or None if the SMILES is invalid.
22
+ """
23
+ mol = Chem.MolFromSmiles(smiles)
24
+ if mol is None:
25
+ return None
26
+
27
+ if fp_generator is None:
28
+ fp_generator = rdFingerprintGenerator.GetMorganGenerator(
29
+ radius=16,
30
+ fpSize=1024,
31
+ useBondTypes=True,
32
+ includeChirality=True,
33
+ )
34
+
35
+ if return_np:
36
+ return fp_generator.GetFingerprintAsNumPy(mol)
37
+ else:
38
+ return fp_generator.GetFingerprint(mol)
39
+
40
+ def average_tanimoto_distance(
41
+ smiles: str,
42
+ fingerprints: List[DataStructs.ExplicitBitVect],
43
+ morgan_fp_generator: Optional[Any] = None,
44
+ ) -> float:
45
+ """
46
+ Compute the average Tanimoto distance between a query SMILES and a list of RDKit fingerprints.
47
+
48
+ Parameters:
49
+ smiles (str): SMILES string of the query molecule.
50
+ fingerprints (list): List of RDKit fingerprint objects (e.g., ExplicitBitVect).
51
+ morgan_fp_generator: RDKit Morgan fingerprint generator.
52
+
53
+ Returns:
54
+ float: Average Tanimoto distance (1 - similarity) between the query and the fingerprints.
55
+ """
56
+ query_fp = get_fp(smiles, morgan_fp_generator, return_np=False)
57
+ if query_fp is None:
58
+ raise ValueError(f"Invalid SMILES string: {smiles}")
59
+ distances = DataStructs.BulkTanimotoSimilarity(query_fp, fingerprints, returnDistance=True)
60
+
61
+ return np.array(distances).mean()
62
+
63
+ def numpy_to_rdkit_fp(arr: np.ndarray) -> DataStructs.ExplicitBitVect:
64
+ """
65
+ Convert a NumPy array to an RDKit ExplicitBitVect.
66
+ """
67
+ return DataStructs.CreateFromBitString(''.join(arr.astype(str)))
protac_splitter/graphs_utils.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numba import njit
2
+ import numpy as np
3
+ import networkx as nx
4
+ from rdkit import Chem
5
+
6
+
7
+ def mol2graph(mol: Chem.Mol) -> nx.Graph:
8
+ """ Convert an RDKit molecule to a NetworkX graph.
9
+
10
+ Args:
11
+ mol (Chem.Mol): The RDKit molecule to convert.
12
+
13
+ Returns:
14
+ nx.Graph: The NetworkX graph representation of the molecule.
15
+ """
16
+ # NOTE: https://github.com/maxhodak/keras-molecules/pull/32/files
17
+ # TODO: Double check this implementation too: https://gist.github.com/jhjensen2/6450138cda3ab796a30850610843cfff
18
+ if mol is None:
19
+ return nx.empty_graph()
20
+ G = nx.Graph()
21
+ for atom in mol.GetAtoms():
22
+ # Skip non-heavy atoms
23
+ if atom.GetAtomicNum() != 0:
24
+ G.add_node(atom.GetIdx(), label=atom.GetSymbol())
25
+ for bond in mol.GetBonds():
26
+ # Skip bonds to non-heavy atoms
27
+ if bond.GetBeginAtom().GetAtomicNum() == 0 or bond.GetEndAtom().GetAtomicNum() == 0:
28
+ continue
29
+ G.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), label=bond.GetBondType())
30
+ return G
31
+
32
+ def smiles2graph(smiles: str) -> nx.Graph:
33
+ """ Convert a SMILES string to a NetworkX graph.
34
+
35
+ Args:
36
+ smiles (str): The SMILES string to convert.
37
+
38
+ Returns:
39
+ nx.Graph: The NetworkX graph representation of the molecule.
40
+ """
41
+ return mol2graph(Chem.MolFromSmiles(smiles))
42
+
43
+ def get_smiles2graph_edit_distance(smi1: str, smi2: str, **kwargs) -> float:
44
+ """ Compute the graph edit distance between two SMILES strings.
45
+
46
+ Args:
47
+ smi1 (str): The first SMILES string.
48
+ smi2 (str): The second SMILES string.
49
+ **kwargs: Additional keyword arguments for `nx.graph_edit_distance`.
50
+
51
+ Returns:
52
+ float: The graph edit distance between the two SMILES strings.
53
+ """
54
+ ged = nx.graph_edit_distance(smiles2graph(smi1), smiles2graph(smi2), **kwargs)
55
+ return ged if ged is not None else np.inf
56
+
57
+ def get_mol2graph_edit_distance(mol1: str, mol2: str, **kwargs) -> float:
58
+ """ Compute the graph edit distance between two RDKit molecules.
59
+
60
+ Args:
61
+ mol1 (Chem.Mol): The first RDKit molecule.
62
+ mol2 (Chem.Mol): The second RDKit molecule.
63
+ **kwargs: Additional keyword arguments for `nx.graph_edit_distance`.
64
+
65
+ Returns:
66
+ float: The graph edit distance between the two RDKit molecules.
67
+ """
68
+ ged = nx.graph_edit_distance(mol2graph(mol1), mol2graph(mol2), **kwargs)
69
+ return ged if ged is not None else np.inf
70
+
71
+ def get_smiles2graph_edit_distance_norm(
72
+ smi1: str,
73
+ smi2: str,
74
+ ged_G1_G2: None,
75
+ eps: float = 1e-9,
76
+ **kwargs,
77
+ ) -> float:
78
+ """ Compute the normalized graph edit distance between two SMILES strings.
79
+
80
+ Args:
81
+ smi1 (str): The first SMILES string.
82
+ smi2 (str): The second SMILES string.
83
+ ged_G1_G2 (float): The graph edit distance between the two graphs. If None, it will be computed using `nx.graph_edit_distance`.
84
+ eps (float): A small value to avoid division by zero.
85
+ **kwargs: Additional keyword arguments for `nx.graph_edit_distance`.
86
+
87
+ Returns:
88
+ float: The normalized graph edit distance between the two SMILES strings.
89
+ """
90
+ G1 = smiles2graph(smi1)
91
+ G2 = smiles2graph(smi2)
92
+ G0 = nx.empty_graph()
93
+ ged_G1_G2 = ged_G1_G2 if ged_G1_G2 is not None else nx.graph_edit_distance(G1, G2, **kwargs)
94
+ ged_G1_G0 = nx.graph_edit_distance(G1, G0, **kwargs)
95
+ ged_G2_G0 = nx.graph_edit_distance(G2, G0, **kwargs)
96
+ if None in [ged_G1_G2, ged_G1_G0, ged_G2_G0]:
97
+ return np.inf
98
+ return ged_G1_G2 / (ged_G1_G0 + ged_G2_G0 + eps)
99
+
100
+ def smiles2adjacency_matrix(smiles: str) -> np.ndarray:
101
+ return nx.adjacency_matrix(smiles2graph(smiles)).todense()
102
+
103
+ def build_label_mapping(G1, G2):
104
+ labels = set()
105
+ for G in [G1, G2]:
106
+ for node in G.nodes():
107
+ labels.add(G.nodes[node]['label'])
108
+ label_to_int = {label: idx for idx, label in enumerate(sorted(labels))}
109
+ return label_to_int
110
+
111
+ def preprocess_graph(G, label_to_int):
112
+ n = G.number_of_nodes()
113
+ adj = np.zeros((n, n), dtype=np.int32)
114
+ labels = np.zeros(n, dtype=np.int32)
115
+ node_id_to_idx = {}
116
+ for idx, node in enumerate(G.nodes()):
117
+ node_id_to_idx[node] = idx
118
+ label = G.nodes[node]['label']
119
+ labels[idx] = label_to_int[label]
120
+ for u, v in G.edges():
121
+ idx_u = node_id_to_idx[u]
122
+ idx_v = node_id_to_idx[v]
123
+ adj[idx_u, idx_v] = 1
124
+ adj[idx_v, idx_u] = 1 # Assuming undirected graph
125
+ return adj, labels
126
+
127
+ @njit
128
+ def compute_cost_matrix(labels1, labels2, degrees1, degrees2):
129
+ n1 = labels1.shape[0]
130
+ n2 = labels2.shape[0]
131
+ C = np.zeros((n1, n2), dtype=np.float64)
132
+ for i in range(n1):
133
+ for j in range(n2):
134
+ label_cost = 0.0 if labels1[i] == labels2[j] else 1.0
135
+ neighborhood_cost = abs(degrees1[i] - degrees2[j])
136
+ C[i, j] = label_cost + neighborhood_cost
137
+ return C
138
+
139
+ @njit
140
+ def greedy_assignment(C):
141
+ n1, n2 = C.shape
142
+ assigned_cols = np.full(n2, False)
143
+ row_ind = np.full(n1, -1, dtype=np.int32)
144
+ for i in range(n1):
145
+ min_cost = np.inf
146
+ min_j = -1
147
+ for j in range(n2):
148
+ if not assigned_cols[j] and C[i, j] < min_cost:
149
+ min_cost = C[i, j]
150
+ min_j = j
151
+ if min_j != -1:
152
+ row_ind[i] = min_j
153
+ assigned_cols[min_j] = True
154
+ return row_ind
155
+
156
+ @njit
157
+ def compute_total_cost(C, row_ind, n1, n2, c_node_del, c_node_ins):
158
+ total_cost = 0.0
159
+ assigned_cols = np.full(n2, False)
160
+ for i in range(n1):
161
+ j = row_ind[i]
162
+ if j != -1:
163
+ total_cost += C[i, j]
164
+ assigned_cols[j] = True
165
+ else:
166
+ total_cost += c_node_del
167
+ for j in range(n2):
168
+ if not assigned_cols[j]:
169
+ total_cost += c_node_ins
170
+ return total_cost
171
+
172
+ def approximate_graph_edit_distance(adj1, labels1, adj2, labels2, c_node_del=1.0, c_node_ins=1.0):
173
+ degrees1 = adj1.sum(axis=1)
174
+ degrees2 = adj2.sum(axis=1)
175
+ C = compute_cost_matrix(labels1, labels2, degrees1, degrees2)
176
+ row_ind = greedy_assignment(C)
177
+ total_cost = compute_total_cost(C, row_ind, labels1.shape[0], labels2.shape[0], c_node_del, c_node_ins)
178
+ return total_cost
179
+
180
+ def get_approximate_ged(G1, G2):
181
+ label_to_int = build_label_mapping(G1, G2)
182
+ adj1, labels1 = preprocess_graph(G1, label_to_int)
183
+ adj2, labels2 = preprocess_graph(G2, label_to_int)
184
+ cost = approximate_graph_edit_distance(adj1, labels1, adj2, labels2)
185
+ return cost
186
+
187
+ def get_smiles2graph_edit_distance_approx(smi1: str, smi2: str) -> float:
188
+ G1 = smiles2graph(smi1)
189
+ G2 = smiles2graph(smi2)
190
+ return get_approximate_ged(G1, G2)
protac_splitter/llms/__init__.py ADDED
File without changes
protac_splitter/llms/data_utils.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import logging
4
+ from typing import Optional, Union
5
+
6
+ import torch
7
+ from datasets import load_dataset, concatenate_datasets, Dataset
8
+ from transformers import AutoTokenizer
9
+ from rdkit import Chem
10
+
11
+ from protac_splitter.evaluation import split_prediction
12
+
13
+
14
+ def randomize_smiles_dataset(
15
+ batch: dict,
16
+ repeat: int = 1,
17
+ prob: float = 0.5,
18
+ apply_to_text: bool = True,
19
+ apply_to_labels: bool = False,
20
+ ) -> dict:
21
+ """ Randomize SMILES in a batch of data.
22
+
23
+ Args:
24
+ batch (dict): Batch of data with "text" and "labels" keys.
25
+ repeat (int, optional): Number of times to repeat the randomization. Defaults to 1.
26
+ prob (float, optional): Probability of randomizing SMILES. Defaults to 0.5.
27
+ apply_to_text (bool, optional): Whether to apply randomization to text. Defaults to True.
28
+ apply_to_labels (bool, optional): Whether to apply randomization to labels. Defaults to False.
29
+
30
+ Returns:
31
+ dict: Randomized batch of data.
32
+ """
33
+ new_texts, new_labels = [], []
34
+ for text, label in zip(batch["text"], batch["labels"]):
35
+ try:
36
+ mol_text = Chem.MolFromSmiles(text)
37
+ mol_label = Chem.MolFromSmiles(label)
38
+ except Exception:
39
+ logging.error("Failed to convert SMILES to Mol!")
40
+ new_texts.append(text)
41
+ new_labels.append(label)
42
+ continue
43
+
44
+ if random.random() < prob:
45
+ if apply_to_text:
46
+ rand_texts = [Chem.MolToSmiles(mol_text, canonical=False, doRandom=True) for _ in range(repeat)]
47
+ else:
48
+ rand_texts = [text] * repeat
49
+
50
+ if apply_to_labels:
51
+ rand_labels = [Chem.MolToSmiles(mol_label, canonical=False, doRandom=True) for _ in range(repeat)]
52
+ else:
53
+ rand_labels = [label] * repeat
54
+
55
+ new_texts.extend(rand_texts)
56
+ new_labels.extend(rand_labels)
57
+ else:
58
+ new_texts.append(text)
59
+ new_labels.append(label)
60
+
61
+ return {"text": new_texts, "labels": new_labels}
62
+
63
+
64
+ def process_data_to_model_inputs(
65
+ batch,
66
+ tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
67
+ encoder_max_length: int = 512,
68
+ decoder_max_length: int = 512,
69
+ ):
70
+ if isinstance(tokenizer, str):
71
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
72
+ # tokenize the inputs and labels
73
+ inputs = tokenizer(batch["text"], truncation=True, max_length=encoder_max_length)
74
+ outputs = tokenizer(batch["labels"], truncation=True, max_length=decoder_max_length)
75
+ batch["input_ids"] = inputs.input_ids
76
+ batch["attention_mask"] = inputs.attention_mask
77
+ batch["labels"] = outputs.input_ids.copy()
78
+
79
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
+ # batch["input_ids"] = batch["input_ids"].to(device)
81
+ # batch["attention_mask"] = batch["attention_mask"].to(device)
82
+ # batch["labels"] = batch["labels"].to(device)
83
+
84
+ # Because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`.
85
+ # We have to make sure that the PAD token is ignored when calculating the loss.
86
+ # NOTE: Check the `ignore_index` argument in nn.CrossEntropyLoss.
87
+ # NOTE: The following is already done in the DataCollatorForSeq2Seq
88
+ # batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]
89
+ return batch
90
+
91
+
92
+ def get_fragments_in_labels(labels: str, linkers_only_as_labels: bool = True) -> list[str]:
93
+ """ Get the fragments in the labels.
94
+
95
+ Args:
96
+ labels (str): The labels.
97
+ linkers_only_as_labels (bool, optional): Whether to get only the linkers in the labels. Defaults to True.
98
+
99
+ Returns:
100
+ list[str]: The fragments in the labels.
101
+ """
102
+ ligands = split_prediction(labels)
103
+ if linkers_only_as_labels:
104
+ return ligands.get("linker", None)
105
+ if None in ligands.values():
106
+ return None
107
+ return f"{ligands['e3']}.{ligands['poi']}"
108
+
109
+
110
+ def load_tokenized_dataset(
111
+ dataset_dir: str,
112
+ dataset_config: str = 'default',
113
+ tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
114
+ batch_size: int = 512,
115
+ encoder_max_length: int = 512,
116
+ decoder_max_length: int = 512,
117
+ token: Optional[str] = None,
118
+ num_proc_map: int = 1,
119
+ randomize_smiles: bool = False,
120
+ randomize_smiles_prob: float = 0.5,
121
+ randomize_smiles_repeat: int = 1,
122
+ randomize_text: bool = True,
123
+ randomize_labels: bool = False,
124
+ cache_dir: Optional[str] = None,
125
+ all_fragments_as_labels: bool = True,
126
+ linkers_only_as_labels: bool = False,
127
+ causal_language_modeling: bool = False,
128
+ train_size_ratio: float = 1.0,
129
+ ) -> Dataset:
130
+ """ Load dataset and tokenize it.
131
+
132
+ Args:
133
+ dataset_dir (str): The directory of the dataset or the name of the data on the Hugging Face Hub.
134
+ dataset_config (str, optional): The configuration of the dataset. Defaults to 'default'.
135
+ tokenizer (AutoTokenizer | str, optional): The tokenizer to use for tokenization. If a string, the tokenizer will be loaded using `AutoTokenizer.from_pretrained(tokenizer)`. Defaults to "seyonec/ChemBERTa-zinc-base-v1".
136
+ batch_size (int, optional): The batch size for tokenization. Defaults to 512.
137
+ encoder_max_length (int, optional): The maximum length of the encoder input sequence. Defaults to 512.
138
+ decoder_max_length (int, optional): The maximum length of the decoder input sequence. Defaults to 512.
139
+ token (Optional[str], optional): The Hugging Face API token. Defaults to None.
140
+ num_proc_map (int, optional): The number of processes to use for mapping. Defaults to 1.
141
+ randomize_smiles (bool, optional): Whether to randomize SMILES. Defaults to False.
142
+ randomize_smiles_prob (float, optional): The probability of randomizing SMILES. Defaults to 0.5.
143
+ randomize_smiles_repeat (int, optional): The number of times to repeat the randomization. Defaults to 1.
144
+ randomize_text (bool, optional): Whether to randomize text. Defaults to True.
145
+ randomize_labels (bool, optional): Whether to randomize labels. Defaults to False.
146
+ cache_dir (Optional[str], optional): The directory to cache the dataset. Defaults to None.
147
+ all_fragments_as_labels (bool, optional): Whether to get all fragments in the labels. Defaults to True.
148
+ linkers_only_as_labels (bool, optional): Whether to get only the linkers in the labels. Defaults to False.
149
+ causal_language_modeling (bool, optional): Whether to use causal language modeling. Defaults to False.
150
+ train_size_ratio (float, optional): The ratio of the training dataset to use. Defaults to 1.0.
151
+
152
+ Returns:
153
+ Dataset: The tokenized dataset.
154
+ """
155
+ if isinstance(tokenizer, str):
156
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
157
+ if os.path.exists(dataset_dir):
158
+ # NOTE: We need a different argument to load a dataset from disk:
159
+ dataset = load_dataset(
160
+ dataset_dir,
161
+ data_dir=dataset_config,
162
+ )
163
+ print(f"Dataset loaded from disk at: \"{dataset_dir}\". Length: {dataset.num_rows}")
164
+ else:
165
+ dataset = load_dataset(
166
+ dataset_dir,
167
+ dataset_config,
168
+ token=token,
169
+ cache_dir=cache_dir,
170
+ )
171
+ print(f"Dataset loaded from hub. Length: {dataset.num_rows}")
172
+
173
+ if train_size_ratio < 1.0 and train_size_ratio > 0:
174
+ # Reduce the size of the training dataset but just selecting a fraction of the samples
175
+ dataset["train"] = dataset["train"].select(range(int(train_size_ratio * dataset["train"].num_rows)))
176
+ print(f"Reduced training dataset size to {train_size_ratio}. Length: {dataset.num_rows}")
177
+ elif train_size_ratio > 1.0 or train_size_ratio < 0:
178
+ raise ValueError("train_size_ratio must be between 0 and 1.")
179
+
180
+ if not all_fragments_as_labels:
181
+ dataset = dataset.map(
182
+ lambda x: {
183
+ "text": x["text"],
184
+ "labels": get_fragments_in_labels(x["labels"], linkers_only_as_labels),
185
+ },
186
+ batched=False,
187
+ num_proc=num_proc_map,
188
+ load_from_cache_file=True,
189
+ desc="Getting fragments in labels",
190
+ )
191
+ # Filter out the samples with None labels
192
+ dataset = dataset.filter(lambda x: x["labels"] is not None)
193
+
194
+ if linkers_only_as_labels:
195
+ print(f"Set labels to linkers only. Length: {dataset.num_rows}")
196
+ else:
197
+ print(f"Set labels to E3 and WH only. Length: {dataset.num_rows}")
198
+
199
+ if randomize_smiles:
200
+ dataset["train"] = dataset["train"].map(
201
+ randomize_smiles_dataset,
202
+ batched=True,
203
+ batch_size=batch_size,
204
+ fn_kwargs={
205
+ "repeat": randomize_smiles_repeat,
206
+ "prob": randomize_smiles_prob,
207
+ "apply_to_text": randomize_text,
208
+ "apply_to_labels": randomize_labels,
209
+ },
210
+ num_proc=num_proc_map,
211
+ load_from_cache_file=True,
212
+ desc="Randomizing SMILES",
213
+ )
214
+ print(f"Randomized SMILES in dataset. Length: {dataset.num_rows}")
215
+
216
+ if causal_language_modeling:
217
+ dataset = dataset.map(
218
+ lambda x: {
219
+ "text": x["text"] + "." + x["labels"],
220
+ "labels": x["labels"],
221
+ },
222
+ batched=False,
223
+ num_proc=num_proc_map,
224
+ load_from_cache_file=True,
225
+ desc="Setting labels to text",
226
+ )
227
+ print(f"Appended labels to text. Length: {dataset.num_rows}")
228
+
229
+ # NOTE: Remove the "labels" column if causal language modeling, since the
230
+ # DataCollatorForLM will automatically set the labels to the input_ids.
231
+ dataset = dataset.map(
232
+ process_data_to_model_inputs,
233
+ batched=True,
234
+ batch_size=batch_size,
235
+ remove_columns=["text", "labels"] if causal_language_modeling else ["text"],
236
+ fn_kwargs={
237
+ "tokenizer": tokenizer,
238
+ "encoder_max_length": encoder_max_length,
239
+ "decoder_max_length": decoder_max_length,
240
+ },
241
+ num_proc=num_proc_map,
242
+ load_from_cache_file=True,
243
+ desc="Tokenizing dataset",
244
+ )
245
+ print(f"Tokenized dataset. Length: {dataset.num_rows}")
246
+
247
+ return dataset
248
+
249
+
250
+ def load_trl_dataset(
251
+ tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
252
+ token: Optional[str] = None,
253
+ max_length: int = 512,
254
+ dataset_name: str = "ailab-bio/PROTAC-Splitter-Dataset",
255
+ ds_config: str = "standard",
256
+ ds_unalabeled: Optional[str] = None,
257
+ ) -> Dataset:
258
+ if isinstance(tokenizer, str):
259
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
260
+ # Load training data
261
+ train_dataset = load_dataset(
262
+ dataset_name,
263
+ ds_config,
264
+ split="train",
265
+ token=token,
266
+ )
267
+ train_dataset = train_dataset.rename_column("text", "query")
268
+ train_dataset = train_dataset.remove_columns(["labels"])
269
+
270
+ if ds_unalabeled is not None:
271
+ # Load un-labelled data
272
+ unlabeled_dataset = load_dataset(
273
+ dataset_name,
274
+ ds_unalabeled,
275
+ split="train",
276
+ token=token,
277
+ )
278
+ unlabeled_dataset = unlabeled_dataset.rename_column("text", "query")
279
+ unlabeled_dataset = unlabeled_dataset.remove_columns(["labels"])
280
+ # Concatenate datasets row-wise
281
+ dataset = concatenate_datasets([train_dataset, unlabeled_dataset])
282
+ else:
283
+ dataset = train_dataset
284
+
285
+ def tokenize(sample, tokenizer, max_length=512):
286
+ input_ids = tokenizer.encode(sample["query"], padding="max_length", max_length=max_length)
287
+ return {"input_ids": input_ids, "query": sample["query"]}
288
+
289
+ return dataset.map(lambda x: tokenize(x, tokenizer, max_length), batched=False)
290
+
291
+
292
+ def data_collator_for_trl(batch):
293
+ return {
294
+ "input_ids": [torch.tensor(x["input_ids"]) for x in batch],
295
+ "query": [x["query"] for x in batch],
296
+ }
protac_splitter/llms/evaluation.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ from transformers import AutoTokenizer, EvalPrediction
4
+ import numpy as np
5
+ from rdkit import Chem, DataStructs
6
+ import evaluate
7
+ import multiprocessing as mp
8
+ import datetime
9
+
10
+ from protac_splitter.evaluation import (
11
+ # is_valid_smiles,
12
+ # has_three_substructures,
13
+ # has_all_attachment_points,
14
+ # check_substructs,
15
+ score_prediction,
16
+ )
17
+
18
+ def process_predictions(args) -> list:
19
+ """ Process one iteration of the prediction scoring.
20
+
21
+ Args:
22
+ args (tuple): Tuple of arguments for the scoring function.
23
+
24
+ Returns:
25
+ dict: The scores for the prediction.
26
+ """
27
+ pred_smiles, protac_smiles, label_smiles, fpgen, compute_rdkit_metrics, compute_graph_metrics = args
28
+ scores = []
29
+ for protac, pred, label in zip(protac_smiles, pred_smiles, label_smiles):
30
+ scores.append(score_prediction(
31
+ protac_smiles=protac,
32
+ label_smiles=label,
33
+ pred_smiles=pred,
34
+ fpgen=fpgen,
35
+ compute_rdkit_metrics=compute_rdkit_metrics,
36
+ compute_graph_metrics=compute_graph_metrics,
37
+ graph_edit_kwargs={"timeout": 0.05},
38
+ ))
39
+ return scores
40
+
41
+
42
+ def decode_and_get_metrics(
43
+ pred: EvalPrediction,
44
+ tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
45
+ rouge = None, # Optional[evaluate.metrics.rouge.Rouge] = None,
46
+ fpgen = None, # Optional[Chem.rdFingerprintGenerator] = None,
47
+ compute_rdkit_metrics: bool = False,
48
+ compute_graph_metrics: bool = True,
49
+ num_proc: int = 1,
50
+ batch_size: int = 128,
51
+ use_nan_for_missing: bool = True,
52
+ causal_language_modeling: bool = False,
53
+ ) -> dict[str, float]:
54
+ """ Compute metrics for tokenized PROTAC predictions.
55
+
56
+ Args:
57
+ pred (transformers.EvalPrediction): The predictions from the model.
58
+ rouge (Rouge): The Rouge object to use for scoring. Example: `rouge = evaluate.load("rouge")`
59
+ tokenizer (AutoTokenizer | str): The tokenizer to use for decoding the predictions. If a string, the tokenizer will be loaded using `AutoTokenizer.from_pretrained(tokenizer)`. Default: "seyonec/ChemBERTa-zinc-base-v1"
60
+ fpgen (Chem.rdFingerprintGenerator): The fingerprint generator to use for computing the Tanimoto similarity. Default: `Chem.rdFingerprintGenerator.GetMorganGenerator(radius=8, fpSize=2048)`
61
+
62
+ Returns:
63
+ dict[str, float]: A dictionary containing the scores for the predictions
64
+ """
65
+ print(f"[{datetime.datetime.now()}] Starting decode_and_get_metrics (protac_splitter/llms/evaluation.py)")
66
+
67
+ if causal_language_modeling:
68
+ # NOTE: For causal language models, we only care about perplexity, so we
69
+ # only need the eval_loss, which is automatically added.
70
+ return {}
71
+
72
+ if isinstance(tokenizer, str):
73
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
74
+
75
+ labels_ids = pred.label_ids
76
+ pred_ids = pred.predictions
77
+ input_ids = pred.inputs
78
+
79
+ if causal_language_modeling:
80
+ # The prediction logits will be of shape: (batch_size, sequence_length, vocabulary_size)
81
+ # So we need to get the argmax of the last dimension to get the
82
+ # predicted token IDs.
83
+ # NOTE: Not exactly the same as what would happen during generation, but
84
+ # hopefully it's close enough to assess model performance during
85
+ # training.
86
+ pred_ids = np.argmax(pred_ids, axis=-1)
87
+
88
+ # Replace -100 in the IDs with the tokenizer pad token id
89
+ # NOTE: Check the `ignore_index` argument in nn.CrossEntropyLoss.
90
+ # TODO: Understand why this needs to be done to the inputs as well
91
+ ignore_index = -100
92
+ labels_ids[labels_ids == ignore_index] = tokenizer.pad_token_id
93
+ pred_ids[pred_ids == ignore_index] = tokenizer.pad_token_id
94
+
95
+ # Get strings from IDs
96
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
97
+ label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
98
+
99
+ if not causal_language_modeling:
100
+ input_ids[input_ids == ignore_index] = tokenizer.pad_token_id
101
+ input_str = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
102
+ else:
103
+ # NOTE: For causal language models, i.e., decoder only, the input PROTAC
104
+ # is in the label. Therefore, we need to decode the label to get the
105
+ # input. The label looks something like "PROTAC.E3.Linker.WH", so we
106
+ # need to split it and get the last (three) parts.
107
+ input_str = [str(s.split('.')[0]) for s in label_str]
108
+ label_str = ['.'.join(s.split('.')[1:]) for s in label_str]
109
+ pred_str = ['.'.join(s.split('.')[1:]) if '.' in s else s for s in pred_str]
110
+
111
+ # Get scores
112
+ if num_proc == 1:
113
+ scores = process_predictions((
114
+ pred_str, input_str, label_str, fpgen, compute_rdkit_metrics, compute_graph_metrics
115
+ ))
116
+ else:
117
+ # Use pools to process batches of predictions
118
+ with mp.Pool(processes=num_proc) as pool:
119
+ scores = []
120
+ for i in range(0, len(pred_str), batch_size):
121
+ scores += pool.map(process_predictions, [
122
+ (pred_str[i:i+batch_size], input_str[i:i+batch_size], label_str[i:i+batch_size], fpgen, compute_rdkit_metrics, compute_graph_metrics)
123
+ ])
124
+ # Flatten the list of scores
125
+ scores = [s for ls in scores for s in ls]
126
+
127
+ # Aggregate scores
128
+ scores_labels = set()
129
+ for s in scores:
130
+ scores_labels.update(s.keys())
131
+
132
+ aggregated_scores = {}
133
+ for k in scores_labels:
134
+ values = np.array([s.get(k, np.nan) for s in scores], dtype=float)
135
+
136
+ # If values is all NaN, set the aggregated score to NaN and continue
137
+ if np.all(np.isnan(values)):
138
+ aggregated_scores[k] = None
139
+ continue
140
+
141
+ # Compute average, excluding `NaN` values if necessary
142
+ if use_nan_for_missing:
143
+ aggregated_scores[k] = np.nanmean(values)
144
+ else:
145
+ valid_values = values[~np.isnan(values)]
146
+ aggregated_scores[k] = np.mean(valid_values) if valid_values.size > 0 else float('nan')
147
+
148
+ # Get Rouge score
149
+ if rouge is not None:
150
+ rouge_output = rouge.compute(predictions=pred_str, references=label_str)
151
+ aggregated_scores.update({k: v for k, v in rouge_output.items()})
152
+
153
+ # TODO
154
+ # # Get tanimoto score
155
+ # pred_str = np.array(pred_str)[valid_smiles == 1]
156
+ # label_str = np.array(label_str)[valid_smiles == 1]
157
+ # if len(pred_str) == 0:
158
+ # scores['tanimoto'] = 0.0
159
+ # return scores
160
+ # pred_mols = [Chem.MolFromSmiles(s) for s in pred_str]
161
+ # label_mols = [Chem.MolFromSmiles(s) for s in label_str]
162
+ # pred_fps = [fpgen.GetFingerprint(m) for m in pred_mols]
163
+ # label_fps = [fpgen.GetFingerprint(m) for m in label_mols]
164
+ # tanimoto = [DataStructs.TanimotoSimilarity(l, p) for l, p in zip(label_fps, pred_fps)]
165
+ # scores['tanimoto'] = np.array(tanimoto).mean()
166
+
167
+ print(f"[{datetime.datetime.now()}] Done with decode_and_get_metrics (protac_splitter/llms/evaluation.py)")
168
+
169
+ return aggregated_scores
protac_splitter/llms/hf_utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Hugging Face Hub utilities for repository management and file uploads. """
2
+ from typing import Optional
3
+
4
+ import huggingface_hub as hf
5
+ from huggingface_hub import repo_info
6
+ from huggingface_hub.utils import RepositoryNotFoundError
7
+
8
+ def repo_exists(repo_id: str, token: Optional[str] = None) -> bool:
9
+ """ Checks if a Hugging Face repository exists. """
10
+ try:
11
+ print(repo_info(repo_id, token=token))
12
+ return True
13
+ except RepositoryNotFoundError:
14
+ return False
15
+
16
+ def create_hf_repository(**kwargs):
17
+ """Creates a new Hugging Face repository."""
18
+ api = hf.HfApi()
19
+ return api.create_repo(**kwargs)
20
+
21
+
22
+ def delete_hf_repository(**kwargs):
23
+ """Creates a new Hugging Face repository."""
24
+ print(f'Deleting repository {kwargs["repo_id"]}.')
25
+ api = hf.HfApi()
26
+ return api.delete_repo(**kwargs)
27
+
28
+
29
+ def upload_single_file(**kwargs):
30
+ """Uploads a single file to a Hugging Face repository."""
31
+ try:
32
+ api = hf.HfApi()
33
+ api.upload_file(**kwargs)
34
+ except Exception as e:
35
+ print(e)
36
+ print("WARNING. Best parameters NOT pushed to the hub.")
protac_splitter/llms/model_utils.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Hugging Face utilities for model loading and pipeline creation. """
2
+ from typing import Optional, List, Dict, Union
3
+ from datasets import Dataset
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ EncoderDecoderModel,
7
+ AutoModelForCausalLM,
8
+ pipeline,
9
+ GenerationConfig,
10
+ )
11
+ from transformers.pipelines.pt_utils import KeyDataset
12
+ from tqdm import tqdm
13
+ import torch
14
+
15
+
16
+ def get_encoder_decoder_model(
17
+ pretrained_encoder: str = "seyonec/ChemBERTa-zinc-base-v1",
18
+ pretrained_decoder: str = "seyonec/ChemBERTa-zinc-base-v1",
19
+ max_length: Optional[int] = 512,
20
+ tie_encoder_decoder: bool = False,
21
+ ) -> EncoderDecoderModel:
22
+ """ Get the EncoderDecoderModel model for the PROTAC splitter.
23
+
24
+ Args:
25
+ pretrained_encoder (str): The pretrained model to use for the encoder. Default: "seyonec/ChemBERTa-zinc-base-v1"
26
+ pretrained_decoder (str): The pretrained model to use for the decoder. Default: "seyonec/ChemBERTa-zinc-base-v1"
27
+ max_length (int): The maximum length of the input sequence. Default: 512
28
+ tie_encoder_decoder (bool): Whether to tie the encoder and decoder weights. Default: False
29
+
30
+ Returns:
31
+ EncoderDecoderModel: The EncoderDecoderModel model for the PROTAC splitter
32
+ """
33
+ bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained(
34
+ pretrained_encoder,
35
+ pretrained_decoder,
36
+ tie_encoder_decoder=tie_encoder_decoder,
37
+ )
38
+ print(f"Number of parameters: {bert2bert.num_parameters():,}")
39
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_encoder)
40
+ # Tokenizer-related configs
41
+ bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
42
+ bert2bert.config.eos_token_id = tokenizer.sep_token_id
43
+ bert2bert.config.pad_token_id = tokenizer.pad_token_id
44
+ bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
45
+ # Generation configs
46
+ # NOTE: See full list of configurations can be found here: https://huggingface.co/docs/transformers/v4.33.3/en/main_classes/text_generation#transformers.GenerationConfig
47
+ bert2bert.encoder.config.max_length = max_length
48
+ bert2bert.decoder.config.max_length = max_length
49
+
50
+ def setup_gen(config):
51
+ config.do_sample = True
52
+ config.num_beams = 5
53
+ config.top_k = 20
54
+ config.max_length = 512
55
+ # config.max_new_tokens = 512
56
+ return config
57
+
58
+ bert2bert.config = setup_gen(bert2bert.config)
59
+ bert2bert.encoder.config = setup_gen(bert2bert.encoder.config)
60
+ bert2bert.decoder.config = setup_gen(bert2bert.decoder.config)
61
+ bert2bert.decoder.config.is_decoder = True
62
+ bert2bert.generation_config = setup_gen(bert2bert.generation_config)
63
+
64
+ # bert2bert.config.do_sample = True
65
+ # bert2bert.config.num_beams = 5
66
+ # bert2bert.config.top_k = 20
67
+ # bert2bert.config.max_length=512
68
+ # bert2bert.config.max_new_tokens=512
69
+
70
+ # bert2bert.generation_config.max_new_tokens = 512
71
+ # bert2bert.generation_config.min_new_tokens = 512
72
+
73
+ # bert2bert.config.max_new_tokens = 514
74
+ # bert2bert.config.early_stopping = True
75
+ # bert2bert.config.length_penalty = 2.0
76
+ # # bert2bert.config.no_repeat_ngram_size = 3 # Default: 0
77
+
78
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+ bert2bert.to(device)
80
+
81
+ return bert2bert
82
+
83
+
84
+ def get_causal_model(
85
+ pretrained_model: str = "seyonec/ChemBERTa-zinc-base-v1",
86
+ max_length: Optional[int] = 512,
87
+ ) -> AutoModelForCausalLM:
88
+ """ Get the causal language model for the PROTAC splitter.
89
+
90
+ Args:
91
+ pretrained_model (str): The pretrained model to use for the causal language model. Default: "seyonec/ChemBERTa-zinc-base-v1"
92
+ max_length (int): The maximum length of the input sequence. Default: 512
93
+
94
+ Returns:
95
+ AutoModelForCausalLM: The causal language model for the PROTAC splitter
96
+ """
97
+ model = AutoModelForCausalLM.from_pretrained(pretrained_model, is_decoder=True)
98
+ # model.is_decoder = True # It might not be necessary, but it's good to be explicit
99
+
100
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
+ model.to(device)
102
+
103
+ return model
104
+
105
+
106
+ # REF: https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/generation/configuration_utils.py#L71
107
+ GENERATION_STRATEGY_PARAMS = {
108
+ "greedy": {"num_beams": 1, "do_sample": False},
109
+ "contrastive_search": {"penalty_alpha": 0.1, "top_k": 10},
110
+ "multinomial_sampling": {"num_beams": 1, "do_sample": True},
111
+ "beam_search_decoding": {"num_beams": 5, "do_sample": False, "num_return_sequences": 5},
112
+ "beam_search_multinomial_sampling": {"num_beams": 5, "do_sample": True, "num_return_sequences": 5},
113
+ "diverse_beam_search_decoding": {"num_beams": 5, "num_beam_groups": 5, "diversity_penalty": 1.0, "num_return_sequences": 5},
114
+ }
115
+
116
+ def avail_generation_strategies() -> List[str]:
117
+ """ Get the available generation strategies. """
118
+ return list(GENERATION_STRATEGY_PARAMS.keys())
119
+
120
+ def get_generation_config(generation_strategy: str) -> GenerationConfig:
121
+ """ Get the generation config for the given generation strategy. """
122
+ return GenerationConfig(
123
+ max_length=512,
124
+ max_new_tokens=512,
125
+ **GENERATION_STRATEGY_PARAMS[generation_strategy],
126
+ )
127
+
128
+ def get_pipeline(
129
+ model_name: str,
130
+ token: str,
131
+ is_causal_language_model: bool,
132
+ generation_strategy: Optional[str] = None,
133
+ num_return_sequences: int = 1,
134
+ device: Optional[Union[int, str]] = None,
135
+ ) -> pipeline:
136
+ """ Get the pipeline for the given model name and generation strategy.
137
+
138
+
139
+
140
+ """
141
+ device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
142
+ if is_causal_language_model and generation_strategy is None:
143
+ print('Loading pipeline for causal language models...')
144
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, padding_side='left')
145
+ return pipeline(
146
+ "text-generation",
147
+ model=model_name,
148
+ tokenizer=tokenizer,
149
+ token=token,
150
+ device=device,
151
+ num_return_sequences=num_return_sequences,
152
+ )
153
+ if is_causal_language_model and generation_strategy is not None:
154
+ print('Loading pipeline for causal language models...')
155
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, padding_side='left')
156
+ return pipeline(
157
+ "text-generation",
158
+ model=model_name,
159
+ tokenizer=tokenizer,
160
+ token=token,
161
+ device=device,
162
+ generation_config=get_generation_config(generation_strategy),
163
+ )
164
+ if not is_causal_language_model and generation_strategy is None:
165
+ print('Loading pipeline for sequence-to-sequence models...')
166
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
167
+ return pipeline(
168
+ "text2text-generation",
169
+ model=model_name,
170
+ tokenizer=tokenizer,
171
+ token=token,
172
+ device=device,
173
+ )
174
+ if not is_causal_language_model and generation_strategy is not None:
175
+ print('Loading pipeline for sequence-to-sequence models...')
176
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
177
+ return pipeline(
178
+ "text2text-generation",
179
+ model=model_name,
180
+ tokenizer=tokenizer,
181
+ token=token,
182
+ device=device,
183
+ generation_config=get_generation_config(generation_strategy),
184
+ )
185
+
186
+ def run_causal_pipeline(
187
+ pipe: pipeline,
188
+ test_ds: Dataset,
189
+ batch_size: int,
190
+ smiles_column: str = 'prompt',
191
+ ) -> List[Dict[str, str]]:
192
+ """ Run the pipeline for causal language models and return the predictions.
193
+
194
+ Args:
195
+ pipe (pipeline): The pipeline object to use for generating predictions.
196
+ test_ds (Dataset): The test dataset to generate predictions for.
197
+ batch_size (int): The batch size to use for generating predictions.
198
+
199
+ Returns:
200
+ List[Dict[str, str]]: A list of dictionaries containing the predictions.
201
+ """
202
+ preds = []
203
+ for pred in tqdm(pipe(KeyDataset(test_ds, smiles_column), batch_size=batch_size, max_length=512), total=len(test_ds) // batch_size):
204
+ generated_text = [p['generated_text'] for p in pred]
205
+ # Remove the prompt from the generated text
206
+ generated_text = ['.'.join(t.split('.')[1:]) for t in generated_text]
207
+ # Add the predictions to the list
208
+ p = {f'pred_n{i}': t for i, t in enumerate(generated_text)}
209
+ preds.append(p)
210
+ return preds
211
+
212
+ def run_seq2seq_pipeline(
213
+ pipe: pipeline,
214
+ test_ds: Dataset,
215
+ batch_size: int,
216
+ smiles_column: str = 'text',
217
+ ) -> List[Dict[str, str]]:
218
+ """ Run the pipeline for sequence-to-sequence models and return the predictions.
219
+
220
+ Args:
221
+ pipe (pipeline): The pipeline object to use for generating predictions.
222
+ test_ds (Dataset): The test dataset to generate predictions for.
223
+ batch_size (int): The batch size to use for generating predictions.
224
+
225
+ Returns:
226
+ List[Dict[str, str]]: A list of dictionaries containing the predictions.
227
+ """
228
+ preds = []
229
+ for pred in tqdm(pipe(KeyDataset(test_ds, smiles_column), batch_size=batch_size, max_length=512), total=len(test_ds) // batch_size):
230
+ p = {f'pred_n{i}': p['generated_text'] for i, p in enumerate(pred)}
231
+ preds.append(p)
232
+ return preds
233
+
234
+ def run_pipeline(
235
+ pipe: pipeline,
236
+ test_ds: Dataset,
237
+ batch_size: int,
238
+ is_causal_language_model: bool,
239
+ smiles_column: str = 'text',
240
+ ) -> List[Dict[str, str]]:
241
+ """ Run the pipeline and return the predictions.
242
+
243
+ Args:
244
+ pipe (pipeline): The pipeline object to use for generating predictions.
245
+ test_ds (Dataset): The test dataset to generate predictions for.
246
+ batch_size (int): The batch size to use for generating predictions.
247
+ is_causal_language_model (bool): Whether the model is a causal language model or not.
248
+ smiles_column (str): The column name in the dataset that contains the SMILES strings. Default: 'text'
249
+
250
+ Returns:
251
+ List[Dict[str, str]]: A list of dictionaries containing the beam-size predictions in the format: [{'pred_n0': 'prediction_0', 'pred_n1': 'prediction_1', ...}, ...]
252
+ """
253
+ if is_causal_language_model:
254
+ return run_causal_pipeline(pipe, test_ds, batch_size, smiles_column)
255
+ else:
256
+ return run_seq2seq_pipeline(pipe, test_ds, batch_size, smiles_column)
protac_splitter/llms/training.py ADDED
@@ -0,0 +1,869 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Dict, Any, Callable, Tuple, Union
3
+ from functools import partial
4
+ import subprocess
5
+ import copy
6
+ import datetime
7
+ import logging
8
+ import math
9
+ import json
10
+
11
+ import torch
12
+ import numpy as np
13
+ import huggingface_hub as hf
14
+ from transformers import (
15
+ Trainer,
16
+ TrainingArguments,
17
+ Seq2SeqTrainer,
18
+ Seq2SeqTrainingArguments,
19
+ DataCollatorForSeq2Seq,
20
+ DataCollatorForLanguageModeling,
21
+ AutoTokenizer,
22
+ GenerationConfig,
23
+ TrainerCallback,
24
+ set_seed,
25
+ )
26
+ from accelerate.utils import write_basic_config
27
+ from accelerate import Accelerator
28
+
29
+ import optuna
30
+ from optuna.samplers import QMCSampler
31
+ from optuna.pruners import (
32
+ BasePruner,
33
+ HyperbandPruner,
34
+ ThresholdPruner,
35
+ PatientPruner,
36
+ MedianPruner,
37
+ )
38
+ from optuna.study._study_direction import StudyDirection
39
+
40
+ from .data_utils import load_tokenized_dataset
41
+ from .evaluation import decode_and_get_metrics
42
+ from .hf_utils import (
43
+ create_hf_repository,
44
+ delete_hf_repository,
45
+ repo_exists,
46
+ upload_single_file,
47
+ )
48
+ from .model_utils import get_encoder_decoder_model, get_causal_model
49
+
50
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use GPU with index 0
51
+ # logging.basicConfig(level=logging.DEBUG)
52
+
53
+ class PrintStepCallback(TrainerCallback):
54
+
55
+ def on_init_end(self, args, state, control, **kwargs):
56
+ print(f"[{datetime.datetime.now()}] Initialization complete. Training is starting.")
57
+
58
+ def on_step_begin(self, args, state, control, **kwargs):
59
+ if state.global_step % args.logging_steps == 0:
60
+ print(f"[{datetime.datetime.now()}] Global step: {state.global_step:,}")
61
+
62
+
63
+ class ScoreMetric:
64
+
65
+ def __init__(self):
66
+ self.batch_scores = []
67
+
68
+ def update(self, scores):
69
+ self.batch_scores.append(scores)
70
+
71
+ def compute(self):
72
+ all_labels = set()
73
+ for scores in self.batch_scores:
74
+ all_labels.update(scores.keys())
75
+
76
+ aggregate_scores = {}
77
+ for k in all_labels:
78
+ scores = [s.get(k, np.nan) for s in self.batch_scores]
79
+ print(f"{k}: {np.nanmean(scores):.4f}")
80
+ aggregate_scores[k] = np.nanmean(scores)
81
+
82
+ self.batch_scores = []
83
+ return aggregate_scores
84
+
85
+
86
+ score_metric = ScoreMetric()
87
+ hp_score_metric = ScoreMetric()
88
+
89
+
90
+ class WrappedEarlyStoppingPruner(BasePruner):
91
+ """
92
+ Pruner that wraps another pruner and checks if the trial should be pruned.
93
+ It first evaluates the wrapped pruner and, if the wrapped pruner suggests
94
+ pruning, prune. Otherwise, evaluates based on a patience threshold with a
95
+ tolerance (min_delta) and eventually prunes.
96
+
97
+ Args:
98
+ wrapped_pruner:
99
+ Wrapped pruner to check first. Pruning is only applied if this pruner recommends it.
100
+ patience:
101
+ Number of steps to wait for an improvement before pruning.
102
+ min_delta:
103
+ Minimum improvement required to reset patience.
104
+ n_warmup_steps:
105
+ Number of initial steps to skip the patience check.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ wrapped_pruner: BasePruner,
111
+ patience: int,
112
+ min_delta: float = 0.0,
113
+ n_warmup_steps: int = 0,
114
+ ) -> None:
115
+ if wrapped_pruner is None or not isinstance(wrapped_pruner, BasePruner):
116
+ raise ValueError(f"wrapped_pruner must be an instance of BasePruner but got {wrapped_pruner}.")
117
+ if patience < 0:
118
+ raise ValueError(f"patience cannot be negative but got {patience}.")
119
+ if min_delta < 0:
120
+ raise ValueError(f"min_delta cannot be negative but got {min_delta}.")
121
+ if n_warmup_steps < 0:
122
+ raise ValueError(f"n_warmup_steps cannot be negative but got {n_warmup_steps}.")
123
+
124
+ self._wrapped_pruner = wrapped_pruner
125
+ self._patience = patience
126
+ self._min_delta = min_delta
127
+ self._n_warmup_steps = n_warmup_steps
128
+
129
+ def prune(self, study: "optuna.study.Study", trial: "optuna.trial.FrozenTrial") -> bool:
130
+ step = trial.last_step
131
+ if step is None:
132
+ return False
133
+
134
+ intermediate_values = trial.intermediate_values
135
+ steps = np.asarray(list(intermediate_values.keys()))
136
+
137
+ # If there are insufficient steps or we are still in the warmup phase, do not prune.
138
+ if steps.size <= self._patience + 1 or step < self._n_warmup_steps:
139
+ return False
140
+
141
+ # First, check the wrapped pruner. If it suggests pruning, prune.
142
+ if self._wrapped_pruner.prune(study, trial):
143
+ return True
144
+
145
+ steps.sort()
146
+
147
+ # This is the score patience steps ago
148
+ steps_before_patience = steps[: -self._patience - 1]
149
+ scores_before_patience = np.asarray(
150
+ list(intermediate_values[step] for step in steps_before_patience)
151
+ )
152
+
153
+ # And these are the scores after that
154
+ steps_after_patience = steps[-self._patience - 1 :]
155
+ scores_after_patience = np.asarray(
156
+ list(intermediate_values[step] for step in steps_after_patience)
157
+ )
158
+
159
+ direction = study.direction
160
+ if direction == StudyDirection.MINIMIZE:
161
+ should_prune = np.nanmin(scores_before_patience) + self._min_delta < np.nanmin(
162
+ scores_after_patience
163
+ )
164
+ else:
165
+ should_prune = np.nanmax(scores_before_patience) - self._min_delta > np.nanmax(
166
+ scores_after_patience
167
+ )
168
+
169
+ return should_prune
170
+
171
+
172
+ def get_lr_scheduler_kwargs(lr_scheduler_type: str) -> Dict[str, Any]:
173
+ """ Returns the default learning rate scheduler kwargs for a given type.
174
+
175
+ Reference: https://huggingface.co/docs/timm/en/reference/schedulers
176
+
177
+ Args:
178
+ lr_scheduler_type (str): The type of the learning rate scheduler.
179
+
180
+ Returns:
181
+ Dict[str, Any]: The default learning rate scheduler kwargs.
182
+ """
183
+ if lr_scheduler_type == "cosine":
184
+ return {}
185
+ elif lr_scheduler_type == "cosine_with_restarts":
186
+ return {"num_cycles": 3}
187
+ elif lr_scheduler_type == "cosine_with_min_lr":
188
+ return {}
189
+ elif lr_scheduler_type == "polynomial":
190
+ return {"power": 1.0}
191
+ elif lr_scheduler_type == "reduce_lr_on_plateau":
192
+ return {"min_lr": 1e-6}
193
+ else:
194
+ raise ValueError(f"Unknown learning rate scheduler type: '{lr_scheduler_type}'")
195
+
196
+
197
+ def get_best_hyperparameters(
198
+ model_init: Callable,
199
+ tokenizer: AutoTokenizer,
200
+ data_collator: Union[DataCollatorForSeq2Seq, DataCollatorForLanguageModeling],
201
+ compute_metrics: Callable,
202
+ dataset_tokenized: Dict[str, Any],
203
+ training_args: Dict[str, Any],
204
+ num_optuna_trials: int,
205
+ lr_scheduler_type: Optional[str] = None,
206
+ causal_language_modeling: bool = False,
207
+ all_fragments_as_labels: bool = True,
208
+ linkers_only_as_labels: bool = False,
209
+ ) -> Tuple[float, Dict[str, Any], Dict[str, Any]]:
210
+ """Runs an Optuna hyperparameter search to find the best hyperparameters.
211
+
212
+ Args:
213
+ model_init (Callable): The model initialization function.
214
+ tokenizer (AutoTokenizer): The tokenizer.
215
+ data_collator (DataCollatorForSeq2Seq): The data collator.
216
+ compute_metrics (Callable): The compute metrics function.
217
+ dataset_tokenized (Dict[str, Any]): The tokenized dataset.
218
+ training_args (Dict[str, Any]): The training arguments.
219
+ num_optuna_trials (int): The number of Optuna trials.
220
+
221
+ Returns:
222
+ Tuple[float, Dict[str, Any], Dict[str, Any]]: The best objective, the best hyperparameters, and the best training arguments.
223
+ """
224
+ def optuna_hp_space(trial):
225
+ # NOTE: Tuning generation config is not implemented yet, please refer to this issue: https://github.com/huggingface/transformers/issues/33755
226
+ # Suggest hparams "shared" across all scheduler types
227
+ # learning_rate = trial.suggest_float("learning_rate", 1e-6, 1e-3, log=True)
228
+ # warmup_ratio = trial.suggest_float("warmup_ratio", 0.01, 0.1, step=0.01)
229
+
230
+ # Restrict learning rate closer to best-performing values
231
+ learning_rate = trial.suggest_float("learning_rate", 5e-6, 2e-4, log=True) # Previously 1e-6 to 1e-3
232
+
233
+ # Slightly adjust warmup ratio to avoid extreme values
234
+ warmup_ratio = trial.suggest_float("warmup_ratio", 0.02, 0.06, step=0.01) # Previously 0.01 to 0.1
235
+
236
+
237
+ # NOTE: We might want to use QMCSampler instead of TPESampler, which
238
+ # doesn't support categorical parameters. Categories can be encoded as
239
+ # integers and then decoded back to the original categories.
240
+
241
+ # NOTE: According to the GitHub code, the number of training and warmup
242
+ # steps for the scheduler types are automatically set, we don't need to
243
+ # pass them in the lr_scheduler_kwargs.
244
+
245
+ if lr_scheduler_type is None:
246
+ lr_scheduler_types = ["cosine", "cosine_with_restarts", "reduce_lr_on_plateau"] # "cosine_with_min_lr", "polynomial"
247
+ suggested_lr_sched = trial.suggest_int("lr_scheduler_type", 0, len(lr_scheduler_types) - 1)
248
+ suggested_lr_sched = lr_scheduler_types[suggested_lr_sched]
249
+ lr_scheduler_kwargs = get_lr_scheduler_kwargs(lr_scheduler_type)
250
+ elif lr_scheduler_type == "cosine":
251
+ lr_scheduler_kwargs = {
252
+ "num_cycles": trial.suggest_float("num_cycles", 0.5, 10, step=0.5),
253
+ }
254
+ elif lr_scheduler_type == "cosine_with_restarts":
255
+ lr_scheduler_kwargs = {
256
+ "num_cycles": trial.suggest_int("num_cycles", 1, 10, step=1),
257
+ }
258
+ elif lr_scheduler_type == "reduce_lr_on_plateau":
259
+ lr_scheduler_kwargs = {
260
+ "min_lr": trial.suggest_float("min_lr", 1e-10, 1e-8, log=True), # Previously 1e-12 to 1e-9
261
+ "factor": trial.suggest_float("factor", 0.8, 0.98, step=0.01), # Previously 0.1 to 0.99
262
+ }
263
+
264
+ return {
265
+ "lr_scheduler_kwargs": lr_scheduler_kwargs,
266
+ "lr_scheduler_type": lr_scheduler_type if lr_scheduler_type is not None else suggested_lr_sched,
267
+ "learning_rate": learning_rate,
268
+ "warmup_ratio": warmup_ratio,
269
+ }
270
+
271
+ if causal_language_modeling:
272
+ def compute_objective(metrics: Dict[str, float]):
273
+ # NOTE: We want to minimize the model perplexity, which is the
274
+ # exponential of the negative log-likelihood loss. Optuna is setup
275
+ # to maximize the objective, so we return the negative perplexity.
276
+ return -math.exp(metrics["eval_loss"])
277
+ else:
278
+ if all_fragments_as_labels:
279
+ def compute_objective(metrics: Dict[str, float]):
280
+ # NOTE: Having a higher eval_reassembly score should also correspond
281
+ # to a low eval loss, so we just focus on the reassembly score.
282
+ return metrics["eval_all_ligands_equal"]
283
+ else:
284
+ if linkers_only_as_labels:
285
+ def compute_objective(metrics: Dict[str, float]):
286
+ return metrics["eval_linker_equal"]
287
+ else:
288
+ def compute_objective(metrics: Dict[str, float]):
289
+ return metrics["eval_e3_equal"] + metrics["eval_poi_equal"]
290
+
291
+ def hp_name(trial: Any) -> str:
292
+ trial_name = f"trial-number={trial.number}"
293
+ for hparam, value in trial.params.items():
294
+ # Check if the value is a float and round it to 3 decimals
295
+ if hparam == "learning_rate":
296
+ value = f"{value:.1e}"
297
+ elif isinstance(value, float):
298
+ value = f"{value:.3f}"
299
+ trial_name += f"-{hparam}={value}"
300
+ return trial_name
301
+
302
+ # Override the training steps
303
+ hp_training_args = copy.deepcopy(training_args)
304
+ hp_training_args["num_train_epochs"] = -1
305
+ hp_training_args["max_steps"] = 10_000
306
+ hp_training_args["eval_steps"] = 2500
307
+ hp_training_args["eval_delay"] = 5000 # TODO: Double check if this is needed
308
+ hp_training_args["logging_steps"] = 500
309
+ hp_training_args["save_steps"] = 5000
310
+ if not causal_language_modeling:
311
+ # Use greedy decoding for the evaluation during HP search
312
+ hp_training_args["generation_config"] = GenerationConfig(
313
+ max_length=512,
314
+ max_new_tokens=512,
315
+ do_sample=False,
316
+ num_beams=1,
317
+ )
318
+
319
+ print("Hyperparameter search training arguments:")
320
+ for k, v in hp_training_args.items():
321
+ if 'token' in k:
322
+ continue
323
+ print(f" - {k}: {v}")
324
+
325
+ if causal_language_modeling:
326
+ TrainerClass = Trainer
327
+ TrainingArgumentsClass = TrainingArguments
328
+ else:
329
+ TrainerClass = Seq2SeqTrainer
330
+ TrainingArgumentsClass = Seq2SeqTrainingArguments
331
+
332
+ # Setup a "fake" Trainer for the hyperparameter search
333
+ trainer = TrainerClass(
334
+ model_init=model_init,
335
+ tokenizer=tokenizer,
336
+ data_collator=data_collator,
337
+ args=TrainingArgumentsClass(**hp_training_args),
338
+ compute_metrics=compute_metrics,
339
+ train_dataset=dataset_tokenized["train"],
340
+ eval_dataset=dataset_tokenized["validation"],
341
+ callbacks=[PrintStepCallback],
342
+ )
343
+
344
+ # Setup the Optuna pruner and sampler
345
+ max_warmup_ratio = 0.1
346
+ pruner = WrappedEarlyStoppingPruner(
347
+ MedianPruner(
348
+ n_startup_trials=0,
349
+ interval_steps=1,
350
+ n_warmup_steps=int(max_warmup_ratio * hp_training_args["max_steps"]),
351
+ ),
352
+ patience=5, # Check every 5000 training steps
353
+ min_delta=0.01,
354
+ n_warmup_steps=int(max_warmup_ratio * hp_training_args["max_steps"]),
355
+ )
356
+ sampler = QMCSampler(scramble=True, seed=42)
357
+
358
+ # NOTE: The Trainer will return a BestRun object, not the Optuna trial
359
+ best_run = trainer.hyperparameter_search(
360
+ direction="maximize",
361
+ backend="optuna",
362
+ hp_space=optuna_hp_space,
363
+ hp_name=hp_name,
364
+ n_trials=num_optuna_trials,
365
+ compute_objective=compute_objective, # Default: Will sum over all metrics but loss
366
+ sampler=sampler,
367
+ pruner=pruner,
368
+ )
369
+
370
+ # Set the best hyperparameters in the original Trainer arguments
371
+ try:
372
+ print("-" * 80)
373
+ print(f"Best trial objective: {best_run.objective:.4f}. Summary: {best_run.run_summary}")
374
+ except Exception as e:
375
+ print(e)
376
+ print("WARNING. Best trial objective could not be printed.")
377
+
378
+ return best_run, hp_training_args
379
+
380
+
381
+ def train_model(
382
+ model_id: str,
383
+ ds_name: str,
384
+ ds_config: str = 'default',
385
+ learning_rate: float = 5e-5,
386
+ max_steps: int = -1,
387
+ num_train_epochs: int = 40,
388
+ batch_size: int = 128,
389
+ batch_size_tokenizer: int = 512,
390
+ gradient_accumulation_steps: int = 4,
391
+ hub_token: Optional[str] = None,
392
+ organization: Optional[str] = None,
393
+ output_dir: str = "./models/",
394
+ tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
395
+ pretrained_encoder: str = "seyonec/ChemBERTa-zinc-base-v1",
396
+ pretrained_decoder: str = "seyonec/ChemBERTa-zinc-base-v1",
397
+ encoder_max_length: int = 512,
398
+ decoder_max_length: int = 512,
399
+ tie_encoder_decoder: bool = False,
400
+ delete_repo_if_exists: bool = False,
401
+ delete_local_repo_if_exists: bool = False,
402
+ training_args: Optional[Dict[str, Any]] = None,
403
+ resume_from_checkpoint: Optional[str] = None,
404
+ num_optuna_trials: int = 0,
405
+ num_proc_map: int = 1,
406
+ per_device_train_batch_size: Optional[int] = None,
407
+ per_device_eval_batch_size: Optional[int] = None,
408
+ lr_scheduler_type: Optional[str] = None,
409
+ cache_dir: Optional[str] = None,
410
+ randomize_smiles: bool = False,
411
+ randomize_smiles_prob: float = 0.0,
412
+ all_fragments_as_labels: bool = True,
413
+ linkers_only_as_labels: bool = False,
414
+ warmup_ratio: Optional[float] = None,
415
+ num_cycles: Optional[int] = None,
416
+ warmup_steps: Optional[int] = None,
417
+ causal_language_modeling: bool = False,
418
+ train_size_ratio: float = 1.0,
419
+ training_args_bin: Optional[str] = None,
420
+ ):
421
+ """Trains a model on a given dataset.
422
+
423
+ Args:
424
+ model_id (str): The name of the model to be trained.
425
+ ds_name (str): The name of the dataset to be used for training.
426
+ ds_config (str, optional): The name of the dataset configuration to be used for training. Defaults to 'default'.
427
+ learning_rate (float, optional): The learning rate. Defaults to 5e-5.
428
+ max_steps (int, optional): The maximum number of training steps. Defaults to -1.
429
+ num_train_epochs (int, optional): The number of training epochs. Defaults to 40.
430
+ batch_size (int, optional): The batch size. Defaults to 128.
431
+ batch_size_tokenizer (int, optional): The batch size for the tokenizer. Defaults to 512.
432
+ gradient_accumulation_steps (int, optional): The number of gradient accumulation steps. Defaults to 4.
433
+ hub_token (Optional[str], optional): The Hugging Face token. Defaults to None.
434
+ organization (Optional[str], optional): The Hugging Face organization. Defaults to None.
435
+ output_dir (str, optional): The output directory. Defaults to "./models/".
436
+ tokenizer (AutoTokenizer | str, optional): The tokenizer. Defaults to "seyonec/ChemBERTa-zinc-base-v1".
437
+ pretrained_encoder (str, optional): The name of the pretrained encoder. Defaults to "seyonec/ChemBERTa-zinc-base-v1".
438
+ pretrained_decoder (str, optional): The name of the pretrained decoder. Defaults to "seyonec/ChemBERTa-zinc-base-v1".
439
+ encoder_max_length (int, optional): The maximum length of the encoder. Defaults to 256.
440
+ decoder_max_length (int, optional): The maximum length of the decoder. Defaults to 256.
441
+ delete_repo_if_exists (bool, optional): Whether to delete the repository first. Defaults to False.
442
+ training_args (Optional[Seq2SeqTrainingArguments], optional): The training arguments. Defaults to None.
443
+ resume_from_checkpoint (Optional[str], optional): The checkpoint to resume training from. Defaults to None.
444
+ num_optuna_trials (int, optional): The number of Optuna trials. Defaults to 0, i.e., no Optuna hyperparameter search.
445
+ """
446
+ set_seed(42)
447
+
448
+ # if torch.cuda.is_available():
449
+ # write_basic_config(mixed_precision='fp16')
450
+ accelerator = Accelerator()
451
+ accelerator.print(f"Accelerator state from the current environment:\n{accelerator.state}")
452
+
453
+ # Check if resume_from_checkpoint exists and it's a file
454
+ if resume_from_checkpoint is not None:
455
+ # Check if the checkpoint exists: it can be either a file or a directory
456
+ if not os.path.exists(resume_from_checkpoint):
457
+ raise ValueError(f"Checkpoint file '{resume_from_checkpoint}' does not exist.")
458
+
459
+ if hub_token is not None:
460
+ hf.login(token=hub_token)
461
+
462
+ # Setup output directory and Hugging Face repository
463
+ output_dir += f"/{model_id}"
464
+ if organization is not None:
465
+ hub_model_id = f"{organization}/{model_id}"
466
+ if delete_local_repo_if_exists and os.path.exists(output_dir):
467
+ subprocess.run(["rm", "-rf", output_dir])
468
+ if not os.path.exists(output_dir):
469
+ print(f"Local repository '{output_dir}' deleted.")
470
+ else:
471
+ print(f"Local repository '{output_dir}' could not be deleted.")
472
+ return
473
+ if delete_repo_if_exists and repo_exists(hub_model_id, token=hub_token):
474
+ delete_hf_repository(repo_id=hub_model_id, token=hub_token, missing_ok=True)
475
+ print(f"Repository '{hub_model_id}' deleted.")
476
+
477
+ repo_url = create_hf_repository(
478
+ repo_id=hub_model_id,
479
+ repo_type="model",
480
+ exist_ok=True,
481
+ private=True,
482
+ token=hub_token,
483
+ )
484
+ print(f"Repository '{hub_model_id}' created at URL: {repo_url}")
485
+ else:
486
+ hub_model_id = None
487
+ print(f"Hub model ID: {hub_model_id}")
488
+
489
+ if isinstance(tokenizer, str):
490
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
491
+ elif tokenizer is None:
492
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_encoder)
493
+
494
+ # Load the tokenized dataset
495
+ print("Loading tokenized dataset.")
496
+ dataset_tokenized = load_tokenized_dataset(
497
+ ds_name,
498
+ ds_config,
499
+ tokenizer,
500
+ batch_size_tokenizer,
501
+ encoder_max_length,
502
+ decoder_max_length,
503
+ token=hub_token,
504
+ num_proc_map=num_proc_map,
505
+ cache_dir=cache_dir,
506
+ randomize_smiles=randomize_smiles,
507
+ randomize_smiles_prob=randomize_smiles_prob,
508
+ all_fragments_as_labels=all_fragments_as_labels,
509
+ linkers_only_as_labels=linkers_only_as_labels,
510
+ causal_language_modeling=causal_language_modeling,
511
+ train_size_ratio=train_size_ratio,
512
+ )
513
+ print("Dataset loaded.")
514
+
515
+ if causal_language_modeling:
516
+ # Setup the model for `model_init` in the Trainer
517
+ model_lambda = lambda: get_causal_model(
518
+ pretrained_model=pretrained_decoder,
519
+ )
520
+
521
+ # Setup the data collator, which will efficiently pad the inputs and targets
522
+ data_collator = DataCollatorForLanguageModeling(
523
+ tokenizer,
524
+ mlm=False,
525
+ pad_to_multiple_of=8, # Default: None, Original: 8
526
+ )
527
+ else:
528
+ # Precompute a "length" column for the dataset using the map function
529
+ def add_length(x):
530
+ x["length"] = len(x["input_ids"])
531
+ return x
532
+ dataset_tokenized = dataset_tokenized.map(
533
+ add_length,
534
+ num_proc=num_proc_map,
535
+ )
536
+
537
+ # Setup the model for `model_init` in the Trainer
538
+ model_lambda = lambda: get_encoder_decoder_model(
539
+ pretrained_encoder=pretrained_encoder,
540
+ pretrained_decoder=pretrained_decoder,
541
+ max_length=encoder_max_length,
542
+ tie_encoder_decoder=tie_encoder_decoder,
543
+ )
544
+
545
+ # Setup the data collator, which will efficiently pad the inputs and targets
546
+ data_collator = DataCollatorForSeq2Seq(
547
+ tokenizer,
548
+ model=model_lambda(),
549
+ pad_to_multiple_of=32, # Default: None, Original: 8
550
+ )
551
+
552
+ # Setup the training arguments
553
+ if per_device_train_batch_size is None:
554
+ per_device_train_batch_size = batch_size // gradient_accumulation_steps
555
+ if per_device_eval_batch_size is None:
556
+ per_device_eval_batch_size = batch_size // gradient_accumulation_steps
557
+ if training_args is None:
558
+ training_args = {
559
+ "output_dir": output_dir,
560
+ # Optimizer-related configs
561
+ "learning_rate": learning_rate,
562
+ "optim": "adamw_torch",
563
+ "lr_scheduler_type": "cosine" if lr_scheduler_type is None else lr_scheduler_type,
564
+ "lr_scheduler_kwargs": get_lr_scheduler_kwargs(lr_scheduler_type),
565
+ # "warmup_steps": int(0.08 * 10_000), # NOTE: ChemFormer: 8000
566
+ # "warmup_ratio": warmup_ratio,
567
+ "adam_beta1": 0.9, # NOTE: ChemFormer: 0.9
568
+ "adam_beta2": 0.999, # NOTE: ChemFormer: 0.999
569
+ "adam_epsilon": 1e-8, # Default: 1e-8
570
+ # Batch size, device, and performance optimizations configs
571
+ "batch_eval_metrics": False, # Default: False
572
+ "group_by_length": True,
573
+ "per_device_train_batch_size": per_device_train_batch_size,
574
+ "per_device_eval_batch_size": per_device_eval_batch_size,
575
+ "gradient_accumulation_steps": gradient_accumulation_steps,
576
+ "auto_find_batch_size": True,
577
+ "fp16": True if torch.cuda.is_available() else False,
578
+ "fp16_full_eval" : True, # Enable full BF16 evaluation for efficiency
579
+ "half_precision_backend" : "auto", # Let Hugging Face decide the best backend. Default: "auto"
580
+ "use_cpu": False, # Default: False
581
+ "dataloader_num_workers": 8, # Default: 0 (main process only)
582
+ "dataloader_prefetch_factor": None, # Default: None
583
+ # Evaluation and checkpointing configs
584
+ "max_steps": max_steps,
585
+ "num_train_epochs": num_train_epochs,
586
+ "save_steps": 20_000, # NOTE: 200
587
+ "save_strategy": "steps",
588
+ "eval_steps": 20_000, # NOTE: 500
589
+ "eval_delay": max(int(max(max_steps, num_train_epochs) * 0.7), 0), # Default: None
590
+ "eval_strategy": "steps", # NOTE: "evaluation_strategy" is deprecated.
591
+ "save_total_limit": 2, # This will save both the best and the last trainer checkpoint
592
+ "load_best_model_at_end": True,
593
+ "metric_for_best_model": "all_ligands_equal",
594
+ "include_inputs_for_metrics": True,
595
+ "eval_on_start": False, # Default: False
596
+ # Logging configs
597
+ "log_level": "debug",
598
+ "logging_steps": 5000,
599
+ "disable_tqdm": True,
600
+ "report_to": ["tensorboard"],
601
+ "save_only_model": False, # Default: False
602
+ # Hub information configs
603
+ "push_to_hub": hub_model_id is not None, # NOTE: Also manually done further down
604
+ "push_to_hub_model_id": model_id,
605
+ "push_to_hub_organization": organization,
606
+ "hub_model_id": hub_model_id,
607
+ "hub_token": hub_token,
608
+ "hub_strategy": "checkpoint", # NOTE: Allows to resume training from last checkpoint
609
+ "hub_private_repo": True,
610
+ # Other configs
611
+ "seed": 42,
612
+ "data_seed": 42,
613
+ }
614
+ if 'num_cycles' in training_args["lr_scheduler_kwargs"] and num_cycles is not None:
615
+ training_args["lr_scheduler_kwargs"]["num_cycles"] = num_cycles
616
+ if warmup_ratio is not None:
617
+ training_args["warmup_ratio"] = warmup_ratio
618
+ if warmup_steps is not None:
619
+ training_args["warmup_steps"] = warmup_steps
620
+
621
+ # Add Generation configs
622
+ if causal_language_modeling:
623
+ training_args["metric_for_best_model"] = "eval_loss"
624
+ else:
625
+ generation_config = GenerationConfig(
626
+ max_length=512,
627
+ max_new_tokens=512,
628
+ do_sample=True,
629
+ num_beams=5,
630
+ temperature=1.0,
631
+ )
632
+ training_args["generation_config"] = generation_config
633
+ training_args["predict_with_generate"] = True
634
+ training_args["generation_config"] = generation_config
635
+ training_args["generation_max_length"] = 512
636
+
637
+ print("Training arguments:")
638
+ for k, v in training_args.items():
639
+ if 'token' in k:
640
+ continue
641
+ print(f" - {k}: {v}")
642
+
643
+ # Modify the training arguments with Optuna hyperparameter search
644
+ if num_optuna_trials > 0:
645
+ # Setup the compute_metrics function for the hyperparameter search
646
+ hp_compute_metrics = partial(
647
+ decode_and_get_metrics,
648
+ tokenizer=tokenizer,
649
+ compute_rdkit_metrics=False,
650
+ compute_graph_metrics=False,
651
+ num_proc=num_proc_map,
652
+ causal_language_modeling=causal_language_modeling,
653
+ )
654
+
655
+ # Run the HP search (and update the training_args accordingly)
656
+ best_run, hp_training_args = get_best_hyperparameters(
657
+ model_init=model_lambda,
658
+ tokenizer=tokenizer,
659
+ data_collator=data_collator,
660
+ compute_metrics=hp_compute_metrics,
661
+ dataset_tokenized=dataset_tokenized,
662
+ training_args=copy.deepcopy(training_args),
663
+ lr_scheduler_type=lr_scheduler_type,
664
+ num_optuna_trials=num_optuna_trials,
665
+ causal_language_modeling=causal_language_modeling,
666
+ all_fragments_as_labels=all_fragments_as_labels,
667
+ linkers_only_as_labels=linkers_only_as_labels,
668
+ )
669
+ best_objective = best_run.objective
670
+ best_trial_number = best_run.run_id
671
+ best_hparams = best_run.hyperparameters
672
+
673
+ # Save to output directory the best hyperparameters
674
+ with open(f"{output_dir}/best_hyperparameters.md", "w") as f:
675
+ f.write(f"Number of Optuna trials: {num_optuna_trials}\n\n")
676
+ f.write(f"Best trial objective: {best_objective:.4f} (best trial number: {best_trial_number})\n\n")
677
+
678
+ f.write("Best hyperparameters:\n")
679
+ for hparam, value in best_hparams.items():
680
+ f.write(f"- {hparam}: {value}\n")
681
+ f.write("\n")
682
+
683
+ f.write("Training arguments:\n")
684
+ for hparam, value in hp_training_args.items():
685
+ if "token" in hparam:
686
+ continue
687
+ elif isinstance(value, str):
688
+ if 'hf_' in value:
689
+ continue
690
+ f.write(f"- {hparam}: {value}\n")
691
+
692
+ # Open the file and remove any line that might contain the token
693
+ with open(f"{output_dir}/best_hyperparameters.md", "r") as f:
694
+ lines = f.readlines()
695
+ with open(f"{output_dir}/best_hyperparameters.md", "w") as f:
696
+ for line in lines:
697
+ if "hf_" in line:
698
+ continue
699
+ f.write(line)
700
+ print(f"Best hyperparameters saved to '{output_dir}/best_hyperparameters.md'.")
701
+
702
+ if hub_model_id is not None:
703
+ upload_single_file(
704
+ path_or_fileobj=f"{output_dir}/best_hyperparameters.md",
705
+ path_in_repo="best_hyperparameters.md",
706
+ repo_id=hub_model_id,
707
+ token=hub_token,
708
+ )
709
+
710
+ # Save the best_hparams to a JSON file
711
+ with open(f"{output_dir}/best_hyperparameters.json", "w") as f:
712
+ json.dump(best_hparams, f, indent=4)
713
+ print(f"Best hyperparameters saved to '{output_dir}/best_hyperparameters.json'.")
714
+
715
+ if hub_model_id is not None:
716
+ upload_single_file(
717
+ path_or_fileobj=f"{output_dir}/best_hyperparameters.json",
718
+ path_in_repo="best_hyperparameters.json",
719
+ repo_id=hub_model_id,
720
+ token=hub_token,
721
+ )
722
+
723
+ # Update the training arguments with the best hyperparameters
724
+ hp_specific_args = [
725
+ "num_train_epochs",
726
+ "max_steps",
727
+ "eval_steps",
728
+ "eval_delay",
729
+ "logging_steps",
730
+ "save_steps",
731
+ "generation_config",
732
+ ]
733
+ for k, v in hp_training_args.items():
734
+ # Skip the specific arguments set/modifed by the HP search
735
+ if k in hp_specific_args:
736
+ continue
737
+ training_args[k] = v
738
+
739
+ # Update the num_cycles according to the original max_steps
740
+ lr_scheduler_kwargs = hp_training_args["lr_scheduler_kwargs"]
741
+
742
+ if "num_cycles" in lr_scheduler_kwargs:
743
+ hp_num_cycles = lr_scheduler_kwargs["num_cycles"]
744
+ hp_max_steps = hp_training_args["max_steps"]
745
+
746
+ # Adjust/scale the max_cycles according to the number of steps
747
+ if hp_max_steps > 0:
748
+ hp_cycle_ratio = hp_num_cycles / hp_max_steps
749
+ num_cycles = int(hp_cycle_ratio * max_steps)
750
+ training_args["lr_scheduler_kwargs"]["num_cycles"] = num_cycles
751
+ print(f"Adjusted number of cycles: {num_cycles}")
752
+
753
+ # Adjust the warmup steps according to the original max_steps
754
+ if "warmup_ratio" in hp_training_args:
755
+ hp_warmup_ratio = hp_training_args["warmup_ratio"]
756
+ hp_max_steps = hp_training_args["max_steps"]
757
+ warmup_steps = int(hp_warmup_ratio * hp_max_steps)
758
+ warmup_ratio = warmup_steps / max_steps
759
+ training_args["warmup_steps"] = warmup_steps
760
+ training_args["warmup_ratio"] = warmup_ratio
761
+
762
+ print("Training arguments updated with the best hyperparameters:")
763
+ for k, v in training_args.items():
764
+ if 'token' in k:
765
+ continue
766
+ print(f" - {k}: {v}")
767
+ print("-" * 80)
768
+ print("Starting training with the best hyperparameters.")
769
+ print("-" * 80)
770
+
771
+ # rouge = evaluate.load("rouge") # , cache_dir="/mimer/NOBACKUP/groups/naiss2023-6-290/stefano/.cache/huggingface/evaluate/")
772
+ # fpgen = Chem.rdFingerprintGenerator.GetMorganGenerator(
773
+ # radius=11,
774
+ # fpSize=1024,
775
+ # )
776
+ rouge = None
777
+ fpgen = None
778
+ compute_metrics = partial(
779
+ decode_and_get_metrics,
780
+ tokenizer=tokenizer,
781
+ rouge=rouge,
782
+ fpgen=fpgen,
783
+ compute_rdkit_metrics=False,
784
+ compute_graph_metrics=True,
785
+ num_proc=max(1, num_proc_map - 2), # NOTE: Use 2 less process for the metrics, since there will be a timeout logic
786
+ causal_language_modeling=causal_language_modeling,
787
+ )
788
+
789
+ if training_args_bin is not None:
790
+ print(f"Loading training arguments from: {training_args_bin}.")
791
+ # Load training arguments from a binary file and update model-specific arguments
792
+ args = torch.load(training_args_bin)
793
+ args.output_dir = output_dir
794
+ args.overwrite_output_dir = True if delete_local_repo_if_exists else False
795
+ args.push_to_hub_model_id = model_id
796
+ args.push_to_hub_organization = organization
797
+ args.hub_model_id = hub_model_id
798
+ args.hub_token = hub_token
799
+ # Print all the training arguments
800
+ print("Training arguments loaded:")
801
+ for k, v in args.__dict__.items():
802
+ if 'token' in k:
803
+ continue
804
+ print(f" - {k}: {v}")
805
+ else:
806
+ if causal_language_modeling:
807
+ args = TrainingArguments(**training_args)
808
+ else:
809
+ args = Seq2SeqTrainingArguments(**training_args)
810
+
811
+ if causal_language_modeling:
812
+ TrainerClass = Trainer
813
+ else:
814
+ TrainerClass = Seq2SeqTrainer
815
+
816
+ # Setup the Trainer and start training (no Optuna hyperparameter search)
817
+ trainer = TrainerClass(
818
+ model_init=model_lambda,
819
+ tokenizer=tokenizer,
820
+ data_collator=data_collator,
821
+ args=args,
822
+ compute_metrics=compute_metrics,
823
+ train_dataset=dataset_tokenized["train"],
824
+ eval_dataset=dataset_tokenized["test"],
825
+ )
826
+ if resume_from_checkpoint is not None:
827
+ trainer.train(
828
+ resume_from_checkpoint=resume_from_checkpoint,
829
+ )
830
+ else:
831
+ trainer.train()
832
+ print("-" * 80)
833
+ print("Training completed.")
834
+ print("-" * 80)
835
+
836
+ if causal_language_modeling:
837
+ tasks = ["Text Generation"]
838
+ else:
839
+ tasks = ["Text2Text Generation", "question-answering"]
840
+
841
+ tokenizer.save_pretrained(output_dir)
842
+
843
+ if hub_model_id is not None:
844
+ print("Pushing model to Hugging Face Hub.")
845
+ print("-" * 80)
846
+ trainer.push_to_hub(
847
+ commit_message="Initial version",
848
+ model_name=hub_model_id,
849
+ license="mit",
850
+ finetuned_from=f"{pretrained_encoder}",
851
+ tasks=tasks,
852
+ tags=["PROTAC", "cheminformatics"],
853
+ dataset=[ds_name],
854
+ dataset_args=[ds_config],
855
+ )
856
+ tokenizer.push_to_hub(
857
+ repo_id=hub_model_id,
858
+ commit_message="Upload tokenizer",
859
+ private=True,
860
+ token=hub_token,
861
+ tags=["PROTAC", "cheminformatics"],
862
+ )
863
+ else:
864
+ print("Pushing model to local directory.")
865
+ print("-" * 80)
866
+ trainer.save_model(output_dir)
867
+ tokenizer.save_pretrained(output_dir)
868
+ print(f"Model saved to '{output_dir}'.")
869
+ print("All done.")
protac_splitter/llms/training_causal_model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from typing import Dict, Any
4
+ import torch
5
+ from transformers import TrainerCallback
6
+ from trl import SFTTrainer
7
+ from rdkit import Chem
8
+
9
+ from protac_splitter.llms.data_utils import load_tokenized_dataset
10
+ from protac_splitter.llms.model_utils import get_model
11
+
12
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use GPU if available
13
+
14
+ # Placeholder for a scoring function that evaluates the generated SMILES
15
+ def score_function(smiles1, predicted_smiles):
16
+ """ Evaluates the generated SMILES sequence based on validity. """
17
+ mol = Chem.MolFromSmiles(predicted_smiles)
18
+ return 1 if mol else 0 # Returns 1 if valid, 0 if invalid
19
+
20
+ # Custom Trainer subclass to integrate SMILES evaluation
21
+ class CustomSFTTrainer(SFTTrainer):
22
+ def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"):
23
+ if eval_dataset is None:
24
+ eval_dataset = self.eval_dataset
25
+
26
+ # Generate predictions
27
+ predictions = self.predict(eval_dataset)
28
+ generated_texts = self.tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True)
29
+
30
+ total_score = 0
31
+ total_samples = len(generated_texts)
32
+
33
+ for i, example in enumerate(eval_dataset):
34
+ input_text = example["text"] # Full input: "Smiles1 Smiles2.Smiles3.Smiles4"
35
+ smiles1 = input_text.split(" ")[0] # Extract Smiles1 (the prompt)
36
+
37
+ # Remove the prompt from the generated text to get the predicted completion
38
+ predicted_completion = generated_texts[i].removeprefix(smiles1).strip()
39
+
40
+ # Compute custom score
41
+ score = score_function(smiles1, predicted_completion)
42
+ total_score += score
43
+
44
+ # Compute average score
45
+ average_score = total_score / total_samples if total_samples > 0 else 0
46
+
47
+ # Log metrics
48
+ metrics = {f"{metric_key_prefix}_average_score": average_score}
49
+ self.log(metrics)
50
+
51
+ return metrics
52
+
53
+ def train():
54
+ """ Main training function """
55
+ model = get_model() # Load the model
56
+ tokenizer = model.tokenizer # Get tokenizer from model
57
+
58
+ # Load dataset
59
+ dataset = load_tokenized_dataset()
60
+
61
+ # Training arguments
62
+ training_args = {
63
+ "output_dir": "./trained_model",
64
+ "evaluation_strategy": "steps",
65
+ "save_strategy": "steps",
66
+ "logging_steps": 100,
67
+ "save_steps": 500,
68
+ "num_train_epochs": 3,
69
+ "per_device_train_batch_size": 8,
70
+ "per_device_eval_batch_size": 8,
71
+ "learning_rate": 5e-5,
72
+ "save_total_limit": 2,
73
+ }
74
+
75
+ # Initialize custom trainer
76
+ trainer = CustomSFTTrainer(
77
+ model=model,
78
+ args=training_args,
79
+ train_dataset=dataset["train"],
80
+ eval_dataset=dataset["validation"],
81
+ )
82
+
83
+ # Train model
84
+ trainer.train()
85
+
86
+ if __name__ == "__main__":
87
+ train()
protac_splitter/llms/training_mlm_model.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Train a masked language model (MLM) using an encoder-decoder architecture. """
2
+ import os
3
+ from typing import Optional, Dict, Any, Union
4
+ import subprocess
5
+
6
+ import torch
7
+ import huggingface_hub as hf
8
+ from transformers import (
9
+ Trainer,
10
+ TrainingArguments,
11
+ DataCollatorForLanguageModeling,
12
+ AutoTokenizer,
13
+ )
14
+
15
+ from protac_splitter.llms.data_utils import load_tokenized_dataset
16
+ from protac_splitter.llms.hf_utils import (
17
+ create_hf_repository,
18
+ delete_hf_repository,
19
+ repo_exists,
20
+ )
21
+ from protac_splitter.llms.model_utils import get_encoder_decoder_model
22
+
23
+
24
+ def compute_metrics_for_mlm(pred) -> Dict[str, float]:
25
+ """Compute metrics for MLM predictions, i.e., perplexity."""
26
+ logits = pred.predictions[0] if isinstance(pred.predictions, tuple) else pred.predictions
27
+ labels = pred.label_ids
28
+
29
+ # Convert to torch tensors
30
+ logits = torch.tensor(logits)
31
+ labels = torch.tensor(labels)
32
+
33
+ # Compute masked loss
34
+ loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
35
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
36
+
37
+ return {
38
+ "perplexity": torch.exp(loss).item(),
39
+ "loss": loss.item()
40
+ }
41
+
42
+
43
+ def train_mlm_model(
44
+ model_id: str,
45
+ ds_name: str,
46
+ ds_config: str = 'default',
47
+ learning_rate: float = 5e-5,
48
+ max_steps: int = -1,
49
+ num_train_epochs: int = 40,
50
+ batch_size: int = 128,
51
+ batch_size_tokenizer: int = 512,
52
+ gradient_accumulation_steps: int = 4,
53
+ hub_token: Optional[str] = None,
54
+ organization: Optional[str] = None,
55
+ output_dir: str = "./models/",
56
+ tokenizer: Union[AutoTokenizer, str] = "seyonec/ChemBERTa-zinc-base-v1",
57
+ pretrained_encoder: str = "seyonec/ChemBERTa-zinc-base-v1",
58
+ pretrained_decoder: str = "seyonec/ChemBERTa-zinc-base-v1",
59
+ encoder_max_length: int = 512,
60
+ decoder_max_length: int = 512,
61
+ tie_encoder_decoder: bool = False,
62
+ delete_repo_if_exists: bool = False,
63
+ delete_local_repo_if_exists: bool = False,
64
+ training_args: Optional[Dict[str, Any]] = None,
65
+ resume_from_checkpoint: Optional[str] = None,
66
+ num_proc_map: int = 1,
67
+ per_device_batch_size: Optional[int] = None,
68
+ lr_scheduler_type: Optional[str] = None,
69
+ mlm_probability: float = 0.15,
70
+ randomize_smiles: bool = False,
71
+ randomize_smiles_prob: float = 0.5,
72
+ randomize_smiles_repeat: int = 1,
73
+ ):
74
+ """
75
+ Trains a masked language model (MLM) using an encoder-decoder architecture.
76
+
77
+ Args:
78
+ model_id (str): The name of the model to be trained.
79
+ ds_name (str): The name of the dataset to use for training.
80
+ ds_config (str): The configuration of the dataset to use. Default: 'default'.
81
+ learning_rate (float): The learning rate for training. Default: 5e-5.
82
+ max_steps (int): The maximum number of training steps. Default: -1.
83
+ num_train_epochs (int): The number of training epochs. Default: 40.
84
+ batch_size (int): The total batch size. Default: 128.
85
+ batch_size_tokenizer (int): The batch size for the tokenizer. Default: 512.
86
+ gradient_accumulation_steps (int): The number of gradient accumulation steps. Default: 4.
87
+ hub_token (str): The Hugging Face token for authentication. Default: None.
88
+ organization (str): The organization to push the model to. Default: None.
89
+ output_dir (str): The output directory for the model. Default: "./models/".
90
+ tokenizer (AutoTokenizer | str): The tokenizer to use for training. Default: "seyonec/ChemBERTa-zinc-base-v1".
91
+ pretrained_encoder (str): The pretrained encoder model to use. Default: "seyonec/ChemBERTa-zinc-base-v1".
92
+ pretrained_decoder (str): The pretrained decoder model to use. Default: "seyonec/ChemBERTa-zinc-base-v1".
93
+ encoder_max_length (int): The maximum length of the encoder input. Default: 512.
94
+ decoder_max_length (int): The maximum length of the decoder input. Default: 512.
95
+ tie_encoder_decoder (bool): Whether to tie the encoder and decoder weights. Default: False.
96
+ delete_repo_if_exists (bool): Whether to delete the repository if it already exists. Default: False.
97
+ delete_local_repo_if_exists (bool): Whether to delete the local repository if it already exists. Default: False.
98
+ training_args (Dict[str, Any]): The training arguments for the Trainer. Default: None.
99
+ resume_from_checkpoint (str): The checkpoint to resume training from. Default: None.
100
+ num_optuna_trials (int): The number of Optuna hyperparameter search trials. Default: 0.
101
+ num_proc_map (int): The number of processes to use for mapping. Default: 1.
102
+ per_device_batch_size (int): The batch size per device. If defined, it will overwrite batch_size. Default: None.
103
+ lr_scheduler_type (str): The learning rate scheduler type. Default: None.
104
+ mlm_probability (float): The probability of masking tokens in the input. Default: 0.15.
105
+ randomize_smiles (bool): Whether to randomize SMILES strings. Default: False.
106
+ randomize_smiles_prob (float): The probability of randomizing SMILES strings. Default: 0.5.
107
+ randomize_smiles_repeat (int): The number of times to repeat randomizing SMILES strings. Default: 1.
108
+ """
109
+ # Check if resume_from_checkpoint exists and it's a file
110
+ if resume_from_checkpoint is not None:
111
+ # Check if the checkpoint exists: it can be either a file or a directory
112
+ if not os.path.exists(resume_from_checkpoint):
113
+ raise ValueError(f"Checkpoint file '{resume_from_checkpoint}' does not exist.")
114
+
115
+ if hub_token is not None:
116
+ hf.login(token=hub_token)
117
+
118
+ # Setup output directory and Hugging Face repository
119
+ output_dir += f"/{model_id}"
120
+ if organization is not None:
121
+ hub_model_id = f"{organization}/{model_id}"
122
+ if delete_repo_if_exists and repo_exists(hub_model_id, token=hub_token):
123
+ delete_hf_repository(repo_id=hub_model_id, token=hub_token)
124
+ if not repo_exists(hub_model_id, token=hub_token):
125
+ print(f"Repository '{hub_model_id}' deleted.")
126
+ else:
127
+ print(f"Repository '{hub_model_id}' could not be deleted.")
128
+ return
129
+ if delete_local_repo_if_exists and os.path.exists(output_dir):
130
+ subprocess.run(["rm", "-rf", output_dir])
131
+ if not os.path.exists(output_dir):
132
+ print(f"Local repository '{output_dir}' deleted.")
133
+ else:
134
+ print(f"Local repository '{output_dir}' could not be deleted.")
135
+ return
136
+ repo_url = create_hf_repository(
137
+ repo_id=hub_model_id,
138
+ repo_type="model",
139
+ exist_ok=True,
140
+ private=True,
141
+ token=hub_token,
142
+ )
143
+ print(f"Repository '{hub_model_id}' created at URL: {repo_url}")
144
+ else:
145
+ hub_model_id = None
146
+ print(f"Hub model ID: {hub_model_id}")
147
+
148
+ if isinstance(tokenizer, str):
149
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
150
+ elif tokenizer is None:
151
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_encoder)
152
+
153
+ # Set the pad token to the end of the sequence, required for MLM training
154
+ tokenizer.pad_token = tokenizer.eos_token
155
+
156
+ # Load the tokenized dataset
157
+ print("Loading tokenized dataset.")
158
+ dataset_tokenized = load_tokenized_dataset(
159
+ ds_name,
160
+ ds_config,
161
+ tokenizer,
162
+ batch_size_tokenizer,
163
+ encoder_max_length,
164
+ decoder_max_length,
165
+ token=hub_token,
166
+ num_proc_map=num_proc_map,
167
+ randomize_smiles=randomize_smiles,
168
+ randomize_smiles_prob=randomize_smiles_prob,
169
+ randomize_smiles_repeat=randomize_smiles_repeat,
170
+ randomize_text=True,
171
+ randomize_labels=False,
172
+ )
173
+ # Remove "labels" column from the dataset
174
+ dataset_tokenized = dataset_tokenized.remove_columns(["labels"])
175
+ print("Dataset loaded.")
176
+
177
+ # Setup the model for `model_init` in the Trainer
178
+ bert2bert = lambda: get_encoder_decoder_model(
179
+ pretrained_encoder=pretrained_encoder,
180
+ pretrained_decoder=pretrained_decoder,
181
+ max_length=encoder_max_length,
182
+ tie_encoder_decoder=tie_encoder_decoder,
183
+ )
184
+
185
+ # Setup the data collator
186
+ data_collator = DataCollatorForLanguageModeling(
187
+ tokenizer,
188
+ mlm=True,
189
+ mlm_probability=mlm_probability,
190
+ pad_to_multiple_of=8,
191
+ )
192
+
193
+ # Setup the training arguments
194
+ if per_device_batch_size is None:
195
+ per_device_batch_size = batch_size // gradient_accumulation_steps
196
+ if training_args is None:
197
+ training_args = {
198
+ "output_dir": output_dir,
199
+ # Optimizer-related configs
200
+ "learning_rate": learning_rate,
201
+ "optim": "adamw_torch",
202
+ "lr_scheduler_type": "cosine" if lr_scheduler_type is None else lr_scheduler_type,
203
+ "warmup_steps": 8000, # NOTE: ChemFormer: 8000
204
+ # "warmup_ratio": 0,
205
+ "adam_beta1": 0.9, # NOTE: ChemFormer: 0.9
206
+ "adam_beta2": 0.999, # NOTE: ChemFormer: 0.999
207
+ "adam_epsilon": 1e-8, # Default: 1e-8
208
+ # Batch size, device, and performance optimizations configs
209
+ # "torch_compile": True,
210
+ "group_by_length": True,
211
+ "per_device_train_batch_size": per_device_batch_size,
212
+ "per_device_eval_batch_size": per_device_batch_size,
213
+ "gradient_accumulation_steps": gradient_accumulation_steps,
214
+ "auto_find_batch_size": True,
215
+ "fp16": True if torch.cuda.is_available() else False,
216
+ # Evaluation and checkpointing configs
217
+ "max_steps": max_steps,
218
+ "num_train_epochs": num_train_epochs,
219
+ "save_steps": 1000, # NOTE: 200
220
+ "save_strategy": "steps",
221
+ "eval_steps": 1000, # NOTE: 500
222
+ "evaluation_strategy": "steps",
223
+ "save_total_limit": 1,
224
+ "load_best_model_at_end": True,
225
+ "metric_for_best_model": "perplexity",
226
+ "include_inputs_for_metrics": True,
227
+ # Logging configs
228
+ "log_level": "warning",
229
+ "logging_steps": 500,
230
+ "disable_tqdm": True,
231
+ "report_to": ["tensorboard"],
232
+ "save_only_model": False, # Default: False
233
+ # Hub information configs
234
+ "push_to_hub": True, # NOTE: Also manually done further down
235
+ "push_to_hub_model_id": model_id,
236
+ "push_to_hub_organization": organization,
237
+ "hub_model_id": hub_model_id,
238
+ "hub_token": hub_token,
239
+ "hub_strategy": "checkpoint", # NOTE: Allows to resume training from last checkpoint
240
+ "hub_private_repo": True,
241
+ # Other configs
242
+ "seed": 42,
243
+ "data_seed": 42,
244
+ }
245
+
246
+ # Setup the Trainer and start training (no Optuna hyperparameter search)
247
+ trainer = Trainer(
248
+ model_init=bert2bert,
249
+ tokenizer=tokenizer,
250
+ data_collator=data_collator,
251
+ args=TrainingArguments(**training_args),
252
+ compute_metrics=compute_metrics_for_mlm,
253
+ train_dataset=dataset_tokenized["train"],
254
+ eval_dataset=dataset_tokenized["validation"],
255
+ )
256
+ if resume_from_checkpoint is not None:
257
+ trainer.train(
258
+ resume_from_checkpoint=resume_from_checkpoint,
259
+ )
260
+ else:
261
+ trainer.train()
262
+ print("-" * 80)
263
+ print("Training completed.")
264
+ print("-" * 80)
265
+
266
+ if hub_model_id is not None:
267
+ print("Pushing model to Hugging Face Hub.")
268
+ print("-" * 80)
269
+ tokenizer.save_pretrained(output_dir)
270
+ trainer.push_to_hub(
271
+ commit_message="Initial version",
272
+ model_name=hub_model_id,
273
+ license="mit",
274
+ finetuned_from=f"{pretrained_encoder}",
275
+ tasks=["Text2Text Generation", "question-answering"],
276
+ tags=["PROTAC", "cheminformatics"],
277
+ dataset=[ds_name],
278
+ dataset_args=[ds_config],
279
+ )
280
+ tokenizer.push_to_hub(
281
+ repo_id=hub_model_id,
282
+ commit_message="Upload tokenizer",
283
+ private=True,
284
+ token=hub_token,
285
+ tags=["PROTAC", "cheminformatics"],
286
+ )
287
+ print("All done.")
protac_splitter/llms/training_rl_models.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Train a PPO and DPO model for PROTAC-Splitter using Hugging Face
2
+ Transformers and TRL. This is a work in progress code, so it's not tested nor
3
+ used in the package.
4
+ """
5
+ from typing import Optional, Literal
6
+ from functools import partial
7
+ import os
8
+ import subprocess
9
+
10
+ import torch
11
+ import evaluate
12
+ import huggingface_hub as hf
13
+ from tqdm import tqdm
14
+ from datasets import load_dataset
15
+ from rdkit import Chem
16
+ from transformers import (
17
+ AutoTokenizer,
18
+ TrainingArguments,
19
+ EncoderDecoderModel,
20
+ AutoConfig,
21
+ )
22
+ from trl import (
23
+ AutoModelForSeq2SeqLMWithValueHead,
24
+ PPOConfig,
25
+ PPOTrainer,
26
+ DPOTrainer,
27
+ )
28
+
29
+ from protac_splitter.llms.data_utils import (
30
+ load_trl_dataset,
31
+ data_collator_for_trl,
32
+ )
33
+
34
+ from protac_splitter.llms.hf_utils import (
35
+ create_hf_repository,
36
+ delete_hf_repository,
37
+ repo_exists,
38
+ )
39
+ from protac_splitter.llms.evaluation import decode_and_get_metrics
40
+ from protac_splitter.evaluation import check_substructs, split_prediction
41
+
42
+
43
+ def clean_text(text: str) -> str:
44
+ """ Cleans the text by removing special tokens. """
45
+ return text.replace("<s>", "").replace("</s>", "")
46
+
47
+
48
+ def reward_function(
49
+ query: str,
50
+ response: str,
51
+ ) -> float:
52
+ """ Reward function for the RL-based models.
53
+
54
+ Args:
55
+ query (str): The query SMILES string.
56
+ response (str): The response SMILES string.
57
+
58
+ Returns:
59
+ float: The reward value.
60
+ """
61
+
62
+ substructs = split_prediction(response)
63
+ if substructs is None:
64
+ return torch.Tensor(-1.)
65
+
66
+ if not check_substructs(
67
+ protac_smiles=query,
68
+ poi_smiles=substructs['poi'],
69
+ linker_smiles=substructs['linker'],
70
+ e3_smiles=substructs['e3'],
71
+ return_bond_types=False,
72
+ poi_attachment_id=1,
73
+ e3_attachment_id=2,
74
+ ):
75
+ return torch.Tensor(0.)
76
+
77
+ return torch.Tensor(1.)
78
+
79
+
80
+ def train_ppo_model(
81
+ model_id: str = "PROTAC-Splitter-PPO-standard_rand_recombined-ChemBERTa-zinc-base",
82
+ organization: str = 'ailab-bio',
83
+ output_dir: str = "./models/",
84
+ max_steps: int = 2000,
85
+ ppo_epochs: int = 5,
86
+ batch_size: int = 128,
87
+ hub_token: Optional[str] = None,
88
+ pretrained_model_name: str = "ailab-bio/PROTAC-Splitter-standard_rand_recombined-ChemBERTa-zinc-base",
89
+ max_length: int = 512,
90
+ delete_repo_if_exists: bool = False,
91
+ delete_local_repo_if_exists: bool = False,
92
+ ds_name: str = "ailab-bio/PROTAC-Splitter-Dataset",
93
+ ds_config: str = "standard",
94
+ ):
95
+ """ Trains a PPO model on a given dataset.
96
+
97
+ Args:
98
+ model_id (str, optional): The name of the model to be trained. Defaults to "PROTAC-Splitter-PPO-standard_rand_recombined-ChemBERTa-zinc-base".
99
+ organization (str, optional): The organization name. Defaults to 'ailab-bio'.
100
+ output_dir (str, optional): The output directory. Defaults to "./models/".
101
+ max_steps (int, optional): The maximum number of training steps. Defaults to 2000.
102
+ ppo_epochs (int, optional): The number of PPO epochs. Defaults to 4.
103
+ batch_size (int, optional): The batch size. Defaults to 128.
104
+ hub_token (Optional[str], optional): The Hugging Face token. Defaults to None.
105
+ pretrained_model_name (str, optional): The name of the pretrained model. Defaults to "ailab-bio/PROTAC-Splitter-standard_rand_recombined-ChemBERTa-zinc-base".
106
+ max_length (int, optional): The maximum length of the input sequence. Defaults to 512.
107
+ delete_repo_first (bool, optional): Whether to delete the repository first. Defaults to False.
108
+ """
109
+ if ppo_epochs < 1:
110
+ raise ValueError(f"ppo_epochs must be >= 1, got {ppo_epochs}.")
111
+ if hub_token is not None:
112
+ hf.login(token=hub_token)
113
+
114
+ # Setup output directory and Hugging Face repository
115
+ output_dir += f"/{model_id}"
116
+ if organization is not None:
117
+ hub_model_id = f"{organization}/{model_id}"
118
+ if delete_repo_if_exists and repo_exists(hub_model_id, token=hub_token):
119
+ delete_hf_repository(repo_id=hub_model_id, token=hub_token)
120
+ if not repo_exists(hub_model_id, token=hub_token):
121
+ print(f"Repository '{hub_model_id}' deleted.")
122
+ else:
123
+ print(f"Repository '{hub_model_id}' could not be deleted.")
124
+ return
125
+ if delete_local_repo_if_exists and os.path.exists(output_dir):
126
+ subprocess.run(["rm", "-rf", output_dir])
127
+ if not os.path.exists(output_dir):
128
+ print(f"Local repository '{output_dir}' deleted.")
129
+ else:
130
+ print(f"Local repository '{output_dir}' could not be deleted.")
131
+ return
132
+ repo_url = create_hf_repository(
133
+ repo_id=hub_model_id,
134
+ repo_type="model",
135
+ exist_ok=True,
136
+ private=True,
137
+ token=hub_token,
138
+ )
139
+ print(f"Repository '{hub_model_id}' created at URL: {repo_url}")
140
+ else:
141
+ hub_model_id = None
142
+ print(f"Hub model ID: {hub_model_id}")
143
+
144
+ # Load pretrained model
145
+ model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
146
+ pretrained_model_name,
147
+ max_length=max_length,
148
+ )
149
+ ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(
150
+ pretrained_model_name,
151
+ max_length=max_length,
152
+ )
153
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
154
+ tokenizer.pad_token = tokenizer.eos_token
155
+
156
+ # Get dataset
157
+ train_dataset = load_trl_dataset(
158
+ tokenizer=tokenizer,
159
+ token=hub_token,
160
+ max_length=max_length,
161
+ dataset_name=ds_name,
162
+ ds_config=ds_config,
163
+ ).shuffle(seed=42).flatten_indices()
164
+
165
+ # Setup PPO trainer
166
+ hub_configs = {
167
+ "repo_id": hub_model_id,
168
+ "commit_message": "Initial version",
169
+ "private": True,
170
+ }
171
+ ppo_config = PPOConfig(
172
+ # Learning parameters
173
+ learning_rate=1e-5,
174
+ steps=max_steps, # Default: 20_000
175
+ ppo_epochs=ppo_epochs, # Default: 4
176
+ batch_size=batch_size, # Default: 256
177
+ gradient_accumulation_steps=1, # Default: 1
178
+ optimize_device_cache=True,
179
+ # PPO parameters
180
+ init_kl_coef=1.0,
181
+ adap_kl_ctrl=True,
182
+ target=0.5,
183
+ horizon=1000,
184
+ cliprange=0.1,
185
+ early_stopping=True,
186
+ target_kl=0.5,
187
+ max_grad_norm=1.0,
188
+ use_score_scaling=True,
189
+ use_score_norm=True,
190
+ whiten_rewards=True,
191
+ # Logging parameters
192
+ # NOTE: Check this guide for more information about the logged metrics:
193
+ # https://huggingface.co/docs/trl/v0.10.1/logging
194
+ model_name=hub_model_id,
195
+ push_to_hub_if_best_kwargs=hub_configs,
196
+ log_with="tensorboard", # ["wandb", LoggerType.TENSORBOARD],
197
+ project_kwargs={"logging_dir": output_dir},
198
+ seed=42,
199
+ )
200
+ ppo_trainer = PPOTrainer(
201
+ model=model,
202
+ ref_model=ref_model,
203
+ num_shared_layers=0,
204
+ config=ppo_config,
205
+ tokenizer=tokenizer,
206
+ dataset=train_dataset,
207
+ data_collator=data_collator_for_trl,
208
+ # lr_scheduler=torch.optim.lr_scheduler.LRScheduler, # NOTE: It must be that, CosineAnnealingLR is not supported
209
+ )
210
+
211
+ # Training Loop
212
+ generation_kwargs = {
213
+ "do_sample": True,
214
+ "num_beams": 5,
215
+ "top_k": 20,
216
+ "max_length": 512,
217
+ "pad_token_id": tokenizer.eos_token_id,
218
+ }
219
+
220
+ for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader), total=len(ppo_trainer.dataloader)):
221
+ query_tensors = batch["input_ids"]
222
+
223
+ # Get response from SFTModel
224
+ response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
225
+ batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
226
+
227
+ # Compute reward score
228
+ rewards = [reward_function(clean_text(q), clean_text(r)) for q, r in zip(batch["query"], batch["response"])]
229
+ rewards = [torch.tensor(r) for r in rewards]
230
+
231
+ # Run PPO step
232
+ stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
233
+ ppo_trainer.log_stats(stats, batch, rewards)
234
+
235
+ # Save model and tokenizer
236
+ ppo_trainer.push_to_hub(**hub_configs)
237
+ tokenizer.push_to_hub(**hub_configs)
238
+
239
+
240
+ def train_dpo_model(
241
+ model_name: str = "ailab-bio/PROTAC-Splitter-DPO",
242
+ output_dir: str = "./models/",
243
+ beta: float = 0.1,
244
+ loss_type: Literal["sigmoid", "hinge"] = "sigmoid",
245
+ learning_rate: float = 5e-5,
246
+ max_steps: int = 2000,
247
+ num_train_epochs: int = -1,
248
+ batch_size: int = 128,
249
+ gradient_accumulation_steps: int = 4,
250
+ resume_from_checkpoint: bool = False,
251
+ hub_token: Optional[str] = None,
252
+ pretrained_model_name: str = "ailab-bio/PROTAC-Splitter_untied_80-20-split",
253
+ pretrained_ref_model_name: str = "ailab-bio/PROTAC-Splitter_untied_80-20-split",
254
+ max_length: int = None,
255
+ delete_repo_first: bool = False,
256
+ optuna_search: bool = False,
257
+ ):
258
+ """ Trains a DPO model on a given dataset.
259
+
260
+ Args:
261
+ model_name (str, optional): The name of the model to be trained. Defaults to "ailab-bio/PROTAC-Splitter-DPO".
262
+ max_steps (int, optional): The maximum number of training steps. Defaults to 2000.
263
+ """
264
+ if hub_token is not None:
265
+ hf.login(token=hub_token)
266
+ if delete_repo_first and not resume_from_checkpoint:
267
+ delete_hf_repository(repo_id=model_name, token=hub_token)
268
+ tokenizer = AutoTokenizer.from_pretrained(
269
+ pretrained_model_name,
270
+ token=hub_token,
271
+ )
272
+ if tokenizer.pad_token is None:
273
+ tokenizer.pad_token = tokenizer.eos_token
274
+ # Get train and eval datasets
275
+ dataset = load_dataset(
276
+ "ailab-bio/PROTAC-Substructures-DPO",
277
+ token=hub_token,
278
+ )
279
+ # Setup models
280
+ def model_init():
281
+ return EncoderDecoderModel.from_pretrained(
282
+ pretrained_model_name,
283
+ token=hub_token,
284
+ )
285
+ model_ref = EncoderDecoderModel.from_pretrained(
286
+ pretrained_ref_model_name,
287
+ token=hub_token,
288
+ )
289
+ # Setup training arguments
290
+ per_device_batch_size = batch_size // gradient_accumulation_steps
291
+ training_args = TrainingArguments(
292
+ output_dir=output_dir,
293
+ # Optimizer-related configs
294
+ learning_rate=learning_rate,
295
+ optim="adamw_torch",
296
+ lr_scheduler_type="cosine", # Default: "linear"
297
+ # Batch size and device configs
298
+ per_device_train_batch_size=per_device_batch_size,
299
+ per_device_eval_batch_size=per_device_batch_size,
300
+ gradient_accumulation_steps=gradient_accumulation_steps,
301
+ auto_find_batch_size=True,
302
+ # torch_compile=True,
303
+ fp16=True,
304
+ # Evaluation and checkpointing configs
305
+ evaluation_strategy="steps", # TODO: Why is it not working? "steps",
306
+ max_steps=max_steps,
307
+ num_train_epochs=num_train_epochs,
308
+ eval_steps=100,
309
+ save_steps=200,
310
+ # eval_steps=7500,
311
+ # warmup_steps=2000,
312
+ save_strategy="steps",
313
+ save_total_limit=1,
314
+ load_best_model_at_end=True,
315
+ # metric_for_best_model="valid_smiles",
316
+ # Logging configs
317
+ log_level="info",
318
+ logging_steps=50,
319
+ disable_tqdm=True,
320
+ # Hub information configs
321
+ push_to_hub=True, # NOTE: Done manually further down
322
+ hub_token=hub_token,
323
+ hub_model_id=model_name,
324
+ hub_strategy="checkpoint", # NOTE: Allows to resume training from last checkpoint
325
+ hub_private_repo=True,
326
+ # Other configs
327
+ remove_unused_columns=False,
328
+ seed=42,
329
+ data_seed=42,
330
+ )
331
+ # Setup Matrics
332
+ # TODO: The metric is not working because the predictions include rewards,
333
+ # or something like that, i.e., real values, which cannot be decoded by the
334
+ # tokenizer. Skipping for now and using the default one.
335
+ rouge = evaluate.load("rouge")
336
+ fpgen = Chem.rdFingerprintGenerator.GetMorganGenerator(
337
+ radius=8,
338
+ fpSize=2048,
339
+ )
340
+ metric = partial(
341
+ decode_and_get_metrics,
342
+ rouge=rouge,
343
+ tokenizer=tokenizer,
344
+ fpgen=fpgen,
345
+ )
346
+ # Setup trainer and start training
347
+ if max_length is None:
348
+ max_length = AutoConfig.from_pretrained(
349
+ pretrained_model_name,
350
+ token=hub_token,
351
+ ).max_length
352
+ # max_length = model.config.max_length
353
+ dpo_trainer = DPOTrainer(
354
+ model=model_init(),
355
+ ref_model=model_ref,
356
+ beta=beta,
357
+ loss_type=loss_type,
358
+ train_dataset=dataset["train"],
359
+ eval_dataset=dataset["test"],
360
+ tokenizer=tokenizer,
361
+ model_init=model_init if optuna_search else None,
362
+ # compute_metrics=metric,
363
+ max_length=max_length,
364
+ max_prompt_length=max_length,
365
+ max_target_length=max_length,
366
+ is_encoder_decoder=True,
367
+ padding_value=tokenizer.pad_token_id,
368
+ truncation_mode="keep_start",
369
+ args=training_args,
370
+ )
371
+ if optuna_search and False:
372
+ # TODO: This is not working because the training arguments do NOT
373
+ # include the beta parameter...
374
+ def optuna_hp_space(trial):
375
+ return {
376
+ "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
377
+ "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [16, 32, 64, 128]),
378
+ "beta": trial.suggest_float("beta", 0.1, 0.5),
379
+ }
380
+ best_trials = dpo_trainer.hyperparameter_search(
381
+ direction=["minimize"],
382
+ backend="optuna",
383
+ hp_space=optuna_hp_space,
384
+ n_trials=20,
385
+ # compute_objective=compute_objective,
386
+ )
387
+ print("-" * 80)
388
+ print(f"Best trials:\n{best_trials}")
389
+ print("-" * 80)
390
+ else:
391
+ if resume_from_checkpoint:
392
+ resume_from_checkpoint = "last-checkpoint"
393
+ else:
394
+ resume_from_checkpoint = None
395
+ dpo_trainer.train(
396
+ resume_from_checkpoint=resume_from_checkpoint,
397
+ )
398
+ dpo_trainer.push_to_hub(
399
+ commit_message="Initial version",
400
+ model_name=model_name,
401
+ license="mit",
402
+ finetuned_from=pretrained_model_name,
403
+ tasks=["Text2Text Generation"],
404
+ tags=["PROTAC", "cheminformatics"],
405
+ dataset="ailab-bio/PROTAC-Substructures-DPO",
406
+ )
protac_splitter/protac_cheminformatics.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ from typing import List, Tuple, Callable, Any, Union, Dict, Optional, Literal
4
+ from functools import lru_cache
5
+
6
+ from rdkit import Chem
7
+ from rdkit.Chem import AllChem
8
+ from rdkit.Chem import rdchem
9
+ from rdkit import RDLogger
10
+ from rdkit.Chem import CanonSmiles
11
+
12
+ from .chemoinformatics import (
13
+ canonize,
14
+ smiles2mol,
15
+ )
16
+
17
+ RDLogger.DisableLog("rdApp.*")
18
+
19
+
20
+ @lru_cache(maxsize=None)
21
+ def get_mol(smiles: str) -> rdchem.Mol:
22
+ return Chem.MolFromSmiles(smiles)
23
+
24
+
25
+ def find_atom_idx_of_map_atoms(
26
+ mol: rdchem.Mol,
27
+ find_poi: True,
28
+ find_e3: True,
29
+ poi_attachment_id: int = 1,
30
+ e3_attachment_id: int = 2,
31
+ ) -> Union[int, Tuple[int, int]]:
32
+ """ Find the indices of the attachment points in the given molecule.
33
+
34
+ Args:
35
+ mol (rdkit.Chem.rdchem.Mol): The molecule.
36
+ find_poi (bool): Whether to find the POI attachment point.
37
+ find_e3 (bool): Whether to find the E3 attachment point.
38
+ poi_attachment_id (int): The label of the attachment point for the POI ligand, i.e., "[*:{poi_attachment_id}]".
39
+ e3_attachment_id (int): The label of the attachment point for the E3 binder, i.e., "[*:{e3_attachment_id}]".
40
+
41
+ Returns:
42
+ int | Tuple[int, int]: The index of the attachment point for the POI ligand if find_poi is True, the index of the attachment point for the E3 binder if find_e3 is True, or a tuple containing POI and E3 indices (in this order) if both find_poi and find_e3 are True.
43
+ """
44
+ if find_poi and find_e3:
45
+ poi_idx = None
46
+ e3_idx = None
47
+ for atom in mol.GetAtoms():
48
+ if atom.GetAtomMapNum() == poi_attachment_id:
49
+ poi_idx = atom.GetIdx()
50
+ elif atom.GetAtomMapNum() == e3_attachment_id:
51
+ e3_idx = atom.GetIdx()
52
+ if poi_idx is not None and e3_idx is not None:
53
+ break
54
+ return poi_idx, e3_idx
55
+ elif find_poi:
56
+ for atom in mol.GetAtoms():
57
+ if atom.GetAtomMapNum() == poi_attachment_id:
58
+ return atom.GetIdx()
59
+ elif find_e3:
60
+ for atom in mol.GetAtoms():
61
+ if atom.GetAtomMapNum() == e3_attachment_id:
62
+ return atom.GetIdx()
63
+
64
+
65
+ def reassemble_protac(
66
+ ligands_smiles: Optional[str] = None,
67
+ poi_smiles: Optional[str] = None,
68
+ linker_smiles: Optional[str] = None,
69
+ e3_smiles: Optional[str] = None,
70
+ e3_bond_type: Literal['single', 'double', 'triple', 'rand_uniform'] = 'single',
71
+ poi_bond_type: Literal['single', 'double', 'triple', 'rand_uniform'] = 'single',
72
+ poi_attachment_id: int = 1,
73
+ e3_attachment_id: int = 2,
74
+ rand_generator = None,
75
+ ) -> Tuple[str, Chem.rdchem.Mol]:
76
+ """ Reassemble a PROTAC molecule from its substructures. The SMILES must contain attachment points.
77
+
78
+ In case the bond type cannot be formed an error will be raised.
79
+
80
+ Example of usage:
81
+
82
+ ```python
83
+ e3_smiles = '[*:2]NC(C(=O)N1CC(O)CC1C(=O)NCc1ccc(-c2scnc2C)cc1)C(C)(C)C'
84
+ linker_smiles = '[*:2]C(=O)CCCCCCCCCC[*:1]'
85
+ poi_smiles = '[*:1]CN1CCN(c2ccc(Nc3ncc4c(C)cc(=O)n(-c5cccc(NC(=O)C=C)c5)c4n3)c(OC)c2)CC1'
86
+
87
+ merged_smiles, _ = reassemble_protac(poi_smiles, linker_smiles, e3_smiles, 'single', 'single')
88
+ print(merged_smiles)
89
+ ```
90
+
91
+ Args:
92
+ poi_smiles (str): The SMILES notation for the POI ligand.
93
+ linker_smiles (str): The SMILES notation for the linker.
94
+ e3_smiles (str): The SMILES notation for the E3 binder.
95
+ e3_bond_type (str): The type of bond to be added between the E3 binder and the linker. Can be 'single', 'double', 'triple', or 'rand_uniform'.
96
+ poi_bond_type (str): The type of bond to be added between the POI ligand and the linker. Can be 'single', 'double', 'triple', or 'rand_uniform'.
97
+ poi_attachment_id (int): The label of the attachment point for the POI ligand, i.e., "[*:{poi_attachment_id}]".
98
+ e3_attachment_id (int): The label of the attachment point for the E3 binder, i.e., "[*:{e3_attachment_id}]".
99
+ rand_generator: A random number generator for 'rand_uniform' bond types. Defaults to None, i.e., standard library random.
100
+
101
+ Returns:
102
+ Tuple[str, Chem.rdchem.Mol]: The SMILES notation and RDKit molecule object for the reassembled PROTAC molecule.
103
+ """
104
+ if ligands_smiles is None:
105
+ if None in [poi_smiles, linker_smiles, e3_smiles]:
106
+ raise ValueError("Missing substructures SMILES: either provide ligands_smiles or all of poi_smiles, linker_smiles, and e3_smiles")
107
+ ligands_smiles = f'{e3_smiles}.{linker_smiles}.{poi_smiles}'
108
+ if None in [poi_smiles, linker_smiles, e3_smiles]:
109
+ if ligands_smiles is None:
110
+ raise ValueError("Missing substructures SMILES: either provide ligands_smiles or all of poi_smiles, linker_smiles, and e3_smiles")
111
+
112
+ ligands_mol = canonize(smiles2mol(ligands_smiles))
113
+ if ligands_mol is None:
114
+ return None, None
115
+
116
+ try:
117
+ protac_mol = Chem.molzip(ligands_mol)
118
+ except ValueError as e:
119
+ logging.error(f"Failed to reassemble PROTAC: {e}")
120
+ return None, None
protac_splitter/protac_splitter.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from typing import Union, Optional, Dict, List
4
+ from pathlib import Path
5
+ import logging
6
+
7
+ from datasets import Dataset
8
+ import pandas as pd
9
+
10
+ from protac_splitter.chemoinformatics import canonize
11
+ from protac_splitter.fixing_functions import fix_prediction
12
+ from protac_splitter.llms.model_utils import get_pipeline, run_pipeline
13
+ from protac_splitter.graphs.e3_clustering import get_representative_e3s_fp
14
+ from protac_splitter.graphs.edge_classifier import GraphEdgeClassifier
15
+ from protac_splitter.graphs.splitting_algorithms import split_protac_graph_based
16
+
17
+
18
+ def load_graph_edge_classifier_from_cache(
19
+ cache_dir: Union[str, Path] = "~/.cache/protac_splitter",
20
+ model_filename: str = "PROTAC-Splitter-XGBoost.joblib",
21
+ download_url: str = "https://docs.google.com/uc?export=download&id=1bb9i5_L_-re3QYPc7tSiCtVNEEbNIzAC",
22
+ ) -> GraphEdgeClassifier:
23
+ """
24
+ Loads the GraphEdgeClassifier model from a local cache directory.
25
+ If the model file is not found, downloads it from the specified URL.
26
+
27
+ Args:
28
+ cache_dir (str or Path): Directory to cache the model file.
29
+ model_filename (str): Name of the model file.
30
+ download_url (str): URL to download the model if not present.
31
+
32
+ Returns:
33
+ GraphEdgeClassifier: Loaded classifier.
34
+ """
35
+ cache_dir = Path(os.path.expanduser(cache_dir))
36
+ cache_dir.mkdir(parents=True, exist_ok=True)
37
+ model_path = cache_dir / model_filename
38
+
39
+ if not model_path.exists():
40
+ response = requests.get(download_url, stream=True)
41
+ response.raise_for_status()
42
+ expected_size = int(response.headers.get("Content-Length", -1))
43
+
44
+ with open(model_path, "wb") as f:
45
+ for chunk in response.iter_content(chunk_size=1024*1024):
46
+ if chunk:
47
+ f.write(chunk)
48
+
49
+ if expected_size != -1:
50
+ actual = model_path.stat().st_size
51
+ if actual != expected_size:
52
+ raise RuntimeError(f"Download incomplete: got {actual}, expected {expected_size}")
53
+
54
+ # Optional checksum:
55
+ # NOTE: Uncomment the following for debugging
56
+ import hashlib
57
+ h = hashlib.sha256(model_path.read_bytes()).hexdigest()
58
+ h_orig = "513621f4dc2ff7ec819a222bc7311afb8b6e6e89d6d694dd2906e695a50086dd"
59
+ if h != h_orig:
60
+ raise RuntimeError(
61
+ f"Downloaded model checksum mismatch: got {h}, expected {h_orig}. "
62
+ "Please delete the model file and try again."
63
+ )
64
+
65
+ return GraphEdgeClassifier.load(model_path)
66
+
67
+
68
+ def split_protac(
69
+ protac_smiles: Union[str, List, pd.DataFrame],
70
+ use_transformer: bool = False,
71
+ use_xgboost: bool = True,
72
+ fix_predictions: bool = True,
73
+ protac_smiles_col: str = "text",
74
+ batch_size: int = 1,
75
+ beam_size: int = 5,
76
+ device: Optional[Union[int, str]] = None,
77
+ num_proc: int = 1,
78
+ verbose: int = 0,
79
+ ) -> Union[Dict[str, str], List[Dict[str, str]]]:
80
+ """ Split a PROTAC SMILES into the two ligands and the linker.
81
+
82
+ If `use_transformer` and `use_xgboost` are both True, the Transformer model
83
+ will run first, and XGBost will be used as a fallback for predictions that
84
+ fail re-assembly and fixing. If both `use_transformer` and `use_xgboost`
85
+ are False, a fully heuristic-based algorithm will be used for splitting.
86
+
87
+ Args:
88
+ protac_smiles (str, list, or pd.DataFrame): The PROTAC SMILES to split.
89
+ If a DataFrame is provided, it must contain a column named `protac_smiles_col`.
90
+ use_transformer (bool): Whether to use the transformer model for splitting.
91
+ use_xgboost (bool): Whether to use the XGBoost model for splitting.
92
+ fix_predictions (bool): Whether to fix the predictions using deterministic cheminformatics rules. Only used if `use_transformer` is True.
93
+ protac_smiles_col (str): The name of the column containing the PROTAC SMILES in the DataFrame.
94
+ batch_size (int): Batch size for processing. Only used if `use_transformer` is True.
95
+ beam_size (int): Number of beam search predictions to generate. Only used if `use_transformer` is True. Higher values may yield better results but increase computation time.
96
+ device (int or str, optional): Device to run the Transformer model on. Defaults to None will attempt to run on GPU if available, otherwise CPU.
97
+ num_proc (int): Number of processes to use for parallel processing. Useful for large datasets of PROTACs to split.
98
+ verbose (int): Verbosity level.
99
+
100
+ Returns:
101
+ Union[Dict[str, str], List[Dict[str, str]]]: Depending on the input type, returns:
102
+ - If a single string is provided, returns a dictionary with format: `{protac_smiles_col: protac_smiles, "default_pred_n0": e3l.linker.warhead, "model_name": Transformer|XGBoost|Heuristic}`.
103
+ - If a list of strings is provided, returns a list of dictionaries with the same format as above.
104
+ - If a DataFrame is provided, returns a DataFrame with columns: `protac_smiles_col`, `default_pred_n0`, and `model_name`. The `default_pred_n0` column contains the predicted split strings in the format `e3.linker.warhead`.
105
+ """
106
+ if use_xgboost:
107
+ representative_e3s_fp = get_representative_e3s_fp()
108
+ xgboost_model = load_graph_edge_classifier_from_cache()
109
+
110
+ # Generate a Dataset from the input PROTAC SMILES
111
+ if isinstance(protac_smiles, str):
112
+ protac_smiles_canon = canonize(protac_smiles)
113
+ if protac_smiles_canon is None:
114
+ raise ValueError(f"Invalid PROTAC SMILES: {protac_smiles}")
115
+ ds = Dataset.from_dict({protac_smiles_col: [protac_smiles_canon]})
116
+ elif isinstance(protac_smiles, list):
117
+ # Canonize and check if all PROTAC SMILES are valid
118
+ protac_smiles_canon = [canonize(protac) for protac in protac_smiles]
119
+ if None in protac_smiles_canon:
120
+ wrong_protacs = [protac for protac, canon in zip(protac_smiles, protac_smiles_canon) if canon is None]
121
+ raise ValueError(f"Invalid PROTAC SMILES in list: {wrong_protacs}")
122
+ ds = Dataset.from_dict({protac_smiles_col: protac_smiles_canon})
123
+ elif isinstance(protac_smiles, pd.DataFrame):
124
+ # Check if the DataFrame contains a columns named `protac_smiles_col`
125
+ if protac_smiles_col not in protac_smiles.columns:
126
+ raise ValueError(f"DataFrame must contain a column named \"{protac_smiles_col}\".")
127
+ # Canonize and check if all PROTAC SMILES are valid
128
+ protac_smiles_canon = protac_smiles[protac_smiles_col].apply(canonize)
129
+ if protac_smiles_canon.isnull().any():
130
+ wrong_protacs = protac_smiles[protac_smiles_canon.isnull()]
131
+ raise ValueError(f"Invalid PROTAC SMILES in DataFrame: {wrong_protacs}")
132
+ ds = Dataset.from_pandas(protac_smiles_canon.to_frame(name=protac_smiles_col))
133
+
134
+ if use_transformer:
135
+ pipe = get_pipeline(
136
+ model_name="ailab-bio/PROTAC-Splitter-EncoderDecoder-lr_reduce-rand-smiles",
137
+ token=os.environ.get("HF_TOKEN", None),
138
+ is_causal_language_model=False,
139
+ num_return_sequences=beam_size,
140
+ device=device,
141
+ )
142
+
143
+ # preds will be a list of dictionaries, each containing the
144
+ # beam-size predictions for each input PROTAC SMILES. Format: [{'pred_n0': 'prediction_0', 'pred_n1': 'prediction_1', ...}, ...]
145
+ preds = run_pipeline(
146
+ pipe,
147
+ ds,
148
+ batch_size,
149
+ is_causal_language_model=False,
150
+ smiles_column=protac_smiles_col,
151
+ )
152
+
153
+ # Turn the predictions into a DataFrame and then into a Dataset
154
+ preds_df = pd.DataFrame(preds)
155
+ preds_df[protac_smiles_col] = ds[protac_smiles_col]
156
+ preds_ds = Dataset.from_pandas(preds_df)
157
+
158
+ def mapping_func(row: Dict[str, str]) -> Dict[str, str]:
159
+ """Fix the predictions for each row."""
160
+ protac = row[protac_smiles_col]
161
+ if fix_predictions:
162
+ preds = {k: fix_prediction(protac, v, verbose=verbose) for k, v in row.items() if k.startswith("pred_")}
163
+ else:
164
+ preds = {k: v for k, v in row.items() if k.startswith("pred_")}
165
+
166
+ # If all preds are None, we attempt to use the XGBoost model
167
+ if all(v is None for v in preds.values()):
168
+ if use_xgboost:
169
+ pred = split_protac_graph_based(
170
+ protac_smiles=protac,
171
+ use_classifier=True,
172
+ classifier=xgboost_model,
173
+ representative_e3s_fp=representative_e3s_fp,
174
+ )
175
+ return {
176
+ protac_smiles_col: protac,
177
+ "default_pred_n0": f"{pred['e3']}.{pred['linker']}.{pred['poi']}",
178
+ "model_name": "XGBoost",
179
+ }
180
+ else:
181
+ # If no predictions are valid, we return None for the default prediction
182
+ return {
183
+ protac_smiles_col: protac,
184
+ "default_pred_n0": None,
185
+ "model_name": "Transformer",
186
+ }
187
+ else:
188
+ # Select the non-None prediction with the lowest beam index
189
+ # NOTE: The HF predictions comes in lists, with the first
190
+ # element being the one with the highest likelihood.
191
+ for i in range(beam_size):
192
+ key = f"pred_n{i}"
193
+ if preds[key] is not None:
194
+ return {
195
+ protac_smiles_col: protac,
196
+ "default_pred_n0": preds[key],
197
+ "model_name": "Transformer",
198
+ }
199
+
200
+ # Map the function over the Dataset to fix the predictions and/or
201
+ # replace them with the XGBoost fallback predictions if they fail.
202
+ if fix_predictions or use_xgboost:
203
+ preds_ds = preds_ds.map(
204
+ mapping_func,
205
+ num_proc=1 if use_xgboost else num_proc, # Using XGBoost IN a map function might not be thread-safe
206
+ desc=f"{'Fixing predictions' if fix_predictions else ''}{' and ' if fix_predictions and use_xgboost else ''}{'Replacing predictions with XGBoost fallback' if use_xgboost else ''}",
207
+ )
208
+
209
+ elif use_xgboost:
210
+ # Use the XGBoost model only
211
+ def mapping_func(row: Dict[str, str]) -> Dict[str, str]:
212
+ """Split the PROTAC SMILES using the XGBoost model."""
213
+ protac = row[protac_smiles_col]
214
+ pred = split_protac_graph_based(
215
+ protac_smiles=protac,
216
+ use_classifier=True,
217
+ classifier=xgboost_model,
218
+ representative_e3s_fp=representative_e3s_fp,
219
+ )
220
+ if all(v is None for v in pred.values()):
221
+ split = None
222
+ else:
223
+ split = f"{pred['e3']}.{pred['linker']}.{pred['poi']}"
224
+ return {
225
+ protac_smiles_col: protac,
226
+ "default_pred_n0": split,
227
+ "model_name": "XGBoost",
228
+ }
229
+ preds_ds = ds.map(
230
+ mapping_func,
231
+ num_proc=1,
232
+ desc="Splitting PROTAC SMILES using XGBoost model",
233
+ )
234
+ else:
235
+ # If neither transformer nor XGBoost is used, we use the heuristic-based
236
+ # algorithm, that does not require any model.
237
+ def mapping_func(row: Dict[str, str]) -> Dict[str, str]:
238
+ """Split the PROTAC SMILES using the heuristic-based algorithm."""
239
+ protac = row[protac_smiles_col]
240
+ pred = split_protac_graph_based(
241
+ protac_smiles=protac,
242
+ use_classifier=False,
243
+ )
244
+ if all(v is None for v in pred.values()):
245
+ split = None
246
+ else:
247
+ split = f"{pred['e3']}.{pred['linker']}.{pred['poi']}"
248
+ return {
249
+ protac_smiles_col: protac,
250
+ "default_pred_n0": split,
251
+ "model_name": "Heuristic",
252
+ }
253
+ preds_ds = ds.map(
254
+ mapping_func,
255
+ num_proc=num_proc,
256
+ desc="Splitting PROTAC SMILES using heuristic-based algorithm",
257
+ )
258
+
259
+ if isinstance(protac_smiles, str):
260
+ # If the input was a single string, we return the first prediction
261
+ return preds_ds[0]
262
+ elif isinstance(protac_smiles, pd.DataFrame):
263
+ # If the input was a DataFrame, we return a dataframe with the predictions
264
+ return preds_ds.to_pandas()
265
+ elif isinstance(protac_smiles, list):
266
+ # Convert the Dataset to a list of dictionaries
267
+ return [row for row in preds_ds]
268
+
269
+ # if tokenizer is None:
270
+ # if verbose:
271
+ # print(f"Loading tokenizer...")
272
+ # tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
273
+
274
+ # if pipe is None:
275
+ # if verbose:
276
+ # print("Loading pipeline for \"default\" predictions...")
277
+ # pipe = pipeline(
278
+ # "text2text-generation",
279
+ # model=model_name,
280
+ # tokenizer=tokenizer,
281
+ # device="cuda" if torch.cuda.is_available() else "cpu",
282
+ # token=hf_token,
283
+ # num_return_sequences=beam_size,
284
+ # )
285
+
286
+ # if isinstance(protac_smiles, str):
287
+ # protac_smiles_canon = canonize(protac_smiles)
288
+ # if protac_smiles_canon is None:
289
+ # raise ValueError(f"Invalid PROTAC SMILES: {protac_smiles}")
290
+ # pred = pipe(protac_smiles_canon)
291
+ # pred = {f"default_pred_n{i}": pred[i]["generated_text"] for i in range(len(pred))}
292
+ # if fix_predictions:
293
+ # p_fixed = {k: fix_prediction(protac_smiles_canon, v, verbose=verbose) for k, v in pred.items()}
294
+ # # For each prediction, if the fixed prediction is not None, we
295
+ # # replace the original prediction with the fixed one.
296
+ # for k, v in p_fixed.items():
297
+ # if v is not None:
298
+ # pred[k] = v
299
+ # preds = [pred]
300
+
301
+ # if isinstance(protac_smiles, list):
302
+ # # Canonize and check if all PROTAC SMILES are valid
303
+ # protac_smiles_canon = [canonize(protac) for protac in protac_smiles]
304
+ # if None in protac_smiles_canon:
305
+ # wrong_protacs = [protac for protac, canon in zip(protac_smiles, protac_smiles_canon) if canon is None]
306
+ # raise ValueError(f"Invalid PROTAC SMILES in list: {wrong_protacs}")
307
+
308
+ # # Get the predictions for all PROTAC SMILES
309
+ # preds = pipe(protac_smiles_canon, batch_size=batch_size)
310
+ # preds = [{f"default_pred_n{i}": p["generated_text"] for i, p in enumerate(pred)} for pred in preds]
311
+
312
+ # if fix_predictions:
313
+ # for i, (protac, pred) in enumerate(zip(protac_smiles_canon, preds)):
314
+ # p_fixed = {k: fix_prediction(protac, v, verbose=verbose) for k, v in pred.items()}
315
+ # # For each prediction, if the fixed prediction is not None, we
316
+ # # replace the original prediction with the fixed one.
317
+ # for k, v in p_fixed.items():
318
+ # if v is not None:
319
+ # preds[i][k] = v
320
+
321
+ # if isinstance(protac_smiles, pd.DataFrame):
322
+ # # Check if the DataFrame contains a columns named `protac_smiles_col`
323
+ # if protac_smiles_col not in protac_smiles.columns:
324
+ # raise ValueError(f"DataFrame must contain a column named \"{protac_smiles_col}\".")
325
+
326
+ # # Canonize and check if all PROTAC SMILES are valid
327
+ # protac_smiles_canon = protac_smiles.apply(lambda x: canonize(x[protac_smiles_col]), axis=1)
328
+
329
+ # # Check if there are invalid PROTAC SMILES
330
+ # if protac_smiles_canon.isnull().any():
331
+ # wrong_protacs = protac_smiles[protac_smiles_canon.isnull()]
332
+ # raise ValueError(f"Invalid PROTAC SMILES in DataFrame: {wrong_protacs}")
333
+
334
+ # # Convert the Series to a DataFrame
335
+ # protac_smiles_canon = pd.DataFrame(protac_smiles_canon, columns=[protac_smiles_col])
336
+
337
+ # # Convert the DataFrame to a Dataset
338
+ # dataset = Dataset.from_pandas(protac_smiles_canon)
339
+ # preds = []
340
+ # for pred in tqdm(pipe(KeyDataset(dataset, protac_smiles_col), batch_size=batch_size), total=len(dataset) // batch_size, desc="Generating predictions"):
341
+ # p = {f"default_pred_n{i}": pred[i]["generated_text"] for i in range(len(pred))}
342
+ # preds.append(p)
343
+
344
+ # if fix_predictions:
345
+ # for i, (protac, pred) in tqdm(enumerate(zip(protac_smiles_canon, preds)), desc="Fixing predictions", total=len(preds)):
346
+ # p_fixed = {k: fix_prediction(protac, v, verbose=verbose) for k, v in pred.items()}
347
+ # # For each prediction, if the fixed prediction is not None, we
348
+ # # replace the original prediction with the fixed one.
349
+ # for k, v in p_fixed.items():
350
+ # if v is not None:
351
+ # pred[k] = v
352
+
353
+ # if return_check_reassembly:
354
+ # if isinstance(protac_smiles_canon, str):
355
+ # protac_smiles_list = [protac_smiles_canon]
356
+ # elif isinstance(protac_smiles_canon, list):
357
+ # protac_smiles_list = protac_smiles_canon
358
+ # elif isinstance(protac_smiles_canon, pd.DataFrame):
359
+ # protac_smiles_list = protac_smiles_canon[protac_smiles_col].tolist()
360
+
361
+ # print("Checking re-assembly...")
362
+ # for protac, pred in zip(protac_smiles_list, preds):
363
+ # for i in range(beam_size):
364
+ # pred[f"reassembly_correct_n{i}"] = check_reassembly(protac, pred[f"default_pred_n{i}"])
365
+
366
+ # # Just take the first prediction if the input was a string
367
+ # if isinstance(protac_smiles, str):
368
+ # preds = preds[0]
369
+
370
+ # return preds
protac_splitter_app.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PROTAC Splitter Web Application
3
+
4
+ This script provides a web interface for splitting PROTAC molecules into their
5
+ constituent parts: E3 ligase binder, linker, and protein-of-interest (POI)
6
+ ligand (warhead).
7
+
8
+ The app uses the protac_splitter library to perform the splitting and offers
9
+ two main modes of operation:
10
+ 1. Single SMILES processing
11
+ 2. Batch processing via CSV file upload
12
+
13
+ Users can select which models to use:
14
+ - XGBoost model (default): Fast graph-based edge classification model
15
+ - Transformer model: More accurate but slower deep learning model
16
+ - If neither is selected, a rule-based splitting algorithm is used
17
+
18
+ Author: Stefano Ribes
19
+ Date: 2025-06
20
+ """
21
+
22
+ import logging
23
+ import tempfile
24
+ from pathlib import Path
25
+ from typing import Union
26
+
27
+ from PIL import Image
28
+ import gradio as gr
29
+ import pandas as pd
30
+ from rdkit import Chem
31
+ from rdkit.Chem import Draw
32
+
33
+ from protac_splitter import split_protac
34
+ from protac_splitter.display_utils import get_mapped_protac_img
35
+
36
+ def save_svg_to_tempfile(svg_string: str, suffix: str = ".svg") -> Union[str, Path]:
37
+ """
38
+ Write an SVG string to a temporary file and return its filesystem path.
39
+ """
40
+ # Create a named temporary file that persists after closing
41
+ tmp_file = tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False, encoding="utf-8")
42
+ logging.debug(f"Saving SVG to temporary file: {tmp_file.name}")
43
+ try:
44
+ tmp_file.write(svg_string)
45
+ tmp_file.flush()
46
+ return Path(tmp_file.name)
47
+ finally:
48
+ tmp_file.close()
49
+
50
+ def process_single_smiles(protac_smiles: str, use_transformer: bool = False, use_xgboost: bool = True, beam_size: int = 5) -> tuple:
51
+ """
52
+ Process a single SMILES string and generate PROTAC fragment predictions
53
+
54
+ Args:
55
+ protac_smiles: The SMILES string of the PROTAC molecule
56
+ use_transformer: Whether to use the transformer model for prediction
57
+ use_xgboost: Whether to use the XGBoost model for prediction
58
+
59
+ Returns:
60
+ Tuple containing input image, output images, SMILES texts and status message
61
+ """
62
+ if not protac_smiles:
63
+ raise gr.Error("Please provide a valid PROTAC SMILES string.", duration=5)
64
+
65
+ try:
66
+ results = split_protac(
67
+ protac_smiles,
68
+ use_transformer=use_transformer,
69
+ use_xgboost=use_xgboost,
70
+ fix_predictions=True, # Always apply fixes to predictions
71
+ beam_size=beam_size, # Use beam search width for Transformer model
72
+ verbose=1
73
+ )
74
+ except Exception as e:
75
+ exception_message = str(e)
76
+ if exception_message.startswith("Invalid PROTAC SMILES"):
77
+ raise gr.Error("The input SMILES string is not valid (couldn't be parsed by RDKit).", duration=5)
78
+ else:
79
+ raise gr.Error(f"An error occurred while processing the input SMILES: {exception_message}", duration=10)
80
+
81
+ valid_molecules = []
82
+ pred_key = f'default_pred_n0'
83
+ valid_molecules.append(results[pred_key])
84
+
85
+ # Generate images and corresponding SMILES text
86
+ images = []
87
+ smiles_texts = []
88
+ input_mol = Chem.MolFromSmiles(protac_smiles)
89
+
90
+ if input_mol is not None:
91
+ input_img = Draw.MolToImage(input_mol, legend="", size=(1000, 200))
92
+ else:
93
+ input_img = Image.new('RGB', (1000, 1000))
94
+
95
+ splits = {}
96
+ for smiles in results[pred_key].split("."):
97
+ mol = Chem.MolFromSmiles(smiles)
98
+ if mol:
99
+ if "[*:1]" in smiles and "[*:2]" in smiles:
100
+ legend = "Linker"
101
+ splits['linker'] = smiles
102
+ elif "[*:1]" in smiles:
103
+ legend = "Warhead"
104
+ splits['poi'] = smiles
105
+ elif "[*:2]" in smiles:
106
+ legend = "E3 Ligase Ligand"
107
+ splits['e3'] = smiles
108
+
109
+ img = Draw.MolToImage(mol, legend="", size=(1000, 1000))
110
+ images.append(img)
111
+ smiles_texts.append(f"{legend}: {smiles}")
112
+ smiles_texts = "\n".join(smiles_texts)
113
+
114
+ use_svg = False
115
+ input_img = get_mapped_protac_img(
116
+ protac_smiles=protac_smiles,
117
+ poi_smiles=splits.get('poi', ''),
118
+ linker_smiles=splits.get('linker', ''),
119
+ e3_smiles=splits.get('e3', ''),
120
+ w=1000,
121
+ h=500,
122
+ legend=None,
123
+ useSVG=use_svg,
124
+ )
125
+
126
+ if use_svg:
127
+ input_img = save_svg_to_tempfile(input_img)
128
+ logging.debug(f"Returning processed image path: {input_img}")
129
+
130
+ return input_img, list(images), smiles_texts
131
+
132
+ def process_csv(
133
+ file: gr.File,
134
+ smiles_col: str,
135
+ use_transformer: bool = False,
136
+ use_xgboost: bool = True,
137
+ beam_size: int = 5,
138
+ batch_size: int = 4,
139
+ num_proc: int = 2,
140
+ # NOTE: `pr` is a progress tracker, it is used to track the progress but
141
+ # it is not used in this function. Do not remove it.
142
+ pr: gr.Progress = gr.Progress(track_tqdm=True),
143
+ ) -> Path:
144
+ """
145
+ Process a CSV file containing PROTAC SMILES
146
+
147
+ Args:
148
+ file: Uploaded CSV file
149
+ smiles_col: Name of the column containing SMILES strings
150
+ use_transformer: Whether to use the transformer model for prediction
151
+ use_xgboost: Whether to use the XGBoost model for prediction
152
+
153
+ Returns:
154
+ Path to output CSV file with predictions
155
+ """
156
+ df = pd.read_csv(file.name)
157
+ if smiles_col not in df.columns:
158
+ # Use Gradio's error message instead of raising an exception
159
+ raise gr.Error(f"Column \"{smiles_col}\" is not in the provided CSV file.", duration=5)
160
+
161
+ try:
162
+ results = split_protac(
163
+ df,
164
+ use_transformer=use_transformer,
165
+ use_xgboost=use_xgboost,
166
+ protac_smiles_col=smiles_col,
167
+ fix_predictions=True,
168
+ batch_size=batch_size,
169
+ num_proc=num_proc,
170
+ beam_size=beam_size, # Use beam search width for Transformer model
171
+ verbose=1
172
+ )
173
+ except Exception as e:
174
+ exception_message = str(e)
175
+ if exception_message.startswith("Invalid PROTAC SMILES"):
176
+ raise gr.Error("One or more of the input SMILES are not valid (couldn't be parsed by RDKit).", duration=5)
177
+ else:
178
+ raise gr.Error(f"An error occurred while processing: {exception_message}", duration=10)
179
+
180
+ output_df = pd.DataFrame(results)
181
+
182
+ # Create a temporary output file
183
+ output_file = str(Path(tempfile.gettempdir()) / "split_preds.csv")
184
+ logging.debug(f"Saving predictions to temporary file: {output_file}")
185
+ output_df.to_csv(output_file, index=False)
186
+ logging.debug(f"Output DataFrame saved to: {output_file}")
187
+
188
+ return output_file
189
+
190
+ def create_interface():
191
+ """
192
+ Create and return the Gradio interface for the PROTAC Splitter app
193
+
194
+ The interface includes two tabs:
195
+ 1. Single SMILES Input - For processing individual PROTAC SMILES
196
+ 2. CSV Upload - For batch processing of multiple PROTAC SMILES
197
+
198
+ Returns:
199
+ gr.Blocks: The Gradio interface
200
+ """
201
+ with gr.Blocks() as demo:
202
+ header = """# PROTAC-Splitter Web Application
203
+
204
+ Upload a CSV file or enter a single SMILES string to predict PROTAC substructures.
205
+
206
+ Warheads and E3 ligase ligands connections to the linker are marked with dummy atoms, _i.e._, attachment points, as follows:
207
+
208
+ - Warhead: `[*:1]`
209
+ - E3 Ligase ligand: `[*:2]`
210
+
211
+ """
212
+ gr.Markdown(header)
213
+
214
+ # Model selection section - common to both tabs
215
+ model_selection = """## Model Selection
216
+
217
+ You can choose which model to use for splitting PROTAC molecules:
218
+
219
+ - **XGBoost model** (default): Fast graph-based edge classification model
220
+ - **Transformer model**: More accurate but slower deep learning model
221
+ - If both are selected, the Transformer model will be used first, then if it fails, the XGBoost model will be used.
222
+ - If no model is selected, splitting will be done using graph-based heuristics, with no AI model involved.
223
+
224
+ For fast splitting, we reccommend using the XGBoost model only, which is fast and efficient for most cases. The Transformer model might be more accurate but it is slower, especially for processing large CSV files.
225
+ """
226
+ gr.Markdown(model_selection)
227
+ with gr.Row():
228
+ with gr.Column(scale=2):
229
+ with gr.Row():
230
+ use_xgboost = gr.Checkbox(label="Use XGBoost model", value=True)
231
+ use_transformer = gr.Checkbox(label="Use Transformer model", value=False)
232
+
233
+ # Performance configuration section
234
+ performance_configs = """### Performance Configurations
235
+
236
+ Change the following parameters to optimize performance based on your machine's capabilities. Particularly useful when processing large CSV files or when using the Transformer model.
237
+ For single SMILES processing, the default values should work well in most cases.
238
+ """
239
+ gr.Markdown(performance_configs)
240
+ with gr.Column(scale=1):
241
+ # Add a num_proc input
242
+ with gr.Row():
243
+ num_proc = gr.Number(
244
+ label="Number of Processes",
245
+ value=2,
246
+ minimum=1,
247
+ maximum=8,
248
+ step=1,
249
+ info="Number of processes to use for parallel processing. Higher values may improve performance but require more memory."
250
+ )
251
+
252
+ # Add a number input for beam_size if Transformer model is selected
253
+ with gr.Row():
254
+ # Only show beam size input if Transformer model is selected
255
+ beam_size = gr.Number(
256
+ label="Beam Search Width",
257
+ value=5,
258
+ minimum=1,
259
+ maximum=10,
260
+ step=1,
261
+ info="Width of the beam search for the Transformer model. Higher values may improve accuracy but increase processing time.",
262
+ visible=use_transformer.value # Initially hidden, will be shown if Transformer is selected
263
+ )
264
+ # Add a dynamic visibility condition to show/hide beam_size based on Transformer model selection
265
+ use_transformer.change(
266
+ lambda x: gr.update(visible=x),
267
+ inputs=[use_transformer],
268
+ outputs=[beam_size]
269
+ )
270
+
271
+ # Add a batch size input for Transformer model if selected
272
+ with gr.Row():
273
+ batch_size = gr.Number(
274
+ label="Batch Size",
275
+ value=4,
276
+ minimum=1,
277
+ maximum=64,
278
+ step=1,
279
+ info="Batch size for processing. Higher values may improve performance, especially on GPU machines, but require more memory.",
280
+ visible=use_transformer.value # Initially hidden, will be shown if Transformer is selected
281
+ )
282
+ use_transformer.change(
283
+ lambda x: gr.update(visible=x),
284
+ inputs=[use_transformer],
285
+ outputs=[batch_size]
286
+ )
287
+
288
+ # Single SMILES Input tab
289
+ gr.Markdown("## Specify Inputs")
290
+ with gr.Tab("Single SMILES Input"):
291
+ # Input area
292
+ smiles_input = gr.Textbox(
293
+ label="Enter SMILES String",
294
+ placeholder="E.g., CC(C)(C)S(=O)(=O)c1cc2c(Nc3ccc4scnc4c3)ccnc2cc1OCCOCCOCCOCCOCC(=O)Nc1cccc2c1CN(C1CCC(=O)NC1=O)C2=O",
295
+ # value="CC(C)(C)S(=O)(=O)c1cc2c(Nc3ccc4scnc4c3)ccnc2cc1OCCOCCOCCOCCOCC(=O)Nc1cccc2c1CN(C1CCC(=O)NC1=O)C2=O",
296
+ )
297
+
298
+ submit_smiles = gr.Button("Process SMILES")
299
+
300
+ # Output area
301
+ smiles_input_image = gr.Image(label="Input PROTAC", type="filepath") # Use None to allow SVG input
302
+ smiles_output_images = gr.Gallery(label="Valid Splits", columns=3)
303
+ smiles_output_texts = gr.Textbox(label="SMILES of the Splits", interactive=False, lines=3)
304
+
305
+ # Connect the button click event to the processing function
306
+ submit_smiles.click(
307
+ process_single_smiles,
308
+ inputs=[smiles_input, use_transformer, use_xgboost, beam_size],
309
+ outputs=[smiles_input_image, smiles_output_images, smiles_output_texts]
310
+ )
311
+
312
+ # CSV file processing tab
313
+ with gr.Tab("Upload CSV"):
314
+ # File upload area
315
+ file_input = gr.File(label="Upload CSV File")
316
+ smiles_column = gr.Textbox(
317
+ label="Column Name for PROTAC SMILES",
318
+ placeholder="E.g., \"PROTAC SMILES\"",
319
+ # value="PROTAC SMILES",
320
+ )
321
+ submit_csv = gr.Button("Process CSV")
322
+
323
+ # Output file download area
324
+ download_output = gr.File(label="Download Predictions")
325
+
326
+ # Connect the button click event to the processing function
327
+ submit_csv.click(
328
+ process_csv,
329
+ inputs=[file_input, smiles_column, use_transformer, use_xgboost, beam_size, batch_size, num_proc],
330
+ outputs=[download_output]
331
+ )
332
+
333
+ csv_notes = f"""**Note:** The output CSV will contain the following columns:
334
+
335
+ - `{smiles_column}`: The original PROTAC SMILES string
336
+ - `default_pred_n0`: The predicted SMILES strings for the splits
337
+ - `model_name`: The model used for the prediction
338
+ """
339
+ gr.Markdown(csv_notes)
340
+
341
+ return demo
342
+
343
+ # Create the Gradio interface
344
+ # NOTE: `demo` must be a global variable, so to make the Gradio’s hot-reload system work.
345
+ # NOTE: Launch the app with `gradio scripts/protac_splitter_app.py` to develop it.
346
+ demo = create_interface()
347
+
348
+ if __name__ == "__main__":
349
+ # Set logging level to DEBUG for detailed output
350
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
351
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.3.0
2
+ aiofiles==24.1.0
3
+ aiohappyeyeballs==2.6.1
4
+ aiohttp==3.12.13
5
+ aiosignal==1.3.2
6
+ alembic==1.16.2
7
+ annotated-types==0.7.0
8
+ anyio==4.9.0
9
+ asttokens==3.0.0
10
+ attrs==25.3.0
11
+ cairocffi==1.7.1
12
+ CairoSVG==2.8.2
13
+ certifi==2025.6.15
14
+ cffi==1.17.1
15
+ charset-normalizer==3.4.2
16
+ click==8.2.1
17
+ colorlog==6.9.0
18
+ contourpy==1.3.2
19
+ cssselect2==0.8.0
20
+ cycler==0.12.1
21
+ datasets==3.0.0
22
+ decorator==5.2.1
23
+ defusedxml==0.7.1
24
+ dill==0.3.8
25
+ docstring_parser==0.16
26
+ evaluate==0.4.3
27
+ executing==2.2.0
28
+ fastapi==0.115.14
29
+ ffmpy==0.6.0
30
+ filelock==3.18.0
31
+ fonttools==4.58.4
32
+ frozenlist==1.7.0
33
+ fsspec==2024.6.1
34
+ gradio==5.35.0
35
+ gradio_client==1.10.4
36
+ groovy==0.1.2
37
+ h11==0.16.0
38
+ hf-xet==1.1.5
39
+ httpcore==1.0.9
40
+ httpx==0.28.1
41
+ huggingface-hub==0.33.1
42
+ idna==3.10
43
+ imbalanced-learn==0.13.0
44
+ imblearn==0.0
45
+ iniconfig==2.1.0
46
+ ipython==9.4.0
47
+ ipython_pygments_lexers==1.1.1
48
+ jedi==0.19.2
49
+ Jinja2==3.1.6
50
+ joblib==1.5.1
51
+ jsonargparse==4.40.0
52
+ kiwisolver==1.4.8
53
+ lightning-utilities==0.14.3
54
+ llvmlite==0.44.0
55
+ Mako==1.3.10
56
+ markdown-it-py==3.0.0
57
+ MarkupSafe==3.0.2
58
+ matplotlib==3.10.3
59
+ matplotlib-inline==0.1.7
60
+ mdurl==0.1.2
61
+ mpmath==1.3.0
62
+ multidict==6.6.3
63
+ multiprocess==0.70.16
64
+ networkx==3.1
65
+ numba==0.61.0
66
+ numpy==1.26.4
67
+ optuna==4.2.0
68
+ ordered-set==4.1.0
69
+ orjson==3.10.18
70
+ packaging==25.0
71
+ pandas==2.2.2
72
+ parso==0.8.4
73
+ pexpect==4.9.0
74
+ pillow==11.3.0
75
+ pluggy==1.6.0
76
+ prompt_toolkit==3.0.51
77
+ propcache==0.3.2
78
+ psutil==7.0.0
79
+ ptyprocess==0.7.0
80
+ pure_eval==0.2.3
81
+ pyarrow==20.0.0
82
+ pycparser==2.22
83
+ pydantic==2.11.7
84
+ pydantic_core==2.33.2
85
+ pydub==0.25.1
86
+ Pygments==2.19.2
87
+ PyLaTeX==1.4.2
88
+ pyparsing==3.2.3
89
+ pytest==8.4.1
90
+ python-dateutil==2.9.0.post0
91
+ python-multipart==0.0.20
92
+ pytz==2025.2
93
+ PyYAML==6.0.2
94
+ rdkit==2024.9.4
95
+ regex==2024.11.6
96
+ requests==2.32.4
97
+ rich==14.0.0
98
+ ruff==0.12.1
99
+ safehttpx==0.1.6
100
+ safetensors==0.5.3
101
+ scikit-learn==1.6.1
102
+ scipy==1.14.1
103
+ seaborn==0.13.2
104
+ semantic-version==2.10.0
105
+ setuptools==80.9.0
106
+ shellingham==1.5.4
107
+ shtab==1.7.2
108
+ six==1.17.0
109
+ sklearn-compat==0.1.3
110
+ sniffio==1.3.1
111
+ SQLAlchemy==2.0.41
112
+ stack-data==0.6.3
113
+ starlette==0.46.2
114
+ sympy==1.13.1
115
+ threadpoolctl==3.6.0
116
+ tinycss2==1.4.0
117
+ tokenizers==0.19.1
118
+ tomlkit==0.13.3
119
+ torch==2.6.0
120
+ torchmetrics==1.7.3
121
+ tqdm==4.67.1
122
+ traitlets==5.14.3
123
+ transformers==4.44.2
124
+ trl==0.10.1
125
+ typeguard==4.4.4
126
+ typer==0.16.0
127
+ typing-inspection==0.4.1
128
+ typing_extensions==4.14.0
129
+ tyro==0.9.25
130
+ tzdata==2025.2
131
+ urllib3==2.5.0
132
+ uvicorn==0.35.0
133
+ wcwidth==0.2.13
134
+ webencodings==0.5.1
135
+ websockets==15.0.1
136
+ xgboost==3.0.1
137
+ xxhash==3.5.0
138
+ yarl==1.20.1