File size: 4,569 Bytes
db1f63e
 
9f50f1e
 
 
 
 
94b9de0
db1f63e
 
 
 
9f50f1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94b9de0
9f50f1e
 
94b9de0
 
 
 
9f50f1e
 
90fdaaf
 
9f50f1e
 
 
 
db1f63e
9f50f1e
 
 
727a4ff
9f50f1e
 
90fdaaf
 
 
 
 
 
 
 
 
9f50f1e
 
 
 
db1f63e
9f50f1e
 
8c5af3c
94b9de0
 
 
 
 
 
 
 
 
 
9f50f1e
94b9de0
8c5af3c
94b9de0
9f50f1e
94b9de0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f50f1e
8c5af3c
 
 
 
 
 
d414beb
 
 
8c5af3c
 
 
9f50f1e
6c46732
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os, json, time
import numpy as np, cv2, torch, gradio as gr
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import Visualizer, ColorMode

# --- 1. CONFIGURATION & MODEL LOADING ---
LOAD_DIR = "./artifacts"
WEIGHTS = os.path.join(LOAD_DIR, "model_final.pth")
CFG_PATH = os.path.join(LOAD_DIR, "config.yaml")
CLASSES_PATH = os.path.join(LOAD_DIR, "classes.json")

cfg = get_cfg()
cfg.merge_from_file(CFG_PATH)
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
cfg.MODEL.WEIGHTS = WEIGHTS
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5

classes = None
if os.path.exists(CLASSES_PATH):
    with open(CLASSES_PATH) as f:
        classes = json.load(f)
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(classes)

predictor = DefaultPredictor(cfg)
meta = MetadataCatalog.get("inference_only")
meta.thing_classes = classes if classes else [f"class_{i}" for i in range(cfg.MODEL.ROI_HEADS.NUM_CLASSES)]

MAX_SIDE = 1600

# --- 2. INFERENCE FUNCTION ---
def segment(rgb: np.ndarray):
    t0 = time.time()
    # Handle potential None input if user clicks run without image
    if rgb is None:
        return None, {"error": "No image uploaded"}
        
    h0, w0 = rgb.shape[:2]
    scale = 1.0
    if max(h0, w0) > MAX_SIDE:
        scale = MAX_SIDE / max(h0, w0)
        rgb_small = cv2.resize(rgb, (int(w0*scale), int(h0*scale)), interpolation=cv2.INTER_AREA)
    else:
        rgb_small = rgb

    outputs = predictor(rgb_small[:, :, ::-1])  # predictor expects BGR
    inst = outputs["instances"].to("cpu")

    vis = Visualizer(rgb_small, metadata=meta, scale=1.0, instance_mode=ColorMode.IMAGE_BW)
    overlay_rgb = vis.draw_instance_predictions(inst).get_image()

    dets = []
    if inst.has("pred_boxes"):
        boxes = inst.pred_boxes.tensor.numpy().tolist()
        scores = inst.scores.numpy().tolist() if inst.has("scores") else [None]*len(boxes)
        classes_idx = inst.pred_classes.numpy().tolist() if inst.has("pred_classes") else [0]*len(boxes)
        inv = (1.0/scale) if scale != 1.0 else 1.0
        for b, s, c in zip(boxes, scores, classes_idx):
            b = [float(x*inv) for x in b]
            label = meta.thing_classes[c] if 0 <= c < len(meta.thing_classes) else str(c)
            dets.append({"box": b, "class": label, "score": float(s)})

    return overlay_rgb, {
        "instances": dets,
        "original_size": [int(h0), int(w0)],
        "latency_ms": int((time.time()-t0)*1000),
    }

# --- 3. GRADIO INTERFACE ---

# Define the paths to your example images
example_files = [
    ["examples/1.jpg"],
    ["examples/2.jpg"],
    ["examples/3.jpg"],
    ["examples/4.jpg"],
    ["examples/5.jpg"]
]

with gr.Blocks(title="Panoramic Radiograph Segmentation") as demo:
    gr.Markdown("## Dental X-Ray Segmentation App")
    gr.Markdown("Upload a panoramic radiograph (or click an example below) to detect teeth.")

    with gr.Row():
        # --- Left Column: Input ---
        with gr.Column():
            img_in = gr.Image(type="numpy", label="Input Radiograph")
            
            # This adds the thumbnails row
            gr.Examples(
                examples=example_files,
                inputs=img_in,
                label="Click an example to load it:"
            )
            
            submit_btn = gr.Button("Run Segmentation", variant="primary")

        # --- Right Column: Output ---
        with gr.Column():
            img_out = gr.Image(label="Overlay Result")
            json_out = gr.JSON(label="Detections Data")

    # Link the button to the function
    submit_btn.click(fn=segment, inputs=img_in, outputs=[img_out, json_out], api_name="/predict")

    # --- CITATIONS SECTION ---
    gr.Markdown("---") 
    gr.Markdown(
        """
        ### Credits & Citations
        Credits & Citations:
        * **Brahmi, W., & Jdey, I. (2024). Automatic tooth instance segmentation and identification from panoramic X-Ray images using deep CNN. *Multimedia Tools and Applications, 83*(18), 55565–55585.
        * **Brahmi, W., Jdey, I., & Drira, F. (2024). Exploring the role of Convolutional Neural Networks (CNN) in dental radiography segmentation: A comprehensive Systematic Literature Review. *Engineering Applications of Artificial Intelligence, 133*, 108510.
        * **[Panoramic Dental X-rays (Mendeley Data)](https://data.mendeley.com/datasets/73n3kz2k4k/3)
        """
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)