Spaces:
Paused
Paused
| # import gradio as gr | |
| # import cv2 | |
| # import numpy as np | |
| # from detectron2.config import get_cfg | |
| # from detectron2.engine import DefaultPredictor | |
| # from detectron2.utils.visualizer import Visualizer, ColorMode | |
| # from detectron2.data import MetadataCatalog | |
| # from huggingface_hub import hf_hub_download | |
| # import os | |
| # REPO_ID = os.getenv("MODEL_REPO_ID", "PUSHPENDAR/hrsid-ship-detection") | |
| # os.makedirs("/app/hf_cache", exist_ok=True) | |
| # print("Downloading model files...") | |
| # MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="model_final.pth", cache_dir="/app/hf_cache") | |
| # CONFIG_PATH = hf_hub_download(repo_id=REPO_ID, filename="config.yaml", cache_dir="/app/hf_cache") | |
| # print(f"Model: {MODEL_PATH} β ") | |
| # print(f"Config: {CONFIG_PATH} β ") | |
| # print("Loading Faster R-CNN model...") | |
| # cfg = get_cfg() | |
| # cfg.merge_from_file(CONFIG_PATH) | |
| # cfg.MODEL.WEIGHTS = MODEL_PATH | |
| # cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
| # cfg.MODEL.DEVICE = "cpu" | |
| # MetadataCatalog.get("__unused").set(thing_classes=["ship"]) | |
| # predictor = DefaultPredictor(cfg) | |
| # print("Model loaded β ") | |
| # def detect_ships(image, confidence_threshold): | |
| # if image is None: | |
| # return None, "Please upload an image." | |
| # cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold | |
| # img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| # outputs = predictor(img_bgr) | |
| # instances = outputs["instances"].to("cpu") | |
| # keep = instances.scores >= confidence_threshold | |
| # instances = instances[keep] | |
| # metadata = MetadataCatalog.get("__unused") | |
| # v = Visualizer(img_bgr[:, :, ::-1], metadata=metadata, scale=1.0, instance_mode=ColorMode.IMAGE) | |
| # out = v.draw_instance_predictions(instances) | |
| # result_img = out.get_image() | |
| # num_ships = len(instances) | |
| # scores = instances.scores.tolist() | |
| # info = f"β Detected {num_ships} ship(s)\n" | |
| # if scores: | |
| # info += "Confidence scores: " + ", ".join([f"{s:.2f}" for s in scores]) | |
| # if hasattr(instances, "pred_boxes"): | |
| # boxes = instances.pred_boxes.tensor.tolist() | |
| # info += "\n\nBounding boxes (x1,y1,x2,y2):\n" | |
| # for i, (box, score) in enumerate(zip(boxes, scores)): | |
| # x1, y1, x2, y2 = [int(v) for v in box] | |
| # info += f" Ship {i+1}: [{x1},{y1},{x2},{y2}] conf={score:.2f}\n" | |
| # else: | |
| # info += "No ships detected above threshold." | |
| # return result_img, info | |
| # with gr.Blocks(title="π’ HRSID Ship Detection") as demo: | |
| # gr.Markdown("# π’ HRSID Ship Detection") | |
| # gr.Markdown("Upload a SAR image to detect ships using Faster R-CNN with ResNet-101, trained on HRSID dataset.") | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # image_input = gr.Image(type="pil", label="Upload SAR Image") | |
| # threshold = gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Confidence Threshold") | |
| # btn = gr.Button("Detect Ships", variant="primary") | |
| # with gr.Column(): | |
| # image_output = gr.Image(type="numpy", label="Detection Result") | |
| # info_output = gr.Textbox(label="Detection Info", lines=10) | |
| # btn.click(fn=detect_ships, inputs=[image_input, threshold], outputs=[image_output, info_output]) | |
| # if __name__ == "__main__": | |
| # demo.launch(server_name="0.0.0.0", server_port=7860) | |
| import os | |
| import tempfile | |
| from copy import deepcopy | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| from detectron2.config import get_cfg | |
| from detectron2.data import MetadataCatalog | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.utils.visualizer import ColorMode, Visualizer | |
| from huggingface_hub import hf_hub_download | |
| # ββ Model loading ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| REPO_ID = os.getenv("MODEL_REPO_ID", "PUSHPENDAR/hrsid-ship-detection") | |
| os.makedirs("/app/hf_cache", exist_ok=True) | |
| print("Downloading model files...") | |
| MODEL_PATH = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="model_final.pth", | |
| cache_dir="/app/hf_cache", | |
| token=os.getenv("HF_TOKEN"), # uses secret if set, else None (public repos) | |
| ) | |
| CONFIG_PATH = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="config.yaml", | |
| cache_dir="/app/hf_cache", | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| print(f"Model: {MODEL_PATH} β ") | |
| print(f"Config: {CONFIG_PATH} β ") | |
| print("Loading Faster R-CNN model...") | |
| _base_cfg = get_cfg() | |
| _base_cfg.merge_from_file(CONFIG_PATH) | |
| _base_cfg.MODEL.WEIGHTS = MODEL_PATH | |
| _base_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
| _base_cfg.MODEL.DEVICE = "cpu" | |
| _base_cfg.freeze() # make it immutable so we always deepcopy before mutating | |
| MetadataCatalog.get("__unused").set(thing_classes=["ship"]) | |
| print("Model loaded β ") | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_predictor(confidence_threshold: float) -> DefaultPredictor: | |
| """Return a fresh predictor with the requested threshold. | |
| deepcopy avoids mutating the global frozen cfg across concurrent requests. | |
| """ | |
| cfg = deepcopy(_base_cfg) | |
| cfg.defrost() | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold | |
| return DefaultPredictor(cfg) | |
| def run_inference(img_bgr: np.ndarray, confidence_threshold: float): | |
| """Run detection on a single BGR frame. Returns (result_bgr, instances).""" | |
| predictor = get_predictor(confidence_threshold) | |
| outputs = predictor(img_bgr) | |
| instances = outputs["instances"].to("cpu") | |
| instances = instances[instances.scores >= confidence_threshold] | |
| metadata = MetadataCatalog.get("__unused") | |
| v = Visualizer( | |
| img_bgr[:, :, ::-1], | |
| metadata=metadata, | |
| scale=1.0, | |
| instance_mode=ColorMode.IMAGE, | |
| ) | |
| out = v.draw_instance_predictions(instances) | |
| result_rgb = out.get_image() # HΓWΓ3 RGB | |
| result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) | |
| return result_bgr, instances | |
| def build_info(instances) -> str: | |
| num = len(instances) | |
| scores = instances.scores.tolist() | |
| info = f"β Detected {num} ship(s)\n" | |
| if scores: | |
| info += "Confidence scores: " + ", ".join([f"{s:.2f}" for s in scores]) | |
| if hasattr(instances, "pred_boxes"): | |
| boxes = instances.pred_boxes.tensor.tolist() | |
| info += "\n\nBounding boxes (x1,y1,x2,y2):\n" | |
| for i, (box, score) in enumerate(zip(boxes, scores)): | |
| x1, y1, x2, y2 = [int(c) for c in box] | |
| info += f" Ship {i+1}: [{x1},{y1},{x2},{y2}] conf={score:.2f}\n" | |
| else: | |
| info += "No ships detected above threshold." | |
| return info | |
| # ββ Image tab ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def detect_ships_image(image, confidence_threshold): | |
| if image is None: | |
| return None, "Please upload an image." | |
| img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| result_bgr, inst = run_inference(img_bgr, confidence_threshold) | |
| result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB) | |
| return result_rgb, build_info(inst) | |
| # ββ Video tab ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def detect_ships_video(video_path, confidence_threshold, progress=gr.Progress()): | |
| if video_path is None: | |
| return None, "Please upload a video." | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return None, "Could not open video file." | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 25 | |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
| out_path = out_file.name | |
| out_file.close() | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h)) | |
| frame_idx = 0 | |
| total_ships = 0 | |
| max_per_frame = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| result_bgr, inst = run_inference(frame, confidence_threshold) | |
| writer.write(result_bgr) | |
| n = len(inst) | |
| total_ships += n | |
| max_per_frame = max(max_per_frame, n) | |
| frame_idx += 1 | |
| if total_frames > 0: | |
| progress( | |
| frame_idx / total_frames, | |
| desc=f"Processing frame {frame_idx}/{total_frames}", | |
| ) | |
| cap.release() | |
| writer.release() | |
| info = ( | |
| f"β Video processed: {frame_idx} frames\n" | |
| f"Total ship detections across all frames: {total_ships}\n" | |
| f"Peak ships in a single frame: {max_per_frame}\n" | |
| f"FPS: {fps:.1f} | Resolution: {w}Γ{h}" | |
| ) | |
| return out_path, info | |
| # ββ UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="π’ HRSID Ship Detection") as demo: | |
| gr.Markdown("# π’ HRSID Ship Detection") | |
| gr.Markdown( | |
| "Detect ships in SAR images **or videos** using " | |
| "Faster R-CNN with ResNet-101, trained on the HRSID dataset." | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("πΌοΈ Image Detection"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input = gr.Image(type="pil", label="Upload SAR Image") | |
| img_thresh = gr.Slider( | |
| 0.1, 0.9, value=0.5, step=0.05, label="Confidence Threshold" | |
| ) | |
| img_btn = gr.Button("Detect Ships", variant="primary") | |
| with gr.Column(): | |
| img_output = gr.Image(type="numpy", label="Detection Result") | |
| img_info = gr.Textbox(label="Detection Info", lines=10) | |
| img_btn.click( | |
| fn=detect_ships_image, | |
| inputs=[img_input, img_thresh], | |
| outputs=[img_output, img_info], | |
| ) | |
| with gr.Tab("π₯ Video Detection"): | |
| gr.Markdown( | |
| "> β οΈ CPU inference is slow. Short clips (< 30 s) are recommended." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| vid_input = gr.Video(label="Upload SAR Video") | |
| vid_thresh = gr.Slider( | |
| 0.1, 0.9, value=0.5, step=0.05, label="Confidence Threshold" | |
| ) | |
| vid_btn = gr.Button("Detect Ships in Video", variant="primary") | |
| with gr.Column(): | |
| vid_output = gr.Video(label="Detection Result Video") | |
| vid_info = gr.Textbox(label="Detection Summary", lines=8) | |
| vid_btn.click( | |
| fn=detect_ships_video, | |
| inputs=[vid_input, vid_thresh], | |
| outputs=[vid_output, vid_info], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) # NO share=True |