rdkit_api / app.py
Vaishnav14220
Remove atom numbers, show only element symbols and add orbital visualization toggle
e2cbfc6
raw
history blame
10.9 kB
import gradio as gr
from rdkit import Chem
from rdkit.Chem import Descriptors, Draw, AllChem
import cirpy
# RDKit API with multiple endpoints
def _mol_from_smiles(smiles: str):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise gr.Error("Invalid SMILES string.")
return mol
def smiles_to_canonical(smiles: str) -> str:
mol = _mol_from_smiles(smiles)
return Chem.MolToSmiles(mol)
def molecular_weight(smiles: str) -> float:
mol = _mol_from_smiles(smiles)
return float(Descriptors.MolWt(mol))
def logp(smiles: str) -> float:
mol = _mol_from_smiles(smiles)
return float(Descriptors.MolLogP(mol))
def tpsa(smiles: str) -> float:
mol = _mol_from_smiles(smiles)
return float(Descriptors.TPSA(mol))
def mol_image(smiles: str):
mol = _mol_from_smiles(smiles)
return Draw.MolToImage(mol)
def name_to_smiles(name: str) -> str:
"""Convert chemical name to SMILES using Chemical Identifier Resolver (CIR)"""
try:
smiles = cirpy.resolve(name, 'smiles')
if smiles is None:
raise gr.Error(f"Could not find SMILES for chemical name: {name}")
return smiles
except Exception as e:
raise gr.Error(f"Error converting name to SMILES: {str(e)}")
def name_to_3d_molecule(name: str, show_orbitals: bool = False) -> tuple:
"""Convert chemical name to 3D molecule SDF and 2D visualization"""
try:
# Convert name to SMILES
smiles = cirpy.resolve(name, 'smiles')
if smiles is None:
raise gr.Error(f"Could not find SMILES for chemical name: {name}")
# Create molecule from SMILES
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise gr.Error(f"Could not create molecule from SMILES: {smiles}")
# Add hydrogens for better 3D structure
mol = Chem.AddHs(mol)
# Generate 3D coordinates
success = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
if success == -1:
raise gr.Error(f"Could not generate 3D coordinates for: {name}")
# Optimize geometry
AllChem.MMFFOptimizeMolecule(mol)
# Generate SDF content
sdf_content = Chem.MolToMolBlock(mol)
# Import needed modules
import io
import base64
from rdkit.Chem.Draw import rdMolDraw2D
# Generate 2D image for preview
mol_2d = Chem.MolFromSmiles(smiles) # Get 2D version
# Create drawer for 2D structure
drawer = rdMolDraw2D.MolDraw2DCairo(600, 600)
drawer.DrawMolecule(mol_2d)
drawer.FinishDrawing()
img_2d_data = drawer.GetDrawingText()
img_2d_str = base64.b64encode(img_2d_data).decode()
# Get SMILES string
canonical_smiles = Chem.MolToSmiles(mol_2d)
# Create 3D visualization using Plotly
import plotly.graph_objects as go
# Get atom positions and elements
conf = mol.GetConformer()
atoms = mol.GetAtoms()
x_coords = []
y_coords = []
z_coords = []
elements = []
atom_labels = []
for i, atom in enumerate(atoms):
pos = conf.GetAtomPosition(i)
x_coords.append(pos.x)
y_coords.append(pos.y)
z_coords.append(pos.z)
element = atom.GetSymbol()
elements.append(element)
atom_labels.append(element) # Only element symbol, no numbers
# Define color scheme for common elements (Jmol colors)
color_map = {
'H': '#FFFFFF', 'C': '#909090', 'N': '#3050F8', 'O': '#FF0D0D',
'F': '#90E050', 'Cl': '#1FF01F', 'Br': '#A62929', 'I': '#940094',
'P': '#FF8000', 'S': '#FFFF30', 'B': '#FFB5B5', 'Si': '#F0C8A0'
}
colors = [color_map.get(e, '#FF1493') for e in elements]
# Create 3D scatter plot for atoms with labels
atoms_trace = go.Scatter3d(
x=x_coords, y=y_coords, z=z_coords,
mode='markers+text',
marker=dict(size=10, color=colors, line=dict(color='black', width=1)),
text=atom_labels,
textposition='top center',
textfont=dict(size=10, color='black'),
name='Atoms',
hovertext=[f"{lbl}<br>Pos: ({x:.2f}, {y:.2f}, {z:.2f})"
for lbl, x, y, z in zip(atom_labels, x_coords, y_coords, z_coords)],
hoverinfo='text'
)
# Create bonds
bond_x = []
bond_y = []
bond_z = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
pos_i = conf.GetAtomPosition(i)
pos_j = conf.GetAtomPosition(j)
bond_x.extend([pos_i.x, pos_j.x, None])
bond_y.extend([pos_i.y, pos_j.y, None])
bond_z.extend([pos_i.z, pos_j.z, None])
bonds_trace = go.Scatter3d(
x=bond_x, y=bond_y, z=bond_z,
mode='lines',
line=dict(color='#888888', width=4),
name='Bonds',
hoverinfo='skip'
)
# Create orbital clouds if requested
data_traces = [bonds_trace, atoms_trace]
if show_orbitals:
import numpy as np
# Add semi-transparent spheres around atoms to represent electron orbitals
orbital_radius = {
'H': 0.6, 'C': 0.9, 'N': 0.8, 'O': 0.75,
'F': 0.7, 'Cl': 1.0, 'Br': 1.15, 'I': 1.4,
'P': 1.1, 'S': 1.05, 'B': 0.95, 'Si': 1.2
}
# Create mesh for each atom's orbital
for i, (x, y, z, elem) in enumerate(zip(x_coords, y_coords, z_coords, elements)):
radius = orbital_radius.get(elem, 0.8)
# Create sphere mesh
u = np.linspace(0, 2 * np.pi, 20)
v = np.linspace(0, np.pi, 15)
sphere_x = x + radius * np.outer(np.cos(u), np.sin(v))
sphere_y = y + radius * np.outer(np.sin(u), np.sin(v))
sphere_z = z + radius * np.outer(np.ones(np.size(u)), np.cos(v))
color = color_map.get(elem, '#FF1493')
orbital_trace = go.Surface(
x=sphere_x, y=sphere_y, z=sphere_z,
colorscale=[[0, color], [1, color]],
showscale=False,
opacity=0.15,
name=f'{elem} orbital',
hoverinfo='skip'
)
data_traces.append(orbital_trace)
# Create figure
fig = go.Figure(data=data_traces)
fig.update_layout(
title=dict(text=f"{name}<br>SMILES: {canonical_smiles}", x=0.5, xanchor='center'),
scene=dict(
xaxis=dict(showbackground=False, showgrid=True, zeroline=False, showticklabels=False, title=''),
yaxis=dict(showbackground=False, showgrid=True, zeroline=False, showticklabels=False, title=''),
zaxis=dict(showbackground=False, showgrid=True, zeroline=False, showticklabels=False, title=''),
bgcolor='white'
),
showlegend=False,
width=800,
height=700,
margin=dict(l=0, r=0, t=50, b=0)
)
# Generate 2D structure HTML
html_2d = f"""
<div style="text-align: center; padding: 20px;">
<h4>2D Structure</h4>
<img src="data:image/png;base64,{img_2d_str}" style="max-width: 600px; width: 100%; border: 1px solid #ddd; border-radius: 8px; padding: 10px; background: white;">
</div>
"""
return html_2d, fig, sdf_content
except Exception as e:
raise gr.Error(f"Error creating molecule: {str(e)}")
smiles_interface = gr.Interface(
fn=smiles_to_canonical,
inputs=gr.Textbox(label="SMILES"),
outputs=gr.Textbox(label="Canonical SMILES"),
api_name="smiles_to_mol",
description="Convert an input SMILES string to its canonical form.",
)
name_interface = gr.Interface(
fn=name_to_smiles,
inputs=gr.Textbox(label="Chemical Name", placeholder="e.g., aspirin, caffeine, benzene"),
outputs=gr.Textbox(label="SMILES"),
api_name="name_to_smiles",
description="Convert a chemical name to SMILES notation.",
examples=[["aspirin"], ["caffeine"], ["benzene"], ["ethanol"]],
)
mw_interface = gr.Interface(
fn=molecular_weight,
inputs=gr.Textbox(label="SMILES"),
outputs=gr.Number(label="Molecular Weight (g/mol)"),
api_name="molecular_weight",
description="Compute the molecular weight from a SMILES string.",
)
logp_interface = gr.Interface(
fn=logp,
inputs=gr.Textbox(label="SMILES"),
outputs=gr.Number(label="logP"),
api_name="logp",
description="Calculate the octanol/water partition coefficient (logP).",
)
tpsa_interface = gr.Interface(
fn=tpsa,
inputs=gr.Textbox(label="SMILES"),
outputs=gr.Number(label="TPSA"),
api_name="tpsa",
description="Calculate the topological polar surface area (TPSA).",
)
molecule_3d_interface = gr.Interface(
fn=name_to_3d_molecule,
inputs=[
gr.Textbox(label="Chemical Name", placeholder="e.g., benzene, aspirin, caffeine, glucose"),
gr.Checkbox(label="Show Electron Orbitals", value=False)
],
outputs=[
gr.HTML(label="2D Structure"),
gr.Plot(label="3D Interactive Molecule Viewer - Rotate and Zoom!"),
gr.Textbox(label="3D SDF Content (Optional - for external viewers)", lines=10, max_lines=20, visible=False)
],
api_name="name_to_molecule",
description="View 2D structure and interactive 3D molecule with atom labels. Click and drag to rotate, scroll to zoom! Toggle 'Show Electron Orbitals' to visualize electron clouds around atoms.",
cache_examples=False,
examples=[["benzene", False], ["aspirin", False], ["caffeine", True], ["glucose", False]],
)
demo = gr.TabbedInterface(
[name_interface, molecule_3d_interface, smiles_interface, mw_interface, logp_interface, tpsa_interface],
[
"Name to SMILES",
"Molecule Viewer",
"SMILES to Canonical",
"Molecular Weight",
"LogP",
"TPSA",
],
title="RDKit API",
css=".gradio-container {max-width: 800px; margin: auto;}",
)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)