SHIKARICHACHA's picture
Upload 4 files
8e0fabe verified
import os
import numpy as np
import gradio as gr
import plotly.graph_objects as go
from skimage.filters import frangi
from skimage import measure
import pydicom
def generate_synthetic_dental_volume(shape=(128, 192, 192)):
depth, height, width = shape
vol = np.random.normal(loc=-700.0, scale=60.0, size=shape).astype(np.float32)
y_grid, x_grid = np.ogrid[:height, :width]
teeth = []
cols = 8
xs = np.linspace(int(width * 0.15), int(width * 0.85), cols)
y_center = int(height * 0.65)
z0, z1 = int(depth * 0.2), int(depth * 0.85)
for xc in xs:
rx = int(width * 0.03)
ry = int(height * 0.05)
canal_rx = max(2, int(rx * 0.25))
canal_ry = max(2, int(ry * 0.25))
teeth.append((int(xc), y_center, rx, ry, canal_rx, canal_ry))
for (xc, yc, rx, ry, crx, cry) in teeth:
ell = ((y_grid - yc) / float(ry)) ** 2 + ((x_grid - xc) / float(rx)) ** 2 <= 1.0
canal = ((y_grid - yc) / float(cry)) ** 2 + ((x_grid - xc) / float(crx)) ** 2 <= 1.0
bone_val = 1200.0
canal_val = -250.0
for z in range(z0, z1):
vol[z][ell] = bone_val
vol[z][canal] = canal_val
vol = np.clip(vol, -1000.0, 2000.0)
spacing = (1.0, 1.0, 1.0) # (sx, sy, sz)
return vol, spacing
def build_mesh_figure(volume: np.ndarray, threshold: float, spacing):
try:
verts, faces, normals, values = measure.marching_cubes(volume, level=threshold, step_size=2)
# verts are (z, y, x); reorder and scale by spacing
sx, sy, sz = spacing
x = verts[:, 2] * sx
y = verts[:, 1] * sy
z = verts[:, 0] * sz
i, j, k = faces.T
mesh = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k,
color='lightyellow', opacity=0.65, flatshading=False)
fig = go.Figure(data=[mesh])
fig.update_layout(scene=dict(aspectmode='data'))
return fig
except Exception:
# Fallback empty figure if threshold invalid
fig = go.Figure()
fig.update_layout(scene=dict(aspectmode='data'))
return fig
def axial_slice_image(volume: np.ndarray, z_idx: int, points=None, level=400.0, width=1500.0):
z_idx = int(np.clip(z_idx, 0, volume.shape[0] - 1))
sl = volume[z_idx]
vmin = level - width / 2.0
vmax = level + width / 2.0
sl = np.clip(sl, vmin, vmax)
sl = (sl - vmin) / (vmax - vmin + 1e-6)
sl_rgb = (np.stack([sl, sl, sl], axis=-1) * 255).astype(np.uint8)
if points:
for (x, y, z) in points:
if int(z) == int(z_idx):
xr = int(np.clip(x, 0, sl_rgb.shape[1] - 1))
yr = int(np.clip(y, 0, sl_rgb.shape[0] - 1))
# Draw small cyan cross
s = 2
sl_rgb[max(0, yr - s):yr + s + 1, xr: xr + 1] = [0, 255, 255]
sl_rgb[yr: yr + 1, max(0, xr - s):xr + s + 1] = [0, 255, 255]
return sl_rgb
def detect_root_canals_fast_axial(volume: np.ndarray, bone_threshold=200.0, downsample=3,
top_n=40, center_index=None, slice_range=30, slice_step=3):
vol = volume.astype(np.float32, copy=False)
mask = vol > bone_threshold
if np.any(mask):
coords = np.argwhere(mask)
zmin, ymin, xmin = coords.min(axis=0)
zmax, ymax, xmax = coords.max(axis=0)
margin = 4 * downsample
zmin = max(0, int(zmin - margin))
ymin = max(0, int(ymin - margin))
xmin = max(0, int(xmin - margin))
zmax = min(vol.shape[0] - 1, int(zmax + margin))
ymax = min(vol.shape[1] - 1, int(ymax + margin))
xmax = min(vol.shape[2] - 1, int(xmax + margin))
else:
zmin = 0; zmax = vol.shape[0] - 1
ymin = 0; ymax = vol.shape[1] - 1
xmin = 0; xmax = vol.shape[2] - 1
if center_index is None:
center_index = vol.shape[0] // 2
start = max(0, int(center_index) - int(slice_range))
end = min(vol.shape[0] - 1, int(center_index) + int(slice_range))
zs = list(range(start, end + 1, int(max(1, slice_step))))
points = []
for z in zs:
sl = vol[z, ymin:ymax + 1, xmin:xmax + 1]
p5, p995 = np.percentile(sl, [5, 99.5])
if p995 <= p5:
p5, p995 = float(sl.min()), float(sl.max())
sl = np.clip(sl, p5, p995)
sl = (sl - p5) / (p995 - p5 + 1e-6)
inv2 = 1.0 - sl
ds = int(max(1, downsample))
inv2_ds = inv2[::ds, ::ds] if ds > 1 else inv2
resp2 = frangi(inv2_ds, sigmas=np.array([0.6, 1.2]), alpha=0.5, beta=0.5, gamma=15, black_ridges=True)
k = max(1, int(top_n) // max(1, len(zs)))
flat = resp2.ravel()
if flat.size == 0:
continue
idxs = np.argpartition(flat, -k)[-k:]
for idx in idxs:
r, c = divmod(int(idx), resp2.shape[1])
y_full = ymin + r * ds
x_full = xmin + c * ds
points.append((int(x_full), int(y_full), int(z)))
points = list({(x, y, z) for (x, y, z) in points}) # unique
return points[: int(top_n)]
def load_dicom_series(files):
datasets = []
for f in files or []:
try:
ds = pydicom.dcmread(f.name, force=True)
if hasattr(ds, 'pixel_array'):
datasets.append(ds)
except Exception:
continue
if not datasets:
raise ValueError('No valid DICOM slices uploaded')
# Sort
try:
datasets.sort(key=lambda x: float(x.SliceLocation) if hasattr(x, 'SliceLocation') else (
int(x.InstanceNumber) if hasattr(x, 'InstanceNumber') else 0))
except Exception:
pass
rows = int(datasets[0].Rows)
cols = int(datasets[0].Columns)
num = len(datasets)
vol = np.zeros((num, rows, cols), dtype=np.float32)
for i, ds in enumerate(datasets):
arr = ds.pixel_array.astype(np.float32)
slope = float(getattr(ds, 'RescaleSlope', 1.0))
intercept = float(getattr(ds, 'RescaleIntercept', 0.0))
vol[i] = arr * slope + intercept
# Spacing
sx = 1.0; sy = 1.0; sz = 1.0
ds0 = datasets[0]
if hasattr(ds0, 'PixelSpacing') and len(ds0.PixelSpacing) >= 2:
sy = float(ds0.PixelSpacing[0])
sx = float(ds0.PixelSpacing[1])
if hasattr(ds0, 'SpacingBetweenSlices'):
try:
sz = float(ds0.SpacingBetweenSlices)
except Exception:
pass
elif hasattr(ds0, 'SliceThickness'):
try:
sz = float(ds0.SliceThickness)
except Exception:
pass
return vol, (sx, sy, sz)
# Gradio app
with gr.Blocks(title="Dental AI - Hugging Face Space") as demo:
gr.Markdown("""
# Dental AI Demo (Ethical, Heuristic)
- Generate a synthetic CBCT-like volume and visualize in 3D and 2D.
- Run a fast heuristic root canal candidate detector (Frangi) — no fake ML.
- Or upload a DICOM series (multiple files) to visualize.
""")
vol_state = gr.State(None)
spacing_state = gr.State((1.0, 1.0, 1.0))
points_state = gr.State([])
with gr.Row():
gen_btn = gr.Button("🧪 Generate Synthetic Volume", variant="primary")
files = gr.File(label="Upload DICOM slices (multiple)", file_count="multiple")
with gr.Row():
threshold = gr.Slider(0, 3000, value=200, step=10, label="Bone Threshold (HU)")
slice_idx = gr.Slider(0, 1, value=0, step=1, label="Axial Slice")
with gr.Row():
detect_btn = gr.Button("🦷 Detect Root Canals (Fast 2.5D)")
downsample = gr.Slider(1, 8, value=3, step=1, label="Downsample")
topn = gr.Slider(5, 200, value=40, step=5, label="Top N Candidates")
sl_range = gr.Slider(5, 120, value=30, step=5, label="Slice Range")
sl_step = gr.Slider(1, 10, value=3, step=1, label="Slice Step")
with gr.Row():
fig3d = gr.Plot(label="3D Surface")
img2d = gr.Image(label="Axial Slice", type="numpy")
info = gr.Markdown(visible=True)
def _update_view(vol, spacing, thr, z, points):
if vol is None:
return gr.update(), gr.update()
fig = build_mesh_figure(vol, thr, spacing)
img = axial_slice_image(vol, int(z), points, level=400.0, width=1500.0)
return fig, img
def on_generate(thr):
vol, spacing = generate_synthetic_dental_volume()
depth = int(vol.shape[0])
points = []
fig, img = _update_view(vol, spacing, thr, depth // 2, points)
info = f"Generated synthetic volume: {vol.shape} with spacing {spacing}"
return vol, spacing, points, gr.update(minimum=0, maximum=depth - 1, value=depth // 2, step=1), fig, img, info
gen_btn.click(on_generate, inputs=[threshold],
outputs=[vol_state, spacing_state, points_state, slice_idx, fig3d, img2d, info])
def on_files(thr, files_list):
if not files_list:
return gr.update(), gr.update(), gr.update(), None, None, "No files uploaded"
try:
vol, spacing = load_dicom_series(files_list)
depth = int(vol.shape[0])
points = []
fig, img = _update_view(vol, spacing, thr, depth // 2, points)
info = f"Loaded DICOM series: {vol.shape} with spacing {spacing} (showing middle slice)"
return vol, spacing, points, gr.update(minimum=0, maximum=depth - 1, value=depth // 2, step=1), fig, img, info
except Exception as e:
return gr.update(), gr.update(), gr.update(), None, None, f"❌ Error: {e}"
files.change(on_files, inputs=[threshold, files],
outputs=[vol_state, spacing_state, points_state, slice_idx, fig3d, img2d, info])
def on_view_change(vol, spacing, thr, z, points):
return _update_view(vol, spacing, thr, z, points)
threshold.release(on_view_change, inputs=[vol_state, spacing_state, threshold, slice_idx, points_state], outputs=[fig3d, img2d])
slice_idx.release(on_view_change, inputs=[vol_state, spacing_state, threshold, slice_idx, points_state], outputs=[fig3d, img2d])
def on_detect(vol, spacing, thr, z, ds, tn, rge, stp):
if vol is None:
return gr.update(), gr.update(), []
points = detect_root_canals_fast_axial(vol, bone_threshold=thr, downsample=ds,
top_n=int(tn), center_index=int(z),
slice_range=int(rge), slice_step=int(stp))
fig, img = _update_view(vol, spacing, thr, z, points)
return fig, img, points
detect_btn.click(on_detect,
inputs=[vol_state, spacing_state, threshold, slice_idx, downsample, topn, sl_range, sl_step],
outputs=[fig3d, img2d, points_state])
if __name__ == "__main__":
port = int(os.environ.get("PORT", "7860"))
demo.launch(server_name="0.0.0.0", server_port=port)