Joey Callanan
commited on
Commit
·
a3863ea
1
Parent(s):
0f4a31d
minor changes
Browse files- .vscode/settings.json +5 -0
- Gen_PartialSMILES2.py +24 -7
- Join.py +100 -0
- src/molecules/generated_variations.py +52 -16
- src/ui/handlers.py +68 -106
.vscode/settings.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
| 3 |
+
"python-envs.defaultPackageManager": "ms-python.python:conda",
|
| 4 |
+
"python-envs.pythonProjects": []
|
| 5 |
+
}
|
Gen_PartialSMILES2.py
CHANGED
|
@@ -222,9 +222,6 @@ def path_aligned_generation(
|
|
| 222 |
str_print += f" n_invalid {n_invalid:05d}"
|
| 223 |
# str_print += f" n_supressed_eos {n_supressed_eos:05d}"
|
| 224 |
print(str_print)
|
| 225 |
-
# logger.info(str_print)
|
| 226 |
-
# print(f"Iteration {iteration_counter:05d} step {step_idx:05d} merged total {total_merge_count:05d} current {count_merged:05d} dict_prefix {len(dict_path_inchikey):05d} dict_inch {len(dict_inchikey_merged_path):05d} eos {tensor_generation.shape[0]-n_eos_tokens:05d} current {tensor_generation.shape[0]:05d} generated {len(generated_smiles):08d} n_calls {n_calls:05d} n_repeated {n_repeated:05d}")
|
| 227 |
-
# get generated smiles and remove the merged prefixes
|
| 228 |
iteration_counter += 1
|
| 229 |
total_merge_count += count_merged
|
| 230 |
return generated_smiles, dict_inchikey_merged_path, dict_inchikey_count, dict_path_inchikey, total_merge_count, n_calls, n_repeated
|
|
@@ -250,10 +247,12 @@ parser.add_argument("--max_rotatable_bond", type=int, default=8)
|
|
| 250 |
parser.add_argument("--min_prefix_length", type=int, default=4)
|
| 251 |
parser.add_argument("--top_p", type=float, default=1.0)
|
| 252 |
parser.add_argument("--top_k", type=int, default=10)
|
|
|
|
|
|
|
| 253 |
# list of decode methods
|
| 254 |
parser.add_argument("--decode_methods", type=str, default="Structure-Aware_Decoding")
|
| 255 |
args = parser.parse_args()
|
| 256 |
-
|
| 257 |
pathlib.Path(args.save_dir).mkdir(parents=True, exist_ok=True)
|
| 258 |
# device = torch.device("cuda:0")
|
| 259 |
device = torch.device("cpu")
|
|
@@ -274,7 +273,10 @@ model.to(device)
|
|
| 274 |
model.eval()
|
| 275 |
budget_generation = 10
|
| 276 |
batch_size = 512
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
| 278 |
if len(scaf_smi) > 0:
|
| 279 |
if "[*]" not in scaf_smi:
|
| 280 |
raise ValueError("Scaffold does not contain attachment point")
|
|
@@ -298,10 +300,25 @@ torch.backends.cudnn.deterministic = True
|
|
| 298 |
torch.backends.cudnn.benchmark = False
|
| 299 |
|
| 300 |
n_to_gen = args.n_to_gen
|
| 301 |
-
generated_smiles_raw, dict_inchikey_merged_path, dict_inchikey_count, dict_path_inchikey, total_merge_count, n_calls, n_repeated = path_aligned_generation(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
generated_smiles = dict([(smiles.split("<can>")[-1], freq) for smiles, freq in generated_smiles_raw.items()])
|
| 303 |
|
| 304 |
pd.DataFrame({
|
| 305 |
"smiles": list(generated_smiles.keys()),
|
| 306 |
"count": list(generated_smiles.values())
|
| 307 |
-
}).to_csv("generated_molecules.csv", index=False)
|
|
|
|
| 222 |
str_print += f" n_invalid {n_invalid:05d}"
|
| 223 |
# str_print += f" n_supressed_eos {n_supressed_eos:05d}"
|
| 224 |
print(str_print)
|
|
|
|
|
|
|
|
|
|
| 225 |
iteration_counter += 1
|
| 226 |
total_merge_count += count_merged
|
| 227 |
return generated_smiles, dict_inchikey_merged_path, dict_inchikey_count, dict_path_inchikey, total_merge_count, n_calls, n_repeated
|
|
|
|
| 247 |
parser.add_argument("--min_prefix_length", type=int, default=4)
|
| 248 |
parser.add_argument("--top_p", type=float, default=1.0)
|
| 249 |
parser.add_argument("--top_k", type=int, default=10)
|
| 250 |
+
# NEW: scaffold passed from Gradio UI
|
| 251 |
+
parser.add_argument("--scaffold", type=str, default="[*]c1ccccc1")
|
| 252 |
# list of decode methods
|
| 253 |
parser.add_argument("--decode_methods", type=str, default="Structure-Aware_Decoding")
|
| 254 |
args = parser.parse_args()
|
| 255 |
+
|
| 256 |
pathlib.Path(args.save_dir).mkdir(parents=True, exist_ok=True)
|
| 257 |
# device = torch.device("cuda:0")
|
| 258 |
device = torch.device("cpu")
|
|
|
|
| 273 |
model.eval()
|
| 274 |
budget_generation = 10
|
| 275 |
batch_size = 512
|
| 276 |
+
|
| 277 |
+
# Use scaffold from CLI args
|
| 278 |
+
scaf_smi = args.scaffold
|
| 279 |
+
|
| 280 |
if len(scaf_smi) > 0:
|
| 281 |
if "[*]" not in scaf_smi:
|
| 282 |
raise ValueError("Scaffold does not contain attachment point")
|
|
|
|
| 300 |
torch.backends.cudnn.benchmark = False
|
| 301 |
|
| 302 |
n_to_gen = args.n_to_gen
|
| 303 |
+
generated_smiles_raw, dict_inchikey_merged_path, dict_inchikey_count, dict_path_inchikey, total_merge_count, n_calls, n_repeated = path_aligned_generation(
|
| 304 |
+
model,
|
| 305 |
+
tokenizer=tokenizer,
|
| 306 |
+
max_length=args.max_length,
|
| 307 |
+
n_generation=n_to_gen,
|
| 308 |
+
batch_size=batch_size,
|
| 309 |
+
device=device,
|
| 310 |
+
tensor_scaffold=tensor_scaffold,
|
| 311 |
+
boundary=boundary,
|
| 312 |
+
budget_generation=budget_generation,
|
| 313 |
+
max_molwt=args.max_molwt,
|
| 314 |
+
max_clogp=args.max_clogp,
|
| 315 |
+
max_rotatable_bond=args.max_rotatable_bond,
|
| 316 |
+
use_merge=True,
|
| 317 |
+
min_prefix_length=args.min_prefix_length
|
| 318 |
+
)
|
| 319 |
generated_smiles = dict([(smiles.split("<can>")[-1], freq) for smiles, freq in generated_smiles_raw.items()])
|
| 320 |
|
| 321 |
pd.DataFrame({
|
| 322 |
"smiles": list(generated_smiles.keys()),
|
| 323 |
"count": list(generated_smiles.values())
|
| 324 |
+
}).to_csv("generated_molecules.csv", index=False)
|
Join.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rdkit import Chem
|
| 2 |
+
import re
|
| 3 |
+
import random
|
| 4 |
+
# supress rdkit warnings
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings("ignore")
|
| 7 |
+
|
| 8 |
+
ATTACHMENT_POINT_TOKEN = "*"
|
| 9 |
+
ATTACHMENT_POINT_NUM_REGEXP = r"\[{}:(\d+)\]".format(re.escape(ATTACHMENT_POINT_TOKEN))
|
| 10 |
+
ATTACHMENT_POINT_REGEXP = r"(?:{0}|\[{0}[^\]]*\])".format(re.escape(ATTACHMENT_POINT_TOKEN))
|
| 11 |
+
ATTACHMENT_POINT_NO_BRACKETS_REGEXP = r"(?<!\[){}".format(re.escape(ATTACHMENT_POINT_TOKEN))
|
| 12 |
+
# "[*][C@H]1C[C@@H](N)C1
|
| 13 |
+
|
| 14 |
+
def add_attachment_point_numbers(mol_or_smi, canonicalize=True):
|
| 15 |
+
smi = mol_or_smi
|
| 16 |
+
if canonicalize:
|
| 17 |
+
smi = Chem.MolToSmiles(Chem.MolFromSmiles(mol_or_smi), isomericSmiles=True, canonical=True)
|
| 18 |
+
# only add numbers ordered by the SMILES ordering
|
| 19 |
+
num = -1
|
| 20 |
+
def _ap_callback(_):
|
| 21 |
+
nonlocal num
|
| 22 |
+
num += 1
|
| 23 |
+
return "[{}:{}]".format(ATTACHMENT_POINT_TOKEN, num)
|
| 24 |
+
return re.sub(ATTACHMENT_POINT_REGEXP, _ap_callback, smi)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def remove_attachment_point_numbers(smi):
|
| 29 |
+
return re.sub(ATTACHMENT_POINT_NUM_REGEXP, "[{}]".format(ATTACHMENT_POINT_TOKEN), smi)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def join(scaffold_smi, decoration_smi, keep_label_on_atoms=False,invert_chiralty=False):
|
| 35 |
+
scaffold = Chem.MolFromSmiles(scaffold_smi)
|
| 36 |
+
decoration = Chem.MolFromSmiles(decoration_smi)
|
| 37 |
+
|
| 38 |
+
if scaffold and decoration:
|
| 39 |
+
# obtain id in the decoration
|
| 40 |
+
try:
|
| 41 |
+
attachment_points = [atom.GetProp("molAtomMapNumber") for atom in decoration.GetAtoms()
|
| 42 |
+
if atom.GetSymbol() == ATTACHMENT_POINT_TOKEN]
|
| 43 |
+
if len(attachment_points) != 1:
|
| 44 |
+
return None # more than one attachment point...
|
| 45 |
+
attachment_point = attachment_points[0]
|
| 46 |
+
except KeyError:
|
| 47 |
+
return None
|
| 48 |
+
combined_scaffold = Chem.RWMol(Chem.CombineMols(decoration, scaffold))
|
| 49 |
+
attachments = [atom for atom in combined_scaffold.GetAtoms()
|
| 50 |
+
if atom.GetSymbol() == ATTACHMENT_POINT_TOKEN and
|
| 51 |
+
atom.HasProp("molAtomMapNumber") and atom.GetProp("molAtomMapNumber") == attachment_point]
|
| 52 |
+
if len(attachments) != 2:
|
| 53 |
+
return None # something weird
|
| 54 |
+
neighbors = []
|
| 55 |
+
for atom in attachments:
|
| 56 |
+
if atom.GetDegree() != 1:
|
| 57 |
+
return None # the attachment is wrongly generated
|
| 58 |
+
neighbors.append(atom.GetNeighbors()[0])
|
| 59 |
+
bonds = [atom.GetBonds()[0] for atom in attachments]
|
| 60 |
+
bond_type = Chem.BondType.SINGLE
|
| 61 |
+
if any(bond for bond in bonds if bond.GetBondType() == Chem.BondType.DOUBLE):
|
| 62 |
+
bond_type = Chem.BondType.DOUBLE
|
| 63 |
+
combined_scaffold.AddBond(neighbors[0].GetIdx(), neighbors[1].GetIdx(), bond_type)
|
| 64 |
+
combined_scaffold.RemoveAtom(attachments[0].GetIdx())
|
| 65 |
+
combined_scaffold.RemoveAtom(attachments[1].GetIdx())
|
| 66 |
+
if invert_chiralty:
|
| 67 |
+
neighbors[1].InvertChirality()
|
| 68 |
+
if keep_label_on_atoms:
|
| 69 |
+
for neigh in neighbors:
|
| 70 |
+
_add_attachment_point_num(neigh, attachment_point)
|
| 71 |
+
|
| 72 |
+
scaffold = combined_scaffold.GetMol()
|
| 73 |
+
try:
|
| 74 |
+
Chem.SanitizeMol(scaffold)
|
| 75 |
+
except ValueError: # sanitization error
|
| 76 |
+
return None
|
| 77 |
+
else:
|
| 78 |
+
return None
|
| 79 |
+
return scaffold
|
| 80 |
+
|
| 81 |
+
def join_scaf_deco(scaffold='O=C1NN=C([*])c2c1cccc2',decorator='[*]N1CCN(C)CC1',Parameter_InvertChiralty=False):
|
| 82 |
+
try:
|
| 83 |
+
# smiles_scaffold = remove_attachment_point_numbers(scaffold)
|
| 84 |
+
# smiles_decorator = remove_attachment_point_numbers(decorator)
|
| 85 |
+
smiles_scaffold = add_attachment_point_numbers(scaffold)
|
| 86 |
+
smiles_decorator = add_attachment_point_numbers(decorator)
|
| 87 |
+
smiles_joined = Chem.MolToSmiles(join(smiles_scaffold,smiles_decorator,invert_chiralty=Parameter_InvertChiralty), isomericSmiles=True, canonical=True)
|
| 88 |
+
smiles_joined = remove_attachment_point_numbers(smiles_joined)
|
| 89 |
+
return smiles_joined
|
| 90 |
+
except:
|
| 91 |
+
return ''
|
| 92 |
+
|
| 93 |
+
# print results to the terminal for testing
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
scaffold = 'O=C1NN=C([*])c2c1cccc2'
|
| 96 |
+
decorator = '[*]N1CCN(C)CC1'
|
| 97 |
+
print("Scaffold: ", scaffold)
|
| 98 |
+
print("Decorator:", decorator)
|
| 99 |
+
joined = join_scaf_deco(scaffold,decorator,Parameter_InvertChiralty=True)
|
| 100 |
+
print("Joined: ", joined)
|
src/molecules/generated_variations.py
CHANGED
|
@@ -1,29 +1,65 @@
|
|
|
|
|
|
|
|
| 1 |
import subprocess
|
|
|
|
|
|
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
from rdkit import Chem
|
| 4 |
from rdkit.Chem import Draw
|
| 5 |
|
| 6 |
-
|
|
|
|
| 7 |
"""
|
| 8 |
-
|
| 9 |
-
and
|
|
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"--n_to_gen", str(n_to_gen)
|
| 17 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
|
|
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
| 23 |
for smi in df["smiles"].head(n_to_gen):
|
| 24 |
mol = Chem.MolFromSmiles(smi)
|
| 25 |
-
if mol:
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app/src/molecules/generated_variations.py
|
| 2 |
+
|
| 3 |
import subprocess
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
import pandas as pd
|
| 8 |
from rdkit import Chem
|
| 9 |
from rdkit.Chem import Draw
|
| 10 |
|
| 11 |
+
|
| 12 |
+
def generate_variations_from_partial_smiles(scaffold_smiles: str, n_to_gen: int = 12):
|
| 13 |
"""
|
| 14 |
+
Call Gen_PartialSMILES2.py as a subprocess, passing the user scaffold,
|
| 15 |
+
then read generated_molecules.csv and return a list of variations:
|
| 16 |
+
each item is a dict: {"smiles": str, "image": PIL.Image, "style": str}
|
| 17 |
"""
|
| 18 |
|
| 19 |
+
if not scaffold_smiles or scaffold_smiles.strip() == "":
|
| 20 |
+
return []
|
| 21 |
+
|
| 22 |
+
# Determine project root (where Gen_PartialSMILES2.py lives)
|
| 23 |
+
# This file is app/src/molecules/generated_variations.py
|
| 24 |
+
# parents[0] = .../molecules, parents[1] = .../src, parents[2] = .../app
|
| 25 |
+
project_root = Path(__file__).resolve().parents[2]
|
| 26 |
+
script_path = project_root / "Gen_PartialSMILES2.py"
|
| 27 |
+
csv_path = project_root / "generated_molecules.csv"
|
| 28 |
+
|
| 29 |
+
# Remove old CSV if it exists
|
| 30 |
+
if csv_path.exists():
|
| 31 |
+
csv_path.unlink()
|
| 32 |
+
|
| 33 |
+
# Build subprocess command
|
| 34 |
+
cmd = [
|
| 35 |
+
sys.executable,
|
| 36 |
+
str(script_path),
|
| 37 |
+
"--scaffold", scaffold_smiles,
|
| 38 |
"--n_to_gen", str(n_to_gen)
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
subprocess.run(cmd, cwd=project_root, check=True)
|
| 43 |
+
except subprocess.CalledProcessError as e:
|
| 44 |
+
print(f"Error running Gen_PartialSMILES2.py: {e}")
|
| 45 |
+
return []
|
| 46 |
|
| 47 |
+
if not csv_path.exists():
|
| 48 |
+
print("generated_molecules.csv not found after generation.")
|
| 49 |
+
return []
|
| 50 |
|
| 51 |
+
df = pd.read_csv(csv_path)
|
| 52 |
+
|
| 53 |
+
variations = []
|
| 54 |
for smi in df["smiles"].head(n_to_gen):
|
| 55 |
mol = Chem.MolFromSmiles(smi)
|
| 56 |
+
if mol is None:
|
| 57 |
+
continue
|
| 58 |
+
img = Draw.MolToImage(mol, size=(250, 250))
|
| 59 |
+
variations.append({
|
| 60 |
+
"smiles": smi,
|
| 61 |
+
"image": img,
|
| 62 |
+
"style": "partial_smiles_gen"
|
| 63 |
+
})
|
| 64 |
+
|
| 65 |
+
return variations
|
src/ui/handlers.py
CHANGED
|
@@ -6,9 +6,9 @@ for the drug discovery application UI components.
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
from ..molecules.analysis import analyze_molecule_image_only, validate_smiles_realtime, get_molecule_properties_for_hover
|
| 9 |
-
from ..molecules.variations import
|
|
|
|
| 10 |
from ..ai.services import respond, handle_structure_chat, parse_ai_structures
|
| 11 |
-
from ..molecules.generated_variations import generate_variations_from_model
|
| 12 |
|
| 13 |
|
| 14 |
class VariationHandlers:
|
|
@@ -20,93 +20,61 @@ class VariationHandlers:
|
|
| 20 |
self.variations_per_page = 12
|
| 21 |
|
| 22 |
def generate_variations_for_display(self, smiles, num_variations=12):
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
print(f"
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
print(f"Variation {i}: {var.get('style', 'Unknown')}, image type: {type(var.get('image', None))}")
|
| 38 |
-
gallery_items.append((var['image'], f"Style: {var['style']}"))
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
|
|
|
|
| 45 |
|
| 46 |
def select_variation(self, evt):
|
| 47 |
"""Handle selection of a variation from the grid."""
|
| 48 |
try:
|
| 49 |
-
print(
|
| 50 |
print(f"Event: {evt}, type: {type(evt)}")
|
| 51 |
print(f"Current variations count: {len(self.current_variations)}")
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
print("Event is None, trying to return first variation")
|
| 56 |
-
if self.current_variations:
|
| 57 |
-
selected_var = self.current_variations[0]
|
| 58 |
-
print(f"Using first variation: {selected_var.get('style', 'Unknown')}")
|
| 59 |
-
properties_text = get_molecule_properties_for_hover(selected_var['smiles'])
|
| 60 |
-
return selected_var['image'], selected_var['smiles'], selected_var['style'], properties_text
|
| 61 |
-
else:
|
| 62 |
-
print("No variations available, returning empty")
|
| 63 |
-
return None, "", "", ""
|
| 64 |
|
| 65 |
-
#
|
| 66 |
-
if
|
|
|
|
|
|
|
| 67 |
index = evt.index
|
| 68 |
elif isinstance(evt, (int, float)):
|
| 69 |
index = int(evt)
|
| 70 |
else:
|
| 71 |
-
|
| 72 |
-
# Try to return first variation as fallback
|
| 73 |
-
if self.current_variations:
|
| 74 |
-
selected_var = self.current_variations[0]
|
| 75 |
-
properties_text = get_molecule_properties_for_hover(selected_var['smiles'])
|
| 76 |
-
return selected_var['image'], selected_var['smiles'], selected_var['style'], properties_text
|
| 77 |
-
return None, "", "", ""
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
print(f"No variations available or index {index} out of range (total: {len(self.current_variations)})")
|
| 83 |
-
# Try to return first variation as fallback
|
| 84 |
-
if self.current_variations:
|
| 85 |
-
selected_var = self.current_variations[0]
|
| 86 |
-
properties_text = get_molecule_properties_for_hover(selected_var['smiles'])
|
| 87 |
-
return selected_var['image'], selected_var['smiles'], selected_var['style'], properties_text
|
| 88 |
-
return None, "", "", ""
|
| 89 |
|
| 90 |
selected_var = self.current_variations[index]
|
| 91 |
-
print(f"Selected variation {index}: {selected_var.get('style', 'Unknown')}")
|
| 92 |
-
print(f"Selected variation image type: {type(selected_var['image'])}")
|
| 93 |
-
print(f"Selected variation SMILES: {selected_var['smiles']}")
|
| 94 |
|
| 95 |
-
# Also update properties for the selected variation
|
| 96 |
-
print(f"Getting properties for SMILES: {selected_var['smiles']}")
|
| 97 |
properties_text = get_molecule_properties_for_hover(selected_var['smiles'])
|
| 98 |
-
print(f"Properties text length: {len(properties_text) if properties_text else 'None'}")
|
| 99 |
-
print(f"Properties text preview: {properties_text[:100] if properties_text else 'None'}...")
|
| 100 |
-
|
| 101 |
-
result = (selected_var['image'], selected_var['smiles'], selected_var['style'], properties_text)
|
| 102 |
-
print(f"Returning result: {len(result)} items")
|
| 103 |
-
print(f"Image type: {type(result[0])}")
|
| 104 |
-
print(f"SMILES: {result[1]}")
|
| 105 |
-
print(f"Style: {result[2]}")
|
| 106 |
-
print(f"Properties length: {len(result[3]) if result[3] else 'None'}")
|
| 107 |
-
print(f"=== SELECT_VARIATION COMPLETE ===")
|
| 108 |
|
| 109 |
-
return
|
| 110 |
except Exception as e:
|
| 111 |
print(f"Error in select_variation: {e}")
|
| 112 |
import traceback
|
|
@@ -136,17 +104,21 @@ class VariationHandlers:
|
|
| 136 |
end_idx = min(start_idx + self.variations_per_page, len(self.current_variations))
|
| 137 |
page_variations = self.current_variations[start_idx:end_idx]
|
| 138 |
|
| 139 |
-
|
| 140 |
-
gallery_items = []
|
| 141 |
-
for var in page_variations:
|
| 142 |
-
gallery_items.append((var['image'], f"Style: {var['style']}"))
|
| 143 |
-
|
| 144 |
page_info = f"Page {self.current_page + 1} of {total_pages}"
|
| 145 |
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
def update_variation_count(self, count):
|
| 149 |
-
"""Update the number of variations
|
| 150 |
self.variations_per_page = count
|
| 151 |
return count
|
| 152 |
|
|
@@ -154,9 +126,6 @@ class VariationHandlers:
|
|
| 154 |
"""Analyze molecule and return image with tooltip data."""
|
| 155 |
molecule_img = analyze_molecule_image_only(smiles)
|
| 156 |
tooltip_text = get_molecule_properties_for_hover(smiles)
|
| 157 |
-
|
| 158 |
-
# For now, we'll return the image and tooltip text separately
|
| 159 |
-
# The tooltip will be handled by JavaScript or CSS
|
| 160 |
return molecule_img, tooltip_text
|
| 161 |
|
| 162 |
|
|
@@ -171,48 +140,39 @@ class BookmarkHandlers:
|
|
| 171 |
from rdkit import Chem
|
| 172 |
from rdkit.Chem import Draw
|
| 173 |
|
| 174 |
-
# Validate SMILES first
|
| 175 |
mol = Chem.MolFromSmiles(smiles)
|
| 176 |
if not mol:
|
| 177 |
return "❌ Invalid SMILES string - cannot bookmark"
|
| 178 |
|
| 179 |
-
# Check if already bookmarked
|
| 180 |
if smiles in [bm['smiles'] for bm in self.bookmarked_molecules]:
|
| 181 |
return "⚠️ Molecule already bookmarked"
|
| 182 |
|
| 183 |
-
# Generate a name if not provided
|
| 184 |
if not molecule_name:
|
| 185 |
molecule_name = f"Bookmarked_{len(self.bookmarked_molecules) + 1}"
|
| 186 |
|
| 187 |
-
# Add to bookmarks
|
| 188 |
self.bookmarked_molecules.append({
|
| 189 |
'smiles': smiles,
|
| 190 |
'name': molecule_name,
|
| 191 |
-
'timestamp': len(self.bookmarked_molecules) + 1
|
| 192 |
})
|
| 193 |
|
| 194 |
return f"✅ Bookmarked: {molecule_name}"
|
| 195 |
|
| 196 |
def get_bookmarked_molecules(self):
|
| 197 |
-
"""Get all bookmarked molecules for display."""
|
| 198 |
return self.bookmarked_molecules
|
| 199 |
|
| 200 |
def remove_bookmark(self, smiles):
|
| 201 |
-
"""Remove a molecule from bookmarks."""
|
| 202 |
self.bookmarked_molecules = [bm for bm in self.bookmarked_molecules if bm['smiles'] != smiles]
|
| 203 |
return "🗑️ Removed from bookmarks"
|
| 204 |
|
| 205 |
def bookmark_current_molecule(self, smiles, name):
|
| 206 |
-
"""Bookmark current molecule and update gallery."""
|
| 207 |
from rdkit import Chem
|
| 208 |
from rdkit.Chem import Draw
|
| 209 |
|
| 210 |
result = self.bookmark_molecule(smiles, name)
|
| 211 |
-
# Update the bookmarked gallery
|
| 212 |
bookmarked_mols = self.get_bookmarked_molecules()
|
| 213 |
gallery_items = []
|
| 214 |
for mol in bookmarked_mols:
|
| 215 |
-
# Generate smaller images for gallery
|
| 216 |
mol_obj = Chem.MolFromSmiles(mol['smiles'])
|
| 217 |
if mol_obj:
|
| 218 |
img = Draw.MolToImage(mol_obj, size=(150, 150), kekulize=True)
|
|
@@ -231,37 +191,39 @@ class AIHandler:
|
|
| 231 |
if not message.strip() or not hf_token.strip():
|
| 232 |
return history, []
|
| 233 |
|
| 234 |
-
# Add user message to history
|
| 235 |
history.append({"role": "user", "content": message})
|
| 236 |
|
| 237 |
-
# Determine if this is a structure generation request
|
| 238 |
structure_keywords = ['generate', 'create', 'modify', 'derivative', 'variant', 'structure']
|
| 239 |
is_structure_request = any(keyword in message.lower() for keyword in structure_keywords)
|
| 240 |
|
| 241 |
if is_structure_request and selected_smiles:
|
| 242 |
-
# Handle structure generation
|
| 243 |
ai_response = ""
|
| 244 |
-
for chunk in respond(
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
ai_response = chunk
|
| 248 |
|
| 249 |
-
# Add AI response to history
|
| 250 |
history.append({"role": "assistant", "content": ai_response})
|
| 251 |
-
|
| 252 |
-
# Parse and generate structure images
|
| 253 |
structures = parse_ai_structures(ai_response, selected_smiles)
|
| 254 |
-
|
| 255 |
return history, structures
|
| 256 |
else:
|
| 257 |
-
# Handle general drug discovery questions
|
| 258 |
ai_response = ""
|
| 259 |
-
for chunk in respond(
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
ai_response = chunk
|
| 263 |
|
| 264 |
-
# Add AI response to history
|
| 265 |
history.append({"role": "assistant", "content": ai_response})
|
| 266 |
-
|
| 267 |
-
return history, []
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
from ..molecules.analysis import analyze_molecule_image_only, validate_smiles_realtime, get_molecule_properties_for_hover
|
| 9 |
+
from ..molecules.variations import generate_molecule_images
|
| 10 |
+
from ..molecules.generated_variations import generate_variations_from_partial_smiles
|
| 11 |
from ..ai.services import respond, handle_structure_chat, parse_ai_structures
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class VariationHandlers:
|
|
|
|
| 20 |
self.variations_per_page = 12
|
| 21 |
|
| 22 |
def generate_variations_for_display(self, smiles, num_variations=12):
|
| 23 |
+
"""
|
| 24 |
+
Generate variations using Gen_PartialSMILES2.py (via subprocess),
|
| 25 |
+
then format them for the gallery display.
|
| 26 |
+
"""
|
| 27 |
+
print("=== GENERATE_VARIATIONS_FOR_DISPLAY CALLED ===")
|
| 28 |
+
print(f"SMILES input: {smiles}")
|
| 29 |
+
print(f"Num variations requested: {num_variations}")
|
| 30 |
+
|
| 31 |
+
# Call the subprocess-based generator
|
| 32 |
+
variations = generate_variations_from_partial_smiles(smiles, n_to_gen=num_variations)
|
| 33 |
+
print(f"Generated {len(variations)} variations from partial SMILES model")
|
| 34 |
+
|
| 35 |
+
# Store internally for selection/navigation
|
| 36 |
+
self.current_variations = variations
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
# Gradio Gallery expects [(image, caption), ...]
|
| 39 |
+
gallery_items = [(v["image"], v["smiles"]) for v in self.current_variations]
|
| 40 |
+
|
| 41 |
+
# Style to return (for hidden display)
|
| 42 |
+
first_style = self.current_variations[0]["style"] if self.current_variations else "None"
|
| 43 |
+
|
| 44 |
+
print("=== GENERATE_VARIATIONS_FOR_DISPLAY COMPLETE ===")
|
| 45 |
|
| 46 |
+
# outputs: variations_grid, selected_smiles_display, selected_style_display
|
| 47 |
+
return gallery_items, smiles, first_style
|
| 48 |
|
| 49 |
def select_variation(self, evt):
|
| 50 |
"""Handle selection of a variation from the grid."""
|
| 51 |
try:
|
| 52 |
+
print("=== SELECT_VARIATION CALLED ===")
|
| 53 |
print(f"Event: {evt}, type: {type(evt)}")
|
| 54 |
print(f"Current variations count: {len(self.current_variations)}")
|
| 55 |
|
| 56 |
+
if not self.current_variations:
|
| 57 |
+
return None, "", "", ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
# If event is None (e.g. change without select), default to first
|
| 60 |
+
if evt is None:
|
| 61 |
+
index = 0
|
| 62 |
+
elif hasattr(evt, 'index'):
|
| 63 |
index = evt.index
|
| 64 |
elif isinstance(evt, (int, float)):
|
| 65 |
index = int(evt)
|
| 66 |
else:
|
| 67 |
+
index = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
+
# Clamp index
|
| 70 |
+
if index < 0 or index >= len(self.current_variations):
|
| 71 |
+
index = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
selected_var = self.current_variations[index]
|
|
|
|
|
|
|
|
|
|
| 74 |
|
|
|
|
|
|
|
| 75 |
properties_text = get_molecule_properties_for_hover(selected_var['smiles'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
return selected_var['image'], selected_var['smiles'], selected_var['style'], properties_text
|
| 78 |
except Exception as e:
|
| 79 |
print(f"Error in select_variation: {e}")
|
| 80 |
import traceback
|
|
|
|
| 104 |
end_idx = min(start_idx + self.variations_per_page, len(self.current_variations))
|
| 105 |
page_variations = self.current_variations[start_idx:end_idx]
|
| 106 |
|
| 107 |
+
gallery_items = [(v["image"], v["smiles"]) for v in page_variations]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
page_info = f"Page {self.current_page + 1} of {total_pages}"
|
| 109 |
|
| 110 |
+
first = page_variations[0] if page_variations else None
|
| 111 |
+
|
| 112 |
+
return (
|
| 113 |
+
gallery_items,
|
| 114 |
+
page_info,
|
| 115 |
+
first['image'] if first else None,
|
| 116 |
+
first['smiles'] if first else "",
|
| 117 |
+
first['style'] if first else ""
|
| 118 |
+
)
|
| 119 |
|
| 120 |
def update_variation_count(self, count):
|
| 121 |
+
"""Update the number of variations per page."""
|
| 122 |
self.variations_per_page = count
|
| 123 |
return count
|
| 124 |
|
|
|
|
| 126 |
"""Analyze molecule and return image with tooltip data."""
|
| 127 |
molecule_img = analyze_molecule_image_only(smiles)
|
| 128 |
tooltip_text = get_molecule_properties_for_hover(smiles)
|
|
|
|
|
|
|
|
|
|
| 129 |
return molecule_img, tooltip_text
|
| 130 |
|
| 131 |
|
|
|
|
| 140 |
from rdkit import Chem
|
| 141 |
from rdkit.Chem import Draw
|
| 142 |
|
|
|
|
| 143 |
mol = Chem.MolFromSmiles(smiles)
|
| 144 |
if not mol:
|
| 145 |
return "❌ Invalid SMILES string - cannot bookmark"
|
| 146 |
|
|
|
|
| 147 |
if smiles in [bm['smiles'] for bm in self.bookmarked_molecules]:
|
| 148 |
return "⚠️ Molecule already bookmarked"
|
| 149 |
|
|
|
|
| 150 |
if not molecule_name:
|
| 151 |
molecule_name = f"Bookmarked_{len(self.bookmarked_molecules) + 1}"
|
| 152 |
|
|
|
|
| 153 |
self.bookmarked_molecules.append({
|
| 154 |
'smiles': smiles,
|
| 155 |
'name': molecule_name,
|
| 156 |
+
'timestamp': len(self.bookmarked_molecules) + 1
|
| 157 |
})
|
| 158 |
|
| 159 |
return f"✅ Bookmarked: {molecule_name}"
|
| 160 |
|
| 161 |
def get_bookmarked_molecules(self):
|
|
|
|
| 162 |
return self.bookmarked_molecules
|
| 163 |
|
| 164 |
def remove_bookmark(self, smiles):
|
|
|
|
| 165 |
self.bookmarked_molecules = [bm for bm in self.bookmarked_molecules if bm['smiles'] != smiles]
|
| 166 |
return "🗑️ Removed from bookmarks"
|
| 167 |
|
| 168 |
def bookmark_current_molecule(self, smiles, name):
|
|
|
|
| 169 |
from rdkit import Chem
|
| 170 |
from rdkit.Chem import Draw
|
| 171 |
|
| 172 |
result = self.bookmark_molecule(smiles, name)
|
|
|
|
| 173 |
bookmarked_mols = self.get_bookmarked_molecules()
|
| 174 |
gallery_items = []
|
| 175 |
for mol in bookmarked_mols:
|
|
|
|
| 176 |
mol_obj = Chem.MolFromSmiles(mol['smiles'])
|
| 177 |
if mol_obj:
|
| 178 |
img = Draw.MolToImage(mol_obj, size=(150, 150), kekulize=True)
|
|
|
|
| 191 |
if not message.strip() or not hf_token.strip():
|
| 192 |
return history, []
|
| 193 |
|
|
|
|
| 194 |
history.append({"role": "user", "content": message})
|
| 195 |
|
|
|
|
| 196 |
structure_keywords = ['generate', 'create', 'modify', 'derivative', 'variant', 'structure']
|
| 197 |
is_structure_request = any(keyword in message.lower() for keyword in structure_keywords)
|
| 198 |
|
| 199 |
if is_structure_request and selected_smiles:
|
|
|
|
| 200 |
ai_response = ""
|
| 201 |
+
for chunk in respond(
|
| 202 |
+
message,
|
| 203 |
+
history[:-1],
|
| 204 |
+
"You are an expert medicinal chemist. Generate new chemical structures based on user requests.",
|
| 205 |
+
512,
|
| 206 |
+
temperature,
|
| 207 |
+
0.9,
|
| 208 |
+
hf_token
|
| 209 |
+
):
|
| 210 |
ai_response = chunk
|
| 211 |
|
|
|
|
| 212 |
history.append({"role": "assistant", "content": ai_response})
|
|
|
|
|
|
|
| 213 |
structures = parse_ai_structures(ai_response, selected_smiles)
|
|
|
|
| 214 |
return history, structures
|
| 215 |
else:
|
|
|
|
| 216 |
ai_response = ""
|
| 217 |
+
for chunk in respond(
|
| 218 |
+
message,
|
| 219 |
+
history[:-1],
|
| 220 |
+
"You are an expert medicinal chemist and drug discovery specialist. Help with molecular analysis, drug design, and medicinal chemistry questions.",
|
| 221 |
+
512,
|
| 222 |
+
temperature,
|
| 223 |
+
0.9,
|
| 224 |
+
hf_token
|
| 225 |
+
):
|
| 226 |
ai_response = chunk
|
| 227 |
|
|
|
|
| 228 |
history.append({"role": "assistant", "content": ai_response})
|
| 229 |
+
return history, []
|
|
|