Update app.py
Browse files
app.py
CHANGED
|
@@ -40,9 +40,8 @@ LABEL_MAP = {
|
|
| 40 |
"автомобиль": "Car",
|
| 41 |
"машина": "Car",
|
| 42 |
"БПЛА самелет": "UAV Airplane",
|
| 43 |
-
"drone": "Drone",
|
| 44 |
}
|
| 45 |
-
|
| 46 |
THREAT_SET = {"drone", "uav", "airplane", "helicopter"}
|
| 47 |
|
| 48 |
def map_label(name: str) -> str:
|
|
@@ -59,12 +58,11 @@ def is_threat(label_en: str) -> bool:
|
|
| 59 |
return label_en and label_en.lower() in THREAT_SET
|
| 60 |
|
| 61 |
# =========================
|
| 62 |
-
#
|
| 63 |
-
# (you can tighten later in Space Secrets)
|
| 64 |
# =========================
|
| 65 |
-
MIN_CONF = float(os.getenv("MIN_CONF", 0.30))
|
| 66 |
-
MIN_AREA_PCT = float(os.getenv("MIN_AREA_PCT", 0.001))#
|
| 67 |
-
SKY_RATIO = float(os.getenv("SKY_RATIO", 0.95))
|
| 68 |
|
| 69 |
# =========================
|
| 70 |
# LAZY GLOBAL STATE
|
|
@@ -124,7 +122,11 @@ def _get_model(model_key: str, conf: float, iou: float):
|
|
| 124 |
try:
|
| 125 |
weights = _download_from_hf(repo, file)
|
| 126 |
m = YOLO(weights)
|
|
|
|
| 127 |
m.overrides["max_det"] = 300
|
|
|
|
|
|
|
|
|
|
| 128 |
_model = m
|
| 129 |
_loaded_repo, _loaded_file = repo, file
|
| 130 |
try:
|
|
@@ -138,8 +140,10 @@ def _get_model(model_key: str, conf: float, iou: float):
|
|
| 138 |
_model_err = f"Model load failed for {repo}/{file}. Error: {last_err}"
|
| 139 |
if _model_err:
|
| 140 |
raise RuntimeError(_model_err)
|
|
|
|
| 141 |
_model.overrides["conf"] = float(conf)
|
| 142 |
_model.overrides["iou"] = float(iou)
|
|
|
|
| 143 |
return _model
|
| 144 |
|
| 145 |
def _model_info_text():
|
|
@@ -181,7 +185,7 @@ def _results_to_rows(results) -> List[dict]:
|
|
| 181 |
def _filter_rows_by_geometry(r, rows: List[dict], model_key: str) -> List[dict]:
|
| 182 |
"""
|
| 183 |
Drop low-conf, tiny, ground-region boxes.
|
| 184 |
-
For drone-only model,
|
| 185 |
For multi-class, keep only classes we care about.
|
| 186 |
"""
|
| 187 |
if "Multi-class" in model_key:
|
|
@@ -201,9 +205,9 @@ def _filter_rows_by_geometry(r, rows: List[dict], model_key: str) -> List[dict]:
|
|
| 201 |
cls = map_label(str(row.get("class","")))
|
| 202 |
if allowed and cls not in allowed:
|
| 203 |
continue
|
| 204 |
-
if H and W:
|
| 205 |
area = row["width"] * row["height"]
|
| 206 |
-
if
|
| 207 |
continue
|
| 208 |
y_bottom = row["y2"]
|
| 209 |
horizon = H * SKY_RATIO
|
|
@@ -301,20 +305,20 @@ def _apply_english_overlay(r):
|
|
| 301 |
pass
|
| 302 |
|
| 303 |
# =========================
|
| 304 |
-
# INFERENCE (
|
| 305 |
# =========================
|
| 306 |
-
def detect_image_safe(model_key: str, image, conf: float, iou: float):
|
| 307 |
try:
|
| 308 |
if image is None:
|
| 309 |
return None, [], "⚠️ No image provided.", [], None, _model_info_text()
|
| 310 |
cv2 = _lazy_cv2()
|
| 311 |
model = _get_model(model_key, conf, iou)
|
| 312 |
-
results = model.predict(image, imgsz=
|
| 313 |
r = results[0]
|
| 314 |
_apply_english_overlay(r)
|
| 315 |
|
| 316 |
rows_raw = _results_to_rows(results)
|
| 317 |
-
rows = _filter_rows_by_geometry(r, rows_raw, model_key)
|
| 318 |
|
| 319 |
annotated_bgr = _draw_annotations_bgr(r.orig_img, rows)
|
| 320 |
now_utc = time.strftime("%Y-%m-%d %H:%M:%S UTC", time.gmtime())
|
|
@@ -344,7 +348,7 @@ def detect_image_safe(model_key: str, image, conf: float, iou: float):
|
|
| 344 |
except Exception as e:
|
| 345 |
return None, [], f"❌ Error during image detection: {e}", [], None, _model_info_text()
|
| 346 |
|
| 347 |
-
def detect_video_safe(model_key: str, video_path: str, conf: float, iou: float, max_frames: int = 300):
|
| 348 |
try:
|
| 349 |
if not video_path:
|
| 350 |
return None, "{}", "⚠️ No video provided.", [], _model_info_text()
|
|
@@ -377,12 +381,12 @@ def detect_video_safe(model_key: str, video_path: str, conf: float, iou: float,
|
|
| 377 |
if frames > int(max_frames):
|
| 378 |
break
|
| 379 |
|
| 380 |
-
results = model.predict(frame, imgsz=
|
| 381 |
r = results[0]
|
| 382 |
_apply_english_overlay(r)
|
| 383 |
|
| 384 |
rows_raw = _results_to_rows(results)
|
| 385 |
-
rows = _filter_rows_by_geometry(r, rows_raw, model_key)
|
| 386 |
raw_total += len(rows_raw)
|
| 387 |
kept_total += len(rows)
|
| 388 |
|
|
@@ -438,7 +442,7 @@ def export_pdf_vid(det_records: List[dict], summary: str):
|
|
| 438 |
# =========================
|
| 439 |
NOTE = (
|
| 440 |
"Detections include timestamp, object, confidence, and Threat/Non-threat. "
|
| 441 |
-
"
|
| 442 |
)
|
| 443 |
|
| 444 |
with gr.Blocks(title="UAV / Drone Detector (YOLO)") as demo:
|
|
@@ -461,8 +465,9 @@ with gr.Blocks(title="UAV / Drone Detector (YOLO)") as demo:
|
|
| 461 |
label="Input Image"
|
| 462 |
)
|
| 463 |
with gr.Column():
|
| 464 |
-
conf_img = gr.Slider(0.05, 0.9, 0.
|
| 465 |
iou_img = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU")
|
|
|
|
| 466 |
run_img = gr.Button("Run Detection")
|
| 467 |
gr.Markdown(NOTE)
|
| 468 |
|
|
@@ -474,12 +479,12 @@ with gr.Blocks(title="UAV / Drone Detector (YOLO)") as demo:
|
|
| 474 |
annotated_tmp_img_path = gr.State(value=None)
|
| 475 |
image_det_state = gr.State(value=[])
|
| 476 |
|
| 477 |
-
def _run_img(mkey, image, conf, iou):
|
| 478 |
-
return detect_image_safe(mkey, image, conf, iou)
|
| 479 |
|
| 480 |
run_img.click(
|
| 481 |
fn=_run_img,
|
| 482 |
-
inputs=[model_key, image_in, conf_img, iou_img],
|
| 483 |
outputs=[image_out, table_out, msg_img, image_det_state, annotated_tmp_img_path, model_info_md],
|
| 484 |
)
|
| 485 |
|
|
@@ -497,9 +502,10 @@ with gr.Blocks(title="UAV / Drone Detector (YOLO)") as demo:
|
|
| 497 |
label="Input Video"
|
| 498 |
)
|
| 499 |
with gr.Column():
|
| 500 |
-
conf_vid = gr.Slider(0.05, 0.9, 0.
|
| 501 |
iou_vid = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU")
|
| 502 |
max_frames = gr.Slider(60, 2000, 300, step=10, label="Max frames to process")
|
|
|
|
| 503 |
run_vid = gr.Button("Run Detection")
|
| 504 |
gr.Markdown(NOTE)
|
| 505 |
|
|
@@ -510,12 +516,12 @@ with gr.Blocks(title="UAV / Drone Detector (YOLO)") as demo:
|
|
| 510 |
pdf_vid_path = gr.File(label="PDF Report", interactive=False)
|
| 511 |
video_det_state = gr.State(value=[])
|
| 512 |
|
| 513 |
-
def _run_vid(mkey, vpath, conf, iou, maxf):
|
| 514 |
-
return detect_video_safe(mkey, vpath, conf, iou, int(maxf))
|
| 515 |
|
| 516 |
run_vid.click(
|
| 517 |
fn=_run_vid,
|
| 518 |
-
inputs=[model_key, video_in, conf_vid, iou_vid, max_frames],
|
| 519 |
outputs=[video_out, detections_json_text, msg_vid, video_det_state, model_info_md],
|
| 520 |
)
|
| 521 |
|
|
|
|
| 40 |
"автомобиль": "Car",
|
| 41 |
"машина": "Car",
|
| 42 |
"БПЛА самелет": "UAV Airplane",
|
| 43 |
+
"drone": "Drone",
|
| 44 |
}
|
|
|
|
| 45 |
THREAT_SET = {"drone", "uav", "airplane", "helicopter"}
|
| 46 |
|
| 47 |
def map_label(name: str) -> str:
|
|
|
|
| 58 |
return label_en and label_en.lower() in THREAT_SET
|
| 59 |
|
| 60 |
# =========================
|
| 61 |
+
# FILTERS (relaxed defaults; can be tightened later)
|
|
|
|
| 62 |
# =========================
|
| 63 |
+
MIN_CONF = float(os.getenv("MIN_CONF", 0.30)) # model outputs below this are filtered (our post-filter)
|
| 64 |
+
MIN_AREA_PCT = float(os.getenv("MIN_AREA_PCT", 0.001)) # drop tiny boxes (fraction of frame)
|
| 65 |
+
SKY_RATIO = float(os.getenv("SKY_RATIO", 0.95)) # keep boxes whose bottoms are above 95% height (nearly off)
|
| 66 |
|
| 67 |
# =========================
|
| 68 |
# LAZY GLOBAL STATE
|
|
|
|
| 122 |
try:
|
| 123 |
weights = _download_from_hf(repo, file)
|
| 124 |
m = YOLO(weights)
|
| 125 |
+
# core overrides
|
| 126 |
m.overrides["max_det"] = 300
|
| 127 |
+
m.overrides["conf"] = float(conf) # driven by UI
|
| 128 |
+
m.overrides["iou"] = float(iou) # driven by UI
|
| 129 |
+
m.overrides["agnostic_nms"] = True # reduce class‑based NMS misses
|
| 130 |
_model = m
|
| 131 |
_loaded_repo, _loaded_file = repo, file
|
| 132 |
try:
|
|
|
|
| 140 |
_model_err = f"Model load failed for {repo}/{file}. Error: {last_err}"
|
| 141 |
if _model_err:
|
| 142 |
raise RuntimeError(_model_err)
|
| 143 |
+
# also set at call time in case sliders change
|
| 144 |
_model.overrides["conf"] = float(conf)
|
| 145 |
_model.overrides["iou"] = float(iou)
|
| 146 |
+
_model.overrides["agnostic_nms"] = True
|
| 147 |
return _model
|
| 148 |
|
| 149 |
def _model_info_text():
|
|
|
|
| 185 |
def _filter_rows_by_geometry(r, rows: List[dict], model_key: str) -> List[dict]:
|
| 186 |
"""
|
| 187 |
Drop low-conf, tiny, ground-region boxes.
|
| 188 |
+
For drone-only model, DO NOT restrict classes (some checkpoints label as 'UAV'/'drone' variants).
|
| 189 |
For multi-class, keep only classes we care about.
|
| 190 |
"""
|
| 191 |
if "Multi-class" in model_key:
|
|
|
|
| 205 |
cls = map_label(str(row.get("class","")))
|
| 206 |
if allowed and cls not in allowed:
|
| 207 |
continue
|
| 208 |
+
if H and W and (W * H) > 0:
|
| 209 |
area = row["width"] * row["height"]
|
| 210 |
+
if area / (W * H) < MIN_AREA_PCT:
|
| 211 |
continue
|
| 212 |
y_bottom = row["y2"]
|
| 213 |
horizon = H * SKY_RATIO
|
|
|
|
| 305 |
pass
|
| 306 |
|
| 307 |
# =========================
|
| 308 |
+
# INFERENCE (filters toggle + imgsz=1280 + debug)
|
| 309 |
# =========================
|
| 310 |
+
def detect_image_safe(model_key: str, image, conf: float, iou: float, bypass_filters: bool = True):
|
| 311 |
try:
|
| 312 |
if image is None:
|
| 313 |
return None, [], "⚠️ No image provided.", [], None, _model_info_text()
|
| 314 |
cv2 = _lazy_cv2()
|
| 315 |
model = _get_model(model_key, conf, iou)
|
| 316 |
+
results = model.predict(image, imgsz=1280, verbose=False) # larger input helps tiny drones
|
| 317 |
r = results[0]
|
| 318 |
_apply_english_overlay(r)
|
| 319 |
|
| 320 |
rows_raw = _results_to_rows(results)
|
| 321 |
+
rows = rows_raw if bypass_filters else _filter_rows_by_geometry(r, rows_raw, model_key)
|
| 322 |
|
| 323 |
annotated_bgr = _draw_annotations_bgr(r.orig_img, rows)
|
| 324 |
now_utc = time.strftime("%Y-%m-%d %H:%M:%S UTC", time.gmtime())
|
|
|
|
| 348 |
except Exception as e:
|
| 349 |
return None, [], f"❌ Error during image detection: {e}", [], None, _model_info_text()
|
| 350 |
|
| 351 |
+
def detect_video_safe(model_key: str, video_path: str, conf: float, iou: float, max_frames: int = 300, bypass_filters: bool = True):
|
| 352 |
try:
|
| 353 |
if not video_path:
|
| 354 |
return None, "{}", "⚠️ No video provided.", [], _model_info_text()
|
|
|
|
| 381 |
if frames > int(max_frames):
|
| 382 |
break
|
| 383 |
|
| 384 |
+
results = model.predict(frame, imgsz=1280, verbose=False)
|
| 385 |
r = results[0]
|
| 386 |
_apply_english_overlay(r)
|
| 387 |
|
| 388 |
rows_raw = _results_to_rows(results)
|
| 389 |
+
rows = rows_raw if bypass_filters else _filter_rows_by_geometry(r, rows_raw, model_key)
|
| 390 |
raw_total += len(rows_raw)
|
| 391 |
kept_total += len(rows)
|
| 392 |
|
|
|
|
| 442 |
# =========================
|
| 443 |
NOTE = (
|
| 444 |
"Detections include timestamp, object, confidence, and Threat/Non-threat. "
|
| 445 |
+
"Use 'Bypass filters (debug)' to see raw model boxes; tighten filters after you confirm detections."
|
| 446 |
)
|
| 447 |
|
| 448 |
with gr.Blocks(title="UAV / Drone Detector (YOLO)") as demo:
|
|
|
|
| 465 |
label="Input Image"
|
| 466 |
)
|
| 467 |
with gr.Column():
|
| 468 |
+
conf_img = gr.Slider(0.05, 0.9, 0.25, step=0.05, label="Model Confidence")
|
| 469 |
iou_img = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU")
|
| 470 |
+
filters_off_img = gr.Checkbox(value=True, label="Bypass filters (debug)")
|
| 471 |
run_img = gr.Button("Run Detection")
|
| 472 |
gr.Markdown(NOTE)
|
| 473 |
|
|
|
|
| 479 |
annotated_tmp_img_path = gr.State(value=None)
|
| 480 |
image_det_state = gr.State(value=[])
|
| 481 |
|
| 482 |
+
def _run_img(mkey, image, conf, iou, bypass):
|
| 483 |
+
return detect_image_safe(mkey, image, conf, iou, bypass)
|
| 484 |
|
| 485 |
run_img.click(
|
| 486 |
fn=_run_img,
|
| 487 |
+
inputs=[model_key, image_in, conf_img, iou_img, filters_off_img],
|
| 488 |
outputs=[image_out, table_out, msg_img, image_det_state, annotated_tmp_img_path, model_info_md],
|
| 489 |
)
|
| 490 |
|
|
|
|
| 502 |
label="Input Video"
|
| 503 |
)
|
| 504 |
with gr.Column():
|
| 505 |
+
conf_vid = gr.Slider(0.05, 0.9, 0.25, step=0.05, label="Model Confidence")
|
| 506 |
iou_vid = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU")
|
| 507 |
max_frames = gr.Slider(60, 2000, 300, step=10, label="Max frames to process")
|
| 508 |
+
filters_off_vid = gr.Checkbox(value=True, label="Bypass filters (debug)")
|
| 509 |
run_vid = gr.Button("Run Detection")
|
| 510 |
gr.Markdown(NOTE)
|
| 511 |
|
|
|
|
| 516 |
pdf_vid_path = gr.File(label="PDF Report", interactive=False)
|
| 517 |
video_det_state = gr.State(value=[])
|
| 518 |
|
| 519 |
+
def _run_vid(mkey, vpath, conf, iou, maxf, bypass):
|
| 520 |
+
return detect_video_safe(mkey, vpath, conf, iou, int(maxf), bypass)
|
| 521 |
|
| 522 |
run_vid.click(
|
| 523 |
fn=_run_vid,
|
| 524 |
+
inputs=[model_key, video_in, conf_vid, iou_vid, max_frames, filters_off_vid],
|
| 525 |
outputs=[video_out, detections_json_text, msg_vid, video_det_state, model_info_md],
|
| 526 |
)
|
| 527 |
|