| import os |
| import cv2 |
| |
| os.environ["YOLO_CONFIG_DIR"] = "/tmp/ultralytics" |
| os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" |
| os.environ["XDG_CACHE_HOME"] = "/tmp/fontconfig" |
| os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" |
|
|
| import gdown |
| from ultralytics import YOLO |
|
|
| |
| MODEL_PATH = "/tmp/best.pt" |
| DRIVE_ID = "10IYZGOXIwp3AUKAf05f6sKb4JQJyBEaK" |
|
|
| def download_model(): |
| if not os.path.exists(MODEL_PATH): |
| url = f"https://drive.google.com/uc?id={DRIVE_ID}" |
| tmp_dir = "/tmp/gdown" |
| os.makedirs(tmp_dir, exist_ok=True) |
| os.environ["GDOWN_CACHE_DIR"] = tmp_dir |
|
|
| print("Downloading YOLO model...") |
| downloaded_path = gdown.download( |
| url, |
| output=MODEL_PATH, |
| quiet=False, |
| fuzzy=True, |
| use_cookies=False |
| ) |
| print("Download complete.") |
|
|
| |
| if downloaded_path and downloaded_path != MODEL_PATH: |
| os.rename(downloaded_path, MODEL_PATH) |
|
|
| return MODEL_PATH |
|
|
| |
| model = YOLO(download_model()) |
|
|
| def predict_yolo(image_path, mode="segmentation"): |
| """ |
| Predict using YOLO and plot results according to mode: |
| - "segmentation": only masks |
| - "boxes_labels": boxes + labels |
| - "scores": boxes + labels + confidence scores |
| """ |
| results = model.predict(source=image_path, conf=0.26, stream=True) |
| r = next(results) |
|
|
| |
| if r.masks is not None and mode in ["segmentation", "boxes_labels", "scores"]: |
| r.masks.data = (r.masks.data > 0.3).float() |
|
|
| |
| detections = [] |
| for box in r.boxes: |
| cls = int(box.cls[0]) |
| conf = float(box.conf[0]) |
| xyxy = box.xyxy[0].tolist() |
| detections.append({ |
| "class": cls, |
| "confidence": round(conf, 3), |
| "box": xyxy |
| }) |
|
|
| |
| if mode == "segmentation": |
| pred_img = r.plot(labels=False, conf=False, boxes=False) |
| elif mode == "boxes_labels": |
| pred_img = r.plot(labels=True, conf=False, boxes=True) |
| elif mode == "scores": |
| pred_img = r.plot(labels=True, conf=True, boxes=True) |
| else: |
| |
| pred_img = r.plot(labels=True, conf=False, boxes=True) |
|
|
| |
| pred_img_bgr = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR) |
|
|
| return pred_img_bgr, detections |
|
|