New_Teeth_Seg_API / yolo_numbering.py
Noursine's picture
Rename yolo_numbering to yolo_numbering.py
a06f359 verified
import os
import cv2
# Redirect config/cache dirs to writable /tmp to avoid permission denied errors on Hugging Face Spaces
os.environ["YOLO_CONFIG_DIR"] = "/tmp/ultralytics"
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
os.environ["XDG_CACHE_HOME"] = "/tmp/fontconfig"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
import gdown
from ultralytics import YOLO
# Clean and safe model path
MODEL_PATH = "/tmp/best.pt"
DRIVE_ID = "10IYZGOXIwp3AUKAf05f6sKb4JQJyBEaK"
def download_model():
if not os.path.exists(MODEL_PATH):
url = f"https://drive.google.com/uc?id={DRIVE_ID}"
tmp_dir = "/tmp/gdown"
os.makedirs(tmp_dir, exist_ok=True)
os.environ["GDOWN_CACHE_DIR"] = tmp_dir
print("Downloading YOLO model...")
downloaded_path = gdown.download(
url,
output=MODEL_PATH,
quiet=False,
fuzzy=True,
use_cookies=False
)
print("Download complete.")
# If gdown renamed the file (e.g., to 'best (2).pt'), rename it back
if downloaded_path and downloaded_path != MODEL_PATH:
os.rename(downloaded_path, MODEL_PATH)
return MODEL_PATH
# Download model and load it
model = YOLO(download_model())
def predict_yolo(image_path):
# Use stream=True to mimic your colab behavior
results = model.predict(source=image_path, conf=0.26, stream=True)
# Get first result (only one image uploaded per call)
r = next(results)
# Optional mask thresholding if masks exist
if r.masks is not None:
r.masks.data = (r.masks.data > 0.3).float()
# Get predictions info
detections = []
for box in r.boxes:
cls = int(box.cls[0])
conf = float(box.conf[0])
xyxy = box.xyxy[0].tolist()
detections.append({
"class": cls,
"confidence": round(conf, 3),
"box": xyxy
})
# Get plotted image with labels, boxes, masks drawn by YOLO's internal method
pred_img = r.plot(labels=True, conf=False, boxes=True)
# Convert from RGB numpy array to BGR for OpenCV if needed later
pred_img_bgr = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR)
return pred_img_bgr, detections