BMD / app.py
Medvira's picture
Update app.py
abfc512 verified
import os
import numpy as np
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import pydicom
from ultralytics import YOLO
# Constants
PREVIEW_SZ = 640
RADIUS_PREVIEW = 10
NOISE_X_PREV = 600 # X-coordinate for noise circle in preview
NOISE_BOX_FACTOR = 3 # factor for noise box size (3×radius)
SNR_REF = 271.325 # Reference SNR (set value)
SD_REF = 50.15 # Reference SD (set value)
new_min = 0.0
new_max = 255.0
MODEL_PATH = os.path.join(os.getcwd(), "model.pt")
NORMAL_SAMPLES = [os.path.join(os.getcwd(), f"sample_n_{i}.dcm") for i in range(1,6)]
OSTEO_SAMPLES = [os.path.join(os.getcwd(), f"sample_o_{i}.dcm") for i in range(1,6)]
model = YOLO(MODEL_PATH)
# ------------------------------------------------------------ helpers
def dcm_to_preview(path: str, size=(PREVIEW_SZ, PREVIEW_SZ)) -> Image.Image:
ds = pydicom.dcmread(path)
arr = ds.pixel_array.astype(float)
rr = ((arr - arr.min())/(arr.max()-arr.min())*255).astype('uint8')
return Image.fromarray(arr).convert("RGB").resize(size)
def infer_and_prune(img, model, num_classes=6, conf=0.25, spatial_thresh=0.7):
"""
Performs inference and prunes detections while enforcing vertebral ordering and spatial consistency on-the-fly.
Args:
img: Image input for model inference.
model: Object detection model with .predict method.
num_classes: Expected number of vertebra classes (default 6: L5, L4, L3, L2, L1, T12).
conf: Confidence threshold for predictions.
spatial_thresh: Fractional tolerance for spatial check (default 0.7).
Verifies that the vertical distance between adjacent vertebra centers
lies within [height_sum * (1 - spatial_thresh), height_sum * (1 + spatial_thresh)].
Returns:
List of validated box dicts in descending vertebral order.
"""
# 1. Inference
res = model.predict(source=img, conf=conf)[0]
raw_boxes = []
for b in res.boxes:
cls = int(b.cls)
conf = float(b.conf)
x1, y1, x2, y2 = map(float, b.xyxy[0])
cy = (y1 + y2) / 2
h = y2 - y1
raw_boxes.append({'cls': cls, 'xy': (x1, y1, x2, y2), 'center_y': cy, 'height': h, 'conf':conf})
# 2. Sort by vertical position descending (bottom-up)
sorted_boxes = sorted(raw_boxes, key=lambda x: x['center_y'], reverse=True)
# 3. On-the-fly prune, order enforce & spatial check
desired_order = [4, 3, 2, 1, 0, 5] # L5, L4, L3, L2, L1, T12
ordered = []
prev = None
for box in sorted_boxes:
expected_cls = desired_order[len(ordered)]
# Spatial check for non-first vertebra
if prev is not None:
dist = abs(prev['center_y'] - box['center_y'])
height_sum = prev['height'] + box['height']
lower = height_sum - (height_sum * spatial_thresh) # lower bound
upper = height_sum + (height_sum * spatial_thresh) # upper bound
if not (lower <= dist <= upper):
print(f"Spatial check failed: {prev['cls']}->{box['cls']} (dist={dist:.3f}, "
f"expected in [{lower:.3f}, {upper:.3f}])")
continue # skip if spacing is off
if box['cls'] != expected_cls:
# Class mismatch, but spatial check passed, so relabel and accept
print(f"Relabeling cls {box['cls']} to expected {expected_cls} due to spatial match")
box['cls'] = expected_cls
# Accept this box
ordered.append(box)
prev = box
if len(ordered) == num_classes:
break
# Warn if any vertebra missing
if len(ordered) < num_classes:
missing = set(desired_order) - {b['cls'] for b in ordered}
print(f"Warning: missing classes after prune: {missing}")
return ordered
def mean_circle(img, cx, cy, r):
y, x = np.ogrid[-r:r+1, -r:r+1]
mask = x**2 + y**2 <= r**2
region = img[max(0, cy-r):cy+r+1, max(0, cx-r):cx+r+1]
return region[mask[:region.shape[0], :region.shape[1]]].mean() if region.size else 0
def sd_circle(img, cx, cy, r):
y, x = np.ogrid[-r:r+1, -r:r+1]
mask = x**2 + y**2 <= r**2 # Create a circular mask
# Define the region of interest (ROI) within the circle
region = img[max(0, cy-r):cy+r+1, max(0, cx-r):cx+r+1]
# Apply the mask to the region, and calculate the standard deviation of the pixel values
if region.size:
masked_values = region[mask[:region.shape[0], :region.shape[1]]]
return masked_values.std() # Return the standard deviation
else:
return 0 # If the region size is zero, return 0
def to_five(vals, pad=np.nan):
while len(vals) < 5:
vals.append(pad)
return vals[:5]
# ------------------------------------------------------------ core processing
def process_image(preview: Image.Image, raw_path: str):
if preview is None or raw_path is None:
blank = Image.new("RGB", (PREVIEW_SZ, PREVIEW_SZ))
return blank, *([np.nan] * 9), "Error: No image uploaded"
raw = (pydicom.dcmread(raw_path).pixel_array if raw_path.lower().endswith(".dcm")
else np.array(Image.open(raw_path).convert("L"))).astype(float)
# Normalize the pixel values to the range [0, 1]
raw_min = raw.min()
raw_max = raw.max()
raw_normalized = (raw - raw_min) / (raw_max - raw_min) # Scale to [0, 1]
# Scale to the desired range [new_min, new_max]
raw_normalized = raw_normalized * (new_max - new_min) + new_min
sy, sx = raw.shape[0] / PREVIEW_SZ, raw.shape[1] / PREVIEW_SZ
res = infer_and_prune(np.array(preview),model)
draw = ImageDraw.Draw(preview)
font = ImageFont.load_default()
means, centres, cls = [], [], []
class_dict = {0:'L1', 1:'L2', 2:'L3', 3:'L4', 4:'L5',5:'T12'}
for b in res:
v_cls = int(b['cls'])
x1, y1, x2, y2 = map(int, b['xy'])
if float(b['conf']) < 0.5:
continue
cx_p, cy_p = (x1 + x2) // 2, (y1 + y2) // 2
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
draw.text((x1-5,y1+5), class_dict[v_cls], fill="yellow", font=font)
cx_r, cy_r, r_r = int(cx_p * sx), int(cy_p * sy), int(RADIUS_PREVIEW * sx)
means.append(mean_circle(raw_normalized, cx_r, cy_r, r_r))
centres.append((cy_p, cx_p))
cls.append(v_cls)
means = to_five(means)
centres.sort(key=lambda t: t[0])
l3_cy_p = centres[2][0] if len(centres) >= 3 else PREVIEW_SZ // 2
half_box = RADIUS_PREVIEW * NOISE_BOX_FACTOR
draw.rectangle([NOISE_X_PREV - half_box, l3_cy_p - half_box,
NOISE_X_PREV + half_box, l3_cy_p + half_box],
outline="green", width=2)
noise_radius_preview = RADIUS_PREVIEW
draw.ellipse([NOISE_X_PREV - noise_radius_preview, l3_cy_p - noise_radius_preview,
NOISE_X_PREV + noise_radius_preview, l3_cy_p + noise_radius_preview],
outline="blue", width=2)
noise_mean = mean_circle(
raw_normalized,
int(NOISE_X_PREV * sx),
int(l3_cy_p * sy),
int(RADIUS_PREVIEW * sx)
)
noise_sd = sd_circle(
raw_normalized,
int(NOISE_X_PREV * sx),
int(l3_cy_p * sy),
int(RADIUS_PREVIEW * sx)
)
snr_l1_l4_list = []
for m_idx, m in enumerate(means):
if cls[m_idx] < 4:
cy_p, cx_p = centres[m_idx]
draw.ellipse([cx_p - RADIUS_PREVIEW, cy_p - RADIUS_PREVIEW,
cx_p + RADIUS_PREVIEW, cy_p + RADIUS_PREVIEW],
outline="blue", width=2)
snr_l1_l4_list.append(m/(noise_sd+0.005))
snr_l1_l4 = np.median(snr_l1_l4_list)
m_score = (snr_l1_l4 - SNR_REF) / SD_REF
if m_score < 1.26:
result = "Normal Density"
elif 1.26 <= m_score <= 2.05:
result = "Osteopenia"
else:
result = "Osteoporosis"
# Annotate result text and metrics
info_text = f"Result: {result}\n"
info_text += f"Means: {['{:.2f}'.format(m) for m in means]}\n"
info_text += f"Noise Mean: {noise_mean:.2f}\n"
info_text += f"SNR L1-L4: {snr_l1_l4:.2f}\n"
info_text += f"Score: {m_score:.2f}"
for i, line in enumerate(info_text.split('\n')):
draw.text((10, 10 + i * 15), line, fill="yellow", font=font)
return preview, *means, noise_mean, snr_l1_l4, m_score, result
# ------------------------------------------------------------ file→preview helpers
def choose_sample(type_choice: str, img_choice: str):
if type_choice == "Normal":
path = NORMAL_SAMPLES[int(img_choice.split()[1]) - 1]
else:
path = OSTEO_SAMPLES[int(img_choice.split()[1]) - 1]
return dcm_to_preview(path), path
def upload_file(file):
if file is None:
return None, None
p = file.name
preview = dcm_to_preview(p) if p.lower().endswith(".dcm") \
else Image.open(p).convert("RGB").resize((PREVIEW_SZ, PREVIEW_SZ))
return preview, p
# ------------------------------------------------------------ UI
def update_final_result(result):
if result == "Normal Density":
color = "green"
elif result == "Osteopenia":
color = "orange"
else:
color = "red"
return gr.update(value=result, elem_id="final_result", style=f"background-color:{color};")
with gr.Blocks() as demo:
raw_state = gr.State(None)
with gr.Row():
with gr.Column():
gr.Markdown("### Input image")
type_selector = gr.Dropdown(["Normal", "Osteoporosis"], value="Normal", label="Sample Type")
img_selector = gr.Dropdown(["Image 1", "Image 2", "Image 3","Image 4", "Image 5"], value="Image 1", label="Sample Image")
uploader = gr.File(label="Upload .png / .dcm", file_types=[".png", ".dcm"])
preview = gr.Image(type="pil", label="Preview", height=260)
btn = gr.Button("▶ Process image")
with gr.Column():
gr.Markdown("### Results")
with gr.Row():
l1 = gr.Number(label="L1 mean", precision=2)
l2 = gr.Number(label="L2 mean", precision=2)
l3 = gr.Number(label="L3 mean", precision=2)
with gr.Row():
l4 = gr.Number(label="L4 mean", precision=2)
l5 = gr.Number(label="L5 mean", precision=2)
ln = gr.Number(label="Noise mean", precision=2)
with gr.Row():
snr_l1_l4 = gr.Number(label="SNR L1-L4", precision=2)
mscore = gr.Number(label="M-Score", precision=2)
final_result = gr.Textbox(label="Bone Density Status", interactive=False)
type_selector.change(choose_sample, [type_selector, img_selector], [preview, raw_state])
img_selector.change(choose_sample, [type_selector, img_selector], [preview, raw_state])
uploader.change(upload_file, uploader, [preview, raw_state])
btn.click(process_image, [preview, raw_state],
[preview, l1, l2, l3, l4, l5, ln, snr_l1_l4, mscore, final_result])
btn.click(update_final_result, final_result, final_result)
demo.launch()