import os, json, time import numpy as np, cv2, torch, gradio as gr from detectron2.config import get_cfg from detectron2.engine import DefaultPredictor from detectron2.data import MetadataCatalog from detectron2.utils.visualizer import Visualizer, ColorMode # --- 1. CONFIGURATION & MODEL LOADING --- LOAD_DIR = "./artifacts" WEIGHTS = os.path.join(LOAD_DIR, "model_final.pth") CFG_PATH = os.path.join(LOAD_DIR, "config.yaml") CLASSES_PATH = os.path.join(LOAD_DIR, "classes.json") cfg = get_cfg() cfg.merge_from_file(CFG_PATH) cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" cfg.MODEL.WEIGHTS = WEIGHTS cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 classes = None if os.path.exists(CLASSES_PATH): with open(CLASSES_PATH) as f: classes = json.load(f) cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(classes) predictor = DefaultPredictor(cfg) meta = MetadataCatalog.get("inference_only") meta.thing_classes = classes if classes else [f"class_{i}" for i in range(cfg.MODEL.ROI_HEADS.NUM_CLASSES)] MAX_SIDE = 1600 # --- 2. INFERENCE FUNCTION --- def segment(rgb: np.ndarray): t0 = time.time() # Handle potential None input if user clicks run without image if rgb is None: return None, {"error": "No image uploaded"} h0, w0 = rgb.shape[:2] scale = 1.0 if max(h0, w0) > MAX_SIDE: scale = MAX_SIDE / max(h0, w0) rgb_small = cv2.resize(rgb, (int(w0*scale), int(h0*scale)), interpolation=cv2.INTER_AREA) else: rgb_small = rgb outputs = predictor(rgb_small[:, :, ::-1]) # predictor expects BGR inst = outputs["instances"].to("cpu") vis = Visualizer(rgb_small, metadata=meta, scale=1.0, instance_mode=ColorMode.IMAGE_BW) overlay_rgb = vis.draw_instance_predictions(inst).get_image() dets = [] if inst.has("pred_boxes"): boxes = inst.pred_boxes.tensor.numpy().tolist() scores = inst.scores.numpy().tolist() if inst.has("scores") else [None]*len(boxes) classes_idx = inst.pred_classes.numpy().tolist() if inst.has("pred_classes") else [0]*len(boxes) inv = (1.0/scale) if scale != 1.0 else 1.0 for b, s, c in zip(boxes, scores, classes_idx): b = [float(x*inv) for x in b] label = meta.thing_classes[c] if 0 <= c < len(meta.thing_classes) else str(c) dets.append({"box": b, "class": label, "score": float(s)}) return overlay_rgb, { "instances": dets, "original_size": [int(h0), int(w0)], "latency_ms": int((time.time()-t0)*1000), } # --- 3. GRADIO INTERFACE --- # Define the paths to your example images example_files = [ ["examples/1.jpg"], ["examples/2.jpg"], ["examples/3.jpg"], ["examples/4.jpg"], ["examples/5.jpg"] ] with gr.Blocks(title="Panoramic Radiograph Segmentation") as demo: gr.Markdown("## Dental X-Ray Segmentation App") gr.Markdown("Upload a panoramic radiograph (or click an example below) to detect teeth.") with gr.Row(): # --- Left Column: Input --- with gr.Column(): img_in = gr.Image(type="numpy", label="Input Radiograph") # This adds the thumbnails row gr.Examples( examples=example_files, inputs=img_in, label="Click an example to load it:" ) submit_btn = gr.Button("Run Segmentation", variant="primary") # --- Right Column: Output --- with gr.Column(): img_out = gr.Image(label="Overlay Result") json_out = gr.JSON(label="Detections Data") # Link the button to the function submit_btn.click(fn=segment, inputs=img_in, outputs=[img_out, json_out], api_name="/predict") # --- CITATIONS SECTION --- gr.Markdown("---") gr.Markdown( """ ### Credits & Citations Credits & Citations: * **Brahmi, W., & Jdey, I. (2024). Automatic tooth instance segmentation and identification from panoramic X-Ray images using deep CNN. *Multimedia Tools and Applications, 83*(18), 55565–55585. * **Brahmi, W., Jdey, I., & Drira, F. (2024). Exploring the role of Convolutional Neural Networks (CNN) in dental radiography segmentation: A comprehensive Systematic Literature Review. *Engineering Applications of Artificial Intelligence, 133*, 108510. * **[Panoramic Dental X-rays (Mendeley Data)](https://data.mendeley.com/datasets/73n3kz2k4k/3) """ ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)