Update app.py
Browse files
app.py
CHANGED
|
@@ -1,20 +1,15 @@
|
|
| 1 |
-
# app.py — Gradio app for panoramic radiograph segmentation (Detectron2, CPU/GPU)
|
| 2 |
-
|
| 3 |
import os, json, time
|
| 4 |
import numpy as np, cv2, torch, gradio as gr
|
| 5 |
-
|
| 6 |
from detectron2.config import get_cfg
|
| 7 |
from detectron2.engine import DefaultPredictor
|
| 8 |
from detectron2.data import MetadataCatalog
|
| 9 |
from detectron2.utils.visualizer import Visualizer, ColorMode
|
| 10 |
|
| 11 |
-
# --------- Artifacts already in the repo ---------
|
| 12 |
LOAD_DIR = "./artifacts"
|
| 13 |
WEIGHTS = os.path.join(LOAD_DIR, "model_final.pth")
|
| 14 |
CFG_PATH = os.path.join(LOAD_DIR, "config.yaml")
|
| 15 |
CLASSES_PATH = os.path.join(LOAD_DIR, "classes.json")
|
| 16 |
|
| 17 |
-
# --------- Build cfg & load model ---------
|
| 18 |
cfg = get_cfg()
|
| 19 |
cfg.merge_from_file(CFG_PATH)
|
| 20 |
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -31,16 +26,14 @@ predictor = DefaultPredictor(cfg)
|
|
| 31 |
meta = MetadataCatalog.get("inference_only")
|
| 32 |
meta.thing_classes = classes if classes else [f"class_{i}" for i in range(cfg.MODEL.ROI_HEADS.NUM_CLASSES)]
|
| 33 |
|
| 34 |
-
# --------- Inference (auto-downscale very wide panoramics) ---------
|
| 35 |
MAX_SIDE = 1600
|
| 36 |
|
| 37 |
def segment(rgb: np.ndarray):
|
| 38 |
t0 = time.time()
|
| 39 |
h0, w0 = rgb.shape[:2]
|
| 40 |
-
long_side = max(h0, w0)
|
| 41 |
scale = 1.0
|
| 42 |
-
if
|
| 43 |
-
scale = MAX_SIDE /
|
| 44 |
rgb_small = cv2.resize(rgb, (int(w0*scale), int(h0*scale)), interpolation=cv2.INTER_AREA)
|
| 45 |
else:
|
| 46 |
rgb_small = rgb
|
|
@@ -52,14 +45,15 @@ def segment(rgb: np.ndarray):
|
|
| 52 |
overlay_rgb = vis.draw_instance_predictions(inst).get_image()
|
| 53 |
|
| 54 |
dets = []
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
b
|
| 61 |
-
|
| 62 |
-
|
|
|
|
| 63 |
|
| 64 |
return overlay_rgb, {
|
| 65 |
"instances": dets,
|
|
@@ -76,4 +70,4 @@ with gr.Blocks(title="Panoramic Radiograph Segmentation") as demo:
|
|
| 76 |
gr.Button("Run").click(segment, inputs=img_in, outputs=[img_out, json_out], api_name="/predict")
|
| 77 |
|
| 78 |
if __name__ == "__main__":
|
| 79 |
-
demo.launch() #
|
|
|
|
|
|
|
|
|
|
| 1 |
import os, json, time
|
| 2 |
import numpy as np, cv2, torch, gradio as gr
|
|
|
|
| 3 |
from detectron2.config import get_cfg
|
| 4 |
from detectron2.engine import DefaultPredictor
|
| 5 |
from detectron2.data import MetadataCatalog
|
| 6 |
from detectron2.utils.visualizer import Visualizer, ColorMode
|
| 7 |
|
|
|
|
| 8 |
LOAD_DIR = "./artifacts"
|
| 9 |
WEIGHTS = os.path.join(LOAD_DIR, "model_final.pth")
|
| 10 |
CFG_PATH = os.path.join(LOAD_DIR, "config.yaml")
|
| 11 |
CLASSES_PATH = os.path.join(LOAD_DIR, "classes.json")
|
| 12 |
|
|
|
|
| 13 |
cfg = get_cfg()
|
| 14 |
cfg.merge_from_file(CFG_PATH)
|
| 15 |
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 26 |
meta = MetadataCatalog.get("inference_only")
|
| 27 |
meta.thing_classes = classes if classes else [f"class_{i}" for i in range(cfg.MODEL.ROI_HEADS.NUM_CLASSES)]
|
| 28 |
|
|
|
|
| 29 |
MAX_SIDE = 1600
|
| 30 |
|
| 31 |
def segment(rgb: np.ndarray):
|
| 32 |
t0 = time.time()
|
| 33 |
h0, w0 = rgb.shape[:2]
|
|
|
|
| 34 |
scale = 1.0
|
| 35 |
+
if max(h0, w0) > MAX_SIDE:
|
| 36 |
+
scale = MAX_SIDE / max(h0, w0)
|
| 37 |
rgb_small = cv2.resize(rgb, (int(w0*scale), int(h0*scale)), interpolation=cv2.INTER_AREA)
|
| 38 |
else:
|
| 39 |
rgb_small = rgb
|
|
|
|
| 45 |
overlay_rgb = vis.draw_instance_predictions(inst).get_image()
|
| 46 |
|
| 47 |
dets = []
|
| 48 |
+
if inst.has("pred_boxes"):
|
| 49 |
+
boxes = inst.pred_boxes.tensor.numpy().tolist()
|
| 50 |
+
scores = inst.scores.numpy().tolist() if inst.has("scores") else [None]*len(boxes)
|
| 51 |
+
classes_idx = inst.pred_classes.numpy().tolist() if inst.has("pred_classes") else [0]*len(boxes)
|
| 52 |
+
inv = (1.0/scale) if scale != 1.0 else 1.0
|
| 53 |
+
for b, s, c in zip(boxes, scores, classes_idx):
|
| 54 |
+
b = [float(x*inv) for x in b]
|
| 55 |
+
label = meta.thing_classes[c] if 0 <= c < len(meta.thing_classes) else str(c)
|
| 56 |
+
dets.append({"box": b, "class": label, "score": float(s)})
|
| 57 |
|
| 58 |
return overlay_rgb, {
|
| 59 |
"instances": dets,
|
|
|
|
| 70 |
gr.Button("Run").click(segment, inputs=img_in, outputs=[img_out, json_out], api_name="/predict")
|
| 71 |
|
| 72 |
if __name__ == "__main__":
|
| 73 |
+
demo.launch() # Spaces provides the permanent URL
|