bone_age / app.py
Medvira's picture
Update app.py
17d78a7 verified
# Beautiful Bone Age App (Gradio 4.x)
import os, random
import numpy as np
import torch
import onnxruntime as ort
from PIL import Image, ImageOps
import gradio as gr
import gdown
# -------------------- Config --------------------
MODEL_ID = "18HYScsRJuRmfzL0E0BW35uaA542Vd5M5" # Google Drive file id
MODEL_PATH = os.path.join(os.getcwd(), "bone_age_model.onnx")
SAMPLES_DIR = os.path.join(os.getcwd(), "samples")
MEANS = np.array([0.485, 0.456, 0.406], dtype=np.float32)[:, None, None]
STDS = np.array([0.229, 0.224, 0.225], dtype=np.float32)[:, None, None]
THEME = gr.themes.Soft(
primary_hue="sky", neutral_hue="slate", font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif"]
)
CSS = """
#hero h1 { font-weight:800; letter-spacing:-0.02em; }
#hero p { color: var(--body-text-color-subdued); }
.card { border:1px solid var(--border-color-primary); border-radius:1.25rem; padding:1rem; background:var(--panel-background-fill); }
.metric .label { color: var(--body-text-color-subdued); font-size:0.85rem; }
footer { opacity:0.85 }
"""
# -------------------- Model init --------------------
if not os.path.exists(MODEL_PATH):
# Robust download via file id (works better than view URL)
gdown.download(id=MODEL_ID, output=MODEL_PATH, quiet=False)
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
try:
session = ort.InferenceSession(MODEL_PATH, providers=providers)
except Exception:
session = ort.InferenceSession(MODEL_PATH)
# -------------------- Data helpers --------------------
def list_samples():
files = [f for f in os.listdir(SAMPLES_DIR) if f.endswith(".pth")]
names = sorted(os.path.splitext(f)[0] for f in files)
return names
SAMPLE_NAMES = list_samples()
def load_sample(name: str):
path = os.path.join(SAMPLES_DIR, f"{name}.pth")
sample = torch.load(path, weights_only=False, map_location="cpu")
return sample
def tensor_to_pil(img3chw: np.ndarray) -> Image.Image:
"""
img3chw: float32 tensor in CHW normalized space [0..1] (preprocessed)
Denormalize with ImageNet stats and convert to PIL.
"""
x = (img3chw * STDS + MEANS) # CHW
x = np.clip(x, 0.0, 1.0)
x = (x.transpose(1, 2, 0) * 255.0).round().astype(np.uint8) # HWC
return Image.fromarray(x)
# -------------------- Inference --------------------
def predict(sample_name: str):
if not sample_name:
return None, None, None, None, "Select a sample to run prediction."
try:
sample = load_sample(sample_name)
true_age = float(sample["boneage"].item()) # months
# ONNX expects NCHW float input; sample['path'] already tensor-like
# Convert to numpy on CPU
x = sample["path"].numpy() # shape [1, 3, H, W]
outputs = session.run(None, {"input": x})
# Linear mapping (same as your code)
pred_age = float(outputs[0][0][0] * 41.172 + 127.329)
# Visual
img = tensor_to_pil(sample["path"][0].numpy())
RESIZE_TO = (420, 420)
img = ImageOps.contain(img, RESIZE_TO, method=Image.LANCZOS)
abs_err = abs(pred_age - true_age)
delta = f"{'Over' if pred_age >= true_age else 'Under'} by {abs_err:.1f} months"
note = "✅ Prediction complete."
return img, round(true_age, 1), round(pred_age, 1), round(abs_err, 1), note
except Exception as e:
return None, None, None, None, f"❌ Error: {e}"
def random_sample():
return random.choice(SAMPLE_NAMES) if SAMPLE_NAMES else None
# -------------------- UI --------------------
with gr.Blocks(theme=THEME, css=CSS) as demo:
with gr.Column(elem_id="hero"):
gr.Markdown(
"""
# Bone Age Prediction
A simple, elegant demo that loads a **pretrained ONNX model** and predicts bone age from a sample.
""".strip()
)
with gr.Row():
with gr.Column(scale=5):
with gr.Group(elem_classes=["card"]):
sample_dd = gr.Dropdown(
choices=SAMPLE_NAMES, label="Sample", value=SAMPLE_NAMES[0] if SAMPLE_NAMES else None,
allow_custom_value=False, filterable=True
)
with gr.Row():
run_btn = gr.Button("🔮 Predict", variant="primary")
rand_btn = gr.Button("🎲 Random sample")
clear_btn = gr.Button("🧹 Clear")
status = gr.Markdown("", elem_id="status")
with gr.Column(scale=4):
with gr.Group(elem_classes=["card"]):
gr.Markdown("### Model & Notes")
gr.Markdown(
"""
- Backend: **ONNX Runtime** (CUDA if available, fallback to CPU)
- Input: Pre-normalized 3×H×W tensor
- Output mapping: `y = 41.172 * onnx_out + 127.329`
- ⚠️ Educational demo — not for clinical use.
""".strip()
)
with gr.Group(elem_classes=["card"]):
gr.Markdown("### Results")
with gr.Row():
img_out = gr.Image(label="Input Image", type="pil", height=420, show_download_button=False)
with gr.Row():
true_age = gr.Number(label="True Bone Age (months)", interactive=False)
pred_age = gr.Number(label="Predicted Age (months)", interactive=False)
abs_err = gr.Number(label="Absolute Error (months)", interactive=False)
delta_lbl = gr.Label(label="Delta", show_label=True)
# Events
sample_dd.change(predict, inputs=sample_dd, outputs=[img_out, true_age, pred_age, abs_err, status])
run_btn.click(predict, inputs=sample_dd, outputs=[img_out, true_age, pred_age, abs_err, status])
rand_btn.click(random_sample, outputs=sample_dd).then(
predict, inputs=sample_dd, outputs=[img_out, true_age, pred_age, abs_err, status]
)
clear_btn.click(lambda: (None, None, None, None, "Cleared."),
outputs=[img_out, true_age, pred_age, abs_err, status])
gr.Markdown("---")
gr.Markdown("Built with ❤️ using Gradio.")
# Launch
if __name__ == "__main__":
demo.launch()