# import gradio as gr # import cv2 # import numpy as np # from detectron2.config import get_cfg # from detectron2.engine import DefaultPredictor # from detectron2.utils.visualizer import Visualizer, ColorMode # from detectron2.data import MetadataCatalog # from huggingface_hub import hf_hub_download # import os # REPO_ID = os.getenv("MODEL_REPO_ID", "PUSHPENDAR/hrsid-ship-detection") # os.makedirs("/app/hf_cache", exist_ok=True) # print("Downloading model files...") # MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="model_final.pth", cache_dir="/app/hf_cache") # CONFIG_PATH = hf_hub_download(repo_id=REPO_ID, filename="config.yaml", cache_dir="/app/hf_cache") # print(f"Model: {MODEL_PATH} ✅") # print(f"Config: {CONFIG_PATH} ✅") # print("Loading Faster R-CNN model...") # cfg = get_cfg() # cfg.merge_from_file(CONFIG_PATH) # cfg.MODEL.WEIGHTS = MODEL_PATH # cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # cfg.MODEL.DEVICE = "cpu" # MetadataCatalog.get("__unused").set(thing_classes=["ship"]) # predictor = DefaultPredictor(cfg) # print("Model loaded ✅") # def detect_ships(image, confidence_threshold): # if image is None: # return None, "Please upload an image." # cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold # img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # outputs = predictor(img_bgr) # instances = outputs["instances"].to("cpu") # keep = instances.scores >= confidence_threshold # instances = instances[keep] # metadata = MetadataCatalog.get("__unused") # v = Visualizer(img_bgr[:, :, ::-1], metadata=metadata, scale=1.0, instance_mode=ColorMode.IMAGE) # out = v.draw_instance_predictions(instances) # result_img = out.get_image() # num_ships = len(instances) # scores = instances.scores.tolist() # info = f"✅ Detected {num_ships} ship(s)\n" # if scores: # info += "Confidence scores: " + ", ".join([f"{s:.2f}" for s in scores]) # if hasattr(instances, "pred_boxes"): # boxes = instances.pred_boxes.tensor.tolist() # info += "\n\nBounding boxes (x1,y1,x2,y2):\n" # for i, (box, score) in enumerate(zip(boxes, scores)): # x1, y1, x2, y2 = [int(v) for v in box] # info += f" Ship {i+1}: [{x1},{y1},{x2},{y2}] conf={score:.2f}\n" # else: # info += "No ships detected above threshold." # return result_img, info # with gr.Blocks(title="🚢 HRSID Ship Detection") as demo: # gr.Markdown("# 🚢 HRSID Ship Detection") # gr.Markdown("Upload a SAR image to detect ships using Faster R-CNN with ResNet-101, trained on HRSID dataset.") # with gr.Row(): # with gr.Column(): # image_input = gr.Image(type="pil", label="Upload SAR Image") # threshold = gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Confidence Threshold") # btn = gr.Button("Detect Ships", variant="primary") # with gr.Column(): # image_output = gr.Image(type="numpy", label="Detection Result") # info_output = gr.Textbox(label="Detection Info", lines=10) # btn.click(fn=detect_ships, inputs=[image_input, threshold], outputs=[image_output, info_output]) # if __name__ == "__main__": # demo.launch(server_name="0.0.0.0", server_port=7860) import os import tempfile from copy import deepcopy import cv2 import gradio as gr import numpy as np from detectron2.config import get_cfg from detectron2.data import MetadataCatalog from detectron2.engine import DefaultPredictor from detectron2.utils.visualizer import ColorMode, Visualizer from huggingface_hub import hf_hub_download # ── Model loading ──────────────────────────────────────────────────────────── REPO_ID = os.getenv("MODEL_REPO_ID", "PUSHPENDAR/hrsid-ship-detection") os.makedirs("/app/hf_cache", exist_ok=True) print("Downloading model files...") MODEL_PATH = hf_hub_download( repo_id=REPO_ID, filename="model_final.pth", cache_dir="/app/hf_cache", token=os.getenv("HF_TOKEN"), # uses secret if set, else None (public repos) ) CONFIG_PATH = hf_hub_download( repo_id=REPO_ID, filename="config.yaml", cache_dir="/app/hf_cache", token=os.getenv("HF_TOKEN"), ) print(f"Model: {MODEL_PATH} ✅") print(f"Config: {CONFIG_PATH} ✅") print("Loading Faster R-CNN model...") _base_cfg = get_cfg() _base_cfg.merge_from_file(CONFIG_PATH) _base_cfg.MODEL.WEIGHTS = MODEL_PATH _base_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 _base_cfg.MODEL.DEVICE = "cpu" _base_cfg.freeze() # make it immutable so we always deepcopy before mutating MetadataCatalog.get("__unused").set(thing_classes=["ship"]) print("Model loaded ✅") # ── Helpers ────────────────────────────────────────────────────────────────── def get_predictor(confidence_threshold: float) -> DefaultPredictor: """Return a fresh predictor with the requested threshold. deepcopy avoids mutating the global frozen cfg across concurrent requests. """ cfg = deepcopy(_base_cfg) cfg.defrost() cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold return DefaultPredictor(cfg) def run_inference(img_bgr: np.ndarray, confidence_threshold: float): """Run detection on a single BGR frame. Returns (result_bgr, instances).""" predictor = get_predictor(confidence_threshold) outputs = predictor(img_bgr) instances = outputs["instances"].to("cpu") instances = instances[instances.scores >= confidence_threshold] metadata = MetadataCatalog.get("__unused") v = Visualizer( img_bgr[:, :, ::-1], metadata=metadata, scale=1.0, instance_mode=ColorMode.IMAGE, ) out = v.draw_instance_predictions(instances) result_rgb = out.get_image() # H×W×3 RGB result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) return result_bgr, instances def build_info(instances) -> str: num = len(instances) scores = instances.scores.tolist() info = f"✅ Detected {num} ship(s)\n" if scores: info += "Confidence scores: " + ", ".join([f"{s:.2f}" for s in scores]) if hasattr(instances, "pred_boxes"): boxes = instances.pred_boxes.tensor.tolist() info += "\n\nBounding boxes (x1,y1,x2,y2):\n" for i, (box, score) in enumerate(zip(boxes, scores)): x1, y1, x2, y2 = [int(c) for c in box] info += f" Ship {i+1}: [{x1},{y1},{x2},{y2}] conf={score:.2f}\n" else: info += "No ships detected above threshold." return info # ── Image tab ──────────────────────────────────────────────────────────────── def detect_ships_image(image, confidence_threshold): if image is None: return None, "Please upload an image." img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) result_bgr, inst = run_inference(img_bgr, confidence_threshold) result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB) return result_rgb, build_info(inst) # ── Video tab ──────────────────────────────────────────────────────────────── def detect_ships_video(video_path, confidence_threshold, progress=gr.Progress()): if video_path is None: return None, "Please upload a video." cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return None, "Could not open video file." total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) or 25 w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) out_path = out_file.name out_file.close() fourcc = cv2.VideoWriter_fourcc(*"mp4v") writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h)) frame_idx = 0 total_ships = 0 max_per_frame = 0 while True: ret, frame = cap.read() if not ret: break result_bgr, inst = run_inference(frame, confidence_threshold) writer.write(result_bgr) n = len(inst) total_ships += n max_per_frame = max(max_per_frame, n) frame_idx += 1 if total_frames > 0: progress( frame_idx / total_frames, desc=f"Processing frame {frame_idx}/{total_frames}", ) cap.release() writer.release() info = ( f"✅ Video processed: {frame_idx} frames\n" f"Total ship detections across all frames: {total_ships}\n" f"Peak ships in a single frame: {max_per_frame}\n" f"FPS: {fps:.1f} | Resolution: {w}×{h}" ) return out_path, info # ── UI ─────────────────────────────────────────────────────────────────────── with gr.Blocks(title="🚢 HRSID Ship Detection") as demo: gr.Markdown("# 🚢 HRSID Ship Detection") gr.Markdown( "Detect ships in SAR images **or videos** using " "Faster R-CNN with ResNet-101, trained on the HRSID dataset." ) with gr.Tabs(): with gr.Tab("🖼️ Image Detection"): with gr.Row(): with gr.Column(): img_input = gr.Image(type="pil", label="Upload SAR Image") img_thresh = gr.Slider( 0.1, 0.9, value=0.5, step=0.05, label="Confidence Threshold" ) img_btn = gr.Button("Detect Ships", variant="primary") with gr.Column(): img_output = gr.Image(type="numpy", label="Detection Result") img_info = gr.Textbox(label="Detection Info", lines=10) img_btn.click( fn=detect_ships_image, inputs=[img_input, img_thresh], outputs=[img_output, img_info], ) with gr.Tab("🎥 Video Detection"): gr.Markdown( "> ⚠️ CPU inference is slow. Short clips (< 30 s) are recommended." ) with gr.Row(): with gr.Column(): vid_input = gr.Video(label="Upload SAR Video") vid_thresh = gr.Slider( 0.1, 0.9, value=0.5, step=0.05, label="Confidence Threshold" ) vid_btn = gr.Button("Detect Ships in Video", variant="primary") with gr.Column(): vid_output = gr.Video(label="Detection Result Video") vid_info = gr.Textbox(label="Detection Summary", lines=8) vid_btn.click( fn=detect_ships_video, inputs=[vid_input, vid_thresh], outputs=[vid_output, vid_info], ) if __name__ == "__main__": demo.queue() demo.launch(server_name="0.0.0.0", server_port=7860) # NO share=True