Vaishnav14220 commited on
Commit
1b3c9a7
·
1 Parent(s): e1f2b89

Implement real quantum orbitals using PySCF - Natural Atomic Orbitals with DFT/B3LYP

Browse files
Files changed (2) hide show
  1. app.py +108 -134
  2. requirements.txt +4 -1
app.py CHANGED
@@ -192,142 +192,116 @@ def name_to_3d_molecule(name: str, show_orbitals: bool = False) -> tuple:
192
  data_traces = [bonds_trace, atoms_trace]
193
 
194
  if show_orbitals:
195
- import numpy as np
196
-
197
- # Orbital size reference
198
- orbital_radius = {
199
- 'H': 0.6, 'C': 0.9, 'N': 0.8, 'O': 0.75,
200
- 'F': 0.7, 'Cl': 1.0, 'Br': 1.15, 'I': 1.4,
201
- 'P': 1.1, 'S': 1.05, 'B': 0.95, 'Si': 1.2
202
- }
203
-
204
- # Get electron configuration for each atom to determine orbital types
205
- def get_orbitals(atomic_num):
206
- """Return list of orbitals for an atom based on electron configuration"""
207
- # Simplified orbital visualization - show valence orbitals
208
- if atomic_num == 1: # H: 1s
209
- return ['s']
210
- elif atomic_num <= 2: # He: 1s
211
- return ['s']
212
- elif atomic_num <= 10: # Li-Ne: has s and p
213
- return ['s', 'p']
214
- elif atomic_num <= 18: # Na-Ar: has s and p
215
- return ['s', 'p']
216
- elif atomic_num <= 36: # K-Kr: has s, p, d
217
- return ['s', 'p', 'd']
218
- else: # Rb onwards: has s, p, d, f
219
- return ['s', 'p', 'd', 'f']
220
-
221
- # Create orbital shapes around each atom
222
- for idx, (x, y, z, elem) in enumerate(zip(x_coords, y_coords, z_coords, elements)):
223
- atom = atoms[idx]
224
- atomic_num = atom.GetAtomicNum()
225
- orbitals = get_orbitals(atomic_num)
226
 
227
- # Base radius for orbitals
228
- base_radius = orbital_radius.get(elem, 0.8)
229
- color = color_map.get(elem, '#FF1493')
 
 
 
230
 
231
- # S orbital (spherical)
232
- if 's' in orbitals:
233
- u = np.linspace(0, 2 * np.pi, 25)
234
- v = np.linspace(0, np.pi, 20)
235
- s_radius = base_radius * 0.6
236
- sphere_x = x + s_radius * np.outer(np.cos(u), np.sin(v))
237
- sphere_y = y + s_radius * np.outer(np.sin(u), np.sin(v))
238
- sphere_z = z + s_radius * np.outer(np.ones(np.size(u)), np.cos(v))
239
-
240
- s_orbital = go.Surface(
241
- x=sphere_x, y=sphere_y, z=sphere_z,
242
- colorscale=[[0, color], [1, color]],
243
- showscale=False,
244
- opacity=0.2,
245
- name=f'{elem}-s',
246
- hoverinfo='skip'
247
- )
248
- data_traces.append(s_orbital)
249
 
250
- # P orbitals (dumbbell shaped - 3 orientations: px, py, pz)
251
- if 'p' in orbitals:
252
- p_radius = base_radius * 1.0
253
- p_length = base_radius * 1.5
254
-
255
- # Create dumbbell shape for p orbitals
256
- u = np.linspace(0, 2 * np.pi, 20)
257
- v = np.linspace(-1, 1, 15)
258
-
259
- # Pz orbital (along z-axis)
260
- pz_r = p_radius * np.sqrt(1 - v**2)
261
- pz_z_vals = z + p_length * v
262
- pz_x = x + np.outer(pz_r, np.cos(u))
263
- pz_y = y + np.outer(pz_r, np.sin(u))
264
- pz_z = np.outer(pz_z_vals, np.ones(len(u)))
265
-
266
- pz_orbital = go.Surface(
267
- x=pz_x, y=pz_y, z=pz_z,
268
- colorscale=[[0, color], [0.5, color], [1, color]],
269
- showscale=False,
270
- opacity=0.15,
271
- name=f'{elem}-pz',
272
- hoverinfo='skip'
273
- )
274
- data_traces.append(pz_orbital)
275
-
276
- # Px orbital (along x-axis)
277
- px_r = p_radius * np.sqrt(1 - v**2)
278
- px_x_vals = x + p_length * v
279
- px_x = np.outer(px_x_vals, np.ones(len(u)))
280
- px_y = y + np.outer(px_r, np.cos(u))
281
- px_z = z + np.outer(px_r, np.sin(u))
282
-
283
- px_orbital = go.Surface(
284
- x=px_x, y=px_y, z=px_z,
285
- colorscale=[[0, color], [0.5, color], [1, color]],
286
- showscale=False,
287
- opacity=0.15,
288
- name=f'{elem}-px',
289
- hoverinfo='skip'
290
- )
291
- data_traces.append(px_orbital)
292
-
293
- # Py orbital (along y-axis)
294
- py_r = p_radius * np.sqrt(1 - v**2)
295
- py_y_vals = y + p_length * v
296
- py_x = x + np.outer(py_r, np.cos(u))
297
- py_y = np.outer(py_y_vals, np.ones(len(u)))
298
- py_z = z + np.outer(py_r, np.sin(u))
299
-
300
- py_orbital = go.Surface(
301
- x=py_x, y=py_y, z=py_z,
302
- colorscale=[[0, color], [0.5, color], [1, color]],
303
- showscale=False,
304
- opacity=0.15,
305
- name=f'{elem}-py',
306
- hoverinfo='skip'
307
- )
308
- data_traces.append(py_orbital)
309
 
310
- # D orbitals (for transition metals) - simplified cloverleaf shape
311
- if 'd' in orbitals and elem not in ['H', 'C', 'N', 'O', 'F', 'S', 'P', 'Cl']:
312
- d_radius = base_radius * 1.2
313
- theta = np.linspace(0, 2 * np.pi, 25)
314
- phi = np.linspace(0, np.pi, 15)
315
 
316
- # Simplified d-orbital (four-lobed)
317
- r = d_radius * np.abs(np.outer(np.sin(2 * phi), np.cos(2 * theta)))
318
- d_x = x + r * np.outer(np.sin(phi), np.cos(theta))
319
- d_y = y + r * np.outer(np.sin(phi), np.sin(theta))
320
- d_z = z + r * np.outer(np.cos(phi), np.ones(len(theta)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
- d_orbital = go.Surface(
323
- x=d_x, y=d_y, z=d_z,
324
- colorscale=[[0, color], [1, color]],
325
- showscale=False,
326
- opacity=0.12,
327
- name=f'{elem}-d',
328
- hoverinfo='skip'
329
- )
330
- data_traces.append(d_orbital)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
  # Create figure
333
  fig = go.Figure(data=data_traces)
@@ -383,8 +357,8 @@ name_interface = gr.Interface(
383
 
384
  # Create Blocks interface for molecule viewer with autocomplete
385
  with gr.Blocks() as molecule_3d_block:
386
- gr.Markdown("## 🔬 3D Molecule Viewer with Electron Orbitals")
387
- gr.Markdown("Enter a chemical name to view its 2D and 3D structure. Toggle orbitals to see s, p, and d orbital shapes!")
388
 
389
  with gr.Row():
390
  with gr.Column():
@@ -399,9 +373,9 @@ with gr.Blocks() as molecule_3d_block:
399
  filterable=True
400
  )
401
  orbital_checkbox = gr.Checkbox(
402
- label="Show Electron Orbitals (s, p, d shapes)",
403
  value=False,
404
- info="Toggle to see s-orbitals (spheres), p-orbitals (dumbbells), and d-orbitals (cloverleaf) around atoms"
405
  )
406
  submit_btn = gr.Button("Generate 3D Molecule", variant="primary")
407
 
 
192
  data_traces = [bonds_trace, atoms_trace]
193
 
194
  if show_orbitals:
195
+ try:
196
+ import numpy as np
197
+ from pyscf import gto, lo, dft
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ # Build PySCF molecule object
200
+ pyscf_elements = [atom.GetSymbol() for atom in mol.GetAtoms()]
201
+ pyscf_coords = []
202
+ for i in range(mol.GetNumAtoms()):
203
+ pos = conf.GetAtomPosition(i)
204
+ pyscf_coords.append([pos.x, pos.y, pos.z])
205
 
206
+ pyscf_atoms = [(elem, coord) for elem, coord in zip(pyscf_elements, pyscf_coords)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
+ # Create PySCF molecule - use small basis for speed
209
+ pyscf_mole = gto.Mole(basis="sto-3g")
210
+ pyscf_mole.atom = pyscf_atoms
211
+ pyscf_mole.build()
212
+
213
+ # Run fast DFT calculation
214
+ mf = dft.RKS(pyscf_mole)
215
+ mf.xc = 'b3lyp'
216
+ mf.verbose = 0 # Suppress output
217
+ mf.run()
218
+
219
+ # Calculate Natural Atomic Orbitals (pre-NAOs for more localized orbitals)
220
+ dm = mf.make_rdm1()
221
+ naos = lo.nao.prenao(pyscf_mole, dm)
222
+
223
+ # Create grid for orbital evaluation
224
+ grid_resolution = 25 # Lower for speed
225
+ margin = 3.0
226
+
227
+ # Determine grid bounds
228
+ all_x = [coord[0] for coord in pyscf_coords]
229
+ all_y = [coord[1] for coord in pyscf_coords]
230
+ all_z = [coord[2] for coord in pyscf_coords]
231
+
232
+ x_min, x_max = min(all_x) - margin, max(all_x) + margin
233
+ y_min, y_max = min(all_y) - margin, max(all_y) + margin
234
+ z_min, z_max = min(all_z) - margin, max(all_z) + margin
235
+
236
+ # Create meshgrid
237
+ xs = np.linspace(x_min, x_max, grid_resolution)
238
+ ys = np.linspace(y_min, y_max, grid_resolution)
239
+ zs = np.linspace(z_min, z_max, grid_resolution)
240
+ grid_x, grid_y, grid_z = np.meshgrid(xs, ys, zs, indexing='ij')
241
+
242
+ # Flatten grid for evaluation
243
+ grid_coords = np.column_stack([grid_x.ravel(), grid_y.ravel(), grid_z.ravel()])
244
+
245
+ # Evaluate a few representative valence orbitals
246
+ # Select orbitals around HOMO (Highest Occupied Molecular Orbital)
247
+ n_orbitals_to_show = min(3, naos.shape[1]) # Show up to 3 orbitals
248
+ start_orbital = max(0, naos.shape[1] - 5) # Start near valence orbitals
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
+ for orbital_idx in range(start_orbital, min(start_orbital + n_orbitals_to_show, naos.shape[1])):
251
+ # Evaluate orbital on grid
252
+ ao_values = pyscf_mole.eval_gto('GTOval_sph', grid_coords)
253
+ orbital_values = np.dot(ao_values, naos[:, orbital_idx])
254
+ orbital_grid = orbital_values.reshape(grid_x.shape)
255
 
256
+ # Create isosurface for positive lobe
257
+ isoval_positive = 0.02
258
+ try:
259
+ from skimage import measure
260
+ verts_pos, faces_pos, _, _ = measure.marching_cubes(orbital_grid, level=isoval_positive)
261
+
262
+ # Transform vertices to real coordinates
263
+ verts_pos[:, 0] = x_min + verts_pos[:, 0] * (x_max - x_min) / (grid_resolution - 1)
264
+ verts_pos[:, 1] = y_min + verts_pos[:, 1] * (y_max - y_min) / (grid_resolution - 1)
265
+ verts_pos[:, 2] = z_min + verts_pos[:, 2] * (z_max - z_min) / (grid_resolution - 1)
266
+
267
+ orbital_trace_pos = go.Mesh3d(
268
+ x=verts_pos[:, 0], y=verts_pos[:, 1], z=verts_pos[:, 2],
269
+ i=faces_pos[:, 0], j=faces_pos[:, 1], k=faces_pos[:, 2],
270
+ color='blue',
271
+ opacity=0.3,
272
+ name=f'Orbital {orbital_idx+1} (+)',
273
+ hoverinfo='skip'
274
+ )
275
+ data_traces.append(orbital_trace_pos)
276
+ except:
277
+ pass
278
 
279
+ # Create isosurface for negative lobe
280
+ isoval_negative = -0.02
281
+ try:
282
+ verts_neg, faces_neg, _, _ = measure.marching_cubes(orbital_grid, level=isoval_negative)
283
+
284
+ # Transform vertices to real coordinates
285
+ verts_neg[:, 0] = x_min + verts_neg[:, 0] * (x_max - x_min) / (grid_resolution - 1)
286
+ verts_neg[:, 1] = y_min + verts_neg[:, 1] * (y_max - y_min) / (grid_resolution - 1)
287
+ verts_neg[:, 2] = z_min + verts_neg[:, 2] * (z_max - z_min) / (grid_resolution - 1)
288
+
289
+ orbital_trace_neg = go.Mesh3d(
290
+ x=verts_neg[:, 0], y=verts_neg[:, 1], z=verts_neg[:, 2],
291
+ i=faces_neg[:, 0], j=faces_neg[:, 1], k=faces_neg[:, 2],
292
+ color='red',
293
+ opacity=0.3,
294
+ name=f'Orbital {orbital_idx+1} (-)',
295
+ hoverinfo='skip'
296
+ )
297
+ data_traces.append(orbital_trace_neg)
298
+ except:
299
+ pass
300
+
301
+ except Exception as e:
302
+ # If orbital calculation fails, add a note
303
+ print(f"Orbital calculation failed: {str(e)}")
304
+ # Fall back to simple message - orbitals are computationally expensive
305
 
306
  # Create figure
307
  fig = go.Figure(data=data_traces)
 
357
 
358
  # Create Blocks interface for molecule viewer with autocomplete
359
  with gr.Blocks() as molecule_3d_block:
360
+ gr.Markdown("## 🔬 3D Molecule Viewer with Quantum Orbitals")
361
+ gr.Markdown("Enter a chemical name to view its 2D and 3D structure. Toggle orbitals to see **real quantum mechanical Natural Atomic Orbitals (NAOs)** computed with PySCF!")
362
 
363
  with gr.Row():
364
  with gr.Column():
 
373
  filterable=True
374
  )
375
  orbital_checkbox = gr.Checkbox(
376
+ label="Show Quantum Orbitals (NAOs)",
377
  value=False,
378
+ info="Compute and display real Natural Atomic Orbitals using quantum chemistry (DFT/B3LYP)"
379
  )
380
  submit_btn = gr.Button("Generate 3D Molecule", variant="primary")
381
 
requirements.txt CHANGED
@@ -3,4 +3,7 @@ rdkit
3
  gradio==4.44.1
4
  huggingface_hub==0.19.4
5
  cirpy
6
- plotly
 
 
 
 
3
  gradio==4.44.1
4
  huggingface_hub==0.19.4
5
  cirpy
6
+ plotly
7
+ pyscf
8
+ numpy
9
+ scikit-image