rba28 commited on
Commit
b7e601b
·
verified ·
1 Parent(s): 716dac8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -53
app.py CHANGED
@@ -3,26 +3,29 @@ import io
3
  import time
4
  import json
5
  import tempfile
6
- from typing import List, Dict, Tuple
7
 
8
  import cv2
9
  import gradio as gr
10
  import numpy as np
11
  import pandas as pd
 
12
 
 
13
  from ultralyticsplus import YOLO, render_result
14
 
15
  # =========================
16
  # CONFIG
17
  # =========================
18
  MODEL_ID = "mshamrai/yolov8s-visdrone"
 
19
  SAMPLES_DIR = "samples"
20
  SAMPLE_IMAGE = os.path.join(SAMPLES_DIR, "drone_sample.jpg")
21
  SAMPLE_VIDEO = os.path.join(SAMPLES_DIR, "airspace_sample.mp4")
22
 
 
23
  SAMPLE_URLS = {
24
  SAMPLE_IMAGE: "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/airplane.jpg",
25
- # This is a small demo clip just to validate pipeline; replace with your own short UAV/airspace clip if you prefer.
26
  SAMPLE_VIDEO: "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/short_harvard_bridge.mp4",
27
  }
28
 
@@ -35,35 +38,56 @@ def _ensure_samples():
35
  if os.path.exists(local_path):
36
  continue
37
  try:
38
- import requests
39
- r = requests.get(url, timeout=15)
40
  r.raise_for_status()
41
  with open(local_path, "wb") as f:
42
  f.write(r.content)
43
  except Exception:
44
- # If download fails (e.g., no internet policy), we just skip; UI still works with user uploads.
45
  pass
46
 
47
  _ensure_samples()
48
 
49
  # =========================
50
- # MODEL
 
 
 
 
 
 
 
 
 
 
51
  # =========================
52
- _model = None
53
- def load_model(conf: float, iou: float):
54
- global _model
55
- if _model is None:
56
- _model = YOLO(MODEL_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  _model.overrides["conf"] = float(conf)
58
  _model.overrides["iou"] = float(iou)
59
- _model.overrides["max_det"] = 300
60
  return _model
61
 
62
  # =========================
63
  # UTILS
64
  # =========================
65
  def results_to_rows(results) -> List[dict]:
66
- rows = []
67
  if not results:
68
  return rows
69
  r = results[0]
@@ -86,18 +110,28 @@ def results_to_rows(results) -> List[dict]:
86
  return rows
87
 
88
  def dict_count_by_class(rows: List[dict]) -> Dict[str, int]:
89
- tally = {}
90
  for r in rows:
91
  tally[r["class"]] = tally.get(r["class"], 0) + 1
92
  return tally
93
 
94
- def write_video(path: str, fps: float, w: int, h: int):
95
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
96
- return cv2.VideoWriter(path, fourcc, fps, (w, h))
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def save_dataframe_to_csv(rows: List[dict]) -> str:
99
  if not rows:
100
- # create an empty CSV for consistency
101
  df = pd.DataFrame(columns=["class","confidence","x1","y1","x2","y2","width","height"])
102
  else:
103
  df = pd.DataFrame(rows)
@@ -108,8 +142,8 @@ def save_dataframe_to_csv(rows: List[dict]) -> str:
108
  def save_pdf_report(title: str,
109
  summary_text: str,
110
  counts: Dict[str, int],
111
- annotated_image_path: str | None = None) -> str:
112
- # Light-weight PDF (no external dependencies besides reportlab)
113
  from reportlab.lib.pagesizes import A4
114
  from reportlab.pdfgen import canvas
115
  from reportlab.lib.units import cm
@@ -125,7 +159,7 @@ def save_pdf_report(title: str,
125
  y -= 1.2*cm
126
 
127
  c.setFont("Helvetica", 11)
128
- for line in summary_text.splitlines():
129
  c.drawString(2*cm, y, line[:110])
130
  y -= 0.7*cm
131
 
@@ -166,17 +200,21 @@ def detect_on_image(image: np.ndarray, conf: float, iou: float):
166
  model = load_model(conf, iou)
167
  results = model.predict(image, imgsz=960, verbose=False)
168
  rows = results_to_rows(results)
169
- annotated = render_result(image, results[0])
170
 
171
  counts = dict_count_by_class(rows)
172
  summary = "Detections: " + ", ".join(f"{k}: {v}" for k, v in counts.items()) if rows else "No objects detected."
173
 
174
- # Save a temp annotated image for PDF export convenience
175
  tmp_img = os.path.join(tempfile.gettempdir(), f"annotated_{int(time.time())}.jpg")
176
- cv2.imwrite(tmp_img, annotated[:, :, ::-1]) # BGR->RGB guard if needed
177
- csv_path = save_dataframe_to_csv(rows)
 
 
 
178
 
179
- return annotated, rows, summary, csv_path, tmp_img
 
180
 
181
  def detect_on_video(video_path: str, conf: float, iou: float, max_frames: int = 300):
182
  if not video_path:
@@ -191,37 +229,40 @@ def detect_on_video(video_path: str, conf: float, iou: float, max_frames: int =
191
  w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 1280)
192
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 720)
193
 
194
- out_path = os.path.join(tempfile.gettempdir(), f"annotated_{int(time.time())}.mp4")
195
- writer = write_video(out_path, fps, w, h)
 
 
196
 
197
- total_counts = {}
198
  frame_idx = 0
199
- while True:
200
- ok, frame = cap.read()
201
- if not ok:
202
- break
203
- frame_idx += 1
204
- if frame_idx > int(max_frames):
205
- break
206
-
207
- results = model.predict(frame, imgsz=960, verbose=False)
208
- for row in results_to_rows(results):
209
- total_counts[row["class"]] = total_counts.get(row["class"], 0) + 1
210
 
211
- annotated = render_result(frame, results[0])
212
- writer.write(annotated)
 
213
 
214
- cap.release()
215
- writer.release()
 
 
 
216
 
217
  summary = "Detections (frame-wise tallies): " + ", ".join(f"{k}: {v}" for k, v in total_counts.items()) if total_counts else "No objects detected."
218
- # For videos, CSV is a tally (not per-box) to keep file small
219
  rows = [{"class": k, "count": v} for k, v in sorted(total_counts.items())]
220
  csv_path = save_dataframe_to_csv(rows)
221
 
222
  return out_path, total_counts, summary, csv_path
223
 
224
- def export_pdf_image(summary: str, table_rows: List[dict], annotated_tmp_jpg: str):
225
  counts = dict_count_by_class(table_rows or [])
226
  pdf_path = save_pdf_report(
227
  title="Airspace Drone Detector — Image Report",
@@ -281,12 +322,12 @@ No dataset or training required — just run it.
281
  pdf_img_btn = gr.Button("Generate PDF Report")
282
  pdf_img_path = gr.File(label="PDF Report", interactive=False)
283
 
284
- # Hidden state for annotated path (for PDF embedding)
285
  annotated_tmp_img_path = gr.State(value=None)
286
 
287
  def _run_img(image, conf, iou):
288
- annotated, rows, summary, csv_path, tmp_img = detect_on_image(image, conf, iou)
289
- return annotated, rows, summary, csv_path, tmp_img
290
 
291
  run_img.click(
292
  fn=_run_img,
@@ -300,7 +341,6 @@ No dataset or training required — just run it.
300
  outputs=[pdf_img_path],
301
  )
302
 
303
- # Prefilled example (if sample exists)
304
  if os.path.exists(SAMPLE_IMAGE):
305
  gr.Examples(
306
  examples=[[SAMPLE_IMAGE]],
@@ -352,9 +392,13 @@ No dataset or training required — just run it.
352
  )
353
 
354
  gr.Markdown(
355
- """
356
- **Model:** `mshamrai/yolov8s-visdrone` (pretrained; pulled via `ultralyticsplus`)
357
- **Credits:** Ultralytics and VisDrone community weights.
 
 
 
 
358
  """
359
  )
360
 
 
3
  import time
4
  import json
5
  import tempfile
6
+ from typing import List, Dict, Tuple, Optional
7
 
8
  import cv2
9
  import gradio as gr
10
  import numpy as np
11
  import pandas as pd
12
+ import requests
13
 
14
+ # YOLO wrapper (pulls pretrained model from Hugging Face by ID)
15
  from ultralyticsplus import YOLO, render_result
16
 
17
  # =========================
18
  # CONFIG
19
  # =========================
20
  MODEL_ID = "mshamrai/yolov8s-visdrone"
21
+
22
  SAMPLES_DIR = "samples"
23
  SAMPLE_IMAGE = os.path.join(SAMPLES_DIR, "drone_sample.jpg")
24
  SAMPLE_VIDEO = os.path.join(SAMPLES_DIR, "airspace_sample.mp4")
25
 
26
+ # Small public files for smoke testing (replace with your own if desired)
27
  SAMPLE_URLS = {
28
  SAMPLE_IMAGE: "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/airplane.jpg",
 
29
  SAMPLE_VIDEO: "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/short_harvard_bridge.mp4",
30
  }
31
 
 
38
  if os.path.exists(local_path):
39
  continue
40
  try:
41
+ r = requests.get(url, timeout=20)
 
42
  r.raise_for_status()
43
  with open(local_path, "wb") as f:
44
  f.write(r.content)
45
  except Exception:
46
+ # If download fails (e.g., offline build), UI still works with user uploads
47
  pass
48
 
49
  _ensure_samples()
50
 
51
  # =========================
52
+ # DIAGNOSTICS
53
+ # =========================
54
+ def _ffmpeg_ok() -> bool:
55
+ try:
56
+ v = cv2.getBuildInformation()
57
+ return ("FFMPEG:YES" in v) or ("FFMPEG: YES" in v)
58
+ except Exception:
59
+ return False
60
+
61
+ # =========================
62
+ # MODEL (robust lazy loader)
63
  # =========================
64
+ _model: Optional[YOLO] = None
65
+ _model_error: Optional[str] = None
66
+
67
+ def load_model(conf: float, iou: float) -> YOLO:
68
+ """
69
+ Load the pretrained YOLO model once and set runtime thresholds.
70
+ Raises RuntimeError if loading previously failed.
71
+ """
72
+ global _model, _model_error
73
+ if _model is None and _model_error is None:
74
+ try:
75
+ m = YOLO(MODEL_ID) # pulls weights from HF
76
+ m.overrides["max_det"] = 300
77
+ _model = m
78
+ except Exception as e:
79
+ _model_error = f"Model load failed: {e}"
80
+ if _model_error:
81
+ raise RuntimeError(_model_error)
82
  _model.overrides["conf"] = float(conf)
83
  _model.overrides["iou"] = float(iou)
 
84
  return _model
85
 
86
  # =========================
87
  # UTILS
88
  # =========================
89
  def results_to_rows(results) -> List[dict]:
90
+ rows: List[dict] = []
91
  if not results:
92
  return rows
93
  r = results[0]
 
110
  return rows
111
 
112
  def dict_count_by_class(rows: List[dict]) -> Dict[str, int]:
113
+ tally: Dict[str, int] = {}
114
  for r in rows:
115
  tally[r["class"]] = tally.get(r["class"], 0) + 1
116
  return tally
117
 
118
+ def write_video(base_path: str, fps: float, w: int, h: int) -> Tuple[cv2.VideoWriter, str]:
119
+ """
120
+ Try MP4 first; if it fails (codec not available), fall back to AVI/MJPG.
121
+ Returns (writer, output_path).
122
+ """
123
+ # MP4
124
+ mp4_path = base_path if base_path.endswith(".mp4") else base_path + ".mp4"
125
+ writer = cv2.VideoWriter(mp4_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
126
+ if writer is not None and getattr(writer, "isOpened", lambda: False)():
127
+ return writer, mp4_path
128
+ # Fallback AVI
129
+ avi_path = os.path.splitext(mp4_path)[0] + ".avi"
130
+ writer = cv2.VideoWriter(avi_path, cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h))
131
+ return writer, avi_path
132
 
133
  def save_dataframe_to_csv(rows: List[dict]) -> str:
134
  if not rows:
 
135
  df = pd.DataFrame(columns=["class","confidence","x1","y1","x2","y2","width","height"])
136
  else:
137
  df = pd.DataFrame(rows)
 
142
  def save_pdf_report(title: str,
143
  summary_text: str,
144
  counts: Dict[str, int],
145
+ annotated_image_path: Optional[str] = None) -> str:
146
+ # Lightweight PDF via reportlab
147
  from reportlab.lib.pagesizes import A4
148
  from reportlab.pdfgen import canvas
149
  from reportlab.lib.units import cm
 
159
  y -= 1.2*cm
160
 
161
  c.setFont("Helvetica", 11)
162
+ for line in (summary_text or "").splitlines():
163
  c.drawString(2*cm, y, line[:110])
164
  y -= 0.7*cm
165
 
 
200
  model = load_model(conf, iou)
201
  results = model.predict(image, imgsz=960, verbose=False)
202
  rows = results_to_rows(results)
203
+ annotated = render_result(image, results[0]) # returns np.ndarray in BGR
204
 
205
  counts = dict_count_by_class(rows)
206
  summary = "Detections: " + ", ".join(f"{k}: {v}" for k, v in counts.items()) if rows else "No objects detected."
207
 
208
+ # Save annotated image (ensure correct color order for disk write)
209
  tmp_img = os.path.join(tempfile.gettempdir(), f"annotated_{int(time.time())}.jpg")
210
+ try:
211
+ # render_result returns BGR; cv2.imwrite expects BGR, so write directly
212
+ cv2.imwrite(tmp_img, annotated)
213
+ except Exception:
214
+ tmp_img = None
215
 
216
+ csv_path = save_dataframe_to_csv(rows)
217
+ return annotated[:, :, ::-1], rows, summary, csv_path, tmp_img # Convert to RGB for Gradio Image
218
 
219
  def detect_on_video(video_path: str, conf: float, iou: float, max_frames: int = 300):
220
  if not video_path:
 
229
  w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 1280)
230
  h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 720)
231
 
232
+ writer, out_path = write_video(os.path.join(tempfile.gettempdir(), f"annotated_{int(time.time())}"), fps, w, h)
233
+ if writer is None or (hasattr(writer, "isOpened") and not writer.isOpened()):
234
+ cap.release()
235
+ return None, None, "Video writer could not open. Try another format/resolution.", None
236
 
237
+ total_counts: Dict[str, int] = {}
238
  frame_idx = 0
239
+ try:
240
+ while True:
241
+ ok, frame = cap.read()
242
+ if not ok:
243
+ break
244
+ frame_idx += 1
245
+ if frame_idx > int(max_frames):
246
+ break
 
 
 
247
 
248
+ results = model.predict(frame, imgsz=960, verbose=False)
249
+ for row in results_to_rows(results):
250
+ total_counts[row["class"]] = total_counts.get(row["class"], 0) + 1
251
 
252
+ annotated = render_result(frame, results[0])
253
+ writer.write(annotated)
254
+ finally:
255
+ cap.release()
256
+ writer.release()
257
 
258
  summary = "Detections (frame-wise tallies): " + ", ".join(f"{k}: {v}" for k, v in total_counts.items()) if total_counts else "No objects detected."
259
+ # For videos, export a compact CSV tally
260
  rows = [{"class": k, "count": v} for k, v in sorted(total_counts.items())]
261
  csv_path = save_dataframe_to_csv(rows)
262
 
263
  return out_path, total_counts, summary, csv_path
264
 
265
+ def export_pdf_image(summary: str, table_rows: List[dict], annotated_tmp_jpg: Optional[str]):
266
  counts = dict_count_by_class(table_rows or [])
267
  pdf_path = save_pdf_report(
268
  title="Airspace Drone Detector — Image Report",
 
322
  pdf_img_btn = gr.Button("Generate PDF Report")
323
  pdf_img_path = gr.File(label="PDF Report", interactive=False)
324
 
325
+ # Hidden state for annotated-image path (for PDF embedding)
326
  annotated_tmp_img_path = gr.State(value=None)
327
 
328
  def _run_img(image, conf, iou):
329
+ annotated_rgb, rows, summary, csv_path, tmp_img = detect_on_image(image, conf, iou)
330
+ return annotated_rgb, rows, summary, csv_path, tmp_img
331
 
332
  run_img.click(
333
  fn=_run_img,
 
341
  outputs=[pdf_img_path],
342
  )
343
 
 
344
  if os.path.exists(SAMPLE_IMAGE):
345
  gr.Examples(
346
  examples=[[SAMPLE_IMAGE]],
 
392
  )
393
 
394
  gr.Markdown(
395
+ f"""
396
+ **Model:** `{MODEL_ID}` (pretrained; pulled via `ultralyticsplus`)
397
+ **Diagnostics**
398
+ - FFmpeg available: {'Yes' if _ffmpeg_ok() else 'No'}
399
+ - Python: 3.10 (set via runtime.txt)
400
+ - Torch: 2.3.1 (pinned in requirements)
401
+ - Ultralytics: 8.3.x
402
  """
403
  )
404