File size: 4,569 Bytes
13e00dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os, json, time
import numpy as np, cv2, torch, gradio as gr
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import Visualizer, ColorMode

# --- 1. CONFIGURATION & MODEL LOADING ---
LOAD_DIR = "./artifacts"
WEIGHTS = os.path.join(LOAD_DIR, "model_final.pth")
CFG_PATH = os.path.join(LOAD_DIR, "config.yaml")
CLASSES_PATH = os.path.join(LOAD_DIR, "classes.json")

cfg = get_cfg()
cfg.merge_from_file(CFG_PATH)
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
cfg.MODEL.WEIGHTS = WEIGHTS
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5

classes = None
if os.path.exists(CLASSES_PATH):
    with open(CLASSES_PATH) as f:
        classes = json.load(f)
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(classes)

predictor = DefaultPredictor(cfg)
meta = MetadataCatalog.get("inference_only")
meta.thing_classes = classes if classes else [f"class_{i}" for i in range(cfg.MODEL.ROI_HEADS.NUM_CLASSES)]

MAX_SIDE = 1600

# --- 2. INFERENCE FUNCTION ---
def segment(rgb: np.ndarray):
    t0 = time.time()
    # Handle potential None input if user clicks run without image
    if rgb is None:
        return None, {"error": "No image uploaded"}
        
    h0, w0 = rgb.shape[:2]
    scale = 1.0
    if max(h0, w0) > MAX_SIDE:
        scale = MAX_SIDE / max(h0, w0)
        rgb_small = cv2.resize(rgb, (int(w0*scale), int(h0*scale)), interpolation=cv2.INTER_AREA)
    else:
        rgb_small = rgb

    outputs = predictor(rgb_small[:, :, ::-1])  # predictor expects BGR
    inst = outputs["instances"].to("cpu")

    vis = Visualizer(rgb_small, metadata=meta, scale=1.0, instance_mode=ColorMode.IMAGE_BW)
    overlay_rgb = vis.draw_instance_predictions(inst).get_image()

    dets = []
    if inst.has("pred_boxes"):
        boxes = inst.pred_boxes.tensor.numpy().tolist()
        scores = inst.scores.numpy().tolist() if inst.has("scores") else [None]*len(boxes)
        classes_idx = inst.pred_classes.numpy().tolist() if inst.has("pred_classes") else [0]*len(boxes)
        inv = (1.0/scale) if scale != 1.0 else 1.0
        for b, s, c in zip(boxes, scores, classes_idx):
            b = [float(x*inv) for x in b]
            label = meta.thing_classes[c] if 0 <= c < len(meta.thing_classes) else str(c)
            dets.append({"box": b, "class": label, "score": float(s)})

    return overlay_rgb, {
        "instances": dets,
        "original_size": [int(h0), int(w0)],
        "latency_ms": int((time.time()-t0)*1000),
    }

# --- 3. GRADIO INTERFACE ---

# Define the paths to your example images
example_files = [
    ["examples/1.jpg"],
    ["examples/2.jpg"],
    ["examples/3.jpg"],
    ["examples/4.jpg"],
    ["examples/5.jpg"]
]

with gr.Blocks(title="Panoramic Radiograph Segmentation") as demo:
    gr.Markdown("## Dental X-Ray Segmentation App")
    gr.Markdown("Upload a panoramic radiograph (or click an example below) to detect teeth.")

    with gr.Row():
        # --- Left Column: Input ---
        with gr.Column():
            img_in = gr.Image(type="numpy", label="Input Radiograph")
            
            # This adds the thumbnails row
            gr.Examples(
                examples=example_files,
                inputs=img_in,
                label="Click an example to load it:"
            )
            
            submit_btn = gr.Button("Run Segmentation", variant="primary")

        # --- Right Column: Output ---
        with gr.Column():
            img_out = gr.Image(label="Overlay Result")
            json_out = gr.JSON(label="Detections Data")

    # Link the button to the function
    submit_btn.click(fn=segment, inputs=img_in, outputs=[img_out, json_out], api_name="/predict")

    # --- CITATIONS SECTION ---
    gr.Markdown("---") 
    gr.Markdown(
        """
        ### Credits & Citations
        Credits & Citations:
        * **Brahmi, W., & Jdey, I. (2024). Automatic tooth instance segmentation and identification from panoramic X-Ray images using deep CNN. *Multimedia Tools and Applications, 83*(18), 55565–55585.
        * **Brahmi, W., Jdey, I., & Drira, F. (2024). Exploring the role of Convolutional Neural Networks (CNN) in dental radiography segmentation: A comprehensive Systematic Literature Review. *Engineering Applications of Artificial Intelligence, 133*, 108510.
        * **[Panoramic Dental X-rays (Mendeley Data)](https://data.mendeley.com/datasets/73n3kz2k4k/3)
        """
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)