| | import os |
| | import numpy as np |
| | import gradio as gr |
| | from PIL import Image, ImageDraw, ImageFont |
| | import pydicom |
| | from ultralytics import YOLO |
| |
|
| | |
| | PREVIEW_SZ = 640 |
| | RADIUS_PREVIEW = 10 |
| | NOISE_X_PREV = 600 |
| | NOISE_BOX_FACTOR = 3 |
| | SNR_REF = 271.325 |
| | SD_REF = 50.15 |
| | new_min = 0.0 |
| | new_max = 255.0 |
| | MODEL_PATH = os.path.join(os.getcwd(), "model.pt") |
| | NORMAL_SAMPLES = [os.path.join(os.getcwd(), f"sample_n_{i}.dcm") for i in range(1,6)] |
| | OSTEO_SAMPLES = [os.path.join(os.getcwd(), f"sample_o_{i}.dcm") for i in range(1,6)] |
| |
|
| | model = YOLO(MODEL_PATH) |
| |
|
| | |
| | def dcm_to_preview(path: str, size=(PREVIEW_SZ, PREVIEW_SZ)) -> Image.Image: |
| | ds = pydicom.dcmread(path) |
| | arr = ds.pixel_array.astype(float) |
| | rr = ((arr - arr.min())/(arr.max()-arr.min())*255).astype('uint8') |
| | return Image.fromarray(arr).convert("RGB").resize(size) |
| |
|
| | def infer_and_prune(img, model, num_classes=6, conf=0.25, spatial_thresh=0.7): |
| | """ |
| | Performs inference and prunes detections while enforcing vertebral ordering and spatial consistency on-the-fly. |
| | |
| | Args: |
| | img: Image input for model inference. |
| | model: Object detection model with .predict method. |
| | num_classes: Expected number of vertebra classes (default 6: L5, L4, L3, L2, L1, T12). |
| | conf: Confidence threshold for predictions. |
| | spatial_thresh: Fractional tolerance for spatial check (default 0.7). |
| | Verifies that the vertical distance between adjacent vertebra centers |
| | lies within [height_sum * (1 - spatial_thresh), height_sum * (1 + spatial_thresh)]. |
| | |
| | Returns: |
| | List of validated box dicts in descending vertebral order. |
| | """ |
| | |
| | res = model.predict(source=img, conf=conf)[0] |
| | raw_boxes = [] |
| | for b in res.boxes: |
| | cls = int(b.cls) |
| | conf = float(b.conf) |
| | x1, y1, x2, y2 = map(float, b.xyxy[0]) |
| | cy = (y1 + y2) / 2 |
| | h = y2 - y1 |
| | raw_boxes.append({'cls': cls, 'xy': (x1, y1, x2, y2), 'center_y': cy, 'height': h, 'conf':conf}) |
| |
|
| | |
| | sorted_boxes = sorted(raw_boxes, key=lambda x: x['center_y'], reverse=True) |
| |
|
| | |
| | desired_order = [4, 3, 2, 1, 0, 5] |
| | ordered = [] |
| | prev = None |
| | for box in sorted_boxes: |
| | expected_cls = desired_order[len(ordered)] |
| |
|
| | |
| | if prev is not None: |
| | dist = abs(prev['center_y'] - box['center_y']) |
| | height_sum = prev['height'] + box['height'] |
| | lower = height_sum - (height_sum * spatial_thresh) |
| | upper = height_sum + (height_sum * spatial_thresh) |
| | if not (lower <= dist <= upper): |
| | print(f"Spatial check failed: {prev['cls']}->{box['cls']} (dist={dist:.3f}, " |
| | f"expected in [{lower:.3f}, {upper:.3f}])") |
| | continue |
| |
|
| | if box['cls'] != expected_cls: |
| | |
| | print(f"Relabeling cls {box['cls']} to expected {expected_cls} due to spatial match") |
| | box['cls'] = expected_cls |
| |
|
| | |
| | ordered.append(box) |
| | prev = box |
| | if len(ordered) == num_classes: |
| | break |
| |
|
| | |
| | if len(ordered) < num_classes: |
| | missing = set(desired_order) - {b['cls'] for b in ordered} |
| | print(f"Warning: missing classes after prune: {missing}") |
| |
|
| | return ordered |
| |
|
| | def mean_circle(img, cx, cy, r): |
| | y, x = np.ogrid[-r:r+1, -r:r+1] |
| | mask = x**2 + y**2 <= r**2 |
| | region = img[max(0, cy-r):cy+r+1, max(0, cx-r):cx+r+1] |
| | return region[mask[:region.shape[0], :region.shape[1]]].mean() if region.size else 0 |
| |
|
| | def sd_circle(img, cx, cy, r): |
| | y, x = np.ogrid[-r:r+1, -r:r+1] |
| | mask = x**2 + y**2 <= r**2 |
| | |
| | |
| | region = img[max(0, cy-r):cy+r+1, max(0, cx-r):cx+r+1] |
| | |
| | |
| | if region.size: |
| | masked_values = region[mask[:region.shape[0], :region.shape[1]]] |
| | return masked_values.std() |
| | else: |
| | return 0 |
| |
|
| | def to_five(vals, pad=np.nan): |
| | while len(vals) < 5: |
| | vals.append(pad) |
| | return vals[:5] |
| |
|
| | |
| | def process_image(preview: Image.Image, raw_path: str): |
| | if preview is None or raw_path is None: |
| | blank = Image.new("RGB", (PREVIEW_SZ, PREVIEW_SZ)) |
| | return blank, *([np.nan] * 9), "Error: No image uploaded" |
| |
|
| | raw = (pydicom.dcmread(raw_path).pixel_array if raw_path.lower().endswith(".dcm") |
| | else np.array(Image.open(raw_path).convert("L"))).astype(float) |
| | |
| | raw_min = raw.min() |
| | raw_max = raw.max() |
| | raw_normalized = (raw - raw_min) / (raw_max - raw_min) |
| | |
| | raw_normalized = raw_normalized * (new_max - new_min) + new_min |
| | sy, sx = raw.shape[0] / PREVIEW_SZ, raw.shape[1] / PREVIEW_SZ |
| |
|
| | res = infer_and_prune(np.array(preview),model) |
| | draw = ImageDraw.Draw(preview) |
| | font = ImageFont.load_default() |
| | means, centres, cls = [], [], [] |
| | class_dict = {0:'L1', 1:'L2', 2:'L3', 3:'L4', 4:'L5',5:'T12'} |
| | for b in res: |
| | v_cls = int(b['cls']) |
| | x1, y1, x2, y2 = map(int, b['xy']) |
| | if float(b['conf']) < 0.5: |
| | continue |
| | cx_p, cy_p = (x1 + x2) // 2, (y1 + y2) // 2 |
| | draw.rectangle([x1, y1, x2, y2], outline="red", width=2) |
| | draw.text((x1-5,y1+5), class_dict[v_cls], fill="yellow", font=font) |
| | cx_r, cy_r, r_r = int(cx_p * sx), int(cy_p * sy), int(RADIUS_PREVIEW * sx) |
| | means.append(mean_circle(raw_normalized, cx_r, cy_r, r_r)) |
| | centres.append((cy_p, cx_p)) |
| | cls.append(v_cls) |
| |
|
| | means = to_five(means) |
| | centres.sort(key=lambda t: t[0]) |
| | l3_cy_p = centres[2][0] if len(centres) >= 3 else PREVIEW_SZ // 2 |
| |
|
| | half_box = RADIUS_PREVIEW * NOISE_BOX_FACTOR |
| | draw.rectangle([NOISE_X_PREV - half_box, l3_cy_p - half_box, |
| | NOISE_X_PREV + half_box, l3_cy_p + half_box], |
| | outline="green", width=2) |
| | noise_radius_preview = RADIUS_PREVIEW |
| | draw.ellipse([NOISE_X_PREV - noise_radius_preview, l3_cy_p - noise_radius_preview, |
| | NOISE_X_PREV + noise_radius_preview, l3_cy_p + noise_radius_preview], |
| | outline="blue", width=2) |
| | noise_mean = mean_circle( |
| | raw_normalized, |
| | int(NOISE_X_PREV * sx), |
| | int(l3_cy_p * sy), |
| | int(RADIUS_PREVIEW * sx) |
| | ) |
| | noise_sd = sd_circle( |
| | raw_normalized, |
| | int(NOISE_X_PREV * sx), |
| | int(l3_cy_p * sy), |
| | int(RADIUS_PREVIEW * sx) |
| | ) |
| | snr_l1_l4_list = [] |
| | for m_idx, m in enumerate(means): |
| | if cls[m_idx] < 4: |
| | cy_p, cx_p = centres[m_idx] |
| | draw.ellipse([cx_p - RADIUS_PREVIEW, cy_p - RADIUS_PREVIEW, |
| | cx_p + RADIUS_PREVIEW, cy_p + RADIUS_PREVIEW], |
| | outline="blue", width=2) |
| | snr_l1_l4_list.append(m/(noise_sd+0.005)) |
| | snr_l1_l4 = np.median(snr_l1_l4_list) |
| | m_score = (snr_l1_l4 - SNR_REF) / SD_REF |
| |
|
| | if m_score < 1.26: |
| | result = "Normal Density" |
| | elif 1.26 <= m_score <= 2.05: |
| | result = "Osteopenia" |
| | else: |
| | result = "Osteoporosis" |
| |
|
| | |
| | info_text = f"Result: {result}\n" |
| | info_text += f"Means: {['{:.2f}'.format(m) for m in means]}\n" |
| | info_text += f"Noise Mean: {noise_mean:.2f}\n" |
| | info_text += f"SNR L1-L4: {snr_l1_l4:.2f}\n" |
| | info_text += f"Score: {m_score:.2f}" |
| |
|
| | for i, line in enumerate(info_text.split('\n')): |
| | draw.text((10, 10 + i * 15), line, fill="yellow", font=font) |
| | return preview, *means, noise_mean, snr_l1_l4, m_score, result |
| |
|
| | |
| | def choose_sample(type_choice: str, img_choice: str): |
| | if type_choice == "Normal": |
| | path = NORMAL_SAMPLES[int(img_choice.split()[1]) - 1] |
| | else: |
| | path = OSTEO_SAMPLES[int(img_choice.split()[1]) - 1] |
| | return dcm_to_preview(path), path |
| |
|
| | def upload_file(file): |
| | if file is None: |
| | return None, None |
| | p = file.name |
| | preview = dcm_to_preview(p) if p.lower().endswith(".dcm") \ |
| | else Image.open(p).convert("RGB").resize((PREVIEW_SZ, PREVIEW_SZ)) |
| | return preview, p |
| |
|
| | |
| | def update_final_result(result): |
| | if result == "Normal Density": |
| | color = "green" |
| | elif result == "Osteopenia": |
| | color = "orange" |
| | else: |
| | color = "red" |
| | return gr.update(value=result, elem_id="final_result", style=f"background-color:{color};") |
| |
|
| | with gr.Blocks() as demo: |
| | raw_state = gr.State(None) |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | gr.Markdown("### Input image") |
| | type_selector = gr.Dropdown(["Normal", "Osteoporosis"], value="Normal", label="Sample Type") |
| | img_selector = gr.Dropdown(["Image 1", "Image 2", "Image 3","Image 4", "Image 5"], value="Image 1", label="Sample Image") |
| | uploader = gr.File(label="Upload .png / .dcm", file_types=[".png", ".dcm"]) |
| | preview = gr.Image(type="pil", label="Preview", height=260) |
| | btn = gr.Button("▶ Process image") |
| | with gr.Column(): |
| | gr.Markdown("### Results") |
| | with gr.Row(): |
| | l1 = gr.Number(label="L1 mean", precision=2) |
| | l2 = gr.Number(label="L2 mean", precision=2) |
| | l3 = gr.Number(label="L3 mean", precision=2) |
| | with gr.Row(): |
| | l4 = gr.Number(label="L4 mean", precision=2) |
| | l5 = gr.Number(label="L5 mean", precision=2) |
| | ln = gr.Number(label="Noise mean", precision=2) |
| | with gr.Row(): |
| | snr_l1_l4 = gr.Number(label="SNR L1-L4", precision=2) |
| | mscore = gr.Number(label="M-Score", precision=2) |
| | final_result = gr.Textbox(label="Bone Density Status", interactive=False) |
| |
|
| | type_selector.change(choose_sample, [type_selector, img_selector], [preview, raw_state]) |
| | img_selector.change(choose_sample, [type_selector, img_selector], [preview, raw_state]) |
| | uploader.change(upload_file, uploader, [preview, raw_state]) |
| | btn.click(process_image, [preview, raw_state], |
| | [preview, l1, l2, l3, l4, l5, ln, snr_l1_l4, mscore, final_result]) |
| | btn.click(update_final_result, final_result, final_result) |
| |
|
| | demo.launch() |
| |
|