Noursine commited on
Commit
7d462b8
·
verified ·
1 Parent(s): baa5379

Create yolo_infer

Browse files
Files changed (1) hide show
  1. yolo_infer +71 -0
yolo_infer ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ # Redirect config/cache dirs to writable /tmp to avoid permission denied errors on Hugging Face Spaces
4
+ os.environ["YOLO_CONFIG_DIR"] = "/tmp/ultralytics"
5
+ os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
6
+ os.environ["XDG_CACHE_HOME"] = "/tmp/fontconfig"
7
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
8
+
9
+ import gdown
10
+ from ultralytics import YOLO
11
+
12
+ # Clean and safe model path
13
+ MODEL_PATH = "/tmp/best.pt"
14
+ DRIVE_ID = "10IYZGOXIwp3AUKAf05f6sKb4JQJyBEaK"
15
+
16
+ def download_model():
17
+ if not os.path.exists(MODEL_PATH):
18
+ url = f"https://drive.google.com/uc?id={DRIVE_ID}"
19
+ tmp_dir = "/tmp/gdown"
20
+ os.makedirs(tmp_dir, exist_ok=True)
21
+ os.environ["GDOWN_CACHE_DIR"] = tmp_dir
22
+
23
+ print("Downloading YOLO model...")
24
+ downloaded_path = gdown.download(
25
+ url,
26
+ output=MODEL_PATH,
27
+ quiet=False,
28
+ fuzzy=True,
29
+ use_cookies=False
30
+ )
31
+ print("Download complete.")
32
+
33
+ # If gdown renamed the file (e.g., to 'best (2).pt'), rename it back
34
+ if downloaded_path and downloaded_path != MODEL_PATH:
35
+ os.rename(downloaded_path, MODEL_PATH)
36
+
37
+ return MODEL_PATH
38
+
39
+ # Download model and load it
40
+ model = YOLO(download_model())
41
+
42
+ def predict_yolo(image_path):
43
+ # Use stream=True to mimic your colab behavior
44
+ results = model.predict(source=image_path, conf=0.26, stream=True)
45
+
46
+ # Get first result (only one image uploaded per call)
47
+ r = next(results)
48
+
49
+ # Optional mask thresholding if masks exist
50
+ if r.masks is not None:
51
+ r.masks.data = (r.masks.data > 0.3).float()
52
+
53
+ # Get predictions info
54
+ detections = []
55
+ for box in r.boxes:
56
+ cls = int(box.cls[0])
57
+ conf = float(box.conf[0])
58
+ xyxy = box.xyxy[0].tolist()
59
+ detections.append({
60
+ "class": cls,
61
+ "confidence": round(conf, 3),
62
+ "box": xyxy
63
+ })
64
+
65
+ # Get plotted image with labels, boxes, masks drawn by YOLO's internal method
66
+ pred_img = r.plot(labels=True, conf=False, boxes=True)
67
+
68
+ # Convert from RGB numpy array to BGR for OpenCV if needed later
69
+ pred_img_bgr = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR)
70
+
71
+ return pred_img_bgr, detections