""" EM Embedded - Geometry Module Contains geometry preview builders and hole computation functions. """ import numpy as np import plotly.graph_objects as go from .state import state, ctrl from .globals import DEFAULT_AXIS_TICKS, EXCITATION_SURFACE_COLORSCALE __all__ = [ "nearest_gridline", "compute_hole_edges", "build_geometry_placeholder", "build_square_domain_plot", "push_geometry_plot", "update_geometry_preview", "update_geometry_hole_preview", ] def nearest_gridline(val: float, nx: int) -> float: """Snap a value to the nearest gridline.""" denom = float(max(int(nx) - 1, 1)) return round(float(val) * denom) / denom def compute_hole_edges(nx: int, cx: float, cy: float, a: float, snap: bool = True): """ Compute square hole edges (xL, xR, yB, yT) in [0,1]. Args: nx: points per direction; grid lines at k/(nx-1). cx, cy: center in (0,1). a: edge length in (0,1]. snap: if True, snap edges to nearest grid lines; else require exact alignment. Returns: tuple (xL, xR, yB, yT) or None when invalid/out-of-bounds/misaligned. """ try: nx = int(nx) cx = float(cx) cy = float(cy) a = float(a) except Exception: return None if not (a > 0.0): return None half = a / 2.0 xL, xR = cx - half, cx + half yB, yT = cy - half, cy + half # Must be strictly inside domain to allow removing interior cells safely if not (0.0 < xL < xR < 1.0 and 0.0 < yB < yT < 1.0): return None if snap: xL_s = nearest_gridline(xL, nx) xR_s = nearest_gridline(xR, nx) yB_s = nearest_gridline(yB, nx) yT_s = nearest_gridline(yT, nx) # Ensure non-degenerate after snapping; attempt a minimal adjustment if equal if xL_s >= xR_s or yB_s >= yT_s: step = 1.0 / float(max(nx - 1, 1)) if xL_s >= xR_s: if xL_s - step > 0.0: xL_s -= step elif xR_s + step < 1.0: xR_s += step if yB_s >= yT_s: if yB_s - step > 0.0: yB_s -= step elif yT_s + step < 1.0: yT_s += step if xL_s >= xR_s or yB_s >= yT_s: return None return (xL_s, xR_s, yB_s, yT_s) else: # Require edges to already lie on grid lines denom = float(max(nx - 1, 1)) tol = 1e-9 def _aligned(v: float) -> bool: return abs(v * denom - round(v * denom)) < tol if all(_aligned(v) for v in (xL, xR, yB, yT)): return (xL, xR, yB, yT) return None def build_geometry_placeholder(message: str) -> go.Figure: """Build a placeholder figure with a message.""" fig = go.Figure() fig.add_annotation( text=message, x=0.5, y=0.5, showarrow=False, font=dict(size=18, color="#5F259F"), ) fig.update_xaxes(visible=False) fig.update_yaxes(visible=False) fig.update_layout( template="plotly_white", margin=dict(l=20, r=20, t=40, b=20), paper_bgcolor="#ffffff", plot_bgcolor="#ffffff", height=460, showlegend=False, ) return fig def build_square_domain_plot( nx: int, title: str, hole_edges=None, *, show_edges: bool = True, dense_grid: bool = False, ) -> go.Figure: """Build a 3D square domain plot with optional hole.""" nx = max(int(nx), 3) grid = np.linspace(0.0, 1.0, nx) X, Y = np.meshgrid(grid, grid, indexing="xy") Z = np.zeros_like(X, dtype=float) color_field = np.full_like(Z, 0.85) if hole_edges is not None: xL, xR, yB, yT = hole_edges mask = (X >= xL) & (X <= xR) & (Y >= yB) & (Y <= yT) Z = np.where(mask, np.nan, Z) color_field = np.where(mask, np.nan, color_field) fig = go.Figure() fig.add_trace( go.Surface( x=X, y=Y, z=Z, surfacecolor=np.where(np.isnan(color_field), 0.15, color_field), colorscale=EXCITATION_SURFACE_COLORSCALE, cmin=0.0, cmax=1.0, showscale=False, opacity=0.98, lighting=dict(ambient=0.85, diffuse=0.55, specular=0.1), hovertemplate="x=%{x:.3f}
y=%{y:.3f}", ) ) if show_edges: base_z = -0.012 grid_vals = ( np.linspace(0.0, 1.0, max(int(nx), 2)) if dense_grid else np.asarray(DEFAULT_AXIS_TICKS) ) line_x, line_y, line_z = [], [], [] x_min_val, x_max_val = float(grid[0]), float(grid[-1]) for val in grid_vals: val_f = float(val) line_x.extend([val_f, val_f, np.nan]) line_y.extend([x_min_val, x_max_val, np.nan]) line_z.extend([base_z, base_z, np.nan]) for val in grid_vals: val_f = float(val) line_x.extend([x_min_val, x_max_val, np.nan]) line_y.extend([val_f, val_f, np.nan]) line_z.extend([base_z, base_z, np.nan]) fig.add_trace( go.Scatter3d( x=line_x, y=line_y, z=line_z, mode="lines", line=dict(color="rgba(174,139,216,0.65)", width=1.6), showlegend=False, hoverinfo="skip", ) ) scale_ticks = list(DEFAULT_AXIS_TICKS) tick_text = [f"{t:.2f}" for t in scale_ticks] tick_plane = -0.02 fig.add_trace( go.Scatter3d( x=scale_ticks, y=[-0.018] * len(scale_ticks), z=[tick_plane] * len(scale_ticks), mode="text", text=tick_text, textfont=dict(color="#5F259F", size=12), showlegend=False, hoverinfo="skip", ) ) fig.add_trace( go.Scatter3d( x=[-0.018] * len(scale_ticks), y=scale_ticks, z=[tick_plane] * len(scale_ticks), mode="text", text=tick_text, textfont=dict(color="#5F259F", size=12), showlegend=False, hoverinfo="skip", ) ) if hole_edges is not None: xL, xR, yB, yT = hole_edges fig.add_trace( go.Scatter3d( x=[xL, xR, xR, xL, xL], y=[yB, yB, yT, yT, yB], z=[0.0] * 5, mode="lines", line=dict(color="#FFFFFF", width=5), hoverinfo="skip", showlegend=False, ) ) fig.update_layout( title=title, margin=dict(l=8, r=8, t=44, b=8), height=620, template="plotly_white", scene=dict( xaxis=dict(range=[-0.05, 1.05], visible=False, backgroundcolor="#f7f3ff"), yaxis=dict(range=[-0.05, 1.05], visible=False, backgroundcolor="#f7f3ff"), zaxis=dict(range=[0.1, 0.1], visible=False, backgroundcolor="#f7f3ff"), aspectmode="cube", camera=dict(eye=dict(x=1.25, y=1.25, z=0.85)), ), dragmode="orbit", uirevision="geometry_surface", ) return fig def push_geometry_plot(fig: go.Figure): """Push a geometry plot to the UI.""" try: if hasattr(ctrl, "geometry_preview_update"): ctrl.geometry_preview_update(fig) except Exception: pass def update_geometry_preview(): """Update the geometry preview based on current state.""" if not state.bound: return geo = state.geometry_selection if geo in (None, "None"): fig = build_geometry_placeholder("Select a geometry to preview.") elif geo == "Square Domain": nx = int(state.nx or 16) fig = build_square_domain_plot(nx, "Square Domain Preview", None, show_edges=True) elif geo == "Square Metallic Body": nx = int(state.nx or 16) # Use hole parameters from state edges = compute_hole_edges( nx, float(state.hole_center_x or 0.5), float(state.hole_center_y or 0.5), float(state.hole_size_edge or 0.2), snap=getattr(state, "hole_snap", True), ) fig = build_square_domain_plot(nx, "Square Metallic Body Preview", edges, show_edges=True) else: fig = build_geometry_placeholder(f"Geometry: {geo}") push_geometry_plot(fig) def update_geometry_hole_preview(): """Update geometry preview with current hole settings.""" if not state.bound: return geo = state.geometry_selection if geo != "Square Metallic Body": return nx = int(state.nx or 16) edges = compute_hole_edges( nx, float(state.hole_center_x or 0.5), float(state.hole_center_y or 0.5), float(state.hole_size_edge or 0.2), snap=getattr(state, "hole_snap", True), ) if edges is None: state.hole_error_message = "Invalid hole configuration (edges out of bounds or degenerate)" else: state.hole_error_message = "" fig = build_square_domain_plot(nx, "Square Metallic Body Preview", edges, show_edges=True) push_geometry_plot(fig)