Spaces:
Sleeping
Sleeping
| # app.py β Object Detection only (multi-image YOLO, up to 10) | |
| import os | |
| import csv | |
| import tempfile | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| import gradio as gr | |
| from PIL import Image | |
| # Try import ultralytics (ensure it's in requirements.txt) | |
| try: | |
| from ultralytics import YOLO | |
| except Exception: | |
| YOLO = None | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MAX_BATCH = 10 | |
| # Option A: local file baked into Space (easiest if allowed) | |
| YOLO_WEIGHTS = os.path.join(BASE_DIR, "yolo11_150_best.pt") | |
| # Option B (optional): pull from a private HF model repo using a Space secret | |
| # Set these env vars in your Space if you want auto-download: | |
| # HF_TOKEN=<read token> YOLO_REPO_ID="yourname/yolo-detector" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| YOLO_REPO_ID = os.environ.get("YOLO_REPO_ID") | |
| def _download_from_hub_if_needed() -> str | None: | |
| """If YOLO_REPO_ID is set, download weights with huggingface_hub; else return None.""" | |
| if not YOLO_REPO_ID: | |
| return None | |
| try: | |
| from huggingface_hub import snapshot_download | |
| local_dir = snapshot_download( | |
| repo_id=YOLO_REPO_ID, repo_type="model", token=HF_TOKEN | |
| ) | |
| # try common filenames | |
| for name in ("yolo11_best.pt", "best.pt", "yolo.pt", "weights.pt"): | |
| cand = Path(local_dir) / name | |
| if cand.exists(): | |
| return str(cand) | |
| except Exception as e: | |
| print("[YOLO] Hub download failed:", e) | |
| return None | |
| _yolo_model = None | |
| def _load_yolo(): | |
| """Load YOLO weights either from local file or HF Hub.""" | |
| global _yolo_model | |
| if _yolo_model is not None: | |
| return _yolo_model | |
| if YOLO is None: | |
| raise RuntimeError("ultralytics package not installed. Add 'ultralytics' to requirements.txt") | |
| model_path = None | |
| if os.path.exists(YOLO_WEIGHTS): | |
| model_path = YOLO_WEIGHTS | |
| else: | |
| hub_path = _download_from_hub_if_needed() | |
| if hub_path: | |
| model_path = hub_path | |
| if not model_path: | |
| raise FileNotFoundError( | |
| "YOLO weights not found. Either include 'yolo11_best.pt' in the repo root, " | |
| "or set YOLO_REPO_ID (+ HF_TOKEN if private) to pull from the Hub." | |
| ) | |
| _yolo_model = YOLO(model_path) | |
| return _yolo_model | |
| def detect_objects_batch(files, conf=0.25, iou=0.25): | |
| """ | |
| Run YOLO detection on multiple images (up to 10). | |
| Returns: gallery of annotated images, rows table, csv filepath | |
| """ | |
| if YOLO is None: | |
| return [], [], None | |
| if not files: | |
| return [], [], None | |
| # Diagnostic: list incoming file objects/paths (useful when Gradio passes blob paths) | |
| try: | |
| incoming = [getattr(f, 'name', None) or getattr(f, 'path', None) or str(f) for f in files] | |
| print('[DETECT] incoming files:', incoming) | |
| except Exception: | |
| print('[DETECT] incoming files: (unreadable)') | |
| try: | |
| ymodel = _load_yolo() | |
| except Exception as e: | |
| print("YOLO load error:", e) | |
| return [], [], None | |
| gallery, table_rows = [], [] | |
| _created_temp_files = [] | |
| def _ensure_path(fileobj): | |
| """Return a filesystem path suitable for YOLO.predict. | |
| Handles: | |
| - strings that are existing paths | |
| - Gradio 'blob' temp paths without extension | |
| - file-like objects (have .read()) | |
| - bytes | |
| If we create a temp file, record it in _created_temp_files for cleanup. | |
| """ | |
| # If it's already a readable path string | |
| if isinstance(fileobj, str) and os.path.exists(fileobj): | |
| return fileobj | |
| # If object has .path attribute pointing to an existing file | |
| try: | |
| p = getattr(fileobj, 'path', None) | |
| if p and os.path.exists(p): | |
| return p | |
| except Exception: | |
| pass | |
| # If object has a name attribute that's a path | |
| try: | |
| n = getattr(fileobj, 'name', None) | |
| if n and isinstance(n, str) and os.path.exists(n): | |
| return n | |
| except Exception: | |
| pass | |
| # Read bytes from file-like or bytes object | |
| data = None | |
| try: | |
| if hasattr(fileobj, 'read'): | |
| # file-like | |
| data = fileobj.read() | |
| elif isinstance(fileobj, (bytes, bytearray)): | |
| data = bytes(fileobj) | |
| except Exception: | |
| data = None | |
| # If fileobj is a string but file doesn't exist, try reading it | |
| if data is None and isinstance(fileobj, str): | |
| try: | |
| with open(fileobj, 'rb') as fh: | |
| data = fh.read() | |
| except Exception: | |
| data = None | |
| if data is None: | |
| # give up and return the original object | |
| return fileobj | |
| # Detect image format via PIL | |
| from io import BytesIO | |
| try: | |
| bio = BytesIO(data) | |
| img = Image.open(bio) | |
| fmt = (img.format or 'JPEG').lower() | |
| except Exception: | |
| # fallback: try imghdr | |
| try: | |
| import imghdr | |
| fmt = imghdr.what(None, data) or 'jpeg' | |
| except Exception: | |
| fmt = 'jpeg' | |
| suffix = '.' + (fmt if not fmt.startswith('.') else fmt) | |
| try: | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, prefix='gr_blob_', dir=BASE_DIR) | |
| tmp.write(data) | |
| tmp.flush(); tmp.close() | |
| _created_temp_files.append(tmp.name) | |
| print(f"[DETECT] wrote temp file: {tmp.name} (fmt={fmt})") | |
| return tmp.name | |
| except Exception as e: | |
| print('[DETECT] failed to write temp file from upload:', e) | |
| return fileobj | |
| for f in files[:MAX_BATCH]: | |
| path = _ensure_path(f) | |
| # Diagnostic: show resolved path and file info | |
| try: | |
| exists = os.path.exists(path) | |
| size = os.path.getsize(path) if exists else None | |
| except Exception: | |
| exists = False | |
| size = None | |
| print(f"[DETECT] resolved path={path!r}, exists={exists}, size={size}") | |
| # Try opening with PIL to ensure file is a readable image | |
| try: | |
| with Image.open(path) as _img: | |
| print(f"[DETECT] PIL can open file: format={_img.format}, size={_img.size}") | |
| except Exception as pil_e: | |
| print(f"[DETECT] PIL failed to open file before predict: {pil_e}") | |
| try: | |
| results = ymodel.predict(source=path, conf=conf, iou=iou, imgsz=640, verbose=False) | |
| except Exception as e: | |
| import traceback | |
| print(f"[DETECT] Detection failed for {path}: {e}") | |
| traceback.print_exc() | |
| # Also print type/info about the model and source | |
| try: | |
| print(f"[DETECT] model type={type(ymodel)}, model_repr={repr(ymodel)[:200]}") | |
| except Exception: | |
| pass | |
| continue | |
| res = results[0] | |
| # annotated image | |
| ann_path = None | |
| try: | |
| ann_img = res.plot() | |
| ann_pil = Image.fromarray(ann_img) | |
| out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR) | |
| os.makedirs(out_dir, exist_ok=True) | |
| ann_filename = Path(path).stem + "_annotated.jpg" | |
| ann_path = os.path.join(out_dir, ann_filename) | |
| ann_pil.save(ann_path) | |
| except Exception: | |
| try: | |
| out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR) | |
| res.save(save_dir=out_dir) | |
| saved_files = getattr(res, "files", []) | |
| ann_path = saved_files[0] if saved_files else None | |
| except Exception: | |
| ann_path = None | |
| # extract detections | |
| boxes = getattr(res, "boxes", None) | |
| if boxes is None or len(boxes) == 0: | |
| table_rows.append([os.path.basename(path), 0, "", "", ""]) | |
| img_for_gallery = Image.open(ann_path).convert("RGB") if ann_path and os.path.exists(ann_path) \ | |
| else Image.open(path).convert("RGB") | |
| gallery.append((img_for_gallery, f"{os.path.basename(path)}\nNo detections")) | |
| continue | |
| det_labels, det_scores, det_boxes = [], [], [] | |
| for box in boxes: | |
| cls = int(box.cls.cpu().item()) if hasattr(box, "cls") else None | |
| # conf | |
| try: | |
| confscore = float(box.conf.cpu().item()) if hasattr(box, "conf") else None | |
| except Exception: | |
| try: | |
| confscore = float(box.conf.item()) | |
| except Exception: | |
| confscore = None | |
| # xyxy | |
| coords = [] | |
| if hasattr(box, "xyxy"): | |
| try: | |
| arr = box.xyxy.cpu().numpy() | |
| if getattr(arr, "ndim", None) == 2 and arr.shape[0] == 1: | |
| coords = arr[0].tolist() | |
| elif getattr(arr, "ndim", None) == 1: | |
| coords = arr.tolist() | |
| else: | |
| coords = arr.reshape(-1).tolist() | |
| except Exception: | |
| try: | |
| coords = box.xyxy.tolist() | |
| except Exception: | |
| coords = [] | |
| det_labels.append(ymodel.names.get(cls, str(cls)) if cls is not None else "") | |
| det_scores.append(round(confscore, 4) if confscore is not None else "") | |
| try: | |
| det_boxes.append([round(float(x), 2) for x in coords]) | |
| except Exception: | |
| det_boxes.append([str(coords)]) | |
| label_conf_pairs = [f"{l}:{s}" for l, s in zip(det_labels, det_scores)] | |
| boxes_repr = ["[" + ", ".join(map(str, b)) + "]" for b in det_boxes] | |
| table_rows.append([ | |
| os.path.basename(path), | |
| len(det_labels), | |
| ", ".join(label_conf_pairs), | |
| ", ".join(boxes_repr), | |
| "; ".join([str(b) for b in det_boxes]), | |
| ]) | |
| img_for_gallery = Image.open(ann_path).convert("RGB") if ann_path and os.path.exists(ann_path) \ | |
| else Image.open(path).convert("RGB") | |
| gallery.append((img_for_gallery, f"{os.path.basename(path)}\n{len(det_labels)} detections")) | |
| # write CSV | |
| csv_path = None | |
| try: | |
| tmp = tempfile.NamedTemporaryFile( | |
| delete=False, suffix=".csv", prefix="yolo_preds_", dir=BASE_DIR, | |
| mode="w", newline='', encoding='utf-8' | |
| ) | |
| writer = csv.writer(tmp) | |
| writer.writerow(["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"]) | |
| for r in table_rows: | |
| writer.writerow(r) | |
| tmp.flush(); tmp.close() | |
| csv_path = tmp.name | |
| except Exception as e: | |
| print("Failed to write CSV:", e) | |
| csv_path = None | |
| # cleanup created temp files | |
| try: | |
| for p in _created_temp_files: | |
| try: | |
| if os.path.exists(p): | |
| os.remove(p) | |
| print(f"[DETECT] removed temp file: {p}") | |
| except Exception: | |
| pass | |
| except Exception: | |
| pass | |
| return gallery, table_rows, csv_path | |
| # ---------- UI ---------- | |
| if YOLO is None: | |
| demo = gr.Interface( | |
| fn=lambda *a, **k: ("Ultralytics not installed; add 'ultralytics' to requirements.txt",), | |
| inputs=[], | |
| outputs="text", | |
| title="π BenthicAI β Object Detection", | |
| description="Ultralytics is not installed." | |
| ) | |
| else: | |
| demo = gr.Interface( | |
| fn=detect_objects_batch, | |
| inputs=[ | |
| gr.Files(label="Upload images (max 10)"), | |
| gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Conf threshold"), | |
| gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="IoU threshold"), | |
| ], | |
| outputs=[ | |
| gr.Gallery(label="Detections (annotated)", height=500, rows=3), | |
| gr.Dataframe(headers=["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"], | |
| label="Detection Table"), | |
| gr.File(label="Download CSV"), | |
| ], | |
| title="π BenthicAI β Object Detection", | |
| description=( | |
| "Run YOLO object detection on multiple images. " | |
| "Upload up to 10 images at a time. The model detects various benthic species. " | |
| "Adjust the confidence and IoU thresholds as needed." | |
| ), | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |