Spaces:
Running
Running
| """ | |
| FracAtlas YOLACT+ Demo | |
| ====================== | |
| Gradio app for fracture detection and segmentation on X-ray images. | |
| Deployed on Hugging Face Spaces: https://huggingface.co/spaces/MuhammadAdil63/FracAtlas-YOLACT | |
| """ | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| from model import YOLACTPlus | |
| # βββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| IMG_SIZE = 550 | |
| NUM_CLASSES = 1 | |
| CLASS_NAMES = ["fracture"] | |
| SCORE_THRESH = 0.4 | |
| NMS_THRESH = 0.4 | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # HuggingFace repo where best.pth is hosted | |
| HF_REPO_ID = "MuhammadAdil63/FracAtlas-YOLACT" # change if model hosted separately | |
| CKPT_FILE = "best.pth" | |
| # Mask overlay colour (red in RGB) | |
| MASK_COLOR = (220, 50, 50) | |
| BOX_COLOR = (220, 50, 50) | |
| # βββ Load model (cached after first load) ββββββββββββββββββββββββββββββββββββ | |
| def load_model(): | |
| print(f"[INFO] Loading model on {DEVICE} ...") | |
| # Try local first, then download from HF Hub | |
| if os.path.isfile(CKPT_FILE): | |
| ckpt_path = CKPT_FILE | |
| print(f"[INFO] Using local checkpoint: {ckpt_path}") | |
| else: | |
| print(f"[INFO] Downloading checkpoint from HF Hub ...") | |
| ckpt_path = hf_hub_download(repo_id=HF_REPO_ID, filename=CKPT_FILE) | |
| model = YOLACTPlus(num_classes=NUM_CLASSES, img_size=IMG_SIZE, pretrained=False) | |
| ckpt = torch.load(ckpt_path, map_location=DEVICE) | |
| model.load_state_dict(ckpt.get("model", ckpt)) | |
| model.to(DEVICE).eval() | |
| print("[INFO] Model loaded successfully!") | |
| return model | |
| # Load once at startup | |
| MODEL = load_model() | |
| # βββ Preprocessing ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def preprocess(image_rgb: np.ndarray): | |
| transform = A.Compose([ | |
| A.LongestMaxSize(max_size=IMG_SIZE), | |
| A.PadIfNeeded(min_height=IMG_SIZE, min_width=IMG_SIZE, fill=0), | |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ToTensorV2(), | |
| ]) | |
| out = transform(image=image_rgb) | |
| tensor = out["image"].unsqueeze(0).to(DEVICE) | |
| return tensor | |
| # βββ Draw predictions βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def draw_predictions(image_rgb: np.ndarray, result: dict) -> np.ndarray: | |
| vis = image_rgb.copy().astype(np.float32) | |
| H, W = vis.shape[:2] | |
| boxes = result["boxes"] # [N, 4] normalised | |
| scores = result["scores"] # [N] | |
| masks = result["masks"] # [N, IMG_SIZE, IMG_SIZE] | |
| for i in range(len(scores)): | |
| # Mask overlay | |
| mask_np = masks[i].numpy() | |
| mask_rs = cv2.resize(mask_np, (W, H), interpolation=cv2.INTER_LINEAR) | |
| mask_bin = (mask_rs > 0.5).astype(np.float32) | |
| colour = np.array(MASK_COLOR, dtype=np.float32) | |
| for c in range(3): | |
| vis[:, :, c] = vis[:, :, c] * (1 - 0.45 * mask_bin) + \ | |
| colour[c] * (0.45 * mask_bin) | |
| # Bounding box | |
| x1 = int(boxes[i, 0].item() * W) | |
| y1 = int(boxes[i, 1].item() * H) | |
| x2 = int(boxes[i, 2].item() * W) | |
| y2 = int(boxes[i, 3].item() * H) | |
| vis_u8 = np.clip(vis, 0, 255).astype(np.uint8) | |
| cv2.rectangle(vis_u8, (x1, y1), (x2, y2), BOX_COLOR, 2) | |
| # Label | |
| score = scores[i].item() | |
| tag = f"fracture: {score:.2f}" | |
| (tw, th), _ = cv2.getTextSize(tag, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1) | |
| cv2.rectangle(vis_u8, (x1, y1 - th - 8), (x1 + tw + 4, y1), BOX_COLOR, -1) | |
| cv2.putText(vis_u8, tag, (x1 + 2, y1 - 4), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1, cv2.LINE_AA) | |
| vis = vis_u8.astype(np.float32) | |
| return np.clip(vis, 0, 255).astype(np.uint8) | |
| # βββ Main inference function ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict(image: np.ndarray, score_threshold: float, nms_threshold: float): | |
| """ | |
| Gradio inference function. | |
| Args: | |
| image : RGB numpy array from Gradio | |
| score_threshold : confidence threshold slider | |
| nms_threshold : NMS IoU threshold slider | |
| Returns: | |
| annotated image, result summary text | |
| """ | |
| if image is None: | |
| return None, "No image provided." | |
| orig_rgb = image.copy() | |
| tensor = preprocess(orig_rgb) | |
| with torch.no_grad(): | |
| results = MODEL.predict(tensor, score_threshold, nms_threshold) | |
| res = results[0] | |
| n = len(res["scores"]) | |
| # Draw | |
| annotated = draw_predictions(orig_rgb, res) | |
| # Summary text | |
| if n == 0: | |
| summary = "**No fractures detected.**\n\nTry lowering the Score Threshold." | |
| else: | |
| lines = [f"**{n} fracture(s) detected:**\n"] | |
| for i in range(n): | |
| score = res["scores"][i].item() | |
| box = res["boxes"][i].tolist() | |
| lines.append( | |
| f"- Detection {i+1}: confidence **{score:.3f}** | " | |
| f"box `[{box[0]:.3f}, {box[1]:.3f}, {box[2]:.3f}, {box[3]:.3f}]`" | |
| ) | |
| summary = "\n".join(lines) | |
| return annotated, summary | |
| # βββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DESCRIPTION = """ | |
| ## FracAtlas Fracture Detection β YOLACT+ (ResNet-18) | |
| Upload an X-ray image to detect and segment bone fractures. | |
| **Model:** YOLACT+ with ResNet-18 backbone | |
| **Dataset:** [FracAtlas](https://figshare.com/articles/dataset/The_dataset/22363012) β 717 fractured + 3366 non-fractured X-rays | |
| **Training:** 200 epochs | AdamW | Cosine LR decay with warmup | |
| **Val F1:** 0.537 | **Val Avg IoU:** 0.940 | |
| """ | |
| EXAMPLES = [ | |
| ["examples/fractured_example.jpg", 0.4, 0.4], | |
| ["examples/nonfractured_example.jpg", 0.4, 0.4], | |
| ] | |
| with gr.Blocks(theme=gr.themes.Soft(), title="FracAtlas YOLACT+") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| label="Input X-ray Image", | |
| type="numpy", | |
| image_mode="RGB", | |
| ) | |
| with gr.Accordion("Detection Settings", open=False): | |
| score_thresh = gr.Slider( | |
| minimum=0.1, maximum=0.9, value=0.4, step=0.05, | |
| label="Score Threshold", | |
| info="Higher = fewer but more confident detections", | |
| ) | |
| nms_thresh = gr.Slider( | |
| minimum=0.1, maximum=0.9, value=0.4, step=0.05, | |
| label="NMS Threshold", | |
| info="Lower = suppress more overlapping boxes", | |
| ) | |
| run_btn = gr.Button("Detect Fractures", variant="primary") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image( | |
| label="Predicted Segmentation", | |
| type="numpy", | |
| ) | |
| output_text = gr.Markdown(label="Detection Summary") | |
| run_btn.click( | |
| fn=predict, | |
| inputs=[input_image, score_thresh, nms_thresh], | |
| outputs=[output_image, output_text], | |
| ) | |
| # Also run on image upload | |
| input_image.change( | |
| fn=predict, | |
| inputs=[input_image, score_thresh, nms_thresh], | |
| outputs=[output_image, output_text], | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Note:** This is a research demo. Not intended for clinical use. | |
| **Author:** Muhammad Adil | MS Data Science, ITU Lahore | |
| **GitHub:** [Adil6312](https://github.com/Adil6312) | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |