anycoder-fbbd70d5 / models.py
kamcio1989's picture
Upload folder using huggingface_hub
84125b2 verified
from ultralytics import YOLO
import numpy as np
import cv2
class ObjectDetector:
def __init__(self, model_name="yolov8n.pt"):
"""
Initialize the YOLO model.
Using yolov8n (nano) is recommended for CPU-based real-time inference.
"""
print(f"Loading {model_name}...")
try:
self.model = YOLO(model_name)
except Exception as e:
print(f"Error loading model: {e}")
raise e
def detect_and_annotate(self, image, conf_threshold=0.4, filter_classes=None):
"""
Performs inference on the image and returns the annotated image.
Args:
image (numpy.ndarray): Input image (RGB).
conf_threshold (float): Confidence threshold for detections.
filter_classes (list): List of class names to filter (e.g. ['person', 'car']).
Returns:
numpy.ndarray: Annotated image.
"""
if image is None:
return None
# Convert valid class names to class IDs if filter is provided
classes_ids = None
if filter_classes and len(filter_classes) > 0:
# YOLO model.names is a dict {0: 'person', 1: 'bicycle', ...}
# Invert it to map names to IDs
name_to_id = {v: k for k, v in self.model.names.items()}
classes_ids = [name_to_id[name] for name in filter_classes if name in name_to_id]
# Run inference
# verbose=False prevents cluttering the console
results = self.model(image, conf=conf_threshold, classes=classes_ids, verbose=False)
# Plot results directly on the image
# results[0].plot() returns a BGR numpy array
annotated_frame = results[0].plot()
# YOLO plot() usually returns BGR, Gradio expects RGB
# We need to ensure color consistency
annotated_frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
return annotated_frame_rgb