Spaces:
Sleeping
Sleeping
| import os | |
| import glob | |
| import shutil | |
| import tempfile | |
| import gradio as gr | |
| import cv2 | |
| from PIL import Image | |
| from ultralytics import YOLO | |
| # ------------------------------------------------------------------ | |
| # HF Spaces requirement: Ultralytics config must be writable | |
| # ------------------------------------------------------------------ | |
| os.environ["YOLO_CONFIG_DIR"] = "/tmp/Ultralytics" | |
| # ------------------------------------------------------------------ | |
| # Load YOLOv8 segmentation model (best.pt in same folder) | |
| # ------------------------------------------------------------------ | |
| model = YOLO("best.pt") | |
| def unwrap_image(item): | |
| """ | |
| Normalize Gradio Gallery item to a PIL.Image. | |
| Gradio Gallery may return: | |
| - PIL.Image | |
| - dict with {"data": PIL.Image, ...} | |
| - tuple (e.g., (PIL.Image, metadata) or (name, PIL.Image)) | |
| """ | |
| # Case 1: dict | |
| if isinstance(item, dict): | |
| return item["data"] | |
| # Case 2: tuple | |
| if isinstance(item, tuple): | |
| for v in item: | |
| if hasattr(v, "save"): # PIL.Image has .save() | |
| return v | |
| raise ValueError("Gallery tuple does not contain a PIL image") | |
| # Case 3: already PIL.Image | |
| return item | |
| def segment(gallery_items, conf=0.25, imgsz=640): | |
| """ | |
| Runs YOLOv8 segmentation on multiple images using the SAME | |
| pipeline as the training notebook: model.predict(save=True). | |
| Returns segmentation overlays as a flat list of PIL images. | |
| """ | |
| if not gallery_items: | |
| return [] | |
| # ------------------------------------------------------------------ | |
| # Create a temporary working directory | |
| # ------------------------------------------------------------------ | |
| workdir = tempfile.mkdtemp(prefix="yolo_seg_") | |
| src_dir = os.path.join(workdir, "src") | |
| os.makedirs(src_dir, exist_ok=True) | |
| try: | |
| # -------------------------------------------------------------- | |
| # Save uploaded images to disk | |
| # -------------------------------------------------------------- | |
| for i, item in enumerate(gallery_items): | |
| img = unwrap_image(item) | |
| out_path = os.path.join(src_dir, f"img_{i}.jpg") | |
| img.save(out_path) | |
| # -------------------------------------------------------------- | |
| # Run YOLO prediction (NOTEBOOK-EQUIVALENT) | |
| # -------------------------------------------------------------- | |
| preds = model.predict( | |
| source=src_dir, | |
| imgsz=int(imgsz), | |
| conf=float(conf), | |
| task="segment", | |
| save=True, | |
| project=workdir, | |
| name="preds", | |
| exist_ok=True, | |
| ) | |
| # YOLO saves overlays into: {workdir}/preds/ | |
| pred_dir = os.path.join(workdir, "preds") | |
| # -------------------------------------------------------------- | |
| # Load saved overlay images | |
| # -------------------------------------------------------------- | |
| outputs = [] | |
| for ext in ("*.jpg", "*.png", "*.jpeg", "*.bmp"): | |
| for f in sorted(glob.glob(os.path.join(pred_dir, ext))): | |
| im = cv2.imread(f) | |
| if im is None: | |
| continue | |
| im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) | |
| outputs.append(Image.fromarray(im)) | |
| return outputs | |
| finally: | |
| # Always clean up temp files | |
| shutil.rmtree(workdir, ignore_errors=True) | |
| # ------------------------------------------------------------------ | |
| # Gradio Interface (HF-safe) | |
| # ------------------------------------------------------------------ | |
| demo = gr.Interface( | |
| fn=segment, | |
| inputs=[ | |
| gr.Gallery(label="Upload one or multiple images", type="pil"), | |
| gr.Slider(0.01, 0.9, value=0.25, step=0.01, label="Confidence"), | |
| gr.Slider(320, 1280, value=640, step=32, label="Image size"), | |
| ], | |
| outputs=gr.Gallery(label="YOLOv8 Segmentation Overlays"), | |
| title="YOLOv8 Segmentation (Hugging Face Space)", | |
| description="Multi-image YOLOv8 segmentation using the same predict(save=True) pipeline as training.", | |
| ) | |
| # IMPORTANT: disable SSR to avoid asyncio cleanup noise on HF | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ssr_mode=False, | |
| ) | |