rba28 commited on
Commit
6f69d10
·
verified ·
1 Parent(s): 2ebd3aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -51
app.py CHANGED
@@ -5,23 +5,41 @@ from typing import List, Dict, Tuple, Optional
5
  import json
6
  import gradio as gr
7
 
8
- # -------------------
9
- # Config
10
- # -------------------
11
- REPO_ID = "mshamrai/yolov8s-visdrone"
12
- FILENAME = "weights/best.pt"
13
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  SAMPLES_DIR = "samples"
15
- EMBED_IMG = os.path.join(SAMPLES_DIR, "aerial_image.jpg")
16
- EMBED_VID = os.path.join(SAMPLES_DIR, "aerial_video.mp4")
 
17
 
18
- # -------------------
19
  # Lazy state
20
- # -------------------
21
  _model = None
22
  _model_err = None
23
  _model_names = None
24
  _ffmpeg_status = None
 
 
25
 
26
  def _lazy_cv2():
27
  import cv2
@@ -39,35 +57,71 @@ def _ffmpeg_ok() -> bool:
39
  _ffmpeg_status = False
40
  return _ffmpeg_status
41
 
42
- def _lazy_hf_download() -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  from huggingface_hub import hf_hub_download
44
- return hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
45
 
46
  def _get_model(conf: float, iou: float):
47
- """Load YOLO weights from HF on first use."""
48
- global _model, _model_err, _model_names
49
  if _model is None and _model_err is None:
50
- try:
51
- from ultralytics import YOLO
52
- weights_path = _lazy_hf_download()
53
- m = YOLO(weights_path)
54
- m.overrides["max_det"] = 300
55
- _model = m
56
  try:
57
- _model_names = m.model.names if hasattr(m, "model") else None
58
- except Exception:
59
- _model_names = None
60
- except Exception as e:
61
- _model_err = f"Model load failed: {e}"
 
 
 
 
 
 
 
 
 
 
62
  if _model_err:
63
  raise RuntimeError(_model_err)
64
  _model.overrides["conf"] = float(conf)
65
  _model.overrides["iou"] = float(iou)
66
  return _model
67
 
68
- # -------------------
69
  # Helpers
70
- # -------------------
71
  def _results_to_rows(results) -> List[dict]:
72
  rows: List[dict] = []
73
  if not results:
@@ -147,9 +201,9 @@ def _save_pdf(title: str, summary: str, counts: Dict[str, int], annotated_image_
147
  c.showPage(); c.save()
148
  return out_path
149
 
150
- # -------------------
151
  # Inference
152
- # -------------------
153
  def detect_image(image, conf: float, iou: float):
154
  if image is None:
155
  return None, [], "No image provided.", None, None
@@ -211,26 +265,26 @@ def detect_video(video_path: str, conf: float, iou: float, max_frames: int = 300
211
 
212
  def export_pdf_img(summary: str, table_rows: List[dict], annotated_tmp_jpg: Optional[str]):
213
  counts = _count_by_class(table_rows or [])
214
- return _save_pdf("Airspace Drone Detector — Image Report", summary or "No summary.", counts,
215
  annotated_tmp_jpg if annotated_tmp_jpg and os.path.exists(annotated_tmp_jpg) else None)
216
 
217
  def export_pdf_vid(summary: str, counts: dict):
218
- return _save_pdf("Airspace Drone Detector — Video Report", summary or "No summary.", counts or {}, None)
219
 
220
- # -------------------
221
- # UI (embedded-local samples + uploads)
222
- # -------------------
223
  NOTE = (
224
- "Model: VisDrone (aerial **cars/pedestrians/vehicles**). It does **not** include a 'drone' class. "
225
- "Use top‑down scenes with people/traffic for best results."
226
  )
227
 
228
- with gr.Blocks(title="Aerial Object Detector (VisDrone)") as demo:
229
  gr.Markdown(
230
  """
231
- # Aerial Object Detector (Pretrained on VisDrone)
232
- Use the **embedded samples** or your own uploads.
233
- Exports: **CSV** and **PDF** reports.
234
  """
235
  )
236
 
@@ -246,7 +300,7 @@ Exports: **CSV** and **PDF** reports.
246
  with gr.Column():
247
  conf_img = gr.Slider(0.05, 0.8, 0.35, step=0.05, label="Confidence")
248
  iou_img = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU")
249
- load_embed_img = gr.Button("Load Embedded Sample Image")
250
  run_img = gr.Button("Run Detection")
251
  gr.Markdown(NOTE)
252
 
@@ -262,9 +316,7 @@ Exports: **CSV** and **PDF** reports.
262
  annotated_tmp_img_path = gr.State(value=None)
263
 
264
  def _load_embed_img():
265
- if os.path.exists(EMBED_IMG):
266
- return EMBED_IMG
267
- return None
268
 
269
  load_embed_img.click(fn=_load_embed_img, outputs=[image_in])
270
 
@@ -294,7 +346,7 @@ Exports: **CSV** and **PDF** reports.
294
  conf_vid = gr.Slider(0.05, 0.8, 0.35, step=0.05, label="Confidence")
295
  iou_vid = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU")
296
  max_frames = gr.Slider(60, 2000, 300, step=10, label="Max frames to process")
297
- load_embed_vid = gr.Button("Load Embedded Sample Video")
298
  run_vid = gr.Button("Run Detection")
299
  gr.Markdown(NOTE)
300
 
@@ -308,9 +360,7 @@ Exports: **CSV** and **PDF** reports.
308
  pdf_vid_path = gr.File(label="PDF Report", interactive=False)
309
 
310
  def _load_embed_vid():
311
- if os.path.exists(EMBED_VID):
312
- return EMBED_VID
313
- return None
314
 
315
  load_embed_vid.click(fn=_load_embed_vid, outputs=[video_in])
316
 
@@ -338,11 +388,13 @@ Exports: **CSV** and **PDF** reports.
338
  outputs=[pdf_vid_path],
339
  )
340
 
 
 
341
  gr.Markdown(
342
  f"""
343
- **Weights:** `{REPO_ID}/{FILENAME}` (downloaded lazily)
344
- **Diagnostics** FFmpeg: {'Yes' if _ffmpeg_ok() else 'No'} • Python: 3.10
345
- **Tip:** For true *drone* detection, I can swap in a UAV‑specific model. Say the word and I’ll rewire it.
346
  """
347
  )
348
 
 
5
  import json
6
  import gradio as gr
7
 
8
+ # =========================================================
9
+ # Config — you can override these via Space Secrets / Env
10
+ # =========================================================
11
+ # If you know the exact HF repo + file you want, set:
12
+ # HF_MODEL_REPO = "owner/repo"
13
+ # HF_MODEL_FILE = "path/to/weights.pt"
14
+ HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "").strip()
15
+ HF_MODEL_FILE = os.getenv("HF_MODEL_FILE", "").strip()
16
+
17
+ # Fallback candidates (tried in order) — real drone/UAV detectors
18
+ MODEL_CANDIDATES = []
19
+ if HF_MODEL_REPO and HF_MODEL_FILE:
20
+ MODEL_CANDIDATES.append((HF_MODEL_REPO, HF_MODEL_FILE))
21
+
22
+ # A couple of known community models. If one is unavailable, the next is tried.
23
+ MODEL_CANDIDATES += [
24
+ ("keremberke/yolov8n-drone-detection", "best.pt"), # small, fast
25
+ ("keremberke/yolov8m-drone-detection", "best.pt"), # larger, more accurate
26
+ ]
27
+
28
+ # Embedded samples (we’ll download a short drone clip and auto‑extract a frame as the image)
29
  SAMPLES_DIR = "samples"
30
+ EMBED_VID = os.path.join(SAMPLES_DIR, "uav_sample.mp4")
31
+ EMBED_IMG = os.path.join(SAMPLES_DIR, "uav_sample_frame.jpg")
32
+ DRONE_VIDEO_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/drone.mp4"
33
 
34
+ # =========================================================
35
  # Lazy state
36
+ # =========================================================
37
  _model = None
38
  _model_err = None
39
  _model_names = None
40
  _ffmpeg_status = None
41
+ _loaded_repo = None
42
+ _loaded_file = None
43
 
44
  def _lazy_cv2():
45
  import cv2
 
57
  _ffmpeg_status = False
58
  return _ffmpeg_status
59
 
60
+ def _ensure_samples():
61
+ os.makedirs(SAMPLES_DIR, exist_ok=True)
62
+ # Download drone video if missing
63
+ if not os.path.exists(EMBED_VID):
64
+ try:
65
+ import requests
66
+ r = requests.get(DRONE_VIDEO_URL, timeout=30)
67
+ r.raise_for_status()
68
+ with open(EMBED_VID, "wb") as f:
69
+ f.write(r.content)
70
+ except Exception:
71
+ pass
72
+ # Extract one frame from the video as the image sample
73
+ if os.path.exists(EMBED_VID) and not os.path.exists(EMBED_IMG):
74
+ try:
75
+ cv2 = _lazy_cv2()
76
+ cap = cv2.VideoCapture(EMBED_VID)
77
+ # Skip a few frames so the drone is centered
78
+ frame_no = 15
79
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no)
80
+ ok, frame = cap.read()
81
+ cap.release()
82
+ if ok and frame is not None:
83
+ cv2.imwrite(EMBED_IMG, frame)
84
+ except Exception:
85
+ pass
86
+
87
+ _ensure_samples()
88
+
89
+ def _download_from_hf(repo_id: str, filename: str) -> str:
90
  from huggingface_hub import hf_hub_download
91
+ return hf_hub_download(repo_id=repo_id, filename=filename)
92
 
93
  def _get_model(conf: float, iou: float):
94
+ """Try to load a UAV-specific YOLO model from the candidate list."""
95
+ global _model, _model_err, _model_names, _loaded_repo, _loaded_file
96
  if _model is None and _model_err is None:
97
+ from ultralytics import YOLO
98
+ last_err = None
99
+ for repo, file in MODEL_CANDIDATES:
 
 
 
100
  try:
101
+ weights_path = _download_from_hf(repo, file)
102
+ m = YOLO(weights_path)
103
+ m.overrides["max_det"] = 300
104
+ _model = m
105
+ _loaded_repo, _loaded_file = repo, file
106
+ try:
107
+ _model_names = m.model.names if hasattr(m, "model") else None
108
+ except Exception:
109
+ _model_names = None
110
+ break
111
+ except Exception as e:
112
+ last_err = e
113
+ continue
114
+ if _model is None and last_err is not None:
115
+ _model_err = f"Model load failed. Tried: {MODEL_CANDIDATES}. Last error: {last_err}"
116
  if _model_err:
117
  raise RuntimeError(_model_err)
118
  _model.overrides["conf"] = float(conf)
119
  _model.overrides["iou"] = float(iou)
120
  return _model
121
 
122
+ # =========================================================
123
  # Helpers
124
+ # =========================================================
125
  def _results_to_rows(results) -> List[dict]:
126
  rows: List[dict] = []
127
  if not results:
 
201
  c.showPage(); c.save()
202
  return out_path
203
 
204
+ # =========================================================
205
  # Inference
206
+ # =========================================================
207
  def detect_image(image, conf: float, iou: float):
208
  if image is None:
209
  return None, [], "No image provided.", None, None
 
265
 
266
  def export_pdf_img(summary: str, table_rows: List[dict], annotated_tmp_jpg: Optional[str]):
267
  counts = _count_by_class(table_rows or [])
268
+ return _save_pdf("UAV Detector — Image Report", summary or "No summary.", counts,
269
  annotated_tmp_jpg if annotated_tmp_jpg and os.path.exists(annotated_tmp_jpg) else None)
270
 
271
  def export_pdf_vid(summary: str, counts: dict):
272
+ return _save_pdf("UAV Detector — Video Report", summary or "No summary.", counts or {}, None)
273
 
274
+ # =========================================================
275
+ # UI (embedded UAV samples + uploads)
276
+ # =========================================================
277
  NOTE = (
278
+ "UAV model: detects drones (class names vary per checkpoint, e.g., 'drone', 'uav'). "
279
+ "Use scenes where the drone occupies enough pixels (≥ 30–40 px on the short side)."
280
  )
281
 
282
+ with gr.Blocks(title="UAV / Drone Detector (YOLO)") as demo:
283
  gr.Markdown(
284
  """
285
+ # UAV / Drone Detector (Pretrained YOLO)
286
+ We embedded a **drone video** and auto‑extracted an **image frame** so you can test immediately.
287
+ Use your own uploads too. Exports: **CSV** and **PDF**.
288
  """
289
  )
290
 
 
300
  with gr.Column():
301
  conf_img = gr.Slider(0.05, 0.8, 0.35, step=0.05, label="Confidence")
302
  iou_img = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU")
303
+ load_embed_img = gr.Button("Load Embedded UAV Image")
304
  run_img = gr.Button("Run Detection")
305
  gr.Markdown(NOTE)
306
 
 
316
  annotated_tmp_img_path = gr.State(value=None)
317
 
318
  def _load_embed_img():
319
+ return EMBED_IMG if os.path.exists(EMBED_IMG) else None
 
 
320
 
321
  load_embed_img.click(fn=_load_embed_img, outputs=[image_in])
322
 
 
346
  conf_vid = gr.Slider(0.05, 0.8, 0.35, step=0.05, label="Confidence")
347
  iou_vid = gr.Slider(0.1, 0.9, 0.45, step=0.05, label="NMS IoU")
348
  max_frames = gr.Slider(60, 2000, 300, step=10, label="Max frames to process")
349
+ load_embed_vid = gr.Button("Load Embedded UAV Video")
350
  run_vid = gr.Button("Run Detection")
351
  gr.Markdown(NOTE)
352
 
 
360
  pdf_vid_path = gr.File(label="PDF Report", interactive=False)
361
 
362
  def _load_embed_vid():
363
+ return EMBED_VID if os.path.exists(EMBED_VID) else None
 
 
364
 
365
  load_embed_vid.click(fn=_load_embed_vid, outputs=[video_in])
366
 
 
388
  outputs=[pdf_vid_path],
389
  )
390
 
391
+ # Footer / diagnostics
392
+ model_str = f"{_loaded_repo}/{_loaded_file}" if _loaded_repo else "loading on first run"
393
  gr.Markdown(
394
  f"""
395
+ **Model:** {model_str}
396
+ **Diagnostics:** FFmpeg: {'Yes' if _ffmpeg_ok() else 'No'} • Python: 3.10
397
+ If loading fails, set Space Secrets `HF_MODEL_REPO` and `HF_MODEL_FILE` to a known drone checkpoint.
398
  """
399
  )
400