iflp1908sl commited on
Commit
94974e9
·
1 Parent(s): 1eb0a26

Replace gradio_molecule3d with direct 3Dmol.js: full style/bg/label/H control

Browse files
Files changed (2) hide show
  1. app.py +170 -79
  2. requirements.txt +0 -1
app.py CHANGED
@@ -183,18 +183,7 @@ def create_xyz_zip(xyz_strings):
183
 
184
  return tmp_zip.name
185
 
186
- from gradio_molecule3d import Molecule3D
187
-
188
- # 3Dmol.js representation configs
189
- STYLE_REPS = {
190
- "Ball and Stick": [
191
- {"model": 0, "style": "stick", "color": "whiteCarbon", "radius": 0.2},
192
- {"model": 0, "style": "sphere", "color": "whiteCarbon", "scale": 0.3}
193
- ],
194
- "Licorice": [
195
- {"model": 0, "style": "stick", "color": "whiteCarbon", "radius": 0.4}
196
- ]
197
- }
198
 
199
  def xyz_to_pdb(xyz_str):
200
  """Convert XYZ string to PDB format for the viewer."""
@@ -288,12 +277,7 @@ def generate(num_molecules, size_mode, fixed_size, diffusion_steps, seed):
288
  xyz_strings.append(xyz_str)
289
  summary_rows.append(parse_composition(xyz_str))
290
 
291
- # 5. Output generation – save PDB files for the viewer (XYZ for download)
292
- file_paths = []
293
- for i, xyz_str in enumerate(xyz_strings):
294
- # Default to PDB for viewer as it's most compatible
295
- file_paths.append(save_to_format(xyz_str, i, "pdb"))
296
-
297
  zip_path = create_xyz_zip(xyz_strings)
298
 
299
  # Prepare table with "Name" column
@@ -304,21 +288,20 @@ def generate(num_molecules, size_mode, fixed_size, diffusion_steps, seed):
304
  cols = ["Name"] + [c for c in df_out.columns if c != "Name"]
305
  df_out = df_out[cols]
306
 
307
- choices = [f"Molecule {i+1}" for i in range(len(file_paths))]
308
 
309
  return (
310
- file_paths[0], # viewer
311
- gr.update(choices=choices, value=choices[0]), # selector
312
- file_paths, # state (paths)
313
- xyz_strings, # state (raw)
314
- zip_path, # download
315
- df_out, # table
316
  )
317
 
318
  except Exception as e:
319
  import traceback
320
  traceback.print_exc()
321
- return None, gr.update(choices=[], value=None), [], [], None, None
322
 
323
  finally:
324
  # Restore original T to be safe
@@ -328,8 +311,75 @@ def generate(num_molecules, size_mode, fixed_size, diffusion_steps, seed):
328
  elif hasattr(TASK, 'T'):
329
  TASK.T = original_T
330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  # ── GRADIO UI ──────────────────────────────────────────────────
332
- with gr.Blocks(title="MolCraftDiffusion", theme=gr.themes.Soft()) as demo:
333
  gr.Markdown(
334
  """
335
  # 🧪 MolCraftDiffusion - Unconditional Generation
@@ -369,20 +419,27 @@ with gr.Blocks(title="MolCraftDiffusion", theme=gr.themes.Soft()) as demo:
369
  )
370
 
371
  seed = gr.Number(value=42, label="Random Seed", precision=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
- with gr.Accordion("🎨 Visual & Download Options", open=True):
374
- mol_style = gr.Radio(
375
- ["Ball and Stick", "Licorice"],
376
- value="Ball and Stick",
377
- label="3D Display Style"
378
- )
379
- dl_format = gr.Radio(
380
- ["PDB", "XYZ"],
381
- value="PDB",
382
- label="Individual Download Format"
383
- )
384
-
385
- btn = gr.Button("🚀 Generate Molecules", variant="primary", size="lg")
386
 
387
  with gr.Column(scale=2):
388
  gr.Markdown("### 🧬 Generated Molecules")
@@ -393,8 +450,13 @@ with gr.Blocks(title="MolCraftDiffusion", theme=gr.themes.Soft()) as demo:
393
  )
394
  single_dl = gr.File(label="Download Selection", scale=2)
395
 
 
 
 
 
 
 
396
  with gr.Row():
397
- viewer = Molecule3D(label="3D Viewer", reps=STYLE_REPS["Ball and Stick"], scale=2)
398
  preview = gr.Image(label="2D Preview (PNG)", scale=1)
399
 
400
  with gr.Row():
@@ -402,14 +464,13 @@ with gr.Blocks(title="MolCraftDiffusion", theme=gr.themes.Soft()) as demo:
402
  download_all = gr.File(label="Download All (.zip)")
403
 
404
  # Hidden states
405
- mol_files_state = gr.State([])
406
- raw_xyz_state = gr.State([])
407
 
408
  def mol_to_png(xyz_str):
409
  """Generate a 2D PNG preview using RDKit."""
410
  from rdkit.Chem import Draw
411
  try:
412
- # Simple XYZ -> RDKit Mol (no connectivity info in XYZ, so we guess)
413
  lines = xyz_str.strip().splitlines()
414
  if len(lines) < 3: return None
415
 
@@ -422,7 +483,6 @@ with gr.Blocks(title="MolCraftDiffusion", theme=gr.themes.Soft()) as demo:
422
  mol.AddAtom(atom)
423
  pos.append([float(parts[1]), float(parts[2]), float(parts[3])])
424
 
425
- # Simple distance-based connectivity for drawing
426
  adj = build_adjacency_matrix(np.array(pos), [a.GetAtomicNum() for a in mol.GetAtoms()], scale=1.2)
427
  for i in range(len(pos)):
428
  for j in range(i+1, len(pos)):
@@ -430,11 +490,8 @@ with gr.Blocks(title="MolCraftDiffusion", theme=gr.themes.Soft()) as demo:
430
  mol.AddBond(i, j, Chem.BondType.SINGLE)
431
 
432
  rd_mol = mol.GetMol()
433
- # Compute 2D coordinates for drawing
434
- from rdkit.Chem import AllChem
435
  AllChem.Compute2DCoords(rd_mol)
436
 
437
- # Draw to PNG
438
  img = Draw.MolToImage(rd_mol, size=(400, 400))
439
  tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
440
  img.save(tmp_img.name)
@@ -443,53 +500,87 @@ with gr.Blocks(title="MolCraftDiffusion", theme=gr.themes.Soft()) as demo:
443
  print(f"Drawing error: {e}")
444
  return None
445
 
446
- def update_view(choice, file_paths, raw_xyzs, style, fmt):
447
- """Update viewer, preview and individual download based on selection/style/format."""
448
- if not file_paths or not raw_xyzs or choice is None:
449
- return gr.update(value=None), None, None
450
-
451
  try:
452
  idx = int(choice.split()[-1]) - 1
453
- if not (0 <= idx < len(file_paths)):
454
- return gr.update(value=None), None, None
455
  except (ValueError, IndexError):
456
- return gr.update(value=None), None, None
457
-
458
- # 1. Update Viewer (always use PDB for viewer compatibility)
459
- new_reps = STYLE_REPS.get(style, STYLE_REPS["Ball and Stick"])
460
- viewer_update = gr.update(value=file_paths[idx], reps=new_reps)
461
 
462
- # 2. Update Download File (respect user's format preference)
463
  dl_path = save_to_format(raw_xyzs[idx], idx, fmt)
464
-
465
- # 3. Generate PNG Preview
466
  png_path = mol_to_png(raw_xyzs[idx])
467
-
468
- return viewer_update, dl_path, png_path
 
 
 
 
 
 
 
 
469
 
470
- # Listeners
 
 
 
 
 
 
 
 
471
  mol_selector.change(
472
- fn=update_view,
473
- inputs=[mol_selector, mol_files_state, raw_xyz_state, mol_style, dl_format],
474
- outputs=[viewer, single_dl, preview],
475
- )
476
-
477
- mol_style.change(
478
- fn=update_view,
479
- inputs=[mol_selector, mol_files_state, raw_xyz_state, mol_style, dl_format],
480
- outputs=[viewer, single_dl, preview],
481
  )
482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  dl_format.change(
484
- fn=update_view,
485
- inputs=[mol_selector, mol_files_state, raw_xyz_state, mol_style, dl_format],
486
- outputs=[viewer, single_dl, preview],
487
  )
488
 
 
489
  btn.click(
490
  fn=generate,
491
  inputs=[num_mol, size_mode, fixed_size, diffusion_steps, seed],
492
- outputs=[viewer, mol_selector, mol_files_state, raw_xyz_state, download_all, table]
 
 
 
 
 
493
  )
494
 
495
  gr.Markdown("---")
 
183
 
184
  return tmp_zip.name
185
 
186
+ # ── 3Dmol.js SETUP (loaded via head tag) ──────────────────────
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  def xyz_to_pdb(xyz_str):
189
  """Convert XYZ string to PDB format for the viewer."""
 
277
  xyz_strings.append(xyz_str)
278
  summary_rows.append(parse_composition(xyz_str))
279
 
280
+ # 5. Output generation – save zip for bulk download
 
 
 
 
 
281
  zip_path = create_xyz_zip(xyz_strings)
282
 
283
  # Prepare table with "Name" column
 
288
  cols = ["Name"] + [c for c in df_out.columns if c != "Name"]
289
  df_out = df_out[cols]
290
 
291
+ choices = [f"Molecule {i+1}" for i in range(len(xyz_strings))]
292
 
293
  return (
294
+ xyz_strings[0], # current_xyz (triggers JS)
295
+ gr.update(choices=choices, value=choices[0]), # selector
296
+ xyz_strings, # raw_xyz_state
297
+ zip_path, # download_all
298
+ df_out, # table
 
299
  )
300
 
301
  except Exception as e:
302
  import traceback
303
  traceback.print_exc()
304
+ return "", gr.update(choices=[], value=None), [], None, None
305
 
306
  finally:
307
  # Restore original T to be safe
 
311
  elif hasattr(TASK, 'T'):
312
  TASK.T = original_T
313
 
314
+ # ── 3Dmol.js JavaScript ───────────────────────────────────────
315
+ THREEDMOL_HEAD = """
316
+ <script src="https://3Dmol.org/build/3Dmol-min.js"></script>
317
+ <style>
318
+ #mol3d-container { width: 100%; height: 480px; position: relative; border: 1px solid #ccc; border-radius: 8px; overflow: hidden; }
319
+ </style>
320
+ <script>
321
+ var _viewer = null;
322
+ var _currentXYZ = null;
323
+
324
+ function initViewer() {
325
+ var el = document.getElementById('mol3d-container');
326
+ if (!el) return;
327
+ if (_viewer) { _viewer.clear(); }
328
+ else { _viewer = $3Dmol.createViewer(el, {backgroundColor: '0xffffff'}); }
329
+ }
330
+
331
+ function loadMolecule(xyzStr, style, bg, showLabels, hideH) {
332
+ if (!_viewer) initViewer();
333
+ if (!_viewer) return;
334
+ _currentXYZ = xyzStr;
335
+ _viewer.clear();
336
+ if (!xyzStr || xyzStr.length < 5) return;
337
+ _viewer.addModel(xyzStr, 'xyz');
338
+ applyStyle(style, hideH);
339
+ if (showLabels) {
340
+ _viewer.addPropertyLabels('elem', {}, {fontSize: 12, fontColor: 'black', backgroundOpacity: 0.3, backgroundColor: 'white', alignment: 'center'});
341
+ }
342
+ _viewer.setBackgroundColor(bg);
343
+ _viewer.zoomTo();
344
+ _viewer.render();
345
+ }
346
+
347
+ function applyStyle(style, hideH) {
348
+ if (!_viewer) return;
349
+ _viewer.setStyle({}, {});
350
+ var sel = hideH ? {elem: ['C','N','O','S','P','F','Cl','Br','I','B','Si','Se','As','Al','Hg','Bi']} : {};
351
+ if (hideH) { _viewer.setStyle({elem: 'H'}, {}); }
352
+ switch(style) {
353
+ case 'Ball and Stick':
354
+ _viewer.setStyle(sel, {stick: {radius: 0.15, colorscheme: 'Jmol'}, sphere: {scale: 0.25, colorscheme: 'Jmol'}});
355
+ break;
356
+ case 'Licorice':
357
+ _viewer.setStyle(sel, {stick: {radius: 0.3, colorscheme: 'Jmol'}});
358
+ break;
359
+ case 'Sphere':
360
+ _viewer.setStyle(sel, {sphere: {colorscheme: 'Jmol'}});
361
+ break;
362
+ case 'Stick':
363
+ _viewer.setStyle(sel, {stick: {colorscheme: 'Jmol'}});
364
+ break;
365
+ default:
366
+ _viewer.setStyle(sel, {stick: {radius: 0.15, colorscheme: 'Jmol'}, sphere: {scale: 0.25, colorscheme: 'Jmol'}});
367
+ }
368
+ _viewer.render();
369
+ }
370
+
371
+ function refreshViewer(style, bg, showLabels, hideH) {
372
+ if (!_viewer || !_currentXYZ) return;
373
+ loadMolecule(_currentXYZ, style, bg, showLabels, hideH);
374
+ }
375
+ </script>
376
+ """
377
+
378
+ # Background color mapping
379
+ BG_COLORS = {"White": "0xffffff", "Black": "0x000000", "Grey": "0x333333"}
380
+
381
  # ── GRADIO UI ──────────────────────────────────────────────────
382
+ with gr.Blocks(title="MolCraftDiffusion", theme=gr.themes.Soft(), head=THREEDMOL_HEAD) as demo:
383
  gr.Markdown(
384
  """
385
  # 🧪 MolCraftDiffusion - Unconditional Generation
 
419
  )
420
 
421
  seed = gr.Number(value=42, label="Random Seed", precision=0)
422
+
423
+ with gr.Accordion("🎨 Viewer Options", open=True):
424
+ mol_style = gr.Radio(
425
+ ["Ball and Stick", "Licorice", "Sphere", "Stick"],
426
+ value="Ball and Stick",
427
+ label="3D Style"
428
+ )
429
+ bg_color = gr.Radio(
430
+ ["White", "Black", "Grey"],
431
+ value="White",
432
+ label="Background"
433
+ )
434
+ show_labels = gr.Checkbox(label="Show atom labels", value=False)
435
+ hide_h = gr.Checkbox(label="Hide hydrogens", value=False)
436
+ dl_format = gr.Radio(
437
+ ["PDB", "XYZ"],
438
+ value="PDB",
439
+ label="Download Format"
440
+ )
441
 
442
+ btn = gr.Button("🚀 Generate Molecules", variant="primary", size="lg")
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
  with gr.Column(scale=2):
445
  gr.Markdown("### 🧬 Generated Molecules")
 
450
  )
451
  single_dl = gr.File(label="Download Selection", scale=2)
452
 
453
+ # 3Dmol.js viewer container
454
+ viewer_html = gr.HTML(
455
+ value='<div id="mol3d-container" style="width:100%;height:480px;"></div>',
456
+ label="3D Viewer"
457
+ )
458
+
459
  with gr.Row():
 
460
  preview = gr.Image(label="2D Preview (PNG)", scale=1)
461
 
462
  with gr.Row():
 
464
  download_all = gr.File(label="Download All (.zip)")
465
 
466
  # Hidden states
467
+ raw_xyz_state = gr.State([]) # List of raw XYZ strings
468
+ current_xyz = gr.Textbox(visible=False) # Current molecule XYZ for JS
469
 
470
  def mol_to_png(xyz_str):
471
  """Generate a 2D PNG preview using RDKit."""
472
  from rdkit.Chem import Draw
473
  try:
 
474
  lines = xyz_str.strip().splitlines()
475
  if len(lines) < 3: return None
476
 
 
483
  mol.AddAtom(atom)
484
  pos.append([float(parts[1]), float(parts[2]), float(parts[3])])
485
 
 
486
  adj = build_adjacency_matrix(np.array(pos), [a.GetAtomicNum() for a in mol.GetAtoms()], scale=1.2)
487
  for i in range(len(pos)):
488
  for j in range(i+1, len(pos)):
 
490
  mol.AddBond(i, j, Chem.BondType.SINGLE)
491
 
492
  rd_mol = mol.GetMol()
 
 
493
  AllChem.Compute2DCoords(rd_mol)
494
 
 
495
  img = Draw.MolToImage(rd_mol, size=(400, 400))
496
  tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
497
  img.save(tmp_img.name)
 
500
  print(f"Drawing error: {e}")
501
  return None
502
 
503
+ def select_molecule(choice, raw_xyzs, fmt):
504
+ """Update download file and current XYZ when molecule selection changes."""
505
+ if not raw_xyzs or choice is None:
506
+ return None, None, ""
 
507
  try:
508
  idx = int(choice.split()[-1]) - 1
509
+ if not (0 <= idx < len(raw_xyzs)):
510
+ return None, None, ""
511
  except (ValueError, IndexError):
512
+ return None, None, ""
 
 
 
 
513
 
 
514
  dl_path = save_to_format(raw_xyzs[idx], idx, fmt)
 
 
515
  png_path = mol_to_png(raw_xyzs[idx])
516
+ return dl_path, png_path, raw_xyzs[idx]
517
+
518
+ # JS to call loadMolecule when current_xyz changes
519
+ LOAD_JS = """
520
+ (xyz, style, bg, labels, hideH) => {
521
+ var bgHex = {'White':'0xffffff','Black':'0x000000','Grey':'0x333333'}[bg] || '0xffffff';
522
+ setTimeout(() => { loadMolecule(xyz, style, bgHex, labels, hideH); }, 100);
523
+ return [xyz, style, bg, labels, hideH];
524
+ }
525
+ """
526
 
527
+ REFRESH_JS = """
528
+ (style, bg, labels, hideH) => {
529
+ var bgHex = {'White':'0xffffff','Black':'0x000000','Grey':'0x333333'}[bg] || '0xffffff';
530
+ refreshViewer(style, bgHex, labels, hideH);
531
+ return [style, bg, labels, hideH];
532
+ }
533
+ """
534
+
535
+ # When molecule selection changes: update download + preview, then trigger JS
536
  mol_selector.change(
537
+ fn=select_molecule,
538
+ inputs=[mol_selector, raw_xyz_state, dl_format],
539
+ outputs=[single_dl, preview, current_xyz],
540
+ ).then(
541
+ fn=None,
542
+ inputs=[current_xyz, mol_style, bg_color, show_labels, hide_h],
543
+ outputs=None,
544
+ js=LOAD_JS,
 
545
  )
546
 
547
+ # When style/bg/labels/hideH changes: refresh viewer via JS only
548
+ for ctrl in [mol_style, bg_color, show_labels, hide_h]:
549
+ ctrl.change(
550
+ fn=None,
551
+ inputs=[mol_style, bg_color, show_labels, hide_h],
552
+ outputs=None,
553
+ js=REFRESH_JS,
554
+ )
555
+
556
+ # When format changes: just update the download file
557
+ def update_dl(choice, raw_xyzs, fmt):
558
+ if not raw_xyzs or choice is None:
559
+ return None
560
+ try:
561
+ idx = int(choice.split()[-1]) - 1
562
+ if 0 <= idx < len(raw_xyzs):
563
+ return save_to_format(raw_xyzs[idx], idx, fmt)
564
+ except (ValueError, IndexError):
565
+ pass
566
+ return None
567
+
568
  dl_format.change(
569
+ fn=update_dl,
570
+ inputs=[mol_selector, raw_xyz_state, dl_format],
571
+ outputs=[single_dl],
572
  )
573
 
574
+ # Generate button
575
  btn.click(
576
  fn=generate,
577
  inputs=[num_mol, size_mode, fixed_size, diffusion_steps, seed],
578
+ outputs=[current_xyz, mol_selector, raw_xyz_state, download_all, table]
579
+ ).then(
580
+ fn=None,
581
+ inputs=[current_xyz, mol_style, bg_color, show_labels, hide_h],
582
+ outputs=None,
583
+ js=LOAD_JS,
584
  )
585
 
586
  gr.Markdown("---")
requirements.txt CHANGED
@@ -41,5 +41,4 @@ psutil
41
  # App / Deployment
42
  gradio==4.42.0
43
  gradio_client==1.3.0
44
- gradio_molecule3d
45
  huggingface-hub<0.25.0
 
41
  # App / Deployment
42
  gradio==4.42.0
43
  gradio_client==1.3.0
 
44
  huggingface-hub<0.25.0