Joey Callanan commited on
Commit
a3863ea
·
1 Parent(s): 0f4a31d

minor changes

Browse files
.vscode/settings.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "python-envs.defaultEnvManager": "ms-python.python:conda",
3
+ "python-envs.defaultPackageManager": "ms-python.python:conda",
4
+ "python-envs.pythonProjects": []
5
+ }
Gen_PartialSMILES2.py CHANGED
@@ -222,9 +222,6 @@ def path_aligned_generation(
222
  str_print += f" n_invalid {n_invalid:05d}"
223
  # str_print += f" n_supressed_eos {n_supressed_eos:05d}"
224
  print(str_print)
225
- # logger.info(str_print)
226
- # print(f"Iteration {iteration_counter:05d} step {step_idx:05d} merged total {total_merge_count:05d} current {count_merged:05d} dict_prefix {len(dict_path_inchikey):05d} dict_inch {len(dict_inchikey_merged_path):05d} eos {tensor_generation.shape[0]-n_eos_tokens:05d} current {tensor_generation.shape[0]:05d} generated {len(generated_smiles):08d} n_calls {n_calls:05d} n_repeated {n_repeated:05d}")
227
- # get generated smiles and remove the merged prefixes
228
  iteration_counter += 1
229
  total_merge_count += count_merged
230
  return generated_smiles, dict_inchikey_merged_path, dict_inchikey_count, dict_path_inchikey, total_merge_count, n_calls, n_repeated
@@ -250,10 +247,12 @@ parser.add_argument("--max_rotatable_bond", type=int, default=8)
250
  parser.add_argument("--min_prefix_length", type=int, default=4)
251
  parser.add_argument("--top_p", type=float, default=1.0)
252
  parser.add_argument("--top_k", type=int, default=10)
 
 
253
  # list of decode methods
254
  parser.add_argument("--decode_methods", type=str, default="Structure-Aware_Decoding")
255
  args = parser.parse_args()
256
- # example: python PTS_Generate.py --save_dir "entropy/gpt2_zinc_87m" --model_name "gpt2_zinc_87m" --generate_mode "scaffold_decorator" --filepath_scaffold "scaf_5.smi" --model_path "" --decode_methods "Structure-Aware_Decoding"
257
  pathlib.Path(args.save_dir).mkdir(parents=True, exist_ok=True)
258
  # device = torch.device("cuda:0")
259
  device = torch.device("cpu")
@@ -274,7 +273,10 @@ model.to(device)
274
  model.eval()
275
  budget_generation = 10
276
  batch_size = 512
277
- scaf_smi = "[*]c1ccccc1"
 
 
 
278
  if len(scaf_smi) > 0:
279
  if "[*]" not in scaf_smi:
280
  raise ValueError("Scaffold does not contain attachment point")
@@ -298,10 +300,25 @@ torch.backends.cudnn.deterministic = True
298
  torch.backends.cudnn.benchmark = False
299
 
300
  n_to_gen = args.n_to_gen
301
- generated_smiles_raw, dict_inchikey_merged_path, dict_inchikey_count, dict_path_inchikey, total_merge_count, n_calls, n_repeated = path_aligned_generation(model,tokenizer=tokenizer,max_length=args.max_length,n_generation=n_to_gen,batch_size=batch_size,device=device,tensor_scaffold=tensor_scaffold,boundary=boundary,budget_generation=budget_generation,max_molwt=args.max_molwt,max_clogp=args.max_clogp,max_rotatable_bond=args.max_rotatable_bond,use_merge=True,min_prefix_length=args.min_prefix_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  generated_smiles = dict([(smiles.split("<can>")[-1], freq) for smiles, freq in generated_smiles_raw.items()])
303
 
304
  pd.DataFrame({
305
  "smiles": list(generated_smiles.keys()),
306
  "count": list(generated_smiles.values())
307
- }).to_csv("generated_molecules.csv", index=False)
 
222
  str_print += f" n_invalid {n_invalid:05d}"
223
  # str_print += f" n_supressed_eos {n_supressed_eos:05d}"
224
  print(str_print)
 
 
 
225
  iteration_counter += 1
226
  total_merge_count += count_merged
227
  return generated_smiles, dict_inchikey_merged_path, dict_inchikey_count, dict_path_inchikey, total_merge_count, n_calls, n_repeated
 
247
  parser.add_argument("--min_prefix_length", type=int, default=4)
248
  parser.add_argument("--top_p", type=float, default=1.0)
249
  parser.add_argument("--top_k", type=int, default=10)
250
+ # NEW: scaffold passed from Gradio UI
251
+ parser.add_argument("--scaffold", type=str, default="[*]c1ccccc1")
252
  # list of decode methods
253
  parser.add_argument("--decode_methods", type=str, default="Structure-Aware_Decoding")
254
  args = parser.parse_args()
255
+
256
  pathlib.Path(args.save_dir).mkdir(parents=True, exist_ok=True)
257
  # device = torch.device("cuda:0")
258
  device = torch.device("cpu")
 
273
  model.eval()
274
  budget_generation = 10
275
  batch_size = 512
276
+
277
+ # Use scaffold from CLI args
278
+ scaf_smi = args.scaffold
279
+
280
  if len(scaf_smi) > 0:
281
  if "[*]" not in scaf_smi:
282
  raise ValueError("Scaffold does not contain attachment point")
 
300
  torch.backends.cudnn.benchmark = False
301
 
302
  n_to_gen = args.n_to_gen
303
+ generated_smiles_raw, dict_inchikey_merged_path, dict_inchikey_count, dict_path_inchikey, total_merge_count, n_calls, n_repeated = path_aligned_generation(
304
+ model,
305
+ tokenizer=tokenizer,
306
+ max_length=args.max_length,
307
+ n_generation=n_to_gen,
308
+ batch_size=batch_size,
309
+ device=device,
310
+ tensor_scaffold=tensor_scaffold,
311
+ boundary=boundary,
312
+ budget_generation=budget_generation,
313
+ max_molwt=args.max_molwt,
314
+ max_clogp=args.max_clogp,
315
+ max_rotatable_bond=args.max_rotatable_bond,
316
+ use_merge=True,
317
+ min_prefix_length=args.min_prefix_length
318
+ )
319
  generated_smiles = dict([(smiles.split("<can>")[-1], freq) for smiles, freq in generated_smiles_raw.items()])
320
 
321
  pd.DataFrame({
322
  "smiles": list(generated_smiles.keys()),
323
  "count": list(generated_smiles.values())
324
+ }).to_csv("generated_molecules.csv", index=False)
Join.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit import Chem
2
+ import re
3
+ import random
4
+ # supress rdkit warnings
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+
8
+ ATTACHMENT_POINT_TOKEN = "*"
9
+ ATTACHMENT_POINT_NUM_REGEXP = r"\[{}:(\d+)\]".format(re.escape(ATTACHMENT_POINT_TOKEN))
10
+ ATTACHMENT_POINT_REGEXP = r"(?:{0}|\[{0}[^\]]*\])".format(re.escape(ATTACHMENT_POINT_TOKEN))
11
+ ATTACHMENT_POINT_NO_BRACKETS_REGEXP = r"(?<!\[){}".format(re.escape(ATTACHMENT_POINT_TOKEN))
12
+ # "[*][C@H]1C[C@@H](N)C1
13
+
14
+ def add_attachment_point_numbers(mol_or_smi, canonicalize=True):
15
+ smi = mol_or_smi
16
+ if canonicalize:
17
+ smi = Chem.MolToSmiles(Chem.MolFromSmiles(mol_or_smi), isomericSmiles=True, canonical=True)
18
+ # only add numbers ordered by the SMILES ordering
19
+ num = -1
20
+ def _ap_callback(_):
21
+ nonlocal num
22
+ num += 1
23
+ return "[{}:{}]".format(ATTACHMENT_POINT_TOKEN, num)
24
+ return re.sub(ATTACHMENT_POINT_REGEXP, _ap_callback, smi)
25
+
26
+
27
+
28
+ def remove_attachment_point_numbers(smi):
29
+ return re.sub(ATTACHMENT_POINT_NUM_REGEXP, "[{}]".format(ATTACHMENT_POINT_TOKEN), smi)
30
+
31
+
32
+
33
+
34
+ def join(scaffold_smi, decoration_smi, keep_label_on_atoms=False,invert_chiralty=False):
35
+ scaffold = Chem.MolFromSmiles(scaffold_smi)
36
+ decoration = Chem.MolFromSmiles(decoration_smi)
37
+
38
+ if scaffold and decoration:
39
+ # obtain id in the decoration
40
+ try:
41
+ attachment_points = [atom.GetProp("molAtomMapNumber") for atom in decoration.GetAtoms()
42
+ if atom.GetSymbol() == ATTACHMENT_POINT_TOKEN]
43
+ if len(attachment_points) != 1:
44
+ return None # more than one attachment point...
45
+ attachment_point = attachment_points[0]
46
+ except KeyError:
47
+ return None
48
+ combined_scaffold = Chem.RWMol(Chem.CombineMols(decoration, scaffold))
49
+ attachments = [atom for atom in combined_scaffold.GetAtoms()
50
+ if atom.GetSymbol() == ATTACHMENT_POINT_TOKEN and
51
+ atom.HasProp("molAtomMapNumber") and atom.GetProp("molAtomMapNumber") == attachment_point]
52
+ if len(attachments) != 2:
53
+ return None # something weird
54
+ neighbors = []
55
+ for atom in attachments:
56
+ if atom.GetDegree() != 1:
57
+ return None # the attachment is wrongly generated
58
+ neighbors.append(atom.GetNeighbors()[0])
59
+ bonds = [atom.GetBonds()[0] for atom in attachments]
60
+ bond_type = Chem.BondType.SINGLE
61
+ if any(bond for bond in bonds if bond.GetBondType() == Chem.BondType.DOUBLE):
62
+ bond_type = Chem.BondType.DOUBLE
63
+ combined_scaffold.AddBond(neighbors[0].GetIdx(), neighbors[1].GetIdx(), bond_type)
64
+ combined_scaffold.RemoveAtom(attachments[0].GetIdx())
65
+ combined_scaffold.RemoveAtom(attachments[1].GetIdx())
66
+ if invert_chiralty:
67
+ neighbors[1].InvertChirality()
68
+ if keep_label_on_atoms:
69
+ for neigh in neighbors:
70
+ _add_attachment_point_num(neigh, attachment_point)
71
+
72
+ scaffold = combined_scaffold.GetMol()
73
+ try:
74
+ Chem.SanitizeMol(scaffold)
75
+ except ValueError: # sanitization error
76
+ return None
77
+ else:
78
+ return None
79
+ return scaffold
80
+
81
+ def join_scaf_deco(scaffold='O=C1NN=C([*])c2c1cccc2',decorator='[*]N1CCN(C)CC1',Parameter_InvertChiralty=False):
82
+ try:
83
+ # smiles_scaffold = remove_attachment_point_numbers(scaffold)
84
+ # smiles_decorator = remove_attachment_point_numbers(decorator)
85
+ smiles_scaffold = add_attachment_point_numbers(scaffold)
86
+ smiles_decorator = add_attachment_point_numbers(decorator)
87
+ smiles_joined = Chem.MolToSmiles(join(smiles_scaffold,smiles_decorator,invert_chiralty=Parameter_InvertChiralty), isomericSmiles=True, canonical=True)
88
+ smiles_joined = remove_attachment_point_numbers(smiles_joined)
89
+ return smiles_joined
90
+ except:
91
+ return ''
92
+
93
+ # print results to the terminal for testing
94
+ if __name__ == "__main__":
95
+ scaffold = 'O=C1NN=C([*])c2c1cccc2'
96
+ decorator = '[*]N1CCN(C)CC1'
97
+ print("Scaffold: ", scaffold)
98
+ print("Decorator:", decorator)
99
+ joined = join_scaf_deco(scaffold,decorator,Parameter_InvertChiralty=True)
100
+ print("Joined: ", joined)
src/molecules/generated_variations.py CHANGED
@@ -1,29 +1,65 @@
 
 
1
  import subprocess
 
 
 
2
  import pandas as pd
3
  from rdkit import Chem
4
  from rdkit.Chem import Draw
5
 
6
- def generate_variations_from_model(user_smiles, n_to_gen=12):
 
7
  """
8
- Runs Gen_SMILES2.py using the given user SMILES in place of scaf_smi
9
- and returns RDKit images.
 
10
  """
11
 
12
- # Run Gen_SMILES2.py with an argument
13
- subprocess.run([
14
- "python", "app/Gen_PartialSMILES2.py",
15
- "--scaffold", user_smiles,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  "--n_to_gen", str(n_to_gen)
17
- ], check=True)
 
 
 
 
 
 
18
 
19
- # Read the generated CSV
20
- df = pd.read_csv("generated_molecules.csv")
 
21
 
22
- images = []
 
 
23
  for smi in df["smiles"].head(n_to_gen):
24
  mol = Chem.MolFromSmiles(smi)
25
- if mol:
26
- img = Draw.MolToImage(mol, size=(250, 250))
27
- images.append({"smiles": smi, "image": img, "style": "generated"})
28
-
29
- return images
 
 
 
 
 
 
1
+ # app/src/molecules/generated_variations.py
2
+
3
  import subprocess
4
+ import sys
5
+ from pathlib import Path
6
+
7
  import pandas as pd
8
  from rdkit import Chem
9
  from rdkit.Chem import Draw
10
 
11
+
12
+ def generate_variations_from_partial_smiles(scaffold_smiles: str, n_to_gen: int = 12):
13
  """
14
+ Call Gen_PartialSMILES2.py as a subprocess, passing the user scaffold,
15
+ then read generated_molecules.csv and return a list of variations:
16
+ each item is a dict: {"smiles": str, "image": PIL.Image, "style": str}
17
  """
18
 
19
+ if not scaffold_smiles or scaffold_smiles.strip() == "":
20
+ return []
21
+
22
+ # Determine project root (where Gen_PartialSMILES2.py lives)
23
+ # This file is app/src/molecules/generated_variations.py
24
+ # parents[0] = .../molecules, parents[1] = .../src, parents[2] = .../app
25
+ project_root = Path(__file__).resolve().parents[2]
26
+ script_path = project_root / "Gen_PartialSMILES2.py"
27
+ csv_path = project_root / "generated_molecules.csv"
28
+
29
+ # Remove old CSV if it exists
30
+ if csv_path.exists():
31
+ csv_path.unlink()
32
+
33
+ # Build subprocess command
34
+ cmd = [
35
+ sys.executable,
36
+ str(script_path),
37
+ "--scaffold", scaffold_smiles,
38
  "--n_to_gen", str(n_to_gen)
39
+ ]
40
+
41
+ try:
42
+ subprocess.run(cmd, cwd=project_root, check=True)
43
+ except subprocess.CalledProcessError as e:
44
+ print(f"Error running Gen_PartialSMILES2.py: {e}")
45
+ return []
46
 
47
+ if not csv_path.exists():
48
+ print("generated_molecules.csv not found after generation.")
49
+ return []
50
 
51
+ df = pd.read_csv(csv_path)
52
+
53
+ variations = []
54
  for smi in df["smiles"].head(n_to_gen):
55
  mol = Chem.MolFromSmiles(smi)
56
+ if mol is None:
57
+ continue
58
+ img = Draw.MolToImage(mol, size=(250, 250))
59
+ variations.append({
60
+ "smiles": smi,
61
+ "image": img,
62
+ "style": "partial_smiles_gen"
63
+ })
64
+
65
+ return variations
src/ui/handlers.py CHANGED
@@ -6,9 +6,9 @@ for the drug discovery application UI components.
6
  """
7
 
8
  from ..molecules.analysis import analyze_molecule_image_only, validate_smiles_realtime, get_molecule_properties_for_hover
9
- from ..molecules.variations import generate_chemical_series_variations, generate_molecule_images
 
10
  from ..ai.services import respond, handle_structure_chat, parse_ai_structures
11
- from ..molecules.generated_variations import generate_variations_from_model
12
 
13
 
14
  class VariationHandlers:
@@ -20,93 +20,61 @@ class VariationHandlers:
20
  self.variations_per_page = 12
21
 
22
  def generate_variations_for_display(self, smiles, num_variations=12):
23
- """Generate variations and format for gallery display."""
24
- print(f"=== GENERATE_VARIATIONS_FOR_DISPLAY CALLED ===")
25
- print(f"SMILES: {smiles}")
26
- print(f"Num variations: {num_variations}")
27
-
28
- variations = generate_variations_from_model(smiles, num_variations)
29
- print(f"Generated {len(variations)} variations")
30
-
31
- self.current_variations = variations[:num_variations]
32
- print(f"Stored {len(self.current_variations)} variations in current_variations")
33
-
34
- # Format for gallery display
35
- gallery_items = []
36
- for i, var in enumerate(self.current_variations):
37
- print(f"Variation {i}: {var.get('style', 'Unknown')}, image type: {type(var.get('image', None))}")
38
- gallery_items.append((var['image'], f"Style: {var['style']}"))
39
 
40
- result = (gallery_items, smiles, self.current_variations[0]['style'] if self.current_variations else "None")
41
- print(f"Returning: {len(gallery_items)} gallery items, SMILES: {smiles}, style: {result[2]}")
42
- print(f"=== GENERATE_VARIATIONS_FOR_DISPLAY COMPLETE ===")
 
 
 
 
43
 
44
- return result
 
45
 
46
  def select_variation(self, evt):
47
  """Handle selection of a variation from the grid."""
48
  try:
49
- print(f"=== SELECT_VARIATION CALLED ===")
50
  print(f"Event: {evt}, type: {type(evt)}")
51
  print(f"Current variations count: {len(self.current_variations)}")
52
 
53
- # If event is None, try to get the first variation as default
54
- if evt is None:
55
- print("Event is None, trying to return first variation")
56
- if self.current_variations:
57
- selected_var = self.current_variations[0]
58
- print(f"Using first variation: {selected_var.get('style', 'Unknown')}")
59
- properties_text = get_molecule_properties_for_hover(selected_var['smiles'])
60
- return selected_var['image'], selected_var['smiles'], selected_var['style'], properties_text
61
- else:
62
- print("No variations available, returning empty")
63
- return None, "", "", ""
64
 
65
- # Handle both event object and direct index
66
- if hasattr(evt, 'index'):
 
 
67
  index = evt.index
68
  elif isinstance(evt, (int, float)):
69
  index = int(evt)
70
  else:
71
- print(f"Unexpected event type: {type(evt)}, value: {evt}")
72
- # Try to return first variation as fallback
73
- if self.current_variations:
74
- selected_var = self.current_variations[0]
75
- properties_text = get_molecule_properties_for_hover(selected_var['smiles'])
76
- return selected_var['image'], selected_var['smiles'], selected_var['style'], properties_text
77
- return None, "", "", ""
78
 
79
- print(f"Selected index: {index}")
80
-
81
- if not self.current_variations or index >= len(self.current_variations):
82
- print(f"No variations available or index {index} out of range (total: {len(self.current_variations)})")
83
- # Try to return first variation as fallback
84
- if self.current_variations:
85
- selected_var = self.current_variations[0]
86
- properties_text = get_molecule_properties_for_hover(selected_var['smiles'])
87
- return selected_var['image'], selected_var['smiles'], selected_var['style'], properties_text
88
- return None, "", "", ""
89
 
90
  selected_var = self.current_variations[index]
91
- print(f"Selected variation {index}: {selected_var.get('style', 'Unknown')}")
92
- print(f"Selected variation image type: {type(selected_var['image'])}")
93
- print(f"Selected variation SMILES: {selected_var['smiles']}")
94
 
95
- # Also update properties for the selected variation
96
- print(f"Getting properties for SMILES: {selected_var['smiles']}")
97
  properties_text = get_molecule_properties_for_hover(selected_var['smiles'])
98
- print(f"Properties text length: {len(properties_text) if properties_text else 'None'}")
99
- print(f"Properties text preview: {properties_text[:100] if properties_text else 'None'}...")
100
-
101
- result = (selected_var['image'], selected_var['smiles'], selected_var['style'], properties_text)
102
- print(f"Returning result: {len(result)} items")
103
- print(f"Image type: {type(result[0])}")
104
- print(f"SMILES: {result[1]}")
105
- print(f"Style: {result[2]}")
106
- print(f"Properties length: {len(result[3]) if result[3] else 'None'}")
107
- print(f"=== SELECT_VARIATION COMPLETE ===")
108
 
109
- return result
110
  except Exception as e:
111
  print(f"Error in select_variation: {e}")
112
  import traceback
@@ -136,17 +104,21 @@ class VariationHandlers:
136
  end_idx = min(start_idx + self.variations_per_page, len(self.current_variations))
137
  page_variations = self.current_variations[start_idx:end_idx]
138
 
139
- # Format for gallery display
140
- gallery_items = []
141
- for var in page_variations:
142
- gallery_items.append((var['image'], f"Style: {var['style']}"))
143
-
144
  page_info = f"Page {self.current_page + 1} of {total_pages}"
145
 
146
- return gallery_items, page_info, page_variations[0]['image'] if page_variations else None, page_variations[0]['smiles'] if page_variations else "", page_variations[0]['style'] if page_variations else ""
 
 
 
 
 
 
 
 
147
 
148
  def update_variation_count(self, count):
149
- """Update the number of variations to generate."""
150
  self.variations_per_page = count
151
  return count
152
 
@@ -154,9 +126,6 @@ class VariationHandlers:
154
  """Analyze molecule and return image with tooltip data."""
155
  molecule_img = analyze_molecule_image_only(smiles)
156
  tooltip_text = get_molecule_properties_for_hover(smiles)
157
-
158
- # For now, we'll return the image and tooltip text separately
159
- # The tooltip will be handled by JavaScript or CSS
160
  return molecule_img, tooltip_text
161
 
162
 
@@ -171,48 +140,39 @@ class BookmarkHandlers:
171
  from rdkit import Chem
172
  from rdkit.Chem import Draw
173
 
174
- # Validate SMILES first
175
  mol = Chem.MolFromSmiles(smiles)
176
  if not mol:
177
  return "❌ Invalid SMILES string - cannot bookmark"
178
 
179
- # Check if already bookmarked
180
  if smiles in [bm['smiles'] for bm in self.bookmarked_molecules]:
181
  return "⚠️ Molecule already bookmarked"
182
 
183
- # Generate a name if not provided
184
  if not molecule_name:
185
  molecule_name = f"Bookmarked_{len(self.bookmarked_molecules) + 1}"
186
 
187
- # Add to bookmarks
188
  self.bookmarked_molecules.append({
189
  'smiles': smiles,
190
  'name': molecule_name,
191
- 'timestamp': len(self.bookmarked_molecules) + 1 # Simple counter
192
  })
193
 
194
  return f"✅ Bookmarked: {molecule_name}"
195
 
196
  def get_bookmarked_molecules(self):
197
- """Get all bookmarked molecules for display."""
198
  return self.bookmarked_molecules
199
 
200
  def remove_bookmark(self, smiles):
201
- """Remove a molecule from bookmarks."""
202
  self.bookmarked_molecules = [bm for bm in self.bookmarked_molecules if bm['smiles'] != smiles]
203
  return "🗑️ Removed from bookmarks"
204
 
205
  def bookmark_current_molecule(self, smiles, name):
206
- """Bookmark current molecule and update gallery."""
207
  from rdkit import Chem
208
  from rdkit.Chem import Draw
209
 
210
  result = self.bookmark_molecule(smiles, name)
211
- # Update the bookmarked gallery
212
  bookmarked_mols = self.get_bookmarked_molecules()
213
  gallery_items = []
214
  for mol in bookmarked_mols:
215
- # Generate smaller images for gallery
216
  mol_obj = Chem.MolFromSmiles(mol['smiles'])
217
  if mol_obj:
218
  img = Draw.MolToImage(mol_obj, size=(150, 150), kekulize=True)
@@ -231,37 +191,39 @@ class AIHandler:
231
  if not message.strip() or not hf_token.strip():
232
  return history, []
233
 
234
- # Add user message to history
235
  history.append({"role": "user", "content": message})
236
 
237
- # Determine if this is a structure generation request
238
  structure_keywords = ['generate', 'create', 'modify', 'derivative', 'variant', 'structure']
239
  is_structure_request = any(keyword in message.lower() for keyword in structure_keywords)
240
 
241
  if is_structure_request and selected_smiles:
242
- # Handle structure generation
243
  ai_response = ""
244
- for chunk in respond(message, history[:-1],
245
- "You are an expert medicinal chemist. Generate new chemical structures based on user requests.",
246
- 512, temperature, 0.9, hf_token):
 
 
 
 
 
 
247
  ai_response = chunk
248
 
249
- # Add AI response to history
250
  history.append({"role": "assistant", "content": ai_response})
251
-
252
- # Parse and generate structure images
253
  structures = parse_ai_structures(ai_response, selected_smiles)
254
-
255
  return history, structures
256
  else:
257
- # Handle general drug discovery questions
258
  ai_response = ""
259
- for chunk in respond(message, history[:-1],
260
- "You are an expert medicinal chemist and drug discovery specialist. Help with molecular analysis, drug design, and medicinal chemistry questions.",
261
- 512, temperature, 0.9, hf_token):
 
 
 
 
 
 
262
  ai_response = chunk
263
 
264
- # Add AI response to history
265
  history.append({"role": "assistant", "content": ai_response})
266
-
267
- return history, []
 
6
  """
7
 
8
  from ..molecules.analysis import analyze_molecule_image_only, validate_smiles_realtime, get_molecule_properties_for_hover
9
+ from ..molecules.variations import generate_molecule_images
10
+ from ..molecules.generated_variations import generate_variations_from_partial_smiles
11
  from ..ai.services import respond, handle_structure_chat, parse_ai_structures
 
12
 
13
 
14
  class VariationHandlers:
 
20
  self.variations_per_page = 12
21
 
22
  def generate_variations_for_display(self, smiles, num_variations=12):
23
+ """
24
+ Generate variations using Gen_PartialSMILES2.py (via subprocess),
25
+ then format them for the gallery display.
26
+ """
27
+ print("=== GENERATE_VARIATIONS_FOR_DISPLAY CALLED ===")
28
+ print(f"SMILES input: {smiles}")
29
+ print(f"Num variations requested: {num_variations}")
30
+
31
+ # Call the subprocess-based generator
32
+ variations = generate_variations_from_partial_smiles(smiles, n_to_gen=num_variations)
33
+ print(f"Generated {len(variations)} variations from partial SMILES model")
34
+
35
+ # Store internally for selection/navigation
36
+ self.current_variations = variations
 
 
37
 
38
+ # Gradio Gallery expects [(image, caption), ...]
39
+ gallery_items = [(v["image"], v["smiles"]) for v in self.current_variations]
40
+
41
+ # Style to return (for hidden display)
42
+ first_style = self.current_variations[0]["style"] if self.current_variations else "None"
43
+
44
+ print("=== GENERATE_VARIATIONS_FOR_DISPLAY COMPLETE ===")
45
 
46
+ # outputs: variations_grid, selected_smiles_display, selected_style_display
47
+ return gallery_items, smiles, first_style
48
 
49
  def select_variation(self, evt):
50
  """Handle selection of a variation from the grid."""
51
  try:
52
+ print("=== SELECT_VARIATION CALLED ===")
53
  print(f"Event: {evt}, type: {type(evt)}")
54
  print(f"Current variations count: {len(self.current_variations)}")
55
 
56
+ if not self.current_variations:
57
+ return None, "", "", ""
 
 
 
 
 
 
 
 
 
58
 
59
+ # If event is None (e.g. change without select), default to first
60
+ if evt is None:
61
+ index = 0
62
+ elif hasattr(evt, 'index'):
63
  index = evt.index
64
  elif isinstance(evt, (int, float)):
65
  index = int(evt)
66
  else:
67
+ index = 0
 
 
 
 
 
 
68
 
69
+ # Clamp index
70
+ if index < 0 or index >= len(self.current_variations):
71
+ index = 0
 
 
 
 
 
 
 
72
 
73
  selected_var = self.current_variations[index]
 
 
 
74
 
 
 
75
  properties_text = get_molecule_properties_for_hover(selected_var['smiles'])
 
 
 
 
 
 
 
 
 
 
76
 
77
+ return selected_var['image'], selected_var['smiles'], selected_var['style'], properties_text
78
  except Exception as e:
79
  print(f"Error in select_variation: {e}")
80
  import traceback
 
104
  end_idx = min(start_idx + self.variations_per_page, len(self.current_variations))
105
  page_variations = self.current_variations[start_idx:end_idx]
106
 
107
+ gallery_items = [(v["image"], v["smiles"]) for v in page_variations]
 
 
 
 
108
  page_info = f"Page {self.current_page + 1} of {total_pages}"
109
 
110
+ first = page_variations[0] if page_variations else None
111
+
112
+ return (
113
+ gallery_items,
114
+ page_info,
115
+ first['image'] if first else None,
116
+ first['smiles'] if first else "",
117
+ first['style'] if first else ""
118
+ )
119
 
120
  def update_variation_count(self, count):
121
+ """Update the number of variations per page."""
122
  self.variations_per_page = count
123
  return count
124
 
 
126
  """Analyze molecule and return image with tooltip data."""
127
  molecule_img = analyze_molecule_image_only(smiles)
128
  tooltip_text = get_molecule_properties_for_hover(smiles)
 
 
 
129
  return molecule_img, tooltip_text
130
 
131
 
 
140
  from rdkit import Chem
141
  from rdkit.Chem import Draw
142
 
 
143
  mol = Chem.MolFromSmiles(smiles)
144
  if not mol:
145
  return "❌ Invalid SMILES string - cannot bookmark"
146
 
 
147
  if smiles in [bm['smiles'] for bm in self.bookmarked_molecules]:
148
  return "⚠️ Molecule already bookmarked"
149
 
 
150
  if not molecule_name:
151
  molecule_name = f"Bookmarked_{len(self.bookmarked_molecules) + 1}"
152
 
 
153
  self.bookmarked_molecules.append({
154
  'smiles': smiles,
155
  'name': molecule_name,
156
+ 'timestamp': len(self.bookmarked_molecules) + 1
157
  })
158
 
159
  return f"✅ Bookmarked: {molecule_name}"
160
 
161
  def get_bookmarked_molecules(self):
 
162
  return self.bookmarked_molecules
163
 
164
  def remove_bookmark(self, smiles):
 
165
  self.bookmarked_molecules = [bm for bm in self.bookmarked_molecules if bm['smiles'] != smiles]
166
  return "🗑️ Removed from bookmarks"
167
 
168
  def bookmark_current_molecule(self, smiles, name):
 
169
  from rdkit import Chem
170
  from rdkit.Chem import Draw
171
 
172
  result = self.bookmark_molecule(smiles, name)
 
173
  bookmarked_mols = self.get_bookmarked_molecules()
174
  gallery_items = []
175
  for mol in bookmarked_mols:
 
176
  mol_obj = Chem.MolFromSmiles(mol['smiles'])
177
  if mol_obj:
178
  img = Draw.MolToImage(mol_obj, size=(150, 150), kekulize=True)
 
191
  if not message.strip() or not hf_token.strip():
192
  return history, []
193
 
 
194
  history.append({"role": "user", "content": message})
195
 
 
196
  structure_keywords = ['generate', 'create', 'modify', 'derivative', 'variant', 'structure']
197
  is_structure_request = any(keyword in message.lower() for keyword in structure_keywords)
198
 
199
  if is_structure_request and selected_smiles:
 
200
  ai_response = ""
201
+ for chunk in respond(
202
+ message,
203
+ history[:-1],
204
+ "You are an expert medicinal chemist. Generate new chemical structures based on user requests.",
205
+ 512,
206
+ temperature,
207
+ 0.9,
208
+ hf_token
209
+ ):
210
  ai_response = chunk
211
 
 
212
  history.append({"role": "assistant", "content": ai_response})
 
 
213
  structures = parse_ai_structures(ai_response, selected_smiles)
 
214
  return history, structures
215
  else:
 
216
  ai_response = ""
217
+ for chunk in respond(
218
+ message,
219
+ history[:-1],
220
+ "You are an expert medicinal chemist and drug discovery specialist. Help with molecular analysis, drug design, and medicinal chemistry questions.",
221
+ 512,
222
+ temperature,
223
+ 0.9,
224
+ hf_token
225
+ ):
226
  ai_response = chunk
227
 
 
228
  history.append({"role": "assistant", "content": ai_response})
229
+ return history, []