danielhshi8224 commited on
Commit
9aaf8dc
·
1 Parent(s): 649fb36

obj detect only

Browse files
Files changed (1) hide show
  1. app.py +105 -216
app.py CHANGED
@@ -1,259 +1,165 @@
1
- #Main Gradio app ith image classification and object detection tabs
2
- import gradio as gr
3
- import torch
4
- import torch.nn.functional as F
5
- from transformers import AutoImageProcessor, AutoModelForImageClassification
6
- from PIL import Image
7
  import os
8
  import csv
9
  import tempfile
10
  from pathlib import Path
11
- from ultralytics import YOLO
12
- # ultralytics YOLO import (for object detection)
 
 
 
 
13
  try:
14
  from ultralytics import YOLO
15
  except Exception:
16
  YOLO = None
17
 
18
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
19
- MODEL_ID = "dshi01/convnext-tiny-224-7clss"
20
-
21
- print(f"Loading model from: {MODEL_ID}")
22
- processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
23
- model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
24
- model.eval()
25
-
26
- # (Optional) use model's own labels if present
27
- ID2LABEL = [
28
- model.config.id2label.get(str(i), model.config.id2label.get(i, f"Label_{i}"))
29
- for i in range(model.config.num_labels)
30
- ]
31
- def classify_image(image):
32
- if not isinstance(image, Image.Image):
33
- image = Image.fromarray(image).convert("RGB")
34
-
35
- inputs = processor(images=image, return_tensors="pt")
36
- with torch.no_grad():
37
- logits = model(**inputs).logits
38
- probs = F.softmax(logits, dim=1)[0].tolist()
39
-
40
- return {ID2LABEL[i]: float(p) for i, p in enumerate(probs)}
41
-
42
- # ---------- NEW: batch classify up to 10 images ----------
43
  MAX_BATCH = 10
44
 
45
- def classify_images_batch(files):
46
- """
47
- files: list of gradio UploadedFile (paths) or None
48
- Returns:
49
- - gallery: list of (image, caption)
50
- - table: list of rows for Dataframe
51
- """
52
- if not files:
53
- return [], [], None
54
-
55
- # Keep at most 10
56
- files = files[:MAX_BATCH]
57
-
58
- # Load as PIL
59
- pil_images, names = [], []
60
- for f in files:
61
- path = getattr(f, "name", None) or getattr(f, "path", None) or f
62
- try:
63
- img = Image.open(path).convert("RGB")
64
- pil_images.append(img)
65
- names.append(os.path.basename(path))
66
- except Exception:
67
- # Skip unreadable file
68
- continue
69
-
70
- if not pil_images:
71
- return [], [], None
72
-
73
- # Batch preprocess + forward
74
- inputs = processor(images=pil_images, return_tensors="pt")
75
- with torch.no_grad():
76
- logits = model(**inputs).logits
77
- probs = F.softmax(logits, dim=1)
78
-
79
- # Build outputs
80
- gallery = []
81
- table_rows = [] # [filename, top1_label, top1_conf, top3_labels, top3_confs]
82
-
83
- for idx, (img, fname) in enumerate(zip(pil_images, names)):
84
- p = probs[idx].tolist()
85
- top_idxs = sorted(range(len(p)), key=lambda i: p[i], reverse=True)[:3]
86
- top1 = top_idxs[0]
87
- caption = f"{ID2LABEL[top1]} ({p[top1]:.2%})"
88
 
89
- gallery.append((img, f"{fname}\n{caption}"))
 
 
 
 
90
 
91
- top3_labels = [ID2LABEL[i] for i in top_idxs]
92
- top3_scores = [round(p[i], 4) for i in top_idxs]
93
- table_rows.append([
94
- fname,
95
- ID2LABEL[top1],
96
- round(p[top1], 4),
97
- ", ".join(top3_labels),
98
- ", ".join(map(str, top3_scores)),
99
- ])
100
-
101
- # Create CSV for download
102
- csv_path = None
103
  try:
104
- # Write CSV into a temp file inside project dir so Gradio can serve it
105
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", prefix="predictions_", dir=BASE_DIR, mode="w", newline='', encoding='utf-8')
106
- writer = csv.writer(tmp)
107
- # headers
108
- writer.writerow(["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"])
109
- for row in table_rows:
110
- writer.writerow(row)
111
- tmp.flush()
112
- tmp.close()
113
- csv_path = tmp.name
114
- except Exception:
115
- # If CSV can't be created, return None for the file but keep other outputs
116
- csv_path = None
117
-
118
- return gallery, table_rows, csv_path
119
-
120
 
121
- # ---------- NEW: YOLO object detection for multi-image upload ----------
122
- YOLO_WEIGHTS = os.path.join(BASE_DIR, "yolo11_best.pt")
123
  _yolo_model = None
124
  def _load_yolo():
 
125
  global _yolo_model
126
  if _yolo_model is not None:
127
  return _yolo_model
128
  if YOLO is None:
129
- raise RuntimeError("ultralytics package not installed. Please install 'ultralytics'.")
130
- if not os.path.exists(YOLO_WEIGHTS):
131
- # Try current directory too
132
- alt = Path.cwd() / "yolo11_best.pt"
133
- if alt.exists():
134
- model_path = str(alt)
135
- else:
136
- raise FileNotFoundError(f"YOLO weights not found at {YOLO_WEIGHTS}. Place yolo11_best.pt in project root.")
137
- else:
138
  model_path = YOLO_WEIGHTS
 
 
 
 
 
 
 
 
 
 
139
 
140
  _yolo_model = YOLO(model_path)
141
  return _yolo_model
142
 
143
-
144
- def detect_objects_batch(files, iou=0.25, conf=0.25):
145
  """
146
- Run YOLO detection on multiple images.
147
- Returns: gallery of annotated images, dataframe rows, csv file path
148
  """
149
  if YOLO is None:
150
  return [], [], None
151
-
152
  if not files:
153
  return [], [], None
154
 
155
- # Load model
156
  try:
157
  ymodel = _load_yolo()
158
  except Exception as e:
159
  print("YOLO load error:", e)
160
  return [], [], None
161
 
162
- annotated_paths = []
163
- table_rows = []
164
- gallery = []
165
 
166
  for f in files[:MAX_BATCH]:
167
  path = getattr(f, "name", None) or getattr(f, "path", None) or f
168
  try:
169
- # Run predict; returns a Results object list
170
  results = ymodel.predict(source=path, conf=conf, iou=iou, imgsz=640, verbose=False)
171
  except Exception as e:
172
  print(f"Detection failed for {path}:", e)
173
  continue
174
-
175
- # results is list-like; take first
176
  res = results[0]
177
 
178
- # Prepare annotation image using res.plot() so boxes+confidences are drawn
179
  ann_path = None
180
  try:
181
- ann_img = res.plot() # returns numpy array with annotations
182
- from PIL import Image as PILImage
183
- ann_pil = PILImage.fromarray(ann_img)
184
  out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
185
  os.makedirs(out_dir, exist_ok=True)
186
- ann_filename = os.path.splitext(os.path.basename(path))[0] + "_annotated.jpg"
187
  ann_path = os.path.join(out_dir, ann_filename)
188
  ann_pil.save(ann_path)
189
  except Exception:
190
- # Fallback to ultralytics save if plot() isn't available
191
  try:
192
  out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
193
  res.save(save_dir=out_dir)
194
- saved_files = res.files if hasattr(res, 'files') else []
195
  ann_path = saved_files[0] if saved_files else None
196
  except Exception:
197
  ann_path = None
198
 
199
- # Build table rows from detections
200
- boxes = res.boxes if hasattr(res, 'boxes') else None
201
  if boxes is None or len(boxes) == 0:
202
  table_rows.append([os.path.basename(path), 0, "", "", ""])
203
- if ann_path and os.path.exists(ann_path):
204
- gallery.append((Image.open(ann_path).convert('RGB'), f"{os.path.basename(path)}\nNo detections"))
205
- else:
206
- gallery.append((Image.open(path).convert('RGB'), f"{os.path.basename(path)}\nNo detections"))
207
  continue
208
 
209
- det_labels = []
210
- det_scores = []
211
- det_boxes = []
212
  for box in boxes:
213
- # box.cls, box.conf, box.xyxy
214
- cls = int(box.cls.cpu().item()) if hasattr(box, 'cls') else None
215
- # use .item() to extract scalar and avoid numpy deprecation warnings
216
- if hasattr(box, 'conf'):
 
217
  try:
218
- confscore = float(box.conf.cpu().item())
219
  except Exception:
220
- try:
221
- confscore = float(box.conf.item())
222
- except Exception:
223
- confscore = None
224
- else:
225
- confscore = None
226
-
227
- # extract xyxy coords; box.xyxy may be shape (1,4) -> nested list after .tolist()
228
  coords = []
229
- if hasattr(box, 'xyxy'):
230
  try:
231
  arr = box.xyxy.cpu().numpy()
232
- # handle nested shape (1,4) or (4,)
233
- if getattr(arr, 'ndim', None) == 2 and arr.shape[0] == 1:
234
  coords = arr[0].tolist()
235
- elif getattr(arr, 'ndim', None) == 1:
236
  coords = arr.tolist()
237
  else:
238
  coords = arr.reshape(-1).tolist()
239
  except Exception:
240
- # fallback: try to call tolist()
241
  try:
242
  coords = box.xyxy.tolist()
243
  except Exception:
244
  coords = []
245
 
246
- # append detection info
247
  det_labels.append(ymodel.names.get(cls, str(cls)) if cls is not None else "")
248
  det_scores.append(round(confscore, 4) if confscore is not None else "")
249
- # round and store coords
250
  try:
251
  det_boxes.append([round(float(x), 2) for x in coords])
252
  except Exception:
253
- # fallback: store raw repr
254
  det_boxes.append([str(coords)])
255
 
256
- # create readable label:confidence pairs
257
  label_conf_pairs = [f"{l}:{s}" for l, s in zip(det_labels, det_scores)]
258
  boxes_repr = ["[" + ", ".join(map(str, b)) + "]" for b in det_boxes]
259
  table_rows.append([
@@ -261,28 +167,25 @@ def detect_objects_batch(files, iou=0.25, conf=0.25):
261
  len(det_labels),
262
  ", ".join(label_conf_pairs),
263
  ", ".join(boxes_repr),
264
- "; ".join([str(b) for b in det_boxes])
265
  ])
266
 
267
- # Use annotated image if exists
268
- if ann_path and os.path.exists(ann_path):
269
- try:
270
- gallery.append((Image.open(ann_path).convert('RGB'), f"{os.path.basename(path)}\n{len(det_labels)} detections"))
271
- except Exception:
272
- gallery.append((Image.open(path).convert('RGB'), f"{os.path.basename(path)}\n{len(det_labels)} detections"))
273
- else:
274
- gallery.append((Image.open(path).convert('RGB'), f"{os.path.basename(path)}\n{len(det_labels)} detections"))
275
 
276
  # write CSV
277
  csv_path = None
278
  try:
279
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", prefix="yolo_preds_", dir=BASE_DIR, mode="w", newline='', encoding='utf-8')
 
 
 
280
  writer = csv.writer(tmp)
281
  writer.writerow(["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"])
282
  for r in table_rows:
283
  writer.writerow(r)
284
- tmp.flush()
285
- tmp.close()
286
  csv_path = tmp.name
287
  except Exception as e:
288
  print("Failed to write CSV:", e)
@@ -291,49 +194,35 @@ def detect_objects_batch(files, iou=0.25, conf=0.25):
291
  return gallery, table_rows, csv_path
292
 
293
  # ---------- UI ----------
294
- single = gr.Interface(
295
- fn=classify_image,
296
- inputs=gr.Image(type="pil", label="Upload Underwater Image"),
297
- outputs=gr.Label(num_top_classes=len(ID2LABEL), label="Species Classification"),
298
- title="🌊 BenthicAI - Single Image",
299
- description="Classify one image into one of 7 benthic species."
300
- )
301
-
302
- batch = gr.Interface(
303
- fn=classify_images_batch,
304
- inputs=gr.Files(label="Upload up to 10 images"),
305
- outputs=[
306
- gr.Gallery(label="Results (Top-1 in caption)", height=500, rows=3),
307
- gr.Dataframe(
308
- headers=["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"],
309
- label="Predictions Table",
310
- wrap=True
311
- )
312
- , gr.File(label="Download CSV")
313
- ],
314
- title="🌊 BenthicAI - Batch (up to 10)",
315
- description="Upload multiple images (max 10). Outputs a gallery with captions and a table of top predictions.",
316
- )
317
-
318
- demo = gr.TabbedInterface([single, batch], ["Single", "Batch"])
319
- print(YOLO==None, flush=True)
320
- # Add Object Detection tab if ultralytics available
321
- if YOLO is not None:
322
- detection_iface = gr.Interface(
323
  fn=detect_objects_batch,
324
- inputs=[gr.Files(label="Upload images for detection (max 10)"), gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="conf threshold"), gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="IOU threshold")],
 
 
 
 
325
  outputs=[
326
  gr.Gallery(label="Detections (annotated)", height=500, rows=3),
327
- gr.Dataframe(headers=["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"], label="Detection Table"),
328
- gr.File(label="Download CSV")
 
329
  ],
330
- title="🌊 BenthicAI - Object Detection",
331
- description="Run YOLO object detection on multiple images. Requires 'yolo11_best.pt' in project root."
 
 
 
 
332
  )
333
 
334
- # extend tabs
335
- demo = gr.TabbedInterface([single, batch, detection_iface], ["Single", "Batch", "Detection"])
336
-
337
  if __name__ == "__main__":
338
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
339
-
 
1
+ # app.py Object Detection only (multi-image YOLO, up to 10)
 
 
 
 
 
2
  import os
3
  import csv
4
  import tempfile
5
  from pathlib import Path
6
+ from typing import List, Tuple
7
+
8
+ import gradio as gr
9
+ from PIL import Image
10
+
11
+ # Try import ultralytics (ensure it's in requirements.txt)
12
  try:
13
  from ultralytics import YOLO
14
  except Exception:
15
  YOLO = None
16
 
17
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  MAX_BATCH = 10
19
 
20
+ # Option A: local file baked into Space (easiest if allowed)
21
+ YOLO_WEIGHTS = os.path.join(BASE_DIR, "yolo11_best.pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Option B (optional): pull from a private HF model repo using a Space secret
24
+ # Set these env vars in your Space if you want auto-download:
25
+ # HF_TOKEN=<read token> YOLO_REPO_ID="yourname/yolo-detector"
26
+ HF_TOKEN = os.environ.get("HF_TOKEN")
27
+ YOLO_REPO_ID = os.environ.get("YOLO_REPO_ID")
28
 
29
+ def _download_from_hub_if_needed() -> str | None:
30
+ """If YOLO_REPO_ID is set, download weights with huggingface_hub; else return None."""
31
+ if not YOLO_REPO_ID:
32
+ return None
 
 
 
 
 
 
 
 
33
  try:
34
+ from huggingface_hub import snapshot_download
35
+ local_dir = snapshot_download(
36
+ repo_id=YOLO_REPO_ID, repo_type="model", token=HF_TOKEN
37
+ )
38
+ # try common filenames
39
+ for name in ("yolo11_best.pt", "best.pt", "yolo.pt", "weights.pt"):
40
+ cand = Path(local_dir) / name
41
+ if cand.exists():
42
+ return str(cand)
43
+ except Exception as e:
44
+ print("[YOLO] Hub download failed:", e)
45
+ return None
 
 
 
 
46
 
 
 
47
  _yolo_model = None
48
  def _load_yolo():
49
+ """Load YOLO weights either from local file or HF Hub."""
50
  global _yolo_model
51
  if _yolo_model is not None:
52
  return _yolo_model
53
  if YOLO is None:
54
+ raise RuntimeError("ultralytics package not installed. Add 'ultralytics' to requirements.txt")
55
+
56
+ model_path = None
57
+ if os.path.exists(YOLO_WEIGHTS):
 
 
 
 
 
58
  model_path = YOLO_WEIGHTS
59
+ else:
60
+ hub_path = _download_from_hub_if_needed()
61
+ if hub_path:
62
+ model_path = hub_path
63
+
64
+ if not model_path:
65
+ raise FileNotFoundError(
66
+ "YOLO weights not found. Either include 'yolo11_best.pt' in the repo root, "
67
+ "or set YOLO_REPO_ID (+ HF_TOKEN if private) to pull from the Hub."
68
+ )
69
 
70
  _yolo_model = YOLO(model_path)
71
  return _yolo_model
72
 
73
+ def detect_objects_batch(files, conf=0.25, iou=0.25):
 
74
  """
75
+ Run YOLO detection on multiple images (up to 10).
76
+ Returns: gallery of annotated images, rows table, csv filepath
77
  """
78
  if YOLO is None:
79
  return [], [], None
 
80
  if not files:
81
  return [], [], None
82
 
 
83
  try:
84
  ymodel = _load_yolo()
85
  except Exception as e:
86
  print("YOLO load error:", e)
87
  return [], [], None
88
 
89
+ gallery, table_rows = [], []
 
 
90
 
91
  for f in files[:MAX_BATCH]:
92
  path = getattr(f, "name", None) or getattr(f, "path", None) or f
93
  try:
 
94
  results = ymodel.predict(source=path, conf=conf, iou=iou, imgsz=640, verbose=False)
95
  except Exception as e:
96
  print(f"Detection failed for {path}:", e)
97
  continue
 
 
98
  res = results[0]
99
 
100
+ # annotated image
101
  ann_path = None
102
  try:
103
+ ann_img = res.plot()
104
+ ann_pil = Image.fromarray(ann_img)
 
105
  out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
106
  os.makedirs(out_dir, exist_ok=True)
107
+ ann_filename = Path(path).stem + "_annotated.jpg"
108
  ann_path = os.path.join(out_dir, ann_filename)
109
  ann_pil.save(ann_path)
110
  except Exception:
 
111
  try:
112
  out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
113
  res.save(save_dir=out_dir)
114
+ saved_files = getattr(res, "files", [])
115
  ann_path = saved_files[0] if saved_files else None
116
  except Exception:
117
  ann_path = None
118
 
119
+ # extract detections
120
+ boxes = getattr(res, "boxes", None)
121
  if boxes is None or len(boxes) == 0:
122
  table_rows.append([os.path.basename(path), 0, "", "", ""])
123
+ img_for_gallery = Image.open(ann_path).convert("RGB") if ann_path and os.path.exists(ann_path) \
124
+ else Image.open(path).convert("RGB")
125
+ gallery.append((img_for_gallery, f"{os.path.basename(path)}\nNo detections"))
 
126
  continue
127
 
128
+ det_labels, det_scores, det_boxes = [], [], []
 
 
129
  for box in boxes:
130
+ cls = int(box.cls.cpu().item()) if hasattr(box, "cls") else None
131
+ # conf
132
+ try:
133
+ confscore = float(box.conf.cpu().item()) if hasattr(box, "conf") else None
134
+ except Exception:
135
  try:
136
+ confscore = float(box.conf.item())
137
  except Exception:
138
+ confscore = None
139
+ # xyxy
 
 
 
 
 
 
140
  coords = []
141
+ if hasattr(box, "xyxy"):
142
  try:
143
  arr = box.xyxy.cpu().numpy()
144
+ if getattr(arr, "ndim", None) == 2 and arr.shape[0] == 1:
 
145
  coords = arr[0].tolist()
146
+ elif getattr(arr, "ndim", None) == 1:
147
  coords = arr.tolist()
148
  else:
149
  coords = arr.reshape(-1).tolist()
150
  except Exception:
 
151
  try:
152
  coords = box.xyxy.tolist()
153
  except Exception:
154
  coords = []
155
 
 
156
  det_labels.append(ymodel.names.get(cls, str(cls)) if cls is not None else "")
157
  det_scores.append(round(confscore, 4) if confscore is not None else "")
 
158
  try:
159
  det_boxes.append([round(float(x), 2) for x in coords])
160
  except Exception:
 
161
  det_boxes.append([str(coords)])
162
 
 
163
  label_conf_pairs = [f"{l}:{s}" for l, s in zip(det_labels, det_scores)]
164
  boxes_repr = ["[" + ", ".join(map(str, b)) + "]" for b in det_boxes]
165
  table_rows.append([
 
167
  len(det_labels),
168
  ", ".join(label_conf_pairs),
169
  ", ".join(boxes_repr),
170
+ "; ".join([str(b) for b in det_boxes]),
171
  ])
172
 
173
+ img_for_gallery = Image.open(ann_path).convert("RGB") if ann_path and os.path.exists(ann_path) \
174
+ else Image.open(path).convert("RGB")
175
+ gallery.append((img_for_gallery, f"{os.path.basename(path)}\n{len(det_labels)} detections"))
 
 
 
 
 
176
 
177
  # write CSV
178
  csv_path = None
179
  try:
180
+ tmp = tempfile.NamedTemporaryFile(
181
+ delete=False, suffix=".csv", prefix="yolo_preds_", dir=BASE_DIR,
182
+ mode="w", newline='', encoding='utf-8'
183
+ )
184
  writer = csv.writer(tmp)
185
  writer.writerow(["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"])
186
  for r in table_rows:
187
  writer.writerow(r)
188
+ tmp.flush(); tmp.close()
 
189
  csv_path = tmp.name
190
  except Exception as e:
191
  print("Failed to write CSV:", e)
 
194
  return gallery, table_rows, csv_path
195
 
196
  # ---------- UI ----------
197
+ if YOLO is None:
198
+ demo = gr.Interface(
199
+ fn=lambda *a, **k: ("Ultralytics not installed; add 'ultralytics' to requirements.txt",),
200
+ inputs=[],
201
+ outputs="text",
202
+ title="🌊 BenthicAI Object Detection",
203
+ description="Ultralytics is not installed."
204
+ )
205
+ else:
206
+ demo = gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  fn=detect_objects_batch,
208
+ inputs=[
209
+ gr.Files(label="Upload images (max 10)"),
210
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Conf threshold"),
211
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="IoU threshold"),
212
+ ],
213
  outputs=[
214
  gr.Gallery(label="Detections (annotated)", height=500, rows=3),
215
+ gr.Dataframe(headers=["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"],
216
+ label="Detection Table"),
217
+ gr.File(label="Download CSV"),
218
  ],
219
+ title="🌊 BenthicAI Object Detection",
220
+ description=(
221
+ "Run YOLO object detection on multiple images. "
222
+ "Place 'yolo11_best.pt' in the repo root, OR set YOLO_REPO_ID (+ HF_TOKEN if private) "
223
+ "to fetch from the Hub."
224
+ ),
225
  )
226
 
 
 
 
227
  if __name__ == "__main__":
228
+ demo.launch(server_name="0.0.0.0", server_port=7860)