SHIP / app.py
PUSHPENDAR's picture
Update app.py
cab7fb3 verified
# import gradio as gr
# import cv2
# import numpy as np
# from detectron2.config import get_cfg
# from detectron2.engine import DefaultPredictor
# from detectron2.utils.visualizer import Visualizer, ColorMode
# from detectron2.data import MetadataCatalog
# from huggingface_hub import hf_hub_download
# import os
# REPO_ID = os.getenv("MODEL_REPO_ID", "PUSHPENDAR/hrsid-ship-detection")
# os.makedirs("/app/hf_cache", exist_ok=True)
# print("Downloading model files...")
# MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="model_final.pth", cache_dir="/app/hf_cache")
# CONFIG_PATH = hf_hub_download(repo_id=REPO_ID, filename="config.yaml", cache_dir="/app/hf_cache")
# print(f"Model: {MODEL_PATH} βœ…")
# print(f"Config: {CONFIG_PATH} βœ…")
# print("Loading Faster R-CNN model...")
# cfg = get_cfg()
# cfg.merge_from_file(CONFIG_PATH)
# cfg.MODEL.WEIGHTS = MODEL_PATH
# cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
# cfg.MODEL.DEVICE = "cpu"
# MetadataCatalog.get("__unused").set(thing_classes=["ship"])
# predictor = DefaultPredictor(cfg)
# print("Model loaded βœ…")
# def detect_ships(image, confidence_threshold):
# if image is None:
# return None, "Please upload an image."
# cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
# img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# outputs = predictor(img_bgr)
# instances = outputs["instances"].to("cpu")
# keep = instances.scores >= confidence_threshold
# instances = instances[keep]
# metadata = MetadataCatalog.get("__unused")
# v = Visualizer(img_bgr[:, :, ::-1], metadata=metadata, scale=1.0, instance_mode=ColorMode.IMAGE)
# out = v.draw_instance_predictions(instances)
# result_img = out.get_image()
# num_ships = len(instances)
# scores = instances.scores.tolist()
# info = f"βœ… Detected {num_ships} ship(s)\n"
# if scores:
# info += "Confidence scores: " + ", ".join([f"{s:.2f}" for s in scores])
# if hasattr(instances, "pred_boxes"):
# boxes = instances.pred_boxes.tensor.tolist()
# info += "\n\nBounding boxes (x1,y1,x2,y2):\n"
# for i, (box, score) in enumerate(zip(boxes, scores)):
# x1, y1, x2, y2 = [int(v) for v in box]
# info += f" Ship {i+1}: [{x1},{y1},{x2},{y2}] conf={score:.2f}\n"
# else:
# info += "No ships detected above threshold."
# return result_img, info
# with gr.Blocks(title="🚒 HRSID Ship Detection") as demo:
# gr.Markdown("# 🚒 HRSID Ship Detection")
# gr.Markdown("Upload a SAR image to detect ships using Faster R-CNN with ResNet-101, trained on HRSID dataset.")
# with gr.Row():
# with gr.Column():
# image_input = gr.Image(type="pil", label="Upload SAR Image")
# threshold = gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Confidence Threshold")
# btn = gr.Button("Detect Ships", variant="primary")
# with gr.Column():
# image_output = gr.Image(type="numpy", label="Detection Result")
# info_output = gr.Textbox(label="Detection Info", lines=10)
# btn.click(fn=detect_ships, inputs=[image_input, threshold], outputs=[image_output, info_output])
# if __name__ == "__main__":
# demo.launch(server_name="0.0.0.0", server_port=7860)
import os
import tempfile
from copy import deepcopy
import cv2
import gradio as gr
import numpy as np
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import ColorMode, Visualizer
from huggingface_hub import hf_hub_download
# ── Model loading ────────────────────────────────────────────────────────────
REPO_ID = os.getenv("MODEL_REPO_ID", "PUSHPENDAR/hrsid-ship-detection")
os.makedirs("/app/hf_cache", exist_ok=True)
print("Downloading model files...")
MODEL_PATH = hf_hub_download(
repo_id=REPO_ID,
filename="model_final.pth",
cache_dir="/app/hf_cache",
token=os.getenv("HF_TOKEN"), # uses secret if set, else None (public repos)
)
CONFIG_PATH = hf_hub_download(
repo_id=REPO_ID,
filename="config.yaml",
cache_dir="/app/hf_cache",
token=os.getenv("HF_TOKEN"),
)
print(f"Model: {MODEL_PATH} βœ…")
print(f"Config: {CONFIG_PATH} βœ…")
print("Loading Faster R-CNN model...")
_base_cfg = get_cfg()
_base_cfg.merge_from_file(CONFIG_PATH)
_base_cfg.MODEL.WEIGHTS = MODEL_PATH
_base_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
_base_cfg.MODEL.DEVICE = "cpu"
_base_cfg.freeze() # make it immutable so we always deepcopy before mutating
MetadataCatalog.get("__unused").set(thing_classes=["ship"])
print("Model loaded βœ…")
# ── Helpers ──────────────────────────────────────────────────────────────────
def get_predictor(confidence_threshold: float) -> DefaultPredictor:
"""Return a fresh predictor with the requested threshold.
deepcopy avoids mutating the global frozen cfg across concurrent requests.
"""
cfg = deepcopy(_base_cfg)
cfg.defrost()
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
return DefaultPredictor(cfg)
def run_inference(img_bgr: np.ndarray, confidence_threshold: float):
"""Run detection on a single BGR frame. Returns (result_bgr, instances)."""
predictor = get_predictor(confidence_threshold)
outputs = predictor(img_bgr)
instances = outputs["instances"].to("cpu")
instances = instances[instances.scores >= confidence_threshold]
metadata = MetadataCatalog.get("__unused")
v = Visualizer(
img_bgr[:, :, ::-1],
metadata=metadata,
scale=1.0,
instance_mode=ColorMode.IMAGE,
)
out = v.draw_instance_predictions(instances)
result_rgb = out.get_image() # HΓ—WΓ—3 RGB
result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
return result_bgr, instances
def build_info(instances) -> str:
num = len(instances)
scores = instances.scores.tolist()
info = f"βœ… Detected {num} ship(s)\n"
if scores:
info += "Confidence scores: " + ", ".join([f"{s:.2f}" for s in scores])
if hasattr(instances, "pred_boxes"):
boxes = instances.pred_boxes.tensor.tolist()
info += "\n\nBounding boxes (x1,y1,x2,y2):\n"
for i, (box, score) in enumerate(zip(boxes, scores)):
x1, y1, x2, y2 = [int(c) for c in box]
info += f" Ship {i+1}: [{x1},{y1},{x2},{y2}] conf={score:.2f}\n"
else:
info += "No ships detected above threshold."
return info
# ── Image tab ────────────────────────────────────────────────────────────────
def detect_ships_image(image, confidence_threshold):
if image is None:
return None, "Please upload an image."
img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
result_bgr, inst = run_inference(img_bgr, confidence_threshold)
result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB)
return result_rgb, build_info(inst)
# ── Video tab ────────────────────────────────────────────────────────────────
def detect_ships_video(video_path, confidence_threshold, progress=gr.Progress()):
if video_path is None:
return None, "Please upload a video."
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, "Could not open video file."
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS) or 25
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
out_path = out_file.name
out_file.close()
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
frame_idx = 0
total_ships = 0
max_per_frame = 0
while True:
ret, frame = cap.read()
if not ret:
break
result_bgr, inst = run_inference(frame, confidence_threshold)
writer.write(result_bgr)
n = len(inst)
total_ships += n
max_per_frame = max(max_per_frame, n)
frame_idx += 1
if total_frames > 0:
progress(
frame_idx / total_frames,
desc=f"Processing frame {frame_idx}/{total_frames}",
)
cap.release()
writer.release()
info = (
f"βœ… Video processed: {frame_idx} frames\n"
f"Total ship detections across all frames: {total_ships}\n"
f"Peak ships in a single frame: {max_per_frame}\n"
f"FPS: {fps:.1f} | Resolution: {w}Γ—{h}"
)
return out_path, info
# ── UI ───────────────────────────────────────────────────────────────────────
with gr.Blocks(title="🚒 HRSID Ship Detection") as demo:
gr.Markdown("# 🚒 HRSID Ship Detection")
gr.Markdown(
"Detect ships in SAR images **or videos** using "
"Faster R-CNN with ResNet-101, trained on the HRSID dataset."
)
with gr.Tabs():
with gr.Tab("πŸ–ΌοΈ Image Detection"):
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Upload SAR Image")
img_thresh = gr.Slider(
0.1, 0.9, value=0.5, step=0.05, label="Confidence Threshold"
)
img_btn = gr.Button("Detect Ships", variant="primary")
with gr.Column():
img_output = gr.Image(type="numpy", label="Detection Result")
img_info = gr.Textbox(label="Detection Info", lines=10)
img_btn.click(
fn=detect_ships_image,
inputs=[img_input, img_thresh],
outputs=[img_output, img_info],
)
with gr.Tab("πŸŽ₯ Video Detection"):
gr.Markdown(
"> ⚠️ CPU inference is slow. Short clips (< 30 s) are recommended."
)
with gr.Row():
with gr.Column():
vid_input = gr.Video(label="Upload SAR Video")
vid_thresh = gr.Slider(
0.1, 0.9, value=0.5, step=0.05, label="Confidence Threshold"
)
vid_btn = gr.Button("Detect Ships in Video", variant="primary")
with gr.Column():
vid_output = gr.Video(label="Detection Result Video")
vid_info = gr.Textbox(label="Detection Summary", lines=8)
vid_btn.click(
fn=detect_ships_video,
inputs=[vid_input, vid_thresh],
outputs=[vid_output, vid_info],
)
if __name__ == "__main__":
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860) # NO share=True