Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| import random | |
| from PIL import Image, ImageDraw, ImageFont | |
| from huggingface_hub import hf_hub_download | |
| logger = logging.getLogger(__name__) | |
| class ObjectDetector: | |
| def __init__(self, model_key="yolov8n", device="cpu"): | |
| self.device = device | |
| self.model = None | |
| self.model_key = model_key.lower().replace(".pt", "") | |
| hf_map = { | |
| "yolov8n": ("ultralytics/yolov8", "yolov8n.pt"), | |
| "yolov8s": ("ultralytics/yolov8", "yolov8s.pt"), | |
| "yolov8l": ("ultralytics/yolov8", "yolov8l.pt"), | |
| "yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"), | |
| } | |
| if self.model_key not in hf_map: | |
| raise ValueError(f"Unsupported model key: {self.model_key}") | |
| repo_id, filename = hf_map[self.model_key] | |
| self.weights_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| cache_dir="models/detection/weights", | |
| force_download=False | |
| ) | |
| def load_model(self): | |
| logger.info(f"Loading model from path: {self.weights_path}") | |
| if self.model is None: | |
| import torch # Safe to import here | |
| from ultralytics import YOLO # Defer import | |
| if self.device == "cpu": | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
| # Initialize model | |
| self.model = YOLO(self.weights_path) | |
| # Move to CUDA only if necessary and safe | |
| if self.device == "cuda" and torch.cuda.is_available(): | |
| self.model.to("cuda") | |
| return self | |
| def predict(self, image: Image.Image, conf_threshold=0.25): | |
| self.load_model() | |
| if self.model is None: | |
| raise RuntimeError("YOLO model not loaded. Call load_model() first.") | |
| results = self.model(image) | |
| detections = [] | |
| for r in results: | |
| for box in r.boxes: | |
| detections.append({ | |
| "class_name": r.names[int(box.cls)], | |
| "confidence": float(box.conf), | |
| "bbox": box.xyxy[0].tolist() | |
| }) | |
| return detections | |
| def draw(self, image: Image.Image, detections, alpha=0.5): | |
| """ | |
| Draws thicker, per-class-colored bounding boxes and labels. | |
| Args: | |
| image (PIL.Image.Image): Original image. | |
| detections (List[Dict]): Each dict has "bbox", "class_name", "confidence". | |
| alpha (float): Blend strength for overlay. | |
| Returns: | |
| PIL.Image.Image: Blended image with overlays. | |
| """ | |
| # copy & overlay | |
| overlay = image.copy() | |
| draw = ImageDraw.Draw(overlay) | |
| # try a TTF font, fallback to default | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 18) | |
| except: | |
| font = ImageFont.load_default() | |
| # deterministic color per class | |
| class_colors = {} | |
| def get_color(cls): | |
| if cls not in class_colors: | |
| # seed by class name → same color every run | |
| rnd = random.Random(cls) | |
| class_colors[cls] = ( | |
| rnd.randint(100, 255), | |
| rnd.randint(100, 255), | |
| rnd.randint(100, 255), | |
| ) | |
| return class_colors[cls] | |
| for det in detections: | |
| x1, y1, x2, y2 = det["bbox"] | |
| cls_name = det["class_name"] | |
| conf = det["confidence"] | |
| label = f"{cls_name} {conf:.2f}" | |
| color = get_color(cls_name) | |
| # thicker box: draw multiple offsets | |
| for t in range(4): | |
| draw.rectangle( | |
| (x1 - t, y1 - t, x2 + t, y2 + t), | |
| outline=color | |
| ) | |
| # calculate text size | |
| text_box = draw.textbbox((x1, y1), label, font=font) | |
| tb_w = text_box[2] - text_box[0] | |
| tb_h = text_box[3] - text_box[1] | |
| # background rect for text | |
| bg = (x1, y1 - tb_h, x1 + tb_w + 6, y1) | |
| draw.rectangle(bg, fill=color) | |
| # draw text (with small padding) | |
| draw.text((x1 + 3, y1 - tb_h), label, font=font, fill="black") | |
| # blend and return | |
| return Image.blend(image, overlay, alpha) |