|
|
import os, json, time |
|
|
import numpy as np, cv2, torch, gradio as gr |
|
|
from detectron2.config import get_cfg |
|
|
from detectron2.engine import DefaultPredictor |
|
|
from detectron2.data import MetadataCatalog |
|
|
from detectron2.utils.visualizer import Visualizer, ColorMode |
|
|
|
|
|
|
|
|
LOAD_DIR = "./artifacts" |
|
|
WEIGHTS = os.path.join(LOAD_DIR, "model_final.pth") |
|
|
CFG_PATH = os.path.join(LOAD_DIR, "config.yaml") |
|
|
CLASSES_PATH = os.path.join(LOAD_DIR, "classes.json") |
|
|
|
|
|
cfg = get_cfg() |
|
|
cfg.merge_from_file(CFG_PATH) |
|
|
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
cfg.MODEL.WEIGHTS = WEIGHTS |
|
|
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 |
|
|
|
|
|
classes = None |
|
|
if os.path.exists(CLASSES_PATH): |
|
|
with open(CLASSES_PATH) as f: |
|
|
classes = json.load(f) |
|
|
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(classes) |
|
|
|
|
|
predictor = DefaultPredictor(cfg) |
|
|
meta = MetadataCatalog.get("inference_only") |
|
|
meta.thing_classes = classes if classes else [f"class_{i}" for i in range(cfg.MODEL.ROI_HEADS.NUM_CLASSES)] |
|
|
|
|
|
MAX_SIDE = 1600 |
|
|
|
|
|
|
|
|
def segment(rgb: np.ndarray): |
|
|
t0 = time.time() |
|
|
|
|
|
if rgb is None: |
|
|
return None, {"error": "No image uploaded"} |
|
|
|
|
|
h0, w0 = rgb.shape[:2] |
|
|
scale = 1.0 |
|
|
if max(h0, w0) > MAX_SIDE: |
|
|
scale = MAX_SIDE / max(h0, w0) |
|
|
rgb_small = cv2.resize(rgb, (int(w0*scale), int(h0*scale)), interpolation=cv2.INTER_AREA) |
|
|
else: |
|
|
rgb_small = rgb |
|
|
|
|
|
outputs = predictor(rgb_small[:, :, ::-1]) |
|
|
inst = outputs["instances"].to("cpu") |
|
|
|
|
|
vis = Visualizer(rgb_small, metadata=meta, scale=1.0, instance_mode=ColorMode.IMAGE_BW) |
|
|
overlay_rgb = vis.draw_instance_predictions(inst).get_image() |
|
|
|
|
|
dets = [] |
|
|
if inst.has("pred_boxes"): |
|
|
boxes = inst.pred_boxes.tensor.numpy().tolist() |
|
|
scores = inst.scores.numpy().tolist() if inst.has("scores") else [None]*len(boxes) |
|
|
classes_idx = inst.pred_classes.numpy().tolist() if inst.has("pred_classes") else [0]*len(boxes) |
|
|
inv = (1.0/scale) if scale != 1.0 else 1.0 |
|
|
for b, s, c in zip(boxes, scores, classes_idx): |
|
|
b = [float(x*inv) for x in b] |
|
|
label = meta.thing_classes[c] if 0 <= c < len(meta.thing_classes) else str(c) |
|
|
dets.append({"box": b, "class": label, "score": float(s)}) |
|
|
|
|
|
return overlay_rgb, { |
|
|
"instances": dets, |
|
|
"original_size": [int(h0), int(w0)], |
|
|
"latency_ms": int((time.time()-t0)*1000), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
example_files = [ |
|
|
["examples/1.jpg"], |
|
|
["examples/2.jpg"], |
|
|
["examples/3.jpg"], |
|
|
["examples/4.jpg"], |
|
|
["examples/5.jpg"] |
|
|
] |
|
|
|
|
|
with gr.Blocks(title="Panoramic Radiograph Segmentation") as demo: |
|
|
gr.Markdown("## Dental X-Ray Segmentation App") |
|
|
gr.Markdown("Upload a panoramic radiograph (or click an example below) to detect teeth.") |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(): |
|
|
img_in = gr.Image(type="numpy", label="Input Radiograph") |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=example_files, |
|
|
inputs=img_in, |
|
|
label="Click an example to load it:" |
|
|
) |
|
|
|
|
|
submit_btn = gr.Button("Run Segmentation", variant="primary") |
|
|
|
|
|
|
|
|
with gr.Column(): |
|
|
img_out = gr.Image(label="Overlay Result") |
|
|
json_out = gr.JSON(label="Detections Data") |
|
|
|
|
|
|
|
|
submit_btn.click(fn=segment, inputs=img_in, outputs=[img_out, json_out], api_name="/predict") |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown( |
|
|
""" |
|
|
### Credits & Citations |
|
|
Credits & Citations: |
|
|
* **Brahmi, W., & Jdey, I. (2024). Automatic tooth instance segmentation and identification from panoramic X-Ray images using deep CNN. *Multimedia Tools and Applications, 83*(18), 55565–55585. |
|
|
* **Brahmi, W., Jdey, I., & Drira, F. (2024). Exploring the role of Convolutional Neural Networks (CNN) in dental radiography segmentation: A comprehensive Systematic Literature Review. *Engineering Applications of Artificial Intelligence, 133*, 108510. |
|
|
* **[Panoramic Dental X-rays (Mendeley Data)](https://data.mendeley.com/datasets/73n3kz2k4k/3) |
|
|
""" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |