MuhammadAdil63's picture
deploy YOLACT+ fracture detection demo
fcec417
Raw
History Blame Contribute Delete
8.28 kB
"""
FracAtlas YOLACT+ Demo
======================
Gradio app for fracture detection and segmentation on X-ray images.
Deployed on Hugging Face Spaces: https://huggingface.co/spaces/MuhammadAdil63/FracAtlas-YOLACT
"""
import os
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import gradio as gr
from huggingface_hub import hf_hub_download
import albumentations as A
from albumentations.pytorch import ToTensorV2
from model import YOLACTPlus
# ─── Config ───────────────────────────────────────────────────────────────────
IMG_SIZE = 550
NUM_CLASSES = 1
CLASS_NAMES = ["fracture"]
SCORE_THRESH = 0.4
NMS_THRESH = 0.4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# HuggingFace repo where best.pth is hosted
HF_REPO_ID = "MuhammadAdil63/FracAtlas-YOLACT" # change if model hosted separately
CKPT_FILE = "best.pth"
# Mask overlay colour (red in RGB)
MASK_COLOR = (220, 50, 50)
BOX_COLOR = (220, 50, 50)
# ─── Load model (cached after first load) ────────────────────────────────────
def load_model():
print(f"[INFO] Loading model on {DEVICE} ...")
# Try local first, then download from HF Hub
if os.path.isfile(CKPT_FILE):
ckpt_path = CKPT_FILE
print(f"[INFO] Using local checkpoint: {ckpt_path}")
else:
print(f"[INFO] Downloading checkpoint from HF Hub ...")
ckpt_path = hf_hub_download(repo_id=HF_REPO_ID, filename=CKPT_FILE)
model = YOLACTPlus(num_classes=NUM_CLASSES, img_size=IMG_SIZE, pretrained=False)
ckpt = torch.load(ckpt_path, map_location=DEVICE)
model.load_state_dict(ckpt.get("model", ckpt))
model.to(DEVICE).eval()
print("[INFO] Model loaded successfully!")
return model
# Load once at startup
MODEL = load_model()
# ─── Preprocessing ────────────────────────────────────────────────────────────
def preprocess(image_rgb: np.ndarray):
transform = A.Compose([
A.LongestMaxSize(max_size=IMG_SIZE),
A.PadIfNeeded(min_height=IMG_SIZE, min_width=IMG_SIZE, fill=0),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
out = transform(image=image_rgb)
tensor = out["image"].unsqueeze(0).to(DEVICE)
return tensor
# ─── Draw predictions ─────────────────────────────────────────────────────────
def draw_predictions(image_rgb: np.ndarray, result: dict) -> np.ndarray:
vis = image_rgb.copy().astype(np.float32)
H, W = vis.shape[:2]
boxes = result["boxes"] # [N, 4] normalised
scores = result["scores"] # [N]
masks = result["masks"] # [N, IMG_SIZE, IMG_SIZE]
for i in range(len(scores)):
# Mask overlay
mask_np = masks[i].numpy()
mask_rs = cv2.resize(mask_np, (W, H), interpolation=cv2.INTER_LINEAR)
mask_bin = (mask_rs > 0.5).astype(np.float32)
colour = np.array(MASK_COLOR, dtype=np.float32)
for c in range(3):
vis[:, :, c] = vis[:, :, c] * (1 - 0.45 * mask_bin) + \
colour[c] * (0.45 * mask_bin)
# Bounding box
x1 = int(boxes[i, 0].item() * W)
y1 = int(boxes[i, 1].item() * H)
x2 = int(boxes[i, 2].item() * W)
y2 = int(boxes[i, 3].item() * H)
vis_u8 = np.clip(vis, 0, 255).astype(np.uint8)
cv2.rectangle(vis_u8, (x1, y1), (x2, y2), BOX_COLOR, 2)
# Label
score = scores[i].item()
tag = f"fracture: {score:.2f}"
(tw, th), _ = cv2.getTextSize(tag, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
cv2.rectangle(vis_u8, (x1, y1 - th - 8), (x1 + tw + 4, y1), BOX_COLOR, -1)
cv2.putText(vis_u8, tag, (x1 + 2, y1 - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA)
vis = vis_u8.astype(np.float32)
return np.clip(vis, 0, 255).astype(np.uint8)
# ─── Main inference function ──────────────────────────────────────────────────
def predict(image: np.ndarray, score_threshold: float, nms_threshold: float):
"""
Gradio inference function.
Args:
image : RGB numpy array from Gradio
score_threshold : confidence threshold slider
nms_threshold : NMS IoU threshold slider
Returns:
annotated image, result summary text
"""
if image is None:
return None, "No image provided."
orig_rgb = image.copy()
tensor = preprocess(orig_rgb)
with torch.no_grad():
results = MODEL.predict(tensor, score_threshold, nms_threshold)
res = results[0]
n = len(res["scores"])
# Draw
annotated = draw_predictions(orig_rgb, res)
# Summary text
if n == 0:
summary = "**No fractures detected.**\n\nTry lowering the Score Threshold."
else:
lines = [f"**{n} fracture(s) detected:**\n"]
for i in range(n):
score = res["scores"][i].item()
box = res["boxes"][i].tolist()
lines.append(
f"- Detection {i+1}: confidence **{score:.3f}** | "
f"box `[{box[0]:.3f}, {box[1]:.3f}, {box[2]:.3f}, {box[3]:.3f}]`"
)
summary = "\n".join(lines)
return annotated, summary
# ─── Gradio UI ────────────────────────────────────────────────────────────────
DESCRIPTION = """
## FracAtlas Fracture Detection β€” YOLACT+ (ResNet-18)
Upload an X-ray image to detect and segment bone fractures.
**Model:** YOLACT+ with ResNet-18 backbone
**Dataset:** [FracAtlas](https://figshare.com/articles/dataset/The_dataset/22363012) β€” 717 fractured + 3366 non-fractured X-rays
**Training:** 200 epochs | AdamW | Cosine LR decay with warmup
**Val F1:** 0.537 | **Val Avg IoU:** 0.940
"""
EXAMPLES = [
["examples/fractured_example.jpg", 0.4, 0.4],
["examples/nonfractured_example.jpg", 0.4, 0.4],
]
with gr.Blocks(theme=gr.themes.Soft(), title="FracAtlas YOLACT+") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
label="Input X-ray Image",
type="numpy",
image_mode="RGB",
)
with gr.Accordion("Detection Settings", open=False):
score_thresh = gr.Slider(
minimum=0.1, maximum=0.9, value=0.4, step=0.05,
label="Score Threshold",
info="Higher = fewer but more confident detections",
)
nms_thresh = gr.Slider(
minimum=0.1, maximum=0.9, value=0.4, step=0.05,
label="NMS Threshold",
info="Lower = suppress more overlapping boxes",
)
run_btn = gr.Button("Detect Fractures", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(
label="Predicted Segmentation",
type="numpy",
)
output_text = gr.Markdown(label="Detection Summary")
run_btn.click(
fn=predict,
inputs=[input_image, score_thresh, nms_thresh],
outputs=[output_image, output_text],
)
# Also run on image upload
input_image.change(
fn=predict,
inputs=[input_image, score_thresh, nms_thresh],
outputs=[output_image, output_text],
)
gr.Markdown("""
---
**Note:** This is a research demo. Not intended for clinical use.
**Author:** Muhammad Adil | MS Data Science, ITU Lahore
**GitHub:** [Adil6312](https://github.com/Adil6312)
""")
if __name__ == "__main__":
demo.launch()