import os import numpy as np import gradio as gr from PIL import Image, ImageDraw, ImageFont import pydicom from ultralytics import YOLO # Constants PREVIEW_SZ = 640 RADIUS_PREVIEW = 10 NOISE_X_PREV = 600 # X-coordinate for noise circle in preview NOISE_BOX_FACTOR = 3 # factor for noise box size (3×radius) SNR_REF = 271.325 # Reference SNR (set value) SD_REF = 50.15 # Reference SD (set value) new_min = 0.0 new_max = 255.0 MODEL_PATH = os.path.join(os.getcwd(), "model.pt") NORMAL_SAMPLES = [os.path.join(os.getcwd(), f"sample_n_{i}.dcm") for i in range(1,6)] OSTEO_SAMPLES = [os.path.join(os.getcwd(), f"sample_o_{i}.dcm") for i in range(1,6)] model = YOLO(MODEL_PATH) # ------------------------------------------------------------ helpers def dcm_to_preview(path: str, size=(PREVIEW_SZ, PREVIEW_SZ)) -> Image.Image: ds = pydicom.dcmread(path) arr = ds.pixel_array.astype(float) rr = ((arr - arr.min())/(arr.max()-arr.min())*255).astype('uint8') return Image.fromarray(arr).convert("RGB").resize(size) def infer_and_prune(img, model, num_classes=6, conf=0.25, spatial_thresh=0.7): """ Performs inference and prunes detections while enforcing vertebral ordering and spatial consistency on-the-fly. Args: img: Image input for model inference. model: Object detection model with .predict method. num_classes: Expected number of vertebra classes (default 6: L5, L4, L3, L2, L1, T12). conf: Confidence threshold for predictions. spatial_thresh: Fractional tolerance for spatial check (default 0.7). Verifies that the vertical distance between adjacent vertebra centers lies within [height_sum * (1 - spatial_thresh), height_sum * (1 + spatial_thresh)]. Returns: List of validated box dicts in descending vertebral order. """ # 1. Inference res = model.predict(source=img, conf=conf)[0] raw_boxes = [] for b in res.boxes: cls = int(b.cls) conf = float(b.conf) x1, y1, x2, y2 = map(float, b.xyxy[0]) cy = (y1 + y2) / 2 h = y2 - y1 raw_boxes.append({'cls': cls, 'xy': (x1, y1, x2, y2), 'center_y': cy, 'height': h, 'conf':conf}) # 2. Sort by vertical position descending (bottom-up) sorted_boxes = sorted(raw_boxes, key=lambda x: x['center_y'], reverse=True) # 3. On-the-fly prune, order enforce & spatial check desired_order = [4, 3, 2, 1, 0, 5] # L5, L4, L3, L2, L1, T12 ordered = [] prev = None for box in sorted_boxes: expected_cls = desired_order[len(ordered)] # Spatial check for non-first vertebra if prev is not None: dist = abs(prev['center_y'] - box['center_y']) height_sum = prev['height'] + box['height'] lower = height_sum - (height_sum * spatial_thresh) # lower bound upper = height_sum + (height_sum * spatial_thresh) # upper bound if not (lower <= dist <= upper): print(f"Spatial check failed: {prev['cls']}->{box['cls']} (dist={dist:.3f}, " f"expected in [{lower:.3f}, {upper:.3f}])") continue # skip if spacing is off if box['cls'] != expected_cls: # Class mismatch, but spatial check passed, so relabel and accept print(f"Relabeling cls {box['cls']} to expected {expected_cls} due to spatial match") box['cls'] = expected_cls # Accept this box ordered.append(box) prev = box if len(ordered) == num_classes: break # Warn if any vertebra missing if len(ordered) < num_classes: missing = set(desired_order) - {b['cls'] for b in ordered} print(f"Warning: missing classes after prune: {missing}") return ordered def mean_circle(img, cx, cy, r): y, x = np.ogrid[-r:r+1, -r:r+1] mask = x**2 + y**2 <= r**2 region = img[max(0, cy-r):cy+r+1, max(0, cx-r):cx+r+1] return region[mask[:region.shape[0], :region.shape[1]]].mean() if region.size else 0 def sd_circle(img, cx, cy, r): y, x = np.ogrid[-r:r+1, -r:r+1] mask = x**2 + y**2 <= r**2 # Create a circular mask # Define the region of interest (ROI) within the circle region = img[max(0, cy-r):cy+r+1, max(0, cx-r):cx+r+1] # Apply the mask to the region, and calculate the standard deviation of the pixel values if region.size: masked_values = region[mask[:region.shape[0], :region.shape[1]]] return masked_values.std() # Return the standard deviation else: return 0 # If the region size is zero, return 0 def to_five(vals, pad=np.nan): while len(vals) < 5: vals.append(pad) return vals[:5] # ------------------------------------------------------------ core processing def process_image(preview: Image.Image, raw_path: str): if preview is None or raw_path is None: blank = Image.new("RGB", (PREVIEW_SZ, PREVIEW_SZ)) return blank, *([np.nan] * 9), "Error: No image uploaded" raw = (pydicom.dcmread(raw_path).pixel_array if raw_path.lower().endswith(".dcm") else np.array(Image.open(raw_path).convert("L"))).astype(float) # Normalize the pixel values to the range [0, 1] raw_min = raw.min() raw_max = raw.max() raw_normalized = (raw - raw_min) / (raw_max - raw_min) # Scale to [0, 1] # Scale to the desired range [new_min, new_max] raw_normalized = raw_normalized * (new_max - new_min) + new_min sy, sx = raw.shape[0] / PREVIEW_SZ, raw.shape[1] / PREVIEW_SZ res = infer_and_prune(np.array(preview),model) draw = ImageDraw.Draw(preview) font = ImageFont.load_default() means, centres, cls = [], [], [] class_dict = {0:'L1', 1:'L2', 2:'L3', 3:'L4', 4:'L5',5:'T12'} for b in res: v_cls = int(b['cls']) x1, y1, x2, y2 = map(int, b['xy']) if float(b['conf']) < 0.5: continue cx_p, cy_p = (x1 + x2) // 2, (y1 + y2) // 2 draw.rectangle([x1, y1, x2, y2], outline="red", width=2) draw.text((x1-5,y1+5), class_dict[v_cls], fill="yellow", font=font) cx_r, cy_r, r_r = int(cx_p * sx), int(cy_p * sy), int(RADIUS_PREVIEW * sx) means.append(mean_circle(raw_normalized, cx_r, cy_r, r_r)) centres.append((cy_p, cx_p)) cls.append(v_cls) means = to_five(means) centres.sort(key=lambda t: t[0]) l3_cy_p = centres[2][0] if len(centres) >= 3 else PREVIEW_SZ // 2 half_box = RADIUS_PREVIEW * NOISE_BOX_FACTOR draw.rectangle([NOISE_X_PREV - half_box, l3_cy_p - half_box, NOISE_X_PREV + half_box, l3_cy_p + half_box], outline="green", width=2) noise_radius_preview = RADIUS_PREVIEW draw.ellipse([NOISE_X_PREV - noise_radius_preview, l3_cy_p - noise_radius_preview, NOISE_X_PREV + noise_radius_preview, l3_cy_p + noise_radius_preview], outline="blue", width=2) noise_mean = mean_circle( raw_normalized, int(NOISE_X_PREV * sx), int(l3_cy_p * sy), int(RADIUS_PREVIEW * sx) ) noise_sd = sd_circle( raw_normalized, int(NOISE_X_PREV * sx), int(l3_cy_p * sy), int(RADIUS_PREVIEW * sx) ) snr_l1_l4_list = [] for m_idx, m in enumerate(means): if cls[m_idx] < 4: cy_p, cx_p = centres[m_idx] draw.ellipse([cx_p - RADIUS_PREVIEW, cy_p - RADIUS_PREVIEW, cx_p + RADIUS_PREVIEW, cy_p + RADIUS_PREVIEW], outline="blue", width=2) snr_l1_l4_list.append(m/(noise_sd+0.005)) snr_l1_l4 = np.median(snr_l1_l4_list) m_score = (snr_l1_l4 - SNR_REF) / SD_REF if m_score < 1.26: result = "Normal Density" elif 1.26 <= m_score <= 2.05: result = "Osteopenia" else: result = "Osteoporosis" # Annotate result text and metrics info_text = f"Result: {result}\n" info_text += f"Means: {['{:.2f}'.format(m) for m in means]}\n" info_text += f"Noise Mean: {noise_mean:.2f}\n" info_text += f"SNR L1-L4: {snr_l1_l4:.2f}\n" info_text += f"Score: {m_score:.2f}" for i, line in enumerate(info_text.split('\n')): draw.text((10, 10 + i * 15), line, fill="yellow", font=font) return preview, *means, noise_mean, snr_l1_l4, m_score, result # ------------------------------------------------------------ file→preview helpers def choose_sample(type_choice: str, img_choice: str): if type_choice == "Normal": path = NORMAL_SAMPLES[int(img_choice.split()[1]) - 1] else: path = OSTEO_SAMPLES[int(img_choice.split()[1]) - 1] return dcm_to_preview(path), path def upload_file(file): if file is None: return None, None p = file.name preview = dcm_to_preview(p) if p.lower().endswith(".dcm") \ else Image.open(p).convert("RGB").resize((PREVIEW_SZ, PREVIEW_SZ)) return preview, p # ------------------------------------------------------------ UI def update_final_result(result): if result == "Normal Density": color = "green" elif result == "Osteopenia": color = "orange" else: color = "red" return gr.update(value=result, elem_id="final_result", style=f"background-color:{color};") with gr.Blocks() as demo: raw_state = gr.State(None) with gr.Row(): with gr.Column(): gr.Markdown("### Input image") type_selector = gr.Dropdown(["Normal", "Osteoporosis"], value="Normal", label="Sample Type") img_selector = gr.Dropdown(["Image 1", "Image 2", "Image 3","Image 4", "Image 5"], value="Image 1", label="Sample Image") uploader = gr.File(label="Upload .png / .dcm", file_types=[".png", ".dcm"]) preview = gr.Image(type="pil", label="Preview", height=260) btn = gr.Button("▶ Process image") with gr.Column(): gr.Markdown("### Results") with gr.Row(): l1 = gr.Number(label="L1 mean", precision=2) l2 = gr.Number(label="L2 mean", precision=2) l3 = gr.Number(label="L3 mean", precision=2) with gr.Row(): l4 = gr.Number(label="L4 mean", precision=2) l5 = gr.Number(label="L5 mean", precision=2) ln = gr.Number(label="Noise mean", precision=2) with gr.Row(): snr_l1_l4 = gr.Number(label="SNR L1-L4", precision=2) mscore = gr.Number(label="M-Score", precision=2) final_result = gr.Textbox(label="Bone Density Status", interactive=False) type_selector.change(choose_sample, [type_selector, img_selector], [preview, raw_state]) img_selector.change(choose_sample, [type_selector, img_selector], [preview, raw_state]) uploader.change(upload_file, uploader, [preview, raw_state]) btn.click(process_image, [preview, raw_state], [preview, l1, l2, l3, l4, l5, ln, snr_l1_l4, mscore, final_result]) btn.click(update_final_result, final_result, final_result) demo.launch()