Capstone / app.py
WSLINMSAI's picture
Update app.py
d414beb verified
import os, json, time
import numpy as np, cv2, torch, gradio as gr
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import Visualizer, ColorMode
# --- 1. CONFIGURATION & MODEL LOADING ---
LOAD_DIR = "./artifacts"
WEIGHTS = os.path.join(LOAD_DIR, "model_final.pth")
CFG_PATH = os.path.join(LOAD_DIR, "config.yaml")
CLASSES_PATH = os.path.join(LOAD_DIR, "classes.json")
cfg = get_cfg()
cfg.merge_from_file(CFG_PATH)
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
cfg.MODEL.WEIGHTS = WEIGHTS
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
classes = None
if os.path.exists(CLASSES_PATH):
with open(CLASSES_PATH) as f:
classes = json.load(f)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(classes)
predictor = DefaultPredictor(cfg)
meta = MetadataCatalog.get("inference_only")
meta.thing_classes = classes if classes else [f"class_{i}" for i in range(cfg.MODEL.ROI_HEADS.NUM_CLASSES)]
MAX_SIDE = 1600
# --- 2. INFERENCE FUNCTION ---
def segment(rgb: np.ndarray):
t0 = time.time()
# Handle potential None input if user clicks run without image
if rgb is None:
return None, {"error": "No image uploaded"}
h0, w0 = rgb.shape[:2]
scale = 1.0
if max(h0, w0) > MAX_SIDE:
scale = MAX_SIDE / max(h0, w0)
rgb_small = cv2.resize(rgb, (int(w0*scale), int(h0*scale)), interpolation=cv2.INTER_AREA)
else:
rgb_small = rgb
outputs = predictor(rgb_small[:, :, ::-1]) # predictor expects BGR
inst = outputs["instances"].to("cpu")
vis = Visualizer(rgb_small, metadata=meta, scale=1.0, instance_mode=ColorMode.IMAGE_BW)
overlay_rgb = vis.draw_instance_predictions(inst).get_image()
dets = []
if inst.has("pred_boxes"):
boxes = inst.pred_boxes.tensor.numpy().tolist()
scores = inst.scores.numpy().tolist() if inst.has("scores") else [None]*len(boxes)
classes_idx = inst.pred_classes.numpy().tolist() if inst.has("pred_classes") else [0]*len(boxes)
inv = (1.0/scale) if scale != 1.0 else 1.0
for b, s, c in zip(boxes, scores, classes_idx):
b = [float(x*inv) for x in b]
label = meta.thing_classes[c] if 0 <= c < len(meta.thing_classes) else str(c)
dets.append({"box": b, "class": label, "score": float(s)})
return overlay_rgb, {
"instances": dets,
"original_size": [int(h0), int(w0)],
"latency_ms": int((time.time()-t0)*1000),
}
# --- 3. GRADIO INTERFACE ---
# Define the paths to your example images
example_files = [
["examples/1.jpg"],
["examples/2.jpg"],
["examples/3.jpg"],
["examples/4.jpg"],
["examples/5.jpg"]
]
with gr.Blocks(title="Panoramic Radiograph Segmentation") as demo:
gr.Markdown("## Dental X-Ray Segmentation App")
gr.Markdown("Upload a panoramic radiograph (or click an example below) to detect teeth.")
with gr.Row():
# --- Left Column: Input ---
with gr.Column():
img_in = gr.Image(type="numpy", label="Input Radiograph")
# This adds the thumbnails row
gr.Examples(
examples=example_files,
inputs=img_in,
label="Click an example to load it:"
)
submit_btn = gr.Button("Run Segmentation", variant="primary")
# --- Right Column: Output ---
with gr.Column():
img_out = gr.Image(label="Overlay Result")
json_out = gr.JSON(label="Detections Data")
# Link the button to the function
submit_btn.click(fn=segment, inputs=img_in, outputs=[img_out, json_out], api_name="/predict")
# --- CITATIONS SECTION ---
gr.Markdown("---")
gr.Markdown(
"""
### Credits & Citations
Credits & Citations:
* **Brahmi, W., & Jdey, I. (2024). Automatic tooth instance segmentation and identification from panoramic X-Ray images using deep CNN. *Multimedia Tools and Applications, 83*(18), 55565–55585.
* **Brahmi, W., Jdey, I., & Drira, F. (2024). Exploring the role of Convolutional Neural Networks (CNN) in dental radiography segmentation: A comprehensive Systematic Literature Review. *Engineering Applications of Artificial Intelligence, 133*, 108510.
* **[Panoramic Dental X-rays (Mendeley Data)](https://data.mendeley.com/datasets/73n3kz2k4k/3)
"""
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)