Spaces:
Runtime error
Runtime error
| import logging | |
| from PIL import Image, ImageDraw | |
| from huggingface_hub import hf_hub_download | |
| from ultralytics import YOLO | |
| import os | |
| import shutil | |
| # Setup logger | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| # Optional: clear weights cache each time (only for dev use) | |
| shutil.rmtree("models/detection/weights", ignore_errors=True) | |
| class ObjectDetector: | |
| def __init__(self, model_key="yolov8n", device="cpu"): | |
| """ | |
| Initializes an Ultralytics YOLO model using HF download path. | |
| Args: | |
| model_key (str): e.g. 'yolov8n', 'yolov8s', etc. | |
| device (str): 'cpu' or 'cuda' | |
| """ | |
| # Optional aliasing | |
| alias_map = { | |
| "yolov8n": "yolov8n", | |
| "yolov8s": "yolov8s", | |
| "yolov8l": "yolov8l", | |
| "yolov11b": "yolov11b" | |
| } | |
| resolved_key = alias_map.get(model_key.lower(), model_key.lower()) | |
| # HF repo map | |
| hf_map = { | |
| "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"), | |
| "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"), | |
| "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"), | |
| "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"), | |
| } | |
| if resolved_key not in hf_map: | |
| raise ValueError(f"Unsupported model key: {resolved_key}") | |
| repo_id, filename = hf_map[resolved_key] | |
| # π₯ Download from HF Hub | |
| weights_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| cache_dir="models/detection/weights", | |
| force_download=True # Optional: change to False for reuse | |
| ) | |
| logger.info(f"β Loaded YOLO model: {resolved_key} from {weights_path}") | |
| self.device = device | |
| self.model = YOLO(weights_path) | |
| def predict(self, image: Image.Image, conf_threshold=0.25): | |
| logger.info("Running object detection") | |
| results = self.model(image) | |
| detections = [] | |
| for r in results: | |
| for box in r.boxes: | |
| detections.append({ | |
| "class_name": r.names[int(box.cls)], | |
| "confidence": float(box.conf), | |
| "bbox": box.xyxy[0].tolist() | |
| }) | |
| logger.info(f"Detected {len(detections)} objects") | |
| return detections | |
| def draw(self, image: Image.Image, detections, alpha=0.5): | |
| overlay = image.copy() | |
| draw = ImageDraw.Draw(overlay) | |
| for det in detections: | |
| bbox = det["bbox"] | |
| label = f'{det["class_name"]} {det["confidence"]:.2f}' | |
| draw.rectangle(bbox, outline="red", width=2) | |
| draw.text((bbox[0], bbox[1]), label, fill="red") | |
| return Image.blend(image, overlay, alpha) | |