File size: 2,361 Bytes
2e01643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb55219
 
 
 
 
 
 
e3951d6
cb55219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch, subprocess, sys

print("CUDA available:", torch.cuda.is_available())

try:
    import detectron2
except:
    print("Installing detectron2...")
    subprocess.check_call([
        sys.executable, "-m", "pip", "install",
        "git+https://github.com/facebookresearch/detectron2.git"
    ])
    import detectron2

print("Detectron2 ready")
import gradio as gr
import cv2
import numpy as np
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import MetadataCatalog


# Load model
cfg = get_cfg()
cfg.merge_from_file("config.yaml")
cfg.MODEL.WEIGHTS = "model_final.pth"
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.DEVICE = "cpu"   # HF Spaces free tier is CPU

MetadataCatalog.get("__unused").set(thing_classes=["ship"])
predictor = DefaultPredictor(cfg)

def detect_ships(image, confidence_threshold):
    """Run ship detection on uploaded image."""
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold

    # Convert PIL → BGR numpy
    img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)

    outputs = predictor(img_bgr)
    instances = outputs["instances"].to("cpu")

    # Filter by threshold
    keep = instances.scores >= confidence_threshold
    instances = instances[keep]

    metadata = MetadataCatalog.get("__unused")
    v = Visualizer(img_bgr[:, :, ::-1], metadata=metadata,
                   scale=1.0, instance_mode=ColorMode.IMAGE)
    out = v.draw_instance_predictions(instances)

    result_img = out.get_image()
    num_ships  = len(instances)
    scores     = instances.scores.tolist()
    info = f"Detected {num_ships} ship(s)\n"
    if scores:
        info += "Confidence scores: " + ", ".join([f"{s:.2f}" for s in scores])

    return result_img, info

demo = gr.Interface(
    fn=detect_ships,
    inputs=[
        gr.Image(type="pil", label="Upload SAR Image"),
        gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Confidence Threshold")
    ],
    outputs=[
        gr.Image(type="numpy", label="Detection Result"),
        gr.Textbox(label="Detection Info")
    ],
    title="🚢 HRSID Ship Detection",
    description="Upload a SAR image to detect ships using Faster R-CNN trained on HRSID dataset.",
    examples=[]
)

if __name__ == "__main__":
    demo.launch()