PoreGCN / visualize.py
catenate's picture
visualize.py: Fast/Full XAI toggle + skip +0.00 cavity labels + About Ensembles and XAI explainer
8ffab12 verified
Raw
History Blame Contribute Delete
38.3 kB
"""
visualize.py - 3Dmol.js 3D visualization and iRASPA CIF export for PoreGCN HF Space.
Frontend contract (app.py imports):
from visualize import create_3d_visualization, export_iraspa_cif
create_3d_visualization():
Returns an HTML string (with embedded 3Dmol.js) ready for gr.HTML().
- Atoms colored by attribution via B-factor channel (bwr colormap):
blue = negative attribution, white = neutral, red = positive attribution
- Metal atoms rendered as larger spheres (radius 0.6) to visually mark SBU centers
- Non-metal atoms rendered as sticks (radius 0.15)
- Pores shown as a Mercury-style translucent yellow void isosurface
(rolling 1.2 A probe + marching cubes; CCDC convention)
- Per-cavity attribution carried by small opaque orange/blue beads
at cavity centres, sized by |attribution|
- Unit cell wireframe via addUnitCell()
export_iraspa_cif():
Writes a CIF with _atom_site_B_iso_or_equiv overwritten by normalised
attributions in [1, 99] (1=most negative, 50=neutral, 99=most positive).
Compatible with iRASPA's Temperature Factor colouring.
"""
from __future__ import annotations
import os
import re
from pathlib import Path
from typing import Optional
import numpy as np
# Bundled 3Dmol.js library, loaded once at module import.
# Avoids CDN dependency so the viewer works on networks that block cdnjs/jsdelivr.
_3DMOL_JS_PATH = Path(__file__).parent / "3Dmol-min.js"
_3DMOL_JS_INLINE: Optional[str] = None
if _3DMOL_JS_PATH.exists():
_3DMOL_JS_INLINE = _3DMOL_JS_PATH.read_text(encoding="utf-8")
def _inline_3dmol_library(viewer_html: str, width: int = 800, height: int = 600) -> str:
"""Patch py3Dmol HTML so 3Dmol.js runs in an isolated iframe.
Two reasons we need the iframe:
1. Gradio's gr.HTML component sanitises <script> tags, so the original
py3Dmol JS would never run. An iframe with srcdoc executes scripts
in its own document context, bypassing the sanitiser.
2. py3Dmol's _make_html() loads 3Dmol.js from cdnjs.cloudflare.com via
loadScriptAsync(). We replace that with an inlined copy so the viewer
works on networks that block CDN access.
"""
if _3DMOL_JS_INLINE is None:
# No bundled library; keep the original CDN call but still iframe-wrap
# so scripts can execute outside Gradio's sanitiser
full_doc = (
"<!DOCTYPE html><html><head><meta charset='utf-8'></head>"
f"<body style='margin:0;padding:0;'>{viewer_html}</body></html>"
)
else:
inline_script = f'<script type="text/javascript">{_3DMOL_JS_INLINE}</script>'
patched_viewer = re.sub(
r"loadScriptAsync\(['\"]https?://[^'\"]+['\"]\)",
"Promise.resolve()",
viewer_html,
count=1,
)
full_doc = (
"<!DOCTYPE html><html><head><meta charset='utf-8'>"
f"{inline_script}</head>"
f"<body style='margin:0;padding:0;overflow:hidden;background:black;'>"
f"{patched_viewer}</body></html>"
)
# HTML-escape the doc for the srcdoc attribute (escape quotes and ampersands)
srcdoc = full_doc.replace("&", "&amp;").replace('"', "&quot;")
iframe_html = (
f'<iframe srcdoc="{srcdoc}" '
f'width="100%" height="{height + 20}" '
f'frameborder="0" style="border:1px solid #1f2937;border-radius:4px;background:black;"'
f'></iframe>'
)
return iframe_html
try:
import py3Dmol
PY3DMOL_OK = True
except ImportError:
PY3DMOL_OK = False
try:
from pymatgen.core import Structure
from pymatgen.io.cif import CifWriter
PYMATGEN_OK = True
except ImportError:
PYMATGEN_OK = False
# Metal elements that get the large-sphere treatment in the 3D viewer.
# Covers most MOF secondary building units (SBUs).
_METAL_ELEMENTS = {
'Cu', 'Zn', 'Co', 'Zr', 'Fe', 'Mn', 'Ni', 'Mg', 'Ca', 'V',
'Al', 'Cr', 'Ti', 'Hf', 'In', 'Ga', 'Y', 'Ce', 'Nd', 'Tb',
'Eu', 'Gd', 'La', 'Mo', 'W', 'Ru', 'Rh', 'Pd', 'Cd', 'Sn',
'Ba', 'Sr', 'Li', 'Na', 'K', 'Rb', 'Cs',
}
def _standard_cif_lattice(abc, angles_deg):
"""Build the standard CIF Cartesian lattice matrix from a/b/c/alpha/beta/gamma.
The standard convention (used by 3Dmol.js, Mercury, VESTA, iRASPA when
parsing a CIF that specifies only cell scalars):
a along +x
b in the +xy plane (positive y)
c chosen so c_z > 0
Pymatgen does NOT always store its `Lattice.matrix` in this orientation;
when a CIF is loaded, the matrix may be rotated to a different frame.
Atom positions written by pymatgen's CifWriter contain only fractional
coordinates plus the cell scalars, so the rendering viewer (3Dmol)
re-derives Cartesian positions in the standard convention. Any geometry
we compute alongside the structure must be transformed into the same
standard frame before being passed to the viewer.
Returns a 3x3 ndarray with row vectors a, b, c.
"""
a, b, c = abc
al, be, ga = [np.deg2rad(x) for x in angles_deg]
ax, ay, az = a, 0.0, 0.0
bx, by, bz = b * np.cos(ga), b * np.sin(ga), 0.0
cx = c * np.cos(be)
sin_ga = np.sin(ga) if abs(np.sin(ga)) > 1e-12 else 1e-12
cy = c * (np.cos(al) - np.cos(be) * np.cos(ga)) / sin_ga
cz = np.sqrt(max(c * c - cx * cx - cy * cy, 0.0))
return np.array([[ax, ay, az], [bx, by, bz], [cx, cy, cz]], dtype=float)
def _compute_void_mesh(
structure,
probe_radius: float = 1.2,
target_grid_spacing: float = 0.7,
n_per_axis_min: int = 16,
n_per_axis_max: int = 60,
max_vertices: int = 40000,
vdw_eff: float = 1.5,
):
"""Compute a Mercury-style contact-surface void mesh for the unit cell.
Rolls a `probe_radius` probe (default 1.2 A, helium-sized) through the
unit cell and runs marching cubes at the void boundary. Mirrors
Mercury's Display > Voids > Contact Surface (CCDC, see
https://www.ccdc.cam.ac.uk/discover/blog/how-to-search-visualize-and-analyse-mof-structures).
Atoms are replicated to a 3x3x3 super-image so distance queries respect
periodic boundaries. A constant effective vdW radius of 1.5 A is used
instead of per-element vdW; for typical MOF elements (C, N, O, H, Cu,
Zn etc.) the per-element value is within 0.2 A of this mean and the
visual is unchanged.
Robustness behaviour. Per-axis grid resolution is clamped to
[`n_per_axis_min`, `n_per_axis_max`] so very large unit cells do not
blow up compute or browser payload, and very small cells still get a
decent surface. Mesh vertex count is capped at `max_vertices`; if the
surface would exceed that, the function returns (None, None) and the
caller falls back to a no-mesh render (atom colouring plus attribution
beads). All anticipated failure modes (missing scikit-image, missing
scipy, degenerate lattice, all-void or all-solid grid, marching-cubes
failure, oversize mesh) return (None, None).
Returns
-------
(vertices_cart, faces) where
vertices_cart: np.ndarray [V, 3] in Cartesian Angstroms
faces: np.ndarray [F, 3] of vertex indices per triangle
or (None, None).
"""
# Optional dependencies; degrade silently if either is missing
try:
from skimage.measure import marching_cubes
except Exception:
return None, None
try:
from scipy.spatial import cKDTree
except Exception:
return None, None
try:
lattice = np.asarray(structure.lattice.matrix, dtype=float)
a, b, c = structure.lattice.abc
if min(a, b, c) <= 0.0:
return None, None
# Per-axis resolution clamped so any cell stays within the compute
# and payload budget. Tiny cells get more samples than nominal
# spacing would suggest; huge cells get coarser sampling.
nx = int(np.clip(np.ceil(a / target_grid_spacing),
n_per_axis_min, n_per_axis_max))
ny = int(np.clip(np.ceil(b / target_grid_spacing),
n_per_axis_min, n_per_axis_max))
nz = int(np.clip(np.ceil(c / target_grid_spacing),
n_per_axis_min, n_per_axis_max))
fx = np.linspace(0.0, 1.0, nx, endpoint=False)
fy = np.linspace(0.0, 1.0, ny, endpoint=False)
fz = np.linspace(0.0, 1.0, nz, endpoint=False)
grid_frac = np.stack(np.meshgrid(fx, fy, fz, indexing='ij'), axis=-1)
grid_cart = grid_frac.reshape(-1, 3) @ lattice
atom_cart = np.array([site.coords for site in structure], dtype=float)
if atom_cart.shape[0] == 0:
return None, None
shifts = []
for da in (-1, 0, 1):
for db in (-1, 0, 1):
for dc in (-1, 0, 1):
shifts.append(da * lattice[0] + db * lattice[1] + dc * lattice[2])
images = np.vstack([atom_cart + s for s in shifts])
tree = cKDTree(images)
cutoff = vdw_eff + probe_radius
dist, _ = tree.query(grid_cart, k=1)
void = (dist > cutoff).reshape(nx, ny, nz).astype(np.float32)
if void.max() < 0.5 or void.min() > 0.5:
return None, None
try:
verts_idx, faces, _, _ = marching_cubes(
void, level=0.5, allow_degenerate=False
)
except Exception:
return None, None
# Vertex cap so the JSON payload shipped to the iframe stays
# bounded even for unusual very-high-porosity structures
if verts_idx.shape[0] > max_vertices:
return None, None
# Convert to fractional coordinates (frame-agnostic), then to
# Cartesian using the STANDARD CIF convention so the mesh aligns
# with how 3Dmol.js positions the atoms it parses from the CIF.
# Using `lattice` (pymatgen's matrix) here would put the mesh in a
# different frame from the rendered framework when the source CIF
# uses a non-canonical lattice orientation.
verts_frac = verts_idx / np.array([nx, ny, nz], dtype=float)
# Defensive clip: drop any face whose vertices stray outside
# [0, 1) fractional. With endpoint=False linspace this should
# never trigger, but keeps the mesh strictly inside the cell
# wireframe even if marching_cubes produces an edge case.
eps = 1e-6
inside = np.all(
(verts_frac >= -eps) & (verts_frac <= 1.0 + eps), axis=1
)
if not inside.all():
face_inside = inside[faces].all(axis=1)
faces = faces[face_inside]
if faces.shape[0] == 0:
return None, None
std_lat = _standard_cif_lattice(structure.lattice.abc,
structure.lattice.angles)
verts_cart = verts_frac @ std_lat
return verts_cart.astype(float), faces.astype(int)
except Exception:
# Catch-all so a malformed structure never breaks the rest of the
# rendering pipeline. Caller falls back to no-mesh render.
return None, None
def _cluster_pore_vertices(
positions: np.ndarray,
radii: np.ndarray,
attrs: np.ndarray,
) -> list:
"""Cluster Voronoi pore vertices into cavity centres.
Greedy non-maximum-suppression style: walk vertices in descending order
of inscribed-sphere radius, take the largest as a cavity centre, absorb
all unassigned vertices within that radius. The largest inscribed sphere
in a cavity *is* the cavity centre by construction (Zeo++ semantics), and
nearby smaller-radius vertices are pinch points along channels feeding it.
Returns a list of dicts with center, radius, and aggregated attribution
statistics over cluster members. The cavity-scale sphere idiom follows
Lisensky and Yaghi, J. Chem. Educ. 2022, 99, 1998-2004 (Figures 1-8),
where one translucent sphere per crystallographic pore is drawn at the
cavity centre, sized to the cavity diameter.
"""
n = len(positions)
if n == 0:
return []
positions = np.asarray(positions, dtype=float)
radii = np.asarray(radii, dtype=float)
attrs = np.asarray(attrs, dtype=float)
order = np.argsort(radii)[::-1]
consumed = np.zeros(n, dtype=bool)
clusters: list = []
for i in order:
if consumed[i]:
continue
centre = positions[i]
cavity_radius = float(max(radii[i], 0.5)) # floor to avoid invisible spheres
d = np.linalg.norm(positions - centre, axis=1)
members = np.where((d <= cavity_radius) & (~consumed))[0]
if len(members) == 0:
members = np.array([i])
consumed[members] = True
member_attrs = attrs[members]
# Choose the cluster's representative attribution as the signed value
# of the member with the largest magnitude. This preserves the sign of
# the strongest contributor instead of cancelling positives and
# negatives in a mean.
peak_idx = int(np.argmax(np.abs(member_attrs)))
peak_attr = float(member_attrs[peak_idx])
clusters.append({
'center': centre,
'radius': cavity_radius,
'attr': peak_attr,
'members': members,
})
return clusters
# =============================================================================
# Private helper: CIF string with B-factors injected
# =============================================================================
def _cif_string_with_bfactors(structure, per_atom_attrs: np.ndarray) -> str:
"""
Build a CIF text string from a pymatgen Structure with per-atom B-factors
encoding XAI attributions on the iRASPA scale:
1 = most negative attribution
50 = neutral
99 = most positive attribution
Both create_3d_visualization and export_iraspa_cif delegate here so the
B-factor injection logic is defined exactly once.
Atom fractional coordinates are wrapped to [0, 1) before writing the CIF
so that no atoms appear outside the unit cell wireframe in the rendered
viewer. Source CIFs occasionally place atoms at fractional positions
slightly past 1.0 (numeric precision or symmetry-related sites), which
otherwise produce visually confusing atoms drifting outside the cell.
Args:
structure: pymatgen Structure
per_atom_attrs: np.ndarray [N_atoms], signed attributions
Returns:
CIF text string with _atom_site_B_iso_or_equiv column appended.
"""
abs_max = max(float(np.max(np.abs(per_atom_attrs))), 1e-8)
b_factors = 50.0 + 49.0 * (per_atom_attrs / abs_max)
# Wrap fractional coordinates to [0, 1) so 3Dmol does not render atoms
# outside the unit cell wireframe. Atom order is preserved so that
# per_atom_attrs indexing remains valid.
try:
wrapped_struct = Structure(
lattice=structure.lattice,
species=[site.species for site in structure],
coords=[(np.asarray(site.frac_coords, dtype=float) % 1.0)
for site in structure],
coords_are_cartesian=False,
)
except Exception:
wrapped_struct = structure
writer = CifWriter(wrapped_struct)
cif_text = str(writer)
lines = cif_text.split('\n')
new_lines = []
in_atom_block = False
in_atom_data = False
b_col_injected = False
atom_site_columns_seen: list = []
atom_counter = 0
for line in lines:
stripped = line.strip()
if stripped.startswith('_atom_site_'):
in_atom_block = True
atom_site_columns_seen.append(stripped)
new_lines.append(line)
continue
if in_atom_block and not stripped.startswith('_atom_site_'):
if '_atom_site_B_iso_or_equiv' not in '\n'.join(atom_site_columns_seen):
new_lines.append(' _atom_site_B_iso_or_equiv')
b_col_injected = True
in_atom_block = False
in_atom_data = True
if (
in_atom_data
and b_col_injected
and stripped
and not stripped.startswith('_')
and not stripped.startswith('loop_')
and not stripped.startswith('#')
and len(stripped.split()) >= 4
and atom_counter < len(b_factors)
):
bf = b_factors[atom_counter]
line = f'{line} {bf:.2f}'
atom_counter += 1
elif in_atom_data and not stripped:
in_atom_data = False
new_lines.append(line)
return '\n'.join(new_lines)
# =============================================================================
# create_3d_visualization
# =============================================================================
def create_3d_visualization(
structure,
per_atom_attrs: np.ndarray,
per_pore_attrs: np.ndarray,
pore_positions: np.ndarray,
pore_radii: np.ndarray,
property_name: str,
prediction_value: Optional[float] = None,
scenario: Optional[str] = None,
width: int = 800,
height: int = 600,
) -> str:
"""
Render the MOF structure in 3Dmol.js with chemistry-style aesthetics.
Color scheme (bwr = blue-white-red, low B-factor to high):
B = 1 (most negative attribution) -> blue
B = 50 (neutral) -> white
B = 99 (most positive attribution) -> red
Rendering:
Metal atoms: large indigo spheres marking SBU centres
Non-metal atoms: ball-and-stick with Jmol colours
Void mesh: translucent yellow isosurface at the contact-surface
boundary (probe radius 1.2 A, grid 0.7 A)
Attribution: small orange/blue beads at cavity centres,
opacity 0.95, radius 0.7-1.4 A scaled by |attr|,
top-12 by |peak attribution| above 10% threshold
Unit cell: subtle gray wireframe
Args:
structure: pymatgen Structure
per_atom_attrs: np.ndarray [N_atoms], signed attributions
per_pore_attrs: np.ndarray [N_pores], signed attributions
pore_positions: np.ndarray [N_pores, 3] Cartesian Angstroms
pore_radii: np.ndarray [N_pores] Angstroms
property_name: Property string for the info header
prediction_value: Optional float; shown in header if given
scenario: Optional scenario letter ('A'/'B'/'C'/'D')
width: Canvas width in pixels (default 800)
height: Canvas height in pixels (default 600)
Returns:
HTML string with embedded 3Dmol.js, ready for gr.HTML().
"""
if not PY3DMOL_OK:
return (
'<div style="padding:20px;color:#991b1b;background:#fee2e2;border-radius:4px;">'
'<b>py3Dmol not installed.</b> Run: pip install py3Dmol'
'</div>'
)
if not PYMATGEN_OK:
return (
'<div style="padding:20px;color:#991b1b;background:#fee2e2;border-radius:4px;">'
'<b>pymatgen not installed.</b> Run: pip install pymatgen'
'</div>'
)
per_atom_attrs = np.asarray(per_atom_attrs, dtype=float)
per_pore_attrs = np.asarray(per_pore_attrs, dtype=float)
pore_positions = np.asarray(pore_positions, dtype=float)
pore_radii = np.asarray(pore_radii, dtype=float)
# Build CIF string with B-factors
try:
cif_str = _cif_string_with_bfactors(structure, per_atom_attrs)
except Exception as exc:
return (
f'<div style="padding:20px;color:#991b1b;background:#fee2e2;border-radius:4px;">'
f'<b>CIF generation failed:</b> {exc}'
f'</div>'
)
# Create the 3Dmol viewer
view = py3Dmol.view(width=width, height=height)
view.addModel(cif_str, 'cif')
# Polyhedra-style aesthetic: metals as bold blue spheres (the SBU centers),
# other atoms as smaller spheres in soft gray shades. Inspired by iRASPA /
# crystal-toolkit visualisation conventions where metal SBUs are blue
# polyhedra and the linker network sits behind in muted tones.
metal_style = {'color': '#60a5fa'} # cornflower blue, visible on black
nonmetal_style = {'colorscheme': 'Jmol'} # standard chemistry colors
bfactor_style = nonmetal_style # name retained for downstream code paths
# Determine which elements are present in the structure
elements_present = set(str(site.specie) for site in structure)
metals_present = sorted(elements_present & _METAL_ELEMENTS)
nonmetals_present = sorted(elements_present - _METAL_ELEMENTS)
# Default: ball-and-stick with Jmol coloring for non-metal framework atoms
view.setStyle(
{},
{
'sphere': {'radius': 0.28, **nonmetal_style},
'stick': {'radius': 0.12, **nonmetal_style},
},
)
# Metal SBU centers: large indigo spheres to visually mark the polyhedra
if metals_present:
view.setStyle(
{'elem': metals_present},
{
'sphere': {'radius': 0.85, **metal_style},
'stick': {'radius': 0.16, **metal_style},
},
)
# Hydrogens: tiny spheres so they do not dominate the view
if 'H' in nonmetals_present:
view.setStyle(
{'elem': 'H'},
{
'sphere': {'radius': 0.14, **nonmetal_style},
'stick': {'radius': 0.06, **nonmetal_style},
},
)
# Mercury-style void isosurface: rolls a 1.2 A probe through the unit
# cell and renders a translucent yellow mesh at the void boundary. This
# is the chemistry-paper convention for showing MOF porosity (CCDC blog,
# most MOF supplementary information documents). The mesh carries the
# geometric signal (where are the pores, what shape are they). Per-cavity
# attribution sign is carried by the small coloured beads added below;
# per-atom attribution is carried by the framework atom B-factor colour.
try:
verts_cart, faces = _compute_void_mesh(structure)
except Exception:
verts_cart, faces = None, None
if verts_cart is not None and faces is not None and len(faces) > 0:
try:
vertex_arr = [
{'x': float(v[0]), 'y': float(v[1]), 'z': float(v[2])}
for v in verts_cart
]
face_arr = faces.flatten().tolist()
view.addCustom({
'vertexArr': vertex_arr,
'faceArr': face_arr,
'color': '#fbbf24', # warm Mercury-style yellow
'opacity': 0.55,
})
except Exception:
# Mesh shape construction failed (e.g. py3Dmol API mismatch on
# some build). Skip the mesh and continue with the rest of the
# render so the user still sees atoms and attribution beads.
pass
# Per-cavity attribution glyphs: small opaque beads at cluster centres.
# Size scales with |attribution| so the strongest contributing cavity
# is visually largest. The mesh above carries cavity geometry; these
# beads carry the directional attribution signal.
n_attribution_beads = 0
n_pores = len(per_pore_attrs)
if n_pores > 0 and len(pore_positions) == n_pores:
clusters = _cluster_pore_vertices(pore_positions, pore_radii, per_pore_attrs)
if clusters:
cluster_abs = np.array([abs(c['attr']) for c in clusters])
cluster_abs_max = float(cluster_abs.max()) if cluster_abs.max() > 0 else 1e-8
threshold = 0.10 * cluster_abs_max
ranked = sorted(clusters, key=lambda c: abs(c['attr']), reverse=True)
shown = 0
for c in ranked:
if shown >= 12:
break
if abs(c['attr']) < threshold:
break
centre = c['center']
color = '#ea580c' if c['attr'] >= 0 else '#1d4ed8'
# Bead radius capped at 0.9 A so the markers read as small
# attribution glyphs rather than rivalling the SBU spheres
bead_radius = 0.5 + 0.4 * (abs(c['attr']) / cluster_abs_max)
view.addSphere({
'center': {'x': float(centre[0]),
'y': float(centre[1]),
'z': float(centre[2])},
'radius': float(bead_radius),
'color': color,
'opacity': 1.0,
})
# Persistent label next to the bead so the user can read the
# signed cavity attribution without hovering. Only labelled
# when the rounded value is meaningfully non-zero, so we do
# not litter the figure with "+0.00" noise on cavities whose
# contribution is below the readable display precision.
if abs(c['attr']) >= 0.005:
view.addLabel(
f'{c["attr"]:+.2f}',
{
'position': {
'x': float(centre[0]) + float(bead_radius) + 0.4,
'y': float(centre[1]),
'z': float(centre[2]),
},
'backgroundColor': '#0f172a',
'backgroundOpacity': 0.6,
'fontColor': 'white',
'fontSize': 10,
'borderThickness': 0,
'showBackground': True,
'inFront': True,
},
)
shown += 1
n_attribution_beads += 1
# Unit cell wireframe drawn manually as 12 line segments in white, plus
# x/y/z labels at the axis tips. We intentionally bypass 3Dmol's
# addUnitCell() because its `astyle.hidden` flag for the colored a/b/c
# axis arrows is not honoured reliably across versions, and the default
# red/green/blue arrows competed with the attribution colour signal.
cell_lat = _standard_cif_lattice(structure.lattice.abc,
structure.lattice.angles)
corners_frac = np.array([
[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1],
[1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1],
], dtype=float)
corners_cart = corners_frac @ cell_lat
edge_pairs = []
for i_e in range(8):
for j_e in range(i_e + 1, 8):
if abs(corners_frac[j_e] - corners_frac[i_e]).sum() == 1.0:
edge_pairs.append((i_e, j_e))
for i_e, j_e in edge_pairs:
view.addLine({
'start': {'x': float(corners_cart[i_e, 0]),
'y': float(corners_cart[i_e, 1]),
'z': float(corners_cart[i_e, 2])},
'end': {'x': float(corners_cart[j_e, 0]),
'y': float(corners_cart[j_e, 1]),
'z': float(corners_cart[j_e, 2])},
'color': '#e2e8f0',
'linewidth': 2.0,
})
# Axis tip labels using the crystallographic a/b/c convention
# (standard for crystal-structure papers; was briefly x/y/z but reverted
# because a/b/c is the established convention for periodic frameworks).
for tip_idx, label_text in [(1, 'a'), (2, 'b'), (3, 'c')]:
tip = corners_cart[tip_idx]
view.addLabel(label_text, {
'position': {'x': float(tip[0]),
'y': float(tip[1]),
'z': float(tip[2])},
'fontColor': 'white',
'fontSize': 14,
'backgroundColor': 'black',
'backgroundOpacity': 0.55,
'borderThickness': 0,
'inFront': True,
})
# Compute attribution scale so the hover label can de-scale B-factor back to
# the original signed attribution value the user understands.
abs_max = float(np.abs(per_atom_attrs).max()) if len(per_atom_attrs) > 0 else 1.0
if abs_max < 1e-12:
abs_max = 1.0
# Hover tooltip on atoms: element + index + position + signed attribution
# Attribution is recovered from the B-factor channel (set in _cif_string_with_bfactors):
# B = 50 + 49 * (attr / abs_max) -> attr = (B - 50) / 49 * abs_max
hover_callback = (
'function(atom, viewer) {'
' if(atom.label) return;'
f' var absMax = {abs_max:.6f};'
' var attr = ((atom.b - 50.0) / 49.0) * absMax;'
' var sign = attr >= 0 ? "+" : "";'
' var line1 = atom.elem + " #" + atom.serial;'
' var line2 = "Attribution: " + sign + attr.toFixed(4);'
' var line3 = "Coords: (" + atom.x.toFixed(2) + ", " + atom.y.toFixed(2) + ", " + atom.z.toFixed(2) + ") A";'
' atom.label = viewer.addLabel(line1 + "\\n" + line2 + "\\n" + line3, {'
' position: atom, backgroundColor: "#0f172a", fontColor: "#ffffff",'
' fontSize: 11, padding: 5, borderRadius: 3, alignment: "topLeft"'
' });'
'}'
)
unhover_callback = (
'function(atom, viewer) {'
' if(atom.label) { viewer.removeLabel(atom.label); delete atom.label; }'
'}'
)
view.setHoverable({}, True, hover_callback, unhover_callback)
# Finalize
view.setBackgroundColor('black')
view.zoomTo()
viewer_html = _inline_3dmol_library(view._make_html(), width=width, height=height)
# Build a property/scenario header to display above the 3Dmol canvas
label_map = {'A': 'Trustworthy', 'B': 'Overconfident',
'C': 'Underconfident', 'D': 'Unreliable'}
header_parts = [f'XAI &mdash; <b>{property_name}</b>']
if prediction_value is not None:
header_parts.append(f'pred = {prediction_value:.3g}')
if scenario:
sc_label = label_map.get(scenario, '')
sc_colors = {
'A': ('#16a34a', '#dcfce7'),
'B': ('#d97706', '#fef3c7'),
'C': ('#2563eb', '#dbeafe'),
'D': ('#dc2626', '#fee2e2'),
}
fg, bg = sc_colors.get(scenario, ('#334155', '#f1f5f9'))
header_parts.append(
f'<span style="background:{bg};color:{fg};padding:2px 8px;'
f'border-radius:3px;font-size:0.88em;">Scenario {scenario}: {sc_label}</span>'
)
header_html = (
'<div style="font-family:\'IBM Plex Mono\',monospace;font-size:0.80em;'
'color:#334155;padding:8px 12px;background:#f8fafc;border-left:3px solid #0891b2;'
'margin-bottom:4px;line-height:1.7;">'
+ ' &ensp;&bull;&ensp; '.join(header_parts)
+ '<span style="float:right;color:#94a3b8;">'
'Hover atoms for details &bull; Drag to rotate &bull; Scroll to zoom'
'</span></div>'
)
# Legend: only show entries that actually appear in the rendered scene.
# Metals are overridden to cornflower blue in the viewer (to be visible
# on the black background), so we override their swatch colour here too
# rather than using the Jmol palette default. Cavity attribution beads
# have no Jmol equivalent and get their own legend block when drawn.
JMOL_COLORS = {
'H': '#FFFFFF', 'C': '#909090', 'N': '#3050F8', 'O': '#FF0D0D',
'F': '#90E050', 'Cl': '#1FF01F', 'Br': '#A62929', 'I': '#940094',
'P': '#FF8000', 'S': '#FFFF30', 'B': '#FFB5B5', 'Si': '#F0C8A0',
'Cu': '#C88033', 'Zn': '#7D80B0', 'Co': '#F090A0', 'Fe': '#E06633',
'Ni': '#50D050', 'Mn': '#9C7AC7', 'Cr': '#8A99C7', 'V': '#A6A6AB',
'Ti': '#BFC2C7', 'Zr': '#94E0E0', 'Hf': '#4DC2FF', 'Mg': '#8AFF00',
'Ca': '#3DFF00', 'Al': '#BFA6A6', 'Cd': '#FFD98F', 'La': '#70D4FF',
'Ce': '#FFFFC7', 'Ag': '#C0C0C0', 'Au': '#FFD123', 'Pb': '#575961',
}
METAL_VIEWER_COLOR = '#60a5fa' # must match metal_style above
def _legend_chip(color: str, text: str) -> str:
border = 'border:1px solid #cbd5e1;' if color.upper() == '#FFFFFF' else ''
return (
f'<span style="display:inline-flex;align-items:center;gap:4px;'
f'margin:0 8px 4px 0;font-family:\'IBM Plex Sans\',sans-serif;'
f'font-size:0.78em;color:#334155;">'
f'<span style="display:inline-block;width:11px;height:11px;'
f'border-radius:50%;background:{color};{border}"></span>{text}'
f'</span>'
)
element_chips = []
for elem in sorted(elements_present):
if elem in _METAL_ELEMENTS:
color = METAL_VIEWER_COLOR
else:
color = JMOL_COLORS.get(elem, '#888888')
element_chips.append(_legend_chip(color, elem))
cavity_chips_html = ''
if n_attribution_beads > 0:
cavity_chips_html = (
'<span style="font-family:\'IBM Plex Sans\',sans-serif;font-size:0.72em;'
'font-weight:500;color:#94a3b8;text-transform:uppercase;letter-spacing:0.05em;'
'margin:0 10px 0 18px;">Cavity attribution:</span>'
+ _legend_chip('#ea580c', '+ contribution')
+ _legend_chip('#1d4ed8', '&minus; contribution')
)
legend_html = (
'<div style="padding:6px 12px;background:#ffffff;border:1px solid #e2e8f0;'
'border-top:none;border-radius:0 0 4px 4px;margin-top:-4px;line-height:1.6;">'
'<span style="font-family:\'IBM Plex Sans\',sans-serif;font-size:0.72em;'
'font-weight:500;color:#94a3b8;text-transform:uppercase;letter-spacing:0.05em;'
'margin-right:10px;">Elements:</span>'
+ ''.join(element_chips)
+ cavity_chips_html
+ '</div>'
)
return header_html + viewer_html + legend_html
# =============================================================================
# export_attribution_csv
# =============================================================================
def export_attribution_csv(
structure,
per_atom_attrs: np.ndarray,
pore_positions: np.ndarray,
pore_radii: np.ndarray,
per_pore_attrs: np.ndarray,
property_name: str,
output_path: str,
) -> str:
"""
Write per-atom and per-pore signed attribution to a CSV for downstream
analysis (e.g. ranking sites, identifying design hotspots, or feeding into
follow-up MD/DFT calculations on high-attribution motifs).
Columns:
kind 'atom' or 'pore'
idx sequential index within kind
element chemical symbol (atoms only; 'pore' for pore vertices)
x, y, z Cartesian coordinates (Angstroms)
radius inscribed-sphere radius (pores only; blank for atoms)
attribution signed attribution to <property_name>
b_factor attribution mapped to 1-99 iRASPA convention
(1=most negative, 50=neutral, 99=most positive)
"""
abs_max = float(np.abs(per_atom_attrs).max()) if len(per_atom_attrs) > 0 else 1.0
if len(per_pore_attrs) > 0:
abs_max = max(abs_max, float(np.abs(per_pore_attrs).max()))
if abs_max < 1e-12:
abs_max = 1.0
rows = [
f'# PoreGCN per-atom and per-pore attribution to {property_name}',
f'# Generated by huggingface.co/spaces/catenate/PoreGCN',
f'# abs_max = {abs_max:.6f}',
'kind,idx,element,x,y,z,radius,attribution,b_factor',
]
# Atom rows
for i, site in enumerate(structure):
attr = float(per_atom_attrs[i]) if i < len(per_atom_attrs) else 0.0
b = 50.0 + 49.0 * (attr / abs_max)
b = max(1.0, min(99.0, b))
rows.append(
f'atom,{i},{site.specie},{site.coords[0]:.4f},'
f'{site.coords[1]:.4f},{site.coords[2]:.4f},,'
f'{attr:.6f},{b:.2f}'
)
# Pore rows
n_pores = len(per_pore_attrs)
if n_pores > 0 and len(pore_positions) == n_pores:
for j in range(n_pores):
attr = float(per_pore_attrs[j])
b = 50.0 + 49.0 * (attr / abs_max)
b = max(1.0, min(99.0, b))
r = float(pore_radii[j]) if j < len(pore_radii) else 0.0
rows.append(
f'pore,{j},pore,{pore_positions[j,0]:.4f},'
f'{pore_positions[j,1]:.4f},{pore_positions[j,2]:.4f},'
f'{r:.3f},{attr:.6f},{b:.2f}'
)
with open(output_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(rows))
f.write('\n')
return output_path
# =============================================================================
# export_iraspa_cif
# =============================================================================
def export_iraspa_cif(
structure,
per_atom_attrs: np.ndarray,
output_path: str,
) -> str:
"""
Write a CIF file with _atom_site_B_iso_or_equiv encoding XAI attributions.
B-factor mapping (symmetric around 50):
1 = most negative attribution (atom drives property DOWN strongly)
50 = neutral (attribution ~ 0)
99 = most positive attribution (atom drives property UP strongly)
iRASPA usage:
1. File > Open > <output_path>
2. Appearance > Atoms > Color by > Temperature Factor
3. Apply a diverging blue-white-red colormap
Delegates B-factor injection to _cif_string_with_bfactors() so the
encoding logic is shared with create_3d_visualization.
Args:
structure: pymatgen Structure
per_atom_attrs: np.ndarray [N_atoms] signed attributions
output_path: Destination .cif file path (string)
Returns:
output_path (string) for gr.File() download component
"""
if not PYMATGEN_OK:
raise ImportError('pymatgen is required: pip install pymatgen')
per_atom_attrs = np.asarray(per_atom_attrs, dtype=float)
cif_text = _cif_string_with_bfactors(structure, per_atom_attrs)
output_path = str(output_path)
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
Path(output_path).write_text(cif_text, encoding='utf-8')
return output_path