danielhshi8224
reset
ef39fc3
# app.py β€” Object Detection only (multi-image YOLO, up to 10)
import os
import csv
import tempfile
from pathlib import Path
from typing import List, Tuple
import gradio as gr
from PIL import Image
# Try import ultralytics (ensure it's in requirements.txt)
try:
from ultralytics import YOLO
except Exception:
YOLO = None
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MAX_BATCH = 10
# Option A: local file baked into Space (easiest if allowed)
YOLO_WEIGHTS = os.path.join(BASE_DIR, "yolo11_150_best.pt")
# Option B (optional): pull from a private HF model repo using a Space secret
# Set these env vars in your Space if you want auto-download:
# HF_TOKEN=<read token> YOLO_REPO_ID="yourname/yolo-detector"
HF_TOKEN = os.environ.get("HF_TOKEN")
YOLO_REPO_ID = os.environ.get("YOLO_REPO_ID")
def _download_from_hub_if_needed() -> str | None:
"""If YOLO_REPO_ID is set, download weights with huggingface_hub; else return None."""
if not YOLO_REPO_ID:
return None
try:
from huggingface_hub import snapshot_download
local_dir = snapshot_download(
repo_id=YOLO_REPO_ID, repo_type="model", token=HF_TOKEN
)
# try common filenames
for name in ("yolo11_best.pt", "best.pt", "yolo.pt", "weights.pt"):
cand = Path(local_dir) / name
if cand.exists():
return str(cand)
except Exception as e:
print("[YOLO] Hub download failed:", e)
return None
_yolo_model = None
def _load_yolo():
"""Load YOLO weights either from local file or HF Hub."""
global _yolo_model
if _yolo_model is not None:
return _yolo_model
if YOLO is None:
raise RuntimeError("ultralytics package not installed. Add 'ultralytics' to requirements.txt")
model_path = None
if os.path.exists(YOLO_WEIGHTS):
model_path = YOLO_WEIGHTS
else:
hub_path = _download_from_hub_if_needed()
if hub_path:
model_path = hub_path
if not model_path:
raise FileNotFoundError(
"YOLO weights not found. Either include 'yolo11_best.pt' in the repo root, "
"or set YOLO_REPO_ID (+ HF_TOKEN if private) to pull from the Hub."
)
_yolo_model = YOLO(model_path)
return _yolo_model
def detect_objects_batch(files, conf=0.25, iou=0.25):
"""
Run YOLO detection on multiple images (up to 10).
Returns: gallery of annotated images, rows table, csv filepath
"""
if YOLO is None:
return [], [], None
if not files:
return [], [], None
# Diagnostic: list incoming file objects/paths (useful when Gradio passes blob paths)
try:
incoming = [getattr(f, 'name', None) or getattr(f, 'path', None) or str(f) for f in files]
print('[DETECT] incoming files:', incoming)
except Exception:
print('[DETECT] incoming files: (unreadable)')
try:
ymodel = _load_yolo()
except Exception as e:
print("YOLO load error:", e)
return [], [], None
gallery, table_rows = [], []
_created_temp_files = []
def _ensure_path(fileobj):
"""Return a filesystem path suitable for YOLO.predict.
Handles:
- strings that are existing paths
- Gradio 'blob' temp paths without extension
- file-like objects (have .read())
- bytes
If we create a temp file, record it in _created_temp_files for cleanup.
"""
# If it's already a readable path string
if isinstance(fileobj, str) and os.path.exists(fileobj):
return fileobj
# If object has .path attribute pointing to an existing file
try:
p = getattr(fileobj, 'path', None)
if p and os.path.exists(p):
return p
except Exception:
pass
# If object has a name attribute that's a path
try:
n = getattr(fileobj, 'name', None)
if n and isinstance(n, str) and os.path.exists(n):
return n
except Exception:
pass
# Read bytes from file-like or bytes object
data = None
try:
if hasattr(fileobj, 'read'):
# file-like
data = fileobj.read()
elif isinstance(fileobj, (bytes, bytearray)):
data = bytes(fileobj)
except Exception:
data = None
# If fileobj is a string but file doesn't exist, try reading it
if data is None and isinstance(fileobj, str):
try:
with open(fileobj, 'rb') as fh:
data = fh.read()
except Exception:
data = None
if data is None:
# give up and return the original object
return fileobj
# Detect image format via PIL
from io import BytesIO
try:
bio = BytesIO(data)
img = Image.open(bio)
fmt = (img.format or 'JPEG').lower()
except Exception:
# fallback: try imghdr
try:
import imghdr
fmt = imghdr.what(None, data) or 'jpeg'
except Exception:
fmt = 'jpeg'
suffix = '.' + (fmt if not fmt.startswith('.') else fmt)
try:
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, prefix='gr_blob_', dir=BASE_DIR)
tmp.write(data)
tmp.flush(); tmp.close()
_created_temp_files.append(tmp.name)
print(f"[DETECT] wrote temp file: {tmp.name} (fmt={fmt})")
return tmp.name
except Exception as e:
print('[DETECT] failed to write temp file from upload:', e)
return fileobj
for f in files[:MAX_BATCH]:
path = _ensure_path(f)
# Diagnostic: show resolved path and file info
try:
exists = os.path.exists(path)
size = os.path.getsize(path) if exists else None
except Exception:
exists = False
size = None
print(f"[DETECT] resolved path={path!r}, exists={exists}, size={size}")
# Try opening with PIL to ensure file is a readable image
try:
with Image.open(path) as _img:
print(f"[DETECT] PIL can open file: format={_img.format}, size={_img.size}")
except Exception as pil_e:
print(f"[DETECT] PIL failed to open file before predict: {pil_e}")
try:
results = ymodel.predict(source=path, conf=conf, iou=iou, imgsz=640, verbose=False)
except Exception as e:
import traceback
print(f"[DETECT] Detection failed for {path}: {e}")
traceback.print_exc()
# Also print type/info about the model and source
try:
print(f"[DETECT] model type={type(ymodel)}, model_repr={repr(ymodel)[:200]}")
except Exception:
pass
continue
res = results[0]
# annotated image
ann_path = None
try:
ann_img = res.plot()
ann_pil = Image.fromarray(ann_img)
out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
os.makedirs(out_dir, exist_ok=True)
ann_filename = Path(path).stem + "_annotated.jpg"
ann_path = os.path.join(out_dir, ann_filename)
ann_pil.save(ann_path)
except Exception:
try:
out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
res.save(save_dir=out_dir)
saved_files = getattr(res, "files", [])
ann_path = saved_files[0] if saved_files else None
except Exception:
ann_path = None
# extract detections
boxes = getattr(res, "boxes", None)
if boxes is None or len(boxes) == 0:
table_rows.append([os.path.basename(path), 0, "", "", ""])
img_for_gallery = Image.open(ann_path).convert("RGB") if ann_path and os.path.exists(ann_path) \
else Image.open(path).convert("RGB")
gallery.append((img_for_gallery, f"{os.path.basename(path)}\nNo detections"))
continue
det_labels, det_scores, det_boxes = [], [], []
for box in boxes:
cls = int(box.cls.cpu().item()) if hasattr(box, "cls") else None
# conf
try:
confscore = float(box.conf.cpu().item()) if hasattr(box, "conf") else None
except Exception:
try:
confscore = float(box.conf.item())
except Exception:
confscore = None
# xyxy
coords = []
if hasattr(box, "xyxy"):
try:
arr = box.xyxy.cpu().numpy()
if getattr(arr, "ndim", None) == 2 and arr.shape[0] == 1:
coords = arr[0].tolist()
elif getattr(arr, "ndim", None) == 1:
coords = arr.tolist()
else:
coords = arr.reshape(-1).tolist()
except Exception:
try:
coords = box.xyxy.tolist()
except Exception:
coords = []
det_labels.append(ymodel.names.get(cls, str(cls)) if cls is not None else "")
det_scores.append(round(confscore, 4) if confscore is not None else "")
try:
det_boxes.append([round(float(x), 2) for x in coords])
except Exception:
det_boxes.append([str(coords)])
label_conf_pairs = [f"{l}:{s}" for l, s in zip(det_labels, det_scores)]
boxes_repr = ["[" + ", ".join(map(str, b)) + "]" for b in det_boxes]
table_rows.append([
os.path.basename(path),
len(det_labels),
", ".join(label_conf_pairs),
", ".join(boxes_repr),
"; ".join([str(b) for b in det_boxes]),
])
img_for_gallery = Image.open(ann_path).convert("RGB") if ann_path and os.path.exists(ann_path) \
else Image.open(path).convert("RGB")
gallery.append((img_for_gallery, f"{os.path.basename(path)}\n{len(det_labels)} detections"))
# write CSV
csv_path = None
try:
tmp = tempfile.NamedTemporaryFile(
delete=False, suffix=".csv", prefix="yolo_preds_", dir=BASE_DIR,
mode="w", newline='', encoding='utf-8'
)
writer = csv.writer(tmp)
writer.writerow(["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"])
for r in table_rows:
writer.writerow(r)
tmp.flush(); tmp.close()
csv_path = tmp.name
except Exception as e:
print("Failed to write CSV:", e)
csv_path = None
# cleanup created temp files
try:
for p in _created_temp_files:
try:
if os.path.exists(p):
os.remove(p)
print(f"[DETECT] removed temp file: {p}")
except Exception:
pass
except Exception:
pass
return gallery, table_rows, csv_path
# ---------- UI ----------
if YOLO is None:
demo = gr.Interface(
fn=lambda *a, **k: ("Ultralytics not installed; add 'ultralytics' to requirements.txt",),
inputs=[],
outputs="text",
title="🌊 BenthicAI β€” Object Detection",
description="Ultralytics is not installed."
)
else:
demo = gr.Interface(
fn=detect_objects_batch,
inputs=[
gr.Files(label="Upload images (max 10)"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Conf threshold"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="IoU threshold"),
],
outputs=[
gr.Gallery(label="Detections (annotated)", height=500, rows=3),
gr.Dataframe(headers=["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"],
label="Detection Table"),
gr.File(label="Download CSV"),
],
title="🌊 BenthicAI β€” Object Detection",
description=(
"Run YOLO object detection on multiple images. "
"Upload up to 10 images at a time. The model detects various benthic species. "
"Adjust the confidence and IoU thresholds as needed."
),
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)