danielhshi8224 commited on
Commit
649fb36
·
1 Parent(s): adcf9e4

obj detect app

Browse files
Files changed (2) hide show
  1. app.py +339 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Main Gradio app ith image classification and object detection tabs
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
6
+ from PIL import Image
7
+ import os
8
+ import csv
9
+ import tempfile
10
+ from pathlib import Path
11
+ from ultralytics import YOLO
12
+ # ultralytics YOLO import (for object detection)
13
+ try:
14
+ from ultralytics import YOLO
15
+ except Exception:
16
+ YOLO = None
17
+
18
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
19
+ MODEL_ID = "dshi01/convnext-tiny-224-7clss"
20
+
21
+ print(f"Loading model from: {MODEL_ID}")
22
+ processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
23
+ model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
24
+ model.eval()
25
+
26
+ # (Optional) use model's own labels if present
27
+ ID2LABEL = [
28
+ model.config.id2label.get(str(i), model.config.id2label.get(i, f"Label_{i}"))
29
+ for i in range(model.config.num_labels)
30
+ ]
31
+ def classify_image(image):
32
+ if not isinstance(image, Image.Image):
33
+ image = Image.fromarray(image).convert("RGB")
34
+
35
+ inputs = processor(images=image, return_tensors="pt")
36
+ with torch.no_grad():
37
+ logits = model(**inputs).logits
38
+ probs = F.softmax(logits, dim=1)[0].tolist()
39
+
40
+ return {ID2LABEL[i]: float(p) for i, p in enumerate(probs)}
41
+
42
+ # ---------- NEW: batch classify up to 10 images ----------
43
+ MAX_BATCH = 10
44
+
45
+ def classify_images_batch(files):
46
+ """
47
+ files: list of gradio UploadedFile (paths) or None
48
+ Returns:
49
+ - gallery: list of (image, caption)
50
+ - table: list of rows for Dataframe
51
+ """
52
+ if not files:
53
+ return [], [], None
54
+
55
+ # Keep at most 10
56
+ files = files[:MAX_BATCH]
57
+
58
+ # Load as PIL
59
+ pil_images, names = [], []
60
+ for f in files:
61
+ path = getattr(f, "name", None) or getattr(f, "path", None) or f
62
+ try:
63
+ img = Image.open(path).convert("RGB")
64
+ pil_images.append(img)
65
+ names.append(os.path.basename(path))
66
+ except Exception:
67
+ # Skip unreadable file
68
+ continue
69
+
70
+ if not pil_images:
71
+ return [], [], None
72
+
73
+ # Batch preprocess + forward
74
+ inputs = processor(images=pil_images, return_tensors="pt")
75
+ with torch.no_grad():
76
+ logits = model(**inputs).logits
77
+ probs = F.softmax(logits, dim=1)
78
+
79
+ # Build outputs
80
+ gallery = []
81
+ table_rows = [] # [filename, top1_label, top1_conf, top3_labels, top3_confs]
82
+
83
+ for idx, (img, fname) in enumerate(zip(pil_images, names)):
84
+ p = probs[idx].tolist()
85
+ top_idxs = sorted(range(len(p)), key=lambda i: p[i], reverse=True)[:3]
86
+ top1 = top_idxs[0]
87
+ caption = f"{ID2LABEL[top1]} ({p[top1]:.2%})"
88
+
89
+ gallery.append((img, f"{fname}\n{caption}"))
90
+
91
+ top3_labels = [ID2LABEL[i] for i in top_idxs]
92
+ top3_scores = [round(p[i], 4) for i in top_idxs]
93
+ table_rows.append([
94
+ fname,
95
+ ID2LABEL[top1],
96
+ round(p[top1], 4),
97
+ ", ".join(top3_labels),
98
+ ", ".join(map(str, top3_scores)),
99
+ ])
100
+
101
+ # Create CSV for download
102
+ csv_path = None
103
+ try:
104
+ # Write CSV into a temp file inside project dir so Gradio can serve it
105
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", prefix="predictions_", dir=BASE_DIR, mode="w", newline='', encoding='utf-8')
106
+ writer = csv.writer(tmp)
107
+ # headers
108
+ writer.writerow(["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"])
109
+ for row in table_rows:
110
+ writer.writerow(row)
111
+ tmp.flush()
112
+ tmp.close()
113
+ csv_path = tmp.name
114
+ except Exception:
115
+ # If CSV can't be created, return None for the file but keep other outputs
116
+ csv_path = None
117
+
118
+ return gallery, table_rows, csv_path
119
+
120
+
121
+ # ---------- NEW: YOLO object detection for multi-image upload ----------
122
+ YOLO_WEIGHTS = os.path.join(BASE_DIR, "yolo11_best.pt")
123
+ _yolo_model = None
124
+ def _load_yolo():
125
+ global _yolo_model
126
+ if _yolo_model is not None:
127
+ return _yolo_model
128
+ if YOLO is None:
129
+ raise RuntimeError("ultralytics package not installed. Please install 'ultralytics'.")
130
+ if not os.path.exists(YOLO_WEIGHTS):
131
+ # Try current directory too
132
+ alt = Path.cwd() / "yolo11_best.pt"
133
+ if alt.exists():
134
+ model_path = str(alt)
135
+ else:
136
+ raise FileNotFoundError(f"YOLO weights not found at {YOLO_WEIGHTS}. Place yolo11_best.pt in project root.")
137
+ else:
138
+ model_path = YOLO_WEIGHTS
139
+
140
+ _yolo_model = YOLO(model_path)
141
+ return _yolo_model
142
+
143
+
144
+ def detect_objects_batch(files, iou=0.25, conf=0.25):
145
+ """
146
+ Run YOLO detection on multiple images.
147
+ Returns: gallery of annotated images, dataframe rows, csv file path
148
+ """
149
+ if YOLO is None:
150
+ return [], [], None
151
+
152
+ if not files:
153
+ return [], [], None
154
+
155
+ # Load model
156
+ try:
157
+ ymodel = _load_yolo()
158
+ except Exception as e:
159
+ print("YOLO load error:", e)
160
+ return [], [], None
161
+
162
+ annotated_paths = []
163
+ table_rows = []
164
+ gallery = []
165
+
166
+ for f in files[:MAX_BATCH]:
167
+ path = getattr(f, "name", None) or getattr(f, "path", None) or f
168
+ try:
169
+ # Run predict; returns a Results object list
170
+ results = ymodel.predict(source=path, conf=conf, iou=iou, imgsz=640, verbose=False)
171
+ except Exception as e:
172
+ print(f"Detection failed for {path}:", e)
173
+ continue
174
+
175
+ # results is list-like; take first
176
+ res = results[0]
177
+
178
+ # Prepare annotation image using res.plot() so boxes+confidences are drawn
179
+ ann_path = None
180
+ try:
181
+ ann_img = res.plot() # returns numpy array with annotations
182
+ from PIL import Image as PILImage
183
+ ann_pil = PILImage.fromarray(ann_img)
184
+ out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
185
+ os.makedirs(out_dir, exist_ok=True)
186
+ ann_filename = os.path.splitext(os.path.basename(path))[0] + "_annotated.jpg"
187
+ ann_path = os.path.join(out_dir, ann_filename)
188
+ ann_pil.save(ann_path)
189
+ except Exception:
190
+ # Fallback to ultralytics save if plot() isn't available
191
+ try:
192
+ out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
193
+ res.save(save_dir=out_dir)
194
+ saved_files = res.files if hasattr(res, 'files') else []
195
+ ann_path = saved_files[0] if saved_files else None
196
+ except Exception:
197
+ ann_path = None
198
+
199
+ # Build table rows from detections
200
+ boxes = res.boxes if hasattr(res, 'boxes') else None
201
+ if boxes is None or len(boxes) == 0:
202
+ table_rows.append([os.path.basename(path), 0, "", "", ""])
203
+ if ann_path and os.path.exists(ann_path):
204
+ gallery.append((Image.open(ann_path).convert('RGB'), f"{os.path.basename(path)}\nNo detections"))
205
+ else:
206
+ gallery.append((Image.open(path).convert('RGB'), f"{os.path.basename(path)}\nNo detections"))
207
+ continue
208
+
209
+ det_labels = []
210
+ det_scores = []
211
+ det_boxes = []
212
+ for box in boxes:
213
+ # box.cls, box.conf, box.xyxy
214
+ cls = int(box.cls.cpu().item()) if hasattr(box, 'cls') else None
215
+ # use .item() to extract scalar and avoid numpy deprecation warnings
216
+ if hasattr(box, 'conf'):
217
+ try:
218
+ confscore = float(box.conf.cpu().item())
219
+ except Exception:
220
+ try:
221
+ confscore = float(box.conf.item())
222
+ except Exception:
223
+ confscore = None
224
+ else:
225
+ confscore = None
226
+
227
+ # extract xyxy coords; box.xyxy may be shape (1,4) -> nested list after .tolist()
228
+ coords = []
229
+ if hasattr(box, 'xyxy'):
230
+ try:
231
+ arr = box.xyxy.cpu().numpy()
232
+ # handle nested shape (1,4) or (4,)
233
+ if getattr(arr, 'ndim', None) == 2 and arr.shape[0] == 1:
234
+ coords = arr[0].tolist()
235
+ elif getattr(arr, 'ndim', None) == 1:
236
+ coords = arr.tolist()
237
+ else:
238
+ coords = arr.reshape(-1).tolist()
239
+ except Exception:
240
+ # fallback: try to call tolist()
241
+ try:
242
+ coords = box.xyxy.tolist()
243
+ except Exception:
244
+ coords = []
245
+
246
+ # append detection info
247
+ det_labels.append(ymodel.names.get(cls, str(cls)) if cls is not None else "")
248
+ det_scores.append(round(confscore, 4) if confscore is not None else "")
249
+ # round and store coords
250
+ try:
251
+ det_boxes.append([round(float(x), 2) for x in coords])
252
+ except Exception:
253
+ # fallback: store raw repr
254
+ det_boxes.append([str(coords)])
255
+
256
+ # create readable label:confidence pairs
257
+ label_conf_pairs = [f"{l}:{s}" for l, s in zip(det_labels, det_scores)]
258
+ boxes_repr = ["[" + ", ".join(map(str, b)) + "]" for b in det_boxes]
259
+ table_rows.append([
260
+ os.path.basename(path),
261
+ len(det_labels),
262
+ ", ".join(label_conf_pairs),
263
+ ", ".join(boxes_repr),
264
+ "; ".join([str(b) for b in det_boxes])
265
+ ])
266
+
267
+ # Use annotated image if exists
268
+ if ann_path and os.path.exists(ann_path):
269
+ try:
270
+ gallery.append((Image.open(ann_path).convert('RGB'), f"{os.path.basename(path)}\n{len(det_labels)} detections"))
271
+ except Exception:
272
+ gallery.append((Image.open(path).convert('RGB'), f"{os.path.basename(path)}\n{len(det_labels)} detections"))
273
+ else:
274
+ gallery.append((Image.open(path).convert('RGB'), f"{os.path.basename(path)}\n{len(det_labels)} detections"))
275
+
276
+ # write CSV
277
+ csv_path = None
278
+ try:
279
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", prefix="yolo_preds_", dir=BASE_DIR, mode="w", newline='', encoding='utf-8')
280
+ writer = csv.writer(tmp)
281
+ writer.writerow(["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"])
282
+ for r in table_rows:
283
+ writer.writerow(r)
284
+ tmp.flush()
285
+ tmp.close()
286
+ csv_path = tmp.name
287
+ except Exception as e:
288
+ print("Failed to write CSV:", e)
289
+ csv_path = None
290
+
291
+ return gallery, table_rows, csv_path
292
+
293
+ # ---------- UI ----------
294
+ single = gr.Interface(
295
+ fn=classify_image,
296
+ inputs=gr.Image(type="pil", label="Upload Underwater Image"),
297
+ outputs=gr.Label(num_top_classes=len(ID2LABEL), label="Species Classification"),
298
+ title="🌊 BenthicAI - Single Image",
299
+ description="Classify one image into one of 7 benthic species."
300
+ )
301
+
302
+ batch = gr.Interface(
303
+ fn=classify_images_batch,
304
+ inputs=gr.Files(label="Upload up to 10 images"),
305
+ outputs=[
306
+ gr.Gallery(label="Results (Top-1 in caption)", height=500, rows=3),
307
+ gr.Dataframe(
308
+ headers=["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"],
309
+ label="Predictions Table",
310
+ wrap=True
311
+ )
312
+ , gr.File(label="Download CSV")
313
+ ],
314
+ title="🌊 BenthicAI - Batch (up to 10)",
315
+ description="Upload multiple images (max 10). Outputs a gallery with captions and a table of top predictions.",
316
+ )
317
+
318
+ demo = gr.TabbedInterface([single, batch], ["Single", "Batch"])
319
+ print(YOLO==None, flush=True)
320
+ # Add Object Detection tab if ultralytics available
321
+ if YOLO is not None:
322
+ detection_iface = gr.Interface(
323
+ fn=detect_objects_batch,
324
+ inputs=[gr.Files(label="Upload images for detection (max 10)"), gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="conf threshold"), gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="IOU threshold")],
325
+ outputs=[
326
+ gr.Gallery(label="Detections (annotated)", height=500, rows=3),
327
+ gr.Dataframe(headers=["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"], label="Detection Table"),
328
+ gr.File(label="Download CSV")
329
+ ],
330
+ title="🌊 BenthicAI - Object Detection",
331
+ description="Run YOLO object detection on multiple images. Requires 'yolo11_best.pt' in project root."
332
+ )
333
+
334
+ # extend tabs
335
+ demo = gr.TabbedInterface([single, batch, detection_iface], ["Single", "Batch", "Detection"])
336
+
337
+ if __name__ == "__main__":
338
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
339
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ gradio
5
+ Pillow
6
+ ultralytics