rba28 commited on
Commit
d19905b
·
verified ·
1 Parent(s): 1b5977c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -26
app.py CHANGED
@@ -40,9 +40,8 @@ LABEL_MAP = {
40
  "автомобиль": "Car",
41
  "машина": "Car",
42
  "БПЛА самелет": "UAV Airplane",
43
- "drone": "Drone", # some models lowercase
44
  }
45
-
46
  THREAT_SET = {"drone", "uav", "airplane", "helicopter"}
47
 
48
  def map_label(name: str) -> str:
@@ -59,12 +58,11 @@ def is_threat(label_en: str) -> bool:
59
  return label_en and label_en.lower() in THREAT_SET
60
 
61
  # =========================
62
- # FALSE-POSITIVE FILTERS relaxed defaults
63
- # (you can tighten later in Space Secrets)
64
  # =========================
65
- MIN_CONF = float(os.getenv("MIN_CONF", 0.30)) # was 0.60
66
- MIN_AREA_PCT = float(os.getenv("MIN_AREA_PCT", 0.001))# was 0.004
67
- SKY_RATIO = float(os.getenv("SKY_RATIO", 0.95)) # was 0.65 (almost no sky gating)
68
 
69
  # =========================
70
  # LAZY GLOBAL STATE
@@ -124,7 +122,11 @@ def _get_model(model_key: str, conf: float, iou: float):
124
  try:
125
  weights = _download_from_hf(repo, file)
126
  m = YOLO(weights)
 
127
  m.overrides["max_det"] = 300
 
 
 
128
  _model = m
129
  _loaded_repo, _loaded_file = repo, file
130
  try:
@@ -138,8 +140,10 @@ def _get_model(model_key: str, conf: float, iou: float):
138
  _model_err = f"Model load failed for {repo}/{file}. Error: {last_err}"
139
  if _model_err:
140
  raise RuntimeError(_model_err)
 
141
  _model.overrides["conf"] = float(conf)
142
  _model.overrides["iou"] = float(iou)
 
143
  return _model
144
 
145
  def _model_info_text():
@@ -181,7 +185,7 @@ def _results_to_rows(results) -> List[dict]:
181
  def _filter_rows_by_geometry(r, rows: List[dict], model_key: str) -> List[dict]:
182
  """
183
  Drop low-conf, tiny, ground-region boxes.
184
- For drone-only model, do NOT restrict classes (some checkpoints label as 'UAV'/'drone' variants).
185
  For multi-class, keep only classes we care about.
186
  """
187
  if "Multi-class" in model_key:
@@ -201,9 +205,9 @@ def _filter_rows_by_geometry(r, rows: List[dict], model_key: str) -> List[dict]:
201
  cls = map_label(str(row.get("class","")))
202
  if allowed and cls not in allowed:
203
  continue
204
- if H and W:
205
  area = row["width"] * row["height"]
206
- if (W * H) > 0 and area / (W * H) < MIN_AREA_PCT:
207
  continue
208
  y_bottom = row["y2"]
209
  horizon = H * SKY_RATIO
@@ -301,20 +305,20 @@ def _apply_english_overlay(r):
301
  pass
302
 
303
  # =========================
304
- # INFERENCE (with filtering + custom draw + debug)
305
  # =========================
306
- def detect_image_safe(model_key: str, image, conf: float, iou: float):
307
  try:
308
  if image is None:
309
  return None, [], "⚠️ No image provided.", [], None, _model_info_text()
310
  cv2 = _lazy_cv2()
311
  model = _get_model(model_key, conf, iou)
312
- results = model.predict(image, imgsz=960, verbose=False)
313
  r = results[0]
314
  _apply_english_overlay(r)
315
 
316
  rows_raw = _results_to_rows(results)
317
- rows = _filter_rows_by_geometry(r, rows_raw, model_key)
318
 
319
  annotated_bgr = _draw_annotations_bgr(r.orig_img, rows)
320
  now_utc = time.strftime("%Y-%m-%d %H:%M:%S UTC", time.gmtime())
@@ -344,7 +348,7 @@ def detect_image_safe(model_key: str, image, conf: float, iou: float):
344
  except Exception as e:
345
  return None, [], f"❌ Error during image detection: {e}", [], None, _model_info_text()
346
 
347
- def detect_video_safe(model_key: str, video_path: str, conf: float, iou: float, max_frames: int = 300):
348
  try:
349
  if not video_path:
350
  return None, "{}", "⚠️ No video provided.", [], _model_info_text()
@@ -377,12 +381,12 @@ def detect_video_safe(model_key: str, video_path: str, conf: float, iou: float,
377
  if frames > int(max_frames):
378
  break
379
 
380
- results = model.predict(frame, imgsz=960, verbose=False)
381
  r = results[0]
382
  _apply_english_overlay(r)
383
 
384
  rows_raw = _results_to_rows(results)
385
- rows = _filter_rows_by_geometry(r, rows_raw, model_key)
386
  raw_total += len(rows_raw)
387
  kept_total += len(rows)
388
 
@@ -438,7 +442,7 @@ def export_pdf_vid(det_records: List[dict], summary: str):
438
  # =========================
439
  NOTE = (
440
  "Detections include timestamp, object, confidence, and Threat/Non-threat. "
441
- "Filters are relaxed (MIN_CONF=0.30, MIN_AREA_PCT=0.001, SKY_RATIO=0.95) so you see boxes; tighten later as needed."
442
  )
443
 
444
  with gr.Blocks(title="UAV / Drone Detector (YOLO)") as demo:
@@ -461,8 +465,9 @@ with gr.Blocks(title="UAV / Drone Detector (YOLO)") as demo:
461
  label="Input Image"
462
  )
463
  with gr.Column():
464
- conf_img = gr.Slider(0.05, 0.9, 0.35, step=0.05, label="Model Confidence")
465
  iou_img = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU")
 
466
  run_img = gr.Button("Run Detection")
467
  gr.Markdown(NOTE)
468
 
@@ -474,12 +479,12 @@ with gr.Blocks(title="UAV / Drone Detector (YOLO)") as demo:
474
  annotated_tmp_img_path = gr.State(value=None)
475
  image_det_state = gr.State(value=[])
476
 
477
- def _run_img(mkey, image, conf, iou):
478
- return detect_image_safe(mkey, image, conf, iou)
479
 
480
  run_img.click(
481
  fn=_run_img,
482
- inputs=[model_key, image_in, conf_img, iou_img],
483
  outputs=[image_out, table_out, msg_img, image_det_state, annotated_tmp_img_path, model_info_md],
484
  )
485
 
@@ -497,9 +502,10 @@ with gr.Blocks(title="UAV / Drone Detector (YOLO)") as demo:
497
  label="Input Video"
498
  )
499
  with gr.Column():
500
- conf_vid = gr.Slider(0.05, 0.9, 0.35, step=0.05, label="Model Confidence")
501
  iou_vid = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU")
502
  max_frames = gr.Slider(60, 2000, 300, step=10, label="Max frames to process")
 
503
  run_vid = gr.Button("Run Detection")
504
  gr.Markdown(NOTE)
505
 
@@ -510,12 +516,12 @@ with gr.Blocks(title="UAV / Drone Detector (YOLO)") as demo:
510
  pdf_vid_path = gr.File(label="PDF Report", interactive=False)
511
  video_det_state = gr.State(value=[])
512
 
513
- def _run_vid(mkey, vpath, conf, iou, maxf):
514
- return detect_video_safe(mkey, vpath, conf, iou, int(maxf))
515
 
516
  run_vid.click(
517
  fn=_run_vid,
518
- inputs=[model_key, video_in, conf_vid, iou_vid, max_frames],
519
  outputs=[video_out, detections_json_text, msg_vid, video_det_state, model_info_md],
520
  )
521
 
 
40
  "автомобиль": "Car",
41
  "машина": "Car",
42
  "БПЛА самелет": "UAV Airplane",
43
+ "drone": "Drone",
44
  }
 
45
  THREAT_SET = {"drone", "uav", "airplane", "helicopter"}
46
 
47
  def map_label(name: str) -> str:
 
58
  return label_en and label_en.lower() in THREAT_SET
59
 
60
  # =========================
61
+ # FILTERS (relaxed defaults; can be tightened later)
 
62
  # =========================
63
+ MIN_CONF = float(os.getenv("MIN_CONF", 0.30)) # model outputs below this are filtered (our post-filter)
64
+ MIN_AREA_PCT = float(os.getenv("MIN_AREA_PCT", 0.001)) # drop tiny boxes (fraction of frame)
65
+ SKY_RATIO = float(os.getenv("SKY_RATIO", 0.95)) # keep boxes whose bottoms are above 95% height (nearly off)
66
 
67
  # =========================
68
  # LAZY GLOBAL STATE
 
122
  try:
123
  weights = _download_from_hf(repo, file)
124
  m = YOLO(weights)
125
+ # core overrides
126
  m.overrides["max_det"] = 300
127
+ m.overrides["conf"] = float(conf) # driven by UI
128
+ m.overrides["iou"] = float(iou) # driven by UI
129
+ m.overrides["agnostic_nms"] = True # reduce class‑based NMS misses
130
  _model = m
131
  _loaded_repo, _loaded_file = repo, file
132
  try:
 
140
  _model_err = f"Model load failed for {repo}/{file}. Error: {last_err}"
141
  if _model_err:
142
  raise RuntimeError(_model_err)
143
+ # also set at call time in case sliders change
144
  _model.overrides["conf"] = float(conf)
145
  _model.overrides["iou"] = float(iou)
146
+ _model.overrides["agnostic_nms"] = True
147
  return _model
148
 
149
  def _model_info_text():
 
185
  def _filter_rows_by_geometry(r, rows: List[dict], model_key: str) -> List[dict]:
186
  """
187
  Drop low-conf, tiny, ground-region boxes.
188
+ For drone-only model, DO NOT restrict classes (some checkpoints label as 'UAV'/'drone' variants).
189
  For multi-class, keep only classes we care about.
190
  """
191
  if "Multi-class" in model_key:
 
205
  cls = map_label(str(row.get("class","")))
206
  if allowed and cls not in allowed:
207
  continue
208
+ if H and W and (W * H) > 0:
209
  area = row["width"] * row["height"]
210
+ if area / (W * H) < MIN_AREA_PCT:
211
  continue
212
  y_bottom = row["y2"]
213
  horizon = H * SKY_RATIO
 
305
  pass
306
 
307
  # =========================
308
+ # INFERENCE (filters toggle + imgsz=1280 + debug)
309
  # =========================
310
+ def detect_image_safe(model_key: str, image, conf: float, iou: float, bypass_filters: bool = True):
311
  try:
312
  if image is None:
313
  return None, [], "⚠️ No image provided.", [], None, _model_info_text()
314
  cv2 = _lazy_cv2()
315
  model = _get_model(model_key, conf, iou)
316
+ results = model.predict(image, imgsz=1280, verbose=False) # larger input helps tiny drones
317
  r = results[0]
318
  _apply_english_overlay(r)
319
 
320
  rows_raw = _results_to_rows(results)
321
+ rows = rows_raw if bypass_filters else _filter_rows_by_geometry(r, rows_raw, model_key)
322
 
323
  annotated_bgr = _draw_annotations_bgr(r.orig_img, rows)
324
  now_utc = time.strftime("%Y-%m-%d %H:%M:%S UTC", time.gmtime())
 
348
  except Exception as e:
349
  return None, [], f"❌ Error during image detection: {e}", [], None, _model_info_text()
350
 
351
+ def detect_video_safe(model_key: str, video_path: str, conf: float, iou: float, max_frames: int = 300, bypass_filters: bool = True):
352
  try:
353
  if not video_path:
354
  return None, "{}", "⚠️ No video provided.", [], _model_info_text()
 
381
  if frames > int(max_frames):
382
  break
383
 
384
+ results = model.predict(frame, imgsz=1280, verbose=False)
385
  r = results[0]
386
  _apply_english_overlay(r)
387
 
388
  rows_raw = _results_to_rows(results)
389
+ rows = rows_raw if bypass_filters else _filter_rows_by_geometry(r, rows_raw, model_key)
390
  raw_total += len(rows_raw)
391
  kept_total += len(rows)
392
 
 
442
  # =========================
443
  NOTE = (
444
  "Detections include timestamp, object, confidence, and Threat/Non-threat. "
445
+ "Use 'Bypass filters (debug)' to see raw model boxes; tighten filters after you confirm detections."
446
  )
447
 
448
  with gr.Blocks(title="UAV / Drone Detector (YOLO)") as demo:
 
465
  label="Input Image"
466
  )
467
  with gr.Column():
468
+ conf_img = gr.Slider(0.05, 0.9, 0.25, step=0.05, label="Model Confidence")
469
  iou_img = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU")
470
+ filters_off_img = gr.Checkbox(value=True, label="Bypass filters (debug)")
471
  run_img = gr.Button("Run Detection")
472
  gr.Markdown(NOTE)
473
 
 
479
  annotated_tmp_img_path = gr.State(value=None)
480
  image_det_state = gr.State(value=[])
481
 
482
+ def _run_img(mkey, image, conf, iou, bypass):
483
+ return detect_image_safe(mkey, image, conf, iou, bypass)
484
 
485
  run_img.click(
486
  fn=_run_img,
487
+ inputs=[model_key, image_in, conf_img, iou_img, filters_off_img],
488
  outputs=[image_out, table_out, msg_img, image_det_state, annotated_tmp_img_path, model_info_md],
489
  )
490
 
 
502
  label="Input Video"
503
  )
504
  with gr.Column():
505
+ conf_vid = gr.Slider(0.05, 0.9, 0.25, step=0.05, label="Model Confidence")
506
  iou_vid = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU")
507
  max_frames = gr.Slider(60, 2000, 300, step=10, label="Max frames to process")
508
+ filters_off_vid = gr.Checkbox(value=True, label="Bypass filters (debug)")
509
  run_vid = gr.Button("Run Detection")
510
  gr.Markdown(NOTE)
511
 
 
516
  pdf_vid_path = gr.File(label="PDF Report", interactive=False)
517
  video_det_state = gr.State(value=[])
518
 
519
+ def _run_vid(mkey, vpath, conf, iou, maxf, bypass):
520
+ return detect_video_safe(mkey, vpath, conf, iou, int(maxf), bypass)
521
 
522
  run_vid.click(
523
  fn=_run_vid,
524
+ inputs=[model_key, video_in, conf_vid, iou_vid, max_frames, filters_off_vid],
525
  outputs=[video_out, detections_json_text, msg_vid, video_det_state, model_info_md],
526
  )
527