drone / app.py
rishab1090's picture
Update app.py
b629612 verified
import numpy as np
import cv2
import gradio as gr
from pathlib import Path
from ultralytics import YOLO
# ------------ CONFIG ------------
MODEL_PATH = Path(__file__).parent / "best.pt"
if not MODEL_PATH.exists():
raise FileNotFoundError(f"Model not found at {MODEL_PATH}. Place your .pt next to app.py or change MODEL_PATH.")
# Load model once
model = YOLO(str(MODEL_PATH))
# ------------ HELPERS ------------
def build_mask_from_result(res, H, W):
"""
Return a single-channel uint8 mask HxW (0 background, 255 rockfall).
If res.masks has polygons (res.masks.xy), rasterize them.
Otherwise, fill bounding boxes.
"""
mask = np.zeros((H, W), dtype=np.uint8)
# Try segmentation polygons first (res.masks.xy is list of polygons for each instance)
try:
if hasattr(res, "masks") and res.masks is not None and hasattr(res.masks, "xy") and len(res.masks.xy) > 0:
# res.masks.xy is a list of lists of (x,y) points per instance
for poly in res.masks.xy:
# poly might be list of floats [x1,y1,x2,y2,...] or list of points
pts = np.array(poly, dtype=np.int32).reshape(-1, 2)
cv2.fillPoly(mask, [pts], 255)
return mask
except Exception:
# fallback to boxes
pass
# fallback: use bounding boxes
try:
if hasattr(res, "boxes") and len(res.boxes) > 0:
xyxy = res.boxes.xyxy.cpu().numpy()
for b in xyxy:
x1,y1,x2,y2 = map(int, b)
# clip
x1, y1 = max(0,x1), max(0,y1)
x2, y2 = min(W-1,x2), min(H-1,y2)
if x2 > x1 and y2 > y1:
cv2.rectangle(mask, (x1,y1), (x2,y2), 255, thickness=-1)
except Exception:
pass
return mask
def detections_from_result(res):
"""
Return list of detections dicts: {class_id, class_name, confidence, bbox}
"""
detections = []
if hasattr(res, "boxes") and len(res.boxes) > 0:
xyxy = res.boxes.xyxy.cpu().numpy()
confs = res.boxes.conf.cpu().numpy()
cls = res.boxes.cls.cpu().numpy().astype(int)
for b, c, cl in zip(xyxy, confs, cls):
x1,y1,x2,y2 = map(int, b)
detections.append({
"class_id": int(cl),
"class_name": model.names[int(cl)] if int(cl) in model.names else str(int(cl)),
"confidence": float(c),
"bbox": [x1,y1,x2,y2]
})
return detections
def annotate_image_from_result(res, img):
"""
res.plot() from Ultralytics returns an RGB numpy image already annotated by the library.
Use it directly (no BGR/RGB conversion).
If res.plot() fails, draw boxes ourselves from res.boxes.
"""
try:
annotated = res.plot() # should be RGB HWC numpy
# ensure dtype uint8
if annotated.dtype != np.uint8:
annotated = (annotated * 255).astype(np.uint8)
return annotated
except Exception:
# fallback: draw boxes on copy of original (which may be numpy RGB passed to predict)
out = img.copy()
dets = detections_from_result(res)
for d in dets:
x1,y1,x2,y2 = d["bbox"]
label = f"{d['class_name']} {d['confidence']:.2f}"
cv2.rectangle(out, (x1,y1), (x2,y2), (0,255,0), 2)
(tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
cv2.rectangle(out, (x1, y1 - th - 6), (x1 + tw + 6, y1), (0,255,0), -1)
cv2.putText(out, label, (x1+3, y1-4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 1, cv2.LINE_AA)
return out
# ------------ PREDICT FUNCTION ------------
def predict_gradio(image, conf=0.25, imgsz=640, return_mask=False, return_probability=True):
"""
image: numpy (HWC, RGB) provided by Gradio Image component (type="numpy")
returns: annotated_image (RGB numpy), optional mask (HWC or 2D), detections (json), probability (float)
"""
if image is None:
return None, None, [], 0.0
# Ultralytics accepts RGB numpy images; Gradio provides numpy in RGB already.
# Run model
results = model.predict(source=image, conf=float(conf), imgsz=int(imgsz), verbose=False)
res = results[0]
H, W = image.shape[0], image.shape[1]
# get detections and probability
detections = detections_from_result(res)
probability = max([d["confidence"] for d in detections], default=0.0)
# annotated image: use res.plot() (RGB)
annotated = annotate_image_from_result(res, image)
mask_out = None
if return_mask:
mask = build_mask_from_result(res, H, W) # single-channel 0/255 uint8
# Gradio Image expects either HxW (grayscale) or HxWx3; we return grayscale
mask_out = mask
# Return ordering depends on Gradio outputs setup. We will return (annotated, mask, detections, probability)
return annotated, mask_out, detections, float(probability)
# ------------ GRADIO UI ------------
with gr.Blocks(title="Rockfall Detection (YOLOv8)") as demo:
gr.Markdown("# 🪨 Rockfall Detection (YOLOv8)")
with gr.Row():
with gr.Column(scale=1):
inp = gr.Image(type="numpy", label="Upload mine image (RGB)")
conf = gr.Slider(0.05, 0.9, value=0.25, step=0.01, label="Confidence threshold")
imgsz = gr.Slider(320, 1280, value=640, step=32, label="Inference image size")
return_mask = gr.Checkbox(False, label="Return binary mask (0/255)")
run_btn = gr.Button("Run")
with gr.Column(scale=1):
out_img = gr.Image(type="numpy", label="Annotated image (RGB)")
out_mask = gr.Image(type="numpy", label="Binary mask (if requested)")
out_json = gr.JSON(label="Detections (class/conf/bbox)")
out_prob = gr.Textbox(label="Max rockfall probability")
run_btn.click(fn=predict_gradio, inputs=[inp, conf, imgsz, return_mask, gr.State(True)], outputs=[out_img, out_mask, out_json, out_prob])
# launch
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)