""" FracAtlas YOLACT+ Demo ====================== Gradio app for fracture detection and segmentation on X-ray images. Deployed on Hugging Face Spaces: https://huggingface.co/spaces/MuhammadAdil63/FracAtlas-YOLACT """ import os import cv2 import numpy as np import torch import torch.nn.functional as F import gradio as gr from huggingface_hub import hf_hub_download import albumentations as A from albumentations.pytorch import ToTensorV2 from model import YOLACTPlus # ─── Config ─────────────────────────────────────────────────────────────────── IMG_SIZE = 550 NUM_CLASSES = 1 CLASS_NAMES = ["fracture"] SCORE_THRESH = 0.4 NMS_THRESH = 0.4 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # HuggingFace repo where best.pth is hosted HF_REPO_ID = "MuhammadAdil63/FracAtlas-YOLACT" # change if model hosted separately CKPT_FILE = "best.pth" # Mask overlay colour (red in RGB) MASK_COLOR = (220, 50, 50) BOX_COLOR = (220, 50, 50) # ─── Load model (cached after first load) ──────────────────────────────────── def load_model(): print(f"[INFO] Loading model on {DEVICE} ...") # Try local first, then download from HF Hub if os.path.isfile(CKPT_FILE): ckpt_path = CKPT_FILE print(f"[INFO] Using local checkpoint: {ckpt_path}") else: print(f"[INFO] Downloading checkpoint from HF Hub ...") ckpt_path = hf_hub_download(repo_id=HF_REPO_ID, filename=CKPT_FILE) model = YOLACTPlus(num_classes=NUM_CLASSES, img_size=IMG_SIZE, pretrained=False) ckpt = torch.load(ckpt_path, map_location=DEVICE) model.load_state_dict(ckpt.get("model", ckpt)) model.to(DEVICE).eval() print("[INFO] Model loaded successfully!") return model # Load once at startup MODEL = load_model() # ─── Preprocessing ──────────────────────────────────────────────────────────── def preprocess(image_rgb: np.ndarray): transform = A.Compose([ A.LongestMaxSize(max_size=IMG_SIZE), A.PadIfNeeded(min_height=IMG_SIZE, min_width=IMG_SIZE, fill=0), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ]) out = transform(image=image_rgb) tensor = out["image"].unsqueeze(0).to(DEVICE) return tensor # ─── Draw predictions ───────────────────────────────────────────────────────── def draw_predictions(image_rgb: np.ndarray, result: dict) -> np.ndarray: vis = image_rgb.copy().astype(np.float32) H, W = vis.shape[:2] boxes = result["boxes"] # [N, 4] normalised scores = result["scores"] # [N] masks = result["masks"] # [N, IMG_SIZE, IMG_SIZE] for i in range(len(scores)): # Mask overlay mask_np = masks[i].numpy() mask_rs = cv2.resize(mask_np, (W, H), interpolation=cv2.INTER_LINEAR) mask_bin = (mask_rs > 0.5).astype(np.float32) colour = np.array(MASK_COLOR, dtype=np.float32) for c in range(3): vis[:, :, c] = vis[:, :, c] * (1 - 0.45 * mask_bin) + \ colour[c] * (0.45 * mask_bin) # Bounding box x1 = int(boxes[i, 0].item() * W) y1 = int(boxes[i, 1].item() * H) x2 = int(boxes[i, 2].item() * W) y2 = int(boxes[i, 3].item() * H) vis_u8 = np.clip(vis, 0, 255).astype(np.uint8) cv2.rectangle(vis_u8, (x1, y1), (x2, y2), BOX_COLOR, 2) # Label score = scores[i].item() tag = f"fracture: {score:.2f}" (tw, th), _ = cv2.getTextSize(tag, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1) cv2.rectangle(vis_u8, (x1, y1 - th - 8), (x1 + tw + 4, y1), BOX_COLOR, -1) cv2.putText(vis_u8, tag, (x1 + 2, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA) vis = vis_u8.astype(np.float32) return np.clip(vis, 0, 255).astype(np.uint8) # ─── Main inference function ────────────────────────────────────────────────── def predict(image: np.ndarray, score_threshold: float, nms_threshold: float): """ Gradio inference function. Args: image : RGB numpy array from Gradio score_threshold : confidence threshold slider nms_threshold : NMS IoU threshold slider Returns: annotated image, result summary text """ if image is None: return None, "No image provided." orig_rgb = image.copy() tensor = preprocess(orig_rgb) with torch.no_grad(): results = MODEL.predict(tensor, score_threshold, nms_threshold) res = results[0] n = len(res["scores"]) # Draw annotated = draw_predictions(orig_rgb, res) # Summary text if n == 0: summary = "**No fractures detected.**\n\nTry lowering the Score Threshold." else: lines = [f"**{n} fracture(s) detected:**\n"] for i in range(n): score = res["scores"][i].item() box = res["boxes"][i].tolist() lines.append( f"- Detection {i+1}: confidence **{score:.3f}** | " f"box `[{box[0]:.3f}, {box[1]:.3f}, {box[2]:.3f}, {box[3]:.3f}]`" ) summary = "\n".join(lines) return annotated, summary # ─── Gradio UI ──────────────────────────────────────────────────────────────── DESCRIPTION = """ ## FracAtlas Fracture Detection — YOLACT+ (ResNet-18) Upload an X-ray image to detect and segment bone fractures. **Model:** YOLACT+ with ResNet-18 backbone **Dataset:** [FracAtlas](https://figshare.com/articles/dataset/The_dataset/22363012) — 717 fractured + 3366 non-fractured X-rays **Training:** 200 epochs | AdamW | Cosine LR decay with warmup **Val F1:** 0.537 | **Val Avg IoU:** 0.940 """ EXAMPLES = [ ["examples/fractured_example.jpg", 0.4, 0.4], ["examples/nonfractured_example.jpg", 0.4, 0.4], ] with gr.Blocks(theme=gr.themes.Soft(), title="FracAtlas YOLACT+") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( label="Input X-ray Image", type="numpy", image_mode="RGB", ) with gr.Accordion("Detection Settings", open=False): score_thresh = gr.Slider( minimum=0.1, maximum=0.9, value=0.4, step=0.05, label="Score Threshold", info="Higher = fewer but more confident detections", ) nms_thresh = gr.Slider( minimum=0.1, maximum=0.9, value=0.4, step=0.05, label="NMS Threshold", info="Lower = suppress more overlapping boxes", ) run_btn = gr.Button("Detect Fractures", variant="primary") with gr.Column(scale=1): output_image = gr.Image( label="Predicted Segmentation", type="numpy", ) output_text = gr.Markdown(label="Detection Summary") run_btn.click( fn=predict, inputs=[input_image, score_thresh, nms_thresh], outputs=[output_image, output_text], ) # Also run on image upload input_image.change( fn=predict, inputs=[input_image, score_thresh, nms_thresh], outputs=[output_image, output_text], ) gr.Markdown(""" --- **Note:** This is a research demo. Not intended for clinical use. **Author:** Muhammad Adil | MS Data Science, ITU Lahore **GitHub:** [Adil6312](https://github.com/Adil6312) """) if __name__ == "__main__": demo.launch()