WSLINMSAI commited on
Commit
90fdaaf
·
verified ·
1 Parent(s): 741aa4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -18
app.py CHANGED
@@ -1,20 +1,15 @@
1
- # app.py — Gradio app for panoramic radiograph segmentation (Detectron2, CPU/GPU)
2
-
3
  import os, json, time
4
  import numpy as np, cv2, torch, gradio as gr
5
-
6
  from detectron2.config import get_cfg
7
  from detectron2.engine import DefaultPredictor
8
  from detectron2.data import MetadataCatalog
9
  from detectron2.utils.visualizer import Visualizer, ColorMode
10
 
11
- # --------- Artifacts already in the repo ---------
12
  LOAD_DIR = "./artifacts"
13
  WEIGHTS = os.path.join(LOAD_DIR, "model_final.pth")
14
  CFG_PATH = os.path.join(LOAD_DIR, "config.yaml")
15
  CLASSES_PATH = os.path.join(LOAD_DIR, "classes.json")
16
 
17
- # --------- Build cfg & load model ---------
18
  cfg = get_cfg()
19
  cfg.merge_from_file(CFG_PATH)
20
  cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -31,16 +26,14 @@ predictor = DefaultPredictor(cfg)
31
  meta = MetadataCatalog.get("inference_only")
32
  meta.thing_classes = classes if classes else [f"class_{i}" for i in range(cfg.MODEL.ROI_HEADS.NUM_CLASSES)]
33
 
34
- # --------- Inference (auto-downscale very wide panoramics) ---------
35
  MAX_SIDE = 1600
36
 
37
  def segment(rgb: np.ndarray):
38
  t0 = time.time()
39
  h0, w0 = rgb.shape[:2]
40
- long_side = max(h0, w0)
41
  scale = 1.0
42
- if long_side > MAX_SIDE:
43
- scale = MAX_SIDE / long_side
44
  rgb_small = cv2.resize(rgb, (int(w0*scale), int(h0*scale)), interpolation=cv2.INTER_AREA)
45
  else:
46
  rgb_small = rgb
@@ -52,14 +45,15 @@ def segment(rgb: np.ndarray):
52
  overlay_rgb = vis.draw_instance_predictions(inst).get_image()
53
 
54
  dets = []
55
- boxes_small = inst.pred_boxes.tensor.numpy().tolist() if inst.has("pred_boxes") else []
56
- scores = inst.scores.numpy().tolist() if inst.has("scores") else []
57
- classes_idx = inst.pred_classes.numpy().tolist() if inst.has("pred_classes") else []
58
- inv = (1.0/scale) if scale != 1.0 else 1.0
59
- for b, s, c in zip(boxes_small, scores, classes_idx):
60
- b = [float(x*inv) for x in b] # back to original image coords
61
- label = meta.thing_classes[c] if 0 <= c < len(meta.thing_classes) else str(c)
62
- dets.append({"box": b, "class": label, "score": float(s)})
 
63
 
64
  return overlay_rgb, {
65
  "instances": dets,
@@ -76,4 +70,4 @@ with gr.Blocks(title="Panoramic Radiograph Segmentation") as demo:
76
  gr.Button("Run").click(segment, inputs=img_in, outputs=[img_out, json_out], api_name="/predict")
77
 
78
  if __name__ == "__main__":
79
- demo.launch() # no share=True on Spaces
 
 
 
1
  import os, json, time
2
  import numpy as np, cv2, torch, gradio as gr
 
3
  from detectron2.config import get_cfg
4
  from detectron2.engine import DefaultPredictor
5
  from detectron2.data import MetadataCatalog
6
  from detectron2.utils.visualizer import Visualizer, ColorMode
7
 
 
8
  LOAD_DIR = "./artifacts"
9
  WEIGHTS = os.path.join(LOAD_DIR, "model_final.pth")
10
  CFG_PATH = os.path.join(LOAD_DIR, "config.yaml")
11
  CLASSES_PATH = os.path.join(LOAD_DIR, "classes.json")
12
 
 
13
  cfg = get_cfg()
14
  cfg.merge_from_file(CFG_PATH)
15
  cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
26
  meta = MetadataCatalog.get("inference_only")
27
  meta.thing_classes = classes if classes else [f"class_{i}" for i in range(cfg.MODEL.ROI_HEADS.NUM_CLASSES)]
28
 
 
29
  MAX_SIDE = 1600
30
 
31
  def segment(rgb: np.ndarray):
32
  t0 = time.time()
33
  h0, w0 = rgb.shape[:2]
 
34
  scale = 1.0
35
+ if max(h0, w0) > MAX_SIDE:
36
+ scale = MAX_SIDE / max(h0, w0)
37
  rgb_small = cv2.resize(rgb, (int(w0*scale), int(h0*scale)), interpolation=cv2.INTER_AREA)
38
  else:
39
  rgb_small = rgb
 
45
  overlay_rgb = vis.draw_instance_predictions(inst).get_image()
46
 
47
  dets = []
48
+ if inst.has("pred_boxes"):
49
+ boxes = inst.pred_boxes.tensor.numpy().tolist()
50
+ scores = inst.scores.numpy().tolist() if inst.has("scores") else [None]*len(boxes)
51
+ classes_idx = inst.pred_classes.numpy().tolist() if inst.has("pred_classes") else [0]*len(boxes)
52
+ inv = (1.0/scale) if scale != 1.0 else 1.0
53
+ for b, s, c in zip(boxes, scores, classes_idx):
54
+ b = [float(x*inv) for x in b]
55
+ label = meta.thing_classes[c] if 0 <= c < len(meta.thing_classes) else str(c)
56
+ dets.append({"box": b, "class": label, "score": float(s)})
57
 
58
  return overlay_rgb, {
59
  "instances": dets,
 
70
  gr.Button("Run").click(segment, inputs=img_in, outputs=[img_out, json_out], api_name="/predict")
71
 
72
  if __name__ == "__main__":
73
+ demo.launch() # Spaces provides the permanent URL