File size: 6,137 Bytes
63a8caf
21ea2eb
 
 
6f6a21b
130e656
21ea2eb
b629612
21ea2eb
 
130e656
21ea2eb
 
130e656
21ea2eb
 
 
 
 
 
 
 
6f6a21b
21ea2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
130e656
63a8caf
 
21ea2eb
 
63a8caf
 
21ea2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130e656
21ea2eb
 
 
130e656
 
21ea2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130e656
21ea2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c999c27
 
21ea2eb
 
 
 
 
 
 
 
 
 
 
 
 
6f6a21b
21ea2eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import numpy as np
import cv2
import gradio as gr
from pathlib import Path
from ultralytics import YOLO

# ------------ CONFIG ------------
MODEL_PATH = Path(__file__).parent / "best.pt"
if not MODEL_PATH.exists():
    raise FileNotFoundError(f"Model not found at {MODEL_PATH}. Place your .pt next to app.py or change MODEL_PATH.")

# Load model once
model = YOLO(str(MODEL_PATH))

# ------------ HELPERS ------------
def build_mask_from_result(res, H, W):
    """
    Return a single-channel uint8 mask HxW (0 background, 255 rockfall).
    If res.masks has polygons (res.masks.xy), rasterize them.
    Otherwise, fill bounding boxes.
    """
    mask = np.zeros((H, W), dtype=np.uint8)

    # Try segmentation polygons first (res.masks.xy is list of polygons for each instance)
    try:
        if hasattr(res, "masks") and res.masks is not None and hasattr(res.masks, "xy") and len(res.masks.xy) > 0:
            # res.masks.xy is a list of lists of (x,y) points per instance
            for poly in res.masks.xy:
                # poly might be list of floats [x1,y1,x2,y2,...] or list of points
                pts = np.array(poly, dtype=np.int32).reshape(-1, 2)
                cv2.fillPoly(mask, [pts], 255)
            return mask
    except Exception:
        # fallback to boxes
        pass

    # fallback: use bounding boxes
    try:
        if hasattr(res, "boxes") and len(res.boxes) > 0:
            xyxy = res.boxes.xyxy.cpu().numpy()
            for b in xyxy:
                x1,y1,x2,y2 = map(int, b)
                # clip
                x1, y1 = max(0,x1), max(0,y1)
                x2, y2 = min(W-1,x2), min(H-1,y2)
                if x2 > x1 and y2 > y1:
                    cv2.rectangle(mask, (x1,y1), (x2,y2), 255, thickness=-1)
    except Exception:
        pass

    return mask

def detections_from_result(res):
    """
    Return list of detections dicts: {class_id, class_name, confidence, bbox}
    """
    detections = []
    if hasattr(res, "boxes") and len(res.boxes) > 0:
        xyxy = res.boxes.xyxy.cpu().numpy()
        confs = res.boxes.conf.cpu().numpy()
        cls = res.boxes.cls.cpu().numpy().astype(int)
        for b, c, cl in zip(xyxy, confs, cls):
            x1,y1,x2,y2 = map(int, b)
            detections.append({
                "class_id": int(cl),
                "class_name": model.names[int(cl)] if int(cl) in model.names else str(int(cl)),
                "confidence": float(c),
                "bbox": [x1,y1,x2,y2]
            })
    return detections

def annotate_image_from_result(res, img):
    """
    res.plot() from Ultralytics returns an RGB numpy image already annotated by the library.
    Use it directly (no BGR/RGB conversion).
    If res.plot() fails, draw boxes ourselves from res.boxes.
    """
    try:
        annotated = res.plot()  # should be RGB HWC numpy
        # ensure dtype uint8
        if annotated.dtype != np.uint8:
            annotated = (annotated * 255).astype(np.uint8)
        return annotated
    except Exception:
        # fallback: draw boxes on copy of original (which may be numpy RGB passed to predict)
        out = img.copy()
        dets = detections_from_result(res)
        for d in dets:
            x1,y1,x2,y2 = d["bbox"]
            label = f"{d['class_name']} {d['confidence']:.2f}"
            cv2.rectangle(out, (x1,y1), (x2,y2), (0,255,0), 2)
            (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
            cv2.rectangle(out, (x1, y1 - th - 6), (x1 + tw + 6, y1), (0,255,0), -1)
            cv2.putText(out, label, (x1+3, y1-4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 1, cv2.LINE_AA)
        return out

# ------------ PREDICT FUNCTION ------------
def predict_gradio(image, conf=0.25, imgsz=640, return_mask=False, return_probability=True):
    """
    image: numpy (HWC, RGB) provided by Gradio Image component (type="numpy")
    returns: annotated_image (RGB numpy), optional mask (HWC or 2D), detections (json), probability (float)
    """
    if image is None:
        return None, None, [], 0.0

    # Ultralytics accepts RGB numpy images; Gradio provides numpy in RGB already.
    # Run model
    results = model.predict(source=image, conf=float(conf), imgsz=int(imgsz), verbose=False)
    res = results[0]

    H, W = image.shape[0], image.shape[1]

    # get detections and probability
    detections = detections_from_result(res)
    probability = max([d["confidence"] for d in detections], default=0.0)

    # annotated image: use res.plot() (RGB)
    annotated = annotate_image_from_result(res, image)

    mask_out = None
    if return_mask:
        mask = build_mask_from_result(res, H, W)  # single-channel 0/255 uint8
        # Gradio Image expects either HxW (grayscale) or HxWx3; we return grayscale
        mask_out = mask

    # Return ordering depends on Gradio outputs setup. We will return (annotated, mask, detections, probability)
    return annotated, mask_out, detections, float(probability)

# ------------ GRADIO UI ------------
with gr.Blocks(title="Rockfall Detection (YOLOv8)") as demo:
    gr.Markdown("# 🪨 Rockfall Detection (YOLOv8)")
    with gr.Row():
        with gr.Column(scale=1):
            inp = gr.Image(type="numpy", label="Upload mine image (RGB)")

            conf = gr.Slider(0.05, 0.9, value=0.25, step=0.01, label="Confidence threshold")
            imgsz = gr.Slider(320, 1280, value=640, step=32, label="Inference image size")
            return_mask = gr.Checkbox(False, label="Return binary mask (0/255)")
            run_btn = gr.Button("Run")
        with gr.Column(scale=1):
            out_img = gr.Image(type="numpy", label="Annotated image (RGB)")
            out_mask = gr.Image(type="numpy", label="Binary mask (if requested)")
            out_json = gr.JSON(label="Detections (class/conf/bbox)")
            out_prob = gr.Textbox(label="Max rockfall probability")

    run_btn.click(fn=predict_gradio, inputs=[inp, conf, imgsz, return_mask, gr.State(True)], outputs=[out_img, out_mask, out_json, out_prob])

# launch
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=False)