quantum / em /geometry.py
harishaseebat92's picture
Refactor: Migrate to modular embedded architecture from quantum_embedded
7f9a25d
"""
EM Embedded - Geometry Module
Contains geometry preview builders and hole computation functions.
"""
import numpy as np
import plotly.graph_objects as go
from .state import state, ctrl
from .globals import DEFAULT_AXIS_TICKS, EXCITATION_SURFACE_COLORSCALE
__all__ = [
"nearest_gridline",
"compute_hole_edges",
"build_geometry_placeholder",
"build_square_domain_plot",
"push_geometry_plot",
"update_geometry_preview",
"update_geometry_hole_preview",
]
def nearest_gridline(val: float, nx: int) -> float:
"""Snap a value to the nearest gridline."""
denom = float(max(int(nx) - 1, 1))
return round(float(val) * denom) / denom
def compute_hole_edges(nx: int, cx: float, cy: float, a: float, snap: bool = True):
"""
Compute square hole edges (xL, xR, yB, yT) in [0,1].
Args:
nx: points per direction; grid lines at k/(nx-1).
cx, cy: center in (0,1).
a: edge length in (0,1].
snap: if True, snap edges to nearest grid lines; else require exact alignment.
Returns:
tuple (xL, xR, yB, yT) or None when invalid/out-of-bounds/misaligned.
"""
try:
nx = int(nx)
cx = float(cx)
cy = float(cy)
a = float(a)
except Exception:
return None
if not (a > 0.0):
return None
half = a / 2.0
xL, xR = cx - half, cx + half
yB, yT = cy - half, cy + half
# Must be strictly inside domain to allow removing interior cells safely
if not (0.0 < xL < xR < 1.0 and 0.0 < yB < yT < 1.0):
return None
if snap:
xL_s = nearest_gridline(xL, nx)
xR_s = nearest_gridline(xR, nx)
yB_s = nearest_gridline(yB, nx)
yT_s = nearest_gridline(yT, nx)
# Ensure non-degenerate after snapping; attempt a minimal adjustment if equal
if xL_s >= xR_s or yB_s >= yT_s:
step = 1.0 / float(max(nx - 1, 1))
if xL_s >= xR_s:
if xL_s - step > 0.0:
xL_s -= step
elif xR_s + step < 1.0:
xR_s += step
if yB_s >= yT_s:
if yB_s - step > 0.0:
yB_s -= step
elif yT_s + step < 1.0:
yT_s += step
if xL_s >= xR_s or yB_s >= yT_s:
return None
return (xL_s, xR_s, yB_s, yT_s)
else:
# Require edges to already lie on grid lines
denom = float(max(nx - 1, 1))
tol = 1e-9
def _aligned(v: float) -> bool:
return abs(v * denom - round(v * denom)) < tol
if all(_aligned(v) for v in (xL, xR, yB, yT)):
return (xL, xR, yB, yT)
return None
def build_geometry_placeholder(message: str) -> go.Figure:
"""Build a placeholder figure with a message."""
fig = go.Figure()
fig.add_annotation(
text=message,
x=0.5,
y=0.5,
showarrow=False,
font=dict(size=18, color="#5F259F"),
)
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
fig.update_layout(
template="plotly_white",
margin=dict(l=20, r=20, t=40, b=20),
paper_bgcolor="#ffffff",
plot_bgcolor="#ffffff",
height=460,
showlegend=False,
)
return fig
def build_square_domain_plot(
nx: int,
title: str,
hole_edges=None,
*,
show_edges: bool = True,
dense_grid: bool = False,
) -> go.Figure:
"""Build a 3D square domain plot with optional hole."""
nx = max(int(nx), 3)
grid = np.linspace(0.0, 1.0, nx)
X, Y = np.meshgrid(grid, grid, indexing="xy")
Z = np.zeros_like(X, dtype=float)
color_field = np.full_like(Z, 0.85)
if hole_edges is not None:
xL, xR, yB, yT = hole_edges
mask = (X >= xL) & (X <= xR) & (Y >= yB) & (Y <= yT)
Z = np.where(mask, np.nan, Z)
color_field = np.where(mask, np.nan, color_field)
fig = go.Figure()
fig.add_trace(
go.Surface(
x=X,
y=Y,
z=Z,
surfacecolor=np.where(np.isnan(color_field), 0.15, color_field),
colorscale=EXCITATION_SURFACE_COLORSCALE,
cmin=0.0,
cmax=1.0,
showscale=False,
opacity=0.98,
lighting=dict(ambient=0.85, diffuse=0.55, specular=0.1),
hovertemplate="x=%{x:.3f}<br>y=%{y:.3f}<extra></extra>",
)
)
if show_edges:
base_z = -0.012
grid_vals = (
np.linspace(0.0, 1.0, max(int(nx), 2))
if dense_grid
else np.asarray(DEFAULT_AXIS_TICKS)
)
line_x, line_y, line_z = [], [], []
x_min_val, x_max_val = float(grid[0]), float(grid[-1])
for val in grid_vals:
val_f = float(val)
line_x.extend([val_f, val_f, np.nan])
line_y.extend([x_min_val, x_max_val, np.nan])
line_z.extend([base_z, base_z, np.nan])
for val in grid_vals:
val_f = float(val)
line_x.extend([x_min_val, x_max_val, np.nan])
line_y.extend([val_f, val_f, np.nan])
line_z.extend([base_z, base_z, np.nan])
fig.add_trace(
go.Scatter3d(
x=line_x,
y=line_y,
z=line_z,
mode="lines",
line=dict(color="rgba(174,139,216,0.65)", width=1.6),
showlegend=False,
hoverinfo="skip",
)
)
scale_ticks = list(DEFAULT_AXIS_TICKS)
tick_text = [f"{t:.2f}" for t in scale_ticks]
tick_plane = -0.02
fig.add_trace(
go.Scatter3d(
x=scale_ticks,
y=[-0.018] * len(scale_ticks),
z=[tick_plane] * len(scale_ticks),
mode="text",
text=tick_text,
textfont=dict(color="#5F259F", size=12),
showlegend=False,
hoverinfo="skip",
)
)
fig.add_trace(
go.Scatter3d(
x=[-0.018] * len(scale_ticks),
y=scale_ticks,
z=[tick_plane] * len(scale_ticks),
mode="text",
text=tick_text,
textfont=dict(color="#5F259F", size=12),
showlegend=False,
hoverinfo="skip",
)
)
if hole_edges is not None:
xL, xR, yB, yT = hole_edges
fig.add_trace(
go.Scatter3d(
x=[xL, xR, xR, xL, xL],
y=[yB, yB, yT, yT, yB],
z=[0.0] * 5,
mode="lines",
line=dict(color="#FFFFFF", width=5),
hoverinfo="skip",
showlegend=False,
)
)
fig.update_layout(
title=title,
margin=dict(l=8, r=8, t=44, b=8),
height=620,
template="plotly_white",
scene=dict(
xaxis=dict(range=[-0.05, 1.05], visible=False, backgroundcolor="#f7f3ff"),
yaxis=dict(range=[-0.05, 1.05], visible=False, backgroundcolor="#f7f3ff"),
zaxis=dict(range=[0.1, 0.1], visible=False, backgroundcolor="#f7f3ff"),
aspectmode="cube",
camera=dict(eye=dict(x=1.25, y=1.25, z=0.85)),
),
dragmode="orbit",
uirevision="geometry_surface",
)
return fig
def push_geometry_plot(fig: go.Figure):
"""Push a geometry plot to the UI."""
try:
if hasattr(ctrl, "geometry_preview_update"):
ctrl.geometry_preview_update(fig)
except Exception:
pass
def update_geometry_preview():
"""Update the geometry preview based on current state."""
if not state.bound:
return
geo = state.geometry_selection
if geo in (None, "None"):
fig = build_geometry_placeholder("Select a geometry to preview.")
elif geo == "Square Domain":
nx = int(state.nx or 16)
fig = build_square_domain_plot(nx, "Square Domain Preview", None, show_edges=True)
elif geo == "Square Metallic Body":
nx = int(state.nx or 16)
# Use hole parameters from state
edges = compute_hole_edges(
nx,
float(state.hole_center_x or 0.5),
float(state.hole_center_y or 0.5),
float(state.hole_size_edge or 0.2),
snap=getattr(state, "hole_snap", True),
)
fig = build_square_domain_plot(nx, "Square Metallic Body Preview", edges, show_edges=True)
else:
fig = build_geometry_placeholder(f"Geometry: {geo}")
push_geometry_plot(fig)
def update_geometry_hole_preview():
"""Update geometry preview with current hole settings."""
if not state.bound:
return
geo = state.geometry_selection
if geo != "Square Metallic Body":
return
nx = int(state.nx or 16)
edges = compute_hole_edges(
nx,
float(state.hole_center_x or 0.5),
float(state.hole_center_y or 0.5),
float(state.hole_size_edge or 0.2),
snap=getattr(state, "hole_snap", True),
)
if edges is None:
state.hole_error_message = "Invalid hole configuration (edges out of bounds or degenerate)"
else:
state.hole_error_message = ""
fig = build_square_domain_plot(nx, "Square Metallic Body Preview", edges, show_edges=True)
push_geometry_plot(fig)