import os import numpy as np import gradio as gr import plotly.graph_objects as go from skimage.filters import frangi from skimage import measure import pydicom def generate_synthetic_dental_volume(shape=(128, 192, 192)): depth, height, width = shape vol = np.random.normal(loc=-700.0, scale=60.0, size=shape).astype(np.float32) y_grid, x_grid = np.ogrid[:height, :width] teeth = [] cols = 8 xs = np.linspace(int(width * 0.15), int(width * 0.85), cols) y_center = int(height * 0.65) z0, z1 = int(depth * 0.2), int(depth * 0.85) for xc in xs: rx = int(width * 0.03) ry = int(height * 0.05) canal_rx = max(2, int(rx * 0.25)) canal_ry = max(2, int(ry * 0.25)) teeth.append((int(xc), y_center, rx, ry, canal_rx, canal_ry)) for (xc, yc, rx, ry, crx, cry) in teeth: ell = ((y_grid - yc) / float(ry)) ** 2 + ((x_grid - xc) / float(rx)) ** 2 <= 1.0 canal = ((y_grid - yc) / float(cry)) ** 2 + ((x_grid - xc) / float(crx)) ** 2 <= 1.0 bone_val = 1200.0 canal_val = -250.0 for z in range(z0, z1): vol[z][ell] = bone_val vol[z][canal] = canal_val vol = np.clip(vol, -1000.0, 2000.0) spacing = (1.0, 1.0, 1.0) # (sx, sy, sz) return vol, spacing def build_mesh_figure(volume: np.ndarray, threshold: float, spacing): try: verts, faces, normals, values = measure.marching_cubes(volume, level=threshold, step_size=2) # verts are (z, y, x); reorder and scale by spacing sx, sy, sz = spacing x = verts[:, 2] * sx y = verts[:, 1] * sy z = verts[:, 0] * sz i, j, k = faces.T mesh = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color='lightyellow', opacity=0.65, flatshading=False) fig = go.Figure(data=[mesh]) fig.update_layout(scene=dict(aspectmode='data')) return fig except Exception: # Fallback empty figure if threshold invalid fig = go.Figure() fig.update_layout(scene=dict(aspectmode='data')) return fig def axial_slice_image(volume: np.ndarray, z_idx: int, points=None, level=400.0, width=1500.0): z_idx = int(np.clip(z_idx, 0, volume.shape[0] - 1)) sl = volume[z_idx] vmin = level - width / 2.0 vmax = level + width / 2.0 sl = np.clip(sl, vmin, vmax) sl = (sl - vmin) / (vmax - vmin + 1e-6) sl_rgb = (np.stack([sl, sl, sl], axis=-1) * 255).astype(np.uint8) if points: for (x, y, z) in points: if int(z) == int(z_idx): xr = int(np.clip(x, 0, sl_rgb.shape[1] - 1)) yr = int(np.clip(y, 0, sl_rgb.shape[0] - 1)) # Draw small cyan cross s = 2 sl_rgb[max(0, yr - s):yr + s + 1, xr: xr + 1] = [0, 255, 255] sl_rgb[yr: yr + 1, max(0, xr - s):xr + s + 1] = [0, 255, 255] return sl_rgb def detect_root_canals_fast_axial(volume: np.ndarray, bone_threshold=200.0, downsample=3, top_n=40, center_index=None, slice_range=30, slice_step=3): vol = volume.astype(np.float32, copy=False) mask = vol > bone_threshold if np.any(mask): coords = np.argwhere(mask) zmin, ymin, xmin = coords.min(axis=0) zmax, ymax, xmax = coords.max(axis=0) margin = 4 * downsample zmin = max(0, int(zmin - margin)) ymin = max(0, int(ymin - margin)) xmin = max(0, int(xmin - margin)) zmax = min(vol.shape[0] - 1, int(zmax + margin)) ymax = min(vol.shape[1] - 1, int(ymax + margin)) xmax = min(vol.shape[2] - 1, int(xmax + margin)) else: zmin = 0; zmax = vol.shape[0] - 1 ymin = 0; ymax = vol.shape[1] - 1 xmin = 0; xmax = vol.shape[2] - 1 if center_index is None: center_index = vol.shape[0] // 2 start = max(0, int(center_index) - int(slice_range)) end = min(vol.shape[0] - 1, int(center_index) + int(slice_range)) zs = list(range(start, end + 1, int(max(1, slice_step)))) points = [] for z in zs: sl = vol[z, ymin:ymax + 1, xmin:xmax + 1] p5, p995 = np.percentile(sl, [5, 99.5]) if p995 <= p5: p5, p995 = float(sl.min()), float(sl.max()) sl = np.clip(sl, p5, p995) sl = (sl - p5) / (p995 - p5 + 1e-6) inv2 = 1.0 - sl ds = int(max(1, downsample)) inv2_ds = inv2[::ds, ::ds] if ds > 1 else inv2 resp2 = frangi(inv2_ds, sigmas=np.array([0.6, 1.2]), alpha=0.5, beta=0.5, gamma=15, black_ridges=True) k = max(1, int(top_n) // max(1, len(zs))) flat = resp2.ravel() if flat.size == 0: continue idxs = np.argpartition(flat, -k)[-k:] for idx in idxs: r, c = divmod(int(idx), resp2.shape[1]) y_full = ymin + r * ds x_full = xmin + c * ds points.append((int(x_full), int(y_full), int(z))) points = list({(x, y, z) for (x, y, z) in points}) # unique return points[: int(top_n)] def load_dicom_series(files): datasets = [] for f in files or []: try: ds = pydicom.dcmread(f.name, force=True) if hasattr(ds, 'pixel_array'): datasets.append(ds) except Exception: continue if not datasets: raise ValueError('No valid DICOM slices uploaded') # Sort try: datasets.sort(key=lambda x: float(x.SliceLocation) if hasattr(x, 'SliceLocation') else ( int(x.InstanceNumber) if hasattr(x, 'InstanceNumber') else 0)) except Exception: pass rows = int(datasets[0].Rows) cols = int(datasets[0].Columns) num = len(datasets) vol = np.zeros((num, rows, cols), dtype=np.float32) for i, ds in enumerate(datasets): arr = ds.pixel_array.astype(np.float32) slope = float(getattr(ds, 'RescaleSlope', 1.0)) intercept = float(getattr(ds, 'RescaleIntercept', 0.0)) vol[i] = arr * slope + intercept # Spacing sx = 1.0; sy = 1.0; sz = 1.0 ds0 = datasets[0] if hasattr(ds0, 'PixelSpacing') and len(ds0.PixelSpacing) >= 2: sy = float(ds0.PixelSpacing[0]) sx = float(ds0.PixelSpacing[1]) if hasattr(ds0, 'SpacingBetweenSlices'): try: sz = float(ds0.SpacingBetweenSlices) except Exception: pass elif hasattr(ds0, 'SliceThickness'): try: sz = float(ds0.SliceThickness) except Exception: pass return vol, (sx, sy, sz) # Gradio app with gr.Blocks(title="Dental AI - Hugging Face Space") as demo: gr.Markdown(""" # Dental AI Demo (Ethical, Heuristic) - Generate a synthetic CBCT-like volume and visualize in 3D and 2D. - Run a fast heuristic root canal candidate detector (Frangi) — no fake ML. - Or upload a DICOM series (multiple files) to visualize. """) vol_state = gr.State(None) spacing_state = gr.State((1.0, 1.0, 1.0)) points_state = gr.State([]) with gr.Row(): gen_btn = gr.Button("🧪 Generate Synthetic Volume", variant="primary") files = gr.File(label="Upload DICOM slices (multiple)", file_count="multiple") with gr.Row(): threshold = gr.Slider(0, 3000, value=200, step=10, label="Bone Threshold (HU)") slice_idx = gr.Slider(0, 1, value=0, step=1, label="Axial Slice") with gr.Row(): detect_btn = gr.Button("🦷 Detect Root Canals (Fast 2.5D)") downsample = gr.Slider(1, 8, value=3, step=1, label="Downsample") topn = gr.Slider(5, 200, value=40, step=5, label="Top N Candidates") sl_range = gr.Slider(5, 120, value=30, step=5, label="Slice Range") sl_step = gr.Slider(1, 10, value=3, step=1, label="Slice Step") with gr.Row(): fig3d = gr.Plot(label="3D Surface") img2d = gr.Image(label="Axial Slice", type="numpy") info = gr.Markdown(visible=True) def _update_view(vol, spacing, thr, z, points): if vol is None: return gr.update(), gr.update() fig = build_mesh_figure(vol, thr, spacing) img = axial_slice_image(vol, int(z), points, level=400.0, width=1500.0) return fig, img def on_generate(thr): vol, spacing = generate_synthetic_dental_volume() depth = int(vol.shape[0]) points = [] fig, img = _update_view(vol, spacing, thr, depth // 2, points) info = f"Generated synthetic volume: {vol.shape} with spacing {spacing}" return vol, spacing, points, gr.update(minimum=0, maximum=depth - 1, value=depth // 2, step=1), fig, img, info gen_btn.click(on_generate, inputs=[threshold], outputs=[vol_state, spacing_state, points_state, slice_idx, fig3d, img2d, info]) def on_files(thr, files_list): if not files_list: return gr.update(), gr.update(), gr.update(), None, None, "No files uploaded" try: vol, spacing = load_dicom_series(files_list) depth = int(vol.shape[0]) points = [] fig, img = _update_view(vol, spacing, thr, depth // 2, points) info = f"Loaded DICOM series: {vol.shape} with spacing {spacing} (showing middle slice)" return vol, spacing, points, gr.update(minimum=0, maximum=depth - 1, value=depth // 2, step=1), fig, img, info except Exception as e: return gr.update(), gr.update(), gr.update(), None, None, f"❌ Error: {e}" files.change(on_files, inputs=[threshold, files], outputs=[vol_state, spacing_state, points_state, slice_idx, fig3d, img2d, info]) def on_view_change(vol, spacing, thr, z, points): return _update_view(vol, spacing, thr, z, points) threshold.release(on_view_change, inputs=[vol_state, spacing_state, threshold, slice_idx, points_state], outputs=[fig3d, img2d]) slice_idx.release(on_view_change, inputs=[vol_state, spacing_state, threshold, slice_idx, points_state], outputs=[fig3d, img2d]) def on_detect(vol, spacing, thr, z, ds, tn, rge, stp): if vol is None: return gr.update(), gr.update(), [] points = detect_root_canals_fast_axial(vol, bone_threshold=thr, downsample=ds, top_n=int(tn), center_index=int(z), slice_range=int(rge), slice_step=int(stp)) fig, img = _update_view(vol, spacing, thr, z, points) return fig, img, points detect_btn.click(on_detect, inputs=[vol_state, spacing_state, threshold, slice_idx, downsample, topn, sl_range, sl_step], outputs=[fig3d, img2d, points_state]) if __name__ == "__main__": port = int(os.environ.get("PORT", "7860")) demo.launch(server_name="0.0.0.0", server_port=port)