| """ |
| visualize.py - 3Dmol.js 3D visualization and iRASPA CIF export for PoreGCN HF Space. |
| |
| Frontend contract (app.py imports): |
| from visualize import create_3d_visualization, export_iraspa_cif |
| |
| create_3d_visualization(): |
| Returns an HTML string (with embedded 3Dmol.js) ready for gr.HTML(). |
| - Atoms colored by attribution via B-factor channel (bwr colormap): |
| blue = negative attribution, white = neutral, red = positive attribution |
| - Metal atoms rendered as larger spheres (radius 0.6) to visually mark SBU centers |
| - Non-metal atoms rendered as sticks (radius 0.15) |
| - Pores shown as a Mercury-style translucent yellow void isosurface |
| (rolling 1.2 A probe + marching cubes; CCDC convention) |
| - Per-cavity attribution carried by small opaque orange/blue beads |
| at cavity centres, sized by |attribution| |
| - Unit cell wireframe via addUnitCell() |
| |
| export_iraspa_cif(): |
| Writes a CIF with _atom_site_B_iso_or_equiv overwritten by normalised |
| attributions in [1, 99] (1=most negative, 50=neutral, 99=most positive). |
| Compatible with iRASPA's Temperature Factor colouring. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import re |
| from pathlib import Path |
| from typing import Optional |
|
|
| import numpy as np |
|
|
| |
| |
| _3DMOL_JS_PATH = Path(__file__).parent / "3Dmol-min.js" |
| _3DMOL_JS_INLINE: Optional[str] = None |
| if _3DMOL_JS_PATH.exists(): |
| _3DMOL_JS_INLINE = _3DMOL_JS_PATH.read_text(encoding="utf-8") |
|
|
|
|
| def _inline_3dmol_library(viewer_html: str, width: int = 800, height: int = 600) -> str: |
| """Patch py3Dmol HTML so 3Dmol.js runs in an isolated iframe. |
| |
| Two reasons we need the iframe: |
| 1. Gradio's gr.HTML component sanitises <script> tags, so the original |
| py3Dmol JS would never run. An iframe with srcdoc executes scripts |
| in its own document context, bypassing the sanitiser. |
| 2. py3Dmol's _make_html() loads 3Dmol.js from cdnjs.cloudflare.com via |
| loadScriptAsync(). We replace that with an inlined copy so the viewer |
| works on networks that block CDN access. |
| """ |
| if _3DMOL_JS_INLINE is None: |
| |
| |
| full_doc = ( |
| "<!DOCTYPE html><html><head><meta charset='utf-8'></head>" |
| f"<body style='margin:0;padding:0;'>{viewer_html}</body></html>" |
| ) |
| else: |
| inline_script = f'<script type="text/javascript">{_3DMOL_JS_INLINE}</script>' |
| patched_viewer = re.sub( |
| r"loadScriptAsync\(['\"]https?://[^'\"]+['\"]\)", |
| "Promise.resolve()", |
| viewer_html, |
| count=1, |
| ) |
| full_doc = ( |
| "<!DOCTYPE html><html><head><meta charset='utf-8'>" |
| f"{inline_script}</head>" |
| f"<body style='margin:0;padding:0;overflow:hidden;background:black;'>" |
| f"{patched_viewer}</body></html>" |
| ) |
| |
| srcdoc = full_doc.replace("&", "&").replace('"', """) |
| iframe_html = ( |
| f'<iframe srcdoc="{srcdoc}" ' |
| f'width="100%" height="{height + 20}" ' |
| f'frameborder="0" style="border:1px solid #1f2937;border-radius:4px;background:black;"' |
| f'></iframe>' |
| ) |
| return iframe_html |
|
|
| try: |
| import py3Dmol |
| PY3DMOL_OK = True |
| except ImportError: |
| PY3DMOL_OK = False |
|
|
| try: |
| from pymatgen.core import Structure |
| from pymatgen.io.cif import CifWriter |
| PYMATGEN_OK = True |
| except ImportError: |
| PYMATGEN_OK = False |
|
|
| |
| |
| _METAL_ELEMENTS = { |
| 'Cu', 'Zn', 'Co', 'Zr', 'Fe', 'Mn', 'Ni', 'Mg', 'Ca', 'V', |
| 'Al', 'Cr', 'Ti', 'Hf', 'In', 'Ga', 'Y', 'Ce', 'Nd', 'Tb', |
| 'Eu', 'Gd', 'La', 'Mo', 'W', 'Ru', 'Rh', 'Pd', 'Cd', 'Sn', |
| 'Ba', 'Sr', 'Li', 'Na', 'K', 'Rb', 'Cs', |
| } |
|
|
|
|
| def _standard_cif_lattice(abc, angles_deg): |
| """Build the standard CIF Cartesian lattice matrix from a/b/c/alpha/beta/gamma. |
| |
| The standard convention (used by 3Dmol.js, Mercury, VESTA, iRASPA when |
| parsing a CIF that specifies only cell scalars): |
| a along +x |
| b in the +xy plane (positive y) |
| c chosen so c_z > 0 |
| |
| Pymatgen does NOT always store its `Lattice.matrix` in this orientation; |
| when a CIF is loaded, the matrix may be rotated to a different frame. |
| Atom positions written by pymatgen's CifWriter contain only fractional |
| coordinates plus the cell scalars, so the rendering viewer (3Dmol) |
| re-derives Cartesian positions in the standard convention. Any geometry |
| we compute alongside the structure must be transformed into the same |
| standard frame before being passed to the viewer. |
| |
| Returns a 3x3 ndarray with row vectors a, b, c. |
| """ |
| a, b, c = abc |
| al, be, ga = [np.deg2rad(x) for x in angles_deg] |
| ax, ay, az = a, 0.0, 0.0 |
| bx, by, bz = b * np.cos(ga), b * np.sin(ga), 0.0 |
| cx = c * np.cos(be) |
| sin_ga = np.sin(ga) if abs(np.sin(ga)) > 1e-12 else 1e-12 |
| cy = c * (np.cos(al) - np.cos(be) * np.cos(ga)) / sin_ga |
| cz = np.sqrt(max(c * c - cx * cx - cy * cy, 0.0)) |
| return np.array([[ax, ay, az], [bx, by, bz], [cx, cy, cz]], dtype=float) |
|
|
|
|
| def _compute_void_mesh( |
| structure, |
| probe_radius: float = 1.2, |
| target_grid_spacing: float = 0.7, |
| n_per_axis_min: int = 16, |
| n_per_axis_max: int = 60, |
| max_vertices: int = 40000, |
| vdw_eff: float = 1.5, |
| ): |
| """Compute a Mercury-style contact-surface void mesh for the unit cell. |
| |
| Rolls a `probe_radius` probe (default 1.2 A, helium-sized) through the |
| unit cell and runs marching cubes at the void boundary. Mirrors |
| Mercury's Display > Voids > Contact Surface (CCDC, see |
| https://www.ccdc.cam.ac.uk/discover/blog/how-to-search-visualize-and-analyse-mof-structures). |
| |
| Atoms are replicated to a 3x3x3 super-image so distance queries respect |
| periodic boundaries. A constant effective vdW radius of 1.5 A is used |
| instead of per-element vdW; for typical MOF elements (C, N, O, H, Cu, |
| Zn etc.) the per-element value is within 0.2 A of this mean and the |
| visual is unchanged. |
| |
| Robustness behaviour. Per-axis grid resolution is clamped to |
| [`n_per_axis_min`, `n_per_axis_max`] so very large unit cells do not |
| blow up compute or browser payload, and very small cells still get a |
| decent surface. Mesh vertex count is capped at `max_vertices`; if the |
| surface would exceed that, the function returns (None, None) and the |
| caller falls back to a no-mesh render (atom colouring plus attribution |
| beads). All anticipated failure modes (missing scikit-image, missing |
| scipy, degenerate lattice, all-void or all-solid grid, marching-cubes |
| failure, oversize mesh) return (None, None). |
| |
| Returns |
| ------- |
| (vertices_cart, faces) where |
| vertices_cart: np.ndarray [V, 3] in Cartesian Angstroms |
| faces: np.ndarray [F, 3] of vertex indices per triangle |
| or (None, None). |
| """ |
| |
| try: |
| from skimage.measure import marching_cubes |
| except Exception: |
| return None, None |
| try: |
| from scipy.spatial import cKDTree |
| except Exception: |
| return None, None |
|
|
| try: |
| lattice = np.asarray(structure.lattice.matrix, dtype=float) |
| a, b, c = structure.lattice.abc |
| if min(a, b, c) <= 0.0: |
| return None, None |
|
|
| |
| |
| |
| nx = int(np.clip(np.ceil(a / target_grid_spacing), |
| n_per_axis_min, n_per_axis_max)) |
| ny = int(np.clip(np.ceil(b / target_grid_spacing), |
| n_per_axis_min, n_per_axis_max)) |
| nz = int(np.clip(np.ceil(c / target_grid_spacing), |
| n_per_axis_min, n_per_axis_max)) |
|
|
| fx = np.linspace(0.0, 1.0, nx, endpoint=False) |
| fy = np.linspace(0.0, 1.0, ny, endpoint=False) |
| fz = np.linspace(0.0, 1.0, nz, endpoint=False) |
| grid_frac = np.stack(np.meshgrid(fx, fy, fz, indexing='ij'), axis=-1) |
| grid_cart = grid_frac.reshape(-1, 3) @ lattice |
|
|
| atom_cart = np.array([site.coords for site in structure], dtype=float) |
| if atom_cart.shape[0] == 0: |
| return None, None |
|
|
| shifts = [] |
| for da in (-1, 0, 1): |
| for db in (-1, 0, 1): |
| for dc in (-1, 0, 1): |
| shifts.append(da * lattice[0] + db * lattice[1] + dc * lattice[2]) |
| images = np.vstack([atom_cart + s for s in shifts]) |
|
|
| tree = cKDTree(images) |
| cutoff = vdw_eff + probe_radius |
| dist, _ = tree.query(grid_cart, k=1) |
| void = (dist > cutoff).reshape(nx, ny, nz).astype(np.float32) |
|
|
| if void.max() < 0.5 or void.min() > 0.5: |
| return None, None |
|
|
| try: |
| verts_idx, faces, _, _ = marching_cubes( |
| void, level=0.5, allow_degenerate=False |
| ) |
| except Exception: |
| return None, None |
|
|
| |
| |
| if verts_idx.shape[0] > max_vertices: |
| return None, None |
|
|
| |
| |
| |
| |
| |
| |
| verts_frac = verts_idx / np.array([nx, ny, nz], dtype=float) |
|
|
| |
| |
| |
| |
| eps = 1e-6 |
| inside = np.all( |
| (verts_frac >= -eps) & (verts_frac <= 1.0 + eps), axis=1 |
| ) |
| if not inside.all(): |
| face_inside = inside[faces].all(axis=1) |
| faces = faces[face_inside] |
| if faces.shape[0] == 0: |
| return None, None |
|
|
| std_lat = _standard_cif_lattice(structure.lattice.abc, |
| structure.lattice.angles) |
| verts_cart = verts_frac @ std_lat |
| return verts_cart.astype(float), faces.astype(int) |
| except Exception: |
| |
| |
| return None, None |
|
|
|
|
| def _cluster_pore_vertices( |
| positions: np.ndarray, |
| radii: np.ndarray, |
| attrs: np.ndarray, |
| ) -> list: |
| """Cluster Voronoi pore vertices into cavity centres. |
| |
| Greedy non-maximum-suppression style: walk vertices in descending order |
| of inscribed-sphere radius, take the largest as a cavity centre, absorb |
| all unassigned vertices within that radius. The largest inscribed sphere |
| in a cavity *is* the cavity centre by construction (Zeo++ semantics), and |
| nearby smaller-radius vertices are pinch points along channels feeding it. |
| |
| Returns a list of dicts with center, radius, and aggregated attribution |
| statistics over cluster members. The cavity-scale sphere idiom follows |
| Lisensky and Yaghi, J. Chem. Educ. 2022, 99, 1998-2004 (Figures 1-8), |
| where one translucent sphere per crystallographic pore is drawn at the |
| cavity centre, sized to the cavity diameter. |
| """ |
| n = len(positions) |
| if n == 0: |
| return [] |
|
|
| positions = np.asarray(positions, dtype=float) |
| radii = np.asarray(radii, dtype=float) |
| attrs = np.asarray(attrs, dtype=float) |
|
|
| order = np.argsort(radii)[::-1] |
| consumed = np.zeros(n, dtype=bool) |
| clusters: list = [] |
|
|
| for i in order: |
| if consumed[i]: |
| continue |
| centre = positions[i] |
| cavity_radius = float(max(radii[i], 0.5)) |
| d = np.linalg.norm(positions - centre, axis=1) |
| members = np.where((d <= cavity_radius) & (~consumed))[0] |
| if len(members) == 0: |
| members = np.array([i]) |
| consumed[members] = True |
|
|
| member_attrs = attrs[members] |
| |
| |
| |
| |
| peak_idx = int(np.argmax(np.abs(member_attrs))) |
| peak_attr = float(member_attrs[peak_idx]) |
|
|
| clusters.append({ |
| 'center': centre, |
| 'radius': cavity_radius, |
| 'attr': peak_attr, |
| 'members': members, |
| }) |
|
|
| return clusters |
|
|
|
|
| |
| |
| |
|
|
| def _cif_string_with_bfactors(structure, per_atom_attrs: np.ndarray) -> str: |
| """ |
| Build a CIF text string from a pymatgen Structure with per-atom B-factors |
| encoding XAI attributions on the iRASPA scale: |
| 1 = most negative attribution |
| 50 = neutral |
| 99 = most positive attribution |
| |
| Both create_3d_visualization and export_iraspa_cif delegate here so the |
| B-factor injection logic is defined exactly once. |
| |
| Atom fractional coordinates are wrapped to [0, 1) before writing the CIF |
| so that no atoms appear outside the unit cell wireframe in the rendered |
| viewer. Source CIFs occasionally place atoms at fractional positions |
| slightly past 1.0 (numeric precision or symmetry-related sites), which |
| otherwise produce visually confusing atoms drifting outside the cell. |
| |
| Args: |
| structure: pymatgen Structure |
| per_atom_attrs: np.ndarray [N_atoms], signed attributions |
| |
| Returns: |
| CIF text string with _atom_site_B_iso_or_equiv column appended. |
| """ |
| abs_max = max(float(np.max(np.abs(per_atom_attrs))), 1e-8) |
| b_factors = 50.0 + 49.0 * (per_atom_attrs / abs_max) |
|
|
| |
| |
| |
| try: |
| wrapped_struct = Structure( |
| lattice=structure.lattice, |
| species=[site.species for site in structure], |
| coords=[(np.asarray(site.frac_coords, dtype=float) % 1.0) |
| for site in structure], |
| coords_are_cartesian=False, |
| ) |
| except Exception: |
| wrapped_struct = structure |
|
|
| writer = CifWriter(wrapped_struct) |
| cif_text = str(writer) |
|
|
| lines = cif_text.split('\n') |
| new_lines = [] |
| in_atom_block = False |
| in_atom_data = False |
| b_col_injected = False |
| atom_site_columns_seen: list = [] |
| atom_counter = 0 |
|
|
| for line in lines: |
| stripped = line.strip() |
|
|
| if stripped.startswith('_atom_site_'): |
| in_atom_block = True |
| atom_site_columns_seen.append(stripped) |
| new_lines.append(line) |
| continue |
|
|
| if in_atom_block and not stripped.startswith('_atom_site_'): |
| if '_atom_site_B_iso_or_equiv' not in '\n'.join(atom_site_columns_seen): |
| new_lines.append(' _atom_site_B_iso_or_equiv') |
| b_col_injected = True |
| in_atom_block = False |
| in_atom_data = True |
|
|
| if ( |
| in_atom_data |
| and b_col_injected |
| and stripped |
| and not stripped.startswith('_') |
| and not stripped.startswith('loop_') |
| and not stripped.startswith('#') |
| and len(stripped.split()) >= 4 |
| and atom_counter < len(b_factors) |
| ): |
| bf = b_factors[atom_counter] |
| line = f'{line} {bf:.2f}' |
| atom_counter += 1 |
| elif in_atom_data and not stripped: |
| in_atom_data = False |
|
|
| new_lines.append(line) |
|
|
| return '\n'.join(new_lines) |
|
|
|
|
| |
| |
| |
|
|
| def create_3d_visualization( |
| structure, |
| per_atom_attrs: np.ndarray, |
| per_pore_attrs: np.ndarray, |
| pore_positions: np.ndarray, |
| pore_radii: np.ndarray, |
| property_name: str, |
| prediction_value: Optional[float] = None, |
| scenario: Optional[str] = None, |
| width: int = 800, |
| height: int = 600, |
| ) -> str: |
| """ |
| Render the MOF structure in 3Dmol.js with chemistry-style aesthetics. |
| |
| Color scheme (bwr = blue-white-red, low B-factor to high): |
| B = 1 (most negative attribution) -> blue |
| B = 50 (neutral) -> white |
| B = 99 (most positive attribution) -> red |
| |
| Rendering: |
| Metal atoms: large indigo spheres marking SBU centres |
| Non-metal atoms: ball-and-stick with Jmol colours |
| Void mesh: translucent yellow isosurface at the contact-surface |
| boundary (probe radius 1.2 A, grid 0.7 A) |
| Attribution: small orange/blue beads at cavity centres, |
| opacity 0.95, radius 0.7-1.4 A scaled by |attr|, |
| top-12 by |peak attribution| above 10% threshold |
| Unit cell: subtle gray wireframe |
| |
| Args: |
| structure: pymatgen Structure |
| per_atom_attrs: np.ndarray [N_atoms], signed attributions |
| per_pore_attrs: np.ndarray [N_pores], signed attributions |
| pore_positions: np.ndarray [N_pores, 3] Cartesian Angstroms |
| pore_radii: np.ndarray [N_pores] Angstroms |
| property_name: Property string for the info header |
| prediction_value: Optional float; shown in header if given |
| scenario: Optional scenario letter ('A'/'B'/'C'/'D') |
| width: Canvas width in pixels (default 800) |
| height: Canvas height in pixels (default 600) |
| |
| Returns: |
| HTML string with embedded 3Dmol.js, ready for gr.HTML(). |
| """ |
| if not PY3DMOL_OK: |
| return ( |
| '<div style="padding:20px;color:#991b1b;background:#fee2e2;border-radius:4px;">' |
| '<b>py3Dmol not installed.</b> Run: pip install py3Dmol' |
| '</div>' |
| ) |
| if not PYMATGEN_OK: |
| return ( |
| '<div style="padding:20px;color:#991b1b;background:#fee2e2;border-radius:4px;">' |
| '<b>pymatgen not installed.</b> Run: pip install pymatgen' |
| '</div>' |
| ) |
|
|
| per_atom_attrs = np.asarray(per_atom_attrs, dtype=float) |
| per_pore_attrs = np.asarray(per_pore_attrs, dtype=float) |
| pore_positions = np.asarray(pore_positions, dtype=float) |
| pore_radii = np.asarray(pore_radii, dtype=float) |
|
|
| |
| try: |
| cif_str = _cif_string_with_bfactors(structure, per_atom_attrs) |
| except Exception as exc: |
| return ( |
| f'<div style="padding:20px;color:#991b1b;background:#fee2e2;border-radius:4px;">' |
| f'<b>CIF generation failed:</b> {exc}' |
| f'</div>' |
| ) |
|
|
| |
| view = py3Dmol.view(width=width, height=height) |
| view.addModel(cif_str, 'cif') |
|
|
| |
| |
| |
| |
| metal_style = {'color': '#60a5fa'} |
| nonmetal_style = {'colorscheme': 'Jmol'} |
| bfactor_style = nonmetal_style |
|
|
| |
| elements_present = set(str(site.specie) for site in structure) |
| metals_present = sorted(elements_present & _METAL_ELEMENTS) |
| nonmetals_present = sorted(elements_present - _METAL_ELEMENTS) |
|
|
| |
| view.setStyle( |
| {}, |
| { |
| 'sphere': {'radius': 0.28, **nonmetal_style}, |
| 'stick': {'radius': 0.12, **nonmetal_style}, |
| }, |
| ) |
|
|
| |
| if metals_present: |
| view.setStyle( |
| {'elem': metals_present}, |
| { |
| 'sphere': {'radius': 0.85, **metal_style}, |
| 'stick': {'radius': 0.16, **metal_style}, |
| }, |
| ) |
|
|
| |
| if 'H' in nonmetals_present: |
| view.setStyle( |
| {'elem': 'H'}, |
| { |
| 'sphere': {'radius': 0.14, **nonmetal_style}, |
| 'stick': {'radius': 0.06, **nonmetal_style}, |
| }, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| try: |
| verts_cart, faces = _compute_void_mesh(structure) |
| except Exception: |
| verts_cart, faces = None, None |
|
|
| if verts_cart is not None and faces is not None and len(faces) > 0: |
| try: |
| vertex_arr = [ |
| {'x': float(v[0]), 'y': float(v[1]), 'z': float(v[2])} |
| for v in verts_cart |
| ] |
| face_arr = faces.flatten().tolist() |
| view.addCustom({ |
| 'vertexArr': vertex_arr, |
| 'faceArr': face_arr, |
| 'color': '#fbbf24', |
| 'opacity': 0.55, |
| }) |
| except Exception: |
| |
| |
| |
| pass |
|
|
| |
| |
| |
| |
| n_attribution_beads = 0 |
| n_pores = len(per_pore_attrs) |
| if n_pores > 0 and len(pore_positions) == n_pores: |
| clusters = _cluster_pore_vertices(pore_positions, pore_radii, per_pore_attrs) |
| if clusters: |
| cluster_abs = np.array([abs(c['attr']) for c in clusters]) |
| cluster_abs_max = float(cluster_abs.max()) if cluster_abs.max() > 0 else 1e-8 |
| threshold = 0.10 * cluster_abs_max |
| ranked = sorted(clusters, key=lambda c: abs(c['attr']), reverse=True) |
| shown = 0 |
| for c in ranked: |
| if shown >= 12: |
| break |
| if abs(c['attr']) < threshold: |
| break |
| centre = c['center'] |
| color = '#ea580c' if c['attr'] >= 0 else '#1d4ed8' |
| |
| |
| bead_radius = 0.5 + 0.4 * (abs(c['attr']) / cluster_abs_max) |
| view.addSphere({ |
| 'center': {'x': float(centre[0]), |
| 'y': float(centre[1]), |
| 'z': float(centre[2])}, |
| 'radius': float(bead_radius), |
| 'color': color, |
| 'opacity': 1.0, |
| }) |
| |
| |
| |
| |
| |
| if abs(c['attr']) >= 0.005: |
| view.addLabel( |
| f'{c["attr"]:+.2f}', |
| { |
| 'position': { |
| 'x': float(centre[0]) + float(bead_radius) + 0.4, |
| 'y': float(centre[1]), |
| 'z': float(centre[2]), |
| }, |
| 'backgroundColor': '#0f172a', |
| 'backgroundOpacity': 0.6, |
| 'fontColor': 'white', |
| 'fontSize': 10, |
| 'borderThickness': 0, |
| 'showBackground': True, |
| 'inFront': True, |
| }, |
| ) |
| shown += 1 |
| n_attribution_beads += 1 |
|
|
| |
| |
| |
| |
| |
| cell_lat = _standard_cif_lattice(structure.lattice.abc, |
| structure.lattice.angles) |
| corners_frac = np.array([ |
| [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], |
| [1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1], |
| ], dtype=float) |
| corners_cart = corners_frac @ cell_lat |
| edge_pairs = [] |
| for i_e in range(8): |
| for j_e in range(i_e + 1, 8): |
| if abs(corners_frac[j_e] - corners_frac[i_e]).sum() == 1.0: |
| edge_pairs.append((i_e, j_e)) |
| for i_e, j_e in edge_pairs: |
| view.addLine({ |
| 'start': {'x': float(corners_cart[i_e, 0]), |
| 'y': float(corners_cart[i_e, 1]), |
| 'z': float(corners_cart[i_e, 2])}, |
| 'end': {'x': float(corners_cart[j_e, 0]), |
| 'y': float(corners_cart[j_e, 1]), |
| 'z': float(corners_cart[j_e, 2])}, |
| 'color': '#e2e8f0', |
| 'linewidth': 2.0, |
| }) |
| |
| |
| |
| for tip_idx, label_text in [(1, 'a'), (2, 'b'), (3, 'c')]: |
| tip = corners_cart[tip_idx] |
| view.addLabel(label_text, { |
| 'position': {'x': float(tip[0]), |
| 'y': float(tip[1]), |
| 'z': float(tip[2])}, |
| 'fontColor': 'white', |
| 'fontSize': 14, |
| 'backgroundColor': 'black', |
| 'backgroundOpacity': 0.55, |
| 'borderThickness': 0, |
| 'inFront': True, |
| }) |
|
|
| |
| |
| abs_max = float(np.abs(per_atom_attrs).max()) if len(per_atom_attrs) > 0 else 1.0 |
| if abs_max < 1e-12: |
| abs_max = 1.0 |
|
|
| |
| |
| |
| hover_callback = ( |
| 'function(atom, viewer) {' |
| ' if(atom.label) return;' |
| f' var absMax = {abs_max:.6f};' |
| ' var attr = ((atom.b - 50.0) / 49.0) * absMax;' |
| ' var sign = attr >= 0 ? "+" : "";' |
| ' var line1 = atom.elem + " #" + atom.serial;' |
| ' var line2 = "Attribution: " + sign + attr.toFixed(4);' |
| ' var line3 = "Coords: (" + atom.x.toFixed(2) + ", " + atom.y.toFixed(2) + ", " + atom.z.toFixed(2) + ") A";' |
| ' atom.label = viewer.addLabel(line1 + "\\n" + line2 + "\\n" + line3, {' |
| ' position: atom, backgroundColor: "#0f172a", fontColor: "#ffffff",' |
| ' fontSize: 11, padding: 5, borderRadius: 3, alignment: "topLeft"' |
| ' });' |
| '}' |
| ) |
| unhover_callback = ( |
| 'function(atom, viewer) {' |
| ' if(atom.label) { viewer.removeLabel(atom.label); delete atom.label; }' |
| '}' |
| ) |
| view.setHoverable({}, True, hover_callback, unhover_callback) |
|
|
| |
| view.setBackgroundColor('black') |
| view.zoomTo() |
|
|
| viewer_html = _inline_3dmol_library(view._make_html(), width=width, height=height) |
|
|
| |
| label_map = {'A': 'Trustworthy', 'B': 'Overconfident', |
| 'C': 'Underconfident', 'D': 'Unreliable'} |
| header_parts = [f'XAI — <b>{property_name}</b>'] |
| if prediction_value is not None: |
| header_parts.append(f'pred = {prediction_value:.3g}') |
| if scenario: |
| sc_label = label_map.get(scenario, '') |
| sc_colors = { |
| 'A': ('#16a34a', '#dcfce7'), |
| 'B': ('#d97706', '#fef3c7'), |
| 'C': ('#2563eb', '#dbeafe'), |
| 'D': ('#dc2626', '#fee2e2'), |
| } |
| fg, bg = sc_colors.get(scenario, ('#334155', '#f1f5f9')) |
| header_parts.append( |
| f'<span style="background:{bg};color:{fg};padding:2px 8px;' |
| f'border-radius:3px;font-size:0.88em;">Scenario {scenario}: {sc_label}</span>' |
| ) |
|
|
| header_html = ( |
| '<div style="font-family:\'IBM Plex Mono\',monospace;font-size:0.80em;' |
| 'color:#334155;padding:8px 12px;background:#f8fafc;border-left:3px solid #0891b2;' |
| 'margin-bottom:4px;line-height:1.7;">' |
| + '  •  '.join(header_parts) |
| + '<span style="float:right;color:#94a3b8;">' |
| 'Hover atoms for details • Drag to rotate • Scroll to zoom' |
| '</span></div>' |
| ) |
|
|
| |
| |
| |
| |
| |
| JMOL_COLORS = { |
| 'H': '#FFFFFF', 'C': '#909090', 'N': '#3050F8', 'O': '#FF0D0D', |
| 'F': '#90E050', 'Cl': '#1FF01F', 'Br': '#A62929', 'I': '#940094', |
| 'P': '#FF8000', 'S': '#FFFF30', 'B': '#FFB5B5', 'Si': '#F0C8A0', |
| 'Cu': '#C88033', 'Zn': '#7D80B0', 'Co': '#F090A0', 'Fe': '#E06633', |
| 'Ni': '#50D050', 'Mn': '#9C7AC7', 'Cr': '#8A99C7', 'V': '#A6A6AB', |
| 'Ti': '#BFC2C7', 'Zr': '#94E0E0', 'Hf': '#4DC2FF', 'Mg': '#8AFF00', |
| 'Ca': '#3DFF00', 'Al': '#BFA6A6', 'Cd': '#FFD98F', 'La': '#70D4FF', |
| 'Ce': '#FFFFC7', 'Ag': '#C0C0C0', 'Au': '#FFD123', 'Pb': '#575961', |
| } |
| METAL_VIEWER_COLOR = '#60a5fa' |
|
|
| def _legend_chip(color: str, text: str) -> str: |
| border = 'border:1px solid #cbd5e1;' if color.upper() == '#FFFFFF' else '' |
| return ( |
| f'<span style="display:inline-flex;align-items:center;gap:4px;' |
| f'margin:0 8px 4px 0;font-family:\'IBM Plex Sans\',sans-serif;' |
| f'font-size:0.78em;color:#334155;">' |
| f'<span style="display:inline-block;width:11px;height:11px;' |
| f'border-radius:50%;background:{color};{border}"></span>{text}' |
| f'</span>' |
| ) |
|
|
| element_chips = [] |
| for elem in sorted(elements_present): |
| if elem in _METAL_ELEMENTS: |
| color = METAL_VIEWER_COLOR |
| else: |
| color = JMOL_COLORS.get(elem, '#888888') |
| element_chips.append(_legend_chip(color, elem)) |
|
|
| cavity_chips_html = '' |
| if n_attribution_beads > 0: |
| cavity_chips_html = ( |
| '<span style="font-family:\'IBM Plex Sans\',sans-serif;font-size:0.72em;' |
| 'font-weight:500;color:#94a3b8;text-transform:uppercase;letter-spacing:0.05em;' |
| 'margin:0 10px 0 18px;">Cavity attribution:</span>' |
| + _legend_chip('#ea580c', '+ contribution') |
| + _legend_chip('#1d4ed8', '− contribution') |
| ) |
|
|
| legend_html = ( |
| '<div style="padding:6px 12px;background:#ffffff;border:1px solid #e2e8f0;' |
| 'border-top:none;border-radius:0 0 4px 4px;margin-top:-4px;line-height:1.6;">' |
| '<span style="font-family:\'IBM Plex Sans\',sans-serif;font-size:0.72em;' |
| 'font-weight:500;color:#94a3b8;text-transform:uppercase;letter-spacing:0.05em;' |
| 'margin-right:10px;">Elements:</span>' |
| + ''.join(element_chips) |
| + cavity_chips_html |
| + '</div>' |
| ) |
|
|
| return header_html + viewer_html + legend_html |
|
|
|
|
| |
| |
| |
|
|
| def export_attribution_csv( |
| structure, |
| per_atom_attrs: np.ndarray, |
| pore_positions: np.ndarray, |
| pore_radii: np.ndarray, |
| per_pore_attrs: np.ndarray, |
| property_name: str, |
| output_path: str, |
| ) -> str: |
| """ |
| Write per-atom and per-pore signed attribution to a CSV for downstream |
| analysis (e.g. ranking sites, identifying design hotspots, or feeding into |
| follow-up MD/DFT calculations on high-attribution motifs). |
| |
| Columns: |
| kind 'atom' or 'pore' |
| idx sequential index within kind |
| element chemical symbol (atoms only; 'pore' for pore vertices) |
| x, y, z Cartesian coordinates (Angstroms) |
| radius inscribed-sphere radius (pores only; blank for atoms) |
| attribution signed attribution to <property_name> |
| b_factor attribution mapped to 1-99 iRASPA convention |
| (1=most negative, 50=neutral, 99=most positive) |
| """ |
| abs_max = float(np.abs(per_atom_attrs).max()) if len(per_atom_attrs) > 0 else 1.0 |
| if len(per_pore_attrs) > 0: |
| abs_max = max(abs_max, float(np.abs(per_pore_attrs).max())) |
| if abs_max < 1e-12: |
| abs_max = 1.0 |
|
|
| rows = [ |
| f'# PoreGCN per-atom and per-pore attribution to {property_name}', |
| f'# Generated by huggingface.co/spaces/catenate/PoreGCN', |
| f'# abs_max = {abs_max:.6f}', |
| 'kind,idx,element,x,y,z,radius,attribution,b_factor', |
| ] |
| |
| for i, site in enumerate(structure): |
| attr = float(per_atom_attrs[i]) if i < len(per_atom_attrs) else 0.0 |
| b = 50.0 + 49.0 * (attr / abs_max) |
| b = max(1.0, min(99.0, b)) |
| rows.append( |
| f'atom,{i},{site.specie},{site.coords[0]:.4f},' |
| f'{site.coords[1]:.4f},{site.coords[2]:.4f},,' |
| f'{attr:.6f},{b:.2f}' |
| ) |
| |
| n_pores = len(per_pore_attrs) |
| if n_pores > 0 and len(pore_positions) == n_pores: |
| for j in range(n_pores): |
| attr = float(per_pore_attrs[j]) |
| b = 50.0 + 49.0 * (attr / abs_max) |
| b = max(1.0, min(99.0, b)) |
| r = float(pore_radii[j]) if j < len(pore_radii) else 0.0 |
| rows.append( |
| f'pore,{j},pore,{pore_positions[j,0]:.4f},' |
| f'{pore_positions[j,1]:.4f},{pore_positions[j,2]:.4f},' |
| f'{r:.3f},{attr:.6f},{b:.2f}' |
| ) |
|
|
| with open(output_path, 'w', encoding='utf-8') as f: |
| f.write('\n'.join(rows)) |
| f.write('\n') |
| return output_path |
|
|
|
|
| |
| |
| |
|
|
| def export_iraspa_cif( |
| structure, |
| per_atom_attrs: np.ndarray, |
| output_path: str, |
| ) -> str: |
| """ |
| Write a CIF file with _atom_site_B_iso_or_equiv encoding XAI attributions. |
| |
| B-factor mapping (symmetric around 50): |
| 1 = most negative attribution (atom drives property DOWN strongly) |
| 50 = neutral (attribution ~ 0) |
| 99 = most positive attribution (atom drives property UP strongly) |
| |
| iRASPA usage: |
| 1. File > Open > <output_path> |
| 2. Appearance > Atoms > Color by > Temperature Factor |
| 3. Apply a diverging blue-white-red colormap |
| |
| Delegates B-factor injection to _cif_string_with_bfactors() so the |
| encoding logic is shared with create_3d_visualization. |
| |
| Args: |
| structure: pymatgen Structure |
| per_atom_attrs: np.ndarray [N_atoms] signed attributions |
| output_path: Destination .cif file path (string) |
| |
| Returns: |
| output_path (string) for gr.File() download component |
| """ |
| if not PYMATGEN_OK: |
| raise ImportError('pymatgen is required: pip install pymatgen') |
|
|
| per_atom_attrs = np.asarray(per_atom_attrs, dtype=float) |
| cif_text = _cif_string_with_bfactors(structure, per_atom_attrs) |
|
|
| output_path = str(output_path) |
| os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True) |
| Path(output_path).write_text(cif_text, encoding='utf-8') |
|
|
| return output_path |
|
|