from __future__ import annotations import os from functools import lru_cache from pathlib import Path import gradio as gr from PIL import Image, ImageDraw, ImageFont from ultralytics import YOLO ROOT = Path(__file__).resolve().parent MODEL_PATH = ROOT / "models" / "best.pt" MODEL_URL = "https://huggingface.co/DefendIntelligence/vessel-detection/resolve/main/models/best.pt" EXAMPLES_DIR = ROOT / "examples" MAIN_EXAMPLE_PATH = EXAMPLES_DIR / "example-00-multi-vessel-patch.png" MAX_TILES = 196 BATCH_SIZE = 8 @lru_cache(maxsize=1) def load_model() -> YOLO: if not MODEL_PATH.exists(): raise FileNotFoundError( f"Model not found: {MODEL_PATH}. Run `python run_local.py` or download it from {MODEL_URL}." ) return YOLO(str(MODEL_PATH)) def _tile_starts(length: int, tile_size: int, overlap: int) -> list[int]: if length <= tile_size: return [0] stride = max(1, tile_size - overlap) starts = list(range(0, max(1, length - tile_size + 1), stride)) last = length - tile_size if starts[-1] != last: starts.append(last) return starts def _iter_tiles(image: Image.Image, tile_size: int, overlap: int) -> list[tuple[Image.Image, int, int]]: width, height = image.size x_starts = _tile_starts(width, tile_size, overlap) y_starts = _tile_starts(height, tile_size, overlap) tiles: list[tuple[Image.Image, int, int]] = [] for y in y_starts: for x in x_starts: right = min(width, x + tile_size) bottom = min(height, y + tile_size) tiles.append((image.crop((x, y, right, bottom)), x, y)) return tiles def _box_iou(a: list[float], b: list[float]) -> float: ax1, ay1, ax2, ay2 = a bx1, by1, bx2, by2 = b inter_x1 = max(ax1, bx1) inter_y1 = max(ay1, by1) inter_x2 = min(ax2, bx2) inter_y2 = min(ay2, by2) inter_w = max(0.0, inter_x2 - inter_x1) inter_h = max(0.0, inter_y2 - inter_y1) inter_area = inter_w * inter_h if inter_area <= 0: return 0.0 area_a = max(0.0, ax2 - ax1) * max(0.0, ay2 - ay1) area_b = max(0.0, bx2 - bx1) * max(0.0, by2 - by1) union = area_a + area_b - inter_area return inter_area / union if union > 0 else 0.0 def _nms(detections: list[dict], iou_threshold: float) -> list[dict]: remaining = sorted(detections, key=lambda item: float(item["confidence"]), reverse=True) kept: list[dict] = [] while remaining: current = remaining.pop(0) kept.append(current) remaining = [ item for item in remaining if item["class_id"] != current["class_id"] or _box_iou(item["box"], current["box"]) < iou_threshold ] return kept def _model_names(model: YOLO) -> dict[int, str]: names = getattr(model, "names", None) or {} if isinstance(names, dict): return {int(key): str(value) for key, value in names.items()} return {index: str(name) for index, name in enumerate(names)} def _predict_tiles( image: Image.Image, *, confidence: float, iou: float, tile_size: int, overlap: int, max_det: int, ) -> tuple[list[dict], int]: model = load_model() names = _model_names(model) rgb_image = image.convert("RGB") safe_tile_size = max(320, int(tile_size)) safe_overlap = max(0, min(int(overlap), safe_tile_size - 32)) tiles = _iter_tiles(rgb_image, safe_tile_size, safe_overlap) if len(tiles) > MAX_TILES: raise ValueError( f"Image too large for this CPU Space: {len(tiles)} tiles. " f"Resize the image or increase the tile size." ) detections: list[dict] = [] for start in range(0, len(tiles), BATCH_SIZE): batch = tiles[start : start + BATCH_SIZE] batch_images = [tile for tile, _, _ in batch] results = model.predict( source=batch_images, conf=float(confidence), iou=float(iou), imgsz=safe_tile_size, max_det=int(max_det), verbose=False, ) for result, (_, offset_x, offset_y) in zip(results, batch): boxes = getattr(result, "boxes", None) if boxes is None or len(boxes) == 0: continue xyxy = boxes.xyxy.cpu().numpy() confs = boxes.conf.cpu().numpy() classes = boxes.cls.cpu().numpy().astype(int) for box, score, class_id in zip(xyxy, confs, classes): x1, y1, x2, y2 = box.tolist() detections.append( { "label": names.get(int(class_id), f"class_{int(class_id)}"), "class_id": int(class_id), "confidence": float(score), "box": [ float(x1 + offset_x), float(y1 + offset_y), float(x2 + offset_x), float(y2 + offset_y), ], } ) detections = _nms(detections, float(iou)) detections = detections[: int(max_det)] return detections, len(tiles) def _draw_detections(image: Image.Image, detections: list[dict]) -> Image.Image: annotated = image.convert("RGB").copy() draw = ImageDraw.Draw(annotated) font = ImageFont.load_default() line_width = max(2, round(max(annotated.size) / 420)) for detection in detections: x1, y1, x2, y2 = detection["box"] label = f"{detection['label']} {detection['confidence']:.2f}" draw.rectangle((x1, y1, x2, y2), outline=(255, 64, 48), width=line_width) text_box = draw.textbbox((x1, y1), label, font=font) text_w = text_box[2] - text_box[0] text_h = text_box[3] - text_box[1] label_y = max(0, y1 - text_h - 6) draw.rectangle((x1, label_y, x1 + text_w + 8, label_y + text_h + 6), fill=(255, 64, 48)) draw.text((x1 + 4, label_y + 3), label, fill=(255, 255, 255), font=font) return annotated def _table_rows(detections: list[dict]) -> list[list[object]]: rows: list[list[object]] = [] for index, detection in enumerate(detections, start=1): x1, y1, x2, y2 = detection["box"] rows.append( [ index, detection["label"], round(float(detection["confidence"]), 4), round(x1, 1), round(y1, 1), round(x2, 1), round(y2, 1), round(x2 - x1, 1), round(y2 - y1, 1), ] ) return rows def detect_boats( image: Image.Image | None, confidence: float, iou: float, tile_size: int, overlap: int, max_det: int, ) -> tuple[Image.Image | None, list[list[object]], str]: if image is None: return None, [], "Upload a satellite image to run detection." try: detections, tile_count = _predict_tiles( image, confidence=confidence, iou=iou, tile_size=tile_size, overlap=overlap, max_det=max_det, ) except Exception as exc: return image, [], f"Inference error: {exc}" annotated = _draw_detections(image, detections) rows = _table_rows(detections) if detections: summary = f"{len(detections)} detection(s) above {confidence:.2f}. Tiles analyzed: {tile_count}." else: summary = f"No detections above {confidence:.2f}. Tiles analyzed: {tile_count}." return annotated, rows, summary def _example_paths() -> list[list[str]]: paths = sorted(EXAMPLES_DIR.glob("*.png")) return [[str(path)] for path in paths[:10]] with gr.Blocks(title="Vessel Detection") as demo: gr.Markdown( """ # Vessel Detection Fine-tuned YOLOv8s model for detecting vessels in RGB satellite imagery. Upload a satellite image or select an example, then run detection. """ ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( value=str(MAIN_EXAMPLE_PATH) if MAIN_EXAMPLE_PATH.exists() else None, type="pil", label="Satellite image", ) confidence_input = gr.Slider(0.01, 0.95, value=0.20, step=0.01, label="Confidence threshold") iou_input = gr.Slider(0.05, 0.90, value=0.45, step=0.05, label="IoU NMS") tile_size_input = gr.Slider(320, 1024, value=640, step=32, label="Tile size") overlap_input = gr.Slider(0, 256, value=96, step=16, label="Tile overlap") max_det_input = gr.Slider(1, 200, value=80, step=1, label="Max detections") run_button = gr.Button("Detect vessels", variant="primary") with gr.Column(scale=1): output_image = gr.Image(type="pil", label="Annotated image") summary_output = gr.Markdown() table_output = gr.Dataframe( headers=["#", "label", "confidence", "x1", "y1", "x2", "y2", "width", "height"], datatype=["number", "str", "number", "number", "number", "number", "number", "number", "number"], label="Detections", ) run_button.click( fn=detect_boats, inputs=[image_input, confidence_input, iou_input, tile_size_input, overlap_input, max_det_input], outputs=[output_image, table_output, summary_output], ) gr.Examples( examples=_example_paths(), inputs=[image_input], label="Example images", ) if __name__ == "__main__": launch_kwargs = {} if os.environ.get("GRADIO_SERVER_NAME"): launch_kwargs["server_name"] = os.environ["GRADIO_SERVER_NAME"] if os.environ.get("GRADIO_SERVER_PORT"): launch_kwargs["server_port"] = int(os.environ["GRADIO_SERVER_PORT"]) demo.launch(**launch_kwargs)