| from __future__ import annotations |
|
|
| import os |
| from functools import lru_cache |
| from pathlib import Path |
|
|
| import gradio as gr |
| from PIL import Image, ImageDraw, ImageFont |
| from ultralytics import YOLO |
|
|
|
|
| ROOT = Path(__file__).resolve().parent |
| MODEL_PATH = ROOT / "models" / "best.pt" |
| MODEL_URL = "https://huggingface.co/DefendIntelligence/vessel-detection/resolve/main/models/best.pt" |
| EXAMPLES_DIR = ROOT / "examples" |
| MAIN_EXAMPLE_PATH = EXAMPLES_DIR / "example-00-multi-vessel-patch.png" |
| MAX_TILES = 196 |
| BATCH_SIZE = 8 |
|
|
|
|
| @lru_cache(maxsize=1) |
| def load_model() -> YOLO: |
| if not MODEL_PATH.exists(): |
| raise FileNotFoundError( |
| f"Model not found: {MODEL_PATH}. Run `python run_local.py` or download it from {MODEL_URL}." |
| ) |
| return YOLO(str(MODEL_PATH)) |
|
|
|
|
| def _tile_starts(length: int, tile_size: int, overlap: int) -> list[int]: |
| if length <= tile_size: |
| return [0] |
| stride = max(1, tile_size - overlap) |
| starts = list(range(0, max(1, length - tile_size + 1), stride)) |
| last = length - tile_size |
| if starts[-1] != last: |
| starts.append(last) |
| return starts |
|
|
|
|
| def _iter_tiles(image: Image.Image, tile_size: int, overlap: int) -> list[tuple[Image.Image, int, int]]: |
| width, height = image.size |
| x_starts = _tile_starts(width, tile_size, overlap) |
| y_starts = _tile_starts(height, tile_size, overlap) |
| tiles: list[tuple[Image.Image, int, int]] = [] |
| for y in y_starts: |
| for x in x_starts: |
| right = min(width, x + tile_size) |
| bottom = min(height, y + tile_size) |
| tiles.append((image.crop((x, y, right, bottom)), x, y)) |
| return tiles |
|
|
|
|
| def _box_iou(a: list[float], b: list[float]) -> float: |
| ax1, ay1, ax2, ay2 = a |
| bx1, by1, bx2, by2 = b |
| inter_x1 = max(ax1, bx1) |
| inter_y1 = max(ay1, by1) |
| inter_x2 = min(ax2, bx2) |
| inter_y2 = min(ay2, by2) |
| inter_w = max(0.0, inter_x2 - inter_x1) |
| inter_h = max(0.0, inter_y2 - inter_y1) |
| inter_area = inter_w * inter_h |
| if inter_area <= 0: |
| return 0.0 |
| area_a = max(0.0, ax2 - ax1) * max(0.0, ay2 - ay1) |
| area_b = max(0.0, bx2 - bx1) * max(0.0, by2 - by1) |
| union = area_a + area_b - inter_area |
| return inter_area / union if union > 0 else 0.0 |
|
|
|
|
| def _nms(detections: list[dict], iou_threshold: float) -> list[dict]: |
| remaining = sorted(detections, key=lambda item: float(item["confidence"]), reverse=True) |
| kept: list[dict] = [] |
| while remaining: |
| current = remaining.pop(0) |
| kept.append(current) |
| remaining = [ |
| item |
| for item in remaining |
| if item["class_id"] != current["class_id"] |
| or _box_iou(item["box"], current["box"]) < iou_threshold |
| ] |
| return kept |
|
|
|
|
| def _model_names(model: YOLO) -> dict[int, str]: |
| names = getattr(model, "names", None) or {} |
| if isinstance(names, dict): |
| return {int(key): str(value) for key, value in names.items()} |
| return {index: str(name) for index, name in enumerate(names)} |
|
|
|
|
| def _predict_tiles( |
| image: Image.Image, |
| *, |
| confidence: float, |
| iou: float, |
| tile_size: int, |
| overlap: int, |
| max_det: int, |
| ) -> tuple[list[dict], int]: |
| model = load_model() |
| names = _model_names(model) |
| rgb_image = image.convert("RGB") |
| safe_tile_size = max(320, int(tile_size)) |
| safe_overlap = max(0, min(int(overlap), safe_tile_size - 32)) |
| tiles = _iter_tiles(rgb_image, safe_tile_size, safe_overlap) |
|
|
| if len(tiles) > MAX_TILES: |
| raise ValueError( |
| f"Image too large for this CPU Space: {len(tiles)} tiles. " |
| f"Resize the image or increase the tile size." |
| ) |
|
|
| detections: list[dict] = [] |
| for start in range(0, len(tiles), BATCH_SIZE): |
| batch = tiles[start : start + BATCH_SIZE] |
| batch_images = [tile for tile, _, _ in batch] |
| results = model.predict( |
| source=batch_images, |
| conf=float(confidence), |
| iou=float(iou), |
| imgsz=safe_tile_size, |
| max_det=int(max_det), |
| verbose=False, |
| ) |
| for result, (_, offset_x, offset_y) in zip(results, batch): |
| boxes = getattr(result, "boxes", None) |
| if boxes is None or len(boxes) == 0: |
| continue |
| xyxy = boxes.xyxy.cpu().numpy() |
| confs = boxes.conf.cpu().numpy() |
| classes = boxes.cls.cpu().numpy().astype(int) |
| for box, score, class_id in zip(xyxy, confs, classes): |
| x1, y1, x2, y2 = box.tolist() |
| detections.append( |
| { |
| "label": names.get(int(class_id), f"class_{int(class_id)}"), |
| "class_id": int(class_id), |
| "confidence": float(score), |
| "box": [ |
| float(x1 + offset_x), |
| float(y1 + offset_y), |
| float(x2 + offset_x), |
| float(y2 + offset_y), |
| ], |
| } |
| ) |
|
|
| detections = _nms(detections, float(iou)) |
| detections = detections[: int(max_det)] |
| return detections, len(tiles) |
|
|
|
|
| def _draw_detections(image: Image.Image, detections: list[dict]) -> Image.Image: |
| annotated = image.convert("RGB").copy() |
| draw = ImageDraw.Draw(annotated) |
| font = ImageFont.load_default() |
| line_width = max(2, round(max(annotated.size) / 420)) |
|
|
| for detection in detections: |
| x1, y1, x2, y2 = detection["box"] |
| label = f"{detection['label']} {detection['confidence']:.2f}" |
| draw.rectangle((x1, y1, x2, y2), outline=(255, 64, 48), width=line_width) |
| text_box = draw.textbbox((x1, y1), label, font=font) |
| text_w = text_box[2] - text_box[0] |
| text_h = text_box[3] - text_box[1] |
| label_y = max(0, y1 - text_h - 6) |
| draw.rectangle((x1, label_y, x1 + text_w + 8, label_y + text_h + 6), fill=(255, 64, 48)) |
| draw.text((x1 + 4, label_y + 3), label, fill=(255, 255, 255), font=font) |
|
|
| return annotated |
|
|
|
|
| def _table_rows(detections: list[dict]) -> list[list[object]]: |
| rows: list[list[object]] = [] |
| for index, detection in enumerate(detections, start=1): |
| x1, y1, x2, y2 = detection["box"] |
| rows.append( |
| [ |
| index, |
| detection["label"], |
| round(float(detection["confidence"]), 4), |
| round(x1, 1), |
| round(y1, 1), |
| round(x2, 1), |
| round(y2, 1), |
| round(x2 - x1, 1), |
| round(y2 - y1, 1), |
| ] |
| ) |
| return rows |
|
|
|
|
| def detect_boats( |
| image: Image.Image | None, |
| confidence: float, |
| iou: float, |
| tile_size: int, |
| overlap: int, |
| max_det: int, |
| ) -> tuple[Image.Image | None, list[list[object]], str]: |
| if image is None: |
| return None, [], "Upload a satellite image to run detection." |
|
|
| try: |
| detections, tile_count = _predict_tiles( |
| image, |
| confidence=confidence, |
| iou=iou, |
| tile_size=tile_size, |
| overlap=overlap, |
| max_det=max_det, |
| ) |
| except Exception as exc: |
| return image, [], f"Inference error: {exc}" |
|
|
| annotated = _draw_detections(image, detections) |
| rows = _table_rows(detections) |
| if detections: |
| summary = f"{len(detections)} detection(s) above {confidence:.2f}. Tiles analyzed: {tile_count}." |
| else: |
| summary = f"No detections above {confidence:.2f}. Tiles analyzed: {tile_count}." |
| return annotated, rows, summary |
|
|
|
|
| def _example_paths() -> list[list[str]]: |
| paths = sorted(EXAMPLES_DIR.glob("*.png")) |
| return [[str(path)] for path in paths[:10]] |
|
|
|
|
| with gr.Blocks(title="Vessel Detection") as demo: |
| gr.Markdown( |
| """ |
| # Vessel Detection |
| |
| Fine-tuned YOLOv8s model for detecting vessels in RGB satellite imagery. |
| Upload a satellite image or select an example, then run detection. |
| """ |
| ) |
| with gr.Row(): |
| with gr.Column(scale=1): |
| image_input = gr.Image( |
| value=str(MAIN_EXAMPLE_PATH) if MAIN_EXAMPLE_PATH.exists() else None, |
| type="pil", |
| label="Satellite image", |
| ) |
| confidence_input = gr.Slider(0.01, 0.95, value=0.20, step=0.01, label="Confidence threshold") |
| iou_input = gr.Slider(0.05, 0.90, value=0.45, step=0.05, label="IoU NMS") |
| tile_size_input = gr.Slider(320, 1024, value=640, step=32, label="Tile size") |
| overlap_input = gr.Slider(0, 256, value=96, step=16, label="Tile overlap") |
| max_det_input = gr.Slider(1, 200, value=80, step=1, label="Max detections") |
| run_button = gr.Button("Detect vessels", variant="primary") |
| with gr.Column(scale=1): |
| output_image = gr.Image(type="pil", label="Annotated image") |
| summary_output = gr.Markdown() |
| table_output = gr.Dataframe( |
| headers=["#", "label", "confidence", "x1", "y1", "x2", "y2", "width", "height"], |
| datatype=["number", "str", "number", "number", "number", "number", "number", "number", "number"], |
| label="Detections", |
| ) |
|
|
| run_button.click( |
| fn=detect_boats, |
| inputs=[image_input, confidence_input, iou_input, tile_size_input, overlap_input, max_det_input], |
| outputs=[output_image, table_output, summary_output], |
| ) |
|
|
| gr.Examples( |
| examples=_example_paths(), |
| inputs=[image_input], |
| label="Example images", |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| launch_kwargs = {} |
| if os.environ.get("GRADIO_SERVER_NAME"): |
| launch_kwargs["server_name"] = os.environ["GRADIO_SERVER_NAME"] |
| if os.environ.get("GRADIO_SERVER_PORT"): |
| launch_kwargs["server_port"] = int(os.environ["GRADIO_SERVER_PORT"]) |
| demo.launch(**launch_kwargs) |
|
|