File size: 2,250 Bytes
7d462b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import cv2
# Redirect config/cache dirs to writable /tmp to avoid permission denied errors on Hugging Face Spaces
os.environ["YOLO_CONFIG_DIR"] = "/tmp/ultralytics"
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
os.environ["XDG_CACHE_HOME"] = "/tmp/fontconfig"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"

import gdown
from ultralytics import YOLO

# Clean and safe model path
MODEL_PATH = "/tmp/best.pt"
DRIVE_ID = "10IYZGOXIwp3AUKAf05f6sKb4JQJyBEaK"

def download_model():
    if not os.path.exists(MODEL_PATH):
        url = f"https://drive.google.com/uc?id={DRIVE_ID}"
        tmp_dir = "/tmp/gdown"
        os.makedirs(tmp_dir, exist_ok=True)
        os.environ["GDOWN_CACHE_DIR"] = tmp_dir

        print("Downloading YOLO model...")
        downloaded_path = gdown.download(
            url,
            output=MODEL_PATH,
            quiet=False,
            fuzzy=True,
            use_cookies=False
        )
        print("Download complete.")

        # If gdown renamed the file (e.g., to 'best (2).pt'), rename it back
        if downloaded_path and downloaded_path != MODEL_PATH:
            os.rename(downloaded_path, MODEL_PATH)

    return MODEL_PATH

# Download model and load it
model = YOLO(download_model())

def predict_yolo(image_path):
    # Use stream=True to mimic your colab behavior
    results = model.predict(source=image_path, conf=0.26, stream=True)
    
    # Get first result (only one image uploaded per call)
    r = next(results)  
    
    # Optional mask thresholding if masks exist
    if r.masks is not None:
        r.masks.data = (r.masks.data > 0.3).float()
    
    # Get predictions info
    detections = []
    for box in r.boxes:
        cls = int(box.cls[0])
        conf = float(box.conf[0])
        xyxy = box.xyxy[0].tolist()
        detections.append({
            "class": cls,
            "confidence": round(conf, 3),
            "box": xyxy
        })
    
    # Get plotted image with labels, boxes, masks drawn by YOLO's internal method
    pred_img = r.plot(labels=True, conf=False, boxes=True)
    
    # Convert from RGB numpy array to BGR for OpenCV if needed later
    pred_img_bgr = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR)
    
    return pred_img_bgr, detections