PUSHPENDAR commited on
Commit
cab7fb3
Β·
verified Β·
1 Parent(s): 11dcc48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -54
app.py CHANGED
@@ -80,62 +80,89 @@
80
 
81
  # if __name__ == "__main__":
82
  # demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
83
 
84
- import gradio as gr
85
  import cv2
 
86
  import numpy as np
87
- import tempfile
88
- import os
89
  from detectron2.config import get_cfg
90
- from detectron2.engine import DefaultPredictor
91
- from detectron2.utils.visualizer import Visualizer, ColorMode
92
  from detectron2.data import MetadataCatalog
 
 
93
  from huggingface_hub import hf_hub_download
94
 
 
 
95
  REPO_ID = os.getenv("MODEL_REPO_ID", "PUSHPENDAR/hrsid-ship-detection")
96
 
97
  os.makedirs("/app/hf_cache", exist_ok=True)
98
 
99
  print("Downloading model files...")
100
- MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="model_final.pth", cache_dir="/app/hf_cache")
101
- CONFIG_PATH = hf_hub_download(repo_id=REPO_ID, filename="config.yaml", cache_dir="/app/hf_cache")
102
- print(f"Model: {MODEL_PATH} βœ…")
 
 
 
 
 
 
 
 
 
 
103
  print(f"Config: {CONFIG_PATH} βœ…")
104
 
105
  print("Loading Faster R-CNN model...")
106
- cfg = get_cfg()
107
- cfg.merge_from_file(CONFIG_PATH)
108
- cfg.MODEL.WEIGHTS = MODEL_PATH
109
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
110
- cfg.MODEL.DEVICE = "cpu"
 
111
 
112
  MetadataCatalog.get("__unused").set(thing_classes=["ship"])
113
- predictor = DefaultPredictor(cfg)
114
  print("Model loaded βœ…")
115
 
116
 
117
- # ── helpers ──────────────────────────────────────────────────────────────────
118
 
119
- def run_inference(img_bgr, confidence_threshold):
120
- """Run detection on a single BGR frame. Returns (result_bgr, instances)."""
 
 
 
 
121
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
122
- outputs = predictor(img_bgr)
 
 
 
 
 
 
123
  instances = outputs["instances"].to("cpu")
124
  instances = instances[instances.scores >= confidence_threshold]
125
 
126
  metadata = MetadataCatalog.get("__unused")
127
- v = Visualizer(img_bgr[:, :, ::-1], metadata=metadata,
128
- scale=1.0, instance_mode=ColorMode.IMAGE)
 
 
 
 
129
  out = v.draw_instance_predictions(instances)
130
- result_rgb = out.get_image() # HΓ—WΓ—3 RGB
131
  result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
132
  return result_bgr, instances
133
 
134
 
135
- def build_info(instances):
136
- num = len(instances)
137
  scores = instances.scores.tolist()
138
- info = f"βœ… Detected {num} ship(s)\n"
139
  if scores:
140
  info += "Confidence scores: " + ", ".join([f"{s:.2f}" for s in scores])
141
  if hasattr(instances, "pred_boxes"):
@@ -143,24 +170,24 @@ def build_info(instances):
143
  info += "\n\nBounding boxes (x1,y1,x2,y2):\n"
144
  for i, (box, score) in enumerate(zip(boxes, scores)):
145
  x1, y1, x2, y2 = [int(c) for c in box]
146
- info += f" Ship {i+1}: [{x1},{y1},{x2},{y2}] conf={score:.2f}\n"
147
  else:
148
  info += "No ships detected above threshold."
149
  return info
150
 
151
 
152
- # ── image tab ────���───────────────────────────────────────────────────────────
153
 
154
  def detect_ships_image(image, confidence_threshold):
155
  if image is None:
156
  return None, "Please upload an image."
157
- img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
158
  result_bgr, inst = run_inference(img_bgr, confidence_threshold)
159
- result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB)
160
  return result_rgb, build_info(inst)
161
 
162
 
163
- # ── video tab ────────────────────────────────────────────────────────────────
164
 
165
  def detect_ships_video(video_path, confidence_threshold, progress=gr.Progress()):
166
  if video_path is None:
@@ -171,11 +198,10 @@ def detect_ships_video(video_path, confidence_threshold, progress=gr.Progress())
171
  return None, "Could not open video file."
172
 
173
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
174
- fps = cap.get(cv2.CAP_PROP_FPS) or 25
175
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
176
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
177
 
178
- # Write output to a temp MP4 file
179
  out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
180
  out_path = out_file.name
181
  out_file.close()
@@ -183,9 +209,9 @@ def detect_ships_video(video_path, confidence_threshold, progress=gr.Progress())
183
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
184
  writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
185
 
186
- frame_idx = 0
187
- total_ships = 0
188
- max_per_frame = 0
189
 
190
  while True:
191
  ret, frame = cap.read()
@@ -195,14 +221,16 @@ def detect_ships_video(video_path, confidence_threshold, progress=gr.Progress())
195
  result_bgr, inst = run_inference(frame, confidence_threshold)
196
  writer.write(result_bgr)
197
 
198
- n = len(inst)
199
  total_ships += n
200
  max_per_frame = max(max_per_frame, n)
201
- frame_idx += 1
202
 
203
  if total_frames > 0:
204
- progress(frame_idx / total_frames,
205
- desc=f"Processing frame {frame_idx}/{total_frames}")
 
 
206
 
207
  cap.release()
208
  writer.release()
@@ -211,7 +239,7 @@ def detect_ships_video(video_path, confidence_threshold, progress=gr.Progress())
211
  f"βœ… Video processed: {frame_idx} frames\n"
212
  f"Total ship detections across all frames: {total_ships}\n"
213
  f"Peak ships in a single frame: {max_per_frame}\n"
214
- f"FPS: {fps:.1f} | Resolution: {w}Γ—{h}"
215
  )
216
  return out_path, info
217
 
@@ -227,17 +255,17 @@ with gr.Blocks(title="🚒 HRSID Ship Detection") as demo:
227
 
228
  with gr.Tabs():
229
 
230
- # ── Image tab ────────────────────────────────────────────────────────
231
  with gr.Tab("πŸ–ΌοΈ Image Detection"):
232
  with gr.Row():
233
  with gr.Column():
234
- img_input = gr.Image(type="pil", label="Upload SAR Image")
235
- img_thresh = gr.Slider(0.1, 0.9, value=0.5, step=0.05,
236
- label="Confidence Threshold")
237
- img_btn = gr.Button("Detect Ships", variant="primary")
 
238
  with gr.Column():
239
  img_output = gr.Image(type="numpy", label="Detection Result")
240
- img_info = gr.Textbox(label="Detection Info", lines=10)
241
 
242
  img_btn.click(
243
  fn=detect_ships_image,
@@ -245,20 +273,20 @@ with gr.Blocks(title="🚒 HRSID Ship Detection") as demo:
245
  outputs=[img_output, img_info],
246
  )
247
 
248
- # ── Video tab ────────────────────────────────────────────────────────
249
  with gr.Tab("πŸŽ₯ Video Detection"):
250
  gr.Markdown(
251
  "> ⚠️ CPU inference is slow. Short clips (< 30 s) are recommended."
252
  )
253
  with gr.Row():
254
  with gr.Column():
255
- vid_input = gr.Video(label="Upload SAR Video")
256
- vid_thresh = gr.Slider(0.1, 0.9, value=0.5, step=0.05,
257
- label="Confidence Threshold")
258
- vid_btn = gr.Button("Detect Ships in Video", variant="primary")
 
259
  with gr.Column():
260
  vid_output = gr.Video(label="Detection Result Video")
261
- vid_info = gr.Textbox(label="Detection Summary", lines=8)
262
 
263
  vid_btn.click(
264
  fn=detect_ships_video,
@@ -268,4 +296,4 @@ with gr.Blocks(title="🚒 HRSID Ship Detection") as demo:
268
 
269
  if __name__ == "__main__":
270
  demo.queue()
271
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
80
 
81
  # if __name__ == "__main__":
82
  # demo.launch(server_name="0.0.0.0", server_port=7860)
83
+ import os
84
+ import tempfile
85
+ from copy import deepcopy
86
 
 
87
  import cv2
88
+ import gradio as gr
89
  import numpy as np
 
 
90
  from detectron2.config import get_cfg
 
 
91
  from detectron2.data import MetadataCatalog
92
+ from detectron2.engine import DefaultPredictor
93
+ from detectron2.utils.visualizer import ColorMode, Visualizer
94
  from huggingface_hub import hf_hub_download
95
 
96
+ # ── Model loading ────────────────────────────────────────────────────────────
97
+
98
  REPO_ID = os.getenv("MODEL_REPO_ID", "PUSHPENDAR/hrsid-ship-detection")
99
 
100
  os.makedirs("/app/hf_cache", exist_ok=True)
101
 
102
  print("Downloading model files...")
103
+ MODEL_PATH = hf_hub_download(
104
+ repo_id=REPO_ID,
105
+ filename="model_final.pth",
106
+ cache_dir="/app/hf_cache",
107
+ token=os.getenv("HF_TOKEN"), # uses secret if set, else None (public repos)
108
+ )
109
+ CONFIG_PATH = hf_hub_download(
110
+ repo_id=REPO_ID,
111
+ filename="config.yaml",
112
+ cache_dir="/app/hf_cache",
113
+ token=os.getenv("HF_TOKEN"),
114
+ )
115
+ print(f"Model: {MODEL_PATH} βœ…")
116
  print(f"Config: {CONFIG_PATH} βœ…")
117
 
118
  print("Loading Faster R-CNN model...")
119
+ _base_cfg = get_cfg()
120
+ _base_cfg.merge_from_file(CONFIG_PATH)
121
+ _base_cfg.MODEL.WEIGHTS = MODEL_PATH
122
+ _base_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
123
+ _base_cfg.MODEL.DEVICE = "cpu"
124
+ _base_cfg.freeze() # make it immutable so we always deepcopy before mutating
125
 
126
  MetadataCatalog.get("__unused").set(thing_classes=["ship"])
 
127
  print("Model loaded βœ…")
128
 
129
 
130
+ # ── Helpers ──────────────────────────────────────────────────────────────────
131
 
132
+ def get_predictor(confidence_threshold: float) -> DefaultPredictor:
133
+ """Return a fresh predictor with the requested threshold.
134
+ deepcopy avoids mutating the global frozen cfg across concurrent requests.
135
+ """
136
+ cfg = deepcopy(_base_cfg)
137
+ cfg.defrost()
138
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
139
+ return DefaultPredictor(cfg)
140
+
141
+
142
+ def run_inference(img_bgr: np.ndarray, confidence_threshold: float):
143
+ """Run detection on a single BGR frame. Returns (result_bgr, instances)."""
144
+ predictor = get_predictor(confidence_threshold)
145
+ outputs = predictor(img_bgr)
146
  instances = outputs["instances"].to("cpu")
147
  instances = instances[instances.scores >= confidence_threshold]
148
 
149
  metadata = MetadataCatalog.get("__unused")
150
+ v = Visualizer(
151
+ img_bgr[:, :, ::-1],
152
+ metadata=metadata,
153
+ scale=1.0,
154
+ instance_mode=ColorMode.IMAGE,
155
+ )
156
  out = v.draw_instance_predictions(instances)
157
+ result_rgb = out.get_image() # HΓ—WΓ—3 RGB
158
  result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
159
  return result_bgr, instances
160
 
161
 
162
+ def build_info(instances) -> str:
163
+ num = len(instances)
164
  scores = instances.scores.tolist()
165
+ info = f"βœ… Detected {num} ship(s)\n"
166
  if scores:
167
  info += "Confidence scores: " + ", ".join([f"{s:.2f}" for s in scores])
168
  if hasattr(instances, "pred_boxes"):
 
170
  info += "\n\nBounding boxes (x1,y1,x2,y2):\n"
171
  for i, (box, score) in enumerate(zip(boxes, scores)):
172
  x1, y1, x2, y2 = [int(c) for c in box]
173
+ info += f" Ship {i+1}: [{x1},{y1},{x2},{y2}] conf={score:.2f}\n"
174
  else:
175
  info += "No ships detected above threshold."
176
  return info
177
 
178
 
179
+ # ── Image tab ────────────────────────────────────────────────────────────────
180
 
181
  def detect_ships_image(image, confidence_threshold):
182
  if image is None:
183
  return None, "Please upload an image."
184
+ img_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
185
  result_bgr, inst = run_inference(img_bgr, confidence_threshold)
186
+ result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB)
187
  return result_rgb, build_info(inst)
188
 
189
 
190
+ # ── Video tab ────────────────────────────────────────────────────────────────
191
 
192
  def detect_ships_video(video_path, confidence_threshold, progress=gr.Progress()):
193
  if video_path is None:
 
198
  return None, "Could not open video file."
199
 
200
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
201
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25
202
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
203
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
204
 
 
205
  out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
206
  out_path = out_file.name
207
  out_file.close()
 
209
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
210
  writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
211
 
212
+ frame_idx = 0
213
+ total_ships = 0
214
+ max_per_frame = 0
215
 
216
  while True:
217
  ret, frame = cap.read()
 
221
  result_bgr, inst = run_inference(frame, confidence_threshold)
222
  writer.write(result_bgr)
223
 
224
+ n = len(inst)
225
  total_ships += n
226
  max_per_frame = max(max_per_frame, n)
227
+ frame_idx += 1
228
 
229
  if total_frames > 0:
230
+ progress(
231
+ frame_idx / total_frames,
232
+ desc=f"Processing frame {frame_idx}/{total_frames}",
233
+ )
234
 
235
  cap.release()
236
  writer.release()
 
239
  f"βœ… Video processed: {frame_idx} frames\n"
240
  f"Total ship detections across all frames: {total_ships}\n"
241
  f"Peak ships in a single frame: {max_per_frame}\n"
242
+ f"FPS: {fps:.1f} | Resolution: {w}Γ—{h}"
243
  )
244
  return out_path, info
245
 
 
255
 
256
  with gr.Tabs():
257
 
 
258
  with gr.Tab("πŸ–ΌοΈ Image Detection"):
259
  with gr.Row():
260
  with gr.Column():
261
+ img_input = gr.Image(type="pil", label="Upload SAR Image")
262
+ img_thresh = gr.Slider(
263
+ 0.1, 0.9, value=0.5, step=0.05, label="Confidence Threshold"
264
+ )
265
+ img_btn = gr.Button("Detect Ships", variant="primary")
266
  with gr.Column():
267
  img_output = gr.Image(type="numpy", label="Detection Result")
268
+ img_info = gr.Textbox(label="Detection Info", lines=10)
269
 
270
  img_btn.click(
271
  fn=detect_ships_image,
 
273
  outputs=[img_output, img_info],
274
  )
275
 
 
276
  with gr.Tab("πŸŽ₯ Video Detection"):
277
  gr.Markdown(
278
  "> ⚠️ CPU inference is slow. Short clips (< 30 s) are recommended."
279
  )
280
  with gr.Row():
281
  with gr.Column():
282
+ vid_input = gr.Video(label="Upload SAR Video")
283
+ vid_thresh = gr.Slider(
284
+ 0.1, 0.9, value=0.5, step=0.05, label="Confidence Threshold"
285
+ )
286
+ vid_btn = gr.Button("Detect Ships in Video", variant="primary")
287
  with gr.Column():
288
  vid_output = gr.Video(label="Detection Result Video")
289
+ vid_info = gr.Textbox(label="Detection Summary", lines=8)
290
 
291
  vid_btn.click(
292
  fn=detect_ships_video,
 
296
 
297
  if __name__ == "__main__":
298
  demo.queue()
299
+ demo.launch(server_name="0.0.0.0", server_port=7860) # NO share=True