| | |
| | import os, random |
| | import numpy as np |
| | import torch |
| | import onnxruntime as ort |
| | from PIL import Image, ImageOps |
| | import gradio as gr |
| | import gdown |
| |
|
| | |
| | MODEL_ID = "18HYScsRJuRmfzL0E0BW35uaA542Vd5M5" |
| | MODEL_PATH = os.path.join(os.getcwd(), "bone_age_model.onnx") |
| | SAMPLES_DIR = os.path.join(os.getcwd(), "samples") |
| |
|
| | MEANS = np.array([0.485, 0.456, 0.406], dtype=np.float32)[:, None, None] |
| | STDS = np.array([0.229, 0.224, 0.225], dtype=np.float32)[:, None, None] |
| |
|
| | THEME = gr.themes.Soft( |
| | primary_hue="sky", neutral_hue="slate", font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif"] |
| | ) |
| |
|
| | CSS = """ |
| | #hero h1 { font-weight:800; letter-spacing:-0.02em; } |
| | #hero p { color: var(--body-text-color-subdued); } |
| | .card { border:1px solid var(--border-color-primary); border-radius:1.25rem; padding:1rem; background:var(--panel-background-fill); } |
| | .metric .label { color: var(--body-text-color-subdued); font-size:0.85rem; } |
| | footer { opacity:0.85 } |
| | """ |
| |
|
| | |
| | if not os.path.exists(MODEL_PATH): |
| | |
| | gdown.download(id=MODEL_ID, output=MODEL_PATH, quiet=False) |
| |
|
| | providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
| | try: |
| | session = ort.InferenceSession(MODEL_PATH, providers=providers) |
| | except Exception: |
| | session = ort.InferenceSession(MODEL_PATH) |
| |
|
| | |
| | def list_samples(): |
| | files = [f for f in os.listdir(SAMPLES_DIR) if f.endswith(".pth")] |
| | names = sorted(os.path.splitext(f)[0] for f in files) |
| | return names |
| |
|
| | SAMPLE_NAMES = list_samples() |
| |
|
| | def load_sample(name: str): |
| | path = os.path.join(SAMPLES_DIR, f"{name}.pth") |
| | sample = torch.load(path, weights_only=False, map_location="cpu") |
| | return sample |
| |
|
| | def tensor_to_pil(img3chw: np.ndarray) -> Image.Image: |
| | """ |
| | img3chw: float32 tensor in CHW normalized space [0..1] (preprocessed) |
| | Denormalize with ImageNet stats and convert to PIL. |
| | """ |
| | x = (img3chw * STDS + MEANS) |
| | x = np.clip(x, 0.0, 1.0) |
| | x = (x.transpose(1, 2, 0) * 255.0).round().astype(np.uint8) |
| | return Image.fromarray(x) |
| |
|
| | |
| | def predict(sample_name: str): |
| | if not sample_name: |
| | return None, None, None, None, "Select a sample to run prediction." |
| |
|
| | try: |
| | sample = load_sample(sample_name) |
| | true_age = float(sample["boneage"].item()) |
| |
|
| | |
| | |
| | x = sample["path"].numpy() |
| | outputs = session.run(None, {"input": x}) |
| | |
| | pred_age = float(outputs[0][0][0] * 41.172 + 127.329) |
| |
|
| | |
| | img = tensor_to_pil(sample["path"][0].numpy()) |
| | RESIZE_TO = (420, 420) |
| | img = ImageOps.contain(img, RESIZE_TO, method=Image.LANCZOS) |
| |
|
| | abs_err = abs(pred_age - true_age) |
| | delta = f"{'Over' if pred_age >= true_age else 'Under'} by {abs_err:.1f} months" |
| | note = "✅ Prediction complete." |
| |
|
| | return img, round(true_age, 1), round(pred_age, 1), round(abs_err, 1), note |
| | except Exception as e: |
| | return None, None, None, None, f"❌ Error: {e}" |
| |
|
| | def random_sample(): |
| | return random.choice(SAMPLE_NAMES) if SAMPLE_NAMES else None |
| |
|
| | |
| | with gr.Blocks(theme=THEME, css=CSS) as demo: |
| | with gr.Column(elem_id="hero"): |
| | gr.Markdown( |
| | """ |
| | # Bone Age Prediction |
| | A simple, elegant demo that loads a **pretrained ONNX model** and predicts bone age from a sample. |
| | """.strip() |
| | ) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=5): |
| | with gr.Group(elem_classes=["card"]): |
| | sample_dd = gr.Dropdown( |
| | choices=SAMPLE_NAMES, label="Sample", value=SAMPLE_NAMES[0] if SAMPLE_NAMES else None, |
| | allow_custom_value=False, filterable=True |
| | ) |
| | with gr.Row(): |
| | run_btn = gr.Button("🔮 Predict", variant="primary") |
| | rand_btn = gr.Button("🎲 Random sample") |
| | clear_btn = gr.Button("🧹 Clear") |
| |
|
| | status = gr.Markdown("", elem_id="status") |
| | with gr.Column(scale=4): |
| | with gr.Group(elem_classes=["card"]): |
| | gr.Markdown("### Model & Notes") |
| | gr.Markdown( |
| | """ |
| | - Backend: **ONNX Runtime** (CUDA if available, fallback to CPU) |
| | - Input: Pre-normalized 3×H×W tensor |
| | - Output mapping: `y = 41.172 * onnx_out + 127.329` |
| | - ⚠️ Educational demo — not for clinical use. |
| | """.strip() |
| | ) |
| | with gr.Group(elem_classes=["card"]): |
| | gr.Markdown("### Results") |
| | with gr.Row(): |
| | img_out = gr.Image(label="Input Image", type="pil", height=420, show_download_button=False) |
| | with gr.Row(): |
| | true_age = gr.Number(label="True Bone Age (months)", interactive=False) |
| | pred_age = gr.Number(label="Predicted Age (months)", interactive=False) |
| | abs_err = gr.Number(label="Absolute Error (months)", interactive=False) |
| | delta_lbl = gr.Label(label="Delta", show_label=True) |
| |
|
| |
|
| | |
| | sample_dd.change(predict, inputs=sample_dd, outputs=[img_out, true_age, pred_age, abs_err, status]) |
| | run_btn.click(predict, inputs=sample_dd, outputs=[img_out, true_age, pred_age, abs_err, status]) |
| | rand_btn.click(random_sample, outputs=sample_dd).then( |
| | predict, inputs=sample_dd, outputs=[img_out, true_age, pred_age, abs_err, status] |
| | ) |
| | clear_btn.click(lambda: (None, None, None, None, "Cleared."), |
| | outputs=[img_out, true_age, pred_age, abs_err, status]) |
| |
|
| | gr.Markdown("---") |
| | gr.Markdown("Built with ❤️ using Gradio.") |
| |
|
| | |
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|