Spaces:
Configuration error
Configuration error
| import os | |
| import numpy as np | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| from skimage.filters import frangi | |
| from skimage import measure | |
| import pydicom | |
| def generate_synthetic_dental_volume(shape=(128, 192, 192)): | |
| depth, height, width = shape | |
| vol = np.random.normal(loc=-700.0, scale=60.0, size=shape).astype(np.float32) | |
| y_grid, x_grid = np.ogrid[:height, :width] | |
| teeth = [] | |
| cols = 8 | |
| xs = np.linspace(int(width * 0.15), int(width * 0.85), cols) | |
| y_center = int(height * 0.65) | |
| z0, z1 = int(depth * 0.2), int(depth * 0.85) | |
| for xc in xs: | |
| rx = int(width * 0.03) | |
| ry = int(height * 0.05) | |
| canal_rx = max(2, int(rx * 0.25)) | |
| canal_ry = max(2, int(ry * 0.25)) | |
| teeth.append((int(xc), y_center, rx, ry, canal_rx, canal_ry)) | |
| for (xc, yc, rx, ry, crx, cry) in teeth: | |
| ell = ((y_grid - yc) / float(ry)) ** 2 + ((x_grid - xc) / float(rx)) ** 2 <= 1.0 | |
| canal = ((y_grid - yc) / float(cry)) ** 2 + ((x_grid - xc) / float(crx)) ** 2 <= 1.0 | |
| bone_val = 1200.0 | |
| canal_val = -250.0 | |
| for z in range(z0, z1): | |
| vol[z][ell] = bone_val | |
| vol[z][canal] = canal_val | |
| vol = np.clip(vol, -1000.0, 2000.0) | |
| spacing = (1.0, 1.0, 1.0) # (sx, sy, sz) | |
| return vol, spacing | |
| def build_mesh_figure(volume: np.ndarray, threshold: float, spacing): | |
| try: | |
| verts, faces, normals, values = measure.marching_cubes(volume, level=threshold, step_size=2) | |
| # verts are (z, y, x); reorder and scale by spacing | |
| sx, sy, sz = spacing | |
| x = verts[:, 2] * sx | |
| y = verts[:, 1] * sy | |
| z = verts[:, 0] * sz | |
| i, j, k = faces.T | |
| mesh = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, | |
| color='lightyellow', opacity=0.65, flatshading=False) | |
| fig = go.Figure(data=[mesh]) | |
| fig.update_layout(scene=dict(aspectmode='data')) | |
| return fig | |
| except Exception: | |
| # Fallback empty figure if threshold invalid | |
| fig = go.Figure() | |
| fig.update_layout(scene=dict(aspectmode='data')) | |
| return fig | |
| def axial_slice_image(volume: np.ndarray, z_idx: int, points=None, level=400.0, width=1500.0): | |
| z_idx = int(np.clip(z_idx, 0, volume.shape[0] - 1)) | |
| sl = volume[z_idx] | |
| vmin = level - width / 2.0 | |
| vmax = level + width / 2.0 | |
| sl = np.clip(sl, vmin, vmax) | |
| sl = (sl - vmin) / (vmax - vmin + 1e-6) | |
| sl_rgb = (np.stack([sl, sl, sl], axis=-1) * 255).astype(np.uint8) | |
| if points: | |
| for (x, y, z) in points: | |
| if int(z) == int(z_idx): | |
| xr = int(np.clip(x, 0, sl_rgb.shape[1] - 1)) | |
| yr = int(np.clip(y, 0, sl_rgb.shape[0] - 1)) | |
| # Draw small cyan cross | |
| s = 2 | |
| sl_rgb[max(0, yr - s):yr + s + 1, xr: xr + 1] = [0, 255, 255] | |
| sl_rgb[yr: yr + 1, max(0, xr - s):xr + s + 1] = [0, 255, 255] | |
| return sl_rgb | |
| def detect_root_canals_fast_axial(volume: np.ndarray, bone_threshold=200.0, downsample=3, | |
| top_n=40, center_index=None, slice_range=30, slice_step=3): | |
| vol = volume.astype(np.float32, copy=False) | |
| mask = vol > bone_threshold | |
| if np.any(mask): | |
| coords = np.argwhere(mask) | |
| zmin, ymin, xmin = coords.min(axis=0) | |
| zmax, ymax, xmax = coords.max(axis=0) | |
| margin = 4 * downsample | |
| zmin = max(0, int(zmin - margin)) | |
| ymin = max(0, int(ymin - margin)) | |
| xmin = max(0, int(xmin - margin)) | |
| zmax = min(vol.shape[0] - 1, int(zmax + margin)) | |
| ymax = min(vol.shape[1] - 1, int(ymax + margin)) | |
| xmax = min(vol.shape[2] - 1, int(xmax + margin)) | |
| else: | |
| zmin = 0; zmax = vol.shape[0] - 1 | |
| ymin = 0; ymax = vol.shape[1] - 1 | |
| xmin = 0; xmax = vol.shape[2] - 1 | |
| if center_index is None: | |
| center_index = vol.shape[0] // 2 | |
| start = max(0, int(center_index) - int(slice_range)) | |
| end = min(vol.shape[0] - 1, int(center_index) + int(slice_range)) | |
| zs = list(range(start, end + 1, int(max(1, slice_step)))) | |
| points = [] | |
| for z in zs: | |
| sl = vol[z, ymin:ymax + 1, xmin:xmax + 1] | |
| p5, p995 = np.percentile(sl, [5, 99.5]) | |
| if p995 <= p5: | |
| p5, p995 = float(sl.min()), float(sl.max()) | |
| sl = np.clip(sl, p5, p995) | |
| sl = (sl - p5) / (p995 - p5 + 1e-6) | |
| inv2 = 1.0 - sl | |
| ds = int(max(1, downsample)) | |
| inv2_ds = inv2[::ds, ::ds] if ds > 1 else inv2 | |
| resp2 = frangi(inv2_ds, sigmas=np.array([0.6, 1.2]), alpha=0.5, beta=0.5, gamma=15, black_ridges=True) | |
| k = max(1, int(top_n) // max(1, len(zs))) | |
| flat = resp2.ravel() | |
| if flat.size == 0: | |
| continue | |
| idxs = np.argpartition(flat, -k)[-k:] | |
| for idx in idxs: | |
| r, c = divmod(int(idx), resp2.shape[1]) | |
| y_full = ymin + r * ds | |
| x_full = xmin + c * ds | |
| points.append((int(x_full), int(y_full), int(z))) | |
| points = list({(x, y, z) for (x, y, z) in points}) # unique | |
| return points[: int(top_n)] | |
| def load_dicom_series(files): | |
| datasets = [] | |
| for f in files or []: | |
| try: | |
| ds = pydicom.dcmread(f.name, force=True) | |
| if hasattr(ds, 'pixel_array'): | |
| datasets.append(ds) | |
| except Exception: | |
| continue | |
| if not datasets: | |
| raise ValueError('No valid DICOM slices uploaded') | |
| # Sort | |
| try: | |
| datasets.sort(key=lambda x: float(x.SliceLocation) if hasattr(x, 'SliceLocation') else ( | |
| int(x.InstanceNumber) if hasattr(x, 'InstanceNumber') else 0)) | |
| except Exception: | |
| pass | |
| rows = int(datasets[0].Rows) | |
| cols = int(datasets[0].Columns) | |
| num = len(datasets) | |
| vol = np.zeros((num, rows, cols), dtype=np.float32) | |
| for i, ds in enumerate(datasets): | |
| arr = ds.pixel_array.astype(np.float32) | |
| slope = float(getattr(ds, 'RescaleSlope', 1.0)) | |
| intercept = float(getattr(ds, 'RescaleIntercept', 0.0)) | |
| vol[i] = arr * slope + intercept | |
| # Spacing | |
| sx = 1.0; sy = 1.0; sz = 1.0 | |
| ds0 = datasets[0] | |
| if hasattr(ds0, 'PixelSpacing') and len(ds0.PixelSpacing) >= 2: | |
| sy = float(ds0.PixelSpacing[0]) | |
| sx = float(ds0.PixelSpacing[1]) | |
| if hasattr(ds0, 'SpacingBetweenSlices'): | |
| try: | |
| sz = float(ds0.SpacingBetweenSlices) | |
| except Exception: | |
| pass | |
| elif hasattr(ds0, 'SliceThickness'): | |
| try: | |
| sz = float(ds0.SliceThickness) | |
| except Exception: | |
| pass | |
| return vol, (sx, sy, sz) | |
| # Gradio app | |
| with gr.Blocks(title="Dental AI - Hugging Face Space") as demo: | |
| gr.Markdown(""" | |
| # Dental AI Demo (Ethical, Heuristic) | |
| - Generate a synthetic CBCT-like volume and visualize in 3D and 2D. | |
| - Run a fast heuristic root canal candidate detector (Frangi) — no fake ML. | |
| - Or upload a DICOM series (multiple files) to visualize. | |
| """) | |
| vol_state = gr.State(None) | |
| spacing_state = gr.State((1.0, 1.0, 1.0)) | |
| points_state = gr.State([]) | |
| with gr.Row(): | |
| gen_btn = gr.Button("🧪 Generate Synthetic Volume", variant="primary") | |
| files = gr.File(label="Upload DICOM slices (multiple)", file_count="multiple") | |
| with gr.Row(): | |
| threshold = gr.Slider(0, 3000, value=200, step=10, label="Bone Threshold (HU)") | |
| slice_idx = gr.Slider(0, 1, value=0, step=1, label="Axial Slice") | |
| with gr.Row(): | |
| detect_btn = gr.Button("🦷 Detect Root Canals (Fast 2.5D)") | |
| downsample = gr.Slider(1, 8, value=3, step=1, label="Downsample") | |
| topn = gr.Slider(5, 200, value=40, step=5, label="Top N Candidates") | |
| sl_range = gr.Slider(5, 120, value=30, step=5, label="Slice Range") | |
| sl_step = gr.Slider(1, 10, value=3, step=1, label="Slice Step") | |
| with gr.Row(): | |
| fig3d = gr.Plot(label="3D Surface") | |
| img2d = gr.Image(label="Axial Slice", type="numpy") | |
| info = gr.Markdown(visible=True) | |
| def _update_view(vol, spacing, thr, z, points): | |
| if vol is None: | |
| return gr.update(), gr.update() | |
| fig = build_mesh_figure(vol, thr, spacing) | |
| img = axial_slice_image(vol, int(z), points, level=400.0, width=1500.0) | |
| return fig, img | |
| def on_generate(thr): | |
| vol, spacing = generate_synthetic_dental_volume() | |
| depth = int(vol.shape[0]) | |
| points = [] | |
| fig, img = _update_view(vol, spacing, thr, depth // 2, points) | |
| info = f"Generated synthetic volume: {vol.shape} with spacing {spacing}" | |
| return vol, spacing, points, gr.update(minimum=0, maximum=depth - 1, value=depth // 2, step=1), fig, img, info | |
| gen_btn.click(on_generate, inputs=[threshold], | |
| outputs=[vol_state, spacing_state, points_state, slice_idx, fig3d, img2d, info]) | |
| def on_files(thr, files_list): | |
| if not files_list: | |
| return gr.update(), gr.update(), gr.update(), None, None, "No files uploaded" | |
| try: | |
| vol, spacing = load_dicom_series(files_list) | |
| depth = int(vol.shape[0]) | |
| points = [] | |
| fig, img = _update_view(vol, spacing, thr, depth // 2, points) | |
| info = f"Loaded DICOM series: {vol.shape} with spacing {spacing} (showing middle slice)" | |
| return vol, spacing, points, gr.update(minimum=0, maximum=depth - 1, value=depth // 2, step=1), fig, img, info | |
| except Exception as e: | |
| return gr.update(), gr.update(), gr.update(), None, None, f"❌ Error: {e}" | |
| files.change(on_files, inputs=[threshold, files], | |
| outputs=[vol_state, spacing_state, points_state, slice_idx, fig3d, img2d, info]) | |
| def on_view_change(vol, spacing, thr, z, points): | |
| return _update_view(vol, spacing, thr, z, points) | |
| threshold.release(on_view_change, inputs=[vol_state, spacing_state, threshold, slice_idx, points_state], outputs=[fig3d, img2d]) | |
| slice_idx.release(on_view_change, inputs=[vol_state, spacing_state, threshold, slice_idx, points_state], outputs=[fig3d, img2d]) | |
| def on_detect(vol, spacing, thr, z, ds, tn, rge, stp): | |
| if vol is None: | |
| return gr.update(), gr.update(), [] | |
| points = detect_root_canals_fast_axial(vol, bone_threshold=thr, downsample=ds, | |
| top_n=int(tn), center_index=int(z), | |
| slice_range=int(rge), slice_step=int(stp)) | |
| fig, img = _update_view(vol, spacing, thr, z, points) | |
| return fig, img, points | |
| detect_btn.click(on_detect, | |
| inputs=[vol_state, spacing_state, threshold, slice_idx, downsample, topn, sl_range, sl_step], | |
| outputs=[fig3d, img2d, points_state]) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", "7860")) | |
| demo.launch(server_name="0.0.0.0", server_port=port) | |